1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tests classes in pw_rpc.descriptors.""" 15 16import unittest 17 18from google.protobuf.message_factory import GetMessageClass 19 20from pw_protobuf_compiler import python_protos 21from pw_rpc import descriptors 22 23TEST_PROTO = """\ 24syntax = "proto3"; 25 26package pw.test1; 27 28message SomeMessage { 29 uint32 magic_number = 1; 30} 31 32message AnotherMessage { 33 enum Result { 34 FAILED = 0; 35 FAILED_MISERABLY = 1; 36 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 37 } 38 39 Result result = 1; 40 string payload = 2; 41} 42 43service PublicService { 44 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} 45 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 46 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 47 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 48} 49""" 50 51 52class MethodTest(unittest.TestCase): 53 """Tests pw_rpc.Method.""" 54 55 def setUp(self): 56 (module,) = python_protos.compile_and_import_strings([TEST_PROTO]) 57 service = descriptors.Service.from_descriptor( 58 module.DESCRIPTOR.services_by_name['PublicService'] 59 ) 60 self._method = service.methods['SomeUnary'] 61 62 def test_get_request_with_both_message_and_kwargs(self): 63 with self.assertRaisesRegex(TypeError, r'either'): 64 self._method.get_request( 65 self._method.request_type(), {'magic_number': 1} 66 ) 67 68 def test_get_request_neither_message_nor_kwargs(self): 69 self.assertEqual( 70 self._method.request_type(), self._method.get_request(None, None) 71 ) 72 73 def test_get_request_with_wrong_type(self): 74 with self.assertRaisesRegex(TypeError, r'pw\.test1\.SomeMessage'): 75 self._method.get_request('a str!', {}) 76 77 def test_get_request_with_different_message_type(self): 78 msg = self._method.response_type() 79 with self.assertRaisesRegex(TypeError, r'pw\.test1\.SomeMessage'): 80 self._method.get_request(msg, {}) 81 82 def test_get_request_with_different_copy_of_same_message_class(self): 83 some_message_clone = GetMessageClass( 84 self._method.request_type.DESCRIPTOR 85 ) 86 msg = some_message_clone() 87 88 self.assertIsInstance(msg, self._method.request_type) 89 self.assertIs(msg.DESCRIPTOR, self._method.request_type.DESCRIPTOR) 90 91 result = self._method.get_request(msg, {}) 92 self.assertIs(result, msg) 93 94 95if __name__ == '__main__': 96 unittest.main() 97