1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Classes for handling ongoing RPC calls.""" 15 16from __future__ import annotations 17 18import enum 19from collections import deque 20import logging 21import math 22import queue 23from typing import ( 24 Any, 25 Callable, 26 Deque, 27 Iterable, 28 Iterator, 29 NamedTuple, 30 Sequence, 31 TypeVar, 32) 33 34from pw_protobuf_compiler.python_protos import proto_repr 35from pw_status import Status 36from google.protobuf.message import Message 37 38from pw_rpc.callback_client.errors import RpcTimeout, RpcError 39from pw_rpc.client import PendingRpc, PendingRpcs 40from pw_rpc.descriptors import Method 41 42_LOG = logging.getLogger(__package__) 43 44 45class UseDefault(enum.Enum): 46 """Marker for args that should use a default value, when None is valid.""" 47 48 VALUE = 0 49 50 51CallTypeT = TypeVar( 52 'CallTypeT', 53 'UnaryCall', 54 'ServerStreamingCall', 55 'ClientStreamingCall', 56 'BidirectionalStreamingCall', 57) 58 59OnNextCallback = Callable[[CallTypeT, Any], Any] 60OnCompletedCallback = Callable[[CallTypeT, Any], Any] 61OnErrorCallback = Callable[[CallTypeT, Any], Any] 62 63OptionalTimeout = UseDefault | float | None 64 65 66class UnaryResponse(NamedTuple): 67 """Result from a unary or client streaming RPC: status and response.""" 68 69 status: Status 70 response: Any 71 72 def unwrap_or_raise(self): 73 """Returns the response value or raises `ValueError` if not OK.""" 74 if not self.status.ok(): 75 raise ValueError(f'RPC returned non-OK status: {self.status}') 76 return self.response 77 78 def __repr__(self) -> str: 79 reply = proto_repr(self.response) if self.response else self.response 80 return f'({self.status}, {reply})' 81 82 83class StreamResponse(NamedTuple): 84 """Results from a server or bidirectional streaming RPC.""" 85 86 status: Status 87 responses: Sequence[Any] 88 89 def __repr__(self) -> str: 90 return ( 91 f'({self.status}, ' 92 f'[{", ".join(proto_repr(r) for r in self.responses)}])' 93 ) 94 95 96class Call: 97 """Represents an in-progress or completed RPC call.""" 98 99 def __init__( 100 self, 101 rpcs: PendingRpcs, 102 rpc: PendingRpc, 103 default_timeout_s: float | None, 104 on_next: OnNextCallback | None, 105 on_completed: OnCompletedCallback | None, 106 on_error: OnErrorCallback | None, 107 max_responses: int, 108 ) -> None: 109 self._rpcs = rpcs 110 self._rpc = rpc 111 self.default_timeout_s = default_timeout_s 112 113 self.status: Status | None = None 114 self.error: Status | None = None 115 self._callback_exception: Exception | None = None 116 self._responses: Deque = deque(maxlen=max_responses) 117 self._response_queue: queue.SimpleQueue = queue.SimpleQueue() 118 119 self.on_next = on_next or Call._default_response 120 self.on_completed = on_completed or Call._default_completion 121 self.on_error = on_error or Call._default_error 122 123 def _invoke(self, request: Message | None) -> None: 124 """Calls the RPC. This must be called immediately after __init__.""" 125 self._rpcs.send_request(self._rpc, request, self) 126 127 def _open(self) -> None: 128 self._rpcs.open(self._rpc, self) 129 130 def _default_response(self, response: Message) -> None: 131 _LOG.debug('%s received response: %s', self._rpc, response) 132 133 def _default_completion(self, status: Status) -> None: 134 _LOG.info('%s completed: %s', self._rpc, status) 135 136 def _default_error(self, error: Status) -> None: 137 _LOG.warning('%s terminated due to an error: %s', self._rpc, error) 138 139 @property 140 def call_id(self) -> int: 141 return self._rpc.call_id 142 143 @property 144 def method(self) -> Method: 145 return self._rpc.method 146 147 def completed(self) -> bool: 148 """True if the RPC call has completed, successfully or from an error.""" 149 return self.status is not None or self.error is not None 150 151 def _send_client_stream( 152 self, request_proto: Message | None, request_fields: dict 153 ) -> None: 154 """Sends a client to the server in the client stream. 155 156 Sending a client stream packet on a closed RPC raises an exception. 157 """ 158 self._check_errors() 159 160 if self.status is not None: 161 raise RpcError(self._rpc, Status.FAILED_PRECONDITION) 162 163 self._rpcs.send_client_stream( 164 self._rpc, self.method.get_request(request_proto, request_fields) 165 ) 166 167 def _finish_client_stream(self, requests: Iterable[Message]) -> None: 168 for request in requests: 169 self._send_client_stream(request, {}) 170 171 if not self.completed(): 172 self._rpcs.send_client_stream_end(self._rpc) 173 174 def _unary_wait(self, timeout_s: OptionalTimeout) -> UnaryResponse: 175 """Waits until the RPC has completed.""" 176 for _ in self._get_responses(timeout_s=timeout_s): 177 pass 178 179 assert self.status is not None and self._responses 180 return UnaryResponse(self.status, self._responses[-1]) 181 182 def _stream_wait(self, timeout_s: OptionalTimeout) -> StreamResponse: 183 """Waits until the RPC has completed.""" 184 for _ in self._get_responses(timeout_s=timeout_s): 185 pass 186 187 assert self.status is not None 188 return StreamResponse(self.status, list(self._responses)) 189 190 def _get_responses( 191 self, *, count: int | None = None, timeout_s: OptionalTimeout 192 ) -> Iterator: 193 """Returns an iterator of stream responses. 194 195 Args: 196 count: Responses to read before returning; None reads all 197 timeout_s: max time in seconds to wait between responses; 0 doesn't 198 block, None blocks indefinitely 199 """ 200 self._check_errors() 201 202 if self.completed() and self._response_queue.empty(): 203 return 204 205 if timeout_s is UseDefault.VALUE: 206 timeout_s = self.default_timeout_s 207 208 remaining = math.inf if count is None else count 209 210 try: 211 while remaining: 212 response = self._response_queue.get(True, timeout_s) 213 214 self._check_errors() 215 216 if response is None: 217 return 218 219 yield response 220 remaining -= 1 221 except queue.Empty: 222 raise RpcTimeout(self._rpc, timeout_s) 223 224 def cancel(self) -> bool: 225 """Cancels the RPC; returns whether the RPC was active.""" 226 if self.completed(): 227 return False 228 229 self.error = Status.CANCELLED 230 return self._rpcs.send_cancel(self._rpc) 231 232 def _check_errors(self) -> None: 233 if self._callback_exception: 234 raise self._callback_exception 235 236 if self.error: 237 raise RpcError(self._rpc, self.error) 238 239 def _handle_response(self, response: Any) -> None: 240 self._responses.append(response) 241 self._response_queue.put(response) 242 243 self._invoke_callback('on_next', response) 244 245 def _handle_completion(self, status: Status) -> None: 246 self.status = status 247 self._response_queue.put(None) 248 249 self._invoke_callback('on_completed', status) 250 251 def _handle_error(self, error: Status) -> None: 252 self.error = error 253 self._response_queue.put(None) 254 255 self._invoke_callback('on_error', error) 256 257 def _invoke_callback(self, callback_name: str, arg: Any) -> None: 258 """Invokes a user-provided callback function for an RPC event.""" 259 260 # Catch and log any exceptions from the user-provided callback so that 261 # exceptions don't terminate the thread handling RPC packets. 262 callback: Callable[[Call, Any], None] = getattr(self, callback_name) 263 264 try: 265 callback(self, arg) 266 except Exception as callback_exception: # pylint: disable=broad-except 267 msg = ( 268 f'The {callback_name} callback ({callback}) for ' 269 f'{self._rpc} raised an exception' 270 ) 271 _LOG.exception(msg) 272 273 self._callback_exception = RuntimeError(msg) 274 self._callback_exception.__cause__ = callback_exception 275 276 def __enter__(self) -> Call: 277 return self 278 279 def __exit__(self, exc_type, exc_value, traceback) -> None: 280 self.cancel() 281 282 def __repr__(self) -> str: 283 return f'{type(self).__name__}({self.method})' 284 285 286class UnaryCall(Call): 287 """Tracks the state of a unary RPC call.""" 288 289 @property 290 def response(self) -> Any: 291 return self._responses[-1] if self._responses else None 292 293 def wait( 294 self, timeout_s: OptionalTimeout = UseDefault.VALUE 295 ) -> UnaryResponse: 296 return self._unary_wait(timeout_s) 297 298 299class ServerStreamingCall(Call): 300 """Tracks the state of a server streaming RPC call.""" 301 302 @property 303 def responses(self) -> Sequence: 304 return self._responses 305 306 def wait( 307 self, timeout_s: OptionalTimeout = UseDefault.VALUE 308 ) -> StreamResponse: 309 return self._stream_wait(timeout_s) 310 311 def get_responses( 312 self, 313 *, 314 count: int | None = None, 315 timeout_s: OptionalTimeout = UseDefault.VALUE, 316 ) -> Iterator: 317 return self._get_responses(count=count, timeout_s=timeout_s) 318 319 def request_completion(self) -> None: 320 """Sends client completion packet to server.""" 321 if not self.completed(): 322 self._rpcs.send_client_stream_end(self._rpc) 323 324 def __iter__(self) -> Iterator: 325 return self.get_responses() 326 327 328class ClientStreamingCall(Call): 329 """Tracks the state of a client streaming RPC call.""" 330 331 @property 332 def response(self) -> Any: 333 return self._responses[-1] if self._responses else None 334 335 def send( 336 self, request_proto: Message | None = None, /, **request_fields 337 ) -> None: 338 """Sends client stream request to the server.""" 339 self._send_client_stream(request_proto, request_fields) 340 341 def finish_and_wait( 342 self, 343 requests: Iterable[Message] = (), 344 *, 345 timeout_s: OptionalTimeout = UseDefault.VALUE, 346 ) -> UnaryResponse: 347 """Ends the client stream and waits for the RPC to complete.""" 348 self._finish_client_stream(requests) 349 return self._unary_wait(timeout_s) 350 351 352class BidirectionalStreamingCall(Call): 353 """Tracks the state of a bidirectional streaming RPC call.""" 354 355 @property 356 def responses(self) -> Sequence: 357 return self._responses 358 359 def send( 360 self, request_proto: Message | None = None, /, **request_fields 361 ) -> None: 362 """Sends a message to the server in the client stream.""" 363 self._send_client_stream(request_proto, request_fields) 364 365 def finish_and_wait( 366 self, 367 requests: Iterable[Message] = (), 368 *, 369 timeout_s: OptionalTimeout = UseDefault.VALUE, 370 ) -> StreamResponse: 371 """Ends the client stream and waits for the RPC to complete.""" 372 self._finish_client_stream(requests) 373 return self._stream_wait(timeout_s) 374 375 def get_responses( 376 self, 377 *, 378 count: int | None = None, 379 timeout_s: OptionalTimeout = UseDefault.VALUE, 380 ) -> Iterator: 381 return self._get_responses(count=count, timeout_s=timeout_s) 382 383 def __iter__(self) -> Iterator: 384 return self.get_responses() 385