1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4import copy 5import functools 6import unittest 7from unittest.mock import patch 8 9import torch 10import torch._dynamo 11import torch._dynamo.testing 12import torch.distributed as dist 13import torch.nn as nn 14from torch._C import FileCheck 15from torch._inductor.utils import run_and_get_triton_code 16from torch.distributed._tensor import ( 17 DeviceMesh, 18 DTensor, 19 init_device_mesh, 20 Partial, 21 Replicate, 22 Shard, 23) 24from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta 25from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 26 checkpoint_wrapper, 27 CheckpointImpl, 28) 29from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 30from torch.distributed.tensor.parallel import ( 31 ColwiseParallel, 32 parallelize_module, 33 PrepareModuleInput, 34 PrepareModuleOutput, 35 RowwiseParallel, 36) 37from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 38from torch.testing._internal.common_utils import ( 39 instantiate_parametrized_tests, 40 parametrize, 41 run_tests, 42) 43from torch.testing._internal.distributed._tensor.common_dtensor import ( 44 DTensorTestBase, 45 MLPModule, 46 with_comms, 47) 48from torch.testing._internal.distributed.fake_pg import FakeStore 49from torch.utils._triton import has_triton 50from torch.utils.checkpoint import checkpoint 51 52 53class SimpleModel(nn.Module): 54 def __init__(self, device): 55 super().__init__() 56 self.mlp_0 = MLPModule(device) 57 self.mlp_1 = MLPModule(device) 58 59 def forward(self, input): 60 return self.mlp_1(self.mlp_0(input)) 61 62 63def extract_graph(fx_g, _, graph_cell): 64 graph_cell[0] = fx_g.code 65 return fx_g 66 67 68# Make a custom compiler that runs aot autograd but extracts the fw graph 69fw_graph_cell = [None] 70bw_graph_cell = [None] 71fw_compiler = functools.partial(extract_graph, graph_cell=fw_graph_cell) 72bw_compiler = functools.partial(extract_graph, graph_cell=bw_graph_cell) 73 74from functorch.compile import min_cut_rematerialization_partition 75from torch._dynamo.backends.common import aot_autograd 76 77 78aot_eager_graph = aot_autograd( 79 fw_compiler=fw_compiler, 80 bw_compiler=bw_compiler, 81 partition_fn=min_cut_rematerialization_partition, 82) 83 84 85class TestDTensorCompile(torch._dynamo.test_case.TestCase): 86 def setUp(self): 87 super().setUp() 88 fake_store = FakeStore() 89 dist.init_process_group( 90 "fake", store=fake_store, rank=0, world_size=self.world_size 91 ) 92 93 def tearDown(self): 94 super().tearDown() 95 dist.destroy_process_group() 96 97 @property 98 def device_type(self) -> str: 99 return "cuda" if torch.cuda.is_available() else "cpu" 100 101 @property 102 def world_size(self) -> int: 103 return 2 104 105 def test_placement_compile(self): 106 def fn(x): 107 a = 0 108 if x.is_replicate(): 109 a += 1 110 if x.is_shard(): 111 a += 2 112 if x.dim < 0: 113 raise RuntimeError("dim < 0") 114 if x.is_shard(0): 115 a += 2 116 if x.is_shard(dim=0): 117 a += 2 118 if x.is_shard(dim=None): 119 a += 2 120 if x.is_partial(): 121 a += 3 122 return a 123 124 compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn) 125 126 for x in [Shard(0), Replicate(), Partial()]: 127 opt_fn = fn(x) 128 compiled_out = compiled_fn(x) 129 self.assertEqual(opt_fn, compiled_out) 130 131 def test_device_mesh_compile(self): 132 def fn(x): 133 # test size() 134 a = x.size() 135 b = x.size(0) 136 c = x.size(mesh_dim=0) 137 size = a + b + c 138 # test get_coordinate() 139 coord = x.get_coordinate() 140 # test get_group() 141 group = x.get_group() 142 return size, coord, group 143 144 compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn) 145 146 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 147 opt_fn = fn(mesh) 148 compiled_out = compiled_fn(mesh) 149 self.assertEqual(opt_fn, compiled_out) 150 151 def test_fakify_dtensor(self): 152 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 153 154 # pass in DTensor as inputs/outputs to the function 155 def fn(x): 156 return x 157 158 x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False) 159 ref = fn(x) 160 161 opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) 162 res = opt_fn(x) 163 self.assertEqual(res, ref) 164 165 def test_dynamo_dtensor(self): 166 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 167 168 # test passing in DTensor as inputs/outputs and run some tensor computation 169 def fn(x): 170 return x * x + 2 171 172 x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False) 173 ref = fn(x) 174 175 opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) 176 res = opt_fn(x) 177 self.assertEqual(res, ref) 178 179 def test_dtensor_attribute_access_on_intermediate(self): 180 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 181 182 def fn(x): 183 tmp = x * 2 184 if tmp.placements[0].is_shard(): 185 return tmp._local_tensor + 2 186 else: 187 return tmp._local_tensor + 3 188 189 x = DTensor.from_local(torch.ones(4), mesh, [Shard(0)], run_check=False) 190 ref = fn(x) 191 192 opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) 193 res = opt_fn(x) 194 self.assertEqual(res, ref) 195 196 def test_dtensor_constructor_w_graph_break(self): 197 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 198 x = torch.randn(64, 32, requires_grad=True) 199 spec = DTensorSpec( 200 mesh, 201 (Replicate(), Shard(0)), 202 tensor_meta=TensorMeta( 203 shape=torch.Size([128, 32]), stride=(32, 1), dtype=x.dtype 204 ), 205 ) 206 207 # test passing in DTensor as inputs/outputs and run some tensor computation 208 def fn(x): 209 print("graph break!") 210 return DTensor( 211 x, 212 spec, 213 requires_grad=x.requires_grad, 214 ) 215 216 out = fn(x) 217 out2 = torch.compile(fn, backend="eager")(x) 218 219 def test_dtensor_constructor_w_dynamo_disable(self): 220 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 221 x = torch.randn(32, requires_grad=True) 222 spec = DTensorSpec( 223 mesh, 224 (Replicate(),), 225 tensor_meta=TensorMeta(shape=torch.Size([32]), stride=(1,), dtype=x.dtype), 226 ) 227 228 @torch._dynamo.disable(recursive=False) 229 def fn(x): 230 print("foo") 231 return DTensor( 232 x, 233 spec, 234 requires_grad=x.requires_grad, 235 ) 236 237 out = fn(x) 238 out2 = torch.compile(fn, backend="eager")(x) 239 self.assertEqual(out, out2) 240 241 def test_dtensor_noncontiguous_output(self): 242 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 243 244 # test passing in DTensor as inputs/outputs and run some tensor computation 245 def fn(x, y, z): 246 x_transposed = x.permute(0, 2, 1).contiguous() 247 tmp = torch._C._nn.linear(x_transposed, y, z) 248 return tmp.permute(0, 2, 1) 249 250 x_inner = torch.randn(4, 16, 4, requires_grad=True) 251 y_inner = torch.randn(4, 16, requires_grad=True) 252 z_inner = torch.randn(4, requires_grad=True) 253 x = DTensor.from_local(x_inner, mesh, [Shard(1)], run_check=False) 254 y = DTensor.from_local(y_inner, mesh, [Shard(1)], run_check=False) 255 z = DTensor.from_local(z_inner, mesh, [Replicate()], run_check=False) 256 out = torch.compile(fn, backend="aot_eager", fullgraph=True)(x, y, z) 257 out.contiguous().sum().backward() 258 259 def test_dynamo_dtensor_from_local(self): 260 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 261 262 # create DTensor inside fn and run some compute 263 def fn(x): 264 dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) 265 return dt.to_local() + 2 266 267 # below is the op approach for reference 268 # from torch.distributed._tensor.api import _FromTorchTensor 269 # def from_local_tensor(x): 270 # return _FromTorchTensor.apply(x, mesh, [Replicate()], False) 271 272 # _dt_lib_def = torch.library.Library("dtensor", "DEF") 273 # _dt_lib_def.define("from_local(Tensor self) -> Tensor") 274 275 # _dt_lib_impl = torch.library.Library("dtensor", "IMPL") 276 # _dt_lib_impl.impl("from_local", from_local_tensor, "Autograd") 277 278 x = torch.ones(1, requires_grad=True) 279 ref = fn(x) 280 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 281 opt_fn = torch.compile(fn, backend=cnt, fullgraph=True) 282 res = opt_fn(x) 283 # backward should work as well 284 res.sum().backward() 285 286 self.assertEqual(res, ref) 287 self.assertEqual(cnt.frame_count, 1) 288 289 # test if user calls from_local with mesh/placements as kwargs and that should still work 290 def from_local_kwargs_fn(x): 291 dt = DTensor.from_local( 292 x, device_mesh=mesh, placements=[Replicate()], run_check=False 293 ) 294 return dt.to_local() + 2 295 296 ref = from_local_kwargs_fn(x) 297 opt_kwargs_fn = torch.compile(from_local_kwargs_fn, backend=cnt, fullgraph=True) 298 res = opt_kwargs_fn(x) 299 self.assertEqual(res, ref) 300 self.assertEqual(cnt.frame_count, 2) 301 302 def test_dynamo_dtensor_from_local_dynamic_shapes(self): 303 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 304 305 # Case 1: all dims dynamic 306 def fn(x): 307 dt = DTensor.from_local( 308 x, 309 mesh, 310 [Replicate()], 311 run_check=False, 312 shape=x.shape, 313 stride=x.stride(), 314 ) 315 return dt.to_local() + 2 316 317 inp = torch.randn(4, 6, requires_grad=True) 318 ref = fn(inp) 319 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 320 res = torch.compile(fn, backend=cnt, fullgraph=True, dynamic=True)(inp) 321 res.sum().backward() 322 323 self.assertEqual(res, ref) 324 self.assertEqual(cnt.frame_count, 1) 325 326 # Case 2: only sizes are dynamic, strides are static 327 def fn(x): 328 dt = DTensor.from_local( 329 x, mesh, [Replicate()], run_check=False, shape=x.shape, stride=(1,) 330 ) 331 return dt.to_local() + 2 332 333 inp = torch.randn(4, requires_grad=True) 334 torch._dynamo.mark_dynamic(inp, 0) 335 ref = fn(inp) 336 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 337 res = torch.compile(fn, backend=cnt, fullgraph=True)(inp) 338 res.sum().backward() 339 340 self.assertEqual(res, ref) 341 self.assertEqual(cnt.frame_count, 1) 342 343 # Case 3: both sizes and strides have a mix of dynamic and static dims 344 def fn(x): 345 dt = DTensor.from_local( 346 x, 347 mesh, 348 [Replicate()], 349 run_check=False, 350 shape=(x.shape[0], x.shape[1], 2), 351 stride=(x.stride()[0], 2, 1), 352 ) 353 return dt.to_local() + 2 354 355 inp = torch.randn(4, 6, 2, requires_grad=True) 356 torch._dynamo.mark_dynamic(inp, 0) 357 torch._dynamo.mark_dynamic(inp, 1) 358 ref = fn(inp) 359 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 360 res = torch.compile(fn, backend=cnt, fullgraph=True)(inp) 361 res.sum().backward() 362 363 self.assertEqual(res, ref) 364 self.assertEqual(cnt.frame_count, 1) 365 366 def test_dynamo_dtensor_recompile(self): 367 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 368 369 # test passing in DTensor as inputs/outputs and run some tensor computation 370 def fn(x): 371 return torch.mul(x, x) 372 373 x = DTensor.from_local(torch.rand(2, 2), mesh, [Shard(0)], run_check=False) 374 x2 = DTensor.from_local(torch.rand(2, 2), mesh, [Shard(0)], run_check=False) 375 x3 = DTensor.from_local(torch.rand(2, 2), mesh, [Shard(1)], run_check=False) 376 377 cnt = torch._dynamo.testing.CompileCounter() 378 opt_fn = torch.compile(fn, backend=cnt, fullgraph=True, dynamic=False) 379 self.assertEqual(fn(x), opt_fn(x)) 380 self.assertEqual(cnt.frame_count, 1) 381 self.assertEqual(fn(x2), opt_fn(x2)) 382 self.assertEqual(cnt.frame_count, 1) 383 self.assertEqual(fn(x3), opt_fn(x3)) 384 self.assertEqual(cnt.frame_count, 2) 385 386 def test_dtensor_partial_placement_redistribute_unbalanced_correct_strides(self): 387 # Partial -> Shard on an unbalanced tensor results in: 388 # - A contiguous DTensor 389 # - where the inner _local_tensor is noncontiguous 390 placement = Shard(1) 391 392 def fn(x): 393 out = x.redistribute(mesh, [placement]) 394 return out 395 396 # Temporarily ignore setUp(), and use rank3 graphs during tracing 397 dist.destroy_process_group() 398 fake_store = FakeStore() 399 dist.init_process_group("fake", store=fake_store, rank=3, world_size=2) 400 mesh = DeviceMesh(self.device_type, [1, 3]) 401 402 x = torch.randn(10, 257, 160, requires_grad=True) 403 x_dt = DTensor.from_local( 404 x, 405 mesh, 406 [Partial()], 407 run_check=False, 408 shape=(10, 257, 160), 409 stride=(41120, 160, 1), 410 ) 411 412 # tmp_dt has an inner, non-contiguous tensor, and is an autograd non-leaf 413 tmp_dt = fn(x_dt) 414 fake_mode = torch._subclasses.FakeTensorMode() 415 tmp_dt_fake = fake_mode.from_tensor(tmp_dt) 416 self.assertEqual(tmp_dt.shape, tmp_dt_fake.shape) 417 self.assertEqual(tmp_dt.stride(), tmp_dt_fake.stride()) 418 self.assertEqual(tmp_dt._local_tensor.shape, tmp_dt_fake._local_tensor.shape) 419 self.assertEqual( 420 tmp_dt._local_tensor.stride(), tmp_dt_fake._local_tensor.stride() 421 ) 422 423 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 424 def test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent(self): 425 # Partial -> Shard on an unbalanced tensor results in: 426 # - A contiguous DTensor 427 # - where the inner _local_tensor is noncontiguous 428 # When this tensor is a fwd graph output, 429 # AOTAutograd needs to make sure we trace the backward 430 # with a contiguous tangent 431 placement = Shard(1) 432 433 def fn(x): 434 out = x.redistribute(mesh, [placement]) 435 return out 436 437 # Temporarily ignore setUp(), and use rank3 graphs during tracing 438 dist.destroy_process_group() 439 fake_store = FakeStore() 440 dist.init_process_group("fake", store=fake_store, rank=3, world_size=2) 441 mesh = DeviceMesh(self.device_type, [1, 3]) 442 443 x = torch.randn(10, 257, 160, requires_grad=True) 444 x_dt = DTensor.from_local( 445 x, 446 mesh, 447 [Partial()], 448 run_check=False, 449 shape=(10, 257, 160), 450 stride=(41120, 160, 1), 451 ) 452 453 out_dt = torch.compile(fn)(x_dt) 454 # If we don't properly contiguify our traced tangents, 455 # this fails with an inductor stride assert 456 out_dt.to_local().sum().backward() 457 458 def test_dynamo_to_local_kwargs(self): 459 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 460 461 def fn(x): 462 return dt.to_local(grad_placements=[Shard(0)]) + 2 463 464 fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) 465 x = torch.ones(4) 466 dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) 467 468 out_ref = fn(dt) 469 out_test = fn_opt(dt) 470 self.assertEqual(out_ref, out_test) 471 472 def test_dynamo_to_local_kwargs_forward_hook(self): 473 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 474 475 def fw_hook(module, inp, out): 476 tmp = out.to_local(grad_placements=out.placements) + 2 477 return DTensor.from_local(tmp, mesh, out.placements, run_check=False) 478 479 mod = torch.nn.Linear(4, 4) 480 mod.register_forward_hook(fw_hook) 481 482 mod = torch.nn.Linear(4, 4) 483 mod.register_forward_hook(fw_hook) 484 mod.weight = torch.nn.Parameter( 485 DTensor.from_local(mod.weight, mesh, [Replicate()], run_check=False) 486 ) 487 mod.bias = torch.nn.Parameter( 488 DTensor.from_local(mod.bias, mesh, [Replicate()], run_check=False) 489 ) 490 opt_mod = torch.compile(mod, backend="aot_eager", fullgraph=True) 491 492 x = torch.ones(4, 4) 493 dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) 494 495 out_ref = mod(dt) 496 out_test = opt_mod(dt) 497 self.assertEqual(out_ref, out_test) 498 499 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 500 def test_dtensor_different_gradient_placement(self): 501 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 502 503 def fn(x, y, z): 504 permute = x.permute(0, 2, 1) 505 permute2 = permute.contiguous() 506 layer_norm = torch.nn.functional.layer_norm(permute2, (4,), y, z, 1e-05) 507 out = layer_norm.permute(0, 2, 1) 508 return out 509 510 x = torch.randn(4, 2, 4, requires_grad=True, device="cuda") 511 x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False) 512 513 y = torch.randn(4, requires_grad=True, device="cuda") 514 y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False) 515 516 z = torch.randn(4, requires_grad=True, device="cuda") 517 z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False) 518 519 opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) 520 tmp_dt = opt_fn(x_dt, y_dt, z_dt) 521 out_dt = torch.matmul(tmp_dt, x_dt).permute(0, 2, 1) 522 out_dt.sum().backward() 523 524 def test_dynamo_dtensor_from_local_redistribute(self): 525 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 526 527 # pass in tensor as inputs/outputs, create DTensor and run redistribute 528 # (allgather collective) inside the fn 529 def fn(x): 530 dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) 531 return dt.redistribute(mesh, [Replicate()]).to_local() + 2 532 533 x = torch.ones(1) 534 ref = fn(x) 535 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 536 opt_fn = torch.compile(fn, backend=cnt, fullgraph=True) 537 res = opt_fn(x) 538 self.assertEqual(res, ref) 539 540 def redistribute_kwargs_fn(x): 541 dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) 542 return ( 543 dt.redistribute(device_mesh=mesh, placements=[Replicate()]).to_local() 544 + 2 545 ) 546 547 x = torch.ones(1) 548 ref = redistribute_kwargs_fn(x) 549 opt_kwargs_fn = torch.compile( 550 redistribute_kwargs_fn, backend=cnt, fullgraph=True 551 ) 552 res = opt_kwargs_fn(x) 553 self.assertEqual(res, ref) 554 555 def test_dtensor_dont_recompile_on_same_placement_devicemesh(self): 556 cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") 557 558 @torch.compile(backend=cnt) 559 def fn(x): 560 dt = DTensor.from_local(x, mesh, [placement], run_check=False) 561 562 x = torch.ones(4, 4, requires_grad=True) 563 564 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 565 placement = Shard(1) 566 fn(x) 567 568 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 569 placement = Shard(1) 570 # no recompile, placement is unchanged 571 fn(x) 572 573 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 574 placement = Partial() 575 # recompile since placement is different 576 fn(x) 577 578 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 579 placement = Partial() 580 # no recompile, placement is unchanged 581 fn(x) 582 583 # 2 total frames (one for Partial(), one for Shard()) 584 self.assertEqual(cnt.frame_count, 2) 585 586 def test_dtensor_dynamo_device_mesh_attrs(self): 587 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 588 589 # pass in tensor as inputs/outputs, create DTensor and run redistribute 590 # (allgather collective) inside the fn 591 def fn(x_dt): 592 if x_dt.device_mesh.device_type == "cuda": 593 return x_dt + 1 594 else: 595 return x_dt + 2 596 597 x = torch.ones(4, 4) 598 x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) 599 ref = fn(x_dt) 600 601 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 602 res = opt_fn(x_dt) 603 self.assertEqual(ref, res) 604 605 def test_graph_input_is_async(self): 606 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 607 608 def fn(x): 609 return x.sin().sin() 610 611 opt_fn = torch.compile(fn, backend=aot_eager_graph, fullgraph=True) 612 613 x = torch.randn(4, 4, requires_grad=True) 614 x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) 615 x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True) 616 x2 = x2.to_local() 617 out = opt_fn(x2) 618 # The important part: we get a wait_tensor() in the graph. 619 # At runtime, the input to the graph is an AsyncCollectiveTensor, 620 # and inside the graph we need to issue a wait() to synchronize. 621 self.assertExpectedInline( 622 str(fw_graph_cell[0]).strip(), 623 """\ 624def forward(self, primals_1): 625 wait_tensor = torch.ops._c10d_functional.wait_tensor.default(primals_1) 626 sin = torch.ops.aten.sin.default(wait_tensor) 627 sin_1 = torch.ops.aten.sin.default(sin); sin = None 628 return (sin_1, primals_1, wait_tensor)""", 629 ) 630 631 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 632 def test_dtensor_partial_placement_graph_output(self): 633 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 634 635 def fn(x): 636 return x + x 637 638 x = torch.randn(4, 4, requires_grad=True) 639 x_dt = DTensor.from_local(x, mesh, [Partial()], run_check=False) 640 641 y = torch.randn(4, 4, requires_grad=True) 642 y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False) 643 644 opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) 645 tmp_dt = opt_fn(x_dt) 646 out_dt = torch.matmul(tmp_dt, y_dt) 647 out_dt.sum().backward() 648 649 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 650 @skip_if_lt_x_gpu(1) 651 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor 652 @patch.object(torch._inductor.config, "compile_threads", 1) 653 @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) 654 def test_tp_compile_comm_reordering(self): 655 class FakeAttention(nn.Module): 656 def __init__(self) -> None: 657 super().__init__() 658 self.wq = nn.Linear(16, 16) 659 self.wk = nn.Linear(16, 16) 660 self.wv = nn.Linear(16, 16) 661 self.wo = nn.Linear(16, 16) 662 663 def forward(self, x): 664 xq = self.wq(x) 665 xk = self.wk(x) 666 xv = self.wv(x) 667 # fake attention: 668 xo = xq + xk + xv 669 return self.wo(xo) 670 671 class FakeTransformerBlock(nn.Module): 672 def __init__(self) -> None: 673 super().__init__() 674 self.attn = FakeAttention() 675 676 def forward(self, x): 677 return self.attn(x) 678 679 class FakeTransformer(nn.Module): 680 def __init__(self) -> None: 681 super().__init__() 682 self.block = FakeTransformerBlock() 683 684 def forward(self, input): 685 return self.block(input) 686 687 model = FakeTransformer().to(self.device_type) 688 689 tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",)) 690 691 # apply sequence parallel 692 parallel_plan = { 693 "attn": PrepareModuleInput( 694 input_layouts=Shard(0), desired_input_layouts=Replicate() 695 ), 696 "attn.wq": ColwiseParallel(), 697 "attn.wk": ColwiseParallel(), 698 "attn.wv": ColwiseParallel(), 699 "attn.wo": RowwiseParallel(output_layouts=Shard(0)), 700 } 701 702 parallelize_module( 703 module=model.block, 704 device_mesh=tp_mesh, 705 parallelize_plan=parallel_plan, 706 ) 707 708 cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") 709 compiled_model = torch.compile(model, backend=cnt, fullgraph=True) 710 inp = torch.rand(20, 16).to(self.device_type) 711 out = compiled_model(inp) 712 out.sum().backward() 713 self.assertEqual(cnt.frame_count, 1) 714 715 code = run_and_get_triton_code(compiled_model, inp) 716 FileCheck().check( 717 "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal" 718 ).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check( 719 "extern_kernels.mm(buf0," 720 ).run( 721 code 722 ) 723 724 725@instantiate_parametrized_tests 726class TestDTensorCompileE2E(DTensorTestBase): 727 @property 728 def world_size(self): 729 return 4 730 731 @with_comms 732 @parametrize("is_seq_parallel", [True, False]) 733 def test_tp_compile_fullgraph(self, is_seq_parallel): 734 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 735 736 model = SimpleModel(self.device_type) 737 738 colwise_style = ( 739 ColwiseParallel(input_layouts=Shard(0)) 740 if is_seq_parallel 741 else ColwiseParallel() 742 ) 743 rowwise_style = ( 744 RowwiseParallel(output_layouts=Shard(0)) 745 if is_seq_parallel 746 else RowwiseParallel() 747 ) 748 749 if is_seq_parallel: 750 # use input preparation to test out the compile of it 751 prepare_module_input = PrepareModuleInput( 752 input_layouts=Shard(0), 753 desired_input_layouts=Replicate(), 754 ) 755 prepare_module_out = PrepareModuleOutput( 756 output_layouts=Replicate(), 757 desired_output_layouts=Shard(0), 758 ) 759 plan = { 760 "mlp_0": prepare_module_input, 761 "mlp_0.net1": ColwiseParallel(), 762 "mlp_0.net2": rowwise_style, 763 "mlp_1.net1": colwise_style, 764 "mlp_1.net2": RowwiseParallel(), 765 "mlp_1": prepare_module_out, 766 } 767 else: 768 plan = { 769 "mlp_0.net1": colwise_style, 770 "mlp_0.net2": rowwise_style, 771 "mlp_1.net1": colwise_style, 772 "mlp_1.net2": rowwise_style, 773 } 774 775 model = parallelize_module( 776 model, 777 mesh, 778 parallelize_plan=plan, 779 ) 780 rng_seed = self.rank if is_seq_parallel else 0 781 torch.manual_seed(rng_seed) 782 inp = torch.rand(20, 10, device=self.device_type) 783 out = model(inp) 784 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 785 compiled_mod = torch.compile(model, backend=cnt, fullgraph=True) 786 compiled_out = compiled_mod(inp) 787 compiled_out.sum().backward() 788 self.assertEqual(compiled_out, out) 789 self.assertEqual(cnt.frame_count, 1) 790 791 @with_comms 792 @skip_if_lt_x_gpu(4) 793 def test_2d_fsdp_tp_compile(self): 794 data_parallel_size = 2 795 model = SimpleModel(self.device_type) 796 model_copy = copy.deepcopy(model) 797 798 # 2-D mesh is [dp, tp] 799 twod_mesh = init_device_mesh( 800 "cuda", 801 (data_parallel_size, self.world_size // data_parallel_size), 802 mesh_dim_names=["dp", "tp"], 803 ) 804 805 fsdp_pg = twod_mesh.get_group(mesh_dim=0) 806 807 inp = torch.rand(20, 10, device=self.device_type) 808 parallelize_plan = { 809 "mlp_0.net1": ColwiseParallel(), 810 "mlp_0.net2": RowwiseParallel(), 811 "mlp_1.net1": ColwiseParallel(), 812 "mlp_1.net2": RowwiseParallel(), 813 } 814 tp_model = parallelize_module(model, twod_mesh["tp"], parallelize_plan) 815 eager_2d = FSDP( 816 tp_model, 817 device_id=self.rank, 818 use_orig_params=True, 819 device_mesh=twod_mesh["dp"], 820 ) 821 out = eager_2d(inp) 822 tp_model2 = parallelize_module( 823 model_copy, 824 twod_mesh["tp"], 825 parallelize_plan, 826 ) 827 fsdp_2d = FSDP( 828 tp_model2, 829 device_id=self.rank, 830 use_orig_params=True, 831 device_mesh=twod_mesh["dp"], 832 ) 833 834 # TODO: once aot autograd support is ready we can just use default backend 835 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 836 compiled_2d = torch.compile(fsdp_2d, backend=cnt) 837 compiled_output = compiled_2d(inp) 838 839 self.assertEqual(out, compiled_output) 840 self.assertEqual(cnt.frame_count, 1) 841 842 @with_comms 843 @skip_if_lt_x_gpu(4) 844 def test_2d_fsdp_tp_ac_compile(self): 845 dp_degree = 2 846 tp_degree = self.world_size // dp_degree 847 model = SimpleModel(self.device_type) 848 model_copy = copy.deepcopy(model) 849 850 # 2-D mesh is [dp, tp] 851 mesh_2d = init_device_mesh( 852 "cuda", mesh_shape=(dp_degree, tp_degree), mesh_dim_names=("dp", "tp") 853 ) 854 855 inp = torch.rand(20, 10, device=self.device_type) 856 parallelize_plan = { 857 "mlp_0.net1": ColwiseParallel(), 858 "mlp_0.net2": RowwiseParallel(), 859 "mlp_1.net1": ColwiseParallel(), 860 "mlp_1.net2": RowwiseParallel(), 861 } 862 tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan) 863 tp_model = checkpoint_wrapper( 864 tp_model, 865 checkpoint_impl=CheckpointImpl.NO_REENTRANT, 866 checkpoint_fn=checkpoint, 867 use_reentrant=False, 868 ) 869 eager_2d = FSDP(tp_model, device_mesh=mesh_2d["dp"], use_orig_params=True) 870 871 tp_model2 = parallelize_module(model_copy, mesh_2d["tp"], parallelize_plan) 872 fsdp_2d = FSDP( 873 tp_model2, 874 device_mesh=mesh_2d["dp"], 875 use_orig_params=True, 876 ) 877 # TODO: once aot autograd support is ready we can just use default backend 878 compiled_2d = torch.compile(fsdp_2d, backend="aot_eager") 879 880 # forward pass 881 out = eager_2d(inp) 882 compiled_output = compiled_2d(inp) 883 self.assertEqual(out, compiled_output) 884 885 # backward pass 886 out.sum().backward() 887 compiled_output.sum().backward() 888 889 # compare the gradients: 890 for n, p in zip(fsdp_2d.parameters(), compiled_2d.parameters()): 891 self.assertEqual(n.grad, p.grad) 892 893 @with_comms 894 @skip_if_lt_x_gpu(4) 895 def test_compile_dtensor_redistribute_backward(self): 896 mesh = DeviceMesh(device_type="cuda", mesh=torch.arange(self.world_size)) 897 898 def fn(x, y): 899 dt = DTensor.from_local(x.reshape(2, 4), mesh, [Shard(0)], run_check=False) 900 dt2 = DTensor.from_local(y.reshape(4, 2), mesh, [Shard(1)], run_check=False) 901 dt_out = torch.matmul(dt, dt2) 902 dt_out_redistribute = dt_out.redistribute(mesh, [Replicate()]) 903 return dt_out_redistribute.to_local() 904 905 opt_fn = torch.compile(fn, backend=aot_eager_graph, fullgraph=True) 906 907 x_ref = torch.arange(8, requires_grad=True, dtype=torch.float32) 908 y_ref = torch.arange(8, requires_grad=True, dtype=torch.float32) 909 ref = fn(x_ref, y_ref) 910 911 x = torch.arange(8, requires_grad=True, dtype=torch.float32) 912 y = torch.arange(8, requires_grad=True, dtype=torch.float32) 913 res = opt_fn(x, y) 914 915 self.assertEqual(res, ref) 916 917 # Now run and assert the backward + gradients 918 ref.sum().backward() 919 res.sum().backward() 920 921 self.assertEqual(x_ref.grad, x.grad) 922 self.assertEqual(y_ref.grad, y.grad) 923 924 925if __name__ == "__main__": 926 run_tests() 927