1# Copyright 2019 the gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test of responsiveness to signals.""" 15 16import logging 17import os 18import signal 19import subprocess 20import sys 21import tempfile 22import threading 23import unittest 24 25import grpc 26 27from tests.unit import _signal_client 28from tests.unit import test_common 29 30_CLIENT_PATH = None 31if sys.executable is not None: 32 _CLIENT_PATH = os.path.abspath(os.path.realpath(_signal_client.__file__)) 33else: 34 # NOTE(rbellevi): For compatibility with internal testing. 35 if len(sys.argv) != 2: 36 raise RuntimeError("Must supply path to executable client.") 37 client_name = sys.argv[1].split("/")[-1] 38 del sys.argv[1] # For compatibility with test runner. 39 _CLIENT_PATH = os.path.realpath( 40 os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name) 41 ) 42 43_HOST = "localhost" 44 45# The gevent test harness cannot run the monkeypatch code for the child process, 46# so we need to instrument it manually. 47_GEVENT_ARG = ("--gevent",) if test_common.running_under_gevent() else () 48 49 50class _GenericHandler(grpc.GenericRpcHandler): 51 def __init__(self): 52 self._connected_clients_lock = threading.RLock() 53 self._connected_clients_event = threading.Event() 54 self._connected_clients = 0 55 56 self._unary_unary_handler = grpc.unary_unary_rpc_method_handler( 57 self._handle_unary_unary 58 ) 59 self._unary_stream_handler = grpc.unary_stream_rpc_method_handler( 60 self._handle_unary_stream 61 ) 62 63 def _on_client_connect(self): 64 with self._connected_clients_lock: 65 self._connected_clients += 1 66 self._connected_clients_event.set() 67 68 def _on_client_disconnect(self): 69 with self._connected_clients_lock: 70 self._connected_clients -= 1 71 if self._connected_clients == 0: 72 self._connected_clients_event.clear() 73 74 def await_connected_client(self): 75 """Blocks until a client connects to the server.""" 76 self._connected_clients_event.wait() 77 78 def _handle_unary_unary(self, request, servicer_context): 79 """Handles a unary RPC. 80 81 Blocks until the client disconnects and then echoes. 82 """ 83 stop_event = threading.Event() 84 85 def on_rpc_end(): 86 self._on_client_disconnect() 87 stop_event.set() 88 89 servicer_context.add_callback(on_rpc_end) 90 self._on_client_connect() 91 stop_event.wait() 92 return request 93 94 def _handle_unary_stream(self, request, servicer_context): 95 """Handles a server streaming RPC. 96 97 Blocks until the client disconnects and then echoes. 98 """ 99 stop_event = threading.Event() 100 101 def on_rpc_end(): 102 self._on_client_disconnect() 103 stop_event.set() 104 105 servicer_context.add_callback(on_rpc_end) 106 self._on_client_connect() 107 stop_event.wait() 108 yield request 109 110 def service(self, handler_call_details): 111 if handler_call_details.method == _signal_client.UNARY_UNARY: 112 return self._unary_unary_handler 113 elif handler_call_details.method == _signal_client.UNARY_STREAM: 114 return self._unary_stream_handler 115 else: 116 return None 117 118 119def _read_stream(stream): 120 stream.seek(0) 121 return stream.read() 122 123 124def _start_client(args, stdout, stderr): 125 invocation = None 126 if sys.executable is not None: 127 invocation = (sys.executable, _CLIENT_PATH) + tuple(args) 128 else: 129 invocation = (_CLIENT_PATH,) + tuple(args) 130 return subprocess.Popen(invocation, stdout=stdout, stderr=stderr) 131 132 133class SignalHandlingTest(unittest.TestCase): 134 def setUp(self): 135 self._server = test_common.test_server() 136 self._port = self._server.add_insecure_port("{}:0".format(_HOST)) 137 self._handler = _GenericHandler() 138 self._server.add_generic_rpc_handlers((self._handler,)) 139 self._server.start() 140 141 def tearDown(self): 142 self._server.stop(None) 143 144 @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") 145 def testUnary(self): 146 """Tests that the server unary code path does not stall signal handlers.""" 147 server_target = "{}:{}".format(_HOST, self._port) 148 with tempfile.TemporaryFile(mode="r") as client_stdout: 149 with tempfile.TemporaryFile(mode="r") as client_stderr: 150 client = _start_client( 151 (server_target, "unary") + _GEVENT_ARG, 152 client_stdout, 153 client_stderr, 154 ) 155 self._handler.await_connected_client() 156 client.send_signal(signal.SIGINT) 157 self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) 158 client_stdout.seek(0) 159 self.assertIn( 160 _signal_client.SIGTERM_MESSAGE, client_stdout.read() 161 ) 162 163 @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") 164 def testStreaming(self): 165 """Tests that the server streaming code path does not stall signal handlers.""" 166 server_target = "{}:{}".format(_HOST, self._port) 167 with tempfile.TemporaryFile(mode="r") as client_stdout: 168 with tempfile.TemporaryFile(mode="r") as client_stderr: 169 client = _start_client( 170 (server_target, "streaming") + _GEVENT_ARG, 171 client_stdout, 172 client_stderr, 173 ) 174 self._handler.await_connected_client() 175 client.send_signal(signal.SIGINT) 176 self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) 177 client_stdout.seek(0) 178 self.assertIn( 179 _signal_client.SIGTERM_MESSAGE, client_stdout.read() 180 ) 181 182 @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") 183 def testUnaryWithException(self): 184 server_target = "{}:{}".format(_HOST, self._port) 185 with tempfile.TemporaryFile(mode="r") as client_stdout: 186 with tempfile.TemporaryFile(mode="r") as client_stderr: 187 client = _start_client( 188 ("--exception", server_target, "unary") + _GEVENT_ARG, 189 client_stdout, 190 client_stderr, 191 ) 192 self._handler.await_connected_client() 193 client.send_signal(signal.SIGINT) 194 client.wait() 195 self.assertEqual(0, client.returncode) 196 197 @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") 198 def testStreamingHandlerWithException(self): 199 server_target = "{}:{}".format(_HOST, self._port) 200 with tempfile.TemporaryFile(mode="r") as client_stdout: 201 with tempfile.TemporaryFile(mode="r") as client_stderr: 202 client = _start_client( 203 ("--exception", server_target, "streaming") + _GEVENT_ARG, 204 client_stdout, 205 client_stderr, 206 ) 207 self._handler.await_connected_client() 208 client.send_signal(signal.SIGINT) 209 client.wait() 210 print(_read_stream(client_stderr)) 211 self.assertEqual(0, client.returncode) 212 213 214if __name__ == "__main__": 215 logging.basicConfig() 216 unittest.main(verbosity=2) 217