1# Copyright 2020 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import logging 15import re 16from typing import Any, Dict, Optional 17 18from google.protobuf import json_format 19import google.protobuf.message 20import grpc 21 22logger = logging.getLogger(__name__) 23 24# Type aliases 25Message = google.protobuf.message.Message 26 27 28class GrpcClientHelper: 29 DEFAULT_RPC_DEADLINE_SEC = 90 30 channel: grpc.Channel 31 # This is purely cosmetic to make RPC logs look like method calls. 32 log_service_name: str 33 # This is purely cosmetic to output the RPC target. Normally set to the 34 # hostname:port of the remote service, but it doesn't have to be the 35 # real target. This is done so that when RPC are routed to the proxy 36 # or port forwarding, this still is set to a useful name. 37 log_target: str 38 39 def __init__(self, 40 channel: grpc.Channel, 41 stub_class: Any, 42 *, 43 log_target: Optional[str] = ''): 44 self.channel = channel 45 self.stub = stub_class(channel) 46 self.log_service_name = re.sub('Stub$', '', 47 self.stub.__class__.__name__) 48 self.log_target = log_target or '' 49 50 def call_unary_with_deadline( 51 self, 52 *, 53 rpc: str, 54 req: Message, 55 deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC, 56 log_level: Optional[int] = logging.DEBUG) -> Message: 57 if deadline_sec is None: 58 deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC 59 60 call_kwargs = dict(wait_for_ready=True, timeout=deadline_sec) 61 self._log_rpc_request(rpc, req, call_kwargs, log_level) 62 63 # Call RPC, e.g. RpcStub(channel).RpcMethod(req, ...options) 64 rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc) 65 return rpc_callable(req, **call_kwargs) 66 67 def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG): 68 logger.log(logging.DEBUG if log_level is None else log_level, 69 '[%s] RPC %s.%s(request=%s(%r), %s)', self.log_target, 70 self.log_service_name, rpc, req.__class__.__name__, 71 json_format.MessageToDict(req), 72 ', '.join({f'{k}={v}' for k, v in call_kwargs.items()})) 73 74 75class GrpcApp: 76 channels: Dict[int, grpc.Channel] 77 78 class NotFound(Exception): 79 """Requested resource not found""" 80 81 def __init__(self, message): 82 self.message = message 83 super().__init__(message) 84 85 def __init__(self, rpc_host): 86 self.rpc_host = rpc_host 87 # Cache gRPC channels per port 88 self.channels = dict() 89 90 def _make_channel(self, port) -> grpc.Channel: 91 if port not in self.channels: 92 target = f'{self.rpc_host}:{port}' 93 self.channels[port] = grpc.insecure_channel(target) 94 return self.channels[port] 95 96 def close(self): 97 # Close all channels 98 for channel in self.channels.values(): 99 channel.close() 100 101 def __enter__(self): 102 return self 103 104 def __exit__(self, exc_type, exc_val, exc_tb): 105 self.close() 106 return False 107 108 def __del__(self): 109 self.close() 110