1# mypy: ignore-errors 2 3import abc 4import faulthandler 5import itertools 6import logging 7import multiprocessing 8import os 9import queue 10import subprocess 11import sys 12import tempfile 13import threading 14import time 15import traceback 16import types 17import unittest 18from contextlib import contextmanager 19from dataclasses import dataclass 20from datetime import timedelta 21from enum import Enum 22from functools import partial, reduce, wraps 23from io import StringIO 24from typing import Dict, NamedTuple, Optional, Union, List, Any, Callable, Tuple 25from unittest.mock import patch 26 27import torch 28import torch._dynamo.test_case 29import torch.cuda.nccl 30import torch.distributed as c10d 31import torch.nn as nn 32from torch.testing._internal.common_utils import ( 33 FILE_SCHEMA, 34 find_free_port, 35 IS_SANDCASTLE, 36 retry_on_connect_failures, 37 skip_but_pass_in_sandcastle, 38 skip_but_pass_in_sandcastle_if, 39 TEST_WITH_ROCM, 40 TEST_WITH_TSAN, 41 TestCase, 42 run_tests, 43) 44from torch.testing._internal.distributed.multi_threaded_pg import ( 45 _install_threaded_pg, 46 _uninstall_threaded_pg, 47 ProcessLocalGroup, 48) 49import operator 50 51logging.basicConfig(level=logging.INFO) 52logger = logging.getLogger(__name__) 53 54 55class TestSkip(NamedTuple): 56 exit_code: int 57 message: str 58 59 60TEST_SKIPS = { 61 "backend_unavailable": TestSkip( 62 72, "Skipped because distributed backend is not available." 63 ), 64 "small_worldsize": TestSkip(73, "Skipped due to small world size."), 65 "odd_worldsize": TestSkip(87, "Skipped due to odd world size."), 66 "no_cuda": TestSkip(74, "CUDA is not available."), 67 "multi-gpu-1": TestSkip(75, "Need at least 1 CUDA device"), 68 "multi-gpu-2": TestSkip(77, "Need at least 2 CUDA devices"), 69 "multi-gpu-3": TestSkip(80, "Need at least 3 CUDA devices"), 70 "multi-gpu-4": TestSkip(81, "Need at least 4 CUDA devices"), 71 "multi-gpu-5": TestSkip(82, "Need at least 5 CUDA devices"), 72 "multi-gpu-6": TestSkip(83, "Need at least 6 CUDA devices"), 73 "multi-gpu-7": TestSkip(84, "Need at least 7 CUDA devices"), 74 "multi-gpu-8": TestSkip(85, "Need at least 8 CUDA devices"), 75 "nccl": TestSkip(76, "c10d not compiled with NCCL support"), 76 "skipIfRocm": TestSkip(78, "Test skipped for ROCm"), 77 "no_peer_access": TestSkip(79, "Test skipped because no GPU peer access"), 78 "generic": TestSkip( 79 86, "Test skipped at subprocess level, look at subprocess log for skip reason" 80 ), 81 "importerror": TestSkip(88, "Test skipped due to missing import"), 82} 83 84 85@dataclass 86class DistTestCases: 87 # Backends that do not support a specific collective 88 skip_collective = {} 89 skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"} 90 skip_collective["reduce"] = set() 91 skip_collective["sendrecv anysource"] = {"nccl", "ucc"} 92 skip_collective["cpu barrier"] = {"nccl", "ucc"} 93 94 # Sets showing that something is implemented 95 backend_feature = {} 96 backend_feature["gpu"] = {"nccl", "gloo", "ucc"} 97 backend_feature["cuda"] = {"nccl", "gloo", "ucc"} 98 backend_feature["ddp"] = {"nccl", "gloo", "ucc"} 99 backend_feature["subgroup"] = {"nccl", "gloo", "ucc"} 100 backend_feature["plugin"] = set() 101 102 103def skip_if_no_gpu(func): 104 """Skips if the world size exceeds the number of GPUs, ensuring that if the 105 test is run, each rank has its own GPU via ``torch.cuda.device(rank)``.""" 106 107 @wraps(func) 108 def wrapper(*args, **kwargs): 109 if not torch.cuda.is_available(): 110 sys.exit(TEST_SKIPS["no_cuda"].exit_code) 111 world_size = int(os.environ["WORLD_SIZE"]) 112 if torch.cuda.device_count() < world_size: 113 sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) 114 115 return func(*args, **kwargs) 116 117 return wrapper 118 119 120def skip_if_small_worldsize(func): 121 @wraps(func) 122 def wrapper(*args, **kwargs): 123 if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) <= 2: 124 sys.exit(TEST_SKIPS["small_worldsize"].exit_code) 125 126 return func(*args, **kwargs) 127 128 return wrapper 129 130 131def skip_if_odd_worldsize(func): 132 @wraps(func) 133 def wrapper(*args, **kwargs): 134 if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) % 2 == 1: 135 sys.exit(TEST_SKIPS["odd_worldsize"].exit_code) 136 137 return func(*args, **kwargs) 138 139 return wrapper 140 141 142def require_n_gpus_for_nccl_backend(n, backend): 143 def decorator(func): 144 @wraps(func) 145 def wrapper(*args, **kwargs): 146 if backend == "nccl" and torch.cuda.device_count() < n: 147 sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code) 148 else: 149 return func(*args, **kwargs) 150 151 return wrapper 152 153 return decorator 154 155 156def import_transformers_or_skip(): 157 def decorator(func): 158 @wraps(func) 159 def wrapper(*args, **kwargs): 160 try: 161 from transformers import ( # noqa: F401 162 AutoModelForMaskedLM, 163 BertConfig, 164 ) 165 166 return func(*args, **kwargs) 167 except ImportError: 168 sys.exit(TEST_SKIPS["importerror"].exit_code) 169 170 return wrapper 171 172 return decorator 173 174 175def at_least_x_gpu(x): 176 return torch.cuda.is_available() and torch.cuda.device_count() >= x 177 178 179def skip_if_lt_x_gpu(x): 180 def decorator(func): 181 @wraps(func) 182 def wrapper(*args, **kwargs): 183 if torch.cuda.is_available() and torch.cuda.device_count() >= x: 184 return func(*args, **kwargs) 185 sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) 186 187 return wrapper 188 189 return decorator 190 191 192# This decorator helps avoiding initializing cuda while testing other backends 193def nccl_skip_if_lt_x_gpu(backend, x): 194 def decorator(func): 195 @wraps(func) 196 def wrapper(*args, **kwargs): 197 if backend != "nccl": 198 return func(*args, **kwargs) 199 if torch.cuda.is_available() and torch.cuda.device_count() >= x: 200 return func(*args, **kwargs) 201 sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) 202 203 return wrapper 204 205 return decorator 206 207 208def verify_ddp_error_logged(model_DDP, err_substr): 209 # Verify error was logged in ddp_logging_data. 210 ddp_logging_data = model_DDP._get_ddp_logging_data() 211 assert "iteration" in ddp_logging_data 212 assert "has_error" in ddp_logging_data 213 assert "error" in ddp_logging_data 214 logging_err = ddp_logging_data["error"] 215 # Remove C++ stacktrace if needed. 216 actual = ( 217 err_substr 218 if err_substr.find("\nException raised from ") == -1 219 else err_substr.split("\nException raised from ")[0] 220 ) 221 assert ( 222 actual in logging_err 223 ), f"Did not find expected {actual} in ddp logging data error: {logging_err}" 224 225 226def with_nccl_blocking_wait(func): 227 """ 228 Convenience decorator to set/unset TORCH_NCCL_BLOCKING_WAIT flag. Note that use of 229 this decorator will override the setting of TORCH_NCCL_ASYNC_ERROR_HANDLING for 230 the particular test. After the test, both TORCH_NCCL_BLOCKING_WAIT and 231 TORCH_NCCL_ASYNC_ERROR_HANDLING will be restored to their original values. 232 """ 233 234 @wraps(func) 235 def wrapper(*args, **kwargs): 236 # Save and unset TORCH_NCCL_ASYNC_ERROR_HANDLING 237 try: 238 cached_nccl_async_error_handling: Union[str, None] = os.environ[ 239 "TORCH_NCCL_ASYNC_ERROR_HANDLING" 240 ] 241 del os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] 242 except KeyError: 243 # TORCH_NCCL_ASYNC_ERROR_HANDLING was unset 244 cached_nccl_async_error_handling = None 245 246 # Save val of TORCH_NCCL_BLOCKING_WAIT and set it. 247 try: 248 cached_nccl_blocking_wait: Union[str, None] = os.environ[ 249 "TORCH_NCCL_BLOCKING_WAIT" 250 ] 251 except KeyError: 252 cached_nccl_blocking_wait = None 253 finally: 254 os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" 255 256 try: 257 ret = func(*args, **kwargs) 258 return ret 259 finally: 260 # restore old values. 261 if cached_nccl_async_error_handling is not None: 262 os.environ[ 263 "TORCH_NCCL_ASYNC_ERROR_HANDLING" 264 ] = cached_nccl_async_error_handling 265 266 if cached_nccl_blocking_wait is not None: 267 os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait 268 269 return wrapper 270 271 272def with_dist_debug_levels(levels): 273 """ 274 Runs a test for each distributed debug level specified in levels. 275 """ 276 277 def decorator(func): 278 @wraps(func) 279 def wrapper(*args, **kwargs): 280 old_level = os.environ.get("TORCH_DISTRIBUTED_DEBUG", None) 281 for level in levels: 282 os.environ["TORCH_DISTRIBUTED_DEBUG"] = level 283 c10d.set_debug_level_from_env() 284 ret = func(*args, **kwargs) 285 c10d.barrier() 286 if old_level is not None: 287 os.environ["TORCH_DISTRIBUTED_DEBUG"] = old_level 288 # Only returns test return for last test, but since these are 289 # unittests the return value is not really used and earlier tests 290 # would've raised had they failed. 291 return ret 292 293 return wrapper 294 295 return decorator 296 297 298def requires_gloo(): 299 return skip_but_pass_in_sandcastle_if( 300 not c10d.is_gloo_available(), 301 "c10d was not compiled with the Gloo backend", 302 ) 303 304 305def requires_nccl_version(version, msg): 306 if not c10d.is_nccl_available(): 307 return skip_but_pass_in_sandcastle( 308 "c10d was not compiled with the NCCL backend", 309 ) 310 else: 311 return skip_but_pass_in_sandcastle_if( 312 torch.cuda.nccl.version() < version, 313 f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}", 314 ) 315 316 317def requires_nccl(): 318 return skip_but_pass_in_sandcastle_if( 319 not c10d.is_nccl_available(), 320 "c10d was not compiled with the NCCL backend", 321 ) 322 323def requires_ucc(): 324 return skip_but_pass_in_sandcastle_if( 325 not c10d.is_ucc_available(), 326 "c10d was not compiled with the UCC backend", 327 ) 328 329def requires_mpi(): 330 return skip_but_pass_in_sandcastle_if( 331 not c10d.is_mpi_available(), 332 "c10d was not compiled with the MPI backend", 333 ) 334 335 336def skip_if_rocm_multiprocess(func): 337 """Skips a test for ROCm""" 338 func.skip_if_rocm_multiprocess = True 339 340 @wraps(func) 341 def wrapper(*args, **kwargs): 342 if not TEST_WITH_ROCM: 343 return func(*args, **kwargs) 344 sys.exit(TEST_SKIPS["skipIfRocm"].exit_code) 345 346 return wrapper 347 348 349def skip_if_win32(): 350 return skip_but_pass_in_sandcastle_if( 351 sys.platform == "win32", 352 "This unit test case is not supported on Windows platform", 353 ) 354 355 356@retry_on_connect_failures 357def create_tcp_store( 358 addr="localhost", 359 world_size=1, 360 is_master=True, 361 timeout=timedelta(minutes=5), 362 wait_for_workers=True, 363 jit_class=False, 364 use_libuv=True, 365): 366 """ 367 Creates a TCP store. Retries if the chosen port is already in use. 368 """ 369 port = find_free_port() 370 if jit_class: 371 timeout_millisecond = int(timeout / timedelta(milliseconds=1)) 372 return torch.classes.dist_c10d.TCPStore( 373 addr, port, world_size, is_master, timeout_millisecond 374 ) 375 else: 376 return c10d.TCPStore( 377 addr, port, world_size, is_master, wait_for_workers=wait_for_workers, use_libuv=use_libuv 378 ) 379 380 381if TEST_WITH_TSAN: 382 # TSAN runs much slower. 383 TIMEOUT_DEFAULT = 500 384else: 385 TIMEOUT_DEFAULT = int(os.getenv('DISTRIBUTED_TESTS_DEFAULT_TIMEOUT', '300')) 386TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400} 387 388 389# https://github.com/pytorch/pytorch/issues/75665 390if TEST_WITH_ROCM: 391 TIMEOUT_OVERRIDE["test_join_kwargs"] = 200 392 393 394def create_device(interface=None): 395 if sys.platform == "win32" or interface is None: 396 return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1") 397 else: 398 return c10d.ProcessGroupGloo.create_device(interface=interface) 399 400 401def get_timeout(test_id) -> int: 402 return TIMEOUT_OVERRIDE.get(test_id.split(".")[-1], TIMEOUT_DEFAULT) 403 404 405@contextmanager 406def captured_output(): 407 new_out, new_err = StringIO(), StringIO() 408 old_out, old_err = sys.stdout, sys.stderr 409 try: 410 sys.stdout, sys.stderr = new_out, new_err 411 yield sys.stdout, sys.stderr 412 finally: 413 sys.stdout, sys.stderr = old_out, old_err 414 415 416def simple_sparse_reduce_tests(rank: int, world_size: int, num_inputs: int = 1): 417 """ 418 Generate a number of basic test cases for sparse reduction. 419 These cover tensors with a varying number of sparse dimensions and a varying 420 number of dense dimensions. The only reduction operation we support is sum. 421 """ 422 423 def generate(rank: int, world_size: int, sparse_dims: int = 1, dense_dims: int = 0): 424 # First sparse dimension is [0..rank]. 425 # Subsequent dimensions are always 0, so we know there is 426 # a non-empty intersection between any two sparse tensors. 427 indices = torch.reshape(torch.arange(rank + 1), (1, rank + 1)) 428 shape = [world_size] + [2 for _ in range(dense_dims)] 429 for _ in range(sparse_dims - 1): 430 indices = torch.cat((indices, torch.zeros(1, rank + 1))) 431 shape.append(world_size) 432 values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)]) 433 return torch.sparse_coo_tensor(indices, values, shape) 434 435 def compute_sum(fn, world_size: int): 436 return reduce( 437 operator.add, [fn(rank, world_size) for rank in range(world_size)] 438 ) 439 440 return [ 441 ( 442 [ 443 fn(num_inputs * rank + i, num_inputs * world_size) 444 for i in range(num_inputs) 445 ], 446 [compute_sum(fn, num_inputs * world_size) for i in range(num_inputs)], 447 ) 448 for fn in [ 449 partial(generate, sparse_dims=1), 450 partial(generate, sparse_dims=2), 451 partial(generate, sparse_dims=3), 452 partial(generate, dense_dims=1), 453 partial(generate, dense_dims=2), 454 partial(generate, dense_dims=3), 455 ] 456 ] 457 458 459# HELPER FOR MULTIGPU TESTS 460def init_multigpu_helper(world_size: int, backend: str): 461 """Multigpu tests are designed to simulate the multi nodes with multi 462 GPUs on each node. Nccl backend requires equal #GPUs in each process. 463 On a single node, all visible GPUs are evenly 464 divided to subsets, each process only uses a subset. 465 """ 466 nGPUs = torch.cuda.device_count() 467 visible_devices = range(nGPUs) 468 469 # If rank is less than or equal to number of available GPU's 470 # then each rank can be mapped to corresponding GPU. 471 nGPUs_per_process = 1 472 if world_size > nGPUs: 473 nGPUs_per_process = nGPUs // world_size 474 rank_to_GPU = { 475 i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process]) 476 for i in range(world_size) 477 } 478 return rank_to_GPU 479 480 481tmp_dir: Optional[tempfile.TemporaryDirectory] = None 482 483 484def initialize_temp_directories(init_method: Optional[str] = None) -> None: 485 global tmp_dir 486 tmp_dir = tempfile.TemporaryDirectory() 487 os.environ["TEMP_DIR"] = tmp_dir.name 488 os.mkdir(os.path.join(tmp_dir.name, "barrier")) 489 os.mkdir(os.path.join(tmp_dir.name, "test_dir")) 490 init_dir_path = os.path.join(tmp_dir.name, "init_dir") 491 os.mkdir(init_dir_path) 492 # Set init method if specified. 493 if init_method is not None: 494 os.environ["INIT_METHOD"] = init_method 495 else: 496 os.environ["INIT_METHOD"] = FILE_SCHEMA + os.path.join( 497 init_dir_path, "shared_init_file" 498 ) 499 500 501def cleanup_temp_dir() -> None: 502 if tmp_dir is not None: 503 tmp_dir.cleanup() 504 505 506# Most tests operate with this worldsize 507DEFAULT_WORLD_SIZE = 4 508 509# [How does MultiProcessTestCase work?] 510# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by 511# default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an 512# example which inherits from this class. Its `Setup()` methods calls into 513# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()` 514# subprocesses. During the spawn, the main process passes the test name to 515# subprocesses, and the name is acquired from self.id(). The subprocesses 516# then use the provided test function name to retrieve the function attribute 517# from the test instance and run it. The main process simply waits for all 518# subprocesses to join. 519 520 521class MultiProcessTestCase(TestCase): 522 MAIN_PROCESS_RANK = -1 523 # This exit code is used to indicate that the test code had an error and 524 # exited abnormally. There are certain tests that might use sys.exit() to 525 # simulate failures and in those cases, we can't have an exit code of 0, 526 # but we still want to ensure we didn't run into any other errors. 527 TEST_ERROR_EXIT_CODE = 10 528 529 # do not early terminate for distributed tests. 530 def _should_stop_test_suite(self) -> bool: 531 return False 532 533 @property 534 def world_size(self) -> int: 535 return DEFAULT_WORLD_SIZE 536 537 def join_or_run(self, fn): 538 @wraps(fn) 539 def wrapper(self): 540 if self.rank == self.MAIN_PROCESS_RANK: 541 self._join_processes(fn) 542 else: 543 fn() 544 545 return types.MethodType(wrapper, self) 546 547 # The main process spawns N subprocesses that run the test. 548 # Constructor patches current instance test method to 549 # assume the role of the main process and join its subprocesses, 550 # or run the underlying test function. 551 def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None: 552 # methodName is the correct naming in unittest and testslide uses keyword arguments. 553 # So we need to use both to 1) not break BC and, 2) support testslide. 554 if methodName != "runTest": 555 method_name = methodName 556 super().__init__(method_name) 557 fn = getattr(self, method_name) 558 setattr(self, method_name, self.join_or_run(fn)) 559 560 def setUp(self) -> None: 561 super().setUp() 562 self.skip_return_code_checks = [] # type: ignore[var-annotated] 563 self.processes = [] # type: ignore[var-annotated] 564 self.rank = self.MAIN_PROCESS_RANK 565 self.file_name = tempfile.NamedTemporaryFile(delete=False).name 566 # pid to pipe consisting of error message from process. 567 self.pid_to_pipe = {} # type: ignore[var-annotated] 568 569 def tearDown(self) -> None: 570 super().tearDown() 571 for p in self.processes: 572 p.terminate() 573 # Each Process instance holds a few open file descriptors. The unittest 574 # runner creates a new TestCase instance for each test method and keeps 575 # it alive until the end of the entire suite. We must thus reset the 576 # processes to prevent an effective file descriptor leak. 577 self.processes = [] 578 579 def _current_test_name(self) -> str: 580 # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' 581 return self.id().split(".")[-1] 582 583 def _start_processes(self, proc) -> None: 584 self.processes = [] 585 for rank in range(int(self.world_size)): 586 parent_conn, child_conn = torch.multiprocessing.Pipe() 587 process = proc( 588 target=self.__class__._run, 589 name="process " + str(rank), 590 args=(rank, self._current_test_name(), self.file_name, child_conn), 591 kwargs={ 592 "fake_pg": getattr(self, "fake_pg", False), 593 } 594 ) 595 process.start() 596 logger.info("Started process %s with pid %s", rank, process.pid) 597 self.pid_to_pipe[process.pid] = parent_conn 598 self.processes.append(process) 599 600 def _spawn_processes(self) -> None: 601 proc = torch.multiprocessing.get_context("spawn").Process 602 self._start_processes(proc) 603 604 class Event(Enum): 605 GET_TRACEBACK = 1 606 607 @staticmethod 608 def _event_listener(parent_pipe, signal_pipe, rank: int): 609 logger.info("Starting event listener thread for rank %s", rank) 610 while True: 611 ready_pipes = multiprocessing.connection.wait([parent_pipe, signal_pipe]) 612 613 if parent_pipe in ready_pipes: 614 615 if parent_pipe.closed: 616 logger.info( 617 "Pipe closed for process %s, stopping event listener thread", rank 618 ) 619 return 620 621 event = parent_pipe.recv() 622 logger.info("Received event %s on process %s", event, rank) 623 624 if event == MultiProcessTestCase.Event.GET_TRACEBACK: 625 # Return traceback to the parent process. 626 with tempfile.NamedTemporaryFile(mode="r+") as tmp_file: 627 faulthandler.dump_traceback(tmp_file) 628 # Flush buffers and seek to read from the beginning 629 tmp_file.flush() 630 tmp_file.seek(0) 631 parent_pipe.send(tmp_file.read()) 632 633 logger.info("Process %s sent traceback", rank) 634 635 if signal_pipe in ready_pipes: 636 return 637 638 @classmethod 639 def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs) -> None: 640 self = cls(test_name) 641 self.rank = rank 642 self.file_name = file_name 643 self.run_test(test_name, parent_pipe) 644 645 def run_test(self, test_name: str, parent_pipe) -> None: 646 # Start event listener thread. 647 signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe(duplex=False) 648 event_listener_thread = threading.Thread( 649 target=MultiProcessTestCase._event_listener, 650 args=(parent_pipe, signal_recv_pipe, self.rank), 651 daemon=True, 652 ) 653 event_listener_thread.start() 654 if sys.platform != "win32" and sys.platform != "darwin": 655 # Register signal handler to dump stack traces on FATALs. 656 # Windows and MacOS do not support the signal handlers. 657 torch._C._set_print_stack_traces_on_fatal_signal(True) 658 # Show full C++ stacktraces when a Python error originating from C++ is raised. 659 os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" 660 661 # self.id() == e.g. '__main__.TestDistributed.test_get_rank' 662 # We're retrieving a corresponding test and executing it. 663 try: 664 getattr(self, test_name)() 665 except unittest.SkipTest as se: 666 logger.info( 667 "Process %s skipping test %s for following reason: %s", self.rank, test_name, str(se) 668 ) 669 sys.exit(TEST_SKIPS["generic"].exit_code) 670 except Exception as e: 671 logger.error( 672 "Caught exception: \n%s exiting " 673 "process %s with exit code: %s", 674 traceback.format_exc(), self.rank, MultiProcessTestCase.TEST_ERROR_EXIT_CODE 675 ) 676 # Send error to parent process. 677 parent_pipe.send(traceback.format_exc()) 678 sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) 679 finally: 680 if signal_send_pipe is not None: 681 signal_send_pipe.send(None) 682 683 assert event_listener_thread is not None 684 event_listener_thread.join() 685 # Close pipe after done with test. 686 parent_pipe.close() 687 688 def _get_timedout_process_traceback(self) -> None: 689 pipes = [] 690 for i, process in enumerate(self.processes): 691 if process.exitcode is None: 692 pipe = self.pid_to_pipe[process.pid] 693 try: 694 pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK) 695 pipes.append((i, pipe)) 696 except ConnectionError as e: 697 logger.error( 698 "Encountered error while trying to get traceback for process %s: %s", i, e 699 ) 700 701 # Wait for results. 702 for rank, pipe in pipes: 703 try: 704 # Wait for traceback 705 if pipe.poll(5): 706 if pipe.closed: 707 logger.info( 708 "Pipe closed for process %s, cannot retrieve traceback", rank 709 ) 710 continue 711 712 traceback = pipe.recv() 713 logger.error( 714 "Process %s timed out with traceback: \n\n%s", rank, traceback 715 ) 716 else: 717 logger.error( 718 "Could not retrieve traceback for timed out process: %s", rank 719 ) 720 except ConnectionError as e: 721 logger.error( 722 "Encountered error while trying to get traceback for process %s: %s", rank, e 723 ) 724 725 def _join_processes(self, fn) -> None: 726 timeout = get_timeout(self.id()) 727 start_time = time.time() 728 subprocess_error = False 729 try: 730 while True: 731 # check to see if any subprocess exited with an error early. 732 for (i, p) in enumerate(self.processes): 733 # This is the exit code processes exit with if they 734 # encountered an exception. 735 if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE: 736 print( 737 f"Process {i} terminated with exit code {p.exitcode}, terminating remaining processes." 738 ) 739 active_children = torch.multiprocessing.active_children() 740 for ac in active_children: 741 ac.terminate() 742 subprocess_error = True 743 break 744 if subprocess_error: 745 break 746 # All processes have joined cleanly if they all a valid exitcode 747 if all(p.exitcode is not None for p in self.processes): 748 break 749 # Check if we should time out the test. If so, we terminate each process. 750 elapsed = time.time() - start_time 751 if elapsed > timeout: 752 self._get_timedout_process_traceback() 753 print( 754 f"Timing out after {timeout} seconds and killing subprocesses." 755 ) 756 for p in self.processes: 757 p.terminate() 758 break 759 # Sleep to avoid excessive busy polling. 760 time.sleep(0.1) 761 762 elapsed_time = time.time() - start_time 763 764 if fn in self.skip_return_code_checks: 765 self._check_no_test_errors(elapsed_time) 766 else: 767 self._check_return_codes(elapsed_time) 768 finally: 769 # Close all pipes 770 for pipe in self.pid_to_pipe.values(): 771 pipe.close() 772 773 def _check_no_test_errors(self, elapsed_time) -> None: 774 """ 775 Checks that we didn't have any errors thrown in the child processes. 776 """ 777 for i, p in enumerate(self.processes): 778 if p.exitcode is None: 779 raise RuntimeError( 780 f"Process {i} timed out after {elapsed_time} seconds" 781 ) 782 self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode) 783 784 def _check_return_codes(self, elapsed_time) -> None: 785 """ 786 Checks that the return codes of all spawned processes match, and skips 787 tests if they returned a return code indicating a skipping condition. 788 """ 789 # If no processes are spawned, there is nothing to check. 790 if not self.processes: 791 logger.warning("Note: no subprocesses were spawned, test was likely skipped.") 792 return 793 794 first_process = self.processes[0] 795 # first, we check if there are errors in actual processes 796 # (via TEST_ERROR_EXIT CODE), and raise an exception for those. 797 # the reason we do this is to attempt to raise a more helpful error 798 # message than "Process x terminated/timed out" 799 # TODO: we should pipe the exception of the failed subprocess here. 800 # Currently, the actual exception is displayed as a logging output. 801 errored_processes = [ 802 (i, p) 803 for i, p in enumerate(self.processes) 804 if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE 805 ] 806 if errored_processes: 807 error = "" 808 for i, process in errored_processes: 809 # Get error from pipe. 810 error_message = self.pid_to_pipe[process.pid].recv() 811 error += ( 812 f"Process {i} exited with error code {MultiProcessTestCase.TEST_ERROR_EXIT_CODE} " 813 f"and exception:\n{error_message}\n" 814 ) 815 816 raise RuntimeError(error) 817 # If no process exited uncleanly, we check for timeouts, and then ensure 818 # each process exited cleanly. 819 for i, p in enumerate(self.processes): 820 if p.exitcode is None: 821 raise RuntimeError( 822 f"Process {i} terminated or timed out after {elapsed_time} seconds" 823 ) 824 self.assertEqual( 825 p.exitcode, 826 first_process.exitcode, 827 msg=f"Expect process {i} exit code to match Process 0 exit code of {first_process.exitcode}, but got {p.exitcode}", 828 ) 829 for skip in TEST_SKIPS.values(): 830 if first_process.exitcode == skip.exit_code: 831 if IS_SANDCASTLE: 832 # Don't use unittest.skip to skip the test on sandcastle 833 # since it creates tasks for skipped tests assuming there 834 # is some follow-up needed. Instead just "pass" the test 835 # with an appropriate message. 836 logger.info( 837 "Skipping %s on sandcastle for the following reason: %s", self.id(), skip.message 838 ) 839 return 840 else: 841 raise unittest.SkipTest(skip.message) 842 self.assertEqual( 843 first_process.exitcode, 844 0, 845 msg=f"Expected zero exit code but got {first_process.exitcode} for pid: {first_process.pid}", 846 ) 847 848 @property 849 def is_master(self) -> bool: 850 return self.rank == 0 851 852 853def run_subtests( 854 cls_inst, 855 subtest_config: Dict[str, List[Any]], 856 test_fn: Callable, 857 *test_args, 858 **test_kwargs: Any, 859): 860 """ 861 Runs a test function given by ``test_fn`` as a subtest according to the 862 configurations specified by ``subtest_config``. This amortizes the 863 costly setup overhead (including process spawn and initializing the 864 process group) over the subtests. 865 866 Args: 867 subtest_config (Dict[str, List[Any]]): A mapping from subtest 868 keyword argument name to a list of its possible values. 869 test_fn (Callable): A callable that runs the actual test. 870 test_args: Positional arguments to pass to ``test_fn``. 871 test_kwargs: Keyword arguments to pass to ``test_fn``. 872 """ 873 # Convert the config mapping to a list to have a fixed order 874 subtest_config_items: List[Tuple[str, List[Any]]] = list(subtest_config.items()) 875 subtest_config_keys: List[str] = [item[0] for item in subtest_config_items] 876 subtest_config_values: List[List[Any]] = [item[1] for item in subtest_config_items] 877 for values in itertools.product(*subtest_config_values): 878 # Map keyword to chosen value 879 subtest_kwargs = dict(zip(subtest_config_keys, values)) 880 with cls_inst.subTest(**subtest_kwargs): 881 torch._dynamo.reset() 882 test_fn(*test_args, **test_kwargs, **subtest_kwargs) 883 torch._dynamo.reset() 884 c10d.barrier() 885 886 887# Cannot use functools.cache as it requires python 3.9 888EFA_PROBE_RESULT = None 889 890 891def has_efa() -> bool: 892 """ 893 If shell command `fi_info -p efa -t FI_EP_RDM` returns exit code 0 then we assume that the machine has 894 Libfabric EFA interfaces and EFA software components installed, 895 see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa-start.html. 896 """ 897 global EFA_PROBE_RESULT 898 if EFA_PROBE_RESULT is not None: 899 return EFA_PROBE_RESULT 900 901 try: 902 EFA_PROBE_RESULT = ( 903 subprocess.run(["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], check=False).returncode == 0 904 ) 905 except FileNotFoundError: 906 EFA_PROBE_RESULT = False 907 return EFA_PROBE_RESULT 908 909 910def tp_transports(): 911 """ 912 If the machine has Libfabric EFA interfaces and EFA software components installed it may cause 913 'RuntimeError: In operator() at tensorpipe/common/ibv.h:172 "": Operation not supported' if tensorpipe 914 uses InfiniBand transport, so we exclude it from tensorpipe transports, 915 see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022 916 """ 917 return ["shm", "uv"] if has_efa() else None 918 919 920def spawn_threads_and_init_comms( 921 func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE 922): 923 """ 924 Wrapper to use with a test method 925 """ 926 if func is None: 927 return partial( 928 spawn_threads_and_init_comms, timeout=timeout, world_size=world_size 929 ) 930 931 932 def _run_test_method_with_multi_threads(world_size, callback): 933 world = _install_threaded_pg() 934 global_store = c10d.HashStore() 935 936 def world_is_valid(): 937 return world == c10d.distributed_c10d._world 938 939 def worker(rank, world_pg, store): 940 c10d.init_process_group( 941 backend="threaded", rank=rank, world_size=world_size, store=store 942 ) 943 try: 944 callback() 945 except BaseException as ex: 946 # Exceptions are handled in MultiThreadedTestCase 947 MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info())) 948 ProcessLocalGroup.exception_handle(ex) # trigger _terminate event and awaken worker threads 949 finally: 950 if world_is_valid(): 951 c10d.destroy_process_group() 952 953 threads = [] 954 for rank in range(world_size): 955 t = threading.Thread(target=worker, args=(rank, world, global_store)) 956 t.start() 957 threads.append(t) 958 959 return threads 960 961 962 @wraps(func) 963 def wrapper(self, *args, **kwargs): 964 # TODO: get test name from kwargs 965 torch._C._distributed_c10d._set_thread_isolation_mode(True) 966 try: 967 threads = _run_test_method_with_multi_threads(world_size, lambda: func(self, *args, **kwargs)) 968 # join and error handling 969 MultiThreadedTestCase._join_threads(threads, func) 970 finally: 971 torch._C._distributed_c10d._set_thread_isolation_mode(False) 972 973 return wrapper 974 975 976class MultiThreadedTestCase(TestCase): 977 """ 978 Test runner that runs all tests with the in-proc process group using 979 multiple threads with the threaded process group. 980 981 Each test spawns world_size threads and run the test method in each thread. 982 983 Difference from regular MultiProcess test runner: 984 Must explicitly defines SetUp and call self._spawn_threads() to run the tests. 985 Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown) 986 to set up / tear down each thread when running each test. 987 No global state possible 988 How bad of a limitation is this? 989 """ 990 exception_queue = queue.Queue() 991 992 MAIN_THREAD_RANK = -1 993 994 def join_or_run(self, fn): 995 @wraps(fn) 996 def wrapper(self): 997 if self.rank == self.MAIN_THREAD_RANK: 998 self._join_threads(self.threads, fn) 999 else: 1000 fn() 1001 1002 return types.MethodType(wrapper, self) 1003 1004 def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None: 1005 # methodName is the correct naming in unittest and testslide uses keyword arguments. 1006 # So we need to use both to 1) not break BC and, 2) support testslide. 1007 if methodName != "runTest": 1008 method_name = methodName 1009 super().__init__(method_name) 1010 fn = getattr(self, method_name) 1011 setattr(self, method_name, self.join_or_run(fn)) 1012 1013 def perThreadSetUp(self): 1014 # super().setUp() # TestCase.setUp() calls torch.manual_seed() 1015 pass 1016 1017 def perThreadTearDown(self): 1018 pass 1019 1020 def setUp(self) -> None: 1021 """ 1022 setUp only set up things in the main thread, if you want to configure things 1023 in the spawned threads, use perThreadSetUp 1024 """ 1025 super().setUp() 1026 self.rank = self.MAIN_THREAD_RANK 1027 self.threads = [] 1028 # Show full C++ stacktraces when a Python error originating from C++ is raised. 1029 os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" 1030 1031 def tearDown(self): 1032 """ 1033 tearDown only set up things in the main thread, if you want to configure things 1034 in the spawned threads, use perThreadTearDown 1035 """ 1036 super().tearDown() 1037 self.threads = [] 1038 1039 def _spawn_threads(self): 1040 """ 1041 class method to spawn threads and run test, use this method in the SetUp of your TestCase 1042 """ 1043 torch._C._distributed_c10d._set_thread_isolation_mode(True) 1044 test_name = self._current_test_name 1045 # for each test case, we need to create thread local world, and a global store 1046 world = _install_threaded_pg() 1047 self.__class__.global_store = c10d.HashStore() 1048 1049 def world_is_valid(): 1050 return world == c10d.distributed_c10d._world 1051 1052 if not world_is_valid(): 1053 raise RuntimeError("Invalid world") 1054 1055 for rank in range(self.world_size): 1056 t = threading.Thread(target=self.__class__._run, args=(test_name, rank, self.world_size)) 1057 t.start() 1058 self.threads.append(t) 1059 1060 @classmethod 1061 def _run(cls, test_name, rank, world_size, **kwargs): 1062 self = cls(test_name) 1063 self.rank = rank 1064 1065 # precision/rel_tol is a thread-local setting since it may be overridden per test, need to make 1066 # every thread have the same value. This would be relevant when we use op db tests, where it 1067 # needs those states to be set i.e. using instantiate_device_type_tests() 1068 # TODO: figure out a better way to do this 1069 if hasattr(self, "_tls"): 1070 self._tls = threading.local() 1071 self._tls.precision = TestCase._precision 1072 self._tls.rel_tol = TestCase._rel_tol 1073 1074 self.run_test_with_threaded_pg(test_name, rank, world_size) 1075 1076 def run_test_with_threaded_pg(self, test_name, rank, world_size): 1077 """ 1078 Run the current test associated with `test_name` using the threaded process group. 1079 """ 1080 c10d.init_process_group( 1081 backend="threaded", rank=rank, world_size=world_size, store=self.__class__.global_store 1082 ) 1083 self.perThreadSetUp() 1084 1085 try: 1086 getattr(self, test_name)() 1087 except BaseException as ex: 1088 self.exception_queue.put((rank, sys.exc_info())) 1089 ProcessLocalGroup.exception_handle(ex) # trigger _terminate event and awaken worker threads 1090 finally: 1091 c10d.destroy_process_group() 1092 self.perThreadTearDown() 1093 1094 1095 @classmethod 1096 def _join_threads(cls, threads, fn): 1097 timeout = TIMEOUT_DEFAULT 1098 try: 1099 for idx, thread in enumerate(threads): 1100 thread.join(max(0, timeout)) 1101 if thread.is_alive(): 1102 MultiThreadedTestCase.exception_queue.put( 1103 ( 1104 idx, 1105 ( 1106 TimeoutError, 1107 TimeoutError( 1108 f"Rank failed to join in under {timeout} seconds" 1109 ), 1110 None, 1111 ), 1112 ) 1113 ) 1114 ProcessLocalGroup.reset() 1115 failed_ranks = [] 1116 while not cls.exception_queue.empty(): 1117 failure = cls.exception_queue.get() 1118 failed_ranks.append(failure) 1119 finally: 1120 _uninstall_threaded_pg() 1121 torch._C._distributed_c10d._set_thread_isolation_mode(False) 1122 1123 cls._check_return_codes(failed_ranks, timeout, fn) 1124 1125 @classmethod 1126 def _check_return_codes(cls, failed_ranks, timeout, fn): 1127 # Print based on exceptions raised from threads 1128 # SkipTest: print info for each thread 1129 # TimeoutError: raise RuntimeError for any timed out thread 1130 # Normal Exception: print error for each thread that raises exception 1131 # and raise a RuntimeError 1132 error_msg = "" 1133 skip_code = -1 1134 for rank, exc_info in failed_ranks: 1135 exc = exc_info[1] 1136 if isinstance(exc, unittest.SkipTest): 1137 logger.info( 1138 "Thread %s skipping test %s for following reason: %s", rank, fn, str(exc) 1139 ) 1140 if skip_code < 0: 1141 skip_code = TEST_SKIPS["generic"].exit_code 1142 elif isinstance(exc, TimeoutError): 1143 msg = f"Thread {rank} terminated or timed out after {timeout} seconds\n" 1144 logger.error(msg) 1145 raise RuntimeError(msg) 1146 elif isinstance(exc, Exception): 1147 msg = "".join(traceback.format_exception(*exc_info)) 1148 logger.error( 1149 "Caught exception: \n%s exiting thread %s", msg, rank 1150 ) 1151 error_msg += ( 1152 f"Thread {rank} exited with exception:\n{msg}\n" 1153 ) 1154 elif isinstance(exc, SystemExit): 1155 if type(exc.code) == int and skip_code < 0: 1156 skip_code = exc.code 1157 1158 # check exceptions 1159 if len(error_msg) > 0: 1160 raise RuntimeError(error_msg) 1161 # check skip 1162 if skip_code > 0: 1163 for skip in TEST_SKIPS.values(): 1164 if skip_code == skip.exit_code: 1165 if IS_SANDCASTLE: 1166 # "pass" the test with an appropriate message. 1167 logger.info( 1168 "Skipping %s on sandcastle for the following reason: %s", fn, skip.message 1169 ) 1170 return 1171 else: 1172 raise unittest.SkipTest(skip.message) 1173 1174 @property 1175 def world_size(self) -> int: 1176 return DEFAULT_WORLD_SIZE 1177 1178 @property 1179 def _current_test_name(self) -> str: 1180 # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' 1181 return self.id().split(".")[-1] 1182 1183 def assertEqualOnRank(self, x, y, msg=None, *, rank=0): 1184 """ 1185 The reason why we have this util function instead of 1186 self.assertEqual is all threads are sharing one CPU RNG 1187 so the assertion result is only reliable on rank 0 1188 """ 1189 if self.rank == rank: 1190 self.assertEqual(x, y, msg) 1191 1192 def assertNotEqualOnRank(self, x, y, msg=None, *, rank=0): 1193 if self.rank == rank: 1194 self.assertNotEqual(x, y) 1195 1196 1197class SaveForwardInputsModule(nn.Module): 1198 def __init__( 1199 self, 1200 forward_inputs: Dict[nn.Module, torch.Tensor], 1201 cast_forward_inputs: bool, 1202 ) -> None: 1203 super().__init__() 1204 self.l = nn.Linear(100, 100) 1205 self.forward_inputs = forward_inputs 1206 self.cast_forward_inputs = cast_forward_inputs 1207 1208 def forward(self, x: torch.Tensor) -> torch.Tensor: 1209 self.forward_inputs[self] = x 1210 return self.l(x.to(self.l.weight.dtype) if self.cast_forward_inputs else x) 1211 1212 1213class SaveForwardInputsModel(nn.Module): 1214 def __init__( 1215 self, 1216 forward_inputs: Dict[nn.Module, torch.Tensor], 1217 cast_forward_inputs: bool, 1218 ) -> None: 1219 super().__init__() 1220 self.c1 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs) 1221 self.c2 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs) 1222 self.forward_inputs = forward_inputs 1223 1224 def forward(self, x: torch.Tensor) -> torch.Tensor: 1225 self.forward_inputs[self] = x 1226 return self.c2(self.c1(x)) 1227 1228@contextmanager 1229def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False): 1230 # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase, 1231 # Just manually implement the most important part of the dynamo behavior to reset/clear. 1232 if not fake_pg: 1233 torch.cuda.set_device(rank) 1234 os.environ['MASTER_ADDR'] = 'localhost' 1235 os.environ['MASTER_PORT'] = '6789' 1236 if init_pg: 1237 if fake_pg: 1238 store = torch.testing._internal.distributed.fake_pg.FakeStore() 1239 c10d.init_process_group( 1240 backend="fake", 1241 world_size=world_size, 1242 rank=rank, 1243 store=store, 1244 ) 1245 else: 1246 c10d.init_process_group("nccl", rank=rank, world_size=world_size) 1247 torch._dynamo.reset() 1248 torch._dynamo.utils.counters.clear() 1249 try: 1250 yield 1251 finally: 1252 torch._dynamo.reset() 1253 torch._dynamo.utils.counters.clear() 1254 if init_pg: 1255 c10d.destroy_process_group() 1256 1257 1258class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase): 1259 """ 1260 Test harness for single-process dynamo distributed tests, 1261 initializes dist process group. 1262 1263 Prefer this for simple tests, as it's easier to debug. 1264 """ 1265 1266 @classmethod 1267 def setUpClass(cls): 1268 super().setUpClass() 1269 # _exit_stack is set up in TestCase 1270 cls._exit_stack.enter_context( 1271 patch.dict( 1272 os.environ, 1273 { 1274 "MASTER_ADDR": "localhost", 1275 "MASTER_PORT": "12355", 1276 }, 1277 ) 1278 ) 1279 cls.rank = 0 1280 cls.device = f"cuda:{cls.rank}" 1281 cls.device_ids = None if "cuda" in cls.device else [cls.rank] 1282 c10d.init_process_group("nccl", rank=cls.rank, world_size=1) 1283 1284 @classmethod 1285 def tearDownClass(cls): 1286 c10d.destroy_process_group() 1287 super().tearDownClass() 1288 1289 1290class DynamoDistributedMultiProcTestCase(MultiProcessTestCase): 1291 """ 1292 Use this for tests that actually run on multiple GPUs. 1293 1294 Decorate tests with @skip_if_lt_x_gpu(ngpu) 1295 1296 Note: MultiProcTestCase spawns processes per test and is slow. 1297 Prefer MultiThreadedTestCase for most tests. Perhaps use this one 1298 sparingly for integration tests. 1299 """ 1300 def setUp(self): 1301 super().setUp() 1302 self._spawn_processes() 1303 1304 def tearDown(self): 1305 super().tearDown() 1306 try: 1307 os.remove(self.file_name) 1308 except OSError: 1309 pass 1310 1311 @property 1312 def world_size(self) -> int: 1313 return torch.cuda.device_count() 1314 1315 @classmethod 1316 def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs) -> None: 1317 # The rest is copypasta from MultiProcessTestCase._run 1318 self = cls(test_name) 1319 self.rank = rank 1320 self.file_name = file_name 1321 self.run_test(test_name, parent_pipe) 1322 1323 1324class MultiProcContinousTest(TestCase): 1325 # Class variables: 1326 # number of test processes 1327 world_size: int = 2 1328 # rank of the current process 1329 rank: int = -1 # unset state 1330 # Rendezvous file 1331 rdvz_file: Optional[str] = None 1332 1333 @classmethod 1334 @abc.abstractmethod 1335 def backend_str(cls) -> str: 1336 """ 1337 ProcessGroup backend str. 1338 To be customized by sub test classes, e.g. "nccl". 1339 Here we raise error. 1340 """ 1341 raise NotImplementedError("Please implement backend_str in your test class") 1342 1343 @classmethod 1344 def opts(cls, high_priority_stream=False): 1345 """ 1346 ProcessGroup init options. 1347 To be customized by sub test classes, e.g. ProcessGroupNCCLOpTest 1348 Here we return None. 1349 """ 1350 return None 1351 1352 @classmethod 1353 def setUpClass(cls): 1354 """ 1355 Class-scope test fixture. Run once for entire test class, before any test starts. 1356 Set up the process group. 1357 """ 1358 super().setUpClass() 1359 if not 0 <= cls.rank < cls.world_size: 1360 raise RuntimeError( 1361 "Rank must be set and in the range of 0 to world_size. " 1362 f"World size: {cls.world_size} Rank: {cls.rank}" 1363 ) 1364 if cls.rdvz_file: 1365 store = c10d.FileStore(cls.rdvz_file, cls.world_size) 1366 else: 1367 # torchrun takes care of rendezvous 1368 store = None 1369 opts = cls.opts() 1370 backend = cls.backend_str() 1371 print(f"Testing {backend=}") 1372 # create nccl processgroup with opts 1373 c10d.init_process_group( 1374 backend=backend, 1375 world_size=cls.world_size, 1376 rank=cls.rank, 1377 store=store, 1378 pg_options=opts, 1379 ) 1380 cls.pg = c10d.distributed_c10d._get_default_group() 1381 print(f"Rank {cls.rank} setup complete") 1382 1383 @classmethod 1384 def tearDownClass(cls): 1385 """ 1386 Class-scope test fixture. Run once for entire test class, after all tests finish. 1387 Tear down the process group. 1388 """ 1389 c10d.destroy_process_group() 1390 super().tearDownClass() 1391 # Clear up the rendezvous file 1392 if cls.rdvz_file: 1393 try: 1394 os.remove(cls.rdvz_file) 1395 except OSError: 1396 pass 1397 print(f"Rank {cls.rank} teardown complete") 1398 1399 @classmethod 1400 def run_rank( 1401 cls, 1402 rank: int, 1403 world_size: int, 1404 rdvz_file: Optional[str] = None, 1405 ): 1406 """ 1407 This is an entry point for each rank to run the tests in `MultiProcContinousTest`. 1408 In this entry point, we set the class variables for the test class. 1409 Then we run all tests. 1410 1411 Note: 1412 - This helper only works for a subclass of `MultiProcContinousTest`. 1413 1414 Example: 1415 - See `test_c10d_ops_nccl.py`. 1416 """ 1417 # set class variables for the test class 1418 cls.rank = rank 1419 cls.world_size = world_size 1420 cls.rdvz_file = rdvz_file 1421 # Launch tests via `common_utils` infra 1422 run_tests() 1423