xref: /aosp_15_r20/external/pigweed/pw_transfer/integration_test/proxy.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2022 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Proxy for transfer integration testing.
16
17This module contains a proxy for transfer intergation testing.  It is capable
18of introducing various link failures into the connection between the client and
19server.
20"""
21
22import abc
23import argparse
24import asyncio
25from enum import Enum
26import logging
27import random
28import socket
29import sys
30import time
31from typing import Awaitable, Callable, Iterable, NamedTuple
32
33from google.protobuf import text_format
34
35from pw_rpc.internal import packet_pb2
36from pw_transfer import transfer_pb2
37from pw_transfer.integration_test import config_pb2
38from pw_hdlc import decode
39from pw_transfer.chunk import Chunk
40
41_LOG = logging.getLogger('pw_transfer_intergration_test_proxy')
42
43# This is the maximum size of the socket receive buffers. Ideally, this is set
44# to the lowest allowed value to minimize buffering between the proxy and
45# clients so rate limiting causes the client to block and wait for the
46# integration test proxy to drain rather than allowing OS buffers to backlog
47# large quantities of data.
48#
49# Note that the OS may chose to not strictly follow this requested buffer size.
50# Still, setting this value to be relatively small does reduce bufer sizes
51# significantly enough to better reflect typical inter-device communication.
52#
53# For this to be effective, clients should also configure their sockets to a
54# smaller send buffer size.
55_RECEIVE_BUFFER_SIZE = 2048
56
57
58class EventType(Enum):
59    TRANSFER_START = 1
60    PARAMETERS_RETRANSMIT = 2
61    PARAMETERS_CONTINUE = 3
62    START_ACK_CONFIRMATION = 4
63
64
65class Event(NamedTuple):
66    type: EventType
67    chunk: Chunk
68
69
70class Filter(abc.ABC):
71    """An abstract interface for manipulating a stream of data.
72
73    ``Filter``s are used to implement various transforms to simulate real
74    world link properties.  Some examples include: data corruption,
75    packet loss, packet reordering, rate limiting, latency modeling.
76
77    A ``Filter`` implementation should implement the ``process`` method
78    and call ``self.send_data()`` when it has data to send.
79    """
80
81    def __init__(self, send_data: Callable[[bytes], Awaitable[None]]):
82        self.send_data = send_data
83
84    @abc.abstractmethod
85    async def process(self, data: bytes) -> None:
86        """Processes incoming data.
87
88        Implementations of this method may send arbitrary data, or none, using
89        the ``self.send_data()`` handler.
90        """
91
92    async def __call__(self, data: bytes) -> None:
93        await self.process(data)
94
95
96class HdlcPacketizer(Filter):
97    """A filter which aggregates data into complete HDLC packets.
98
99    Since the proxy transport (SOCK_STREAM) has no framing and we want some
100    filters to operates on whole frames, this filter can be used so that
101    downstream filters see whole frames.
102    """
103
104    def __init__(self, send_data: Callable[[bytes], Awaitable[None]]):
105        super().__init__(send_data)
106        self.decoder = decode.FrameDecoder()
107
108    async def process(self, data: bytes) -> None:
109        for frame in self.decoder.process(data):
110            await self.send_data(frame.raw_encoded)
111
112
113class DataDropper(Filter):
114    """A filter which drops some data.
115
116    DataDropper will drop data passed through ``process()`` at the
117    specified ``rate``.
118    """
119
120    def __init__(
121        self,
122        send_data: Callable[[bytes], Awaitable[None]],
123        name: str,
124        rate: float,
125        seed: int | None = None,
126    ):
127        super().__init__(send_data)
128        self._rate = rate
129        self._name = name
130        if seed == None:
131            seed = time.time_ns()
132        self._rng = random.Random(seed)
133        _LOG.info(f'{name} DataDropper initialized with seed {seed}')
134
135    async def process(self, data: bytes) -> None:
136        if self._rng.uniform(0.0, 1.0) < self._rate:
137            _LOG.info(f'{self._name} dropped {len(data)} bytes of data')
138        else:
139            await self.send_data(data)
140
141
142class KeepDropQueue(Filter):
143    """A filter which alternates between sending packets and dropping packets.
144
145    A KeepDropQueue filter will alternate between keeping packets and dropping
146    chunks of data based on a keep/drop queue provided during its creation. The
147    queue is looped over unless a negative element is found. A negative number
148    is effectively the same as a value of infinity.
149
150     This filter is typically most practical when used with a packetizer so data
151     can be dropped as distinct packets.
152
153    Examples:
154
155      keep_drop_queue = [3, 2]:
156        Keeps 3 packets,
157        Drops 2 packets,
158        Keeps 3 packets,
159        Drops 2 packets,
160        ... [loops indefinitely]
161
162      keep_drop_queue = [5, 99, 1, -1]:
163        Keeps 5 packets,
164        Drops 99 packets,
165        Keeps 1 packet,
166        Drops all further packets.
167    """
168
169    def __init__(
170        self,
171        send_data: Callable[[bytes], Awaitable[None]],
172        name: str,
173        keep_drop_queue: Iterable[int],
174        only_consider_transfer_chunks: bool = False,
175    ):
176        super().__init__(send_data)
177        self._keep_drop_queue = list(keep_drop_queue)
178        self._loop_idx = 0
179        self._current_count = self._keep_drop_queue[0]
180        self._keep = True
181        self._name = name
182        self._only_consider_transfer_chunks = only_consider_transfer_chunks
183
184    async def process(self, data: bytes) -> None:
185        if self._only_consider_transfer_chunks:
186            try:
187                _extract_transfer_chunk(data)
188            except Exception:
189                await self.send_data(data)
190                return
191
192        # Move forward through the queue if needed.
193        while self._current_count == 0:
194            self._loop_idx += 1
195            self._current_count = self._keep_drop_queue[
196                self._loop_idx % len(self._keep_drop_queue)
197            ]
198            self._keep = not self._keep
199
200        if self._current_count > 0:
201            self._current_count -= 1
202
203        if self._keep:
204            await self.send_data(data)
205            _LOG.info(f'{self._name} forwarded {len(data)} bytes of data')
206        else:
207            _LOG.info(f'{self._name} dropped {len(data)} bytes of data')
208
209
210class RateLimiter(Filter):
211    """A filter which limits transmission rate.
212
213    This filter delays transmission of data by len(data)/rate.
214    """
215
216    def __init__(
217        self, send_data: Callable[[bytes], Awaitable[None]], rate: float
218    ):
219        super().__init__(send_data)
220        self._rate = rate
221
222    async def process(self, data: bytes) -> None:
223        delay = len(data) / self._rate
224        await asyncio.sleep(delay)
225        await self.send_data(data)
226
227
228class DataTransposer(Filter):
229    """A filter which occasionally transposes two chunks of data.
230
231    This filter transposes data at the specified rate.  It does this by
232    holding a chunk to transpose until another chunk arrives. The filter
233    will not hold a chunk longer than ``timeout`` seconds.
234    """
235
236    def __init__(
237        self,
238        send_data: Callable[[bytes], Awaitable[None]],
239        name: str,
240        rate: float,
241        timeout: float,
242        seed: int,
243    ):
244        super().__init__(send_data)
245        self._name = name
246        self._rate = rate
247        self._timeout = timeout
248        self._data_queue = asyncio.Queue()
249        self._rng = random.Random(seed)
250        self._transpose_task = asyncio.create_task(self._transpose_handler())
251
252        _LOG.info(f'{name} DataTranspose initialized with seed {seed}')
253
254    def __del__(self):
255        _LOG.info(f'{self._name} cleaning up transpose task.')
256        self._transpose_task.cancel()
257
258    async def _transpose_handler(self):
259        """Async task that handles the packet transposition and timeouts"""
260        held_data: bytes | None = None
261        while True:
262            # Only use timeout if we have data held for transposition
263            timeout = None if held_data is None else self._timeout
264            try:
265                data = await asyncio.wait_for(
266                    self._data_queue.get(), timeout=timeout
267                )
268
269                if held_data is not None:
270                    # If we have held data, send it out of order.
271                    await self.send_data(data)
272                    await self.send_data(held_data)
273                    held_data = None
274                else:
275                    # Otherwise decide if we should transpose the current data.
276                    if self._rng.uniform(0.0, 1.0) < self._rate:
277                        _LOG.info(
278                            f'{self._name} transposing {len(data)} bytes of data'
279                        )
280                        held_data = data
281                    else:
282                        await self.send_data(data)
283
284            except asyncio.TimeoutError:
285                _LOG.info(f'{self._name} sending data in order due to timeout')
286                await self.send_data(held_data)
287                held_data = None
288
289    async def process(self, data: bytes) -> None:
290        # Queue data for processing by the transpose task.
291        await self._data_queue.put(data)
292
293
294class ServerFailure(Filter):
295    """A filter to simulate the server stopping sending packets.
296
297    ServerFailure takes a list of numbers of packets to send before
298    dropping all subsequent packets until a TRANSFER_START packet
299    is seen.  This process is repeated for each element in
300    packets_before_failure.  After that list is exhausted, ServerFailure
301    will send all packets.
302
303    This filter should be instantiated in the same filter stack as an
304    HdlcPacketizer so that EventFilter can decode complete packets.
305    """
306
307    def __init__(
308        self,
309        send_data: Callable[[bytes], Awaitable[None]],
310        name: str,
311        packets_before_failure_list: list[int],
312        start_immediately: bool = False,
313        only_consider_transfer_chunks: bool = False,
314    ):
315        super().__init__(send_data)
316        self._name = name
317        self._relay_packets = True
318        self._packets_before_failure_list = packets_before_failure_list
319        self._packets_before_failure = None
320        self._only_consider_transfer_chunks = only_consider_transfer_chunks
321        if start_immediately:
322            self.advance_packets_before_failure()
323
324    def advance_packets_before_failure(self):
325        if len(self._packets_before_failure_list) > 0:
326            self._packets_before_failure = (
327                self._packets_before_failure_list.pop(0)
328            )
329        else:
330            self._packets_before_failure = None
331
332    async def process(self, data: bytes) -> None:
333        if self._only_consider_transfer_chunks:
334            try:
335                _extract_transfer_chunk(data)
336            except Exception:
337                await self.send_data(data)
338                return
339
340        if self._packets_before_failure is None:
341            await self.send_data(data)
342        elif self._packets_before_failure > 0:
343            self._packets_before_failure -= 1
344            await self.send_data(data)
345
346    def handle_event(self, event: Event) -> None:
347        if event.type is EventType.TRANSFER_START:
348            self.advance_packets_before_failure()
349
350
351class WindowPacketDropper(Filter):
352    """A filter to allow the same packet in each window to be dropped.
353
354    WindowPacketDropper with drop the nth packet in each window as
355    specified by window_packet_to_drop.  This process will happen
356    indefinitely for each window.
357
358    This filter should be instantiated in the same filter stack as an
359    HdlcPacketizer so that EventFilter can decode complete packets.
360    """
361
362    def __init__(
363        self,
364        send_data: Callable[[bytes], Awaitable[None]],
365        name: str,
366        window_packet_to_drop: int,
367    ):
368        super().__init__(send_data)
369        self._name = name
370        self._relay_packets = True
371        self._window_packet_to_drop = window_packet_to_drop
372        self._next_window_start_offset: int | None = 0
373        self._window_packet = 0
374
375    async def process(self, data: bytes) -> None:
376        data_chunk = None
377        try:
378            chunk = _extract_transfer_chunk(data)
379            if chunk.type is Chunk.Type.DATA:
380                data_chunk = chunk
381        except Exception:
382            # Invalid / non-chunk data (e.g. text logs); ignore.
383            pass
384
385        # Only count transfer data chunks as part of a window.
386        if data_chunk is not None:
387            if data_chunk.offset == self._next_window_start_offset:
388                # If a new window has been requested, wait until the first
389                # chunk matching its requested offset to begin counting window
390                # chunks. Any in-flight chunks from the previous window are
391                # allowed through.
392                self._window_packet = 0
393                self._next_window_start_offset = None
394
395            if self._window_packet != self._window_packet_to_drop:
396                await self.send_data(data)
397
398            self._window_packet += 1
399        else:
400            await self.send_data(data)
401
402    def handle_event(self, event: Event) -> None:
403        if event.type in (
404            EventType.PARAMETERS_RETRANSMIT,
405            EventType.PARAMETERS_CONTINUE,
406            EventType.START_ACK_CONFIRMATION,
407        ):
408            # A new transmission window has been requested, starting at the
409            # offset specified in the chunk. The receiver may already have data
410            # from the previous window in-flight, so don't immediately reset
411            # the window packet counter.
412            self._next_window_start_offset = event.chunk.offset
413
414
415class EventFilter(Filter):
416    """A filter that inspects packets and send events to other filters.
417
418    This filter should be instantiated in the same filter stack as an
419    HdlcPacketizer so that it can decode complete packets.
420    """
421
422    def __init__(
423        self,
424        send_data: Callable[[bytes], Awaitable[None]],
425        name: str,
426        event_queue: asyncio.Queue,
427    ):
428        super().__init__(send_data)
429        self._name = name
430        self._queue = event_queue
431
432    async def process(self, data: bytes) -> None:
433        try:
434            chunk = _extract_transfer_chunk(data)
435            if chunk.type is Chunk.Type.START:
436                await self._queue.put(Event(EventType.TRANSFER_START, chunk))
437            if chunk.type is Chunk.Type.START_ACK_CONFIRMATION:
438                await self._queue.put(
439                    Event(EventType.START_ACK_CONFIRMATION, chunk)
440                )
441            elif chunk.type is Chunk.Type.PARAMETERS_RETRANSMIT:
442                await self._queue.put(
443                    Event(EventType.PARAMETERS_RETRANSMIT, chunk)
444                )
445            elif chunk.type is Chunk.Type.PARAMETERS_CONTINUE:
446                await self._queue.put(
447                    Event(EventType.PARAMETERS_CONTINUE, chunk)
448                )
449        except:
450            # Silently ignore invalid packets
451            pass
452
453        await self.send_data(data)
454
455
456def _extract_transfer_chunk(data: bytes) -> Chunk:
457    """Gets a transfer Chunk from an HDLC frame containing an RPC packet.
458
459    Raises an exception if a valid chunk does not exist.
460    """
461
462    decoder = decode.FrameDecoder()
463    for frame in decoder.process(data):
464        packet = packet_pb2.RpcPacket()
465        packet.ParseFromString(frame.data)
466
467        if packet.payload:
468            raw_chunk = transfer_pb2.Chunk()
469            raw_chunk.ParseFromString(packet.payload)
470            return Chunk.from_message(raw_chunk)
471
472        # The incoming data is expected to be HDLC-packetized, so only one
473        # frame should exist.
474        break
475
476    raise ValueError("Invalid transfer chunk frame")
477
478
479async def _handle_simplex_events(
480    event_queue: asyncio.Queue, handlers: list[Callable[[Event], None]]
481):
482    while True:
483        event = await event_queue.get()
484        for handler in handlers:
485            handler(event)
486
487
488async def _handle_simplex_connection(
489    name: str,
490    filter_stack_config: list[config_pb2.FilterConfig],
491    reader: asyncio.StreamReader,
492    writer: asyncio.StreamWriter,
493    inbound_event_queue: asyncio.Queue,
494    outbound_event_queue: asyncio.Queue,
495) -> None:
496    """Handle a single direction of a bidirectional connection between
497    server and client."""
498
499    async def send(data: bytes):
500        writer.write(data)
501        await writer.drain()
502
503    filter_stack = EventFilter(send, name, outbound_event_queue)
504
505    event_handlers: list[Callable[[Event], None]] = []
506
507    # Build the filter stack from the bottom up
508    for config in reversed(filter_stack_config):
509        filter_name = config.WhichOneof("filter")
510        if filter_name == "hdlc_packetizer":
511            filter_stack = HdlcPacketizer(filter_stack)
512        elif filter_name == "data_dropper":
513            data_dropper = config.data_dropper
514            filter_stack = DataDropper(
515                filter_stack, name, data_dropper.rate, data_dropper.seed
516            )
517        elif filter_name == "rate_limiter":
518            filter_stack = RateLimiter(filter_stack, config.rate_limiter.rate)
519        elif filter_name == "data_transposer":
520            transposer = config.data_transposer
521            filter_stack = DataTransposer(
522                filter_stack,
523                name,
524                transposer.rate,
525                transposer.timeout,
526                transposer.seed,
527            )
528        elif filter_name == "server_failure":
529            server_failure = config.server_failure
530            filter_stack = ServerFailure(
531                filter_stack,
532                name,
533                server_failure.packets_before_failure,
534                server_failure.start_immediately,
535                server_failure.only_consider_transfer_chunks,
536            )
537            event_handlers.append(filter_stack.handle_event)
538        elif filter_name == "keep_drop_queue":
539            keep_drop_queue = config.keep_drop_queue
540            filter_stack = KeepDropQueue(
541                filter_stack,
542                name,
543                keep_drop_queue.keep_drop_queue,
544                keep_drop_queue.only_consider_transfer_chunks,
545            )
546        elif filter_name == "window_packet_dropper":
547            window_packet_dropper = config.window_packet_dropper
548            filter_stack = WindowPacketDropper(
549                filter_stack, name, window_packet_dropper.window_packet_to_drop
550            )
551            event_handlers.append(filter_stack.handle_event)
552        else:
553            sys.exit(f'Unknown filter {filter_name}')
554
555    event_task = asyncio.create_task(
556        _handle_simplex_events(inbound_event_queue, event_handlers)
557    )
558
559    while True:
560        # Arbitrarily chosen "page sized" read.
561        data = await reader.read(4096)
562
563        # An empty data indicates that the connection is closed.
564        if not data:
565            _LOG.info(f'{name} connection closed.')
566            return
567
568        await filter_stack.process(data)
569
570
571async def _handle_connection(
572    server_port: int,
573    config: config_pb2.ProxyConfig,
574    client_reader: asyncio.StreamReader,
575    client_writer: asyncio.StreamWriter,
576) -> None:
577    """Handle a connection between server and client."""
578
579    client_addr = client_writer.get_extra_info('peername')
580    _LOG.info(f'New client connection from {client_addr}')
581
582    # Open a new connection to the server for each client connection.
583    #
584    # TODO(konkers): catch exception and close client writer
585    server_reader, server_writer = await asyncio.open_connection(
586        'localhost', server_port
587    )
588    _LOG.info('New connection opened to server')
589
590    # Queues for the simplex connections to pass events to each other.
591    server_event_queue = asyncio.Queue()
592    client_event_queue = asyncio.Queue()
593
594    # Instantiate two simplex handler one for each direction of the connection.
595    _, pending = await asyncio.wait(
596        [
597            asyncio.create_task(
598                _handle_simplex_connection(
599                    "client",
600                    config.client_filter_stack,
601                    client_reader,
602                    server_writer,
603                    server_event_queue,
604                    client_event_queue,
605                )
606            ),
607            asyncio.create_task(
608                _handle_simplex_connection(
609                    "server",
610                    config.server_filter_stack,
611                    server_reader,
612                    client_writer,
613                    client_event_queue,
614                    server_event_queue,
615                )
616            ),
617        ],
618        return_when=asyncio.FIRST_COMPLETED,
619    )
620
621    # When one side terminates the connection, also terminate the other side
622    for task in pending:
623        task.cancel()
624
625    for stream in [client_writer, server_writer]:
626        stream.close()
627
628
629def _parse_args() -> argparse.Namespace:
630    parser = argparse.ArgumentParser(
631        description=__doc__,
632        formatter_class=argparse.RawDescriptionHelpFormatter,
633    )
634
635    parser.add_argument(
636        '--server-port',
637        type=int,
638        required=True,
639        help='Port of the integration test server.  The proxy will forward connections to this port',
640    )
641    parser.add_argument(
642        '--client-port',
643        type=int,
644        required=True,
645        help='Port on which to listen for connections from integration test client.',
646    )
647
648    return parser.parse_args()
649
650
651def _init_logging(level: int) -> None:
652    _LOG.setLevel(logging.DEBUG)
653    log_to_stderr = logging.StreamHandler()
654    log_to_stderr.setLevel(level)
655    log_to_stderr.setFormatter(
656        logging.Formatter(
657            fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s',
658            datefmt='%H:%M:%S',
659        )
660    )
661
662    _LOG.addHandler(log_to_stderr)
663
664
665async def _main(server_port: int, client_port: int) -> None:
666    _init_logging(logging.DEBUG)
667
668    # Load config from stdin using synchronous IO
669    text_config = sys.stdin.buffer.read()
670
671    config = text_format.Parse(text_config, config_pb2.ProxyConfig())
672
673    # Instantiate the TCP server.
674    server_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
675    server_socket.setsockopt(
676        socket.SOL_SOCKET, socket.SO_RCVBUF, _RECEIVE_BUFFER_SIZE
677    )
678    server_socket.bind(('', client_port))
679    server = await asyncio.start_server(
680        lambda reader, writer: _handle_connection(
681            server_port, config, reader, writer
682        ),
683        limit=_RECEIVE_BUFFER_SIZE,
684        sock=server_socket,
685    )
686
687    addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets)
688    _LOG.info(f'Listening for client connection on {addrs}')
689
690    # Run the TCP server.
691    async with server:
692        await server.serve_forever()
693
694
695if __name__ == '__main__':
696    asyncio.run(_main(**vars(_parse_args())))
697