xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_optim_state.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import bisect
4import sys
5from copy import deepcopy
6from enum import auto, Enum
7from typing import Any, Callable, Dict, List, Optional, Tuple, Type
8
9import torch
10import torch.nn as nn
11from torch import distributed as dist
12from torch.distributed._shard.sharded_tensor import ShardedTensor
13from torch.distributed._state_dict_utils import _gather_state_dict
14from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
15    _CHECKPOINT_WRAPPED_MODULE,
16    apply_activation_checkpointing,
17)
18from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
19from torch.distributed.fsdp.api import ShardingStrategy
20from torch.distributed.fsdp.fully_sharded_data_parallel import (
21    FullOptimStateDictConfig,
22    FullStateDictConfig,
23    OptimStateKeyType,
24    ShardedOptimStateDictConfig,
25    ShardedStateDictConfig,
26    StateDictSettings,
27    StateDictType,
28)
29from torch.distributed.optim import _NamedOptimizer
30from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
31from torch.testing._internal.common_fsdp import (
32    CUDAInitMode,
33    FSDPInitMode,
34    FSDPTest,
35    TransformerWithSharedParams,
36)
37from torch.testing._internal.common_utils import (
38    instantiate_parametrized_tests,
39    parametrize,
40    run_tests,
41    TEST_WITH_DEV_DBG_ASAN,
42)
43
44
45STATE_DICT_TYPES = [StateDictType.FULL_STATE_DICT, StateDictType.SHARDED_STATE_DICT]
46
47if not dist.is_available():
48    print("Distributed not available, skipping tests", file=sys.stderr)
49    sys.exit(0)
50
51if TEST_WITH_DEV_DBG_ASAN:
52    print(
53        "Skip dev-asan as torch + multiprocessing spawn have known issues",
54        file=sys.stderr,
55    )
56    sys.exit(0)
57
58
59class _OSDCommMethod(Enum):
60    """Method for communicating the optimizer state dict for internal tests."""
61
62    BROADCAST_OBJECT_LIST = auto()
63    SCATTER_FULL_OSD = auto()
64    FLATTEN_SHARDED_OSD = auto()
65    OPTIM_STATE_DICT = auto()
66
67
68class _ModelClass(Enum):
69    """Different model type to test."""
70
71    NESTED = auto()
72    TRANSFORMER = auto()
73
74
75class Bias(torch.nn.Module):
76    """This module applies a 1D additive bias with dimension ``dim``."""
77
78    def __init__(self, dim: int) -> None:
79        super().__init__()
80        assert dim > 0
81        torch.manual_seed(0)
82        self.bias = torch.nn.Parameter(torch.randn((dim,)))
83
84    def forward(self, x):
85        return x + self.bias
86
87
88class BlockA(torch.nn.Module):
89    """
90    Used to define interesting nested structure for FSDP wrapping.
91    BlockA
92        Bias0
93            bias
94        weight
95        Bias1
96            bias
97    """
98
99    def __init__(self, in_dim: int, out_dim: int) -> None:
100        super().__init__()
101        assert all(v > 0 for v in (in_dim, out_dim))
102        torch.manual_seed(0)
103        self.bias_module0 = Bias(out_dim)
104        self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim)))
105        self.bias_module1 = Bias(out_dim)
106        self.relu = torch.nn.ReLU()
107
108    def forward(self, x):
109        x = x @ self.weight
110        x = self.bias_module0(x)
111        x = self.relu(x)  # ensure biases have different gradients
112        x = self.bias_module1(x)
113        return x
114
115
116class BlockB(torch.nn.Module):
117    """
118    Used to define interesting nested structure for FSDP wrapping.
119    BlockB
120        weight
121        Bias
122            bias
123        Bias
124            bias
125    """
126
127    def __init__(self, in_dim: int, out_dim: int) -> None:
128        super().__init__()
129        assert all(v > 0 for v in (in_dim, out_dim))
130        torch.manual_seed(0)
131        self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim)))
132        self.bias_module0 = Bias(out_dim)
133        self.bias_module1 = Bias(out_dim)
134        self.relu = torch.nn.ReLU()
135
136    def forward(self, x):
137        x = x @ self.weight
138        x = self.bias_module0(x)
139        x = self.relu(x)  # ensure biases have different gradients
140        x = self.bias_module1(x)
141        return x
142
143
144class NestedModel(torch.nn.Module):
145    def __init__(self) -> None:
146        super().__init__()
147        self.block0 = BlockB(5, 3)
148        self.block1 = BlockB(3, 7)
149        self.bias = torch.nn.Parameter(torch.randn((5,)))
150        self.block2 = torch.nn.Sequential(
151            BlockA(7, 9),
152            BlockA(9, 9),
153            BlockB(9, 5),
154        )
155        self.relu = torch.nn.ReLU()
156
157    def forward(self, x) -> torch.Tensor:
158        x = self.relu(self.block0(x))
159        x = self.relu(self.block1(x))
160        x = self.relu(self.block2(x))
161        x = x + self.bias
162        return x
163
164    def get_input(self, device):
165        BATCH_SIZE = 8
166        return (torch.randn((BATCH_SIZE, 5)).to(device),)
167
168    def get_loss(self, inp, output):
169        return output.sum()
170
171    def run_backward(self, loss):
172        loss.backward()
173
174    @staticmethod
175    def wrap(
176        model: torch.nn.Module,
177        group: Optional[dist.ProcessGroup] = None,
178        ignore_modules: bool = False,
179        fsdp_kwargs: Optional[Dict[str, Any]] = None,
180    ) -> torch.nn.Module:
181        if fsdp_kwargs is None:
182            fsdp_kwargs = {}
183        # Flatten Bias0; then flatten weight and Bias1 together into `block1`
184        model.block1.bias_module0 = FSDP(
185            model.block1.bias_module0,
186            process_group=group,
187            **fsdp_kwargs,
188        )
189        model.block1 = FSDP(model.block1, process_group=group, **fsdp_kwargs)
190        # Flatten Bias0; flatten Bias1; then flatten weight into `block2[1]`
191        model.block2[1].bias_module0 = FSDP(
192            model.block2[1].bias_module0,
193            process_group=group,
194            **fsdp_kwargs,
195        )
196        model.block2[1].bias_module1 = FSDP(
197            model.block2[1].bias_module1,
198            process_group=group,
199            **fsdp_kwargs,
200        )
201        model.block2[1] = FSDP(model.block2[1], process_group=group, **fsdp_kwargs)
202        # Flatten weight, Bias, bias into `block2[2]`
203        ignored_modules = [model.block2[2].bias_module0] if ignore_modules else None
204        model.block2[2] = FSDP(
205            model.block2[2],
206            process_group=group,
207            ignored_modules=ignored_modules,
208            **fsdp_kwargs,
209        )
210        return model
211
212    @staticmethod
213    def wrap_alt(
214        model: torch.nn.Module,
215        group: Optional[dist.ProcessGroup] = None,
216        fsdp_kwargs: Optional[Dict[str, Any]] = None,
217    ) -> torch.nn.Module:
218        if fsdp_kwargs is None:
219            fsdp_kwargs = {}
220        model.block0.bias_module0 = FSDP(
221            model.block0.bias_module0,
222            process_group=group,
223            **fsdp_kwargs,
224        )
225        model.block0 = FSDP(model.block0, process_group=group, **fsdp_kwargs)
226        return model
227
228    @staticmethod
229    def wrap_with_unmanaged_params(
230        model,
231        add_to_fsdp_module: bool,
232        group=None,
233    ) -> Tuple[torch.nn.Module, List[torch.nn.Parameter]]:
234        """Registers unmanaged parameters before wrapping with :meth:`wrap`."""
235        device = next(model.parameters()).device
236        unmanaged_param = torch.nn.Parameter(torch.randn(5, 5, device=device))
237        # Either register the parameter to a module to be wrapped with FSDP
238        # (`model.block2[2]`) or a module not to be wrapped with FSDP (`model`)
239        register_module = model.block2[2] if add_to_fsdp_module else model
240        register_module.register_parameter(
241            "unmanaged_param",
242            unmanaged_param,
243        )
244        # For simplicity, we only add a single unmanaged parameter, but should
245        # be easy to generalize if needed
246        return NestedModel.wrap(model, group), [unmanaged_param]
247
248    @staticmethod
249    def add_unmanaged_param_entry(osd, unmanaged_param, step) -> None:
250        """Adds an entry for the unmanaged parameter ``unmanaged_param``
251        assuming Adam optimizer and a single parameter group."""
252        # The unmanaged parameters should be passed to this method in
253        # `model.parameters()` order since their parameter IDs will be assigned
254        # in order of the skipped IDs
255        # Assign a parameter ID to the unmanaged parameter
256        unmanaged_param_id = -1
257        param_ids = osd["param_groups"][0]["params"]
258        for i in range(1, len(param_ids)):
259            diff = param_ids[i] - param_ids[i - 1]
260            if diff != 1:
261                assert diff > 1, f"Invalid IDs: {param_ids[i - 1]} {param_ids[i]}"
262                unmanaged_param_id = param_ids[i - 1] + 1
263                break
264        if unmanaged_param_id == -1:
265            unmanaged_param_id = len(param_ids)  # last ID skipped
266        assert unmanaged_param_id >= 0, "One parameter ID should be skipped"
267        # Add a state entry for the unmanaged parameter
268        state_device = next(iter(next(iter(osd["state"].values())).values())).device
269        osd["state"][unmanaged_param_id] = {
270            "step": torch.tensor(float(step), device=state_device),
271            "exp_avg": torch.randn(unmanaged_param.shape, device=state_device),
272            "exp_avg_sq": torch.randn(unmanaged_param.shape, device=state_device),
273        }
274        # Insert the ID into the parameter group in order
275        bisect.insort(osd["param_groups"][0]["params"], unmanaged_param_id)
276
277    # NOTE: We exclude `self.bias` from either parameter group to test the
278    # case where the optimizer input does not include all model parameters
279    def param_group0(self) -> List[torch.nn.Parameter]:
280        # Use `block1`'s parameters for the first parameter group to deviate
281        # from the `model.parameters()` order
282        return list(self.block1.parameters())
283
284    def param_group1(self) -> List[torch.nn.Parameter]:
285        # Deviate from the `model.parameters()` order further by rearranging
286        # `block2`'s parameters to be before `block0`'s parameters
287        return list(self.block2.parameters()) + list(self.block0.parameters())
288
289
290# Simple and boring model to test interface and some corner cases that do not
291# require complicated wrapping strategy.
292class TestDummyModel(torch.nn.Module):
293    def __init__(self, no_grad: bool = False):
294        super().__init__()
295        torch.manual_seed(0)
296        self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
297        self.net1[0].weight.requires_grad = not no_grad
298        self.net1[0].bias.requires_grad = not no_grad
299        self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
300        self.net3 = nn.Linear(32, 64)
301        self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
302
303    def forward(self, x):
304        return self.net4(self.net3(self.net2(self.net1(x))))
305
306    def get_input(self):
307        return torch.rand(8, 8, device="cuda")
308
309
310class TestFSDPOptimState(FSDPTest):
311    def __init__(self, *args, **kwargs):
312        super().__init__(*args, **kwargs)
313        self._model_class = {
314            _ModelClass.NESTED: self._init_nested_model,
315            _ModelClass.TRANSFORMER: self._init_transformer_model,
316        }
317
318    def _init_nested_model(
319        self,
320        wrap: bool,
321        wrap_alt: bool = False,  # ignored if `wrap=False`
322        device: torch.device = torch.device("cuda"),
323        group=None,
324        optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
325        use_multiple_param_groups: bool = False,
326        use_diff_optim_inputs: bool = False,
327        fsdp_kwargs: Optional[Dict[str, Any]] = None,
328    ):
329        model = NestedModel().to(device)
330        if wrap:
331            model = (
332                NestedModel.wrap_alt(model, group, fsdp_kwargs)
333                if wrap_alt
334                else NestedModel.wrap(model, group, fsdp_kwargs=fsdp_kwargs)
335            )
336        if not use_multiple_param_groups:
337            optim_input = list(model.parameters())
338        else:
339            optim_input = [
340                {"params": model.param_group0()},
341                {"params": model.param_group1(), "weight_decay": 0.9},
342            ]
343        # Use a reversed parameter order for the optimizer input on odd ranks
344        if use_diff_optim_inputs and self.rank % 2 == 1:
345            if isinstance(optim_input[0], dict):
346                for param_group in optim_input:
347                    param_group["params"] = list(reversed(param_group["params"]))
348            else:
349                optim_input = list(reversed(optim_input))
350        optim = optim_class(optim_input, lr=0.01)
351        return model, optim, optim_input
352
353    def _init_transformer_model(
354        self,
355        wrap: bool,
356        device: torch.device = torch.device("cuda"),
357        group=None,
358        optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
359        use_multiple_param_groups: bool = False,
360        use_diff_optim_inputs: bool = False,
361    ):
362        if use_multiple_param_groups or use_diff_optim_inputs:
363            # Keep these as arguments for parity with `_init_nested_model()`;
364            # these settings are not implemented since the transformer is
365            # wrapped with FSDP at the top-level, which means that there is
366            # only a single flat parameter, making these booleans vacuous
367            raise NotImplementedError
368        if group is None:
369            group = dist.distributed_c10d._get_default_group()
370        model = TransformerWithSharedParams.init(
371            group,
372            FSDPInitMode.RECURSIVE if wrap else FSDPInitMode.NO_FSDP,
373            CUDAInitMode.CUDA_BEFORE,
374            deterministic=True,
375        )
376        optim = optim_class(model.parameters(), lr=0.01)
377        return model, optim, None
378
379    def _step_model(
380        self,
381        model: torch.nn.Module,
382        optim: torch.optim.Optimizer,
383        device: torch.device = torch.device("cuda"),
384        num_iters: int = 1,
385    ) -> List[float]:
386        """Performs a forward pass, backward pass, and optimizer step
387        ``num_iters``-many times, and returns the per-iteration losses."""
388        torch.manual_seed(0)  # set seed for determinism
389        losses = []
390        module = getattr(model, "module", model)
391        for _ in range(num_iters):
392            optim.zero_grad()
393            inp = module.get_input(device)
394            output = model(*inp)
395            loss = module.get_loss(inp, output).to(device)
396            losses.append(loss.item())
397            module.run_backward(loss)
398            optim.step()
399        return losses
400
401    def _broadcast_full_osd(self, full_osd: Dict[str, Any], group=None):
402        """Broadcasts the full optimizer state dict in place of using
403        ``torch.save()`` and ``torch.load()`` so that all ranks can have it."""
404        obj_list = [full_osd]
405        dist.broadcast_object_list(
406            obj_list,
407            src=0,
408            group=group,
409        )
410        full_osd = obj_list[0]
411        return full_osd
412
413    def _are_equal_states(
414        self,
415        state1: Dict[str, Any],
416        state2: Dict[str, Any],
417    ) -> bool:
418        """Checks if ``state1`` and ``state2`` contain the same mappings."""
419        if set(state1.keys()) != set(state2.keys()):
420            return False
421        for state_name, value1 in state1.items():
422            value2 = state2[state_name]
423            if type(value1) != type(value2):
424                return False
425            if torch.is_tensor(value1):  # tensor state
426                assert torch.is_tensor(value2)
427                # Check the values on CPU to be device-agnostic
428                value1 = value1.cpu()
429                value2 = value2.cpu()
430                if value1.shape != value2.shape or not torch.all(
431                    torch.isclose(value1, value2)
432                ):
433                    return False
434            else:  # non-tensor state
435                if value1 != value2:
436                    return False
437        return True
438
439    def _check_same_state(
440        self,
441        fsdp_osd,
442        ref_osd,
443        check_same_param_keys: bool,
444    ):
445        """Checks that ``full_osd`` and ``ref_osd`` have the same "state" part.
446        If ``check_same_param_keys=True``, then checks that the parameter keys
447        match (e.g. when both should be parameter names), and does not check
448        the parameter keys otherwise."""
449        assert "state" in ref_osd
450        self.assertTrue("state" in fsdp_osd)
451        ref_osd_state = ref_osd["state"]
452        fsdp_osd_state = {
453            k: _gather_state_dict(v) for k, v in fsdp_osd["state"].items()
454        }
455
456        if check_same_param_keys:
457            # Check parameter keys are the same first for earlier erroring
458            ref_osd_param_ids = set(ref_osd_state.keys())
459            fsdp_osd_param_ids = set(fsdp_osd_state.keys())
460            self.assertTrue(
461                ref_osd_param_ids == fsdp_osd_param_ids,
462                f"Rank {self.rank}: {(ref_osd_param_ids, fsdp_osd_param_ids)}",
463            )
464            # Check state values are the same
465            for param_id, param_state in fsdp_osd_state.items():
466                for state_name, value in param_state.items():
467                    ref_value = ref_osd_state[param_id][state_name]
468                    self.assertEqual(value, ref_value)
469            return
470        # Otherwise, only require the parameter keys to be isomorphic (e.g.
471        # between IDs and names)
472        ref_osd_states = list(ref_osd_state.values())
473        fsdp_osd_states = list(fsdp_osd_state.values())
474        self.assertEqual(len(ref_osd_states), len(fsdp_osd_states))
475        # Use brute-force quadratic-time comparison since it is hard to
476        # hash a tensor by value instead of by object
477        for fsdp_osd_state in fsdp_osd_states:
478            # Check for at least one match (may be > 1 in toy edge cases, e.g.
479            # multiple biases); nonetheless, each having >= 1 match and the two
480            # lists having equal length imply that the list contents are equal
481            self.assertTrue(
482                any(
483                    self._are_equal_states(fsdp_osd_state, ref_osd_state)
484                    for ref_osd_state in ref_osd_states
485                )
486            )
487
488    def _check_same_param_groups(
489        self,
490        full_osd,
491        ref_osd,
492        check_same_param_keys: bool,
493    ):
494        """Checks that ``full_osd`` and ``ref_osd`` have the same
495        "param_groups" part. If ``check_same_param_keys=True`, then checks that
496        the parameter keys match (e.g. when both should be parameter names),
497        and does not check the parameter keys otherwise."""
498        assert "param_groups" in ref_osd
499        self.assertTrue("param_groups" in full_osd)
500        ref_osd_param_groups = ref_osd["param_groups"]
501        full_osd_param_groups = full_osd["param_groups"]
502        self.assertTrue(len(full_osd_param_groups), len(ref_osd_param_groups))
503        for full_osd_pg, ref_osd_pg in zip(
504            full_osd_param_groups,
505            ref_osd_param_groups,
506        ):
507            self.assertEqual(
508                set(full_osd_pg.keys()),
509                set(ref_osd_pg.keys()),
510            )
511            for name, full_osd_value in full_osd_pg.items():
512                if name == "params" and not check_same_param_keys:
513                    continue
514                self.assertEqual(full_osd_value, ref_osd_pg[name])
515
516    @skip_if_lt_x_gpu(2)
517    @parametrize("state_dict_type", STATE_DICT_TYPES)
518    @parametrize("use_multiple_param_groups", [False, True])
519    @parametrize("rank0_only", [False, True])
520    @parametrize("use_diff_optim_inputs", [False, True])
521    def test_optim_state_dict_nested(
522        self,
523        state_dict_type: StateDictType,
524        use_multiple_param_groups: bool,
525        rank0_only: bool,
526        use_diff_optim_inputs: bool,
527    ) -> None:
528        """
529        Tests :meth:`full_optim_state_dict` and meth:`sharded_optim_state_dict`
530        by comparing the returned dict for an FSDP-wrapped model with that of
531        an equivalent non-wrapped model.
532
533        The test checks the equivalence excluding the parameter keys since the
534        FSDP and normal optimizer state dicts key by names and IDs,
535        respectively. This means that the test can pass even if parameter keys
536        are incorrectly mapped to values. Their correct mapping is tested in
537        other tests that exercise the save/load workflow.
538        """
539        self.run_subtests(
540            {"use_optim_input": [False, True]},
541            self._test_optim_state_dict_nested,
542            state_dict_type=state_dict_type,
543            use_multiple_param_groups=use_multiple_param_groups,
544            rank0_only=rank0_only,
545            use_diff_optim_inputs=use_diff_optim_inputs,
546        )
547
548    def _test_optim_state_dict_nested(
549        self,
550        state_dict_type: StateDictType,
551        use_multiple_param_groups: bool,
552        rank0_only: bool,
553        use_diff_optim_inputs: bool,
554        use_optim_input: bool,
555    ) -> None:
556        if rank0_only and state_dict_type == StateDictType.SHARDED_STATE_DICT:
557            return  # not supported
558        NUM_ITERS = 3
559        model1, optim1, optim_input = self._init_nested_model(
560            wrap=True,
561            use_multiple_param_groups=use_multiple_param_groups,
562            use_diff_optim_inputs=use_diff_optim_inputs,
563        )
564        losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS)
565        if state_dict_type == StateDictType.FULL_STATE_DICT:
566            if use_optim_input:
567                fsdp_osd = FSDP.full_optim_state_dict(
568                    model1,
569                    optim1,
570                    optim_input,
571                    rank0_only=rank0_only,
572                )
573            else:
574                fsdp_osd = FSDP.full_optim_state_dict(
575                    model1,
576                    optim1,
577                    rank0_only=rank0_only,
578                )
579        else:
580            fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1)
581        # Non-target ranks get an empty state dict
582        if rank0_only and self.rank != 0:
583            self.assertEqual(len(fsdp_osd), 0)
584            return
585        model2, optim2, _ = self._init_nested_model(
586            wrap=False,
587            use_multiple_param_groups=use_multiple_param_groups,
588            use_diff_optim_inputs=use_diff_optim_inputs,
589        )
590        losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS)
591        ref_osd = optim2.state_dict()
592        # Check the losses to eliminate model drift as a source of error
593        for i, (l1, l2) in enumerate(zip(losses1, losses2)):
594            assert l1 == l2, f"Losses differ on iter {i}: {l1:.5f} {l2:.5f}"
595        # Do not check the parameter keys since the full/sharded optimizer state
596        # dict uses parameter names, while the non-wrapped equivalent uses
597        # parameter IDs
598        check_same_param_keys = False
599        self._check_same_param_groups(
600            fsdp_osd,
601            ref_osd,
602            check_same_param_keys=check_same_param_keys,
603        )
604        self._check_same_state(
605            fsdp_osd,
606            ref_osd,
607            check_same_param_keys=check_same_param_keys,
608        )
609
610    @skip_if_lt_x_gpu(2)
611    def test_full_optim_state_dict_keys(self):
612        """Tests that the parameter keys returned by
613        :meth:`full_optim_state_dict` match those of :meth:`state_dict` with
614        full ``state_dict_type`` for a non-FSDP-root model with nested FSDP
615        instances and ignored modules."""
616        device = torch.device("cuda")
617        model = NestedModel().to(device)
618        wrapped_model = NestedModel.wrap(model, ignore_modules=True)
619        # Add checkpointing to ensure optim_state_dict and state_dict strip out
620        # checkpointing prefixes.
621        apply_activation_checkpointing(
622            model, check_fn=lambda module: isinstance(module, torch.nn.Sequential)
623        )
624        optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
625        self._step_model(model, optim, device)
626        optim_state_dict = FSDP.full_optim_state_dict(
627            wrapped_model, optim, rank0_only=False
628        )
629        with FSDP.state_dict_type(wrapped_model, StateDictType.FULL_STATE_DICT):
630            state_dict = wrapped_model.state_dict()
631        self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys())
632        # Check that checkpointing prefix was indeed stripped.
633        for key in optim_state_dict["state"]:
634            self.assertNotIn(_CHECKPOINT_WRAPPED_MODULE, key)
635
636    @skip_if_lt_x_gpu(2)
637    def test_full_optim_state_dict_nested_invalid(self):
638        """Tests that :meth:`full_optim_state_dict` raises an error when
639        nonzero ranks are missing the optimizer state for parameters on rank
640        0."""
641        device = torch.device("cuda")
642        model = NestedModel.wrap(NestedModel().to(device), None)
643        optim_input = list(model.parameters())
644        if self.rank != 0:
645            # Exclude a parameter so that nonzero ranks are missing state
646            optim_input = optim_input[:-1]
647        optim = torch.optim.Adam(optim_input, lr=1e-3)
648        self._step_model(model, optim, num_iters=3)
649        error_regex = (
650            "FSDP currently requires each rank to have at least the "
651            "optimizer states needed by rank 0's optimizer but some ranks "
652            "are missing some of those states"
653        )
654        with self.assertRaisesRegex(RuntimeError, error_regex):
655            FSDP.full_optim_state_dict(model, optim)
656
657    @skip_if_lt_x_gpu(2)
658    @parametrize("use_multiple_param_groups", [False, True])
659    @parametrize("wrap_alt", [False, True])
660    @parametrize("use_diff_optim_inputs", [False, True])
661    def test_shard_full_optim_state_dict_nested(
662        self,
663        use_multiple_param_groups: bool,
664        wrap_alt: bool,
665        use_diff_optim_inputs: bool,
666    ):
667        """Tests :meth:`shard_full_optim_state_dict` for a non-FSDP-root model
668        with nested FSDP instances."""
669        self.run_subtests(
670            {"use_optim_input": [False, True]},
671            self._test_load_optim_state,
672            model_class=_ModelClass.NESTED,
673            use_multiple_param_groups=use_multiple_param_groups,
674            halve_world_size=False,
675            osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,
676            use_diff_optim_inputs=use_diff_optim_inputs,
677            wrap_alt=wrap_alt,
678            num_iters=3,
679        )
680
681        self._test_load_optim_state_with_optim_state_dict(
682            _ModelClass.NESTED,
683            state_dict_settings=StateDictSettings(
684                StateDictType.FULL_STATE_DICT,
685                FullStateDictConfig(),
686                FullOptimStateDictConfig(),
687            ),
688            use_multiple_param_groups=False,
689            halve_world_size=False,
690            use_diff_optim_inputs=use_diff_optim_inputs,
691            wrap_alt=wrap_alt,
692            num_iters=3,
693        )
694
695    @skip_if_lt_x_gpu(2)
696    def test_shard_full_optim_state_dict_nested_halve_world_size(self):
697        """Tests :meth:`shard_full_optim_state_dict` for a non-FSDP-root model
698        with nested FSDP instances when loading into a new process group with
699        halved world size."""
700        # To save CI costs, we test with the "harder" settings:
701        use_multiple_param_groups = True
702        use_diff_optim_inputs = True
703        wrap_alt = True
704        self.run_subtests(
705            {"use_optim_input": [False, True]},
706            self._test_load_optim_state,
707            model_class=_ModelClass.NESTED,
708            use_multiple_param_groups=use_multiple_param_groups,
709            halve_world_size=True,
710            osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,
711            use_diff_optim_inputs=use_diff_optim_inputs,
712            wrap_alt=wrap_alt,
713            num_iters=3,
714        )
715
716        self._test_load_optim_state_with_optim_state_dict(
717            _ModelClass.NESTED,
718            state_dict_settings=StateDictSettings(
719                StateDictType.FULL_STATE_DICT,
720                FullStateDictConfig(),
721                FullOptimStateDictConfig(),
722            ),
723            use_multiple_param_groups=use_multiple_param_groups,
724            halve_world_size=True,
725            use_diff_optim_inputs=use_diff_optim_inputs,
726            wrap_alt=wrap_alt,
727            num_iters=3,
728        )
729
730    @skip_if_lt_x_gpu(2)
731    def test_shard_full_optim_state_dict_transformer(self) -> None:
732        """Tests :meth:`shard_full_optim_state_dict` for an FSDP-root
733        transformer model with shared parameters."""
734        self.run_subtests(
735            {"use_optim_input": [False, True]},
736            self._test_load_optim_state,
737            model_class=_ModelClass.TRANSFORMER,
738            use_multiple_param_groups=False,
739            halve_world_size=True,
740            osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,
741            use_diff_optim_inputs=False,
742            num_iters=3,
743        )
744
745        self._test_load_optim_state_with_optim_state_dict(
746            _ModelClass.TRANSFORMER,
747            state_dict_settings=StateDictSettings(
748                StateDictType.FULL_STATE_DICT,
749                FullStateDictConfig(),
750                FullOptimStateDictConfig(),
751            ),
752            use_multiple_param_groups=False,
753            halve_world_size=True,
754            use_diff_optim_inputs=False,
755            num_iters=3,
756        )
757
758    @skip_if_lt_x_gpu(2)
759    @parametrize("use_multiple_param_groups", [False, True])
760    @parametrize("wrap_alt", [False, True])
761    @parametrize("use_diff_optim_inputs", [False, True])
762    def test_scatter_full_optim_state_dict_nested(
763        self,
764        use_multiple_param_groups: bool,
765        wrap_alt: bool,
766        use_diff_optim_inputs: bool,
767    ):
768        """Tests :meth:`scatter_full_optim_state_dict` for a non-FSDP-root
769        model with nested FSDP instances."""
770        self.run_subtests(
771            {"use_optim_input": [False, True]},
772            self._test_load_optim_state,
773            model_class=_ModelClass.NESTED,
774            use_multiple_param_groups=use_multiple_param_groups,
775            halve_world_size=False,
776            osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,
777            use_diff_optim_inputs=use_diff_optim_inputs,
778            wrap_alt=wrap_alt,
779            num_iters=3,
780        )
781
782        self._test_load_optim_state_with_optim_state_dict(
783            _ModelClass.NESTED,
784            state_dict_settings=StateDictSettings(
785                StateDictType.FULL_STATE_DICT,
786                FullStateDictConfig(),
787                FullOptimStateDictConfig(rank0_only=True),
788            ),
789            use_multiple_param_groups=use_multiple_param_groups,
790            halve_world_size=False,
791            use_diff_optim_inputs=use_diff_optim_inputs,
792            wrap_alt=wrap_alt,
793            num_iters=3,
794        )
795
796    @skip_if_lt_x_gpu(2)
797    def test_scatter_full_optim_state_dict_nested_halve_world_size(self):
798        """Tests :meth:`scatter_full_optim_state_dict` for a non-FSDP-root
799        model with nested FSDP instances when loading into a new process group
800        with halved world size."""
801        # To save CI costs, we test with the "harder" settings:
802        use_multiple_param_groups = True
803        use_diff_optim_inputs = True
804        wrap_alt = True
805        self.run_subtests(
806            {"use_optim_input": [False, True]},
807            self._test_load_optim_state,
808            model_class=_ModelClass.NESTED,
809            use_multiple_param_groups=use_multiple_param_groups,
810            halve_world_size=True,
811            osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,
812            use_diff_optim_inputs=use_diff_optim_inputs,
813            wrap_alt=wrap_alt,
814            num_iters=3,
815        )
816
817        self._test_load_optim_state_with_optim_state_dict(
818            _ModelClass.NESTED,
819            state_dict_settings=StateDictSettings(
820                StateDictType.FULL_STATE_DICT,
821                FullStateDictConfig(),
822                FullOptimStateDictConfig(rank0_only=True),
823            ),
824            use_multiple_param_groups=use_multiple_param_groups,
825            halve_world_size=True,
826            use_diff_optim_inputs=use_diff_optim_inputs,
827            wrap_alt=wrap_alt,
828            num_iters=3,
829        )
830
831    @skip_if_lt_x_gpu(2)
832    def test_scatter_full_optim_state_dict_transformer(self) -> None:
833        """Tests :meth:`scatter_full_optim_state_dict` for an FSDP-root
834        transformer model with shared parameters."""
835        self.run_subtests(
836            {"use_optim_input": [False, True]},
837            self._test_load_optim_state,
838            model_class=_ModelClass.TRANSFORMER,
839            use_multiple_param_groups=False,
840            halve_world_size=True,
841            osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,
842            use_diff_optim_inputs=False,
843            num_iters=3,
844        )
845
846        self._test_load_optim_state_with_optim_state_dict(
847            _ModelClass.TRANSFORMER,
848            state_dict_settings=StateDictSettings(
849                StateDictType.FULL_STATE_DICT,
850                FullStateDictConfig(),
851                FullOptimStateDictConfig(rank0_only=True),
852            ),
853            use_multiple_param_groups=False,
854            halve_world_size=True,
855            use_diff_optim_inputs=False,
856            num_iters=3,
857        )
858
859    @skip_if_lt_x_gpu(2)
860    def test_flatten_sharded_optim_state_dict_nested(self) -> None:
861        """Tests :meth:`flatten_sharded_optim_state_dict` for an FSDP-root
862        nested model."""
863        self._test_load_optim_state(
864            _ModelClass.NESTED,
865            use_multiple_param_groups=False,
866            halve_world_size=False,
867            osd_comm_method=_OSDCommMethod.FLATTEN_SHARDED_OSD,
868            use_diff_optim_inputs=False,
869            use_optim_input=False,
870            wrap_alt=True,
871            num_iters=3,
872        )
873
874        self._test_load_optim_state_with_optim_state_dict(
875            _ModelClass.NESTED,
876            state_dict_settings=StateDictSettings(
877                StateDictType.SHARDED_STATE_DICT,
878                ShardedStateDictConfig(),
879                ShardedOptimStateDictConfig(),
880            ),
881            use_multiple_param_groups=False,
882            halve_world_size=False,
883            use_diff_optim_inputs=False,
884            wrap_alt=True,
885            num_iters=3,
886        )
887
888    @skip_if_lt_x_gpu(2)
889    def test_flatten_sharded_optim_state_dict_transformer(self) -> None:
890        """Tests :meth:`flatten_sharded_optim_state_dict` for an FSDP-root
891        transformer model."""
892        self._test_load_optim_state(
893            _ModelClass.TRANSFORMER,
894            use_multiple_param_groups=False,
895            halve_world_size=False,
896            osd_comm_method=_OSDCommMethod.FLATTEN_SHARDED_OSD,
897            use_diff_optim_inputs=False,
898            use_optim_input=False,
899            num_iters=3,
900        )
901
902        self._test_load_optim_state_with_optim_state_dict(
903            _ModelClass.TRANSFORMER,
904            state_dict_settings=StateDictSettings(
905                StateDictType.SHARDED_STATE_DICT,
906                ShardedStateDictConfig(),
907                ShardedOptimStateDictConfig(),
908            ),
909            use_multiple_param_groups=False,
910            halve_world_size=False,
911            use_diff_optim_inputs=False,
912            num_iters=3,
913        )
914
915    @skip_if_lt_x_gpu(2)
916    def test_use_orig_params(self) -> None:
917        """Tests :meth:`optim_state_dict` for an FSDP-root nested model."""
918        self.run_subtests(
919            {
920                "halve_world_size": [True, False],
921                "wrap_alt": [True, False],
922            },
923            self._test_load_optim_state_with_optim_state_dict,
924            model_class=_ModelClass.NESTED,
925            state_dict_settings=StateDictSettings(
926                StateDictType.FULL_STATE_DICT,
927                FullStateDictConfig(),
928                FullOptimStateDictConfig(),
929            ),
930            use_multiple_param_groups=False,
931            use_diff_optim_inputs=False,
932            num_iters=3,
933            fsdp_kwargs={"use_orig_params": True},
934        )
935
936        self.run_subtests(
937            {
938                "halve_world_size": [True, False],
939                "wrap_alt": [True, False],
940            },
941            self._test_load_optim_state_with_optim_state_dict,
942            model_class=_ModelClass.NESTED,
943            state_dict_settings=StateDictSettings(
944                StateDictType.FULL_STATE_DICT,
945                FullStateDictConfig(),
946                FullOptimStateDictConfig(rank0_only=True),
947            ),
948            use_multiple_param_groups=False,
949            use_diff_optim_inputs=False,
950            num_iters=3,
951            fsdp_kwargs={"use_orig_params": True},
952        )
953
954        self.run_subtests(
955            {
956                "wrap_alt": [True, False],
957            },
958            self._test_load_optim_state_with_optim_state_dict,
959            model_class=_ModelClass.NESTED,
960            state_dict_settings=StateDictSettings(
961                StateDictType.SHARDED_STATE_DICT,
962                ShardedStateDictConfig(),
963                ShardedOptimStateDictConfig(),
964            ),
965            use_multiple_param_groups=False,
966            # We cannot test halve_world_size with SHARDED_STATE_DICT.
967            halve_world_size=False,
968            use_diff_optim_inputs=False,
969            num_iters=3,
970            fsdp_kwargs={"use_orig_params": True},
971        )
972
973    def _test_load_optim_state(
974        self,
975        model_class: _ModelClass,
976        use_multiple_param_groups: bool,
977        halve_world_size: bool,
978        osd_comm_method: _OSDCommMethod,
979        use_diff_optim_inputs: bool,
980        use_optim_input: bool,
981        num_iters: int,
982        **new_model_kwargs,
983    ):
984        """
985        (1) Runs a model with full world size for K iterations to generate a
986        full/sharded optimizer state dict;
987        (2) initializes a model with halved world size and possibly different
988        FSDP wrapping scheme (based on ``new_model_kwargs``);
989        (3) loads the full/sharded optimizer state dict from (1) according to the
990        halved-world-size model;
991        (4) runs the halved-world-size model for K iterations; and
992        (5) checks that the sharded optimizer state dict from (3) matches the
993        halved-world-size model's local optimizer state dict, meaning that the
994        former could have equivalently been loaded into the local optimizer.
995        """
996        initializer = self._model_class[model_class]
997        if osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:
998            osd_method = FSDP.optim_state_dict
999        elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD:
1000            osd_method = FSDP.sharded_optim_state_dict
1001        else:
1002            osd_method = FSDP.full_optim_state_dict
1003
1004        # First, run a wrapped model with full world size for a few iterations
1005        model1, optim1, optim_input1 = initializer(
1006            wrap=True,
1007            use_multiple_param_groups=use_multiple_param_groups,
1008        )
1009        self._step_model(model1, optim1, num_iters=num_iters)
1010        fsdp_osd1 = (
1011            osd_method(model1, optim1, optim_input1)
1012            if use_optim_input
1013            else osd_method(model1, optim1)
1014        )
1015        if halve_world_size:
1016            # Create a new process group with halved world size
1017            new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]
1018            new_group = dist.new_group(ranks=new_group_ranks)
1019            if self.rank not in new_group_ranks:
1020                return
1021        else:
1022            # Continue using the same group and hence world size
1023            new_group = dist.distributed_c10d._get_default_group()
1024        # Second, run a wrapped model with (possibly) halved world size and
1025        # (possibly) differing `optim_input` across ranks
1026        model2, optim2, optim_input2 = initializer(
1027            wrap=True,
1028            group=new_group,
1029            use_multiple_param_groups=use_multiple_param_groups,
1030            use_diff_optim_inputs=use_diff_optim_inputs,
1031            **new_model_kwargs,  # specify `wrap_alt` to change wrapping
1032        )
1033        self._step_model(model2, optim2, num_iters=num_iters)
1034        fsdp_osd2 = (
1035            osd_method(model2, optim2, optim_input2, group=new_group)
1036            if use_optim_input
1037            else osd_method(model2, optim2, group=new_group)
1038        )
1039        # Compute two sharded optim state dicts: (1) for the first model
1040        # according to the second model and (2) for the second model according
1041        # to the second model
1042        if osd_comm_method == _OSDCommMethod.BROADCAST_OBJECT_LIST:
1043            fsdp_osd1 = self._broadcast_full_osd(fsdp_osd1, group=new_group)
1044            sharded_osd1 = (
1045                FSDP.shard_full_optim_state_dict(
1046                    fsdp_osd1, model2, optim_input=optim_input2
1047                )
1048                if use_optim_input
1049                else FSDP.shard_full_optim_state_dict(fsdp_osd1, model2, optim=optim2)
1050            )
1051            fsdp_osd2 = self._broadcast_full_osd(fsdp_osd2, group=new_group)
1052            sharded_osd2 = (
1053                FSDP.shard_full_optim_state_dict(
1054                    fsdp_osd2, model2, optim_input=optim_input2
1055                )
1056                if use_optim_input
1057                else FSDP.shard_full_optim_state_dict(fsdp_osd2, model2, optim=optim2)
1058            )
1059        elif osd_comm_method == _OSDCommMethod.SCATTER_FULL_OSD:
1060            sharded_osd1 = (
1061                FSDP.scatter_full_optim_state_dict(
1062                    fsdp_osd1 if self.rank == 0 else None,
1063                    model2,
1064                    optim_input=optim_input2,
1065                    group=new_group,
1066                )
1067                if use_optim_input
1068                else FSDP.scatter_full_optim_state_dict(
1069                    fsdp_osd1 if self.rank == 0 else None,
1070                    model2,
1071                    optim=optim2,
1072                    group=new_group,
1073                )
1074            )
1075            sharded_osd2 = (
1076                FSDP.scatter_full_optim_state_dict(
1077                    fsdp_osd2 if self.rank == 0 else None,
1078                    model2,
1079                    optim_input=optim_input2,
1080                    group=new_group,
1081                )
1082                if use_optim_input
1083                else FSDP.scatter_full_optim_state_dict(
1084                    fsdp_osd2 if self.rank == 0 else None,
1085                    model2,
1086                    optim=optim2,
1087                    group=new_group,
1088                )
1089            )
1090        elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD:
1091            sharded_osd1 = FSDP.flatten_sharded_optim_state_dict(
1092                fsdp_osd1,
1093                model2,
1094                optim=optim2,
1095            )
1096            sharded_osd2 = FSDP.flatten_sharded_optim_state_dict(
1097                fsdp_osd2,
1098                model2,
1099                optim=optim2,
1100            )
1101        elif osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:
1102            sharded_osd1 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd1)
1103            sharded_osd2 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd2)
1104
1105        # As a sanity check, check that sharding the second model's full/sharded
1106        # optimizer state dict according to itself is equivalent to its local
1107        # optimizer's state dict
1108        local_osd2 = optim2.state_dict()
1109        check_same_param_keys = True  # should all have matching parameter IDs
1110        self._check_same_param_groups(
1111            sharded_osd2,
1112            local_osd2,
1113            check_same_param_keys=check_same_param_keys,
1114        )
1115        self._check_same_state(
1116            sharded_osd2,
1117            local_osd2,
1118            check_same_param_keys=check_same_param_keys,
1119        )
1120        # Check that sharding the first model's full/sharded optimizer state dict
1121        # according to the second model is equivalent to the second model's
1122        # local optimizer state dict
1123        self._check_same_param_groups(
1124            sharded_osd1,
1125            local_osd2,
1126            check_same_param_keys=check_same_param_keys,
1127        )
1128        self._check_same_state(
1129            sharded_osd1,
1130            local_osd2,
1131            check_same_param_keys=check_same_param_keys,
1132        )
1133        # As a sanity check, check that we can load and run a few iterations
1134        optim2.load_state_dict(sharded_osd2)
1135        self._step_model(model2, optim2, num_iters=num_iters)
1136
1137    @skip_if_lt_x_gpu(2)
1138    @parametrize("state_dict_type", STATE_DICT_TYPES)
1139    @parametrize("add_to_fsdp_module", [False, True])
1140    def test_shard_full_optim_state_dict_unmanaged_params(
1141        self,
1142        state_dict_type: StateDictType,
1143        add_to_fsdp_module: bool,
1144    ):
1145        """
1146        Tests :meth:`shard_full_optim_state_dict` when there are unmanaged
1147        parameters.
1148          - If ``add_to_fsdp_module=True``, then the unmanaged parameters are
1149          added to a module to be wrapped with FSDP, in which case there should
1150          be an error since we require that all unflattened parameter
1151          comprising a flat parameter have the same scalar state (e.g. Adam
1152          "step") but the added parameter is missing its entry.
1153          - If ``add_to_fsdp_module=False``, then the unmanaged parameters are
1154          added to a module not to be wrapped with FSDP, in which case there
1155          should be no error (emulating model parallel use cases where some
1156          parameters may be managed externally to FSDP).
1157        We do not separately test unmanaged parameters for
1158        :meth:`scatter_full_optim_state_dict` and `flatten_sharded_optim_state_dict`
1159        to save CI cost since it call into the same subroutine
1160        :meth:`_flatten_optim_state_dict`.
1161        """
1162        if state_dict_type == StateDictType.SHARDED_STATE_DICT:
1163            use_optim_input = [False]
1164        else:
1165            use_optim_input = [False, True]
1166        self.run_subtests(
1167            {"use_optim_input": use_optim_input},
1168            self._test_shard_full_optim_state_dict_unmanaged_params,
1169            state_dict_type=state_dict_type,
1170            add_to_fsdp_module=add_to_fsdp_module,
1171        )
1172
1173    def _test_shard_full_optim_state_dict_unmanaged_params(
1174        self,
1175        state_dict_type: StateDictType,
1176        add_to_fsdp_module: bool,
1177        use_optim_input: bool,
1178    ):
1179        NUM_ITERS = 1
1180        # Create a normal wrapped model
1181        model, optim, optim_input = self._init_nested_model(wrap=True)
1182        self._step_model(model, optim, num_iters=NUM_ITERS)
1183
1184        if state_dict_type == StateDictType.FULL_STATE_DICT:
1185            fsdp_osd = (
1186                FSDP.full_optim_state_dict(model, optim, optim_input, rank0_only=False)
1187                if use_optim_input
1188                else FSDP.full_optim_state_dict(model, optim, rank0_only=False)
1189            )  # save on all ranks to avoid having to broadcast from rank 0
1190        else:
1191            fsdp_osd = FSDP.sharded_optim_state_dict(model, optim)
1192        # Create a new model with the same structure but additional unmanaged
1193        # parameters, representing the model for which we want to load
1194        device = torch.device("cuda")
1195        model = NestedModel().to(device)
1196        model, unmanaged_params = NestedModel.wrap_with_unmanaged_params(
1197            model,
1198            add_to_fsdp_module,
1199        )
1200        optim_input = list(model.parameters())
1201        optim = torch.optim.Adam(optim_input, lr=1e-3)
1202        if add_to_fsdp_module:
1203            # If we add the unmanaged parameters to a module wrapped with FSDP,
1204            # then the flat parameter will be comprised of some unflattened
1205            # parameters with zero-dimensional tensor state (i.e. Adam "step")
1206            # and others without (i.e. the unmanaged parameters), which
1207            # triggers an error that we have to ensure correctness
1208            error_prefix = (
1209                "^(All unflattened parameters comprising a "
1210                "single flat parameter must have scalar state with the "
1211                "same value and dtype)"
1212            )
1213            with self.assertRaisesRegex(ValueError, error_prefix):
1214                if state_dict_type == StateDictType.FULL_STATE_DICT:
1215                    (
1216                        FSDP.shard_full_optim_state_dict(
1217                            fsdp_osd, model, optim_input=optim_input
1218                        )
1219                        if use_optim_input
1220                        else FSDP.shard_full_optim_state_dict(
1221                            fsdp_osd, model, optim=optim
1222                        )
1223                    )
1224                else:
1225                    FSDP.flatten_sharded_optim_state_dict(fsdp_osd, model, optim=optim)
1226        else:
1227            # If we add the unmanaged parameters to a module not wrapped with
1228            # FSDP, then we simply ignore them without erroring to enable
1229            # model parallelism use cases, where some parameters are managed
1230            # externally to FSDP
1231            if state_dict_type == StateDictType.FULL_STATE_DICT:
1232                flattened_osd = (
1233                    FSDP.shard_full_optim_state_dict(
1234                        fsdp_osd, model, optim_input=optim_input
1235                    )
1236                    if use_optim_input
1237                    else FSDP.shard_full_optim_state_dict(fsdp_osd, model, optim=optim)
1238                )
1239            else:
1240                flattened_osd = FSDP.flatten_sharded_optim_state_dict(
1241                    fsdp_osd, model, optim=optim
1242                )
1243            # Add entries for the unmanaged parameters to be able to load
1244            for unmanaged_param in unmanaged_params:
1245                NestedModel.add_unmanaged_param_entry(
1246                    flattened_osd,
1247                    unmanaged_param,
1248                    NUM_ITERS,
1249                )
1250            # Check that we can load the optimizer state dict
1251            optim.load_state_dict(flattened_osd)
1252
1253    @skip_if_lt_x_gpu(2)
1254    @parametrize("state_dict_type", STATE_DICT_TYPES)
1255    @parametrize("use_multiple_param_groups", [False, True])
1256    def test_rekey_optim_state_dict_to_ids(
1257        self,
1258        state_dict_type: StateDictType,
1259        use_multiple_param_groups: bool,
1260    ):
1261        """Tests :meth:`rekey_optim_state_dict` with the new keys being
1262        parameter IDs by checking that a wrapped model (i.e. with FSDP modules)
1263        can rekey its optimizer state dict to match that of an equivalent
1264        non-wrapped model (i.e. without FSDP modules)."""
1265        if state_dict_type == StateDictType.SHARDED_STATE_DICT:
1266            use_optim_input = [False]
1267        else:
1268            use_optim_input = [False, True]
1269        self.run_subtests(
1270            {"use_optim_input": use_optim_input},
1271            self._test_rekey_optim_state_dict_to_ids,
1272            state_dict_type=state_dict_type,
1273            use_multiple_param_groups=use_multiple_param_groups,
1274        )
1275
1276    @skip_if_lt_x_gpu(2)
1277    def _test_rekey_optim_state_dict_to_ids(
1278        self,
1279        state_dict_type: StateDictType,
1280        use_multiple_param_groups: bool,
1281        use_optim_input: bool,
1282    ):
1283        NUM_ITERS = 3
1284        # Run a wrapped model for a few iterations
1285        model1, optim1, optim_input1 = self._init_nested_model(
1286            wrap=True,
1287            use_multiple_param_groups=use_multiple_param_groups,
1288        )
1289        self._step_model(model1, optim1, num_iters=NUM_ITERS)
1290        if state_dict_type == StateDictType.FULL_STATE_DICT:
1291            fsdp_osd = (
1292                FSDP.full_optim_state_dict(model1, optim1, optim_input1)
1293                if use_optim_input
1294                else FSDP.full_optim_state_dict(model1, optim1)
1295            )
1296            # Broadcast instead of `torch.save()`/`torch.load()` so that all ranks
1297            # have the full state dict
1298            fsdp_osd = self._broadcast_full_osd(fsdp_osd)
1299        else:
1300            fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1)
1301        # Run a non-wrapped model for a few iterations
1302        model2, optim2, optim_input2 = self._init_nested_model(
1303            wrap=False,
1304            use_multiple_param_groups=use_multiple_param_groups,
1305        )
1306        self._step_model(model2, optim2, num_iters=NUM_ITERS)
1307        # Re-key the wrapped model's optimizer state dict using parameter IDs
1308        # according to the non-wrapped model
1309        rekeyed_osd = (
1310            FSDP.rekey_optim_state_dict(
1311                fsdp_osd,
1312                OptimStateKeyType.PARAM_ID,
1313                model2,
1314                optim_input=optim_input2,
1315            )
1316            if use_optim_input
1317            else FSDP.rekey_optim_state_dict(
1318                fsdp_osd,
1319                OptimStateKeyType.PARAM_ID,
1320                model2,
1321                optim=optim2,
1322            )
1323        )
1324        # Check that the re-keyed dict and actual dict are the same
1325        osd = optim2.state_dict()
1326        check_same_param_keys = True
1327        self._check_same_param_groups(
1328            rekeyed_osd,
1329            osd,
1330            check_same_param_keys=check_same_param_keys,
1331        )
1332        self._check_same_state(
1333            rekeyed_osd,
1334            osd,
1335            check_same_param_keys=check_same_param_keys,
1336        )
1337        # As a sanity check, check that we can load and run a few iterations
1338        if state_dict_type != StateDictType.SHARDED_STATE_DICT:
1339            optim2.load_state_dict(rekeyed_osd)
1340            self._step_model(model2, optim2, num_iters=NUM_ITERS)
1341
1342    @skip_if_lt_x_gpu(2)
1343    def test_rekey_optim_state_dict_to_names(self):
1344        """Tests :meth:`rekey_optim_state_dict` with the new keys being
1345        parameter names by checking that a non-wrapped model (i.e. without FSDP
1346        modules) can rekey its optimizer state dict to match the expected
1347        output of :meth:`full_optim_state_dict`, hence be sharded using
1348        :meth:`shard_full_optim_state_dict`, and finally match the per-rank
1349        optimizer state dict of a wrapped model (i.e. with FSDP modules)."""
1350        self.run_subtests(
1351            {"use_optim_input": [False, True]},
1352            self._test_rekey_optim_state_dict_to_names,
1353            use_multiple_param_groups=False,
1354        )
1355
1356    def _test_rekey_optim_state_dict_to_names(
1357        self,
1358        use_multiple_param_groups: bool,
1359        use_optim_input: bool,
1360    ):
1361        NUM_ITERS = 3
1362        # Run a wrapped model for a few iterations
1363        model1, optim1, optim_input1 = self._init_nested_model(
1364            wrap=True,
1365            use_multiple_param_groups=use_multiple_param_groups,
1366        )
1367        self._step_model(model1, optim1, num_iters=NUM_ITERS)
1368        # Run a non-wrapped model for a few iterations
1369        model2, optim2, optim_input2 = self._init_nested_model(
1370            wrap=False,
1371            use_multiple_param_groups=use_multiple_param_groups,
1372        )
1373        self._step_model(model2, optim2, num_iters=NUM_ITERS)
1374        # Re-key the non-wrapped model's optimizer state dict using parameter
1375        # names (still according to itself)
1376        osd2 = optim2.state_dict()
1377        rekeyed_osd = (
1378            FSDP.rekey_optim_state_dict(
1379                osd2,
1380                OptimStateKeyType.PARAM_NAME,
1381                model2,
1382                optim_input=optim_input2,
1383            )
1384            if use_optim_input
1385            else FSDP.rekey_optim_state_dict(
1386                osd2,
1387                OptimStateKeyType.PARAM_NAME,
1388                model2,
1389                optim=optim2,
1390            )
1391        )
1392        # Shard the non-wrapped model's re-keyed optimizer state dict, which
1393        # maps back to (flattened) parameter IDs
1394        sharded_osd = (
1395            FSDP.shard_full_optim_state_dict(
1396                rekeyed_osd,
1397                model1,
1398                optim_input=optim_input1,
1399            )
1400            if use_optim_input
1401            else FSDP.shard_full_optim_state_dict(
1402                rekeyed_osd,
1403                model1,
1404                optim=optim1,
1405            )
1406        )
1407        # Check that this sharded optimizer state dict matches the wrapped
1408        # model's per-rank optimizer state dict
1409        osd1 = optim1.state_dict()
1410        check_same_param_keys = True
1411        self._check_same_param_groups(
1412            sharded_osd,
1413            osd1,
1414            check_same_param_keys=check_same_param_keys,
1415        )
1416        self._check_same_state(
1417            sharded_osd,
1418            osd1,
1419            check_same_param_keys=check_same_param_keys,
1420        )
1421        # As a sanity check, check that we can load and run a few iterations
1422        optim1.load_state_dict(sharded_osd)
1423        self._step_model(model1, optim1, num_iters=NUM_ITERS)
1424
1425    @skip_if_lt_x_gpu(2)
1426    def test_optim_input_warning(self):
1427        """Tests that passing the ``optim_input`` argument into optimizer state
1428        checkpointing APIs issues a warning."""
1429
1430        def should_check_method(method_name: str):
1431            # Check every method since they all accept `optim_input`
1432            return method_name not in (
1433                "sharded_optim_state_dict",
1434                "flatten_sharded_optim_state_dict",
1435            )
1436
1437        def get_warning_context():
1438            warning_regex = "`optim_input` argument is deprecated"
1439            return self.assertWarnsRegex(
1440                expected_warning=FutureWarning, expected_regex=warning_regex
1441            )
1442
1443        self._run_on_all_optim_state_apis(
1444            should_check_method, get_warning_context, fsdp_kwargs=None
1445        )
1446
1447    def _run_on_all_optim_state_apis(
1448        self,
1449        should_check_method_fn: Callable[[str], bool],
1450        context_fn: Callable,
1451        fsdp_kwargs: Optional[Dict[str, Any]],
1452    ):
1453        """
1454        Runs through all optimizer state checkpointing APIs with a context
1455        manager instantiated by ``context_fn``. Certain APIs can be skipped
1456        via ``should_check_method_fn``, which gets passed the string name of
1457        the method.
1458        """
1459        wrapped_model, wrapped_optim, wrapped_optim_input = self._init_nested_model(
1460            wrap=True,
1461            use_multiple_param_groups=False,
1462            fsdp_kwargs=fsdp_kwargs,
1463        )
1464        self._step_model(wrapped_model, wrapped_optim, num_iters=2)
1465
1466        # Sharded optim state dict
1467        if should_check_method_fn("sharded_optim_state_dict"):
1468            with context_fn():
1469                fsdp_osd = FSDP.sharded_optim_state_dict(
1470                    wrapped_model,
1471                    wrapped_optim,
1472                )
1473        if "fsdp_osd" not in locals():
1474            fsdp_osd = {}  # may not be defined due to previous method erroring
1475        if should_check_method_fn("flatten_sharded_optim_state_dict"):
1476            with context_fn():
1477                FSDP.flatten_sharded_optim_state_dict(
1478                    fsdp_osd,
1479                    wrapped_model,
1480                    wrapped_optim,
1481                )
1482        # Full optim state dict
1483        if should_check_method_fn("full_optim_state_dict"):
1484            with context_fn():
1485                fsdp_osd = FSDP.full_optim_state_dict(
1486                    wrapped_model,
1487                    wrapped_optim,
1488                    optim_input=wrapped_optim_input,
1489                    rank0_only=False,
1490                )
1491        if should_check_method_fn("shard_full_optim_state_dict"):
1492            with context_fn():
1493                FSDP.shard_full_optim_state_dict(
1494                    fsdp_osd,
1495                    wrapped_model,
1496                    optim_input=wrapped_optim_input,
1497                )
1498        if should_check_method_fn("scatter_full_optim_state_dict"):
1499            with context_fn():
1500                FSDP.scatter_full_optim_state_dict(
1501                    fsdp_osd,
1502                    wrapped_model,
1503                    optim_input=wrapped_optim_input,
1504                )
1505        # Rekey optim state dict
1506        (
1507            nonwrapped_model,
1508            nonwrapped_optim,
1509            nonwrapped_optim_input,
1510        ) = self._init_nested_model(wrap=False, use_multiple_param_groups=False)
1511        if should_check_method_fn("rekey_optim_state_dict"):
1512            with context_fn():
1513                rekeyed_osd = FSDP.rekey_optim_state_dict(
1514                    fsdp_osd,  # from `full_optim_state_dict()`
1515                    OptimStateKeyType.PARAM_ID,
1516                    nonwrapped_model,
1517                    optim_input=nonwrapped_optim_input,
1518                )
1519        self._step_model(nonwrapped_model, nonwrapped_optim, num_iters=2)
1520        osd = nonwrapped_optim.state_dict()
1521        if should_check_method_fn("rekey_optim_state_dict"):
1522            with context_fn():
1523                FSDP.rekey_optim_state_dict(
1524                    osd,
1525                    OptimStateKeyType.PARAM_NAME,
1526                    nonwrapped_model,
1527                    optim_input=nonwrapped_optim_input,
1528                )
1529
1530    @skip_if_lt_x_gpu(2)
1531    @parametrize("state_dict_type", STATE_DICT_TYPES)
1532    def test_save_load_without_0th_param_state(self, state_dict_type: StateDictType):
1533        """
1534        Tests saving and loading an optim state dict for Adam optimizer (i.e.
1535        any optimizer with a "step" key in its state) when the first parameter
1536        does not have optimizer state (e.g. unused or frozen).
1537        """
1538
1539        class Model(nn.Module):
1540            def __init__(self) -> None:
1541                super().__init__()
1542                self.lin1 = nn.Linear(5, 5)
1543                self.lin2 = nn.Linear(5, 5)
1544                self.relu = nn.ReLU()
1545
1546            def forward(self, x: torch.Tensor) -> torch.Tensor:
1547                # Do not use `lin1`, which is the parameter passed to the
1548                # optimizer and the one checked for "step" state to see if it
1549                # is tensor or float
1550                return self.relu(self.lin2(x))
1551
1552        model = Model().cuda()
1553        model.lin1 = FSDP(model.lin1)
1554        model.lin2 = FSDP(model.lin2)
1555        fsdp_model = FSDP(model)
1556        optim = torch.optim.Adam(
1557            fsdp_model.parameters(), lr=1e-2
1558        )  # or any optimizer with "step"
1559
1560        # Run an iteration to construct optimizer state
1561        device = torch.device("cuda")
1562        inp = torch.randn((2, 5), device=device)
1563        loss = fsdp_model(inp).sum()
1564        loss.backward()
1565        optim.step()
1566
1567        # Check that save and load does not error
1568        if state_dict_type == StateDictType.FULL_STATE_DICT:
1569            fsdp_osd = FSDP.full_optim_state_dict(fsdp_model, optim, rank0_only=False)
1570            flattened_osd = FSDP.shard_full_optim_state_dict(fsdp_osd, fsdp_model)
1571        elif state_dict_type == StateDictType.SHARDED_STATE_DICT:
1572            fsdp_osd = FSDP.sharded_optim_state_dict(fsdp_model, optim)
1573            flattened_osd = FSDP.flatten_sharded_optim_state_dict(
1574                fsdp_osd, fsdp_model, optim
1575            )
1576        optim.load_state_dict(flattened_osd)
1577        # `__setstate__()` will check the 0th parameter to see if "step" is
1578        # represented as a tensor or float, so it is imperative that its state
1579        # is non-empty.
1580
1581        # Run an iteration as a sanity check
1582        inp = torch.randn((2, 5), device=device)
1583        loss = fsdp_model(inp).sum()
1584        loss.backward()
1585        optim.step()
1586
1587    @skip_if_lt_x_gpu(2)
1588    def test_compatible_with_trec(self):
1589        class DenseModel(torch.nn.Module):
1590            def __init__(self) -> None:
1591                super().__init__()
1592                self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
1593                self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
1594                self.net3 = nn.Linear(32, 64)
1595                self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
1596
1597            def forward(self, x):
1598                return self.net4(self.net3(self.net2(self.net1(x))))
1599
1600        class FakeMPModel(torch.nn.Module):
1601            def __init__(self) -> None:
1602                super().__init__()
1603                torch.manual_seed(0)
1604                self.dense = FSDP(DenseModel().cuda(), use_orig_params=True)
1605                if dist.get_rank() == 0:
1606                    self.sparse0 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
1607                else:
1608                    self.sparse1 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
1609
1610            def forward(self, x):
1611                if dist.get_rank() == 0:
1612                    sparse = self.sparse0(x)
1613                else:
1614                    sparse = self.sparse1(x)
1615                dist.all_reduce(sparse)
1616                return self.dense(sparse)
1617
1618        models = [FakeMPModel().cuda(), FakeMPModel().cuda()]
1619        optims = [
1620            torch.optim.Adam(models[0].parameters(), lr=1e-2),
1621            _NamedOptimizer(
1622                models[1].named_parameters(),
1623                torch.optim.Adam,
1624                [{"params": models[1].parameters()}],
1625                models[1],
1626                lr=1e-2,
1627            ),
1628        ]
1629        state_dicts = []
1630
1631        # Train one batch and see if optim_state_dict are the same.
1632        batch = torch.rand(5, 8, device=torch.device("cuda"))
1633        for model, optim in zip(models, optims):
1634            # Eagerly initialize the states
1635            for param in model.parameters():
1636                if param.requires_grad:
1637                    t = torch.zeros_like(param)
1638                    param.grad = torch.autograd.Variable(t)
1639            optim.step()
1640            loss = model(batch).sum()
1641            loss.backward()
1642            optim.step()
1643            state_dicts.append(deepcopy(FSDP.optim_state_dict(model, optim)))
1644
1645        self._check_same_param_groups(
1646            state_dicts[0], state_dicts[1], check_same_param_keys=False
1647        )
1648        self._check_same_state(
1649            state_dicts[0], state_dicts[1], check_same_param_keys=True
1650        )
1651
1652        # Make optim1 has a different state.
1653        for i in range(5):
1654            batch = torch.rand(5, 8).cuda()
1655            loss = models[1](batch).sum()
1656            loss.backward()
1657            optims[1].step()
1658
1659        # Load the state back to see if load_optim_state_dict works.
1660        state_dict_to_load = FSDP.optim_state_dict_to_load(
1661            models[1], optims[1], state_dicts[1], is_named_optimizer=True
1662        )
1663        optims[1].load_state_dict(state_dict_to_load)
1664        state_dicts[1] = FSDP.optim_state_dict(models[1], optims[1])
1665
1666        self._check_same_param_groups(
1667            state_dicts[0], state_dicts[1], check_same_param_keys=False
1668        )
1669        self._check_same_state(
1670            state_dicts[0], state_dicts[1], check_same_param_keys=True
1671        )
1672
1673    @skip_if_lt_x_gpu(2)
1674    def test_optim_state_without_param_groups(self):
1675        class SimpleModel(torch.nn.Module):
1676            def __init__(self) -> None:
1677                super().__init__()
1678                torch.manual_seed(0)
1679                self.net1 = nn.Sequential(nn.Linear(2, 4), nn.ReLU())
1680
1681            def forward(self, x):
1682                return self.net1(x)
1683
1684        model = FSDP(SimpleModel().cuda())
1685        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
1686
1687        # Train one step to save original optimizer state dict and original optimizer param groups.
1688        batch = torch.rand(3, 2, device=torch.device("cuda"))
1689        for param in model.parameters():
1690            if param.requires_grad:
1691                t = torch.zeros_like(param)
1692                param.grad = torch.autograd.Variable(t)
1693        optim.step()
1694        loss = model(batch).sum()
1695        loss.backward()
1696
1697        original_osd = deepcopy(optim.state_dict())
1698        original_osd_no_param_groups = deepcopy(original_osd)
1699        # manually remove param_groups from optimizer state dict
1700        original_param_groups = deepcopy(
1701            original_osd_no_param_groups.pop("param_groups")
1702        )
1703        # passing the osd without param_groups to FSDP
1704        original_fsdp_optim_state_dict = deepcopy(
1705            FSDP.optim_state_dict(
1706                model, optim, optim_state_dict=original_osd_no_param_groups
1707            )
1708        )
1709        # check the state_dict sharded by FSDP does not contain param_groups.
1710        self.assertEqual(None, original_fsdp_optim_state_dict.get("param_groups"))
1711
1712        # train another step to make optim a different state.
1713        for param in model.parameters():
1714            if param.requires_grad:
1715                t = torch.zeros_like(param)
1716                param.grad = torch.autograd.Variable(t)
1717        optim.step()
1718        loss = model(batch).sum()
1719        loss.backward()
1720
1721        state_dict_to_load = FSDP.optim_state_dict_to_load(
1722            model, optim, original_fsdp_optim_state_dict
1723        )
1724        # manually add param_groups to state_dict_to_load before loading the optimizer state
1725        state_dict_to_load["param_groups"] = original_param_groups
1726        optim.load_state_dict(state_dict_to_load)
1727        self.assertEqual(original_osd, optim.state_dict())
1728
1729        fsdp_optim_state = FSDP.optim_state_dict(model, optim)
1730        self._check_same_state(
1731            original_fsdp_optim_state_dict, fsdp_optim_state, check_same_param_keys=True
1732        )
1733        self.assertEqual(original_param_groups, optim.state_dict()["param_groups"])
1734
1735    @skip_if_lt_x_gpu(2)
1736    def test_with_empty_optimizer_state(self):
1737        model = FSDP(TestDummyModel().cuda())
1738        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
1739        state_dict = optim.state_dict()
1740        gathered_state_dict = FSDP.optim_state_dict(model, optim)
1741        self.assertEqual(gathered_state_dict["state"], state_dict["state"])
1742
1743    def _test_load_optim_state_with_optim_state_dict(
1744        self,
1745        model_class: _ModelClass,
1746        state_dict_settings: StateDictSettings,
1747        use_multiple_param_groups: bool,
1748        halve_world_size: bool,
1749        use_diff_optim_inputs: bool,
1750        num_iters: int,
1751        **new_model_kwargs,
1752    ):
1753        """
1754        (1) Runs a model with full world size for K iterations to generate a
1755        full/sharded optimizer state dict;
1756        (2) initializes a model with halved world size and possibly different
1757        FSDP wrapping scheme (based on ``new_model_kwargs``);
1758        (3) loads the full/sharded optimizer state dict from (1) according to the
1759        halved-world-size model;
1760        (4) runs the halved-world-size model for K iterations; and
1761        (5) checks that the sharded optimizer state dict from (3) matches the
1762        halved-world-size model's local optimizer state dict, meaning that the
1763        former could have equivalently been loaded into the local optimizer.
1764        """
1765        initializer = self._model_class[model_class]
1766
1767        # First, run a wrapped model with full world size for a few iterations
1768        model1, optim1, optim_input1 = initializer(
1769            wrap=True,
1770            use_multiple_param_groups=use_multiple_param_groups,
1771        )
1772        FSDP.set_state_dict_type(
1773            model1,
1774            state_dict_settings.state_dict_type,
1775            state_dict_settings.state_dict_config,
1776            state_dict_settings.optim_state_dict_config,
1777        )
1778        self._step_model(model1, optim1, num_iters=num_iters)
1779        fsdp_osd1 = FSDP.optim_state_dict(model1, optim1)
1780        if halve_world_size:
1781            # Create a new process group with halved world size
1782            new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]
1783            new_group = dist.new_group(ranks=new_group_ranks)
1784            if self.rank not in new_group_ranks:
1785                return
1786        else:
1787            # Continue using the same group and hence world size
1788            new_group = dist.distributed_c10d._get_default_group()
1789        # Second, run a wrapped model with (possibly) halved world size and
1790        # (possibly) differing `optim_input` across ranks
1791        model2, optim2, optim_input2 = initializer(
1792            wrap=True,
1793            group=new_group,
1794            use_multiple_param_groups=use_multiple_param_groups,
1795            use_diff_optim_inputs=use_diff_optim_inputs,
1796            **new_model_kwargs,  # specify `wrap_alt` to change wrapping
1797        )
1798        FSDP.set_state_dict_type(
1799            model2,
1800            state_dict_settings.state_dict_type,
1801            state_dict_settings.state_dict_config,
1802            state_dict_settings.optim_state_dict_config,
1803        )
1804        self._step_model(model2, optim2, num_iters=num_iters)
1805        fsdp_osd2 = FSDP.optim_state_dict(model2, optim2, group=new_group)
1806        # Compute two sharded optim state dicts: (1) for the first model
1807        # according to the second model and (2) for the second model according
1808        # to the second model
1809        sharded_osd2 = FSDP.optim_state_dict_to_load(
1810            model2, optim2, fsdp_osd2, group=new_group
1811        )
1812
1813        # As a sanity check, check that sharding the second model's full/sharded
1814        # optimizer state dict according to itself is equivalent to its local
1815        # optimizer's state dict
1816        local_osd2 = optim2.state_dict()
1817        self._check_same_param_groups(
1818            sharded_osd2,
1819            local_osd2,
1820            check_same_param_keys=True,
1821        )
1822        self._check_same_state(
1823            sharded_osd2,
1824            local_osd2,
1825            check_same_param_keys=True,
1826        )
1827        # Check that sharding the first model's full/sharded optimizer state dict
1828        # according to the second model is equivalent to the second model's
1829        # local optimizer state dict
1830        sharded_osd1 = FSDP.optim_state_dict_to_load(
1831            model2, optim2, fsdp_osd1, group=new_group
1832        )
1833        self._check_same_param_groups(
1834            sharded_osd1,
1835            local_osd2,
1836            check_same_param_keys=True,
1837        )
1838        self._check_same_state(
1839            sharded_osd1,
1840            local_osd2,
1841            check_same_param_keys=True,
1842        )
1843        # As a sanity check, check that we can load and run a few iterations
1844        optim2.load_state_dict(sharded_osd2)
1845        self._step_model(model2, optim2, num_iters=num_iters)
1846
1847    @skip_if_lt_x_gpu(2)
1848    def test_interface_arguments(self):
1849        model = FSDP(TestDummyModel().cuda())
1850        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
1851
1852        def step():
1853            loss = model(model.get_input())
1854            loss.backward(loss)
1855            optim.step()
1856
1857        step()
1858        original_osd = deepcopy(optim.state_dict())
1859        osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
1860        self._check_same_state(
1861            FSDP.optim_state_dict(model, optim), osd, check_same_param_keys=True
1862        )
1863        step()
1864        osd_to_load = FSDP.optim_state_dict_to_load(
1865            model, optim, osd, load_directly=True
1866        )
1867        self._check_same_state(
1868            optim.state_dict(), original_osd, check_same_param_keys=True
1869        )
1870
1871        # Test the default setting.
1872        osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
1873        for state in osd["state"].values():
1874            for s in state.values():
1875                self.assertFalse(isinstance(s, ShardedTensor))
1876                self.assertFalse(s.is_cuda)
1877
1878        # Test sharded state_dict without offload_to_cpu
1879        with FSDP.state_dict_type(
1880            model,
1881            StateDictType.SHARDED_STATE_DICT,
1882            ShardedStateDictConfig(),
1883            ShardedOptimStateDictConfig(offload_to_cpu=False),
1884        ):
1885            osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
1886            for state in osd["state"].values():
1887                for s in state.values():
1888                    if s.dim() == 0:
1889                        continue
1890                    self.assertTrue(isinstance(s, ShardedTensor))
1891                    if s._local_shards[0]:
1892                        self.assertTrue(s._local_shards[0].tensor.is_cuda)
1893
1894        # Test full state_dict with rank0_only
1895        with FSDP.state_dict_type(
1896            model,
1897            StateDictType.FULL_STATE_DICT,
1898            FullStateDictConfig(),
1899            FullOptimStateDictConfig(
1900                offload_to_cpu=True,
1901                rank0_only=True,
1902            ),
1903        ):
1904            osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
1905            if dist.get_rank() > 0:
1906                self.assertEqual(osd, {})
1907            else:
1908                for state in osd["state"].values():
1909                    for s in state.values():
1910                        if s.dim() == 0:
1911                            continue
1912                        self.assertFalse(s.is_cuda)
1913                        self.assertFalse(isinstance(s, ShardedTensor))
1914
1915    @skip_if_lt_x_gpu(2)
1916    def test_state_dict_with_none_tensor_state(self):
1917        def _run_test(use_orig_params, optimizer_has_tensor_state):
1918            model = FSDP(TestDummyModel().cuda(), use_orig_params=use_orig_params)
1919            optimizer_cls = (
1920                torch.optim.Adam if optimizer_has_tensor_state else torch.optim.SGD
1921            )
1922            optim = optimizer_cls(model.parameters(), lr=1e-2)
1923
1924            def step():
1925                loss = model(model.get_input())
1926                loss.backward(loss)
1927                optim.step()
1928
1929            step()
1930            original_osd = deepcopy(optim.state_dict())
1931            for state in original_osd["state"].values():
1932                # Add customized value
1933                state["value1"] = 2.74
1934                state["value2"] = None
1935
1936            osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
1937            osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd)
1938            for state in osd_to_load["state"].values():
1939                self.assertEqual(state["value1"], 2.74)
1940                self.assertEqual(state["value2"], None)
1941
1942        self.run_subtests(
1943            {
1944                "use_orig_params": [False, True],
1945                "optimizer_has_tensor_state": [False, True],
1946            },
1947            _run_test,
1948        )
1949
1950    @skip_if_lt_x_gpu(2)
1951    def test_with_no_shard(self):
1952        def _run_test(use_orig_params: bool) -> None:
1953            model = FSDP(
1954                TestDummyModel().cuda(),
1955                sharding_strategy=ShardingStrategy.NO_SHARD,
1956                use_orig_params=use_orig_params,
1957            )
1958            optim = torch.optim.Adam(model.parameters(), lr=1e-2)
1959
1960            def step():
1961                loss = model(model.get_input())
1962                loss.backward(loss)
1963                optim.step()
1964
1965            step()
1966
1967            original_osd = deepcopy(optim.state_dict())
1968
1969            osd = FSDP.optim_state_dict(model, optim)
1970            osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd)
1971            optim.load_state_dict(osd_to_load)
1972
1973            new_osd = optim.state_dict()
1974
1975            self.assertEqual(original_osd, new_osd)
1976
1977        self.run_subtests({"use_orig_params": [False, True]}, _run_test)
1978
1979    @skip_if_lt_x_gpu(2)
1980    def test_no_grad(self):
1981        model = TestDummyModel(no_grad=True).cuda()
1982        fsdp_model = FSDP(deepcopy(model), use_orig_params=True)
1983        fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
1984
1985        for i in range(5):
1986            if i % 2 == 1:
1987                fsdp_model.net1[0].weight.requires_grad = True
1988                fsdp_model.net1[0].bias.requires_grad = True
1989            else:
1990                fsdp_model.net1[0].weight.requires_grad = False
1991                fsdp_model.net1[0].bias.requires_grad = False
1992            batch = fsdp_model.get_input()
1993            loss = fsdp_model(batch).sum()
1994            loss.backward()
1995            fsdp_optim.step()
1996            orig_state_dict = deepcopy(fsdp_optim.state_dict())
1997            optim_state_dict = FSDP.optim_state_dict(fsdp_model, fsdp_optim)
1998            FSDP.optim_state_dict_to_load(
1999                fsdp_model,
2000                fsdp_optim,
2001                FSDP.optim_state_dict(fsdp_model, fsdp_optim),
2002                load_directly=True,
2003            )
2004
2005            self._check_same_state(
2006                fsdp_optim.state_dict(),
2007                orig_state_dict,
2008                check_same_param_keys=True,
2009            )
2010
2011
2012instantiate_parametrized_tests(TestFSDPOptimState)
2013
2014if __name__ == "__main__":
2015    run_tests()
2016