xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_use_orig_params.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import functools
5import itertools
6import os
7import sys
8import unittest
9from typing import Any, Dict, List, Optional, Tuple, Type
10
11import torch
12import torch.nn as nn
13from torch import distributed as dist
14from torch.distributed.fsdp import (
15    BackwardPrefetch,
16    CPUOffload,
17    FullyShardedDataParallel as FSDP,
18    MixedPrecision,
19    ShardingStrategy,
20    StateDictType,
21)
22from torch.distributed.fsdp._common_utils import clean_tensor_name
23from torch.distributed.fsdp._flat_param import (
24    _FSDP_SKIP_WRITEBACK_CHECK,
25    _FSDP_USE_FULL_PREC_IN_EVAL,
26)
27from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
28from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy
29from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
30from torch.nn.parallel.distributed import DistributedDataParallel as DDP
31from torch.testing._internal.common_cuda import TEST_CUDA
32from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
33from torch.testing._internal.common_fsdp import (
34    CUDAInitMode,
35    FSDPInitMode,
36    FSDPTest,
37    TransformerWithSharedParams,
38)
39from torch.testing._internal.common_utils import (
40    instantiate_parametrized_tests,
41    parametrize,
42    run_tests,
43    TEST_WITH_DEV_DBG_ASAN,
44    TestCase,
45)
46from torch.utils._triton import has_triton
47
48
49if not dist.is_available():
50    print("Distributed not available, skipping tests", file=sys.stderr)
51    sys.exit(0)
52
53if TEST_WITH_DEV_DBG_ASAN:
54    print(
55        "Skip dev-asan as torch + multiprocessing spawn have known issues",
56        file=sys.stderr,
57    )
58    sys.exit(0)
59
60
61class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
62    """Tests multiple parameter groups."""
63
64    @property
65    def world_size(self) -> int:
66        return 2
67
68    def _get_param_groups(self, model: nn.Module) -> List[Dict[str, Any]]:
69        """
70        Constructs separate parameter groups for weights, biases, and other
71        parameters.
72        """
73        param_groups = [
74            {"params": [], "weight_decay": 0.1, "lr": 1e-2},
75            {"params": [], "weight_decay": 0.01, "lr": 1e-3},
76            {"params": []},
77        ]
78        for param_name, param in model.named_parameters():
79            if "weight" in param_name:
80                param_groups[0]["params"].append(param)
81            elif "bias" in param_name:
82                param_groups[1]["params"].append(param)
83            else:
84                param_groups[2]["params"].append(param)
85        return param_groups
86
87    def _get_optim(
88        self,
89        model: nn.Module,
90        optim_class: Type[torch.optim.Optimizer],
91        multi_tensor: bool,
92    ) -> torch.optim.Optimizer:
93        """
94        Constructs an Adam optimizer with three parameter groups, one for
95        weights, one for biases, and one for everything else, each with
96        different weight decay and learning rates.
97        """
98        param_groups = self._get_param_groups(model)
99        return optim_class(param_groups, lr=5e-3, foreach=multi_tensor)
100
101    def _get_ddp_transformer(self, find_unused_params: bool) -> DDP:
102        """Returns a transformer with shared parameters wrapped with DDP."""
103        model = TransformerWithSharedParams.init(
104            self.process_group,
105            FSDPInitMode.NO_FSDP,
106            CUDAInitMode.CUDA_BEFORE,
107            deterministic=True,
108        )
109        ddp_model = DDP(
110            model,
111            device_ids=[self.rank],
112            find_unused_parameters=find_unused_params,
113        )
114        return ddp_model
115
116    def _get_fsdp_transformer_and_optim(
117        self,
118        cuda_init_mode: CUDAInitMode,
119        init_optim_before_wrap: bool,
120        optim_class: Type[torch.optim.Optimizer],
121        multi_tensor: bool,
122        sharding_strategy: ShardingStrategy,
123        backward_prefetch: Optional[BackwardPrefetch],
124        cpu_offload: CPUOffload,
125    ) -> Tuple[FSDP, torch.optim.Optimizer]:
126        """
127        Returns a transformer with shared parameters wrapped with FSDP and a
128        corresponding optimizer.
129        """
130        # Each transformer layer has multiple linear layers, so this policy, in
131        # combination with the parameter group construction, ensures different
132        # hyperparameter settings within one `FlatParameter`
133        fsdp_kwargs = {
134            "auto_wrap_policy": ModuleWrapPolicy(
135                {
136                    TransformerEncoderLayer,
137                    TransformerDecoderLayer,
138                }
139            ),
140            "use_orig_params": True,
141            "sharding_strategy": sharding_strategy,
142            "backward_prefetch": backward_prefetch,
143            "cpu_offload": cpu_offload,
144        }
145        model = TransformerWithSharedParams.init(
146            self.process_group,
147            FSDPInitMode.NO_FSDP,
148            cuda_init_mode,
149            deterministic=True,
150        )
151        if init_optim_before_wrap:
152            fsdp_optim = self._get_optim(model, optim_class, multi_tensor)
153            fsdp_model = FSDP(model, self.process_group, **fsdp_kwargs)
154        else:
155            fsdp_model = FSDP(model, self.process_group, **fsdp_kwargs)
156            fsdp_optim = self._get_optim(fsdp_model, optim_class, multi_tensor)
157        if (
158            cuda_init_mode == CUDAInitMode.CUDA_AFTER
159            and not fsdp_model.cpu_offload.offload_params
160        ):
161            fsdp_model = fsdp_model.cuda()
162        return fsdp_model, fsdp_optim
163
164    def _check_train_parity(
165        self,
166        ddp_model: DDP,
167        ddp_optim: torch.optim.Optimizer,
168        fsdp_model: FSDP,
169        fsdp_optim: torch.optim.Optimizer,
170        set_to_none: bool,
171        num_iters: int = 10,
172    ):
173        """Checks training parity between DDP and FSDP."""
174        device = torch.device("cuda")
175        for i in range(num_iters):
176            iter_losses = []
177            for model, optim in ((ddp_model, ddp_optim), (fsdp_model, fsdp_optim)):
178                module = model.module
179                # Test two different `zero_grad()` timings
180                if i % 2 == 0:
181                    optim.zero_grad(set_to_none=set_to_none)  # pre-forward
182                inp = module.get_input(device)
183                output = model(*inp)
184                loss = module.get_loss(inp, output).to(device)
185                iter_losses.append(loss)
186                if i % 2 == 1:
187                    optim.zero_grad(set_to_none=set_to_none)  # pre-backward
188                module.run_backward(loss)
189                # Perform the DDP optimizer step on CPU to match FSDP if needed
190                if model is ddp_model and fsdp_model.cpu_offload.offload_params:
191                    model.to(torch.device("cpu"))
192                optim.step()
193                if model is ddp_model and fsdp_model.cpu_offload.offload_params:
194                    model.to(device)
195            torch.testing.assert_close(iter_losses[0], iter_losses[1])
196            iter_losses.clear()
197        self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)
198
199    def _check_ddp_fsdp_param_parity(self, ddp_model: DDP, fsdp_model: FSDP):
200        with FSDP.summon_full_params(fsdp_model):
201            for (n1, p1), (n2, p2) in zip(
202                ddp_model.module.named_parameters(), fsdp_model.named_parameters()
203            ):
204                # Allow for FSDP prefixes
205                self.assertEqual(n1, clean_tensor_name(n2))
206                torch.testing.assert_close(p1, p2)
207
208    def _get_sharding_strategy_from_str(
209        self, sharding_strategy_str: str
210    ) -> ShardingStrategy:
211        if sharding_strategy_str == "no_shard":
212            sharding_strategy = ShardingStrategy.NO_SHARD
213        elif sharding_strategy_str == "shard_grad_op":
214            sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
215        elif sharding_strategy_str == "full_shard":
216            sharding_strategy = ShardingStrategy.FULL_SHARD
217        else:
218            raise ValueError(f"Invalid string: {sharding_strategy_str}")
219        return sharding_strategy
220
221    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
222    @skip_if_lt_x_gpu(2)
223    def test_fsdp_compile(self):
224        self.run_subtests(
225            {
226                "sharding_strategy": [
227                    ShardingStrategy.FULL_SHARD,
228                    ShardingStrategy.SHARD_GRAD_OP,
229                    ShardingStrategy.NO_SHARD,
230                ],
231                "skip_fsdp_guards": [True, False],
232            },
233            self._test_fsdp_compile,
234        )
235
236    def _test_fsdp_compile(
237        self, sharding_strategy: ShardingStrategy, skip_fsdp_guards: bool
238    ):
239        torch._dynamo.config.skip_fsdp_guards = skip_fsdp_guards
240        fsdp_kwargs = {
241            "auto_wrap_policy": ModuleWrapPolicy(
242                {
243                    TransformerEncoderLayer,
244                    TransformerDecoderLayer,
245                }
246            ),
247            "use_orig_params": True,
248            "sharding_strategy": sharding_strategy,
249            "backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
250            "cpu_offload": CPUOffload(False),
251        }
252        base_model = TransformerWithSharedParams.init(
253            self.process_group,
254            FSDPInitMode.NO_FSDP,
255            CUDAInitMode.CUDA_BEFORE,
256            deterministic=True,
257        )
258        ref_model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs)
259        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
260        model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs)
261        model = torch.compile(model)
262        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
263        for i in range(10):
264            losses = []
265            inp = ref_model.get_input(torch.device("cuda"))
266            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
267                _optim.zero_grad()
268                loss = _model(*inp).sum()
269                losses.append(loss)
270                loss.backward()
271                _optim.step()
272            self.assertEqual(losses[0], losses[1])
273
274    @skip_if_lt_x_gpu(2)
275    @parametrize(
276        "sharding_strategy_str",
277        ["no_shard", "shard_grad_op", "full_shard"],
278    )
279    def test_diff_hyperparams(self, sharding_strategy_str: str):
280        """
281        Tests FSDP parity with DDP when using multiple parameter groups with
282        different hyperparameter settings.
283        """
284        sharding_strategy = self._get_sharding_strategy_from_str(sharding_strategy_str)
285        self.run_subtests(
286            {
287                "cuda_init_mode": [
288                    CUDAInitMode.CUDA_BEFORE,
289                    CUDAInitMode.CUDA_AFTER,
290                ],
291                "init_optim_before_wrap": [False, True],
292                "optim_class": [torch.optim.AdamW],
293                "multi_tensor": [False, True],
294                "set_to_none": [False, True],
295                "backward_prefetch": [
296                    None,
297                    BackwardPrefetch.BACKWARD_PRE,
298                    BackwardPrefetch.BACKWARD_POST,
299                ],
300                "skip_writeback_check": [False, True],
301            },
302            self._test_diff_hyperparams,
303            cpu_offload=CPUOffload(offload_params=False),
304            sharding_strategy=sharding_strategy,
305        )
306
307    @skip_if_lt_x_gpu(2)
308    @parametrize(
309        "sharding_strategy_str",
310        ["no_shard", "shard_grad_op", "full_shard"],
311    )
312    def test_diff_hyperparams_cpu_offload(self, sharding_strategy_str: str):
313        """
314        Tests FSDP parity with DDP when using multiple parameter groups with
315        different hyperparameter settings with CPU offloading enabled. This is
316        separate from :meth:`test_diff_hyperparams` because CPU offloading has
317        some issues with subtesting for some specific subtesting configs (e.g.,
318        with ``offload_params=False`` followed by ``True`` but not vice versa).
319        """
320        sharding_strategy = self._get_sharding_strategy_from_str(sharding_strategy_str)
321        for skip_writeback_check in (False, True):
322            self._test_diff_hyperparams(
323                cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
324                init_optim_before_wrap=False,
325                optim_class=torch.optim.Adam,
326                multi_tensor=False,
327                set_to_none=False,
328                backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
329                cpu_offload=CPUOffload(offload_params=True),
330                sharding_strategy=sharding_strategy,
331                skip_writeback_check=skip_writeback_check,
332            )
333
334    def _test_diff_hyperparams(
335        self,
336        cuda_init_mode: CUDAInitMode,
337        init_optim_before_wrap: bool,
338        optim_class: Type[torch.optim.Optimizer],
339        multi_tensor: bool,
340        set_to_none: bool,
341        backward_prefetch: Optional[BackwardPrefetch],
342        cpu_offload: CPUOffload,
343        sharding_strategy: ShardingStrategy,
344        skip_writeback_check: bool,
345    ):
346        """
347        Args:
348            init_optim_before_wrap (bool): If ``True``, initializes the
349                FSDP optimizer before wrapping the model with FSDP; otherwise,
350                initializes the FSDP optimizer after wrapping the model with
351                FSDP. We permit both forms of initialization to give users
352                flexibility.
353        """
354        if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params:
355            return  # not supported
356        if skip_writeback_check:
357            os.environ[_FSDP_SKIP_WRITEBACK_CHECK] = "1"
358        ddp_model = self._get_ddp_transformer(find_unused_params=False)
359        ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor)
360        fsdp_model, fsdp_optim = self._get_fsdp_transformer_and_optim(
361            cuda_init_mode=cuda_init_mode,
362            init_optim_before_wrap=init_optim_before_wrap,
363            optim_class=optim_class,
364            multi_tensor=multi_tensor,
365            sharding_strategy=sharding_strategy,
366            backward_prefetch=backward_prefetch,
367            cpu_offload=cpu_offload,
368        )
369        self._check_train_parity(
370            ddp_model, ddp_optim, fsdp_model, fsdp_optim, set_to_none
371        )
372
373    @skip_if_lt_x_gpu(2)
374    def test_diff_trainability(self):
375        """
376        Tests FSDP parity with DDP when using multiple parameter groups and
377        freezing the parameters in one parameter group.
378        """
379        self.run_subtests(
380            {
381                "multi_tensor": [False, True],
382                "sharding_strategy": [
383                    ShardingStrategy.FULL_SHARD,
384                    ShardingStrategy.SHARD_GRAD_OP,
385                    ShardingStrategy.NO_SHARD,
386                ],
387            },
388            self._test_diff_trainability,
389        )
390
391    def _test_diff_trainability(
392        self,
393        multi_tensor: bool,
394        sharding_strategy: ShardingStrategy,
395    ):
396        optim_class = torch.optim.Adam
397        ddp_model = self._get_ddp_transformer(find_unused_params=True)
398        ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor)
399        fsdp_model, fsdp_optim = self._get_fsdp_transformer_and_optim(
400            cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
401            init_optim_before_wrap=False,
402            optim_class=optim_class,
403            multi_tensor=multi_tensor,
404            sharding_strategy=sharding_strategy,
405            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
406            cpu_offload=None,
407        )
408        # Freeze all biases (which happen to be in the same parameter group)
409        for param_name, param in ddp_model.named_parameters():
410            if "bias" in param_name:
411                param.requires_grad_(False)
412        for param_name, param in fsdp_model.named_parameters():
413            if "bias" in param_name:
414                param.requires_grad_(False)
415        self._check_train_parity(ddp_model, ddp_optim, fsdp_model, fsdp_optim, False)
416
417    @skip_if_lt_x_gpu(2)
418    def test_multiple_optimizers(self):
419        """
420        Tests using two optimizers where only one sets gradients to ``None``.
421        """
422        self.run_subtests(
423            {
424                "sharding_strategy": [
425                    ShardingStrategy.FULL_SHARD,
426                    ShardingStrategy.SHARD_GRAD_OP,
427                ]
428            },
429            self._test_multiple_optimizers,
430        )
431
432    def _test_multiple_optimizers(self, sharding_strategy: ShardingStrategy):
433        ddp_model = self._get_ddp_transformer(find_unused_params=True)
434        ddp_param_groups = self._get_param_groups(ddp_model)
435        assert len(ddp_param_groups) == 3, f"{len(ddp_param_groups)}"
436        (
437            fsdp_model,
438            _,
439        ) = self._get_fsdp_transformer_and_optim(  # ignore returned optimizer
440            cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
441            init_optim_before_wrap=False,
442            optim_class=torch.optim.Adam,  # ignored
443            multi_tensor=False,  # ignored
444            sharding_strategy=sharding_strategy,
445            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
446            cpu_offload=None,
447        )
448        fsdp_param_groups = self._get_param_groups(fsdp_model)
449        assert len(fsdp_param_groups) == 3, f"{len(fsdp_param_groups)}"
450        ddp_optims = []
451        fsdp_optims = []
452        # For the transformer model, every parameter is either a weight or a
453        # bias, so we only use the first two parameter groups. Moreover, we use
454        # Adam and AdamW in particular since they both use bias correction
455        # dependent on the step, which is incremented even if a parameter has a
456        # zero gradient but not if the gradient is `None`. This is to test that
457        # we are differentiating between a zero and `None` gradient correctly.
458        optim_ctors = [
459            functools.partial(torch.optim.Adam, lr=5e-3),
460            functools.partial(torch.optim.AdamW, lr=1e-2),
461        ]
462
463        for optim_ctor, ddp_param_group, fsdp_param_group in zip(
464            optim_ctors,
465            ddp_param_groups[:2],
466            fsdp_param_groups[:2],
467        ):
468            ddp_optims.append(optim_ctor(ddp_param_group["params"]))
469            fsdp_optims.append(optim_ctor(fsdp_param_group["params"]))
470        device = torch.device("cuda")
471
472        # Check that there exists a `FlatParameter` that has both a weight and
473        # a bias in this rank's shard
474        has_both = False
475        for fsdp_module in FSDP.fsdp_modules(fsdp_model):
476            handle = fsdp_module._handle
477            if not handle:
478                continue
479            flat_param = handle.flat_param
480            assert flat_param._params is not None
481            has_weight = False
482            has_bias = False
483            for param, fqn in zip(flat_param._params, flat_param._fqns):
484                if "weight" in fqn and param.numel() > 0:
485                    has_weight = True
486                elif "bias" in fqn and param.numel() > 0:
487                    has_bias = True
488            has_both |= has_weight and has_bias
489        assert has_both, (
490            f"Rank {self.rank} does not have a `FlatParameter` with both a "
491            "weight and a bias in its shard, meaning that this test is vacuous"
492        )
493
494        # Run one iteration to generate gradients
495        def run_iter():
496            iter_losses = []
497            for model, optims in ((ddp_model, ddp_optims), (fsdp_model, fsdp_optims)):
498                module = model.module
499                inp = module.get_input(device)
500                output = model(*inp)
501                loss = module.get_loss(inp, output).to(device)
502                iter_losses.append(loss)
503                module.run_backward(loss)
504                for optim in optims:
505                    optim.step()
506            torch.testing.assert_close(iter_losses[0], iter_losses[1])
507            iter_losses.clear()
508            self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)
509
510        run_iter()
511
512        # Only set the weights' gradients to None
513        ddp_optims[0].zero_grad(set_to_none=True)
514        fsdp_optims[0].zero_grad(set_to_none=True)
515        inp = ddp_model.module.get_input(device)
516        ddp_output = ddp_model(*inp)
517        fsdp_output = fsdp_model(*inp)
518
519        # Check that FSDP correctly exposes gradients even after forward
520        # (namely, `None` for weights and non-`None` for biases)
521        if sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES:
522            # Skip the check since we do not expose the gradients after forward
523            # for these strategies
524            return
525        for (ddp_n, ddp_p), (fsdp_n, fsdp_p) in zip(
526            ddp_model.module.named_parameters(),
527            fsdp_model.named_parameters(),
528        ):
529            self.assertEqual(ddp_n, clean_tensor_name(fsdp_n))
530            if fsdp_p.numel() == 0:
531                # Not in this rank's shard
532                self.assertTrue(fsdp_p.grad is None)
533                continue
534            if ddp_p.grad is None:
535                self.assertTrue(fsdp_p.grad is None)
536            else:
537                self.assertEqual(ddp_p.flatten(), fsdp_p.flatten())
538                self.assertEqual(ddp_p.grad.flatten(), fsdp_p.grad.flatten())
539        self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)
540
541        # Finish the iteration (backward pass and optimizer step)
542        ddp_loss = ddp_model.module.get_loss(inp, ddp_output).to(device)
543        fsdp_loss = fsdp_model.module.get_loss(inp, fsdp_output).to(device)
544        ddp_model.module.run_backward(ddp_loss)
545        fsdp_model.module.run_backward(fsdp_loss)
546        for optim in itertools.chain(ddp_optims, fsdp_optims):
547            optim.step()
548        self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)
549
550        # Run one more iteration to confirm bias corrections are correct
551        run_iter()
552        self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)
553
554
555class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
556    """Tests the unshard/reshard flow."""
557
558    @property
559    def world_size(self) -> int:
560        return 2
561
562    def _get_fsdp_models_and_optims(
563        self,
564        sharding_strategy: ShardingStrategy,
565        cpu_offload: CPUOffload,
566    ) -> Tuple[FSDP, torch.optim.Optimizer, FSDP, torch.optim.Optimizer]:
567        """
568        Returns a pair of (FSDP model, optimizer) for ``use_orig_params=False``
569        and ``True``, respectively.
570        """
571        LR = 1e-2
572        fsdp_kwargs = {
573            "sharding_strategy": sharding_strategy,
574            "cpu_offload": cpu_offload,
575            "use_orig_params": False,
576        }
577        fsdp_model = TransformerWithSharedParams.init(
578            self.process_group,
579            FSDPInitMode.RECURSIVE,
580            CUDAInitMode.CUDA_BEFORE,
581            fsdp_kwargs=fsdp_kwargs,
582            deterministic=True,
583        )
584        optim = torch.optim.Adam(fsdp_model.parameters(), foreach=False, lr=LR)
585        fsdp_kwargs["use_orig_params"] = True
586        fsdp_model_orig_params = TransformerWithSharedParams.init(
587            self.process_group,
588            FSDPInitMode.RECURSIVE,
589            CUDAInitMode.CUDA_BEFORE,
590            fsdp_kwargs=fsdp_kwargs,
591            deterministic=True,
592        )
593        optim_orig_params = torch.optim.Adam(
594            fsdp_model_orig_params.parameters(), foreach=False, lr=LR
595        )
596        return fsdp_model, optim, fsdp_model_orig_params, optim_orig_params
597
598    def _check_fsdp_parameter_parity(self, fsdp1: FSDP, fsdp2: FSDP) -> None:
599        """Checks that two FSDP instances have the same model parameters."""
600        with FSDP.summon_full_params(fsdp1), FSDP.summon_full_params(fsdp2):
601            for (n1, p1), (n2, p2) in zip(
602                fsdp1.named_parameters(),
603                fsdp2.named_parameters(),
604            ):
605                self.assertEqual(n1, n2)
606                torch.testing.assert_close(p1, p2)
607
608    def _get_fsdp_parity_subtest_config(self):
609        return {
610            "sharding_strategy": [
611                ShardingStrategy.NO_SHARD,
612                ShardingStrategy.SHARD_GRAD_OP,
613                ShardingStrategy.FULL_SHARD,
614            ],
615        }
616
617    @skip_if_lt_x_gpu(2)
618    @parametrize("offload_params", [False, True])
619    def test_multiple_forward(self, offload_params: bool):
620        """
621        Tests that ``use_orig_params=True`` has parity with ``False`` when
622        running multiple forward passes before a backward pass.
623        """
624        cpu_offload = CPUOffload(offload_params=offload_params)
625        self.run_subtests(
626            self._get_fsdp_parity_subtest_config(),
627            self._test_multiple_forward,
628            cpu_offload=cpu_offload,
629        )
630
631    @skip_if_lt_x_gpu(2)
632    def _test_multiple_forward(
633        self,
634        sharding_strategy: ShardingStrategy,
635        cpu_offload: CPUOffload,
636    ):
637        (
638            fsdp_model,
639            optim,
640            fsdp_model_orig_params,
641            optim_orig_params,
642        ) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload)
643        device = torch.device("cuda")
644        for _ in range(3):
645            inp1 = fsdp_model.get_input(device)
646            _inp2 = fsdp_model.get_input(device)
647            inp2 = tuple(
648                t + torch.ones_like(t) for t in _inp2
649            )  # make different from `inp1`
650            # For these loss lists: elem 0 is baseline; elem 1 is test
651            losses1 = []
652            losses2 = []
653            losses = []
654            for _model, _optim in (fsdp_model, optim), (
655                fsdp_model_orig_params,
656                optim_orig_params,
657            ):
658                _optim.zero_grad()
659                loss1 = _model(*inp1)
660                losses1.append(loss1)
661                loss2 = _model(*inp2)
662                losses2.append(loss2)
663                loss = (loss1 + loss2).sum()
664                losses.append(loss)
665                _model.run_backward(loss)
666                _optim.step()
667            self.assertEqual(losses1[0], losses1[1])
668            self.assertEqual(losses2[0], losses2[1])
669            self.assertEqual(losses[0], losses[1])
670        self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params)
671
672    @skip_if_lt_x_gpu(2)
673    @parametrize("offload_params", [False, True])
674    def test_summon_between_two_forwards(self, offload_params: bool):
675        """
676        Tests that ``use_orig_params=True`` has parity with ``False`` when
677        running a forward pass, :meth:`summon_full_params()`, and another
678        forward pass before a backward pass.
679        """
680        cpu_offload = CPUOffload(offload_params=offload_params)
681        self.run_subtests(
682            self._get_fsdp_parity_subtest_config(),
683            self._test_summon_between_two_forwards,
684            cpu_offload=cpu_offload,
685        )
686
687    def _test_summon_between_two_forwards(
688        self,
689        sharding_strategy: ShardingStrategy,
690        cpu_offload: CPUOffload,
691    ):
692        (
693            fsdp_model,
694            optim,
695            fsdp_model_orig_params,
696            optim_orig_params,
697        ) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload)
698        device = torch.device("cuda")
699        for _ in range(3):
700            optim.zero_grad()
701            optim_orig_params.zero_grad()
702
703            inp1 = fsdp_model.get_input(device)
704            loss1 = fsdp_model(*inp1)
705            loss_orig_params1 = fsdp_model_orig_params(*inp1)
706            self.assertEqual(loss1, loss_orig_params1)
707
708            # Calls into `summon_full_params()`
709            self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params)
710
711            inp2 = fsdp_model.get_input(device)
712            loss2 = fsdp_model(*inp2)
713            loss_orig_params2 = fsdp_model_orig_params(*inp2)
714            self.assertEqual(loss2, loss_orig_params2)
715
716            loss = (loss1 + loss2).sum()
717            loss_orig_params = (loss_orig_params1 + loss_orig_params2).sum()
718            fsdp_model.run_backward(loss)
719            fsdp_model_orig_params.run_backward(loss_orig_params)
720            optim.step()
721            optim_orig_params.step()
722        self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params)
723
724
725class TestFSDPUseOrigParamsParamAccess(FSDPTest):
726    """Tests original parameter access."""
727
728    @property
729    def world_size(self):
730        # Force a world size of 2 since the tests hard code to the FSDP
731        # sharding strategy to check sharded parameter parity
732        return 2
733
734    @skip_if_lt_x_gpu(2)
735    def test_access_params_after_forward(self):
736        """
737        Tests that accessing the original parameters after the forward but
738        before the backward. Notably, this is not supported when
739        ``use_orig_params=False``. However, for ``True``, FSDP exposes the
740        (flattened) sharded original parameters, making it possible.
741        """
742        self.run_subtests(
743            {
744                "sharding_strategy": [
745                    ShardingStrategy.NO_SHARD,
746                    ShardingStrategy.FULL_SHARD,
747                    ShardingStrategy.SHARD_GRAD_OP,
748                ],
749            },
750            self._test_access_params_after_forward,
751        )
752
753    def _test_access_params_after_forward(
754        self,
755        sharding_strategy: ShardingStrategy,
756    ):
757        # NOTE: This test needs to be changed if the FSDP sharding algorithm
758        # changes. It is still valuable until such a change to sanity check the
759        # `use_orig_params=True` implementation.
760        class Model(nn.Module):
761            def __init__(self) -> None:
762                super().__init__()
763                torch.manual_seed(42)
764                # 5 * 5 = 25 numel -> pad to 26 -> 13 on each rank
765                self.lin1 = nn.Linear(5, 5, bias=False)
766                # 5 * 7 + (1) + 7 = 43 numel -> pad to 44 -> 22 on each rank,
767                # where the (1) is from intra-`FlatParameter` alignment padding
768                # 22 of weight on rank 0; 13 of weight, 1 alignment padding,
769                # and 7 of bias on rank 1
770                self.lin2 = nn.Linear(5, 7)
771
772            def forward(self, x: torch.Tensor) -> torch.Tensor:
773                z = self.lin1(x)
774                z = nn.functional.relu(z)
775                z = self.lin2(z)
776                return z
777
778            def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]:
779                return (torch.randn((2, 5)).to(device),)
780
781            def get_loss(self, inp, out):
782                return out.sum()
783
784        def check_parameter_parity(
785            ddp_model: DDP, fsdp_model: FSDP, between_fwd_and_bwd: bool
786        ):
787            assert self.rank in (
788                0,
789                1,
790            ), f"Expects world size of 2 but got {self.world_size}"
791            for (n1, p1), (n2, p2) in zip(
792                ddp_model.module.named_parameters(),
793                fsdp_model.named_parameters(),
794            ):
795                self.assertEqual(n1, clean_tensor_name(n2))
796                if sharding_strategy == ShardingStrategy.NO_SHARD:
797                    # For `NO_SHARD`, do nothing since the original parameters
798                    # are unflattened
799                    pass
800                elif (
801                    between_fwd_and_bwd
802                    and sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES
803                ):
804                    # For no reshard after forward strategies, do nothing since
805                    # FSDP did not use sharded views after forward
806                    pass
807                # Otherwise, case on the parameter (see the model definition)
808                elif n1 == "lin1.weight":
809                    if self.rank == 0:
810                        p1 = p1.flatten()[:13]
811                    elif self.rank == 1:
812                        p1 = p1.flatten()[13:]
813                elif n1 == "lin2.weight":
814                    if self.rank == 0:
815                        p1 = p1.flatten()[:22]
816                    elif self.rank == 1:
817                        p1 = p1.flatten()[22:]
818                elif n1 == "lin2.bias":
819                    if self.rank == 0:
820                        p1 = torch.empty(0, device=p1.device)
821                    elif self.rank == 1:
822                        p1 = p1.flatten()
823                torch.testing.assert_close(p1, p2)
824
825        ddp_model = DDP(Model().cuda(), device_ids=[self.rank])
826        fsdp_model = FSDP(
827            Model().cuda(),
828            sharding_strategy=sharding_strategy,
829            auto_wrap_policy=always_wrap_policy,
830            use_orig_params=True,
831        )
832        LR = 1e-2
833        ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
834        fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
835        device = torch.device("cuda")
836
837        inp = fsdp_model.get_input(device)
838        ddp_out = ddp_model(*inp)
839        fsdp_out = fsdp_model(*inp)
840        check_parameter_parity(ddp_model, fsdp_model, True)
841
842        ddp_loss = ddp_model.module.get_loss(inp, ddp_out)
843        fsdp_loss = fsdp_model.get_loss(inp, fsdp_out)
844        ddp_loss.backward()
845        fsdp_loss.backward()
846        ddp_optim.step()
847        fsdp_optim.step()
848        check_parameter_parity(ddp_model, fsdp_model, False)
849
850        inp = fsdp_model.get_input(device)
851        ddp_out = ddp_model(*inp)
852        fsdp_out = fsdp_model(*inp)
853        check_parameter_parity(ddp_model, fsdp_model, True)
854
855
856class TestFSDPUseOrigParamsWriteback(FSDPTest):
857    """Tests parameter and gradient writeback."""
858
859    class Model(nn.Module):
860        def __init__(self, device: torch.device):
861            super().__init__()
862            torch.manual_seed(42)
863            self.lin1 = nn.Linear(5, 5, bias=True, device=device)
864            self.lin2 = nn.Linear(5, 7, bias=True, device=device)
865
866        def forward(self, x: torch.Tensor) -> torch.Tensor:
867            z = self.lin1(x)
868            z = nn.functional.relu(z)
869            z = self.lin2(z)
870            return z
871
872        def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]:
873            return (torch.randn((2, 5)).to(device),)
874
875        def get_loss(self, inp, out):
876            return out.sum()
877
878    @property
879    def world_size(self):
880        # Force a world size of 2 since the tests hard code to the FSDP
881        # sharding strategy
882        return 2
883
884    def _check_param_parity(self, ddp_model: DDP, fsdp_model: FSDP):
885        with FSDP.summon_full_params(fsdp_model):
886            for (n1, p1), (n2, p2) in zip(
887                ddp_model.module.named_parameters(),
888                fsdp_model.named_parameters(),
889            ):
890                self.assertEqual(n1, n2)
891                torch.testing.assert_close(p1, p2)
892
893    @skip_if_lt_x_gpu(2)
894    def test_param_writeback(self):
895        """Tests that changes to the original parameters are written back."""
896        self.run_subtests(
897            {
898                "change_first_weight": [True, False],  # first vs. second `weight`
899                "change_data": [True, False],  # change `.data` vs. variable itself
900            },
901            self._test_param_writeback,
902        )
903
904    def _test_param_writeback(self, change_first_weight: bool, change_data: bool):
905        def transform_param(param: nn.Parameter) -> nn.Parameter:
906            return nn.Parameter(torch.ones_like(param) * 2)
907
908        # Check that the writeback propagates
909        ddp_model = DDP(
910            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
911            device_ids=[self.rank],
912        )
913        fsdp_model = FSDP(
914            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
915            use_orig_params=True,
916        )
917        ddp = ddp_model.module  # for brevity
918        fsdp = fsdp_model.module
919        if change_first_weight:
920            if change_data:
921                ddp.lin1.weight.data = transform_param(ddp.lin1.weight)
922                fsdp.lin1.weight.data = transform_param(fsdp.lin1.weight)
923            else:
924                ddp.lin1.weight = transform_param(ddp.lin1.weight)
925                fsdp.lin1.weight = transform_param(fsdp.lin1.weight)
926        else:
927            if change_data:
928                ddp.lin2.weight.data = transform_param(ddp.lin2.weight)
929                fsdp.lin2.weight.data = transform_param(fsdp.lin2.weight)
930            else:
931                ddp.lin2.weight = transform_param(ddp.lin2.weight)
932                fsdp.lin2.weight = transform_param(fsdp.lin2.weight)
933        self._check_param_parity(ddp_model, fsdp_model)  # triggers a writeback
934
935    @skip_if_lt_x_gpu(2)
936    def test_grad_writeback(self):
937        """
938        Tests that changes to the original parameters' gradients are written
939        back.
940        """
941        self.run_subtests(
942            {
943                "change_first_weight_grad": [False, True],
944                "change_data": [False, True],  # change `.data` vs. variable itself
945                "set_to_none": [False, True],
946            },
947            self._test_grad_writeback,
948        )
949
950    def _test_grad_writeback(
951        self,
952        change_first_weight_grad: bool,
953        change_data: bool,
954        set_to_none: bool,
955    ):
956        if change_data and set_to_none:
957            return  # not well-defined
958
959        def transform_grad(param: nn.Parameter) -> nn.Parameter:
960            return None if set_to_none else torch.ones_like(param) * 2
961
962        ddp_model = DDP(
963            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
964            device_ids=[self.rank],
965        )
966        fsdp_model = FSDP(
967            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
968            use_orig_params=True,
969        )
970        LR = 1e-2
971        # TODO: If we add `summon_full_params(with_grads=True)`, then replace
972        # the following. For now, we use the optimizer step as a surrogate for
973        # checking that gradients were written back.
974        ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
975        fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
976
977        # Generate an initial gradient
978        inp = fsdp_model.get_input(torch.device("cuda"))
979        ddp_out = ddp_model(*inp)
980        fsdp_out = fsdp_model(*inp)
981        ddp_out.sum().backward()
982        fsdp_out.sum().backward()
983
984        # Change the gradient through the original parameters
985        ddp = ddp_model.module  # for brevity
986        fsdp = fsdp_model.module
987        if change_first_weight_grad:
988            if change_data:
989                ddp.lin1.weight.grad.data = transform_grad(ddp.lin1.weight)
990                if fsdp.lin1.weight.grad is not None:
991                    fsdp.lin1.weight.grad.data = transform_grad(fsdp.lin1.weight)
992            else:
993                ddp.lin1.weight.grad = transform_grad(ddp.lin1.weight)
994                fsdp.lin1.weight.grad = transform_grad(fsdp.lin1.weight)
995        else:
996            if change_data:
997                ddp.lin2.weight.grad.data = transform_grad(ddp.lin2.weight)
998                if fsdp.lin2.weight.grad is not None:
999                    fsdp.lin2.weight.grad.data = transform_grad(fsdp.lin2.weight)
1000            else:
1001                ddp.lin2.weight.grad = transform_grad(ddp.lin2.weight)
1002                fsdp.lin2.weight.grad = transform_grad(fsdp.lin2.weight)
1003        ddp_optim.step()
1004        fsdp_optim.step()
1005        self._check_param_parity(ddp_model, fsdp_model)  # triggers a writeback
1006
1007        # Intentionally do not zero the gradient to check writeback
1008        inp = fsdp_model.get_input(torch.device("cuda"))
1009        ddp_out = ddp_model(*inp)
1010        fsdp_out = fsdp_model(*inp)
1011        ddp_out.sum().backward()
1012        fsdp_out.sum().backward()
1013        ddp_optim.step()
1014        fsdp_optim.step()
1015        self._check_param_parity(ddp_model, fsdp_model)  # triggers a writeback
1016
1017    @skip_if_lt_x_gpu(2)
1018    def test_writeback_shape_mismatch(self):
1019        fsdp_model = FSDP(
1020            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
1021            use_orig_params=True,
1022        )
1023        # Check that writing back with mismatched shape errors
1024        fsdp = fsdp_model.module  # for brevity
1025        assert self.rank in (0, 1), f"Expects world size of 2 but got {self.world_size}"
1026        with self.assertRaisesRegex(RuntimeError, "Cannot writeback"):
1027            # Change the gradient to a new one with 1 added to each dimension
1028            # to force a shape mismatch when writing back
1029            if self.rank == 0:
1030                # Change `lin1.weight.grad` since it exists on rank 0
1031                lin1_weight_shape = list(fsdp.lin1.weight.shape)
1032                for dim_index in range(len(lin1_weight_shape)):
1033                    lin1_weight_shape[dim_index] += 1
1034                fsdp.lin1.weight = nn.Parameter(
1035                    torch.randn(
1036                        torch.Size(lin1_weight_shape), device=fsdp.lin1.weight.device
1037                    )
1038                )
1039                fsdp.lin1.weight.grad = torch.randn(
1040                    torch.Size(lin1_weight_shape), device=fsdp.lin1.weight.device
1041                )
1042            elif self.rank == 1:
1043                # Change `lin2.weight.grad` since it exists (partially) on rank 1
1044                lin2_weight_shape = list(fsdp.lin2.weight.shape)
1045                for dim_index in range(len(lin2_weight_shape)):
1046                    lin2_weight_shape[dim_index] += 1
1047                fsdp.lin2.weight = nn.Parameter(
1048                    torch.randn(
1049                        torch.Size(lin2_weight_shape), device=fsdp.lin2.weight.device
1050                    )
1051                )
1052                fsdp.lin2.weight.grad = torch.randn(
1053                    torch.Size(lin2_weight_shape), device=fsdp.lin2.weight.device
1054                )
1055            with FSDP.summon_full_params(fsdp_model):  # triggers a writeback
1056                ...
1057
1058    @skip_if_lt_x_gpu(2)
1059    def test_writeback_between_fwd_and_bwd_for_no_reshard_raises(self):
1060        fsdp_kwargs = {
1061            "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
1062            "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
1063            "use_orig_params": True,
1064        }
1065        fsdp_wrapper = functools.partial(FSDP, **fsdp_kwargs)
1066
1067        # Test changing the parameter storage to no longer be a view into the
1068        # flat parameter
1069        fsdp_model = fsdp_wrapper(
1070            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
1071        )
1072        inp = fsdp_model.get_input(torch.device("cuda"))
1073        loss = fsdp_model(*inp).sum()
1074        fsdp_model.lin1.weight.data = fsdp_model.lin1.weight.clone()
1075        assert_msg = (
1076            "FSDP does not support changing the parameters between forward and backward"
1077        )
1078        with self.assertRaisesRegex(AssertionError, assert_msg):
1079            loss.backward()
1080
1081        # Test changing the parameter variable itself
1082        fsdp_model = fsdp_wrapper(
1083            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
1084        )
1085        inp = fsdp_model.get_input(torch.device("cuda"))
1086        loss = fsdp_model(*inp).sum()
1087        fsdp_model.lin1._fsdp_wrapped_module.weight = nn.Parameter(
1088            fsdp_model.lin1.weight.clone()
1089        )
1090        with self.assertRaisesRegex(AssertionError, assert_msg):
1091            loss.backward()
1092
1093    @skip_if_lt_x_gpu(2)
1094    def test_no_reshard_and_mixed_precision(self):
1095        """
1096        Tests that writeback does not falsely get triggered for a few
1097        configurations (exercising the sharded view skipping logic):
1098        - Train forward -> full-precision unshard -> train forward
1099        - Train forward -> eval forward
1100        - Train forward/backward -> eval forward -> model checkpoint
1101        """
1102        self.run_subtests(
1103            {"use_full_prec_in_eval": [False, True]},
1104            self._test_no_reshard_and_mixed_precision,
1105        )
1106
1107    def _test_no_reshard_and_mixed_precision(self, use_full_prec_in_eval: bool):
1108        if use_full_prec_in_eval:
1109            os.environ[_FSDP_USE_FULL_PREC_IN_EVAL] = "1"
1110        fsdp_kwargs = {
1111            "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
1112            "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
1113            "mixed_precision": MixedPrecision(param_dtype=torch.float16),
1114            "use_orig_params": True,
1115        }
1116
1117        # Train forward -> full-precision unshard -> train forward
1118        fsdp_model = FSDP(
1119            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), **fsdp_kwargs
1120        )
1121        inp = fsdp_model.get_input(torch.device("cuda"))
1122        fsdp_model(*inp)
1123        with FSDP.summon_full_params(fsdp_model):
1124            ...
1125        fsdp_model(*inp).sum()
1126
1127        # Train forward -> eval forward
1128        fsdp_model.train()
1129        fsdp_model(*inp)
1130        fsdp_model.eval()
1131        fsdp_model(*inp)
1132
1133        # Train forward/backward -> eval forward -> model checkpoint
1134        fsdp_model.train()
1135        fsdp_model(*inp).sum().backward()
1136        fsdp_model.eval()
1137        fsdp_model(*inp)
1138        with FSDP.state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT):
1139            sd = fsdp_model.state_dict()
1140            fsdp_model.load_state_dict(sd)
1141        fsdp_model(*inp).sum().backward()
1142
1143
1144class TestFSDPUseOrigParamsFQNs(FSDPTest):
1145    @skip_if_lt_x_gpu(2)
1146    def test_named_parameters_in_forward(self):
1147        """
1148        Tests that calling ``named_parameters()`` during forward returns FQNs
1149        and ``Tensor`` s corresponding to the original parameters.
1150        """
1151        param_shapes = [None, None]
1152        assert_equal_fn = self.assertEqual
1153
1154        class Model(nn.Module):
1155            def __init__(self) -> None:
1156                super().__init__()
1157                self.lin = nn.Linear(5, 5)
1158
1159            def forward(self, x: torch.Tensor) -> torch.Tensor:
1160                nonlocal param_shapes
1161                # Allow for FSDP prefixes
1162                param_names = [
1163                    clean_tensor_name(tup[0]) for tup in self.named_parameters()
1164                ]
1165                params = [tup[1] for tup in self.named_parameters()]
1166                assert (
1167                    param_shapes[0] is not None and param_shapes[1] is not None
1168                ), "`param_sizes` should be set"
1169                assert_equal_fn(
1170                    param_names,
1171                    [
1172                        "lin.weight",
1173                        "lin.bias",
1174                    ],
1175                )
1176                assert_equal_fn(params[0].shape, param_shapes[0])
1177                assert_equal_fn(params[1].shape, param_shapes[1])
1178                return self.lin(x)
1179
1180        model = Model().cuda()
1181        # Save the *unsharded* original parameter shapes and check the shapes
1182        # match in the forward pass
1183        param_shapes[0] = model.lin.weight.shape
1184        param_shapes[1] = model.lin.bias.shape
1185        fsdp_model = FSDP(model, use_orig_params=True)
1186        inp = torch.randn((2, 5), device=torch.device("cuda"))
1187        fsdp_model(inp)
1188
1189
1190class TestFSDPUseOrigParamsNoSync(FSDPTest):
1191    @property
1192    def world_size(self) -> int:
1193        return 2
1194
1195    @skip_if_lt_x_gpu(2)
1196    def test_no_sync_correctness(self):
1197        """
1198        Tests a basic ``no_sync()`` setup by comparing ``use_orig_params=True``
1199        against ``use_orig_params=False``.
1200        """
1201        self.run_subtests(
1202            {
1203                "sharding_strategy": [
1204                    ShardingStrategy.FULL_SHARD,
1205                    ShardingStrategy.SHARD_GRAD_OP,
1206                    ShardingStrategy.NO_SHARD,
1207                ],
1208            },
1209            self._test_no_sync_correctness,
1210        )
1211
1212    def _test_no_sync_correctness(self, sharding_strategy: ShardingStrategy):
1213        model = nn.Linear(7, 1, bias=False, device="cuda")
1214        fsdp_kwargs = {
1215            "sharding_strategy": sharding_strategy,
1216        }
1217        model_use_flat_params = FSDP(
1218            copy.deepcopy(model), use_orig_params=False, **fsdp_kwargs
1219        )
1220        model_use_orig_params = FSDP(model, use_orig_params=True, **fsdp_kwargs)
1221        optim_use_flat_params = torch.optim.AdamW(
1222            model_use_flat_params.parameters(), foreach=True
1223        )
1224        optim_use_orig_params = torch.optim.AdamW(
1225            model_use_orig_params.parameters(), foreach=True
1226        )
1227
1228        def _check_param_grad_parity(
1229            _baseline_model: nn.Module,
1230            _test_model: nn.Module,
1231        ):
1232            """
1233            This assumes that the model is ``nn.Linear(7, 1, bias=False)``
1234            (i.e. with a single 1D weight parameter) to be able to directly
1235            compare the baseline and test models. On rank 1, the baseline
1236            includes 1 element of padding.
1237            """
1238            self.assertEqual(len(list(_baseline_model.parameters())), 1)
1239            self.assertEqual(len(list(_test_model.parameters())), 1)
1240            for flat_param, orig_param in zip(
1241                _baseline_model.parameters(), _test_model.parameters()
1242            ):
1243                # Baseline is permitted to have padding
1244                self.assertGreaterEqual(flat_param.numel(), orig_param.numel())
1245                unpadded_param_numel = orig_param.numel()
1246                # For `NO_SHARD`, `use_orig_params=True` presents unflattened
1247                # parameters, while `False` presents flattened ones
1248                torch.testing.assert_close(
1249                    flat_param[:unpadded_param_numel], orig_param.flatten()
1250                )
1251                # Gradient numel is different if right after `no_sync()` since
1252                # the gradient is unsharded, while the parameter is sharded
1253                unpadded_grad_numel = orig_param.grad.numel()
1254                # For `use_orig_params=False`, the unsharded gradient is
1255                # flattened, while for `True`, it is unflattened
1256                torch.testing.assert_close(
1257                    flat_param.grad[:unpadded_grad_numel].reshape(
1258                        orig_param.grad.shape
1259                    ),
1260                    orig_param.grad,
1261                )
1262
1263        inp = torch.randn((2, 7), device="cuda")
1264        grad = torch.randn((2, 1), device="cuda")
1265
1266        # Compute some reference gradients using one forward/backward
1267        out_use_flat_params = model_use_flat_params(inp)
1268        out_use_orig_params = model_use_orig_params(inp)
1269        torch.testing.assert_close(out_use_flat_params, out_use_orig_params)
1270        out_use_flat_params.backward(grad)
1271        out_use_orig_params.backward(grad)
1272        _check_param_grad_parity(model_use_flat_params, model_use_orig_params)
1273        ref_grads_use_flat_params = [
1274            param.grad.detach().clone() for param in model_use_flat_params.parameters()
1275        ]
1276        ref_grads_use_orig_params = [
1277            param.grad.detach().clone()
1278            for param in model_use_orig_params.parameters()
1279            if param.grad is not None
1280        ]
1281
1282        # Run a forward/backward in `no_sync()`
1283        optim_use_flat_params.zero_grad(set_to_none=True)
1284        optim_use_orig_params.zero_grad(set_to_none=True)
1285        for model in (model_use_flat_params, model_use_orig_params):
1286            with model.no_sync():
1287                out = model(inp)
1288                out.backward(grad)
1289        _check_param_grad_parity(model_use_flat_params, model_use_orig_params)
1290
1291        # Run a forward/backward outside `no_sync()`
1292        for model in (model_use_flat_params, model_use_orig_params):
1293            out = model(inp)
1294            out.backward(grad)
1295        _check_param_grad_parity(model_use_flat_params, model_use_orig_params)
1296
1297        # Check that, since we accumulated gradients across 2 iterations, that
1298        # the new gradients are 2x the reference gradients
1299        grads_use_flat_params = [
1300            param.grad.detach().clone() for param in model_use_flat_params.parameters()
1301        ]
1302        grads_use_orig_params = [
1303            param.grad.detach().clone()
1304            for param in model_use_orig_params.parameters()
1305            if param.grad is not None
1306        ]
1307        for grad, ref_grad in zip(grads_use_flat_params, ref_grads_use_flat_params):
1308            torch.testing.assert_close(grad, 2 * ref_grad)
1309        for grad, ref_grad in zip(grads_use_orig_params, ref_grads_use_orig_params):
1310            torch.testing.assert_close(grad, 2 * ref_grad)
1311
1312    @skip_if_lt_x_gpu(2)
1313    def test_no_sync_mixed_precision(self):
1314        """
1315        Tests that dtypes are as expected when using ``no_sync()`` with
1316        ``use_orig_params=True`` and parameter mixed precision.
1317        """
1318        self.run_subtests(
1319            {
1320                "sharding_strategy": [
1321                    ShardingStrategy.FULL_SHARD,
1322                    ShardingStrategy.SHARD_GRAD_OP,
1323                    ShardingStrategy.NO_SHARD,
1324                ]
1325            },
1326            self._test_no_sync_mixed_precision,
1327        )
1328
1329    def _test_no_sync_mixed_precision(self, sharding_strategy: ShardingStrategy):
1330        model = nn.Linear(3, 3, device="cuda")
1331        mixed_precision = MixedPrecision(
1332            param_dtype=torch.float16,
1333            reduce_dtype=torch.float32,
1334        )
1335        fsdp_kwargs = {
1336            "sharding_strategy": sharding_strategy,
1337            "mixed_precision": mixed_precision,
1338            "use_orig_params": True,
1339        }
1340        fsdp_model = FSDP(model, **fsdp_kwargs)
1341        inp = torch.randn((2, 3), device="cuda")
1342        with fsdp_model.no_sync():
1343            # For each of these `no_sync()` backward passes, check that the
1344            # gradients are in the low precision parameter dtype (FP16)
1345            fsdp_model(inp).sum().backward()
1346            for param in fsdp_model.parameters():
1347                if param.grad is not None:
1348                    self.assertEqual(param.grad.dtype, torch.float16)
1349            fsdp_model(inp).sum().backward()
1350            for param in fsdp_model.parameters():
1351                if param.grad is not None:
1352                    self.assertEqual(param.grad.dtype, torch.float16)
1353        # For the backward pass outside `no_sync()`, check that the gradients
1354        # are cast to the full precision in preparation for the optimizer step
1355        fsdp_model(inp).sum().backward()
1356        for param in fsdp_model.parameters():
1357            if param.grad is not None:
1358                self.assertEqual(param.grad.dtype, torch.float32)
1359
1360
1361class TestFSDPUseOrigParamsInit(FSDPTest):
1362    @skip_if_lt_x_gpu(2)
1363    def test_non_uniform_requires_grad(self):
1364        model = nn.Sequential(
1365            nn.Linear(3, 3, device="cuda"),
1366            nn.Linear(3, 3, device="cuda"),
1367        )
1368        # Freeze biases only and flatten both weights and biases into the same
1369        # `FlatParameter` to exercise non-uniform `requires_grad`
1370        model[0].bias.requires_grad = False
1371        model[1].bias.requires_grad = False
1372        fsdp_model = FSDP(model, use_orig_params=True)
1373        self.assertTrue(fsdp_model[0].weight.requires_grad)
1374        self.assertFalse(fsdp_model[0].bias.requires_grad)
1375        self.assertTrue(fsdp_model[1].weight.requires_grad)
1376        self.assertFalse(fsdp_model[1].bias.requires_grad)
1377
1378
1379# Define this to be large enough to trigger stack corruption
1380NUM_SIZE0_TENSORS = 1000
1381
1382
1383class TestMultiTensorApply(TestCase):
1384    def test_multi_tensor_apply_size0_tensors_cpu(self):
1385        size0_tensors = [torch.empty(0, device="cpu") for _ in range(NUM_SIZE0_TENSORS)]
1386        # Check that this does not segfault
1387        torch._foreach_mul_(size0_tensors, 0.1)
1388
1389    @unittest.skipIf(not TEST_CUDA, "no cuda")
1390    def test_multi_tensor_apply_size0_tensors_cuda(self):
1391        size0_tensors = [
1392            torch.empty(0, device="cuda") for _ in range(NUM_SIZE0_TENSORS)
1393        ]
1394        # Check that this does not segfault
1395        torch._foreach_mul_(size0_tensors, 0.1)
1396
1397
1398instantiate_parametrized_tests(TestFSDPUseOrigParamsMultipleParamGroups)
1399instantiate_parametrized_tests(TestFSDPUseOrigParamsUnshardReshard)
1400instantiate_parametrized_tests(TestFSDPUseOrigParamsParamAccess)
1401instantiate_parametrized_tests(TestFSDPUseOrigParamsFQNs)
1402instantiate_parametrized_tests(TestFSDPUseOrigParamsNoSync)
1403
1404if __name__ == "__main__":
1405    run_tests()
1406