1# Copyright 2019 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Invocation-side implementation of gRPC Asyncio Python.""" 15 16import asyncio 17import enum 18from functools import partial 19import inspect 20import logging 21import traceback 22from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple 23 24import grpc 25from grpc import _common 26from grpc._cython import cygrpc 27 28from . import _base_call 29from ._metadata import Metadata 30from ._typing import DeserializingFunction 31from ._typing import DoneCallbackType 32from ._typing import MetadatumType 33from ._typing import RequestIterableType 34from ._typing import RequestType 35from ._typing import ResponseType 36from ._typing import SerializingFunction 37 38__all__ = "AioRpcError", "Call", "UnaryUnaryCall", "UnaryStreamCall" 39 40_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!" 41_GC_CANCELLATION_DETAILS = "Cancelled upon garbage collection!" 42_RPC_ALREADY_FINISHED_DETAILS = "RPC already finished." 43_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".' 44_API_STYLE_ERROR = ( 45 "The iterator and read/write APIs may not be mixed on a single RPC." 46) 47 48_OK_CALL_REPRESENTATION = ( 49 '<{} of RPC that terminated with:\n\tstatus = {}\n\tdetails = "{}"\n>' 50) 51 52_NON_OK_CALL_REPRESENTATION = ( 53 "<{} of RPC that terminated with:\n" 54 "\tstatus = {}\n" 55 '\tdetails = "{}"\n' 56 '\tdebug_error_string = "{}"\n' 57 ">" 58) 59 60_LOGGER = logging.getLogger(__name__) 61 62 63class AioRpcError(grpc.RpcError): 64 """An implementation of RpcError to be used by the asynchronous API. 65 66 Raised RpcError is a snapshot of the final status of the RPC, values are 67 determined. Hence, its methods no longer needs to be coroutines. 68 """ 69 70 _code: grpc.StatusCode 71 _details: Optional[str] 72 _initial_metadata: Optional[Metadata] 73 _trailing_metadata: Optional[Metadata] 74 _debug_error_string: Optional[str] 75 76 def __init__( 77 self, 78 code: grpc.StatusCode, 79 initial_metadata: Metadata, 80 trailing_metadata: Metadata, 81 details: Optional[str] = None, 82 debug_error_string: Optional[str] = None, 83 ) -> None: 84 """Constructor. 85 86 Args: 87 code: The status code with which the RPC has been finalized. 88 details: Optional details explaining the reason of the error. 89 initial_metadata: Optional initial metadata that could be sent by the 90 Server. 91 trailing_metadata: Optional metadata that could be sent by the Server. 92 """ 93 94 super().__init__() 95 self._code = code 96 self._details = details 97 self._initial_metadata = initial_metadata 98 self._trailing_metadata = trailing_metadata 99 self._debug_error_string = debug_error_string 100 101 def code(self) -> grpc.StatusCode: 102 """Accesses the status code sent by the server. 103 104 Returns: 105 The `grpc.StatusCode` status code. 106 """ 107 return self._code 108 109 def details(self) -> Optional[str]: 110 """Accesses the details sent by the server. 111 112 Returns: 113 The description of the error. 114 """ 115 return self._details 116 117 def initial_metadata(self) -> Metadata: 118 """Accesses the initial metadata sent by the server. 119 120 Returns: 121 The initial metadata received. 122 """ 123 return self._initial_metadata 124 125 def trailing_metadata(self) -> Metadata: 126 """Accesses the trailing metadata sent by the server. 127 128 Returns: 129 The trailing metadata received. 130 """ 131 return self._trailing_metadata 132 133 def debug_error_string(self) -> str: 134 """Accesses the debug error string sent by the server. 135 136 Returns: 137 The debug error string received. 138 """ 139 return self._debug_error_string 140 141 def _repr(self) -> str: 142 """Assembles the error string for the RPC error.""" 143 return _NON_OK_CALL_REPRESENTATION.format( 144 self.__class__.__name__, 145 self._code, 146 self._details, 147 self._debug_error_string, 148 ) 149 150 def __repr__(self) -> str: 151 return self._repr() 152 153 def __str__(self) -> str: 154 return self._repr() 155 156 def __reduce__(self): 157 return ( 158 type(self), 159 ( 160 self._code, 161 self._initial_metadata, 162 self._trailing_metadata, 163 self._details, 164 self._debug_error_string, 165 ), 166 ) 167 168 169def _create_rpc_error( 170 initial_metadata: Metadata, status: cygrpc.AioRpcStatus 171) -> AioRpcError: 172 return AioRpcError( 173 _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], 174 Metadata.from_tuple(initial_metadata), 175 Metadata.from_tuple(status.trailing_metadata()), 176 details=status.details(), 177 debug_error_string=status.debug_error_string(), 178 ) 179 180 181class Call: 182 """Base implementation of client RPC Call object. 183 184 Implements logic around final status, metadata and cancellation. 185 """ 186 187 _loop: asyncio.AbstractEventLoop 188 _code: grpc.StatusCode 189 _cython_call: cygrpc._AioCall 190 _metadata: Tuple[MetadatumType, ...] 191 _request_serializer: SerializingFunction 192 _response_deserializer: DeserializingFunction 193 194 def __init__( 195 self, 196 cython_call: cygrpc._AioCall, 197 metadata: Metadata, 198 request_serializer: SerializingFunction, 199 response_deserializer: DeserializingFunction, 200 loop: asyncio.AbstractEventLoop, 201 ) -> None: 202 self._loop = loop 203 self._cython_call = cython_call 204 self._metadata = tuple(metadata) 205 self._request_serializer = request_serializer 206 self._response_deserializer = response_deserializer 207 208 def __del__(self) -> None: 209 # The '_cython_call' object might be destructed before Call object 210 if hasattr(self, "_cython_call"): 211 if not self._cython_call.done(): 212 self._cancel(_GC_CANCELLATION_DETAILS) 213 214 def cancelled(self) -> bool: 215 return self._cython_call.cancelled() 216 217 def _cancel(self, details: str) -> bool: 218 """Forwards the application cancellation reasoning.""" 219 if not self._cython_call.done(): 220 self._cython_call.cancel(details) 221 return True 222 else: 223 return False 224 225 def cancel(self) -> bool: 226 return self._cancel(_LOCAL_CANCELLATION_DETAILS) 227 228 def done(self) -> bool: 229 return self._cython_call.done() 230 231 def add_done_callback(self, callback: DoneCallbackType) -> None: 232 cb = partial(callback, self) 233 self._cython_call.add_done_callback(cb) 234 235 def time_remaining(self) -> Optional[float]: 236 return self._cython_call.time_remaining() 237 238 async def initial_metadata(self) -> Metadata: 239 raw_metadata_tuple = await self._cython_call.initial_metadata() 240 return Metadata.from_tuple(raw_metadata_tuple) 241 242 async def trailing_metadata(self) -> Metadata: 243 raw_metadata_tuple = ( 244 await self._cython_call.status() 245 ).trailing_metadata() 246 return Metadata.from_tuple(raw_metadata_tuple) 247 248 async def code(self) -> grpc.StatusCode: 249 cygrpc_code = (await self._cython_call.status()).code() 250 return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code] 251 252 async def details(self) -> str: 253 return (await self._cython_call.status()).details() 254 255 async def debug_error_string(self) -> str: 256 return (await self._cython_call.status()).debug_error_string() 257 258 async def _raise_for_status(self) -> None: 259 if self._cython_call.is_locally_cancelled(): 260 raise asyncio.CancelledError() 261 code = await self.code() 262 if code != grpc.StatusCode.OK: 263 raise _create_rpc_error( 264 await self.initial_metadata(), await self._cython_call.status() 265 ) 266 267 def _repr(self) -> str: 268 return repr(self._cython_call) 269 270 def __repr__(self) -> str: 271 return self._repr() 272 273 def __str__(self) -> str: 274 return self._repr() 275 276 277class _APIStyle(enum.IntEnum): 278 UNKNOWN = 0 279 ASYNC_GENERATOR = 1 280 READER_WRITER = 2 281 282 283class _UnaryResponseMixin(Call, Generic[ResponseType]): 284 _call_response: asyncio.Task 285 286 def _init_unary_response_mixin(self, response_task: asyncio.Task): 287 self._call_response = response_task 288 289 def cancel(self) -> bool: 290 if super().cancel(): 291 self._call_response.cancel() 292 return True 293 else: 294 return False 295 296 def __await__(self) -> Generator[Any, None, ResponseType]: 297 """Wait till the ongoing RPC request finishes.""" 298 try: 299 response = yield from self._call_response 300 except asyncio.CancelledError: 301 # Even if we caught all other CancelledError, there is still 302 # this corner case. If the application cancels immediately after 303 # the Call object is created, we will observe this 304 # `CancelledError`. 305 if not self.cancelled(): 306 self.cancel() 307 raise 308 309 # NOTE(lidiz) If we raise RpcError in the task, and users doesn't 310 # 'await' on it. AsyncIO will log 'Task exception was never retrieved'. 311 # Instead, if we move the exception raising here, the spam stops. 312 # Unfortunately, there can only be one 'yield from' in '__await__'. So, 313 # we need to access the private instance variable. 314 if response is cygrpc.EOF: 315 if self._cython_call.is_locally_cancelled(): 316 raise asyncio.CancelledError() 317 else: 318 raise _create_rpc_error( 319 self._cython_call._initial_metadata, 320 self._cython_call._status, 321 ) 322 else: 323 return response 324 325 326class _StreamResponseMixin(Call): 327 _message_aiter: AsyncIterator[ResponseType] 328 _preparation: asyncio.Task 329 _response_style: _APIStyle 330 331 def _init_stream_response_mixin(self, preparation: asyncio.Task): 332 self._message_aiter = None 333 self._preparation = preparation 334 self._response_style = _APIStyle.UNKNOWN 335 336 def _update_response_style(self, style: _APIStyle): 337 if self._response_style is _APIStyle.UNKNOWN: 338 self._response_style = style 339 elif self._response_style is not style: 340 raise cygrpc.UsageError(_API_STYLE_ERROR) 341 342 def cancel(self) -> bool: 343 if super().cancel(): 344 self._preparation.cancel() 345 return True 346 else: 347 return False 348 349 async def _fetch_stream_responses(self) -> ResponseType: 350 message = await self._read() 351 while message is not cygrpc.EOF: 352 yield message 353 message = await self._read() 354 355 # If the read operation failed, Core should explain why. 356 await self._raise_for_status() 357 358 def __aiter__(self) -> AsyncIterator[ResponseType]: 359 self._update_response_style(_APIStyle.ASYNC_GENERATOR) 360 if self._message_aiter is None: 361 self._message_aiter = self._fetch_stream_responses() 362 return self._message_aiter 363 364 async def _read(self) -> ResponseType: 365 # Wait for the request being sent 366 await self._preparation 367 368 # Reads response message from Core 369 try: 370 raw_response = await self._cython_call.receive_serialized_message() 371 except asyncio.CancelledError: 372 if not self.cancelled(): 373 self.cancel() 374 raise 375 376 if raw_response is cygrpc.EOF: 377 return cygrpc.EOF 378 else: 379 return _common.deserialize( 380 raw_response, self._response_deserializer 381 ) 382 383 async def read(self) -> ResponseType: 384 if self.done(): 385 await self._raise_for_status() 386 return cygrpc.EOF 387 self._update_response_style(_APIStyle.READER_WRITER) 388 389 response_message = await self._read() 390 391 if response_message is cygrpc.EOF: 392 # If the read operation failed, Core should explain why. 393 await self._raise_for_status() 394 return response_message 395 396 397class _StreamRequestMixin(Call): 398 _metadata_sent: asyncio.Event 399 _done_writing_flag: bool 400 _async_request_poller: Optional[asyncio.Task] 401 _request_style: _APIStyle 402 403 def _init_stream_request_mixin( 404 self, request_iterator: Optional[RequestIterableType] 405 ): 406 self._metadata_sent = asyncio.Event() 407 self._done_writing_flag = False 408 409 # If user passes in an async iterator, create a consumer Task. 410 if request_iterator is not None: 411 self._async_request_poller = self._loop.create_task( 412 self._consume_request_iterator(request_iterator) 413 ) 414 self._request_style = _APIStyle.ASYNC_GENERATOR 415 else: 416 self._async_request_poller = None 417 self._request_style = _APIStyle.READER_WRITER 418 419 def _raise_for_different_style(self, style: _APIStyle): 420 if self._request_style is not style: 421 raise cygrpc.UsageError(_API_STYLE_ERROR) 422 423 def cancel(self) -> bool: 424 if super().cancel(): 425 if self._async_request_poller is not None: 426 self._async_request_poller.cancel() 427 return True 428 else: 429 return False 430 431 def _metadata_sent_observer(self): 432 self._metadata_sent.set() 433 434 async def _consume_request_iterator( 435 self, request_iterator: RequestIterableType 436 ) -> None: 437 try: 438 if inspect.isasyncgen(request_iterator) or hasattr( 439 request_iterator, "__aiter__" 440 ): 441 async for request in request_iterator: 442 try: 443 await self._write(request) 444 except AioRpcError as rpc_error: 445 _LOGGER.debug( 446 ( 447 "Exception while consuming the" 448 " request_iterator: %s" 449 ), 450 rpc_error, 451 ) 452 return 453 else: 454 for request in request_iterator: 455 try: 456 await self._write(request) 457 except AioRpcError as rpc_error: 458 _LOGGER.debug( 459 ( 460 "Exception while consuming the" 461 " request_iterator: %s" 462 ), 463 rpc_error, 464 ) 465 return 466 467 await self._done_writing() 468 except: # pylint: disable=bare-except 469 # Client iterators can raise exceptions, which we should handle by 470 # cancelling the RPC and logging the client's error. No exceptions 471 # should escape this function. 472 _LOGGER.debug( 473 "Client request_iterator raised exception:\n%s", 474 traceback.format_exc(), 475 ) 476 self.cancel() 477 478 async def _write(self, request: RequestType) -> None: 479 if self.done(): 480 raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) 481 if self._done_writing_flag: 482 raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) 483 if not self._metadata_sent.is_set(): 484 await self._metadata_sent.wait() 485 if self.done(): 486 await self._raise_for_status() 487 488 serialized_request = _common.serialize( 489 request, self._request_serializer 490 ) 491 try: 492 await self._cython_call.send_serialized_message(serialized_request) 493 except cygrpc.InternalError as err: 494 self._cython_call.set_internal_error(str(err)) 495 await self._raise_for_status() 496 except asyncio.CancelledError: 497 if not self.cancelled(): 498 self.cancel() 499 raise 500 501 async def _done_writing(self) -> None: 502 if self.done(): 503 # If the RPC is finished, do nothing. 504 return 505 if not self._done_writing_flag: 506 # If the done writing is not sent before, try to send it. 507 self._done_writing_flag = True 508 try: 509 await self._cython_call.send_receive_close() 510 except asyncio.CancelledError: 511 if not self.cancelled(): 512 self.cancel() 513 raise 514 515 async def write(self, request: RequestType) -> None: 516 self._raise_for_different_style(_APIStyle.READER_WRITER) 517 await self._write(request) 518 519 async def done_writing(self) -> None: 520 """Signal peer that client is done writing. 521 522 This method is idempotent. 523 """ 524 self._raise_for_different_style(_APIStyle.READER_WRITER) 525 await self._done_writing() 526 527 async def wait_for_connection(self) -> None: 528 await self._metadata_sent.wait() 529 if self.done(): 530 await self._raise_for_status() 531 532 533class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): 534 """Object for managing unary-unary RPC calls. 535 536 Returned when an instance of `UnaryUnaryMultiCallable` object is called. 537 """ 538 539 _request: RequestType 540 _invocation_task: asyncio.Task 541 542 # pylint: disable=too-many-arguments 543 def __init__( 544 self, 545 request: RequestType, 546 deadline: Optional[float], 547 metadata: Metadata, 548 credentials: Optional[grpc.CallCredentials], 549 wait_for_ready: Optional[bool], 550 channel: cygrpc.AioChannel, 551 method: bytes, 552 request_serializer: SerializingFunction, 553 response_deserializer: DeserializingFunction, 554 loop: asyncio.AbstractEventLoop, 555 ) -> None: 556 super().__init__( 557 channel.call(method, deadline, credentials, wait_for_ready), 558 metadata, 559 request_serializer, 560 response_deserializer, 561 loop, 562 ) 563 self._request = request 564 self._context = cygrpc.build_census_context() 565 self._invocation_task = loop.create_task(self._invoke()) 566 self._init_unary_response_mixin(self._invocation_task) 567 568 async def _invoke(self) -> ResponseType: 569 serialized_request = _common.serialize( 570 self._request, self._request_serializer 571 ) 572 573 # NOTE(lidiz) asyncio.CancelledError is not a good transport for status, 574 # because the asyncio.Task class do not cache the exception object. 575 # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 576 try: 577 serialized_response = await self._cython_call.unary_unary( 578 serialized_request, self._metadata, self._context 579 ) 580 except asyncio.CancelledError: 581 if not self.cancelled(): 582 self.cancel() 583 584 if self._cython_call.is_ok(): 585 return _common.deserialize( 586 serialized_response, self._response_deserializer 587 ) 588 else: 589 return cygrpc.EOF 590 591 async def wait_for_connection(self) -> None: 592 await self._invocation_task 593 if self.done(): 594 await self._raise_for_status() 595 596 597class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): 598 """Object for managing unary-stream RPC calls. 599 600 Returned when an instance of `UnaryStreamMultiCallable` object is called. 601 """ 602 603 _request: RequestType 604 _send_unary_request_task: asyncio.Task 605 606 # pylint: disable=too-many-arguments 607 def __init__( 608 self, 609 request: RequestType, 610 deadline: Optional[float], 611 metadata: Metadata, 612 credentials: Optional[grpc.CallCredentials], 613 wait_for_ready: Optional[bool], 614 channel: cygrpc.AioChannel, 615 method: bytes, 616 request_serializer: SerializingFunction, 617 response_deserializer: DeserializingFunction, 618 loop: asyncio.AbstractEventLoop, 619 ) -> None: 620 super().__init__( 621 channel.call(method, deadline, credentials, wait_for_ready), 622 metadata, 623 request_serializer, 624 response_deserializer, 625 loop, 626 ) 627 self._request = request 628 self._context = cygrpc.build_census_context() 629 self._send_unary_request_task = loop.create_task( 630 self._send_unary_request() 631 ) 632 self._init_stream_response_mixin(self._send_unary_request_task) 633 634 async def _send_unary_request(self) -> ResponseType: 635 serialized_request = _common.serialize( 636 self._request, self._request_serializer 637 ) 638 try: 639 await self._cython_call.initiate_unary_stream( 640 serialized_request, self._metadata, self._context 641 ) 642 except asyncio.CancelledError: 643 if not self.cancelled(): 644 self.cancel() 645 raise 646 647 async def wait_for_connection(self) -> None: 648 await self._send_unary_request_task 649 if self.done(): 650 await self._raise_for_status() 651 652 653# pylint: disable=too-many-ancestors 654class StreamUnaryCall( 655 _StreamRequestMixin, _UnaryResponseMixin, Call, _base_call.StreamUnaryCall 656): 657 """Object for managing stream-unary RPC calls. 658 659 Returned when an instance of `StreamUnaryMultiCallable` object is called. 660 """ 661 662 # pylint: disable=too-many-arguments 663 def __init__( 664 self, 665 request_iterator: Optional[RequestIterableType], 666 deadline: Optional[float], 667 metadata: Metadata, 668 credentials: Optional[grpc.CallCredentials], 669 wait_for_ready: Optional[bool], 670 channel: cygrpc.AioChannel, 671 method: bytes, 672 request_serializer: SerializingFunction, 673 response_deserializer: DeserializingFunction, 674 loop: asyncio.AbstractEventLoop, 675 ) -> None: 676 super().__init__( 677 channel.call(method, deadline, credentials, wait_for_ready), 678 metadata, 679 request_serializer, 680 response_deserializer, 681 loop, 682 ) 683 684 self._context = cygrpc.build_census_context() 685 self._init_stream_request_mixin(request_iterator) 686 self._init_unary_response_mixin(loop.create_task(self._conduct_rpc())) 687 688 async def _conduct_rpc(self) -> ResponseType: 689 try: 690 serialized_response = await self._cython_call.stream_unary( 691 self._metadata, self._metadata_sent_observer, self._context 692 ) 693 except asyncio.CancelledError: 694 if not self.cancelled(): 695 self.cancel() 696 raise 697 698 if self._cython_call.is_ok(): 699 return _common.deserialize( 700 serialized_response, self._response_deserializer 701 ) 702 else: 703 return cygrpc.EOF 704 705 706class StreamStreamCall( 707 _StreamRequestMixin, _StreamResponseMixin, Call, _base_call.StreamStreamCall 708): 709 """Object for managing stream-stream RPC calls. 710 711 Returned when an instance of `StreamStreamMultiCallable` object is called. 712 """ 713 714 _initializer: asyncio.Task 715 716 # pylint: disable=too-many-arguments 717 def __init__( 718 self, 719 request_iterator: Optional[RequestIterableType], 720 deadline: Optional[float], 721 metadata: Metadata, 722 credentials: Optional[grpc.CallCredentials], 723 wait_for_ready: Optional[bool], 724 channel: cygrpc.AioChannel, 725 method: bytes, 726 request_serializer: SerializingFunction, 727 response_deserializer: DeserializingFunction, 728 loop: asyncio.AbstractEventLoop, 729 ) -> None: 730 super().__init__( 731 channel.call(method, deadline, credentials, wait_for_ready), 732 metadata, 733 request_serializer, 734 response_deserializer, 735 loop, 736 ) 737 self._context = cygrpc.build_census_context() 738 self._initializer = self._loop.create_task(self._prepare_rpc()) 739 self._init_stream_request_mixin(request_iterator) 740 self._init_stream_response_mixin(self._initializer) 741 742 async def _prepare_rpc(self): 743 """This method prepares the RPC for receiving/sending messages. 744 745 All other operations around the stream should only happen after the 746 completion of this method. 747 """ 748 try: 749 await self._cython_call.initiate_stream_stream( 750 self._metadata, self._metadata_sent_observer, self._context 751 ) 752 except asyncio.CancelledError: 753 if not self.cancelled(): 754 self.cancel() 755 # No need to raise RpcError here, because no one will `await` this task. 756