xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/fsdp/test_fully_shard_comm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import functools
5import itertools
6import unittest
7from typing import Callable, List, Optional, Tuple, Union
8
9import torch
10import torch.distributed as dist
11import torch.nn as nn
12import torch.nn.functional as F
13from torch.distributed._composable import checkpoint, replicate
14from torch.distributed._composable.fsdp import (
15    FSDPModule,
16    fully_shard,
17    MixedPrecisionPolicy,
18    OffloadPolicy,
19)
20from torch.distributed._composable.fsdp._fsdp_collectives import (
21    _div_if_needed,
22    _get_gradient_divide_factors,
23    foreach_all_gather,
24    foreach_all_gather_copy_out,
25    foreach_reduce,
26)
27from torch.distributed._composable.fsdp._fsdp_common import FSDPMeshInfo, TrainingState
28from torch.distributed._composable.fsdp._fsdp_init import (
29    _get_post_forward_mesh_info,
30    _init_default_fully_shard_mesh,
31)
32from torch.distributed._composable.fsdp._fsdp_param import ShardedState
33from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
34from torch.distributed._tensor import DTensor
35from torch.distributed._tensor.experimental import implicit_replication
36from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
37from torch.distributed.tensor.debug import CommDebugMode
38from torch.testing._internal.common_cuda import TEST_CUDA
39from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
40from torch.testing._internal.common_fsdp import (
41    check_sharded_parity,
42    DoubleLinear,
43    FSDPTest,
44    FSDPTestMultiThread,
45    MLP,
46    patch_post_backward,
47    patch_reshard,
48    patch_unshard,
49)
50from torch.testing._internal.common_utils import run_tests
51from torch.testing._internal.distributed._tensor.common_dtensor import (
52    ModelArgs,
53    Transformer,
54    TransformerBlock,
55)
56
57
58c10d_ops = torch.ops.c10d
59
60# For recording FSDP events like unshard or post-backward
61EventType = Tuple[str, str, TrainingState]
62
63
64class TestFullyShardCollectiveOps(FSDPTestMultiThread):
65    @property
66    def world_size(self) -> int:
67        return 128
68
69    @property
70    def device(self) -> torch.device:
71        return torch.device("cuda:0")
72
73    def _get_param_sizes(self) -> List[torch.Size]:
74        # For world size 128, the fp32 all-gather and reduce-scatter testing
75        # requires ~0.22 GB
76        return [
77            torch.Size([17, 257]),
78            torch.Size([17]),
79            torch.Size([64, 312]),
80            torch.Size([64]),
81            torch.Size([64, 64]),
82            torch.Size([512, 64]),
83            torch.Size([256]),
84            torch.Size([64, 297]),
85        ]
86
87    def _init_params(self, param_sizes: List[torch.Size]) -> List[nn.Parameter]:
88        torch.manual_seed(42)
89        orig_params = [
90            nn.Parameter(torch.randn(size, device=self.device)) for size in param_sizes
91        ]
92        # Since seed is per process, not per thread, we broadcast to ensure the
93        # same original parameters across ranks
94        for orig_param in orig_params:
95            dist.broadcast(orig_param, src=0)
96        return orig_params
97
98    def _init_fsdp_param_group(
99        self, params: List[nn.Parameter], reshard_after_forward: Union[bool, int]
100    ):
101        module = nn.ParameterList([param.detach().clone() for param in params])
102        mesh_info = FSDPMeshInfo(_init_default_fully_shard_mesh(), shard_mesh_dim=0)
103        post_forward_mesh_info = _get_post_forward_mesh_info(
104            reshard_after_forward, mesh_info
105        )
106        fsdp_param_group = FSDPParamGroup(
107            list(module.parameters()),
108            (module,),
109            mesh_info,
110            post_forward_mesh_info,
111            self.device,
112            MixedPrecisionPolicy(),
113            OffloadPolicy(),
114        )
115        fsdp_param_group.lazy_init()
116        return fsdp_param_group
117
118    @unittest.skipIf(not TEST_CUDA, "no cuda")
119    def test_all_gather_fp32(self):
120        param_sizes = self._get_param_sizes()
121        default_stream = torch.cuda.current_stream()
122        stream1, stream2 = torch.cuda.Stream(), torch.cuda.Stream()
123        for async_op, streams, reshard_after_forward in itertools.product(
124            (False, True),
125            ((default_stream, default_stream), (stream1, stream2)),
126            (True, 8),
127        ):
128            all_gather_copy_in_stream, all_gather_stream = streams
129            # Save test time by only testing reshard after forward as an int
130            # for non-async and non-default streams (like in pre-backward)
131            if type(reshard_after_forward) is int and (
132                async_op or all_gather_stream is default_stream
133            ):
134                continue
135            self._test_all_gather(
136                param_sizes,
137                reshard_after_forward=reshard_after_forward,
138                async_op=async_op,
139                all_gather_copy_in_stream=all_gather_copy_in_stream,
140                all_gather_stream=all_gather_stream,
141            )
142
143    def _test_all_gather(
144        self,
145        param_sizes: List[torch.Size],
146        reshard_after_forward: Union[bool, int],
147        async_op: bool,
148        all_gather_copy_in_stream: torch.cuda.Stream,
149        all_gather_stream: torch.cuda.Stream,
150    ):
151        def all_gather(fsdp_param_group: FSDPParamGroup, group: dist.ProcessGroup):
152            all_gather_result = foreach_all_gather(
153                fsdp_param_group.fsdp_params,
154                group,
155                async_op=async_op,
156                all_gather_copy_in_stream=all_gather_copy_in_stream,
157                all_gather_stream=all_gather_stream,
158                device=self.device,
159            )
160            foreach_all_gather_copy_out(all_gather_result, fsdp_params, group)
161            # Transition to unsharded state to register unsharded parameters
162            for fsdp_param in fsdp_param_group.fsdp_params:
163                fsdp_param.init_unsharded_param()
164            fsdp_param_group._to_unsharded()
165
166        def check_all_gathered_params(
167            orig_params: List[nn.Parameter], module: nn.Module
168        ):
169            for orig_param, param in zip(orig_params, module.parameters()):
170                self.assertIsInstance(param, torch.Tensor)
171                self.assertIsInstance(param, nn.Parameter)
172                self.assertEqual(param, orig_param.to(param.dtype))
173
174        # Set up the reference parameters and construct the FSDP group
175        orig_params = self._init_params(param_sizes)
176        fsdp_param_group = self._init_fsdp_param_group(
177            orig_params, reshard_after_forward
178        )
179        fsdp_params = fsdp_param_group.fsdp_params
180        module = fsdp_param_group.modules[0]
181
182        # Sanity check that the parameter sharding is as expected
183        for orig_param, param in zip(orig_params, module.parameters()):
184            self.assertTrue(isinstance(param, DTensor))
185            self.assertEqual(param.full_tensor(), orig_param)
186
187        # Run the foreach all-gather (including copy-in and copy-out)
188        all_gather(fsdp_param_group, fsdp_param_group.mesh_info.shard_process_group)
189
190        # Check all-gather correctness
191        check_all_gathered_params(orig_params, module)
192
193        # For reshard after after forward as an int, further test emulating the
194        # pre-backward all-gather
195        if type(reshard_after_forward) is not int:
196            return
197        fsdp_param_group._to_sharded_post_forward()
198        all_gather(
199            fsdp_param_group,
200            fsdp_param_group.post_forward_mesh_info.shard_process_group,
201        )
202        check_all_gathered_params(orig_params, module)
203
204    @unittest.skipIf(not TEST_CUDA, "no cuda")
205    def test_reduce_scatter_fp32(self):
206        param_sizes = self._get_param_sizes()
207        default_stream = torch.cuda.current_stream()
208        stream = torch.cuda.Stream()
209        for reduce_scatter_stream in (default_stream, stream):
210            self._test_reduce_scatter(
211                param_sizes,
212                reduce_scatter_stream=reduce_scatter_stream,
213                reduce_scatter_dtype=torch.float32,
214            )
215
216    @unittest.skipIf(not TEST_CUDA, "no cuda")
217    def test_reduce_scatter_fp16(self):
218        param_sizes = self._get_param_sizes()
219        default_stream = torch.cuda.current_stream()
220        stream = torch.cuda.Stream()
221        for reduce_scatter_stream in (default_stream, stream):
222            self._test_reduce_scatter(
223                param_sizes,
224                reduce_scatter_stream=reduce_scatter_stream,
225                reduce_scatter_dtype=torch.float16,
226            )
227
228    def _test_reduce_scatter(
229        self,
230        param_sizes: List[torch.Size],
231        reduce_scatter_stream: torch.cuda.Stream,
232        reduce_scatter_dtype: torch.dtype,
233    ):
234        # Set up the reference parameters and construct the FSDP group
235        orig_params = self._init_params(param_sizes)
236        fsdp_param_group = self._init_fsdp_param_group(orig_params, True)
237        fsdp_params = fsdp_param_group.fsdp_params
238        fsdp_param_group.comm_ctx.lazy_init()
239
240        # Run one unshard to initialize metadata
241        fsdp_param_group.unshard()
242        fsdp_param_group.wait_for_unshard()
243        fsdp_param_group.reshard()
244
245        # Run the foreach reduce-scatter (including copy-in and view-out)
246        torch.manual_seed(42)
247        unsharded_grads = [torch.ones_like(param) * self.rank for param in orig_params]
248        group = fsdp_param_group.mesh_info.shard_process_group
249        self.assertEqual(group.size(), self.world_size)
250        all_reduce_stream = torch.cuda.Stream()
251        (
252            reduce_scatter_input,
253            reduce_scatter_event,
254            post_reduce_event,
255            _,
256        ) = foreach_reduce(
257            fsdp_params,
258            unsharded_grads,
259            group,
260            reduce_scatter_stream,
261            orig_dtype=orig_params[0].dtype,
262            reduce_dtype=reduce_scatter_dtype,
263            device=self.device,
264            reduce_scatter_reduce_op=None,
265            all_reduce_group=None,
266            all_reduce_stream=all_reduce_stream,
267            all_reduce_grads=True,
268            partial_reduce_output=None,
269        )
270        torch.cuda.current_stream().wait_event(post_reduce_event)
271
272        # Check reduce-scatter correctness
273        predivide_factor, postdivide_factor = _get_gradient_divide_factors(
274            group, None, reduce_scatter_dtype
275        )
276        reduced_grads = [grad.detach().clone() for grad in unsharded_grads]
277        for grad in reduced_grads:
278            _div_if_needed(grad, predivide_factor)
279            dist.all_reduce(
280                grad,
281                group=group,
282                op=dist.ReduceOp.AVG if predivide_factor is None else dist.ReduceOp.SUM,
283            )
284            _div_if_needed(grad, postdivide_factor)
285        for fsdp_param, reduced_grad in zip(fsdp_params, reduced_grads):
286            sharded_grad = fsdp_param.sharded_param.grad
287            self.assertIsInstance(sharded_grad, DTensor)
288            self.assertEqual(sharded_grad.full_tensor(), reduced_grad)
289
290
291class TestFullyShardCommunication(FSDPTest):
292    @property
293    def world_size(self) -> int:
294        return min(4, torch.cuda.device_count())
295
296    @skip_if_lt_x_gpu(2)
297    def test_fully_shard_communication_count(self):
298        """
299        Tests that FSDP issues the expected number of all-gathers and
300        reduce-scatters during forward and backward.
301        """
302        self.run_subtests(
303            {"reshard_after_forward": [True, False, 2]},
304            self._test_communication_count,
305        )
306
307    def _test_communication_count(
308        self,
309        reshard_after_forward: Union[bool, int],
310    ):
311        torch.manual_seed(42)
312        model_args = ModelArgs()
313        model = Transformer(model_args)
314        fully_shard_fn = functools.partial(
315            fully_shard, reshard_after_forward=reshard_after_forward
316        )
317        num_blocks = 0
318        for module in model.modules():
319            if isinstance(module, TransformerBlock):
320                fully_shard_fn(module)
321                num_blocks += 1
322        fully_shard_fn(model)
323        # We construct `num_blocks` plus 1 FSDP states/communication groups
324
325        torch.manual_seed(42 + self.rank)
326        inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
327        with CommDebugMode() as fwd_comm_mode:
328            loss = model(inp)
329        fwd_comm_counts = fwd_comm_mode.get_comm_counts()
330        self.assertEqual(len(fwd_comm_counts), 1)
331        self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_blocks + 1)
332        with CommDebugMode() as bwd_comm_mode:
333            loss.sum().backward()
334        bwd_comm_counts = bwd_comm_mode.get_comm_counts()
335        if reshard_after_forward is False:
336            self.assertEqual(len(bwd_comm_counts), 1)
337        else:
338            # The root always does not reshard after forward
339            self.assertEqual(len(bwd_comm_counts), 2)
340            self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_blocks)
341        self.assertEqual(
342            bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_blocks + 1
343        )
344
345    @skip_if_lt_x_gpu(2)
346    def test_manual_reshard_with_reshard_after_forward_false(self):
347        """
348        Tests that we can manually call ``reshard`` on FSDP modules that were
349        initialized with ``reshard_after_forward=False`` and still run unshard.
350        """
351        torch.manual_seed(42)
352        model_args = ModelArgs()
353        model = Transformer(model_args)
354        for module in model.modules():
355            if isinstance(module, TransformerBlock):
356                fully_shard(module, reshard_after_forward=False)
357        model = fully_shard(model, reshard_after_forward=False)
358        num_fsdp_modules = sum(
359            isinstance(module, FSDPModule) for module in model.modules()
360        )
361
362        torch.manual_seed(42 + self.rank)
363        inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
364        with CommDebugMode() as fwd_comm_mode:
365            loss = model(inp)
366        fwd_comm_counts = fwd_comm_mode.get_comm_counts()
367        self.assertEqual(len(fwd_comm_counts), 1)
368        self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_fsdp_modules)
369
370        for module in model.modules():
371            if isinstance(module, FSDPModule):
372                module.reshard()
373
374        with CommDebugMode() as bwd_comm_mode:
375            loss.sum().backward()
376        bwd_comm_counts = bwd_comm_mode.get_comm_counts()
377        self.assertEqual(len(bwd_comm_counts), 2)
378        self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_fsdp_modules)
379        self.assertEqual(
380            bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_fsdp_modules
381        )
382
383    @skip_if_lt_x_gpu(2)
384    def test_set_reduce_scatter_divide_factor(self):
385        self.run_subtests(
386            {"divide_factor": [self.world_size * 2, self.world_size]},
387            self._test_set_reduce_scatter_divide_factor,
388        )
389
390    def _test_set_reduce_scatter_divide_factor(self, divide_factor: float):
391        torch.manual_seed(42)
392        model_args = ModelArgs(dropout_p=0.0, weight_tying=False)
393        model = Transformer(model_args)
394        ref_model = copy.deepcopy(model).cuda()
395        ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
396        for module in model.modules():
397            if isinstance(module, TransformerBlock):
398                fully_shard(module, reshard_after_forward=False)
399        model = fully_shard(model, reshard_after_forward=False)
400        optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
401        model.set_reduce_scatter_divide_factor(divide_factor)
402
403        torch.manual_seed(42 + self.rank)
404        inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
405
406        for iter_idx in range(10):
407            ref_loss = ref_model(inp).sum()
408            ref_loss.backward()
409            for param in ref_model.parameters():
410                param.grad.mul_(1.0 / divide_factor)
411                dist.all_reduce(param.grad)
412            loss = model(inp).sum()
413            loss.backward()
414            ref_optim.step()
415            optim.step()
416            ref_optim.zero_grad()
417            optim.zero_grad()
418            self.assertEqual(ref_loss, loss)
419            check_sharded_parity(self, ref_model, model)
420
421
422class TestFullyShardPrefetch(FSDPTest):
423    @property
424    def world_size(self) -> int:
425        return min(4, torch.cuda.device_count())
426
427    @skip_if_lt_x_gpu(2)
428    def test_fully_shard_backward_prefetch(self):
429        # Activation checkpointing should not affect the expected FSDP events
430        self.run_subtests(
431            {
432                "reshard_after_forward": [True, False, 2],
433                "checkpoint_impl": [None, "utils", "composable"],
434            },
435            self._test_backward_prefetch_forward_backward,
436        )
437        self.run_subtests(
438            {
439                "reshard_after_forward": [True, False, 2],
440                "checkpoint_impl": [None, "utils", "composable"],
441            },
442            self._test_backward_prefetch_multi_forward,
443        )
444        self._test_backward_prefetch_unused_in_backward(True)
445
446    def _test_backward_prefetch_forward_backward(
447        self, reshard_after_forward: Union[bool, int], checkpoint_impl: Optional[str]
448    ):
449        n_layers = 3
450        model, optim, inp = self._init_transformer(
451            n_layers, reshard_after_forward, checkpoint_impl
452        )
453        events: List[EventType] = []
454        unshard_with_record = self._get_unshard_with_record(
455            FSDPParamGroup.unshard, events
456        )
457        post_backward_with_record = self._get_post_backward_with_record(
458            FSDPParamGroup.post_backward, events
459        )
460        # Check the order for normal 1 forward, 1 backward, 1 optimizer step
461        with patch_unshard(unshard_with_record), patch_post_backward(
462            post_backward_with_record
463        ):
464            for iter_idx in range(3):
465                loss = model(inp)
466                expected_events = [
467                    ("unshard", "", TrainingState.FORWARD),  # root
468                    ("unshard", "layers.0", TrainingState.FORWARD),
469                    ("unshard", "layers.1", TrainingState.FORWARD),
470                    ("unshard", "layers.2", TrainingState.FORWARD),
471                ]
472                self.assertEqual(events, expected_events)
473                events.clear()
474                loss.sum().backward()
475                expected_events = [
476                    # Root does not reshard after forward so there is no
477                    # unshard event for it in backward
478                    ("unshard", "layers.2", TrainingState.PRE_BACKWARD),
479                    # Explicit backward prefetching moves the unshards early
480                    # by one module (note how swapping each unshard down one
481                    # event would give the natural event order)
482                    ("unshard", "layers.1", TrainingState.PRE_BACKWARD),
483                    ("post_backward", "layers.2", TrainingState.POST_BACKWARD),
484                    ("unshard", "layers.0", TrainingState.PRE_BACKWARD),
485                    ("post_backward", "layers.1", TrainingState.POST_BACKWARD),
486                    ("post_backward", "layers.0", TrainingState.POST_BACKWARD),
487                    ("post_backward", "", TrainingState.POST_BACKWARD),
488                ]
489                if reshard_after_forward is False:
490                    # No reshard after forward means no backward unshards
491                    expected_events = [e for e in expected_events if e[0] != "unshard"]
492                self.assertEqual(events, expected_events)
493                events.clear()
494                optim.step()
495                optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
496
497    def _test_backward_prefetch_multi_forward(
498        self, reshard_after_forward: Union[bool, int], checkpoint_impl: Optional[str]
499    ):
500        n_layers = 3
501        model, optim, inp = self._init_transformer(
502            n_layers, reshard_after_forward, checkpoint_impl
503        )
504        events: List[EventType] = []
505        unshard_with_record = self._get_unshard_with_record(
506            FSDPParamGroup.unshard, events
507        )
508        post_backward_with_record = self._get_post_backward_with_record(
509            FSDPParamGroup.post_backward, events
510        )
511        # Check the order for multiple forwards before 1 backward
512        with patch_unshard(unshard_with_record), patch_post_backward(
513            post_backward_with_record
514        ):
515            loss1 = model(inp)
516            loss2 = model(inp)
517            expected_events = [
518                ("unshard", "", TrainingState.FORWARD),  # root
519                ("unshard", "layers.0", TrainingState.FORWARD),
520                ("unshard", "layers.1", TrainingState.FORWARD),
521                ("unshard", "layers.2", TrainingState.FORWARD),
522                # Root does not reshard after forward so there is not another
523                # unshard event for it
524                ("unshard", "layers.0", TrainingState.FORWARD),
525                ("unshard", "layers.1", TrainingState.FORWARD),
526                ("unshard", "layers.2", TrainingState.FORWARD),
527            ]
528            if reshard_after_forward is False:
529                # No reshard after forward means no second set of unshards
530                expected_events = expected_events[:-3]
531            self.assertEqual(events, expected_events)
532            events.clear()
533            (loss1 + loss2).sum().backward()
534            expected_events = [
535                # Same as the single forward/backward case except the root's
536                # post-backward does not run until the end of backward in the
537                # final callback (since the input not requiring gradient means
538                # that we do not have a tensor on which to hook for
539                # post-backward)
540                ("unshard", "layers.2", TrainingState.PRE_BACKWARD),
541                ("unshard", "layers.1", TrainingState.PRE_BACKWARD),
542                ("post_backward", "layers.2", TrainingState.POST_BACKWARD),
543                ("unshard", "layers.0", TrainingState.PRE_BACKWARD),
544                ("post_backward", "layers.1", TrainingState.POST_BACKWARD),
545                ("post_backward", "layers.0", TrainingState.POST_BACKWARD),
546            ]
547            if reshard_after_forward is False:
548                # No reshard after forward means no backward unshards
549                expected_events = [e for e in expected_events if e[0] != "unshard"]
550                # However, the post-backward reshards, so the second set of
551                # unshards will run as real ops
552            expected_events += [
553                # Repeat the same pattern except with the root's post-backward
554                # at the end since the final callback runs
555                ("unshard", "layers.2", TrainingState.PRE_BACKWARD),
556                ("unshard", "layers.1", TrainingState.PRE_BACKWARD),
557                ("post_backward", "layers.2", TrainingState.POST_BACKWARD),
558                ("unshard", "layers.0", TrainingState.PRE_BACKWARD),
559                ("post_backward", "layers.1", TrainingState.POST_BACKWARD),
560                ("post_backward", "layers.0", TrainingState.POST_BACKWARD),
561                ("post_backward", "", TrainingState.POST_BACKWARD),
562            ]
563            self.assertEqual(events, expected_events)
564            events.clear()
565
566    def _test_backward_prefetch_unused_in_backward(
567        self, reshard_after_forward: Union[bool, int]
568    ):
569        """
570        Test a model with a linear module then a split into two linear modules,
571        where we run backward through one path first before the other, meaning
572        that (1) only one linear of the two split is used per backward and (2)
573        the initial shared linear is used in both backwards.
574        """
575        dim = 8
576        model = nn.Sequential(nn.Linear(dim, dim), DoubleLinear(dim))
577        fully_shard(model[0], reshard_after_forward=reshard_after_forward)
578        fully_shard(model[1].lin1, reshard_after_forward=reshard_after_forward)
579        fully_shard(model[1].lin2, reshard_after_forward=reshard_after_forward)
580        fully_shard(model, reshard_after_forward=reshard_after_forward)
581        inp = torch.randn((4, dim), device="cuda")
582        events: List[EventType] = []
583        unshard_with_record = self._get_unshard_with_record(
584            FSDPParamGroup.unshard, events
585        )
586        post_backward_with_record = self._get_post_backward_with_record(
587            FSDPParamGroup.post_backward, events
588        )
589        with patch_unshard(unshard_with_record), patch_post_backward(
590            post_backward_with_record
591        ):
592            loss1, loss2 = model(inp)
593            expected_events = [
594                # Root has no parameters, so it does not have an unshard
595                ("unshard", "0", TrainingState.FORWARD),
596                ("unshard", "1.lin1", TrainingState.FORWARD),
597                ("unshard", "1.lin2", TrainingState.FORWARD),
598            ]
599            self.assertEqual(events, expected_events)
600            events.clear()
601
602            model.set_is_last_backward(False)
603            loss2.sum().backward(retain_graph=True)
604            expected_events = [
605                ("unshard", "1.lin2", TrainingState.PRE_BACKWARD),
606                # NOTE: This `1.lin1` unshard is a mistargeted prefetch.
607                ("unshard", "1.lin1", TrainingState.PRE_BACKWARD),
608                ("post_backward", "1.lin2", TrainingState.POST_BACKWARD),
609                ("unshard", "0", TrainingState.PRE_BACKWARD),
610                ("post_backward", "0", TrainingState.POST_BACKWARD),
611            ]
612            self.assertEqual(events, expected_events)
613            events.clear()
614
615            model.set_is_last_backward(True)
616            loss1.sum().backward()
617            expected_events = [
618                # NOTE: `1.lin1` is already unsharded from the mistargeted
619                # prefetch in the first backward.
620                # Prefetch `0`
621                ("unshard", "0", TrainingState.PRE_BACKWARD),
622                ("post_backward", "1.lin1", TrainingState.POST_BACKWARD),
623                ("post_backward", "0", TrainingState.POST_BACKWARD),
624            ]
625            self.assertEqual(events, expected_events)
626            events.clear()
627
628    @skip_if_lt_x_gpu(2)
629    def test_set_modules_to_forward_prefetch(self):
630        n_layers = 4
631        reshard_after_forward = True
632        checkpoint_impl = "utils"
633        model, _, inp = self._init_transformer(
634            n_layers, reshard_after_forward, checkpoint_impl
635        )
636
637        def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
638            # Use model-specific knowledge to configure forward prefetching:
639            # each transformer block (layer) prefetches for the next few
640            for i, layer in enumerate(model.layers):
641                if i >= len(model.layers) - num_to_prefetch:
642                    break
643                layers_to_prefetch = [
644                    model.layers[i + j] for j in range(1, num_to_prefetch + 1)
645                ]
646                layer.set_modules_to_forward_prefetch(layers_to_prefetch)
647
648        events: List[EventType] = []
649        unshard_with_record = self._get_unshard_with_record(
650            FSDPParamGroup.unshard, events
651        )
652        reshard_with_record = self._get_reshard_with_record(
653            FSDPParamGroup.reshard, events
654        )
655        post_backward_with_record = self._get_post_backward_with_record(
656            FSDPParamGroup.post_backward, events
657        )
658        expected_backward_events = [
659            # Default backward prefetching
660            ("unshard", "layers.3", TrainingState.PRE_BACKWARD),
661            ("unshard", "layers.2", TrainingState.PRE_BACKWARD),
662            ("reshard", "layers.3", TrainingState.POST_BACKWARD),
663            ("post_backward", "layers.3", TrainingState.POST_BACKWARD),
664            ("unshard", "layers.1", TrainingState.PRE_BACKWARD),
665            ("reshard", "layers.2", TrainingState.POST_BACKWARD),
666            ("post_backward", "layers.2", TrainingState.POST_BACKWARD),
667            ("unshard", "layers.0", TrainingState.PRE_BACKWARD),
668            ("reshard", "layers.1", TrainingState.POST_BACKWARD),
669            ("post_backward", "layers.1", TrainingState.POST_BACKWARD),
670            ("reshard", "layers.0", TrainingState.POST_BACKWARD),
671            ("post_backward", "layers.0", TrainingState.POST_BACKWARD),
672            ("reshard", "", TrainingState.POST_BACKWARD),
673            ("post_backward", "", TrainingState.POST_BACKWARD),
674        ]
675        with patch_unshard(unshard_with_record), patch_reshard(
676            reshard_with_record
677        ), patch_post_backward(post_backward_with_record):
678            set_forward_prefetch(model, num_to_prefetch=1)
679            loss = model(inp)
680            expected_forward_events = [
681                ("unshard", "", TrainingState.FORWARD),
682                # `layers.i` prefetches `layers.i+1`
683                ("unshard", "layers.0", TrainingState.FORWARD),
684                ("unshard", "layers.1", TrainingState.FORWARD),
685                ("reshard", "layers.0", TrainingState.FORWARD),
686                ("unshard", "layers.2", TrainingState.FORWARD),
687                ("reshard", "layers.1", TrainingState.FORWARD),
688                ("unshard", "layers.3", TrainingState.FORWARD),
689                ("reshard", "layers.2", TrainingState.FORWARD),
690                ("reshard", "layers.3", TrainingState.FORWARD),
691            ]
692            self.assertEqual(events, expected_forward_events)
693            events.clear()
694            loss.sum().backward()
695            self.assertEqual(events, expected_backward_events)
696            events.clear()
697
698            set_forward_prefetch(model, num_to_prefetch=2)
699            loss = model(inp)
700            expected_forward_events = [
701                ("unshard", "", TrainingState.FORWARD),
702                # `layers.i` prefetches `layers.i+1` and `layers.i+2`
703                ("unshard", "layers.0", TrainingState.FORWARD),
704                ("unshard", "layers.1", TrainingState.FORWARD),
705                ("unshard", "layers.2", TrainingState.FORWARD),
706                ("reshard", "layers.0", TrainingState.FORWARD),
707                ("unshard", "layers.3", TrainingState.FORWARD),
708                ("reshard", "layers.1", TrainingState.FORWARD),
709                ("reshard", "layers.2", TrainingState.FORWARD),
710                ("reshard", "layers.3", TrainingState.FORWARD),
711            ]
712            self.assertEqual(events, expected_forward_events)
713            events.clear()
714            loss.sum().backward()
715            self.assertEqual(events, expected_backward_events)
716            events.clear()
717
718    @skip_if_lt_x_gpu(2)
719    def test_set_modules_to_backward_prefetch(self):
720        n_layers = 4
721        reshard_after_forward = True
722        checkpoint_impl = "utils"
723        model, _, inp = self._init_transformer(
724            n_layers, reshard_after_forward, checkpoint_impl
725        )
726
727        def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
728            # Use model-specific knowledge to configure backward prefetching:
729            # each transformer block (layer) prefetches for the previous few
730            for i, layer in enumerate(model.layers):
731                if i < num_to_prefetch:
732                    continue
733                layers_to_prefetch = [
734                    model.layers[i - j] for j in range(1, num_to_prefetch + 1)
735                ]
736                layer.set_modules_to_backward_prefetch(layers_to_prefetch)
737
738        events: List[EventType] = []
739        unshard_with_record = self._get_unshard_with_record(
740            FSDPParamGroup.unshard, events
741        )
742        reshard_with_record = self._get_reshard_with_record(
743            FSDPParamGroup.reshard, events
744        )
745        post_backward_with_record = self._get_post_backward_with_record(
746            FSDPParamGroup.post_backward, events
747        )
748        expected_forward_events = [
749            # Default forward prefetching
750            ("unshard", "", TrainingState.FORWARD),  # root
751            ("unshard", "layers.0", TrainingState.FORWARD),
752            ("reshard", "layers.0", TrainingState.FORWARD),
753            ("unshard", "layers.1", TrainingState.FORWARD),
754            ("reshard", "layers.1", TrainingState.FORWARD),
755            ("unshard", "layers.2", TrainingState.FORWARD),
756            ("reshard", "layers.2", TrainingState.FORWARD),
757            ("unshard", "layers.3", TrainingState.FORWARD),
758            ("reshard", "layers.3", TrainingState.FORWARD),
759        ]
760        with patch_unshard(unshard_with_record), patch_reshard(
761            reshard_with_record
762        ), patch_post_backward(post_backward_with_record):
763            set_backward_prefetch(model, num_to_prefetch=1)
764            loss = model(inp)
765            self.assertEqual(events, expected_forward_events)
766            events.clear()
767            loss.sum().backward()
768            expected_backward_events = [
769                # Root prefetches `layers.3` per default
770                ("unshard", "layers.3", TrainingState.PRE_BACKWARD),
771                # `layers.i` prefetches for `layers.i-1` (same as default)
772                ("unshard", "layers.2", TrainingState.PRE_BACKWARD),
773                ("reshard", "layers.3", TrainingState.POST_BACKWARD),
774                ("post_backward", "layers.3", TrainingState.POST_BACKWARD),
775                ("unshard", "layers.1", TrainingState.PRE_BACKWARD),
776                ("reshard", "layers.2", TrainingState.POST_BACKWARD),
777                ("post_backward", "layers.2", TrainingState.POST_BACKWARD),
778                ("unshard", "layers.0", TrainingState.PRE_BACKWARD),
779                ("reshard", "layers.1", TrainingState.POST_BACKWARD),
780                ("post_backward", "layers.1", TrainingState.POST_BACKWARD),
781                ("reshard", "layers.0", TrainingState.POST_BACKWARD),
782                ("post_backward", "layers.0", TrainingState.POST_BACKWARD),
783                ("reshard", "", TrainingState.POST_BACKWARD),
784                ("post_backward", "", TrainingState.POST_BACKWARD),
785            ]
786            self.assertEqual(events, expected_backward_events)
787            events.clear()
788
789            set_backward_prefetch(model, num_to_prefetch=2)
790            loss = model(inp)
791            self.assertEqual(events, expected_forward_events)
792            events.clear()
793            loss.sum().backward()
794            expected_backward_events = [
795                # Root prefetches `layers.3` per default
796                ("unshard", "layers.3", TrainingState.PRE_BACKWARD),
797                # `layers.i` prefetches for `layers.i-1` and `layers.i-2`
798                ("unshard", "layers.2", TrainingState.PRE_BACKWARD),
799                ("unshard", "layers.1", TrainingState.PRE_BACKWARD),
800                ("reshard", "layers.3", TrainingState.POST_BACKWARD),
801                ("post_backward", "layers.3", TrainingState.POST_BACKWARD),
802                ("unshard", "layers.0", TrainingState.PRE_BACKWARD),
803                ("reshard", "layers.2", TrainingState.POST_BACKWARD),
804                ("post_backward", "layers.2", TrainingState.POST_BACKWARD),
805                ("reshard", "layers.1", TrainingState.POST_BACKWARD),
806                ("post_backward", "layers.1", TrainingState.POST_BACKWARD),
807                ("reshard", "layers.0", TrainingState.POST_BACKWARD),
808                ("post_backward", "layers.0", TrainingState.POST_BACKWARD),
809                ("reshard", "", TrainingState.POST_BACKWARD),
810                ("post_backward", "", TrainingState.POST_BACKWARD),
811            ]
812            self.assertEqual(events, expected_backward_events)
813            events.clear()
814
815    @skip_if_lt_x_gpu(2)
816    def test_fully_shard_multi_module_backward_prefetch(self):
817        n_layers = 5
818        model_args = ModelArgs(n_layers=n_layers, checkpoint_activations=True)
819        model = Transformer(model_args)
820        for i in range(n_layers):
821            if i == 0:
822                fully_shard(model.layers[i])
823            elif i % 2 == 1:
824                fully_shard([model.layers[i], model.layers[i + 1]])
825        fully_shard([model.tok_embeddings, model.pos_embeddings])
826        fully_shard([model.norm, model.output], reshard_after_forward=False)
827        fully_shard(model)
828        optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
829
830        events: List[EventType] = []
831        unshard_with_record = self._get_unshard_with_record(
832            FSDPParamGroup.unshard, events
833        )
834        post_backward_with_record = self._get_post_backward_with_record(
835            FSDPParamGroup.post_backward, events
836        )
837        inp = torch.randint(
838            0, model_args.vocab_size, (2, model_args.max_seq_len), device="cuda"
839        )
840        with patch_unshard(unshard_with_record), patch_post_backward(
841            post_backward_with_record
842        ):
843            for iter_idx in range(3):
844                loss = model(inp)
845                expected_events = [
846                    (
847                        "unshard",
848                        "tok_embeddings, pos_embeddings",
849                        TrainingState.FORWARD,
850                    ),
851                    ("unshard", "layers.0", TrainingState.FORWARD),
852                    ("unshard", "layers.1, layers.2", TrainingState.FORWARD),
853                    ("unshard", "layers.3, layers.4", TrainingState.FORWARD),
854                    ("unshard", "norm, output", TrainingState.FORWARD),
855                ]
856                self.assertEqual(events, expected_events)
857                events.clear()
858                loss.sum().backward()
859                expected_events = [
860                    # (norm, output) does not reshard after forward, so there is
861                    # no unshard to begin backward
862                    ("unshard", "layers.3, layers.4", TrainingState.PRE_BACKWARD),
863                    ("post_backward", "norm, output", TrainingState.POST_BACKWARD),
864                    ("unshard", "layers.1, layers.2", TrainingState.PRE_BACKWARD),
865                    (
866                        "post_backward",
867                        "layers.3, layers.4",
868                        TrainingState.POST_BACKWARD,
869                    ),
870                    ("unshard", "layers.0", TrainingState.PRE_BACKWARD),
871                    (
872                        "post_backward",
873                        "layers.1, layers.2",
874                        TrainingState.POST_BACKWARD,
875                    ),
876                    (
877                        "unshard",
878                        "tok_embeddings, pos_embeddings",
879                        TrainingState.PRE_BACKWARD,
880                    ),
881                    ("post_backward", "layers.0", TrainingState.POST_BACKWARD),
882                    (
883                        "post_backward",
884                        "tok_embeddings, pos_embeddings",
885                        TrainingState.POST_BACKWARD,
886                    ),
887                ]
888                events.clear()
889                optim.step()
890                optim.zero_grad()
891
892    @skip_if_lt_x_gpu(2)
893    def test_fully_shard_multi_module_unused_module(self):
894        class ModuleWithUnusedLinear(nn.Module):
895            def __init__(self) -> None:
896                super().__init__()
897                self.unused_lin = nn.Linear(1, 1)
898                self.lin = nn.Linear(16, 16)
899
900            def forward(self, x: torch.Tensor) -> torch.Tensor:
901                return nn.functional.relu(self.lin(x))
902
903        model = nn.Sequential(
904            ModuleWithUnusedLinear(), ModuleWithUnusedLinear(), nn.Linear(16, 16)
905        )
906        fully_shard([model[0].unused_lin, model[0].lin], reshard_after_forward=True)
907        fully_shard([model[1].unused_lin, model[1].lin], reshard_after_forward=True)
908        fully_shard(model)
909        optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
910
911        events: List[EventType] = []
912        unshard_with_record = self._get_unshard_with_record(
913            FSDPParamGroup.unshard, events
914        )
915        post_backward_with_record = self._get_post_backward_with_record(
916            FSDPParamGroup.post_backward, events
917        )
918        inp = torch.randn((2, 16), device="cuda")
919        with patch_unshard(unshard_with_record), patch_post_backward(
920            post_backward_with_record
921        ):
922            for iter_idx in range(3):
923                loss = model(inp)
924                expected_events = [
925                    ("unshard", "", TrainingState.FORWARD),
926                    ("unshard", "0.unused_lin, 0.lin", TrainingState.FORWARD),
927                    ("unshard", "1.unused_lin, 1.lin", TrainingState.FORWARD),
928                ]
929                self.assertEqual(events, expected_events)
930                events.clear()
931                loss.sum().backward()
932                expected_events = [
933                    # Since both `model[0]` and `model[1]` have unused modules
934                    # that never ran forward, they do not reshard after forward
935                    # despite setting it to `True`. Check that there are no
936                    # unshards in backward.
937                    (
938                        "post_backward",
939                        "1.unused_lin, 1.lin",
940                        TrainingState.POST_BACKWARD,
941                    ),
942                    (
943                        "post_backward",
944                        "0.unused_lin, 0.lin",
945                        TrainingState.POST_BACKWARD,
946                    ),
947                    ("post_backward", "", TrainingState.POST_BACKWARD),
948                ]
949                events.clear()
950                optim.step()
951                optim.zero_grad()
952
953    def _init_transformer(
954        self,
955        n_layers: int,
956        reshard_after_forward: Union[bool, int],
957        checkpoint_impl: Optional[str],
958    ):
959        model_args = ModelArgs(
960            n_layers=n_layers, checkpoint_activations=(checkpoint_impl == "utils")
961        )
962        model = Transformer(model_args)
963        for module in model.modules():
964            if isinstance(module, TransformerBlock):
965                if checkpoint_impl == "composable":
966                    checkpoint(module)
967                fully_shard(module, reshard_after_forward=reshard_after_forward)
968        fully_shard(model, reshard_after_forward=reshard_after_forward)
969        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
970        inp = torch.randint(
971            0, model_args.vocab_size, (2, model_args.max_seq_len), device="cuda"
972        )
973        return model, optim, inp
974
975    def _get_unshard_with_record(
976        self, orig_unshard: Callable, events: List[EventType]
977    ) -> Callable:
978        def unshard_with_record(self, *args, **kwargs):
979            nonlocal events
980            if (
981                self._all_gather_result is None
982                and self._sharded_state != ShardedState.UNSHARDED
983            ):  # skip no-ops
984                events.append(("unshard", self._module_fqn, self._training_state))
985            return orig_unshard(self, *args, **kwargs)
986
987        return unshard_with_record
988
989    def _get_reshard_with_record(
990        self, orig_reshard: Callable, events: List[EventType]
991    ) -> Callable:
992        def reshard_with_record(self, *args, **kwargs):
993            nonlocal events
994            if (
995                self._training_state == TrainingState.FORWARD
996                and not self._reshard_after_forward
997            ):  # skip no-ops
998                return
999            events.append(("reshard", self._module_fqn, self._training_state))
1000            return orig_reshard(self, *args, **kwargs)
1001
1002        return reshard_with_record
1003
1004    def _get_post_backward_with_record(
1005        self, orig_post_backward: Callable, events: List[EventType]
1006    ) -> Callable:
1007        def post_backward_with_record(self, *args, **kwargs):
1008            nonlocal events
1009            ret = orig_post_backward(self, *args, **kwargs)
1010            # Use training state after running post-backward to check that the
1011            # state is transitioned to `POST_BACKWARD` as expected
1012            events.append(("post_backward", self._module_fqn, self._training_state))
1013            return ret
1014
1015        return post_backward_with_record
1016
1017
1018class TestFullyShardUnshardMultiProcess(FSDPTest):
1019    @property
1020    def world_size(self) -> int:
1021        return min(torch.cuda.device_count(), 2)
1022
1023    @skip_if_lt_x_gpu(2)
1024    def test_unshard_async(self):
1025        class ReduceModule(nn.Module):
1026            def __init__(self, dim: int, mesh: DeviceMesh):
1027                super().__init__()
1028                self.mesh = mesh
1029                self.weight = nn.Parameter(torch.randn(dim, dim))
1030
1031            def forward(self, x: torch.Tensor):
1032                y = F.relu(x @ self.weight)
1033                # NOTE: This all-reduce is not differentiable and is included
1034                # to exercise the overlap.
1035                work = dist.all_reduce(y, group=self.mesh.get_group(), async_op=True)
1036                return y, work
1037
1038        class MLPs(nn.Module):
1039            def __init__(self, dim: int):
1040                super().__init__()
1041                self.mlp1 = MLP(dim)
1042                self.mlp2 = MLP(dim)
1043                self.mlp3 = MLP(dim)
1044
1045            def forward(self, ys: List[torch.Tensor], works: List[dist.Work]):
1046                (y1, y2, y3), (work1, work2, work3) = ys, works
1047                work1.wait()
1048                z1 = self.mlp1(y1)
1049                work2.wait()
1050                z2 = self.mlp2(y2)
1051                work3.wait()
1052                z3 = self.mlp3(y3)
1053                return z1 + z2 + z3
1054
1055        class ReduceModel(nn.Module):
1056            def __init__(self, dim: int, mesh: DeviceMesh):
1057                super().__init__()
1058                self.reduce_module1 = ReduceModule(dim, mesh)
1059                self.reduce_module2 = ReduceModule(dim, mesh)
1060                self.reduce_module3 = ReduceModule(dim, mesh)
1061                self.mlps = MLPs(dim)
1062
1063            def forward(self, x: torch.Tensor):
1064                y1, work1 = self.reduce_module1(x)
1065                if isinstance(self.mlps.mlp1, FSDPModule):
1066                    self.mlps.mlp1.unshard(async_op=True)
1067                y2, work2 = self.reduce_module2(x)
1068                if isinstance(self.mlps.mlp2, FSDPModule):
1069                    self.mlps.mlp2.unshard(async_op=True)
1070                y3, work3 = self.reduce_module3(x)
1071                if isinstance(self.mlps.mlp3, FSDPModule):
1072                    self.mlps.mlp3.unshard(async_op=True)
1073                return self.mlps([y1, y2, y3], [work1, work2, work3])
1074
1075        mesh = init_device_mesh("cuda", (self.world_size,))
1076        batch_size, dim = 2, 8
1077        torch.manual_seed(42)
1078        ref_model = replicate(ReduceModel(dim, mesh).cuda())
1079        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
1080        torch.manual_seed(42)
1081        model = ReduceModel(dim, mesh)
1082        fully_shard(model.mlps.mlp1, reshard_after_forward=False)
1083        fully_shard(model.mlps.mlp2, reshard_after_forward=False)
1084        fully_shard(model.mlps.mlp3, reshard_after_forward=False)
1085        fully_shard(model.mlps)
1086        replicate(model.cuda())
1087        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
1088        torch.manual_seed(42 + self.rank + 1)
1089        inp = torch.randn((batch_size, dim), device="cuda")
1090        for _ in range(10):
1091            losses: List[torch.Tensor] = []
1092            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
1093                losses.append(_model(inp).sum())
1094                losses[-1].backward()
1095                with implicit_replication():
1096                    _optim.step()
1097                _optim.zero_grad()
1098            self.assertEqual(losses[0], losses[1])
1099
1100
1101class TestFullyShardUnshardMultiThread(FSDPTestMultiThread):
1102    @property
1103    def world_size(self) -> int:
1104        return 2
1105
1106    @unittest.skipIf(not TEST_CUDA, "no cuda")
1107    def test_unshard_no_param_group(self):
1108        # Check that we can call `unshard()` on a module with no parameter
1109        # group / no managed parameters without erroring
1110        model = nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4))
1111        for lin in model:
1112            fully_shard(lin)
1113        fully_shard(model)
1114        handle = model.unshard(async_op=True)
1115        handle.wait()
1116
1117    @unittest.skipIf(not TEST_CUDA, "no cuda")
1118    def test_unshard_without_lazy_init(self):
1119        torch.manual_seed(42)
1120        model = MLP(4)
1121        for param in model.parameters():
1122            dist.broadcast(param, src=0)
1123        ref_model = copy.deepcopy(model)
1124        fully_shard(model)
1125        model.unshard()  # no lazy init yet
1126        for ref_param, param in zip(ref_model.parameters(), model.parameters()):
1127            self.assertEqual(ref_param, param)
1128
1129
1130if __name__ == "__main__":
1131    run_tests()
1132