xref: /aosp_15_r20/external/pigweed/pw_rpc/py/pw_rpc/client.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Provides a pw_rpc client for Python."""
15
16from __future__ import annotations
17
18import abc
19from dataclasses import dataclass
20import logging
21from typing import (
22    Any,
23    Callable,
24    Collection,
25    Iterable,
26    Iterator,
27)
28
29from google.protobuf.message import DecodeError, Message
30from pw_status import Status
31
32from pw_rpc import descriptors, packets
33from pw_rpc.descriptors import Channel, Service, Method, PendingRpc
34from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
35
36_LOG = logging.getLogger(__package__)
37
38# Calls with ID of `kOpenCallId` were unrequested, and are updated to have the
39# call ID of the first matching request.
40LEGACY_OPEN_CALL_ID: int = 0
41OPEN_CALL_ID: int = (2**32) - 1
42
43_MAX_CALL_ID: int = 1 << 21
44
45
46class Error(Exception):
47    """Error from incorrectly using the RPC client classes."""
48
49
50class _PendingRpcMetadata:
51    def __init__(self, context: object):
52        self.context = context
53
54
55class PendingRpcs:
56    """Tracks pending RPCs and encodes outgoing RPC packets."""
57
58    def __init__(self) -> None:
59        self._pending: dict[PendingRpc, _PendingRpcMetadata] = {}
60        # We skip call_id = 0 in order to avoid LEGACY_OPEN_CALL_ID.
61        self._next_call_id: int = 1
62
63    def allocate_call_id(self) -> int:
64        call_id = self._next_call_id
65        self._next_call_id = (self._next_call_id + 1) % _MAX_CALL_ID
66        # We skip call_id = 0 in order to avoid LEGACY_OPEN_CALL_ID.
67        if self._next_call_id == 0:
68            self._next_call_id = 1
69        return call_id
70
71    def request(
72        self, rpc: PendingRpc, request: Message | None, context: object
73    ) -> bytes:
74        """Starts the provided RPC and returns the encoded packet to send."""
75        # Ensure that every context is a unique object by wrapping it in a list.
76        self.open(rpc, context)
77        return packets.encode_request(rpc, request)
78
79    def send_request(
80        self, rpc: PendingRpc, request: Message | None, context: object
81    ) -> None:
82        """Starts the provided RPC and sends the request packet to the channel.
83
84        Returns:
85          the previous context object or None
86        """
87        self.open(rpc, context)
88        packet = packets.encode_request(rpc, request)
89        rpc.channel.output(packet)
90
91    def open(self, rpc: PendingRpc, context: object) -> None:
92        """Creates a context for an RPC, but does not invoke it.
93
94        open() can be used to receive streaming responses to an RPC that was not
95        invoked by this client. For example, a server may stream logs with a
96        server streaming RPC prior to any clients invoking it.
97
98        Returns:
99          the previous context object or None
100        """
101        _LOG.debug('Starting %s', rpc)
102        metadata = _PendingRpcMetadata(context)
103
104        if self._pending.setdefault(rpc, metadata) is not metadata:
105            # If the context was not added, the RPC was already pending.
106            raise Error(
107                f'Sent request for {rpc}, but it is already pending! '
108                'Cancel the RPC before invoking it again'
109            )
110
111    def send_client_stream(self, rpc: PendingRpc, message: Message) -> None:
112        if rpc not in self._pending:
113            raise Error(f'Attempt to send client stream for inactive RPC {rpc}')
114
115        rpc.channel.output(packets.encode_client_stream(rpc, message))
116
117    def send_client_stream_end(self, rpc: PendingRpc) -> None:
118        if rpc not in self._pending:
119            raise Error(
120                f'Attempt to send client stream end for inactive RPC {rpc}'
121            )
122
123        rpc.channel.output(packets.encode_client_stream_end(rpc))
124
125    def cancel(self, rpc: PendingRpc) -> bytes:
126        """Cancels the RPC.
127
128        Returns:
129          The CLIENT_ERROR packet to send.
130
131        Raises:
132          KeyError if the RPC is not pending
133        """
134        _LOG.debug('Cancelling %s', rpc)
135        del self._pending[rpc]
136
137        return packets.encode_cancel(rpc)
138
139    def send_cancel(self, rpc: PendingRpc) -> bool:
140        """Calls cancel and sends the cancel packet, if any, to the channel."""
141        try:
142            packet = self.cancel(rpc)
143        except KeyError:
144            return False
145
146        if packet:
147            rpc.channel.output(packet)
148
149        return True
150
151    def _match_unrequested_rpcs(
152        self, rpc: PendingRpc, completed: bool
153    ) -> _PendingRpcMetadata | None:
154        # If the inbound packet is unrequested, route to any matching call.
155        # If both the client and server calls use the open ID, they would have
156        # matched in the initial lookup before this function is called.
157        if rpc.call_id in (OPEN_CALL_ID, LEGACY_OPEN_CALL_ID):
158            for pending, context in self._pending.items():
159                if rpc.matches_channel_service_method(pending):
160                    if completed:
161                        del self._pending[pending]
162
163                    return context
164
165        # Otherwise, look for an existing open call that matches. If one is
166        # found, the unrequested call adopts the inbound call's ID.
167        for pending in self._pending:
168            if (
169                pending.call_id == OPEN_CALL_ID
170                and rpc.matches_channel_service_method(pending)
171            ):
172                # Change the call ID in the PendingRpc object. The PendingRpc
173                # MUST be removed from the self._pending dict first since it is
174                # hashable.
175                #
176                # TODO: https://pwbug.dev/359401616 - Changing a hashable object
177                # is not good, but the ClientImpl abstraction boundary makes
178                # updating the call's PendingRpc instance difficult. This code
179                # should be updated after the client is refactored.
180                context = self._pending.pop(pending)
181                object.__setattr__(pending, 'call_id', rpc.call_id)
182                if not completed:
183                    self._pending[pending] = context
184                return context
185
186        return None
187
188    def get_pending(
189        self, rpc: PendingRpc, completed: bool
190    ) -> _PendingRpcMetadata | None:
191        """Gets the pending RPC's context. If status is set, clears the RPC."""
192        # Look up the RPC. If there is no match, check for unrequested RPCs.
193        if (meta := self._pending.get(rpc)) is None:
194            meta = self._match_unrequested_rpcs(rpc, completed)
195        elif completed:
196            del self._pending[rpc]
197
198        return meta
199
200
201class ClientImpl(abc.ABC):
202    """The internal interface of the RPC client.
203
204    This interface defines the semantics for invoking an RPC on a particular
205    client.
206    """
207
208    def __init__(self) -> None:
209        self.client: Client | None = None
210        self.rpcs: PendingRpcs | None = None
211
212    @abc.abstractmethod
213    def method_client(self, channel: Channel, method: Method) -> Any:
214        """Returns an object that invokes a method using the given channel."""
215
216    @abc.abstractmethod
217    def handle_response(
218        self,
219        rpc: PendingRpc,
220        context: Any,
221        payload: Any,
222    ) -> Any:
223        """Handles a response from the RPC server.
224
225        Args:
226          rpc: Information about the pending RPC
227          context: Arbitrary context object associated with the pending RPC
228          payload: A protobuf message
229        """
230
231    @abc.abstractmethod
232    def handle_completion(
233        self,
234        rpc: PendingRpc,
235        context: Any,
236        status: Status,
237    ) -> Any:
238        """Handles the successful completion of an RPC.
239
240        Args:
241          rpc: Information about the pending RPC
242          context: Arbitrary context object associated with the pending RPC
243          status: Status returned from the RPC
244        """
245
246    @abc.abstractmethod
247    def handle_error(
248        self,
249        rpc: PendingRpc,
250        context,
251        status: Status,
252    ):
253        """Handles the abnormal termination of an RPC.
254
255        args:
256          rpc: Information about the pending RPC
257          context: Arbitrary context object associated with the pending RPC
258          status: which error occurred
259        """
260
261
262class ServiceClient(descriptors.ServiceAccessor):
263    """Navigates the methods in a service provided by a ChannelClient."""
264
265    def __init__(
266        self, client_impl: ClientImpl, channel: Channel, service: Service
267    ):
268        super().__init__(
269            {
270                method: client_impl.method_client(channel, method)
271                for method in service.methods
272            },
273            as_attrs='members',
274        )
275
276        self._channel = channel
277        self._service = service
278
279    def __repr__(self) -> str:
280        return (
281            f'Service({self._service.full_name!r}, '
282            f'methods={[m.name for m in self._service.methods]}, '
283            f'channel={self._channel.id})'
284        )
285
286    def __str__(self) -> str:
287        return str(self._service)
288
289
290class Services(descriptors.ServiceAccessor[ServiceClient]):
291    """Navigates the services provided by a ChannelClient."""
292
293    def __init__(
294        self, client_impl, channel: Channel, services: Collection[Service]
295    ):
296        super().__init__(
297            {s: ServiceClient(client_impl, channel, s) for s in services},
298            as_attrs='packages',
299        )
300
301        self._channel = channel
302        self._services = services
303
304    def __repr__(self) -> str:
305        return (
306            f'Services(channel={self._channel.id}, '
307            f'services={[s.full_name for s in self._services]})'
308        )
309
310
311def _decode_status(rpc: PendingRpc, packet) -> Status | None:
312    if packet.type == PacketType.SERVER_STREAM:
313        return None
314
315    try:
316        return Status(packet.status)
317    except ValueError:
318        _LOG.warning('Illegal status code %d for %s', packet.status, rpc)
319        return Status.UNKNOWN
320
321
322def _decode_payload(rpc: PendingRpc, packet) -> Message | None:
323    if packet.type == PacketType.SERVER_ERROR:
324        return None
325
326    # Server streaming RPCs do not send a payload with their RESPONSE packet.
327    if packet.type == PacketType.RESPONSE and rpc.method.server_streaming:
328        return None
329
330    return packets.decode_payload(packet, rpc.method.response_type)
331
332
333@dataclass(frozen=True, eq=False)
334class ChannelClient:
335    """RPC services and methods bound to a particular channel.
336
337    RPCs are invoked through service method clients. These may be accessed via
338    the `rpcs` member. Service methods use a fully qualified name: package,
339    service, method. Service methods may be selected as attributes or by
340    indexing the rpcs member by service and method name or ID.
341
342      # Access the service method client as an attribute
343      rpc = client.channel(1).rpcs.the.package.FooService.SomeMethod
344
345      # Access the service method client by string name
346      rpc = client.channel(1).rpcs[foo_service_id]['SomeMethod']
347
348    RPCs may also be accessed from their canonical name.
349
350      # Access the service method client from its full name:
351      rpc = client.channel(1).method('the.package.FooService/SomeMethod')
352
353      # Using a . instead of a / is also supported:
354      rpc = client.channel(1).method('the.package.FooService.SomeMethod')
355
356    The ClientImpl class determines the type of the service method client. A
357    synchronous RPC client might return a callable object, so an RPC could be
358    invoked directly (e.g. rpc(field1=123, field2=b'456')).
359    """
360
361    client: Client
362    channel: Channel
363    rpcs: Services
364
365    def method(self, method_name: str):
366        """Returns a method client matching the given name.
367
368        Args:
369          method_name: name as package.Service/Method or package.Service.Method.
370
371        Raises:
372          ValueError: the method name is not properly formatted
373          KeyError: the method is not present
374        """
375        return descriptors.get_method(self.rpcs, method_name)
376
377    def services(self) -> Iterator:
378        return iter(self.rpcs)
379
380    def methods(self) -> Iterator:
381        """Iterates over all method clients in this ChannelClient."""
382        for service_client in self.rpcs:
383            yield from service_client
384
385    def __repr__(self) -> str:
386        return (
387            f'ChannelClient(channel={self.channel.id}, '
388            f'services={[str(s) for s in self.services()]})'
389        )
390
391
392def _update_for_backwards_compatibility(
393    rpc: PendingRpc, packet: RpcPacket
394) -> None:
395    """Adapts server streaming RPC packets to the updated protocol if needed."""
396    # The protocol changes only affect server streaming RPCs.
397    if rpc.method.type is not Method.Type.SERVER_STREAMING:
398        return
399
400    # Prior to the introduction of SERVER_STREAM packets, RESPONSE packets with
401    # a payload were used instead. If a non-zero payload is present, assume this
402    # RESPONSE is equivalent to a SERVER_STREAM packet.
403    #
404    # Note that the payload field is not 'optional', so an empty payload is
405    # equivalent to a payload that happens to encode to zero bytes. This would
406    # only affect server streaming RPCs on the old protocol that intentionally
407    # send empty payloads, which will not be an issue in practice.
408    if packet.type == PacketType.RESPONSE and packet.payload:
409        packet.type = PacketType.SERVER_STREAM
410
411
412class Client:
413    """Sends requests and handles responses for a set of channels.
414
415    RPC invocations occur through a ChannelClient.
416
417    Users may set an optional response_callback that is called before processing
418    every response or server stream RPC packet.
419    """
420
421    @classmethod
422    def from_modules(
423        cls, impl: ClientImpl, channels: Iterable[Channel], modules: Iterable
424    ):
425        return cls(
426            impl,
427            channels,
428            (
429                Service.from_descriptor(service)
430                for module in modules
431                for service in module.DESCRIPTOR.services_by_name.values()
432            ),
433        )
434
435    def __init__(
436        self,
437        impl: ClientImpl,
438        channels: Iterable[Channel],
439        services: Iterable[Service],
440    ):
441        self._impl = impl
442        self._impl.client = self
443        self._impl.rpcs = PendingRpcs()
444
445        self.services = descriptors.Services(services)
446
447        self._channels_by_id = {
448            channel.id: ChannelClient(
449                self, channel, Services(self._impl, channel, self.services)
450            )
451            for channel in channels
452        }
453
454        # Optional function called before processing every non-error RPC packet.
455        self.response_callback: (
456            Callable[[PendingRpc, Any, Status | None], Any] | None
457        ) = None
458
459    def channel(self, channel_id: int | None = None) -> ChannelClient:
460        """Returns a ChannelClient, which is used to call RPCs on a channel.
461
462        If no channel is provided, the first channel is used.
463        """
464        if channel_id is None:
465            return next(iter(self._channels_by_id.values()))
466
467        return self._channels_by_id[channel_id]
468
469    def channels(self) -> Iterable[ChannelClient]:
470        """Accesses the ChannelClients in this client."""
471        return self._channels_by_id.values()
472
473    def method(self, method_name: str) -> Method:
474        """Returns a Method matching the given name.
475
476        Args:
477          method_name: name as package.Service/Method or package.Service.Method.
478
479        Raises:
480          ValueError: the method name is not properly formatted
481          KeyError: the method is not present
482        """
483        return descriptors.get_method(self.services, method_name)
484
485    def methods(self) -> Iterator[Method]:
486        """Iterates over all Methods supported by this client."""
487        for service in self.services:
488            yield from service.methods
489
490    def process_packet(self, pw_rpc_raw_packet_data: bytes) -> Status:
491        """Processes an incoming packet.
492
493        Args:
494          pw_rpc_raw_packet_data: raw binary data for exactly one RPC packet
495
496        Returns:
497          OK - the packet was processed by this client
498          DATA_LOSS - the packet could not be decoded
499          INVALID_ARGUMENT - the packet is for a server, not a client
500          NOT_FOUND - the packet's channel ID is not known to this client
501        """
502        try:
503            packet = packets.decode(pw_rpc_raw_packet_data)
504        except DecodeError as err:
505            _LOG.warning('Failed to decode packet: %s', err)
506            _LOG.debug('Raw packet: %r', pw_rpc_raw_packet_data)
507            return Status.DATA_LOSS
508
509        if packets.for_server(packet):
510            return Status.INVALID_ARGUMENT
511
512        try:
513            channel_client = self._channels_by_id[packet.channel_id]
514        except KeyError:
515            _LOG.warning('Unrecognized channel ID %d', packet.channel_id)
516            return Status.NOT_FOUND
517
518        try:
519            rpc = self._look_up_service_and_method(packet, channel_client)
520        except ValueError as err:
521            _send_client_error(channel_client, packet, Status.NOT_FOUND)
522            _LOG.warning('%s', err)
523            return Status.OK
524
525        _update_for_backwards_compatibility(rpc, packet)
526
527        if packet.type not in (
528            PacketType.RESPONSE,
529            PacketType.SERVER_STREAM,
530            PacketType.SERVER_ERROR,
531        ):
532            _LOG.error('%s: unexpected PacketType %s', rpc, packet.type)
533            _LOG.debug('Packet:\n%s', packet)
534            return Status.OK
535
536        status = _decode_status(rpc, packet)
537
538        try:
539            payload = _decode_payload(rpc, packet)
540        except DecodeError as err:
541            _send_client_error(channel_client, packet, Status.DATA_LOSS)
542            _LOG.warning(
543                'Failed to decode %s response for %s: %s',
544                rpc.method.response_type.DESCRIPTOR.full_name,
545                rpc.method.full_name,
546                err,
547            )
548            _LOG.debug('Raw payload: %s', packet.payload)
549
550            # Make this an error packet so the error handler is called.
551            packet.type = PacketType.SERVER_ERROR
552            status = Status.DATA_LOSS
553
554        # If set, call the response callback with non-error packets.
555        if self.response_callback and packet.type != PacketType.SERVER_ERROR:
556            self.response_callback(  # pylint: disable=not-callable
557                rpc, payload, status
558            )
559
560        assert self._impl.rpcs
561        meta = self._impl.rpcs.get_pending(rpc, status is not None)
562
563        if meta is None:
564            _send_client_error(
565                channel_client, packet, Status.FAILED_PRECONDITION
566            )
567            _LOG.debug('Discarding response for %s, which is not pending', rpc)
568            return Status.OK
569
570        if packet.type == PacketType.SERVER_ERROR:
571            assert status is not None and not status.ok()
572            _LOG.warning('%s: invocation failed with %s', rpc, status)
573            self._impl.handle_error(rpc, meta.context, status)
574            return Status.OK
575
576        if payload is not None:
577            self._impl.handle_response(rpc, meta.context, payload)
578        if status is not None:
579            self._impl.handle_completion(rpc, meta.context, status)
580
581        return Status.OK
582
583    def _look_up_service_and_method(
584        self, packet: RpcPacket, channel_client: ChannelClient
585    ) -> PendingRpc:
586        # Protobuf is sometimes silly so the 32 bit python bindings return
587        # signed values from `fixed32` fields. Let's convert back to unsigned.
588        # b/239712573
589        service_id = packet.service_id & 0xFFFFFFFF
590        try:
591            service = self.services[service_id]
592        except KeyError:
593            raise ValueError(f'Unrecognized service ID {service_id}')
594
595        # See above, also for b/239712573
596        method_id = packet.method_id & 0xFFFFFFFF
597        try:
598            method = service.methods[method_id]
599        except KeyError:
600            raise ValueError(
601                f'No method ID {method_id} in service {service.name}'
602            )
603
604        return PendingRpc(
605            channel_client.channel, service, method, packet.call_id
606        )
607
608    def __repr__(self) -> str:
609        return (
610            f'pw_rpc.Client(channels={list(self._channels_by_id)}, '
611            f'services={[s.full_name for s in self.services]})'
612        )
613
614
615def _send_client_error(
616    client: ChannelClient, packet: RpcPacket, error: Status
617) -> None:
618    # Never send responses to SERVER_ERRORs.
619    if packet.type != PacketType.SERVER_ERROR:
620        client.channel.output(packets.encode_client_error(packet, error))
621