1#!/usr/bin/python3
2#
3# Copyright 2015 Google Inc. All Rights Reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17"""Tests for the example portserver."""
18
19import asyncio
20import os
21import signal
22import socket
23import subprocess
24import sys
25import time
26import unittest
27from unittest import mock
28from multiprocessing import Process
29
30import portpicker
31
32# On Windows, portserver.py is located in the "Scripts" folder, which isn't
33# added to the import path by default
34if sys.platform == 'win32':
35    sys.path.append(os.path.join(os.path.split(sys.executable)[0]))
36
37import portserver
38
39
40def setUpModule():
41    portserver._configure_logging(verbose=True)
42
43def exit_immediately():
44    os._exit(0)
45
46class PortserverFunctionsTest(unittest.TestCase):
47
48    @classmethod
49    def setUp(cls):
50        cls.port = portpicker.PickUnusedPort()
51
52    def test_get_process_command_line(self):
53        portserver._get_process_command_line(os.getpid())
54
55    def test_get_process_start_time(self):
56        self.assertGreater(portserver._get_process_start_time(os.getpid()), 0)
57
58    def test_is_port_free(self):
59        """This might be flaky unless this test is run with a portserver."""
60        # The port should be free initially.
61        self.assertTrue(portserver._is_port_free(self.port))
62
63        cases = [
64            (socket.AF_INET,  socket.SOCK_STREAM, None),
65            (socket.AF_INET6, socket.SOCK_STREAM, 1),
66            (socket.AF_INET,  socket.SOCK_DGRAM,  None),
67            (socket.AF_INET6, socket.SOCK_DGRAM,  1),
68        ]
69
70        # Using v6only=0 on Windows doesn't result in collisions
71        if sys.platform != 'win32':
72            cases.extend([
73                (socket.AF_INET6, socket.SOCK_STREAM, 0),
74                (socket.AF_INET6, socket.SOCK_DGRAM,  0),
75            ])
76
77        for (sock_family, sock_type, v6only) in cases:
78            # Occupy the port on a subset of possible protocols.
79            try:
80                sock = socket.socket(sock_family, sock_type, 0)
81            except socket.error:
82                print('Kernel does not support sock_family=%d' % sock_family,
83                      file=sys.stderr)
84                # Skip this case, since we cannot occupy a port.
85                continue
86
87            if not hasattr(socket, 'IPPROTO_IPV6'):
88                v6only = None
89
90            if v6only is not None:
91                try:
92                    sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY,
93                                    v6only)
94                except socket.error:
95                    print('Kernel does not support IPV6_V6ONLY=%d' % v6only,
96                          file=sys.stderr)
97                    # Don't care; just proceed with the default.
98            sock.bind(('', self.port))
99
100            # The port should be busy.
101            self.assertFalse(portserver._is_port_free(self.port))
102            sock.close()
103
104            # Now it's free again.
105            self.assertTrue(portserver._is_port_free(self.port))
106
107    def test_is_port_free_exception(self):
108        with mock.patch.object(socket, 'socket') as mock_sock:
109            mock_sock.side_effect = socket.error('fake socket error', 0)
110            self.assertFalse(portserver._is_port_free(self.port))
111
112    def test_should_allocate_port(self):
113        self.assertFalse(portserver._should_allocate_port(0))
114        self.assertFalse(portserver._should_allocate_port(1))
115        self.assertTrue(portserver._should_allocate_port, os.getpid())
116
117        p = Process(target=exit_immediately)
118        p.start()
119        child_pid = p.pid
120        p.join()
121
122        # This test assumes that after waitpid returns the kernel has finished
123        # cleaning the process.  We also assume that the kernel will not reuse
124        # the former child's pid before our next call checks for its existence.
125        # Likely assumptions, but not guaranteed.
126        self.assertFalse(portserver._should_allocate_port(child_pid))
127
128    def test_parse_command_line(self):
129        with mock.patch.object(
130            sys, 'argv', ['program_name', '--verbose',
131                          '--portserver_static_pool=1-1,3-8',
132                          '--portserver_unix_socket_address=@hello-test']):
133            portserver._parse_command_line()
134
135    def test_parse_port_ranges(self):
136        self.assertFalse(portserver._parse_port_ranges(''))
137        self.assertCountEqual(portserver._parse_port_ranges('1-1'), {1})
138        self.assertCountEqual(portserver._parse_port_ranges('1-1,3-8,375-378'),
139                              {1, 3, 4, 5, 6, 7, 8, 375, 376, 377, 378})
140        # Unparsable parts are logged but ignored.
141        self.assertEqual({1, 2},
142                         portserver._parse_port_ranges('1-2,not,numbers'))
143        self.assertEqual(set(), portserver._parse_port_ranges('8080-8081x'))
144        # Port ranges that go out of bounds are logged but ignored.
145        self.assertEqual(set(), portserver._parse_port_ranges('0-1138'))
146        self.assertEqual(set(range(19, 84 + 1)),
147                         portserver._parse_port_ranges('1138-65536,19-84'))
148
149    def test_configure_logging(self):
150        """Just code coverage really."""
151        portserver._configure_logging(False)
152        portserver._configure_logging(True)
153
154
155    _test_socket_addr = f'@TST-{os.getpid()}'
156
157    @mock.patch.object(
158        sys, 'argv', ['PortserverFunctionsTest.test_main',
159                      f'--portserver_unix_socket_address={_test_socket_addr}']
160    )
161    @mock.patch.object(portserver, '_parse_port_ranges')
162    def test_main_no_ports(self, *unused_mocks):
163        portserver._parse_port_ranges.return_value = set()
164        with self.assertRaises(SystemExit):
165            portserver.main()
166
167    @unittest.skipUnless(sys.executable, 'Requires a stand alone interpreter')
168    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'AF_UNIX required')
169    def test_portserver_binary(self):
170        """Launch python portserver.py and test it."""
171        # Blindly assuming tree layout is src/tests/portserver_test.py
172        # with src/portserver.py.
173        portserver_py = os.path.join(
174                os.path.dirname(os.path.dirname(__file__)),
175                'portserver.py')
176        anon_addr = self._test_socket_addr.replace('@', '\0')
177
178        conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
179        with self.assertRaises(
180                ConnectionRefusedError,
181                msg=f'{self._test_socket_addr} should not listen yet.'):
182            conn.connect(anon_addr)
183            conn.close()
184
185        server = subprocess.Popen(
186            [sys.executable, portserver_py,
187             f'--portserver_unix_socket_address={self._test_socket_addr}'],
188            stderr=subprocess.PIPE,
189        )
190        try:
191            # Wait a few seconds for the server to start listening.
192            start_time = time.monotonic()
193            while True:
194                time.sleep(0.05)
195                try:
196                    conn.connect(anon_addr)
197                    conn.close()
198                except ConnectionRefusedError:
199                    delta = time.monotonic() - start_time
200                    if delta < 4:
201                        continue
202                    else:
203                        server.kill()
204                        self.fail('Failed to connect to portserver '
205                                  f'{self._test_socket_addr} within '
206                                  f'{delta} seconds. STDERR:\n' +
207                                  server.stderr.read().decode('utf-8'))
208                else:
209                    break
210
211            ports = set()
212            port = portpicker.get_port_from_port_server(
213                    portserver_address=self._test_socket_addr)
214            ports.add(port)
215            port = portpicker.get_port_from_port_server(
216                    portserver_address=self._test_socket_addr)
217            ports.add(port)
218
219            with subprocess.Popen('exit 0', shell=True) as quick_process:
220                quick_process.wait()
221            # This process doesn't exist so it should be a denied alloc.
222            # We use the pid from the above quick_process under the assumption
223            # that most OSes try to avoid rapid pid recycling.
224            denied_port = portpicker.get_port_from_port_server(
225                    portserver_address=self._test_socket_addr,
226                    pid=quick_process.pid)  # A now unused pid.
227            self.assertIsNone(denied_port)
228
229            self.assertEqual(len(ports), 2, msg=ports)
230
231            # Check statistics from portserver
232            server.send_signal(signal.SIGUSR1)
233            # TODO implement an I/O timeout
234            for line in server.stderr:
235                if b'denied-allocations ' in line:
236                    denied_allocations = int(
237                            line.split(b'denied-allocations ', 2)[1])
238                    self.assertEqual(1, denied_allocations, msg=line)
239                elif b'total-allocations ' in line:
240                    total_allocations = int(
241                            line.split(b'total-allocations ', 2)[1])
242                    self.assertEqual(2, total_allocations, msg=line)
243                    break
244
245            rejected_port = portpicker.get_port_from_port_server(
246                    portserver_address=self._test_socket_addr,
247                    pid=99999999999999999999999999999999999)  # Out of range.
248            self.assertIsNone(rejected_port)
249
250            # Done.  shutdown gracefully.
251            server.send_signal(signal.SIGINT)
252            server.communicate(timeout=2)
253        finally:
254            server.kill()
255            server.wait()
256
257
258class PortPoolTest(unittest.TestCase):
259
260    @classmethod
261    def setUpClass(cls):
262        cls.port = portpicker.PickUnusedPort()
263
264    def setUp(self):
265        self.pool = portserver._PortPool()
266
267    def test_initialization(self):
268        self.assertEqual(0, self.pool.num_ports())
269        self.pool.add_port_to_free_pool(self.port)
270        self.assertEqual(1, self.pool.num_ports())
271        self.pool.add_port_to_free_pool(1138)
272        self.assertEqual(2, self.pool.num_ports())
273        self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 0)
274        self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 65536)
275
276    @mock.patch.object(portserver, '_is_port_free')
277    def test_get_port_for_process_ok(self, mock_is_port_free):
278        self.pool.add_port_to_free_pool(self.port)
279        mock_is_port_free.return_value = True
280        self.assertEqual(self.port, self.pool.get_port_for_process(os.getpid()))
281        self.assertEqual(1, self.pool.ports_checked_for_last_request)
282
283    @mock.patch.object(portserver, '_is_port_free')
284    def test_get_port_for_process_none_left(self, mock_is_port_free):
285        self.pool.add_port_to_free_pool(self.port)
286        self.pool.add_port_to_free_pool(22)
287        mock_is_port_free.return_value = False
288        self.assertEqual(2, self.pool.num_ports())
289        self.assertEqual(0, self.pool.get_port_for_process(os.getpid()))
290        self.assertEqual(2, self.pool.num_ports())
291        self.assertEqual(2, self.pool.ports_checked_for_last_request)
292
293    @mock.patch.object(portserver, '_is_port_free')
294    @mock.patch.object(os, 'getpid')
295    def test_get_port_for_process_pid_eq_port(self, mock_getpid, mock_is_port_free):
296        self.pool.add_port_to_free_pool(12345)
297        self.pool.add_port_to_free_pool(12344)
298        mock_is_port_free.side_effect = lambda port: port == os.getpid()
299        mock_getpid.return_value = 12345
300        self.assertEqual(2, self.pool.num_ports())
301        self.assertEqual(12345, self.pool.get_port_for_process(os.getpid()))
302        self.assertEqual(2, self.pool.ports_checked_for_last_request)
303
304    @mock.patch.object(portserver, '_is_port_free')
305    @mock.patch.object(os, 'getpid')
306    def test_get_port_for_process_pid_ne_port(self, mock_getpid, mock_is_port_free):
307        self.pool.add_port_to_free_pool(12344)
308        self.pool.add_port_to_free_pool(12345)
309        mock_is_port_free.side_effect = lambda port: port != os.getpid()
310        mock_getpid.return_value = 12345
311        self.assertEqual(2, self.pool.num_ports())
312        self.assertEqual(12344, self.pool.get_port_for_process(os.getpid()))
313        self.assertEqual(2, self.pool.ports_checked_for_last_request)
314
315
316@mock.patch.object(portserver, '_get_process_command_line')
317@mock.patch.object(portserver, '_should_allocate_port')
318@mock.patch.object(portserver._PortPool, 'get_port_for_process')
319class PortServerRequestHandlerTest(unittest.TestCase):
320    def setUp(self):
321        portserver._configure_logging(verbose=True)
322        self.rh = portserver._PortServerRequestHandler([23, 42, 54])
323
324    def test_stats_reporting(self, *unused_mocks):
325        with mock.patch.object(portserver, 'log') as mock_logger:
326            self.rh.dump_stats()
327        mock_logger.info.assert_called_with('total-allocations 0')
328
329    def test_handle_port_request_bad_data(self, *unused_mocks):
330        self._test_bad_data_from_client(b'')
331        self._test_bad_data_from_client(b'\n')
332        self._test_bad_data_from_client(b'99Z\n')
333        self._test_bad_data_from_client(b'99 8\n')
334        self.assertEqual([], portserver._get_process_command_line.mock_calls)
335
336    def _test_bad_data_from_client(self, data):
337        mock_writer = mock.Mock(asyncio.StreamWriter)
338        self.rh._handle_port_request(data, mock_writer)
339        self.assertFalse(portserver._should_allocate_port.mock_calls)
340
341    def test_handle_port_request_denied_allocation(self, *unused_mocks):
342        portserver._should_allocate_port.return_value = False
343        self.assertEqual(0, self.rh._denied_allocations)
344        mock_writer = mock.Mock(asyncio.StreamWriter)
345        self.rh._handle_port_request(b'5\n', mock_writer)
346        self.assertEqual(1, self.rh._denied_allocations)
347
348    def test_handle_port_request_bad_port_returned(self, *unused_mocks):
349        portserver._should_allocate_port.return_value = True
350        self.rh._port_pool.get_port_for_process.return_value = 0
351        mock_writer = mock.Mock(asyncio.StreamWriter)
352        self.rh._handle_port_request(b'6\n', mock_writer)
353        self.rh._port_pool.get_port_for_process.assert_called_once_with(6)
354        self.assertEqual(1, self.rh._denied_allocations)
355
356    def test_handle_port_request_success(self, *unused_mocks):
357        portserver._should_allocate_port.return_value = True
358        self.rh._port_pool.get_port_for_process.return_value = 999
359        mock_writer = mock.Mock(asyncio.StreamWriter)
360        self.assertEqual(0, self.rh._total_allocations)
361        self.rh._handle_port_request(b'8', mock_writer)
362        portserver._should_allocate_port.assert_called_once_with(8)
363        self.rh._port_pool.get_port_for_process.assert_called_once_with(8)
364        self.assertEqual(1, self.rh._total_allocations)
365        self.assertEqual(0, self.rh._denied_allocations)
366        mock_writer.write.assert_called_once_with(b'999\n')
367
368
369if __name__ == '__main__':
370    unittest.main()
371