1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15This contains helpers for gRPC services defined in
16https://github.com/grpc/grpc-proto/blob/master/grpc/channelz/v1/channelz.proto
17"""
18import ipaddress
19import logging
20from typing import Iterator, Optional
21
22import grpc
23from grpc_channelz.v1 import channelz_pb2
24from grpc_channelz.v1 import channelz_pb2_grpc
25
26import framework.rpc
27
28logger = logging.getLogger(__name__)
29
30# Type aliases
31# Channel
32Channel = channelz_pb2.Channel
33ChannelConnectivityState = channelz_pb2.ChannelConnectivityState
34ChannelState = ChannelConnectivityState.State  # pylint: disable=no-member
35_GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest
36_GetTopChannelsResponse = channelz_pb2.GetTopChannelsResponse
37# Subchannel
38Subchannel = channelz_pb2.Subchannel
39_GetSubchannelRequest = channelz_pb2.GetSubchannelRequest
40_GetSubchannelResponse = channelz_pb2.GetSubchannelResponse
41# Server
42Server = channelz_pb2.Server
43_GetServersRequest = channelz_pb2.GetServersRequest
44_GetServersResponse = channelz_pb2.GetServersResponse
45# Sockets
46Socket = channelz_pb2.Socket
47SocketRef = channelz_pb2.SocketRef
48_GetSocketRequest = channelz_pb2.GetSocketRequest
49_GetSocketResponse = channelz_pb2.GetSocketResponse
50Address = channelz_pb2.Address
51Security = channelz_pb2.Security
52# Server Sockets
53_GetServerSocketsRequest = channelz_pb2.GetServerSocketsRequest
54_GetServerSocketsResponse = channelz_pb2.GetServerSocketsResponse
55
56
57class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
58    stub: channelz_pb2_grpc.ChannelzStub
59
60    def __init__(self,
61                 channel: grpc.Channel,
62                 *,
63                 log_target: Optional[str] = ''):
64        super().__init__(channel,
65                         channelz_pb2_grpc.ChannelzStub,
66                         log_target=log_target)
67
68    @staticmethod
69    def is_sock_tcpip_address(address: Address):
70        return address.WhichOneof('address') == 'tcpip_address'
71
72    @staticmethod
73    def is_ipv4(tcpip_address: Address.TcpIpAddress):
74        # According to proto, tcpip_address.ip_address is either IPv4 or IPv6.
75        # Correspondingly, it's either 4 bytes or 16 bytes in length.
76        return len(tcpip_address.ip_address) == 4
77
78    @classmethod
79    def sock_address_to_str(cls, address: Address):
80        if cls.is_sock_tcpip_address(address):
81            tcpip_address: Address.TcpIpAddress = address.tcpip_address
82            if cls.is_ipv4(tcpip_address):
83                ip = ipaddress.IPv4Address(tcpip_address.ip_address)
84            else:
85                ip = ipaddress.IPv6Address(tcpip_address.ip_address)
86            return f'{ip}:{tcpip_address.port}'
87        else:
88            raise NotImplementedError('Only tcpip_address implemented')
89
90    @classmethod
91    def sock_addresses_pretty(cls, socket: Socket):
92        return (f'local={cls.sock_address_to_str(socket.local)}, '
93                f'remote={cls.sock_address_to_str(socket.remote)}')
94
95    @staticmethod
96    def find_server_socket_matching_client(server_sockets: Iterator[Socket],
97                                           client_socket: Socket) -> Socket:
98        for server_socket in server_sockets:
99            if server_socket.remote == client_socket.local:
100                return server_socket
101        return None
102
103    @staticmethod
104    def channel_repr(channel: Channel) -> str:
105        result = f'<Channel channel_id={channel.ref.channel_id}'
106        if channel.data.target:
107            result += f' target={channel.data.target}'
108        result += f' state={ChannelState.Name(channel.data.state.state)}>'
109        return result
110
111    @staticmethod
112    def subchannel_repr(subchannel: Subchannel) -> str:
113        result = f'<Subchannel subchannel_id={subchannel.ref.subchannel_id}'
114        if subchannel.data.target:
115            result += f' target={subchannel.data.target}'
116        result += f' state={ChannelState.Name(subchannel.data.state.state)}>'
117        return result
118
119    def find_channels_for_target(self, target: str,
120                                 **kwargs) -> Iterator[Channel]:
121        return (channel for channel in self.list_channels(**kwargs)
122                if channel.data.target == target)
123
124    def find_server_listening_on_port(self, port: int,
125                                      **kwargs) -> Optional[Server]:
126        for server in self.list_servers(**kwargs):
127            listen_socket_ref: SocketRef
128            for listen_socket_ref in server.listen_socket:
129                listen_socket = self.get_socket(listen_socket_ref.socket_id,
130                                                **kwargs)
131                listen_address: Address = listen_socket.local
132                if (self.is_sock_tcpip_address(listen_address) and
133                        listen_address.tcpip_address.port == port):
134                    return server
135        return None
136
137    def list_channels(self, **kwargs) -> Iterator[Channel]:
138        """
139        Iterate over all pages of all root channels.
140
141        Root channels are those which application has directly created.
142        This does not include subchannels nor non-top level channels.
143        """
144        start: int = -1
145        response: Optional[_GetTopChannelsResponse] = None
146        while start < 0 or not response.end:
147            # From proto: To request subsequent pages, the client generates this
148            # value by adding 1 to the highest seen result ID.
149            start += 1
150            response = self.call_unary_with_deadline(
151                rpc='GetTopChannels',
152                req=_GetTopChannelsRequest(start_channel_id=start),
153                **kwargs)
154            for channel in response.channel:
155                start = max(start, channel.ref.channel_id)
156                yield channel
157
158    def list_servers(self, **kwargs) -> Iterator[Server]:
159        """Iterate over all pages of all servers that exist in the process."""
160        start: int = -1
161        response: Optional[_GetServersResponse] = None
162        while start < 0 or not response.end:
163            # From proto: To request subsequent pages, the client generates this
164            # value by adding 1 to the highest seen result ID.
165            start += 1
166            response = self.call_unary_with_deadline(
167                rpc='GetServers',
168                req=_GetServersRequest(start_server_id=start),
169                **kwargs)
170            for server in response.server:
171                start = max(start, server.ref.server_id)
172                yield server
173
174    def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]:
175        """List all server sockets that exist in server process.
176
177        Iterating over the results will resolve additional pages automatically.
178        """
179        start: int = -1
180        response: Optional[_GetServerSocketsResponse] = None
181        while start < 0 or not response.end:
182            # From proto: To request subsequent pages, the client generates this
183            # value by adding 1 to the highest seen result ID.
184            start += 1
185            response = self.call_unary_with_deadline(
186                rpc='GetServerSockets',
187                req=_GetServerSocketsRequest(server_id=server.ref.server_id,
188                                             start_socket_id=start),
189                **kwargs)
190            socket_ref: SocketRef
191            for socket_ref in response.socket_ref:
192                start = max(start, socket_ref.socket_id)
193                # Yield actual socket
194                yield self.get_socket(socket_ref.socket_id, **kwargs)
195
196    def list_channel_sockets(self, channel: Channel,
197                             **kwargs) -> Iterator[Socket]:
198        """List all sockets of all subchannels of a given channel."""
199        for subchannel in self.list_channel_subchannels(channel, **kwargs):
200            yield from self.list_subchannels_sockets(subchannel, **kwargs)
201
202    def list_channel_subchannels(self, channel: Channel,
203                                 **kwargs) -> Iterator[Subchannel]:
204        """List all subchannels of a given channel."""
205        for subchannel_ref in channel.subchannel_ref:
206            yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs)
207
208    def list_subchannels_sockets(self, subchannel: Subchannel,
209                                 **kwargs) -> Iterator[Socket]:
210        """List all sockets of a given subchannel."""
211        for socket_ref in subchannel.socket_ref:
212            yield self.get_socket(socket_ref.socket_id, **kwargs)
213
214    def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel:
215        """Return a single Subchannel, otherwise raises RpcError."""
216        response: _GetSubchannelResponse = self.call_unary_with_deadline(
217            rpc='GetSubchannel',
218            req=_GetSubchannelRequest(subchannel_id=subchannel_id),
219            **kwargs)
220        return response.subchannel
221
222    def get_socket(self, socket_id, **kwargs) -> Socket:
223        """Return a single Socket, otherwise raises RpcError."""
224        response: _GetSocketResponse = self.call_unary_with_deadline(
225            rpc='GetSocket',
226            req=_GetSocketRequest(socket_id=socket_id),
227            **kwargs)
228        return response.socket
229