1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15This contains helpers for gRPC services defined in
16https://github.com/grpc/grpc/blob/master/src/proto/grpc/testing/test.proto
17"""
18import logging
19from typing import Iterable, Optional, Tuple
20
21import grpc
22from grpc_health.v1 import health_pb2
23from grpc_health.v1 import health_pb2_grpc
24
25import framework.rpc
26from src.proto.grpc.testing import empty_pb2
27from src.proto.grpc.testing import messages_pb2
28from src.proto.grpc.testing import test_pb2_grpc
29
30# Type aliases
31_LoadBalancerStatsRequest = messages_pb2.LoadBalancerStatsRequest
32LoadBalancerStatsResponse = messages_pb2.LoadBalancerStatsResponse
33_LoadBalancerAccumulatedStatsRequest = messages_pb2.LoadBalancerAccumulatedStatsRequest
34LoadBalancerAccumulatedStatsResponse = messages_pb2.LoadBalancerAccumulatedStatsResponse
35
36
37class LoadBalancerStatsServiceClient(framework.rpc.grpc.GrpcClientHelper):
38    stub: test_pb2_grpc.LoadBalancerStatsServiceStub
39    STATS_PARTIAL_RESULTS_TIMEOUT_SEC = 1200
40    STATS_ACCUMULATED_RESULTS_TIMEOUT_SEC = 600
41
42    def __init__(self,
43                 channel: grpc.Channel,
44                 *,
45                 log_target: Optional[str] = ''):
46        super().__init__(channel,
47                         test_pb2_grpc.LoadBalancerStatsServiceStub,
48                         log_target=log_target)
49
50    def get_client_stats(
51        self,
52        *,
53        num_rpcs: int,
54        timeout_sec: Optional[int] = STATS_PARTIAL_RESULTS_TIMEOUT_SEC,
55    ) -> LoadBalancerStatsResponse:
56        if timeout_sec is None:
57            timeout_sec = self.STATS_PARTIAL_RESULTS_TIMEOUT_SEC
58
59        return self.call_unary_with_deadline(rpc='GetClientStats',
60                                             req=_LoadBalancerStatsRequest(
61                                                 num_rpcs=num_rpcs,
62                                                 timeout_sec=timeout_sec),
63                                             deadline_sec=timeout_sec,
64                                             log_level=logging.INFO)
65
66    def get_client_accumulated_stats(
67        self,
68        *,
69        timeout_sec: Optional[int] = None
70    ) -> LoadBalancerAccumulatedStatsResponse:
71        if timeout_sec is None:
72            timeout_sec = self.STATS_ACCUMULATED_RESULTS_TIMEOUT_SEC
73
74        return self.call_unary_with_deadline(
75            rpc='GetClientAccumulatedStats',
76            req=_LoadBalancerAccumulatedStatsRequest(),
77            deadline_sec=timeout_sec,
78            log_level=logging.INFO)
79
80
81class XdsUpdateClientConfigureServiceClient(framework.rpc.grpc.GrpcClientHelper
82                                           ):
83    stub: test_pb2_grpc.XdsUpdateClientConfigureServiceStub
84    CONFIGURE_TIMEOUT_SEC: int = 5
85
86    def __init__(self,
87                 channel: grpc.Channel,
88                 *,
89                 log_target: Optional[str] = ''):
90        super().__init__(channel,
91                         test_pb2_grpc.XdsUpdateClientConfigureServiceStub,
92                         log_target=log_target)
93
94    def configure(
95        self,
96        *,
97        rpc_types: Iterable[str],
98        metadata: Optional[Iterable[Tuple[str, str, str]]] = None,
99        app_timeout: Optional[int] = None,
100        timeout_sec: int = CONFIGURE_TIMEOUT_SEC,
101    ) -> None:
102        request = messages_pb2.ClientConfigureRequest()
103        for rpc_type in rpc_types:
104            request.types.append(
105                messages_pb2.ClientConfigureRequest.RpcType.Value(rpc_type))
106        if metadata:
107            for entry in metadata:
108                request.metadata.append(
109                    messages_pb2.ClientConfigureRequest.Metadata(
110                        type=messages_pb2.ClientConfigureRequest.RpcType.Value(
111                            entry[0]),
112                        key=entry[1],
113                        value=entry[2],
114                    ))
115        if app_timeout:
116            request.timeout_sec = app_timeout
117        # Configure's response is empty
118        self.call_unary_with_deadline(rpc='Configure',
119                                      req=request,
120                                      deadline_sec=timeout_sec,
121                                      log_level=logging.INFO)
122
123
124class XdsUpdateHealthServiceClient(framework.rpc.grpc.GrpcClientHelper):
125    stub: test_pb2_grpc.XdsUpdateHealthServiceStub
126
127    def __init__(self, channel: grpc.Channel, log_target: Optional[str] = ''):
128        super().__init__(channel,
129                         test_pb2_grpc.XdsUpdateHealthServiceStub,
130                         log_target=log_target)
131
132    def set_serving(self):
133        self.call_unary_with_deadline(rpc='SetServing',
134                                      req=empty_pb2.Empty(),
135                                      log_level=logging.INFO)
136
137    def set_not_serving(self):
138        self.call_unary_with_deadline(rpc='SetNotServing',
139                                      req=empty_pb2.Empty(),
140                                      log_level=logging.INFO)
141
142
143class HealthClient(framework.rpc.grpc.GrpcClientHelper):
144    stub: health_pb2_grpc.HealthStub
145
146    def __init__(self, channel: grpc.Channel, log_target: Optional[str] = ''):
147        super().__init__(channel,
148                         health_pb2_grpc.HealthStub,
149                         log_target=log_target)
150
151    def check_health(self):
152        return self.call_unary_with_deadline(
153            rpc='Check',
154            req=health_pb2.HealthCheckRequest(),
155            log_level=logging.INFO)
156