xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/unit/_signal_handling_test.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2019 the gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test of responsiveness to signals."""
15
16import logging
17import os
18import signal
19import subprocess
20import sys
21import tempfile
22import threading
23import unittest
24
25import grpc
26
27from tests.unit import _signal_client
28from tests.unit import test_common
29
30_CLIENT_PATH = None
31if sys.executable is not None:
32    _CLIENT_PATH = os.path.abspath(os.path.realpath(_signal_client.__file__))
33else:
34    # NOTE(rbellevi): For compatibility with internal testing.
35    if len(sys.argv) != 2:
36        raise RuntimeError("Must supply path to executable client.")
37    client_name = sys.argv[1].split("/")[-1]
38    del sys.argv[1]  # For compatibility with test runner.
39    _CLIENT_PATH = os.path.realpath(
40        os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name)
41    )
42
43_HOST = "localhost"
44
45# The gevent test harness cannot run the monkeypatch code for the child process,
46# so we need to instrument it manually.
47_GEVENT_ARG = ("--gevent",) if test_common.running_under_gevent() else ()
48
49
50class _GenericHandler(grpc.GenericRpcHandler):
51    def __init__(self):
52        self._connected_clients_lock = threading.RLock()
53        self._connected_clients_event = threading.Event()
54        self._connected_clients = 0
55
56        self._unary_unary_handler = grpc.unary_unary_rpc_method_handler(
57            self._handle_unary_unary
58        )
59        self._unary_stream_handler = grpc.unary_stream_rpc_method_handler(
60            self._handle_unary_stream
61        )
62
63    def _on_client_connect(self):
64        with self._connected_clients_lock:
65            self._connected_clients += 1
66            self._connected_clients_event.set()
67
68    def _on_client_disconnect(self):
69        with self._connected_clients_lock:
70            self._connected_clients -= 1
71            if self._connected_clients == 0:
72                self._connected_clients_event.clear()
73
74    def await_connected_client(self):
75        """Blocks until a client connects to the server."""
76        self._connected_clients_event.wait()
77
78    def _handle_unary_unary(self, request, servicer_context):
79        """Handles a unary RPC.
80
81        Blocks until the client disconnects and then echoes.
82        """
83        stop_event = threading.Event()
84
85        def on_rpc_end():
86            self._on_client_disconnect()
87            stop_event.set()
88
89        servicer_context.add_callback(on_rpc_end)
90        self._on_client_connect()
91        stop_event.wait()
92        return request
93
94    def _handle_unary_stream(self, request, servicer_context):
95        """Handles a server streaming RPC.
96
97        Blocks until the client disconnects and then echoes.
98        """
99        stop_event = threading.Event()
100
101        def on_rpc_end():
102            self._on_client_disconnect()
103            stop_event.set()
104
105        servicer_context.add_callback(on_rpc_end)
106        self._on_client_connect()
107        stop_event.wait()
108        yield request
109
110    def service(self, handler_call_details):
111        if handler_call_details.method == _signal_client.UNARY_UNARY:
112            return self._unary_unary_handler
113        elif handler_call_details.method == _signal_client.UNARY_STREAM:
114            return self._unary_stream_handler
115        else:
116            return None
117
118
119def _read_stream(stream):
120    stream.seek(0)
121    return stream.read()
122
123
124def _start_client(args, stdout, stderr):
125    invocation = None
126    if sys.executable is not None:
127        invocation = (sys.executable, _CLIENT_PATH) + tuple(args)
128    else:
129        invocation = (_CLIENT_PATH,) + tuple(args)
130    return subprocess.Popen(invocation, stdout=stdout, stderr=stderr)
131
132
133class SignalHandlingTest(unittest.TestCase):
134    def setUp(self):
135        self._server = test_common.test_server()
136        self._port = self._server.add_insecure_port("{}:0".format(_HOST))
137        self._handler = _GenericHandler()
138        self._server.add_generic_rpc_handlers((self._handler,))
139        self._server.start()
140
141    def tearDown(self):
142        self._server.stop(None)
143
144    @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows")
145    def testUnary(self):
146        """Tests that the server unary code path does not stall signal handlers."""
147        server_target = "{}:{}".format(_HOST, self._port)
148        with tempfile.TemporaryFile(mode="r") as client_stdout:
149            with tempfile.TemporaryFile(mode="r") as client_stderr:
150                client = _start_client(
151                    (server_target, "unary") + _GEVENT_ARG,
152                    client_stdout,
153                    client_stderr,
154                )
155                self._handler.await_connected_client()
156                client.send_signal(signal.SIGINT)
157                self.assertFalse(client.wait(), msg=_read_stream(client_stderr))
158                client_stdout.seek(0)
159                self.assertIn(
160                    _signal_client.SIGTERM_MESSAGE, client_stdout.read()
161                )
162
163    @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows")
164    def testStreaming(self):
165        """Tests that the server streaming code path does not stall signal handlers."""
166        server_target = "{}:{}".format(_HOST, self._port)
167        with tempfile.TemporaryFile(mode="r") as client_stdout:
168            with tempfile.TemporaryFile(mode="r") as client_stderr:
169                client = _start_client(
170                    (server_target, "streaming") + _GEVENT_ARG,
171                    client_stdout,
172                    client_stderr,
173                )
174                self._handler.await_connected_client()
175                client.send_signal(signal.SIGINT)
176                self.assertFalse(client.wait(), msg=_read_stream(client_stderr))
177                client_stdout.seek(0)
178                self.assertIn(
179                    _signal_client.SIGTERM_MESSAGE, client_stdout.read()
180                )
181
182    @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows")
183    def testUnaryWithException(self):
184        server_target = "{}:{}".format(_HOST, self._port)
185        with tempfile.TemporaryFile(mode="r") as client_stdout:
186            with tempfile.TemporaryFile(mode="r") as client_stderr:
187                client = _start_client(
188                    ("--exception", server_target, "unary") + _GEVENT_ARG,
189                    client_stdout,
190                    client_stderr,
191                )
192                self._handler.await_connected_client()
193                client.send_signal(signal.SIGINT)
194                client.wait()
195                self.assertEqual(0, client.returncode)
196
197    @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows")
198    def testStreamingHandlerWithException(self):
199        server_target = "{}:{}".format(_HOST, self._port)
200        with tempfile.TemporaryFile(mode="r") as client_stdout:
201            with tempfile.TemporaryFile(mode="r") as client_stderr:
202                client = _start_client(
203                    ("--exception", server_target, "streaming") + _GEVENT_ARG,
204                    client_stdout,
205                    client_stderr,
206                )
207                self._handler.await_connected_client()
208                client.send_signal(signal.SIGINT)
209                client.wait()
210                print(_read_stream(client_stderr))
211                self.assertEqual(0, client.returncode)
212
213
214if __name__ == "__main__":
215    logging.basicConfig()
216    unittest.main(verbosity=2)
217