1#!/usr/bin/python3 2# 3# Copyright 2015 Google Inc. All Rights Reserved. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17"""Tests for the example portserver.""" 18 19import asyncio 20import os 21import signal 22import socket 23import subprocess 24import sys 25import time 26import unittest 27from unittest import mock 28from multiprocessing import Process 29 30import portpicker 31 32# On Windows, portserver.py is located in the "Scripts" folder, which isn't 33# added to the import path by default 34if sys.platform == 'win32': 35 sys.path.append(os.path.join(os.path.split(sys.executable)[0])) 36 37import portserver 38 39 40def setUpModule(): 41 portserver._configure_logging(verbose=True) 42 43def exit_immediately(): 44 os._exit(0) 45 46class PortserverFunctionsTest(unittest.TestCase): 47 48 @classmethod 49 def setUp(cls): 50 cls.port = portpicker.PickUnusedPort() 51 52 def test_get_process_command_line(self): 53 portserver._get_process_command_line(os.getpid()) 54 55 def test_get_process_start_time(self): 56 self.assertGreater(portserver._get_process_start_time(os.getpid()), 0) 57 58 def test_is_port_free(self): 59 """This might be flaky unless this test is run with a portserver.""" 60 # The port should be free initially. 61 self.assertTrue(portserver._is_port_free(self.port)) 62 63 cases = [ 64 (socket.AF_INET, socket.SOCK_STREAM, None), 65 (socket.AF_INET6, socket.SOCK_STREAM, 1), 66 (socket.AF_INET, socket.SOCK_DGRAM, None), 67 (socket.AF_INET6, socket.SOCK_DGRAM, 1), 68 ] 69 70 # Using v6only=0 on Windows doesn't result in collisions 71 if sys.platform != 'win32': 72 cases.extend([ 73 (socket.AF_INET6, socket.SOCK_STREAM, 0), 74 (socket.AF_INET6, socket.SOCK_DGRAM, 0), 75 ]) 76 77 for (sock_family, sock_type, v6only) in cases: 78 # Occupy the port on a subset of possible protocols. 79 try: 80 sock = socket.socket(sock_family, sock_type, 0) 81 except socket.error: 82 print('Kernel does not support sock_family=%d' % sock_family, 83 file=sys.stderr) 84 # Skip this case, since we cannot occupy a port. 85 continue 86 87 if not hasattr(socket, 'IPPROTO_IPV6'): 88 v6only = None 89 90 if v6only is not None: 91 try: 92 sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 93 v6only) 94 except socket.error: 95 print('Kernel does not support IPV6_V6ONLY=%d' % v6only, 96 file=sys.stderr) 97 # Don't care; just proceed with the default. 98 sock.bind(('', self.port)) 99 100 # The port should be busy. 101 self.assertFalse(portserver._is_port_free(self.port)) 102 sock.close() 103 104 # Now it's free again. 105 self.assertTrue(portserver._is_port_free(self.port)) 106 107 def test_is_port_free_exception(self): 108 with mock.patch.object(socket, 'socket') as mock_sock: 109 mock_sock.side_effect = socket.error('fake socket error', 0) 110 self.assertFalse(portserver._is_port_free(self.port)) 111 112 def test_should_allocate_port(self): 113 self.assertFalse(portserver._should_allocate_port(0)) 114 self.assertFalse(portserver._should_allocate_port(1)) 115 self.assertTrue(portserver._should_allocate_port, os.getpid()) 116 117 p = Process(target=exit_immediately) 118 p.start() 119 child_pid = p.pid 120 p.join() 121 122 # This test assumes that after waitpid returns the kernel has finished 123 # cleaning the process. We also assume that the kernel will not reuse 124 # the former child's pid before our next call checks for its existence. 125 # Likely assumptions, but not guaranteed. 126 self.assertFalse(portserver._should_allocate_port(child_pid)) 127 128 def test_parse_command_line(self): 129 with mock.patch.object( 130 sys, 'argv', ['program_name', '--verbose', 131 '--portserver_static_pool=1-1,3-8', 132 '--portserver_unix_socket_address=@hello-test']): 133 portserver._parse_command_line() 134 135 def test_parse_port_ranges(self): 136 self.assertFalse(portserver._parse_port_ranges('')) 137 self.assertCountEqual(portserver._parse_port_ranges('1-1'), {1}) 138 self.assertCountEqual(portserver._parse_port_ranges('1-1,3-8,375-378'), 139 {1, 3, 4, 5, 6, 7, 8, 375, 376, 377, 378}) 140 # Unparsable parts are logged but ignored. 141 self.assertEqual({1, 2}, 142 portserver._parse_port_ranges('1-2,not,numbers')) 143 self.assertEqual(set(), portserver._parse_port_ranges('8080-8081x')) 144 # Port ranges that go out of bounds are logged but ignored. 145 self.assertEqual(set(), portserver._parse_port_ranges('0-1138')) 146 self.assertEqual(set(range(19, 84 + 1)), 147 portserver._parse_port_ranges('1138-65536,19-84')) 148 149 def test_configure_logging(self): 150 """Just code coverage really.""" 151 portserver._configure_logging(False) 152 portserver._configure_logging(True) 153 154 155 _test_socket_addr = f'@TST-{os.getpid()}' 156 157 @mock.patch.object( 158 sys, 'argv', ['PortserverFunctionsTest.test_main', 159 f'--portserver_unix_socket_address={_test_socket_addr}'] 160 ) 161 @mock.patch.object(portserver, '_parse_port_ranges') 162 def test_main_no_ports(self, *unused_mocks): 163 portserver._parse_port_ranges.return_value = set() 164 with self.assertRaises(SystemExit): 165 portserver.main() 166 167 @unittest.skipUnless(sys.executable, 'Requires a stand alone interpreter') 168 @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'AF_UNIX required') 169 def test_portserver_binary(self): 170 """Launch python portserver.py and test it.""" 171 # Blindly assuming tree layout is src/tests/portserver_test.py 172 # with src/portserver.py. 173 portserver_py = os.path.join( 174 os.path.dirname(os.path.dirname(__file__)), 175 'portserver.py') 176 anon_addr = self._test_socket_addr.replace('@', '\0') 177 178 conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 179 with self.assertRaises( 180 ConnectionRefusedError, 181 msg=f'{self._test_socket_addr} should not listen yet.'): 182 conn.connect(anon_addr) 183 conn.close() 184 185 server = subprocess.Popen( 186 [sys.executable, portserver_py, 187 f'--portserver_unix_socket_address={self._test_socket_addr}'], 188 stderr=subprocess.PIPE, 189 ) 190 try: 191 # Wait a few seconds for the server to start listening. 192 start_time = time.monotonic() 193 while True: 194 time.sleep(0.05) 195 try: 196 conn.connect(anon_addr) 197 conn.close() 198 except ConnectionRefusedError: 199 delta = time.monotonic() - start_time 200 if delta < 4: 201 continue 202 else: 203 server.kill() 204 self.fail('Failed to connect to portserver ' 205 f'{self._test_socket_addr} within ' 206 f'{delta} seconds. STDERR:\n' + 207 server.stderr.read().decode('utf-8')) 208 else: 209 break 210 211 ports = set() 212 port = portpicker.get_port_from_port_server( 213 portserver_address=self._test_socket_addr) 214 ports.add(port) 215 port = portpicker.get_port_from_port_server( 216 portserver_address=self._test_socket_addr) 217 ports.add(port) 218 219 with subprocess.Popen('exit 0', shell=True) as quick_process: 220 quick_process.wait() 221 # This process doesn't exist so it should be a denied alloc. 222 # We use the pid from the above quick_process under the assumption 223 # that most OSes try to avoid rapid pid recycling. 224 denied_port = portpicker.get_port_from_port_server( 225 portserver_address=self._test_socket_addr, 226 pid=quick_process.pid) # A now unused pid. 227 self.assertIsNone(denied_port) 228 229 self.assertEqual(len(ports), 2, msg=ports) 230 231 # Check statistics from portserver 232 server.send_signal(signal.SIGUSR1) 233 # TODO implement an I/O timeout 234 for line in server.stderr: 235 if b'denied-allocations ' in line: 236 denied_allocations = int( 237 line.split(b'denied-allocations ', 2)[1]) 238 self.assertEqual(1, denied_allocations, msg=line) 239 elif b'total-allocations ' in line: 240 total_allocations = int( 241 line.split(b'total-allocations ', 2)[1]) 242 self.assertEqual(2, total_allocations, msg=line) 243 break 244 245 rejected_port = portpicker.get_port_from_port_server( 246 portserver_address=self._test_socket_addr, 247 pid=99999999999999999999999999999999999) # Out of range. 248 self.assertIsNone(rejected_port) 249 250 # Done. shutdown gracefully. 251 server.send_signal(signal.SIGINT) 252 server.communicate(timeout=2) 253 finally: 254 server.kill() 255 server.wait() 256 257 258class PortPoolTest(unittest.TestCase): 259 260 @classmethod 261 def setUpClass(cls): 262 cls.port = portpicker.PickUnusedPort() 263 264 def setUp(self): 265 self.pool = portserver._PortPool() 266 267 def test_initialization(self): 268 self.assertEqual(0, self.pool.num_ports()) 269 self.pool.add_port_to_free_pool(self.port) 270 self.assertEqual(1, self.pool.num_ports()) 271 self.pool.add_port_to_free_pool(1138) 272 self.assertEqual(2, self.pool.num_ports()) 273 self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 0) 274 self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 65536) 275 276 @mock.patch.object(portserver, '_is_port_free') 277 def test_get_port_for_process_ok(self, mock_is_port_free): 278 self.pool.add_port_to_free_pool(self.port) 279 mock_is_port_free.return_value = True 280 self.assertEqual(self.port, self.pool.get_port_for_process(os.getpid())) 281 self.assertEqual(1, self.pool.ports_checked_for_last_request) 282 283 @mock.patch.object(portserver, '_is_port_free') 284 def test_get_port_for_process_none_left(self, mock_is_port_free): 285 self.pool.add_port_to_free_pool(self.port) 286 self.pool.add_port_to_free_pool(22) 287 mock_is_port_free.return_value = False 288 self.assertEqual(2, self.pool.num_ports()) 289 self.assertEqual(0, self.pool.get_port_for_process(os.getpid())) 290 self.assertEqual(2, self.pool.num_ports()) 291 self.assertEqual(2, self.pool.ports_checked_for_last_request) 292 293 @mock.patch.object(portserver, '_is_port_free') 294 @mock.patch.object(os, 'getpid') 295 def test_get_port_for_process_pid_eq_port(self, mock_getpid, mock_is_port_free): 296 self.pool.add_port_to_free_pool(12345) 297 self.pool.add_port_to_free_pool(12344) 298 mock_is_port_free.side_effect = lambda port: port == os.getpid() 299 mock_getpid.return_value = 12345 300 self.assertEqual(2, self.pool.num_ports()) 301 self.assertEqual(12345, self.pool.get_port_for_process(os.getpid())) 302 self.assertEqual(2, self.pool.ports_checked_for_last_request) 303 304 @mock.patch.object(portserver, '_is_port_free') 305 @mock.patch.object(os, 'getpid') 306 def test_get_port_for_process_pid_ne_port(self, mock_getpid, mock_is_port_free): 307 self.pool.add_port_to_free_pool(12344) 308 self.pool.add_port_to_free_pool(12345) 309 mock_is_port_free.side_effect = lambda port: port != os.getpid() 310 mock_getpid.return_value = 12345 311 self.assertEqual(2, self.pool.num_ports()) 312 self.assertEqual(12344, self.pool.get_port_for_process(os.getpid())) 313 self.assertEqual(2, self.pool.ports_checked_for_last_request) 314 315 316@mock.patch.object(portserver, '_get_process_command_line') 317@mock.patch.object(portserver, '_should_allocate_port') 318@mock.patch.object(portserver._PortPool, 'get_port_for_process') 319class PortServerRequestHandlerTest(unittest.TestCase): 320 def setUp(self): 321 portserver._configure_logging(verbose=True) 322 self.rh = portserver._PortServerRequestHandler([23, 42, 54]) 323 324 def test_stats_reporting(self, *unused_mocks): 325 with mock.patch.object(portserver, 'log') as mock_logger: 326 self.rh.dump_stats() 327 mock_logger.info.assert_called_with('total-allocations 0') 328 329 def test_handle_port_request_bad_data(self, *unused_mocks): 330 self._test_bad_data_from_client(b'') 331 self._test_bad_data_from_client(b'\n') 332 self._test_bad_data_from_client(b'99Z\n') 333 self._test_bad_data_from_client(b'99 8\n') 334 self.assertEqual([], portserver._get_process_command_line.mock_calls) 335 336 def _test_bad_data_from_client(self, data): 337 mock_writer = mock.Mock(asyncio.StreamWriter) 338 self.rh._handle_port_request(data, mock_writer) 339 self.assertFalse(portserver._should_allocate_port.mock_calls) 340 341 def test_handle_port_request_denied_allocation(self, *unused_mocks): 342 portserver._should_allocate_port.return_value = False 343 self.assertEqual(0, self.rh._denied_allocations) 344 mock_writer = mock.Mock(asyncio.StreamWriter) 345 self.rh._handle_port_request(b'5\n', mock_writer) 346 self.assertEqual(1, self.rh._denied_allocations) 347 348 def test_handle_port_request_bad_port_returned(self, *unused_mocks): 349 portserver._should_allocate_port.return_value = True 350 self.rh._port_pool.get_port_for_process.return_value = 0 351 mock_writer = mock.Mock(asyncio.StreamWriter) 352 self.rh._handle_port_request(b'6\n', mock_writer) 353 self.rh._port_pool.get_port_for_process.assert_called_once_with(6) 354 self.assertEqual(1, self.rh._denied_allocations) 355 356 def test_handle_port_request_success(self, *unused_mocks): 357 portserver._should_allocate_port.return_value = True 358 self.rh._port_pool.get_port_for_process.return_value = 999 359 mock_writer = mock.Mock(asyncio.StreamWriter) 360 self.assertEqual(0, self.rh._total_allocations) 361 self.rh._handle_port_request(b'8', mock_writer) 362 portserver._should_allocate_port.assert_called_once_with(8) 363 self.rh._port_pool.get_port_for_process.assert_called_once_with(8) 364 self.assertEqual(1, self.rh._total_allocations) 365 self.assertEqual(0, self.rh._denied_allocations) 366 mock_writer.write.assert_called_once_with(b'999\n') 367 368 369if __name__ == '__main__': 370 unittest.main() 371