1# Owner(s): ["oncall: distributed"] 2 3import copy 4import os 5import pickle 6import sys 7import tempfile 8import threading 9import time 10from contextlib import nullcontext 11from dataclasses import dataclass 12from datetime import timedelta 13from itertools import product 14from sys import platform 15from typing import Dict, Optional 16 17import torch 18import torch.distributed as dist 19 20 21if not dist.is_available(): 22 print("distributed package not available, skipping tests", file=sys.stderr) 23 sys.exit(0) 24 25import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD 26import torch.distributed.distributed_c10d as c10d 27import torch.nn.functional as F 28import torch.testing._internal.common_utils as common 29from torch import nn 30from torch.nn.parallel import DistributedDataParallel 31from torch.testing._internal.common_distributed import ( 32 MultiProcessTestCase, 33 skip_if_lt_x_gpu, 34) 35from torch.testing._internal.common_utils import ( 36 instantiate_parametrized_tests, 37 load_tests, 38 parametrize, 39 retry_on_connect_failures, 40 run_tests, 41 TEST_WITH_DEV_DBG_ASAN, 42 TestCase, 43) 44from torch.utils.checkpoint import checkpoint 45 46 47if TEST_WITH_DEV_DBG_ASAN: 48 print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr) 49 sys.exit(0) 50 51# load_tests from common_utils is used to automatically filter tests for 52# sharding on sandcastle. This line silences flake warnings 53load_tests = load_tests 54 55if platform == "darwin": 56 LOOPBACK = "lo0" 57else: 58 LOOPBACK = "lo" 59 60torch.backends.cuda.matmul.allow_tf32 = False 61 62 63def gpus_for_rank(world_size): 64 """Multigpu tests are designed to simulate the multi nodes with multi 65 GPUs on each node. Nccl backend requires equal #GPUs in each process. 66 On a single node, all visible GPUs are evenly 67 divided to subsets, each process only uses a subset. 68 """ 69 visible_devices = list(range(torch.cuda.device_count())) 70 gpus_per_process = torch.cuda.device_count() // world_size 71 gpus_for_rank = [] 72 for rank in range(world_size): 73 gpus_for_rank.append( 74 visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process] 75 ) 76 return gpus_for_rank 77 78 79class AbstractTimeoutTest: 80 def _test_store_timeout(self, backend, init_method, c2p): 81 try: 82 dist.init_process_group( 83 backend=backend, 84 init_method=init_method, 85 world_size=1, 86 rank=0, 87 timeout=timedelta(seconds=1), 88 ) 89 default_store = c10d._get_default_store() 90 tik = time.time() 91 with self.assertRaisesRegex(RuntimeError, "(?i)timeout"): 92 default_store.get("nonexistent key") 93 tok = time.time() 94 dist.destroy_process_group() 95 c2p.append(float(tok - tik)) 96 except RuntimeError as e: 97 # catch "Address already in use" error and report it to the main 98 # thread 99 c2p.append(e) 100 101 def _init_methods(self): 102 f = tempfile.NamedTemporaryFile(delete=False) 103 if sys.platform == "win32": 104 yield "file:///{}".format(f.name.replace("\\", "/")) 105 f.close() 106 else: 107 yield f"file://{f.name}" 108 f.close() 109 yield "tcp://127.0.0.1:%d" % common.find_free_port() 110 111 def _test_default_store_timeout(self, backend): 112 for init_method in self._init_methods(): 113 c2p = [] 114 t = threading.Thread( 115 target=self._test_store_timeout, args=(backend, init_method, c2p) 116 ) 117 t.daemon = True 118 t.start() 119 t.join(5) 120 121 self.assertEqual(1, len(c2p)) 122 if isinstance(c2p[0], float): 123 # waiting time should be 1s, use 3s to rule out false alarm 124 self.assertGreater(3, c2p[0]) 125 elif isinstance(c2p[0], RuntimeError): 126 # let @retry_on_connect_failures handle the error 127 raise c2p[0] 128 else: 129 raise RuntimeError(f"Unexpected type {type(c2p[0])}") 130 131 132class TimeoutTest(TestCase): 133 @retry_on_connect_failures 134 def test_store_based_barrier(self): 135 f = tempfile.NamedTemporaryFile(delete=False) 136 port = common.find_free_port() 137 138 def thread_work(timeout, init_type, world_size, rank, error_list): 139 # we need to create a separate store just for the store barrier test 140 if init_type == "file": 141 barrier_store = dist.FileStore(f.name) 142 elif init_type == "tcp": 143 barrier_store = dist.TCPStore( 144 "localhost", 145 port, 146 world_size, 147 is_master=rank == 0, 148 wait_for_workers=False, 149 ) 150 elif init_type == "hash": 151 barrier_store = dist.HashStore() 152 try: 153 # 1 missing worker will cause it to timeout 154 if rank != world_size - 1: 155 c10d._store_based_barrier( 156 rank=rank, 157 store=barrier_store, 158 group_name="_", 159 rendezvous_count=world_size, 160 timeout=timeout, 161 logging_interval=timeout / 2, 162 ) 163 except torch.distributed.DistStoreError as e: 164 self.assertTrue(isinstance(e, torch.distributed.DistError)) 165 error_list.append(e) 166 167 world_size = 4 168 error_list = [] 169 threads = [] 170 for init_type in ["file", "tcp", "hash"]: 171 for rank in range(world_size): 172 t = threading.Thread( 173 target=thread_work, 174 args=( 175 timedelta(seconds=3), 176 init_type, 177 world_size, 178 rank, 179 error_list, 180 ), 181 ) 182 threads.append(t) 183 t.start() 184 185 for i, thread in enumerate(threads): 186 thread.join() 187 188 # we expect the world_size-1 threads to have failed 189 self.assertEqual(len(error_list), world_size - 1) 190 for error in error_list: 191 self.assertTrue( 192 "Timed out initializing process group in store based barrier" 193 in error.args[0] 194 ) 195 error_list = [] 196 threads = [] 197 198 199class Net(nn.Module): 200 def __init__(self) -> None: 201 super().__init__() 202 self.fc1 = nn.Linear(2, 10, bias=False) 203 self.fc2 = nn.Linear(10, 50, bias=False) 204 self.fc3 = nn.Linear(50, 4, bias=False) 205 self.relu = nn.ReLU() 206 207 def forward(self, x): 208 x = self.relu(self.fc1(x)) 209 x = self.relu(self.fc2(x)) 210 x = self.fc3(x) 211 return F.softmax(x, dim=1) 212 213 214class DoubleGpuNet(nn.Module): 215 def __init__(self, gpus): 216 super().__init__() 217 self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0]) 218 self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1]) 219 self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[1]) 220 self.relu = nn.ReLU() 221 self.no_grad_param = nn.Parameter( 222 torch.tensor([2, 2]).long(), requires_grad=False 223 ).to(gpus[0]) 224 225 def forward(self, x): 226 dev0 = self.fc1.weight.device 227 dev1 = self.fc2.weight.device 228 x = self.relu(self.fc1(x.to(dev0))) 229 x = self.relu(self.fc2(x.to(dev1))) 230 x = self.fc3(x) 231 return F.softmax(x, dim=1).to(dev0) 232 233 234class QuadraGpuNet(nn.Module): 235 def __init__(self, gpus): 236 super().__init__() 237 self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0]) 238 self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1]) 239 self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[2]) 240 self.fc4 = nn.Linear(4, 4, bias=False).to(gpus[3]) 241 self.relu = nn.ReLU() 242 self.no_grad_param = nn.Parameter( 243 torch.tensor([2, 2]).long(), requires_grad=False 244 ).to(gpus[0]) 245 246 def forward(self, x): 247 dev0 = self.fc1.weight.device 248 dev1 = self.fc2.weight.device 249 dev2 = self.fc3.weight.device 250 dev3 = self.fc4.weight.device 251 x = self.relu(self.fc1(x.to(dev0))) 252 x = self.relu(self.fc2(x.to(dev1))) 253 x = self.relu(self.fc3(x.to(dev2))) 254 x = self.fc4(x.to(dev3)) 255 return F.softmax(x, dim=1).to(dev0) 256 257 258class ConvNet(nn.Module): 259 def __init__(self, gpus, layouts, dtypes): 260 super().__init__() 261 self.dtypes = dtypes 262 if isinstance(gpus, list): 263 self.layer_gpus = gpus 264 else: 265 gpus = [gpus] * 4 266 self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to( 267 device=gpus[0], memory_format=layouts[0], dtype=dtypes[0] 268 ) 269 self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to( 270 device=gpus[1], memory_format=layouts[1], dtype=dtypes[1] 271 ) 272 self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to( 273 device=gpus[2], memory_format=layouts[2], dtype=dtypes[2] 274 ) 275 self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to( 276 device=gpus[3], memory_format=layouts[3], dtype=dtypes[3] 277 ) 278 279 def forward(self, x): 280 x = x.to(self.dtypes[0]) 281 # Could say 282 # x = self.conv0(x).to(device=self.conv1.weight.device, dtype=self.dtypes[1]) 283 # etc. But I don't want to appeal to the weights' devices directly, because part of this test's purpose 284 # is to verify weights are where expected if the model gets replicated. 285 gpus = self.layer_gpus if hasattr(self, "layer_gpus") else [x.device] * 4 286 x = self.conv0(x).to(device=gpus[1], dtype=self.dtypes[1]) 287 x = self.conv1(x).to(device=gpus[2], dtype=self.dtypes[2]) 288 x = self.conv2(x).to(device=gpus[3], dtype=self.dtypes[3]) 289 return self.conv3(x) 290 291 292class Task(nn.Module): 293 def __init__(self) -> None: 294 super().__init__() 295 self.p = nn.Parameter(torch.ones(2, 2)) 296 297 def forward(self, x): 298 return self.p + x 299 300 301class ModuleForDdpCommHook(nn.Module): 302 def __init__(self) -> None: 303 super().__init__() 304 self.t0 = Task() 305 306 def forward(self, x, rank): 307 return self.t0(x + rank) 308 309 310class SparseGradientModule(nn.Module): 311 def __init__(self) -> None: 312 super().__init__() 313 self.embedding = nn.EmbeddingBag(10, 10, sparse=True) 314 315 def forward(self, x): 316 return F.softmax(self.embedding(x), dim=1) 317 318 319class CommonDistributedDataParallelTest: 320 def tearDown(self): 321 # DistributedDataParallel test doesn't seem to call FileStore destructor 322 # TODO: investigate this test and the test is known to have issues 323 # Use this hack to remove files for that test 324 try: 325 os.remove(self.file_name) 326 except OSError: 327 pass 328 329 @property 330 def world_size(self): 331 return 2 332 333 def _prepare_single_device_module( 334 self, 335 process_group, 336 devices, 337 device_ids, 338 global_batch_size, 339 gradient_as_bucket_view=False, 340 ): 341 model = Net() 342 device = devices[0] if devices else torch.device("cuda:%d" % self.rank) 343 ddp_model = DistributedDataParallel( 344 copy.deepcopy(model).to(device), 345 device_ids=device_ids, 346 process_group=process_group, 347 bucket_cap_mb=0.001, 348 gradient_as_bucket_view=gradient_as_bucket_view, 349 ) 350 351 model.to(device) 352 353 input = torch.randn(global_batch_size, 2).to(device) 354 target = torch.randn(global_batch_size, 4).to(device) 355 356 return model, ddp_model, input, target 357 358 def _prepare_multi_device_module( 359 self, 360 process_group, 361 devices, 362 device_ids, 363 global_batch_size, 364 gradient_as_bucket_view=False, 365 ): 366 self.assertTrue( 367 len(devices) == 2 or len(devices) == 4, 368 f"unexpected devices for ddp tests {devices}", 369 ) 370 if len(devices) == 2: 371 model = DoubleGpuNet(devices) 372 elif len(devices) == 4: 373 model = QuadraGpuNet(devices) 374 375 ddp_model = DistributedDataParallel( 376 copy.deepcopy(model), 377 device_ids=device_ids, 378 process_group=process_group, 379 bucket_cap_mb=0.001, 380 gradient_as_bucket_view=gradient_as_bucket_view, 381 ) 382 383 input = torch.randn(global_batch_size, 2).cuda(devices[0]) 384 target = torch.randn(global_batch_size, 4) 385 386 return model, ddp_model, input, target 387 388 def _get_store(self): 389 return dist.FileStore(self.file_name, self.world_size) 390 391 def _get_process_group(self): 392 raise NotImplementedError("To be implemented by child class") 393 394 def _train_model( 395 self, model, input_var, target, loss, run_checkpoint=False, use_reentrant=True 396 ): 397 model.train() 398 if run_checkpoint: 399 output = checkpoint(model, input_var, use_reentrant=use_reentrant) 400 else: 401 output = model(input_var) 402 l = loss(output, target) 403 l.backward() 404 405 def _test_ddp_checkpointing( 406 self, 407 input_model, 408 process_group, 409 use_bucket_view, 410 find_unused_parameters=False, 411 static_graph=False, 412 run_checkpoint=False, 413 use_reentrant=True, 414 allow_none_grads=False, 415 ): 416 # to reproduce the same training results 417 torch.cuda.set_device(self.rank) 418 torch.manual_seed(31415) 419 model = copy.deepcopy(input_model).cuda() 420 ddp_model = copy.deepcopy(input_model).cuda() 421 ddp_model = nn.parallel.DistributedDataParallel( 422 ddp_model, 423 bucket_cap_mb=1, 424 gradient_as_bucket_view=use_bucket_view, 425 device_ids=[self.rank], 426 process_group=process_group, 427 find_unused_parameters=find_unused_parameters, 428 static_graph=static_graph, 429 ) 430 self.assertEqual( 431 ddp_model._get_ddp_logging_data().get("static_graph", 0), static_graph 432 ) 433 input, ddp_input, target, ddp_target = self._prepare_dummy_data() 434 loss = nn.MSELoss() 435 n_iters = 5 436 for i in range(n_iters): 437 model.zero_grad(set_to_none=False) 438 ddp_model.zero_grad(set_to_none=False) 439 self._train_model( 440 model, 441 input, 442 target, 443 loss, 444 run_checkpoint=run_checkpoint, 445 use_reentrant=use_reentrant, 446 ) 447 self._train_model( 448 ddp_model, 449 ddp_input, 450 ddp_target, 451 loss, 452 run_checkpoint=run_checkpoint, 453 use_reentrant=use_reentrant, 454 ) 455 for i, j in zip(model.parameters(), ddp_model.parameters()): 456 if not allow_none_grads: 457 self.assertTrue(i.grad is not None) 458 self.assertTrue(j.grad is not None) 459 self.assertEqual(i.grad, j.grad, rtol=1.3e-06, atol=5e-5) 460 461 # A list of tests for ddp with activation checkpointing 462 # when gradient_as_bucket_view=True, False. 463 # Most of the tests are referred to 464 # https://github.com/facebookresearch/fairscale/blob/main/tests/nn/pipe/test_checkpoint_ddp.py 465 class CheckpointOnceModule(nn.Module): 466 """ 467 Runs checkpoint for a single layer in the model. 468 """ 469 470 def __init__(self, use_reentrant=True): 471 super().__init__() 472 self.l1 = nn.Linear(20, 20) 473 self.l2 = nn.Linear(20, 20) 474 self.use_reentrant = use_reentrant 475 476 def forward(self, inp): 477 x = self.l1(inp) 478 x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant) 479 return x 480 481 class CheckpointTwiceModule(CheckpointOnceModule): 482 """ 483 Runs checkpoint for the same layer twice in a model. This simulates use 484 cases such as pipeline parallel where the same layer can be checkpointed 485 more than one time. 486 """ 487 488 def __init__(self, use_reentrant=True): 489 super().__init__(use_reentrant=use_reentrant) 490 491 def forward(self, inp): 492 x = self.l1(inp) 493 x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant) 494 x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant) 495 return x 496 497 class CheckpointTwiceModuleWeightSharing(CheckpointTwiceModule): 498 """ 499 Similar to CheckpointTwiceModule but the weights are shared. 500 """ 501 502 def __init__(self, use_reentrant=True): 503 super().__init__(use_reentrant=use_reentrant) 504 # Share weights 505 self.l1.weight = self.l2.weight 506 507 def forward(self, inp): 508 x = self.l1(inp) 509 x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant) 510 x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant) 511 return x 512 513 class DynamicCheckpointTwiceModule(CheckpointTwiceModule): 514 def __init__(self, use_reentrant=True): 515 super().__init__(use_reentrant=use_reentrant) 516 self.count = 0 517 518 def forward(self, inp): 519 if self.count % 2: 520 x = checkpoint(self.l1, inp, use_reentrant=self.use_reentrant) 521 else: 522 x = checkpoint(self.l2, inp, use_reentrant=self.use_reentrant) 523 524 self.count += 1 525 return x 526 527 class DynamicCheckpointTwiceModuleWeightSharing(DynamicCheckpointTwiceModule): 528 def __init__(self, use_reentrant=True): 529 super().__init__(use_reentrant=use_reentrant) 530 # Share weights 531 self.l1.weight = self.l2.weight 532 533 def _prepare_dummy_data(self): 534 ddp_bs = 16 535 bs = ddp_bs * self.world_size 536 input = torch.rand((bs, 20), device="cuda", requires_grad=True) 537 target = torch.randn((bs, 20), device="cuda") 538 offset = self.rank * ddp_bs 539 ddp_input = input[offset : offset + ddp_bs] 540 ddp_target = target[offset : offset + ddp_bs] 541 return input, ddp_input, target, ddp_target 542 543 @skip_if_lt_x_gpu(2) 544 @parametrize("use_reentrant", [True, False]) 545 def test_ddp_checkpointing_once(self, use_reentrant): 546 """ 547 DDP works as expected when layer is checkpointed only once. 548 """ 549 process_group = self._get_process_group() 550 for use_bucket_view, static_graph in product((False, True), (False, True)): 551 self._test_ddp_checkpointing( 552 self.CheckpointOnceModule(use_reentrant=use_reentrant), 553 process_group=process_group, 554 use_bucket_view=use_bucket_view, 555 static_graph=static_graph, 556 ) 557 if static_graph: 558 # find_unused_parameters does not make a difference, since it is 559 # ignored for static graph. 560 self._test_ddp_checkpointing( 561 self.CheckpointOnceModule(), 562 process_group=process_group, 563 use_bucket_view=use_bucket_view, 564 static_graph=static_graph, 565 find_unused_parameters=True, 566 ) 567 568 @skip_if_lt_x_gpu(2) 569 @parametrize("use_reentrant", [True, False]) 570 def test_ddp_checkpointing_unused_params(self, use_reentrant): 571 """ 572 With reentrant autograd checkpointing impl, DDP will fail when there are 573 unused params in the model and no static graph training. With 574 non-reentrant checkpointing implementation, this works as expected. 575 """ 576 process_group = self._get_process_group() 577 for use_bucket_view in (True, False): 578 err_ctx = ( 579 nullcontext() 580 if not use_reentrant 581 else self.assertRaisesRegex( 582 RuntimeError, "Expected to mark a variable ready only once." 583 ) 584 ) 585 with err_ctx: 586 model = self._test_ddp_checkpointing( 587 self.CheckpointOnceModule(use_reentrant=use_reentrant), 588 process_group=process_group, 589 use_bucket_view=use_bucket_view, 590 find_unused_parameters=True, 591 ) 592 # test passes when static_graph is true 593 model = self._test_ddp_checkpointing( 594 self.CheckpointOnceModule(use_reentrant=use_reentrant), 595 process_group=process_group, 596 use_bucket_view=use_bucket_view, 597 find_unused_parameters=True, 598 static_graph=True, 599 ) 600 601 @skip_if_lt_x_gpu(2) 602 @parametrize("use_reentrant", [True, False]) 603 def test_ddp_checkpointing_twice(self, use_reentrant): 604 """ 605 Checkpointing twice fails for non-static graph with reentrant checkpoint 606 implementation, succeeds with non-reentrant checkpoint implementation. 607 """ 608 process_group = self._get_process_group() 609 for use_bucket_view in (True, False): 610 err_ctx = ( 611 nullcontext() 612 if not use_reentrant 613 else self.assertRaisesRegex( 614 RuntimeError, "Expected to mark a variable ready only once." 615 ) 616 ) 617 with err_ctx: 618 model = self._test_ddp_checkpointing( 619 self.CheckpointTwiceModule(use_reentrant=use_reentrant), 620 process_group=process_group, 621 use_bucket_view=use_bucket_view, 622 static_graph=False, 623 ) 624 625 with err_ctx: 626 model = self._test_ddp_checkpointing( 627 self.CheckpointTwiceModule(use_reentrant=use_reentrant), 628 process_group=process_group, 629 use_bucket_view=use_bucket_view, 630 static_graph=False, 631 find_unused_parameters=True, 632 ) 633 634 @skip_if_lt_x_gpu(2) 635 @parametrize("use_reentrant", [True, False]) 636 def test_ddp_checkpointing_twice_static_graph(self, use_reentrant): 637 """ 638 Regardless of reentrant or non-reentrant checkpointing impl, 639 checkpointing twice works with static graph enabled. 640 """ 641 process_group = self._get_process_group() 642 for use_bucket_view in (True, False): 643 # Test passes when static_graph=True. 644 model = self._test_ddp_checkpointing( 645 self.CheckpointTwiceModule(use_reentrant=use_reentrant), 646 process_group=process_group, 647 use_bucket_view=use_bucket_view, 648 static_graph=True, 649 ) 650 651 @skip_if_lt_x_gpu(2) 652 def test_ddp_checkpointing_dynamic_module(self): 653 """ 654 Dynamic module can be checkpointed, multiple times, with non-reentrant 655 checkpointing implementation. 656 """ 657 process_group = self._get_process_group() 658 for use_bucket_view in (True, False): 659 model = self._test_ddp_checkpointing( 660 self.DynamicCheckpointTwiceModule(use_reentrant=False), 661 process_group=process_group, 662 use_bucket_view=use_bucket_view, 663 static_graph=False, 664 find_unused_parameters=True, 665 # Grads can be none sometimes due to dynamic module not using 666 # all params. 667 allow_none_grads=True, 668 ) 669 670 @skip_if_lt_x_gpu(2) 671 def test_ddp_checkpointing_dynamic_weight_sharing(self): 672 """ 673 Dynamic module can be checkpointed multiple times with weight sharing 674 using non-reentrant checkpointing implementation. 675 """ 676 process_group = self._get_process_group() 677 for use_bucket_view in (True, False): 678 model = self._test_ddp_checkpointing( 679 self.DynamicCheckpointTwiceModuleWeightSharing(use_reentrant=False), 680 process_group=process_group, 681 use_bucket_view=use_bucket_view, 682 static_graph=False, 683 find_unused_parameters=True, 684 # Grads can be none sometimes due to dynamic module not using 685 # all params. 686 allow_none_grads=True, 687 ) 688 689 # DDP works as expected if there is weight sharing among layers 690 @skip_if_lt_x_gpu(2) 691 @parametrize("use_reentrant", [True, False]) 692 def test_ddp_checkpointing_weight_sharing(self, use_reentrant): 693 """ 694 Test that checkpointing with weight sharing works. 695 """ 696 process_group = self._get_process_group() 697 torch.cuda.set_device(self.rank) 698 for use_bucket_view, static_graph in product((False, True), (False, True)): 699 torch.manual_seed(31415) 700 l1 = nn.Linear(20, 20) 701 l2 = nn.Linear(20, 20) 702 l1.weight = l2.weight 703 model = nn.Sequential(l1, l2) 704 self._test_ddp_checkpointing( 705 model, 706 process_group=process_group, 707 use_bucket_view=use_bucket_view, 708 static_graph=static_graph, 709 run_checkpoint=True, 710 use_reentrant=use_reentrant, 711 ) 712 713 @skip_if_lt_x_gpu(2) 714 def test_ddp_checkpointing_twice_weight_sharing(self): 715 """ 716 Checkpointing should work with static graph in the case of checkpointing 717 same layer twice and having weights shared across layers. 718 """ 719 process_group = self._get_process_group() 720 torch.cuda.set_device(self.rank) 721 for use_bucket_view in (True, False): 722 model = self._test_ddp_checkpointing( 723 self.CheckpointTwiceModuleWeightSharing(), 724 process_group=process_group, 725 use_bucket_view=use_bucket_view, 726 static_graph=True, 727 ) 728 729 def test_invalid_powerSGD_state(self): 730 for start_powerSGD_iter, use_error_feedback, warm_start in product( 731 [0, 1], [True, False], [True, False] 732 ): 733 if not use_error_feedback and not warm_start: 734 continue 735 with self.assertRaisesRegex( 736 ValueError, 737 "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, " 738 "because PowerSGD can only be applied after the first two iterations in DDP.", 739 ): 740 state = powerSGD.PowerSGDState( 741 process_group=None, 742 matrix_approximation_rank=1, 743 start_powerSGD_iter=start_powerSGD_iter, 744 use_error_feedback=use_error_feedback, 745 warm_start=warm_start, 746 ) 747 748 def _test_ddp_with_process_group( 749 self, 750 process_group, 751 devices, 752 device_ids, 753 multi_device=False, 754 gradient_as_bucket_view=False, 755 ): 756 """ 757 Note: we pass down `device_ids` all the way to DistributedDataParallel 758 as part of the test. Below you find tests that either use a list of 759 integers, a list of `torch.Device` instances, or an empty list. 760 The `devices` argument is used to control placement of the model and 761 must always be specified as list of `torch.Device` instances. 762 """ 763 local_batch_size = 1 if devices is None else len(devices) 764 global_batch_size = self.world_size * local_batch_size 765 766 if multi_device: 767 model, ddp_model, input, target = self._prepare_multi_device_module( 768 process_group, 769 devices, 770 device_ids, 771 global_batch_size, 772 gradient_as_bucket_view, 773 ) 774 ddp_logging_data = ddp_model._get_ddp_logging_data() 775 self.assertTrue(ddp_logging_data.get("is_multi_device_module")) 776 else: 777 model, ddp_model, input, target = self._prepare_single_device_module( 778 process_group, 779 devices, 780 device_ids, 781 global_batch_size, 782 gradient_as_bucket_view, 783 ) 784 ddp_logging_data = ddp_model._get_ddp_logging_data() 785 self.assertFalse(ddp_logging_data.get("is_multi_device_module")) 786 787 def step_model(model, input, target): 788 model.train() 789 output = model(input) 790 loss = F.mse_loss(output, target.to(output.device)) 791 loss.backward() 792 793 def update_parameters(model): 794 for param in model.parameters(): 795 with torch.no_grad(): 796 param -= param.grad 797 param.grad = None 798 799 # check two model parameters over 2 iterations 800 for iteration in range(2): 801 # single cpu/gpu training 802 step_model(model, input, target) 803 804 # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs 805 step_model( 806 ddp_model, 807 input[ 808 self.rank * local_batch_size : (self.rank + 1) * local_batch_size 809 ], 810 target[ 811 self.rank * local_batch_size : (self.rank + 1) * local_batch_size 812 ], 813 ) 814 815 # Update weights and run a second iteration to shake out errors 816 update_parameters(model) 817 update_parameters(ddp_model) 818 self.assertEqual( 819 len(list(model.parameters())), len(list(ddp_model.parameters())) 820 ) 821 for i, j in zip(model.parameters(), ddp_model.parameters()): 822 self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5) 823 824 # Shuffle the input so that DDP input is different 825 torch.manual_seed(1337 + iteration) 826 input = input[torch.randperm(global_batch_size)] 827 828 def _gpu_model_with_ddp_comm_hook( 829 self, process_group, hook=None, gradient_as_bucket_view=False, state=None 830 ): 831 device_id = gpus_for_rank(self.world_size)[self.rank][0] 832 gpu_model = DistributedDataParallel( 833 ModuleForDdpCommHook().to(device_id), 834 device_ids=[device_id], 835 process_group=process_group, 836 gradient_as_bucket_view=gradient_as_bucket_view, 837 ) 838 839 # Register a DDP communication hook if any. 840 if hook is not None: 841 gpu_model.register_comm_hook(state, hook) 842 843 return gpu_model 844 845 def _gpu_model_with_builtin_ddp_comm_hook( 846 self, process_group, hook=None, gradient_as_bucket_view=False 847 ): 848 device_id = gpus_for_rank(self.world_size)[self.rank][0] 849 gpu_model = DistributedDataParallel( 850 ModuleForDdpCommHook().to(device_id), 851 device_ids=[device_id], 852 process_group=process_group, 853 gradient_as_bucket_view=gradient_as_bucket_view, 854 ) 855 856 # Register a built-in DDP communication hook if defined 857 if hook is not None: 858 gpu_model._register_builtin_comm_hook(hook) 859 860 return gpu_model 861 862 def _run_and_verify_hook(self, model, input, expected_grad): 863 # Run forward 864 output = model(input, self.rank) 865 866 # Run backward 867 output.mean().backward() 868 869 [self.assertEqual(p.grad, expected_grad) for p in model.parameters()] 870 871 def _simple_hook( 872 self, state: object, bucket: dist.GradBucket 873 ) -> torch.futures.Future[torch.Tensor]: 874 fut = torch.futures.Future() 875 fut.set_result(torch.ones_like(bucket.buffer())) 876 877 def fut_then(fut): 878 # Add ones to fut's result. 879 t = fut.value() 880 return t + torch.ones_like(t) 881 882 return fut.then(fut_then) 883 884 def _test_not_nan(self, model, x): 885 y = model(x) 886 self.assertFalse(y.isnan().any().item()) 887 y.sum().backward() 888 for p in model.parameters(): 889 self.assertFalse(p.grad.isnan().any().item()) 890 891 @skip_if_lt_x_gpu(2) 892 def test_sync_batch_norm_only_empty_input(self): 893 pg = self._get_process_group() 894 895 model = torch.nn.Sequential( 896 nn.BatchNorm2d(2), 897 ).to(device=self.rank) 898 model = DistributedDataParallel( 899 model, 900 device_ids=[self.rank], 901 process_group=pg, 902 ) 903 model = nn.SyncBatchNorm.convert_sync_batchnorm( 904 model, 905 process_group=pg, 906 ) 907 908 model.train() 909 910 # only rank 0 receives empty inputs 911 x = torch.zeros( 912 (1 if self.rank != 0 else 0, 2, 11, 13), 913 dtype=torch.float32, 914 device=self.rank, 915 ) 916 917 # input requires grad, this will trigger the collective communication 918 # in the backward pass 919 x.requires_grad = True 920 self._test_not_nan(model, x) 921 922 # input does not requires grad 923 x.requires_grad = False 924 self._test_not_nan(model, x) 925 926 # all ranks receive empty inputs 927 x = torch.zeros((0, 2, 11, 13), dtype=torch.float32, device=self.rank) 928 929 # input requires grad, this will trigger the collective communication 930 # in the backward pass 931 x.requires_grad = True 932 self._test_not_nan(model, x) 933 934 # input does not requires grad 935 x.requires_grad = False 936 self._test_not_nan(model, x) 937 938 @skip_if_lt_x_gpu(2) 939 def test_sync_batch_norm_empty_input(self): 940 pg = self._get_process_group() 941 942 model = torch.nn.Sequential( 943 nn.Conv2d(2, 2, 3), 944 nn.BatchNorm2d(2), 945 nn.Linear(28, 2), 946 ).to(device=self.rank) 947 model = DistributedDataParallel( 948 model, 949 device_ids=[self.rank], 950 process_group=pg, 951 ) 952 model = nn.SyncBatchNorm.convert_sync_batchnorm( 953 model, 954 process_group=pg, 955 ) 956 957 model.train() 958 # only rank 0 receives empty inputs 959 x = torch.zeros( 960 (3 if self.rank != 0 else 0, 2, 30, 30), 961 dtype=torch.float32, 962 device=self.rank, 963 ) 964 965 self._test_not_nan(model, x) 966 967 # all ranks receive empty inputs 968 x = torch.zeros((0, 2, 30, 30), dtype=torch.float32, device=self.rank) 969 970 self._test_not_nan(model, x) 971 972 @dataclass 973 class CustomOutput: 974 o1: Optional[torch.Tensor] 975 o2: Dict[str, torch.Tensor] 976 977 class DataclassOutputModule(nn.Module): 978 def __init__(self, skip_o1): 979 super().__init__() 980 self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(3)]) 981 self.relu = nn.ReLU() 982 self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(3)]) 983 self.skip_o1 = skip_o1 984 985 def forward(self, x): 986 o1 = None if self.skip_o1 else self.relu(self.seq1(x)) 987 o2 = {"a": self.seq2(x), "b": self.relu(self.seq2(x))} 988 return CommonDistributedDataParallelTest.CustomOutput(o1=o1, o2=o2) 989 990 def _test_dataclass_output(self, skip_o1): 991 net_x = torch.cat([torch.ones(4, 10) * i for i in range(self.world_size)]).to( 992 self.rank 993 ) 994 ddp_x = torch.ones(4, 10, device=self.rank) * self.rank 995 996 # use manual_seed to make sure local models start with the same values 997 torch.manual_seed(0) 998 net = self.DataclassOutputModule(skip_o1=skip_o1).to(self.rank) 999 ddp = DistributedDataParallel( 1000 copy.deepcopy(net), 1001 device_ids=[self.rank], 1002 find_unused_parameters=True, 1003 static_graph=False, 1004 process_group=self._get_process_group(), 1005 ) 1006 1007 net_out = net(net_x) 1008 ddp_out = ddp(ddp_x) 1009 1010 net_loss = F.mse_loss( 1011 net_out.o1 + net_out.o2["a"] + net_out.o2["b"] 1012 if not skip_o1 1013 else net_out.o2["a"] + net_out.o2["b"], 1014 torch.ones_like(net_out.o2["a"], device=self.rank), 1015 ) 1016 ddp_loss = F.mse_loss( 1017 ddp_out.o1 + ddp_out.o2["a"] + ddp_out.o2["b"] 1018 if not skip_o1 1019 else ddp_out.o2["a"] + ddp_out.o2["b"], 1020 torch.ones_like(ddp_out.o2["a"], device=self.rank), 1021 ) 1022 1023 net_loss.backward() 1024 ddp_loss.backward() 1025 1026 for p1, p2 in zip(net.parameters(), ddp.parameters()): 1027 if torch.is_tensor(p1.grad): 1028 self.assertTrue(p1.grad.allclose(p2.grad)) 1029 else: 1030 self.assertEqual(p1.grad, p2.grad) 1031 1032 @skip_if_lt_x_gpu(2) 1033 def test_dataclass_output(self): 1034 self._test_dataclass_output(skip_o1=False) 1035 1036 @skip_if_lt_x_gpu(2) 1037 def test_dataclass_output_unused_param(self): 1038 self._test_dataclass_output(skip_o1=True) 1039 1040 1041class ComputeBucketAssignmentTest(TestCase): 1042 def test_single_limit_single_dtype(self): 1043 tensors = [ 1044 torch.empty([100], dtype=torch.float), 1045 torch.empty([200], dtype=torch.float), 1046 torch.empty([100], dtype=torch.float), 1047 torch.empty([50], dtype=torch.float), 1048 ] 1049 result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( 1050 tensors, [400] 1051 ) 1052 self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits)) 1053 self.assertEqual([[0], [1], [2], [3]], result) 1054 1055 def test_single_limit_multi_dtype(self): 1056 tensors = [ 1057 torch.empty([50], dtype=torch.float), 1058 torch.empty([25], dtype=torch.double), 1059 torch.empty([50], dtype=torch.float), 1060 torch.empty([25], dtype=torch.double), 1061 torch.empty([50], dtype=torch.float), 1062 torch.empty([25], dtype=torch.double), 1063 ] 1064 result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( 1065 tensors, [400] 1066 ) 1067 self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits)) 1068 self.assertEqual([[0, 2], [1, 3], [4], [5]], result) 1069 1070 def test_multi_limit_single_dtype(self): 1071 tensors = [ 1072 torch.empty([10], dtype=torch.float), 1073 torch.empty([10], dtype=torch.float), 1074 torch.empty([10], dtype=torch.float), 1075 torch.empty([10], dtype=torch.float), 1076 ] 1077 result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( 1078 tensors, [40, 80] 1079 ) 1080 self.assertEqual(per_bucket_size_limits, [40, 80, 80]) 1081 self.assertEqual([[0], [1, 2], [3]], result) 1082 1083 def test_multi_limit_multi_dtype(self): 1084 tensors = [ 1085 torch.empty([50], dtype=torch.float), 1086 torch.empty([25], dtype=torch.double), 1087 torch.empty([50], dtype=torch.float), 1088 torch.empty([25], dtype=torch.double), 1089 torch.empty([50], dtype=torch.float), 1090 torch.empty([25], dtype=torch.double), 1091 ] 1092 result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( 1093 tensors, [200, 400] 1094 ) 1095 self.assertEqual([[0], [1], [2, 4], [3, 5]], result) 1096 self.assertEqual(per_bucket_size_limits, [200, 200, 400, 400]) 1097 1098 1099class AbstractCommTest: 1100 @property 1101 def op_timeout_sec(self): 1102 return 1 1103 1104 @property 1105 def world_size(self): 1106 return 2 1107 1108 @property 1109 def device(self): 1110 self.fail("test subclass didn't override device") 1111 1112 def _verify_sequence_number_across_pg(self, pg, verify_pg): 1113 seq_num = pg._get_sequence_number_for_group() 1114 obj_list = [None for _ in range(dist.get_world_size(verify_pg))] 1115 # We use a separate pg to verify the sequence numbers, otherwise these 1116 # collectives will themselves increment the sequence number. 1117 dist.all_gather_object(obj_list, seq_num, group=verify_pg) 1118 self.assertEqual(len(set(obj_list)), 1) 1119 return obj_list[0] 1120 1121 def _test_sequence_num_incremented(self, process_group, ranks): 1122 # verify initial sequence numbers. Use a distinct process group for 1123 # verification to keep counts as expected with respect to process_group. 1124 verify_pg = dist.new_group( 1125 ranks=ranks, 1126 backend="gloo", 1127 ) 1128 assert dist.get_world_size(process_group) == dist.get_world_size(verify_pg) 1129 1130 initial_num = ( 1131 self._verify_sequence_number_across_pg( 1132 pg=process_group, verify_pg=verify_pg 1133 ) 1134 if not c10d._rank_not_in_group(process_group) 1135 else -1 1136 ) 1137 1138 # Verify sequence numbers are appropriately incremented 1139 for i in range(10): 1140 t = torch.ones(1, device=torch.cuda.current_device()) 1141 dist.all_reduce(t, group=process_group) 1142 if not c10d._rank_not_in_group(process_group): 1143 seq_num = self._verify_sequence_number_across_pg( 1144 pg=process_group, 1145 verify_pg=verify_pg, 1146 ) 1147 self.assertEqual(initial_num + i + 1, seq_num) 1148 1149 if dist.get_world_size(process_group) > 2: 1150 # Test when certain ranks don't call collectives 1151 if dist.get_rank(process_group) not in [0, 2]: 1152 dist.all_reduce(t, group=process_group, async_op=True) 1153 # Now ranks 0 and 2 should be lagging by 1. 1154 if not c10d._rank_not_in_group(process_group): 1155 seq_num = process_group._get_sequence_number_for_group() 1156 rank = dist.get_rank(process_group) 1157 obj_list = [None for _ in range(dist.get_world_size(verify_pg))] 1158 dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg) 1159 rank_to_seq_num = dict(obj_list) 1160 self.assertEqual(len(set(rank_to_seq_num.values())), 2) 1161 self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2]) 1162 expected_same = { 1163 rank_to_seq_num[i] 1164 for i in rank_to_seq_num.keys() 1165 if i not in [0, 2] 1166 } 1167 self.assertEqual(len(expected_same), 1) 1168 self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1]) 1169 1170 def _test_sequence_num_incremented_default_group(self, backend_name): 1171 torch.cuda.set_device(self.rank) 1172 store = dist.FileStore(self.file_name, self.world_size) 1173 dist.init_process_group( 1174 backend_name, 1175 world_size=self.world_size, 1176 rank=self.rank, 1177 store=store, 1178 ) 1179 self._test_sequence_num_incremented( 1180 c10d._get_default_group(), 1181 ranks=list(range(dist.get_world_size())), 1182 ) 1183 1184 def _test_sequence_num_incremented_subgroup(self, backend_name): 1185 torch.cuda.set_device(self.rank) 1186 store = dist.FileStore(self.file_name, self.world_size) 1187 dist.init_process_group( 1188 backend_name, 1189 world_size=self.world_size, 1190 rank=self.rank, 1191 store=store, 1192 ) 1193 subgroup_ranks = [0, 1, 2] 1194 subgroup = dist.new_group(subgroup_ranks) 1195 self._test_sequence_num_incremented(subgroup, subgroup_ranks) 1196 1197 def _test_sequence_num_set_default_pg(self, backend): 1198 store = dist.FileStore(self.file_name, self.world_size) 1199 dist.init_process_group( 1200 backend, 1201 world_size=self.world_size, 1202 rank=self.rank, 1203 store=store, 1204 ) 1205 1206 default_pg = c10d._get_default_group() 1207 seq_num = default_pg._get_sequence_number_for_group() 1208 obj_list = [None for _ in range(dist.get_world_size())] 1209 dist.all_gather_object(obj_list, seq_num) 1210 self.assertEqual(len(set(obj_list)), 1) 1211 1212 def _test_sequence_num_set_new_group(self, backend): 1213 store = dist.FileStore(self.file_name, self.world_size) 1214 dist.init_process_group( 1215 backend, 1216 world_size=self.world_size, 1217 rank=self.rank, 1218 store=store, 1219 ) 1220 1221 subgroup = dist.new_group([0, 1]) 1222 1223 if not c10d._rank_not_in_group(subgroup): 1224 subgroup_seq = subgroup._get_sequence_number_for_group() 1225 obj_list = [None for _ in range(dist.get_world_size(subgroup))] 1226 dist.all_gather_object(obj_list, subgroup_seq, group=subgroup) 1227 self.assertEqual(len(set(obj_list)), 1) 1228 1229 def _test_warn_not_in_group(self, backend): 1230 store = dist.FileStore(self.file_name, self.world_size) 1231 dist.init_process_group( 1232 backend, 1233 world_size=self.world_size, 1234 rank=self.rank, 1235 store=store, 1236 ) 1237 in_group_ranks = list(filter(lambda x: x % 2 == 0, range(self.world_size))) 1238 group = dist.new_group(in_group_ranks) 1239 1240 x = torch.zeros(2, 2).cuda(self.rank) 1241 xs = [torch.zeros(2, 2).cuda(self.rank) for _ in range(len(in_group_ranks))] 1242 if self.rank not in in_group_ranks: 1243 msg = ".*{}.*does not belong to.*" 1244 with self.assertWarnsOnceRegex(UserWarning, msg.format("all_gather")): 1245 dist.all_gather(xs, x, group=group) 1246 with self.assertWarnsOnceRegex(UserWarning, msg.format("all_reduce")): 1247 dist.all_reduce(x, group=group) 1248 with self.assertWarnsOnceRegex(UserWarning, msg.format("barrier")): 1249 dist.barrier(group=group) 1250 with self.assertWarnsOnceRegex(UserWarning, msg.format("broadcast")): 1251 dist.broadcast(x, src=0, group=group) 1252 else: 1253 dist.all_gather(xs, x, group=group) 1254 dist.all_reduce(x, group=group) 1255 dist.barrier(group=group) 1256 dist.broadcast(x, src=0, group=group) 1257 1258 def _test_rank_membership(self, backend): 1259 store = dist.FileStore(self.file_name, self.world_size) 1260 dist.init_process_group( 1261 backend, 1262 world_size=self.world_size, 1263 rank=self.rank, 1264 store=store, 1265 ) 1266 self.assertTrue(self.world_size > 1) 1267 1268 group = dist.new_group(ranks=[1]) 1269 self.assertEqual(dist.get_group_rank(group, 1), 0) 1270 with self.assertRaisesRegex(ValueError, "not part of group"): 1271 dist.get_group_rank(group, 0) 1272 with self.assertRaisesRegex(ValueError, "not registered"): 1273 dist.get_group_rank(DummyProcessGroup(self.rank, self.world_size), 0) 1274 1275 self.assertEqual(dist.get_global_rank(group, 0), 1) 1276 with self.assertRaisesRegex(ValueError, "not part of group"): 1277 dist.get_global_rank(group, 1) 1278 with self.assertRaisesRegex(ValueError, "not registered"): 1279 dist.get_global_rank(DummyProcessGroup(self.rank, self.world_size), 0) 1280 1281 self.assertEqual(dist.get_process_group_ranks(group), [1]) 1282 1283 def _test_tensor_dtype_mismatch(self, backend): 1284 store = dist.FileStore(self.file_name, self.world_size) 1285 dist.init_process_group( 1286 backend, 1287 world_size=self.world_size, 1288 rank=self.rank, 1289 store=store, 1290 ) 1291 1292 tensor = torch.ones(2, 2, device=self.device) * 7 1293 tensor_h = tensor.half() 1294 tensor_list = [ 1295 torch.zeros(2, 2, device=self.device) for _ in range(self.world_size) 1296 ] 1297 tensor_list_h = list(tensor_list) 1298 tensor_list_h[1] = tensor_list_h[1].half() 1299 1300 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1301 dist.all_gather(tensor_list_h, tensor) 1302 1303 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1304 dist.all_gather(tensor_list, tensor_h) 1305 1306 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1307 dist.all_gather_coalesced([tensor_list_h], tensor_list) 1308 dist.all_gather_coalesced([tensor_list], tensor_list_h) 1309 1310 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1311 dist.all_reduce_coalesced(tensor_list_h) 1312 1313 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1314 dist.reduce_scatter(tensor, tensor_list_h) 1315 1316 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1317 dist.reduce_scatter(tensor_h, tensor_list) 1318 1319 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1320 dist.all_to_all_single(tensor_h, tensor) 1321 1322 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1323 dist.all_to_all(tensor_list_h, tensor_list) 1324 1325 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1326 dist.all_to_all(tensor_list, tensor_list_h) 1327 1328 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1329 dist.scatter(tensor, tensor_list_h) 1330 1331 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1332 dist.gather(tensor_h, tensor_list) 1333 1334 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1335 dist.gather(tensor, tensor_list_h) 1336 1337 with self.assertRaisesRegex(ValueError, "tensors with different dtypes"): 1338 dist.scatter(tensor_h, tensor_list) 1339 1340 def _test_tensor_dtype_complex(self, backend): 1341 store = dist.FileStore(self.file_name, self.world_size) 1342 dist.init_process_group( 1343 backend, 1344 world_size=self.world_size, 1345 rank=self.rank, 1346 store=store, 1347 ) 1348 1349 tensor = torch.rand(2, device=self.device) 1350 tensor_c = torch.view_as_complex(tensor) 1351 tensor_list = [ 1352 torch.rand(2, device=self.device) for _ in range(self.world_size) 1353 ] 1354 tensor_list_c = list(tensor_list) 1355 tensor_list_c[1] = torch.view_as_complex(tensor_list_c[1]) 1356 1357 dist.all_gather(tensor_list, tensor) 1358 dist.all_gather(tensor_list, tensor_c) 1359 dist.all_gather(tensor_list_c, tensor) 1360 dist.all_gather(tensor_list_c, tensor_c) 1361 1362 def _test_bool_tensors(self, backend): 1363 store = dist.FileStore(self.file_name, self.world_size) 1364 dist.init_process_group( 1365 backend, 1366 world_size=self.world_size, 1367 rank=self.rank, 1368 store=store, 1369 ) 1370 device = "cuda" if backend == "nccl" else "cpu" 1371 # test alltoall_base 1372 tensor = torch.tensor([1, 0, 0, 1], dtype=torch.bool, device=device) 1373 zeros = torch.tensor([0, 0, 0, 0], dtype=torch.bool, device=device) 1374 outensor = zeros if self.rank > 0 else tensor 1375 dist.broadcast(outensor, src=0) 1376 self.assertEqual(outensor, tensor) 1377 1378 1379# Variant of AbstractCommTest that expects world size of 4 1380class AbstractLargeCommTest: 1381 @property 1382 def op_timeout_sec(self): 1383 return 1 1384 1385 @property 1386 def world_size(self): 1387 return 4 1388 1389 @property 1390 def device(self): 1391 raise RuntimeError("Implement me") 1392 1393 def _test_new_group_local_sync(self, backend): 1394 store = dist.FileStore(self.file_name, self.world_size) 1395 dist.init_process_group( 1396 backend, 1397 world_size=self.world_size, 1398 rank=self.rank, 1399 store=store, 1400 ) 1401 rank = dist.get_rank() 1402 ranks_in = [rank, (rank + 2) % self.world_size] 1403 ranks_out = [i for i in range(self.world_size) if i not in ranks_in] 1404 self.assertIn(rank, ranks_in) 1405 self.assertNotIn(rank, ranks_out) 1406 1407 self.assertIsNone( 1408 dist.new_group(ranks=ranks_out, use_local_synchronization=True) 1409 ) 1410 1411 new_pg = dist.new_group(ranks=ranks_in, use_local_synchronization=True) 1412 self.assertIsInstance(new_pg, dist.ProcessGroup) 1413 1414 # PTD sorts ranks before creating the PG, so [3, 1] actually gets assigned ranks [1, 0] 1415 ranks_in.sort() 1416 self.assertEqual(dist.get_group_rank(new_pg, rank), ranks_in.index(rank)) 1417 self.assertEqual( 1418 ranks_in, 1419 dist.get_process_group_ranks(new_pg), 1420 f"expecting {ranks_in} but got {dist.get_process_group_ranks(new_pg)}", 1421 ) 1422 1423 def _test_new_group_local_sync_sanity_check(self, backend): 1424 store = dist.FileStore(self.file_name, self.world_size) 1425 dist.init_process_group( 1426 backend, 1427 world_size=self.world_size, 1428 rank=self.rank, 1429 store=store, 1430 ) 1431 rank = dist.get_rank() 1432 1433 # split the world in 2 PGs 1434 rank = dist.get_rank() 1435 pg_idx = rank // 2 1436 ranks_in = [pg_idx * 2, pg_idx * 2 + 1] 1437 new_pg = dist.new_group(ranks=ranks_in, use_local_synchronization=True) 1438 1439 input_tensor = torch.tensor([pg_idx, rank], device=self.device) 1440 output_tensor_list = [ 1441 torch.tensor( 1442 [-1, -1], 1443 device=self.device, 1444 ) 1445 for _ in range(new_pg.size()) 1446 ] 1447 dist.all_gather(output_tensor_list, input_tensor, group=new_pg) 1448 1449 expected = [ 1450 torch.tensor([pg_idx, ranks_in[0]], device=self.device), 1451 torch.tensor([pg_idx, ranks_in[1]], device=self.device), 1452 ] 1453 self.assertEqual(output_tensor_list, expected) 1454 1455 def _test_new_group_local_sync_duplicate_pg(self, backend): 1456 """ 1457 We should support users create multiple PGs with the same set of 1458 members, and no conflict in group name 1459 """ 1460 store = dist.FileStore(self.file_name, self.world_size) 1461 dist.init_process_group( 1462 backend, 1463 world_size=self.world_size, 1464 rank=self.rank, 1465 store=store, 1466 ) 1467 rank = dist.get_rank() 1468 1469 # split the world in 2 PGs 1470 rank = dist.get_rank() 1471 pg_idx = rank // 2 1472 ranks_in = [pg_idx * 2, pg_idx * 2 + 1] 1473 new_pgs = [] 1474 for _ in range(2): 1475 new_pgs.append( 1476 dist.new_group(ranks=ranks_in, use_local_synchronization=True) 1477 ) 1478 1479 input_tensor = torch.tensor([pg_idx, rank], device=self.device) 1480 for new_pg in new_pgs: 1481 output_tensor_list = [ 1482 torch.tensor( 1483 [-1, -1], 1484 device=self.device, 1485 ) 1486 for _ in range(new_pg.size()) 1487 ] 1488 dist.all_gather(output_tensor_list, input_tensor, group=new_pg) 1489 1490 expected = [ 1491 torch.tensor([pg_idx, ranks_in[0]], device=self.device), 1492 torch.tensor([pg_idx, ranks_in[1]], device=self.device), 1493 ] 1494 self.assertEqual(output_tensor_list, expected) 1495 1496 1497class CommTest(AbstractCommTest, MultiProcessTestCase): 1498 def setUp(self): 1499 super().setUp() 1500 self._spawn_processes() 1501 1502 def tearDown(self): 1503 super().tearDown() 1504 try: 1505 os.remove(self.file_name) 1506 except OSError: 1507 pass 1508 1509 def test_debug_level(self): 1510 try: 1511 del os.environ["TORCH_DISTRIBUTED_DEBUG"] 1512 except KeyError: 1513 pass 1514 1515 dist.set_debug_level_from_env() 1516 # Default should be off 1517 default_debug_mode = dist.get_debug_level() 1518 self.assertEqual(default_debug_mode, dist.DebugLevel.OFF) 1519 mapping = { 1520 "OFF": dist.DebugLevel.OFF, 1521 "off": dist.DebugLevel.OFF, 1522 "oFf": dist.DebugLevel.OFF, 1523 "INFO": dist.DebugLevel.INFO, 1524 "info": dist.DebugLevel.INFO, 1525 "INfO": dist.DebugLevel.INFO, 1526 "DETAIL": dist.DebugLevel.DETAIL, 1527 "detail": dist.DebugLevel.DETAIL, 1528 "DeTaIl": dist.DebugLevel.DETAIL, 1529 } 1530 invalid_debug_modes = ["foo", 0, 1, -1] 1531 1532 for mode in mapping.keys(): 1533 os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) 1534 dist.set_debug_level_from_env() 1535 set_debug_mode = dist.get_debug_level() 1536 self.assertEqual( 1537 set_debug_mode, 1538 mapping[mode], 1539 f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}", 1540 ) 1541 1542 for mode in invalid_debug_modes: 1543 os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) 1544 with self.assertRaisesRegex( 1545 ValueError, "The value of TORCH_DISTRIBUTED_DEBUG must" 1546 ): 1547 dist.set_debug_level_from_env() 1548 1549 1550class DummyWork(dist._Work): 1551 def wait(self, timeout=5.0): 1552 if torch.cuda.is_available(): 1553 torch.cuda.current_stream().synchronize() 1554 return True 1555 1556 1557class DummyProcessGroup(dist.ProcessGroup): 1558 def getBackendName(self): 1559 return "Dummy" 1560 1561 def allgather(self, output_tensor_lists, input_tensor_list, opts=None): 1562 for output_tensor_list, input_tensor in zip( 1563 output_tensor_lists, input_tensor_list 1564 ): 1565 for output_tensor in output_tensor_list: 1566 output_tensor.copy_(input_tensor) 1567 1568 return DummyWork() 1569 1570 def allreduce(self, tensor_list, opts=None): 1571 for tensor in tensor_list: 1572 tensor.add_(2) 1573 1574 return DummyWork() 1575 1576 def barrier(self, opts=None): 1577 store = c10d._get_default_store() 1578 key = "TEST:DummyProcessGroup:barrier" 1579 if self.rank() == 0: 1580 worker_count = 0 1581 # By default, TCPServer lives on rank 0. So rank 0 needs to make 1582 # sure that it does not exit too early before other ranks finish 1583 # using the store. 1584 # Note that, _store_based_barrier does not solve this problem, as 1585 # all ranks need to run at least one store.add(key, 0) before 1586 # exiting, but there is no guarantee that rank 0 is still alive at 1587 # that point. 1588 while worker_count < self.size() - 1: 1589 worker_count = store.add(key, 0) 1590 else: 1591 store.add(key, 1) 1592 1593 return DummyWork() 1594 1595 def broadcast(self, tensor_list, opts=None): 1596 for tensor in tensor_list: 1597 tensor.add_(1) 1598 1599 return DummyWork() 1600 1601 def reduce_scatter(self, output_tensor_list, input_tensor_lists, opts=None): 1602 for output_tensor, input_tensor_list in zip( 1603 output_tensor_list, input_tensor_lists 1604 ): 1605 output_tensor.copy_(input_tensor_list[self.rank()]) 1606 1607 return DummyWork() 1608 1609 def send(self, tensor_list, dst, tag=0): 1610 for tensor in tensor_list: 1611 tensor.add_(1) 1612 1613 return DummyWork() 1614 1615 def recv(self, tensor_list, src, tag=0): 1616 for tensor in tensor_list: 1617 tensor.add_(2) 1618 1619 return DummyWork() 1620 1621 1622class PythonProcessGroupExtensionTest(MultiProcessTestCase): 1623 def setUp(self): 1624 super().setUp() 1625 self._spawn_processes() 1626 1627 def tearDown(self): 1628 super().tearDown() 1629 try: 1630 os.remove(self.file_name) 1631 except OSError: 1632 pass 1633 1634 def test_get_backend_name(self): 1635 dpg = DummyProcessGroup(0, 1) 1636 self.assertEqual("Dummy", dpg.name()) 1637 1638 def test_backend_class_attr(self): 1639 dist.Backend.register_backend( 1640 "dummy", PythonProcessGroupExtensionTest.create_dummy 1641 ) 1642 self.assertEqual(dist.Backend.DUMMY, "dummy") 1643 self.assertEqual( 1644 dist.Backend._plugins["DUMMY"].creator_fn, 1645 PythonProcessGroupExtensionTest.create_dummy, 1646 ) 1647 1648 def test_is_backend_available(self): 1649 self.assertEqual(dist.is_ucc_available(), dist.is_backend_available("ucc")) 1650 self.assertFalse(dist.is_backend_available("dummy")) 1651 dist.Backend.register_backend( 1652 "dummy", PythonProcessGroupExtensionTest.create_dummy 1653 ) 1654 self.assertTrue(dist.is_backend_available("dummy")) 1655 1656 def test_backend_config(self): 1657 dist.Backend.register_backend( 1658 "dummy", PythonProcessGroupExtensionTest.create_dummy 1659 ) 1660 1661 # Ensure backend config can be created with the following arguments 1662 backend_config_strings_and_expected_values = [ 1663 (dist.Backend.GLOO, "cpu:gloo,cuda:gloo"), 1664 (dist.Backend.NCCL, "cuda:nccl"), 1665 (dist.Backend.MPI, "cpu:mpi,cuda:mpi"), 1666 (dist.Backend.UCC, "cpu:ucc,cuda:ucc"), 1667 (dist.Backend.DUMMY, "cpu:dummy,cuda:dummy"), 1668 ("DUMMY", "cpu:dummy,cuda:dummy"), 1669 ("dummy", "cpu:dummy,cuda:dummy"), 1670 ("cpu:dummy,cuda:dummy", "cpu:dummy,cuda:dummy"), 1671 ("cpu:dummy,cuda:nccl", "cpu:dummy,cuda:nccl"), 1672 ("cpu:gloo,cuda:dummy", "cpu:gloo,cuda:dummy"), 1673 ("cpu:gloo,cuda:nccl", "cpu:gloo,cuda:nccl"), 1674 ] 1675 1676 for config_str, expected_value in backend_config_strings_and_expected_values: 1677 with self.subTest(config_str): 1678 # ensures these configs strings are valid and no ValueError is raised 1679 config = dist.BackendConfig(config_str) 1680 self.assertEqual(str(config), expected_value) 1681 1682 # Ensure backend config will raise ValueError with the following arguments 1683 invalid_backend_config_strings = [ 1684 "cpu:gloo,cuda:nccl,", # trailing comma 1685 "cpu:gloo,cuda:nccl,cpu:dummy", # duplicate device 1686 ] 1687 for config_str in invalid_backend_config_strings: 1688 with self.subTest(config_str): 1689 with self.assertRaises(ValueError): 1690 dist.BackendConfig(config_str) 1691 1692 def test_init_process_group_with_multiple_backends(self): 1693 dist.Backend.register_backend( 1694 "dummy", PythonProcessGroupExtensionTest.create_dummy 1695 ) 1696 1697 os.environ["MASTER_ADDR"] = "localhost" 1698 os.environ["MASTER_PORT"] = "6789" 1699 dist.init_process_group( 1700 "cpu:dummy,cuda:dummy", rank=self.rank, world_size=self.world_size 1701 ) 1702 1703 # test all_gather 1704 input_tensor = torch.ones(2, 2) * 7 1705 output_tensor_list = [torch.zeros(2, 2) for _ in range(self.world_size)] 1706 dist.all_gather(output_tensor_list, input_tensor) 1707 1708 dist.barrier() 1709 dist.destroy_process_group() 1710 1711 class Options: 1712 def __init__(self) -> None: 1713 pass 1714 1715 def create(self): 1716 pass 1717 1718 @staticmethod 1719 def create_dummy(store, group_rank, group_size, timeout): 1720 return DummyProcessGroup(group_rank, group_size) 1721 1722 def test_collectives(self): 1723 dist.Backend.register_backend( 1724 "dummy", PythonProcessGroupExtensionTest.create_dummy 1725 ) 1726 1727 os.environ["MASTER_ADDR"] = "localhost" 1728 os.environ["MASTER_PORT"] = "6789" 1729 dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size) 1730 1731 # test all_gather 1732 input_tensor = torch.ones(2, 2) * 7 1733 output_tensor_list = [torch.zeros(2, 2) for _ in range(self.world_size)] 1734 dist.all_gather(output_tensor_list, input_tensor) 1735 1736 for tensor in output_tensor_list: 1737 self.assertEqual(tensor, input_tensor) 1738 1739 # test all_reduce 1740 input_tensor = torch.ones(2, 2) * 7 1741 dist.all_reduce(input_tensor) 1742 self.assertEqual(input_tensor, torch.ones(2, 2) * 7 + 2) 1743 1744 # test broadcast 1745 input_tensor = torch.zeros(2, 2) 1746 dist.broadcast(input_tensor, 0, async_op=True).wait() 1747 self.assertEqual(torch.ones(2, 2), input_tensor) 1748 1749 # test reduce_scatter 1750 output_tensor = torch.zeros(2, 2) 1751 input_tensor_list = [torch.ones(2, 2) for _ in range(self.world_size)] 1752 dist.reduce_scatter(output_tensor, input_tensor_list) 1753 self.assertEqual(output_tensor, torch.zeros(2, 2) + 1) 1754 1755 dist.barrier() 1756 dist.destroy_process_group() 1757 1758 def test_send_recv(self): 1759 dist.Backend.register_backend( 1760 "dummy", PythonProcessGroupExtensionTest.create_dummy 1761 ) 1762 1763 os.environ["MASTER_ADDR"] = "localhost" 1764 os.environ["MASTER_PORT"] = "6789" 1765 dist.init_process_group("dummy", rank=self.rank, world_size=self.world_size) 1766 1767 # test send 1768 input_tensor = torch.zeros(2, 2) 1769 dist.send(input_tensor, (self.rank + 1) % self.world_size) 1770 self.assertEqual(input_tensor, torch.zeros(2, 2) + 1) 1771 1772 with self.assertRaises(ValueError): 1773 dist.send(input_tensor, dist.get_rank()) 1774 1775 # test recv 1776 input_tensor = torch.zeros(2, 2) 1777 dist.recv(input_tensor, (self.rank + 1) % self.world_size) 1778 self.assertEqual(input_tensor, torch.zeros(2, 2) + 2) 1779 1780 dist.barrier() 1781 # intentionally not calling into `destroy_process_group` as not all 1782 # user applications would explicitly that. 1783 1784 1785instantiate_parametrized_tests(CommonDistributedDataParallelTest) 1786 1787 1788class ProcessGroupWithDispatchedCollectivesTests(MultiProcessTestCase): 1789 @property 1790 def world_size(self): 1791 return 1 1792 1793 def setUp(self): 1794 super().setUp() 1795 self._spawn_processes() 1796 1797 def tearDown(self): 1798 super().tearDown() 1799 try: 1800 os.remove(self.file_name) 1801 except OSError: 1802 pass 1803 1804 def test_init_process_group_optional_backend(self): 1805 with tempfile.NamedTemporaryFile(delete=False) as f: 1806 store = dist.FileStore(f.name, self.world_size) 1807 # creates both gloo and nccl backend 1808 if dist.is_gloo_available() and dist.is_nccl_available(): 1809 dist.init_process_group( 1810 store=store, 1811 rank=self.rank, 1812 world_size=self.world_size, 1813 ) 1814 dist.destroy_process_group() 1815 1816 def test_init_process_group_for_all_backends(self): 1817 for backend in dist.Backend.backend_list: 1818 # skip if the backend is not available on the system 1819 if backend == dist.Backend.UNDEFINED: 1820 continue 1821 elif backend == dist.Backend.MPI: 1822 if not dist.is_mpi_available(): 1823 continue 1824 elif backend == dist.Backend.NCCL: 1825 if not dist.is_nccl_available() or not torch.cuda.is_available(): 1826 continue 1827 elif backend == dist.Backend.GLOO: 1828 if not dist.is_gloo_available(): 1829 continue 1830 elif backend == dist.Backend.UCC: 1831 if not dist.is_ucc_available(): 1832 continue 1833 1834 with tempfile.NamedTemporaryFile(delete=False) as f: 1835 store = dist.FileStore(f.name, self.world_size) 1836 dist.init_process_group( 1837 backend=backend, 1838 rank=self.rank, 1839 world_size=self.world_size, 1840 store=store, 1841 ) 1842 pg = c10d._get_default_group() 1843 self.assertEqual(pg.rank(), self.rank) 1844 self.assertEqual(pg.size(), self.world_size) 1845 self.assertEqual(pg.name(), str(backend)) 1846 1847 dist.destroy_process_group() 1848 1849 def _call_collective_with_varying_tensors(self, backend, collective, *args): 1850 # call collective with varying tensors to ensure that the tensors are 1851 # correctly dispatched 1852 1853 # TODO: this will be updated in the future to not be backend specific 1854 device = "cuda" if backend == "nccl" else "cpu" 1855 # ensure supported devices (cpu, cuda) succeeds during dispatch call 1856 tensor = torch.zeros(2, 2, device=torch.device(device)) 1857 # multi tensor collectives 1858 if collective == dist.barrier: 1859 collective() 1860 elif collective in (dist.all_gather, dist.gather): 1861 collective([tensor], tensor, *args) 1862 elif collective == dist.scatter: 1863 collective(tensor, [tensor], *args) 1864 elif collective in (dist.reduce_scatter, dist.all_to_all): 1865 # gloo does not support reduce_scatter or all_to_all 1866 if backend != "gloo": 1867 if collective == dist.reduce_scatter: 1868 collective(tensor, [tensor], *args) 1869 else: 1870 collective([tensor], [tensor], *args) 1871 else: 1872 collective(tensor, *args) 1873 1874 # TODO: backend will be replaced with a non specified backend 1875 def _test_collectives(self, backend): 1876 store = dist.FileStore(self.file_name, self.world_size) 1877 dist.init_process_group( 1878 backend, 1879 world_size=self.world_size, 1880 rank=self.rank, 1881 store=store, 1882 ) 1883 collectives_and_args = [ 1884 (dist.reduce, self.rank), 1885 (dist.broadcast, self.rank), 1886 (dist.all_reduce,), 1887 (dist.all_gather,), 1888 (dist.reduce_scatter,), 1889 (dist.barrier,), 1890 (dist.all_to_all,), 1891 (dist.scatter,), 1892 ] 1893 for collective, *args in collectives_and_args: 1894 with self.subTest(collective=collective, args=args): 1895 self._call_collective_with_varying_tensors(backend, collective, *args) 1896 1897 def _test_allreduce_coalesced(self, backend): 1898 store = dist.FileStore(self.file_name, self.world_size) 1899 dist.init_process_group( 1900 backend, 1901 world_size=self.world_size, 1902 rank=self.rank, 1903 store=store, 1904 ) 1905 # TODO: this will be updated in the future to not be backend specific 1906 device = "cuda" if backend == "nccl" else "cpu" 1907 tensors = [torch.ones(10, 10, device=torch.device(device))] 1908 dist.all_reduce_coalesced(tensors, dist.ReduceOp.SUM) 1909 for tensor in tensors: 1910 self.assertEqual(tensor, torch.ones(10, 10) * self.world_size) 1911 1912 def _test_all_to_all_single(self, backend): 1913 store = dist.FileStore(self.file_name, self.world_size) 1914 dist.init_process_group( 1915 backend, 1916 world_size=self.world_size, 1917 rank=self.rank, 1918 store=store, 1919 ) 1920 device = "cuda" if backend == "nccl" else "cpu" 1921 # test alltoall_base 1922 input_tensor = torch.ones(2, 2, device=torch.device(device)) 1923 output_tensor = torch.zeros(2, 2, device=torch.device(device)) 1924 dist.all_to_all_single(output_tensor, input_tensor) 1925 1926 1927class ReduceOpTest(TestCase): 1928 # Ref: https://github.com/pytorch/pytorch/issues/87191 1929 def test_op_isinstance_of_reduceop(self): 1930 for reduce_op in ( 1931 c10d.ReduceOp.SUM, 1932 c10d.ReduceOp.AVG, 1933 c10d.ReduceOp.PRODUCT, 1934 c10d.ReduceOp.MIN, 1935 c10d.ReduceOp.MAX, 1936 c10d.ReduceOp.BAND, 1937 c10d.ReduceOp.BOR, 1938 c10d.ReduceOp.BXOR, 1939 ): 1940 self.assertTrue(isinstance(reduce_op, c10d.ReduceOp)) 1941 for scale in (torch.tensor(1.0), 2.0): 1942 self.assertTrue( 1943 isinstance(dist._make_nccl_premul_sum(scale), c10d.ReduceOp) 1944 ) 1945 1946 # Ref: https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700 1947 def test_reduceop_copyable(self): 1948 for reduce_op in ( 1949 c10d.ReduceOp.SUM, 1950 c10d.ReduceOp.AVG, 1951 c10d.ReduceOp.PRODUCT, 1952 c10d.ReduceOp.MIN, 1953 c10d.ReduceOp.MAX, 1954 c10d.ReduceOp.BAND, 1955 c10d.ReduceOp.BOR, 1956 c10d.ReduceOp.BXOR, 1957 ): 1958 self.assertEqual(copy.copy(reduce_op), reduce_op) 1959 self.assertEqual(copy.deepcopy(reduce_op), reduce_op) 1960 self.assertEqual(copy.copy(c10d.ReduceOp(reduce_op)), reduce_op) 1961 self.assertEqual(copy.deepcopy(c10d.ReduceOp(reduce_op)), reduce_op) 1962 1963 for scale in (torch.tensor(1.0), 2.0): 1964 reduce_op = dist._make_nccl_premul_sum(scale) 1965 self.assertEqual(copy.copy(reduce_op), reduce_op) 1966 self.assertEqual(copy.deepcopy(reduce_op), reduce_op) 1967 1968 def test_reduceop_pickle(self): 1969 for reduce_op in ( 1970 c10d.ReduceOp.SUM, 1971 c10d.ReduceOp.AVG, 1972 c10d.ReduceOp.PRODUCT, 1973 c10d.ReduceOp.MIN, 1974 c10d.ReduceOp.MAX, 1975 c10d.ReduceOp.BAND, 1976 c10d.ReduceOp.BOR, 1977 c10d.ReduceOp.BXOR, 1978 ): 1979 pickle.loads(pickle.dumps(reduce_op)) 1980 orig = c10d.ReduceOp(reduce_op) 1981 self.assertEqual(pickle.loads(pickle.dumps(orig)), orig) 1982 for scale in (torch.tensor(1.0), 2.0): 1983 reduce_op = dist._make_nccl_premul_sum(scale) 1984 self.assertEqual(pickle.loads(pickle.dumps(reduce_op)), reduce_op) 1985 1986 # Ref: https://github.com/pytorch/pytorch/issues/90072 1987 def test_reduceop_equal(self): 1988 not_reduceop = "abc" 1989 for reduce_op in ( 1990 c10d.ReduceOp.SUM, 1991 c10d.ReduceOp.AVG, 1992 c10d.ReduceOp.PRODUCT, 1993 c10d.ReduceOp.MIN, 1994 c10d.ReduceOp.MAX, 1995 c10d.ReduceOp.BAND, 1996 c10d.ReduceOp.BOR, 1997 c10d.ReduceOp.BXOR, 1998 ): 1999 reduce_op_obj = c10d.ReduceOp(reduce_op) 2000 # this calls `ReduceOp.__eq__(self, other)` 2001 self.assertEqual(reduce_op_obj, reduce_op_obj) 2002 self.assertEqual(reduce_op_obj, reduce_op) 2003 self.assertNotEqual(reduce_op_obj, not_reduceop) 2004 self.assertNotEqual(reduce_op, not_reduceop) 2005 # TODO(crcrpar): This needs to be `assertEqual` for the associativity even though 2006 # the comparison of `RedOpType` and `ReduceOp` sounds less likely to happen compared 2007 # to that of `ReduceOp` and `RedOptype`. 2008 # this calls `RedOpType.__eq__(self, other)` 2009 self.assertNotEqual(reduce_op, reduce_op_obj) 2010 2011 self.assertFalse(None in (reduce_op, reduce_op_obj)) 2012 self.assertFalse(not_reduceop in (reduce_op, reduce_op_obj)) 2013 2014 2015class LocalRankTest(MultiProcessTestCase): 2016 @property 2017 def world_size(self): 2018 return 4 2019 2020 def setUp(self): 2021 super().setUp() 2022 self._spawn_processes() 2023 2024 def tearDown(self): 2025 super().tearDown() 2026 try: 2027 os.remove(self.file_name) 2028 except OSError: 2029 pass 2030 2031 def testWithoutEnv(self): 2032 with self.assertRaisesRegex(RuntimeError, "LOCAL_RANK"): 2033 dist.get_node_local_rank() 2034 2035 def testWithoutEnvWithFallback(self): 2036 self.assertEqual(dist.get_node_local_rank(fallback_rank=2), 2) 2037 2038 def testNodeLocalRankOverridesFallback(self): 2039 os.environ["LOCAL_RANK"] = str(self.rank) 2040 self.assertEqual(dist.get_node_local_rank(fallback_rank=123), self.rank) 2041 2042 def testNodeLocalRank(self): 2043 os.environ["LOCAL_RANK"] = str(self.rank) 2044 self.assertEqual(dist.get_node_local_rank(), self.rank) 2045 2046 2047if __name__ == "__main__": 2048 assert ( 2049 not torch.cuda._initialized 2050 ), "test_distributed must not have initialized CUDA context on main process" 2051 2052 run_tests() 2053