xref: /aosp_15_r20/external/pigweed/pw_rpc/py/pw_rpc/callback_client/call.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Classes for handling ongoing RPC calls."""
15
16from __future__ import annotations
17
18import enum
19from collections import deque
20import logging
21import math
22import queue
23from typing import (
24    Any,
25    Callable,
26    Deque,
27    Iterable,
28    Iterator,
29    NamedTuple,
30    Sequence,
31    TypeVar,
32)
33
34from pw_protobuf_compiler.python_protos import proto_repr
35from pw_status import Status
36from google.protobuf.message import Message
37
38from pw_rpc.callback_client.errors import RpcTimeout, RpcError
39from pw_rpc.client import PendingRpc, PendingRpcs
40from pw_rpc.descriptors import Method
41
42_LOG = logging.getLogger(__package__)
43
44
45class UseDefault(enum.Enum):
46    """Marker for args that should use a default value, when None is valid."""
47
48    VALUE = 0
49
50
51CallTypeT = TypeVar(
52    'CallTypeT',
53    'UnaryCall',
54    'ServerStreamingCall',
55    'ClientStreamingCall',
56    'BidirectionalStreamingCall',
57)
58
59OnNextCallback = Callable[[CallTypeT, Any], Any]
60OnCompletedCallback = Callable[[CallTypeT, Any], Any]
61OnErrorCallback = Callable[[CallTypeT, Any], Any]
62
63OptionalTimeout = UseDefault | float | None
64
65
66class UnaryResponse(NamedTuple):
67    """Result from a unary or client streaming RPC: status and response."""
68
69    status: Status
70    response: Any
71
72    def unwrap_or_raise(self):
73        """Returns the response value or raises `ValueError` if not OK."""
74        if not self.status.ok():
75            raise ValueError(f'RPC returned non-OK status: {self.status}')
76        return self.response
77
78    def __repr__(self) -> str:
79        reply = proto_repr(self.response) if self.response else self.response
80        return f'({self.status}, {reply})'
81
82
83class StreamResponse(NamedTuple):
84    """Results from a server or bidirectional streaming RPC."""
85
86    status: Status
87    responses: Sequence[Any]
88
89    def __repr__(self) -> str:
90        return (
91            f'({self.status}, '
92            f'[{", ".join(proto_repr(r) for r in self.responses)}])'
93        )
94
95
96class Call:
97    """Represents an in-progress or completed RPC call."""
98
99    def __init__(
100        self,
101        rpcs: PendingRpcs,
102        rpc: PendingRpc,
103        default_timeout_s: float | None,
104        on_next: OnNextCallback | None,
105        on_completed: OnCompletedCallback | None,
106        on_error: OnErrorCallback | None,
107        max_responses: int,
108    ) -> None:
109        self._rpcs = rpcs
110        self._rpc = rpc
111        self.default_timeout_s = default_timeout_s
112
113        self.status: Status | None = None
114        self.error: Status | None = None
115        self._callback_exception: Exception | None = None
116        self._responses: Deque = deque(maxlen=max_responses)
117        self._response_queue: queue.SimpleQueue = queue.SimpleQueue()
118
119        self.on_next = on_next or Call._default_response
120        self.on_completed = on_completed or Call._default_completion
121        self.on_error = on_error or Call._default_error
122
123    def _invoke(self, request: Message | None) -> None:
124        """Calls the RPC. This must be called immediately after __init__."""
125        self._rpcs.send_request(self._rpc, request, self)
126
127    def _open(self) -> None:
128        self._rpcs.open(self._rpc, self)
129
130    def _default_response(self, response: Message) -> None:
131        _LOG.debug('%s received response: %s', self._rpc, response)
132
133    def _default_completion(self, status: Status) -> None:
134        _LOG.info('%s completed: %s', self._rpc, status)
135
136    def _default_error(self, error: Status) -> None:
137        _LOG.warning('%s terminated due to an error: %s', self._rpc, error)
138
139    @property
140    def call_id(self) -> int:
141        return self._rpc.call_id
142
143    @property
144    def method(self) -> Method:
145        return self._rpc.method
146
147    def completed(self) -> bool:
148        """True if the RPC call has completed, successfully or from an error."""
149        return self.status is not None or self.error is not None
150
151    def _send_client_stream(
152        self, request_proto: Message | None, request_fields: dict
153    ) -> None:
154        """Sends a client to the server in the client stream.
155
156        Sending a client stream packet on a closed RPC raises an exception.
157        """
158        self._check_errors()
159
160        if self.status is not None:
161            raise RpcError(self._rpc, Status.FAILED_PRECONDITION)
162
163        self._rpcs.send_client_stream(
164            self._rpc, self.method.get_request(request_proto, request_fields)
165        )
166
167    def _finish_client_stream(self, requests: Iterable[Message]) -> None:
168        for request in requests:
169            self._send_client_stream(request, {})
170
171        if not self.completed():
172            self._rpcs.send_client_stream_end(self._rpc)
173
174    def _unary_wait(self, timeout_s: OptionalTimeout) -> UnaryResponse:
175        """Waits until the RPC has completed."""
176        for _ in self._get_responses(timeout_s=timeout_s):
177            pass
178
179        assert self.status is not None and self._responses
180        return UnaryResponse(self.status, self._responses[-1])
181
182    def _stream_wait(self, timeout_s: OptionalTimeout) -> StreamResponse:
183        """Waits until the RPC has completed."""
184        for _ in self._get_responses(timeout_s=timeout_s):
185            pass
186
187        assert self.status is not None
188        return StreamResponse(self.status, list(self._responses))
189
190    def _get_responses(
191        self, *, count: int | None = None, timeout_s: OptionalTimeout
192    ) -> Iterator:
193        """Returns an iterator of stream responses.
194
195        Args:
196          count: Responses to read before returning; None reads all
197          timeout_s: max time in seconds to wait between responses; 0 doesn't
198              block, None blocks indefinitely
199        """
200        self._check_errors()
201
202        if self.completed() and self._response_queue.empty():
203            return
204
205        if timeout_s is UseDefault.VALUE:
206            timeout_s = self.default_timeout_s
207
208        remaining = math.inf if count is None else count
209
210        try:
211            while remaining:
212                response = self._response_queue.get(True, timeout_s)
213
214                self._check_errors()
215
216                if response is None:
217                    return
218
219                yield response
220                remaining -= 1
221        except queue.Empty:
222            raise RpcTimeout(self._rpc, timeout_s)
223
224    def cancel(self) -> bool:
225        """Cancels the RPC; returns whether the RPC was active."""
226        if self.completed():
227            return False
228
229        self.error = Status.CANCELLED
230        return self._rpcs.send_cancel(self._rpc)
231
232    def _check_errors(self) -> None:
233        if self._callback_exception:
234            raise self._callback_exception
235
236        if self.error:
237            raise RpcError(self._rpc, self.error)
238
239    def _handle_response(self, response: Any) -> None:
240        self._responses.append(response)
241        self._response_queue.put(response)
242
243        self._invoke_callback('on_next', response)
244
245    def _handle_completion(self, status: Status) -> None:
246        self.status = status
247        self._response_queue.put(None)
248
249        self._invoke_callback('on_completed', status)
250
251    def _handle_error(self, error: Status) -> None:
252        self.error = error
253        self._response_queue.put(None)
254
255        self._invoke_callback('on_error', error)
256
257    def _invoke_callback(self, callback_name: str, arg: Any) -> None:
258        """Invokes a user-provided callback function for an RPC event."""
259
260        # Catch and log any exceptions from the user-provided callback so that
261        # exceptions don't terminate the thread handling RPC packets.
262        callback: Callable[[Call, Any], None] = getattr(self, callback_name)
263
264        try:
265            callback(self, arg)
266        except Exception as callback_exception:  # pylint: disable=broad-except
267            msg = (
268                f'The {callback_name} callback ({callback}) for '
269                f'{self._rpc} raised an exception'
270            )
271            _LOG.exception(msg)
272
273            self._callback_exception = RuntimeError(msg)
274            self._callback_exception.__cause__ = callback_exception
275
276    def __enter__(self) -> Call:
277        return self
278
279    def __exit__(self, exc_type, exc_value, traceback) -> None:
280        self.cancel()
281
282    def __repr__(self) -> str:
283        return f'{type(self).__name__}({self.method})'
284
285
286class UnaryCall(Call):
287    """Tracks the state of a unary RPC call."""
288
289    @property
290    def response(self) -> Any:
291        return self._responses[-1] if self._responses else None
292
293    def wait(
294        self, timeout_s: OptionalTimeout = UseDefault.VALUE
295    ) -> UnaryResponse:
296        return self._unary_wait(timeout_s)
297
298
299class ServerStreamingCall(Call):
300    """Tracks the state of a server streaming RPC call."""
301
302    @property
303    def responses(self) -> Sequence:
304        return self._responses
305
306    def wait(
307        self, timeout_s: OptionalTimeout = UseDefault.VALUE
308    ) -> StreamResponse:
309        return self._stream_wait(timeout_s)
310
311    def get_responses(
312        self,
313        *,
314        count: int | None = None,
315        timeout_s: OptionalTimeout = UseDefault.VALUE,
316    ) -> Iterator:
317        return self._get_responses(count=count, timeout_s=timeout_s)
318
319    def request_completion(self) -> None:
320        """Sends client completion packet to server."""
321        if not self.completed():
322            self._rpcs.send_client_stream_end(self._rpc)
323
324    def __iter__(self) -> Iterator:
325        return self.get_responses()
326
327
328class ClientStreamingCall(Call):
329    """Tracks the state of a client streaming RPC call."""
330
331    @property
332    def response(self) -> Any:
333        return self._responses[-1] if self._responses else None
334
335    def send(
336        self, request_proto: Message | None = None, /, **request_fields
337    ) -> None:
338        """Sends client stream request to the server."""
339        self._send_client_stream(request_proto, request_fields)
340
341    def finish_and_wait(
342        self,
343        requests: Iterable[Message] = (),
344        *,
345        timeout_s: OptionalTimeout = UseDefault.VALUE,
346    ) -> UnaryResponse:
347        """Ends the client stream and waits for the RPC to complete."""
348        self._finish_client_stream(requests)
349        return self._unary_wait(timeout_s)
350
351
352class BidirectionalStreamingCall(Call):
353    """Tracks the state of a bidirectional streaming RPC call."""
354
355    @property
356    def responses(self) -> Sequence:
357        return self._responses
358
359    def send(
360        self, request_proto: Message | None = None, /, **request_fields
361    ) -> None:
362        """Sends a message to the server in the client stream."""
363        self._send_client_stream(request_proto, request_fields)
364
365    def finish_and_wait(
366        self,
367        requests: Iterable[Message] = (),
368        *,
369        timeout_s: OptionalTimeout = UseDefault.VALUE,
370    ) -> StreamResponse:
371        """Ends the client stream and waits for the RPC to complete."""
372        self._finish_client_stream(requests)
373        return self._stream_wait(timeout_s)
374
375    def get_responses(
376        self,
377        *,
378        count: int | None = None,
379        timeout_s: OptionalTimeout = UseDefault.VALUE,
380    ) -> Iterator:
381        return self._get_responses(count=count, timeout_s=timeout_s)
382
383    def __iter__(self) -> Iterator:
384        return self.get_responses()
385