1# Owner(s): ["oncall: distributed"] 2 3import os 4import sys 5import unittest 6from functools import partial, wraps 7 8import torch 9import torch.distributed as dist 10import torch.distributed._functional_collectives as ft_c 11import torch.distributed._tensor as dt 12import torch.distributed.distributed_c10d as c10d 13from functorch import make_fx 14from torch._inductor.utils import run_and_get_code 15from torch.testing import FileCheck 16from torch.testing._internal.distributed.fake_pg import FakeStore 17from torch.utils._triton import has_triton 18 19 20if not dist.is_available(): 21 print("Distributed not available, skipping tests", file=sys.stderr) 22 sys.exit(0) 23 24from torch.testing._internal.common_distributed import ( 25 MultiProcessTestCase, 26 MultiThreadedTestCase, 27 requires_nccl, 28 TEST_SKIPS, 29) 30from torch.testing._internal.common_utils import ( 31 instantiate_parametrized_tests, 32 parametrize, 33 run_tests, 34 TestCase, 35) 36 37 38def new_subgroups(group_size: int, pg_tag=None): 39 world_size = dist.get_world_size() 40 subgroups = [] 41 cur_subgroup = None 42 43 for subgroup_id in range(world_size // group_size): 44 start_rank = subgroup_id * group_size 45 end_rank = start_rank + group_size 46 ranks_in_subgroup = list(range(start_rank, end_rank)) 47 subgroup = c10d._new_group_with_tag( 48 ranks=ranks_in_subgroup, 49 pg_tag=pg_tag, 50 ) 51 subgroups.append(subgroup) 52 53 rank = dist.get_rank() 54 if rank in ranks_in_subgroup: 55 cur_subgroup = subgroup 56 57 return cur_subgroup, subgroups 58 59 60class TestExpand(MultiThreadedTestCase): 61 @property 62 def world_size(self): 63 return 4 64 65 def setUp(self): 66 super().setUp() 67 self._spawn_threads() 68 69 def test_expand_1d_rank_list(self): 70 tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3]) 71 self.assertEqual("", tag) 72 self.assertEqual([0, 1, 2, 3], rankset) 73 self.assertEqual(4, group_size) 74 75 tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla") 76 self.assertEqual("bla", tag) 77 78 def test_expand_2d_rank_list(self): 79 tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]]) 80 self.assertEqual("", tag) 81 self.assertEqual([0, 1, 2, 3], rankset) 82 self.assertEqual(2, group_size) 83 84 tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu") 85 self.assertEqual("blu", tag) 86 87 with self.assertRaisesRegex(ValueError, "group sizes must be identical"): 88 ft_c._expand_group([[0], [1, 2, 3]]) 89 90 def test_expand_process_group(self): 91 tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD) 92 self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag) 93 self.assertEqual([0, 1, 2, 3], rankset) 94 self.assertEqual(4, group_size) 95 96 tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla") 97 self.assertEqual("bla", tag) 98 99 my_pg, others = new_subgroups(group_size=2) 100 tag, rankset, group_size = ft_c._expand_group(my_pg) 101 self.assertEqual(c10d._get_group_tag(my_pg), tag) 102 self.assertEqual(dist.get_process_group_ranks(my_pg), rankset) 103 self.assertEqual(2, group_size) 104 105 my_pg = None 106 for i in range(dist.get_world_size()): 107 group = c10d._new_group_with_tag([i], pg_tag="my_pg") 108 if i == dist.get_rank(): 109 my_pg = group 110 tag, rankset, group_size = ft_c._expand_group(my_pg) 111 self.assertEqual("my_pg", tag) 112 self.assertEqual([dist.get_rank()], rankset) 113 self.assertEqual(1, group_size) 114 115 tag, rankset, group_size = ft_c._expand_group(my_pg, "bla") 116 self.assertEqual("bla", tag) 117 118 def test_expand_device_mesh(self): 119 mesh = dt.DeviceMesh("cpu", torch.arange(4)) 120 tag, rankset, group_size = ft_c._expand_group(mesh) 121 self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag) 122 self.assertEqual([0, 1, 2, 3], rankset) 123 self.assertEqual(4, group_size) 124 125 mesh = dt.DeviceMesh("cpu", torch.arange(4)) 126 tag, rankset, group_size = ft_c._expand_group(mesh) 127 self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag) 128 self.assertEqual([0, 1, 2, 3], rankset) 129 self.assertEqual(4, group_size) 130 131 def test_expand_device_mesh_tuple(self): 132 mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2)) 133 with self.assertRaisesRegex(AssertionError, "Only 1D mesh"): 134 tag, rankset, group_size = ft_c._expand_group(mesh) 135 136 tag, rankset, group_size = ft_c._expand_group((mesh, 0)) 137 self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag) 138 expected_rankset = [0, 2] if dist.get_rank() in [0, 2] else [1, 3] 139 self.assertEqual(expected_rankset, rankset) 140 self.assertEqual(2, group_size) 141 142 tag, rankset, group_size = ft_c._expand_group((mesh, 1)) 143 expected_rankset = [0, 1] if dist.get_rank() in [0, 1] else [2, 3] 144 self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=1)), tag) 145 self.assertEqual(expected_rankset, rankset) 146 self.assertEqual(2, group_size) 147 148 149class TestPgTag(MultiThreadedTestCase): 150 @property 151 def world_size(self): 152 return 4 153 154 def setUp(self): 155 super().setUp() 156 self._spawn_threads() 157 158 """ 159 The behavior we want is as follow: 160 161 - rankset+tag will always result in the same PG. 162 Do we enforce this by failing creation of new PGs or returning existing ones? 163 Return existing one. 164 165 - default tag gives existing behavior. 166 This means we should create duplicates. 167 - _expand_group on _default-tagged pg should always resolve to it 168 This mean we can't depend on empty tag + rankset. 169 """ 170 171 def test_pg_creation_with_tag(self): 172 my_group, _ = new_subgroups(group_size=2, pg_tag="blu") 173 my_group2, _ = new_subgroups(group_size=2, pg_tag="blu") 174 self.assertEqual(my_group, my_group2) 175 176 my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2") 177 self.assertNotEqual(my_group, my_group3) 178 179 my_group4, _ = new_subgroups(group_size=2) 180 self.assertNotEqual(my_group, my_group4) 181 182 my_group5, _ = new_subgroups(group_size=2) 183 self.assertNotEqual(my_group4, my_group5) 184 185 def test_pg_lookup_roundtrip(self): 186 pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu") 187 pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2") 188 pg_notag0, _ = new_subgroups(group_size=2) 189 pg_notag1, _ = new_subgroups(group_size=2) 190 191 def roundtrip(pg): 192 tag, rankset, _ = ft_c._expand_group(pg) 193 return c10d._find_pg_by_ranks_and_tag(tag, rankset) 194 195 self.assertEqual(pg_tag0, roundtrip(pg_tag0)) 196 self.assertEqual(pg_tag1, roundtrip(pg_tag1)) 197 self.assertEqual(pg_notag0, roundtrip(pg_notag0)) 198 self.assertEqual(pg_notag1, roundtrip(pg_notag1)) 199 200 def test_pg_lookup_with_tag(self): 201 pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu") 202 pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla") 203 pg_notag0, _ = new_subgroups(group_size=2) 204 205 def roundtrip(pg, pg_tag): 206 tag, rankset, _ = ft_c._expand_group(pg, pg_tag) 207 return c10d._find_pg_by_ranks_and_tag(tag, rankset) 208 209 self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu")) 210 self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu")) 211 # Cannot erase the tag of a PG 212 self.assertEqual(pg_tag0, roundtrip(pg_tag0, "")) 213 214 def test_find_or_create_pg(self): 215 pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2) 216 pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu") 217 self.assertEqual(pg, pg_tag0) 218 219 def test_find_root_pg(self): 220 pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3]) 221 self.assertEqual(dist.group.WORLD, pg) 222 223 224@instantiate_parametrized_tests 225class TestTraceableCollectives(MultiThreadedTestCase): 226 @property 227 def world_size(self): 228 return 4 229 230 def setUp(self): 231 super().setUp() 232 self._spawn_threads() 233 234 @parametrize("device", ["cpu", "cuda"]) 235 def test_broadcast(self, device): 236 if device == "cuda": 237 if torch.cuda.device_count() < self.world_size: 238 self.skipTest("Not enough CUDA devices") 239 torch.cuda.set_device(dist.get_rank()) 240 241 if dist.get_rank() == 0: 242 tensor = torch.ones([4], device=device) 243 else: 244 tensor = torch.zeros([4], device=device) 245 246 mesh = dt.DeviceMesh(device, torch.arange(4)) 247 res = ft_c.broadcast(tensor, 0, mesh) 248 self.assertEqual(res, torch.ones([4], device=device)) 249 250 @parametrize("device", ["cpu", "cuda"]) 251 def test_all_reduce_eager(self, device): 252 if device == "cuda": 253 if torch.cuda.device_count() < self.world_size: 254 self.skipTest("Not enough CUDA devices") 255 torch.cuda.set_device(dist.get_rank()) 256 257 tensor = torch.ones([4], device=device) 258 mesh = dt.DeviceMesh(device, torch.arange(4)) 259 260 res = ft_c.all_reduce(tensor, "sum", mesh) 261 self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float)) 262 263 mesh = dt.DeviceMesh(device, torch.arange(4).view(2, 2)) 264 res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1)) 265 self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float)) 266 267 @parametrize("device", ["cpu", "cuda"]) 268 def test_all_reduce_coalesced_eager(self, device): 269 if device == "cuda": 270 if torch.cuda.device_count() < self.world_size: 271 self.skipTest("Not enough CUDA devices") 272 torch.cuda.set_device(dist.get_rank()) 273 274 t0 = torch.ones([4], device=device) 275 t1 = torch.ones([6], device=device) + 2 276 mesh = dt.DeviceMesh(device, torch.arange(4)) 277 278 res = ft_c.all_reduce_coalesced([t0, t1], "sum", mesh) 279 self.assertEqual(res[0], t0 * 4) 280 self.assertEqual(res[1], t1 * 4) 281 282 @parametrize("device", ["cpu", "cuda"]) 283 def test_all_gather_tensor(self, device): 284 if device == "cuda": 285 if torch.cuda.device_count() < self.world_size: 286 self.skipTest("Not enough CUDA devices") 287 torch.cuda.set_device(dist.get_rank()) 288 289 # testing 1d/2d mesh 290 mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size)) 291 mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2)) 292 for mesh in [mesh_1d, mesh_2d]: 293 dims_to_gather = [0, 1, 2] 294 for dim in dims_to_gather: 295 output_size = [3, 3, 3] 296 output_size[dim] *= mesh.size(0) 297 # each rank have its own tensor, all_gather gives a bigger tensor 298 local_tensor = torch.ones([3, 3, 3], device=device) 299 gathered_tensor = ft_c.all_gather_tensor( 300 local_tensor, gather_dim=dim, group=(mesh, 0) 301 ) 302 self.assertEqual(gathered_tensor, torch.ones(output_size)) 303 304 @parametrize("device", ["cpu", "cuda"]) 305 def test_all_gather_into_tensor_coalesced(self, device): 306 if device == "cuda": 307 if torch.cuda.device_count() < self.world_size: 308 self.skipTest("Not enough CUDA devices") 309 torch.cuda.set_device(dist.get_rank()) 310 311 tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1] 312 mesh = dt.DeviceMesh(device, torch.arange(4)) 313 314 res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh) 315 self.assertEqual(2, len(res)) 316 self.assertEqual(torch.ones([4 * dist.get_world_size()], device=device), res[0]) 317 self.assertEqual( 318 torch.ones([4 * dist.get_world_size()], device=device) + 1, res[1] 319 ) 320 321 @parametrize("device", ["cpu", "cuda"]) 322 def test_reduce_scatter_tensor(self, device): 323 if device == "cuda": 324 if torch.cuda.device_count() < self.world_size: 325 self.skipTest("Not enough CUDA devices") 326 torch.cuda.set_device(dist.get_rank()) 327 328 # testing 1d/2d mesh 329 mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size)) 330 mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2)) 331 for mesh in [mesh_1d, mesh_2d]: 332 dims_to_scatter = [0, 1] 333 for dim in dims_to_scatter: 334 group_size = mesh.size(0) 335 input_size = [3, 3] 336 output_size = [3, 3] 337 output_size[dim] *= group_size 338 input_tensor = torch.ones(output_size, device=device) 339 res_num = 1 * group_size 340 rs_tensor = ft_c.reduce_scatter_tensor( 341 input_tensor, "sum", scatter_dim=dim, group=(mesh, 0) 342 ) 343 self.assertEqual(rs_tensor, torch.ones(input_size) * res_num) 344 345 @parametrize("device", ["cpu", "cuda"]) 346 def test_reduce_scatter_into_tensor_coalesced(self, device): 347 if device == "cuda": 348 if torch.cuda.device_count() < self.world_size: 349 self.skipTest("Not enough CUDA devices") 350 torch.cuda.set_device(dist.get_rank()) 351 tensors = [ 352 torch.ones([4], dtype=torch.int64, device=device), 353 torch.ones([4], dtype=torch.int64, device=device) + 1, 354 ] 355 mesh = dt.DeviceMesh(device, torch.arange(4)) 356 357 res = ft_c.reduce_scatter_tensor_coalesced(tensors, "sum", [0, 0], mesh) 358 self.assertEqual(2, len(res)) 359 self.assertEqual(torch.tensor([4], device=device), res[0]) 360 self.assertEqual(torch.tensor([8], device=device), res[1]) 361 362 363class TestMetaCollectives(TestCase): 364 def test_all_reduce(self): 365 x = torch.rand((2, 3, 4), device="meta") 366 out = ft_c.all_reduce(x, "sum", "0") 367 self.assertEqual(x.size(), out.size()) 368 369 370class TestGradCollectives(MultiThreadedTestCase): 371 @property 372 def world_size(self): 373 return 2 374 375 def setUp(self): 376 super().setUp() 377 self._spawn_threads() 378 379 def test_all_reduce(self): 380 x = torch.rand([4], requires_grad=True) 381 y = torch.rand([4], requires_grad=True) 382 out = ft_c.all_reduce(x, "sum", dist.group.WORLD) 383 (out + y).sum().backward() 384 self.assertIsNone(x.grad) 385 386 387class TestMakeFx(TestCase): 388 def setUp(self): 389 # make_fx is not thread-safe due to patching nd mutating global states 390 # so create a fake_pg. 391 self.rank = 0 392 self.world_size = 2 393 store = FakeStore() 394 dist.init_process_group( 395 backend="fake", 396 world_size=self.world_size, 397 rank=self.rank, 398 store=store, 399 ) 400 401 def tearDown(self): 402 super().tearDown() 403 404 self.assertFalse(torch.fx._symbolic_trace.is_fx_tracing()) 405 406 def test_all_reduce_tracing(self): 407 def allred(input): 408 return ft_c.all_reduce(input, "sum", group=dist.group.WORLD) + 1 409 410 graph = make_fx(allred)(torch.rand(4)) 411 FileCheck().check("all_reduce").check("wait_tensor").run(str(graph.graph)) 412 413 mesh = dt.DeviceMesh("cpu", torch.arange(self.world_size)) 414 415 def allred_mesh(input): 416 return ft_c.all_reduce(input, "sum", mesh) + 1 417 418 mesh_graph = make_fx(allred_mesh)(torch.rand(4)) 419 FileCheck().check_not("get_attr").check("wait_tensor").run( 420 str(mesh_graph.graph) 421 ) 422 423 def allred_mesh_dim(input): 424 return ft_c.all_reduce(input, "sum", (mesh, 0)) + 1 425 426 mesh_dim_graph = make_fx(allred_mesh_dim)(torch.rand(4)) 427 FileCheck().check_not("get_attr").check("wait_tensor").run( 428 str(mesh_dim_graph.graph) 429 ) 430 431 432BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO 433WORLD_SIZE = 2 434 435 436def exit_if_lt_x_gpu(x): 437 if torch.cuda.device_count() < x: 438 sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) 439 440 441def with_comms(func=None): 442 if func is None: 443 return partial( 444 with_comms, 445 ) 446 447 @wraps(func) 448 def wrapper(self, *args, **kwargs): 449 global BACKEND 450 451 if "BACKEND" in os.environ: 452 BACKEND = os.environ["BACKEND"] 453 if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: 454 sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) 455 self.dist_init() 456 func(self) 457 self.destroy_comms() 458 459 return wrapper 460 461 462class TestCollectivesWithNCCL(MultiProcessTestCase): 463 def setUp(self): 464 super().setUp() 465 os.environ["WORLD_SIZE"] = str(self.world_size) 466 os.environ["BACKEND"] = dist.Backend.NCCL 467 BACKEND = dist.Backend.NCCL 468 self._spawn_processes() 469 470 @property 471 def device(self): 472 return torch.device(self.rank) 473 474 @property 475 def world_size(self): 476 return WORLD_SIZE 477 478 @property 479 def process_group(self): 480 return dist.group.WORLD 481 482 def dist_init(self): 483 dist.init_process_group( 484 backend=BACKEND, 485 world_size=self.world_size, 486 rank=self.rank, 487 init_method=f"file://{self.file_name}", 488 ) 489 490 # set device for nccl pg for collectives 491 if BACKEND == "nccl": 492 torch.cuda.set_device(self.rank) 493 494 def destroy_comms(self): 495 # Wait for all ranks to reach here before starting shutdown. 496 dist.barrier() 497 dist.destroy_process_group() 498 499 @requires_nccl() 500 @with_comms() 501 def test_all_gather_into_tensor_coalesced(self): 502 exit_if_lt_x_gpu(self.world_size) 503 504 tensors = [ 505 torch.ones([4], device=f"cuda:{self.rank}"), 506 torch.ones([4], device=f"cuda:{self.rank}") + 1, 507 ] 508 mesh = dt.DeviceMesh(f"cuda:{self.rank}", torch.arange(self.world_size)) 509 510 res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh) 511 self.assertEqual(2, len(res)) 512 self.assertEqual(torch.ones([4 * dist.get_world_size()]), res[0]) 513 self.assertEqual(torch.ones([4 * dist.get_world_size()]) + 1, res[1]) 514 515 @with_comms() 516 def test_all_to_all_single(self): 517 device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu" 518 mesh = dt.DeviceMesh(device, torch.arange(self.world_size)) 519 rank = dist.get_rank() 520 521 row = self.world_size * (rank + 1) * (self.world_size + 1) / 2 522 x = torch.ones(int(row), 5, device=device) * (rank + 1) 523 split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)] 524 y = ft_c.all_to_all_single( 525 x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh 526 ) 527 expected = [] 528 for idx, tensor in enumerate(torch.split(x, split_sizes)): 529 expected.append(torch.full_like(tensor, (idx + 1))) 530 expected = torch.cat(expected) 531 self.assertEqual(y, expected) 532 533 @with_comms() 534 def test_all_to_all_single_1d_input(self): 535 device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu" 536 mesh = dt.DeviceMesh(device, torch.arange(self.world_size)) 537 rank = dist.get_rank() 538 539 row = self.world_size * (rank + 1) * (self.world_size + 1) / 2 540 x = torch.ones(int(row), device=device) * (rank + 1) 541 split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)] 542 y = ft_c.all_to_all_single( 543 x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh 544 ) 545 expected = [] 546 for idx, tensor in enumerate(torch.split(x, split_sizes)): 547 expected.append(torch.full_like(tensor, (idx + 1))) 548 expected = torch.cat(expected) 549 self.assertEqual(y, expected) 550 551 @with_comms() 552 def test_all_to_all_single_split_sizes_none(self): 553 device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu" 554 mesh = dt.DeviceMesh(device, torch.arange(self.world_size)) 555 rank = dist.get_rank() 556 557 x = torch.ones(self.world_size, self.world_size, device=device) * (rank + 1) 558 y = ft_c.all_to_all_single( 559 x, output_split_sizes=None, input_split_sizes=None, group=mesh 560 ) 561 expected = [] 562 for idx, tensor in enumerate(torch.chunk(x, self.world_size)): 563 expected.append(torch.full_like(tensor, (idx + 1))) 564 expected = torch.cat(expected) 565 self.assertEqual(y, expected) 566 567 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 568 @requires_nccl() 569 @with_comms() 570 def test_tracing(self): 571 def allreduce(t, pg): 572 return ft_c.all_reduce(t, "sum", pg) 573 574 compiled_allreduce = torch.compile(allreduce, fullgraph=True) 575 compiled_allreduce(torch.randn(8, device=self.device), self.process_group) 576 577 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 578 def test_tracing_with_fakepg(self): 579 exit_if_lt_x_gpu(self.world_size) 580 581 def allreduce(t, pg): 582 return ft_c.all_reduce(t, "sum", pg) 583 584 compiled_allreduce = torch.compile(allreduce, fullgraph=True) 585 dist.init_process_group( 586 backend="fake", 587 rank=0, 588 world_size=8, 589 store=FakeStore(), 590 ) 591 allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD) 592 593 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 594 @requires_nccl() 595 @with_comms() 596 def test_tracing_with_dce_code(self): 597 if self.world_size > 2: 598 return 599 600 def func(batch, group, rank): 601 ret = ft_c.permute_tensor(batch, [1, 0], group) 602 if hasattr(ret, "wait"): 603 ret = ret.wait() 604 if rank == 0: 605 return ret 606 else: 607 return batch * 5 608 609 compiled_func = torch.compile(func) 610 ret = compiled_func( 611 torch.ones((100,), device="cuda"), self.process_group, self.rank 612 ) 613 dist.barrier() 614 615 616class TestNCCLCollectivesWithWorldSize4(TestCollectivesWithNCCL): 617 @property 618 def world_size(self): 619 return 4 620 621 @requires_nccl() 622 @with_comms() 623 def test_permute_tensor_with_sub_group(self): 624 exit_if_lt_x_gpu(self.world_size) 625 626 device = "cuda" 627 mesh_dim_names = ["dp", "tp"] 628 629 mesh_2d = dt.init_device_mesh( 630 device, (2, self.world_size // 2), mesh_dim_names=mesh_dim_names 631 ) 632 633 for mesh_name in mesh_dim_names: 634 mesh = mesh_2d[mesh_name] 635 rank = mesh.get_local_rank() 636 637 # rank0: [0., 1.], rank1: [2., 3.] 638 send_tensor = torch.arange(2, dtype=torch.float32, device=device) + 2 * rank 639 recvd_tensor = ft_c.permute_tensor(send_tensor, [1, 0], group=mesh) 640 641 # rank0: [2., 3.], rank1: [0., 1.] 642 expected = torch.arange(2, dtype=torch.float32, device=device) + 2 * ( 643 (rank - 1 + 2) % 2 644 ) 645 self.assertEqual( 646 recvd_tensor, 647 expected, 648 msg=f"Expected {expected} on {self.rank=} (local_rank={rank}), " 649 f"but received {recvd_tensor} instead.", 650 ) 651 652 653@instantiate_parametrized_tests 654class TestFunctionalAutograd(MultiThreadedTestCase): 655 def setUp(self): 656 super().setUp() 657 self._spawn_threads() 658 659 @property 660 def world_size(self): 661 return 2 662 663 @parametrize("compile", [True, False]) 664 def test_all_to_all_single(self, compile: bool = True) -> None: 665 group = dist.group.WORLD.group_name 666 667 t = torch.ones((self.world_size, 2), requires_grad=True) 668 669 def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor: 670 sizes = [1] * world_size 671 t = t * 2 672 assert t.requires_grad 673 out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group) 674 out = out + 0 675 return out 676 677 if compile: 678 compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager") 679 else: 680 compiled = my_func 681 682 out = compiled(t, self.world_size) 683 self.assertEqual(out.shape, t.shape) 684 self.assertEqual(out, torch.full_like(t, 2.0)) 685 self.assertIsNotNone(out.grad_fn) 686 self.assertTrue(out.requires_grad) 687 loss = out.sum() 688 loss.backward() 689 self.assertEqual(t.grad, torch.full_like(t, 2.0)) 690 691 def test_all_to_all_single_inductor(self) -> None: 692 group = dist.group.WORLD.group_name 693 694 t = torch.rand((self.world_size, 2), requires_grad=True) 695 696 def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor: 697 sizes = [1] * world_size 698 t = t * 10 699 assert t.requires_grad 700 out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group) 701 out = out + 2 702 return out.sum() 703 704 compiled = torch.compile(my_func, fullgraph=True) 705 706 def run_with_backward(): 707 out = compiled(t, self.world_size) 708 out.backward() 709 710 res, codes = run_and_get_code(run_with_backward) 711 for code in codes: 712 FileCheck().check_count( 713 "_c10d_functional.all_to_all_single.default", 1, exactly=True 714 ).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run( 715 code 716 ) 717 718 self.assertIsNotNone(t.grad) 719 720 @parametrize("compile", [True, False]) 721 def test_all_gather_tensor(self, compile: bool) -> None: 722 group = dist.group.WORLD.group_name 723 724 def my_func(t: torch.Tensor, dim: int) -> torch.Tensor: 725 assert t.requires_grad 726 out = ft_c.all_gather_tensor_autograd( 727 t * 1.0, 728 gather_dim=dim, 729 group=group, 730 ) 731 out = out * 1.0 732 return out 733 734 if compile: 735 compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager") 736 else: 737 compiled = my_func 738 739 dims_to_gather = [0, 1, 2] 740 for dim in dims_to_gather: 741 output_size = [3, 3, 3] 742 output_size[dim] *= self.world_size 743 # each rank have its own tensor, all_gather gives a bigger tensor 744 local_tensor = torch.ones([3, 3, 3], requires_grad=True) 745 gathered_tensor = compiled(local_tensor, dim) 746 self.assertEqual(gathered_tensor, torch.ones(output_size)) 747 748 gathered_tensor.sum().backward() 749 self.assertEqual( 750 local_tensor.grad, 751 torch.full((3, 3, 3), fill_value=float(self.world_size)), 752 ) 753 754 @parametrize("compile", [True, False]) 755 def test_reduce_scatter_tensor(self, compile: bool) -> None: 756 group = dist.group.WORLD.group_name 757 758 def my_func(t: torch.Tensor, dim: int) -> torch.Tensor: 759 assert t.requires_grad 760 rs_tensor = ( 761 ft_c.reduce_scatter_tensor_autograd( 762 input_tensor * 1.0, "sum", scatter_dim=dim, group=group 763 ) 764 * 1.0 765 ) 766 return rs_tensor 767 768 if compile: 769 compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager") 770 else: 771 compiled = my_func 772 773 dims_to_scatter = [0, 1] 774 for dim in dims_to_scatter: 775 group_size = self.world_size 776 input_size = [3, 3] 777 output_size = [3, 3] 778 output_size[dim] *= group_size 779 input_tensor = torch.ones(output_size, requires_grad=True) 780 rs_tensor = compiled(input_tensor, dim) 781 res_num = 1 * group_size 782 self.assertEqual(rs_tensor, torch.ones(input_size) * res_num) 783 rs_tensor.sum().backward() 784 self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0)) 785 786 787class TestFunctionalAutogradWithNCCL(MultiProcessTestCase): 788 def setUp(self): 789 super().setUp() 790 os.environ["WORLD_SIZE"] = str(self.world_size) 791 os.environ["BACKEND"] = dist.Backend.NCCL 792 self._spawn_processes() 793 794 @property 795 def device(self): 796 return torch.device(self.rank) 797 798 @property 799 def world_size(self): 800 return 2 801 802 @property 803 def process_group(self): 804 return dist.group.WORLD 805 806 def dist_init(self): 807 dist.init_process_group( 808 backend=BACKEND, 809 world_size=self.world_size, 810 rank=self.rank, 811 init_method=f"file://{self.file_name}", 812 ) 813 814 # set device for nccl pg for collectives 815 if BACKEND == "nccl": 816 torch.cuda.set_device(self.rank) 817 818 def destroy_comms(self): 819 # Wait for all ranks to reach here before starting shutdown. 820 dist.barrier() 821 dist.destroy_process_group() 822 823 @requires_nccl() 824 @with_comms() 825 def test_all_to_all_single(self) -> None: 826 group = self.process_group.group_name 827 828 t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device) 829 830 sizes = [1] * self.world_size 831 assert t.requires_grad 832 out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0 833 834 self.assertEqual(out.shape, t.shape) 835 self.assertEqual(out, torch.full_like(t, 2.0)) 836 self.assertIsNotNone(out.grad_fn) 837 self.assertTrue(out.requires_grad) 838 loss = out.sum() 839 loss.backward() 840 self.assertEqual(t.grad, torch.full_like(t, 2.0)) 841 842 843if __name__ == "__main__": 844 run_tests() 845