1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15import re
16from typing import Any, Dict, Optional
17
18from google.protobuf import json_format
19import google.protobuf.message
20import grpc
21
22logger = logging.getLogger(__name__)
23
24# Type aliases
25Message = google.protobuf.message.Message
26
27
28class GrpcClientHelper:
29    DEFAULT_RPC_DEADLINE_SEC = 90
30    channel: grpc.Channel
31    # This is purely cosmetic to make RPC logs look like method calls.
32    log_service_name: str
33    # This is purely cosmetic to output the RPC target. Normally set to the
34    # hostname:port of the remote service, but it doesn't have to be the
35    # real target. This is done so that when RPC are routed to the proxy
36    # or port forwarding, this still is set to a useful name.
37    log_target: str
38
39    def __init__(self,
40                 channel: grpc.Channel,
41                 stub_class: Any,
42                 *,
43                 log_target: Optional[str] = ''):
44        self.channel = channel
45        self.stub = stub_class(channel)
46        self.log_service_name = re.sub('Stub$', '',
47                                       self.stub.__class__.__name__)
48        self.log_target = log_target or ''
49
50    def call_unary_with_deadline(
51            self,
52            *,
53            rpc: str,
54            req: Message,
55            deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC,
56            log_level: Optional[int] = logging.DEBUG) -> Message:
57        if deadline_sec is None:
58            deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC
59
60        call_kwargs = dict(wait_for_ready=True, timeout=deadline_sec)
61        self._log_rpc_request(rpc, req, call_kwargs, log_level)
62
63        # Call RPC, e.g. RpcStub(channel).RpcMethod(req, ...options)
64        rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
65        return rpc_callable(req, **call_kwargs)
66
67    def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG):
68        logger.log(logging.DEBUG if log_level is None else log_level,
69                   '[%s] RPC %s.%s(request=%s(%r), %s)', self.log_target,
70                   self.log_service_name, rpc, req.__class__.__name__,
71                   json_format.MessageToDict(req),
72                   ', '.join({f'{k}={v}' for k, v in call_kwargs.items()}))
73
74
75class GrpcApp:
76    channels: Dict[int, grpc.Channel]
77
78    class NotFound(Exception):
79        """Requested resource not found"""
80
81        def __init__(self, message):
82            self.message = message
83            super().__init__(message)
84
85    def __init__(self, rpc_host):
86        self.rpc_host = rpc_host
87        # Cache gRPC channels per port
88        self.channels = dict()
89
90    def _make_channel(self, port) -> grpc.Channel:
91        if port not in self.channels:
92            target = f'{self.rpc_host}:{port}'
93            self.channels[port] = grpc.insecure_channel(target)
94        return self.channels[port]
95
96    def close(self):
97        # Close all channels
98        for channel in self.channels.values():
99            channel.close()
100
101    def __enter__(self):
102        return self
103
104    def __exit__(self, exc_type, exc_val, exc_tb):
105        self.close()
106        return False
107
108    def __del__(self):
109        self.close()
110