xref: /aosp_15_r20/external/pigweed/pw_transfer/integration_test/test_fixture.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2022 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Test fixture for pw_transfer integration tests."""
16
17from __future__ import annotations
18
19import argparse
20import asyncio
21from dataclasses import dataclass
22import logging
23import pathlib
24from pathlib import Path
25import sys
26import tempfile
27from typing import BinaryIO, Iterable, NamedTuple
28import unittest
29
30from google.protobuf import text_format
31
32from pw_protobuf_protos import status_pb2
33from pw_transfer.integration_test import config_pb2
34from python.runfiles import runfiles
35
36_LOG = logging.getLogger('pw_transfer_intergration_test_proxy')
37_LOG.level = logging.DEBUG
38_LOG.addHandler(logging.StreamHandler(sys.stdout))
39
40
41class LogMonitor:
42    """Monitors lines read from the reader, and logs them."""
43
44    class Error(Exception):
45        """Raised if wait_for_line reaches EOF before expected line."""
46
47        pass
48
49    def __init__(self, prefix: str, reader: asyncio.StreamReader):
50        """Initializer.
51
52        Args:
53          prefix: Prepended to read lines before they are logged.
54          reader: StreamReader to read lines from.
55        """
56        self._prefix = prefix
57        self._reader = reader
58
59        # Queue of messages waiting to be monitored.
60        self._queue = asyncio.Queue()
61        # Relog any messages read from the reader, and enqueue them for
62        # monitoring.
63        self._relog_and_enqueue_task = asyncio.create_task(
64            self._relog_and_enqueue()
65        )
66
67    async def wait_for_line(self, msg: str):
68        """Wait for a line containing msg to be read from the reader."""
69        while True:
70            line = await self._queue.get()
71            if not line:
72                raise LogMonitor.Error(
73                    f"Reached EOF before getting line matching {msg}"
74                )
75            if msg in line.decode():
76                return
77
78    async def wait_for_eof(self):
79        """Wait for the reader to reach EOF, relogging any lines read."""
80        # Drain the queue, since we're not monitoring it any more.
81        drain_queue = asyncio.create_task(self._drain_queue())
82        await asyncio.gather(drain_queue, self._relog_and_enqueue_task)
83
84    async def _relog_and_enqueue(self):
85        """Reads lines from the reader, logs them, and puts them in queue."""
86        while True:
87            line = await self._reader.readline()
88            await self._queue.put(line)
89            if line:
90                _LOG.info(f"{self._prefix} {line.decode().rstrip()}")
91            else:
92                # EOF. Note, we still put the EOF in the queue, so that the
93                # queue reader can process it appropriately.
94                return
95
96    async def _drain_queue(self):
97        while True:
98            line = await self._queue.get()
99            if not line:
100                # EOF.
101                return
102
103
104class MonitoredSubprocess:
105    """A subprocess with monitored asynchronous communication."""
106
107    @staticmethod
108    async def create(cmd: list[str], prefix: str, stdinput: bytes):
109        """Starts the subprocess and writes stdinput to stdin.
110
111        This method returns once stdinput has been written to stdin. The
112        MonitoredSubprocess continues to log the process's stderr and stdout
113        (with the prefix) until it terminates.
114
115        Args:
116          cmd: Command line to execute.
117          prefix: Prepended to process logs.
118          stdinput: Written to stdin on process startup.
119        """
120        self = MonitoredSubprocess()
121        self._process = await asyncio.create_subprocess_exec(
122            *cmd,
123            stdin=asyncio.subprocess.PIPE,
124            stdout=asyncio.subprocess.PIPE,
125            stderr=asyncio.subprocess.PIPE,
126        )
127
128        self._stderr_monitor = LogMonitor(
129            f"{prefix} ERR:", self._process.stderr
130        )
131        self._stdout_monitor = LogMonitor(
132            f"{prefix} OUT:", self._process.stdout
133        )
134
135        self._process.stdin.write(stdinput)
136        await self._process.stdin.drain()
137        self._process.stdin.close()
138        await self._process.stdin.wait_closed()
139        return self
140
141    async def wait_for_line(self, stream: str, msg: str, timeout: float):
142        """Wait for a line containing msg to be read on the stream."""
143        if stream == "stdout":
144            monitor = self._stdout_monitor
145        elif stream == "stderr":
146            monitor = self._stderr_monitor
147        else:
148            raise ValueError(
149                "Stream must be 'stdout' or 'stderr', got {stream}"
150            )
151
152        await asyncio.wait_for(monitor.wait_for_line(msg), timeout)
153
154    def returncode(self):
155        return self._process.returncode
156
157    def terminate(self):
158        """Terminate the process."""
159        self._process.terminate()
160
161    async def wait_for_termination(self, timeout: float | None):
162        """Wait for the process to terminate."""
163        await asyncio.wait_for(
164            asyncio.gather(
165                self._process.wait(),
166                self._stdout_monitor.wait_for_eof(),
167                self._stderr_monitor.wait_for_eof(),
168            ),
169            timeout,
170        )
171
172    async def terminate_and_wait(self, timeout: float):
173        """Terminate the process and wait for it to exit."""
174        if self.returncode() is not None:
175            # Process already terminated
176            return
177        self.terminate()
178        await self.wait_for_termination(timeout)
179
180
181class TransferConfig(NamedTuple):
182    """A simple tuple to collect configs for test binaries."""
183
184    server: config_pb2.ServerConfig
185    client: config_pb2.ClientConfig
186    proxy: config_pb2.ProxyConfig
187
188
189class TransferIntegrationTestHarness:
190    """A class to manage transfer integration tests"""
191
192    # Prefix for log messages coming from the harness (as opposed to the server,
193    # client, or proxy processes). Padded so that the length is the same as
194    # "SERVER OUT:".
195    _PREFIX = "HARNESS:   "
196
197    @dataclass
198    class Config:
199        server_port: int = 3300
200        client_port: int = 3301
201        java_client_binary: Path | None = None
202        cpp_client_binary: Path | None = None
203        python_client_binary: Path | None = None
204        proxy_binary: Path | None = None
205        server_binary: Path | None = None
206
207    class TransferExitCodes(NamedTuple):
208        client: int
209        server: int
210
211    def __init__(self, harness_config: Config) -> None:
212        # TODO(tpudlik): This is Bazel-only. Support gn, too.
213        r = runfiles.Create()
214
215        # Set defaults.
216        self._JAVA_CLIENT_BINARY = r.Rlocation(
217            "pigweed/pw_transfer/integration_test/java_client"
218        )
219        self._CPP_CLIENT_BINARY = r.Rlocation(
220            "pigweed/pw_transfer/integration_test/cpp_client"
221        )
222        self._PYTHON_CLIENT_BINARY = r.Rlocation(
223            "pigweed/pw_transfer/integration_test/python_client"
224        )
225        self._PROXY_BINARY = r.Rlocation(
226            "pigweed/pw_transfer/integration_test/proxy"
227        )
228        self._SERVER_BINARY = r.Rlocation(
229            "pigweed/pw_transfer/integration_test/server"
230        )
231
232        # Server/client ports are non-optional, so use those.
233        self._CLIENT_PORT = harness_config.client_port
234        self._SERVER_PORT = harness_config.server_port
235
236        self._server: MonitoredSubprocess | None = None
237        self._client: MonitoredSubprocess | None = None
238        self._proxy: MonitoredSubprocess | None = None
239
240        # If the harness configuration specifies overrides, use those.
241        if harness_config.java_client_binary is not None:
242            self._JAVA_CLIENT_BINARY = harness_config.java_client_binary
243        if harness_config.cpp_client_binary is not None:
244            self._CPP_CLIENT_BINARY = harness_config.cpp_client_binary
245        if harness_config.python_client_binary is not None:
246            self._PYTHON_CLIENT_BINARY = harness_config.python_client_binary
247        if harness_config.proxy_binary is not None:
248            self._PROXY_BINARY = harness_config.proxy_binary
249        if harness_config.server_binary is not None:
250            self._SERVER_BINARY = harness_config.server_binary
251
252        self._CLIENT_BINARY = {
253            "cpp": self._CPP_CLIENT_BINARY,
254            "java": self._JAVA_CLIENT_BINARY,
255            "python": self._PYTHON_CLIENT_BINARY,
256        }
257
258    async def _start_client(
259        self, client_type: str, config: config_pb2.ClientConfig
260    ):
261        _LOG.info(f"{self._PREFIX} Starting client with config\n{config}")
262        self._client = await MonitoredSubprocess.create(
263            [self._CLIENT_BINARY[client_type], str(self._CLIENT_PORT)],
264            "CLIENT",
265            str(config).encode('ascii'),
266        )
267
268    async def _start_server(self, config: config_pb2.ServerConfig):
269        _LOG.info(f"{self._PREFIX} Starting server with config\n{config}")
270        self._server = await MonitoredSubprocess.create(
271            [self._SERVER_BINARY, str(self._SERVER_PORT)],
272            "SERVER",
273            str(config).encode('ascii'),
274        )
275
276    async def _start_proxy(self, config: config_pb2.ProxyConfig):
277        _LOG.info(f"{self._PREFIX} Starting proxy with config\n{config}")
278        self._proxy = await MonitoredSubprocess.create(
279            [
280                self._PROXY_BINARY,
281                "--server-port",
282                str(self._SERVER_PORT),
283                "--client-port",
284                str(self._CLIENT_PORT),
285            ],
286            # Extra space in "PROXY " so that it lines up with "SERVER".
287            "PROXY ",
288            str(config).encode('ascii'),
289        )
290
291    async def perform_transfers(
292        self,
293        server_config: config_pb2.ServerConfig,
294        client_type: str,
295        client_config: config_pb2.ClientConfig,
296        proxy_config: config_pb2.ProxyConfig,
297    ) -> TransferExitCodes:
298        """Performs a pw_transfer write.
299
300        Args:
301          server_config: Server configuration.
302          client_type: Either "cpp", "java", or "python".
303          client_config: Client configuration.
304          proxy_config: Proxy configuration.
305
306        Returns:
307          Exit code of the client and server as a tuple.
308        """
309        # Timeout for components (server, proxy) to come up or shut down after
310        # write is finished or a signal is sent. Approximately arbitrary. Should
311        # not be too long so that we catch bugs in the server that prevent it
312        # from shutting down.
313        TIMEOUT = 5  # seconds
314
315        try:
316            await self._start_proxy(proxy_config)
317            assert self._proxy is not None
318            await self._proxy.wait_for_line(
319                "stderr", "Listening for client connection", TIMEOUT
320            )
321
322            await self._start_server(server_config)
323            assert self._server is not None
324            await self._server.wait_for_line(
325                "stderr", "Starting pw_rpc server on port", TIMEOUT
326            )
327
328            await self._start_client(client_type, client_config)
329            assert self._client is not None
330            # No timeout: the client will only exit once the transfer
331            # completes, and this can take a long time for large payloads.
332            await self._client.wait_for_termination(None)
333
334            # Wait for the server to exit.
335            await self._server.wait_for_termination(TIMEOUT)
336
337        finally:
338            # Stop the server, if still running. (Only expected if the
339            # wait_for above timed out.)
340            if self._server is not None:
341                await self._server.terminate_and_wait(TIMEOUT)
342            # Stop the proxy. Unlike the server, we expect it to still be
343            # running at this stage.
344            if self._proxy is not None:
345                await self._proxy.terminate_and_wait(TIMEOUT)
346
347            return self.TransferExitCodes(
348                self._client.returncode(), self._server.returncode()
349            )
350
351
352class BasicTransfer(NamedTuple):
353    id: int
354    type: config_pb2.TransferAction.TransferType.ValueType
355    data: bytes
356
357
358class TransferIntegrationTest(unittest.TestCase):
359    """A base class for transfer integration tests.
360
361    This significantly reduces the boiler plate required for building
362    integration test cases for pw_transfer. This class does not include any
363    tests itself, but instead bundles together much of the boiler plate required
364    for making an integration test for pw_transfer using this test fixture.
365    """
366
367    HARNESS_CONFIG = TransferIntegrationTestHarness.Config()
368
369    @classmethod
370    def setUpClass(cls):
371        cls.harness = TransferIntegrationTestHarness(cls.HARNESS_CONFIG)
372
373    @staticmethod
374    def default_server_config() -> config_pb2.ServerConfig:
375        return config_pb2.ServerConfig(
376            chunk_size_bytes=216,
377            pending_bytes=64 * 1024,
378            chunk_timeout_seconds=5,
379            transfer_service_retries=4,
380            extend_window_divisor=32,
381        )
382
383    @staticmethod
384    def default_client_config() -> config_pb2.ClientConfig:
385        return config_pb2.ClientConfig(
386            max_retries=5,
387            max_lifetime_retries=1500,
388            initial_chunk_timeout_ms=4000,
389            chunk_timeout_ms=4000,
390        )
391
392    @staticmethod
393    def default_proxy_config() -> config_pb2.ProxyConfig:
394        return text_format.Parse(
395            """
396                client_filter_stack: [
397                    { hdlc_packetizer: {} },
398                    { data_dropper: {rate: 0.01, seed: 1649963713563718435} }
399                ]
400
401                server_filter_stack: [
402                    { hdlc_packetizer: {} },
403                    { data_dropper: {rate: 0.01, seed: 1649963713563718436} }
404            ]""",
405            config_pb2.ProxyConfig(),
406        )
407
408    @staticmethod
409    def default_config() -> TransferConfig:
410        """Returns a new transfer config with default options."""
411        return TransferConfig(
412            TransferIntegrationTest.default_server_config(),
413            TransferIntegrationTest.default_client_config(),
414            TransferIntegrationTest.default_proxy_config(),
415        )
416
417    def do_single_write(
418        self,
419        client_type: str,
420        config: TransferConfig,
421        resource_id: int,
422        data: bytes,
423        protocol_version=config_pb2.TransferAction.ProtocolVersion.LATEST,
424        permanent_resource_id=False,
425        expected_status=status_pb2.StatusCode.OK,
426        initial_offset=0,
427        offsettable_resources=False,
428    ) -> None:
429        """Performs a single client-to-server write of the provided data."""
430        with tempfile.NamedTemporaryFile() as f_payload, tempfile.NamedTemporaryFile() as f_server_output:
431            if permanent_resource_id:
432                config.server.resources[
433                    resource_id
434                ].default_destination_path = f_server_output.name
435            else:
436                config.server.resources[resource_id].destination_paths.append(
437                    f_server_output.name
438                )
439            config.server.resources[
440                resource_id
441            ].offsettable = offsettable_resources
442            config.client.transfer_actions.append(
443                config_pb2.TransferAction(
444                    resource_id=resource_id,
445                    file_path=f_payload.name,
446                    transfer_type=config_pb2.TransferAction.TransferType.WRITE_TO_SERVER,
447                    protocol_version=protocol_version,
448                    expected_status=expected_status,
449                    initial_offset=initial_offset,
450                )
451            )
452
453            f_payload.write(data)
454            f_payload.flush()  # Ensure contents are there to read!
455            exit_codes = asyncio.run(
456                self.harness.perform_transfers(
457                    config.server, client_type, config.client, config.proxy
458                )
459            )
460
461            self.assertEqual(exit_codes.client, 0)
462            self.assertEqual(exit_codes.server, 0)
463            if expected_status == status_pb2.StatusCode.OK:
464                bytes_output = f_server_output.read()
465                self.assertEqual(
466                    bytes_output[initial_offset:],
467                    data,
468                )
469                # Ensure we didn't write data to places before offset
470                self.assertEqual(
471                    bytes_output[:initial_offset], b'\x00' * initial_offset
472                )
473
474    def do_single_read(
475        self,
476        client_type: str,
477        config: TransferConfig,
478        resource_id: int,
479        data: bytes,
480        protocol_version=config_pb2.TransferAction.ProtocolVersion.LATEST,
481        permanent_resource_id=False,
482        expected_status=status_pb2.StatusCode.OK,
483        initial_offset=0,
484        offsettable_resources=False,
485    ) -> None:
486        """Performs a single server-to-client read of the provided data."""
487        with tempfile.NamedTemporaryFile() as f_payload, tempfile.NamedTemporaryFile() as f_client_output:
488            if permanent_resource_id:
489                config.server.resources[
490                    resource_id
491                ].default_source_path = f_payload.name
492            else:
493                config.server.resources[resource_id].source_paths.append(
494                    f_payload.name
495                )
496            config.server.resources[
497                resource_id
498            ].offsettable = offsettable_resources
499            config.client.transfer_actions.append(
500                config_pb2.TransferAction(
501                    resource_id=resource_id,
502                    file_path=f_client_output.name,
503                    transfer_type=config_pb2.TransferAction.TransferType.READ_FROM_SERVER,
504                    protocol_version=protocol_version,
505                    expected_status=expected_status,
506                    initial_offset=initial_offset,
507                )
508            )
509
510            f_payload.write(data)
511            f_payload.flush()  # Ensure contents are there to read!
512            exit_codes = asyncio.run(
513                self.harness.perform_transfers(
514                    config.server, client_type, config.client, config.proxy
515                )
516            )
517            self.assertEqual(exit_codes.client, 0)
518            self.assertEqual(exit_codes.server, 0)
519            if expected_status == status_pb2.StatusCode.OK:
520                bytes_output = f_client_output.read()
521                self.assertEqual(
522                    bytes_output,
523                    data[initial_offset:],
524                )
525
526    def do_basic_transfer_sequence(
527        self,
528        client_type: str,
529        config: TransferConfig,
530        transfers: Iterable[BasicTransfer],
531    ) -> None:
532        """Performs multiple reads/writes in a single client/server session."""
533
534        class ReadbackSet(NamedTuple):
535            server_file: BinaryIO
536            client_file: BinaryIO
537            expected_data: bytes
538
539        transfer_results: list[ReadbackSet] = []
540        for transfer in transfers:
541            server_file = tempfile.NamedTemporaryFile()
542            client_file = tempfile.NamedTemporaryFile()
543
544            if (
545                transfer.type
546                == config_pb2.TransferAction.TransferType.READ_FROM_SERVER
547            ):
548                server_file.write(transfer.data)
549                server_file.flush()
550                config.server.resources[transfer.id].source_paths.append(
551                    server_file.name
552                )
553            elif (
554                transfer.type
555                == config_pb2.TransferAction.TransferType.WRITE_TO_SERVER
556            ):
557                client_file.write(transfer.data)
558                client_file.flush()
559                config.server.resources[transfer.id].destination_paths.append(
560                    server_file.name
561                )
562            else:
563                raise ValueError('Unknown TransferType')
564
565            config.client.transfer_actions.append(
566                config_pb2.TransferAction(
567                    resource_id=transfer.id,
568                    file_path=client_file.name,
569                    transfer_type=transfer.type,
570                )
571            )
572
573            transfer_results.append(
574                ReadbackSet(server_file, client_file, transfer.data)
575            )
576
577        exit_codes = asyncio.run(
578            self.harness.perform_transfers(
579                config.server, client_type, config.client, config.proxy
580            )
581        )
582
583        for i, result in enumerate(transfer_results):
584            with self.subTest(i=i):
585                # Need to seek to the beginning of the file to read written
586                # data.
587                result.client_file.seek(0, 0)
588                result.server_file.seek(0, 0)
589                self.assertEqual(
590                    result.client_file.read(), result.expected_data
591                )
592                self.assertEqual(
593                    result.server_file.read(), result.expected_data
594                )
595
596        # Check exit codes at the end as they provide less useful info.
597        self.assertEqual(exit_codes.client, 0)
598        self.assertEqual(exit_codes.server, 0)
599
600
601def run_tests_for(test_class_name):
602    parser = argparse.ArgumentParser()
603    parser.add_argument(
604        '--server-port',
605        type=int,
606        help='Port of the integration test server.  The proxy will forward connections to this port',
607    )
608    parser.add_argument(
609        '--client-port',
610        type=int,
611        help='Port on which to listen for connections from integration test client.',
612    )
613    parser.add_argument(
614        '--java-client-binary',
615        type=pathlib.Path,
616        default=None,
617        help='Path to the Java transfer client to use in tests',
618    )
619    parser.add_argument(
620        '--cpp-client-binary',
621        type=pathlib.Path,
622        default=None,
623        help='Path to the C++ transfer client to use in tests',
624    )
625    parser.add_argument(
626        '--python-client-binary',
627        type=pathlib.Path,
628        default=None,
629        help='Path to the Python transfer client to use in tests',
630    )
631    parser.add_argument(
632        '--server-binary',
633        type=pathlib.Path,
634        default=None,
635        help='Path to the transfer server to use in tests',
636    )
637    parser.add_argument(
638        '--proxy-binary',
639        type=pathlib.Path,
640        default=None,
641        help=(
642            'Path to the proxy binary to use in tests to allow interception '
643            'of client/server data'
644        ),
645    )
646
647    (args, passthrough_args) = parser.parse_known_args()
648
649    # Inherrit the default configuration from the class being tested, and only
650    # override provided arguments.
651    for arg in vars(args):
652        val = getattr(args, arg)
653        if val:
654            setattr(test_class_name.HARNESS_CONFIG, arg, val)
655
656    unittest_args = [sys.argv[0]] + passthrough_args
657    unittest.main(argv=unittest_args)
658