1# Copyright 2019 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Invocation-side implementation of gRPC Asyncio Python.""" 15 16import asyncio 17import sys 18from typing import Any, Iterable, List, Optional, Sequence 19 20import grpc 21from grpc import _common 22from grpc import _compression 23from grpc import _grpcio_metadata 24from grpc._cython import cygrpc 25 26from . import _base_call 27from . import _base_channel 28from ._call import StreamStreamCall 29from ._call import StreamUnaryCall 30from ._call import UnaryStreamCall 31from ._call import UnaryUnaryCall 32from ._interceptor import ClientInterceptor 33from ._interceptor import InterceptedStreamStreamCall 34from ._interceptor import InterceptedStreamUnaryCall 35from ._interceptor import InterceptedUnaryStreamCall 36from ._interceptor import InterceptedUnaryUnaryCall 37from ._interceptor import StreamStreamClientInterceptor 38from ._interceptor import StreamUnaryClientInterceptor 39from ._interceptor import UnaryStreamClientInterceptor 40from ._interceptor import UnaryUnaryClientInterceptor 41from ._metadata import Metadata 42from ._typing import ChannelArgumentType 43from ._typing import DeserializingFunction 44from ._typing import MetadataType 45from ._typing import RequestIterableType 46from ._typing import RequestType 47from ._typing import ResponseType 48from ._typing import SerializingFunction 49from ._utils import _timeout_to_deadline 50 51_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__) 52 53if sys.version_info[1] < 7: 54 55 def _all_tasks() -> Iterable[asyncio.Task]: 56 return asyncio.Task.all_tasks() # pylint: disable=no-member 57 58else: 59 60 def _all_tasks() -> Iterable[asyncio.Task]: 61 return asyncio.all_tasks() 62 63 64def _augment_channel_arguments( 65 base_options: ChannelArgumentType, compression: Optional[grpc.Compression] 66): 67 compression_channel_argument = _compression.create_channel_option( 68 compression 69 ) 70 user_agent_channel_argument = ( 71 ( 72 cygrpc.ChannelArgKey.primary_user_agent_string, 73 _USER_AGENT, 74 ), 75 ) 76 return ( 77 tuple(base_options) 78 + compression_channel_argument 79 + user_agent_channel_argument 80 ) 81 82 83class _BaseMultiCallable: 84 """Base class of all multi callable objects. 85 86 Handles the initialization logic and stores common attributes. 87 """ 88 89 _loop: asyncio.AbstractEventLoop 90 _channel: cygrpc.AioChannel 91 _method: bytes 92 _request_serializer: SerializingFunction 93 _response_deserializer: DeserializingFunction 94 _interceptors: Optional[Sequence[ClientInterceptor]] 95 _references: List[Any] 96 _loop: asyncio.AbstractEventLoop 97 98 # pylint: disable=too-many-arguments 99 def __init__( 100 self, 101 channel: cygrpc.AioChannel, 102 method: bytes, 103 request_serializer: SerializingFunction, 104 response_deserializer: DeserializingFunction, 105 interceptors: Optional[Sequence[ClientInterceptor]], 106 references: List[Any], 107 loop: asyncio.AbstractEventLoop, 108 ) -> None: 109 self._loop = loop 110 self._channel = channel 111 self._method = method 112 self._request_serializer = request_serializer 113 self._response_deserializer = response_deserializer 114 self._interceptors = interceptors 115 self._references = references 116 117 @staticmethod 118 def _init_metadata( 119 metadata: Optional[MetadataType] = None, 120 compression: Optional[grpc.Compression] = None, 121 ) -> Metadata: 122 """Based on the provided values for <metadata> or <compression> initialise the final 123 metadata, as it should be used for the current call. 124 """ 125 metadata = metadata or Metadata() 126 if not isinstance(metadata, Metadata) and isinstance(metadata, tuple): 127 metadata = Metadata.from_tuple(metadata) 128 if compression: 129 metadata = Metadata( 130 *_compression.augment_metadata(metadata, compression) 131 ) 132 return metadata 133 134 135class UnaryUnaryMultiCallable( 136 _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable 137): 138 def __call__( 139 self, 140 request: RequestType, 141 *, 142 timeout: Optional[float] = None, 143 metadata: Optional[MetadataType] = None, 144 credentials: Optional[grpc.CallCredentials] = None, 145 wait_for_ready: Optional[bool] = None, 146 compression: Optional[grpc.Compression] = None, 147 ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]: 148 metadata = self._init_metadata(metadata, compression) 149 if not self._interceptors: 150 call = UnaryUnaryCall( 151 request, 152 _timeout_to_deadline(timeout), 153 metadata, 154 credentials, 155 wait_for_ready, 156 self._channel, 157 self._method, 158 self._request_serializer, 159 self._response_deserializer, 160 self._loop, 161 ) 162 else: 163 call = InterceptedUnaryUnaryCall( 164 self._interceptors, 165 request, 166 timeout, 167 metadata, 168 credentials, 169 wait_for_ready, 170 self._channel, 171 self._method, 172 self._request_serializer, 173 self._response_deserializer, 174 self._loop, 175 ) 176 177 return call 178 179 180class UnaryStreamMultiCallable( 181 _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable 182): 183 def __call__( 184 self, 185 request: RequestType, 186 *, 187 timeout: Optional[float] = None, 188 metadata: Optional[MetadataType] = None, 189 credentials: Optional[grpc.CallCredentials] = None, 190 wait_for_ready: Optional[bool] = None, 191 compression: Optional[grpc.Compression] = None, 192 ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]: 193 metadata = self._init_metadata(metadata, compression) 194 195 if not self._interceptors: 196 call = UnaryStreamCall( 197 request, 198 _timeout_to_deadline(timeout), 199 metadata, 200 credentials, 201 wait_for_ready, 202 self._channel, 203 self._method, 204 self._request_serializer, 205 self._response_deserializer, 206 self._loop, 207 ) 208 else: 209 call = InterceptedUnaryStreamCall( 210 self._interceptors, 211 request, 212 timeout, 213 metadata, 214 credentials, 215 wait_for_ready, 216 self._channel, 217 self._method, 218 self._request_serializer, 219 self._response_deserializer, 220 self._loop, 221 ) 222 223 return call 224 225 226class StreamUnaryMultiCallable( 227 _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable 228): 229 def __call__( 230 self, 231 request_iterator: Optional[RequestIterableType] = None, 232 timeout: Optional[float] = None, 233 metadata: Optional[MetadataType] = None, 234 credentials: Optional[grpc.CallCredentials] = None, 235 wait_for_ready: Optional[bool] = None, 236 compression: Optional[grpc.Compression] = None, 237 ) -> _base_call.StreamUnaryCall: 238 metadata = self._init_metadata(metadata, compression) 239 240 if not self._interceptors: 241 call = StreamUnaryCall( 242 request_iterator, 243 _timeout_to_deadline(timeout), 244 metadata, 245 credentials, 246 wait_for_ready, 247 self._channel, 248 self._method, 249 self._request_serializer, 250 self._response_deserializer, 251 self._loop, 252 ) 253 else: 254 call = InterceptedStreamUnaryCall( 255 self._interceptors, 256 request_iterator, 257 timeout, 258 metadata, 259 credentials, 260 wait_for_ready, 261 self._channel, 262 self._method, 263 self._request_serializer, 264 self._response_deserializer, 265 self._loop, 266 ) 267 268 return call 269 270 271class StreamStreamMultiCallable( 272 _BaseMultiCallable, _base_channel.StreamStreamMultiCallable 273): 274 def __call__( 275 self, 276 request_iterator: Optional[RequestIterableType] = None, 277 timeout: Optional[float] = None, 278 metadata: Optional[MetadataType] = None, 279 credentials: Optional[grpc.CallCredentials] = None, 280 wait_for_ready: Optional[bool] = None, 281 compression: Optional[grpc.Compression] = None, 282 ) -> _base_call.StreamStreamCall: 283 metadata = self._init_metadata(metadata, compression) 284 285 if not self._interceptors: 286 call = StreamStreamCall( 287 request_iterator, 288 _timeout_to_deadline(timeout), 289 metadata, 290 credentials, 291 wait_for_ready, 292 self._channel, 293 self._method, 294 self._request_serializer, 295 self._response_deserializer, 296 self._loop, 297 ) 298 else: 299 call = InterceptedStreamStreamCall( 300 self._interceptors, 301 request_iterator, 302 timeout, 303 metadata, 304 credentials, 305 wait_for_ready, 306 self._channel, 307 self._method, 308 self._request_serializer, 309 self._response_deserializer, 310 self._loop, 311 ) 312 313 return call 314 315 316class Channel(_base_channel.Channel): 317 _loop: asyncio.AbstractEventLoop 318 _channel: cygrpc.AioChannel 319 _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] 320 _unary_stream_interceptors: List[UnaryStreamClientInterceptor] 321 _stream_unary_interceptors: List[StreamUnaryClientInterceptor] 322 _stream_stream_interceptors: List[StreamStreamClientInterceptor] 323 324 def __init__( 325 self, 326 target: str, 327 options: ChannelArgumentType, 328 credentials: Optional[grpc.ChannelCredentials], 329 compression: Optional[grpc.Compression], 330 interceptors: Optional[Sequence[ClientInterceptor]], 331 ): 332 """Constructor. 333 334 Args: 335 target: The target to which to connect. 336 options: Configuration options for the channel. 337 credentials: A cygrpc.ChannelCredentials or None. 338 compression: An optional value indicating the compression method to be 339 used over the lifetime of the channel. 340 interceptors: An optional list of interceptors that would be used for 341 intercepting any RPC executed with that channel. 342 """ 343 self._unary_unary_interceptors = [] 344 self._unary_stream_interceptors = [] 345 self._stream_unary_interceptors = [] 346 self._stream_stream_interceptors = [] 347 348 if interceptors is not None: 349 for interceptor in interceptors: 350 if isinstance(interceptor, UnaryUnaryClientInterceptor): 351 self._unary_unary_interceptors.append(interceptor) 352 elif isinstance(interceptor, UnaryStreamClientInterceptor): 353 self._unary_stream_interceptors.append(interceptor) 354 elif isinstance(interceptor, StreamUnaryClientInterceptor): 355 self._stream_unary_interceptors.append(interceptor) 356 elif isinstance(interceptor, StreamStreamClientInterceptor): 357 self._stream_stream_interceptors.append(interceptor) 358 else: 359 raise ValueError( 360 "Interceptor {} must be ".format(interceptor) 361 + "{} or ".format(UnaryUnaryClientInterceptor.__name__) 362 + "{} or ".format(UnaryStreamClientInterceptor.__name__) 363 + "{} or ".format(StreamUnaryClientInterceptor.__name__) 364 + "{}. ".format(StreamStreamClientInterceptor.__name__) 365 ) 366 367 self._loop = cygrpc.get_working_loop() 368 self._channel = cygrpc.AioChannel( 369 _common.encode(target), 370 _augment_channel_arguments(options, compression), 371 credentials, 372 self._loop, 373 ) 374 375 async def __aenter__(self): 376 return self 377 378 async def __aexit__(self, exc_type, exc_val, exc_tb): 379 await self._close(None) 380 381 async def _close(self, grace): # pylint: disable=too-many-branches 382 if self._channel.closed(): 383 return 384 385 # No new calls will be accepted by the Cython channel. 386 self._channel.closing() 387 388 # Iterate through running tasks 389 tasks = _all_tasks() 390 calls = [] 391 call_tasks = [] 392 for task in tasks: 393 try: 394 stack = task.get_stack(limit=1) 395 except AttributeError as attribute_error: 396 # NOTE(lidiz) tl;dr: If the Task is created with a CPython 397 # object, it will trigger AttributeError. 398 # 399 # In the global finalizer, the event loop schedules 400 # a CPython PyAsyncGenAThrow object. 401 # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484 402 # 403 # However, the PyAsyncGenAThrow object is written in C and 404 # failed to include the normal Python frame objects. Hence, 405 # this exception is a false negative, and it is safe to ignore 406 # the failure. It is fixed by https://github.com/python/cpython/pull/18669, 407 # but not available until 3.9 or 3.8.3. So, we have to keep it 408 # for a while. 409 # TODO(lidiz) drop this hack after 3.8 deprecation 410 if "frame" in str(attribute_error): 411 continue 412 else: 413 raise 414 415 # If the Task is created by a C-extension, the stack will be empty. 416 if not stack: 417 continue 418 419 # Locate ones created by `aio.Call`. 420 frame = stack[0] 421 candidate = frame.f_locals.get("self") 422 if candidate: 423 if isinstance(candidate, _base_call.Call): 424 if hasattr(candidate, "_channel"): 425 # For intercepted Call object 426 if candidate._channel is not self._channel: 427 continue 428 elif hasattr(candidate, "_cython_call"): 429 # For normal Call object 430 if candidate._cython_call._channel is not self._channel: 431 continue 432 else: 433 # Unidentified Call object 434 raise cygrpc.InternalError( 435 f"Unrecognized call object: {candidate}" 436 ) 437 438 calls.append(candidate) 439 call_tasks.append(task) 440 441 # If needed, try to wait for them to finish. 442 # Call objects are not always awaitables. 443 if grace and call_tasks: 444 await asyncio.wait(call_tasks, timeout=grace) 445 446 # Time to cancel existing calls. 447 for call in calls: 448 call.cancel() 449 450 # Destroy the channel 451 self._channel.close() 452 453 async def close(self, grace: Optional[float] = None): 454 await self._close(grace) 455 456 def __del__(self): 457 if hasattr(self, "_channel"): 458 if not self._channel.closed(): 459 self._channel.close() 460 461 def get_state( 462 self, try_to_connect: bool = False 463 ) -> grpc.ChannelConnectivity: 464 result = self._channel.check_connectivity_state(try_to_connect) 465 return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] 466 467 async def wait_for_state_change( 468 self, 469 last_observed_state: grpc.ChannelConnectivity, 470 ) -> None: 471 assert await self._channel.watch_connectivity_state( 472 last_observed_state.value[0], None 473 ) 474 475 async def channel_ready(self) -> None: 476 state = self.get_state(try_to_connect=True) 477 while state != grpc.ChannelConnectivity.READY: 478 await self.wait_for_state_change(state) 479 state = self.get_state(try_to_connect=True) 480 481 # TODO(xuanwn): Implement this method after we have 482 # observability for Asyncio. 483 def _get_registered_call_handle(self, method: str) -> int: 484 pass 485 486 # TODO(xuanwn): Implement _registered_method after we have 487 # observability for Asyncio. 488 # pylint: disable=arguments-differ,unused-argument 489 def unary_unary( 490 self, 491 method: str, 492 request_serializer: Optional[SerializingFunction] = None, 493 response_deserializer: Optional[DeserializingFunction] = None, 494 _registered_method: Optional[bool] = False, 495 ) -> UnaryUnaryMultiCallable: 496 return UnaryUnaryMultiCallable( 497 self._channel, 498 _common.encode(method), 499 request_serializer, 500 response_deserializer, 501 self._unary_unary_interceptors, 502 [self], 503 self._loop, 504 ) 505 506 # TODO(xuanwn): Implement _registered_method after we have 507 # observability for Asyncio. 508 # pylint: disable=arguments-differ,unused-argument 509 def unary_stream( 510 self, 511 method: str, 512 request_serializer: Optional[SerializingFunction] = None, 513 response_deserializer: Optional[DeserializingFunction] = None, 514 _registered_method: Optional[bool] = False, 515 ) -> UnaryStreamMultiCallable: 516 return UnaryStreamMultiCallable( 517 self._channel, 518 _common.encode(method), 519 request_serializer, 520 response_deserializer, 521 self._unary_stream_interceptors, 522 [self], 523 self._loop, 524 ) 525 526 # TODO(xuanwn): Implement _registered_method after we have 527 # observability for Asyncio. 528 # pylint: disable=arguments-differ,unused-argument 529 def stream_unary( 530 self, 531 method: str, 532 request_serializer: Optional[SerializingFunction] = None, 533 response_deserializer: Optional[DeserializingFunction] = None, 534 _registered_method: Optional[bool] = False, 535 ) -> StreamUnaryMultiCallable: 536 return StreamUnaryMultiCallable( 537 self._channel, 538 _common.encode(method), 539 request_serializer, 540 response_deserializer, 541 self._stream_unary_interceptors, 542 [self], 543 self._loop, 544 ) 545 546 # TODO(xuanwn): Implement _registered_method after we have 547 # observability for Asyncio. 548 # pylint: disable=arguments-differ,unused-argument 549 def stream_stream( 550 self, 551 method: str, 552 request_serializer: Optional[SerializingFunction] = None, 553 response_deserializer: Optional[DeserializingFunction] = None, 554 _registered_method: Optional[bool] = False, 555 ) -> StreamStreamMultiCallable: 556 return StreamStreamMultiCallable( 557 self._channel, 558 _common.encode(method), 559 request_serializer, 560 response_deserializer, 561 self._stream_stream_interceptors, 562 [self], 563 self._loop, 564 ) 565 566 567def insecure_channel( 568 target: str, 569 options: Optional[ChannelArgumentType] = None, 570 compression: Optional[grpc.Compression] = None, 571 interceptors: Optional[Sequence[ClientInterceptor]] = None, 572): 573 """Creates an insecure asynchronous Channel to a server. 574 575 Args: 576 target: The server address 577 options: An optional list of key-value pairs (:term:`channel_arguments` 578 in gRPC Core runtime) to configure the channel. 579 compression: An optional value indicating the compression method to be 580 used over the lifetime of the channel. 581 interceptors: An optional sequence of interceptors that will be executed for 582 any call executed with this channel. 583 584 Returns: 585 A Channel. 586 """ 587 return Channel( 588 target, 589 () if options is None else options, 590 None, 591 compression, 592 interceptors, 593 ) 594 595 596def secure_channel( 597 target: str, 598 credentials: grpc.ChannelCredentials, 599 options: Optional[ChannelArgumentType] = None, 600 compression: Optional[grpc.Compression] = None, 601 interceptors: Optional[Sequence[ClientInterceptor]] = None, 602): 603 """Creates a secure asynchronous Channel to a server. 604 605 Args: 606 target: The server address. 607 credentials: A ChannelCredentials instance. 608 options: An optional list of key-value pairs (:term:`channel_arguments` 609 in gRPC Core runtime) to configure the channel. 610 compression: An optional value indicating the compression method to be 611 used over the lifetime of the channel. 612 interceptors: An optional sequence of interceptors that will be executed for 613 any call executed with this channel. 614 615 Returns: 616 An aio.Channel. 617 """ 618 return Channel( 619 target, 620 () if options is None else options, 621 credentials._credentials, 622 compression, 623 interceptors, 624 ) 625