1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4from numpy.testing import assert_array_equal 5 6import torch 7import torch.nn.functional as F 8from torch.distributed._functional_collectives import AsyncCollectiveTensor 9 10from torch.distributed._tensor import ( 11 DeviceMesh, 12 distribute_tensor, 13 DTensor, 14 init_device_mesh, 15) 16from torch.distributed._tensor.debug import CommDebugMode 17from torch.distributed._tensor.placement_types import ( 18 DTensorSpec, 19 Partial, 20 Replicate, 21 Shard, 22 TensorMeta, 23) 24from torch.distributed.tensor.parallel import ( 25 ColwiseParallel, 26 parallelize_module, 27 RowwiseParallel, 28) 29 30from torch.testing._internal.common_utils import run_tests 31from torch.testing._internal.distributed._tensor.common_dtensor import ( 32 DTensorTestBase, 33 with_comms, 34) 35 36 37c10d_functional = torch.ops.c10d_functional 38 39 40class DummyMLP(torch.nn.Module): 41 def __init__(self, device): 42 super().__init__() 43 self.net1 = torch.nn.Linear(5, 1024, device=device) 44 self.relu = torch.nn.ReLU() 45 self.net2 = torch.nn.Linear(1024, 4, device=device) 46 47 def forward(self, x): 48 return self.net2(F.relu(self.net1(x))) 49 50 def reset_parameters(self, *args, **kwargs): 51 with torch.no_grad(): 52 self.net1.weight.fill_(0.5) 53 self.net2.weight.fill_(1) 54 self.net1.bias.fill_(1.5) 55 self.net2.bias.fill_(1.2) 56 57 58class DTensorTest(DTensorTestBase): 59 @with_comms 60 def test_dtensor_constructor(self): 61 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 62 placements = [Shard(0)] 63 local_tensor = torch.randn(3, 3, requires_grad=True) 64 65 spec = DTensorSpec( 66 device_mesh, 67 tuple(placements), 68 tensor_meta=TensorMeta( 69 torch.Size([self.world_size * 3, 3]), 70 local_tensor.stride(), 71 local_tensor.dtype, 72 ), 73 ) 74 75 dist_tensor = DTensor( 76 local_tensor, 77 spec, 78 requires_grad=True, 79 ) 80 self.assertEqual(dist_tensor.size(), torch.Size((self.world_size * 3, 3))) 81 82 with self.assertWarnsRegex(UserWarning, "To construct"): 83 DTensor( 84 local_tensor, 85 spec, 86 requires_grad=False, 87 ) 88 89 @with_comms 90 def test_meta_dtensor(self): 91 device_mesh = self.build_device_mesh() 92 dist_specs = [[Shard(0)], [Replicate()]] 93 meta_tensor = torch.randn(1024, 2048, device="meta") 94 for dist_spec in dist_specs: 95 # Test distribute_tensor on meta tensor 96 meta_dtensor = distribute_tensor(meta_tensor, device_mesh, dist_spec) 97 self.assertTrue(meta_dtensor.is_meta) 98 meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type) 99 torch.nn.init.constant_(meta_dtensor, 1.2) 100 value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.2) 101 self.assertFalse(meta_dtensor.is_meta) 102 self.assertEqual(meta_dtensor.device.type, self.device_type) 103 self.assertEqual(meta_dtensor.to_local(), value_tensor) 104 # Test from_local on meta tensor 105 meta_dtensor = DTensor.from_local(meta_tensor, device_mesh, dist_spec) 106 meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type) 107 torch.nn.init.constant_(meta_dtensor, 1.5) 108 self.assertEqual(meta_dtensor.device.type, self.device_type) 109 value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5) 110 self.assertEqual(meta_dtensor.to_local(), value_tensor) 111 112 @with_comms 113 def test_modules_w_meta_dtensor(self): 114 model = DummyMLP("meta") 115 device_mesh = self.build_device_mesh() 116 parallelize_plan = { 117 "net1": ColwiseParallel(), 118 "net2": RowwiseParallel(), 119 } 120 model_tp = parallelize_module(model, device_mesh, parallelize_plan) 121 model_tp.to_empty(device=self.device_type) 122 model_tp.reset_parameters() 123 optim = torch.optim.SGD(model_tp.parameters(), lr=0.1) 124 model_regular = DummyMLP(self.device_type) 125 model_regular_tp = parallelize_module( 126 model_regular, device_mesh, parallelize_plan 127 ) 128 optim_regular = torch.optim.SGD(model_regular_tp.parameters(), lr=0.1) 129 model_regular_tp.reset_parameters() 130 torch.manual_seed(0) 131 inp = torch.randn(20, 5, device=self.device_type) 132 133 output = model_tp(inp) 134 output_regular = model_regular_tp(inp) 135 self.assertEqual(output, output_regular) 136 137 output.sum().backward() 138 output_regular.sum().backward() 139 140 optim.step() 141 optim_regular.step() 142 143 torch.manual_seed(1) 144 inp = torch.randn(20, 5, device=self.device_type) 145 self.assertEqual(model_tp(inp), model_regular_tp(inp)) 146 147 @with_comms 148 def test_dtensor_stride(self): 149 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 150 shard0_spec = [Shard(0)] 151 local_tensor = torch.randn(4, 8) 152 global_shape = torch.Size([self.world_size * 4, 8]) 153 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec) 154 # won't affect stride 155 self.assertEqual(dist_tensor.stride(), (8, 1)) 156 157 shard1_spec = [Shard(1)] 158 local_tensor = torch.randn(8, 4) 159 global_shape = torch.Size([8, self.world_size * 4]) 160 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec) 161 # will affect stride after DT initialized 162 self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1)) 163 164 # if initialized from a transposed mat 165 local_tensor = torch.randn(8, 4, 8) 166 local_tensor_t = local_tensor.permute(1, 2, 0) 167 global_shape = torch.Size([4, self.world_size * 8, 8]) 168 self.assertEqual(local_tensor_t.stride(), (8, 1, 32)) 169 dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec) 170 global_stride = (8 * self.world_size, 1, 32 * self.world_size) 171 self.assertEqual(dist_tensor.stride(), global_stride) 172 173 @with_comms 174 def test_from_local(self): 175 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 176 placements = [Shard(0)] 177 local_tensor = torch.randn(3, 3) 178 sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) 179 self.assertEqual(sharded_tensor.size(), torch.Size([self.world_size * 3, 3])) 180 181 replica_spec = [Replicate()] 182 ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec) 183 self.assertEqual(ddp_tensor.size(), local_tensor.size()) 184 185 partial_spec = [Partial()] 186 partial_tensor = DTensor.from_local(local_tensor, device_mesh, partial_spec) 187 self.assertEqual(partial_tensor.size(), local_tensor.size()) 188 189 # test dist tensor works with torch.Tensor during backwards 190 local_tensor_with_grad = torch.randn(3, 3, requires_grad=True) 191 # do some operations on local tensor 192 local_tensor_temp = local_tensor_with_grad * 3 193 # create the dist tensor with non leaf local tensor, dist tensor created 194 # should also be non leaf node 195 dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements) 196 self.assertFalse(dist_tensor.is_leaf) 197 # do some random operations on dist tensor 198 output = dist_tensor * 3 199 self.assertIsInstance(output, DTensor) 200 # trigger .backward() on dist tensor directly 201 local_grad = torch.ones(3, 3) 202 grad_output = DTensor.from_local(local_grad, device_mesh, placements) 203 # run backward directly on dist tensor 204 output.backward(grad_output) 205 # check it gradients flow back to original torch.Tensor 206 self.assertIsNotNone(local_tensor_with_grad.grad) 207 expected_grad = torch.ones(3, 3) * 9 208 self.assertEqual(local_tensor_with_grad.grad, expected_grad) 209 210 @with_comms 211 def test_from_local_uneven_sharding(self): 212 mesh_shape = (self.world_size,) 213 device_mesh = init_device_mesh(self.device_type, mesh_shape) 214 215 uneven_dim0_size = self.world_size + 1 216 global_tensor = torch.randn(uneven_dim0_size, 2) 217 shard_placement = Shard(0) 218 tensor_list, _ = shard_placement._split_tensor( 219 global_tensor, 220 device_mesh.size(mesh_dim=0), 221 with_padding=False, 222 contiguous=True, 223 ) 224 225 dtensor = DTensor.from_local( 226 tensor_list[self.rank], 227 device_mesh, 228 (Shard(0),), 229 shape=global_tensor.size(), 230 stride=global_tensor.stride(), 231 ) 232 233 self.assertEqual(dtensor.size(), global_tensor.size()) 234 self.assertEqual(dtensor.stride(), global_tensor.stride()) 235 236 @with_comms 237 def test_from_local_uneven_sharding_raise_error(self): 238 mesh_shape = (self.world_size,) 239 device_mesh = init_device_mesh(self.device_type, mesh_shape) 240 241 uneven_dim0_size = self.world_size + 1 242 global_tensor = torch.randn(uneven_dim0_size, 2) 243 shard_placement = Shard(0) 244 tensor_list, _ = shard_placement._split_tensor( 245 global_tensor, 246 device_mesh.size(mesh_dim=0), 247 with_padding=False, 248 contiguous=True, 249 ) 250 251 with self.assertRaisesRegex( 252 RuntimeError, "Please pass both shape and stride at the same time." 253 ): 254 dtensor = DTensor.from_local( 255 tensor_list[self.rank], 256 device_mesh, 257 (Shard(0),), 258 shape=global_tensor.size(), 259 ) 260 261 with self.assertRaisesRegex( 262 RuntimeError, "Please pass both shape and stride at the same time." 263 ): 264 dtensor = DTensor.from_local( 265 tensor_list[self.rank], 266 device_mesh, 267 (Shard(0),), 268 stride=global_tensor.stride(), 269 ) 270 271 @with_comms 272 def test_from_local_negative_dim(self): 273 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 274 placements = [Shard(-1)] 275 local_tensor = torch.randn(3, 3) 276 sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) 277 self.assertEqual(sharded_tensor.placements[0].dim, 1) 278 279 @with_comms 280 def test_to_local(self): 281 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 282 placements = (Shard(0),) 283 local_tensor_with_grad = torch.randn( 284 3, 3, device=self.device_type, requires_grad=True 285 ) 286 dist_tensor_shape = torch.Size([self.world_size * 3, 3]) 287 spec = DTensorSpec( 288 mesh=device_mesh, 289 placements=placements, 290 tensor_meta=TensorMeta( 291 dist_tensor_shape, 292 local_tensor_with_grad.stride(), 293 local_tensor_with_grad.dtype, 294 ), 295 ) 296 sharded_tensor = DTensor( 297 local_tensor_with_grad, 298 spec, 299 requires_grad=True, 300 ) 301 self.assertEqual(sharded_tensor.size(), dist_tensor_shape) 302 self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad) 303 304 # test dist tensor works with torch.Tensor during backwards 305 # dist tensor created is a leaf node, do some operation on dist tensor 306 temp_st = sharded_tensor * 3 307 308 # do some operation on local tensor of the dist tensor 309 new_tensor_with_grad = torch.randn( 310 3, 3, device=self.device_type, requires_grad=True 311 ) 312 res = temp_st.to_local() + new_tensor_with_grad 313 # call backward directly on torch.Tensor, and see if it works by 314 # propagating through dist tensor 315 res.sum().backward() 316 self.assertIsNotNone(sharded_tensor.grad) 317 318 self.assertEqual(sharded_tensor.grad.to_local(), torch.ones(3, 3) * 3) 319 320 # test the case when grad stride is different from fwd input. 321 res = sharded_tensor.to_local() 322 model = torch.nn.ReLU() 323 res.register_hook(lambda grad: grad.t()) 324 target = torch.randn(3, 3, device=self.device_type) 325 mae_loss = torch.nn.L1Loss() 326 output = mae_loss(model(res), target) 327 # The manual change to grad stride leads to the failure of the copy op afterwards. 328 # so that we need a try-catch here. 329 try: 330 output.backward() 331 except RuntimeError: 332 self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size]) 333 334 # test the case under no-grad we directly return the local tensor 335 with torch.no_grad(): 336 local_no_grad = sharded_tensor.to_local() 337 assert local_no_grad is sharded_tensor._local_tensor 338 339 @with_comms 340 def test_to_local_grad_hint(self): 341 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 342 placements = (Shard(0),) 343 global_tensor = torch.ones(8, 3, requires_grad=True) 344 345 sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements) 346 comm_mode = CommDebugMode() 347 348 with comm_mode: 349 local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local( 350 grad_placements=[Partial()] 351 ) 352 local_out.backward(torch.ones_like(local_out)) 353 354 self.assertEqual( 355 comm_mode.comm_counts[c10d_functional.all_gather_into_tensor], 1 356 ) 357 self.assertEqual( 358 comm_mode.comm_counts[c10d_functional.reduce_scatter_tensor], 1 359 ) 360 361 replica_grad = sharded_dtensor.grad.full_tensor() 362 self.assertEqual(replica_grad, global_tensor * self.world_size) 363 364 @with_comms 365 def test_full_tensor_sync(self): 366 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 367 placements = (Shard(0),) 368 global_tensor = torch.ones(8, 3, requires_grad=True) 369 370 sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements) 371 full_out = sharded_dtensor.full_tensor() 372 self.assertFalse(isinstance(full_out, AsyncCollectiveTensor)) 373 self.assertEqual(full_out, global_tensor) 374 375 @with_comms 376 def test_full_tensor_grad_hint(self): 377 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 378 placements = (Shard(0),) 379 global_tensor = torch.ones(8, 3, requires_grad=True) 380 381 sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements) 382 local_out = sharded_dtensor.full_tensor(grad_placements=[Partial()]) 383 local_out.sum().backward() 384 385 replica_grad = sharded_dtensor.grad.full_tensor() 386 self.assertEqual(replica_grad, global_tensor * self.world_size) 387 388 @with_comms 389 def test_dtensor_new_empty_strided(self): 390 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 391 local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type) 392 my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)]) 393 new_strided_dtensor = my_dtensor.new_empty_strided( 394 (8, 8), (8, 1), requires_grad=True 395 ) 396 # test the op produces new dtensor and autograd works 397 self.assertEqual(new_strided_dtensor.shape, my_dtensor.shape) 398 new_strided_dtensor.sum().backward() 399 self.assertIsNotNone(new_strided_dtensor.grad) 400 self.assertIsInstance(new_strided_dtensor.grad, DTensor) 401 402 # test backward new_empty_strided with sharding works correctly 403 my_dtensor.to_local().sum().backward() 404 local_tensor.sum().backward() 405 self.assertEqual(my_dtensor.grad, new_strided_dtensor.grad) 406 self.assertEqual( 407 my_dtensor.grad.redistribute(placements=[Replicate()]).to_local(), 408 local_tensor.grad, 409 ) 410 411 @with_comms 412 def test_dtensor_async_output(self): 413 # Tests that if the output of some dtensor operations isn't used in any compute, 414 # the output should be an AsyncCollectiveTensor (representing the fact that 415 # we haven't synced the collective yet). 416 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 417 418 def fn(dt): 419 dt_out_redistribute = dt.redistribute(mesh, [Replicate()], async_op=True) 420 # Make sure we haven't synced yet 421 # TODO: figure out why this is returning None 422 # self.assertTrue(_tensor_needs_wait(dt_out_redistribute)) 423 dt_out_redistribute_view = dt_out_redistribute.view( 424 dt_out_redistribute.shape 425 ) 426 local_tensor = dt_out_redistribute_view.to_local() 427 return local_tensor 428 429 x = torch.ones((4, 2), device=self.device_type) 430 dt = distribute_tensor(x, mesh, [Shard(0)]) 431 out = fn(dt) 432 # Make sure we haven't synced yet 433 self.assertEqual(type(out), AsyncCollectiveTensor) 434 self.assertFalse(out.completed) 435 out_view = out.view(-1) 436 437 # Assert that output is a `AsyncCollectiveTensor` 438 self.assertEqual(type(out_view), AsyncCollectiveTensor) 439 self.assertFalse(out.completed) 440 441 # Use the daa, requiring a sync 442 ref = torch.ones((4, 2), device=self.device_type) + 1 443 ref = ref.view(-1) 444 out_data = out_view + 1 445 self.assertEqual(type(out_data), torch.Tensor) 446 self.assertEqual(out_data, ref) 447 448 # test async_op = False default 449 sync_out = dt.redistribute(mesh, [Replicate()]) 450 self.assertFalse(isinstance(sync_out, AsyncCollectiveTensor)) 451 self.assertEqual(sync_out.to_local(), x) 452 453 @with_comms 454 def test_from_local_then_to_local(self): 455 # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works 456 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 457 placements = [Shard(0)] 458 459 # step 1. construct from construct local tensor 460 local_tensor_with_grad = torch.randn( 461 3, 3, device=self.device_type, requires_grad=True 462 ) 463 # do some operations on local tensor 464 local_tensor_temp = local_tensor_with_grad + 8 465 # step 2. create the dist tensor with non leaf local tensor, dist tensor 466 # created should also be non leaf node 467 dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements) 468 self.assertFalse(dist_tensor.is_leaf) 469 # do some random operations on dist tensor 470 output = dist_tensor * 6 471 self.assertIsInstance(output, DTensor) 472 473 # step 3. do some operation on local tensor of the dist tensor 474 new_tensor_with_grad = torch.randn( 475 3, 3, device=self.device_type, requires_grad=True 476 ) 477 res = output.to_local() + new_tensor_with_grad 478 # call backward directly on torch.Tensor, and see if it works by 479 # propagating all the way back to the original torch.Tensor 480 res.sum().backward() 481 self.assertIsNotNone(local_tensor_with_grad.grad) 482 483 expected_grad = torch.ones(3, 3) * 6 484 self.assertEqual(local_tensor_with_grad.grad, expected_grad) 485 486 @with_comms 487 def test_dtensor_spec_read_only_after_set(self): 488 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 489 placements = [Shard(0)] 490 local_tensor = torch.randn(3, 3) 491 sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) 492 493 # modify placements, and dist_tensor's spec should not be changed 494 placements[0] = Replicate() 495 self.assertTrue(sharded_tensor.placements is not placements) 496 self.assertNotEqual(sharded_tensor.placements, placements) 497 498 @with_comms 499 def test_dtensor_spec_hash(self): 500 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 501 placements = [Shard(0)] 502 local_tensor = torch.randn(3, 3) 503 local_tensor2 = torch.randn(3, 3) 504 sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) 505 sharded_tensor2 = DTensor.from_local(local_tensor2, device_mesh, placements) 506 # note that DTensorSpec without real tensor data, so the hash would be the same 507 # as long as the mesh, placements and tensor properties are the same 508 self.assertEqual(hash(sharded_tensor._spec), hash(sharded_tensor2._spec)) 509 510 # change the placements would change the hash 511 local_tensor3 = torch.ones(3, 3) 512 replica_spec = [Replicate()] 513 replica_tensor = DTensor.from_local( 514 local_tensor3, device_mesh, replica_spec, run_check=False 515 ) 516 self.assertNotEqual(hash(sharded_tensor._spec), hash(replica_tensor._spec)) 517 518 @with_comms 519 def test_dtensor_properties(self): 520 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 521 placements = [Shard(0)] 522 local_tensor = torch.randn(3, 3) 523 sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) 524 self.assertEqual(sharded_tensor.device.type, self.device_type) 525 526 @with_comms 527 def test_dtensor_save_load(self): 528 import io 529 530 device_mesh = self.build_device_mesh() 531 placements = [Shard(0)] 532 local_tensor = torch.randn(3, 3) 533 sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) 534 buffer = io.BytesIO() 535 torch.save(sharded_tensor, buffer) 536 buffer.seek(0) 537 reloaded_st = torch.load(buffer) 538 self.assertEqual(sharded_tensor, reloaded_st) 539 # Test weights_only load 540 try: 541 torch.serialization.add_safe_globals( 542 [DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta] 543 ) 544 buffer.seek(0) 545 reloaded_st = torch.load(buffer, weights_only=True) 546 self.assertEqual(sharded_tensor, reloaded_st) 547 finally: 548 torch.serialization.clear_safe_globals() 549 550 551class DTensorMeshTest(DTensorTestBase): 552 @property 553 def world_size(self): 554 return 8 555 556 def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor): 557 if self.rank in mesh: 558 self.assertEqual(tensor, exp_in_mesh) 559 else: 560 self.assertEqual(tensor, exp_out_of_mesh) 561 562 @with_comms 563 def test_dtensor_device_mesh_device_conversion(self): 564 # construct a cuda device mesh 565 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 566 567 # construct from a cpu local tensor with cuda device mesh 568 # should automatically convert the dist tensor to cuda 569 placements = [Shard(0)] 570 local_tensor = torch.randn(3, 3) 571 dist_tensor = DTensor.from_local(local_tensor, mesh, placements) 572 self.assertEqual(dist_tensor.device.type, self.device_type) 573 self.assertEqual(dist_tensor.to_local().device.type, self.device_type) 574 575 @with_comms 576 def test_dtensor_api_device_mesh_context_manager(self): 577 with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh: 578 placements = [Shard(0)] 579 local_tensor = torch.randn(3, 3) 580 sharded_tensor = DTensor.from_local( 581 local_tensor, device_mesh=mesh, placements=placements 582 ) 583 584 with DeviceMesh(self.device_type, list(range(self.world_size))): 585 placements = [Shard(0)] 586 local_tensor = torch.randn(3, 3) 587 sharded_tensor = DTensor.from_local(local_tensor, placements=placements) 588 replica_spec = [Replicate()] 589 replica_tensor = sharded_tensor.redistribute(placements=replica_spec) 590 self.assertEqual( 591 replica_tensor.size(), torch.Size([3 * self.world_size, 3]) 592 ) 593 594 with DeviceMesh(self.device_type, torch.arange(self.world_size)): 595 placements = [Shard(0)] 596 global_shape = torch.Size([3 * self.world_size, 3]) 597 global_tensor = torch.randn(global_shape) 598 sharded_tensor = distribute_tensor(global_tensor, placements=placements) 599 self.assertEqual(sharded_tensor.to_local().shape, torch.Size([3, 3])) 600 601 mesh_2d = DeviceMesh( 602 self.device_type, torch.arange(self.world_size).reshape(2, 4) 603 ) 604 605 with mesh_2d: 606 shard_2d_spec = [Shard(0), Replicate()] 607 tensor_2d = distribute_tensor(global_tensor, placements=shard_2d_spec) 608 609 self.assertEqual(tensor_2d.to_local().shape, torch.Size([3 * 4, 3])) 610 611 sharded_after_2d = distribute_tensor(global_tensor, placements=placements) 612 self.assertEqual(sharded_after_2d.to_local().shape, torch.Size([3, 3])) 613 614 @with_comms 615 def test_dtensor_2d_mesh(self): 616 mesh_tensor = torch.arange(self.world_size).reshape(2, 4) 617 # construct a cuda device mesh 618 mesh = DeviceMesh(self.device_type, mesh_tensor) 619 620 # construct a dist tensor on 2d device mesh and test if works 621 placements = [Shard(0), Shard(1)] 622 local_tensor = torch.randn(3, 3) 623 dist_tensor = DTensor.from_local(local_tensor, mesh, placements) 624 self.assertEqual( 625 dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)]) 626 ) 627 self.assertEqual(dist_tensor.device.type, self.device_type) 628 self.assertEqual(dist_tensor.to_local().device.type, self.device_type) 629 630 # if shard on the same tensor dimension 631 # we should correctly construct the global tensor size 632 shard_same_dim_spec = [Shard(0), Shard(0)] 633 local_tensor = torch.randn(3, 3) 634 dist_tensor = DTensor.from_local(local_tensor, mesh, shard_same_dim_spec) 635 self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3])) 636 637 @with_comms 638 def test_device_mesh_nd(self): 639 # construct a cuda device mesh 640 mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2) 641 mesh = DeviceMesh(self.device_type, mesh_tensor) 642 # construct a dist tensor on 3d device mesh and test if works 643 placements = [Shard(0), Shard(1), Shard(2)] 644 local_tensor = torch.randn(3, 3, 3) 645 dist_tensor = DTensor.from_local(local_tensor, mesh, placements) 646 self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6])) 647 self.assertEqual(dist_tensor.device.type, self.device_type) 648 self.assertEqual(dist_tensor.to_local().device.type, self.device_type) 649 650 # construct a dist tensor on 3d device mesh with some shards on same dim 651 placements = [Shard(0), Shard(0), Shard(2)] 652 local_tensor = torch.randn(3, 3, 3) 653 dist_tensor = DTensor.from_local(local_tensor, mesh, placements) 654 self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6])) 655 self.assertEqual(dist_tensor.device.type, self.device_type) 656 self.assertEqual(dist_tensor.to_local().device.type, self.device_type) 657 658 @with_comms 659 def test_dtensor_spec_local_shard_offset(self): 660 device_mesh = DeviceMesh( 661 self.device_type, torch.arange(self.world_size).reshape(2, 4) 662 ) 663 tensor_shape = (3 * self.world_size, 3 * self.world_size) 664 # sharding specs and its corresponding local shard offsets 665 shard_spec_and_offsets = [ 666 ( 667 [Shard(0), Replicate()], 668 (3 * (self.world_size // 2) * (self.rank // 4), 0), 669 ), 670 ( 671 [Shard(1), Replicate()], 672 (0, 3 * (self.world_size // 2) * (self.rank // 4)), 673 ), 674 ( 675 [Replicate(), Shard(0)], 676 (3 * (self.world_size // 4) * (self.rank % 4), 0), 677 ), 678 ( 679 [Replicate(), Shard(1)], 680 (0, 3 * (self.world_size // 4) * (self.rank % 4)), 681 ), 682 ] 683 684 from torch.distributed._tensor._utils import ( 685 compute_local_shape_and_global_offset, 686 ) 687 688 # loop through all sharding specs and check local shard offsets 689 logical_tensor = torch.randn(tensor_shape) 690 for placements, expected_shard_offsets in shard_spec_and_offsets: 691 dtensor = distribute_tensor(logical_tensor, device_mesh, placements) 692 _, offset = compute_local_shape_and_global_offset( 693 dtensor.shape, device_mesh, dtensor.placements 694 ) 695 self.assertEqual(expected_shard_offsets, offset) 696 697 @with_comms 698 def test_from_local_sub_mesh(self): 699 mesh = DeviceMesh(self.device_type, [0, 2]) 700 local_tensor = torch.ones(3, 4) 701 702 dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) 703 self.assertEqual(dtensor.size(), torch.Size([6, 4])) 704 705 self.sub_mesh_assert_equal( 706 mesh.mesh, 707 torch.ones(3, 4), 708 torch.tensor([]), 709 dtensor.to_local(), 710 ) 711 712 # test dtensor created in submesh, the operation should only 713 # be applied to the local shard inside the mesh, not the whole 714 # world, so only 0/2 really run the computation 715 dtensor = dtensor + 2 716 717 self.sub_mesh_assert_equal( 718 mesh.mesh, 719 torch.ones(3, 4) + 2, 720 torch.tensor([]), 721 dtensor.to_local(), 722 ) 723 724 @with_comms 725 def test_default_value_sub_mesh(self): 726 mesh = DeviceMesh(self.device_type, [0, 2]) 727 728 # test scalar return value 729 local_tensor1 = torch.ones(4, 3) 730 local_tensor2 = torch.ones(4, 3) 731 dtensor1 = DTensor.from_local(local_tensor1, mesh, [Shard(0)]) 732 dtensor2 = DTensor.from_local(local_tensor2, mesh, [Shard(0)]) 733 local_res = dtensor1.equal(dtensor2) # equal returns local result 734 self.sub_mesh_assert_equal( 735 mesh.mesh, 736 True, 737 True, 738 local_res, 739 ) 740 741 # test 0-d tensor return value 742 local_tensor = torch.ones(4, 3) 743 dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]).sum() 744 self.sub_mesh_assert_equal( 745 mesh.mesh, 746 torch.tensor(12.0), 747 torch.tensor(0.0), 748 dtensor.to_local(), 749 ) 750 751 # test List[torch.Tensor] return value 752 local_tensor = torch.ones(3, 4) 753 dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) 754 dtensor_list = dtensor.split([2, 2], dim=1) 755 self.sub_mesh_assert_equal( 756 mesh.mesh, 757 [torch.ones(3, 2)] * 2, 758 [torch.tensor([])] * 2, 759 [dt.to_local() for dt in dtensor_list], 760 ) 761 762 @with_comms 763 def test_redistribute_sub_mesh(self): 764 mesh = DeviceMesh(self.device_type, [0, 2]) 765 766 # test redistribute on a submesh 767 local_tensor1 = torch.ones(4, 3) 768 sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)]) 769 replicated_dtensor = sharded_dtensor.redistribute(placements=[Replicate()]) 770 self.sub_mesh_assert_equal( 771 mesh.mesh, torch.ones(8, 3), torch.tensor([]), replicated_dtensor.to_local() 772 ) 773 sharded_again = replicated_dtensor.redistribute(placements=[Shard(0)]) 774 self.sub_mesh_assert_equal( 775 mesh.mesh, torch.ones(4, 3), torch.tensor([]), sharded_again.to_local() 776 ) 777 778 @with_comms 779 def test_implicit_replication(self): 780 mesh = init_device_mesh(self.device_type, (self.world_size,)) 781 local_tensor1 = torch.ones(4, 3) 782 sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)]) 783 784 from torch.distributed._tensor.experimental import implicit_replication 785 786 with implicit_replication(): 787 out_dt = sharded_dtensor + torch.ones(3, device=self.device_type) 788 self.assertEqual(out_dt.placements, [Shard(0)]) 789 self.assertEqual(out_dt.shape, (4 * self.world_size, 3)) 790 local_shard = out_dt.to_local() 791 self.assertEqual(local_shard.shape, (4, 3)) 792 self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3)) 793 794 @with_comms 795 def test_auto_implicit_replication(self): 796 mesh = init_device_mesh(self.device_type, (self.world_size,)) 797 798 local_tensor = torch.ones(self.world_size, 3, device=self.device_type) 799 sharded_dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) 800 801 # automatically turn tensor to DTensor replicate when ndim = 0 and numel = 1 802 ndim_0_tensor = torch.tensor(1, device=self.device_type) 803 804 def add_scalar_tensor_with_dtensor(): 805 return sharded_dtensor + ndim_0_tensor 806 807 result = add_scalar_tensor_with_dtensor().to_local() 808 self.assertEqual(result, local_tensor + ndim_0_tensor) 809 self.assertNotWarn( 810 add_scalar_tensor_with_dtensor, 811 "Found a non-scalar tensor with numel=1 and ndim!=0", 812 ) 813 814 # automatically turn tensor to DTensor replicate when ndim = 1 and numel = 1 815 numel_1_tensor = torch.tensor([1], device=self.device_type) 816 self.assertEqual( 817 (sharded_dtensor + numel_1_tensor).to_local(), local_tensor + numel_1_tensor 818 ) 819 820 821class TestDTensorPlacementTypes(DTensorTestBase): 822 @property 823 def world_size(self): 824 return 8 825 826 def _create_tensor(self, size): 827 # Keep everything deterministic. 828 torch.manual_seed(0) 829 tensor = torch.rand(size) 830 if self.device_type == "cuda": 831 return tensor.cuda() 832 else: 833 return tensor 834 835 @with_comms 836 def test_split_tensor_1D(self) -> None: 837 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 838 shard_placement = Shard(0) 839 840 for size in range(8): 841 tensor = self._create_tensor(size) 842 splitted_tensor_list, pad_sizes = shard_placement._split_tensor( 843 tensor, 844 mesh.size(), 845 with_padding=True, 846 contiguous=True, 847 ) 848 if size == 0: 849 # when tensor size is 0, there is no padding needed for all the ranks. 850 expected_pad_sizes = [] 851 assert_array_equal(expected_pad_sizes, pad_sizes) 852 853 is_tensor_empty = [ 854 False if splitted_tensor.numel() > 0 else True 855 for splitted_tensor in splitted_tensor_list 856 ] 857 expected_is_tensor_empty = [True] * self.world_size 858 assert_array_equal(expected_is_tensor_empty, is_tensor_empty) 859 else: 860 expected_pad_sizes = [ 861 0 if idx < size else 1 862 for idx, _ in enumerate(range(self.world_size)) 863 ] 864 assert_array_equal(expected_pad_sizes, pad_sizes) 865 866 from torch.distributed._tensor._collective_utils import unpad_tensor 867 868 unpadded_list = [ 869 unpad_tensor(tensor, shard_placement.dim, pad_sizes[i]) 870 if pad_sizes[i] > 0 871 else tensor 872 for i, tensor in enumerate(splitted_tensor_list) 873 ] 874 expected_is_tensor_empty = [ 875 False if idx < size else True 876 for idx, _ in enumerate(range(self.world_size)) 877 ] 878 is_tensor_empty = [ 879 False if unpadded_tensor.numel() > 0 else True 880 for unpadded_tensor in unpadded_list 881 ] 882 assert_array_equal(expected_is_tensor_empty, is_tensor_empty) 883 884 885if __name__ == "__main__": 886 run_tests() 887