xref: /aosp_15_r20/external/pigweed/pw_unit_test/py/rpc_service_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests using the callback client for pw_rpc."""
16
17import logging
18import unittest
19from unittest import mock
20
21from pw_hdlc import rpc
22from pw_rpc import testing
23from pw_unit_test_proto import unit_test_pb2
24from pw_unit_test import run_tests, EventHandler, TestCase
25
26# The three suites (Passing, Failing, and DISABLED_Disabled) have these cases.
27_CASES = ('Zero', 'One', 'Two', 'DISABLED_Disabled')
28_FILE = 'pw_unit_test/test_rpc_server.cc'
29
30PASSING = tuple(TestCase('Passing', case, _FILE) for case in _CASES[:-1])
31FAILING = tuple(TestCase('Failing', case, _FILE) for case in _CASES[:-1])
32EXECUTED_TESTS = PASSING + FAILING
33
34DISABLED_SUITE = tuple(
35    TestCase('DISABLED_Disabled', case, _FILE) for case in _CASES
36)
37
38ALL_DISABLED_TESTS = (
39    TestCase('Passing', 'DISABLED_Disabled', _FILE),
40    TestCase('Failing', 'DISABLED_Disabled', _FILE),
41    *DISABLED_SUITE,
42)
43
44
45class RpcIntegrationTest(unittest.TestCase):
46    """Calls RPCs on an RPC server through a socket."""
47
48    test_server_command: tuple[str, ...] = ()
49    port: int
50
51    def setUp(self) -> None:
52        self._context = rpc.HdlcRpcLocalServerAndClient(
53            self.test_server_command, self.port, [unit_test_pb2]
54        )
55        self.rpcs = self._context.client.channel(1).rpcs
56        self.handler = mock.NonCallableMagicMock(spec=EventHandler)
57
58    def tearDown(self) -> None:
59        self._context.close()
60
61    def test_run_tests_default_handler(self) -> None:
62        with self.assertLogs(logging.getLogger('pw_unit_test'), 'INFO') as logs:
63            self.assertFalse(run_tests(self.rpcs))
64
65        for test in EXECUTED_TESTS:
66            self.assertTrue(any(str(test) in log for log in logs.output), test)
67
68    def test_run_tests_calls_test_case_start(self) -> None:
69        self.assertFalse(run_tests(self.rpcs, event_handlers=[self.handler]))
70
71        self.handler.test_case_start.assert_has_calls(
72            [mock.call(case) for case in EXECUTED_TESTS], any_order=True
73        )
74
75    def test_run_tests_calls_test_case_end(self) -> None:
76        self.assertFalse(run_tests(self.rpcs, event_handlers=[self.handler]))
77
78        calls = [
79            mock.call(
80                case,
81                unit_test_pb2.SUCCESS
82                if case.suite_name == 'Passing'
83                else unit_test_pb2.FAILURE,
84            )
85            for case in EXECUTED_TESTS
86        ]
87        self.handler.test_case_end.assert_has_calls(calls, any_order=True)
88
89    def test_run_tests_calls_test_case_disabled(self) -> None:
90        self.assertFalse(run_tests(self.rpcs, event_handlers=[self.handler]))
91
92        self.handler.test_case_disabled.assert_has_calls(
93            [mock.call(case) for case in ALL_DISABLED_TESTS], any_order=True
94        )
95
96    def test_passing_tests_only(self) -> None:
97        self.assertTrue(
98            run_tests(
99                self.rpcs,
100                test_suites=['Passing'],
101                event_handlers=[self.handler],
102            )
103        )
104        calls = [mock.call(case, unit_test_pb2.SUCCESS) for case in PASSING]
105        self.handler.test_case_end.assert_has_calls(calls, any_order=True)
106
107    def test_disabled_tests_only(self) -> None:
108        self.assertTrue(
109            run_tests(
110                self.rpcs,
111                test_suites=['DISABLED_Disabled'],
112                event_handlers=[self.handler],
113            )
114        )
115
116        self.handler.test_case_start.assert_not_called()
117        self.handler.test_case_end.assert_not_called()
118        self.handler.test_case_disabled.assert_has_calls(
119            [mock.call(case) for case in DISABLED_SUITE], any_order=True
120        )
121
122    def test_failing_tests(self) -> None:
123        self.assertFalse(
124            run_tests(
125                self.rpcs,
126                test_suites=['Failing'],
127                event_handlers=[self.handler],
128            )
129        )
130        calls = [mock.call(case, unit_test_pb2.FAILURE) for case in FAILING]
131        self.handler.test_case_end.assert_has_calls(calls, any_order=True)
132
133
134def _main(
135    test_server_command: list[str], port: int, unittest_args: list[str]
136) -> None:
137    RpcIntegrationTest.test_server_command = tuple(test_server_command)
138    RpcIntegrationTest.port = port
139    unittest.main(argv=unittest_args)
140
141
142if __name__ == '__main__':
143    _main(**vars(testing.parse_test_server_args()))
144