1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "quiche/quic/test_tools/quic_test_utils.h"
6
7 #include <algorithm>
8 #include <cstddef>
9 #include <cstdint>
10 #include <memory>
11 #include <utility>
12 #include <vector>
13
14 #include "absl/base/macros.h"
15 #include "absl/strings/string_view.h"
16 #include "openssl/chacha.h"
17 #include "openssl/sha.h"
18 #include "quiche/quic/core/crypto/crypto_framer.h"
19 #include "quiche/quic/core/crypto/crypto_handshake.h"
20 #include "quiche/quic/core/crypto/crypto_utils.h"
21 #include "quiche/quic/core/crypto/null_decrypter.h"
22 #include "quiche/quic/core/crypto/null_encrypter.h"
23 #include "quiche/quic/core/crypto/quic_decrypter.h"
24 #include "quiche/quic/core/crypto/quic_encrypter.h"
25 #include "quiche/quic/core/http/quic_spdy_client_session.h"
26 #include "quiche/quic/core/quic_config.h"
27 #include "quiche/quic/core/quic_data_writer.h"
28 #include "quiche/quic/core/quic_framer.h"
29 #include "quiche/quic/core/quic_packet_creator.h"
30 #include "quiche/quic/core/quic_packets.h"
31 #include "quiche/quic/core/quic_time.h"
32 #include "quiche/quic/core/quic_types.h"
33 #include "quiche/quic/core/quic_utils.h"
34 #include "quiche/quic/core/quic_versions.h"
35 #include "quiche/quic/platform/api/quic_flags.h"
36 #include "quiche/quic/platform/api/quic_logging.h"
37 #include "quiche/quic/test_tools/crypto_test_utils.h"
38 #include "quiche/quic/test_tools/quic_config_peer.h"
39 #include "quiche/quic/test_tools/quic_connection_peer.h"
40 #include "quiche/common/quiche_buffer_allocator.h"
41 #include "quiche/common/quiche_endian.h"
42 #include "quiche/common/simple_buffer_allocator.h"
43 #include "quiche/spdy/core/spdy_frame_builder.h"
44
45 using testing::_;
46 using testing::Invoke;
47 using testing::Return;
48
49 namespace quic {
50 namespace test {
51
TestConnectionId()52 QuicConnectionId TestConnectionId() {
53 // Chosen by fair dice roll.
54 // Guaranteed to be random.
55 return TestConnectionId(42);
56 }
57
TestConnectionId(uint64_t connection_number)58 QuicConnectionId TestConnectionId(uint64_t connection_number) {
59 const uint64_t connection_id64_net =
60 quiche::QuicheEndian::HostToNet64(connection_number);
61 return QuicConnectionId(reinterpret_cast<const char*>(&connection_id64_net),
62 sizeof(connection_id64_net));
63 }
64
TestConnectionIdNineBytesLong(uint64_t connection_number)65 QuicConnectionId TestConnectionIdNineBytesLong(uint64_t connection_number) {
66 const uint64_t connection_number_net =
67 quiche::QuicheEndian::HostToNet64(connection_number);
68 char connection_id_bytes[9] = {};
69 static_assert(
70 sizeof(connection_id_bytes) == 1 + sizeof(connection_number_net),
71 "bad lengths");
72 memcpy(connection_id_bytes + 1, &connection_number_net,
73 sizeof(connection_number_net));
74 return QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes));
75 }
76
TestConnectionIdToUInt64(QuicConnectionId connection_id)77 uint64_t TestConnectionIdToUInt64(QuicConnectionId connection_id) {
78 QUICHE_DCHECK_EQ(connection_id.length(), kQuicDefaultConnectionIdLength);
79 uint64_t connection_id64_net = 0;
80 memcpy(&connection_id64_net, connection_id.data(),
81 std::min<size_t>(static_cast<size_t>(connection_id.length()),
82 sizeof(connection_id64_net)));
83 return quiche::QuicheEndian::NetToHost64(connection_id64_net);
84 }
85
CreateStatelessResetTokenForTest()86 std::vector<uint8_t> CreateStatelessResetTokenForTest() {
87 static constexpr uint8_t kStatelessResetTokenDataForTest[16] = {
88 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
89 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F};
90 return std::vector<uint8_t>(kStatelessResetTokenDataForTest,
91 kStatelessResetTokenDataForTest +
92 sizeof(kStatelessResetTokenDataForTest));
93 }
94
TestHostname()95 std::string TestHostname() { return "test.example.com"; }
96
TestServerId()97 QuicServerId TestServerId() { return QuicServerId(TestHostname(), kTestPort); }
98
InitAckFrame(const std::vector<QuicAckBlock> & ack_blocks)99 QuicAckFrame InitAckFrame(const std::vector<QuicAckBlock>& ack_blocks) {
100 QUICHE_DCHECK_GT(ack_blocks.size(), 0u);
101
102 QuicAckFrame ack;
103 QuicPacketNumber end_of_previous_block(1);
104 for (const QuicAckBlock& block : ack_blocks) {
105 QUICHE_DCHECK_GE(block.start, end_of_previous_block);
106 QUICHE_DCHECK_GT(block.limit, block.start);
107 ack.packets.AddRange(block.start, block.limit);
108 end_of_previous_block = block.limit;
109 }
110
111 ack.largest_acked = ack.packets.Max();
112
113 return ack;
114 }
115
InitAckFrame(uint64_t largest_acked)116 QuicAckFrame InitAckFrame(uint64_t largest_acked) {
117 return InitAckFrame(QuicPacketNumber(largest_acked));
118 }
119
InitAckFrame(QuicPacketNumber largest_acked)120 QuicAckFrame InitAckFrame(QuicPacketNumber largest_acked) {
121 return InitAckFrame({{QuicPacketNumber(1), largest_acked + 1}});
122 }
123
MakeAckFrameWithAckBlocks(size_t num_ack_blocks,uint64_t least_unacked)124 QuicAckFrame MakeAckFrameWithAckBlocks(size_t num_ack_blocks,
125 uint64_t least_unacked) {
126 QuicAckFrame ack;
127 ack.largest_acked = QuicPacketNumber(2 * num_ack_blocks + least_unacked);
128 // Add enough received packets to get num_ack_blocks ack blocks.
129 for (QuicPacketNumber i = QuicPacketNumber(2);
130 i < QuicPacketNumber(2 * num_ack_blocks + 1); i += 2) {
131 ack.packets.Add(i + least_unacked);
132 }
133 return ack;
134 }
135
MakeAckFrameWithGaps(uint64_t gap_size,size_t max_num_gaps,uint64_t largest_acked)136 QuicAckFrame MakeAckFrameWithGaps(uint64_t gap_size, size_t max_num_gaps,
137 uint64_t largest_acked) {
138 QuicAckFrame ack;
139 ack.largest_acked = QuicPacketNumber(largest_acked);
140 ack.packets.Add(QuicPacketNumber(largest_acked));
141 for (size_t i = 0; i < max_num_gaps; ++i) {
142 if (largest_acked <= gap_size) {
143 break;
144 }
145 largest_acked -= gap_size;
146 ack.packets.Add(QuicPacketNumber(largest_acked));
147 }
148 return ack;
149 }
150
HeaderToEncryptionLevel(const QuicPacketHeader & header)151 EncryptionLevel HeaderToEncryptionLevel(const QuicPacketHeader& header) {
152 if (header.form == IETF_QUIC_SHORT_HEADER_PACKET) {
153 return ENCRYPTION_FORWARD_SECURE;
154 } else if (header.form == IETF_QUIC_LONG_HEADER_PACKET) {
155 if (header.long_packet_type == HANDSHAKE) {
156 return ENCRYPTION_HANDSHAKE;
157 } else if (header.long_packet_type == ZERO_RTT_PROTECTED) {
158 return ENCRYPTION_ZERO_RTT;
159 }
160 }
161 return ENCRYPTION_INITIAL;
162 }
163
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames)164 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
165 QuicFramer* framer, const QuicPacketHeader& header,
166 const QuicFrames& frames) {
167 const size_t max_plaintext_size =
168 framer->GetMaxPlaintextSize(kMaxOutgoingPacketSize);
169 size_t packet_size = GetPacketHeaderSize(framer->transport_version(), header);
170 for (size_t i = 0; i < frames.size(); ++i) {
171 QUICHE_DCHECK_LE(packet_size, max_plaintext_size);
172 bool first_frame = i == 0;
173 bool last_frame = i == frames.size() - 1;
174 const size_t frame_size = framer->GetSerializedFrameLength(
175 frames[i], max_plaintext_size - packet_size, first_frame, last_frame,
176 header.packet_number_length);
177 QUICHE_DCHECK(frame_size);
178 packet_size += frame_size;
179 }
180 return BuildUnsizedDataPacket(framer, header, frames, packet_size);
181 }
182
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames,size_t packet_size)183 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
184 QuicFramer* framer, const QuicPacketHeader& header,
185 const QuicFrames& frames, size_t packet_size) {
186 char* buffer = new char[packet_size];
187 EncryptionLevel level = HeaderToEncryptionLevel(header);
188 size_t length =
189 framer->BuildDataPacket(header, frames, buffer, packet_size, level);
190
191 if (length == 0) {
192 delete[] buffer;
193 return nullptr;
194 }
195 // Re-construct the data packet with data ownership.
196 return std::make_unique<QuicPacket>(
197 buffer, length, /* owns_buffer */ true,
198 GetIncludedDestinationConnectionIdLength(header),
199 GetIncludedSourceConnectionIdLength(header), header.version_flag,
200 header.nonce != nullptr, header.packet_number_length,
201 header.retry_token_length_length, header.retry_token.length(),
202 header.length_length);
203 }
204
Sha1Hash(absl::string_view data)205 std::string Sha1Hash(absl::string_view data) {
206 char buffer[SHA_DIGEST_LENGTH];
207 SHA1(reinterpret_cast<const uint8_t*>(data.data()), data.size(),
208 reinterpret_cast<uint8_t*>(buffer));
209 return std::string(buffer, ABSL_ARRAYSIZE(buffer));
210 }
211
ClearControlFrame(const QuicFrame & frame)212 bool ClearControlFrame(const QuicFrame& frame) {
213 DeleteFrame(&const_cast<QuicFrame&>(frame));
214 return true;
215 }
216
ClearControlFrameWithTransmissionType(const QuicFrame & frame,TransmissionType)217 bool ClearControlFrameWithTransmissionType(const QuicFrame& frame,
218 TransmissionType /*type*/) {
219 return ClearControlFrame(frame);
220 }
221
RandUint64()222 uint64_t SimpleRandom::RandUint64() {
223 uint64_t result;
224 RandBytes(&result, sizeof(result));
225 return result;
226 }
227
RandBytes(void * data,size_t len)228 void SimpleRandom::RandBytes(void* data, size_t len) {
229 uint8_t* data_bytes = reinterpret_cast<uint8_t*>(data);
230 while (len > 0) {
231 const size_t buffer_left = sizeof(buffer_) - buffer_offset_;
232 const size_t to_copy = std::min(buffer_left, len);
233 memcpy(data_bytes, buffer_ + buffer_offset_, to_copy);
234 data_bytes += to_copy;
235 buffer_offset_ += to_copy;
236 len -= to_copy;
237
238 if (buffer_offset_ == sizeof(buffer_)) {
239 FillBuffer();
240 }
241 }
242 }
243
InsecureRandBytes(void * data,size_t len)244 void SimpleRandom::InsecureRandBytes(void* data, size_t len) {
245 RandBytes(data, len);
246 }
247
InsecureRandUint64()248 uint64_t SimpleRandom::InsecureRandUint64() { return RandUint64(); }
249
FillBuffer()250 void SimpleRandom::FillBuffer() {
251 uint8_t nonce[12];
252 memcpy(nonce, buffer_, sizeof(nonce));
253 CRYPTO_chacha_20(buffer_, buffer_, sizeof(buffer_), key_, nonce, 0);
254 buffer_offset_ = 0;
255 }
256
set_seed(uint64_t seed)257 void SimpleRandom::set_seed(uint64_t seed) {
258 static_assert(sizeof(key_) == SHA256_DIGEST_LENGTH, "Key has to be 256 bits");
259 SHA256(reinterpret_cast<const uint8_t*>(&seed), sizeof(seed), key_);
260
261 memset(buffer_, 0, sizeof(buffer_));
262 FillBuffer();
263 }
264
MockFramerVisitor()265 MockFramerVisitor::MockFramerVisitor() {
266 // By default, we want to accept packets.
267 ON_CALL(*this, OnProtocolVersionMismatch(_))
268 .WillByDefault(testing::Return(false));
269
270 // By default, we want to accept packets.
271 ON_CALL(*this, OnUnauthenticatedHeader(_))
272 .WillByDefault(testing::Return(true));
273
274 ON_CALL(*this, OnUnauthenticatedPublicHeader(_))
275 .WillByDefault(testing::Return(true));
276
277 ON_CALL(*this, OnPacketHeader(_)).WillByDefault(testing::Return(true));
278
279 ON_CALL(*this, OnStreamFrame(_)).WillByDefault(testing::Return(true));
280
281 ON_CALL(*this, OnCryptoFrame(_)).WillByDefault(testing::Return(true));
282
283 ON_CALL(*this, OnStopWaitingFrame(_)).WillByDefault(testing::Return(true));
284
285 ON_CALL(*this, OnPaddingFrame(_)).WillByDefault(testing::Return(true));
286
287 ON_CALL(*this, OnPingFrame(_)).WillByDefault(testing::Return(true));
288
289 ON_CALL(*this, OnRstStreamFrame(_)).WillByDefault(testing::Return(true));
290
291 ON_CALL(*this, OnConnectionCloseFrame(_))
292 .WillByDefault(testing::Return(true));
293
294 ON_CALL(*this, OnStopSendingFrame(_)).WillByDefault(testing::Return(true));
295
296 ON_CALL(*this, OnPathChallengeFrame(_)).WillByDefault(testing::Return(true));
297
298 ON_CALL(*this, OnPathResponseFrame(_)).WillByDefault(testing::Return(true));
299
300 ON_CALL(*this, OnGoAwayFrame(_)).WillByDefault(testing::Return(true));
301 ON_CALL(*this, OnMaxStreamsFrame(_)).WillByDefault(testing::Return(true));
302 ON_CALL(*this, OnStreamsBlockedFrame(_)).WillByDefault(testing::Return(true));
303 }
304
~MockFramerVisitor()305 MockFramerVisitor::~MockFramerVisitor() {}
306
OnProtocolVersionMismatch(ParsedQuicVersion)307 bool NoOpFramerVisitor::OnProtocolVersionMismatch(
308 ParsedQuicVersion /*version*/) {
309 return false;
310 }
311
OnUnauthenticatedPublicHeader(const QuicPacketHeader &)312 bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
313 const QuicPacketHeader& /*header*/) {
314 return true;
315 }
316
OnUnauthenticatedHeader(const QuicPacketHeader &)317 bool NoOpFramerVisitor::OnUnauthenticatedHeader(
318 const QuicPacketHeader& /*header*/) {
319 return true;
320 }
321
OnPacketHeader(const QuicPacketHeader &)322 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& /*header*/) {
323 return true;
324 }
325
OnCoalescedPacket(const QuicEncryptedPacket &)326 void NoOpFramerVisitor::OnCoalescedPacket(
327 const QuicEncryptedPacket& /*packet*/) {}
328
OnUndecryptablePacket(const QuicEncryptedPacket &,EncryptionLevel,bool)329 void NoOpFramerVisitor::OnUndecryptablePacket(
330 const QuicEncryptedPacket& /*packet*/, EncryptionLevel /*decryption_level*/,
331 bool /*has_decryption_key*/) {}
332
OnStreamFrame(const QuicStreamFrame &)333 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& /*frame*/) {
334 return true;
335 }
336
OnCryptoFrame(const QuicCryptoFrame &)337 bool NoOpFramerVisitor::OnCryptoFrame(const QuicCryptoFrame& /*frame*/) {
338 return true;
339 }
340
OnAckFrameStart(QuicPacketNumber,QuicTime::Delta)341 bool NoOpFramerVisitor::OnAckFrameStart(QuicPacketNumber /*largest_acked*/,
342 QuicTime::Delta /*ack_delay_time*/) {
343 return true;
344 }
345
OnAckRange(QuicPacketNumber,QuicPacketNumber)346 bool NoOpFramerVisitor::OnAckRange(QuicPacketNumber /*start*/,
347 QuicPacketNumber /*end*/) {
348 return true;
349 }
350
OnAckTimestamp(QuicPacketNumber,QuicTime)351 bool NoOpFramerVisitor::OnAckTimestamp(QuicPacketNumber /*packet_number*/,
352 QuicTime /*timestamp*/) {
353 return true;
354 }
355
OnAckFrameEnd(QuicPacketNumber,const std::optional<QuicEcnCounts> &)356 bool NoOpFramerVisitor::OnAckFrameEnd(
357 QuicPacketNumber /*start*/,
358 const std::optional<QuicEcnCounts>& /*ecn_counts*/) {
359 return true;
360 }
361
OnStopWaitingFrame(const QuicStopWaitingFrame &)362 bool NoOpFramerVisitor::OnStopWaitingFrame(
363 const QuicStopWaitingFrame& /*frame*/) {
364 return true;
365 }
366
OnPaddingFrame(const QuicPaddingFrame &)367 bool NoOpFramerVisitor::OnPaddingFrame(const QuicPaddingFrame& /*frame*/) {
368 return true;
369 }
370
OnPingFrame(const QuicPingFrame &)371 bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame& /*frame*/) {
372 return true;
373 }
374
OnRstStreamFrame(const QuicRstStreamFrame &)375 bool NoOpFramerVisitor::OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) {
376 return true;
377 }
378
OnConnectionCloseFrame(const QuicConnectionCloseFrame &)379 bool NoOpFramerVisitor::OnConnectionCloseFrame(
380 const QuicConnectionCloseFrame& /*frame*/) {
381 return true;
382 }
383
OnNewConnectionIdFrame(const QuicNewConnectionIdFrame &)384 bool NoOpFramerVisitor::OnNewConnectionIdFrame(
385 const QuicNewConnectionIdFrame& /*frame*/) {
386 return true;
387 }
388
OnRetireConnectionIdFrame(const QuicRetireConnectionIdFrame &)389 bool NoOpFramerVisitor::OnRetireConnectionIdFrame(
390 const QuicRetireConnectionIdFrame& /*frame*/) {
391 return true;
392 }
393
OnNewTokenFrame(const QuicNewTokenFrame &)394 bool NoOpFramerVisitor::OnNewTokenFrame(const QuicNewTokenFrame& /*frame*/) {
395 return true;
396 }
397
OnStopSendingFrame(const QuicStopSendingFrame &)398 bool NoOpFramerVisitor::OnStopSendingFrame(
399 const QuicStopSendingFrame& /*frame*/) {
400 return true;
401 }
402
OnPathChallengeFrame(const QuicPathChallengeFrame &)403 bool NoOpFramerVisitor::OnPathChallengeFrame(
404 const QuicPathChallengeFrame& /*frame*/) {
405 return true;
406 }
407
OnPathResponseFrame(const QuicPathResponseFrame &)408 bool NoOpFramerVisitor::OnPathResponseFrame(
409 const QuicPathResponseFrame& /*frame*/) {
410 return true;
411 }
412
OnGoAwayFrame(const QuicGoAwayFrame &)413 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) {
414 return true;
415 }
416
OnMaxStreamsFrame(const QuicMaxStreamsFrame &)417 bool NoOpFramerVisitor::OnMaxStreamsFrame(
418 const QuicMaxStreamsFrame& /*frame*/) {
419 return true;
420 }
421
OnStreamsBlockedFrame(const QuicStreamsBlockedFrame &)422 bool NoOpFramerVisitor::OnStreamsBlockedFrame(
423 const QuicStreamsBlockedFrame& /*frame*/) {
424 return true;
425 }
426
OnWindowUpdateFrame(const QuicWindowUpdateFrame &)427 bool NoOpFramerVisitor::OnWindowUpdateFrame(
428 const QuicWindowUpdateFrame& /*frame*/) {
429 return true;
430 }
431
OnBlockedFrame(const QuicBlockedFrame &)432 bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& /*frame*/) {
433 return true;
434 }
435
OnMessageFrame(const QuicMessageFrame &)436 bool NoOpFramerVisitor::OnMessageFrame(const QuicMessageFrame& /*frame*/) {
437 return true;
438 }
439
OnHandshakeDoneFrame(const QuicHandshakeDoneFrame &)440 bool NoOpFramerVisitor::OnHandshakeDoneFrame(
441 const QuicHandshakeDoneFrame& /*frame*/) {
442 return true;
443 }
444
OnAckFrequencyFrame(const QuicAckFrequencyFrame &)445 bool NoOpFramerVisitor::OnAckFrequencyFrame(
446 const QuicAckFrequencyFrame& /*frame*/) {
447 return true;
448 }
449
OnResetStreamAtFrame(const QuicResetStreamAtFrame &)450 bool NoOpFramerVisitor::OnResetStreamAtFrame(
451 const QuicResetStreamAtFrame& /*frame*/) {
452 return true;
453 }
454
IsValidStatelessResetToken(const StatelessResetToken &) const455 bool NoOpFramerVisitor::IsValidStatelessResetToken(
456 const StatelessResetToken& /*token*/) const {
457 return false;
458 }
459
MockQuicConnectionVisitor()460 MockQuicConnectionVisitor::MockQuicConnectionVisitor() {
461 ON_CALL(*this, GetFlowControlSendWindowSize(_))
462 .WillByDefault(Return(std::numeric_limits<QuicByteCount>::max()));
463 }
464
~MockQuicConnectionVisitor()465 MockQuicConnectionVisitor::~MockQuicConnectionVisitor() {}
466
MockQuicConnectionHelper()467 MockQuicConnectionHelper::MockQuicConnectionHelper() {}
468
~MockQuicConnectionHelper()469 MockQuicConnectionHelper::~MockQuicConnectionHelper() {}
470
GetClock() const471 const MockClock* MockQuicConnectionHelper::GetClock() const { return &clock_; }
472
GetClock()473 MockClock* MockQuicConnectionHelper::GetClock() { return &clock_; }
474
GetRandomGenerator()475 QuicRandom* MockQuicConnectionHelper::GetRandomGenerator() {
476 return &random_generator_;
477 }
478
CreateAlarm(QuicAlarm::Delegate * delegate)479 QuicAlarm* MockAlarmFactory::CreateAlarm(QuicAlarm::Delegate* delegate) {
480 return new MockAlarmFactory::TestAlarm(
481 QuicArenaScopedPtr<QuicAlarm::Delegate>(delegate));
482 }
483
CreateAlarm(QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,QuicConnectionArena * arena)484 QuicArenaScopedPtr<QuicAlarm> MockAlarmFactory::CreateAlarm(
485 QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,
486 QuicConnectionArena* arena) {
487 if (arena != nullptr) {
488 return arena->New<TestAlarm>(std::move(delegate));
489 } else {
490 return QuicArenaScopedPtr<TestAlarm>(new TestAlarm(std::move(delegate)));
491 }
492 }
493
494 quiche::QuicheBufferAllocator*
GetStreamSendBufferAllocator()495 MockQuicConnectionHelper::GetStreamSendBufferAllocator() {
496 return &buffer_allocator_;
497 }
498
AdvanceTime(QuicTime::Delta delta)499 void MockQuicConnectionHelper::AdvanceTime(QuicTime::Delta delta) {
500 clock_.AdvanceTime(delta);
501 }
502
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)503 MockQuicConnection::MockQuicConnection(QuicConnectionHelperInterface* helper,
504 QuicAlarmFactory* alarm_factory,
505 Perspective perspective)
506 : MockQuicConnection(TestConnectionId(),
507 QuicSocketAddress(TestPeerIPAddress(), kTestPort),
508 helper, alarm_factory, perspective,
509 ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
510
MockQuicConnection(QuicSocketAddress address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)511 MockQuicConnection::MockQuicConnection(QuicSocketAddress address,
512 QuicConnectionHelperInterface* helper,
513 QuicAlarmFactory* alarm_factory,
514 Perspective perspective)
515 : MockQuicConnection(TestConnectionId(), address, helper, alarm_factory,
516 perspective,
517 ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
518
MockQuicConnection(QuicConnectionId connection_id,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)519 MockQuicConnection::MockQuicConnection(QuicConnectionId connection_id,
520 QuicConnectionHelperInterface* helper,
521 QuicAlarmFactory* alarm_factory,
522 Perspective perspective)
523 : MockQuicConnection(connection_id,
524 QuicSocketAddress(TestPeerIPAddress(), kTestPort),
525 helper, alarm_factory, perspective,
526 ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
527
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)528 MockQuicConnection::MockQuicConnection(
529 QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
530 Perspective perspective, const ParsedQuicVersionVector& supported_versions)
531 : MockQuicConnection(
532 TestConnectionId(), QuicSocketAddress(TestPeerIPAddress(), kTestPort),
533 helper, alarm_factory, perspective, supported_versions) {}
534
MockQuicConnection(QuicConnectionId connection_id,QuicSocketAddress initial_peer_address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)535 MockQuicConnection::MockQuicConnection(
536 QuicConnectionId connection_id, QuicSocketAddress initial_peer_address,
537 QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
538 Perspective perspective, const ParsedQuicVersionVector& supported_versions)
539 : QuicConnection(
540 connection_id,
541 /*initial_self_address=*/QuicSocketAddress(QuicIpAddress::Any4(), 5),
542 initial_peer_address, helper, alarm_factory,
543 new testing::NiceMock<MockPacketWriter>(),
544 /* owns_writer= */ true, perspective, supported_versions,
545 connection_id_generator_) {
546 ON_CALL(*this, OnError(_))
547 .WillByDefault(
548 Invoke(this, &PacketSavingConnection::QuicConnection_OnError));
549 ON_CALL(*this, SendCryptoData(_, _, _))
550 .WillByDefault(
551 Invoke(this, &MockQuicConnection::QuicConnection_SendCryptoData));
552
553 SetSelfAddress(QuicSocketAddress(QuicIpAddress::Any4(), 5));
554 }
555
~MockQuicConnection()556 MockQuicConnection::~MockQuicConnection() {}
557
AdvanceTime(QuicTime::Delta delta)558 void MockQuicConnection::AdvanceTime(QuicTime::Delta delta) {
559 static_cast<MockQuicConnectionHelper*>(helper())->AdvanceTime(delta);
560 }
561
OnProtocolVersionMismatch(ParsedQuicVersion)562 bool MockQuicConnection::OnProtocolVersionMismatch(
563 ParsedQuicVersion /*version*/) {
564 return false;
565 }
566
PacketSavingConnection(MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)567 PacketSavingConnection::PacketSavingConnection(MockQuicConnectionHelper* helper,
568 QuicAlarmFactory* alarm_factory,
569 Perspective perspective)
570 : MockQuicConnection(helper, alarm_factory, perspective),
571 mock_helper_(helper) {}
572
PacketSavingConnection(MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)573 PacketSavingConnection::PacketSavingConnection(
574 MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
575 Perspective perspective, const ParsedQuicVersionVector& supported_versions)
576 : MockQuicConnection(helper, alarm_factory, perspective,
577 supported_versions),
578 mock_helper_(helper) {}
579
~PacketSavingConnection()580 PacketSavingConnection::~PacketSavingConnection() {}
581
GetSerializedPacketFate(bool,EncryptionLevel)582 SerializedPacketFate PacketSavingConnection::GetSerializedPacketFate(
583 bool /*is_mtu_discovery*/, EncryptionLevel /*encryption_level*/) {
584 return SEND_TO_WRITER;
585 }
586
SendOrQueuePacket(SerializedPacket packet)587 void PacketSavingConnection::SendOrQueuePacket(SerializedPacket packet) {
588 encrypted_packets_.push_back(std::make_unique<QuicEncryptedPacket>(
589 CopyBuffer(packet), packet.encrypted_length, true));
590 MockClock& clock = *mock_helper_->GetClock();
591 clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(10));
592 // Transfer ownership of the packet to the SentPacketManager and the
593 // ack notifier to the AckNotifierManager.
594 OnPacketSent(packet.encryption_level, packet.transmission_type);
595 QuicConnectionPeer::GetSentPacketManager(this)->OnPacketSent(
596 &packet, clock.ApproximateNow(), NOT_RETRANSMISSION,
597 HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT);
598 }
599
GetPackets() const600 std::vector<const QuicEncryptedPacket*> PacketSavingConnection::GetPackets()
601 const {
602 std::vector<const QuicEncryptedPacket*> packets;
603 for (size_t i = num_cleared_packets_; i < encrypted_packets_.size(); ++i) {
604 packets.push_back(encrypted_packets_[i].get());
605 }
606 return packets;
607 }
608
ClearPackets()609 void PacketSavingConnection::ClearPackets() {
610 num_cleared_packets_ = encrypted_packets_.size();
611 }
612
MockQuicSession(QuicConnection * connection)613 MockQuicSession::MockQuicSession(QuicConnection* connection)
614 : MockQuicSession(connection, true) {}
615
MockQuicSession(QuicConnection * connection,bool create_mock_crypto_stream)616 MockQuicSession::MockQuicSession(QuicConnection* connection,
617 bool create_mock_crypto_stream)
618 : QuicSession(connection, nullptr, DefaultQuicConfig(),
619 connection->supported_versions(),
620 /*num_expected_unidirectional_static_streams = */ 0) {
621 if (create_mock_crypto_stream) {
622 crypto_stream_ =
623 std::make_unique<testing::NiceMock<MockQuicCryptoStream>>(this);
624 }
625 ON_CALL(*this, WritevData(_, _, _, _, _, _))
626 .WillByDefault(testing::Return(QuicConsumedData(0, false)));
627 }
628
~MockQuicSession()629 MockQuicSession::~MockQuicSession() { DeleteConnection(); }
630
GetMutableCryptoStream()631 QuicCryptoStream* MockQuicSession::GetMutableCryptoStream() {
632 return crypto_stream_.get();
633 }
634
GetCryptoStream() const635 const QuicCryptoStream* MockQuicSession::GetCryptoStream() const {
636 return crypto_stream_.get();
637 }
638
SetCryptoStream(QuicCryptoStream * crypto_stream)639 void MockQuicSession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
640 crypto_stream_.reset(crypto_stream);
641 }
642
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,std::optional<EncryptionLevel>)643 QuicConsumedData MockQuicSession::ConsumeData(
644 QuicStreamId id, size_t write_length, QuicStreamOffset offset,
645 StreamSendingState state, TransmissionType /*type*/,
646 std::optional<EncryptionLevel> /*level*/) {
647 if (write_length > 0) {
648 auto buf = std::make_unique<char[]>(write_length);
649 QuicStream* stream = GetOrCreateStream(id);
650 QUICHE_DCHECK(stream);
651 QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
652 stream->WriteStreamData(offset, write_length, &writer);
653 } else {
654 QUICHE_DCHECK(state != NO_FIN);
655 }
656 return QuicConsumedData(write_length, state != NO_FIN);
657 }
658
MockQuicCryptoStream(QuicSession * session)659 MockQuicCryptoStream::MockQuicCryptoStream(QuicSession* session)
660 : QuicCryptoStream(session), params_(new QuicCryptoNegotiatedParameters) {}
661
~MockQuicCryptoStream()662 MockQuicCryptoStream::~MockQuicCryptoStream() {}
663
EarlyDataReason() const664 ssl_early_data_reason_t MockQuicCryptoStream::EarlyDataReason() const {
665 return ssl_early_data_unknown;
666 }
667
one_rtt_keys_available() const668 bool MockQuicCryptoStream::one_rtt_keys_available() const { return false; }
669
670 const QuicCryptoNegotiatedParameters&
crypto_negotiated_params() const671 MockQuicCryptoStream::crypto_negotiated_params() const {
672 return *params_;
673 }
674
crypto_message_parser()675 CryptoMessageParser* MockQuicCryptoStream::crypto_message_parser() {
676 return &crypto_framer_;
677 }
678
MockQuicSpdySession(QuicConnection * connection)679 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection)
680 : MockQuicSpdySession(connection, true) {}
681
MockQuicSpdySession(QuicConnection * connection,bool create_mock_crypto_stream)682 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection,
683 bool create_mock_crypto_stream)
684 : QuicSpdySession(connection, nullptr, DefaultQuicConfig(),
685 connection->supported_versions()) {
686 if (create_mock_crypto_stream) {
687 crypto_stream_ = std::make_unique<MockQuicCryptoStream>(this);
688 }
689
690 ON_CALL(*this, WritevData(_, _, _, _, _, _))
691 .WillByDefault(testing::Return(QuicConsumedData(0, false)));
692
693 ON_CALL(*this, SendWindowUpdate(_, _))
694 .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
695 return QuicSpdySession::SendWindowUpdate(id, byte_offset);
696 });
697
698 ON_CALL(*this, SendBlocked(_, _))
699 .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
700 return QuicSpdySession::SendBlocked(id, byte_offset);
701 });
702
703 ON_CALL(*this, OnCongestionWindowChange(_)).WillByDefault(testing::Return());
704 }
705
~MockQuicSpdySession()706 MockQuicSpdySession::~MockQuicSpdySession() { DeleteConnection(); }
707
GetMutableCryptoStream()708 QuicCryptoStream* MockQuicSpdySession::GetMutableCryptoStream() {
709 return crypto_stream_.get();
710 }
711
GetCryptoStream() const712 const QuicCryptoStream* MockQuicSpdySession::GetCryptoStream() const {
713 return crypto_stream_.get();
714 }
715
SetCryptoStream(QuicCryptoStream * crypto_stream)716 void MockQuicSpdySession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
717 crypto_stream_.reset(crypto_stream);
718 }
719
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,std::optional<EncryptionLevel>)720 QuicConsumedData MockQuicSpdySession::ConsumeData(
721 QuicStreamId id, size_t write_length, QuicStreamOffset offset,
722 StreamSendingState state, TransmissionType /*type*/,
723 std::optional<EncryptionLevel> /*level*/) {
724 if (write_length > 0) {
725 auto buf = std::make_unique<char[]>(write_length);
726 QuicStream* stream = GetOrCreateStream(id);
727 QUICHE_DCHECK(stream);
728 QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
729 stream->WriteStreamData(offset, write_length, &writer);
730 } else {
731 QUICHE_DCHECK(state != NO_FIN);
732 }
733 return QuicConsumedData(write_length, state != NO_FIN);
734 }
735
TestQuicSpdyServerSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)736 TestQuicSpdyServerSession::TestQuicSpdyServerSession(
737 QuicConnection* connection, const QuicConfig& config,
738 const ParsedQuicVersionVector& supported_versions,
739 const QuicCryptoServerConfig* crypto_config,
740 QuicCompressedCertsCache* compressed_certs_cache)
741 : QuicServerSessionBase(config, supported_versions, connection, &visitor_,
742 &helper_, crypto_config, compressed_certs_cache) {
743 ON_CALL(helper_, CanAcceptClientHello(_, _, _, _, _))
744 .WillByDefault(testing::Return(true));
745 }
746
~TestQuicSpdyServerSession()747 TestQuicSpdyServerSession::~TestQuicSpdyServerSession() { DeleteConnection(); }
748
749 std::unique_ptr<QuicCryptoServerStreamBase>
CreateQuicCryptoServerStream(const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)750 TestQuicSpdyServerSession::CreateQuicCryptoServerStream(
751 const QuicCryptoServerConfig* crypto_config,
752 QuicCompressedCertsCache* compressed_certs_cache) {
753 return CreateCryptoServerStream(crypto_config, compressed_certs_cache, this,
754 &helper_);
755 }
756
757 QuicCryptoServerStreamBase*
GetMutableCryptoStream()758 TestQuicSpdyServerSession::GetMutableCryptoStream() {
759 return QuicServerSessionBase::GetMutableCryptoStream();
760 }
761
GetCryptoStream() const762 const QuicCryptoServerStreamBase* TestQuicSpdyServerSession::GetCryptoStream()
763 const {
764 return QuicServerSessionBase::GetCryptoStream();
765 }
766
TestQuicSpdyClientSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicServerId & server_id,QuicCryptoClientConfig * crypto_config,std::optional<QuicSSLConfig> ssl_config)767 TestQuicSpdyClientSession::TestQuicSpdyClientSession(
768 QuicConnection* connection, const QuicConfig& config,
769 const ParsedQuicVersionVector& supported_versions,
770 const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config,
771 std::optional<QuicSSLConfig> ssl_config)
772 : QuicSpdyClientSessionBase(connection, nullptr, config,
773 supported_versions),
774 ssl_config_(std::move(ssl_config)) {
775 // TODO(b/153726130): Consider adding SetServerApplicationStateForResumption
776 // calls in tests and set |has_application_state| to true.
777 crypto_stream_ = std::make_unique<QuicCryptoClientStream>(
778 server_id, this, crypto_test_utils::ProofVerifyContextForTesting(),
779 crypto_config, this, /*has_application_state = */ false);
780 Initialize();
781 ON_CALL(*this, OnConfigNegotiated())
782 .WillByDefault(
783 Invoke(this, &TestQuicSpdyClientSession::RealOnConfigNegotiated));
784 }
785
~TestQuicSpdyClientSession()786 TestQuicSpdyClientSession::~TestQuicSpdyClientSession() {}
787
GetMutableCryptoStream()788 QuicCryptoClientStream* TestQuicSpdyClientSession::GetMutableCryptoStream() {
789 return crypto_stream_.get();
790 }
791
GetCryptoStream() const792 const QuicCryptoClientStream* TestQuicSpdyClientSession::GetCryptoStream()
793 const {
794 return crypto_stream_.get();
795 }
796
RealOnConfigNegotiated()797 void TestQuicSpdyClientSession::RealOnConfigNegotiated() {
798 QuicSpdyClientSessionBase::OnConfigNegotiated();
799 }
800
MockPacketWriter()801 MockPacketWriter::MockPacketWriter() {
802 ON_CALL(*this, GetMaxPacketSize(_))
803 .WillByDefault(testing::Return(kMaxOutgoingPacketSize));
804 ON_CALL(*this, IsBatchMode()).WillByDefault(testing::Return(false));
805 ON_CALL(*this, GetNextWriteLocation(_, _))
806 .WillByDefault(testing::Return(QuicPacketBuffer()));
807 ON_CALL(*this, Flush())
808 .WillByDefault(testing::Return(WriteResult(WRITE_STATUS_OK, 0)));
809 ON_CALL(*this, SupportsReleaseTime()).WillByDefault(testing::Return(false));
810 }
811
~MockPacketWriter()812 MockPacketWriter::~MockPacketWriter() {}
813
MockSendAlgorithm()814 MockSendAlgorithm::MockSendAlgorithm() {
815 ON_CALL(*this, PacingRate(_))
816 .WillByDefault(testing::Return(QuicBandwidth::Zero()));
817 ON_CALL(*this, BandwidthEstimate())
818 .WillByDefault(testing::Return(QuicBandwidth::Zero()));
819 }
820
~MockSendAlgorithm()821 MockSendAlgorithm::~MockSendAlgorithm() {}
822
MockLossAlgorithm()823 MockLossAlgorithm::MockLossAlgorithm() {}
824
~MockLossAlgorithm()825 MockLossAlgorithm::~MockLossAlgorithm() {}
826
MockAckListener()827 MockAckListener::MockAckListener() {}
828
~MockAckListener()829 MockAckListener::~MockAckListener() {}
830
MockNetworkChangeVisitor()831 MockNetworkChangeVisitor::MockNetworkChangeVisitor() {}
832
~MockNetworkChangeVisitor()833 MockNetworkChangeVisitor::~MockNetworkChangeVisitor() {}
834
TestPeerIPAddress()835 QuicIpAddress TestPeerIPAddress() { return QuicIpAddress::Loopback4(); }
836
QuicVersionMax()837 ParsedQuicVersion QuicVersionMax() { return AllSupportedVersions().front(); }
838
QuicVersionMin()839 ParsedQuicVersion QuicVersionMin() { return AllSupportedVersions().back(); }
840
DisableQuicVersionsWithTls()841 void DisableQuicVersionsWithTls() {
842 for (const ParsedQuicVersion& version : AllSupportedVersionsWithTls()) {
843 QuicDisableVersion(version);
844 }
845 }
846
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data)847 QuicEncryptedPacket* ConstructEncryptedPacket(
848 QuicConnectionId destination_connection_id,
849 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
850 uint64_t packet_number, const std::string& data) {
851 return ConstructEncryptedPacket(
852 destination_connection_id, source_connection_id, version_flag, reset_flag,
853 packet_number, data, CONNECTION_ID_PRESENT, CONNECTION_ID_ABSENT,
854 PACKET_4BYTE_PACKET_NUMBER);
855 }
856
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length)857 QuicEncryptedPacket* ConstructEncryptedPacket(
858 QuicConnectionId destination_connection_id,
859 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
860 uint64_t packet_number, const std::string& data,
861 QuicConnectionIdIncluded destination_connection_id_included,
862 QuicConnectionIdIncluded source_connection_id_included,
863 QuicPacketNumberLength packet_number_length) {
864 return ConstructEncryptedPacket(
865 destination_connection_id, source_connection_id, version_flag, reset_flag,
866 packet_number, data, destination_connection_id_included,
867 source_connection_id_included, packet_number_length, nullptr);
868 }
869
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)870 QuicEncryptedPacket* ConstructEncryptedPacket(
871 QuicConnectionId destination_connection_id,
872 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
873 uint64_t packet_number, const std::string& data,
874 QuicConnectionIdIncluded destination_connection_id_included,
875 QuicConnectionIdIncluded source_connection_id_included,
876 QuicPacketNumberLength packet_number_length,
877 ParsedQuicVersionVector* versions) {
878 return ConstructEncryptedPacket(
879 destination_connection_id, source_connection_id, version_flag, reset_flag,
880 packet_number, data, false, destination_connection_id_included,
881 source_connection_id_included, packet_number_length, versions,
882 Perspective::IS_CLIENT);
883 }
884
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)885 QuicEncryptedPacket* ConstructEncryptedPacket(
886 QuicConnectionId destination_connection_id,
887 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
888 uint64_t packet_number, const std::string& data, bool full_padding,
889 QuicConnectionIdIncluded destination_connection_id_included,
890 QuicConnectionIdIncluded source_connection_id_included,
891 QuicPacketNumberLength packet_number_length,
892 ParsedQuicVersionVector* versions) {
893 return ConstructEncryptedPacket(
894 destination_connection_id, source_connection_id, version_flag, reset_flag,
895 packet_number, data, full_padding, destination_connection_id_included,
896 source_connection_id_included, packet_number_length, versions,
897 Perspective::IS_CLIENT);
898 }
899
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions,Perspective perspective)900 QuicEncryptedPacket* ConstructEncryptedPacket(
901 QuicConnectionId destination_connection_id,
902 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
903 uint64_t packet_number, const std::string& data, bool full_padding,
904 QuicConnectionIdIncluded destination_connection_id_included,
905 QuicConnectionIdIncluded source_connection_id_included,
906 QuicPacketNumberLength packet_number_length,
907 ParsedQuicVersionVector* versions, Perspective perspective) {
908 QuicPacketHeader header;
909 header.destination_connection_id = destination_connection_id;
910 header.destination_connection_id_included =
911 destination_connection_id_included;
912 header.source_connection_id = source_connection_id;
913 header.source_connection_id_included = source_connection_id_included;
914 header.version_flag = version_flag;
915 header.reset_flag = reset_flag;
916 header.packet_number_length = packet_number_length;
917 header.packet_number = QuicPacketNumber(packet_number);
918 ParsedQuicVersionVector supported_versions = CurrentSupportedVersions();
919 if (!versions) {
920 versions = &supported_versions;
921 }
922 EXPECT_FALSE(versions->empty());
923 ParsedQuicVersion version = (*versions)[0];
924 if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
925 version_flag) {
926 header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
927 header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
928 }
929
930 QuicFrames frames;
931 QuicFramer framer(*versions, QuicTime::Zero(), perspective,
932 kQuicDefaultConnectionIdLength);
933 framer.SetInitialObfuscators(destination_connection_id);
934 EncryptionLevel level =
935 header.version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
936 if (level != ENCRYPTION_INITIAL) {
937 framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
938 }
939 if (!QuicVersionUsesCryptoFrames(version.transport_version)) {
940 QuicFrame frame(
941 QuicStreamFrame(QuicUtils::GetCryptoStreamId(version.transport_version),
942 false, 0, absl::string_view(data)));
943 frames.push_back(frame);
944 } else {
945 QuicFrame frame(new QuicCryptoFrame(level, 0, data));
946 frames.push_back(frame);
947 }
948 if (full_padding) {
949 frames.push_back(QuicFrame(QuicPaddingFrame(-1)));
950 } else {
951 // We need a minimum number of bytes of encrypted payload. This will
952 // guarantee that we have at least that much. (It ignores the overhead of
953 // the stream/crypto framing, so it overpads slightly.)
954 size_t min_plaintext_size = QuicPacketCreator::MinPlaintextPacketSize(
955 version, packet_number_length);
956 if (data.length() < min_plaintext_size) {
957 size_t padding_length = min_plaintext_size - data.length();
958 frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
959 }
960 }
961
962 std::unique_ptr<QuicPacket> packet(
963 BuildUnsizedDataPacket(&framer, header, frames));
964 EXPECT_TRUE(packet != nullptr);
965 char* buffer = new char[kMaxOutgoingPacketSize];
966 size_t encrypted_length =
967 framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
968 buffer, kMaxOutgoingPacketSize);
969 EXPECT_NE(0u, encrypted_length);
970 DeleteFrames(&frames);
971 return new QuicEncryptedPacket(buffer, encrypted_length, true);
972 }
973
GetUndecryptableEarlyPacket(const ParsedQuicVersion & version,const QuicConnectionId & server_connection_id)974 std::unique_ptr<QuicEncryptedPacket> GetUndecryptableEarlyPacket(
975 const ParsedQuicVersion& version,
976 const QuicConnectionId& server_connection_id) {
977 QuicPacketHeader header;
978 header.destination_connection_id = server_connection_id;
979 header.destination_connection_id_included = CONNECTION_ID_PRESENT;
980 header.source_connection_id = EmptyQuicConnectionId();
981 header.source_connection_id_included = CONNECTION_ID_PRESENT;
982 if (!version.SupportsClientConnectionIds()) {
983 header.source_connection_id_included = CONNECTION_ID_ABSENT;
984 }
985 header.version_flag = true;
986 header.reset_flag = false;
987 header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER;
988 header.packet_number = QuicPacketNumber(33);
989 header.long_packet_type = ZERO_RTT_PROTECTED;
990 if (version.HasLongHeaderLengths()) {
991 header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
992 header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
993 }
994
995 QuicFrames frames;
996 frames.push_back(QuicFrame(QuicPingFrame()));
997 frames.push_back(QuicFrame(QuicPaddingFrame(100)));
998 QuicFramer framer({version}, QuicTime::Zero(), Perspective::IS_CLIENT,
999 kQuicDefaultConnectionIdLength);
1000 framer.SetInitialObfuscators(server_connection_id);
1001
1002 framer.SetEncrypter(ENCRYPTION_ZERO_RTT,
1003 std::make_unique<TaggingEncrypter>(ENCRYPTION_ZERO_RTT));
1004 std::unique_ptr<QuicPacket> packet(
1005 BuildUnsizedDataPacket(&framer, header, frames));
1006 EXPECT_TRUE(packet != nullptr);
1007 char* buffer = new char[kMaxOutgoingPacketSize];
1008 size_t encrypted_length =
1009 framer.EncryptPayload(ENCRYPTION_ZERO_RTT, header.packet_number, *packet,
1010 buffer, kMaxOutgoingPacketSize);
1011 EXPECT_NE(0u, encrypted_length);
1012 DeleteFrames(&frames);
1013 return std::make_unique<QuicEncryptedPacket>(buffer, encrypted_length,
1014 /*owns_buffer=*/true);
1015 }
1016
ConstructReceivedPacket(const QuicEncryptedPacket & encrypted_packet,QuicTime receipt_time)1017 QuicReceivedPacket* ConstructReceivedPacket(
1018 const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time) {
1019 return ConstructReceivedPacket(encrypted_packet, receipt_time, ECN_NOT_ECT);
1020 }
1021
ConstructReceivedPacket(const QuicEncryptedPacket & encrypted_packet,QuicTime receipt_time,QuicEcnCodepoint ecn)1022 QuicReceivedPacket* ConstructReceivedPacket(
1023 const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time,
1024 QuicEcnCodepoint ecn) {
1025 char* buffer = new char[encrypted_packet.length()];
1026 memcpy(buffer, encrypted_packet.data(), encrypted_packet.length());
1027 return new QuicReceivedPacket(buffer, encrypted_packet.length(), receipt_time,
1028 true, 0, true, nullptr, 0, false, ecn);
1029 }
1030
ConstructMisFramedEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersion version,Perspective perspective)1031 QuicEncryptedPacket* ConstructMisFramedEncryptedPacket(
1032 QuicConnectionId destination_connection_id,
1033 QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
1034 uint64_t packet_number, const std::string& data,
1035 QuicConnectionIdIncluded destination_connection_id_included,
1036 QuicConnectionIdIncluded source_connection_id_included,
1037 QuicPacketNumberLength packet_number_length, ParsedQuicVersion version,
1038 Perspective perspective) {
1039 QuicPacketHeader header;
1040 header.destination_connection_id = destination_connection_id;
1041 header.destination_connection_id_included =
1042 destination_connection_id_included;
1043 header.source_connection_id = source_connection_id;
1044 header.source_connection_id_included = source_connection_id_included;
1045 header.version_flag = version_flag;
1046 header.reset_flag = reset_flag;
1047 header.packet_number_length = packet_number_length;
1048 header.packet_number = QuicPacketNumber(packet_number);
1049 if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
1050 version_flag) {
1051 header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
1052 header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
1053 }
1054 QuicFrame frame(QuicStreamFrame(1, false, 0, absl::string_view(data)));
1055 QuicFrames frames;
1056 frames.push_back(frame);
1057 QuicFramer framer({version}, QuicTime::Zero(), perspective,
1058 kQuicDefaultConnectionIdLength);
1059 framer.SetInitialObfuscators(destination_connection_id);
1060 EncryptionLevel level =
1061 version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
1062 if (level != ENCRYPTION_INITIAL) {
1063 framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
1064 }
1065 // We need a minimum of 7 bytes of encrypted payload. This will guarantee that
1066 // we have at least that much. (It ignores the overhead of the stream/crypto
1067 // framing, so it overpads slightly.)
1068 if (data.length() < 7) {
1069 size_t padding_length = 7 - data.length();
1070 frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
1071 }
1072
1073 std::unique_ptr<QuicPacket> packet(
1074 BuildUnsizedDataPacket(&framer, header, frames));
1075 EXPECT_TRUE(packet != nullptr);
1076
1077 // Now set the frame type to 0x1F, which is an invalid frame type.
1078 reinterpret_cast<unsigned char*>(
1079 packet->mutable_data())[GetStartOfEncryptedData(
1080 framer.transport_version(),
1081 GetIncludedDestinationConnectionIdLength(header),
1082 GetIncludedSourceConnectionIdLength(header), version_flag,
1083 false /* no diversification nonce */, packet_number_length,
1084 header.retry_token_length_length, 0, header.length_length)] = 0x1F;
1085
1086 char* buffer = new char[kMaxOutgoingPacketSize];
1087 size_t encrypted_length =
1088 framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
1089 buffer, kMaxOutgoingPacketSize);
1090 EXPECT_NE(0u, encrypted_length);
1091 return new QuicEncryptedPacket(buffer, encrypted_length, true);
1092 }
1093
DefaultQuicConfig()1094 QuicConfig DefaultQuicConfig() {
1095 QuicConfig config;
1096 config.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend(
1097 kInitialStreamFlowControlWindowForTest);
1098 config.SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend(
1099 kInitialStreamFlowControlWindowForTest);
1100 config.SetInitialMaxStreamDataBytesUnidirectionalToSend(
1101 kInitialStreamFlowControlWindowForTest);
1102 config.SetInitialStreamFlowControlWindowToSend(
1103 kInitialStreamFlowControlWindowForTest);
1104 config.SetInitialSessionFlowControlWindowToSend(
1105 kInitialSessionFlowControlWindowForTest);
1106 QuicConfigPeer::SetReceivedMaxBidirectionalStreams(
1107 &config, kDefaultMaxStreamsPerConnection);
1108 // Default enable NSTP.
1109 // This is unnecessary for versions > 44
1110 if (!config.HasClientSentConnectionOption(quic::kNSTP,
1111 quic::Perspective::IS_CLIENT)) {
1112 quic::QuicTagVector connection_options;
1113 connection_options.push_back(quic::kNSTP);
1114 config.SetConnectionOptionsToSend(connection_options);
1115 }
1116 return config;
1117 }
1118
SupportedVersions(ParsedQuicVersion version)1119 ParsedQuicVersionVector SupportedVersions(ParsedQuicVersion version) {
1120 ParsedQuicVersionVector versions;
1121 versions.push_back(version);
1122 return versions;
1123 }
1124
MockQuicConnectionDebugVisitor()1125 MockQuicConnectionDebugVisitor::MockQuicConnectionDebugVisitor() {}
1126
~MockQuicConnectionDebugVisitor()1127 MockQuicConnectionDebugVisitor::~MockQuicConnectionDebugVisitor() {}
1128
MockReceivedPacketManager(QuicConnectionStats * stats)1129 MockReceivedPacketManager::MockReceivedPacketManager(QuicConnectionStats* stats)
1130 : QuicReceivedPacketManager(stats) {}
1131
~MockReceivedPacketManager()1132 MockReceivedPacketManager::~MockReceivedPacketManager() {}
1133
MockPacketCreatorDelegate()1134 MockPacketCreatorDelegate::MockPacketCreatorDelegate() {}
~MockPacketCreatorDelegate()1135 MockPacketCreatorDelegate::~MockPacketCreatorDelegate() {}
1136
MockSessionNotifier()1137 MockSessionNotifier::MockSessionNotifier() {}
~MockSessionNotifier()1138 MockSessionNotifier::~MockSessionNotifier() {}
1139
1140 // static
1141 QuicCryptoClientStream::HandshakerInterface*
GetHandshaker(QuicCryptoClientStream * stream)1142 QuicCryptoClientStreamPeer::GetHandshaker(QuicCryptoClientStream* stream) {
1143 return stream->handshaker_.get();
1144 }
1145
CreateClientSessionForTest(QuicServerId server_id,QuicTime::Delta connection_start_time,const ParsedQuicVersionVector & supported_versions,MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,QuicCryptoClientConfig * crypto_client_config,PacketSavingConnection ** client_connection,TestQuicSpdyClientSession ** client_session)1146 void CreateClientSessionForTest(
1147 QuicServerId server_id, QuicTime::Delta connection_start_time,
1148 const ParsedQuicVersionVector& supported_versions,
1149 MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
1150 QuicCryptoClientConfig* crypto_client_config,
1151 PacketSavingConnection** client_connection,
1152 TestQuicSpdyClientSession** client_session) {
1153 QUICHE_CHECK(crypto_client_config);
1154 QUICHE_CHECK(client_connection);
1155 QUICHE_CHECK(client_session);
1156 QUICHE_CHECK(!connection_start_time.IsZero())
1157 << "Connections must start at non-zero times, otherwise the "
1158 << "strike-register will be unhappy.";
1159
1160 QuicConfig config = DefaultQuicConfig();
1161 *client_connection = new PacketSavingConnection(
1162 helper, alarm_factory, Perspective::IS_CLIENT, supported_versions);
1163 *client_session = new TestQuicSpdyClientSession(*client_connection, config,
1164 supported_versions, server_id,
1165 crypto_client_config);
1166 (*client_connection)->AdvanceTime(connection_start_time);
1167 }
1168
CreateServerSessionForTest(QuicServerId,QuicTime::Delta connection_start_time,ParsedQuicVersionVector supported_versions,MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,QuicCryptoServerConfig * server_crypto_config,QuicCompressedCertsCache * compressed_certs_cache,PacketSavingConnection ** server_connection,TestQuicSpdyServerSession ** server_session)1169 void CreateServerSessionForTest(
1170 QuicServerId /*server_id*/, QuicTime::Delta connection_start_time,
1171 ParsedQuicVersionVector supported_versions,
1172 MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
1173 QuicCryptoServerConfig* server_crypto_config,
1174 QuicCompressedCertsCache* compressed_certs_cache,
1175 PacketSavingConnection** server_connection,
1176 TestQuicSpdyServerSession** server_session) {
1177 QUICHE_CHECK(server_crypto_config);
1178 QUICHE_CHECK(server_connection);
1179 QUICHE_CHECK(server_session);
1180 QUICHE_CHECK(!connection_start_time.IsZero())
1181 << "Connections must start at non-zero times, otherwise the "
1182 << "strike-register will be unhappy.";
1183
1184 *server_connection =
1185 new PacketSavingConnection(helper, alarm_factory, Perspective::IS_SERVER,
1186 ParsedVersionOfIndex(supported_versions, 0));
1187 *server_session = new TestQuicSpdyServerSession(
1188 *server_connection, DefaultQuicConfig(), supported_versions,
1189 server_crypto_config, compressed_certs_cache);
1190 (*server_session)->Initialize();
1191
1192 // We advance the clock initially because the default time is zero and the
1193 // strike register worries that we've just overflowed a uint32_t time.
1194 (*server_connection)->AdvanceTime(connection_start_time);
1195 }
1196
GetNthClientInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1197 QuicStreamId GetNthClientInitiatedBidirectionalStreamId(
1198 QuicTransportVersion version, int n) {
1199 int num = n;
1200 if (!VersionUsesHttp3(version)) {
1201 num++;
1202 }
1203 return QuicUtils::GetFirstBidirectionalStreamId(version,
1204 Perspective::IS_CLIENT) +
1205 QuicUtils::StreamIdDelta(version) * num;
1206 }
1207
GetNthServerInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1208 QuicStreamId GetNthServerInitiatedBidirectionalStreamId(
1209 QuicTransportVersion version, int n) {
1210 return QuicUtils::GetFirstBidirectionalStreamId(version,
1211 Perspective::IS_SERVER) +
1212 QuicUtils::StreamIdDelta(version) * n;
1213 }
1214
GetNthServerInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1215 QuicStreamId GetNthServerInitiatedUnidirectionalStreamId(
1216 QuicTransportVersion version, int n) {
1217 return QuicUtils::GetFirstUnidirectionalStreamId(version,
1218 Perspective::IS_SERVER) +
1219 QuicUtils::StreamIdDelta(version) * n;
1220 }
1221
GetNthClientInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1222 QuicStreamId GetNthClientInitiatedUnidirectionalStreamId(
1223 QuicTransportVersion version, int n) {
1224 return QuicUtils::GetFirstUnidirectionalStreamId(version,
1225 Perspective::IS_CLIENT) +
1226 QuicUtils::StreamIdDelta(version) * n;
1227 }
1228
DetermineStreamType(QuicStreamId id,ParsedQuicVersion version,Perspective perspective,bool is_incoming,StreamType default_type)1229 StreamType DetermineStreamType(QuicStreamId id, ParsedQuicVersion version,
1230 Perspective perspective, bool is_incoming,
1231 StreamType default_type) {
1232 return version.HasIetfQuicFrames()
1233 ? QuicUtils::GetStreamType(id, perspective, is_incoming, version)
1234 : default_type;
1235 }
1236
MemSliceFromString(absl::string_view data)1237 quiche::QuicheMemSlice MemSliceFromString(absl::string_view data) {
1238 if (data.empty()) {
1239 return quiche::QuicheMemSlice();
1240 }
1241
1242 static quiche::SimpleBufferAllocator* allocator =
1243 new quiche::SimpleBufferAllocator();
1244 return quiche::QuicheMemSlice(quiche::QuicheBuffer::Copy(allocator, data));
1245 }
1246
EncryptPacket(uint64_t,absl::string_view,absl::string_view plaintext,char * output,size_t * output_length,size_t max_output_length)1247 bool TaggingEncrypter::EncryptPacket(uint64_t /*packet_number*/,
1248 absl::string_view /*associated_data*/,
1249 absl::string_view plaintext, char* output,
1250 size_t* output_length,
1251 size_t max_output_length) {
1252 const size_t len = plaintext.size() + kTagSize;
1253 if (max_output_length < len) {
1254 return false;
1255 }
1256 // Memmove is safe for inplace encryption.
1257 memmove(output, plaintext.data(), plaintext.size());
1258 output += plaintext.size();
1259 memset(output, tag_, kTagSize);
1260 *output_length = len;
1261 return true;
1262 }
1263
DecryptPacket(uint64_t,absl::string_view,absl::string_view ciphertext,char * output,size_t * output_length,size_t)1264 bool TaggingDecrypter::DecryptPacket(uint64_t /*packet_number*/,
1265 absl::string_view /*associated_data*/,
1266 absl::string_view ciphertext, char* output,
1267 size_t* output_length,
1268 size_t /*max_output_length*/) {
1269 if (ciphertext.size() < kTagSize) {
1270 return false;
1271 }
1272 if (!CheckTag(ciphertext, GetTag(ciphertext))) {
1273 return false;
1274 }
1275 *output_length = ciphertext.size() - kTagSize;
1276 memcpy(output, ciphertext.data(), *output_length);
1277 return true;
1278 }
1279
CheckTag(absl::string_view ciphertext,uint8_t tag)1280 bool TaggingDecrypter::CheckTag(absl::string_view ciphertext, uint8_t tag) {
1281 for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) {
1282 if (ciphertext.data()[i] != tag) {
1283 return false;
1284 }
1285 }
1286
1287 return true;
1288 }
1289
TestPacketWriter(ParsedQuicVersion version,MockClock * clock,Perspective perspective)1290 TestPacketWriter::TestPacketWriter(ParsedQuicVersion version, MockClock* clock,
1291 Perspective perspective)
1292 : version_(version),
1293 framer_(SupportedVersions(version_),
1294 QuicUtils::InvertPerspective(perspective)),
1295 clock_(clock) {
1296 QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(),
1297 TestConnectionId());
1298 framer_.framer()->SetInitialObfuscators(TestConnectionId());
1299
1300 for (int i = 0; i < 128; ++i) {
1301 PacketBuffer* p = new PacketBuffer();
1302 packet_buffer_pool_.push_back(p);
1303 packet_buffer_pool_index_[p->buffer] = p;
1304 packet_buffer_free_list_.push_back(p);
1305 }
1306 }
1307
~TestPacketWriter()1308 TestPacketWriter::~TestPacketWriter() {
1309 EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size())
1310 << packet_buffer_pool_.size() - packet_buffer_free_list_.size()
1311 << " out of " << packet_buffer_pool_.size()
1312 << " packet buffers have been leaked.";
1313 for (auto p : packet_buffer_pool_) {
1314 delete p;
1315 }
1316 }
1317
WritePacket(const char * buffer,size_t buf_len,const QuicIpAddress & self_address,const QuicSocketAddress & peer_address,PerPacketOptions *,const QuicPacketWriterParams & params)1318 WriteResult TestPacketWriter::WritePacket(
1319 const char* buffer, size_t buf_len, const QuicIpAddress& self_address,
1320 const QuicSocketAddress& peer_address, PerPacketOptions* /*options*/,
1321 const QuicPacketWriterParams& params) {
1322 last_write_source_address_ = self_address;
1323 last_write_peer_address_ = peer_address;
1324 // If the buffer is allocated from the pool, return it back to the pool.
1325 // Note the buffer content doesn't change.
1326 if (packet_buffer_pool_index_.find(const_cast<char*>(buffer)) !=
1327 packet_buffer_pool_index_.end()) {
1328 FreePacketBuffer(buffer);
1329 }
1330
1331 QuicEncryptedPacket packet(buffer, buf_len);
1332 ++packets_write_attempts_;
1333
1334 if (packet.length() >= sizeof(final_bytes_of_last_packet_)) {
1335 final_bytes_of_previous_packet_ = final_bytes_of_last_packet_;
1336 memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4,
1337 sizeof(final_bytes_of_last_packet_));
1338 }
1339 if (framer_.framer()->version().KnowsWhichDecrypterToUse()) {
1340 framer_.framer()->InstallDecrypter(ENCRYPTION_HANDSHAKE,
1341 std::make_unique<TaggingDecrypter>());
1342 framer_.framer()->InstallDecrypter(ENCRYPTION_ZERO_RTT,
1343 std::make_unique<TaggingDecrypter>());
1344 framer_.framer()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
1345 std::make_unique<TaggingDecrypter>());
1346 } else if (!framer_.framer()->HasDecrypterOfEncryptionLevel(
1347 ENCRYPTION_FORWARD_SECURE) &&
1348 !framer_.framer()->HasDecrypterOfEncryptionLevel(
1349 ENCRYPTION_ZERO_RTT)) {
1350 framer_.framer()->SetAlternativeDecrypter(
1351 ENCRYPTION_FORWARD_SECURE,
1352 std::make_unique<StrictTaggingDecrypter>(ENCRYPTION_FORWARD_SECURE),
1353 false);
1354 }
1355 EXPECT_EQ(next_packet_processable_, framer_.ProcessPacket(packet))
1356 << framer_.framer()->detailed_error() << " perspective "
1357 << framer_.framer()->perspective();
1358 next_packet_processable_ = true;
1359 if (block_on_next_write_) {
1360 write_blocked_ = true;
1361 block_on_next_write_ = false;
1362 }
1363 if (next_packet_too_large_) {
1364 next_packet_too_large_ = false;
1365 return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1366 }
1367 if (always_get_packet_too_large_) {
1368 return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1369 }
1370 if (IsWriteBlocked()) {
1371 return WriteResult(is_write_blocked_data_buffered_
1372 ? WRITE_STATUS_BLOCKED_DATA_BUFFERED
1373 : WRITE_STATUS_BLOCKED,
1374 0);
1375 }
1376
1377 if (ShouldWriteFail()) {
1378 return WriteResult(WRITE_STATUS_ERROR, write_error_code_);
1379 }
1380
1381 last_packet_size_ = packet.length();
1382 total_bytes_written_ += packet.length();
1383 last_packet_header_ = framer_.header();
1384 if (!framer_.connection_close_frames().empty()) {
1385 ++connection_close_packets_;
1386 }
1387 if (!write_pause_time_delta_.IsZero()) {
1388 clock_->AdvanceTime(write_pause_time_delta_);
1389 }
1390 if (is_batch_mode_) {
1391 bytes_buffered_ += last_packet_size_;
1392 return WriteResult(WRITE_STATUS_OK, 0);
1393 }
1394 last_ecn_sent_ = params.ecn_codepoint;
1395 return WriteResult(WRITE_STATUS_OK, last_packet_size_);
1396 }
1397
GetNextWriteLocation(const QuicIpAddress &,const QuicSocketAddress &)1398 QuicPacketBuffer TestPacketWriter::GetNextWriteLocation(
1399 const QuicIpAddress& /*self_address*/,
1400 const QuicSocketAddress& /*peer_address*/) {
1401 return {AllocPacketBuffer(), [this](const char* p) { FreePacketBuffer(p); }};
1402 }
1403
Flush()1404 WriteResult TestPacketWriter::Flush() {
1405 flush_attempts_++;
1406 if (block_on_next_flush_) {
1407 block_on_next_flush_ = false;
1408 SetWriteBlocked();
1409 return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1);
1410 }
1411 if (write_should_fail_) {
1412 return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1);
1413 }
1414 int bytes_flushed = bytes_buffered_;
1415 bytes_buffered_ = 0;
1416 return WriteResult(WRITE_STATUS_OK, bytes_flushed);
1417 }
1418
AllocPacketBuffer()1419 char* TestPacketWriter::AllocPacketBuffer() {
1420 PacketBuffer* p = packet_buffer_free_list_.front();
1421 EXPECT_FALSE(p->in_use);
1422 p->in_use = true;
1423 packet_buffer_free_list_.pop_front();
1424 return p->buffer;
1425 }
1426
FreePacketBuffer(const char * buffer)1427 void TestPacketWriter::FreePacketBuffer(const char* buffer) {
1428 auto iter = packet_buffer_pool_index_.find(const_cast<char*>(buffer));
1429 ASSERT_TRUE(iter != packet_buffer_pool_index_.end());
1430 PacketBuffer* p = iter->second;
1431 ASSERT_TRUE(p->in_use);
1432 p->in_use = false;
1433 packet_buffer_free_list_.push_back(p);
1434 }
1435
WriteServerVersionNegotiationProbeResponse(char * packet_bytes,size_t * packet_length_out,const char * source_connection_id_bytes,uint8_t source_connection_id_length)1436 bool WriteServerVersionNegotiationProbeResponse(
1437 char* packet_bytes, size_t* packet_length_out,
1438 const char* source_connection_id_bytes,
1439 uint8_t source_connection_id_length) {
1440 if (packet_bytes == nullptr) {
1441 QUIC_BUG(quic_bug_10256_1) << "Invalid packet_bytes";
1442 return false;
1443 }
1444 if (packet_length_out == nullptr) {
1445 QUIC_BUG(quic_bug_10256_2) << "Invalid packet_length_out";
1446 return false;
1447 }
1448 QuicConnectionId source_connection_id(source_connection_id_bytes,
1449 source_connection_id_length);
1450 std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
1451 QuicFramer::BuildVersionNegotiationPacket(
1452 source_connection_id, EmptyQuicConnectionId(),
1453 /*ietf_quic=*/true, /*use_length_prefix=*/true,
1454 ParsedQuicVersionVector{});
1455 if (!encrypted_packet) {
1456 QUIC_BUG(quic_bug_10256_3) << "Failed to create version negotiation packet";
1457 return false;
1458 }
1459 if (*packet_length_out < encrypted_packet->length()) {
1460 QUIC_BUG(quic_bug_10256_4)
1461 << "Invalid *packet_length_out " << *packet_length_out << " < "
1462 << encrypted_packet->length();
1463 return false;
1464 }
1465 *packet_length_out = encrypted_packet->length();
1466 memcpy(packet_bytes, encrypted_packet->data(), *packet_length_out);
1467 return true;
1468 }
1469
ParseClientVersionNegotiationProbePacket(const char * packet_bytes,size_t packet_length,char * destination_connection_id_bytes,uint8_t * destination_connection_id_length_out)1470 bool ParseClientVersionNegotiationProbePacket(
1471 const char* packet_bytes, size_t packet_length,
1472 char* destination_connection_id_bytes,
1473 uint8_t* destination_connection_id_length_out) {
1474 if (packet_bytes == nullptr) {
1475 QUIC_BUG(quic_bug_10256_5) << "Invalid packet_bytes";
1476 return false;
1477 }
1478 if (packet_length < kMinPacketSizeForVersionNegotiation ||
1479 packet_length > 65535) {
1480 QUIC_BUG(quic_bug_10256_6) << "Invalid packet_length";
1481 return false;
1482 }
1483 if (destination_connection_id_bytes == nullptr) {
1484 QUIC_BUG(quic_bug_10256_7) << "Invalid destination_connection_id_bytes";
1485 return false;
1486 }
1487 if (destination_connection_id_length_out == nullptr) {
1488 QUIC_BUG(quic_bug_10256_8)
1489 << "Invalid destination_connection_id_length_out";
1490 return false;
1491 }
1492
1493 QuicEncryptedPacket encrypted_packet(packet_bytes, packet_length);
1494 PacketHeaderFormat format;
1495 QuicLongHeaderType long_packet_type;
1496 bool version_present, has_length_prefix;
1497 QuicVersionLabel version_label;
1498 ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported();
1499 QuicConnectionId destination_connection_id, source_connection_id;
1500 std::optional<absl::string_view> retry_token;
1501 std::string detailed_error;
1502 QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher(
1503 encrypted_packet,
1504 /*expected_destination_connection_id_length=*/0, &format,
1505 &long_packet_type, &version_present, &has_length_prefix, &version_label,
1506 &parsed_version, &destination_connection_id, &source_connection_id,
1507 &retry_token, &detailed_error);
1508 if (error != QUIC_NO_ERROR) {
1509 QUIC_BUG(quic_bug_10256_9) << "Failed to parse packet: " << detailed_error;
1510 return false;
1511 }
1512 if (!version_present) {
1513 QUIC_BUG(quic_bug_10256_10) << "Packet is not a long header";
1514 return false;
1515 }
1516 if (*destination_connection_id_length_out <
1517 destination_connection_id.length()) {
1518 QUIC_BUG(quic_bug_10256_11)
1519 << "destination_connection_id_length_out too small";
1520 return false;
1521 }
1522 *destination_connection_id_length_out = destination_connection_id.length();
1523 memcpy(destination_connection_id_bytes, destination_connection_id.data(),
1524 *destination_connection_id_length_out);
1525 return true;
1526 }
1527
1528 } // namespace test
1529 } // namespace quic
1530