1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests the pw_rpc.console_tools.console module.""" 16 17import types 18import unittest 19 20import pw_status 21 22from pw_protobuf_compiler import python_protos 23import pw_rpc 24from pw_rpc import callback_client 25from pw_rpc.console_tools.console import ( 26 CommandHelper, 27 Context, 28 ClientInfo, 29 alias_deprecated_command, 30) 31 32 33class TestCommandHelper(unittest.TestCase): 34 def setUp(self) -> None: 35 self._commands = {'command_a': 'A', 'command_B': 'B'} 36 self._variables = {'hello': 1, 'world': 2} 37 self._helper = CommandHelper( 38 self._commands, self._variables, 'The header', 'The footer' 39 ) 40 41 def test_help_contents(self) -> None: 42 help_contents = self._helper.help() 43 44 self.assertTrue(help_contents.startswith('The header')) 45 self.assertIn('The footer', help_contents) 46 47 for var_name in self._variables: 48 self.assertIn(var_name, help_contents) 49 50 for cmd_name in self._commands: 51 self.assertIn(cmd_name, help_contents) 52 53 def test_repr_is_help(self): 54 self.assertEqual(repr(self._helper), self._helper.help()) 55 56 57_PROTO = """\ 58syntax = "proto3"; 59 60package the.pkg; 61 62message SomeMessage { 63 uint32 magic_number = 1; 64 65 message AnotherMessage { 66 string payload = 1; 67 } 68 69} 70 71service Service { 72 rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage); 73} 74""" 75 76 77class TestConsoleContext(unittest.TestCase): 78 """Tests console_tools.console.Context.""" 79 80 def setUp(self) -> None: 81 self._protos = python_protos.Library.from_strings(_PROTO) 82 83 self._info = ClientInfo( 84 'the_client', 85 object(), 86 pw_rpc.Client.from_modules( 87 callback_client.Impl(), 88 [ 89 pw_rpc.Channel(1, lambda _: None), 90 pw_rpc.Channel(2, lambda _: None), 91 ], 92 self._protos.modules(), 93 ), 94 ) 95 96 def test_sets_expected_variables(self) -> None: 97 variables = Context( 98 [self._info], default_client=self._info.client, protos=self._protos 99 ).variables() 100 101 self.assertIn('set_target', variables) 102 103 self.assertIsInstance(variables['help'], CommandHelper) 104 self.assertIs(variables['python_help'], help) 105 self.assertIs(pw_status.Status, variables['Status']) 106 self.assertIs(self._info.client, variables['the_client']) 107 108 def test_set_target_switches_between_clients(self) -> None: 109 client_1_channel = self._info.rpc_client.channel(1).channel 110 111 client_2_channel = pw_rpc.Channel(99, lambda _: None) 112 info_2 = ClientInfo( 113 'other_client', 114 object(), 115 pw_rpc.Client.from_modules( 116 callback_client.Impl(), 117 [client_2_channel], 118 self._protos.modules(), 119 ), 120 ) 121 122 context = Context( 123 [self._info, info_2], 124 default_client=self._info.client, 125 protos=self._protos, 126 ) 127 128 # Make sure the RPC service switches from one client to the other. 129 self.assertIs( 130 context.variables()['the'].pkg.Service.Unary.channel, 131 client_1_channel, 132 ) 133 134 context.set_target(info_2.client) 135 136 self.assertIs( 137 context.variables()['the'].pkg.Service.Unary.channel, 138 client_2_channel, 139 ) 140 141 def test_default_client_must_be_in_clients(self) -> None: 142 with self.assertRaises(ValueError): 143 Context( 144 [self._info], 145 default_client='something else', 146 protos=self._protos, 147 ) 148 149 def test_set_target_invalid_channel(self) -> None: 150 context = Context( 151 [self._info], default_client=self._info.client, protos=self._protos 152 ) 153 154 with self.assertRaises(KeyError): 155 context.set_target(self._info.client, 100) 156 157 def test_set_target_non_default_channel(self) -> None: 158 channel_1 = self._info.rpc_client.channel(1).channel 159 channel_2 = self._info.rpc_client.channel(2).channel 160 161 context = Context( 162 [self._info], default_client=self._info.client, protos=self._protos 163 ) 164 variables = context.variables() 165 166 self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1) 167 168 context.set_target(self._info.client, 2) 169 170 self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2) 171 172 with self.assertRaises(KeyError): 173 context.set_target(self._info.client, 100) 174 175 def test_set_target_requires_client_object(self) -> None: 176 context = Context( 177 [self._info], default_client=self._info.client, protos=self._protos 178 ) 179 180 with self.assertRaises(ValueError): 181 context.set_target(self._info.rpc_client) 182 183 context.set_target(self._info.client) 184 185 def test_derived_context(self) -> None: 186 called_derived_set_target = False 187 188 class DerivedContext(Context): 189 def set_target( 190 self, 191 unused_selected_client, 192 unused_channel_id: int | None = None, 193 ) -> None: 194 nonlocal called_derived_set_target 195 called_derived_set_target = True 196 197 variables = DerivedContext( 198 client_info=[self._info], 199 default_client=self._info.client, 200 protos=self._protos, 201 ).variables() 202 variables['set_target'](self._info.client) 203 self.assertTrue(called_derived_set_target) 204 205 206class TestAliasDeprecatedCommand(unittest.TestCase): 207 def test_wraps_command_to_new_package(self) -> None: 208 variables = {'abc': types.SimpleNamespace(command=lambda: 123)} 209 alias_deprecated_command(variables, 'xyz.one.two.three', 'abc.command') 210 211 self.assertEqual(variables['xyz'].one.two.three(), 123) 212 213 def test_wraps_command_to_existing_package(self) -> None: 214 variables = { 215 'abc': types.SimpleNamespace(NewCmd=lambda: 456), 216 'one': types.SimpleNamespace(), 217 } 218 alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd') 219 220 self.assertEqual(variables['one'].two.OldCmd(), 456) 221 222 def test_error_if_new_command_does_not_exist(self) -> None: 223 variables = { 224 'abc': types.SimpleNamespace(), 225 'one': types.SimpleNamespace(), 226 } 227 228 with self.assertRaises(AttributeError): 229 alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd') 230 231 232if __name__ == '__main__': 233 unittest.main() 234