xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_distributed.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import abc
4import faulthandler
5import itertools
6import logging
7import multiprocessing
8import os
9import queue
10import subprocess
11import sys
12import tempfile
13import threading
14import time
15import traceback
16import types
17import unittest
18from contextlib import contextmanager
19from dataclasses import dataclass
20from datetime import timedelta
21from enum import Enum
22from functools import partial, reduce, wraps
23from io import StringIO
24from typing import Dict, NamedTuple, Optional, Union, List, Any, Callable, Tuple
25from unittest.mock import patch
26
27import torch
28import torch._dynamo.test_case
29import torch.cuda.nccl
30import torch.distributed as c10d
31import torch.nn as nn
32from torch.testing._internal.common_utils import (
33    FILE_SCHEMA,
34    find_free_port,
35    IS_SANDCASTLE,
36    retry_on_connect_failures,
37    skip_but_pass_in_sandcastle,
38    skip_but_pass_in_sandcastle_if,
39    TEST_WITH_ROCM,
40    TEST_WITH_TSAN,
41    TestCase,
42    run_tests,
43)
44from torch.testing._internal.distributed.multi_threaded_pg import (
45    _install_threaded_pg,
46    _uninstall_threaded_pg,
47    ProcessLocalGroup,
48)
49import operator
50
51logging.basicConfig(level=logging.INFO)
52logger = logging.getLogger(__name__)
53
54
55class TestSkip(NamedTuple):
56    exit_code: int
57    message: str
58
59
60TEST_SKIPS = {
61    "backend_unavailable": TestSkip(
62        72, "Skipped because distributed backend is not available."
63    ),
64    "small_worldsize": TestSkip(73, "Skipped due to small world size."),
65    "odd_worldsize": TestSkip(87, "Skipped due to odd world size."),
66    "no_cuda": TestSkip(74, "CUDA is not available."),
67    "multi-gpu-1": TestSkip(75, "Need at least 1 CUDA device"),
68    "multi-gpu-2": TestSkip(77, "Need at least 2 CUDA devices"),
69    "multi-gpu-3": TestSkip(80, "Need at least 3 CUDA devices"),
70    "multi-gpu-4": TestSkip(81, "Need at least 4 CUDA devices"),
71    "multi-gpu-5": TestSkip(82, "Need at least 5 CUDA devices"),
72    "multi-gpu-6": TestSkip(83, "Need at least 6 CUDA devices"),
73    "multi-gpu-7": TestSkip(84, "Need at least 7 CUDA devices"),
74    "multi-gpu-8": TestSkip(85, "Need at least 8 CUDA devices"),
75    "nccl": TestSkip(76, "c10d not compiled with NCCL support"),
76    "skipIfRocm": TestSkip(78, "Test skipped for ROCm"),
77    "no_peer_access": TestSkip(79, "Test skipped because no GPU peer access"),
78    "generic": TestSkip(
79        86, "Test skipped at subprocess level, look at subprocess log for skip reason"
80    ),
81    "importerror": TestSkip(88, "Test skipped due to missing import"),
82}
83
84
85@dataclass
86class DistTestCases:
87    # Backends that do not support a specific collective
88    skip_collective = {}
89    skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"}
90    skip_collective["reduce"] = set()
91    skip_collective["sendrecv anysource"] = {"nccl", "ucc"}
92    skip_collective["cpu barrier"] = {"nccl", "ucc"}
93
94    # Sets showing that something is implemented
95    backend_feature = {}
96    backend_feature["gpu"] = {"nccl", "gloo", "ucc"}
97    backend_feature["cuda"] = {"nccl", "gloo", "ucc"}
98    backend_feature["ddp"] = {"nccl", "gloo", "ucc"}
99    backend_feature["subgroup"] = {"nccl", "gloo", "ucc"}
100    backend_feature["plugin"] = set()
101
102
103def skip_if_no_gpu(func):
104    """Skips if the world size exceeds the number of GPUs, ensuring that if the
105    test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
106
107    @wraps(func)
108    def wrapper(*args, **kwargs):
109        if not torch.cuda.is_available():
110            sys.exit(TEST_SKIPS["no_cuda"].exit_code)
111        world_size = int(os.environ["WORLD_SIZE"])
112        if torch.cuda.device_count() < world_size:
113            sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
114
115        return func(*args, **kwargs)
116
117    return wrapper
118
119
120def skip_if_small_worldsize(func):
121    @wraps(func)
122    def wrapper(*args, **kwargs):
123        if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) <= 2:
124            sys.exit(TEST_SKIPS["small_worldsize"].exit_code)
125
126        return func(*args, **kwargs)
127
128    return wrapper
129
130
131def skip_if_odd_worldsize(func):
132    @wraps(func)
133    def wrapper(*args, **kwargs):
134        if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) % 2 == 1:
135            sys.exit(TEST_SKIPS["odd_worldsize"].exit_code)
136
137        return func(*args, **kwargs)
138
139    return wrapper
140
141
142def require_n_gpus_for_nccl_backend(n, backend):
143    def decorator(func):
144        @wraps(func)
145        def wrapper(*args, **kwargs):
146            if backend == "nccl" and torch.cuda.device_count() < n:
147                sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code)
148            else:
149                return func(*args, **kwargs)
150
151        return wrapper
152
153    return decorator
154
155
156def import_transformers_or_skip():
157    def decorator(func):
158        @wraps(func)
159        def wrapper(*args, **kwargs):
160            try:
161                from transformers import (  # noqa: F401
162                    AutoModelForMaskedLM,
163                    BertConfig,
164                )
165
166                return func(*args, **kwargs)
167            except ImportError:
168                sys.exit(TEST_SKIPS["importerror"].exit_code)
169
170        return wrapper
171
172    return decorator
173
174
175def at_least_x_gpu(x):
176    return torch.cuda.is_available() and torch.cuda.device_count() >= x
177
178
179def skip_if_lt_x_gpu(x):
180    def decorator(func):
181        @wraps(func)
182        def wrapper(*args, **kwargs):
183            if torch.cuda.is_available() and torch.cuda.device_count() >= x:
184                return func(*args, **kwargs)
185            sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
186
187        return wrapper
188
189    return decorator
190
191
192# This decorator helps avoiding initializing cuda while testing other backends
193def nccl_skip_if_lt_x_gpu(backend, x):
194    def decorator(func):
195        @wraps(func)
196        def wrapper(*args, **kwargs):
197            if backend != "nccl":
198                return func(*args, **kwargs)
199            if torch.cuda.is_available() and torch.cuda.device_count() >= x:
200                return func(*args, **kwargs)
201            sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
202
203        return wrapper
204
205    return decorator
206
207
208def verify_ddp_error_logged(model_DDP, err_substr):
209    # Verify error was logged in ddp_logging_data.
210    ddp_logging_data = model_DDP._get_ddp_logging_data()
211    assert "iteration" in ddp_logging_data
212    assert "has_error" in ddp_logging_data
213    assert "error" in ddp_logging_data
214    logging_err = ddp_logging_data["error"]
215    # Remove C++ stacktrace if needed.
216    actual = (
217        err_substr
218        if err_substr.find("\nException raised from ") == -1
219        else err_substr.split("\nException raised from ")[0]
220    )
221    assert (
222        actual in logging_err
223    ), f"Did not find expected {actual} in ddp logging data error: {logging_err}"
224
225
226def with_nccl_blocking_wait(func):
227    """
228    Convenience decorator to set/unset TORCH_NCCL_BLOCKING_WAIT flag. Note that use of
229    this decorator will override the setting of TORCH_NCCL_ASYNC_ERROR_HANDLING for
230    the particular test. After the test, both TORCH_NCCL_BLOCKING_WAIT and
231    TORCH_NCCL_ASYNC_ERROR_HANDLING will be restored to their original values.
232    """
233
234    @wraps(func)
235    def wrapper(*args, **kwargs):
236        # Save and unset TORCH_NCCL_ASYNC_ERROR_HANDLING
237        try:
238            cached_nccl_async_error_handling: Union[str, None] = os.environ[
239                "TORCH_NCCL_ASYNC_ERROR_HANDLING"
240            ]
241            del os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"]
242        except KeyError:
243            # TORCH_NCCL_ASYNC_ERROR_HANDLING was unset
244            cached_nccl_async_error_handling = None
245
246        # Save val of TORCH_NCCL_BLOCKING_WAIT and set it.
247        try:
248            cached_nccl_blocking_wait: Union[str, None] = os.environ[
249                "TORCH_NCCL_BLOCKING_WAIT"
250            ]
251        except KeyError:
252            cached_nccl_blocking_wait = None
253        finally:
254            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
255
256        try:
257            ret = func(*args, **kwargs)
258            return ret
259        finally:
260            # restore old values.
261            if cached_nccl_async_error_handling is not None:
262                os.environ[
263                    "TORCH_NCCL_ASYNC_ERROR_HANDLING"
264                ] = cached_nccl_async_error_handling
265
266            if cached_nccl_blocking_wait is not None:
267                os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait
268
269    return wrapper
270
271
272def with_dist_debug_levels(levels):
273    """
274    Runs a test for each distributed debug level specified in levels.
275    """
276
277    def decorator(func):
278        @wraps(func)
279        def wrapper(*args, **kwargs):
280            old_level = os.environ.get("TORCH_DISTRIBUTED_DEBUG", None)
281            for level in levels:
282                os.environ["TORCH_DISTRIBUTED_DEBUG"] = level
283                c10d.set_debug_level_from_env()
284                ret = func(*args, **kwargs)
285                c10d.barrier()
286                if old_level is not None:
287                    os.environ["TORCH_DISTRIBUTED_DEBUG"] = old_level
288            # Only returns test return for last test, but since these are
289            # unittests the return value is not really used and earlier tests
290            # would've raised had they failed.
291            return ret
292
293        return wrapper
294
295    return decorator
296
297
298def requires_gloo():
299    return skip_but_pass_in_sandcastle_if(
300        not c10d.is_gloo_available(),
301        "c10d was not compiled with the Gloo backend",
302    )
303
304
305def requires_nccl_version(version, msg):
306    if not c10d.is_nccl_available():
307        return skip_but_pass_in_sandcastle(
308            "c10d was not compiled with the NCCL backend",
309        )
310    else:
311        return skip_but_pass_in_sandcastle_if(
312            torch.cuda.nccl.version() < version,
313            f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
314        )
315
316
317def requires_nccl():
318    return skip_but_pass_in_sandcastle_if(
319        not c10d.is_nccl_available(),
320        "c10d was not compiled with the NCCL backend",
321    )
322
323def requires_ucc():
324    return skip_but_pass_in_sandcastle_if(
325        not c10d.is_ucc_available(),
326        "c10d was not compiled with the UCC backend",
327    )
328
329def requires_mpi():
330    return skip_but_pass_in_sandcastle_if(
331        not c10d.is_mpi_available(),
332        "c10d was not compiled with the MPI backend",
333    )
334
335
336def skip_if_rocm_multiprocess(func):
337    """Skips a test for ROCm"""
338    func.skip_if_rocm_multiprocess = True
339
340    @wraps(func)
341    def wrapper(*args, **kwargs):
342        if not TEST_WITH_ROCM:
343            return func(*args, **kwargs)
344        sys.exit(TEST_SKIPS["skipIfRocm"].exit_code)
345
346    return wrapper
347
348
349def skip_if_win32():
350    return skip_but_pass_in_sandcastle_if(
351        sys.platform == "win32",
352        "This unit test case is not supported on Windows platform",
353    )
354
355
356@retry_on_connect_failures
357def create_tcp_store(
358    addr="localhost",
359    world_size=1,
360    is_master=True,
361    timeout=timedelta(minutes=5),
362    wait_for_workers=True,
363    jit_class=False,
364    use_libuv=True,
365):
366    """
367    Creates a TCP store. Retries if the chosen port is already in use.
368    """
369    port = find_free_port()
370    if jit_class:
371        timeout_millisecond = int(timeout / timedelta(milliseconds=1))
372        return torch.classes.dist_c10d.TCPStore(
373            addr, port, world_size, is_master, timeout_millisecond
374        )
375    else:
376        return c10d.TCPStore(
377            addr, port, world_size, is_master, wait_for_workers=wait_for_workers, use_libuv=use_libuv
378        )
379
380
381if TEST_WITH_TSAN:
382    # TSAN runs much slower.
383    TIMEOUT_DEFAULT = 500
384else:
385    TIMEOUT_DEFAULT = int(os.getenv('DISTRIBUTED_TESTS_DEFAULT_TIMEOUT', '300'))
386TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400}
387
388
389# https://github.com/pytorch/pytorch/issues/75665
390if TEST_WITH_ROCM:
391    TIMEOUT_OVERRIDE["test_join_kwargs"] = 200
392
393
394def create_device(interface=None):
395    if sys.platform == "win32" or interface is None:
396        return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1")
397    else:
398        return c10d.ProcessGroupGloo.create_device(interface=interface)
399
400
401def get_timeout(test_id) -> int:
402    return TIMEOUT_OVERRIDE.get(test_id.split(".")[-1], TIMEOUT_DEFAULT)
403
404
405@contextmanager
406def captured_output():
407    new_out, new_err = StringIO(), StringIO()
408    old_out, old_err = sys.stdout, sys.stderr
409    try:
410        sys.stdout, sys.stderr = new_out, new_err
411        yield sys.stdout, sys.stderr
412    finally:
413        sys.stdout, sys.stderr = old_out, old_err
414
415
416def simple_sparse_reduce_tests(rank: int, world_size: int, num_inputs: int = 1):
417    """
418    Generate a number of basic test cases for sparse reduction.
419    These cover tensors with a varying number of sparse dimensions and a varying
420    number of dense dimensions. The only reduction operation we support is sum.
421    """
422
423    def generate(rank: int, world_size: int, sparse_dims: int = 1, dense_dims: int = 0):
424        # First sparse dimension is [0..rank].
425        # Subsequent dimensions are always 0, so we know there is
426        # a non-empty intersection between any two sparse tensors.
427        indices = torch.reshape(torch.arange(rank + 1), (1, rank + 1))
428        shape = [world_size] + [2 for _ in range(dense_dims)]
429        for _ in range(sparse_dims - 1):
430            indices = torch.cat((indices, torch.zeros(1, rank + 1)))
431            shape.append(world_size)
432        values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)])
433        return torch.sparse_coo_tensor(indices, values, shape)
434
435    def compute_sum(fn, world_size: int):
436        return reduce(
437            operator.add, [fn(rank, world_size) for rank in range(world_size)]
438        )
439
440    return [
441        (
442            [
443                fn(num_inputs * rank + i, num_inputs * world_size)
444                for i in range(num_inputs)
445            ],
446            [compute_sum(fn, num_inputs * world_size) for i in range(num_inputs)],
447        )
448        for fn in [
449            partial(generate, sparse_dims=1),
450            partial(generate, sparse_dims=2),
451            partial(generate, sparse_dims=3),
452            partial(generate, dense_dims=1),
453            partial(generate, dense_dims=2),
454            partial(generate, dense_dims=3),
455        ]
456    ]
457
458
459# HELPER FOR MULTIGPU TESTS
460def init_multigpu_helper(world_size: int, backend: str):
461    """Multigpu tests are designed to simulate the multi nodes with multi
462    GPUs on each node. Nccl backend requires equal #GPUs in each process.
463    On a single node, all visible GPUs are evenly
464    divided to subsets, each process only uses a subset.
465    """
466    nGPUs = torch.cuda.device_count()
467    visible_devices = range(nGPUs)
468
469    # If rank is less than or equal to number of available GPU's
470    # then each rank can be mapped to corresponding GPU.
471    nGPUs_per_process = 1
472    if world_size > nGPUs:
473        nGPUs_per_process = nGPUs // world_size
474    rank_to_GPU = {
475        i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process])
476        for i in range(world_size)
477    }
478    return rank_to_GPU
479
480
481tmp_dir: Optional[tempfile.TemporaryDirectory] = None
482
483
484def initialize_temp_directories(init_method: Optional[str] = None) -> None:
485    global tmp_dir
486    tmp_dir = tempfile.TemporaryDirectory()
487    os.environ["TEMP_DIR"] = tmp_dir.name
488    os.mkdir(os.path.join(tmp_dir.name, "barrier"))
489    os.mkdir(os.path.join(tmp_dir.name, "test_dir"))
490    init_dir_path = os.path.join(tmp_dir.name, "init_dir")
491    os.mkdir(init_dir_path)
492    # Set init method if specified.
493    if init_method is not None:
494        os.environ["INIT_METHOD"] = init_method
495    else:
496        os.environ["INIT_METHOD"] = FILE_SCHEMA + os.path.join(
497            init_dir_path, "shared_init_file"
498        )
499
500
501def cleanup_temp_dir() -> None:
502    if tmp_dir is not None:
503        tmp_dir.cleanup()
504
505
506# Most tests operate with this worldsize
507DEFAULT_WORLD_SIZE = 4
508
509# [How does MultiProcessTestCase work?]
510# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
511# default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an
512# example which inherits from this class. Its `Setup()` methods calls into
513# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()`
514# subprocesses. During the spawn, the main process passes the test name to
515# subprocesses, and the name is acquired from self.id(). The subprocesses
516# then use the provided test function name to retrieve the function attribute
517# from the test instance and run it. The main process simply waits for all
518# subprocesses to join.
519
520
521class MultiProcessTestCase(TestCase):
522    MAIN_PROCESS_RANK = -1
523    # This exit code is used to indicate that the test code had an error and
524    # exited abnormally. There are certain tests that might use sys.exit() to
525    # simulate failures and in those cases, we can't have an exit code of 0,
526    # but we still want to ensure we didn't run into any other errors.
527    TEST_ERROR_EXIT_CODE = 10
528
529    # do not early terminate for distributed tests.
530    def _should_stop_test_suite(self) -> bool:
531        return False
532
533    @property
534    def world_size(self) -> int:
535        return DEFAULT_WORLD_SIZE
536
537    def join_or_run(self, fn):
538        @wraps(fn)
539        def wrapper(self):
540            if self.rank == self.MAIN_PROCESS_RANK:
541                self._join_processes(fn)
542            else:
543                fn()
544
545        return types.MethodType(wrapper, self)
546
547    # The main process spawns N subprocesses that run the test.
548    # Constructor patches current instance test method to
549    # assume the role of the main process and join its subprocesses,
550    # or run the underlying test function.
551    def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None:
552        # methodName is the correct naming in unittest and testslide uses keyword arguments.
553        # So we need to use both to 1) not break BC and, 2) support testslide.
554        if methodName != "runTest":
555            method_name = methodName
556        super().__init__(method_name)
557        fn = getattr(self, method_name)
558        setattr(self, method_name, self.join_or_run(fn))
559
560    def setUp(self) -> None:
561        super().setUp()
562        self.skip_return_code_checks = []  # type: ignore[var-annotated]
563        self.processes = []  # type: ignore[var-annotated]
564        self.rank = self.MAIN_PROCESS_RANK
565        self.file_name = tempfile.NamedTemporaryFile(delete=False).name
566        # pid to pipe consisting of error message from process.
567        self.pid_to_pipe = {}  # type: ignore[var-annotated]
568
569    def tearDown(self) -> None:
570        super().tearDown()
571        for p in self.processes:
572            p.terminate()
573        # Each Process instance holds a few open file descriptors. The unittest
574        # runner creates a new TestCase instance for each test method and keeps
575        # it alive until the end of the entire suite. We must thus reset the
576        # processes to prevent an effective file descriptor leak.
577        self.processes = []
578
579    def _current_test_name(self) -> str:
580        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
581        return self.id().split(".")[-1]
582
583    def _start_processes(self, proc) -> None:
584        self.processes = []
585        for rank in range(int(self.world_size)):
586            parent_conn, child_conn = torch.multiprocessing.Pipe()
587            process = proc(
588                target=self.__class__._run,
589                name="process " + str(rank),
590                args=(rank, self._current_test_name(), self.file_name, child_conn),
591                kwargs={
592                    "fake_pg": getattr(self, "fake_pg", False),
593                }
594            )
595            process.start()
596            logger.info("Started process %s with pid %s", rank, process.pid)
597            self.pid_to_pipe[process.pid] = parent_conn
598            self.processes.append(process)
599
600    def _spawn_processes(self) -> None:
601        proc = torch.multiprocessing.get_context("spawn").Process
602        self._start_processes(proc)
603
604    class Event(Enum):
605        GET_TRACEBACK = 1
606
607    @staticmethod
608    def _event_listener(parent_pipe, signal_pipe, rank: int):
609        logger.info("Starting event listener thread for rank %s", rank)
610        while True:
611            ready_pipes = multiprocessing.connection.wait([parent_pipe, signal_pipe])
612
613            if parent_pipe in ready_pipes:
614
615                if parent_pipe.closed:
616                    logger.info(
617                        "Pipe closed for process %s, stopping event listener thread", rank
618                    )
619                    return
620
621                event = parent_pipe.recv()
622                logger.info("Received event %s on process %s", event, rank)
623
624                if event == MultiProcessTestCase.Event.GET_TRACEBACK:
625                    # Return traceback to the parent process.
626                    with tempfile.NamedTemporaryFile(mode="r+") as tmp_file:
627                        faulthandler.dump_traceback(tmp_file)
628                        # Flush buffers and seek to read from the beginning
629                        tmp_file.flush()
630                        tmp_file.seek(0)
631                        parent_pipe.send(tmp_file.read())
632
633                        logger.info("Process %s sent traceback", rank)
634
635            if signal_pipe in ready_pipes:
636                return
637
638    @classmethod
639    def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs) -> None:
640        self = cls(test_name)
641        self.rank = rank
642        self.file_name = file_name
643        self.run_test(test_name, parent_pipe)
644
645    def run_test(self, test_name: str, parent_pipe) -> None:
646        # Start event listener thread.
647        signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe(duplex=False)
648        event_listener_thread = threading.Thread(
649            target=MultiProcessTestCase._event_listener,
650            args=(parent_pipe, signal_recv_pipe, self.rank),
651            daemon=True,
652        )
653        event_listener_thread.start()
654        if sys.platform != "win32" and sys.platform != "darwin":
655            # Register signal handler to dump stack traces on FATALs.
656            # Windows and MacOS do not support the signal handlers.
657            torch._C._set_print_stack_traces_on_fatal_signal(True)
658        # Show full C++ stacktraces when a Python error originating from C++ is raised.
659        os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
660
661        # self.id() == e.g. '__main__.TestDistributed.test_get_rank'
662        # We're retrieving a corresponding test and executing it.
663        try:
664            getattr(self, test_name)()
665        except unittest.SkipTest as se:
666            logger.info(
667                "Process %s skipping test %s for following reason: %s", self.rank, test_name, str(se)
668            )
669            sys.exit(TEST_SKIPS["generic"].exit_code)
670        except Exception as e:
671            logger.error(
672                "Caught exception: \n%s exiting "
673                "process %s with exit code: %s",
674                traceback.format_exc(), self.rank, MultiProcessTestCase.TEST_ERROR_EXIT_CODE
675            )
676            # Send error to parent process.
677            parent_pipe.send(traceback.format_exc())
678            sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
679        finally:
680            if signal_send_pipe is not None:
681                signal_send_pipe.send(None)
682
683            assert event_listener_thread is not None
684            event_listener_thread.join()
685            # Close pipe after done with test.
686            parent_pipe.close()
687
688    def _get_timedout_process_traceback(self) -> None:
689        pipes = []
690        for i, process in enumerate(self.processes):
691            if process.exitcode is None:
692                pipe = self.pid_to_pipe[process.pid]
693                try:
694                    pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK)
695                    pipes.append((i, pipe))
696                except ConnectionError as e:
697                    logger.error(
698                        "Encountered error while trying to get traceback for process %s: %s", i, e
699                    )
700
701        # Wait for results.
702        for rank, pipe in pipes:
703            try:
704                # Wait for traceback
705                if pipe.poll(5):
706                    if pipe.closed:
707                        logger.info(
708                            "Pipe closed for process %s, cannot retrieve traceback", rank
709                        )
710                        continue
711
712                    traceback = pipe.recv()
713                    logger.error(
714                        "Process %s timed out with traceback: \n\n%s", rank, traceback
715                    )
716                else:
717                    logger.error(
718                        "Could not retrieve traceback for timed out process: %s", rank
719                    )
720            except ConnectionError as e:
721                logger.error(
722                    "Encountered error while trying to get traceback for process %s: %s", rank, e
723                )
724
725    def _join_processes(self, fn) -> None:
726        timeout = get_timeout(self.id())
727        start_time = time.time()
728        subprocess_error = False
729        try:
730            while True:
731                # check to see if any subprocess exited with an error early.
732                for (i, p) in enumerate(self.processes):
733                    # This is the exit code processes exit with if they
734                    # encountered an exception.
735                    if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE:
736                        print(
737                            f"Process {i} terminated with exit code {p.exitcode}, terminating remaining processes."
738                        )
739                        active_children = torch.multiprocessing.active_children()
740                        for ac in active_children:
741                            ac.terminate()
742                        subprocess_error = True
743                        break
744                if subprocess_error:
745                    break
746                # All processes have joined cleanly if they all a valid exitcode
747                if all(p.exitcode is not None for p in self.processes):
748                    break
749                # Check if we should time out the test. If so, we terminate each process.
750                elapsed = time.time() - start_time
751                if elapsed > timeout:
752                    self._get_timedout_process_traceback()
753                    print(
754                        f"Timing out after {timeout} seconds and killing subprocesses."
755                    )
756                    for p in self.processes:
757                        p.terminate()
758                    break
759                # Sleep to avoid excessive busy polling.
760                time.sleep(0.1)
761
762            elapsed_time = time.time() - start_time
763
764            if fn in self.skip_return_code_checks:
765                self._check_no_test_errors(elapsed_time)
766            else:
767                self._check_return_codes(elapsed_time)
768        finally:
769            # Close all pipes
770            for pipe in self.pid_to_pipe.values():
771                pipe.close()
772
773    def _check_no_test_errors(self, elapsed_time) -> None:
774        """
775        Checks that we didn't have any errors thrown in the child processes.
776        """
777        for i, p in enumerate(self.processes):
778            if p.exitcode is None:
779                raise RuntimeError(
780                    f"Process {i} timed out after {elapsed_time} seconds"
781                )
782            self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
783
784    def _check_return_codes(self, elapsed_time) -> None:
785        """
786        Checks that the return codes of all spawned processes match, and skips
787        tests if they returned a return code indicating a skipping condition.
788        """
789        # If no processes are spawned, there is nothing to check.
790        if not self.processes:
791            logger.warning("Note: no subprocesses were spawned, test was likely skipped.")
792            return
793
794        first_process = self.processes[0]
795        # first, we check if there are errors in actual processes
796        # (via TEST_ERROR_EXIT CODE), and raise an exception for those.
797        # the reason we do this is to attempt to raise a more helpful error
798        # message than "Process x terminated/timed out"
799        # TODO: we should pipe the exception of the failed subprocess here.
800        # Currently, the actual exception is displayed as a logging output.
801        errored_processes = [
802            (i, p)
803            for i, p in enumerate(self.processes)
804            if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE
805        ]
806        if errored_processes:
807            error = ""
808            for i, process in errored_processes:
809                # Get error from pipe.
810                error_message = self.pid_to_pipe[process.pid].recv()
811                error += (
812                    f"Process {i} exited with error code {MultiProcessTestCase.TEST_ERROR_EXIT_CODE} "
813                    f"and exception:\n{error_message}\n"
814                )
815
816            raise RuntimeError(error)
817        # If no process exited uncleanly, we check for timeouts, and then ensure
818        # each process exited cleanly.
819        for i, p in enumerate(self.processes):
820            if p.exitcode is None:
821                raise RuntimeError(
822                    f"Process {i} terminated or timed out after {elapsed_time} seconds"
823                )
824            self.assertEqual(
825                p.exitcode,
826                first_process.exitcode,
827                msg=f"Expect process {i} exit code to match Process 0 exit code of {first_process.exitcode}, but got {p.exitcode}",
828            )
829        for skip in TEST_SKIPS.values():
830            if first_process.exitcode == skip.exit_code:
831                if IS_SANDCASTLE:
832                    # Don't use unittest.skip to skip the test on sandcastle
833                    # since it creates tasks for skipped tests assuming there
834                    # is some follow-up needed. Instead just "pass" the test
835                    # with an appropriate message.
836                    logger.info(
837                        "Skipping %s on sandcastle for the following reason: %s", self.id(), skip.message
838                    )
839                    return
840                else:
841                    raise unittest.SkipTest(skip.message)
842        self.assertEqual(
843            first_process.exitcode,
844            0,
845            msg=f"Expected zero exit code but got {first_process.exitcode} for pid: {first_process.pid}",
846        )
847
848    @property
849    def is_master(self) -> bool:
850        return self.rank == 0
851
852
853def run_subtests(
854    cls_inst,
855    subtest_config: Dict[str, List[Any]],
856    test_fn: Callable,
857    *test_args,
858    **test_kwargs: Any,
859):
860    """
861    Runs a test function given by ``test_fn`` as a subtest according to the
862    configurations specified by ``subtest_config``. This amortizes the
863    costly setup overhead (including process spawn and initializing the
864    process group) over the subtests.
865
866    Args:
867        subtest_config (Dict[str, List[Any]]): A mapping from subtest
868            keyword argument name to a list of its possible values.
869        test_fn (Callable): A callable that runs the actual test.
870        test_args: Positional arguments to pass to ``test_fn``.
871        test_kwargs: Keyword arguments to pass to ``test_fn``.
872    """
873    # Convert the config mapping to a list to have a fixed order
874    subtest_config_items: List[Tuple[str, List[Any]]] = list(subtest_config.items())
875    subtest_config_keys: List[str] = [item[0] for item in subtest_config_items]
876    subtest_config_values: List[List[Any]] = [item[1] for item in subtest_config_items]
877    for values in itertools.product(*subtest_config_values):
878        # Map keyword to chosen value
879        subtest_kwargs = dict(zip(subtest_config_keys, values))
880        with cls_inst.subTest(**subtest_kwargs):
881            torch._dynamo.reset()
882            test_fn(*test_args, **test_kwargs, **subtest_kwargs)
883            torch._dynamo.reset()
884        c10d.barrier()
885
886
887# Cannot use functools.cache as it requires python 3.9
888EFA_PROBE_RESULT = None
889
890
891def has_efa() -> bool:
892    """
893    If shell command `fi_info -p efa -t FI_EP_RDM` returns exit code 0 then we assume that the machine has
894    Libfabric EFA interfaces and EFA software components installed,
895    see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa-start.html.
896    """
897    global EFA_PROBE_RESULT
898    if EFA_PROBE_RESULT is not None:
899        return EFA_PROBE_RESULT
900
901    try:
902        EFA_PROBE_RESULT = (
903            subprocess.run(["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], check=False).returncode == 0
904        )
905    except FileNotFoundError:
906        EFA_PROBE_RESULT = False
907    return EFA_PROBE_RESULT
908
909
910def tp_transports():
911    """
912    If the machine has Libfabric EFA interfaces and EFA software components installed it may cause
913    'RuntimeError: In operator() at tensorpipe/common/ibv.h:172 "": Operation not supported' if tensorpipe
914    uses InfiniBand transport, so we exclude it from tensorpipe transports,
915    see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022
916    """
917    return ["shm", "uv"] if has_efa() else None
918
919
920def spawn_threads_and_init_comms(
921    func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE
922):
923    """
924    Wrapper to use with a test method
925    """
926    if func is None:
927        return partial(
928            spawn_threads_and_init_comms, timeout=timeout, world_size=world_size
929        )
930
931
932    def _run_test_method_with_multi_threads(world_size, callback):
933        world = _install_threaded_pg()
934        global_store = c10d.HashStore()
935
936        def world_is_valid():
937            return world == c10d.distributed_c10d._world
938
939        def worker(rank, world_pg, store):
940            c10d.init_process_group(
941                backend="threaded", rank=rank, world_size=world_size, store=store
942            )
943            try:
944                callback()
945            except BaseException as ex:
946                # Exceptions are handled in MultiThreadedTestCase
947                MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info()))
948                ProcessLocalGroup.exception_handle(ex)  # trigger _terminate event and awaken worker threads
949            finally:
950                if world_is_valid():
951                    c10d.destroy_process_group()
952
953        threads = []
954        for rank in range(world_size):
955            t = threading.Thread(target=worker, args=(rank, world, global_store))
956            t.start()
957            threads.append(t)
958
959        return threads
960
961
962    @wraps(func)
963    def wrapper(self, *args, **kwargs):
964        # TODO: get test name from kwargs
965        torch._C._distributed_c10d._set_thread_isolation_mode(True)
966        try:
967            threads = _run_test_method_with_multi_threads(world_size, lambda: func(self, *args, **kwargs))
968            # join and error handling
969            MultiThreadedTestCase._join_threads(threads, func)
970        finally:
971            torch._C._distributed_c10d._set_thread_isolation_mode(False)
972
973    return wrapper
974
975
976class MultiThreadedTestCase(TestCase):
977    """
978    Test runner that runs all tests with the in-proc process group using
979    multiple threads with the threaded process group.
980
981    Each test spawns world_size threads and run the test method in each thread.
982
983    Difference from regular MultiProcess test runner:
984    Must explicitly defines SetUp and call self._spawn_threads() to run the tests.
985    Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown)
986        to set up / tear down each thread when running each test.
987    No global state possible
988        How bad of a limitation is this?
989    """
990    exception_queue = queue.Queue()
991
992    MAIN_THREAD_RANK = -1
993
994    def join_or_run(self, fn):
995        @wraps(fn)
996        def wrapper(self):
997            if self.rank == self.MAIN_THREAD_RANK:
998                self._join_threads(self.threads, fn)
999            else:
1000                fn()
1001
1002        return types.MethodType(wrapper, self)
1003
1004    def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None:
1005        # methodName is the correct naming in unittest and testslide uses keyword arguments.
1006        # So we need to use both to 1) not break BC and, 2) support testslide.
1007        if methodName != "runTest":
1008            method_name = methodName
1009        super().__init__(method_name)
1010        fn = getattr(self, method_name)
1011        setattr(self, method_name, self.join_or_run(fn))
1012
1013    def perThreadSetUp(self):
1014        # super().setUp()  # TestCase.setUp() calls torch.manual_seed()
1015        pass
1016
1017    def perThreadTearDown(self):
1018        pass
1019
1020    def setUp(self) -> None:
1021        """
1022        setUp only set up things in the main thread, if you want to configure things
1023        in the spawned threads, use perThreadSetUp
1024        """
1025        super().setUp()
1026        self.rank = self.MAIN_THREAD_RANK
1027        self.threads = []
1028        # Show full C++ stacktraces when a Python error originating from C++ is raised.
1029        os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
1030
1031    def tearDown(self):
1032        """
1033        tearDown only set up things in the main thread, if you want to configure things
1034        in the spawned threads, use perThreadTearDown
1035        """
1036        super().tearDown()
1037        self.threads = []
1038
1039    def _spawn_threads(self):
1040        """
1041        class method to spawn threads and run test, use this method in the SetUp of your TestCase
1042        """
1043        torch._C._distributed_c10d._set_thread_isolation_mode(True)
1044        test_name = self._current_test_name
1045        # for each test case, we need to create thread local world, and a global store
1046        world = _install_threaded_pg()
1047        self.__class__.global_store = c10d.HashStore()
1048
1049        def world_is_valid():
1050            return world == c10d.distributed_c10d._world
1051
1052        if not world_is_valid():
1053            raise RuntimeError("Invalid world")
1054
1055        for rank in range(self.world_size):
1056            t = threading.Thread(target=self.__class__._run, args=(test_name, rank, self.world_size))
1057            t.start()
1058            self.threads.append(t)
1059
1060    @classmethod
1061    def _run(cls, test_name, rank, world_size, **kwargs):
1062        self = cls(test_name)
1063        self.rank = rank
1064
1065        # precision/rel_tol is a thread-local setting since it may be overridden per test, need to make
1066        # every thread have the same value. This would be relevant when we use op db tests, where it
1067        # needs those states to be set i.e. using instantiate_device_type_tests()
1068        # TODO: figure out a better way to do this
1069        if hasattr(self, "_tls"):
1070            self._tls = threading.local()
1071            self._tls.precision = TestCase._precision
1072            self._tls.rel_tol = TestCase._rel_tol
1073
1074        self.run_test_with_threaded_pg(test_name, rank, world_size)
1075
1076    def run_test_with_threaded_pg(self, test_name, rank, world_size):
1077        """
1078        Run the current test associated with `test_name` using the threaded process group.
1079        """
1080        c10d.init_process_group(
1081            backend="threaded", rank=rank, world_size=world_size, store=self.__class__.global_store
1082        )
1083        self.perThreadSetUp()
1084
1085        try:
1086            getattr(self, test_name)()
1087        except BaseException as ex:
1088            self.exception_queue.put((rank, sys.exc_info()))
1089            ProcessLocalGroup.exception_handle(ex)  # trigger _terminate event and awaken worker threads
1090        finally:
1091            c10d.destroy_process_group()
1092            self.perThreadTearDown()
1093
1094
1095    @classmethod
1096    def _join_threads(cls, threads, fn):
1097        timeout = TIMEOUT_DEFAULT
1098        try:
1099            for idx, thread in enumerate(threads):
1100                thread.join(max(0, timeout))
1101                if thread.is_alive():
1102                    MultiThreadedTestCase.exception_queue.put(
1103                        (
1104                            idx,
1105                            (
1106                                TimeoutError,
1107                                TimeoutError(
1108                                    f"Rank failed to join in under {timeout} seconds"
1109                                ),
1110                                None,
1111                            ),
1112                        )
1113                    )
1114            ProcessLocalGroup.reset()
1115            failed_ranks = []
1116            while not cls.exception_queue.empty():
1117                failure = cls.exception_queue.get()
1118                failed_ranks.append(failure)
1119        finally:
1120            _uninstall_threaded_pg()
1121            torch._C._distributed_c10d._set_thread_isolation_mode(False)
1122
1123        cls._check_return_codes(failed_ranks, timeout, fn)
1124
1125    @classmethod
1126    def _check_return_codes(cls, failed_ranks, timeout, fn):
1127        # Print based on exceptions raised from threads
1128        #   SkipTest: print info for each thread
1129        #   TimeoutError: raise RuntimeError for any timed out thread
1130        #   Normal Exception: print error for each thread that raises exception
1131        #   and raise a RuntimeError
1132        error_msg = ""
1133        skip_code = -1
1134        for rank, exc_info in failed_ranks:
1135            exc = exc_info[1]
1136            if isinstance(exc, unittest.SkipTest):
1137                logger.info(
1138                    "Thread %s skipping test %s for following reason: %s", rank, fn, str(exc)
1139                )
1140                if skip_code < 0:
1141                    skip_code = TEST_SKIPS["generic"].exit_code
1142            elif isinstance(exc, TimeoutError):
1143                msg = f"Thread {rank} terminated or timed out after {timeout} seconds\n"
1144                logger.error(msg)
1145                raise RuntimeError(msg)
1146            elif isinstance(exc, Exception):
1147                msg = "".join(traceback.format_exception(*exc_info))
1148                logger.error(
1149                    "Caught exception: \n%s exiting thread %s", msg, rank
1150                )
1151                error_msg += (
1152                    f"Thread {rank} exited with exception:\n{msg}\n"
1153                )
1154            elif isinstance(exc, SystemExit):
1155                if type(exc.code) == int and skip_code < 0:
1156                    skip_code = exc.code
1157
1158        # check exceptions
1159        if len(error_msg) > 0:
1160            raise RuntimeError(error_msg)
1161        # check skip
1162        if skip_code > 0:
1163            for skip in TEST_SKIPS.values():
1164                if skip_code == skip.exit_code:
1165                    if IS_SANDCASTLE:
1166                        # "pass" the test with an appropriate message.
1167                        logger.info(
1168                            "Skipping %s on sandcastle for the following reason: %s", fn, skip.message
1169                        )
1170                        return
1171                    else:
1172                        raise unittest.SkipTest(skip.message)
1173
1174    @property
1175    def world_size(self) -> int:
1176        return DEFAULT_WORLD_SIZE
1177
1178    @property
1179    def _current_test_name(self) -> str:
1180        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
1181        return self.id().split(".")[-1]
1182
1183    def assertEqualOnRank(self, x, y, msg=None, *, rank=0):
1184        """
1185        The reason why we have this util function instead of
1186        self.assertEqual is all threads are sharing one CPU RNG
1187        so the assertion result is only reliable on rank 0
1188        """
1189        if self.rank == rank:
1190            self.assertEqual(x, y, msg)
1191
1192    def assertNotEqualOnRank(self, x, y, msg=None, *, rank=0):
1193        if self.rank == rank:
1194            self.assertNotEqual(x, y)
1195
1196
1197class SaveForwardInputsModule(nn.Module):
1198    def __init__(
1199        self,
1200        forward_inputs: Dict[nn.Module, torch.Tensor],
1201        cast_forward_inputs: bool,
1202    ) -> None:
1203        super().__init__()
1204        self.l = nn.Linear(100, 100)
1205        self.forward_inputs = forward_inputs
1206        self.cast_forward_inputs = cast_forward_inputs
1207
1208    def forward(self, x: torch.Tensor) -> torch.Tensor:
1209        self.forward_inputs[self] = x
1210        return self.l(x.to(self.l.weight.dtype) if self.cast_forward_inputs else x)
1211
1212
1213class SaveForwardInputsModel(nn.Module):
1214    def __init__(
1215        self,
1216        forward_inputs: Dict[nn.Module, torch.Tensor],
1217        cast_forward_inputs: bool,
1218    ) -> None:
1219        super().__init__()
1220        self.c1 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs)
1221        self.c2 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs)
1222        self.forward_inputs = forward_inputs
1223
1224    def forward(self, x: torch.Tensor) -> torch.Tensor:
1225        self.forward_inputs[self] = x
1226        return self.c2(self.c1(x))
1227
1228@contextmanager
1229def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
1230    # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
1231    # Just manually implement the most important part of the dynamo behavior to reset/clear.
1232    if not fake_pg:
1233        torch.cuda.set_device(rank)
1234    os.environ['MASTER_ADDR'] = 'localhost'
1235    os.environ['MASTER_PORT'] = '6789'
1236    if init_pg:
1237        if fake_pg:
1238            store = torch.testing._internal.distributed.fake_pg.FakeStore()
1239            c10d.init_process_group(
1240                backend="fake",
1241                world_size=world_size,
1242                rank=rank,
1243                store=store,
1244            )
1245        else:
1246            c10d.init_process_group("nccl", rank=rank, world_size=world_size)
1247    torch._dynamo.reset()
1248    torch._dynamo.utils.counters.clear()
1249    try:
1250        yield
1251    finally:
1252        torch._dynamo.reset()
1253        torch._dynamo.utils.counters.clear()
1254        if init_pg:
1255            c10d.destroy_process_group()
1256
1257
1258class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
1259    """
1260    Test harness for single-process dynamo distributed tests,
1261    initializes dist process group.
1262
1263    Prefer this for simple tests, as it's easier to debug.
1264    """
1265
1266    @classmethod
1267    def setUpClass(cls):
1268        super().setUpClass()
1269        # _exit_stack is set up in TestCase
1270        cls._exit_stack.enter_context(
1271            patch.dict(
1272                os.environ,
1273                {
1274                    "MASTER_ADDR": "localhost",
1275                    "MASTER_PORT": "12355",
1276                },
1277            )
1278        )
1279        cls.rank = 0
1280        cls.device = f"cuda:{cls.rank}"
1281        cls.device_ids = None if "cuda" in cls.device else [cls.rank]
1282        c10d.init_process_group("nccl", rank=cls.rank, world_size=1)
1283
1284    @classmethod
1285    def tearDownClass(cls):
1286        c10d.destroy_process_group()
1287        super().tearDownClass()
1288
1289
1290class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
1291    """
1292    Use this for tests that actually run on multiple GPUs.
1293
1294    Decorate tests with @skip_if_lt_x_gpu(ngpu)
1295
1296    Note: MultiProcTestCase spawns processes per test and is slow.
1297    Prefer MultiThreadedTestCase for most tests. Perhaps use this one
1298    sparingly for integration tests.
1299    """
1300    def setUp(self):
1301        super().setUp()
1302        self._spawn_processes()
1303
1304    def tearDown(self):
1305        super().tearDown()
1306        try:
1307            os.remove(self.file_name)
1308        except OSError:
1309            pass
1310
1311    @property
1312    def world_size(self) -> int:
1313        return torch.cuda.device_count()
1314
1315    @classmethod
1316    def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs) -> None:
1317        # The rest is copypasta from MultiProcessTestCase._run
1318        self = cls(test_name)
1319        self.rank = rank
1320        self.file_name = file_name
1321        self.run_test(test_name, parent_pipe)
1322
1323
1324class MultiProcContinousTest(TestCase):
1325    # Class variables:
1326    # number of test processes
1327    world_size: int = 2
1328    # rank of the current process
1329    rank: int = -1  # unset state
1330    # Rendezvous file
1331    rdvz_file: Optional[str] = None
1332
1333    @classmethod
1334    @abc.abstractmethod
1335    def backend_str(cls) -> str:
1336        """
1337        ProcessGroup backend str.
1338        To be customized by sub test classes, e.g. "nccl".
1339        Here we raise error.
1340        """
1341        raise NotImplementedError("Please implement backend_str in your test class")
1342
1343    @classmethod
1344    def opts(cls, high_priority_stream=False):
1345        """
1346        ProcessGroup init options.
1347        To be customized by sub test classes, e.g. ProcessGroupNCCLOpTest
1348        Here we return None.
1349        """
1350        return None
1351
1352    @classmethod
1353    def setUpClass(cls):
1354        """
1355        Class-scope test fixture. Run once for entire test class, before any test starts.
1356        Set up the process group.
1357        """
1358        super().setUpClass()
1359        if not 0 <= cls.rank < cls.world_size:
1360            raise RuntimeError(
1361                "Rank must be set and in the range of 0 to world_size. "
1362                f"World size: {cls.world_size} Rank: {cls.rank}"
1363            )
1364        if cls.rdvz_file:
1365            store = c10d.FileStore(cls.rdvz_file, cls.world_size)
1366        else:
1367            # torchrun takes care of rendezvous
1368            store = None
1369        opts = cls.opts()
1370        backend = cls.backend_str()
1371        print(f"Testing {backend=}")
1372        # create nccl processgroup with opts
1373        c10d.init_process_group(
1374            backend=backend,
1375            world_size=cls.world_size,
1376            rank=cls.rank,
1377            store=store,
1378            pg_options=opts,
1379        )
1380        cls.pg = c10d.distributed_c10d._get_default_group()
1381        print(f"Rank {cls.rank} setup complete")
1382
1383    @classmethod
1384    def tearDownClass(cls):
1385        """
1386        Class-scope test fixture. Run once for entire test class, after all tests finish.
1387        Tear down the process group.
1388        """
1389        c10d.destroy_process_group()
1390        super().tearDownClass()
1391        # Clear up the rendezvous file
1392        if cls.rdvz_file:
1393            try:
1394                os.remove(cls.rdvz_file)
1395            except OSError:
1396                pass
1397        print(f"Rank {cls.rank} teardown complete")
1398
1399    @classmethod
1400    def run_rank(
1401        cls,
1402        rank: int,
1403        world_size: int,
1404        rdvz_file: Optional[str] = None,
1405    ):
1406        """
1407        This is an entry point for each rank to run the tests in `MultiProcContinousTest`.
1408        In this entry point, we set the class variables for the test class.
1409        Then we run all tests.
1410
1411        Note:
1412        - This helper only works for a subclass of `MultiProcContinousTest`.
1413
1414        Example:
1415        - See `test_c10d_ops_nccl.py`.
1416        """
1417        # set class variables for the test class
1418        cls.rank = rank
1419        cls.world_size = world_size
1420        cls.rdvz_file = rdvz_file
1421        # Launch tests via `common_utils` infra
1422        run_tests()
1423