1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Internal utilities for gRPC Python.""" 15 16import collections 17import logging 18import threading 19import time 20from typing import Callable, Dict, Optional, Sequence 21 22import grpc # pytype: disable=pyi-error 23from grpc import _common # pytype: disable=pyi-error 24from grpc._typing import DoneCallbackType 25 26_LOGGER = logging.getLogger(__name__) 27 28_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = ( 29 'Exception calling connectivity future "done" callback!' 30) 31 32 33class RpcMethodHandler( 34 collections.namedtuple( 35 "_RpcMethodHandler", 36 ( 37 "request_streaming", 38 "response_streaming", 39 "request_deserializer", 40 "response_serializer", 41 "unary_unary", 42 "unary_stream", 43 "stream_unary", 44 "stream_stream", 45 ), 46 ), 47 grpc.RpcMethodHandler, 48): 49 pass 50 51 52class DictionaryGenericHandler(grpc.ServiceRpcHandler): 53 _name: str 54 _method_handlers: Dict[str, grpc.RpcMethodHandler] 55 56 def __init__( 57 self, service: str, method_handlers: Dict[str, grpc.RpcMethodHandler] 58 ): 59 self._name = service 60 self._method_handlers = { 61 _common.fully_qualified_method(service, method): method_handler 62 for method, method_handler in method_handlers.items() 63 } 64 65 def service_name(self) -> str: 66 return self._name 67 68 def service( 69 self, handler_call_details: grpc.HandlerCallDetails 70 ) -> Optional[grpc.RpcMethodHandler]: 71 details_method = handler_call_details.method 72 return self._method_handlers.get( 73 details_method 74 ) # pytype: disable=attribute-error 75 76 77class _ChannelReadyFuture(grpc.Future): 78 _condition: threading.Condition 79 _channel: grpc.Channel 80 _matured: bool 81 _cancelled: bool 82 _done_callbacks: Sequence[Callable] 83 84 def __init__(self, channel: grpc.Channel): 85 self._condition = threading.Condition() 86 self._channel = channel 87 88 self._matured = False 89 self._cancelled = False 90 self._done_callbacks = [] 91 92 def _block(self, timeout: Optional[float]) -> None: 93 until = None if timeout is None else time.time() + timeout 94 with self._condition: 95 while True: 96 if self._cancelled: 97 raise grpc.FutureCancelledError() 98 elif self._matured: 99 return 100 else: 101 if until is None: 102 self._condition.wait() 103 else: 104 remaining = until - time.time() 105 if remaining < 0: 106 raise grpc.FutureTimeoutError() 107 else: 108 self._condition.wait(timeout=remaining) 109 110 def _update(self, connectivity: Optional[grpc.ChannelConnectivity]) -> None: 111 with self._condition: 112 if ( 113 not self._cancelled 114 and connectivity is grpc.ChannelConnectivity.READY 115 ): 116 self._matured = True 117 self._channel.unsubscribe(self._update) 118 self._condition.notify_all() 119 done_callbacks = tuple(self._done_callbacks) 120 self._done_callbacks = None 121 else: 122 return 123 124 for done_callback in done_callbacks: 125 try: 126 done_callback(self) 127 except Exception: # pylint: disable=broad-except 128 _LOGGER.exception(_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE) 129 130 def cancel(self) -> bool: 131 with self._condition: 132 if not self._matured: 133 self._cancelled = True 134 self._channel.unsubscribe(self._update) 135 self._condition.notify_all() 136 done_callbacks = tuple(self._done_callbacks) 137 self._done_callbacks = None 138 else: 139 return False 140 141 for done_callback in done_callbacks: 142 try: 143 done_callback(self) 144 except Exception: # pylint: disable=broad-except 145 _LOGGER.exception(_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE) 146 147 return True 148 149 def cancelled(self) -> bool: 150 with self._condition: 151 return self._cancelled 152 153 def running(self) -> bool: 154 with self._condition: 155 return not self._cancelled and not self._matured 156 157 def done(self) -> bool: 158 with self._condition: 159 return self._cancelled or self._matured 160 161 def result(self, timeout: Optional[float] = None) -> None: 162 self._block(timeout) 163 164 def exception(self, timeout: Optional[float] = None) -> None: 165 self._block(timeout) 166 167 def traceback(self, timeout: Optional[float] = None) -> None: 168 self._block(timeout) 169 170 def add_done_callback(self, fn: DoneCallbackType): 171 with self._condition: 172 if not self._cancelled and not self._matured: 173 self._done_callbacks.append(fn) 174 return 175 176 fn(self) 177 178 def start(self): 179 with self._condition: 180 self._channel.subscribe(self._update, try_to_connect=True) 181 182 def __del__(self): 183 with self._condition: 184 if not self._cancelled and not self._matured: 185 self._channel.unsubscribe(self._update) 186 187 188def channel_ready_future(channel: grpc.Channel) -> _ChannelReadyFuture: 189 ready_future = _ChannelReadyFuture(channel) 190 ready_future.start() 191 return ready_future 192 193 194def first_version_is_lower(version1: str, version2: str) -> bool: 195 """ 196 Compares two versions in the format '1.60.1' or '1.60.1.dev0'. 197 198 This method will be used in all stubs generated by grpcio-tools to check whether 199 the stub version is compatible with the runtime grpcio. 200 201 Args: 202 version1: The first version string. 203 version2: The second version string. 204 205 Returns: 206 True if version1 is lower, False otherwise. 207 """ 208 version1_list = version1.split(".") 209 version2_list = version2.split(".") 210 211 try: 212 for i in range(3): 213 if int(version1_list[i]) < int(version2_list[i]): 214 return True 215 elif int(version1_list[i]) > int(version2_list[i]): 216 return False 217 except ValueError: 218 # Return false in case we can't convert version to int. 219 return False 220 221 # The version without dev0 will be considered lower. 222 return len(version1_list) < len(version2_list) 223