1#!/usr/bin/env python3 2# Copyright 2022 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Test fixture for pw_transfer integration tests.""" 16 17from __future__ import annotations 18 19import argparse 20import asyncio 21from dataclasses import dataclass 22import logging 23import pathlib 24from pathlib import Path 25import sys 26import tempfile 27from typing import BinaryIO, Iterable, NamedTuple 28import unittest 29 30from google.protobuf import text_format 31 32from pw_protobuf_protos import status_pb2 33from pw_transfer.integration_test import config_pb2 34from python.runfiles import runfiles 35 36_LOG = logging.getLogger('pw_transfer_intergration_test_proxy') 37_LOG.level = logging.DEBUG 38_LOG.addHandler(logging.StreamHandler(sys.stdout)) 39 40 41class LogMonitor: 42 """Monitors lines read from the reader, and logs them.""" 43 44 class Error(Exception): 45 """Raised if wait_for_line reaches EOF before expected line.""" 46 47 pass 48 49 def __init__(self, prefix: str, reader: asyncio.StreamReader): 50 """Initializer. 51 52 Args: 53 prefix: Prepended to read lines before they are logged. 54 reader: StreamReader to read lines from. 55 """ 56 self._prefix = prefix 57 self._reader = reader 58 59 # Queue of messages waiting to be monitored. 60 self._queue = asyncio.Queue() 61 # Relog any messages read from the reader, and enqueue them for 62 # monitoring. 63 self._relog_and_enqueue_task = asyncio.create_task( 64 self._relog_and_enqueue() 65 ) 66 67 async def wait_for_line(self, msg: str): 68 """Wait for a line containing msg to be read from the reader.""" 69 while True: 70 line = await self._queue.get() 71 if not line: 72 raise LogMonitor.Error( 73 f"Reached EOF before getting line matching {msg}" 74 ) 75 if msg in line.decode(): 76 return 77 78 async def wait_for_eof(self): 79 """Wait for the reader to reach EOF, relogging any lines read.""" 80 # Drain the queue, since we're not monitoring it any more. 81 drain_queue = asyncio.create_task(self._drain_queue()) 82 await asyncio.gather(drain_queue, self._relog_and_enqueue_task) 83 84 async def _relog_and_enqueue(self): 85 """Reads lines from the reader, logs them, and puts them in queue.""" 86 while True: 87 line = await self._reader.readline() 88 await self._queue.put(line) 89 if line: 90 _LOG.info(f"{self._prefix} {line.decode().rstrip()}") 91 else: 92 # EOF. Note, we still put the EOF in the queue, so that the 93 # queue reader can process it appropriately. 94 return 95 96 async def _drain_queue(self): 97 while True: 98 line = await self._queue.get() 99 if not line: 100 # EOF. 101 return 102 103 104class MonitoredSubprocess: 105 """A subprocess with monitored asynchronous communication.""" 106 107 @staticmethod 108 async def create(cmd: list[str], prefix: str, stdinput: bytes): 109 """Starts the subprocess and writes stdinput to stdin. 110 111 This method returns once stdinput has been written to stdin. The 112 MonitoredSubprocess continues to log the process's stderr and stdout 113 (with the prefix) until it terminates. 114 115 Args: 116 cmd: Command line to execute. 117 prefix: Prepended to process logs. 118 stdinput: Written to stdin on process startup. 119 """ 120 self = MonitoredSubprocess() 121 self._process = await asyncio.create_subprocess_exec( 122 *cmd, 123 stdin=asyncio.subprocess.PIPE, 124 stdout=asyncio.subprocess.PIPE, 125 stderr=asyncio.subprocess.PIPE, 126 ) 127 128 self._stderr_monitor = LogMonitor( 129 f"{prefix} ERR:", self._process.stderr 130 ) 131 self._stdout_monitor = LogMonitor( 132 f"{prefix} OUT:", self._process.stdout 133 ) 134 135 self._process.stdin.write(stdinput) 136 await self._process.stdin.drain() 137 self._process.stdin.close() 138 await self._process.stdin.wait_closed() 139 return self 140 141 async def wait_for_line(self, stream: str, msg: str, timeout: float): 142 """Wait for a line containing msg to be read on the stream.""" 143 if stream == "stdout": 144 monitor = self._stdout_monitor 145 elif stream == "stderr": 146 monitor = self._stderr_monitor 147 else: 148 raise ValueError( 149 "Stream must be 'stdout' or 'stderr', got {stream}" 150 ) 151 152 await asyncio.wait_for(monitor.wait_for_line(msg), timeout) 153 154 def returncode(self): 155 return self._process.returncode 156 157 def terminate(self): 158 """Terminate the process.""" 159 self._process.terminate() 160 161 async def wait_for_termination(self, timeout: float | None): 162 """Wait for the process to terminate.""" 163 await asyncio.wait_for( 164 asyncio.gather( 165 self._process.wait(), 166 self._stdout_monitor.wait_for_eof(), 167 self._stderr_monitor.wait_for_eof(), 168 ), 169 timeout, 170 ) 171 172 async def terminate_and_wait(self, timeout: float): 173 """Terminate the process and wait for it to exit.""" 174 if self.returncode() is not None: 175 # Process already terminated 176 return 177 self.terminate() 178 await self.wait_for_termination(timeout) 179 180 181class TransferConfig(NamedTuple): 182 """A simple tuple to collect configs for test binaries.""" 183 184 server: config_pb2.ServerConfig 185 client: config_pb2.ClientConfig 186 proxy: config_pb2.ProxyConfig 187 188 189class TransferIntegrationTestHarness: 190 """A class to manage transfer integration tests""" 191 192 # Prefix for log messages coming from the harness (as opposed to the server, 193 # client, or proxy processes). Padded so that the length is the same as 194 # "SERVER OUT:". 195 _PREFIX = "HARNESS: " 196 197 @dataclass 198 class Config: 199 server_port: int = 3300 200 client_port: int = 3301 201 java_client_binary: Path | None = None 202 cpp_client_binary: Path | None = None 203 python_client_binary: Path | None = None 204 proxy_binary: Path | None = None 205 server_binary: Path | None = None 206 207 class TransferExitCodes(NamedTuple): 208 client: int 209 server: int 210 211 def __init__(self, harness_config: Config) -> None: 212 # TODO(tpudlik): This is Bazel-only. Support gn, too. 213 r = runfiles.Create() 214 215 # Set defaults. 216 self._JAVA_CLIENT_BINARY = r.Rlocation( 217 "pigweed/pw_transfer/integration_test/java_client" 218 ) 219 self._CPP_CLIENT_BINARY = r.Rlocation( 220 "pigweed/pw_transfer/integration_test/cpp_client" 221 ) 222 self._PYTHON_CLIENT_BINARY = r.Rlocation( 223 "pigweed/pw_transfer/integration_test/python_client" 224 ) 225 self._PROXY_BINARY = r.Rlocation( 226 "pigweed/pw_transfer/integration_test/proxy" 227 ) 228 self._SERVER_BINARY = r.Rlocation( 229 "pigweed/pw_transfer/integration_test/server" 230 ) 231 232 # Server/client ports are non-optional, so use those. 233 self._CLIENT_PORT = harness_config.client_port 234 self._SERVER_PORT = harness_config.server_port 235 236 self._server: MonitoredSubprocess | None = None 237 self._client: MonitoredSubprocess | None = None 238 self._proxy: MonitoredSubprocess | None = None 239 240 # If the harness configuration specifies overrides, use those. 241 if harness_config.java_client_binary is not None: 242 self._JAVA_CLIENT_BINARY = harness_config.java_client_binary 243 if harness_config.cpp_client_binary is not None: 244 self._CPP_CLIENT_BINARY = harness_config.cpp_client_binary 245 if harness_config.python_client_binary is not None: 246 self._PYTHON_CLIENT_BINARY = harness_config.python_client_binary 247 if harness_config.proxy_binary is not None: 248 self._PROXY_BINARY = harness_config.proxy_binary 249 if harness_config.server_binary is not None: 250 self._SERVER_BINARY = harness_config.server_binary 251 252 self._CLIENT_BINARY = { 253 "cpp": self._CPP_CLIENT_BINARY, 254 "java": self._JAVA_CLIENT_BINARY, 255 "python": self._PYTHON_CLIENT_BINARY, 256 } 257 258 async def _start_client( 259 self, client_type: str, config: config_pb2.ClientConfig 260 ): 261 _LOG.info(f"{self._PREFIX} Starting client with config\n{config}") 262 self._client = await MonitoredSubprocess.create( 263 [self._CLIENT_BINARY[client_type], str(self._CLIENT_PORT)], 264 "CLIENT", 265 str(config).encode('ascii'), 266 ) 267 268 async def _start_server(self, config: config_pb2.ServerConfig): 269 _LOG.info(f"{self._PREFIX} Starting server with config\n{config}") 270 self._server = await MonitoredSubprocess.create( 271 [self._SERVER_BINARY, str(self._SERVER_PORT)], 272 "SERVER", 273 str(config).encode('ascii'), 274 ) 275 276 async def _start_proxy(self, config: config_pb2.ProxyConfig): 277 _LOG.info(f"{self._PREFIX} Starting proxy with config\n{config}") 278 self._proxy = await MonitoredSubprocess.create( 279 [ 280 self._PROXY_BINARY, 281 "--server-port", 282 str(self._SERVER_PORT), 283 "--client-port", 284 str(self._CLIENT_PORT), 285 ], 286 # Extra space in "PROXY " so that it lines up with "SERVER". 287 "PROXY ", 288 str(config).encode('ascii'), 289 ) 290 291 async def perform_transfers( 292 self, 293 server_config: config_pb2.ServerConfig, 294 client_type: str, 295 client_config: config_pb2.ClientConfig, 296 proxy_config: config_pb2.ProxyConfig, 297 ) -> TransferExitCodes: 298 """Performs a pw_transfer write. 299 300 Args: 301 server_config: Server configuration. 302 client_type: Either "cpp", "java", or "python". 303 client_config: Client configuration. 304 proxy_config: Proxy configuration. 305 306 Returns: 307 Exit code of the client and server as a tuple. 308 """ 309 # Timeout for components (server, proxy) to come up or shut down after 310 # write is finished or a signal is sent. Approximately arbitrary. Should 311 # not be too long so that we catch bugs in the server that prevent it 312 # from shutting down. 313 TIMEOUT = 5 # seconds 314 315 try: 316 await self._start_proxy(proxy_config) 317 assert self._proxy is not None 318 await self._proxy.wait_for_line( 319 "stderr", "Listening for client connection", TIMEOUT 320 ) 321 322 await self._start_server(server_config) 323 assert self._server is not None 324 await self._server.wait_for_line( 325 "stderr", "Starting pw_rpc server on port", TIMEOUT 326 ) 327 328 await self._start_client(client_type, client_config) 329 assert self._client is not None 330 # No timeout: the client will only exit once the transfer 331 # completes, and this can take a long time for large payloads. 332 await self._client.wait_for_termination(None) 333 334 # Wait for the server to exit. 335 await self._server.wait_for_termination(TIMEOUT) 336 337 finally: 338 # Stop the server, if still running. (Only expected if the 339 # wait_for above timed out.) 340 if self._server is not None: 341 await self._server.terminate_and_wait(TIMEOUT) 342 # Stop the proxy. Unlike the server, we expect it to still be 343 # running at this stage. 344 if self._proxy is not None: 345 await self._proxy.terminate_and_wait(TIMEOUT) 346 347 return self.TransferExitCodes( 348 self._client.returncode(), self._server.returncode() 349 ) 350 351 352class BasicTransfer(NamedTuple): 353 id: int 354 type: config_pb2.TransferAction.TransferType.ValueType 355 data: bytes 356 357 358class TransferIntegrationTest(unittest.TestCase): 359 """A base class for transfer integration tests. 360 361 This significantly reduces the boiler plate required for building 362 integration test cases for pw_transfer. This class does not include any 363 tests itself, but instead bundles together much of the boiler plate required 364 for making an integration test for pw_transfer using this test fixture. 365 """ 366 367 HARNESS_CONFIG = TransferIntegrationTestHarness.Config() 368 369 @classmethod 370 def setUpClass(cls): 371 cls.harness = TransferIntegrationTestHarness(cls.HARNESS_CONFIG) 372 373 @staticmethod 374 def default_server_config() -> config_pb2.ServerConfig: 375 return config_pb2.ServerConfig( 376 chunk_size_bytes=216, 377 pending_bytes=64 * 1024, 378 chunk_timeout_seconds=5, 379 transfer_service_retries=4, 380 extend_window_divisor=32, 381 ) 382 383 @staticmethod 384 def default_client_config() -> config_pb2.ClientConfig: 385 return config_pb2.ClientConfig( 386 max_retries=5, 387 max_lifetime_retries=1500, 388 initial_chunk_timeout_ms=4000, 389 chunk_timeout_ms=4000, 390 ) 391 392 @staticmethod 393 def default_proxy_config() -> config_pb2.ProxyConfig: 394 return text_format.Parse( 395 """ 396 client_filter_stack: [ 397 { hdlc_packetizer: {} }, 398 { data_dropper: {rate: 0.01, seed: 1649963713563718435} } 399 ] 400 401 server_filter_stack: [ 402 { hdlc_packetizer: {} }, 403 { data_dropper: {rate: 0.01, seed: 1649963713563718436} } 404 ]""", 405 config_pb2.ProxyConfig(), 406 ) 407 408 @staticmethod 409 def default_config() -> TransferConfig: 410 """Returns a new transfer config with default options.""" 411 return TransferConfig( 412 TransferIntegrationTest.default_server_config(), 413 TransferIntegrationTest.default_client_config(), 414 TransferIntegrationTest.default_proxy_config(), 415 ) 416 417 def do_single_write( 418 self, 419 client_type: str, 420 config: TransferConfig, 421 resource_id: int, 422 data: bytes, 423 protocol_version=config_pb2.TransferAction.ProtocolVersion.LATEST, 424 permanent_resource_id=False, 425 expected_status=status_pb2.StatusCode.OK, 426 initial_offset=0, 427 offsettable_resources=False, 428 ) -> None: 429 """Performs a single client-to-server write of the provided data.""" 430 with tempfile.NamedTemporaryFile() as f_payload, tempfile.NamedTemporaryFile() as f_server_output: 431 if permanent_resource_id: 432 config.server.resources[ 433 resource_id 434 ].default_destination_path = f_server_output.name 435 else: 436 config.server.resources[resource_id].destination_paths.append( 437 f_server_output.name 438 ) 439 config.server.resources[ 440 resource_id 441 ].offsettable = offsettable_resources 442 config.client.transfer_actions.append( 443 config_pb2.TransferAction( 444 resource_id=resource_id, 445 file_path=f_payload.name, 446 transfer_type=config_pb2.TransferAction.TransferType.WRITE_TO_SERVER, 447 protocol_version=protocol_version, 448 expected_status=expected_status, 449 initial_offset=initial_offset, 450 ) 451 ) 452 453 f_payload.write(data) 454 f_payload.flush() # Ensure contents are there to read! 455 exit_codes = asyncio.run( 456 self.harness.perform_transfers( 457 config.server, client_type, config.client, config.proxy 458 ) 459 ) 460 461 self.assertEqual(exit_codes.client, 0) 462 self.assertEqual(exit_codes.server, 0) 463 if expected_status == status_pb2.StatusCode.OK: 464 bytes_output = f_server_output.read() 465 self.assertEqual( 466 bytes_output[initial_offset:], 467 data, 468 ) 469 # Ensure we didn't write data to places before offset 470 self.assertEqual( 471 bytes_output[:initial_offset], b'\x00' * initial_offset 472 ) 473 474 def do_single_read( 475 self, 476 client_type: str, 477 config: TransferConfig, 478 resource_id: int, 479 data: bytes, 480 protocol_version=config_pb2.TransferAction.ProtocolVersion.LATEST, 481 permanent_resource_id=False, 482 expected_status=status_pb2.StatusCode.OK, 483 initial_offset=0, 484 offsettable_resources=False, 485 ) -> None: 486 """Performs a single server-to-client read of the provided data.""" 487 with tempfile.NamedTemporaryFile() as f_payload, tempfile.NamedTemporaryFile() as f_client_output: 488 if permanent_resource_id: 489 config.server.resources[ 490 resource_id 491 ].default_source_path = f_payload.name 492 else: 493 config.server.resources[resource_id].source_paths.append( 494 f_payload.name 495 ) 496 config.server.resources[ 497 resource_id 498 ].offsettable = offsettable_resources 499 config.client.transfer_actions.append( 500 config_pb2.TransferAction( 501 resource_id=resource_id, 502 file_path=f_client_output.name, 503 transfer_type=config_pb2.TransferAction.TransferType.READ_FROM_SERVER, 504 protocol_version=protocol_version, 505 expected_status=expected_status, 506 initial_offset=initial_offset, 507 ) 508 ) 509 510 f_payload.write(data) 511 f_payload.flush() # Ensure contents are there to read! 512 exit_codes = asyncio.run( 513 self.harness.perform_transfers( 514 config.server, client_type, config.client, config.proxy 515 ) 516 ) 517 self.assertEqual(exit_codes.client, 0) 518 self.assertEqual(exit_codes.server, 0) 519 if expected_status == status_pb2.StatusCode.OK: 520 bytes_output = f_client_output.read() 521 self.assertEqual( 522 bytes_output, 523 data[initial_offset:], 524 ) 525 526 def do_basic_transfer_sequence( 527 self, 528 client_type: str, 529 config: TransferConfig, 530 transfers: Iterable[BasicTransfer], 531 ) -> None: 532 """Performs multiple reads/writes in a single client/server session.""" 533 534 class ReadbackSet(NamedTuple): 535 server_file: BinaryIO 536 client_file: BinaryIO 537 expected_data: bytes 538 539 transfer_results: list[ReadbackSet] = [] 540 for transfer in transfers: 541 server_file = tempfile.NamedTemporaryFile() 542 client_file = tempfile.NamedTemporaryFile() 543 544 if ( 545 transfer.type 546 == config_pb2.TransferAction.TransferType.READ_FROM_SERVER 547 ): 548 server_file.write(transfer.data) 549 server_file.flush() 550 config.server.resources[transfer.id].source_paths.append( 551 server_file.name 552 ) 553 elif ( 554 transfer.type 555 == config_pb2.TransferAction.TransferType.WRITE_TO_SERVER 556 ): 557 client_file.write(transfer.data) 558 client_file.flush() 559 config.server.resources[transfer.id].destination_paths.append( 560 server_file.name 561 ) 562 else: 563 raise ValueError('Unknown TransferType') 564 565 config.client.transfer_actions.append( 566 config_pb2.TransferAction( 567 resource_id=transfer.id, 568 file_path=client_file.name, 569 transfer_type=transfer.type, 570 ) 571 ) 572 573 transfer_results.append( 574 ReadbackSet(server_file, client_file, transfer.data) 575 ) 576 577 exit_codes = asyncio.run( 578 self.harness.perform_transfers( 579 config.server, client_type, config.client, config.proxy 580 ) 581 ) 582 583 for i, result in enumerate(transfer_results): 584 with self.subTest(i=i): 585 # Need to seek to the beginning of the file to read written 586 # data. 587 result.client_file.seek(0, 0) 588 result.server_file.seek(0, 0) 589 self.assertEqual( 590 result.client_file.read(), result.expected_data 591 ) 592 self.assertEqual( 593 result.server_file.read(), result.expected_data 594 ) 595 596 # Check exit codes at the end as they provide less useful info. 597 self.assertEqual(exit_codes.client, 0) 598 self.assertEqual(exit_codes.server, 0) 599 600 601def run_tests_for(test_class_name): 602 parser = argparse.ArgumentParser() 603 parser.add_argument( 604 '--server-port', 605 type=int, 606 help='Port of the integration test server. The proxy will forward connections to this port', 607 ) 608 parser.add_argument( 609 '--client-port', 610 type=int, 611 help='Port on which to listen for connections from integration test client.', 612 ) 613 parser.add_argument( 614 '--java-client-binary', 615 type=pathlib.Path, 616 default=None, 617 help='Path to the Java transfer client to use in tests', 618 ) 619 parser.add_argument( 620 '--cpp-client-binary', 621 type=pathlib.Path, 622 default=None, 623 help='Path to the C++ transfer client to use in tests', 624 ) 625 parser.add_argument( 626 '--python-client-binary', 627 type=pathlib.Path, 628 default=None, 629 help='Path to the Python transfer client to use in tests', 630 ) 631 parser.add_argument( 632 '--server-binary', 633 type=pathlib.Path, 634 default=None, 635 help='Path to the transfer server to use in tests', 636 ) 637 parser.add_argument( 638 '--proxy-binary', 639 type=pathlib.Path, 640 default=None, 641 help=( 642 'Path to the proxy binary to use in tests to allow interception ' 643 'of client/server data' 644 ), 645 ) 646 647 (args, passthrough_args) = parser.parse_known_args() 648 649 # Inherrit the default configuration from the class being tested, and only 650 # override provided arguments. 651 for arg in vars(args): 652 val = getattr(args, arg) 653 if val: 654 setattr(test_class_name.HARNESS_CONFIG, arg, val) 655 656 unittest_args = [sys.argv[0]] + passthrough_args 657 unittest.main(argv=unittest_args) 658