xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/unit/_tcp_proxy.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2019 the gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14""" Proxies a TCP connection between a single client-server pair.
15
16This proxy is not suitable for production, but should work well for cases in
17which a test needs to spy on the bytes put on the wire between a server and
18a client.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import datetime
26import select
27import socket
28import threading
29
30from tests.unit.framework.common import get_socket
31
32_TCP_PROXY_BUFFER_SIZE = 1024
33_TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500)
34
35
36def _init_proxy_socket(gateway_address, gateway_port):
37    proxy_socket = socket.create_connection((gateway_address, gateway_port))
38    return proxy_socket
39
40
41class TcpProxy(object):
42    """Proxies a TCP connection between one client and one server."""
43
44    def __init__(self, bind_address, gateway_address, gateway_port):
45        self._bind_address = bind_address
46        self._gateway_address = gateway_address
47        self._gateway_port = gateway_port
48
49        self._byte_count_lock = threading.RLock()
50        self._sent_byte_count = 0
51        self._received_byte_count = 0
52
53        self._stop_event = threading.Event()
54
55        self._port = None
56        self._listen_socket = None
57        self._proxy_socket = None
58
59        # The following three attributes are owned by the serving thread.
60        self._northbound_data = b""
61        self._southbound_data = b""
62        self._client_sockets = []
63
64        self._thread = threading.Thread(target=self._run_proxy)
65
66    def start(self):
67        _, self._port, self._listen_socket = get_socket(
68            bind_address=self._bind_address
69        )
70        self._proxy_socket = _init_proxy_socket(
71            self._gateway_address, self._gateway_port
72        )
73        self._thread.start()
74
75    def get_port(self):
76        return self._port
77
78    def _handle_reads(self, sockets_to_read):
79        for socket_to_read in sockets_to_read:
80            if socket_to_read is self._listen_socket:
81                client_socket, client_address = socket_to_read.accept()
82                self._client_sockets.append(client_socket)
83            elif socket_to_read is self._proxy_socket:
84                data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE)
85                with self._byte_count_lock:
86                    self._received_byte_count += len(data)
87                self._northbound_data += data
88            elif socket_to_read in self._client_sockets:
89                data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE)
90                if data:
91                    with self._byte_count_lock:
92                        self._sent_byte_count += len(data)
93                    self._southbound_data += data
94                else:
95                    self._client_sockets.remove(socket_to_read)
96            else:
97                raise RuntimeError("Unidentified socket appeared in read set.")
98
99    def _handle_writes(self, sockets_to_write):
100        for socket_to_write in sockets_to_write:
101            if socket_to_write is self._proxy_socket:
102                if self._southbound_data:
103                    self._proxy_socket.sendall(self._southbound_data)
104                    self._southbound_data = b""
105            elif socket_to_write in self._client_sockets:
106                if self._northbound_data:
107                    socket_to_write.sendall(self._northbound_data)
108                    self._northbound_data = b""
109
110    def _run_proxy(self):
111        while not self._stop_event.is_set():
112            expected_reads = (self._listen_socket, self._proxy_socket) + tuple(
113                self._client_sockets
114            )
115            expected_writes = expected_reads
116            sockets_to_read, sockets_to_write, _ = select.select(
117                expected_reads,
118                expected_writes,
119                (),
120                _TCP_PROXY_TIMEOUT.total_seconds(),
121            )
122            self._handle_reads(sockets_to_read)
123            self._handle_writes(sockets_to_write)
124        for client_socket in self._client_sockets:
125            client_socket.close()
126
127    def stop(self):
128        self._stop_event.set()
129        self._thread.join()
130        self._listen_socket.close()
131        self._proxy_socket.close()
132
133    def get_byte_count(self):
134        with self._byte_count_lock:
135            return self._sent_byte_count, self._received_byte_count
136
137    def reset_byte_count(self):
138        with self._byte_count_lock:
139            self._byte_count = 0
140            self._received_byte_count = 0
141
142    def __enter__(self):
143        self.start()
144        return self
145
146    def __exit__(self, exc_type, exc_val, exc_tb):
147        self.stop()
148