1# Owner(s): ["oncall: distributed"] 2 3 4import contextlib 5import copy 6import functools 7import unittest 8from unittest import mock 9 10import torch 11import torch._dynamo.testing 12import torch.distributed._composable.fsdp._fsdp_param 13import torch.nn.functional as F 14from torch import nn 15from torch._dynamo import compiled_autograd 16from torch._inductor import comms 17from torch._inductor.utils import is_fallback_op, run_and_get_code 18from torch.distributed._composable.fsdp import fully_shard 19from torch.distributed._composable.fsdp._fsdp_common import TrainingState 20from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup 21from torch.distributed._tensor import init_device_mesh 22from torch.testing import FileCheck 23from torch.testing._internal.common_distributed import at_least_x_gpu, skip_if_lt_x_gpu 24from torch.testing._internal.common_fsdp import FSDPTest, MLP 25from torch.testing._internal.common_utils import run_tests, skipIfRocm 26from torch.testing._internal.distributed._tensor.common_dtensor import ( 27 ModelArgs, 28 Transformer, 29) 30from torch.utils._triton import has_triton 31 32 33def _is_op_in_graph(graph, op): 34 return any(node.target is op for node in graph.nodes) 35 36 37def _is_fallback_op_in_snodes(snodes, op): 38 return any(is_fallback_op(snode.node, op) for snode in snodes) 39 40 41class TestFullyShardCompileCompute(FSDPTest): 42 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 43 @skip_if_lt_x_gpu(2) 44 def test_disable_compiling_hooks(self): 45 self.run_subtests( 46 { 47 "skip_fsdp_hooks": [False, True], 48 }, 49 self._test_disable_compiling_hooks, 50 ) 51 52 def _test_disable_compiling_hooks( 53 self, 54 skip_fsdp_hooks: bool, 55 ): 56 torch._dynamo.reset() 57 trace_rules_check_count = 0 58 HOOKS_FILE_NAME = "torch/distributed/_composable/fsdp/_fsdp_state.py" 59 HOOK_WRAPPER_NAME = "fsdp_hook_wrapper" 60 61 def patched_trace_rules_check(*args, **kwargs): 62 nonlocal trace_rules_check_count 63 f_code = args[0] 64 if ( 65 hasattr(f_code, "co_filename") 66 and f_code.co_filename.endswith(HOOKS_FILE_NAME) 67 and f_code.co_name != HOOK_WRAPPER_NAME 68 ): 69 trace_rules_check_count += 1 70 return orig_trace_rules_check(*args, **kwargs) 71 72 original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks 73 orig_trace_rules_check = torch._dynamo.trace_rules.check 74 torch.distributed.barrier() 75 torch._dynamo.config.skip_fsdp_hooks = skip_fsdp_hooks 76 torch._dynamo.trace_rules.check = patched_trace_rules_check 77 model = MLP(4) 78 fully_shard(model) 79 model.compile() 80 model(torch.randn((4, 4), device="cuda")) 81 torch.distributed.barrier() 82 torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks 83 torch._dynamo.trace_rules.check = orig_trace_rules_check 84 if skip_fsdp_hooks: 85 self.assertEqual(trace_rules_check_count, 0) 86 else: 87 self.assertTrue(trace_rules_check_count > 0) 88 89 90class TestFullyShardCompile(FSDPTest): 91 fake_pg = not at_least_x_gpu(2) 92 93 @property 94 def world_size(self) -> int: 95 return 2 96 97 def test_dynamo_trace_use_training_state(self): 98 torch._dynamo.reset() 99 # Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager. 100 param_group = FSDPParamGroup( 101 [], # params: List[nn.Parameter], 102 (torch.nn.Linear(1, 1),), # module: Tuple[nn.Module, ...], 103 None, # mesh_info: FSDPMeshInfo, 104 None, # post_forward_mesh_info: Optional[FSDPMeshInfo], 105 None, # device: torch.device, 106 None, # mp_policy: MixedPrecisionPolicy, 107 None, # offload_policy: OffloadPolicy, 108 ) 109 110 def f(x): 111 param_group._training_state = TrainingState.IDLE 112 with param_group.use_training_state(TrainingState.FORWARD): 113 if param_group._training_state == TrainingState.FORWARD: 114 return x + 1 115 else: 116 return x 117 118 inp = torch.zeros(1) 119 self.assertEqual(param_group._training_state, TrainingState.IDLE) 120 121 eager_out = f(inp) 122 self.assertEqual(param_group._training_state, TrainingState.IDLE) 123 self.assertEqual(eager_out, inp + 1) 124 125 cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 126 compiled_out = torch.compile(f, backend=cnt, fullgraph=True)(inp) 127 self.assertEqual(param_group._training_state, TrainingState.IDLE) 128 self.assertEqual(eager_out, compiled_out) 129 self.assertEqual(cnt.frame_count, 1) 130 self.assertEqual(cnt.op_count, 1) 131 self.assertEqual(len(cnt.graphs), 1) 132 133 def test_trace_fsdp_set_(self): 134 @torch.library.custom_op("mylib::add_one_out", mutates_args={"out"}) 135 def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None: 136 torch.add(x, 1, out=out) 137 138 def f(x): 139 buf = torch.zeros(2) 140 buf_view = buf.view(-1) 141 torch.ops.mylib.add_one_out(x, out=buf_view) 142 buf_view2 = buf.view(-1) 143 torch.ops.fsdp.set_(x, buf_view2) 144 145 ref_x = torch.zeros(2) 146 x = copy.deepcopy(ref_x) 147 f(ref_x) 148 torch.compile(f, backend="aot_eager")(x) 149 self.assertEqual(x, ref_x) 150 151 def _reinplace_all_gather_with_optional_checks(self, fullgraph): 152 def _run_with_checks(graph, orig_fn): 153 self.assertTrue( 154 _is_op_in_graph( 155 graph, 156 torch.ops._c10d_functional.all_gather_into_tensor.default, 157 ) 158 ) 159 orig_fn(graph) 160 self.assertFalse( 161 _is_op_in_graph( 162 graph, 163 torch.ops._c10d_functional.all_gather_into_tensor.default, 164 ) 165 ) 166 self.assertTrue( 167 _is_op_in_graph( 168 graph, 169 torch.ops._c10d_functional.all_gather_into_tensor_out.default, 170 ) 171 ) 172 173 if fullgraph: 174 return mock.patch.object( 175 comms, 176 "reinplace_fsdp_all_gather", 177 functools.partial( 178 _run_with_checks, 179 orig_fn=comms.reinplace_fsdp_all_gather, 180 ), 181 ) 182 else: 183 return contextlib.nullcontext() 184 185 def _is_fwd_graph(self, snodes): 186 ag_copy_in_snode = None 187 for snode in snodes: 188 if is_fallback_op(snode.node, torch.ops.fsdp.all_gather_copy_in.default): 189 ag_copy_in_snode = snode 190 break 191 self.assertTrue(ag_copy_in_snode is not None) 192 if any( 193 dep.name.startswith("primals_") 194 for dep in ag_copy_in_snode.read_writes.reads 195 ): 196 return True 197 else: 198 return False 199 200 def _maybe_run_decide_global_ordering_of_comms_with_checks(self, fullgraph): 201 def _check_fsdp_ops_in_snodes(snodes, is_fwd_graph, expect=True): 202 assert_method = self.assertTrue if expect else self.assertFalse 203 common_ops = { 204 torch.ops.fsdp.all_gather_copy_in.default, 205 torch.ops._c10d_functional.all_gather_into_tensor_out.default, 206 torch.ops.fsdp.split_with_sizes_copy.default, 207 } 208 bwd_only_ops = { 209 torch.ops.fsdp.chunk_cat.default, 210 torch.ops._c10d_functional.reduce_scatter_tensor.default, 211 } 212 for op in common_ops: 213 assert_method( 214 _is_fallback_op_in_snodes( 215 snodes, 216 op, 217 ), 218 msg=f"{op}", 219 ) 220 if not is_fwd_graph: 221 for op in bwd_only_ops: 222 assert_method( 223 _is_fallback_op_in_snodes( 224 snodes, 225 op, 226 ), 227 msg=f"{op}", 228 ) 229 230 def _decide_global_ordering_of_comms_with_checks( 231 snodes, name_to_buf, name_to_fused_node, orig_fn 232 ): 233 is_fwd_graph = self._is_fwd_graph(snodes) 234 _check_fsdp_ops_in_snodes(snodes, is_fwd_graph, expect=True) 235 new_snodes = orig_fn(snodes, name_to_buf, name_to_fused_node) 236 _check_fsdp_ops_in_snodes(new_snodes, is_fwd_graph, expect=False) 237 return new_snodes 238 239 if fullgraph: 240 return mock.patch.object( 241 comms, 242 "decide_global_ordering_of_comms", 243 functools.partial( 244 _decide_global_ordering_of_comms_with_checks, 245 orig_fn=comms.decide_global_ordering_of_comms, 246 ), 247 ) 248 else: 249 return contextlib.nullcontext() 250 251 def inductor_code_check_no_compute_op(self, file_check): 252 return ( 253 file_check.check_not(" = aten.") 254 .check_not(" = extern_kernels.") 255 .check_not(" = triton_") 256 .check_not(" = torch.ops.") 257 .check_not(" = inductor_ops.") 258 .check_not(" aten.") 259 .check_not(" extern_kernels.") 260 .check_not(" triton_") 261 .check_not(" torch.ops.") 262 .check_not(" inductor_ops.") 263 ) 264 265 def inductor_code_check_fsdp_all_gather( 266 self, 267 file_check, 268 overlapped_compute_op_str, 269 num_resize, 270 num_set, 271 last_all_gather=False, 272 ): 273 file_check = file_check.check("torch.ops.fsdp.all_gather_copy_in.") 274 file_check = self.inductor_code_check_no_compute_op(file_check) 275 file_check = file_check.check( 276 "torch.ops._c10d_functional.all_gather_into_tensor_out." 277 ) 278 # Checks that AGWait is delayed, making the AG overlap with some compute op. 279 if overlapped_compute_op_str is not None: 280 file_check = file_check.check(f"{overlapped_compute_op_str}") 281 file_check = file_check.check_count( 282 "inductor_ops.resize_storage_bytes_(", num_resize, exactly=True 283 ) 284 file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.") 285 file_check = self.inductor_code_check_no_compute_op(file_check) 286 file_check = file_check.check("torch.ops.fsdp.split_with_sizes_copy.") 287 file_check = self.inductor_code_check_no_compute_op(file_check) 288 file_check = file_check.check_count( 289 "torch.ops.aten.set_.", num_set, exactly=True 290 ) 291 if not last_all_gather: 292 # Checks that there is no compute op between this AGWait and next AG. 293 file_check = self.inductor_code_check_no_compute_op(file_check) 294 return file_check 295 296 def inductor_code_check_fsdp_reduce_scatter( 297 self, file_check, overlapped_compute_op_str 298 ): 299 file_check = file_check.check("torch.ops.fsdp.chunk_cat.") 300 file_check = self.inductor_code_check_no_compute_op(file_check) 301 file_check = file_check.check( 302 "torch.ops._c10d_functional.reduce_scatter_tensor." 303 ) 304 # Checks that RSWait is delayed, making the RS overlap with some compute op. 305 if overlapped_compute_op_str is not None: 306 file_check = file_check.check(f"{overlapped_compute_op_str}") 307 file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.") 308 return file_check 309 310 @torch._dynamo.config.patch( 311 inline_inbuilt_nn_modules=True, 312 skip_fsdp_hooks=False, 313 ) 314 @torch._functorch.config.patch(recompute_views=True) 315 @torch._functorch.config.patch(cse=False) 316 @torch._inductor.config.patch( 317 reorder_for_compute_comm_overlap=True, 318 reorder_for_compute_comm_overlap_passes=[ 319 "sink_waits", 320 "raise_comms", 321 "reorder_compute_for_overlap", 322 ], 323 ) 324 def _test_traceable_fsdp( 325 self, model_init_fn, input_creation_fn, backend, fullgraph 326 ): 327 def compiler_fn(compiled_autograd_backend): 328 def _fn(gm): 329 # fullgraph=True because graph-break in Compiled Autograd BWD graph is not supported by Traceable FSDP2 yet 330 # (main difficulty comes from queue_callback not working well when BWD has graph break). 331 return torch.compile( 332 gm, backend=compiled_autograd_backend, fullgraph=True 333 ) 334 335 return _fn 336 337 def run_iters(model, optim, n_iter=10, compiled_autograd_backend=None): 338 torch.manual_seed(42) 339 losses = [] 340 for i in range(n_iter): 341 inp = input_creation_fn() 342 if compiled_autograd_backend is not None: 343 maybe_compiled_autograd_ctx = compiled_autograd.enable( 344 compiler_fn(compiled_autograd_backend) 345 ) 346 else: 347 maybe_compiled_autograd_ctx = contextlib.nullcontext() 348 with maybe_compiled_autograd_ctx: 349 out = model(inp) 350 loss = out.sum() 351 losses.append(loss.item()) 352 loss.backward() 353 optim.step() 354 optim.zero_grad(set_to_none=True) 355 return losses 356 357 def test_compiled(): 358 model, optim = model_init_fn() 359 # FSDP2 does lazy init using 1st run, so run it once to init using eager mode 360 run_iters(model, optim, n_iter=1) 361 362 model_compiled = torch.compile(model, backend=backend, fullgraph=fullgraph) 363 res = run_iters(model_compiled, optim, compiled_autograd_backend=backend) 364 return res 365 366 def test_eager(): 367 model, optim = model_init_fn() 368 # FSDP2 does lazy init using 1st run, so run it once to init using eager mode 369 run_iters(model, optim, n_iter=1) 370 371 res = run_iters(model, optim) 372 return res 373 374 losses_compiled = test_compiled() 375 losses_eager = test_eager() 376 if not self.fake_pg: 377 for loss_compiled, loss_eager in zip(losses_compiled, losses_eager): 378 self.assertTrue( 379 torch.allclose( 380 torch.tensor(loss_compiled), 381 torch.tensor(loss_eager), 382 rtol=1e-5, 383 atol=1e-8, 384 ), 385 f"{loss_compiled} vs {loss_eager}", 386 ) 387 388 def _create_simple_mlp_factory_fns(self): 389 hidden_dim = 16 390 391 def model_init_fn(): 392 torch.manual_seed(self.rank) 393 fsdp_config = {} 394 model = nn.Sequential( 395 nn.Linear(hidden_dim, hidden_dim, device="cuda"), 396 nn.ReLU(), 397 nn.Linear(hidden_dim, hidden_dim, device="cuda"), 398 nn.ReLU(), 399 nn.Linear(hidden_dim, hidden_dim, device="cuda"), 400 ) 401 fully_shard(model, reshard_after_forward=True, **fsdp_config) 402 optim = torch.optim.SGD(model.parameters(), lr=1e-4) 403 return model, optim 404 405 def input_creation_fn(): 406 torch.manual_seed(self.rank) 407 inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False) 408 return inp 409 410 return model_init_fn, input_creation_fn 411 412 @skipIfRocm 413 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 414 def test_simple_mlp_fullgraph_backend_aot_eager(self): 415 self._test_traceable_fsdp( 416 *self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True 417 ) 418 419 @skipIfRocm 420 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 421 def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self): 422 self._test_traceable_fsdp( 423 *self._create_simple_mlp_factory_fns(), 424 "aot_eager_decomp_partition", 425 fullgraph=True, 426 ) 427 428 @skipIfRocm 429 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 430 def test_simple_mlp_fullgraph_backend_inductor(self): 431 self._test_traceable_fsdp( 432 *self._create_simple_mlp_factory_fns(), "inductor", fullgraph=True 433 ) 434 435 def _create_nested_fully_shard_factory_fns(self, fullgraph): 436 hidden_dim = 16 437 438 class TestSubmodule(nn.Module): 439 def __init__(self, hidden_dim): 440 super().__init__() 441 self.param1 = nn.Parameter( 442 torch.zeros( 443 hidden_dim, hidden_dim, dtype=torch.float, device="cuda" 444 ) 445 ) 446 self.param2 = nn.Parameter( 447 torch.zeros(hidden_dim, dtype=torch.float, device="cuda") 448 ) 449 450 def forward(self, x): 451 if not fullgraph: 452 torch._dynamo.graph_break() 453 ret = torch.matmul(x, self.param1) 454 ret = ret * self.param2 455 ret = torch.relu(ret) 456 return ret 457 458 class TestModule(nn.Module): 459 def __init__(self, n_layers): 460 super().__init__() 461 self.layers = torch.nn.ModuleList() 462 for layer_id in range(n_layers): 463 self.layers.append(TestSubmodule(hidden_dim)) 464 465 def forward(self, x): 466 # Intentionally reusing all layers a few times, 467 # to test "multiple all-gathers for the same parameter" case. 468 for layer in self.layers: 469 x = layer(x) 470 for layer in self.layers: 471 x = layer(x) 472 for layer in self.layers: 473 x = layer(x) 474 return x 475 476 def model_init_fn(): 477 torch.manual_seed(self.rank) 478 fsdp_config = {} 479 mesh = init_device_mesh("cuda", (self.world_size,)) 480 model = TestModule(n_layers=3) 481 for layer_id, mod in enumerate(model.layers): 482 fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config) 483 model = fully_shard( 484 model, mesh=mesh, reshard_after_forward=True, **fsdp_config 485 ) 486 optim = torch.optim.SGD(model.parameters(), lr=1e-4) 487 return model, optim 488 489 def input_creation_fn(): 490 torch.manual_seed(self.rank) 491 inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False) 492 return inp 493 494 return model_init_fn, input_creation_fn 495 496 @skipIfRocm 497 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 498 def test_nested_fully_shard_backend_aot_eager(self): 499 for fullgraph in [True, False]: 500 self._test_traceable_fsdp( 501 *self._create_nested_fully_shard_factory_fns(fullgraph=fullgraph), 502 "aot_eager", 503 fullgraph=fullgraph, 504 ) 505 506 @skipIfRocm 507 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 508 def test_nested_fully_shard_backend_aot_eager_decomp_partition(self): 509 for fullgraph in [True, False]: 510 self._test_traceable_fsdp( 511 *self._create_nested_fully_shard_factory_fns(fullgraph=fullgraph), 512 "aot_eager_decomp_partition", 513 fullgraph=fullgraph, 514 ) 515 516 @skipIfRocm 517 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 518 def test_nested_fully_shard_backend_inductor(self): 519 for fullgraph in [True, False]: 520 with self._reinplace_all_gather_with_optional_checks( 521 fullgraph 522 ), self._maybe_run_decide_global_ordering_of_comms_with_checks(fullgraph): 523 _, triton_codes = run_and_get_code( 524 lambda: self._test_traceable_fsdp( 525 *self._create_nested_fully_shard_factory_fns( 526 fullgraph=fullgraph 527 ), 528 "inductor", 529 fullgraph=fullgraph, 530 ) 531 ) 532 if fullgraph: 533 self.assertTrue( 534 len(triton_codes) == 2, 535 "Expected two separate lowerings to Triton code, one from FWD graph and one from Compiled Autograd BWD graph", 536 ) 537 fwd_code = triton_codes[0] 538 file_check = FileCheck().check("def call(args):") 539 for fwd_ag_block_info in [ 540 dict(overlapped_compute_op_str=None, num_resize=0, num_set=2), 541 dict( 542 overlapped_compute_op_str="extern_kernels.mm(", 543 num_resize=2, 544 num_set=2, 545 ), 546 dict( 547 overlapped_compute_op_str="extern_kernels.mm(", 548 num_resize=2, 549 num_set=2, 550 ), 551 dict( 552 overlapped_compute_op_str="extern_kernels.mm(", 553 num_resize=2, 554 num_set=2, 555 ), 556 dict( 557 overlapped_compute_op_str="extern_kernels.mm(", 558 num_resize=2, 559 num_set=2, 560 ), 561 dict( 562 overlapped_compute_op_str="extern_kernels.mm(", 563 num_resize=2, 564 num_set=2, 565 ), 566 dict( 567 overlapped_compute_op_str="extern_kernels.mm(", 568 num_resize=2, 569 num_set=2, 570 ), 571 dict( 572 overlapped_compute_op_str="extern_kernels.mm(", 573 num_resize=2, 574 num_set=2, 575 ), 576 dict( 577 overlapped_compute_op_str="extern_kernels.mm(", 578 num_resize=2, 579 num_set=2, 580 last_all_gather=True, 581 ), 582 ]: 583 file_check = self.inductor_code_check_fsdp_all_gather( 584 file_check, **fwd_ag_block_info 585 ) 586 file_check.run(fwd_code) 587 588 bwd_code = triton_codes[1] 589 file_check = FileCheck().check("def call(args):") 590 for bwd_ag_block_info in [ 591 dict(overlapped_compute_op_str=None, num_resize=0, num_set=2), 592 dict( 593 overlapped_compute_op_str="extern_kernels.mm(", 594 num_resize=0, 595 num_set=2, 596 ), 597 dict( 598 overlapped_compute_op_str="extern_kernels.mm(", 599 num_resize=0, 600 num_set=2, 601 last_all_gather=True, 602 ), 603 ]: 604 file_check = self.inductor_code_check_fsdp_all_gather( 605 file_check, **bwd_ag_block_info 606 ) 607 for bwd_rs_block_info in [ 608 dict(overlapped_compute_op_str="extern_kernels.mm("), 609 dict( 610 overlapped_compute_op_str=None 611 ), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None 612 dict(overlapped_compute_op_str=None), 613 ]: 614 file_check = self.inductor_code_check_fsdp_reduce_scatter( 615 file_check, **bwd_rs_block_info 616 ) 617 file_check.run(bwd_code) 618 else: 619 # TODO: when fullgraph=False and there is graph break in FWD graph, 620 # there are several recompiles, need to figure out why. 621 self.assertTrue( 622 len(triton_codes) > 2, 623 "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", 624 ) 625 626 def _create_transformer_factory_fns(self): 627 seq_len = 16 628 vocab_size = 8 629 630 def model_init_fn(): 631 torch.manual_seed(self.rank) 632 fsdp_config = {} 633 mesh = init_device_mesh("cuda", (self.world_size,)) 634 model_args = ModelArgs( 635 vocab_size=vocab_size, 636 n_layers=3, 637 ) 638 model = Transformer(model_args) 639 for layer_id, mod in enumerate(model.layers): 640 fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config) 641 model = fully_shard( 642 model, mesh=mesh, reshard_after_forward=True, **fsdp_config 643 ) 644 optim = torch.optim.SGD(model.parameters(), lr=1e-4) 645 return model, optim 646 647 def input_creation_fn(): 648 torch.manual_seed(self.rank) 649 inp = torch.randint( 650 0, vocab_size, (2, seq_len), device="cuda", requires_grad=False 651 ) 652 return inp 653 654 return model_init_fn, input_creation_fn 655 656 def _maybe_add_graph_break_to_sdpa(self, fullgraph): 657 def _sdpa_with_graph_break(orig_fn, fullgraph, *args, **kwargs): 658 if not fullgraph: 659 torch._dynamo.graph_break() 660 return orig_fn(*args, **kwargs) 661 662 return mock.patch.object( 663 F, 664 "scaled_dot_product_attention", 665 functools.partial( 666 _sdpa_with_graph_break, 667 F.scaled_dot_product_attention, 668 fullgraph, 669 ), 670 ) 671 672 @skipIfRocm 673 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 674 def test_transformer_backend_aot_eager(self): 675 for fullgraph in [True, False]: 676 with self._maybe_add_graph_break_to_sdpa( 677 fullgraph 678 ), self._reinplace_all_gather_with_optional_checks(fullgraph): 679 self._test_traceable_fsdp( 680 *self._create_transformer_factory_fns(), 681 "aot_eager", 682 fullgraph=fullgraph, 683 ) 684 685 @skipIfRocm 686 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 687 # TODO: native_dropout has worse accuracy after decomp, need to figure out why 688 @torch._inductor.config.patch(fallback_random=True) 689 def test_transformer_backend_aot_eager_decomp_partition(self): 690 for fullgraph in [True, False]: 691 with self._maybe_add_graph_break_to_sdpa(fullgraph): 692 self._test_traceable_fsdp( 693 *self._create_transformer_factory_fns(), 694 "aot_eager_decomp_partition", 695 fullgraph=fullgraph, 696 ) 697 698 @skipIfRocm 699 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 700 # TODO: native_dropout causes CUDA IMA error, need to figure out why 701 @torch._inductor.config.patch(fallback_random=True) 702 def test_transformer_backend_inductor(self): 703 for fullgraph in [True, False]: 704 with self._maybe_add_graph_break_to_sdpa( 705 fullgraph 706 ), self._reinplace_all_gather_with_optional_checks( 707 fullgraph 708 ), self._maybe_run_decide_global_ordering_of_comms_with_checks( 709 fullgraph 710 ): 711 _, triton_codes = run_and_get_code( 712 lambda: self._test_traceable_fsdp( 713 *self._create_transformer_factory_fns(), 714 "inductor", 715 fullgraph=fullgraph, 716 ) 717 ) 718 if fullgraph: 719 self.assertTrue( 720 len(triton_codes) == 2, 721 "Expected two separate lowerings to Triton code, one from FWD graph and one from Compiled Autograd BWD graph", 722 ) 723 fwd_code = triton_codes[0] 724 file_check = FileCheck().check("def call(args):") 725 for fwd_ag_block_info in [ 726 dict(overlapped_compute_op_str="triton_", num_resize=0, num_set=4), 727 dict( 728 overlapped_compute_op_str="aten.native_dropout.", 729 num_resize=0, 730 num_set=12, 731 ), 732 dict( 733 overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.", 734 num_resize=12, 735 num_set=12, 736 ), 737 dict( 738 overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.", 739 num_resize=12, 740 num_set=12, 741 last_all_gather=True, 742 ), 743 ]: 744 file_check = self.inductor_code_check_fsdp_all_gather( 745 file_check, **fwd_ag_block_info 746 ) 747 file_check.run(fwd_code) 748 749 bwd_code = triton_codes[1] 750 file_check = FileCheck().check("def call(args):") 751 for bwd_ag_block_info in [ 752 dict( 753 overlapped_compute_op_str="extern_kernels.mm(", 754 num_resize=0, 755 num_set=12, 756 ), 757 dict( 758 overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.", 759 num_resize=0, 760 num_set=12, 761 ), 762 dict( 763 overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.", 764 num_resize=0, 765 num_set=12, 766 last_all_gather=True, 767 ), 768 ]: 769 file_check = self.inductor_code_check_fsdp_all_gather( 770 file_check, **bwd_ag_block_info 771 ) 772 for bwd_rs_block_info in [ 773 dict(overlapped_compute_op_str="extern_kernels.mm("), 774 dict( 775 overlapped_compute_op_str=None 776 ), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None 777 dict(overlapped_compute_op_str=None), 778 dict(overlapped_compute_op_str=None), 779 ]: 780 file_check = self.inductor_code_check_fsdp_reduce_scatter( 781 file_check, **bwd_rs_block_info 782 ) 783 file_check.run(bwd_code) 784 else: 785 # TODO: when fullgraph=False and there is graph break in FWD graph, 786 # there are several recompiles, need to figure out why. 787 self.assertTrue( 788 len(triton_codes) > 2, 789 "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", 790 ) 791 792 793if __name__ == "__main__": 794 run_tests() 795