xref: /aosp_15_r20/external/pigweed/pw_transfer/integration_test/proxy_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2022 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Unit test for proxy.py"""
16
17import abc
18import asyncio
19import unittest
20
21from pw_rpc.internal import packet_pb2
22from pw_hdlc import encode
23from pw_transfer.chunk import Chunk, ProtocolVersion
24
25import proxy
26
27
28class MockRng(abc.ABC):
29    def __init__(self, results: list[float]):
30        self._results = results
31
32    def uniform(self, from_val: float, to_val: float) -> float:
33        val_range = to_val - from_val
34        val = self._results.pop()
35        val *= val_range
36        val += from_val
37        return val
38
39
40class ProxyTest(unittest.IsolatedAsyncioTestCase):
41    async def test_transposer_simple(self):
42        sent_packets: list[bytes] = []
43        new_packets_event: asyncio.Event = asyncio.Event()
44
45        # Async helper so DataTransposer can await on it.
46        async def append(list: list[bytes], data: bytes):
47            list.append(data)
48            # Notify that a new packet was "sent".
49            new_packets_event.set()
50
51        transposer = proxy.DataTransposer(
52            lambda data: append(sent_packets, data),
53            name="test",
54            rate=0.5,
55            timeout=100,
56            seed=1234567890,
57        )
58        transposer._rng = MockRng([0.6, 0.4])
59        await transposer.process(b'aaaaaaaaaa')
60        await transposer.process(b'bbbbbbbbbb')
61
62        expected_packets = [b'bbbbbbbbbb', b'aaaaaaaaaa']
63        while True:
64            # Wait for new packets with a generous timeout.
65            try:
66                await asyncio.wait_for(new_packets_event.wait(), timeout=60.0)
67            except TimeoutError:
68                self.fail(
69                    f'Timeout waiting for data.  Packets sent: {sent_packets}'
70                )
71
72            # Only assert the sent packets are corrected when we've sent the
73            # expected number.
74            if len(sent_packets) == len(expected_packets):
75                self.assertEqual(sent_packets, expected_packets)
76                return
77
78    async def test_transposer_timeout(self):
79        sent_packets: list[bytes] = []
80
81        # Async helper so DataTransposer can await on it.
82        async def append(list: list[bytes], data: bytes):
83            list.append(data)
84
85        transposer = proxy.DataTransposer(
86            lambda data: append(sent_packets, data),
87            name="test",
88            rate=0.5,
89            timeout=0.100,
90            seed=1234567890,
91        )
92        transposer._rng = MockRng([0.4, 0.6])
93        await transposer.process(b'aaaaaaaaaa')
94
95        # Even though this should be transposed, there is no following data so
96        # the transposer should timout and send this in-order.
97        await transposer.process(b'bbbbbbbbbb')
98
99        # Give the transposer time to timeout.
100        await asyncio.sleep(0.5)
101
102        self.assertEqual(sent_packets, [b'aaaaaaaaaa', b'bbbbbbbbbb'])
103
104    async def test_server_failure(self):
105        sent_packets: list[bytes] = []
106
107        # Async helper so DataTransposer can await on it.
108        async def append(list: list[bytes], data: bytes):
109            list.append(data)
110
111        packets_before_failure = [1, 2, 3]
112        server_failure = proxy.ServerFailure(
113            lambda data: append(sent_packets, data),
114            name="test",
115            packets_before_failure_list=packets_before_failure.copy(),
116            start_immediately=True,
117        )
118
119        # After passing the list to ServerFailure, add a test for no
120        # packets dropped
121        packets_before_failure.append(5)
122
123        packets = [
124            b'1',
125            b'2',
126            b'3',
127            b'4',
128            b'5',
129        ]
130
131        for num_packets in packets_before_failure:
132            sent_packets.clear()
133            for packet in packets:
134                await server_failure.process(packet)
135            self.assertEqual(len(sent_packets), num_packets)
136            server_failure.handle_event(
137                proxy.Event(
138                    proxy.EventType.TRANSFER_START,
139                    Chunk(ProtocolVersion.VERSION_TWO, Chunk.Type.START),
140                )
141            )
142
143    async def test_server_failure_transfer_chunks_only(self):
144        sent_packets = []
145
146        # Async helper so DataTransposer can await on it.
147        async def append(list: list[bytes], data: bytes):
148            list.append(data)
149
150        packets_before_failure = [2]
151        server_failure = proxy.ServerFailure(
152            lambda data: append(sent_packets, data),
153            name="test",
154            packets_before_failure_list=packets_before_failure.copy(),
155            start_immediately=True,
156            only_consider_transfer_chunks=True,
157        )
158
159        transfer_chunk = _encode_rpc_frame(
160            Chunk(ProtocolVersion.VERSION_TWO, Chunk.Type.DATA, data=b'1')
161        )
162
163        packets = [
164            b'1',
165            b'2',
166            transfer_chunk,  # 1
167            b'3',
168            transfer_chunk,  # 2
169            b'4',
170            b'5',
171            transfer_chunk,  # Transfer chunks should be dropped starting here.
172            transfer_chunk,
173            b'6',
174            b'7',
175            transfer_chunk,
176        ]
177
178        for packet in packets:
179            await server_failure.process(packet)
180
181        expected_result = [
182            b'1',
183            b'2',
184            transfer_chunk,
185            b'3',
186            transfer_chunk,
187            b'4',
188            b'5',
189            b'6',
190            b'7',
191        ]
192        self.assertEqual(sent_packets, expected_result)
193
194    async def test_keep_drop_queue_loop(self):
195        sent_packets: list[bytes] = []
196
197        # Async helper so DataTransposer can await on it.
198        async def append(list: list[bytes], data: bytes):
199            list.append(data)
200
201        keep_drop_queue = proxy.KeepDropQueue(
202            lambda data: append(sent_packets, data),
203            name="test",
204            keep_drop_queue=[2, 1, 3],
205        )
206
207        expected_sequence = [
208            b'1',
209            b'2',
210            b'4',
211            b'5',
212            b'6',
213            b'9',
214        ]
215        input_packets = [
216            b'1',
217            b'2',
218            b'3',
219            b'4',
220            b'5',
221            b'6',
222            b'7',
223            b'8',
224            b'9',
225        ]
226
227        for packet in input_packets:
228            await keep_drop_queue.process(packet)
229        self.assertEqual(sent_packets, expected_sequence)
230
231    async def test_keep_drop_queue(self):
232        sent_packets: list[bytes] = []
233
234        # Async helper so DataTransposer can await on it.
235        async def append(list: list[bytes], data: bytes):
236            list.append(data)
237
238        keep_drop_queue = proxy.KeepDropQueue(
239            lambda data: append(sent_packets, data),
240            name="test",
241            keep_drop_queue=[2, 1, 1, -1],
242        )
243
244        expected_sequence = [
245            b'1',
246            b'2',
247            b'4',
248        ]
249        input_packets = [
250            b'1',
251            b'2',
252            b'3',
253            b'4',
254            b'5',
255            b'6',
256            b'7',
257            b'8',
258            b'9',
259        ]
260
261        for packet in input_packets:
262            await keep_drop_queue.process(packet)
263        self.assertEqual(sent_packets, expected_sequence)
264
265    async def test_keep_drop_queue_transfer_chunks_only(self):
266        sent_packets: list[bytes] = []
267
268        # Async helper so DataTransposer can await on it.
269        async def append(list: list[bytes], data: bytes):
270            list.append(data)
271
272        keep_drop_queue = proxy.KeepDropQueue(
273            lambda data: append(sent_packets, data),
274            name="test",
275            keep_drop_queue=[2, 1, 1, -1],
276            only_consider_transfer_chunks=True,
277        )
278
279        transfer_chunk = _encode_rpc_frame(
280            Chunk(ProtocolVersion.VERSION_TWO, Chunk.Type.DATA, data=b'1')
281        )
282
283        expected_sequence = [
284            b'1',
285            transfer_chunk,
286            b'2',
287            transfer_chunk,
288            b'3',
289            b'4',
290            b'5',
291            b'6',
292            b'7',
293            transfer_chunk,
294            b'8',
295            b'9',
296            b'10',
297        ]
298        input_packets = [
299            b'1',
300            transfer_chunk,  # keep
301            b'2',
302            transfer_chunk,  # keep
303            b'3',
304            b'4',
305            b'5',
306            transfer_chunk,  # drop
307            b'6',
308            b'7',
309            transfer_chunk,  # keep
310            transfer_chunk,  # drop
311            b'8',
312            transfer_chunk,  # drop
313            b'9',
314            transfer_chunk,  # drop
315            transfer_chunk,  # drop
316            b'10',
317        ]
318
319        for packet in input_packets:
320            await keep_drop_queue.process(packet)
321        self.assertEqual(sent_packets, expected_sequence)
322
323    async def test_window_packet_dropper(self):
324        sent_packets: list[bytes] = []
325
326        # Async helper so DataTransposer can await on it.
327        async def append(list: list[bytes], data: bytes):
328            list.append(data)
329
330        window_packet_dropper = proxy.WindowPacketDropper(
331            lambda data: append(sent_packets, data),
332            name="test",
333            window_packet_to_drop=0,
334        )
335
336        packets = [
337            _encode_rpc_frame(
338                Chunk(
339                    ProtocolVersion.VERSION_TWO,
340                    Chunk.Type.DATA,
341                    data=b'1',
342                    session_id=1,
343                )
344            ),
345            _encode_rpc_frame(
346                Chunk(
347                    ProtocolVersion.VERSION_TWO,
348                    Chunk.Type.DATA,
349                    data=b'2',
350                    session_id=1,
351                )
352            ),
353            _encode_rpc_frame(
354                Chunk(
355                    ProtocolVersion.VERSION_TWO,
356                    Chunk.Type.DATA,
357                    data=b'3',
358                    session_id=1,
359                )
360            ),
361            _encode_rpc_frame(
362                Chunk(
363                    ProtocolVersion.VERSION_TWO,
364                    Chunk.Type.DATA,
365                    data=b'4',
366                    session_id=1,
367                )
368            ),
369            _encode_rpc_frame(
370                Chunk(
371                    ProtocolVersion.VERSION_TWO,
372                    Chunk.Type.DATA,
373                    data=b'5',
374                    session_id=1,
375                )
376            ),
377        ]
378
379        expected_packets = packets[1:]
380
381        # Test each even twice to assure the filter does not have issues
382        # on new window bondaries.
383        events = [
384            proxy.Event(
385                proxy.EventType.PARAMETERS_RETRANSMIT,
386                Chunk(
387                    ProtocolVersion.VERSION_TWO,
388                    Chunk.Type.PARAMETERS_RETRANSMIT,
389                ),
390            ),
391            proxy.Event(
392                proxy.EventType.PARAMETERS_CONTINUE,
393                Chunk(
394                    ProtocolVersion.VERSION_TWO, Chunk.Type.PARAMETERS_CONTINUE
395                ),
396            ),
397            proxy.Event(
398                proxy.EventType.PARAMETERS_RETRANSMIT,
399                Chunk(
400                    ProtocolVersion.VERSION_TWO,
401                    Chunk.Type.PARAMETERS_RETRANSMIT,
402                ),
403            ),
404            proxy.Event(
405                proxy.EventType.PARAMETERS_CONTINUE,
406                Chunk(
407                    ProtocolVersion.VERSION_TWO, Chunk.Type.PARAMETERS_CONTINUE
408                ),
409            ),
410        ]
411
412        for event in events:
413            sent_packets.clear()
414            for packet in packets:
415                await window_packet_dropper.process(packet)
416            self.assertEqual(sent_packets, expected_packets)
417            window_packet_dropper.handle_event(event)
418
419    async def test_window_packet_dropper_extra_in_flight_packets(self):
420        sent_packets: list[bytes] = []
421
422        # Async helper so DataTransposer can await on it.
423        async def append(list: list[bytes], data: bytes):
424            list.append(data)
425
426        window_packet_dropper = proxy.WindowPacketDropper(
427            lambda data: append(sent_packets, data),
428            name="test",
429            window_packet_to_drop=1,
430        )
431
432        packets = [
433            _encode_rpc_frame(
434                Chunk(
435                    ProtocolVersion.VERSION_TWO,
436                    Chunk.Type.DATA,
437                    data=b'1',
438                    offset=0,
439                )
440            ),
441            _encode_rpc_frame(
442                Chunk(
443                    ProtocolVersion.VERSION_TWO,
444                    Chunk.Type.DATA,
445                    data=b'2',
446                    offset=1,
447                )
448            ),
449            _encode_rpc_frame(
450                Chunk(
451                    ProtocolVersion.VERSION_TWO,
452                    Chunk.Type.DATA,
453                    data=b'3',
454                    offset=2,
455                )
456            ),
457            _encode_rpc_frame(
458                Chunk(
459                    ProtocolVersion.VERSION_TWO,
460                    Chunk.Type.DATA,
461                    data=b'2',
462                    offset=1,
463                )
464            ),
465            _encode_rpc_frame(
466                Chunk(
467                    ProtocolVersion.VERSION_TWO,
468                    Chunk.Type.DATA,
469                    data=b'3',
470                    offset=2,
471                )
472            ),
473            _encode_rpc_frame(
474                Chunk(
475                    ProtocolVersion.VERSION_TWO,
476                    Chunk.Type.DATA,
477                    data=b'4',
478                    offset=3,
479                )
480            ),
481        ]
482
483        expected_packets = packets[1:]
484
485        # Test each even twice to assure the filter does not have issues
486        # on new window bondaries.
487        events = [
488            None,
489            proxy.Event(
490                proxy.EventType.PARAMETERS_RETRANSMIT,
491                Chunk(
492                    ProtocolVersion.VERSION_TWO,
493                    Chunk.Type.PARAMETERS_RETRANSMIT,
494                    offset=1,
495                ),
496            ),
497            None,
498            None,
499            None,
500            None,
501        ]
502
503        for packet, event in zip(packets, events):
504            await window_packet_dropper.process(packet)
505            if event is not None:
506                window_packet_dropper.handle_event(event)
507
508        expected_packets = [packets[0], packets[2], packets[3], packets[5]]
509        self.assertEqual(sent_packets, expected_packets)
510
511    async def test_event_filter(self):
512        sent_packets: list[bytes] = []
513
514        # Async helper so EventFilter can await on it.
515        async def append(list: list[bytes], data: bytes):
516            list.append(data)
517
518        queue = asyncio.Queue()
519
520        event_filter = proxy.EventFilter(
521            lambda data: append(sent_packets, data),
522            name="test",
523            event_queue=queue,
524        )
525
526        request = packet_pb2.RpcPacket(
527            type=packet_pb2.PacketType.REQUEST,
528            channel_id=101,
529            service_id=1001,
530            method_id=100001,
531        ).SerializeToString()
532
533        packets = [
534            request,
535            _encode_rpc_frame(
536                Chunk(
537                    ProtocolVersion.VERSION_TWO, Chunk.Type.START, session_id=1
538                )
539            ),
540            _encode_rpc_frame(
541                Chunk(
542                    ProtocolVersion.VERSION_TWO,
543                    Chunk.Type.DATA,
544                    session_id=1,
545                    data=b'3',
546                )
547            ),
548            _encode_rpc_frame(
549                Chunk(
550                    ProtocolVersion.VERSION_TWO,
551                    Chunk.Type.DATA,
552                    session_id=1,
553                    data=b'3',
554                )
555            ),
556            request,
557            _encode_rpc_frame(
558                Chunk(
559                    ProtocolVersion.VERSION_TWO, Chunk.Type.START, session_id=2
560                )
561            ),
562            _encode_rpc_frame(
563                Chunk(
564                    ProtocolVersion.VERSION_TWO,
565                    Chunk.Type.DATA,
566                    session_id=2,
567                    data=b'4',
568                )
569            ),
570            _encode_rpc_frame(
571                Chunk(
572                    ProtocolVersion.VERSION_TWO,
573                    Chunk.Type.DATA,
574                    session_id=2,
575                    data=b'5',
576                )
577            ),
578        ]
579
580        expected_events = [
581            None,  # request
582            proxy.EventType.TRANSFER_START,
583            None,  # data chunk
584            None,  # data chunk
585            None,  # request
586            proxy.EventType.TRANSFER_START,
587            None,  # data chunk
588            None,  # data chunk
589        ]
590
591        for packet, expected_event_type in zip(packets, expected_events):
592            await event_filter.process(packet)
593            try:
594                event_type = queue.get_nowait().type
595            except asyncio.QueueEmpty:
596                event_type = None
597            self.assertEqual(event_type, expected_event_type)
598
599
600def _encode_rpc_frame(chunk: Chunk) -> bytes:
601    packet = packet_pb2.RpcPacket(
602        type=packet_pb2.PacketType.SERVER_STREAM,
603        channel_id=101,
604        service_id=1001,
605        method_id=100001,
606        payload=chunk.to_message().SerializeToString(),
607    ).SerializeToString()
608    return encode.ui_frame(73, packet)
609
610
611if __name__ == '__main__':
612    unittest.main()
613