xref: /aosp_15_r20/external/cronet/net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "quiche/quic/test_tools/quic_test_utils.h"
6 
7 #include <algorithm>
8 #include <cstddef>
9 #include <cstdint>
10 #include <memory>
11 #include <utility>
12 #include <vector>
13 
14 #include "absl/base/macros.h"
15 #include "absl/strings/string_view.h"
16 #include "openssl/chacha.h"
17 #include "openssl/sha.h"
18 #include "quiche/quic/core/crypto/crypto_framer.h"
19 #include "quiche/quic/core/crypto/crypto_handshake.h"
20 #include "quiche/quic/core/crypto/crypto_utils.h"
21 #include "quiche/quic/core/crypto/null_decrypter.h"
22 #include "quiche/quic/core/crypto/null_encrypter.h"
23 #include "quiche/quic/core/crypto/quic_decrypter.h"
24 #include "quiche/quic/core/crypto/quic_encrypter.h"
25 #include "quiche/quic/core/http/quic_spdy_client_session.h"
26 #include "quiche/quic/core/quic_config.h"
27 #include "quiche/quic/core/quic_data_writer.h"
28 #include "quiche/quic/core/quic_framer.h"
29 #include "quiche/quic/core/quic_packet_creator.h"
30 #include "quiche/quic/core/quic_packets.h"
31 #include "quiche/quic/core/quic_time.h"
32 #include "quiche/quic/core/quic_types.h"
33 #include "quiche/quic/core/quic_utils.h"
34 #include "quiche/quic/core/quic_versions.h"
35 #include "quiche/quic/platform/api/quic_flags.h"
36 #include "quiche/quic/platform/api/quic_logging.h"
37 #include "quiche/quic/test_tools/crypto_test_utils.h"
38 #include "quiche/quic/test_tools/quic_config_peer.h"
39 #include "quiche/quic/test_tools/quic_connection_peer.h"
40 #include "quiche/common/quiche_buffer_allocator.h"
41 #include "quiche/common/quiche_endian.h"
42 #include "quiche/common/simple_buffer_allocator.h"
43 #include "quiche/spdy/core/spdy_frame_builder.h"
44 
45 using testing::_;
46 using testing::Invoke;
47 using testing::Return;
48 
49 namespace quic {
50 namespace test {
51 
TestConnectionId()52 QuicConnectionId TestConnectionId() {
53   // Chosen by fair dice roll.
54   // Guaranteed to be random.
55   return TestConnectionId(42);
56 }
57 
TestConnectionId(uint64_t connection_number)58 QuicConnectionId TestConnectionId(uint64_t connection_number) {
59   const uint64_t connection_id64_net =
60       quiche::QuicheEndian::HostToNet64(connection_number);
61   return QuicConnectionId(reinterpret_cast<const char*>(&connection_id64_net),
62                           sizeof(connection_id64_net));
63 }
64 
TestConnectionIdNineBytesLong(uint64_t connection_number)65 QuicConnectionId TestConnectionIdNineBytesLong(uint64_t connection_number) {
66   const uint64_t connection_number_net =
67       quiche::QuicheEndian::HostToNet64(connection_number);
68   char connection_id_bytes[9] = {};
69   static_assert(
70       sizeof(connection_id_bytes) == 1 + sizeof(connection_number_net),
71       "bad lengths");
72   memcpy(connection_id_bytes + 1, &connection_number_net,
73          sizeof(connection_number_net));
74   return QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes));
75 }
76 
TestConnectionIdToUInt64(QuicConnectionId connection_id)77 uint64_t TestConnectionIdToUInt64(QuicConnectionId connection_id) {
78   QUICHE_DCHECK_EQ(connection_id.length(), kQuicDefaultConnectionIdLength);
79   uint64_t connection_id64_net = 0;
80   memcpy(&connection_id64_net, connection_id.data(),
81          std::min<size_t>(static_cast<size_t>(connection_id.length()),
82                           sizeof(connection_id64_net)));
83   return quiche::QuicheEndian::NetToHost64(connection_id64_net);
84 }
85 
CreateStatelessResetTokenForTest()86 std::vector<uint8_t> CreateStatelessResetTokenForTest() {
87   static constexpr uint8_t kStatelessResetTokenDataForTest[16] = {
88       0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
89       0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F};
90   return std::vector<uint8_t>(kStatelessResetTokenDataForTest,
91                               kStatelessResetTokenDataForTest +
92                                   sizeof(kStatelessResetTokenDataForTest));
93 }
94 
TestHostname()95 std::string TestHostname() { return "test.example.com"; }
96 
TestServerId()97 QuicServerId TestServerId() { return QuicServerId(TestHostname(), kTestPort); }
98 
InitAckFrame(const std::vector<QuicAckBlock> & ack_blocks)99 QuicAckFrame InitAckFrame(const std::vector<QuicAckBlock>& ack_blocks) {
100   QUICHE_DCHECK_GT(ack_blocks.size(), 0u);
101 
102   QuicAckFrame ack;
103   QuicPacketNumber end_of_previous_block(1);
104   for (const QuicAckBlock& block : ack_blocks) {
105     QUICHE_DCHECK_GE(block.start, end_of_previous_block);
106     QUICHE_DCHECK_GT(block.limit, block.start);
107     ack.packets.AddRange(block.start, block.limit);
108     end_of_previous_block = block.limit;
109   }
110 
111   ack.largest_acked = ack.packets.Max();
112 
113   return ack;
114 }
115 
InitAckFrame(uint64_t largest_acked)116 QuicAckFrame InitAckFrame(uint64_t largest_acked) {
117   return InitAckFrame(QuicPacketNumber(largest_acked));
118 }
119 
InitAckFrame(QuicPacketNumber largest_acked)120 QuicAckFrame InitAckFrame(QuicPacketNumber largest_acked) {
121   return InitAckFrame({{QuicPacketNumber(1), largest_acked + 1}});
122 }
123 
MakeAckFrameWithAckBlocks(size_t num_ack_blocks,uint64_t least_unacked)124 QuicAckFrame MakeAckFrameWithAckBlocks(size_t num_ack_blocks,
125                                        uint64_t least_unacked) {
126   QuicAckFrame ack;
127   ack.largest_acked = QuicPacketNumber(2 * num_ack_blocks + least_unacked);
128   // Add enough received packets to get num_ack_blocks ack blocks.
129   for (QuicPacketNumber i = QuicPacketNumber(2);
130        i < QuicPacketNumber(2 * num_ack_blocks + 1); i += 2) {
131     ack.packets.Add(i + least_unacked);
132   }
133   return ack;
134 }
135 
MakeAckFrameWithGaps(uint64_t gap_size,size_t max_num_gaps,uint64_t largest_acked)136 QuicAckFrame MakeAckFrameWithGaps(uint64_t gap_size, size_t max_num_gaps,
137                                   uint64_t largest_acked) {
138   QuicAckFrame ack;
139   ack.largest_acked = QuicPacketNumber(largest_acked);
140   ack.packets.Add(QuicPacketNumber(largest_acked));
141   for (size_t i = 0; i < max_num_gaps; ++i) {
142     if (largest_acked <= gap_size) {
143       break;
144     }
145     largest_acked -= gap_size;
146     ack.packets.Add(QuicPacketNumber(largest_acked));
147   }
148   return ack;
149 }
150 
HeaderToEncryptionLevel(const QuicPacketHeader & header)151 EncryptionLevel HeaderToEncryptionLevel(const QuicPacketHeader& header) {
152   if (header.form == IETF_QUIC_SHORT_HEADER_PACKET) {
153     return ENCRYPTION_FORWARD_SECURE;
154   } else if (header.form == IETF_QUIC_LONG_HEADER_PACKET) {
155     if (header.long_packet_type == HANDSHAKE) {
156       return ENCRYPTION_HANDSHAKE;
157     } else if (header.long_packet_type == ZERO_RTT_PROTECTED) {
158       return ENCRYPTION_ZERO_RTT;
159     }
160   }
161   return ENCRYPTION_INITIAL;
162 }
163 
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames)164 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
165     QuicFramer* framer, const QuicPacketHeader& header,
166     const QuicFrames& frames) {
167   const size_t max_plaintext_size =
168       framer->GetMaxPlaintextSize(kMaxOutgoingPacketSize);
169   size_t packet_size = GetPacketHeaderSize(framer->transport_version(), header);
170   for (size_t i = 0; i < frames.size(); ++i) {
171     QUICHE_DCHECK_LE(packet_size, max_plaintext_size);
172     bool first_frame = i == 0;
173     bool last_frame = i == frames.size() - 1;
174     const size_t frame_size = framer->GetSerializedFrameLength(
175         frames[i], max_plaintext_size - packet_size, first_frame, last_frame,
176         header.packet_number_length);
177     QUICHE_DCHECK(frame_size);
178     packet_size += frame_size;
179   }
180   return BuildUnsizedDataPacket(framer, header, frames, packet_size);
181 }
182 
BuildUnsizedDataPacket(QuicFramer * framer,const QuicPacketHeader & header,const QuicFrames & frames,size_t packet_size)183 std::unique_ptr<QuicPacket> BuildUnsizedDataPacket(
184     QuicFramer* framer, const QuicPacketHeader& header,
185     const QuicFrames& frames, size_t packet_size) {
186   char* buffer = new char[packet_size];
187   EncryptionLevel level = HeaderToEncryptionLevel(header);
188   size_t length =
189       framer->BuildDataPacket(header, frames, buffer, packet_size, level);
190 
191   if (length == 0) {
192     delete[] buffer;
193     return nullptr;
194   }
195   // Re-construct the data packet with data ownership.
196   return std::make_unique<QuicPacket>(
197       buffer, length, /* owns_buffer */ true,
198       GetIncludedDestinationConnectionIdLength(header),
199       GetIncludedSourceConnectionIdLength(header), header.version_flag,
200       header.nonce != nullptr, header.packet_number_length,
201       header.retry_token_length_length, header.retry_token.length(),
202       header.length_length);
203 }
204 
Sha1Hash(absl::string_view data)205 std::string Sha1Hash(absl::string_view data) {
206   char buffer[SHA_DIGEST_LENGTH];
207   SHA1(reinterpret_cast<const uint8_t*>(data.data()), data.size(),
208        reinterpret_cast<uint8_t*>(buffer));
209   return std::string(buffer, ABSL_ARRAYSIZE(buffer));
210 }
211 
ClearControlFrame(const QuicFrame & frame)212 bool ClearControlFrame(const QuicFrame& frame) {
213   DeleteFrame(&const_cast<QuicFrame&>(frame));
214   return true;
215 }
216 
ClearControlFrameWithTransmissionType(const QuicFrame & frame,TransmissionType)217 bool ClearControlFrameWithTransmissionType(const QuicFrame& frame,
218                                            TransmissionType /*type*/) {
219   return ClearControlFrame(frame);
220 }
221 
RandUint64()222 uint64_t SimpleRandom::RandUint64() {
223   uint64_t result;
224   RandBytes(&result, sizeof(result));
225   return result;
226 }
227 
RandBytes(void * data,size_t len)228 void SimpleRandom::RandBytes(void* data, size_t len) {
229   uint8_t* data_bytes = reinterpret_cast<uint8_t*>(data);
230   while (len > 0) {
231     const size_t buffer_left = sizeof(buffer_) - buffer_offset_;
232     const size_t to_copy = std::min(buffer_left, len);
233     memcpy(data_bytes, buffer_ + buffer_offset_, to_copy);
234     data_bytes += to_copy;
235     buffer_offset_ += to_copy;
236     len -= to_copy;
237 
238     if (buffer_offset_ == sizeof(buffer_)) {
239       FillBuffer();
240     }
241   }
242 }
243 
InsecureRandBytes(void * data,size_t len)244 void SimpleRandom::InsecureRandBytes(void* data, size_t len) {
245   RandBytes(data, len);
246 }
247 
InsecureRandUint64()248 uint64_t SimpleRandom::InsecureRandUint64() { return RandUint64(); }
249 
FillBuffer()250 void SimpleRandom::FillBuffer() {
251   uint8_t nonce[12];
252   memcpy(nonce, buffer_, sizeof(nonce));
253   CRYPTO_chacha_20(buffer_, buffer_, sizeof(buffer_), key_, nonce, 0);
254   buffer_offset_ = 0;
255 }
256 
set_seed(uint64_t seed)257 void SimpleRandom::set_seed(uint64_t seed) {
258   static_assert(sizeof(key_) == SHA256_DIGEST_LENGTH, "Key has to be 256 bits");
259   SHA256(reinterpret_cast<const uint8_t*>(&seed), sizeof(seed), key_);
260 
261   memset(buffer_, 0, sizeof(buffer_));
262   FillBuffer();
263 }
264 
MockFramerVisitor()265 MockFramerVisitor::MockFramerVisitor() {
266   // By default, we want to accept packets.
267   ON_CALL(*this, OnProtocolVersionMismatch(_))
268       .WillByDefault(testing::Return(false));
269 
270   // By default, we want to accept packets.
271   ON_CALL(*this, OnUnauthenticatedHeader(_))
272       .WillByDefault(testing::Return(true));
273 
274   ON_CALL(*this, OnUnauthenticatedPublicHeader(_))
275       .WillByDefault(testing::Return(true));
276 
277   ON_CALL(*this, OnPacketHeader(_)).WillByDefault(testing::Return(true));
278 
279   ON_CALL(*this, OnStreamFrame(_)).WillByDefault(testing::Return(true));
280 
281   ON_CALL(*this, OnCryptoFrame(_)).WillByDefault(testing::Return(true));
282 
283   ON_CALL(*this, OnStopWaitingFrame(_)).WillByDefault(testing::Return(true));
284 
285   ON_CALL(*this, OnPaddingFrame(_)).WillByDefault(testing::Return(true));
286 
287   ON_CALL(*this, OnPingFrame(_)).WillByDefault(testing::Return(true));
288 
289   ON_CALL(*this, OnRstStreamFrame(_)).WillByDefault(testing::Return(true));
290 
291   ON_CALL(*this, OnConnectionCloseFrame(_))
292       .WillByDefault(testing::Return(true));
293 
294   ON_CALL(*this, OnStopSendingFrame(_)).WillByDefault(testing::Return(true));
295 
296   ON_CALL(*this, OnPathChallengeFrame(_)).WillByDefault(testing::Return(true));
297 
298   ON_CALL(*this, OnPathResponseFrame(_)).WillByDefault(testing::Return(true));
299 
300   ON_CALL(*this, OnGoAwayFrame(_)).WillByDefault(testing::Return(true));
301   ON_CALL(*this, OnMaxStreamsFrame(_)).WillByDefault(testing::Return(true));
302   ON_CALL(*this, OnStreamsBlockedFrame(_)).WillByDefault(testing::Return(true));
303 }
304 
~MockFramerVisitor()305 MockFramerVisitor::~MockFramerVisitor() {}
306 
OnProtocolVersionMismatch(ParsedQuicVersion)307 bool NoOpFramerVisitor::OnProtocolVersionMismatch(
308     ParsedQuicVersion /*version*/) {
309   return false;
310 }
311 
OnUnauthenticatedPublicHeader(const QuicPacketHeader &)312 bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
313     const QuicPacketHeader& /*header*/) {
314   return true;
315 }
316 
OnUnauthenticatedHeader(const QuicPacketHeader &)317 bool NoOpFramerVisitor::OnUnauthenticatedHeader(
318     const QuicPacketHeader& /*header*/) {
319   return true;
320 }
321 
OnPacketHeader(const QuicPacketHeader &)322 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& /*header*/) {
323   return true;
324 }
325 
OnCoalescedPacket(const QuicEncryptedPacket &)326 void NoOpFramerVisitor::OnCoalescedPacket(
327     const QuicEncryptedPacket& /*packet*/) {}
328 
OnUndecryptablePacket(const QuicEncryptedPacket &,EncryptionLevel,bool)329 void NoOpFramerVisitor::OnUndecryptablePacket(
330     const QuicEncryptedPacket& /*packet*/, EncryptionLevel /*decryption_level*/,
331     bool /*has_decryption_key*/) {}
332 
OnStreamFrame(const QuicStreamFrame &)333 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& /*frame*/) {
334   return true;
335 }
336 
OnCryptoFrame(const QuicCryptoFrame &)337 bool NoOpFramerVisitor::OnCryptoFrame(const QuicCryptoFrame& /*frame*/) {
338   return true;
339 }
340 
OnAckFrameStart(QuicPacketNumber,QuicTime::Delta)341 bool NoOpFramerVisitor::OnAckFrameStart(QuicPacketNumber /*largest_acked*/,
342                                         QuicTime::Delta /*ack_delay_time*/) {
343   return true;
344 }
345 
OnAckRange(QuicPacketNumber,QuicPacketNumber)346 bool NoOpFramerVisitor::OnAckRange(QuicPacketNumber /*start*/,
347                                    QuicPacketNumber /*end*/) {
348   return true;
349 }
350 
OnAckTimestamp(QuicPacketNumber,QuicTime)351 bool NoOpFramerVisitor::OnAckTimestamp(QuicPacketNumber /*packet_number*/,
352                                        QuicTime /*timestamp*/) {
353   return true;
354 }
355 
OnAckFrameEnd(QuicPacketNumber,const std::optional<QuicEcnCounts> &)356 bool NoOpFramerVisitor::OnAckFrameEnd(
357     QuicPacketNumber /*start*/,
358     const std::optional<QuicEcnCounts>& /*ecn_counts*/) {
359   return true;
360 }
361 
OnStopWaitingFrame(const QuicStopWaitingFrame &)362 bool NoOpFramerVisitor::OnStopWaitingFrame(
363     const QuicStopWaitingFrame& /*frame*/) {
364   return true;
365 }
366 
OnPaddingFrame(const QuicPaddingFrame &)367 bool NoOpFramerVisitor::OnPaddingFrame(const QuicPaddingFrame& /*frame*/) {
368   return true;
369 }
370 
OnPingFrame(const QuicPingFrame &)371 bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame& /*frame*/) {
372   return true;
373 }
374 
OnRstStreamFrame(const QuicRstStreamFrame &)375 bool NoOpFramerVisitor::OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) {
376   return true;
377 }
378 
OnConnectionCloseFrame(const QuicConnectionCloseFrame &)379 bool NoOpFramerVisitor::OnConnectionCloseFrame(
380     const QuicConnectionCloseFrame& /*frame*/) {
381   return true;
382 }
383 
OnNewConnectionIdFrame(const QuicNewConnectionIdFrame &)384 bool NoOpFramerVisitor::OnNewConnectionIdFrame(
385     const QuicNewConnectionIdFrame& /*frame*/) {
386   return true;
387 }
388 
OnRetireConnectionIdFrame(const QuicRetireConnectionIdFrame &)389 bool NoOpFramerVisitor::OnRetireConnectionIdFrame(
390     const QuicRetireConnectionIdFrame& /*frame*/) {
391   return true;
392 }
393 
OnNewTokenFrame(const QuicNewTokenFrame &)394 bool NoOpFramerVisitor::OnNewTokenFrame(const QuicNewTokenFrame& /*frame*/) {
395   return true;
396 }
397 
OnStopSendingFrame(const QuicStopSendingFrame &)398 bool NoOpFramerVisitor::OnStopSendingFrame(
399     const QuicStopSendingFrame& /*frame*/) {
400   return true;
401 }
402 
OnPathChallengeFrame(const QuicPathChallengeFrame &)403 bool NoOpFramerVisitor::OnPathChallengeFrame(
404     const QuicPathChallengeFrame& /*frame*/) {
405   return true;
406 }
407 
OnPathResponseFrame(const QuicPathResponseFrame &)408 bool NoOpFramerVisitor::OnPathResponseFrame(
409     const QuicPathResponseFrame& /*frame*/) {
410   return true;
411 }
412 
OnGoAwayFrame(const QuicGoAwayFrame &)413 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) {
414   return true;
415 }
416 
OnMaxStreamsFrame(const QuicMaxStreamsFrame &)417 bool NoOpFramerVisitor::OnMaxStreamsFrame(
418     const QuicMaxStreamsFrame& /*frame*/) {
419   return true;
420 }
421 
OnStreamsBlockedFrame(const QuicStreamsBlockedFrame &)422 bool NoOpFramerVisitor::OnStreamsBlockedFrame(
423     const QuicStreamsBlockedFrame& /*frame*/) {
424   return true;
425 }
426 
OnWindowUpdateFrame(const QuicWindowUpdateFrame &)427 bool NoOpFramerVisitor::OnWindowUpdateFrame(
428     const QuicWindowUpdateFrame& /*frame*/) {
429   return true;
430 }
431 
OnBlockedFrame(const QuicBlockedFrame &)432 bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& /*frame*/) {
433   return true;
434 }
435 
OnMessageFrame(const QuicMessageFrame &)436 bool NoOpFramerVisitor::OnMessageFrame(const QuicMessageFrame& /*frame*/) {
437   return true;
438 }
439 
OnHandshakeDoneFrame(const QuicHandshakeDoneFrame &)440 bool NoOpFramerVisitor::OnHandshakeDoneFrame(
441     const QuicHandshakeDoneFrame& /*frame*/) {
442   return true;
443 }
444 
OnAckFrequencyFrame(const QuicAckFrequencyFrame &)445 bool NoOpFramerVisitor::OnAckFrequencyFrame(
446     const QuicAckFrequencyFrame& /*frame*/) {
447   return true;
448 }
449 
OnResetStreamAtFrame(const QuicResetStreamAtFrame &)450 bool NoOpFramerVisitor::OnResetStreamAtFrame(
451     const QuicResetStreamAtFrame& /*frame*/) {
452   return true;
453 }
454 
IsValidStatelessResetToken(const StatelessResetToken &) const455 bool NoOpFramerVisitor::IsValidStatelessResetToken(
456     const StatelessResetToken& /*token*/) const {
457   return false;
458 }
459 
MockQuicConnectionVisitor()460 MockQuicConnectionVisitor::MockQuicConnectionVisitor() {
461   ON_CALL(*this, GetFlowControlSendWindowSize(_))
462       .WillByDefault(Return(std::numeric_limits<QuicByteCount>::max()));
463 }
464 
~MockQuicConnectionVisitor()465 MockQuicConnectionVisitor::~MockQuicConnectionVisitor() {}
466 
MockQuicConnectionHelper()467 MockQuicConnectionHelper::MockQuicConnectionHelper() {}
468 
~MockQuicConnectionHelper()469 MockQuicConnectionHelper::~MockQuicConnectionHelper() {}
470 
GetClock() const471 const MockClock* MockQuicConnectionHelper::GetClock() const { return &clock_; }
472 
GetClock()473 MockClock* MockQuicConnectionHelper::GetClock() { return &clock_; }
474 
GetRandomGenerator()475 QuicRandom* MockQuicConnectionHelper::GetRandomGenerator() {
476   return &random_generator_;
477 }
478 
CreateAlarm(QuicAlarm::Delegate * delegate)479 QuicAlarm* MockAlarmFactory::CreateAlarm(QuicAlarm::Delegate* delegate) {
480   return new MockAlarmFactory::TestAlarm(
481       QuicArenaScopedPtr<QuicAlarm::Delegate>(delegate));
482 }
483 
CreateAlarm(QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,QuicConnectionArena * arena)484 QuicArenaScopedPtr<QuicAlarm> MockAlarmFactory::CreateAlarm(
485     QuicArenaScopedPtr<QuicAlarm::Delegate> delegate,
486     QuicConnectionArena* arena) {
487   if (arena != nullptr) {
488     return arena->New<TestAlarm>(std::move(delegate));
489   } else {
490     return QuicArenaScopedPtr<TestAlarm>(new TestAlarm(std::move(delegate)));
491   }
492 }
493 
494 quiche::QuicheBufferAllocator*
GetStreamSendBufferAllocator()495 MockQuicConnectionHelper::GetStreamSendBufferAllocator() {
496   return &buffer_allocator_;
497 }
498 
AdvanceTime(QuicTime::Delta delta)499 void MockQuicConnectionHelper::AdvanceTime(QuicTime::Delta delta) {
500   clock_.AdvanceTime(delta);
501 }
502 
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)503 MockQuicConnection::MockQuicConnection(QuicConnectionHelperInterface* helper,
504                                        QuicAlarmFactory* alarm_factory,
505                                        Perspective perspective)
506     : MockQuicConnection(TestConnectionId(),
507                          QuicSocketAddress(TestPeerIPAddress(), kTestPort),
508                          helper, alarm_factory, perspective,
509                          ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
510 
MockQuicConnection(QuicSocketAddress address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)511 MockQuicConnection::MockQuicConnection(QuicSocketAddress address,
512                                        QuicConnectionHelperInterface* helper,
513                                        QuicAlarmFactory* alarm_factory,
514                                        Perspective perspective)
515     : MockQuicConnection(TestConnectionId(), address, helper, alarm_factory,
516                          perspective,
517                          ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
518 
MockQuicConnection(QuicConnectionId connection_id,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)519 MockQuicConnection::MockQuicConnection(QuicConnectionId connection_id,
520                                        QuicConnectionHelperInterface* helper,
521                                        QuicAlarmFactory* alarm_factory,
522                                        Perspective perspective)
523     : MockQuicConnection(connection_id,
524                          QuicSocketAddress(TestPeerIPAddress(), kTestPort),
525                          helper, alarm_factory, perspective,
526                          ParsedVersionOfIndex(CurrentSupportedVersions(), 0)) {}
527 
MockQuicConnection(QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)528 MockQuicConnection::MockQuicConnection(
529     QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
530     Perspective perspective, const ParsedQuicVersionVector& supported_versions)
531     : MockQuicConnection(
532           TestConnectionId(), QuicSocketAddress(TestPeerIPAddress(), kTestPort),
533           helper, alarm_factory, perspective, supported_versions) {}
534 
MockQuicConnection(QuicConnectionId connection_id,QuicSocketAddress initial_peer_address,QuicConnectionHelperInterface * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)535 MockQuicConnection::MockQuicConnection(
536     QuicConnectionId connection_id, QuicSocketAddress initial_peer_address,
537     QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory,
538     Perspective perspective, const ParsedQuicVersionVector& supported_versions)
539     : QuicConnection(
540           connection_id,
541           /*initial_self_address=*/QuicSocketAddress(QuicIpAddress::Any4(), 5),
542           initial_peer_address, helper, alarm_factory,
543           new testing::NiceMock<MockPacketWriter>(),
544           /* owns_writer= */ true, perspective, supported_versions,
545           connection_id_generator_) {
546   ON_CALL(*this, OnError(_))
547       .WillByDefault(
548           Invoke(this, &PacketSavingConnection::QuicConnection_OnError));
549   ON_CALL(*this, SendCryptoData(_, _, _))
550       .WillByDefault(
551           Invoke(this, &MockQuicConnection::QuicConnection_SendCryptoData));
552 
553   SetSelfAddress(QuicSocketAddress(QuicIpAddress::Any4(), 5));
554 }
555 
~MockQuicConnection()556 MockQuicConnection::~MockQuicConnection() {}
557 
AdvanceTime(QuicTime::Delta delta)558 void MockQuicConnection::AdvanceTime(QuicTime::Delta delta) {
559   static_cast<MockQuicConnectionHelper*>(helper())->AdvanceTime(delta);
560 }
561 
OnProtocolVersionMismatch(ParsedQuicVersion)562 bool MockQuicConnection::OnProtocolVersionMismatch(
563     ParsedQuicVersion /*version*/) {
564   return false;
565 }
566 
PacketSavingConnection(MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,Perspective perspective)567 PacketSavingConnection::PacketSavingConnection(MockQuicConnectionHelper* helper,
568                                                QuicAlarmFactory* alarm_factory,
569                                                Perspective perspective)
570     : MockQuicConnection(helper, alarm_factory, perspective),
571       mock_helper_(helper) {}
572 
PacketSavingConnection(MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,Perspective perspective,const ParsedQuicVersionVector & supported_versions)573 PacketSavingConnection::PacketSavingConnection(
574     MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
575     Perspective perspective, const ParsedQuicVersionVector& supported_versions)
576     : MockQuicConnection(helper, alarm_factory, perspective,
577                          supported_versions),
578       mock_helper_(helper) {}
579 
~PacketSavingConnection()580 PacketSavingConnection::~PacketSavingConnection() {}
581 
GetSerializedPacketFate(bool,EncryptionLevel)582 SerializedPacketFate PacketSavingConnection::GetSerializedPacketFate(
583     bool /*is_mtu_discovery*/, EncryptionLevel /*encryption_level*/) {
584   return SEND_TO_WRITER;
585 }
586 
SendOrQueuePacket(SerializedPacket packet)587 void PacketSavingConnection::SendOrQueuePacket(SerializedPacket packet) {
588   encrypted_packets_.push_back(std::make_unique<QuicEncryptedPacket>(
589       CopyBuffer(packet), packet.encrypted_length, true));
590   MockClock& clock = *mock_helper_->GetClock();
591   clock.AdvanceTime(QuicTime::Delta::FromMilliseconds(10));
592   // Transfer ownership of the packet to the SentPacketManager and the
593   // ack notifier to the AckNotifierManager.
594   OnPacketSent(packet.encryption_level, packet.transmission_type);
595   QuicConnectionPeer::GetSentPacketManager(this)->OnPacketSent(
596       &packet, clock.ApproximateNow(), NOT_RETRANSMISSION,
597       HAS_RETRANSMITTABLE_DATA, true, ECN_NOT_ECT);
598 }
599 
GetPackets() const600 std::vector<const QuicEncryptedPacket*> PacketSavingConnection::GetPackets()
601     const {
602   std::vector<const QuicEncryptedPacket*> packets;
603   for (size_t i = num_cleared_packets_; i < encrypted_packets_.size(); ++i) {
604     packets.push_back(encrypted_packets_[i].get());
605   }
606   return packets;
607 }
608 
ClearPackets()609 void PacketSavingConnection::ClearPackets() {
610   num_cleared_packets_ = encrypted_packets_.size();
611 }
612 
MockQuicSession(QuicConnection * connection)613 MockQuicSession::MockQuicSession(QuicConnection* connection)
614     : MockQuicSession(connection, true) {}
615 
MockQuicSession(QuicConnection * connection,bool create_mock_crypto_stream)616 MockQuicSession::MockQuicSession(QuicConnection* connection,
617                                  bool create_mock_crypto_stream)
618     : QuicSession(connection, nullptr, DefaultQuicConfig(),
619                   connection->supported_versions(),
620                   /*num_expected_unidirectional_static_streams = */ 0) {
621   if (create_mock_crypto_stream) {
622     crypto_stream_ =
623         std::make_unique<testing::NiceMock<MockQuicCryptoStream>>(this);
624   }
625   ON_CALL(*this, WritevData(_, _, _, _, _, _))
626       .WillByDefault(testing::Return(QuicConsumedData(0, false)));
627 }
628 
~MockQuicSession()629 MockQuicSession::~MockQuicSession() { DeleteConnection(); }
630 
GetMutableCryptoStream()631 QuicCryptoStream* MockQuicSession::GetMutableCryptoStream() {
632   return crypto_stream_.get();
633 }
634 
GetCryptoStream() const635 const QuicCryptoStream* MockQuicSession::GetCryptoStream() const {
636   return crypto_stream_.get();
637 }
638 
SetCryptoStream(QuicCryptoStream * crypto_stream)639 void MockQuicSession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
640   crypto_stream_.reset(crypto_stream);
641 }
642 
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,std::optional<EncryptionLevel>)643 QuicConsumedData MockQuicSession::ConsumeData(
644     QuicStreamId id, size_t write_length, QuicStreamOffset offset,
645     StreamSendingState state, TransmissionType /*type*/,
646     std::optional<EncryptionLevel> /*level*/) {
647   if (write_length > 0) {
648     auto buf = std::make_unique<char[]>(write_length);
649     QuicStream* stream = GetOrCreateStream(id);
650     QUICHE_DCHECK(stream);
651     QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
652     stream->WriteStreamData(offset, write_length, &writer);
653   } else {
654     QUICHE_DCHECK(state != NO_FIN);
655   }
656   return QuicConsumedData(write_length, state != NO_FIN);
657 }
658 
MockQuicCryptoStream(QuicSession * session)659 MockQuicCryptoStream::MockQuicCryptoStream(QuicSession* session)
660     : QuicCryptoStream(session), params_(new QuicCryptoNegotiatedParameters) {}
661 
~MockQuicCryptoStream()662 MockQuicCryptoStream::~MockQuicCryptoStream() {}
663 
EarlyDataReason() const664 ssl_early_data_reason_t MockQuicCryptoStream::EarlyDataReason() const {
665   return ssl_early_data_unknown;
666 }
667 
one_rtt_keys_available() const668 bool MockQuicCryptoStream::one_rtt_keys_available() const { return false; }
669 
670 const QuicCryptoNegotiatedParameters&
crypto_negotiated_params() const671 MockQuicCryptoStream::crypto_negotiated_params() const {
672   return *params_;
673 }
674 
crypto_message_parser()675 CryptoMessageParser* MockQuicCryptoStream::crypto_message_parser() {
676   return &crypto_framer_;
677 }
678 
MockQuicSpdySession(QuicConnection * connection)679 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection)
680     : MockQuicSpdySession(connection, true) {}
681 
MockQuicSpdySession(QuicConnection * connection,bool create_mock_crypto_stream)682 MockQuicSpdySession::MockQuicSpdySession(QuicConnection* connection,
683                                          bool create_mock_crypto_stream)
684     : QuicSpdySession(connection, nullptr, DefaultQuicConfig(),
685                       connection->supported_versions()) {
686   if (create_mock_crypto_stream) {
687     crypto_stream_ = std::make_unique<MockQuicCryptoStream>(this);
688   }
689 
690   ON_CALL(*this, WritevData(_, _, _, _, _, _))
691       .WillByDefault(testing::Return(QuicConsumedData(0, false)));
692 
693   ON_CALL(*this, SendWindowUpdate(_, _))
694       .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
695         return QuicSpdySession::SendWindowUpdate(id, byte_offset);
696       });
697 
698   ON_CALL(*this, SendBlocked(_, _))
699       .WillByDefault([this](QuicStreamId id, QuicStreamOffset byte_offset) {
700         return QuicSpdySession::SendBlocked(id, byte_offset);
701       });
702 
703   ON_CALL(*this, OnCongestionWindowChange(_)).WillByDefault(testing::Return());
704 }
705 
~MockQuicSpdySession()706 MockQuicSpdySession::~MockQuicSpdySession() { DeleteConnection(); }
707 
GetMutableCryptoStream()708 QuicCryptoStream* MockQuicSpdySession::GetMutableCryptoStream() {
709   return crypto_stream_.get();
710 }
711 
GetCryptoStream() const712 const QuicCryptoStream* MockQuicSpdySession::GetCryptoStream() const {
713   return crypto_stream_.get();
714 }
715 
SetCryptoStream(QuicCryptoStream * crypto_stream)716 void MockQuicSpdySession::SetCryptoStream(QuicCryptoStream* crypto_stream) {
717   crypto_stream_.reset(crypto_stream);
718 }
719 
ConsumeData(QuicStreamId id,size_t write_length,QuicStreamOffset offset,StreamSendingState state,TransmissionType,std::optional<EncryptionLevel>)720 QuicConsumedData MockQuicSpdySession::ConsumeData(
721     QuicStreamId id, size_t write_length, QuicStreamOffset offset,
722     StreamSendingState state, TransmissionType /*type*/,
723     std::optional<EncryptionLevel> /*level*/) {
724   if (write_length > 0) {
725     auto buf = std::make_unique<char[]>(write_length);
726     QuicStream* stream = GetOrCreateStream(id);
727     QUICHE_DCHECK(stream);
728     QuicDataWriter writer(write_length, buf.get(), quiche::HOST_BYTE_ORDER);
729     stream->WriteStreamData(offset, write_length, &writer);
730   } else {
731     QUICHE_DCHECK(state != NO_FIN);
732   }
733   return QuicConsumedData(write_length, state != NO_FIN);
734 }
735 
TestQuicSpdyServerSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)736 TestQuicSpdyServerSession::TestQuicSpdyServerSession(
737     QuicConnection* connection, const QuicConfig& config,
738     const ParsedQuicVersionVector& supported_versions,
739     const QuicCryptoServerConfig* crypto_config,
740     QuicCompressedCertsCache* compressed_certs_cache)
741     : QuicServerSessionBase(config, supported_versions, connection, &visitor_,
742                             &helper_, crypto_config, compressed_certs_cache) {
743   ON_CALL(helper_, CanAcceptClientHello(_, _, _, _, _))
744       .WillByDefault(testing::Return(true));
745 }
746 
~TestQuicSpdyServerSession()747 TestQuicSpdyServerSession::~TestQuicSpdyServerSession() { DeleteConnection(); }
748 
749 std::unique_ptr<QuicCryptoServerStreamBase>
CreateQuicCryptoServerStream(const QuicCryptoServerConfig * crypto_config,QuicCompressedCertsCache * compressed_certs_cache)750 TestQuicSpdyServerSession::CreateQuicCryptoServerStream(
751     const QuicCryptoServerConfig* crypto_config,
752     QuicCompressedCertsCache* compressed_certs_cache) {
753   return CreateCryptoServerStream(crypto_config, compressed_certs_cache, this,
754                                   &helper_);
755 }
756 
757 QuicCryptoServerStreamBase*
GetMutableCryptoStream()758 TestQuicSpdyServerSession::GetMutableCryptoStream() {
759   return QuicServerSessionBase::GetMutableCryptoStream();
760 }
761 
GetCryptoStream() const762 const QuicCryptoServerStreamBase* TestQuicSpdyServerSession::GetCryptoStream()
763     const {
764   return QuicServerSessionBase::GetCryptoStream();
765 }
766 
TestQuicSpdyClientSession(QuicConnection * connection,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,const QuicServerId & server_id,QuicCryptoClientConfig * crypto_config,std::optional<QuicSSLConfig> ssl_config)767 TestQuicSpdyClientSession::TestQuicSpdyClientSession(
768     QuicConnection* connection, const QuicConfig& config,
769     const ParsedQuicVersionVector& supported_versions,
770     const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config,
771     std::optional<QuicSSLConfig> ssl_config)
772     : QuicSpdyClientSessionBase(connection, nullptr, config,
773                                 supported_versions),
774       ssl_config_(std::move(ssl_config)) {
775   // TODO(b/153726130): Consider adding SetServerApplicationStateForResumption
776   // calls in tests and set |has_application_state| to true.
777   crypto_stream_ = std::make_unique<QuicCryptoClientStream>(
778       server_id, this, crypto_test_utils::ProofVerifyContextForTesting(),
779       crypto_config, this, /*has_application_state = */ false);
780   Initialize();
781   ON_CALL(*this, OnConfigNegotiated())
782       .WillByDefault(
783           Invoke(this, &TestQuicSpdyClientSession::RealOnConfigNegotiated));
784 }
785 
~TestQuicSpdyClientSession()786 TestQuicSpdyClientSession::~TestQuicSpdyClientSession() {}
787 
GetMutableCryptoStream()788 QuicCryptoClientStream* TestQuicSpdyClientSession::GetMutableCryptoStream() {
789   return crypto_stream_.get();
790 }
791 
GetCryptoStream() const792 const QuicCryptoClientStream* TestQuicSpdyClientSession::GetCryptoStream()
793     const {
794   return crypto_stream_.get();
795 }
796 
RealOnConfigNegotiated()797 void TestQuicSpdyClientSession::RealOnConfigNegotiated() {
798   QuicSpdyClientSessionBase::OnConfigNegotiated();
799 }
800 
MockPacketWriter()801 MockPacketWriter::MockPacketWriter() {
802   ON_CALL(*this, GetMaxPacketSize(_))
803       .WillByDefault(testing::Return(kMaxOutgoingPacketSize));
804   ON_CALL(*this, IsBatchMode()).WillByDefault(testing::Return(false));
805   ON_CALL(*this, GetNextWriteLocation(_, _))
806       .WillByDefault(testing::Return(QuicPacketBuffer()));
807   ON_CALL(*this, Flush())
808       .WillByDefault(testing::Return(WriteResult(WRITE_STATUS_OK, 0)));
809   ON_CALL(*this, SupportsReleaseTime()).WillByDefault(testing::Return(false));
810 }
811 
~MockPacketWriter()812 MockPacketWriter::~MockPacketWriter() {}
813 
MockSendAlgorithm()814 MockSendAlgorithm::MockSendAlgorithm() {
815   ON_CALL(*this, PacingRate(_))
816       .WillByDefault(testing::Return(QuicBandwidth::Zero()));
817   ON_CALL(*this, BandwidthEstimate())
818       .WillByDefault(testing::Return(QuicBandwidth::Zero()));
819 }
820 
~MockSendAlgorithm()821 MockSendAlgorithm::~MockSendAlgorithm() {}
822 
MockLossAlgorithm()823 MockLossAlgorithm::MockLossAlgorithm() {}
824 
~MockLossAlgorithm()825 MockLossAlgorithm::~MockLossAlgorithm() {}
826 
MockAckListener()827 MockAckListener::MockAckListener() {}
828 
~MockAckListener()829 MockAckListener::~MockAckListener() {}
830 
MockNetworkChangeVisitor()831 MockNetworkChangeVisitor::MockNetworkChangeVisitor() {}
832 
~MockNetworkChangeVisitor()833 MockNetworkChangeVisitor::~MockNetworkChangeVisitor() {}
834 
TestPeerIPAddress()835 QuicIpAddress TestPeerIPAddress() { return QuicIpAddress::Loopback4(); }
836 
QuicVersionMax()837 ParsedQuicVersion QuicVersionMax() { return AllSupportedVersions().front(); }
838 
QuicVersionMin()839 ParsedQuicVersion QuicVersionMin() { return AllSupportedVersions().back(); }
840 
DisableQuicVersionsWithTls()841 void DisableQuicVersionsWithTls() {
842   for (const ParsedQuicVersion& version : AllSupportedVersionsWithTls()) {
843     QuicDisableVersion(version);
844   }
845 }
846 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data)847 QuicEncryptedPacket* ConstructEncryptedPacket(
848     QuicConnectionId destination_connection_id,
849     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
850     uint64_t packet_number, const std::string& data) {
851   return ConstructEncryptedPacket(
852       destination_connection_id, source_connection_id, version_flag, reset_flag,
853       packet_number, data, CONNECTION_ID_PRESENT, CONNECTION_ID_ABSENT,
854       PACKET_4BYTE_PACKET_NUMBER);
855 }
856 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length)857 QuicEncryptedPacket* ConstructEncryptedPacket(
858     QuicConnectionId destination_connection_id,
859     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
860     uint64_t packet_number, const std::string& data,
861     QuicConnectionIdIncluded destination_connection_id_included,
862     QuicConnectionIdIncluded source_connection_id_included,
863     QuicPacketNumberLength packet_number_length) {
864   return ConstructEncryptedPacket(
865       destination_connection_id, source_connection_id, version_flag, reset_flag,
866       packet_number, data, destination_connection_id_included,
867       source_connection_id_included, packet_number_length, nullptr);
868 }
869 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)870 QuicEncryptedPacket* ConstructEncryptedPacket(
871     QuicConnectionId destination_connection_id,
872     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
873     uint64_t packet_number, const std::string& data,
874     QuicConnectionIdIncluded destination_connection_id_included,
875     QuicConnectionIdIncluded source_connection_id_included,
876     QuicPacketNumberLength packet_number_length,
877     ParsedQuicVersionVector* versions) {
878   return ConstructEncryptedPacket(
879       destination_connection_id, source_connection_id, version_flag, reset_flag,
880       packet_number, data, false, destination_connection_id_included,
881       source_connection_id_included, packet_number_length, versions,
882       Perspective::IS_CLIENT);
883 }
884 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions)885 QuicEncryptedPacket* ConstructEncryptedPacket(
886     QuicConnectionId destination_connection_id,
887     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
888     uint64_t packet_number, const std::string& data, bool full_padding,
889     QuicConnectionIdIncluded destination_connection_id_included,
890     QuicConnectionIdIncluded source_connection_id_included,
891     QuicPacketNumberLength packet_number_length,
892     ParsedQuicVersionVector* versions) {
893   return ConstructEncryptedPacket(
894       destination_connection_id, source_connection_id, version_flag, reset_flag,
895       packet_number, data, full_padding, destination_connection_id_included,
896       source_connection_id_included, packet_number_length, versions,
897       Perspective::IS_CLIENT);
898 }
899 
ConstructEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,bool full_padding,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersionVector * versions,Perspective perspective)900 QuicEncryptedPacket* ConstructEncryptedPacket(
901     QuicConnectionId destination_connection_id,
902     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
903     uint64_t packet_number, const std::string& data, bool full_padding,
904     QuicConnectionIdIncluded destination_connection_id_included,
905     QuicConnectionIdIncluded source_connection_id_included,
906     QuicPacketNumberLength packet_number_length,
907     ParsedQuicVersionVector* versions, Perspective perspective) {
908   QuicPacketHeader header;
909   header.destination_connection_id = destination_connection_id;
910   header.destination_connection_id_included =
911       destination_connection_id_included;
912   header.source_connection_id = source_connection_id;
913   header.source_connection_id_included = source_connection_id_included;
914   header.version_flag = version_flag;
915   header.reset_flag = reset_flag;
916   header.packet_number_length = packet_number_length;
917   header.packet_number = QuicPacketNumber(packet_number);
918   ParsedQuicVersionVector supported_versions = CurrentSupportedVersions();
919   if (!versions) {
920     versions = &supported_versions;
921   }
922   EXPECT_FALSE(versions->empty());
923   ParsedQuicVersion version = (*versions)[0];
924   if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
925       version_flag) {
926     header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
927     header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
928   }
929 
930   QuicFrames frames;
931   QuicFramer framer(*versions, QuicTime::Zero(), perspective,
932                     kQuicDefaultConnectionIdLength);
933   framer.SetInitialObfuscators(destination_connection_id);
934   EncryptionLevel level =
935       header.version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
936   if (level != ENCRYPTION_INITIAL) {
937     framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
938   }
939   if (!QuicVersionUsesCryptoFrames(version.transport_version)) {
940     QuicFrame frame(
941         QuicStreamFrame(QuicUtils::GetCryptoStreamId(version.transport_version),
942                         false, 0, absl::string_view(data)));
943     frames.push_back(frame);
944   } else {
945     QuicFrame frame(new QuicCryptoFrame(level, 0, data));
946     frames.push_back(frame);
947   }
948   if (full_padding) {
949     frames.push_back(QuicFrame(QuicPaddingFrame(-1)));
950   } else {
951     // We need a minimum number of bytes of encrypted payload. This will
952     // guarantee that we have at least that much. (It ignores the overhead of
953     // the stream/crypto framing, so it overpads slightly.)
954     size_t min_plaintext_size = QuicPacketCreator::MinPlaintextPacketSize(
955         version, packet_number_length);
956     if (data.length() < min_plaintext_size) {
957       size_t padding_length = min_plaintext_size - data.length();
958       frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
959     }
960   }
961 
962   std::unique_ptr<QuicPacket> packet(
963       BuildUnsizedDataPacket(&framer, header, frames));
964   EXPECT_TRUE(packet != nullptr);
965   char* buffer = new char[kMaxOutgoingPacketSize];
966   size_t encrypted_length =
967       framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
968                             buffer, kMaxOutgoingPacketSize);
969   EXPECT_NE(0u, encrypted_length);
970   DeleteFrames(&frames);
971   return new QuicEncryptedPacket(buffer, encrypted_length, true);
972 }
973 
GetUndecryptableEarlyPacket(const ParsedQuicVersion & version,const QuicConnectionId & server_connection_id)974 std::unique_ptr<QuicEncryptedPacket> GetUndecryptableEarlyPacket(
975     const ParsedQuicVersion& version,
976     const QuicConnectionId& server_connection_id) {
977   QuicPacketHeader header;
978   header.destination_connection_id = server_connection_id;
979   header.destination_connection_id_included = CONNECTION_ID_PRESENT;
980   header.source_connection_id = EmptyQuicConnectionId();
981   header.source_connection_id_included = CONNECTION_ID_PRESENT;
982   if (!version.SupportsClientConnectionIds()) {
983     header.source_connection_id_included = CONNECTION_ID_ABSENT;
984   }
985   header.version_flag = true;
986   header.reset_flag = false;
987   header.packet_number_length = PACKET_4BYTE_PACKET_NUMBER;
988   header.packet_number = QuicPacketNumber(33);
989   header.long_packet_type = ZERO_RTT_PROTECTED;
990   if (version.HasLongHeaderLengths()) {
991     header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
992     header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
993   }
994 
995   QuicFrames frames;
996   frames.push_back(QuicFrame(QuicPingFrame()));
997   frames.push_back(QuicFrame(QuicPaddingFrame(100)));
998   QuicFramer framer({version}, QuicTime::Zero(), Perspective::IS_CLIENT,
999                     kQuicDefaultConnectionIdLength);
1000   framer.SetInitialObfuscators(server_connection_id);
1001 
1002   framer.SetEncrypter(ENCRYPTION_ZERO_RTT,
1003                       std::make_unique<TaggingEncrypter>(ENCRYPTION_ZERO_RTT));
1004   std::unique_ptr<QuicPacket> packet(
1005       BuildUnsizedDataPacket(&framer, header, frames));
1006   EXPECT_TRUE(packet != nullptr);
1007   char* buffer = new char[kMaxOutgoingPacketSize];
1008   size_t encrypted_length =
1009       framer.EncryptPayload(ENCRYPTION_ZERO_RTT, header.packet_number, *packet,
1010                             buffer, kMaxOutgoingPacketSize);
1011   EXPECT_NE(0u, encrypted_length);
1012   DeleteFrames(&frames);
1013   return std::make_unique<QuicEncryptedPacket>(buffer, encrypted_length,
1014                                                /*owns_buffer=*/true);
1015 }
1016 
ConstructReceivedPacket(const QuicEncryptedPacket & encrypted_packet,QuicTime receipt_time)1017 QuicReceivedPacket* ConstructReceivedPacket(
1018     const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time) {
1019   return ConstructReceivedPacket(encrypted_packet, receipt_time, ECN_NOT_ECT);
1020 }
1021 
ConstructReceivedPacket(const QuicEncryptedPacket & encrypted_packet,QuicTime receipt_time,QuicEcnCodepoint ecn)1022 QuicReceivedPacket* ConstructReceivedPacket(
1023     const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time,
1024     QuicEcnCodepoint ecn) {
1025   char* buffer = new char[encrypted_packet.length()];
1026   memcpy(buffer, encrypted_packet.data(), encrypted_packet.length());
1027   return new QuicReceivedPacket(buffer, encrypted_packet.length(), receipt_time,
1028                                 true, 0, true, nullptr, 0, false, ecn);
1029 }
1030 
ConstructMisFramedEncryptedPacket(QuicConnectionId destination_connection_id,QuicConnectionId source_connection_id,bool version_flag,bool reset_flag,uint64_t packet_number,const std::string & data,QuicConnectionIdIncluded destination_connection_id_included,QuicConnectionIdIncluded source_connection_id_included,QuicPacketNumberLength packet_number_length,ParsedQuicVersion version,Perspective perspective)1031 QuicEncryptedPacket* ConstructMisFramedEncryptedPacket(
1032     QuicConnectionId destination_connection_id,
1033     QuicConnectionId source_connection_id, bool version_flag, bool reset_flag,
1034     uint64_t packet_number, const std::string& data,
1035     QuicConnectionIdIncluded destination_connection_id_included,
1036     QuicConnectionIdIncluded source_connection_id_included,
1037     QuicPacketNumberLength packet_number_length, ParsedQuicVersion version,
1038     Perspective perspective) {
1039   QuicPacketHeader header;
1040   header.destination_connection_id = destination_connection_id;
1041   header.destination_connection_id_included =
1042       destination_connection_id_included;
1043   header.source_connection_id = source_connection_id;
1044   header.source_connection_id_included = source_connection_id_included;
1045   header.version_flag = version_flag;
1046   header.reset_flag = reset_flag;
1047   header.packet_number_length = packet_number_length;
1048   header.packet_number = QuicPacketNumber(packet_number);
1049   if (QuicVersionHasLongHeaderLengths(version.transport_version) &&
1050       version_flag) {
1051     header.retry_token_length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_1;
1052     header.length_length = quiche::VARIABLE_LENGTH_INTEGER_LENGTH_2;
1053   }
1054   QuicFrame frame(QuicStreamFrame(1, false, 0, absl::string_view(data)));
1055   QuicFrames frames;
1056   frames.push_back(frame);
1057   QuicFramer framer({version}, QuicTime::Zero(), perspective,
1058                     kQuicDefaultConnectionIdLength);
1059   framer.SetInitialObfuscators(destination_connection_id);
1060   EncryptionLevel level =
1061       version_flag ? ENCRYPTION_INITIAL : ENCRYPTION_FORWARD_SECURE;
1062   if (level != ENCRYPTION_INITIAL) {
1063     framer.SetEncrypter(level, std::make_unique<TaggingEncrypter>(level));
1064   }
1065   // We need a minimum of 7 bytes of encrypted payload. This will guarantee that
1066   // we have at least that much. (It ignores the overhead of the stream/crypto
1067   // framing, so it overpads slightly.)
1068   if (data.length() < 7) {
1069     size_t padding_length = 7 - data.length();
1070     frames.push_back(QuicFrame(QuicPaddingFrame(padding_length)));
1071   }
1072 
1073   std::unique_ptr<QuicPacket> packet(
1074       BuildUnsizedDataPacket(&framer, header, frames));
1075   EXPECT_TRUE(packet != nullptr);
1076 
1077   // Now set the frame type to 0x1F, which is an invalid frame type.
1078   reinterpret_cast<unsigned char*>(
1079       packet->mutable_data())[GetStartOfEncryptedData(
1080       framer.transport_version(),
1081       GetIncludedDestinationConnectionIdLength(header),
1082       GetIncludedSourceConnectionIdLength(header), version_flag,
1083       false /* no diversification nonce */, packet_number_length,
1084       header.retry_token_length_length, 0, header.length_length)] = 0x1F;
1085 
1086   char* buffer = new char[kMaxOutgoingPacketSize];
1087   size_t encrypted_length =
1088       framer.EncryptPayload(level, QuicPacketNumber(packet_number), *packet,
1089                             buffer, kMaxOutgoingPacketSize);
1090   EXPECT_NE(0u, encrypted_length);
1091   return new QuicEncryptedPacket(buffer, encrypted_length, true);
1092 }
1093 
DefaultQuicConfig()1094 QuicConfig DefaultQuicConfig() {
1095   QuicConfig config;
1096   config.SetInitialMaxStreamDataBytesIncomingBidirectionalToSend(
1097       kInitialStreamFlowControlWindowForTest);
1098   config.SetInitialMaxStreamDataBytesOutgoingBidirectionalToSend(
1099       kInitialStreamFlowControlWindowForTest);
1100   config.SetInitialMaxStreamDataBytesUnidirectionalToSend(
1101       kInitialStreamFlowControlWindowForTest);
1102   config.SetInitialStreamFlowControlWindowToSend(
1103       kInitialStreamFlowControlWindowForTest);
1104   config.SetInitialSessionFlowControlWindowToSend(
1105       kInitialSessionFlowControlWindowForTest);
1106   QuicConfigPeer::SetReceivedMaxBidirectionalStreams(
1107       &config, kDefaultMaxStreamsPerConnection);
1108   // Default enable NSTP.
1109   // This is unnecessary for versions > 44
1110   if (!config.HasClientSentConnectionOption(quic::kNSTP,
1111                                             quic::Perspective::IS_CLIENT)) {
1112     quic::QuicTagVector connection_options;
1113     connection_options.push_back(quic::kNSTP);
1114     config.SetConnectionOptionsToSend(connection_options);
1115   }
1116   return config;
1117 }
1118 
SupportedVersions(ParsedQuicVersion version)1119 ParsedQuicVersionVector SupportedVersions(ParsedQuicVersion version) {
1120   ParsedQuicVersionVector versions;
1121   versions.push_back(version);
1122   return versions;
1123 }
1124 
MockQuicConnectionDebugVisitor()1125 MockQuicConnectionDebugVisitor::MockQuicConnectionDebugVisitor() {}
1126 
~MockQuicConnectionDebugVisitor()1127 MockQuicConnectionDebugVisitor::~MockQuicConnectionDebugVisitor() {}
1128 
MockReceivedPacketManager(QuicConnectionStats * stats)1129 MockReceivedPacketManager::MockReceivedPacketManager(QuicConnectionStats* stats)
1130     : QuicReceivedPacketManager(stats) {}
1131 
~MockReceivedPacketManager()1132 MockReceivedPacketManager::~MockReceivedPacketManager() {}
1133 
MockPacketCreatorDelegate()1134 MockPacketCreatorDelegate::MockPacketCreatorDelegate() {}
~MockPacketCreatorDelegate()1135 MockPacketCreatorDelegate::~MockPacketCreatorDelegate() {}
1136 
MockSessionNotifier()1137 MockSessionNotifier::MockSessionNotifier() {}
~MockSessionNotifier()1138 MockSessionNotifier::~MockSessionNotifier() {}
1139 
1140 // static
1141 QuicCryptoClientStream::HandshakerInterface*
GetHandshaker(QuicCryptoClientStream * stream)1142 QuicCryptoClientStreamPeer::GetHandshaker(QuicCryptoClientStream* stream) {
1143   return stream->handshaker_.get();
1144 }
1145 
CreateClientSessionForTest(QuicServerId server_id,QuicTime::Delta connection_start_time,const ParsedQuicVersionVector & supported_versions,MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,QuicCryptoClientConfig * crypto_client_config,PacketSavingConnection ** client_connection,TestQuicSpdyClientSession ** client_session)1146 void CreateClientSessionForTest(
1147     QuicServerId server_id, QuicTime::Delta connection_start_time,
1148     const ParsedQuicVersionVector& supported_versions,
1149     MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
1150     QuicCryptoClientConfig* crypto_client_config,
1151     PacketSavingConnection** client_connection,
1152     TestQuicSpdyClientSession** client_session) {
1153   QUICHE_CHECK(crypto_client_config);
1154   QUICHE_CHECK(client_connection);
1155   QUICHE_CHECK(client_session);
1156   QUICHE_CHECK(!connection_start_time.IsZero())
1157       << "Connections must start at non-zero times, otherwise the "
1158       << "strike-register will be unhappy.";
1159 
1160   QuicConfig config = DefaultQuicConfig();
1161   *client_connection = new PacketSavingConnection(
1162       helper, alarm_factory, Perspective::IS_CLIENT, supported_versions);
1163   *client_session = new TestQuicSpdyClientSession(*client_connection, config,
1164                                                   supported_versions, server_id,
1165                                                   crypto_client_config);
1166   (*client_connection)->AdvanceTime(connection_start_time);
1167 }
1168 
CreateServerSessionForTest(QuicServerId,QuicTime::Delta connection_start_time,ParsedQuicVersionVector supported_versions,MockQuicConnectionHelper * helper,QuicAlarmFactory * alarm_factory,QuicCryptoServerConfig * server_crypto_config,QuicCompressedCertsCache * compressed_certs_cache,PacketSavingConnection ** server_connection,TestQuicSpdyServerSession ** server_session)1169 void CreateServerSessionForTest(
1170     QuicServerId /*server_id*/, QuicTime::Delta connection_start_time,
1171     ParsedQuicVersionVector supported_versions,
1172     MockQuicConnectionHelper* helper, QuicAlarmFactory* alarm_factory,
1173     QuicCryptoServerConfig* server_crypto_config,
1174     QuicCompressedCertsCache* compressed_certs_cache,
1175     PacketSavingConnection** server_connection,
1176     TestQuicSpdyServerSession** server_session) {
1177   QUICHE_CHECK(server_crypto_config);
1178   QUICHE_CHECK(server_connection);
1179   QUICHE_CHECK(server_session);
1180   QUICHE_CHECK(!connection_start_time.IsZero())
1181       << "Connections must start at non-zero times, otherwise the "
1182       << "strike-register will be unhappy.";
1183 
1184   *server_connection =
1185       new PacketSavingConnection(helper, alarm_factory, Perspective::IS_SERVER,
1186                                  ParsedVersionOfIndex(supported_versions, 0));
1187   *server_session = new TestQuicSpdyServerSession(
1188       *server_connection, DefaultQuicConfig(), supported_versions,
1189       server_crypto_config, compressed_certs_cache);
1190   (*server_session)->Initialize();
1191 
1192   // We advance the clock initially because the default time is zero and the
1193   // strike register worries that we've just overflowed a uint32_t time.
1194   (*server_connection)->AdvanceTime(connection_start_time);
1195 }
1196 
GetNthClientInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1197 QuicStreamId GetNthClientInitiatedBidirectionalStreamId(
1198     QuicTransportVersion version, int n) {
1199   int num = n;
1200   if (!VersionUsesHttp3(version)) {
1201     num++;
1202   }
1203   return QuicUtils::GetFirstBidirectionalStreamId(version,
1204                                                   Perspective::IS_CLIENT) +
1205          QuicUtils::StreamIdDelta(version) * num;
1206 }
1207 
GetNthServerInitiatedBidirectionalStreamId(QuicTransportVersion version,int n)1208 QuicStreamId GetNthServerInitiatedBidirectionalStreamId(
1209     QuicTransportVersion version, int n) {
1210   return QuicUtils::GetFirstBidirectionalStreamId(version,
1211                                                   Perspective::IS_SERVER) +
1212          QuicUtils::StreamIdDelta(version) * n;
1213 }
1214 
GetNthServerInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1215 QuicStreamId GetNthServerInitiatedUnidirectionalStreamId(
1216     QuicTransportVersion version, int n) {
1217   return QuicUtils::GetFirstUnidirectionalStreamId(version,
1218                                                    Perspective::IS_SERVER) +
1219          QuicUtils::StreamIdDelta(version) * n;
1220 }
1221 
GetNthClientInitiatedUnidirectionalStreamId(QuicTransportVersion version,int n)1222 QuicStreamId GetNthClientInitiatedUnidirectionalStreamId(
1223     QuicTransportVersion version, int n) {
1224   return QuicUtils::GetFirstUnidirectionalStreamId(version,
1225                                                    Perspective::IS_CLIENT) +
1226          QuicUtils::StreamIdDelta(version) * n;
1227 }
1228 
DetermineStreamType(QuicStreamId id,ParsedQuicVersion version,Perspective perspective,bool is_incoming,StreamType default_type)1229 StreamType DetermineStreamType(QuicStreamId id, ParsedQuicVersion version,
1230                                Perspective perspective, bool is_incoming,
1231                                StreamType default_type) {
1232   return version.HasIetfQuicFrames()
1233              ? QuicUtils::GetStreamType(id, perspective, is_incoming, version)
1234              : default_type;
1235 }
1236 
MemSliceFromString(absl::string_view data)1237 quiche::QuicheMemSlice MemSliceFromString(absl::string_view data) {
1238   if (data.empty()) {
1239     return quiche::QuicheMemSlice();
1240   }
1241 
1242   static quiche::SimpleBufferAllocator* allocator =
1243       new quiche::SimpleBufferAllocator();
1244   return quiche::QuicheMemSlice(quiche::QuicheBuffer::Copy(allocator, data));
1245 }
1246 
EncryptPacket(uint64_t,absl::string_view,absl::string_view plaintext,char * output,size_t * output_length,size_t max_output_length)1247 bool TaggingEncrypter::EncryptPacket(uint64_t /*packet_number*/,
1248                                      absl::string_view /*associated_data*/,
1249                                      absl::string_view plaintext, char* output,
1250                                      size_t* output_length,
1251                                      size_t max_output_length) {
1252   const size_t len = plaintext.size() + kTagSize;
1253   if (max_output_length < len) {
1254     return false;
1255   }
1256   // Memmove is safe for inplace encryption.
1257   memmove(output, plaintext.data(), plaintext.size());
1258   output += plaintext.size();
1259   memset(output, tag_, kTagSize);
1260   *output_length = len;
1261   return true;
1262 }
1263 
DecryptPacket(uint64_t,absl::string_view,absl::string_view ciphertext,char * output,size_t * output_length,size_t)1264 bool TaggingDecrypter::DecryptPacket(uint64_t /*packet_number*/,
1265                                      absl::string_view /*associated_data*/,
1266                                      absl::string_view ciphertext, char* output,
1267                                      size_t* output_length,
1268                                      size_t /*max_output_length*/) {
1269   if (ciphertext.size() < kTagSize) {
1270     return false;
1271   }
1272   if (!CheckTag(ciphertext, GetTag(ciphertext))) {
1273     return false;
1274   }
1275   *output_length = ciphertext.size() - kTagSize;
1276   memcpy(output, ciphertext.data(), *output_length);
1277   return true;
1278 }
1279 
CheckTag(absl::string_view ciphertext,uint8_t tag)1280 bool TaggingDecrypter::CheckTag(absl::string_view ciphertext, uint8_t tag) {
1281   for (size_t i = ciphertext.size() - kTagSize; i < ciphertext.size(); i++) {
1282     if (ciphertext.data()[i] != tag) {
1283       return false;
1284     }
1285   }
1286 
1287   return true;
1288 }
1289 
TestPacketWriter(ParsedQuicVersion version,MockClock * clock,Perspective perspective)1290 TestPacketWriter::TestPacketWriter(ParsedQuicVersion version, MockClock* clock,
1291                                    Perspective perspective)
1292     : version_(version),
1293       framer_(SupportedVersions(version_),
1294               QuicUtils::InvertPerspective(perspective)),
1295       clock_(clock) {
1296   QuicFramerPeer::SetLastSerializedServerConnectionId(framer_.framer(),
1297                                                       TestConnectionId());
1298   framer_.framer()->SetInitialObfuscators(TestConnectionId());
1299 
1300   for (int i = 0; i < 128; ++i) {
1301     PacketBuffer* p = new PacketBuffer();
1302     packet_buffer_pool_.push_back(p);
1303     packet_buffer_pool_index_[p->buffer] = p;
1304     packet_buffer_free_list_.push_back(p);
1305   }
1306 }
1307 
~TestPacketWriter()1308 TestPacketWriter::~TestPacketWriter() {
1309   EXPECT_EQ(packet_buffer_pool_.size(), packet_buffer_free_list_.size())
1310       << packet_buffer_pool_.size() - packet_buffer_free_list_.size()
1311       << " out of " << packet_buffer_pool_.size()
1312       << " packet buffers have been leaked.";
1313   for (auto p : packet_buffer_pool_) {
1314     delete p;
1315   }
1316 }
1317 
WritePacket(const char * buffer,size_t buf_len,const QuicIpAddress & self_address,const QuicSocketAddress & peer_address,PerPacketOptions *,const QuicPacketWriterParams & params)1318 WriteResult TestPacketWriter::WritePacket(
1319     const char* buffer, size_t buf_len, const QuicIpAddress& self_address,
1320     const QuicSocketAddress& peer_address, PerPacketOptions* /*options*/,
1321     const QuicPacketWriterParams& params) {
1322   last_write_source_address_ = self_address;
1323   last_write_peer_address_ = peer_address;
1324   // If the buffer is allocated from the pool, return it back to the pool.
1325   // Note the buffer content doesn't change.
1326   if (packet_buffer_pool_index_.find(const_cast<char*>(buffer)) !=
1327       packet_buffer_pool_index_.end()) {
1328     FreePacketBuffer(buffer);
1329   }
1330 
1331   QuicEncryptedPacket packet(buffer, buf_len);
1332   ++packets_write_attempts_;
1333 
1334   if (packet.length() >= sizeof(final_bytes_of_last_packet_)) {
1335     final_bytes_of_previous_packet_ = final_bytes_of_last_packet_;
1336     memcpy(&final_bytes_of_last_packet_, packet.data() + packet.length() - 4,
1337            sizeof(final_bytes_of_last_packet_));
1338   }
1339   if (framer_.framer()->version().KnowsWhichDecrypterToUse()) {
1340     framer_.framer()->InstallDecrypter(ENCRYPTION_HANDSHAKE,
1341                                        std::make_unique<TaggingDecrypter>());
1342     framer_.framer()->InstallDecrypter(ENCRYPTION_ZERO_RTT,
1343                                        std::make_unique<TaggingDecrypter>());
1344     framer_.framer()->InstallDecrypter(ENCRYPTION_FORWARD_SECURE,
1345                                        std::make_unique<TaggingDecrypter>());
1346   } else if (!framer_.framer()->HasDecrypterOfEncryptionLevel(
1347                  ENCRYPTION_FORWARD_SECURE) &&
1348              !framer_.framer()->HasDecrypterOfEncryptionLevel(
1349                  ENCRYPTION_ZERO_RTT)) {
1350     framer_.framer()->SetAlternativeDecrypter(
1351         ENCRYPTION_FORWARD_SECURE,
1352         std::make_unique<StrictTaggingDecrypter>(ENCRYPTION_FORWARD_SECURE),
1353         false);
1354   }
1355   EXPECT_EQ(next_packet_processable_, framer_.ProcessPacket(packet))
1356       << framer_.framer()->detailed_error() << " perspective "
1357       << framer_.framer()->perspective();
1358   next_packet_processable_ = true;
1359   if (block_on_next_write_) {
1360     write_blocked_ = true;
1361     block_on_next_write_ = false;
1362   }
1363   if (next_packet_too_large_) {
1364     next_packet_too_large_ = false;
1365     return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1366   }
1367   if (always_get_packet_too_large_) {
1368     return WriteResult(WRITE_STATUS_ERROR, *MessageTooBigErrorCode());
1369   }
1370   if (IsWriteBlocked()) {
1371     return WriteResult(is_write_blocked_data_buffered_
1372                            ? WRITE_STATUS_BLOCKED_DATA_BUFFERED
1373                            : WRITE_STATUS_BLOCKED,
1374                        0);
1375   }
1376 
1377   if (ShouldWriteFail()) {
1378     return WriteResult(WRITE_STATUS_ERROR, write_error_code_);
1379   }
1380 
1381   last_packet_size_ = packet.length();
1382   total_bytes_written_ += packet.length();
1383   last_packet_header_ = framer_.header();
1384   if (!framer_.connection_close_frames().empty()) {
1385     ++connection_close_packets_;
1386   }
1387   if (!write_pause_time_delta_.IsZero()) {
1388     clock_->AdvanceTime(write_pause_time_delta_);
1389   }
1390   if (is_batch_mode_) {
1391     bytes_buffered_ += last_packet_size_;
1392     return WriteResult(WRITE_STATUS_OK, 0);
1393   }
1394   last_ecn_sent_ = params.ecn_codepoint;
1395   return WriteResult(WRITE_STATUS_OK, last_packet_size_);
1396 }
1397 
GetNextWriteLocation(const QuicIpAddress &,const QuicSocketAddress &)1398 QuicPacketBuffer TestPacketWriter::GetNextWriteLocation(
1399     const QuicIpAddress& /*self_address*/,
1400     const QuicSocketAddress& /*peer_address*/) {
1401   return {AllocPacketBuffer(), [this](const char* p) { FreePacketBuffer(p); }};
1402 }
1403 
Flush()1404 WriteResult TestPacketWriter::Flush() {
1405   flush_attempts_++;
1406   if (block_on_next_flush_) {
1407     block_on_next_flush_ = false;
1408     SetWriteBlocked();
1409     return WriteResult(WRITE_STATUS_BLOCKED, /*errno*/ -1);
1410   }
1411   if (write_should_fail_) {
1412     return WriteResult(WRITE_STATUS_ERROR, /*errno*/ -1);
1413   }
1414   int bytes_flushed = bytes_buffered_;
1415   bytes_buffered_ = 0;
1416   return WriteResult(WRITE_STATUS_OK, bytes_flushed);
1417 }
1418 
AllocPacketBuffer()1419 char* TestPacketWriter::AllocPacketBuffer() {
1420   PacketBuffer* p = packet_buffer_free_list_.front();
1421   EXPECT_FALSE(p->in_use);
1422   p->in_use = true;
1423   packet_buffer_free_list_.pop_front();
1424   return p->buffer;
1425 }
1426 
FreePacketBuffer(const char * buffer)1427 void TestPacketWriter::FreePacketBuffer(const char* buffer) {
1428   auto iter = packet_buffer_pool_index_.find(const_cast<char*>(buffer));
1429   ASSERT_TRUE(iter != packet_buffer_pool_index_.end());
1430   PacketBuffer* p = iter->second;
1431   ASSERT_TRUE(p->in_use);
1432   p->in_use = false;
1433   packet_buffer_free_list_.push_back(p);
1434 }
1435 
WriteServerVersionNegotiationProbeResponse(char * packet_bytes,size_t * packet_length_out,const char * source_connection_id_bytes,uint8_t source_connection_id_length)1436 bool WriteServerVersionNegotiationProbeResponse(
1437     char* packet_bytes, size_t* packet_length_out,
1438     const char* source_connection_id_bytes,
1439     uint8_t source_connection_id_length) {
1440   if (packet_bytes == nullptr) {
1441     QUIC_BUG(quic_bug_10256_1) << "Invalid packet_bytes";
1442     return false;
1443   }
1444   if (packet_length_out == nullptr) {
1445     QUIC_BUG(quic_bug_10256_2) << "Invalid packet_length_out";
1446     return false;
1447   }
1448   QuicConnectionId source_connection_id(source_connection_id_bytes,
1449                                         source_connection_id_length);
1450   std::unique_ptr<QuicEncryptedPacket> encrypted_packet =
1451       QuicFramer::BuildVersionNegotiationPacket(
1452           source_connection_id, EmptyQuicConnectionId(),
1453           /*ietf_quic=*/true, /*use_length_prefix=*/true,
1454           ParsedQuicVersionVector{});
1455   if (!encrypted_packet) {
1456     QUIC_BUG(quic_bug_10256_3) << "Failed to create version negotiation packet";
1457     return false;
1458   }
1459   if (*packet_length_out < encrypted_packet->length()) {
1460     QUIC_BUG(quic_bug_10256_4)
1461         << "Invalid *packet_length_out " << *packet_length_out << " < "
1462         << encrypted_packet->length();
1463     return false;
1464   }
1465   *packet_length_out = encrypted_packet->length();
1466   memcpy(packet_bytes, encrypted_packet->data(), *packet_length_out);
1467   return true;
1468 }
1469 
ParseClientVersionNegotiationProbePacket(const char * packet_bytes,size_t packet_length,char * destination_connection_id_bytes,uint8_t * destination_connection_id_length_out)1470 bool ParseClientVersionNegotiationProbePacket(
1471     const char* packet_bytes, size_t packet_length,
1472     char* destination_connection_id_bytes,
1473     uint8_t* destination_connection_id_length_out) {
1474   if (packet_bytes == nullptr) {
1475     QUIC_BUG(quic_bug_10256_5) << "Invalid packet_bytes";
1476     return false;
1477   }
1478   if (packet_length < kMinPacketSizeForVersionNegotiation ||
1479       packet_length > 65535) {
1480     QUIC_BUG(quic_bug_10256_6) << "Invalid packet_length";
1481     return false;
1482   }
1483   if (destination_connection_id_bytes == nullptr) {
1484     QUIC_BUG(quic_bug_10256_7) << "Invalid destination_connection_id_bytes";
1485     return false;
1486   }
1487   if (destination_connection_id_length_out == nullptr) {
1488     QUIC_BUG(quic_bug_10256_8)
1489         << "Invalid destination_connection_id_length_out";
1490     return false;
1491   }
1492 
1493   QuicEncryptedPacket encrypted_packet(packet_bytes, packet_length);
1494   PacketHeaderFormat format;
1495   QuicLongHeaderType long_packet_type;
1496   bool version_present, has_length_prefix;
1497   QuicVersionLabel version_label;
1498   ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported();
1499   QuicConnectionId destination_connection_id, source_connection_id;
1500   std::optional<absl::string_view> retry_token;
1501   std::string detailed_error;
1502   QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher(
1503       encrypted_packet,
1504       /*expected_destination_connection_id_length=*/0, &format,
1505       &long_packet_type, &version_present, &has_length_prefix, &version_label,
1506       &parsed_version, &destination_connection_id, &source_connection_id,
1507       &retry_token, &detailed_error);
1508   if (error != QUIC_NO_ERROR) {
1509     QUIC_BUG(quic_bug_10256_9) << "Failed to parse packet: " << detailed_error;
1510     return false;
1511   }
1512   if (!version_present) {
1513     QUIC_BUG(quic_bug_10256_10) << "Packet is not a long header";
1514     return false;
1515   }
1516   if (*destination_connection_id_length_out <
1517       destination_connection_id.length()) {
1518     QUIC_BUG(quic_bug_10256_11)
1519         << "destination_connection_id_length_out too small";
1520     return false;
1521   }
1522   *destination_connection_id_length_out = destination_connection_id.length();
1523   memcpy(destination_connection_id_bytes, destination_connection_id.data(),
1524          *destination_connection_id_length_out);
1525   return true;
1526 }
1527 
1528 }  // namespace test
1529 }  // namespace quic
1530