xref: /aosp_15_r20/external/pytorch/test/distributed/test_store.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import datetime
4import os
5import socket
6import struct
7import sys
8import tempfile
9import threading
10import time
11from datetime import timedelta
12from sys import platform
13
14import torch
15import torch.distributed as dist
16import torch.distributed.distributed_c10d as c10d
17import torch.distributed.rpc as rpc
18from torch.distributed import DistError, DistNetworkError, DistStoreError
19from torch.testing._internal.common_distributed import MultiThreadedTestCase
20from torch.testing._internal.common_utils import instantiate_parametrized_tests
21
22
23if not dist.is_available():
24    print("torch.distributed not available, skipping tests", file=sys.stderr)
25    sys.exit(0)
26
27import torch.testing._internal.common_utils as common
28from torch.testing._internal.common_distributed import (
29    create_tcp_store,
30    skip_if_win32,
31    tp_transports,
32)
33from torch.testing._internal.common_utils import (
34    ADDRESS_IN_USE,
35    CONNECT_TIMEOUT,
36    load_tests,
37    retry_on_connect_failures,
38    run_tests,
39    TestCase,
40)
41
42
43# load_tests from common_utils is used to automatically filter tests for
44# sharding on sandcastle. This line silences flake warnings
45load_tests = load_tests
46
47if platform == "darwin":
48    LOOPBACK = "lo0"
49else:
50    LOOPBACK = "lo"
51
52DEFAULT_HOSTNAME = "localhost"
53
54torch.backends.cuda.matmul.allow_tf32 = False
55
56
57def gpus_for_rank(world_size):
58    """Multigpu tests are designed to simulate the multi nodes with multi
59    GPUs on each node. Nccl backend requires equal #GPUs in each process.
60    On a single node, all visible GPUs are evenly
61    divided to subsets, each process only uses a subset.
62    """
63    visible_devices = list(range(torch.cuda.device_count()))
64    gpus_per_process = torch.cuda.device_count() // world_size
65    gpus_for_rank = []
66    for rank in range(world_size):
67        gpus_for_rank.append(
68            visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process]
69        )
70    return gpus_for_rank
71
72
73class StoreTestBase:
74    def _create_store(self, i):
75        raise RuntimeError("not implemented")
76
77    def _test_set_get_check(self, fs):
78        fs.add("key", 1)
79        fs.add("key", 2)
80        fs.add("key", 3)
81        fs.set("key0", "value0")
82        fs.add("key3", 1)
83        fs.set("key1", "value1")
84        fs.add("key3", 2)
85        fs.set("key2", "value2")
86        fs.add("key3", 3)
87        fs.add("key3", 4)
88        fs.add("key3", 5)
89        fs.add("key3", 6)
90        self.assertEqual(fs.num_keys(), self.num_keys_total)
91        self.assertEqual(b"6", fs.get("key"))
92        self.assertEqual(b"value0", fs.get("key0"))
93        self.assertEqual(b"value1", fs.get("key1"))
94        self.assertEqual(b"value2", fs.get("key2"))
95        self.assertEqual(b"21", fs.get("key3"))
96        self.assertTrue(fs.check(["key3"]))
97        self.assertFalse(fs.check(["Randomkey3"]))
98
99        fs.set("-key3", "7")
100        self.assertEqual(b"7", fs.get("-key3"))
101        fs.delete_key("-key3")
102        self.assertEqual(fs.num_keys(), self.num_keys_total)
103
104    def test_set_get_check(self):
105        self._test_set_get_check(self._create_store())
106
107    def _test_compare_set(self, store):
108        missing_key_result = store.compare_set(
109            "cs_key0", "wrong_old_value", "new_value0"
110        )
111        self.assertEqual(b"wrong_old_value", missing_key_result)
112
113        store.set("cs_key0", "value0")
114        self.assertEqual(b"value0", store.get("cs_key0"))
115        old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
116        self.assertEqual(b"value0", old_value_result)
117        self.assertEqual(b"value0", store.get("cs_key0"))
118        new_value_result = store.compare_set("cs_key0", "value0", "new_value0")
119        self.assertEqual(b"new_value0", new_value_result)
120        self.assertEqual(b"new_value0", store.get("cs_key0"))
121        empty_old_value_result = store.compare_set("cs_key1", "", "new_value1")
122        self.assertEqual(b"new_value1", empty_old_value_result)
123        self.assertEqual(b"new_value1", store.get("cs_key1"))
124
125    def test_compare_set(self):
126        self._test_compare_set(self._create_store())
127
128    def _test_simple_wait(self, fs):
129        with self.assertRaisesRegex(RuntimeError, "[t -i]imeout"):
130            fs.wait(["bad_key"], timedelta(seconds=0.25))
131        fs.add("good_key", 1)
132        fs.wait(["good_key"])
133
134    def test_simple_wait(self):
135        self._test_simple_wait(self._create_store())
136
137    def _test_append(self, store):
138        if not store.has_extended_api():
139            self.skipTest("Store doesn't support extended APIs")
140        store.set("foo", "po")
141        store.append("foo", "tato")
142        store.append("bar", "po")
143        store.append("bar", "tato")
144        self.assertEqual(b"potato", store.get("foo"))
145        self.assertEqual(b"potato", store.get("bar"))
146
147    def test_append(self):
148        self._test_append(self._create_store())
149
150    def _test_multi_set(self, store):
151        if not store.has_extended_api():
152            self.skipTest("Store doesn't support extended APIs")
153        store.multi_set(["foo", "bar"], ["po", "tato"])
154        self.assertEqual(b"po", store.get("foo"))
155        self.assertEqual(b"tato", store.get("bar"))
156
157    def test_multi_set(self):
158        self._test_multi_set(self._create_store())
159
160    def _test_multi_get(self, store):
161        if not store.has_extended_api():
162            self.skipTest("Store doesn't support extended APIs")
163        store.set("foo", "po")
164        store.set("bar", "tato")
165        v0, v1 = store.multi_get(["foo", "bar"])
166        self.assertEqual(b"po", v0)
167        self.assertEqual(b"tato", v1)
168
169    def test_multi_get(self):
170        self._test_multi_get(self._create_store())
171
172    # This is the number of keys used in test_set_get. Adding this as a class
173    # property instead of hardcoding in the test since some Store
174    # implementations will have differing number of keys. In the base case,
175    # there will be 5 keys: key, key0, key1, key2, key3.
176    @property
177    def num_keys_total(self):
178        return 5
179
180
181class FileStoreTest(TestCase, StoreTestBase):
182    def setUp(self):
183        super().setUp()
184        self.file = tempfile.NamedTemporaryFile(delete=False)
185
186    def _create_store(self):
187        store = dist.FileStore(self.file.name, 1)
188        store.set_timeout(timedelta(seconds=300))
189        return store
190
191    def test_init_pg_and_rpc_with_same_file(self):
192        file = tempfile.NamedTemporaryFile(delete=False)
193        # Init RPC using file
194        rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
195        rpc_backend_options.init_method = f"file://{file.name}"
196        rpc_backend_options._transports = tp_transports()
197        rpc.init_rpc(
198            "worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options
199        )
200
201        # Init PG using file
202        dist.init_process_group(
203            "gloo", rank=0, world_size=1, init_method=f"file://{file.name}"
204        )
205        dist.destroy_process_group()
206        assert os.path.exists(file.name)
207
208        rpc.shutdown()
209        os.remove(file.name)
210
211    def test_refcount(self):
212        file = tempfile.NamedTemporaryFile(delete=False)
213        store = dist.FileStore(file.name, 1)
214        store2 = dist.FileStore(file.name, 1)
215
216        del store
217        assert os.path.exists(file.name)
218        del store2
219        assert not os.path.exists(file.name)
220
221    @property
222    def num_keys_total(self):
223        return 6
224
225
226@skip_if_win32()
227class HashStoreTest(TestCase, StoreTestBase):
228    def _create_store(self):
229        store = dist.HashStore()
230        store.set_timeout(timedelta(seconds=300))
231        return store
232
233
234class PrefixStoreTest(TestCase):
235    def setUp(self):
236        # delete is false as FileStore will automatically clean up the file
237        self.file = tempfile.NamedTemporaryFile(delete=False)
238
239    def test_get_underlying_store(self):
240        tcp_store = dist.TCPStore(
241            host_name=DEFAULT_HOSTNAME, port=0, world_size=1, is_master=True
242        )
243        hash_store = dist.HashStore()
244        file_store = dist.FileStore(self.file.name, world_size=1)
245        for store in [tcp_store, hash_store, file_store]:
246            with self.subTest(f"Testing getting underlying_store for {type(store)}"):
247                prefix_store = dist.PrefixStore("prefix", store)
248                self.assertEqual(prefix_store.underlying_store, store)
249
250
251class PrefixFileStoreTest(TestCase, StoreTestBase):
252    def setUp(self):
253        super().setUp()
254        self.file = tempfile.NamedTemporaryFile(delete=False)
255        self.filestore = dist.FileStore(self.file.name, 1)
256        self.prefix = "test_prefix"
257        self.filestore.set_timeout(timedelta(seconds=300))
258
259    def _create_store(self):
260        return dist.PrefixStore(self.prefix, self.filestore)
261
262    @property
263    def num_keys_total(self):
264        return 6
265
266
267class TCPStoreTest(TestCase, StoreTestBase):
268    _use_libuv = False
269
270    def _create_store(self):
271        store = create_tcp_store(use_libuv=self._use_libuv)
272        store.set_timeout(timedelta(seconds=300))
273        return store
274
275    def _create_store_with_ws(self, addr, world_size):
276        return create_tcp_store(
277            addr, world_size, wait_for_workers=False, use_libuv=self._use_libuv
278        )
279
280    def test_address_already_in_use(self):
281        addr = DEFAULT_HOSTNAME
282        port = common.find_free_port()
283
284        err_msg_reg = f"^The server socket has failed to listen on any local .*{port}"
285        with self.assertRaisesRegex(RuntimeError, err_msg_reg):
286            # Use noqa to silence flake8.
287            # Need to store in an unused variable here to ensure the first
288            # object is not destroyed before the second object is created.
289            store1 = dist.TCPStore(
290                addr, port, 1, True, use_libuv=self._use_libuv
291            )  # noqa: F841
292            store2 = dist.TCPStore(
293                addr, port, 1, True, use_libuv=self._use_libuv
294            )  # noqa: F841
295            self.assertEqual(store1.libuvBackend, self._use_libuv)
296            self.assertEqual(store2.libuvBackend, self._use_libuv)
297
298    @retry_on_connect_failures
299    def test_multitenancy(self):
300        addr = DEFAULT_HOSTNAME
301        port = common.find_free_port()
302
303        # Use noqa to silence flake8.
304        # Need to store in an unused variable here to ensure the first
305        # object is not destroyed before the second object is created.
306        store1 = dist.TCPStore(
307            addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv
308        )  # type: ignore[call-arg] # noqa: F841
309        store2 = dist.TCPStore(
310            addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv
311        )  # type: ignore[call-arg] # noqa: F841
312        self.assertEqual(store1.libuvBackend, self._use_libuv)
313        self.assertEqual(store2.libuvBackend, self._use_libuv)
314
315    def test_repr(self) -> None:
316        # server
317        store1 = self._create_store()
318        self.assertRegex(
319            repr(store1),
320            r"TCPStore\("
321            r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), "
322            r"server=TCPServer\(port=\d+\)\)",
323        )
324
325        # client
326        store2 = dist.TCPStore(
327            store1.host,
328            store1.port,
329            world_size=2,
330            is_master=False,
331        )
332        self.assertRegex(
333            repr(store2),
334            r"TCPStore\("
335            r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), "
336            r"server=<nullptr>\)",
337        )
338
339    @skip_if_win32()
340    @retry_on_connect_failures
341    def test_init_pg_and_rpc_with_same_socket(self):
342        addr = DEFAULT_HOSTNAME
343        port = common.find_free_port()
344
345        os.environ["MASTER_ADDR"] = addr
346        os.environ["MASTER_PORT"] = str(port)
347
348        # We internally use a multi-tenant TCP store. Both PG and RPC should successfully
349        # initialize even when using the same socket address.
350
351        os.environ["USE_LIBUV"] = "1" if self._use_libuv else "0"
352        dist.init_process_group(
353            backend="gloo",
354            init_method="env://",
355            rank=0,
356            world_size=1,
357        )
358
359        backend_opts = rpc.TensorPipeRpcBackendOptions(
360            init_method=f"tcp://{addr}:{port}", _transports=tp_transports()
361        )
362        rpc.init_rpc(
363            name="worker0",
364            rank=0,
365            world_size=1,
366            rpc_backend_options=backend_opts,
367        )
368
369        del os.environ["USE_LIBUV"]
370        assert "USE_LIBUV" not in os.environ
371        rpc.shutdown()
372        dist.destroy_process_group()
373
374    @skip_if_win32()
375    def test_take_over_listen_socket(self):
376        listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
377        listen_sock.bind(("localhost", 0))
378        addr, port, *_ = listen_sock.getsockname()
379        listen_fd = listen_sock.detach()
380
381        store = dist.TCPStore(
382            addr,
383            port,
384            1,
385            is_master=True,
386            master_listen_fd=listen_fd,
387            use_libuv=self._use_libuv,
388        )
389
390        self.assertEqual(store.libuvBackend, self._use_libuv)
391        store.set("key", "value")
392        self.assertEqual(b"value", store.get("key"))
393
394    # The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
395    # the user and one additional key used for coordinate all the workers.
396    @property
397    def num_keys_total(self):
398        return 6
399
400    def _test_numkeys_delkeys(self, fs):
401        # We start off with one init key in the store to coordinate workers
402        self.assertEqual(fs.num_keys(), 1)
403        fs.add("key", 1)
404        fs.add("key", 2)
405        fs.add("key", 3)
406        fs.set("key0", "value0")
407        fs.add("key3", 1)
408        fs.set("key1", "value1")
409        self.assertEqual(fs.num_keys(), 5)
410        fs.delete_key("key")
411        self.assertEqual(fs.num_keys(), 4)
412        fs.set_timeout(timedelta(seconds=2))
413        with self.assertRaises(RuntimeError):
414            fs.get("key")
415        fs.delete_key("key0")
416        fs.delete_key("key3")
417        self.assertEqual(fs.num_keys(), 2)
418        fs.set("key4", "value2")
419        self.assertEqual(fs.num_keys(), 3)
420        self.assertEqual(b"value1", fs.get("key1"))
421        self.assertEqual(b"value2", fs.get("key4"))
422
423    def test_numkeys_delkeys(self):
424        self._test_numkeys_delkeys(self._create_store())
425
426    def _create_client(self, index, addr, port, world_size):
427        client_store = dist.TCPStore(
428            addr,
429            port,
430            world_size=world_size,
431            timeout=timedelta(seconds=10),
432            use_libuv=self._use_libuv,
433        )
434        self.assertEqual(b"value", client_store.get("key"))
435        client_store.set(f"new_key{index}", f"new_value{index}")
436        self.assertEqual(
437            f"next_value{index}".encode(),
438            client_store.compare_set(
439                f"new_key{index}", f"new_value{index}", f"next_value{index}"
440            ),
441        )
442
443    def _multi_worker_helper(self, world_size):
444        addr = DEFAULT_HOSTNAME
445        server_store = self._create_store_with_ws(addr, world_size)
446        self.assertEqual(server_store.libuvBackend, self._use_libuv)
447        server_store.set("key", "value")
448        port = server_store.port
449
450        num_indices = world_size if world_size else 1
451        for i in range(num_indices):
452            self._create_client(i, addr, port, world_size)
453
454    def test_multi_worker_with_fixed_world_size(self):
455        self._multi_worker_helper(5)
456
457    def test_multi_worker_with_nonfixed_world_size(self):
458        self._multi_worker_helper(None)
459
460    def test_append(self):
461        store = self._create_store()
462        self.assertEqual(store.libuvBackend, self._use_libuv)
463        store.set("foo", "po")
464        store.append("foo", "tato")
465        store.append("bar", "po")
466        store.append("bar", "tato")
467        self.assertEqual(b"potato", store.get("foo"))
468        self.assertEqual(b"potato", store.get("bar"))
469
470    def test_multi_set(self):
471        store = self._create_store()
472        self.assertEqual(store.libuvBackend, self._use_libuv)
473        store.multi_set(["foo", "bar"], ["po", "tato"])
474        self.assertEqual(b"po", store.get("foo"))
475        self.assertEqual(b"tato", store.get("bar"))
476
477    def test_multi_get(self):
478        store = self._create_store()
479        self.assertEqual(store.libuvBackend, self._use_libuv)
480        store.set("foo", "po")
481        store.set("bar", "tato")
482        v0, v1 = store.multi_get(["foo", "bar"])
483        self.assertEqual(b"po", v0)
484        self.assertEqual(b"tato", v1)
485
486    def test_store_timeout_on_missing_clients(self):
487        with self.assertRaisesRegex(
488            DistStoreError,
489            r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined.",
490        ):
491            # world_size is 2 so it should timeout
492            dist.TCPStore(
493                "localhost",
494                0,
495                2,
496                True,
497                timeout=timedelta(seconds=2),
498                use_libuv=self._use_libuv,
499            )
500
501        # when wait_for_workers is not set, then there should be no exception raised
502        dist.TCPStore(
503            "localhost",
504            0,
505            2,
506            True,
507            timeout=timedelta(seconds=2),
508            wait_for_workers=False,
509            use_libuv=self._use_libuv,
510        )
511
512
513class LibUvTCPStoreTest(TCPStoreTest):
514    _use_libuv = True
515
516    def _create_store(self):
517        store = create_tcp_store(use_libuv=True)
518        store.set_timeout(timedelta(seconds=300))
519        return store
520
521    def _create_store_with_ws(self, addr, world_size):
522        return create_tcp_store(
523            addr, world_size, wait_for_workers=False, use_libuv=True
524        )
525
526    def test_take_over_listen_socket(self):
527        """
528        override the take_over_listen_socket test in TCPStoreTest.
529        Reason: we have not thoroughly tested libuv TCPStore initialization using
530        open Socket so we decide to not support this use for now.
531        TODO (xilunwu): enable this use case
532        """
533        listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
534        listen_sock.bind(("localhost", 0))
535        addr, port, *_ = listen_sock.getsockname()
536        listen_fd = listen_sock.detach()
537
538        err_msg_reg = (
539            "^The libuv TCPStore backend does not support "
540            "initialization with an listen fd"
541        )
542
543        with self.assertRaisesRegex(NotImplementedError, err_msg_reg):
544            store = dist.TCPStore(
545                addr,
546                port,
547                1,
548                is_master=True,
549                master_listen_fd=listen_fd,
550                use_libuv=self._use_libuv,
551            )
552
553
554class PrefixTCPStoreTest(TestCase, StoreTestBase):
555    def setUp(self):
556        super().setUp()
557        self.tcpstore = create_tcp_store()
558        self.prefix = "test_prefix"
559        self.tcpstore.set_timeout(timedelta(seconds=300))
560
561    def _create_store(self):
562        return dist.PrefixStore(self.prefix, self.tcpstore)
563
564    # The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
565    # added by the user and one additional key used for coordinate all the
566    # workers.
567    @property
568    def num_keys_total(self):
569        return 6
570
571    def test_underlying_non_prefix_store(self):
572        store = self._create_store()
573        wrapped_store = dist.PrefixStore(
574            self.prefix, dist.PrefixStore(self.prefix, store)
575        )
576        self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
577        self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
578
579
580class MyPythonStore(dist.Store):
581    def __init__(self) -> None:
582        super().__init__()
583        self.store = {}
584
585    def set(self, key, value):
586        if not isinstance(key, (str, bytes)):
587            raise AssertionError("Expected set to be called with string key")
588        if type(value) is not bytes:
589            raise AssertionError("Expected set to be called with bytes value")
590        self.store[key] = value
591
592    def get(self, key):
593        value = self.store.get(key, b"")
594        if type(value) is not bytes:
595            raise AssertionError("Expected get to return bytes value")
596        return value
597
598    def add(self, key, value):
599        new = int(self.store.get(key, 0)) + value
600        self.set(key, bytes(str(new).encode("utf-8")))
601        return new
602
603    def compare_set(self, key, expected, newValue):
604        if type(expected) is not bytes:
605            raise AssertionError("compare_set::expected not bytes")
606        if type(newValue) is not bytes:
607            raise AssertionError("compare_set::newValue not bytes")
608
609        val = self.store.get(key, None)
610        if expected == val or val is None:
611            val = self.store[key] = newValue
612        return val
613
614
615class PythonStoreTest(TestCase):
616    def test_set_get(self):
617        # If we were to inherit from StoreTestBase and try to use
618        # its test_set_get function, we would exercise the Python
619        # API directly, instead of going through the C++ trampoline.
620        # We care about testing the C++ trampoline, so run the
621        # equivalent of StoreTestBase.test_set_get from C++.
622        # See `torch/csrc/distributed/c10d/init.cpp` for the definition
623        # of this test function.
624        dist._test_python_store(MyPythonStore())
625
626
627class RendezvousTest(TestCase):
628    def test_unknown_handler(self):
629        with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
630            dist.rendezvous("invalid://")
631
632    def test_url_with_node_params(self):
633        with self.assertRaisesRegex(AssertionError, "has node-specific arguments"):
634            dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16)
635
636
637class RendezvousEnvTest(TestCase):
638    @retry_on_connect_failures
639    def test_nominal(self):
640        os.environ["WORLD_SIZE"] = "1"
641        os.environ["MASTER_ADDR"] = "127.0.0.1"
642        os.environ["MASTER_PORT"] = str(common.find_free_port())
643
644        # Single rank
645        os.environ["RANK"] = "0"
646        gen0 = dist.rendezvous("env://")
647        store0, rank0, size0 = next(gen0)
648        self.assertEqual(0, rank0)
649        self.assertEqual(1, size0)
650
651        store0.set("key0", "value0")
652
653        # check with get
654        self.assertEqual(b"value0", store0.get("key0"))
655
656
657class RendezvousFileTest(TestCase):
658    def test_common_errors(self):
659        with self.assertRaisesRegex(ValueError, "path missing"):
660            gen = dist.rendezvous("file://?rank=0&world_size=1")
661            next(gen)
662        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
663            gen = dist.rendezvous("file:///tmp/foo?world_size=1")
664            next(gen)
665        with self.assertRaisesRegex(ValueError, "size parameter missing"):
666            gen = dist.rendezvous("file:///tmp/foo?rank=0")
667            next(gen)
668
669    def test_nominal(self):
670        with tempfile.NamedTemporaryFile(delete=False) as file:
671            url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
672            gen0 = dist.rendezvous(url + "&rank=0")
673            store0, rank0, size0 = next(gen0)
674            self.assertEqual(0, rank0)
675            self.assertEqual(2, size0)
676            gen1 = dist.rendezvous(url + "&rank=1")
677            store1, rank1, size1 = next(gen1)
678            self.assertEqual(1, rank1)
679            self.assertEqual(2, size1)
680
681            # Set value on both stores
682            store0.set("key0", "value0")
683            store1.set("key1", "value1")
684
685            # Cross check with get
686            self.assertEqual(b"value0", store1.get("key0"))
687            self.assertEqual(b"value1", store0.get("key1"))
688
689
690@skip_if_win32()
691class RendezvousTCPTest(TestCase):
692    def create_tcp_url(self):
693        addr = DEFAULT_HOSTNAME
694        port = common.find_free_port()
695        url = "tcp://%s:%d?world_size=%d" % (addr, port, 1)
696        return url
697
698    def test_common_errors(self):
699        with self.assertRaisesRegex(ValueError, "port number missing"):
700            gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
701            next(gen)
702        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
703            gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1")
704            next(gen)
705        with self.assertRaisesRegex(ValueError, "size parameter missing"):
706            gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
707            next(gen)
708
709    def test_dns_timeout(self):
710        with self.assertRaisesRegex(
711            DistNetworkError, "client socket has timed out after.*dnsnotexist"
712        ) as manager:
713            gen = dist.rendezvous(
714                "tcp://dnsnotexist:23456?world_size=2&rank=0",
715                timeout=timedelta(seconds=1),
716            )
717            next(gen)
718        self.assertTrue(isinstance(manager.exception, DistError))
719
720    @retry_on_connect_failures
721    def test_nominal(self):
722        url = self.create_tcp_url()
723        gen0 = dist.rendezvous(url + "&rank=0")
724        store0, rank0, size0 = next(gen0)
725        self.assertEqual(0, rank0)
726        self.assertEqual(1, size0)
727
728        # Set value on the single store
729        store0.set("key0", "value0")
730
731        # check with get
732        self.assertEqual(b"value0", store0.get("key0"))
733
734    @retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE))
735    def test_tcp_store_timeout_set(self):
736        url = self.create_tcp_url()
737        test_store_timeout = timedelta(seconds=0.1)
738        gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
739        store0, rank0, size0 = next(gen0)
740        store0.set_timeout(test_store_timeout)
741        # this should time out in 0.1s. If the timeout passed into rendezvous was
742        # not respected, it will take much longer to timeout.
743        start = time.time()
744        with self.assertRaisesRegex(
745            DistStoreError, "wait timeout after 100ms, keys: /nonexistant key"
746        ):
747            store0.get("nonexistant key")
748
749        end = time.time()
750        time_diff = end - start
751        self.assertGreater(10, time_diff)
752
753    def test_tcp_store_timeout_doest_break_client(self):
754        url = self.create_tcp_url()
755        test_store_timeout = timedelta(seconds=0.1)
756        gen0 = dist.rendezvous(url + "&rank=0", timeout=timedelta(seconds=10))
757        store0, rank0, size0 = next(gen0)
758        store0.set_timeout(test_store_timeout)
759        # this should time out in 10s. If the timeout passed into rendezvous was
760        # not respected, it will take much longer to timeout.
761        start = time.time()
762        with self.assertRaisesRegex(
763            DistStoreError, "wait timeout after 100ms, keys: /the_key"
764        ):
765            store0.get("the_key")
766
767        store0.set("the_key", "x")
768
769        self.assertEqual(b"x", store0.get("the_key"))
770
771        end = time.time()
772        time_diff = end - start
773        self.assertGreater(10, time_diff)
774
775    def test_tcp_store_url_with_libuv(self):
776        url = self.create_tcp_url()
777        gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1")
778        store0, rank0, size0 = next(gen0)
779        self.assertTrue(store0.libuvBackend)
780
781
782class DummyStore(dist.Store):
783    def __init__(self) -> None:
784        self.appends = []
785        self.multi_sets = []
786        self.multi_gets = []
787        self.multi_get_res = []
788        super().__init__()
789
790    def append(self, key, value):
791        self.appends.append((key, value))
792
793    def multi_get(self, keys):
794        self.multi_gets.append(keys)
795        return self.multi_get_res.pop(0)
796
797    def multi_set(self, keys, values):
798        self.multi_sets.append((keys, values))
799
800    def has_extended_api(self):
801        return True
802
803
804class TestPythonStore(TestCase):
805    def test_optional_methods_fail(self):
806        class TestStore(dist.Store):
807            pass
808
809        store = TestStore()
810        self.assertFalse(store.has_extended_api())
811        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
812            store.append("foo", "bar")
813        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
814            store.multi_get(["foo", "bar"])
815        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
816            store.multi_set(["foo", "bar"], [b"v", b"v"])
817
818    def test_has_extended_api_passthrough(self):
819        class TestStore(dist.Store):
820            pass
821
822        test_store = TestStore()
823        store = dist.PrefixStore("p", test_store)
824        self.assertFalse(store.has_extended_api())
825        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
826            store.append("foo", "bar")
827        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
828            store.multi_get(["foo", "bar"])
829        with self.assertRaisesRegex(RuntimeError, "Not implemented."):
830            store.multi_set(["foo", "bar"], [b"v", b"v"])
831
832    def test_has_extended_api_roundtrip(self):
833        store = DummyStore()
834        prefix = dist.PrefixStore("p", store)
835        self.assertTrue(prefix.has_extended_api())
836
837    def test_append_roundtrip(self):
838        store = DummyStore()
839        prefix = dist.PrefixStore("p", store)
840        prefix.append("foo", "bar")
841        self.assertEqual(1, len(store.appends))
842        self.assertEqual(("p/foo", b"bar"), store.appends[0])
843
844    def test_multi_get_roundtrip(self):
845        store = DummyStore()
846        prefix = dist.PrefixStore("p", store)
847        store.multi_get_res.append([b"x", b"y"])
848        res = prefix.multi_get(["foo", "bar"])
849        self.assertEqual(1, len(store.multi_gets))
850        self.assertEqual(["p/foo", "p/bar"], store.multi_gets[0])
851        self.assertEqual([b"x", b"y"], res)
852
853    def test_multi_set_roundtrip(self):
854        store = DummyStore()
855        prefix = dist.PrefixStore("p", store)
856        prefix.multi_set(["foo", "bar"], [b"x", b"y"])
857        self.assertEqual(1, len(store.multi_sets))
858        self.assertEqual(["p/foo", "p/bar"], store.multi_sets[0][0])
859        self.assertEqual([b"x", b"y"], store.multi_sets[0][1])
860
861    def test_extended_methods_fallbacks(self):
862        test_store = MyPythonStore()
863        store = dist.PrefixStore("p", test_store)
864        self.assertFalse(store.has_extended_api())
865        store.append("foo", b"po")
866        store.append("foo", b"tato")
867        self.assertEqual(store.get("foo"), b"potato")
868
869        store.multi_set(["a", "b"], [b"c", b"d"])
870        self.assertEqual(store.multi_get(["a", "b", "foo"]), [b"c", b"d", b"potato"])
871
872
873class TestMultiThreadedWait(MultiThreadedTestCase):
874    file_store = dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1)
875    hash_store = dist.HashStore()
876
877    tcp_store = create_tcp_store(use_libuv=False)
878    tcp_store_uv = create_tcp_store(use_libuv=True)
879
880    @property
881    def world_size(self):
882        return 2
883
884    def setUp(self):
885        super().setUp()
886        self._spawn_threads()
887
888    def _test_wait(self, store):
889        store.set_timeout(timedelta(seconds=2))
890        if dist.get_rank() == 0:
891            store.wait(["key1"])
892            self.assertEqual(b"value1", store.get("key1"))
893        if dist.get_rank() == 1:
894            store.set("key1", "value1")
895
896    def test_wait_hash_store(self):
897        self._test_wait(self.hash_store)
898
899    def test_wait_file_store(self):
900        self._test_wait(self.file_store)
901
902    def test_wait_prefix_file_store(self):
903        store = dist.PrefixStore("pre", self.file_store)
904        self._test_wait(store)
905
906    def _test_wait_tcp_store(self, master_store):
907        store = (
908            master_store
909            if dist.get_rank() == 0
910            else dist.TCPStore(
911                host_name=master_store.host,
912                port=master_store.port,
913                is_master=False,
914                wait_for_workers=False,
915                use_libuv=False,
916            )
917        )
918        self._test_wait(store)
919
920        prefix_store = dist.PrefixStore("pre", store)
921        self._test_wait(prefix_store)
922
923    def test_wait_tcp_store(self):
924        self._test_wait_tcp_store(self.tcp_store)
925
926    def test_wait_tcp_store_uv(self):
927        self._test_wait_tcp_store(self.tcp_store_uv)
928
929
930instantiate_parametrized_tests(TestMultiThreadedWait)
931
932
933@skip_if_win32()
934class TimeoutTest(TestCase):
935    def tearDown(self):
936        import signal
937
938        super().tearDown()
939        signal.signal(signal.SIGUSR1, signal.SIG_IGN)
940
941    def test_interrupt_doesnt_break_wait(self):
942        import signal
943
944        rank_res = [None, None]
945
946        def run(rank, my_store):
947            nonlocal rank_res
948            try:
949                if rank == 0:
950                    time.sleep(4)
951                    my_store.set("foo", "bar")
952                else:
953                    my_store.wait(["foo"], datetime.timedelta(seconds=10))
954                rank_res[rank] = True
955            except Error as e:  # noqa: F821
956                rank_res[rank] = e
957            time.sleep(1)
958
959        rank0_store = dist.TCPStore(
960            host_name=DEFAULT_HOSTNAME,
961            port=0,
962            world_size=2,
963            is_master=True,
964            wait_for_workers=False,
965        )
966        rank1_store = dist.TCPStore(
967            host_name=DEFAULT_HOSTNAME,
968            port=rank0_store.port,
969            world_size=2,
970            is_master=False,
971            wait_for_workers=False,
972        )
973
974        ths = []
975        for i in range(2):
976            t = threading.Thread(
977                target=run,
978                args=(
979                    i,
980                    [rank0_store, rank1_store][i],
981                ),
982            )
983            t.start()
984            ths.append(t)
985
986        def handler(a, b):
987            pass
988
989        signal.signal(signal.SIGUSR1, handler)
990        time.sleep(1)
991        signal.pthread_kill(ths[1].ident, signal.SIGUSR1)
992
993        for t in ths:
994            t.join()
995        self.assertTrue(rank_res[0], "rank0")
996        self.assertTrue(rank_res[1], "rank1")
997
998
999class InitPgWithNonUvStore(TestCase):
1000    """
1001    This test shows how to use the legacy TCPStore (non-libuv) backend since libuv is now
1002    the default backend.
1003    """
1004
1005    def tearDown(self):
1006        super().tearDown()
1007        os.environ.pop("USE_LIBUV", None)
1008        os.environ.pop("MASTER_ADDR", None)
1009        os.environ.pop("MASTER_PORT", None)
1010
1011    def test_with_url_param(self):
1012        port = common.find_free_port()
1013        dist.init_process_group(
1014            "gloo",
1015            rank=0,
1016            world_size=1,
1017            init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=0",
1018        )
1019        self._run_test()
1020
1021    def test_with_env_var(self):
1022        port = common.find_free_port()
1023        os.environ["USE_LIBUV"] = "0"
1024        os.environ["MASTER_ADDR"] = DEFAULT_HOSTNAME
1025        os.environ["MASTER_PORT"] = str(port)
1026        dist.init_process_group("gloo", rank=0, world_size=1, init_method="env://")
1027        self._run_test()
1028
1029    def _run_test(self):
1030        pg = dist.group.WORLD
1031        store = c10d._get_process_group_store(pg)
1032        self.assertTrue(isinstance(store, dist.PrefixStore))
1033        # c10d does multiple levels of wrapping
1034        while isinstance(store, dist.PrefixStore):
1035            store = store.underlying_store
1036        self.assertTrue(isinstance(store, dist.TCPStore))
1037        self.assertFalse(store.libuvBackend)
1038        dist.destroy_process_group()
1039
1040
1041class TestClientProtocol(TestCase):
1042    def test_client_connect(self) -> None:
1043        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1044        sock.bind(("localhost", 0))
1045        port = sock.getsockname()[1]
1046
1047        def listen() -> None:
1048            sock.listen()
1049            conn, _ = sock.accept()
1050
1051            # VALIDATE
1052            # 0x3C85F7CE
1053            self.assertEqual(conn.recv(5), b"\x00\xce\xf7\x85\x3c")
1054
1055            # PING
1056            data = conn.recv(5)
1057            self.assertEqual(data[0], 13)
1058            nonce = struct.unpack("i", data[1:])[0]
1059            self.assertEqual(nonce, os.getpid())
1060
1061            # send PING nonce response
1062            conn.sendall(data[1:])
1063
1064            conn.close()
1065
1066        thread = threading.Thread(target=listen)
1067        thread.start()
1068
1069        store = dist.TCPStore(
1070            host_name="localhost",
1071            port=port,
1072            world_size=2,
1073            is_master=False,
1074            timeout=timedelta(seconds=2),
1075            wait_for_workers=False,
1076        )
1077
1078        thread.join()
1079
1080
1081if __name__ == "__main__":
1082    assert (
1083        not torch.cuda._initialized
1084    ), "test_distributed must not have initialized CUDA context on main process"
1085
1086    run_tests()
1087