1# Copyright 2019 the gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14""" Proxies a TCP connection between a single client-server pair. 15 16This proxy is not suitable for production, but should work well for cases in 17which a test needs to spy on the bytes put on the wire between a server and 18a client. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import datetime 26import select 27import socket 28import threading 29 30from tests.unit.framework.common import get_socket 31 32_TCP_PROXY_BUFFER_SIZE = 1024 33_TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500) 34 35 36def _init_proxy_socket(gateway_address, gateway_port): 37 proxy_socket = socket.create_connection((gateway_address, gateway_port)) 38 return proxy_socket 39 40 41class TcpProxy(object): 42 """Proxies a TCP connection between one client and one server.""" 43 44 def __init__(self, bind_address, gateway_address, gateway_port): 45 self._bind_address = bind_address 46 self._gateway_address = gateway_address 47 self._gateway_port = gateway_port 48 49 self._byte_count_lock = threading.RLock() 50 self._sent_byte_count = 0 51 self._received_byte_count = 0 52 53 self._stop_event = threading.Event() 54 55 self._port = None 56 self._listen_socket = None 57 self._proxy_socket = None 58 59 # The following three attributes are owned by the serving thread. 60 self._northbound_data = b"" 61 self._southbound_data = b"" 62 self._client_sockets = [] 63 64 self._thread = threading.Thread(target=self._run_proxy) 65 66 def start(self): 67 _, self._port, self._listen_socket = get_socket( 68 bind_address=self._bind_address 69 ) 70 self._proxy_socket = _init_proxy_socket( 71 self._gateway_address, self._gateway_port 72 ) 73 self._thread.start() 74 75 def get_port(self): 76 return self._port 77 78 def _handle_reads(self, sockets_to_read): 79 for socket_to_read in sockets_to_read: 80 if socket_to_read is self._listen_socket: 81 client_socket, client_address = socket_to_read.accept() 82 self._client_sockets.append(client_socket) 83 elif socket_to_read is self._proxy_socket: 84 data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) 85 with self._byte_count_lock: 86 self._received_byte_count += len(data) 87 self._northbound_data += data 88 elif socket_to_read in self._client_sockets: 89 data = socket_to_read.recv(_TCP_PROXY_BUFFER_SIZE) 90 if data: 91 with self._byte_count_lock: 92 self._sent_byte_count += len(data) 93 self._southbound_data += data 94 else: 95 self._client_sockets.remove(socket_to_read) 96 else: 97 raise RuntimeError("Unidentified socket appeared in read set.") 98 99 def _handle_writes(self, sockets_to_write): 100 for socket_to_write in sockets_to_write: 101 if socket_to_write is self._proxy_socket: 102 if self._southbound_data: 103 self._proxy_socket.sendall(self._southbound_data) 104 self._southbound_data = b"" 105 elif socket_to_write in self._client_sockets: 106 if self._northbound_data: 107 socket_to_write.sendall(self._northbound_data) 108 self._northbound_data = b"" 109 110 def _run_proxy(self): 111 while not self._stop_event.is_set(): 112 expected_reads = (self._listen_socket, self._proxy_socket) + tuple( 113 self._client_sockets 114 ) 115 expected_writes = expected_reads 116 sockets_to_read, sockets_to_write, _ = select.select( 117 expected_reads, 118 expected_writes, 119 (), 120 _TCP_PROXY_TIMEOUT.total_seconds(), 121 ) 122 self._handle_reads(sockets_to_read) 123 self._handle_writes(sockets_to_write) 124 for client_socket in self._client_sockets: 125 client_socket.close() 126 127 def stop(self): 128 self._stop_event.set() 129 self._thread.join() 130 self._listen_socket.close() 131 self._proxy_socket.close() 132 133 def get_byte_count(self): 134 with self._byte_count_lock: 135 return self._sent_byte_count, self._received_byte_count 136 137 def reset_byte_count(self): 138 with self._byte_count_lock: 139 self._byte_count = 0 140 self._received_byte_count = 0 141 142 def __enter__(self): 143 self.start() 144 return self 145 146 def __exit__(self, exc_type, exc_val, exc_tb): 147 self.stop() 148