xref: /aosp_15_r20/external/pigweed/pw_rpc/py/tests/console_tools/console_tools_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests the pw_rpc.console_tools.console module."""
16
17import types
18import unittest
19
20import pw_status
21
22from pw_protobuf_compiler import python_protos
23import pw_rpc
24from pw_rpc import callback_client
25from pw_rpc.console_tools.console import (
26    CommandHelper,
27    Context,
28    ClientInfo,
29    alias_deprecated_command,
30)
31
32
33class TestCommandHelper(unittest.TestCase):
34    def setUp(self) -> None:
35        self._commands = {'command_a': 'A', 'command_B': 'B'}
36        self._variables = {'hello': 1, 'world': 2}
37        self._helper = CommandHelper(
38            self._commands, self._variables, 'The header', 'The footer'
39        )
40
41    def test_help_contents(self) -> None:
42        help_contents = self._helper.help()
43
44        self.assertTrue(help_contents.startswith('The header'))
45        self.assertIn('The footer', help_contents)
46
47        for var_name in self._variables:
48            self.assertIn(var_name, help_contents)
49
50        for cmd_name in self._commands:
51            self.assertIn(cmd_name, help_contents)
52
53    def test_repr_is_help(self):
54        self.assertEqual(repr(self._helper), self._helper.help())
55
56
57_PROTO = """\
58syntax = "proto3";
59
60package the.pkg;
61
62message SomeMessage {
63  uint32 magic_number = 1;
64
65    message AnotherMessage {
66      string payload = 1;
67    }
68
69}
70
71service Service {
72  rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage);
73}
74"""
75
76
77class TestConsoleContext(unittest.TestCase):
78    """Tests console_tools.console.Context."""
79
80    def setUp(self) -> None:
81        self._protos = python_protos.Library.from_strings(_PROTO)
82
83        self._info = ClientInfo(
84            'the_client',
85            object(),
86            pw_rpc.Client.from_modules(
87                callback_client.Impl(),
88                [
89                    pw_rpc.Channel(1, lambda _: None),
90                    pw_rpc.Channel(2, lambda _: None),
91                ],
92                self._protos.modules(),
93            ),
94        )
95
96    def test_sets_expected_variables(self) -> None:
97        variables = Context(
98            [self._info], default_client=self._info.client, protos=self._protos
99        ).variables()
100
101        self.assertIn('set_target', variables)
102
103        self.assertIsInstance(variables['help'], CommandHelper)
104        self.assertIs(variables['python_help'], help)
105        self.assertIs(pw_status.Status, variables['Status'])
106        self.assertIs(self._info.client, variables['the_client'])
107
108    def test_set_target_switches_between_clients(self) -> None:
109        client_1_channel = self._info.rpc_client.channel(1).channel
110
111        client_2_channel = pw_rpc.Channel(99, lambda _: None)
112        info_2 = ClientInfo(
113            'other_client',
114            object(),
115            pw_rpc.Client.from_modules(
116                callback_client.Impl(),
117                [client_2_channel],
118                self._protos.modules(),
119            ),
120        )
121
122        context = Context(
123            [self._info, info_2],
124            default_client=self._info.client,
125            protos=self._protos,
126        )
127
128        # Make sure the RPC service switches from one client to the other.
129        self.assertIs(
130            context.variables()['the'].pkg.Service.Unary.channel,
131            client_1_channel,
132        )
133
134        context.set_target(info_2.client)
135
136        self.assertIs(
137            context.variables()['the'].pkg.Service.Unary.channel,
138            client_2_channel,
139        )
140
141    def test_default_client_must_be_in_clients(self) -> None:
142        with self.assertRaises(ValueError):
143            Context(
144                [self._info],
145                default_client='something else',
146                protos=self._protos,
147            )
148
149    def test_set_target_invalid_channel(self) -> None:
150        context = Context(
151            [self._info], default_client=self._info.client, protos=self._protos
152        )
153
154        with self.assertRaises(KeyError):
155            context.set_target(self._info.client, 100)
156
157    def test_set_target_non_default_channel(self) -> None:
158        channel_1 = self._info.rpc_client.channel(1).channel
159        channel_2 = self._info.rpc_client.channel(2).channel
160
161        context = Context(
162            [self._info], default_client=self._info.client, protos=self._protos
163        )
164        variables = context.variables()
165
166        self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1)
167
168        context.set_target(self._info.client, 2)
169
170        self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2)
171
172        with self.assertRaises(KeyError):
173            context.set_target(self._info.client, 100)
174
175    def test_set_target_requires_client_object(self) -> None:
176        context = Context(
177            [self._info], default_client=self._info.client, protos=self._protos
178        )
179
180        with self.assertRaises(ValueError):
181            context.set_target(self._info.rpc_client)
182
183        context.set_target(self._info.client)
184
185    def test_derived_context(self) -> None:
186        called_derived_set_target = False
187
188        class DerivedContext(Context):
189            def set_target(
190                self,
191                unused_selected_client,
192                unused_channel_id: int | None = None,
193            ) -> None:
194                nonlocal called_derived_set_target
195                called_derived_set_target = True
196
197        variables = DerivedContext(
198            client_info=[self._info],
199            default_client=self._info.client,
200            protos=self._protos,
201        ).variables()
202        variables['set_target'](self._info.client)
203        self.assertTrue(called_derived_set_target)
204
205
206class TestAliasDeprecatedCommand(unittest.TestCase):
207    def test_wraps_command_to_new_package(self) -> None:
208        variables = {'abc': types.SimpleNamespace(command=lambda: 123)}
209        alias_deprecated_command(variables, 'xyz.one.two.three', 'abc.command')
210
211        self.assertEqual(variables['xyz'].one.two.three(), 123)
212
213    def test_wraps_command_to_existing_package(self) -> None:
214        variables = {
215            'abc': types.SimpleNamespace(NewCmd=lambda: 456),
216            'one': types.SimpleNamespace(),
217        }
218        alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd')
219
220        self.assertEqual(variables['one'].two.OldCmd(), 456)
221
222    def test_error_if_new_command_does_not_exist(self) -> None:
223        variables = {
224            'abc': types.SimpleNamespace(),
225            'one': types.SimpleNamespace(),
226        }
227
228        with self.assertRaises(AttributeError):
229            alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd')
230
231
232if __name__ == '__main__':
233    unittest.main()
234