1#!/usr/bin/env python3 2# Copyright 2022 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Proxy for transfer integration testing. 16 17This module contains a proxy for transfer intergation testing. It is capable 18of introducing various link failures into the connection between the client and 19server. 20""" 21 22import abc 23import argparse 24import asyncio 25from enum import Enum 26import logging 27import random 28import socket 29import sys 30import time 31from typing import Awaitable, Callable, Iterable, NamedTuple 32 33from google.protobuf import text_format 34 35from pw_rpc.internal import packet_pb2 36from pw_transfer import transfer_pb2 37from pw_transfer.integration_test import config_pb2 38from pw_hdlc import decode 39from pw_transfer.chunk import Chunk 40 41_LOG = logging.getLogger('pw_transfer_intergration_test_proxy') 42 43# This is the maximum size of the socket receive buffers. Ideally, this is set 44# to the lowest allowed value to minimize buffering between the proxy and 45# clients so rate limiting causes the client to block and wait for the 46# integration test proxy to drain rather than allowing OS buffers to backlog 47# large quantities of data. 48# 49# Note that the OS may chose to not strictly follow this requested buffer size. 50# Still, setting this value to be relatively small does reduce bufer sizes 51# significantly enough to better reflect typical inter-device communication. 52# 53# For this to be effective, clients should also configure their sockets to a 54# smaller send buffer size. 55_RECEIVE_BUFFER_SIZE = 2048 56 57 58class EventType(Enum): 59 TRANSFER_START = 1 60 PARAMETERS_RETRANSMIT = 2 61 PARAMETERS_CONTINUE = 3 62 START_ACK_CONFIRMATION = 4 63 64 65class Event(NamedTuple): 66 type: EventType 67 chunk: Chunk 68 69 70class Filter(abc.ABC): 71 """An abstract interface for manipulating a stream of data. 72 73 ``Filter``s are used to implement various transforms to simulate real 74 world link properties. Some examples include: data corruption, 75 packet loss, packet reordering, rate limiting, latency modeling. 76 77 A ``Filter`` implementation should implement the ``process`` method 78 and call ``self.send_data()`` when it has data to send. 79 """ 80 81 def __init__(self, send_data: Callable[[bytes], Awaitable[None]]): 82 self.send_data = send_data 83 84 @abc.abstractmethod 85 async def process(self, data: bytes) -> None: 86 """Processes incoming data. 87 88 Implementations of this method may send arbitrary data, or none, using 89 the ``self.send_data()`` handler. 90 """ 91 92 async def __call__(self, data: bytes) -> None: 93 await self.process(data) 94 95 96class HdlcPacketizer(Filter): 97 """A filter which aggregates data into complete HDLC packets. 98 99 Since the proxy transport (SOCK_STREAM) has no framing and we want some 100 filters to operates on whole frames, this filter can be used so that 101 downstream filters see whole frames. 102 """ 103 104 def __init__(self, send_data: Callable[[bytes], Awaitable[None]]): 105 super().__init__(send_data) 106 self.decoder = decode.FrameDecoder() 107 108 async def process(self, data: bytes) -> None: 109 for frame in self.decoder.process(data): 110 await self.send_data(frame.raw_encoded) 111 112 113class DataDropper(Filter): 114 """A filter which drops some data. 115 116 DataDropper will drop data passed through ``process()`` at the 117 specified ``rate``. 118 """ 119 120 def __init__( 121 self, 122 send_data: Callable[[bytes], Awaitable[None]], 123 name: str, 124 rate: float, 125 seed: int | None = None, 126 ): 127 super().__init__(send_data) 128 self._rate = rate 129 self._name = name 130 if seed == None: 131 seed = time.time_ns() 132 self._rng = random.Random(seed) 133 _LOG.info(f'{name} DataDropper initialized with seed {seed}') 134 135 async def process(self, data: bytes) -> None: 136 if self._rng.uniform(0.0, 1.0) < self._rate: 137 _LOG.info(f'{self._name} dropped {len(data)} bytes of data') 138 else: 139 await self.send_data(data) 140 141 142class KeepDropQueue(Filter): 143 """A filter which alternates between sending packets and dropping packets. 144 145 A KeepDropQueue filter will alternate between keeping packets and dropping 146 chunks of data based on a keep/drop queue provided during its creation. The 147 queue is looped over unless a negative element is found. A negative number 148 is effectively the same as a value of infinity. 149 150 This filter is typically most practical when used with a packetizer so data 151 can be dropped as distinct packets. 152 153 Examples: 154 155 keep_drop_queue = [3, 2]: 156 Keeps 3 packets, 157 Drops 2 packets, 158 Keeps 3 packets, 159 Drops 2 packets, 160 ... [loops indefinitely] 161 162 keep_drop_queue = [5, 99, 1, -1]: 163 Keeps 5 packets, 164 Drops 99 packets, 165 Keeps 1 packet, 166 Drops all further packets. 167 """ 168 169 def __init__( 170 self, 171 send_data: Callable[[bytes], Awaitable[None]], 172 name: str, 173 keep_drop_queue: Iterable[int], 174 only_consider_transfer_chunks: bool = False, 175 ): 176 super().__init__(send_data) 177 self._keep_drop_queue = list(keep_drop_queue) 178 self._loop_idx = 0 179 self._current_count = self._keep_drop_queue[0] 180 self._keep = True 181 self._name = name 182 self._only_consider_transfer_chunks = only_consider_transfer_chunks 183 184 async def process(self, data: bytes) -> None: 185 if self._only_consider_transfer_chunks: 186 try: 187 _extract_transfer_chunk(data) 188 except Exception: 189 await self.send_data(data) 190 return 191 192 # Move forward through the queue if needed. 193 while self._current_count == 0: 194 self._loop_idx += 1 195 self._current_count = self._keep_drop_queue[ 196 self._loop_idx % len(self._keep_drop_queue) 197 ] 198 self._keep = not self._keep 199 200 if self._current_count > 0: 201 self._current_count -= 1 202 203 if self._keep: 204 await self.send_data(data) 205 _LOG.info(f'{self._name} forwarded {len(data)} bytes of data') 206 else: 207 _LOG.info(f'{self._name} dropped {len(data)} bytes of data') 208 209 210class RateLimiter(Filter): 211 """A filter which limits transmission rate. 212 213 This filter delays transmission of data by len(data)/rate. 214 """ 215 216 def __init__( 217 self, send_data: Callable[[bytes], Awaitable[None]], rate: float 218 ): 219 super().__init__(send_data) 220 self._rate = rate 221 222 async def process(self, data: bytes) -> None: 223 delay = len(data) / self._rate 224 await asyncio.sleep(delay) 225 await self.send_data(data) 226 227 228class DataTransposer(Filter): 229 """A filter which occasionally transposes two chunks of data. 230 231 This filter transposes data at the specified rate. It does this by 232 holding a chunk to transpose until another chunk arrives. The filter 233 will not hold a chunk longer than ``timeout`` seconds. 234 """ 235 236 def __init__( 237 self, 238 send_data: Callable[[bytes], Awaitable[None]], 239 name: str, 240 rate: float, 241 timeout: float, 242 seed: int, 243 ): 244 super().__init__(send_data) 245 self._name = name 246 self._rate = rate 247 self._timeout = timeout 248 self._data_queue = asyncio.Queue() 249 self._rng = random.Random(seed) 250 self._transpose_task = asyncio.create_task(self._transpose_handler()) 251 252 _LOG.info(f'{name} DataTranspose initialized with seed {seed}') 253 254 def __del__(self): 255 _LOG.info(f'{self._name} cleaning up transpose task.') 256 self._transpose_task.cancel() 257 258 async def _transpose_handler(self): 259 """Async task that handles the packet transposition and timeouts""" 260 held_data: bytes | None = None 261 while True: 262 # Only use timeout if we have data held for transposition 263 timeout = None if held_data is None else self._timeout 264 try: 265 data = await asyncio.wait_for( 266 self._data_queue.get(), timeout=timeout 267 ) 268 269 if held_data is not None: 270 # If we have held data, send it out of order. 271 await self.send_data(data) 272 await self.send_data(held_data) 273 held_data = None 274 else: 275 # Otherwise decide if we should transpose the current data. 276 if self._rng.uniform(0.0, 1.0) < self._rate: 277 _LOG.info( 278 f'{self._name} transposing {len(data)} bytes of data' 279 ) 280 held_data = data 281 else: 282 await self.send_data(data) 283 284 except asyncio.TimeoutError: 285 _LOG.info(f'{self._name} sending data in order due to timeout') 286 await self.send_data(held_data) 287 held_data = None 288 289 async def process(self, data: bytes) -> None: 290 # Queue data for processing by the transpose task. 291 await self._data_queue.put(data) 292 293 294class ServerFailure(Filter): 295 """A filter to simulate the server stopping sending packets. 296 297 ServerFailure takes a list of numbers of packets to send before 298 dropping all subsequent packets until a TRANSFER_START packet 299 is seen. This process is repeated for each element in 300 packets_before_failure. After that list is exhausted, ServerFailure 301 will send all packets. 302 303 This filter should be instantiated in the same filter stack as an 304 HdlcPacketizer so that EventFilter can decode complete packets. 305 """ 306 307 def __init__( 308 self, 309 send_data: Callable[[bytes], Awaitable[None]], 310 name: str, 311 packets_before_failure_list: list[int], 312 start_immediately: bool = False, 313 only_consider_transfer_chunks: bool = False, 314 ): 315 super().__init__(send_data) 316 self._name = name 317 self._relay_packets = True 318 self._packets_before_failure_list = packets_before_failure_list 319 self._packets_before_failure = None 320 self._only_consider_transfer_chunks = only_consider_transfer_chunks 321 if start_immediately: 322 self.advance_packets_before_failure() 323 324 def advance_packets_before_failure(self): 325 if len(self._packets_before_failure_list) > 0: 326 self._packets_before_failure = ( 327 self._packets_before_failure_list.pop(0) 328 ) 329 else: 330 self._packets_before_failure = None 331 332 async def process(self, data: bytes) -> None: 333 if self._only_consider_transfer_chunks: 334 try: 335 _extract_transfer_chunk(data) 336 except Exception: 337 await self.send_data(data) 338 return 339 340 if self._packets_before_failure is None: 341 await self.send_data(data) 342 elif self._packets_before_failure > 0: 343 self._packets_before_failure -= 1 344 await self.send_data(data) 345 346 def handle_event(self, event: Event) -> None: 347 if event.type is EventType.TRANSFER_START: 348 self.advance_packets_before_failure() 349 350 351class WindowPacketDropper(Filter): 352 """A filter to allow the same packet in each window to be dropped. 353 354 WindowPacketDropper with drop the nth packet in each window as 355 specified by window_packet_to_drop. This process will happen 356 indefinitely for each window. 357 358 This filter should be instantiated in the same filter stack as an 359 HdlcPacketizer so that EventFilter can decode complete packets. 360 """ 361 362 def __init__( 363 self, 364 send_data: Callable[[bytes], Awaitable[None]], 365 name: str, 366 window_packet_to_drop: int, 367 ): 368 super().__init__(send_data) 369 self._name = name 370 self._relay_packets = True 371 self._window_packet_to_drop = window_packet_to_drop 372 self._next_window_start_offset: int | None = 0 373 self._window_packet = 0 374 375 async def process(self, data: bytes) -> None: 376 data_chunk = None 377 try: 378 chunk = _extract_transfer_chunk(data) 379 if chunk.type is Chunk.Type.DATA: 380 data_chunk = chunk 381 except Exception: 382 # Invalid / non-chunk data (e.g. text logs); ignore. 383 pass 384 385 # Only count transfer data chunks as part of a window. 386 if data_chunk is not None: 387 if data_chunk.offset == self._next_window_start_offset: 388 # If a new window has been requested, wait until the first 389 # chunk matching its requested offset to begin counting window 390 # chunks. Any in-flight chunks from the previous window are 391 # allowed through. 392 self._window_packet = 0 393 self._next_window_start_offset = None 394 395 if self._window_packet != self._window_packet_to_drop: 396 await self.send_data(data) 397 398 self._window_packet += 1 399 else: 400 await self.send_data(data) 401 402 def handle_event(self, event: Event) -> None: 403 if event.type in ( 404 EventType.PARAMETERS_RETRANSMIT, 405 EventType.PARAMETERS_CONTINUE, 406 EventType.START_ACK_CONFIRMATION, 407 ): 408 # A new transmission window has been requested, starting at the 409 # offset specified in the chunk. The receiver may already have data 410 # from the previous window in-flight, so don't immediately reset 411 # the window packet counter. 412 self._next_window_start_offset = event.chunk.offset 413 414 415class EventFilter(Filter): 416 """A filter that inspects packets and send events to other filters. 417 418 This filter should be instantiated in the same filter stack as an 419 HdlcPacketizer so that it can decode complete packets. 420 """ 421 422 def __init__( 423 self, 424 send_data: Callable[[bytes], Awaitable[None]], 425 name: str, 426 event_queue: asyncio.Queue, 427 ): 428 super().__init__(send_data) 429 self._name = name 430 self._queue = event_queue 431 432 async def process(self, data: bytes) -> None: 433 try: 434 chunk = _extract_transfer_chunk(data) 435 if chunk.type is Chunk.Type.START: 436 await self._queue.put(Event(EventType.TRANSFER_START, chunk)) 437 if chunk.type is Chunk.Type.START_ACK_CONFIRMATION: 438 await self._queue.put( 439 Event(EventType.START_ACK_CONFIRMATION, chunk) 440 ) 441 elif chunk.type is Chunk.Type.PARAMETERS_RETRANSMIT: 442 await self._queue.put( 443 Event(EventType.PARAMETERS_RETRANSMIT, chunk) 444 ) 445 elif chunk.type is Chunk.Type.PARAMETERS_CONTINUE: 446 await self._queue.put( 447 Event(EventType.PARAMETERS_CONTINUE, chunk) 448 ) 449 except: 450 # Silently ignore invalid packets 451 pass 452 453 await self.send_data(data) 454 455 456def _extract_transfer_chunk(data: bytes) -> Chunk: 457 """Gets a transfer Chunk from an HDLC frame containing an RPC packet. 458 459 Raises an exception if a valid chunk does not exist. 460 """ 461 462 decoder = decode.FrameDecoder() 463 for frame in decoder.process(data): 464 packet = packet_pb2.RpcPacket() 465 packet.ParseFromString(frame.data) 466 467 if packet.payload: 468 raw_chunk = transfer_pb2.Chunk() 469 raw_chunk.ParseFromString(packet.payload) 470 return Chunk.from_message(raw_chunk) 471 472 # The incoming data is expected to be HDLC-packetized, so only one 473 # frame should exist. 474 break 475 476 raise ValueError("Invalid transfer chunk frame") 477 478 479async def _handle_simplex_events( 480 event_queue: asyncio.Queue, handlers: list[Callable[[Event], None]] 481): 482 while True: 483 event = await event_queue.get() 484 for handler in handlers: 485 handler(event) 486 487 488async def _handle_simplex_connection( 489 name: str, 490 filter_stack_config: list[config_pb2.FilterConfig], 491 reader: asyncio.StreamReader, 492 writer: asyncio.StreamWriter, 493 inbound_event_queue: asyncio.Queue, 494 outbound_event_queue: asyncio.Queue, 495) -> None: 496 """Handle a single direction of a bidirectional connection between 497 server and client.""" 498 499 async def send(data: bytes): 500 writer.write(data) 501 await writer.drain() 502 503 filter_stack = EventFilter(send, name, outbound_event_queue) 504 505 event_handlers: list[Callable[[Event], None]] = [] 506 507 # Build the filter stack from the bottom up 508 for config in reversed(filter_stack_config): 509 filter_name = config.WhichOneof("filter") 510 if filter_name == "hdlc_packetizer": 511 filter_stack = HdlcPacketizer(filter_stack) 512 elif filter_name == "data_dropper": 513 data_dropper = config.data_dropper 514 filter_stack = DataDropper( 515 filter_stack, name, data_dropper.rate, data_dropper.seed 516 ) 517 elif filter_name == "rate_limiter": 518 filter_stack = RateLimiter(filter_stack, config.rate_limiter.rate) 519 elif filter_name == "data_transposer": 520 transposer = config.data_transposer 521 filter_stack = DataTransposer( 522 filter_stack, 523 name, 524 transposer.rate, 525 transposer.timeout, 526 transposer.seed, 527 ) 528 elif filter_name == "server_failure": 529 server_failure = config.server_failure 530 filter_stack = ServerFailure( 531 filter_stack, 532 name, 533 server_failure.packets_before_failure, 534 server_failure.start_immediately, 535 server_failure.only_consider_transfer_chunks, 536 ) 537 event_handlers.append(filter_stack.handle_event) 538 elif filter_name == "keep_drop_queue": 539 keep_drop_queue = config.keep_drop_queue 540 filter_stack = KeepDropQueue( 541 filter_stack, 542 name, 543 keep_drop_queue.keep_drop_queue, 544 keep_drop_queue.only_consider_transfer_chunks, 545 ) 546 elif filter_name == "window_packet_dropper": 547 window_packet_dropper = config.window_packet_dropper 548 filter_stack = WindowPacketDropper( 549 filter_stack, name, window_packet_dropper.window_packet_to_drop 550 ) 551 event_handlers.append(filter_stack.handle_event) 552 else: 553 sys.exit(f'Unknown filter {filter_name}') 554 555 event_task = asyncio.create_task( 556 _handle_simplex_events(inbound_event_queue, event_handlers) 557 ) 558 559 while True: 560 # Arbitrarily chosen "page sized" read. 561 data = await reader.read(4096) 562 563 # An empty data indicates that the connection is closed. 564 if not data: 565 _LOG.info(f'{name} connection closed.') 566 return 567 568 await filter_stack.process(data) 569 570 571async def _handle_connection( 572 server_port: int, 573 config: config_pb2.ProxyConfig, 574 client_reader: asyncio.StreamReader, 575 client_writer: asyncio.StreamWriter, 576) -> None: 577 """Handle a connection between server and client.""" 578 579 client_addr = client_writer.get_extra_info('peername') 580 _LOG.info(f'New client connection from {client_addr}') 581 582 # Open a new connection to the server for each client connection. 583 # 584 # TODO(konkers): catch exception and close client writer 585 server_reader, server_writer = await asyncio.open_connection( 586 'localhost', server_port 587 ) 588 _LOG.info('New connection opened to server') 589 590 # Queues for the simplex connections to pass events to each other. 591 server_event_queue = asyncio.Queue() 592 client_event_queue = asyncio.Queue() 593 594 # Instantiate two simplex handler one for each direction of the connection. 595 _, pending = await asyncio.wait( 596 [ 597 asyncio.create_task( 598 _handle_simplex_connection( 599 "client", 600 config.client_filter_stack, 601 client_reader, 602 server_writer, 603 server_event_queue, 604 client_event_queue, 605 ) 606 ), 607 asyncio.create_task( 608 _handle_simplex_connection( 609 "server", 610 config.server_filter_stack, 611 server_reader, 612 client_writer, 613 client_event_queue, 614 server_event_queue, 615 ) 616 ), 617 ], 618 return_when=asyncio.FIRST_COMPLETED, 619 ) 620 621 # When one side terminates the connection, also terminate the other side 622 for task in pending: 623 task.cancel() 624 625 for stream in [client_writer, server_writer]: 626 stream.close() 627 628 629def _parse_args() -> argparse.Namespace: 630 parser = argparse.ArgumentParser( 631 description=__doc__, 632 formatter_class=argparse.RawDescriptionHelpFormatter, 633 ) 634 635 parser.add_argument( 636 '--server-port', 637 type=int, 638 required=True, 639 help='Port of the integration test server. The proxy will forward connections to this port', 640 ) 641 parser.add_argument( 642 '--client-port', 643 type=int, 644 required=True, 645 help='Port on which to listen for connections from integration test client.', 646 ) 647 648 return parser.parse_args() 649 650 651def _init_logging(level: int) -> None: 652 _LOG.setLevel(logging.DEBUG) 653 log_to_stderr = logging.StreamHandler() 654 log_to_stderr.setLevel(level) 655 log_to_stderr.setFormatter( 656 logging.Formatter( 657 fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s', 658 datefmt='%H:%M:%S', 659 ) 660 ) 661 662 _LOG.addHandler(log_to_stderr) 663 664 665async def _main(server_port: int, client_port: int) -> None: 666 _init_logging(logging.DEBUG) 667 668 # Load config from stdin using synchronous IO 669 text_config = sys.stdin.buffer.read() 670 671 config = text_format.Parse(text_config, config_pb2.ProxyConfig()) 672 673 # Instantiate the TCP server. 674 server_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) 675 server_socket.setsockopt( 676 socket.SOL_SOCKET, socket.SO_RCVBUF, _RECEIVE_BUFFER_SIZE 677 ) 678 server_socket.bind(('', client_port)) 679 server = await asyncio.start_server( 680 lambda reader, writer: _handle_connection( 681 server_port, config, reader, writer 682 ), 683 limit=_RECEIVE_BUFFER_SIZE, 684 sock=server_socket, 685 ) 686 687 addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets) 688 _LOG.info(f'Listening for client connection on {addrs}') 689 690 # Run the TCP server. 691 async with server: 692 await server.serve_forever() 693 694 695if __name__ == '__main__': 696 asyncio.run(_main(**vars(_parse_args()))) 697