1# Copyright 2022 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""This module generates the code for pw_protobuf pw_rpc services.""" 15 16import os 17from typing import Iterable 18 19from pw_protobuf.output_file import OutputFile 20from pw_protobuf.proto_tree import ProtoServiceMethod 21from pw_protobuf.proto_tree import build_node_tree 22from pw_rpc import codegen 23from pw_rpc.codegen import ( 24 client_call_type, 25 get_id, 26 CodeGenerator, 27 RPC_NAMESPACE, 28) 29 30PROTO_H_EXTENSION = '.pwpb.h' 31PWPB_H_EXTENSION = '.pwpb.h' 32 33 34def _proto_filename_to_pwpb_header(proto_file: str) -> str: 35 """Returns the generated pwpb header name for a .proto file.""" 36 filename = os.path.splitext(proto_file)[0] 37 return f'{filename}{PWPB_H_EXTENSION}' 38 39 40def _proto_filename_to_generated_header(proto_file: str) -> str: 41 """Returns the generated C++ RPC header name for a .proto file.""" 42 filename = os.path.splitext(proto_file)[0] 43 return f'{filename}.rpc{PROTO_H_EXTENSION}' 44 45 46def _serde(method: ProtoServiceMethod) -> str: 47 """Returns the PwpbMethodSerde for this method.""" 48 return ( 49 f'{RPC_NAMESPACE}::internal::kPwpbMethodSerde<' 50 f'&{method.request_type().pwpb_table()}, ' 51 f'&{method.response_type().pwpb_table()}>' 52 ) 53 54 55def _client_call( 56 method: ProtoServiceMethod, response: str | None = None 57) -> str: 58 template_args = [] 59 60 if method.client_streaming(): 61 template_args.append(method.request_type().pwpb_struct()) 62 63 if response is None: 64 response = method.response_type().pwpb_struct() 65 66 template_args.append(response) 67 68 return f'{client_call_type(method, "Pwpb")}<{", ".join(template_args)}>' 69 70 71def _function( 72 method: ProtoServiceMethod, 73 name: str | None = None, 74) -> str: 75 return f'auto {name or method.name()}' 76 77 78def _user_args( 79 method: ProtoServiceMethod, response: str | None = None 80) -> Iterable[str]: 81 if not method.client_streaming(): 82 yield f'const {method.request_type().pwpb_struct()}& request' 83 84 if response is None: 85 response = method.response_type().pwpb_struct() 86 87 if method.server_streaming(): 88 yield f'::pw::Function<void(const {response}&)>&& on_next = nullptr' 89 yield '::pw::Function<void(::pw::Status)>&& on_completed = nullptr' 90 else: 91 yield ( 92 f'::pw::Function<void(const {response}&, ::pw::Status)>&& ' 93 'on_completed = nullptr' 94 ) 95 96 yield '::pw::Function<void(::pw::Status)>&& on_error = nullptr' 97 98 99class PwpbCodeGenerator(CodeGenerator): 100 """Generates an RPC service and client using the pw_protobuf API.""" 101 102 def name(self) -> str: 103 return 'pwpb' 104 105 def method_union_name(self) -> str: 106 return 'PwpbMethodUnion' 107 108 def includes(self, proto_file_name: str) -> Iterable[str]: 109 yield '#include "pw_rpc/pwpb/client_reader_writer.h"' 110 yield '#include "pw_rpc/pwpb/internal/method_union.h"' 111 yield '#include "pw_rpc/pwpb/server_reader_writer.h"' 112 113 # Include the corresponding pwpb header file for this proto file, in 114 # which the file's messages and enums are generated. All other files 115 # imported from the .proto file are #included in there. 116 pwpb_header = _proto_filename_to_pwpb_header(proto_file_name) 117 yield f'#include "{pwpb_header}"' 118 119 def service_aliases(self) -> None: 120 self.line('template <typename Response>') 121 self.line( 122 'using ServerWriter = ' 123 f'{RPC_NAMESPACE}::PwpbServerWriter<Response>;' 124 ) 125 self.line('template <typename Request, typename Response>') 126 self.line( 127 'using ServerReader = ' 128 f'{RPC_NAMESPACE}::PwpbServerReader<Request, Response>;' 129 ) 130 self.line('template <typename Request, typename Response>') 131 self.line( 132 'using ServerReaderWriter = ' 133 f'{RPC_NAMESPACE}::PwpbServerReaderWriter<Request, Response>;' 134 ) 135 136 def method_descriptor(self, method: ProtoServiceMethod) -> None: 137 impl_method = f'&Implementation::{method.name()}' 138 139 self.line( 140 f'{RPC_NAMESPACE}::internal::GetPwpbOrRawMethodFor<{impl_method}, ' 141 f'{method.type().cc_enum()}, ' 142 f'{method.request_type().pwpb_struct()}, ' 143 f'{method.response_type().pwpb_struct()}>(' 144 ) 145 with self.indent(4): 146 self.line(f'{get_id(method)}, // Hash of "{method.name()}"') 147 self.line(f'{_serde(method)}),') 148 149 def _client_member_function( 150 self, 151 method: ProtoServiceMethod, 152 *, 153 response: str | None = None, 154 name: str | None = None, 155 dynamic: bool, 156 ) -> None: 157 if response is None: 158 response = method.response_type().pwpb_struct() 159 160 if name is None: 161 name = method.name() 162 163 self.line(f'{_function(method, name)}(') 164 self.indented_list(*_user_args(method, response), end=') const {') 165 166 with self.indent(): 167 client_call = _client_call(method, response) 168 base = 'Stream' if method.server_streaming() else 'Unary' 169 self.line( 170 f'return {RPC_NAMESPACE}::internal::' 171 f'Pwpb{base}ResponseClientCall<{response}>::' 172 f'template Start{"Dynamic" if dynamic else ""}' 173 f'<{client_call}>(' 174 ) 175 176 service_client = RPC_NAMESPACE + '::internal::ServiceClient' 177 178 args = [ 179 f'{service_client}::client()', 180 f'{service_client}::channel_id()', 181 'kServiceId', 182 get_id(method), 183 _serde(method), 184 ] 185 if method.server_streaming(): 186 args.append('std::move(on_next)') 187 188 args.append('std::move(on_completed)') 189 args.append('std::move(on_error)') 190 191 if not method.client_streaming(): 192 args.append('request') 193 194 self.indented_list(*args, end=');') 195 196 self.line('}') 197 198 def client_member_function( 199 self, method: ProtoServiceMethod, *, dynamic: bool 200 ) -> None: 201 """Outputs client code for a single RPC method.""" 202 self._client_member_function(method, dynamic=dynamic) 203 204 if dynamic: # Skip custom response overload 205 return 206 207 # Generate functions that allow specifying a custom response struct. 208 self.line( 209 'template <typename Response =' 210 + f'{method.response_type().pwpb_struct()}>' 211 ) 212 self._client_member_function( 213 method, 214 response='Response', 215 name=method.name() + 'Template', 216 dynamic=dynamic, 217 ) 218 219 def _client_static_function( 220 self, 221 method: ProtoServiceMethod, 222 response: str | None = None, 223 name: str | None = None, 224 ) -> None: 225 if response is None: 226 response = method.response_type().pwpb_struct() 227 228 if name is None: 229 name = method.name() 230 231 self.line(f'static {_function(method, name)}(') 232 self.indented_list( 233 f'{RPC_NAMESPACE}::Client& client', 234 'uint32_t channel_id', 235 *_user_args(method, response), 236 end=') {', 237 ) 238 239 with self.indent(): 240 self.line(f'return Client(client, channel_id).{name}(') 241 242 args = [] 243 244 if not method.client_streaming(): 245 args.append('request') 246 247 if method.server_streaming(): 248 args.append('std::move(on_next)') 249 250 self.indented_list( 251 *args, 252 'std::move(on_completed)', 253 'std::move(on_error)', 254 end=');', 255 ) 256 257 self.line('}') 258 259 def client_static_function(self, method: ProtoServiceMethod) -> None: 260 self._client_static_function(method) 261 262 self.line( 263 'template <typename Response =' 264 + f'{method.response_type().pwpb_struct()}>' 265 ) 266 self._client_static_function( 267 method, 'Response', method.name() + 'Template' 268 ) 269 270 def method_info_specialization(self, method: ProtoServiceMethod) -> None: 271 self.line() 272 self.line(f'using Request = {method.request_type().pwpb_struct()};') 273 self.line(f'using Response = {method.response_type().pwpb_struct()};') 274 self.line() 275 self.line( 276 f'static constexpr const {RPC_NAMESPACE}::' 277 'PwpbMethodSerde& serde() {' 278 ) 279 with self.indent(): 280 self.line(f'return {_serde(method)};') 281 self.line('}') 282 283 284class StubGenerator(codegen.StubGenerator): 285 """Generates pw_protobuf RPC stubs.""" 286 287 def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str: 288 return ( 289 f'::pw::Status {prefix}{method.name()}( ' 290 f'const {method.request_type().pwpb_struct()}& request, ' 291 f'{method.response_type().pwpb_struct()}& response)' 292 ) 293 294 def unary_stub( 295 self, method: ProtoServiceMethod, output: OutputFile 296 ) -> None: 297 output.write_line(codegen.STUB_REQUEST_TODO) 298 output.write_line('static_cast<void>(request);') 299 output.write_line(codegen.STUB_RESPONSE_TODO) 300 output.write_line('static_cast<void>(response);') 301 output.write_line('return ::pw::Status::Unimplemented();') 302 303 def server_streaming_signature( 304 self, method: ProtoServiceMethod, prefix: str 305 ) -> str: 306 return ( 307 f'void {prefix}{method.name()}( ' 308 f'const {method.request_type().pwpb_struct()}& request, ' 309 f'ServerWriter<{method.response_type().pwpb_struct()}>& writer)' 310 ) 311 312 def client_streaming_signature( 313 self, method: ProtoServiceMethod, prefix: str 314 ) -> str: 315 return ( 316 f'void {prefix}{method.name()}( ' 317 f'ServerReader<{method.request_type().pwpb_struct()}, ' 318 f'{method.response_type().pwpb_struct()}>& reader)' 319 ) 320 321 def bidirectional_streaming_signature( 322 self, method: ProtoServiceMethod, prefix: str 323 ) -> str: 324 return ( 325 f'void {prefix}{method.name()}( ' 326 f'ServerReaderWriter<{method.request_type().pwpb_struct()}, ' 327 f'{method.response_type().pwpb_struct()}>& reader_writer)' 328 ) 329 330 331def process_proto_file(proto_file) -> Iterable[OutputFile]: 332 """Generates code for a single .proto file.""" 333 334 _, package_root = build_node_tree(proto_file) 335 output_filename = _proto_filename_to_generated_header(proto_file.name) 336 337 generator = PwpbCodeGenerator(output_filename) 338 codegen.generate_package(proto_file, package_root, generator) 339 340 codegen.package_stubs(package_root, generator, StubGenerator()) 341 342 return [generator.output] 343