1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3import os 4 5import torch 6import torch.distributed._functional_collectives as funcol 7from torch.distributed._tensor import DTensor 8from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh 9from torch.distributed.distributed_c10d import ( 10 _get_default_group, 11 _world, 12 get_global_rank, 13 get_world_size, 14 init_process_group, 15 is_initialized, 16 is_nccl_available, 17 ProcessGroup, 18) 19from torch.distributed.tensor._collective_utils import ( 20 mesh_broadcast, 21 mesh_scatter, 22 unpad_tensor, 23) 24from torch.distributed.tensor.placement_types import _Partial, Shard 25from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 26from torch.testing._internal.common_utils import run_tests 27from torch.testing._internal.distributed._tensor.common_dtensor import ( 28 DTensorTestBase, 29 with_comms, 30) 31from torch.testing._internal.distributed.fake_pg import FakeStore 32 33 34def _get_device_type(world_size): 35 if ( 36 torch.cuda.is_available() 37 and torch.cuda.device_count() >= world_size 38 and is_nccl_available() 39 ): 40 device_type = "cuda" 41 else: 42 device_type = "cpu" 43 return device_type 44 45 46def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0): 47 os.environ["MASTER_ADDR"] = addr 48 os.environ["MASTER_PORT"] = port 49 os.environ["WORLD_SIZE"] = f"{world_size}" 50 os.environ["RANK"] = f"{rank}" 51 52 53class DeviceMeshTestGlooBackend(DTensorTestBase): 54 @property 55 def backend(self): 56 return "gloo" 57 58 @with_comms 59 def test_device_mesh_reuse_default_group(self): 60 mesh = init_device_mesh(self.device_type, (self.world_size,)) 61 mesh_group = mesh.get_group() 62 default_group = _get_default_group() 63 if torch.cuda.is_available(): 64 self.assertNotEqual(mesh_group, default_group) 65 self.assertEqual(get_world_size(mesh_group), get_world_size(default_group)) 66 else: 67 self.assertEqual(mesh_group, default_group) 68 69 70class DeviceMeshTest(DTensorTestBase): 71 @property 72 def world_size(self): 73 return 4 74 75 def test_init_process_group(self): 76 device_type = _get_device_type(self.world_size) 77 mesh_tensor = torch.arange(4).reshape(2, 2) 78 self.assertTrue(not is_initialized()) 79 _set_env_var(world_size=self.world_size, rank=self.rank) 80 DeviceMesh(device_type, mesh_tensor) 81 self.assertTrue(is_initialized()) 82 self.destroy_pg() 83 84 @with_comms 85 @skip_if_lt_x_gpu(4) 86 def test_assert_invalid_mesh_tensor(self): 87 mesh = torch.arange(self.world_size).to(self.rank) 88 with self.assertRaises(ValueError): 89 device_mesh = DeviceMesh(self.device_type, mesh) 90 91 @with_comms 92 def test_get_group_and_get_all_groups(self): 93 mesh_shape = (2, self.world_size // 2) 94 mesh_2d = init_device_mesh( 95 self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") 96 ) 97 98 tp_mesh = mesh_2d["tp"] 99 dp_mesh = mesh_2d["dp"] 100 101 self.assertEqual(mesh_2d.get_group(0), mesh_2d.get_group("dp")) 102 self.assertEqual(mesh_2d.get_group(1), mesh_2d.get_group("tp")) 103 104 self.assertEqual(mesh_2d.get_group("dp"), dp_mesh.get_group()) 105 self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group()) 106 107 groups = mesh_2d.get_all_groups() 108 self.assertEqual(len(groups), 2) 109 self.assertTrue(tp_mesh.get_group() in groups) 110 self.assertTrue(dp_mesh.get_group() in groups) 111 112 @with_comms 113 def test_get_local_rank_raises_exception(self): 114 mesh_shape = (2, self.world_size // 2) 115 mesh_2d = init_device_mesh( 116 self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") 117 ) 118 119 with self.assertRaisesRegex( 120 RuntimeError, 121 "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", 122 ): 123 local_rank = mesh_2d.get_local_rank() 124 125 @with_comms 126 def test_get_local_rank(self): 127 mesh_shape = (2, self.world_size // 2) 128 mesh_2d = init_device_mesh( 129 self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") 130 ) 131 self.assertEqual(mesh_2d.get_local_rank("dp"), mesh_2d.get_local_rank(0)) 132 self.assertEqual(mesh_2d.get_local_rank("tp"), mesh_2d.get_local_rank(1)) 133 134 dp_mesh = mesh_2d["dp"] 135 tp_mesh = mesh_2d["tp"] 136 self.assertEqual(dp_mesh.get_local_rank(), mesh_2d.get_local_rank("dp")) 137 self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp")) 138 139 # Verify flattened mesh local rank correctness. 140 flattened_mesh = mesh_2d["dp", "tp"]._flatten() 141 self.assertEqual(flattened_mesh.get_local_rank(), self.rank) 142 143 @with_comms 144 def test_device_mesh_2d(self): 145 mesh_tensor = torch.arange(4).reshape(2, 2) 146 # construct a cuda device mesh 147 mesh = DeviceMesh(self.device_type, mesh_tensor) 148 149 # check all dim groups 150 dim_to_subgroups = mesh.get_all_groups() 151 152 expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] 153 for dim, dim_group in enumerate(dim_to_subgroups): 154 self.assertTrue(dim < 2) 155 dim_ranks = expected_ranks_by_dim[dim] 156 157 dim_group_size = get_world_size(dim_group) 158 self.assertIsInstance(dim_group, ProcessGroup) 159 self.assertEqual(dim_group_size, 2) 160 global_ranks = [ 161 get_global_rank(dim_group, i) for i in range(dim_group_size) 162 ] 163 current_rank_expected_group_ranks = ( 164 dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1] 165 ) 166 self.assertEqual(global_ranks, current_rank_expected_group_ranks) 167 168 @with_comms 169 def test_device_mesh_init_backend(self): 170 mesh = DeviceMesh(self.device_type, [1], _init_backend=False) 171 172 with self.assertRaisesRegex(RuntimeError, "process groups not initialized!"): 173 mesh.get_group() 174 175 # coordinates should always been populated when init_backend is False, as whenever 176 # we call init_backend we should make sure the default pg already created 177 mesh.get_coordinate() 178 179 def test_fake_pg_device_mesh(self): 180 fake_store = FakeStore() 181 init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size) 182 device_type = "cuda" if torch.cuda.is_available() else "cpu" 183 mesh = DeviceMesh(device_type, torch.arange(self.world_size)) 184 185 local_tensor = torch.randn(2, 8) 186 global_tensor = funcol.all_gather_tensor( 187 local_tensor, gather_dim=0, group=(mesh, 0) 188 ) 189 self.assertEqual(global_tensor.shape, (self.world_size * 2, 8)) 190 191 @with_comms 192 def test_from_group_with_global_pg(self): 193 # Simple test: check `from_group` from a mesh pg vs. directly 194 # initializing via `init_device_mesh` 195 ref_global_mesh = init_device_mesh(self.device_type, (self.world_size,)) 196 mesh_pg = ref_global_mesh.get_group() 197 global_mesh = DeviceMesh.from_group(mesh_pg, self.device_type) 198 self.assertEqual(ref_global_mesh, global_mesh) 199 self.assertEqual(ref_global_mesh._dim_group_infos, global_mesh._dim_group_infos) 200 self.assertEqual( 201 ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim 202 ) 203 204 @with_comms 205 def test_from_group_with_invalid_mesh(self): 206 global_pg = _get_default_group() 207 global_pg_size = global_pg.size() 208 assert global_pg_size == 4, "Test assumes global world size of 4" 209 invalid_mesh = [[0, 1], [2, 3]] # 2D mesh when we need 1D 210 regex = r"Invalid mesh \[\[0, 1\], \[2, 3\]\] for ProcessGroup with ranks \[0, 1, 2, 3\]" 211 with self.assertRaisesRegex(ValueError, regex): 212 DeviceMesh.from_group(global_pg, "cuda", invalid_mesh) 213 214 device_mesh = init_device_mesh(self.device_type, (2, 2)) 215 groups = device_mesh.get_all_groups() 216 invalid_mesh = (0, 1, 2, 3) # 1D mesh when we need 2D 217 regex = r"Expects mesh with ndim equal to number of ProcessGroups but got mesh \[0, 1, 2, 3\] and 2 ProcessGroups" 218 with self.assertRaisesRegex(ValueError, regex): 219 DeviceMesh.from_group(groups, self.device_type, invalid_mesh) 220 221 def test_raises_invalid_device_type(self): 222 with self.assertRaisesRegex( 223 RuntimeError, 224 "Device type with GPU index is not supported", 225 ): 226 # test init_device_mesh with an invalid device type that contains a GPU index 227 mesh_shape = (2, self.world_size // 2) 228 mesh_2d = init_device_mesh( 229 "cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp") 230 ) 231 232 @with_comms 233 def test_set_mesh_dim_group_options(self): 234 device_type = "cuda" if torch.cuda.is_available() else "cpu" 235 _mesh_resources._set_mesh_dim_group_options(1, "fake", None) 236 237 mesh_tensor = torch.arange(4).reshape(2, 2) 238 mesh = DeviceMesh(device_type, mesh_tensor) 239 self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake") 240 241 242class DeviceMeshTestNDim(DTensorTestBase): 243 @property 244 def world_size(self): 245 return 8 246 247 @with_comms 248 def test_device_mesh_nd(self): 249 # construct a cuda device mesh 250 mesh_tensor = torch.arange(8).reshape(2, 2, 2) 251 mesh = DeviceMesh(self.device_type, mesh_tensor) 252 253 # check all dim groups 254 dim_to_subgroups = mesh.get_all_groups() 255 256 for dim, dim_group in enumerate(dim_to_subgroups): 257 self.assertTrue(dim < mesh_tensor.ndim) 258 dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2) 259 260 dim_group_size = get_world_size(dim_group) 261 self.assertIsInstance(dim_group, ProcessGroup) 262 self.assertEqual(dim_group_size, 2) 263 global_ranks = [ 264 get_global_rank(dim_group, i) for i in range(dim_group_size) 265 ] 266 for ranks in dim_ranks: 267 if self.rank in ranks: 268 self.assertEqual(global_ranks, ranks.tolist()) 269 270 @with_comms 271 def test_device_mesh_hash(self): 272 mesh_tensor_2d = torch.arange(8).reshape(4, 2) 273 mesh = DeviceMesh(self.device_type, mesh_tensor_2d) 274 mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d) 275 self.assertEqual(hash(mesh), hash(mesh2)) 276 mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2) 277 mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d) 278 self.assertNotEqual(hash(mesh), hash(mesh3)) 279 self.assertNotEqual(hash(mesh2), hash(mesh3)) 280 281 @with_comms 282 def test_get_local_rank_3d(self): 283 """ 284 If we have a 3D mesh and we want to apply dp, pp, tp to it, 285 mesh_dim_names = ["dp", "pp", "tp"], and the mesh tensor would be: 286 mesh_3d_tensor = [ 287 [ 288 [0, 1], 289 [2, 3], 290 ], 291 [ 292 [4, 5], 293 [6, 7], 294 ] 295 296 ] 297 """ 298 mesh_shape = (2, 2, 2) 299 mesh_3d = init_device_mesh( 300 self.device_type, mesh_shape, mesh_dim_names=("dp", "pp", "tp") 301 ) 302 303 # tp_rank_0: [0, 2, 4, 6], tp_rank_1: [1, 3, 5, 7] 304 tp_rank = mesh_3d.get_local_rank("tp") 305 expected_tp_rank = self.rank % 2 306 self.assertEqual(tp_rank, expected_tp_rank) 307 308 # pp_rank_0: [0, 1, 4, 5], pp_rank_1: [2, 3, 6, 7] 309 pp_rank = mesh_3d.get_local_rank("pp") 310 expected_pp_rank = 0 if self.rank % 4 <= 1 else 1 311 self.assertEqual(pp_rank, expected_pp_rank) 312 313 # dp_rank_0: [0, 1, 2, 3], dp_rank_1: [4, 5, 6, 7] 314 dp_rank = mesh_3d.get_local_rank("dp") 315 expected_dp_rank = self.rank // 4 316 self.assertEqual(dp_rank, expected_dp_rank) 317 318 @with_comms 319 def test_device_mesh_parent_child_hash(self): 320 mesh_2d = init_device_mesh( 321 self.device_type, (2, self.world_size // 2), mesh_dim_names=("DP", "TP") 322 ) 323 324 mesh_group_1 = torch.arange(0, self.world_size // 2) 325 mesh_group_2 = torch.arange(self.world_size // 2, self.world_size) 326 ep_mesh_1 = DeviceMesh(self.device_type, mesh_group_1) 327 ep_mesh_2 = DeviceMesh(self.device_type, mesh_group_2) 328 ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2 329 # ep_mesh is considered different from mesh_2d["TP"] 330 self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list) 331 self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape) 332 self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type) 333 self.assertNotEqual(mesh_2d["TP"].mesh_dim_names, ep_mesh.mesh_dim_names) 334 self.assertEqual(mesh_2d["TP"]._thread_id, ep_mesh._thread_id) 335 self.assertNotEqual(hash(mesh_2d["TP"]), hash(ep_mesh)) 336 self.assertNotEqual(mesh_2d["TP"], ep_mesh) 337 338 another_mesh_1 = DeviceMesh(self.device_type, mesh_group_1) 339 another_mesh_2 = DeviceMesh(self.device_type, mesh_group_2) 340 another_mesh = ( 341 another_mesh_1 if self.rank < self.world_size // 2 else another_mesh_2 342 ) 343 # another_mesh is considered the same as ep_mesh 344 self.assertEqual(ep_mesh._flatten_mesh_list, another_mesh._flatten_mesh_list) 345 self.assertEqual(ep_mesh.mesh.shape, another_mesh.mesh.shape) 346 self.assertEqual(ep_mesh.device_type, another_mesh.device_type) 347 self.assertEqual(ep_mesh.mesh_dim_names, another_mesh.mesh_dim_names) 348 self.assertEqual(ep_mesh._thread_id, another_mesh._thread_id) 349 self.assertEqual(hash(ep_mesh), hash(another_mesh)) 350 self.assertEqual(ep_mesh, another_mesh) 351 352 @with_comms 353 def test_from_group_with_mesh_shape(self): 354 """Tests ``from_group`` when passing ``mesh_shape`` as 2D.""" 355 # Consider two different logical views of the same mesh: 356 # - (4, 2) ("dp", "tp") mesh 357 # - (2, 2, 2) ("dp_replicate", "dp_shard", "tp") mesh 358 mesh_shape = (2, 2, 2) 359 mesh_dim_names = ("dp_replicate", "dp_shard", "tp") 360 ref_mesh = init_device_mesh( 361 self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names 362 ) 363 364 dp_shard_group = ref_mesh["dp_shard"].get_group() 365 dp_replicate_group = ref_mesh["dp_replicate"].get_group() 366 367 dp_mesh = DeviceMesh.from_group( 368 [dp_replicate_group, dp_shard_group], 369 self.device_type, 370 mesh=ref_mesh.mesh[:, :, ref_mesh.get_local_rank(2)], 371 mesh_dim_names=mesh_dim_names[:2], 372 ) 373 374 ref_mesh_dp_dim_group_infos = ref_mesh._dim_group_infos[:2] 375 for (_, ref_ranks, _), (_, ranks, _) in zip( 376 ref_mesh_dp_dim_group_infos, dp_mesh._dim_group_infos 377 ): 378 self.assertEqual(ref_ranks, ranks) 379 # Cannot check directly for mesh equality since parent meshes are not 380 # the same since the ref's parent mesh is 3D 381 self.assertEqual(dp_mesh["dp_replicate"].mesh, ref_mesh["dp_replicate"].mesh) 382 for (_, ref_ranks, _), (_, ranks, _) in zip( 383 dp_mesh["dp_replicate"]._dim_group_infos, 384 ref_mesh["dp_replicate"]._dim_group_infos, 385 ): 386 self.assertEqual(ref_ranks, ranks) 387 self.assertEqual(dp_mesh["dp_shard"].mesh, ref_mesh["dp_shard"].mesh) 388 for (_, ref_ranks, _), (_, ranks, _) in zip( 389 dp_mesh["dp_shard"]._dim_group_infos, ref_mesh["dp_shard"]._dim_group_infos 390 ): 391 self.assertEqual(ref_ranks, ranks) 392 393 394class InitDeviceMeshTest(DTensorTestBase): 395 @property 396 def world_size(self): 397 return 8 398 399 @with_comms 400 def test_init_device_mesh(self): 401 mesh_shape = (2, 4) 402 mesh_dim_names = ("DP", "TP") 403 ref_mesh = DeviceMesh( 404 self.device_type, 405 torch.arange(8).view(mesh_shape), 406 mesh_dim_names=mesh_dim_names, 407 ) 408 409 # test init_device_mesh with mesh_dim_names 410 mesh_2d = init_device_mesh( 411 self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names 412 ) 413 self.assertEqual(mesh_2d, ref_mesh) 414 self.assertEqual(mesh_2d.mesh_dim_names, mesh_dim_names) 415 416 @with_comms 417 def test_raises_duplicate_mesh_dim_names(self): 418 with self.assertRaisesRegex( 419 RuntimeError, 420 "Each mesh_dim_name must be unique.", 421 ): 422 mesh = init_device_mesh( 423 self.device_type, 424 (2, 4), 425 mesh_dim_names=["dp", "dp"], 426 ) 427 428 @with_comms 429 def test_raises_mesh_shape_mesh_dim_names_mismatch(self): 430 with self.assertRaisesRegex( 431 RuntimeError, 432 "mesh_shape and mesh_dim_names should have same length!", 433 ): 434 mesh = init_device_mesh( 435 self.device_type, 436 (8,), 437 mesh_dim_names=["dp", "tp"], 438 ) 439 440 441class TestDeviceMeshGetItem(DTensorTestBase): 442 @property 443 def world_size(self): 444 return 8 445 446 @with_comms 447 def test_raises_no_mesh_dim_found(self): 448 with self.assertRaisesRegex( 449 RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!" 450 ): 451 mesh = init_device_mesh(self.device_type, (2, 4)) 452 child_mesh = mesh["DP"] 453 454 @with_comms 455 def test_raises_invalid_mesh_dim_name(self): 456 child_mesh_dim_name = ("PP",) 457 with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): 458 mesh_dim_names = ("DP", "TP") 459 mesh = init_device_mesh( 460 self.device_type, (2, 4), mesh_dim_names=mesh_dim_names 461 ) 462 child_mesh = mesh[child_mesh_dim_name] 463 464 @with_comms 465 def test_get_item_2d(self): 466 mesh_shape = (2, 4) 467 mesh_dim_names = ("DP", "TP") 468 mesh_2d = init_device_mesh( 469 self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names 470 ) 471 472 pg_ranks_by_dim_name = {} 473 for mesh_dim_name in mesh_dim_names: 474 mesh_dim = mesh_dim_names.index(mesh_dim_name) 475 pg_ranks_by_dim_name[mesh_dim_name] = mesh_2d.mesh.swapdims( 476 -1, mesh_dim 477 ).reshape(-1, mesh_2d.mesh.size(mesh_dim)) 478 479 tp_mesh = mesh_2d["TP"] 480 tp_group_idx = self.rank // 4 481 self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx]) 482 483 dp_mesh = mesh_2d["DP"] 484 dp_group_idx = self.rank % 4 485 self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx]) 486 487 @with_comms 488 def test_get_item_1d(self): 489 mesh = init_device_mesh(self.device_type, (8,), mesh_dim_names=("dp",)) 490 # Make sure slicing out 1D mesh from a 1D mesh works. 491 dp_mesh = mesh["dp"] 492 self.assertEqual(dp_mesh, mesh) 493 494 with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): 495 dp_mesh = mesh["dim0"] 496 497 @with_comms 498 def test_get_item_3d(self): 499 mesh_shape = (2, 2, 2) 500 mesh_dim_names = ("Replicate", "Shard", "TP") 501 mesh_3d = init_device_mesh( 502 self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names 503 ) 504 505 tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]] 506 tp_group_idx = int(self.rank / 2) 507 self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx]) 508 509 shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]] 510 shard_group_idx = self.rank % 2 + self.rank // 4 * 2 511 self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx]) 512 513 replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]] 514 replicate_group_idx = self.rank % 4 515 self.assertEqual( 516 mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx] 517 ) 518 519 # We support both UX for nD slicing. 520 # mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"] 521 hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]] 522 hsdp_mesh_2 = mesh_3d["Replicate", "Shard"] 523 hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]] 524 hsdp_group_idx = self.rank % 2 525 self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx]) 526 self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx]) 527 self.assertEqual(hsdp_mesh_1, hsdp_mesh_2) 528 529 @with_comms 530 def test_cache_and_reuse_submesh_slice_result(self): 531 mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp")) 532 533 dp_mesh = mesh["dp"] 534 ref_pg_count = _world.group_count 535 536 # When we call the "dp" slice second time, it should not create any new pg. 537 # As we are just using the cached result so the pg count should be the same. 538 dp_mesh_2 = mesh["dp"] 539 self.assertEqual(ref_pg_count, _world.group_count) 540 541 # When we call the "tp" slice, it should not create a new pg, as the "tp" slice would 542 # just reuse the parent mesh pg. 543 tp_mesh = mesh["tp"] 544 self.assertEqual(_world.group_count, ref_pg_count) 545 546 @with_comms 547 def test_get_item_3d_noncontiguous_slicing(self): 548 mesh_shape = (2, 2, 2) 549 mesh_dim_names = ("dp", "pp", "cp") 550 mesh_3d = init_device_mesh( 551 self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names 552 ) 553 554 # Slice order simply decides which mesh_dim sits on which mesh_dim. 555 # For dp_cp_mesh, cp mesh is the innermost dimension. 556 dp_cp_mesh = mesh_3d["dp", "cp"] 557 expected_mesh_tensor = ( 558 torch.tensor([[0, 1], [4, 5]], dtype=torch.int) 559 if self.rank in (0, 1, 4, 5) 560 else torch.tensor([[2, 3], [6, 7]], dtype=torch.int) 561 ) 562 dp_local_rank = dp_cp_mesh.get_local_rank("dp") 563 self.assertEqual(dp_cp_mesh.mesh, expected_mesh_tensor) 564 cp_mesh = mesh_3d["cp"] 565 # Check on the current dp_local_rank, whether the cp mesh tensor is the same. 566 self.assertEqual(dp_cp_mesh.mesh[dp_local_rank], cp_mesh.mesh) 567 568 with self.assertRaisesRegex( 569 KeyError, 570 "Invalid mesh_dim_names", 571 ): 572 cp_dp_mesh = mesh_3d["cp", "dp"] 573 574 @with_comms 575 def test_flatten_mesh(self): 576 mesh_shape = (2, 2, 2) 577 mesh_dim_names = ("dp", "cp", "tp") 578 mesh_3d = init_device_mesh( 579 self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names 580 ) 581 582 # Test flatten contiguous dims 583 dp_cp_mesh = mesh_3d["dp", "cp"] 584 flattened_dp_cp_mesh = dp_cp_mesh._flatten() 585 self.assertEqual(dp_cp_mesh.mesh.flatten(), flattened_dp_cp_mesh.mesh) 586 self.assertEqual(flattened_dp_cp_mesh.mesh_dim_names[0], "dp_cp") 587 root_mesh = _mesh_resources.get_root_mesh(dp_cp_mesh) 588 self.assertEqual(root_mesh, mesh_3d) 589 flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][ 590 "dp_cp" 591 ] 592 self.assertEqual(flatten_mesh_root_dims, (0, 1)) 593 594 ref_pg_count = _world.group_count 595 # Calling flatten again should not create a new pg. 596 flattened_dp_cp_mesh_2 = dp_cp_mesh._flatten() 597 self.assertEqual(flattened_dp_cp_mesh, flattened_dp_cp_mesh_2) 598 self.assertEqual(ref_pg_count, _world.group_count) 599 600 # Test flatten non-contiguous dims 601 dp_tp_mesh = mesh_3d["dp", "tp"] 602 flattened_dp_tp_mesh = dp_tp_mesh._flatten() 603 self.assertEqual(dp_tp_mesh.mesh.flatten(), flattened_dp_tp_mesh.mesh) 604 self.assertEqual(flattened_dp_tp_mesh.mesh_dim_names[0], "dp_tp") 605 root_mesh = _mesh_resources.get_root_mesh(dp_tp_mesh) 606 self.assertEqual(root_mesh, mesh_3d) 607 flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][ 608 "dp_tp" 609 ] 610 self.assertEqual(flatten_mesh_root_dims, (0, 2)) 611 612 # Test flatten with a flattened mesh_dim_name 613 cp_tp_mesh = mesh_3d["cp", "tp"] 614 cp_tp_mesh._flatten("dummy") 615 self.assertEqual(mesh_3d["dummy"].mesh_dim_names[0], "dummy") 616 617 @with_comms 618 def test_reconstruct_mesh_with_flatten_dim(self): 619 mesh_3d = init_device_mesh( 620 self.device_type, (2, 2, 2), mesh_dim_names=("replicate", "shard", "cp") 621 ) 622 shard_cp_mesh = mesh_3d["shard", "cp"]._flatten() 623 hsdp_mesh = mesh_3d["replicate", "shard_cp"] 624 expected_mesh_tensor = torch.tensor( 625 [[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int 626 ) 627 self.assertEqual(hsdp_mesh.mesh, expected_mesh_tensor) 628 self.assertEqual(shard_cp_mesh.get_group(), mesh_3d["shard_cp"].get_group()) 629 self.assertEqual( 630 shard_cp_mesh.get_group(), mesh_3d.get_group(mesh_dim="shard_cp") 631 ) 632 633 mesh_3d = init_device_mesh( 634 self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp") 635 ) 636 dp_cp_mesh = mesh_3d["dp", "cp"]._flatten() 637 spmd_mesh = mesh_3d["dp_cp", "tp"] 638 expected_mesh_tensor = torch.tensor( 639 [[0, 1], [2, 3], [4, 5], [6, 7]], dtype=torch.int 640 ) 641 self.assertEqual(spmd_mesh.mesh, expected_mesh_tensor) 642 self.assertEqual(dp_cp_mesh.get_group(), mesh_3d["dp_cp"].get_group()) 643 self.assertEqual(dp_cp_mesh.get_group(), mesh_3d.get_group(mesh_dim="dp_cp")) 644 645 646class TestMeshEnv(DTensorTestBase): 647 @property 648 def world_size(self): 649 return 8 650 651 @with_comms 652 def test_get_root_mesh(self): 653 mesh_3d = init_device_mesh( 654 self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp") 655 ) 656 657 dp_cp_mesh = mesh_3d["dp", "cp"] 658 dp_tp_mesh = mesh_3d["dp", "tp"] 659 cp_tp_mesh = mesh_3d["cp", "tp"] 660 dp_mesh = mesh_3d["dp"] 661 cp_mesh = mesh_3d["cp"] 662 tp_mesh = mesh_3d["tp"] 663 self.assertEqual(_mesh_resources.get_root_mesh(dp_cp_mesh), mesh_3d) 664 self.assertEqual(_mesh_resources.get_root_mesh(dp_tp_mesh), mesh_3d) 665 self.assertEqual(_mesh_resources.get_root_mesh(cp_tp_mesh), mesh_3d) 666 self.assertEqual(_mesh_resources.get_root_mesh(dp_mesh), mesh_3d) 667 self.assertEqual(_mesh_resources.get_root_mesh(cp_mesh), mesh_3d) 668 self.assertEqual(_mesh_resources.get_root_mesh(tp_mesh), mesh_3d) 669 670 @with_comms 671 def test_get_root_mesh_dim_exist(self): 672 mesh_shape = (2, self.world_size // 2) 673 mesh_dim_names = ("DP", "TP") 674 mesh_2d = init_device_mesh( 675 self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names 676 ) 677 678 self.assertEqual(_mesh_resources.get_root_mesh_dim(mesh_2d["DP"]), 0) 679 self.assertEqual(_mesh_resources.get_root_mesh_dim(mesh_2d["TP"]), 1) 680 681 @with_comms 682 def test_get_root_mesh_dim_not_exist(self): 683 mesh_shape = (self.world_size,) 684 mesh = init_device_mesh(self.device_type, mesh_shape) 685 686 self.assertEqual(_mesh_resources.get_root_mesh_dim(mesh), None) 687 688 @with_comms 689 def test_get_mesh_dim_by_name(self): 690 mesh_shape = (2, self.world_size // 2) 691 mesh_dim_names = ("DP", "TP") 692 mesh_2d = init_device_mesh( 693 self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names 694 ) 695 696 self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "DP"), 0) 697 self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "TP"), 1) 698 699 @with_comms 700 def test_get_all_submeshes(self): 701 mesh_2d = init_device_mesh( 702 self.device_type, (2, 4), mesh_dim_names=("replicate", "shard") 703 ) 704 all_submeshes = _mesh_resources._get_all_submeshes(mesh_2d, "replicate") 705 self.assertEqual(len(all_submeshes), 4) 706 self.assertEqual( 707 all(submesh.mesh.numel() == 2 for submesh in all_submeshes), True 708 ) 709 710 711class DeviceMeshCollectiveTest(DTensorTestBase): 712 @property 713 def world_size(self): 714 return 8 715 716 @with_comms 717 def test_broadcast_1d(self): 718 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 719 local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank 720 mesh_broadcast(local_tensor, mesh, mesh_dim=0) 721 self.assertEqual(local_tensor, torch.zeros(3, 3)) 722 723 @with_comms 724 def test_scatter_1d(self): 725 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 726 scatter_tensor_shape = [3, 3, 3] 727 for scatter_dim in range(len(scatter_tensor_shape)): 728 shard_placement = Shard(scatter_dim) 729 scatter_tensor_shape[scatter_dim] *= self.world_size 730 # make the random seed same across rank 731 torch.manual_seed(0) 732 global_tensor = torch.randn(scatter_tensor_shape, device=self.device_type) 733 splitted_list, _ = shard_placement._split_tensor( 734 global_tensor, mesh.size(), with_padding=True, contiguous=True 735 ) 736 recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()]) 737 # scatter on dim > 0 would generate non-contiguous tensor, verify that works 738 mesh_scatter(recv_tensor, splitted_list, mesh, mesh_dim=0) 739 self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()]) 740 741 @with_comms 742 def test_scatter_uneven(self): 743 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 744 my_rank = device_mesh.get_rank() 745 tensor_to_split = torch.randn( 746 device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type 747 ) 748 749 for shard_dim in range(tensor_to_split.ndim): 750 shard_placement = Shard(shard_dim) 751 752 tensor_to_scatter = tensor_to_split.clone() 753 tensor_splitted_list = list( 754 torch.chunk(tensor_to_split, self.world_size, dim=shard_dim) 755 ) 756 for _ in range(self.world_size - len(tensor_splitted_list)): 757 tensor_splitted_list.append(torch.tensor([], device=self.device_type)) 758 759 padded_tensor_list, pad_sizes = shard_placement._split_tensor( 760 tensor_to_scatter, 761 device_mesh.size(), 762 with_padding=True, 763 contiguous=True, 764 ) 765 766 scattered_tensor = torch.empty_like(padded_tensor_list[my_rank]) 767 mesh_scatter(scattered_tensor, padded_tensor_list, device_mesh, mesh_dim=0) 768 769 if pad_sizes[my_rank] != 0: 770 scattered_tensor = unpad_tensor( 771 scattered_tensor, shard_dim, pad_sizes[my_rank] 772 ) 773 774 if scattered_tensor.numel() == 0: 775 # We need to check numel() instead of size if a tensor is ([]) after unpadding, 776 # since the size could be ([0, 8]) after unpadding. 777 self.assertEqual( 778 scattered_tensor.numel(), tensor_splitted_list[my_rank].numel() 779 ) 780 else: 781 self.assertEqual( 782 scattered_tensor.size(), tensor_splitted_list[my_rank].size() 783 ) 784 self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank]) 785 786 @with_comms 787 def test_all_gather_uneven(self): 788 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 789 my_rank = device_mesh.get_rank() 790 tensor_to_split = torch.ones( 791 device_mesh.size() + 3, 792 device_mesh.size() + 1, 793 device=self.device_type, 794 ) 795 796 for shard_dim in range(tensor_to_split.ndim): 797 shard_placement = Shard(shard_dim) 798 tensor_padded_list, pad_sizes = shard_placement._split_tensor( 799 tensor_to_split, 800 device_mesh.size(), 801 with_padding=True, 802 contiguous=True, 803 ) 804 local_tensor = tensor_padded_list[my_rank] 805 big_tensor = funcol.all_gather_tensor( 806 local_tensor, gather_dim=shard_dim, group=(device_mesh, 0) 807 ) 808 big_tensor_chunks = list( 809 torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim) 810 ) 811 unpadded_list = [ 812 ( 813 unpad_tensor(big_tensor, shard_dim, pad_sizes[i]) 814 if pad_sizes[i] > 0 815 else big_tensor 816 ) 817 for i, big_tensor in enumerate(big_tensor_chunks) 818 ] 819 all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim) 820 821 self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size()) 822 self.assertEqual(all_gathered_tensor, tensor_to_split) 823 824 @with_comms 825 def test_reduce_scatter_contiguous(self): 826 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 827 my_rank = device_mesh.get_rank() 828 829 # Init the tensor 830 step = self.world_size * 2 831 total_elem = step**2 832 tensor = torch.arange(0, total_elem).view(step, -1).to(device=self.device_type) 833 tensor = tensor * (my_rank + 1) 834 835 # Get non-contiguous tensor by slicing 836 tensor_to_reduce = tensor[::2, :2] 837 tensor_contiguous = tensor_to_reduce.clone().contiguous() 838 839 # Partial to Shard to trigger reduce_scatter 840 tensor_to_reduce = DTensor.from_local( 841 tensor_to_reduce, device_mesh, [_Partial()] 842 ) 843 tensor_contiguous = DTensor.from_local( 844 tensor_contiguous, device_mesh, [_Partial()] 845 ) 846 new_tensor = tensor_to_reduce.redistribute(device_mesh, [Shard(0)]) 847 new_tensor_contiguous = tensor_contiguous.redistribute(device_mesh, [Shard(0)]) 848 849 # The output for contiguous and non-contiguous tensors of the same value 850 # should return the same reducescatter value. 851 new_tensor_local = new_tensor._local_tensor 852 new_tensor_contiguous_local = new_tensor_contiguous._local_tensor 853 self.assertEqual(new_tensor_local, new_tensor_contiguous_local) 854 self.assertEqual(list(new_tensor_local.size()), [1, 2]) 855 856 # Check the reduce numerical value 857 sum_base = (1 + self.world_size) * self.world_size / 2 858 first_elem = my_rank * sum_base * step * 2 859 expected_tensor = torch.tensor( 860 [[first_elem, first_elem + sum_base]], 861 dtype=new_tensor_local.dtype, 862 device=self.device_type, 863 ) 864 self.assertEqual(new_tensor_local, expected_tensor) 865 866 @with_comms 867 def test_reduce_scatter_uneven(self): 868 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 869 my_rank = device_mesh.get_rank() 870 tensor_to_split = ( 871 torch.ones( 872 device_mesh.size() + 3, 873 device_mesh.size() + 1, 874 device=self.device_type, 875 ) 876 * self.rank 877 ) 878 879 for shard_dim in range(tensor_to_split.ndim): 880 shard_placement = Shard(shard_dim) 881 tensor_to_scatter = tensor_to_split.clone() 882 883 tensor_splitted_list = list( 884 torch.chunk(tensor_to_split, self.world_size, dim=shard_dim) 885 ) 886 for _ in range(self.world_size - len(tensor_splitted_list)): 887 tensor_splitted_list.append(torch.tensor([], device=self.device_type)) 888 889 padded_tensor_list, pad_sizes = shard_placement._split_tensor( 890 tensor_to_scatter, 891 device_mesh.size(), 892 with_padding=True, 893 contiguous=True, 894 ) 895 896 tensor_to_reduce = torch.cat(padded_tensor_list, shard_dim) 897 898 res_num = ((0 + self.world_size - 1) * self.world_size) / 2 899 900 scattered_tensor = funcol.reduce_scatter_tensor( 901 tensor_to_reduce, 902 reduceOp="sum", 903 scatter_dim=shard_dim, 904 group=(device_mesh, 0), 905 ) 906 907 # unpad scattered_tensor 908 if pad_sizes[my_rank] > 0: 909 scattered_tensor = unpad_tensor( 910 scattered_tensor, shard_dim, pad_sizes[my_rank] 911 ) 912 913 if scattered_tensor.numel() == 0: 914 # We need to check numel() instead of size if a tensor is ([]) after unpadding, 915 # since the size could be ([0, 8]) after unpadding. 916 self.assertEqual( 917 scattered_tensor.numel(), tensor_splitted_list[my_rank].numel() 918 ) 919 else: 920 self.assertEqual( 921 scattered_tensor.size(), tensor_splitted_list[my_rank].size() 922 ) 923 self.assertEqual( 924 scattered_tensor, 925 torch.ones_like(tensor_splitted_list[my_rank]) * res_num, 926 ) 927 928 @with_comms 929 def test_broadcast_nd(self): 930 mesh_tensor = torch.arange(8).reshape(2, 2, 2) 931 mesh = DeviceMesh(self.device_type, mesh_tensor) 932 local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank 933 934 # check all dim groups 935 dim_to_subgroups = mesh.get_all_groups() 936 for dim, dim_group in enumerate(dim_to_subgroups): 937 dim_group_size = get_world_size(dim_group) 938 global_ranks = [ 939 get_global_rank(dim_group, i) for i in range(dim_group_size) 940 ] 941 cloned_local_tensor = local_tensor.clone() 942 mesh_broadcast(cloned_local_tensor, mesh, mesh_dim=dim) 943 res_num = global_ranks[0] 944 self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num) 945 946 @with_comms 947 def test_scatter_nd(self): 948 mesh_tensor = torch.arange(8).reshape(2, 2, 2) 949 mesh = DeviceMesh(self.device_type, mesh_tensor) 950 951 # check all dim groups 952 dim_to_subgroups = mesh.get_all_groups() 953 for dim, dim_group in enumerate(dim_to_subgroups): 954 dim_group_size = get_world_size(dim_group) 955 global_ranks = [ 956 get_global_rank(dim_group, i) for i in range(dim_group_size) 957 ] 958 scattered_tensors = [ 959 torch.ones(3, 3, device=self.device_type) * global_rank 960 for global_rank in global_ranks 961 ] 962 received_tensor = torch.empty_like( 963 scattered_tensors[mesh.get_coordinate()[dim]] 964 ) 965 mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim) 966 self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank) 967 968 969if __name__ == "__main__": 970 run_tests() 971