xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/fsdp/test_fully_shard_training.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import copy
5import functools
6import unittest
7from typing import Iterable, List, Tuple, Type, Union
8
9import torch
10import torch.distributed as dist
11import torch.distributed.checkpoint as dcp
12import torch.nn as nn
13from torch.distributed._composable import checkpoint, replicate
14from torch.distributed._composable.fsdp import (
15    CPUOffloadPolicy,
16    FSDPModule,
17    fully_shard,
18    OffloadPolicy,
19    register_fsdp_forward_method,
20)
21from torch.distributed._tensor import DTensor, init_device_mesh
22from torch.distributed._tensor.debug.comm_mode import CommDebugMode
23from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
24    _CHECKPOINT_PREFIX,
25    apply_activation_checkpointing,
26    CheckpointWrapper,
27)
28from torch.distributed.checkpoint.state_dict import (
29    get_model_state_dict,
30    get_optimizer_state_dict,
31)
32from torch.distributed.device_mesh import DeviceMesh
33from torch.testing._internal.common_cuda import TEST_CUDA
34from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
35from torch.testing._internal.common_fsdp import (
36    check_sharded_parity,
37    FSDPTest,
38    FSDPTestMultiThread,
39    MLP,
40    MLPStack,
41    patch_all_gather,
42    patch_reduce_scatter,
43    test_compiled_fsdp,
44)
45from torch.testing._internal.common_utils import (
46    get_cycles_per_ms,
47    run_tests,
48    skipIfRocm,
49    wrapSwapTensorsTest,
50)
51from torch.testing._internal.distributed._tensor.common_dtensor import (
52    ModelArgs,
53    Transformer,
54    TransformerBlock,
55)
56from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
57
58c10d_ops = torch.ops.c10d
59funcol = torch.ops.c10d_functional
60
61
62class TestFullyShardForwardInputs(FSDPTestMultiThread):
63    @property
64    def world_size(self) -> int:
65        return 2
66
67    @unittest.skipIf(not TEST_CUDA, "no cuda")
68    def test_root_move_forward_input_to_device(self):
69        device = torch.device("cuda", 0)
70
71        class ParamlessModule(nn.Module):
72            def forward(self, x: torch.Tensor, ys: Tuple[torch.Tensor, ...]):
73                # Check that FSDP moved the inputs to GPU, including recursing
74                # into the tuple data structure
75                assert x.device == device, f"Expects {device} but got {x.device}"
76                assert (
77                    ys[0].device == device
78                ), f"Expects {device} but got {ys[0].device}"
79                assert (
80                    ys[1].device == device
81                ), f"Expects {device} but got {ys[1].device}"
82                y = ys[0] + ys[1]
83                return x + y + 1
84
85        model = ParamlessModule()
86        fully_shard(model)
87        x = torch.randn((3,))
88        ys = (torch.randn((3,)), torch.randn((3,)))
89        self.assertEqual(x.device, torch.device("cpu"))
90        self.assertEqual(ys[0].device, torch.device("cpu"))
91        self.assertEqual(ys[1].device, torch.device("cpu"))
92        model(x, ys)
93
94
95class TestFullyShardRegisteredParams(FSDPTestMultiThread):
96    @property
97    def world_size(self) -> int:
98        return 4
99
100    @unittest.skipIf(not TEST_CUDA, "no cuda")
101    def test_param_registration_after_forward(self):
102        """Tests the parameter registration after forward."""
103        device = torch.device("cuda", 0)
104        # Single FSDP group
105        for reshard_after_forward in (True, False, 2):
106            torch.manual_seed(42)
107            model = MLP(3, device)
108            # Since seed is per process, not per thread, we broadcast to ensure
109            # the same parameters across ranks
110            for param in model.parameters():
111                dist.broadcast(param, src=0)
112            ref_model = copy.deepcopy(model)
113            fully_shard(model, reshard_after_forward=reshard_after_forward)  # root only
114            inp = torch.randn((2, 3), device="cuda")
115            self._assert_dtensor_params(model.parameters())
116            self._assert_same_params(model.parameters(), ref_model.parameters())
117            model(inp)  # root does not reshard after forward
118            self._assert_tensor_params(model.parameters())
119            self._assert_same_params(model.parameters(), ref_model.parameters())
120            model.reshard()  # however, we can manually reshard
121            self._assert_dtensor_params(model.parameters())
122            self._assert_same_params(model.parameters(), ref_model.parameters())
123
124        # Multiple FSDP groups
125        for reshard_after_forward in (True, False, 2):
126            torch.manual_seed(42)
127            model = nn.Sequential(MLP(3, device), MLP(3, device))
128            for param in model.parameters():
129                dist.broadcast(param, src=0)
130            ref_model = copy.deepcopy(model)
131            fully_shard(model[0].in_proj, reshard_after_forward=reshard_after_forward)
132            fully_shard(model[0].out_proj, reshard_after_forward=reshard_after_forward)
133            fully_shard(model, reshard_after_forward=reshard_after_forward)
134
135            self._assert_dtensor_params(model.parameters())
136            self._assert_same_params(model.parameters(), ref_model.parameters())
137            model(inp)
138            non_root_params = list(model[0].in_proj.parameters()) + list(
139                model[0].out_proj.parameters()
140            )
141            root_params = list(set(model.parameters()) - set(non_root_params))
142            if reshard_after_forward is False:
143                self._assert_tensor_params(non_root_params)
144            else:
145                self._assert_dtensor_params(non_root_params)
146            self._assert_tensor_params(root_params)
147            self._assert_same_params(model.parameters(), ref_model.parameters())
148            for module in model.modules():
149                if isinstance(module, FSDPModule):
150                    module.reshard()  # however, we can manually reshard
151            self._assert_dtensor_params(model.parameters())
152            self._assert_same_params(model.parameters(), ref_model.parameters())
153
154    @unittest.skipIf(not TEST_CUDA, "no cuda")
155    def test_param_registration_after_backward(self):
156        """Tests the parameter registration after backward."""
157        device = torch.device("cuda", 0)
158        # Single FSDP group
159        for reshard_after_forward in (True, False, 2):
160            model = MLP(8, device)
161            fully_shard(model, reshard_after_forward=reshard_after_forward)  # root only
162            inp = torch.randn((2, 8), device="cuda")
163            self._assert_dtensor_params(model.parameters())
164            model(inp).sum().backward()
165            self._assert_dtensor_params(model.parameters())
166
167        # Multiple FSDP groups
168        for reshard_after_forward in (True, False, 2):
169            model = MLP(8, device)
170            fully_shard(model.in_proj, reshard_after_forward=reshard_after_forward)
171            fully_shard(model.out_proj, reshard_after_forward=reshard_after_forward)
172            fully_shard(model, reshard_after_forward=reshard_after_forward)
173            self._assert_dtensor_params(model.parameters())
174            model(inp).sum().backward()
175            self._assert_dtensor_params(model.parameters())
176
177    def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
178        self.assertGreater(len(list(params)), 0)
179        for param in params:
180            self.assertNotIsInstance(param, DTensor)
181            self.assertIsInstance(param, torch.Tensor)
182
183    def _assert_dtensor_params(self, params: Iterable[nn.Parameter]):
184        self.assertGreater(len(list(params)), 0)
185        for param in params:
186            self.assertIsInstance(param, DTensor)
187
188    def _assert_same_params(
189        self, params: Iterable[nn.Parameter], ref_params: Iterable[nn.Parameter]
190    ):
191        params, ref_params = list(params), list(ref_params)
192        self.assertEqual(len(params), len(ref_params))
193        for param, ref_param in zip(params, ref_params):
194            if isinstance(param, DTensor):
195                param = param.full_tensor()
196            self.assertEqual(param.shape, ref_param.shape)
197            self.assertEqual(param, ref_param)
198
199
200class TestFullyShardCastAfterInit(FSDPTestMultiThread):
201    @property
202    def world_size(self) -> int:
203        return 2
204
205    @unittest.skipIf(not TEST_CUDA, "no cuda")
206    @wrapSwapTensorsTest(True)
207    def test_to_float64_after_init(self):
208        """Tests that the user can cast the module to float64 after init."""
209        # NOTE: Test fp64 instead of a lower precision dtype like bf16 for
210        # better numerics. The important part is changing the dtype.
211        torch.manual_seed(42)
212        mlp_dim, device, dtype = 4, torch.device("cuda"), torch.float64
213        model = MLP(mlp_dim, device=device)
214        for param in model.parameters():
215            dist.broadcast(param, src=0)
216        ref_model = copy.deepcopy(model).to(dtype)
217        replicate(ref_model)
218        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
219        for module in (model.in_proj, model.out_proj, model):
220            fully_shard(module)
221        model.to(dtype)
222        for param in model.parameters():
223            self.assertEqual(param.dtype, dtype)
224        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
225        check_sharded_parity(self, ref_model, model)
226        torch.manual_seed(42 + self.rank + 1)
227        inp = torch.randn((2, mlp_dim), device="cuda", dtype=dtype)
228        for iter_idx in range(10):
229            losses: List[torch.Tensor] = []
230            for _model in (ref_model, model):
231                losses.append(_model(inp).sum())
232                losses[-1].backward()
233            self.assertEqual(losses[0], losses[1])
234            check_sharded_parity(self, ref_model, model)
235            for _optim in (ref_optim, optim):
236                _optim.step()
237                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
238
239
240class TestFullyShard1DTrainingCore(FSDPTest):
241    @property
242    def world_size(self) -> int:
243        return min(8, torch.cuda.device_count())
244
245    @skip_if_lt_x_gpu(2)
246    def test_train_parity_single_group(self):
247        """Tests train parity with DDP for a single FSDP group."""
248        self.run_subtests(
249            {
250                "lin_shapes": [[(16, 15), (15, 8)], [(7, 15), (15, 3)]],
251            },
252            self._test_train_parity_single_group,
253        )
254
255    def _test_train_parity_single_group(self, lin_shapes: List[Tuple[int, int]]):
256        torch.manual_seed(42)
257        model = nn.Sequential(
258            nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
259        )
260        ref_model = copy.deepcopy(model).cuda()
261        replicate(ref_model, device_ids=[self.rank])
262        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
263        fully_shard(model)
264        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
265        torch.manual_seed(42 + self.rank + 1)
266        inp = (torch.randn((4, lin_shapes[0][0]), device="cuda"),)
267        for iter_idx in range(10):
268            losses: List[torch.Tensor] = []
269            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
270                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
271                losses.append(_model(*inp).sum())
272                losses[-1].backward()
273                _optim.step()
274            self.assertEqual(losses[0], losses[1])
275
276    @skip_if_lt_x_gpu(2)
277    @test_compiled_fsdp(compile_compute_on_module=Transformer)
278    def test_train_parity_multi_group(self):
279        """
280        Tests train parity against DDP when using multiple parameter groups for
281        communication (for communication and computation overlap plus memory
282        reduction).
283        """
284        self.run_subtests(
285            {
286                "reshard_after_forward": [True, False, 2],
287                "device_type": ["cuda"],
288                "offload_policy": [OffloadPolicy()],
289                "delay_after_forward": [False, True],
290                "delay_before_all_gather": [False, True],
291                "delay_before_reduce_scatter": [False, True],
292                "delay_before_optim": [False, True],
293            },
294            self._test_train_parity_multi_group,
295        )
296
297    @skip_if_lt_x_gpu(2)
298    def test_train_parity_multi_group_cpu_offload_eager(self):
299        """
300        Tests train parity against DDP when using multiple parameter groups for
301        communication and CPU offloading.
302        """
303        self.run_subtests(
304            {
305                "reshard_after_forward": [True],  # save CI time
306                "offload_policy": [
307                    CPUOffloadPolicy(pin_memory=True),
308                    CPUOffloadPolicy(pin_memory=False),
309                ],
310                "device_type": ["cuda"],
311                "delay_after_forward": [False, True],
312                "delay_before_all_gather": [False, True],
313                "delay_before_reduce_scatter": [False, True],
314                "delay_before_optim": [False, True],
315            },
316            self._test_train_parity_multi_group,
317        )
318
319    def _test_train_parity_multi_group(
320        self,
321        reshard_after_forward: Union[bool, int],
322        offload_policy: OffloadPolicy,
323        device_type: str,
324        delay_after_forward: bool,
325        delay_before_all_gather: bool,
326        delay_before_reduce_scatter: bool,
327        delay_before_optim: bool,
328    ):
329        # Only test individual delays or all four delays to save test time
330        if (
331            delay_after_forward
332            + delay_before_all_gather
333            + delay_before_reduce_scatter
334            + delay_before_optim
335            in (2, 3)
336        ):
337            return
338        assert device_type in ("cuda", "cpu"), f"{device_type}"
339        torch.manual_seed(42)
340        lin_dim = 32
341        vocab_size = 1024
342        model_args = ModelArgs(
343            n_layers=3,
344            n_heads=4,
345            vocab_size=vocab_size,
346            max_seq_len=64,
347            dropout_p=0,
348        )
349        model = Transformer(model_args)
350        ref_model = copy.deepcopy(model)
351        if device_type == "cuda":
352            replicate(ref_model.cuda(), device_ids=[self.rank])
353        else:
354            gloo_pg = dist.new_group(backend="gloo")
355            replicate(ref_model, process_group=gloo_pg)
356        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
357        mesh = init_device_mesh(device_type, (self.world_size,))
358        fully_shard_fn = functools.partial(
359            fully_shard,
360            mesh=mesh,
361            reshard_after_forward=reshard_after_forward,
362            offload_policy=offload_policy,
363        )
364        for module in model.modules():
365            if isinstance(module, TransformerBlock):
366                fully_shard_fn(module)
367        fully_shard_fn(model)
368        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
369
370        delay_in_ms = 100
371        orig_all_gather = dist.all_gather_into_tensor
372        orig_reduce_scatter = dist.reduce_scatter_tensor
373
374        def delayed_all_gather(*args, **kwargs):
375            torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
376            return orig_all_gather(*args, **kwargs)
377
378        def delayed_reduce_scatter(*args, **kwargs):
379            torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
380            return orig_reduce_scatter(*args, **kwargs)
381
382        torch.manual_seed(42 + self.rank + 1)
383        patch_all_gather_ctx = (
384            patch_all_gather(delayed_all_gather)
385            if delay_before_all_gather
386            else contextlib.nullcontext()
387        )
388        patch_reduce_scatter_ctx = (
389            patch_reduce_scatter(delayed_reduce_scatter)
390            if delay_before_reduce_scatter
391            else contextlib.nullcontext()
392        )
393        with patch_all_gather_ctx, patch_reduce_scatter_ctx:
394            for iter_idx in range(10):
395                inp = torch.randint(0, vocab_size, (3, 64), device=device_type)
396                losses: List[torch.Tensor] = []
397                for _model, _optim in ((ref_model, ref_optim), (model, optim)):
398                    _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
399                    losses.append(_model(inp).sum())
400                    if _model is model and delay_after_forward:
401                        torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
402                    losses[-1].backward()
403                    if _model is model and delay_before_optim:
404                        torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
405                    _optim.step()
406                self.assertEqual(losses[0], losses[1])
407
408    @skip_if_lt_x_gpu(2)
409    def test_non_root_forward_backward(self):
410        """
411        Tests running forward/backward through the root and then through a
412        non-root. The non-root needs to synchronize streams/queue the callback.
413        """
414        torch.manual_seed(42)
415        lin_dim = 32
416        model = nn.Sequential(*[MLP(lin_dim, torch.device("cpu")) for _ in range(3)])
417        ref_model = copy.deepcopy(model).cuda()
418        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
419        for mlp in model:
420            fully_shard(mlp)
421        fully_shard(model)
422        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
423        torch.manual_seed(42 + self.rank)
424        inp = torch.randn((8, lin_dim), device=torch.device("cuda"))
425
426        ref_root_loss = ref_model(inp).sum()
427        ref_root_loss.backward()
428        for param in ref_model.parameters():
429            dist.all_reduce(param.grad)
430            param.grad.detach().div_(self.world_size)
431        ref_optim.step()
432        ref_optim.zero_grad()
433        ref_nonroot_loss = ref_model[0](inp).sum()
434        ref_nonroot_loss.backward()
435        for param in ref_model.parameters():
436            if param.grad is not None:
437                dist.all_reduce(param.grad)
438                param.grad.detach().div_(self.world_size)
439        ref_optim.step()
440
441        root_loss = model(inp).sum()
442        root_loss.backward()
443        torch.cuda._sleep(int(100 * get_cycles_per_ms()))
444        optim.step()
445        optim.zero_grad()
446        nonroot_loss = model[0](inp).sum()
447        nonroot_loss.backward()
448        optim.step()
449
450        self.assertEqual(ref_root_loss, root_loss)
451        self.assertEqual(ref_nonroot_loss, nonroot_loss)
452        self.assertEqual(ref_model(inp).sum(), model(inp).sum())
453
454    @skip_if_lt_x_gpu(2)
455    def test_multi_forward_module(self):
456        """
457        Tests parity with DDP when running a module that participates multiple
458        times in forward.
459        """
460        self.run_subtests(
461            {"reshard_after_forward": [True, False, 2]},
462            self._test_multi_forward_module,
463        )
464
465    def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]):
466        class MultiForwardModule(nn.Module):
467            def __init__(self, device: torch.device):
468                super().__init__()
469                self.inner = nn.Linear(4, 4, device=device)
470                self.outer = nn.Linear(4, 5, device=device)
471
472            def forward(self, x):
473                i = self.inner(x)
474                j = self.inner(x)
475                return self.outer(i + j)
476
477        torch.manual_seed(42)
478        model = MultiForwardModule(device="cuda")
479        ref_model = copy.deepcopy(model)
480        replicate(ref_model, device_ids=[self.rank])
481        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
482        fully_shard(model.inner)
483        fully_shard(model)
484        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
485
486        torch.manual_seed(42 + self.rank)
487        inp = torch.randn((32, 4), device="cuda")
488        for iter_idx in range(10):
489            losses: List[torch.Tensor] = []
490            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
491                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
492                losses.append(_model(inp).sum())
493                losses[-1].backward()
494                _optim.step()
495            self.assertEqual(losses[0], losses[1])
496
497
498class TestFullyShard1DTrainingCompose(FSDPTest):
499    @property
500    def world_size(self) -> int:
501        # Since these tests run with a larger transformer model, they may see
502        # some numeric drift with >2 GPUs
503        return min(torch.cuda.device_count(), 2)
504
505    @skip_if_lt_x_gpu(2)
506    @test_compiled_fsdp(compile_compute_on_module=Transformer)
507    def test_train_parity_with_activation_checkpointing(self):
508        """
509        Tests train parity against DDP when composing with activation
510        checkpointing.
511        """
512        self.run_subtests(
513            {
514                "reshard_after_forward": [True, False, 2],
515                "checkpoint_impl": ["composable", "utils", "wrapper"],
516            },
517            self._test_train_parity_with_activation_checkpointing,
518        )
519
520    def _test_train_parity_with_activation_checkpointing(
521        self, reshard_after_forward: Union[bool, int], checkpoint_impl: str
522    ):
523        assert checkpoint_impl in ("composable", "utils", "wrapper")
524        testing_compile = fully_shard != torch.distributed._composable.fsdp.fully_shard
525        if testing_compile and checkpoint_impl == "composable":
526            return
527        torch.manual_seed(42)
528        vocab_size = 1024
529        with torch.device(torch.device("cuda")):
530            model_args = ModelArgs(
531                n_layers=3,
532                n_heads=4,
533                vocab_size=vocab_size,
534                max_seq_len=64,
535                dropout_p=0,
536                checkpoint_activations=(checkpoint_impl == "utils"),
537            )
538            model = Transformer(model_args)
539        ref_model = replicate(copy.deepcopy(model), device_ids=[self.rank])
540        foreach = True
541        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
542        fully_shard_fn = functools.partial(
543            fully_shard,
544            reshard_after_forward=reshard_after_forward,
545        )
546        if checkpoint_impl == "wrapper":
547            prefixes_to_ignore = (_CHECKPOINT_PREFIX,)
548            apply_activation_checkpointing(
549                model, check_fn=lambda m: isinstance(m, TransformerBlock)
550            )
551            for module in model.modules():
552                # Apply to `CheckpointWrapper`, which wraps `TransformerBlock`
553                if isinstance(module, CheckpointWrapper):
554                    fully_shard_fn(module)
555        else:
556            prefixes_to_ignore = ()
557            for module in model.modules():
558                if isinstance(module, TransformerBlock):
559                    if checkpoint_impl == "composable":
560                        checkpoint(module)
561                    fully_shard_fn(module)
562        fully_shard_fn(model)
563        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
564
565        torch.manual_seed(42 + self.rank)
566        # Reuse the same input across iterations to avoid loss explosion from
567        # trying to learn from random inputs
568        inp = torch.randint(0, vocab_size, (3, 64), device="cuda")
569        check_sharded_parity(
570            self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
571        )
572        for iter_idx in range(10):
573            losses: List[torch.Tensor] = []
574            for _model in (ref_model, model):
575                torch.manual_seed(iter_idx + 1)  # for dropout determinism
576                losses.append(_model(inp).sum())
577                losses[-1].backward()
578            if not testing_compile:
579                check_sharded_parity(
580                    self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
581                )
582            self.assertEqual(losses[0], losses[1])
583            for _optim in (ref_optim, optim):
584                _optim.step()
585                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
586            if not testing_compile:
587                check_sharded_parity(
588                    self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
589                )
590
591
592class TestFullyShardSharedParams(FSDPTest):
593    @property
594    def world_size(self) -> int:
595        return min(4, torch.cuda.device_count())
596
597    @skip_if_lt_x_gpu(2)
598    def test_train_parity_with_shared_params(self):
599        self.run_subtests(
600            {
601                "reshard_after_forward": [False, True],
602                "use_activation_checkpointing": [False, True],
603            },
604            self._test_train_shared_params,
605        )
606
607    def _test_train_shared_params(
608        self,
609        reshard_after_forward: bool,
610        use_activation_checkpointing: bool,
611    ):
612        torch.manual_seed(42)
613        model_args = ModelArgs(n_layers=3, dropout_p=0.0, weight_tying=True)
614        model = Transformer(model_args)
615        ref_model = copy.deepcopy(model).cuda()
616        replicate(ref_model, device_ids=[self.rank])
617        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
618        for module in model.modules():
619            if isinstance(module, TransformerBlock):
620                if use_activation_checkpointing:
621                    checkpoint(module)
622                fully_shard(module, reshard_after_forward=reshard_after_forward)
623        fully_shard(model, reshard_after_forward=reshard_after_forward)
624        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
625
626        torch.manual_seed(42 + self.rank + 1)
627        for iter_idx in range(10):
628            inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
629            losses: List[torch.Tensor] = []
630            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
631                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
632                losses.append(_model(inp).sum())
633                losses[-1].backward()
634                _optim.step()
635            self.assertEqual(losses[0], losses[1])
636
637
638class TestFullyShardGradientAccumulation(FSDPTest):
639    @property
640    def world_size(self) -> int:
641        return min(4, torch.cuda.device_count())
642
643    @skip_if_lt_x_gpu(2)
644    def test_gradient_accumulation(self):
645        """
646        Tests gradient accumulation with/without gradient reduction and
647        with/without resharding after backward.
648        """
649        meshes = [init_device_mesh("cuda", (self.world_size,))]  # always test FSDP
650        if self.world_size == 4:  # test HSDP too if enough GPUs
651            shard_size, replicate_size = 2, 2
652            meshes.append(init_device_mesh("cuda", (replicate_size, shard_size)))
653        self.run_subtests(
654            {
655                "mesh": meshes,
656                "reshard_after_forward": [True, False, 2],
657                # "all": disable reduce-scatter for all modules
658                # "root_only": disable reduce-scatter for root's linear only
659                # "some_mlps": disable reduce-scatter for some MLPs
660                "mode": ["all", "root_only", "some_mlps"],
661                "reshard_after_backward": [False, True],
662                "offload_policy": [OffloadPolicy(), CPUOffloadPolicy()],
663                # For HSDP only:
664                # `True`: reduce-scatter only (no all-reduce) each microbatch
665                # until the last microbatch
666                # `False`: neither reduce-scatter nor all-reduce each
667                # microbatch until the last microbatch
668                "reduce_scatter_only": [False, True],
669            },
670            self._test_gradient_accumulation,
671        )
672
673    def _test_gradient_accumulation(
674        self,
675        mesh: DeviceMesh,
676        reshard_after_forward: Union[bool, int],
677        mode: str,
678        reshard_after_backward: bool,
679        offload_policy: OffloadPolicy,
680        reduce_scatter_only: bool,  # for HSDP
681    ):
682        if (
683            (
684                not reshard_after_backward
685                and (reshard_after_forward is not False or mode == "some_mlps")
686            )
687            or (
688                isinstance(offload_policy, CPUOffloadPolicy)
689                and reshard_after_forward is not True
690            )
691            or (mesh.ndim != 2 and reduce_scatter_only)
692        ):
693            return  # skip since not common or applicable
694
695        torch.manual_seed(42)
696        batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3)
697        if mode == "some_mlps":
698            num_mlps_to_disable_reduce_scatter = 2
699        modules = [nn.Linear(lin_dim, lin_dim)]
700        modules.extend(MLP(lin_dim) for _ in range(num_mlps))
701        model = nn.Sequential(*modules)
702        ref_model = copy.deepcopy(model).cuda()
703        fully_shard_fn = functools.partial(
704            fully_shard,
705            mesh=mesh,
706            reshard_after_forward=reshard_after_forward,
707            offload_policy=offload_policy,
708        )
709        for mlp in model[1:]:
710            fully_shard_fn(mlp)
711        fully_shard_fn(model)  # root gets the 1st linear
712        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
713        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
714
715        def set_grad_sync_flag(
716            module: nn.Module, is_last_microbatch: bool, recurse: bool = True
717        ):
718            if reduce_scatter_only:
719                module.set_requires_all_reduce(is_last_microbatch, recurse=recurse)
720            else:
721                module.set_requires_gradient_sync(is_last_microbatch, recurse=recurse)
722
723        def set_backward_flags(_model: nn.Module, is_last_microbatch: bool):
724            if mode == "all":
725                set_grad_sync_flag(_model, is_last_microbatch)
726                if not reshard_after_backward:
727                    _model.set_reshard_after_backward(is_last_microbatch)
728            elif mode == "some_mlps":
729                for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
730                    set_grad_sync_flag(mlp, is_last_microbatch)
731                    if not reshard_after_backward:
732                        mlp.set_reshard_after_backward(is_last_microbatch)
733            elif mode == "root_only":
734                set_grad_sync_flag(model, is_last_microbatch, recurse=False)
735                if not reshard_after_backward:
736                    model.set_reshard_after_backward(is_last_microbatch, recurse=False)
737
738        torch.manual_seed(42 + self.rank + 1)
739        for iter_idx in range(5):
740            with CommDebugMode() as comm_mode:
741                for microbatch_idx in range(num_microbatches):
742                    is_last_microbatch = microbatch_idx == num_microbatches - 1
743                    set_backward_flags(model, is_last_microbatch)
744                    inp = torch.randn(batch_size, lin_dim, device="cuda")
745                    losses: List[torch.Tensor] = []
746                    for _model in (ref_model, model):
747                        losses.append(_model(inp).sum())
748                        losses[-1].backward()
749                    self.assertEqual(losses[0], losses[1])
750
751            comm_counts = comm_mode.get_comm_counts()
752            all_gather_count = comm_counts[c10d_ops._allgather_base_]
753            reduce_scatter_count = comm_counts[c10d_ops._reduce_scatter_base_]
754            all_reduce_count = comm_counts[c10d_ops.allreduce_]
755
756            # Expect one reduce-scatter per MLP plus one for the root's linear
757            # on the last microbatch
758            expected_reduce_scatter_count = num_mlps + 1
759            if mode == "some_mlps":
760                # Expect additional reduce-scatters for non-disabled MLPs and
761                # the root's linear
762                expected_reduce_scatter_count += (
763                    num_mlps - num_mlps_to_disable_reduce_scatter + 1
764                ) * (num_microbatches - 1)
765            elif mode == "root_only":
766                # Expect additional reduce-scatters for all MLPs
767                expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1)
768            expected_all_reduce_count = (
769                expected_reduce_scatter_count if mesh.ndim == 2 else 0
770            )
771            if reduce_scatter_only:
772                # Specially for HSDP if only reduce-scattering but not
773                # all-reducing until the last microbatch, expect one
774                # reduce-scatter per MLP plus for the root per microbatch
775                expected_reduce_scatter_count = (num_mlps + 1) * num_microbatches
776            self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
777            self.assertEqual(all_reduce_count, expected_all_reduce_count)
778
779            # Expect one all-gather per MLP plus one for the root's linear in
780            # the first microbatch's forward
781            expected_all_gather_count = num_mlps + 1
782            if reshard_after_forward is not False:  # `True` or `2`
783                # Add the number of MLPs without the +1 for the backward
784                # all-gathers since the root does not reshard after forward
785                expected_all_gather_count += num_mlps
786                # Multiply by the number of microbatches since these
787                # all-gathers run every microbatch
788                expected_all_gather_count *= num_microbatches
789            elif reshard_after_backward:  # `reshard_after_forward=False`
790                expected_all_gather_count *= num_microbatches
791            elif mode == "all":  # `reshard_after_forward/backward=False`
792                # Only reshard parameters after the last microbatch's backward,
793                # so there should not be any more all-gathers
794                pass
795            elif mode == "root_only":  # `reshard_after_forward/backward=False`
796                # The MLPs should still contribute all-gathers in each
797                # microbatch forward
798                expected_all_gather_count += num_mlps * (num_microbatches - 1)
799            self.assertEqual(all_gather_count, expected_all_gather_count)
800
801            for param in ref_model.parameters():
802                if param.grad is not None:
803                    dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
804            check_sharded_parity(self, ref_model, model)
805            for _optim in (optim, ref_optim):
806                _optim.step()
807                # When `set_to_none=False`, we are exercising mixing
808                # gradient accumulation with and without communication
809                _optim.zero_grad(set_to_none=(iter_idx % 2))
810
811    @skip_if_lt_x_gpu(2)
812    def test_1f1b_microbatching(self):
813        self.run_subtests(
814            {
815                "use_explicit_unshard": [False, True],
816                "reshard_after_backward": [False, True],
817            },
818            self._test_1f1b_microbatching,
819        )
820
821    def _test_1f1b_microbatching(
822        self, use_explicit_unshard: bool, reshard_after_backward: bool
823    ):
824        torch.manual_seed(42)
825        model_args = ModelArgs(dropout_p=0.0)
826        model = Transformer(model_args)
827        ref_model = copy.deepcopy(model).cuda()
828        ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
829        for module in model.modules():
830            if isinstance(module, TransformerBlock):
831                fully_shard(module, reshard_after_forward=False)
832        fully_shard(model, reshard_after_forward=False)
833        optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
834
835        num_microbatches = 3
836        local_batch_size = 2
837        torch.manual_seed(42 + self.rank + 1)
838        inps = [
839            torch.randint(
840                0, model_args.vocab_size, (local_batch_size, 16), device="cuda"
841            )
842            for _ in range(num_microbatches)
843        ]
844
845        # Before pipelining, we may prefer to issue all all-gathers ahead of
846        # time to increase overlap opportunity at no difference in parameter
847        # memory usage since we do not reshard after forward
848        if use_explicit_unshard:
849            for module in model.modules():
850                if isinstance(module, FSDPModule):
851                    module.unshard(async_op=True)
852
853        # Emulate the 1f1b pipeline schedule and only reduce gradients on the
854        # last microbatch
855        losses: List[torch.Tensor] = []
856        ref_losses: List[torch.Tensor] = []
857        for inp_idx, inp in enumerate(inps):
858            is_last_microbatch = inp_idx == num_microbatches - 1
859            model.set_requires_gradient_sync(is_last_microbatch)
860            model.set_is_last_backward(is_last_microbatch)
861            if not reshard_after_backward:
862                model.set_reshard_after_backward(is_last_microbatch)
863            losses.append(model(inp).sum())
864            losses[-1].backward()
865            ref_losses.append(ref_model(inp).sum())
866            ref_losses[-1].backward()
867        for param in ref_model.parameters():
868            dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
869
870        for loss, ref_loss in zip(losses, ref_losses):
871            self.assertEqual(loss, ref_loss)
872        optim.step()
873        ref_optim.step()
874        check_sharded_parity(self, ref_model, model)
875
876
877class TestFullyShard2DTraining(FSDPTest):
878    @property
879    def world_size(self) -> int:
880        return min(4, torch.cuda.device_count())
881
882    def init_global_mesh(self) -> DeviceMesh:
883        # Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP
884        dp_size = 2 if self.world_size > 2 else 1
885        return init_device_mesh(
886            "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
887        )
888
889    @skip_if_lt_x_gpu(2)
890    @skipIfRocm
891    def test_train_parity_2d_mlp(self):
892        global_mesh = self.init_global_mesh()
893        self.run_subtests(
894            {
895                "reshard_after_forward": [False, True],
896                "use_activation_checkpointing": [False, True],
897                "mlp_dim": [3, 16, 17],
898            },
899            functools.partial(self._test_train_parity_2d_mlp, global_mesh),
900        )
901
902    def _test_train_parity_2d_mlp(
903        self,
904        global_mesh: DeviceMesh,
905        reshard_after_forward: bool,
906        use_activation_checkpointing: bool,
907        mlp_dim: int,
908    ):
909        dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
910        dp_pg = dp_mesh.get_group()  # used for `replicate()`
911
912        torch.manual_seed(42)
913        model = MLPStack(mlp_dim)
914        ref_model = copy.deepcopy(model).cuda()
915        replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
916        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
917        model.parallelize(
918            tp_mesh,
919            dp_mesh,
920            use_activation_checkpointing,
921            reshard_after_forward=reshard_after_forward,
922        )
923        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
924
925        torch.manual_seed(42 + dp_pg.rank() + 1)
926        device = torch.device("cuda")
927        for iter_idx in range(10):
928            inp = torch.randn((8, mlp_dim), device=device)
929            losses: List[torch.Tensor] = []
930            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
931                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
932                losses.append(_model(inp).sum())
933                losses[-1].backward()
934                _optim.step()
935            self.assertEqual(losses[0], losses[1])
936
937    @skip_if_lt_x_gpu(2)
938    @skipIfRocm
939    def test_tp_with_fsdp_offloading(self):
940        global_mesh = init_device_mesh(
941            "cuda", (1, self.world_size), mesh_dim_names=("dp", "tp")
942        )
943        dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
944        torch.manual_seed(42)
945        mlp_dim = 16
946        model = MLPStack(mlp_dim)
947        ref_model = copy.deepcopy(model).cuda()
948        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
949        # Parallelize with N-way TP and 1-way FSDP
950        model.parallelize(
951            tp_mesh,
952            dp_mesh,
953            use_activation_checkpointing=False,
954            reshard_after_forward=True,
955            offload_policy=CPUOffloadPolicy(),
956        )
957        for param in model.parameters():
958            self.assertEqual(param.device.type, "cpu")
959        num_mlps = sum(isinstance(module, MLP) for module in model.modules())
960        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
961
962        # NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
963        # called, but they will just be no-ops without issuing any kernels.
964        # We prefer to keep the no-op check at the c10d level, not in FSDP.
965        inp = torch.randn((4, mlp_dim), device="cuda")  # same on all ranks
966        for iter_idx in range(10):
967            ref_optim.zero_grad()
968            optim.zero_grad()
969
970            with CommDebugMode() as fwd_comm_mode:
971                loss = model(inp).sum()
972
973            fwd_comm_counts = fwd_comm_mode.get_comm_counts()
974            self.assertEqual(len(fwd_comm_counts), 2)
975            self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps)
976            self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
977            ref_loss = ref_model(inp).sum()
978            self.assertEqual(loss, ref_loss)
979
980            with CommDebugMode() as bwd_comm_mode:
981                loss.backward()
982            bwd_comm_counts = bwd_comm_mode.get_comm_counts()
983            self.assertEqual(len(bwd_comm_counts), 3)
984            # First MLP's input gradient does not need to be all-reduced
985            self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1)
986            self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
987            self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps)
988            ref_loss.backward()
989
990            optim.step()
991            ref_optim.step()
992
993    # TODO: remove this test when 2d state_dict is ready.
994    @skip_if_lt_x_gpu(2)
995    @skipIfRocm
996    def test_raise_not_implemented_state_dict_if_2d(self):
997        def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool):
998            _model = Transformer.parallelize(_model, mesh["tp"], use_seq_parallel)
999            for layer in _model.layers:
1000                fully_shard(layer, mesh=mesh["dp"])
1001            fully_shard(_model, mesh=mesh["dp"])
1002            return _model
1003
1004        global_mesh = self.init_global_mesh()
1005        seed = 42
1006        torch.manual_seed(seed)
1007        model_args = ModelArgs(dropout_p=0.0)
1008        model = parallelize(Transformer(model_args), global_mesh, True)
1009
1010        with self.assertRaisesRegex(NotImplementedError, "2D"):
1011            get_model_state_dict(model)
1012
1013    # Temporarily disable 2D state dict test, while strided sharding is being devleoped.
1014    # TODO: re-enable this test once 2d state_dict is ready.
1015    @skip_if_lt_x_gpu(2)
1016    @with_temp_dir
1017    def _temp_disable_test_train_parity_2d_transformer_checkpoint_resume(self):
1018        """
1019        Tests train parity of a 2D transformer without checkpointing against a
1020        2D transformer with a checkpoint save/load.
1021        """
1022        self.run_subtests(
1023            {
1024                "use_seq_parallel": [False, True],
1025                # If reusing, then load into the same model/optimizer instance
1026                # else construct new ones (requiring eager optim state init)
1027                "reuse_model_optim": [False, True],
1028                "optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
1029                # TODO: need to update `parallelize` before including foreach=True for testing
1030                "foreach": [False],
1031            },
1032            self._test_train_parity_2d_transformer_checkpoint_resume,
1033        )
1034
1035    def _test_train_parity_2d_transformer_checkpoint_resume(
1036        self,
1037        use_seq_parallel: bool,
1038        reuse_model_optim: bool,
1039        optimizer_class: Type[torch.optim.Optimizer],
1040        foreach: bool,
1041    ):
1042        def train_step(
1043            _model: nn.Module, _optim: torch.optim.Optimizer, _inp: torch.Tensor
1044        ) -> torch.Tensor:
1045            loss = _model(_inp).sum()
1046            loss.backward()
1047            _optim.step()
1048            _optim.zero_grad()
1049            return loss
1050
1051        def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool):
1052            _model = Transformer.parallelize(_model, mesh["tp"], use_seq_parallel)
1053            for layer in _model.layers:
1054                fully_shard(layer, mesh=mesh["dp"])
1055            fully_shard(_model, mesh=mesh["dp"])
1056            return _model
1057
1058        global_mesh = self.init_global_mesh()
1059        # Baseline: run two iterations without checkpointing
1060        seed = 42
1061        torch.manual_seed(seed)
1062        model_args = ModelArgs(dropout_p=0.0)
1063        model_no_cp = parallelize(
1064            Transformer(model_args), global_mesh, use_seq_parallel
1065        )
1066        optim_no_cp = optimizer_class(
1067            model_no_cp.parameters(), lr=1e-2, foreach=foreach
1068        )
1069
1070        torch.manual_seed(42 + global_mesh["dp"].get_local_rank() + 1)
1071        inp = torch.randint(0, model_args.vocab_size, (3, 16), device="cuda")
1072        loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp)
1073        loss_no_cp2 = train_step(model_no_cp, optim_no_cp, inp)
1074
1075        # Test: run one iteration, save checkpoint, zero states or init new
1076        # model/optimizer, load checkpoint, and run another iteration
1077        torch.manual_seed(seed)
1078        model_cp = parallelize(Transformer(model_args), global_mesh, use_seq_parallel)
1079        optim_cp = optimizer_class(model_cp.parameters(), lr=1e-2, foreach=foreach)
1080
1081        loss_cp1 = train_step(model_cp, optim_cp, inp)
1082        self.assertEqual(loss_no_cp1, loss_cp1)
1083
1084        sharded_sd = {
1085            "model": get_model_state_dict(model_cp),
1086            # Use `get_optimizer_state_dict` to handle eager optim state init
1087            # when constructing a new optimizer instance
1088            "optim": get_optimizer_state_dict(model_cp, optim_cp),
1089        }
1090        dcp.save(
1091            state_dict=sharded_sd,
1092            storage_writer=dcp.FileSystemWriter(self.temp_dir),
1093        )
1094        if reuse_model_optim:
1095            with torch.no_grad():
1096                for param in model_cp.parameters():
1097                    param.zero_()
1098                optim_sd = optim_cp.state_dict()
1099                for param_states in optim_sd["state"].values():
1100                    for state_value in param_states.values():
1101                        if torch.is_tensor(state_value):
1102                            state_value.zero_()
1103        else:
1104            torch.manual_seed(seed + 1)  # different seed
1105            model_cp = parallelize(
1106                Transformer(model_args), global_mesh, use_seq_parallel
1107            )
1108            optim_cp = optimizer_class(model_cp.parameters(), lr=1e-2, foreach=foreach)
1109        self.assertNotEqual(loss_no_cp2, train_step(model_cp, optim_cp, inp))
1110
1111        sharded_sd = {
1112            "model": get_model_state_dict(model_cp),
1113            "optim": get_optimizer_state_dict(model_cp, optim_cp),
1114        }
1115        dcp.load(
1116            state_dict=sharded_sd,
1117            storage_reader=dcp.FileSystemReader(self.temp_dir),
1118        )
1119        self.assertGreater(len(optim_cp.state_dict()["state"]), 0)
1120
1121        loss_cp2 = train_step(model_cp, optim_cp, inp)
1122        self.assertEqual(loss_no_cp2, loss_cp2)
1123
1124
1125class TestFullyShardNDTraining(FSDPTest):
1126    @property
1127    def world_size(self) -> int:
1128        return min(8, torch.cuda.device_count())
1129
1130    def init_global_mesh(self) -> DeviceMesh:
1131        # Prefer to test with >=8 GPUs, but for 2 GPUs, use 2-way TP
1132        dp_size = 2 if self.world_size > 2 else 1
1133        pp_size = 2 if self.world_size > 4 else 1
1134        return init_device_mesh(
1135            "cuda",
1136            (pp_size, dp_size, self.world_size // (dp_size * pp_size)),
1137            mesh_dim_names=("pp", "dp", "tp"),
1138        )
1139
1140    @skip_if_lt_x_gpu(4)
1141    def test_2d_mlp_with_nd_mesh(self):
1142        global_mesh = self.init_global_mesh()
1143        self.run_subtests(
1144            {
1145                "reshard_after_forward": [False, True],
1146                "use_activation_checkpointing": [False, True],
1147                "mlp_dim": [3, 16, 17],
1148                "foreach": [False],
1149            },
1150            functools.partial(self._test_2d_mlp_with_nd_mesh, global_mesh),
1151        )
1152
1153    def _test_2d_mlp_with_nd_mesh(
1154        self,
1155        global_mesh: DeviceMesh,
1156        reshard_after_forward: bool,
1157        use_activation_checkpointing: bool,
1158        mlp_dim: int,
1159        foreach: bool,
1160    ):
1161        global_mesh = self.init_global_mesh()
1162        pp_mesh, dp_mesh, tp_mesh = (
1163            global_mesh["pp"],
1164            global_mesh["dp"],
1165            global_mesh["tp"],
1166        )
1167        dp_pg = dp_mesh.get_group()  # used for `replicate()`
1168
1169        torch.manual_seed(42)
1170        model = MLPStack(mlp_dim)
1171        ref_model = copy.deepcopy(model).cuda()
1172        replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
1173        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
1174        model.parallelize(
1175            tp_mesh,
1176            dp_mesh,
1177            use_activation_checkpointing,
1178            reshard_after_forward=reshard_after_forward,
1179        )
1180        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
1181
1182        torch.manual_seed(42 + dp_pg.rank() + 1)
1183        device = torch.device("cuda")
1184        for iter_idx in range(10):
1185            inp = torch.randn((8, mlp_dim), device=device)
1186            losses: List[torch.Tensor] = []
1187            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
1188                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
1189                losses.append(_model(inp).sum())
1190                losses[-1].backward()
1191                _optim.step()
1192            self.assertEqual(losses[0], losses[1])
1193
1194        for n, p in model.named_parameters():
1195            self.assertIsInstance(p, DTensor)
1196            self.assertEqual(p.device_mesh.ndim, 2)
1197            self.assertEqual(len(p.placements), 2)
1198            self.assertEqual(p.device_mesh.mesh_dim_names, ("dp", "tp"))
1199
1200
1201class TestFullyShardHSDPTraining(FSDPTest):
1202    @property
1203    def world_size(self) -> int:
1204        return min(4, torch.cuda.device_count())
1205
1206    @skip_if_lt_x_gpu(2)
1207    def test_train_parity_hsdp(self):
1208        shard_size = 2 if self.world_size > 2 else 1
1209        replicate_size = self.world_size // shard_size
1210        global_mesh = init_device_mesh(
1211            "cuda", (replicate_size, shard_size), mesh_dim_names=("replicate", "shard")
1212        )
1213        self.run_subtests(
1214            {
1215                "reshard_after_forward": [False, True],
1216                "use_activation_checkpointing": [False, True],
1217                "mlp_dim": [3, 16, 17],
1218                "sync_gradients_at_last_batch": [True, False],
1219            },
1220            functools.partial(self._test_train_parity_hsdp, global_mesh),
1221        )
1222
1223    def _test_train_parity_hsdp(
1224        self,
1225        global_mesh: DeviceMesh,
1226        reshard_after_forward: bool,
1227        use_activation_checkpointing: bool,
1228        mlp_dim: int,
1229        sync_gradients_at_last_batch: bool,
1230    ):
1231        torch.manual_seed(42)
1232        model = nn.Sequential(
1233            nn.LayerNorm(mlp_dim, bias=False),
1234            MLP(mlp_dim, dim_multiplier=3),
1235            MLP(mlp_dim),
1236            MLP(mlp_dim, dim_multiplier=3),
1237        )
1238        ref_model = copy.deepcopy(model).cuda()
1239        replicate(ref_model, device_ids=[self.rank])
1240        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
1241        for mlp in model:
1242            if use_activation_checkpointing:
1243                checkpoint(mlp)
1244            fully_shard(
1245                mlp, mesh=global_mesh, reshard_after_forward=reshard_after_forward
1246            )
1247        fully_shard(
1248            model, mesh=global_mesh, reshard_after_forward=reshard_after_forward
1249        )
1250        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
1251        check_sharded_parity(self, ref_model, model)
1252        torch.manual_seed(42 + self.rank + 1)
1253        device = torch.device("cuda")
1254        num_microbatches = 3
1255        for iter_idx in range(5):
1256            for microbatch_idx in range(num_microbatches):
1257                is_last_microbatch = microbatch_idx == num_microbatches - 1
1258                if sync_gradients_at_last_batch:
1259                    model.set_requires_gradient_sync(is_last_microbatch)
1260                inp = torch.randn((8, mlp_dim), device=device)
1261                losses: List[torch.Tensor] = []
1262                for _model, _optim in ((ref_model, ref_optim), (model, optim)):
1263                    losses.append(_model(inp).sum())
1264                    losses[-1].backward()
1265                self.assertEqual(losses[0], losses[1])
1266            check_sharded_parity(self, ref_model, model)
1267            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
1268                _optim.step()
1269                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
1270            check_sharded_parity(self, ref_model, model)
1271
1272
1273class TestFullyShardCustomForwardMethod(FSDPTest):
1274    @property
1275    def world_size(self) -> int:
1276        return min(torch.cuda.device_count(), 2)
1277
1278    @skip_if_lt_x_gpu(2)
1279    def test_register_fsdp_forward_method(self):
1280        """Based on https://github.com/pytorch/pytorch/issues/109385"""
1281
1282        class VisionTransformer(nn.Module):
1283            def __init__(self):
1284                super().__init__()
1285                self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14)
1286
1287            def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
1288                return self.patch_proj(imgs).flatten(2).transpose(1, 2)
1289
1290            def forward(self, imgs: torch.Tensor) -> torch.Tensor:
1291                return self.forward_features(imgs).sum(dim=1)
1292
1293        class Model(nn.Module):
1294            def __init__(self):
1295                super().__init__()
1296                self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256)
1297
1298            def forward(self, imgs: torch.Tensor) -> torch.Tensor:
1299                # Run `vit.forward_features`, which is not `forward`!
1300                patch_embeddings = self.vit.forward_features(imgs)
1301                return self.projector(patch_embeddings)
1302
1303        torch.manual_seed(42)
1304        model = Model()
1305        ref_model = copy.deepcopy(model).cuda()
1306        fully_shard(model.vit)
1307        fully_shard(model.projector)
1308        fully_shard(model)
1309        register_fsdp_forward_method(model.vit, "forward_features")
1310
1311        torch.manual_seed(42 + self.rank + 1)
1312        inp = torch.randn(4, 3, 224, 224, device="cuda")
1313        ref_loss = ref_model(inp).sum()
1314        loss = model(inp).sum()
1315        self.assertEqual(ref_loss, loss)
1316        ref_loss.backward()
1317        loss.backward()
1318        for param in ref_model.parameters():
1319            dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
1320        check_sharded_parity(self, ref_model, model)
1321
1322
1323if __name__ == "__main__":
1324    run_tests()
1325