1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Provides a pw_rpc client for Python.""" 15 16from __future__ import annotations 17 18import abc 19from dataclasses import dataclass 20import logging 21from typing import ( 22 Any, 23 Callable, 24 Collection, 25 Iterable, 26 Iterator, 27) 28 29from google.protobuf.message import DecodeError, Message 30from pw_status import Status 31 32from pw_rpc import descriptors, packets 33from pw_rpc.descriptors import Channel, Service, Method, PendingRpc 34from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 35 36_LOG = logging.getLogger(__package__) 37 38# Calls with ID of `kOpenCallId` were unrequested, and are updated to have the 39# call ID of the first matching request. 40LEGACY_OPEN_CALL_ID: int = 0 41OPEN_CALL_ID: int = (2**32) - 1 42 43_MAX_CALL_ID: int = 1 << 21 44 45 46class Error(Exception): 47 """Error from incorrectly using the RPC client classes.""" 48 49 50class _PendingRpcMetadata: 51 def __init__(self, context: object): 52 self.context = context 53 54 55class PendingRpcs: 56 """Tracks pending RPCs and encodes outgoing RPC packets.""" 57 58 def __init__(self) -> None: 59 self._pending: dict[PendingRpc, _PendingRpcMetadata] = {} 60 # We skip call_id = 0 in order to avoid LEGACY_OPEN_CALL_ID. 61 self._next_call_id: int = 1 62 63 def allocate_call_id(self) -> int: 64 call_id = self._next_call_id 65 self._next_call_id = (self._next_call_id + 1) % _MAX_CALL_ID 66 # We skip call_id = 0 in order to avoid LEGACY_OPEN_CALL_ID. 67 if self._next_call_id == 0: 68 self._next_call_id = 1 69 return call_id 70 71 def request( 72 self, rpc: PendingRpc, request: Message | None, context: object 73 ) -> bytes: 74 """Starts the provided RPC and returns the encoded packet to send.""" 75 # Ensure that every context is a unique object by wrapping it in a list. 76 self.open(rpc, context) 77 return packets.encode_request(rpc, request) 78 79 def send_request( 80 self, rpc: PendingRpc, request: Message | None, context: object 81 ) -> None: 82 """Starts the provided RPC and sends the request packet to the channel. 83 84 Returns: 85 the previous context object or None 86 """ 87 self.open(rpc, context) 88 packet = packets.encode_request(rpc, request) 89 rpc.channel.output(packet) 90 91 def open(self, rpc: PendingRpc, context: object) -> None: 92 """Creates a context for an RPC, but does not invoke it. 93 94 open() can be used to receive streaming responses to an RPC that was not 95 invoked by this client. For example, a server may stream logs with a 96 server streaming RPC prior to any clients invoking it. 97 98 Returns: 99 the previous context object or None 100 """ 101 _LOG.debug('Starting %s', rpc) 102 metadata = _PendingRpcMetadata(context) 103 104 if self._pending.setdefault(rpc, metadata) is not metadata: 105 # If the context was not added, the RPC was already pending. 106 raise Error( 107 f'Sent request for {rpc}, but it is already pending! ' 108 'Cancel the RPC before invoking it again' 109 ) 110 111 def send_client_stream(self, rpc: PendingRpc, message: Message) -> None: 112 if rpc not in self._pending: 113 raise Error(f'Attempt to send client stream for inactive RPC {rpc}') 114 115 rpc.channel.output(packets.encode_client_stream(rpc, message)) 116 117 def send_client_stream_end(self, rpc: PendingRpc) -> None: 118 if rpc not in self._pending: 119 raise Error( 120 f'Attempt to send client stream end for inactive RPC {rpc}' 121 ) 122 123 rpc.channel.output(packets.encode_client_stream_end(rpc)) 124 125 def cancel(self, rpc: PendingRpc) -> bytes: 126 """Cancels the RPC. 127 128 Returns: 129 The CLIENT_ERROR packet to send. 130 131 Raises: 132 KeyError if the RPC is not pending 133 """ 134 _LOG.debug('Cancelling %s', rpc) 135 del self._pending[rpc] 136 137 return packets.encode_cancel(rpc) 138 139 def send_cancel(self, rpc: PendingRpc) -> bool: 140 """Calls cancel and sends the cancel packet, if any, to the channel.""" 141 try: 142 packet = self.cancel(rpc) 143 except KeyError: 144 return False 145 146 if packet: 147 rpc.channel.output(packet) 148 149 return True 150 151 def _match_unrequested_rpcs( 152 self, rpc: PendingRpc, completed: bool 153 ) -> _PendingRpcMetadata | None: 154 # If the inbound packet is unrequested, route to any matching call. 155 # If both the client and server calls use the open ID, they would have 156 # matched in the initial lookup before this function is called. 157 if rpc.call_id in (OPEN_CALL_ID, LEGACY_OPEN_CALL_ID): 158 for pending, context in self._pending.items(): 159 if rpc.matches_channel_service_method(pending): 160 if completed: 161 del self._pending[pending] 162 163 return context 164 165 # Otherwise, look for an existing open call that matches. If one is 166 # found, the unrequested call adopts the inbound call's ID. 167 for pending in self._pending: 168 if ( 169 pending.call_id == OPEN_CALL_ID 170 and rpc.matches_channel_service_method(pending) 171 ): 172 # Change the call ID in the PendingRpc object. The PendingRpc 173 # MUST be removed from the self._pending dict first since it is 174 # hashable. 175 # 176 # TODO: https://pwbug.dev/359401616 - Changing a hashable object 177 # is not good, but the ClientImpl abstraction boundary makes 178 # updating the call's PendingRpc instance difficult. This code 179 # should be updated after the client is refactored. 180 context = self._pending.pop(pending) 181 object.__setattr__(pending, 'call_id', rpc.call_id) 182 if not completed: 183 self._pending[pending] = context 184 return context 185 186 return None 187 188 def get_pending( 189 self, rpc: PendingRpc, completed: bool 190 ) -> _PendingRpcMetadata | None: 191 """Gets the pending RPC's context. If status is set, clears the RPC.""" 192 # Look up the RPC. If there is no match, check for unrequested RPCs. 193 if (meta := self._pending.get(rpc)) is None: 194 meta = self._match_unrequested_rpcs(rpc, completed) 195 elif completed: 196 del self._pending[rpc] 197 198 return meta 199 200 201class ClientImpl(abc.ABC): 202 """The internal interface of the RPC client. 203 204 This interface defines the semantics for invoking an RPC on a particular 205 client. 206 """ 207 208 def __init__(self) -> None: 209 self.client: Client | None = None 210 self.rpcs: PendingRpcs | None = None 211 212 @abc.abstractmethod 213 def method_client(self, channel: Channel, method: Method) -> Any: 214 """Returns an object that invokes a method using the given channel.""" 215 216 @abc.abstractmethod 217 def handle_response( 218 self, 219 rpc: PendingRpc, 220 context: Any, 221 payload: Any, 222 ) -> Any: 223 """Handles a response from the RPC server. 224 225 Args: 226 rpc: Information about the pending RPC 227 context: Arbitrary context object associated with the pending RPC 228 payload: A protobuf message 229 """ 230 231 @abc.abstractmethod 232 def handle_completion( 233 self, 234 rpc: PendingRpc, 235 context: Any, 236 status: Status, 237 ) -> Any: 238 """Handles the successful completion of an RPC. 239 240 Args: 241 rpc: Information about the pending RPC 242 context: Arbitrary context object associated with the pending RPC 243 status: Status returned from the RPC 244 """ 245 246 @abc.abstractmethod 247 def handle_error( 248 self, 249 rpc: PendingRpc, 250 context, 251 status: Status, 252 ): 253 """Handles the abnormal termination of an RPC. 254 255 args: 256 rpc: Information about the pending RPC 257 context: Arbitrary context object associated with the pending RPC 258 status: which error occurred 259 """ 260 261 262class ServiceClient(descriptors.ServiceAccessor): 263 """Navigates the methods in a service provided by a ChannelClient.""" 264 265 def __init__( 266 self, client_impl: ClientImpl, channel: Channel, service: Service 267 ): 268 super().__init__( 269 { 270 method: client_impl.method_client(channel, method) 271 for method in service.methods 272 }, 273 as_attrs='members', 274 ) 275 276 self._channel = channel 277 self._service = service 278 279 def __repr__(self) -> str: 280 return ( 281 f'Service({self._service.full_name!r}, ' 282 f'methods={[m.name for m in self._service.methods]}, ' 283 f'channel={self._channel.id})' 284 ) 285 286 def __str__(self) -> str: 287 return str(self._service) 288 289 290class Services(descriptors.ServiceAccessor[ServiceClient]): 291 """Navigates the services provided by a ChannelClient.""" 292 293 def __init__( 294 self, client_impl, channel: Channel, services: Collection[Service] 295 ): 296 super().__init__( 297 {s: ServiceClient(client_impl, channel, s) for s in services}, 298 as_attrs='packages', 299 ) 300 301 self._channel = channel 302 self._services = services 303 304 def __repr__(self) -> str: 305 return ( 306 f'Services(channel={self._channel.id}, ' 307 f'services={[s.full_name for s in self._services]})' 308 ) 309 310 311def _decode_status(rpc: PendingRpc, packet) -> Status | None: 312 if packet.type == PacketType.SERVER_STREAM: 313 return None 314 315 try: 316 return Status(packet.status) 317 except ValueError: 318 _LOG.warning('Illegal status code %d for %s', packet.status, rpc) 319 return Status.UNKNOWN 320 321 322def _decode_payload(rpc: PendingRpc, packet) -> Message | None: 323 if packet.type == PacketType.SERVER_ERROR: 324 return None 325 326 # Server streaming RPCs do not send a payload with their RESPONSE packet. 327 if packet.type == PacketType.RESPONSE and rpc.method.server_streaming: 328 return None 329 330 return packets.decode_payload(packet, rpc.method.response_type) 331 332 333@dataclass(frozen=True, eq=False) 334class ChannelClient: 335 """RPC services and methods bound to a particular channel. 336 337 RPCs are invoked through service method clients. These may be accessed via 338 the `rpcs` member. Service methods use a fully qualified name: package, 339 service, method. Service methods may be selected as attributes or by 340 indexing the rpcs member by service and method name or ID. 341 342 # Access the service method client as an attribute 343 rpc = client.channel(1).rpcs.the.package.FooService.SomeMethod 344 345 # Access the service method client by string name 346 rpc = client.channel(1).rpcs[foo_service_id]['SomeMethod'] 347 348 RPCs may also be accessed from their canonical name. 349 350 # Access the service method client from its full name: 351 rpc = client.channel(1).method('the.package.FooService/SomeMethod') 352 353 # Using a . instead of a / is also supported: 354 rpc = client.channel(1).method('the.package.FooService.SomeMethod') 355 356 The ClientImpl class determines the type of the service method client. A 357 synchronous RPC client might return a callable object, so an RPC could be 358 invoked directly (e.g. rpc(field1=123, field2=b'456')). 359 """ 360 361 client: Client 362 channel: Channel 363 rpcs: Services 364 365 def method(self, method_name: str): 366 """Returns a method client matching the given name. 367 368 Args: 369 method_name: name as package.Service/Method or package.Service.Method. 370 371 Raises: 372 ValueError: the method name is not properly formatted 373 KeyError: the method is not present 374 """ 375 return descriptors.get_method(self.rpcs, method_name) 376 377 def services(self) -> Iterator: 378 return iter(self.rpcs) 379 380 def methods(self) -> Iterator: 381 """Iterates over all method clients in this ChannelClient.""" 382 for service_client in self.rpcs: 383 yield from service_client 384 385 def __repr__(self) -> str: 386 return ( 387 f'ChannelClient(channel={self.channel.id}, ' 388 f'services={[str(s) for s in self.services()]})' 389 ) 390 391 392def _update_for_backwards_compatibility( 393 rpc: PendingRpc, packet: RpcPacket 394) -> None: 395 """Adapts server streaming RPC packets to the updated protocol if needed.""" 396 # The protocol changes only affect server streaming RPCs. 397 if rpc.method.type is not Method.Type.SERVER_STREAMING: 398 return 399 400 # Prior to the introduction of SERVER_STREAM packets, RESPONSE packets with 401 # a payload were used instead. If a non-zero payload is present, assume this 402 # RESPONSE is equivalent to a SERVER_STREAM packet. 403 # 404 # Note that the payload field is not 'optional', so an empty payload is 405 # equivalent to a payload that happens to encode to zero bytes. This would 406 # only affect server streaming RPCs on the old protocol that intentionally 407 # send empty payloads, which will not be an issue in practice. 408 if packet.type == PacketType.RESPONSE and packet.payload: 409 packet.type = PacketType.SERVER_STREAM 410 411 412class Client: 413 """Sends requests and handles responses for a set of channels. 414 415 RPC invocations occur through a ChannelClient. 416 417 Users may set an optional response_callback that is called before processing 418 every response or server stream RPC packet. 419 """ 420 421 @classmethod 422 def from_modules( 423 cls, impl: ClientImpl, channels: Iterable[Channel], modules: Iterable 424 ): 425 return cls( 426 impl, 427 channels, 428 ( 429 Service.from_descriptor(service) 430 for module in modules 431 for service in module.DESCRIPTOR.services_by_name.values() 432 ), 433 ) 434 435 def __init__( 436 self, 437 impl: ClientImpl, 438 channels: Iterable[Channel], 439 services: Iterable[Service], 440 ): 441 self._impl = impl 442 self._impl.client = self 443 self._impl.rpcs = PendingRpcs() 444 445 self.services = descriptors.Services(services) 446 447 self._channels_by_id = { 448 channel.id: ChannelClient( 449 self, channel, Services(self._impl, channel, self.services) 450 ) 451 for channel in channels 452 } 453 454 # Optional function called before processing every non-error RPC packet. 455 self.response_callback: ( 456 Callable[[PendingRpc, Any, Status | None], Any] | None 457 ) = None 458 459 def channel(self, channel_id: int | None = None) -> ChannelClient: 460 """Returns a ChannelClient, which is used to call RPCs on a channel. 461 462 If no channel is provided, the first channel is used. 463 """ 464 if channel_id is None: 465 return next(iter(self._channels_by_id.values())) 466 467 return self._channels_by_id[channel_id] 468 469 def channels(self) -> Iterable[ChannelClient]: 470 """Accesses the ChannelClients in this client.""" 471 return self._channels_by_id.values() 472 473 def method(self, method_name: str) -> Method: 474 """Returns a Method matching the given name. 475 476 Args: 477 method_name: name as package.Service/Method or package.Service.Method. 478 479 Raises: 480 ValueError: the method name is not properly formatted 481 KeyError: the method is not present 482 """ 483 return descriptors.get_method(self.services, method_name) 484 485 def methods(self) -> Iterator[Method]: 486 """Iterates over all Methods supported by this client.""" 487 for service in self.services: 488 yield from service.methods 489 490 def process_packet(self, pw_rpc_raw_packet_data: bytes) -> Status: 491 """Processes an incoming packet. 492 493 Args: 494 pw_rpc_raw_packet_data: raw binary data for exactly one RPC packet 495 496 Returns: 497 OK - the packet was processed by this client 498 DATA_LOSS - the packet could not be decoded 499 INVALID_ARGUMENT - the packet is for a server, not a client 500 NOT_FOUND - the packet's channel ID is not known to this client 501 """ 502 try: 503 packet = packets.decode(pw_rpc_raw_packet_data) 504 except DecodeError as err: 505 _LOG.warning('Failed to decode packet: %s', err) 506 _LOG.debug('Raw packet: %r', pw_rpc_raw_packet_data) 507 return Status.DATA_LOSS 508 509 if packets.for_server(packet): 510 return Status.INVALID_ARGUMENT 511 512 try: 513 channel_client = self._channels_by_id[packet.channel_id] 514 except KeyError: 515 _LOG.warning('Unrecognized channel ID %d', packet.channel_id) 516 return Status.NOT_FOUND 517 518 try: 519 rpc = self._look_up_service_and_method(packet, channel_client) 520 except ValueError as err: 521 _send_client_error(channel_client, packet, Status.NOT_FOUND) 522 _LOG.warning('%s', err) 523 return Status.OK 524 525 _update_for_backwards_compatibility(rpc, packet) 526 527 if packet.type not in ( 528 PacketType.RESPONSE, 529 PacketType.SERVER_STREAM, 530 PacketType.SERVER_ERROR, 531 ): 532 _LOG.error('%s: unexpected PacketType %s', rpc, packet.type) 533 _LOG.debug('Packet:\n%s', packet) 534 return Status.OK 535 536 status = _decode_status(rpc, packet) 537 538 try: 539 payload = _decode_payload(rpc, packet) 540 except DecodeError as err: 541 _send_client_error(channel_client, packet, Status.DATA_LOSS) 542 _LOG.warning( 543 'Failed to decode %s response for %s: %s', 544 rpc.method.response_type.DESCRIPTOR.full_name, 545 rpc.method.full_name, 546 err, 547 ) 548 _LOG.debug('Raw payload: %s', packet.payload) 549 550 # Make this an error packet so the error handler is called. 551 packet.type = PacketType.SERVER_ERROR 552 status = Status.DATA_LOSS 553 554 # If set, call the response callback with non-error packets. 555 if self.response_callback and packet.type != PacketType.SERVER_ERROR: 556 self.response_callback( # pylint: disable=not-callable 557 rpc, payload, status 558 ) 559 560 assert self._impl.rpcs 561 meta = self._impl.rpcs.get_pending(rpc, status is not None) 562 563 if meta is None: 564 _send_client_error( 565 channel_client, packet, Status.FAILED_PRECONDITION 566 ) 567 _LOG.debug('Discarding response for %s, which is not pending', rpc) 568 return Status.OK 569 570 if packet.type == PacketType.SERVER_ERROR: 571 assert status is not None and not status.ok() 572 _LOG.warning('%s: invocation failed with %s', rpc, status) 573 self._impl.handle_error(rpc, meta.context, status) 574 return Status.OK 575 576 if payload is not None: 577 self._impl.handle_response(rpc, meta.context, payload) 578 if status is not None: 579 self._impl.handle_completion(rpc, meta.context, status) 580 581 return Status.OK 582 583 def _look_up_service_and_method( 584 self, packet: RpcPacket, channel_client: ChannelClient 585 ) -> PendingRpc: 586 # Protobuf is sometimes silly so the 32 bit python bindings return 587 # signed values from `fixed32` fields. Let's convert back to unsigned. 588 # b/239712573 589 service_id = packet.service_id & 0xFFFFFFFF 590 try: 591 service = self.services[service_id] 592 except KeyError: 593 raise ValueError(f'Unrecognized service ID {service_id}') 594 595 # See above, also for b/239712573 596 method_id = packet.method_id & 0xFFFFFFFF 597 try: 598 method = service.methods[method_id] 599 except KeyError: 600 raise ValueError( 601 f'No method ID {method_id} in service {service.name}' 602 ) 603 604 return PendingRpc( 605 channel_client.channel, service, method, packet.call_id 606 ) 607 608 def __repr__(self) -> str: 609 return ( 610 f'pw_rpc.Client(channels={list(self._channels_by_id)}, ' 611 f'services={[s.full_name for s in self.services]})' 612 ) 613 614 615def _send_client_error( 616 client: ChannelClient, packet: RpcPacket, error: Status 617) -> None: 618 # Never send responses to SERVER_ERRORs. 619 if packet.type != PacketType.SERVER_ERROR: 620 client.channel.output(packets.encode_client_error(packet, error)) 621