1# Owner(s): ["oncall: distributed"] 2 3import datetime 4import os 5import socket 6import struct 7import sys 8import tempfile 9import threading 10import time 11from datetime import timedelta 12from sys import platform 13 14import torch 15import torch.distributed as dist 16import torch.distributed.distributed_c10d as c10d 17import torch.distributed.rpc as rpc 18from torch.distributed import DistError, DistNetworkError, DistStoreError 19from torch.testing._internal.common_distributed import MultiThreadedTestCase 20from torch.testing._internal.common_utils import instantiate_parametrized_tests 21 22 23if not dist.is_available(): 24 print("torch.distributed not available, skipping tests", file=sys.stderr) 25 sys.exit(0) 26 27import torch.testing._internal.common_utils as common 28from torch.testing._internal.common_distributed import ( 29 create_tcp_store, 30 skip_if_win32, 31 tp_transports, 32) 33from torch.testing._internal.common_utils import ( 34 ADDRESS_IN_USE, 35 CONNECT_TIMEOUT, 36 load_tests, 37 retry_on_connect_failures, 38 run_tests, 39 TestCase, 40) 41 42 43# load_tests from common_utils is used to automatically filter tests for 44# sharding on sandcastle. This line silences flake warnings 45load_tests = load_tests 46 47if platform == "darwin": 48 LOOPBACK = "lo0" 49else: 50 LOOPBACK = "lo" 51 52DEFAULT_HOSTNAME = "localhost" 53 54torch.backends.cuda.matmul.allow_tf32 = False 55 56 57def gpus_for_rank(world_size): 58 """Multigpu tests are designed to simulate the multi nodes with multi 59 GPUs on each node. Nccl backend requires equal #GPUs in each process. 60 On a single node, all visible GPUs are evenly 61 divided to subsets, each process only uses a subset. 62 """ 63 visible_devices = list(range(torch.cuda.device_count())) 64 gpus_per_process = torch.cuda.device_count() // world_size 65 gpus_for_rank = [] 66 for rank in range(world_size): 67 gpus_for_rank.append( 68 visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process] 69 ) 70 return gpus_for_rank 71 72 73class StoreTestBase: 74 def _create_store(self, i): 75 raise RuntimeError("not implemented") 76 77 def _test_set_get_check(self, fs): 78 fs.add("key", 1) 79 fs.add("key", 2) 80 fs.add("key", 3) 81 fs.set("key0", "value0") 82 fs.add("key3", 1) 83 fs.set("key1", "value1") 84 fs.add("key3", 2) 85 fs.set("key2", "value2") 86 fs.add("key3", 3) 87 fs.add("key3", 4) 88 fs.add("key3", 5) 89 fs.add("key3", 6) 90 self.assertEqual(fs.num_keys(), self.num_keys_total) 91 self.assertEqual(b"6", fs.get("key")) 92 self.assertEqual(b"value0", fs.get("key0")) 93 self.assertEqual(b"value1", fs.get("key1")) 94 self.assertEqual(b"value2", fs.get("key2")) 95 self.assertEqual(b"21", fs.get("key3")) 96 self.assertTrue(fs.check(["key3"])) 97 self.assertFalse(fs.check(["Randomkey3"])) 98 99 fs.set("-key3", "7") 100 self.assertEqual(b"7", fs.get("-key3")) 101 fs.delete_key("-key3") 102 self.assertEqual(fs.num_keys(), self.num_keys_total) 103 104 def test_set_get_check(self): 105 self._test_set_get_check(self._create_store()) 106 107 def _test_compare_set(self, store): 108 missing_key_result = store.compare_set( 109 "cs_key0", "wrong_old_value", "new_value0" 110 ) 111 self.assertEqual(b"wrong_old_value", missing_key_result) 112 113 store.set("cs_key0", "value0") 114 self.assertEqual(b"value0", store.get("cs_key0")) 115 old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0") 116 self.assertEqual(b"value0", old_value_result) 117 self.assertEqual(b"value0", store.get("cs_key0")) 118 new_value_result = store.compare_set("cs_key0", "value0", "new_value0") 119 self.assertEqual(b"new_value0", new_value_result) 120 self.assertEqual(b"new_value0", store.get("cs_key0")) 121 empty_old_value_result = store.compare_set("cs_key1", "", "new_value1") 122 self.assertEqual(b"new_value1", empty_old_value_result) 123 self.assertEqual(b"new_value1", store.get("cs_key1")) 124 125 def test_compare_set(self): 126 self._test_compare_set(self._create_store()) 127 128 def _test_simple_wait(self, fs): 129 with self.assertRaisesRegex(RuntimeError, "[t -i]imeout"): 130 fs.wait(["bad_key"], timedelta(seconds=0.25)) 131 fs.add("good_key", 1) 132 fs.wait(["good_key"]) 133 134 def test_simple_wait(self): 135 self._test_simple_wait(self._create_store()) 136 137 def _test_append(self, store): 138 if not store.has_extended_api(): 139 self.skipTest("Store doesn't support extended APIs") 140 store.set("foo", "po") 141 store.append("foo", "tato") 142 store.append("bar", "po") 143 store.append("bar", "tato") 144 self.assertEqual(b"potato", store.get("foo")) 145 self.assertEqual(b"potato", store.get("bar")) 146 147 def test_append(self): 148 self._test_append(self._create_store()) 149 150 def _test_multi_set(self, store): 151 if not store.has_extended_api(): 152 self.skipTest("Store doesn't support extended APIs") 153 store.multi_set(["foo", "bar"], ["po", "tato"]) 154 self.assertEqual(b"po", store.get("foo")) 155 self.assertEqual(b"tato", store.get("bar")) 156 157 def test_multi_set(self): 158 self._test_multi_set(self._create_store()) 159 160 def _test_multi_get(self, store): 161 if not store.has_extended_api(): 162 self.skipTest("Store doesn't support extended APIs") 163 store.set("foo", "po") 164 store.set("bar", "tato") 165 v0, v1 = store.multi_get(["foo", "bar"]) 166 self.assertEqual(b"po", v0) 167 self.assertEqual(b"tato", v1) 168 169 def test_multi_get(self): 170 self._test_multi_get(self._create_store()) 171 172 # This is the number of keys used in test_set_get. Adding this as a class 173 # property instead of hardcoding in the test since some Store 174 # implementations will have differing number of keys. In the base case, 175 # there will be 5 keys: key, key0, key1, key2, key3. 176 @property 177 def num_keys_total(self): 178 return 5 179 180 181class FileStoreTest(TestCase, StoreTestBase): 182 def setUp(self): 183 super().setUp() 184 self.file = tempfile.NamedTemporaryFile(delete=False) 185 186 def _create_store(self): 187 store = dist.FileStore(self.file.name, 1) 188 store.set_timeout(timedelta(seconds=300)) 189 return store 190 191 def test_init_pg_and_rpc_with_same_file(self): 192 file = tempfile.NamedTemporaryFile(delete=False) 193 # Init RPC using file 194 rpc_backend_options = rpc.TensorPipeRpcBackendOptions() 195 rpc_backend_options.init_method = f"file://{file.name}" 196 rpc_backend_options._transports = tp_transports() 197 rpc.init_rpc( 198 "worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options 199 ) 200 201 # Init PG using file 202 dist.init_process_group( 203 "gloo", rank=0, world_size=1, init_method=f"file://{file.name}" 204 ) 205 dist.destroy_process_group() 206 assert os.path.exists(file.name) 207 208 rpc.shutdown() 209 os.remove(file.name) 210 211 def test_refcount(self): 212 file = tempfile.NamedTemporaryFile(delete=False) 213 store = dist.FileStore(file.name, 1) 214 store2 = dist.FileStore(file.name, 1) 215 216 del store 217 assert os.path.exists(file.name) 218 del store2 219 assert not os.path.exists(file.name) 220 221 @property 222 def num_keys_total(self): 223 return 6 224 225 226@skip_if_win32() 227class HashStoreTest(TestCase, StoreTestBase): 228 def _create_store(self): 229 store = dist.HashStore() 230 store.set_timeout(timedelta(seconds=300)) 231 return store 232 233 234class PrefixStoreTest(TestCase): 235 def setUp(self): 236 # delete is false as FileStore will automatically clean up the file 237 self.file = tempfile.NamedTemporaryFile(delete=False) 238 239 def test_get_underlying_store(self): 240 tcp_store = dist.TCPStore( 241 host_name=DEFAULT_HOSTNAME, port=0, world_size=1, is_master=True 242 ) 243 hash_store = dist.HashStore() 244 file_store = dist.FileStore(self.file.name, world_size=1) 245 for store in [tcp_store, hash_store, file_store]: 246 with self.subTest(f"Testing getting underlying_store for {type(store)}"): 247 prefix_store = dist.PrefixStore("prefix", store) 248 self.assertEqual(prefix_store.underlying_store, store) 249 250 251class PrefixFileStoreTest(TestCase, StoreTestBase): 252 def setUp(self): 253 super().setUp() 254 self.file = tempfile.NamedTemporaryFile(delete=False) 255 self.filestore = dist.FileStore(self.file.name, 1) 256 self.prefix = "test_prefix" 257 self.filestore.set_timeout(timedelta(seconds=300)) 258 259 def _create_store(self): 260 return dist.PrefixStore(self.prefix, self.filestore) 261 262 @property 263 def num_keys_total(self): 264 return 6 265 266 267class TCPStoreTest(TestCase, StoreTestBase): 268 _use_libuv = False 269 270 def _create_store(self): 271 store = create_tcp_store(use_libuv=self._use_libuv) 272 store.set_timeout(timedelta(seconds=300)) 273 return store 274 275 def _create_store_with_ws(self, addr, world_size): 276 return create_tcp_store( 277 addr, world_size, wait_for_workers=False, use_libuv=self._use_libuv 278 ) 279 280 def test_address_already_in_use(self): 281 addr = DEFAULT_HOSTNAME 282 port = common.find_free_port() 283 284 err_msg_reg = f"^The server socket has failed to listen on any local .*{port}" 285 with self.assertRaisesRegex(RuntimeError, err_msg_reg): 286 # Use noqa to silence flake8. 287 # Need to store in an unused variable here to ensure the first 288 # object is not destroyed before the second object is created. 289 store1 = dist.TCPStore( 290 addr, port, 1, True, use_libuv=self._use_libuv 291 ) # noqa: F841 292 store2 = dist.TCPStore( 293 addr, port, 1, True, use_libuv=self._use_libuv 294 ) # noqa: F841 295 self.assertEqual(store1.libuvBackend, self._use_libuv) 296 self.assertEqual(store2.libuvBackend, self._use_libuv) 297 298 @retry_on_connect_failures 299 def test_multitenancy(self): 300 addr = DEFAULT_HOSTNAME 301 port = common.find_free_port() 302 303 # Use noqa to silence flake8. 304 # Need to store in an unused variable here to ensure the first 305 # object is not destroyed before the second object is created. 306 store1 = dist.TCPStore( 307 addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv 308 ) # type: ignore[call-arg] # noqa: F841 309 store2 = dist.TCPStore( 310 addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv 311 ) # type: ignore[call-arg] # noqa: F841 312 self.assertEqual(store1.libuvBackend, self._use_libuv) 313 self.assertEqual(store2.libuvBackend, self._use_libuv) 314 315 def test_repr(self) -> None: 316 # server 317 store1 = self._create_store() 318 self.assertRegex( 319 repr(store1), 320 r"TCPStore\(" 321 r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), " 322 r"server=TCPServer\(port=\d+\)\)", 323 ) 324 325 # client 326 store2 = dist.TCPStore( 327 store1.host, 328 store1.port, 329 world_size=2, 330 is_master=False, 331 ) 332 self.assertRegex( 333 repr(store2), 334 r"TCPStore\(" 335 r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), " 336 r"server=<nullptr>\)", 337 ) 338 339 @skip_if_win32() 340 @retry_on_connect_failures 341 def test_init_pg_and_rpc_with_same_socket(self): 342 addr = DEFAULT_HOSTNAME 343 port = common.find_free_port() 344 345 os.environ["MASTER_ADDR"] = addr 346 os.environ["MASTER_PORT"] = str(port) 347 348 # We internally use a multi-tenant TCP store. Both PG and RPC should successfully 349 # initialize even when using the same socket address. 350 351 os.environ["USE_LIBUV"] = "1" if self._use_libuv else "0" 352 dist.init_process_group( 353 backend="gloo", 354 init_method="env://", 355 rank=0, 356 world_size=1, 357 ) 358 359 backend_opts = rpc.TensorPipeRpcBackendOptions( 360 init_method=f"tcp://{addr}:{port}", _transports=tp_transports() 361 ) 362 rpc.init_rpc( 363 name="worker0", 364 rank=0, 365 world_size=1, 366 rpc_backend_options=backend_opts, 367 ) 368 369 del os.environ["USE_LIBUV"] 370 assert "USE_LIBUV" not in os.environ 371 rpc.shutdown() 372 dist.destroy_process_group() 373 374 @skip_if_win32() 375 def test_take_over_listen_socket(self): 376 listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 377 listen_sock.bind(("localhost", 0)) 378 addr, port, *_ = listen_sock.getsockname() 379 listen_fd = listen_sock.detach() 380 381 store = dist.TCPStore( 382 addr, 383 port, 384 1, 385 is_master=True, 386 master_listen_fd=listen_fd, 387 use_libuv=self._use_libuv, 388 ) 389 390 self.assertEqual(store.libuvBackend, self._use_libuv) 391 store.set("key", "value") 392 self.assertEqual(b"value", store.get("key")) 393 394 # The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by 395 # the user and one additional key used for coordinate all the workers. 396 @property 397 def num_keys_total(self): 398 return 6 399 400 def _test_numkeys_delkeys(self, fs): 401 # We start off with one init key in the store to coordinate workers 402 self.assertEqual(fs.num_keys(), 1) 403 fs.add("key", 1) 404 fs.add("key", 2) 405 fs.add("key", 3) 406 fs.set("key0", "value0") 407 fs.add("key3", 1) 408 fs.set("key1", "value1") 409 self.assertEqual(fs.num_keys(), 5) 410 fs.delete_key("key") 411 self.assertEqual(fs.num_keys(), 4) 412 fs.set_timeout(timedelta(seconds=2)) 413 with self.assertRaises(RuntimeError): 414 fs.get("key") 415 fs.delete_key("key0") 416 fs.delete_key("key3") 417 self.assertEqual(fs.num_keys(), 2) 418 fs.set("key4", "value2") 419 self.assertEqual(fs.num_keys(), 3) 420 self.assertEqual(b"value1", fs.get("key1")) 421 self.assertEqual(b"value2", fs.get("key4")) 422 423 def test_numkeys_delkeys(self): 424 self._test_numkeys_delkeys(self._create_store()) 425 426 def _create_client(self, index, addr, port, world_size): 427 client_store = dist.TCPStore( 428 addr, 429 port, 430 world_size=world_size, 431 timeout=timedelta(seconds=10), 432 use_libuv=self._use_libuv, 433 ) 434 self.assertEqual(b"value", client_store.get("key")) 435 client_store.set(f"new_key{index}", f"new_value{index}") 436 self.assertEqual( 437 f"next_value{index}".encode(), 438 client_store.compare_set( 439 f"new_key{index}", f"new_value{index}", f"next_value{index}" 440 ), 441 ) 442 443 def _multi_worker_helper(self, world_size): 444 addr = DEFAULT_HOSTNAME 445 server_store = self._create_store_with_ws(addr, world_size) 446 self.assertEqual(server_store.libuvBackend, self._use_libuv) 447 server_store.set("key", "value") 448 port = server_store.port 449 450 num_indices = world_size if world_size else 1 451 for i in range(num_indices): 452 self._create_client(i, addr, port, world_size) 453 454 def test_multi_worker_with_fixed_world_size(self): 455 self._multi_worker_helper(5) 456 457 def test_multi_worker_with_nonfixed_world_size(self): 458 self._multi_worker_helper(None) 459 460 def test_append(self): 461 store = self._create_store() 462 self.assertEqual(store.libuvBackend, self._use_libuv) 463 store.set("foo", "po") 464 store.append("foo", "tato") 465 store.append("bar", "po") 466 store.append("bar", "tato") 467 self.assertEqual(b"potato", store.get("foo")) 468 self.assertEqual(b"potato", store.get("bar")) 469 470 def test_multi_set(self): 471 store = self._create_store() 472 self.assertEqual(store.libuvBackend, self._use_libuv) 473 store.multi_set(["foo", "bar"], ["po", "tato"]) 474 self.assertEqual(b"po", store.get("foo")) 475 self.assertEqual(b"tato", store.get("bar")) 476 477 def test_multi_get(self): 478 store = self._create_store() 479 self.assertEqual(store.libuvBackend, self._use_libuv) 480 store.set("foo", "po") 481 store.set("bar", "tato") 482 v0, v1 = store.multi_get(["foo", "bar"]) 483 self.assertEqual(b"po", v0) 484 self.assertEqual(b"tato", v1) 485 486 def test_store_timeout_on_missing_clients(self): 487 with self.assertRaisesRegex( 488 DistStoreError, 489 r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined.", 490 ): 491 # world_size is 2 so it should timeout 492 dist.TCPStore( 493 "localhost", 494 0, 495 2, 496 True, 497 timeout=timedelta(seconds=2), 498 use_libuv=self._use_libuv, 499 ) 500 501 # when wait_for_workers is not set, then there should be no exception raised 502 dist.TCPStore( 503 "localhost", 504 0, 505 2, 506 True, 507 timeout=timedelta(seconds=2), 508 wait_for_workers=False, 509 use_libuv=self._use_libuv, 510 ) 511 512 513class LibUvTCPStoreTest(TCPStoreTest): 514 _use_libuv = True 515 516 def _create_store(self): 517 store = create_tcp_store(use_libuv=True) 518 store.set_timeout(timedelta(seconds=300)) 519 return store 520 521 def _create_store_with_ws(self, addr, world_size): 522 return create_tcp_store( 523 addr, world_size, wait_for_workers=False, use_libuv=True 524 ) 525 526 def test_take_over_listen_socket(self): 527 """ 528 override the take_over_listen_socket test in TCPStoreTest. 529 Reason: we have not thoroughly tested libuv TCPStore initialization using 530 open Socket so we decide to not support this use for now. 531 TODO (xilunwu): enable this use case 532 """ 533 listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 534 listen_sock.bind(("localhost", 0)) 535 addr, port, *_ = listen_sock.getsockname() 536 listen_fd = listen_sock.detach() 537 538 err_msg_reg = ( 539 "^The libuv TCPStore backend does not support " 540 "initialization with an listen fd" 541 ) 542 543 with self.assertRaisesRegex(NotImplementedError, err_msg_reg): 544 store = dist.TCPStore( 545 addr, 546 port, 547 1, 548 is_master=True, 549 master_listen_fd=listen_fd, 550 use_libuv=self._use_libuv, 551 ) 552 553 554class PrefixTCPStoreTest(TestCase, StoreTestBase): 555 def setUp(self): 556 super().setUp() 557 self.tcpstore = create_tcp_store() 558 self.prefix = "test_prefix" 559 self.tcpstore.set_timeout(timedelta(seconds=300)) 560 561 def _create_store(self): 562 return dist.PrefixStore(self.prefix, self.tcpstore) 563 564 # The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys 565 # added by the user and one additional key used for coordinate all the 566 # workers. 567 @property 568 def num_keys_total(self): 569 return 6 570 571 def test_underlying_non_prefix_store(self): 572 store = self._create_store() 573 wrapped_store = dist.PrefixStore( 574 self.prefix, dist.PrefixStore(self.prefix, store) 575 ) 576 self.assertEqual(self.tcpstore, store._underlying_non_prefix_store) 577 self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store) 578 579 580class MyPythonStore(dist.Store): 581 def __init__(self) -> None: 582 super().__init__() 583 self.store = {} 584 585 def set(self, key, value): 586 if not isinstance(key, (str, bytes)): 587 raise AssertionError("Expected set to be called with string key") 588 if type(value) is not bytes: 589 raise AssertionError("Expected set to be called with bytes value") 590 self.store[key] = value 591 592 def get(self, key): 593 value = self.store.get(key, b"") 594 if type(value) is not bytes: 595 raise AssertionError("Expected get to return bytes value") 596 return value 597 598 def add(self, key, value): 599 new = int(self.store.get(key, 0)) + value 600 self.set(key, bytes(str(new).encode("utf-8"))) 601 return new 602 603 def compare_set(self, key, expected, newValue): 604 if type(expected) is not bytes: 605 raise AssertionError("compare_set::expected not bytes") 606 if type(newValue) is not bytes: 607 raise AssertionError("compare_set::newValue not bytes") 608 609 val = self.store.get(key, None) 610 if expected == val or val is None: 611 val = self.store[key] = newValue 612 return val 613 614 615class PythonStoreTest(TestCase): 616 def test_set_get(self): 617 # If we were to inherit from StoreTestBase and try to use 618 # its test_set_get function, we would exercise the Python 619 # API directly, instead of going through the C++ trampoline. 620 # We care about testing the C++ trampoline, so run the 621 # equivalent of StoreTestBase.test_set_get from C++. 622 # See `torch/csrc/distributed/c10d/init.cpp` for the definition 623 # of this test function. 624 dist._test_python_store(MyPythonStore()) 625 626 627class RendezvousTest(TestCase): 628 def test_unknown_handler(self): 629 with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"): 630 dist.rendezvous("invalid://") 631 632 def test_url_with_node_params(self): 633 with self.assertRaisesRegex(AssertionError, "has node-specific arguments"): 634 dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16) 635 636 637class RendezvousEnvTest(TestCase): 638 @retry_on_connect_failures 639 def test_nominal(self): 640 os.environ["WORLD_SIZE"] = "1" 641 os.environ["MASTER_ADDR"] = "127.0.0.1" 642 os.environ["MASTER_PORT"] = str(common.find_free_port()) 643 644 # Single rank 645 os.environ["RANK"] = "0" 646 gen0 = dist.rendezvous("env://") 647 store0, rank0, size0 = next(gen0) 648 self.assertEqual(0, rank0) 649 self.assertEqual(1, size0) 650 651 store0.set("key0", "value0") 652 653 # check with get 654 self.assertEqual(b"value0", store0.get("key0")) 655 656 657class RendezvousFileTest(TestCase): 658 def test_common_errors(self): 659 with self.assertRaisesRegex(ValueError, "path missing"): 660 gen = dist.rendezvous("file://?rank=0&world_size=1") 661 next(gen) 662 with self.assertRaisesRegex(ValueError, "rank parameter missing"): 663 gen = dist.rendezvous("file:///tmp/foo?world_size=1") 664 next(gen) 665 with self.assertRaisesRegex(ValueError, "size parameter missing"): 666 gen = dist.rendezvous("file:///tmp/foo?rank=0") 667 next(gen) 668 669 def test_nominal(self): 670 with tempfile.NamedTemporaryFile(delete=False) as file: 671 url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2' 672 gen0 = dist.rendezvous(url + "&rank=0") 673 store0, rank0, size0 = next(gen0) 674 self.assertEqual(0, rank0) 675 self.assertEqual(2, size0) 676 gen1 = dist.rendezvous(url + "&rank=1") 677 store1, rank1, size1 = next(gen1) 678 self.assertEqual(1, rank1) 679 self.assertEqual(2, size1) 680 681 # Set value on both stores 682 store0.set("key0", "value0") 683 store1.set("key1", "value1") 684 685 # Cross check with get 686 self.assertEqual(b"value0", store1.get("key0")) 687 self.assertEqual(b"value1", store0.get("key1")) 688 689 690@skip_if_win32() 691class RendezvousTCPTest(TestCase): 692 def create_tcp_url(self): 693 addr = DEFAULT_HOSTNAME 694 port = common.find_free_port() 695 url = "tcp://%s:%d?world_size=%d" % (addr, port, 1) 696 return url 697 698 def test_common_errors(self): 699 with self.assertRaisesRegex(ValueError, "port number missing"): 700 gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1") 701 next(gen) 702 with self.assertRaisesRegex(ValueError, "rank parameter missing"): 703 gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1") 704 next(gen) 705 with self.assertRaisesRegex(ValueError, "size parameter missing"): 706 gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0") 707 next(gen) 708 709 def test_dns_timeout(self): 710 with self.assertRaisesRegex( 711 DistNetworkError, "client socket has timed out after.*dnsnotexist" 712 ) as manager: 713 gen = dist.rendezvous( 714 "tcp://dnsnotexist:23456?world_size=2&rank=0", 715 timeout=timedelta(seconds=1), 716 ) 717 next(gen) 718 self.assertTrue(isinstance(manager.exception, DistError)) 719 720 @retry_on_connect_failures 721 def test_nominal(self): 722 url = self.create_tcp_url() 723 gen0 = dist.rendezvous(url + "&rank=0") 724 store0, rank0, size0 = next(gen0) 725 self.assertEqual(0, rank0) 726 self.assertEqual(1, size0) 727 728 # Set value on the single store 729 store0.set("key0", "value0") 730 731 # check with get 732 self.assertEqual(b"value0", store0.get("key0")) 733 734 @retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE)) 735 def test_tcp_store_timeout_set(self): 736 url = self.create_tcp_url() 737 test_store_timeout = timedelta(seconds=0.1) 738 gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10)) 739 store0, rank0, size0 = next(gen0) 740 store0.set_timeout(test_store_timeout) 741 # this should time out in 0.1s. If the timeout passed into rendezvous was 742 # not respected, it will take much longer to timeout. 743 start = time.time() 744 with self.assertRaisesRegex( 745 DistStoreError, "wait timeout after 100ms, keys: /nonexistant key" 746 ): 747 store0.get("nonexistant key") 748 749 end = time.time() 750 time_diff = end - start 751 self.assertGreater(10, time_diff) 752 753 def test_tcp_store_timeout_doest_break_client(self): 754 url = self.create_tcp_url() 755 test_store_timeout = timedelta(seconds=0.1) 756 gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10)) 757 store0, rank0, size0 = next(gen0) 758 store0.set_timeout(test_store_timeout) 759 # this should time out in 10s. If the timeout passed into rendezvous was 760 # not respected, it will take much longer to timeout. 761 start = time.time() 762 with self.assertRaisesRegex( 763 DistStoreError, "wait timeout after 100ms, keys: /the_key" 764 ): 765 store0.get("the_key") 766 767 store0.set("the_key", "x") 768 769 self.assertEqual(b"x", store0.get("the_key")) 770 771 end = time.time() 772 time_diff = end - start 773 self.assertGreater(10, time_diff) 774 775 def test_tcp_store_url_with_libuv(self): 776 url = self.create_tcp_url() 777 gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1") 778 store0, rank0, size0 = next(gen0) 779 self.assertTrue(store0.libuvBackend) 780 781 782class DummyStore(dist.Store): 783 def __init__(self) -> None: 784 self.appends = [] 785 self.multi_sets = [] 786 self.multi_gets = [] 787 self.multi_get_res = [] 788 super().__init__() 789 790 def append(self, key, value): 791 self.appends.append((key, value)) 792 793 def multi_get(self, keys): 794 self.multi_gets.append(keys) 795 return self.multi_get_res.pop(0) 796 797 def multi_set(self, keys, values): 798 self.multi_sets.append((keys, values)) 799 800 def has_extended_api(self): 801 return True 802 803 804class TestPythonStore(TestCase): 805 def test_optional_methods_fail(self): 806 class TestStore(dist.Store): 807 pass 808 809 store = TestStore() 810 self.assertFalse(store.has_extended_api()) 811 with self.assertRaisesRegex(RuntimeError, "Not implemented."): 812 store.append("foo", "bar") 813 with self.assertRaisesRegex(RuntimeError, "Not implemented."): 814 store.multi_get(["foo", "bar"]) 815 with self.assertRaisesRegex(RuntimeError, "Not implemented."): 816 store.multi_set(["foo", "bar"], [b"v", b"v"]) 817 818 def test_has_extended_api_passthrough(self): 819 class TestStore(dist.Store): 820 pass 821 822 test_store = TestStore() 823 store = dist.PrefixStore("p", test_store) 824 self.assertFalse(store.has_extended_api()) 825 with self.assertRaisesRegex(RuntimeError, "Not implemented."): 826 store.append("foo", "bar") 827 with self.assertRaisesRegex(RuntimeError, "Not implemented."): 828 store.multi_get(["foo", "bar"]) 829 with self.assertRaisesRegex(RuntimeError, "Not implemented."): 830 store.multi_set(["foo", "bar"], [b"v", b"v"]) 831 832 def test_has_extended_api_roundtrip(self): 833 store = DummyStore() 834 prefix = dist.PrefixStore("p", store) 835 self.assertTrue(prefix.has_extended_api()) 836 837 def test_append_roundtrip(self): 838 store = DummyStore() 839 prefix = dist.PrefixStore("p", store) 840 prefix.append("foo", "bar") 841 self.assertEqual(1, len(store.appends)) 842 self.assertEqual(("p/foo", b"bar"), store.appends[0]) 843 844 def test_multi_get_roundtrip(self): 845 store = DummyStore() 846 prefix = dist.PrefixStore("p", store) 847 store.multi_get_res.append([b"x", b"y"]) 848 res = prefix.multi_get(["foo", "bar"]) 849 self.assertEqual(1, len(store.multi_gets)) 850 self.assertEqual(["p/foo", "p/bar"], store.multi_gets[0]) 851 self.assertEqual([b"x", b"y"], res) 852 853 def test_multi_set_roundtrip(self): 854 store = DummyStore() 855 prefix = dist.PrefixStore("p", store) 856 prefix.multi_set(["foo", "bar"], [b"x", b"y"]) 857 self.assertEqual(1, len(store.multi_sets)) 858 self.assertEqual(["p/foo", "p/bar"], store.multi_sets[0][0]) 859 self.assertEqual([b"x", b"y"], store.multi_sets[0][1]) 860 861 def test_extended_methods_fallbacks(self): 862 test_store = MyPythonStore() 863 store = dist.PrefixStore("p", test_store) 864 self.assertFalse(store.has_extended_api()) 865 store.append("foo", b"po") 866 store.append("foo", b"tato") 867 self.assertEqual(store.get("foo"), b"potato") 868 869 store.multi_set(["a", "b"], [b"c", b"d"]) 870 self.assertEqual(store.multi_get(["a", "b", "foo"]), [b"c", b"d", b"potato"]) 871 872 873class TestMultiThreadedWait(MultiThreadedTestCase): 874 file_store = dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1) 875 hash_store = dist.HashStore() 876 877 tcp_store = create_tcp_store(use_libuv=False) 878 tcp_store_uv = create_tcp_store(use_libuv=True) 879 880 @property 881 def world_size(self): 882 return 2 883 884 def setUp(self): 885 super().setUp() 886 self._spawn_threads() 887 888 def _test_wait(self, store): 889 store.set_timeout(timedelta(seconds=2)) 890 if dist.get_rank() == 0: 891 store.wait(["key1"]) 892 self.assertEqual(b"value1", store.get("key1")) 893 if dist.get_rank() == 1: 894 store.set("key1", "value1") 895 896 def test_wait_hash_store(self): 897 self._test_wait(self.hash_store) 898 899 def test_wait_file_store(self): 900 self._test_wait(self.file_store) 901 902 def test_wait_prefix_file_store(self): 903 store = dist.PrefixStore("pre", self.file_store) 904 self._test_wait(store) 905 906 def _test_wait_tcp_store(self, master_store): 907 store = ( 908 master_store 909 if dist.get_rank() == 0 910 else dist.TCPStore( 911 host_name=master_store.host, 912 port=master_store.port, 913 is_master=False, 914 wait_for_workers=False, 915 use_libuv=False, 916 ) 917 ) 918 self._test_wait(store) 919 920 prefix_store = dist.PrefixStore("pre", store) 921 self._test_wait(prefix_store) 922 923 def test_wait_tcp_store(self): 924 self._test_wait_tcp_store(self.tcp_store) 925 926 def test_wait_tcp_store_uv(self): 927 self._test_wait_tcp_store(self.tcp_store_uv) 928 929 930instantiate_parametrized_tests(TestMultiThreadedWait) 931 932 933@skip_if_win32() 934class TimeoutTest(TestCase): 935 def tearDown(self): 936 import signal 937 938 super().tearDown() 939 signal.signal(signal.SIGUSR1, signal.SIG_IGN) 940 941 def test_interrupt_doesnt_break_wait(self): 942 import signal 943 944 rank_res = [None, None] 945 946 def run(rank, my_store): 947 nonlocal rank_res 948 try: 949 if rank == 0: 950 time.sleep(4) 951 my_store.set("foo", "bar") 952 else: 953 my_store.wait(["foo"], datetime.timedelta(seconds=10)) 954 rank_res[rank] = True 955 except Error as e: # noqa: F821 956 rank_res[rank] = e 957 time.sleep(1) 958 959 rank0_store = dist.TCPStore( 960 host_name=DEFAULT_HOSTNAME, 961 port=0, 962 world_size=2, 963 is_master=True, 964 wait_for_workers=False, 965 ) 966 rank1_store = dist.TCPStore( 967 host_name=DEFAULT_HOSTNAME, 968 port=rank0_store.port, 969 world_size=2, 970 is_master=False, 971 wait_for_workers=False, 972 ) 973 974 ths = [] 975 for i in range(2): 976 t = threading.Thread( 977 target=run, 978 args=( 979 i, 980 [rank0_store, rank1_store][i], 981 ), 982 ) 983 t.start() 984 ths.append(t) 985 986 def handler(a, b): 987 pass 988 989 signal.signal(signal.SIGUSR1, handler) 990 time.sleep(1) 991 signal.pthread_kill(ths[1].ident, signal.SIGUSR1) 992 993 for t in ths: 994 t.join() 995 self.assertTrue(rank_res[0], "rank0") 996 self.assertTrue(rank_res[1], "rank1") 997 998 999class InitPgWithNonUvStore(TestCase): 1000 """ 1001 This test shows how to use the legacy TCPStore (non-libuv) backend since libuv is now 1002 the default backend. 1003 """ 1004 1005 def tearDown(self): 1006 super().tearDown() 1007 os.environ.pop("USE_LIBUV", None) 1008 os.environ.pop("MASTER_ADDR", None) 1009 os.environ.pop("MASTER_PORT", None) 1010 1011 def test_with_url_param(self): 1012 port = common.find_free_port() 1013 dist.init_process_group( 1014 "gloo", 1015 rank=0, 1016 world_size=1, 1017 init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=0", 1018 ) 1019 self._run_test() 1020 1021 def test_with_env_var(self): 1022 port = common.find_free_port() 1023 os.environ["USE_LIBUV"] = "0" 1024 os.environ["MASTER_ADDR"] = DEFAULT_HOSTNAME 1025 os.environ["MASTER_PORT"] = str(port) 1026 dist.init_process_group("gloo", rank=0, world_size=1, init_method="env://") 1027 self._run_test() 1028 1029 def _run_test(self): 1030 pg = dist.group.WORLD 1031 store = c10d._get_process_group_store(pg) 1032 self.assertTrue(isinstance(store, dist.PrefixStore)) 1033 # c10d does multiple levels of wrapping 1034 while isinstance(store, dist.PrefixStore): 1035 store = store.underlying_store 1036 self.assertTrue(isinstance(store, dist.TCPStore)) 1037 self.assertFalse(store.libuvBackend) 1038 dist.destroy_process_group() 1039 1040 1041class TestClientProtocol(TestCase): 1042 def test_client_connect(self) -> None: 1043 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 1044 sock.bind(("localhost", 0)) 1045 port = sock.getsockname()[1] 1046 1047 def listen() -> None: 1048 sock.listen() 1049 conn, _ = sock.accept() 1050 1051 # VALIDATE 1052 # 0x3C85F7CE 1053 self.assertEqual(conn.recv(5), b"\x00\xce\xf7\x85\x3c") 1054 1055 # PING 1056 data = conn.recv(5) 1057 self.assertEqual(data[0], 13) 1058 nonce = struct.unpack("i", data[1:])[0] 1059 self.assertEqual(nonce, os.getpid()) 1060 1061 # send PING nonce response 1062 conn.sendall(data[1:]) 1063 1064 conn.close() 1065 1066 thread = threading.Thread(target=listen) 1067 thread.start() 1068 1069 store = dist.TCPStore( 1070 host_name="localhost", 1071 port=port, 1072 world_size=2, 1073 is_master=False, 1074 timeout=timedelta(seconds=2), 1075 wait_for_workers=False, 1076 ) 1077 1078 thread.join() 1079 1080 1081if __name__ == "__main__": 1082 assert ( 1083 not torch.cuda._initialized 1084 ), "test_distributed must not have initialized CUDA context on main process" 1085 1086 run_tests() 1087