xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_state_dict.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import io
4import itertools
5import sys
6from contextlib import nullcontext
7from copy import deepcopy
8from functools import partial
9from typing import Any, Dict
10
11import torch
12import torch.nn as nn
13from torch import distributed as dist
14from torch.distributed._shard.sharded_tensor import (
15    init_from_local_shards,
16    Shard,
17    ShardedTensor,
18)
19from torch.distributed._state_dict_utils import (
20    _all_gather_sharded_tensor,
21    _gather_state_dict,
22)
23from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
24    apply_activation_checkpointing,
25    checkpoint_wrapper,
26    CheckpointImpl,
27)
28from torch.distributed.fsdp import (
29    CPUOffload,
30    FullStateDictConfig,
31    FullyShardedDataParallel as FSDP,
32    LocalStateDictConfig,
33    MixedPrecision,
34    ShardedStateDictConfig,
35    StateDictType,
36)
37from torch.distributed.fsdp._common_utils import FSDP_PREFIX
38from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM
39from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap
40from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer
41from torch.nn.parallel import DistributedDataParallel
42from torch.optim import SGD
43from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
44from torch.testing._internal.common_fsdp import (
45    _assert_module_states,
46    _broadcast_state_dict,
47    _get_state_dict,
48    _zero_model,
49    CUDAInitMode,
50    FSDPInitMode,
51    FSDPTest,
52    get_full_params,
53    SkipModel,
54    TransformerWithSharedParams,
55)
56from torch.testing._internal.common_utils import (
57    instantiate_parametrized_tests,
58    parametrize,
59    run_tests,
60    TEST_WITH_DEV_DBG_ASAN,
61)
62
63
64if not dist.is_available():
65    print("Distributed not available, skipping tests", file=sys.stderr)
66    sys.exit(0)
67
68if TEST_WITH_DEV_DBG_ASAN:
69    print(
70        "Skip dev-asan as torch + multiprocessing spawn have known issues",
71        file=sys.stderr,
72    )
73    sys.exit(0)
74
75INNER_SHAPE = [4, 4]
76OUTER_SHAPE = [4, 5]
77BUFFER_SHAPE = [5, 5]
78
79NON_ROOT_FSDP_PREFIX = "non_fsdp_lin"
80
81_UNFLATTENED_STATE_DICT_IMPLS = ["state_dict", "sharded_state_dict"]
82_FLATTENED_STATE_DICT_IMPLS = ["local_state_dict"]
83_SUPPORTED_STATE_DICT_IMPLS = (
84    _UNFLATTENED_STATE_DICT_IMPLS + _FLATTENED_STATE_DICT_IMPLS
85)
86
87STATE_DICT_MAPPING = {
88    "state_dict": StateDictType.FULL_STATE_DICT,
89    "local_state_dict": StateDictType.LOCAL_STATE_DICT,
90    "sharded_state_dict": StateDictType.SHARDED_STATE_DICT,
91}
92
93
94class Model(Module):
95    def __init__(
96        self,
97        wrap_fsdp,
98        register_buffers=False,
99        ignore_inner=False,
100        mixed_precision=False,
101        process_group=None,
102    ):
103        super().__init__()
104        self.inner = Linear(*INNER_SHAPE)
105        if register_buffers:
106            self.inner.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE))
107            self.inner.register_buffer(
108                "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
109            )
110        if wrap_fsdp:
111            self.inner = FSDP(
112                self.inner,
113                ignored_modules=([self.inner] if ignore_inner else []),
114                mixed_precision=MixedPrecision(
115                    param_dtype=torch.float16,
116                    reduce_dtype=torch.float16,
117                    buffer_dtype=torch.float16,
118                )
119                if mixed_precision
120                else None,
121                process_group=process_group,
122            )
123        self.outer = Linear(*OUTER_SHAPE)
124        if register_buffers:
125            self.outer.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE))
126            self.outer.register_buffer(
127                "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
128            )
129
130    def forward(self, x):
131        # Forward twice.
132        i = self.inner(x)
133        j = self.inner(x)
134        return self.outer(i + j)
135
136
137class TestDummyModel(torch.nn.Module):
138    def __init__(self) -> None:
139        super().__init__()
140        torch.manual_seed(0)
141        self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
142        self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
143        self.net3 = self.net2
144        self.random_parameter = nn.Parameter(torch.Tensor(10))
145        self.shared_parameter = self.random_parameter
146
147    def forward(self, x):
148        return self.net3(self.net2(self.net1(x)))
149
150    def get_input(self):
151        return torch.rand(8, 8, device="cuda")
152
153
154class TestFSDPStateDict(FSDPTest):
155    @property
156    def world_size(self):
157        return min(torch.cuda.device_count(), 2)
158
159    def _broadcast_state_dict(self, model, state_dict):
160        # TODO (rohan-varma): remove model
161        return _broadcast_state_dict(self.rank, state_dict)
162
163    def _state_compare(self, model, model_new, assert_fn, state_generator="parameters"):
164        state_base = list(getattr(model, state_generator)())
165        state_new = list(getattr(model_new, state_generator)())
166        # Regardless of `assert_fn`, the number of parameters should be the same
167        self.assertEqual(len(state_base), len(state_new))
168        assert_fn(state_base, state_new)
169
170    def _compare_models(
171        self, model, model_new, assert_fn, check_fp16=False, check_buffers=True
172    ):
173        assert assert_fn in (self.assertEqual, self.assertNotEqual)
174        with FSDP.summon_full_params(model):
175            with FSDP.summon_full_params(model_new):
176                self._state_compare(model, model_new, assert_fn)
177                if check_buffers:
178                    has_buffers = any(
179                        len(list(m.buffers())) for m in (model, model_new)
180                    )
181                    if has_buffers:
182                        self._state_compare(
183                            model, model_new, assert_fn, state_generator="buffers"
184                        )
185                if check_fp16:
186                    for tensor in model_new.parameters():
187                        self.assertEqual(tensor.dtype, torch.float16)
188
189    def _get_simple_nested_model(
190        self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs
191    ):
192        if wrap:
193            lin1 = nn.Linear(10, 10, bias=False).cuda()
194            lin2 = nn.Linear(10, 10, bias=False).cuda()
195            if checkpoint_wrap:
196                lin1 = checkpoint_wrapper(lin1)
197                lin2 = checkpoint_wrapper(lin2)
198            seq = nn.Sequential(FSDP(lin1, *fsdp_args, **fsdp_kwargs), lin2)
199            if checkpoint_wrap:
200                seq = checkpoint_wrapper(seq)
201            model = FSDP(seq, *fsdp_args, **fsdp_kwargs)
202        else:
203            model = nn.Sequential(
204                nn.Linear(10, 10, bias=False).cuda(),
205                nn.Linear(10, 10, bias=False).cuda(),
206            )
207        return model
208
209    def _get_simple_model(self, *fsdp_args, checkpoint_wrap=False, **fsdp_kwargs):
210        lin = nn.Linear(10, 10, bias=False).cuda()
211        if checkpoint_wrap:
212            lin = checkpoint_wrapper(lin)
213        model = FSDP(lin, *fsdp_args, **fsdp_kwargs)
214        return model
215
216    def _get_multibuffer_nested_model(
217        self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs
218    ):
219        full_p = torch.float32
220        lin_mp = fsdp_kwargs.pop("mixed_precision", None)
221        bn_mp = (
222            MixedPrecision(param_dtype=full_p, reduce_dtype=full_p, buffer_dtype=full_p)
223            if lin_mp
224            else None
225        )
226        if wrap:
227            lin1 = nn.Linear(10, 10, bias=False).cuda()
228            bn1 = nn.BatchNorm1d(10).cuda()
229            lin2 = nn.Linear(10, 10, bias=False).cuda()
230            if checkpoint_wrap:
231                lin1 = checkpoint_wrapper(lin1)
232                bn1 = checkpoint_wrapper(bn1)
233                lin2 = checkpoint_wrapper(lin2)
234            seq = nn.Sequential(
235                FSDP(lin1, *fsdp_args, mixed_precision=lin_mp, **fsdp_kwargs),
236                FSDP(bn1, *fsdp_args, mixed_precision=bn_mp, **fsdp_kwargs),
237                lin2,
238            )
239            if checkpoint_wrap:
240                seq = checkpoint_wrapper(seq)
241            model = FSDP(seq, *fsdp_args, **fsdp_kwargs)
242        else:
243            model = nn.Sequential(
244                nn.Linear(10, 10, bias=False).cuda(),
245                nn.BatchNorm1d(10).cuda(),
246                nn.Linear(10, 10, bias=False).cuda(),
247            )
248        return model
249
250    def _get_non_fsdp_root_module(self, *fsdp_args, wrap=True, **fsdp_kwargs):
251        class FSDPContainer(nn.Module):
252            def __init__(self, fsdp_1, fsdp_2):
253                super().__init__()
254                self.non_fsdp_lin = nn.Linear(10, 10, bias=False).cuda()
255                self.fsdp_1 = fsdp_1
256                self.fsdp_2 = fsdp_2
257
258            def forward(self, x):
259                x = self.non_fsdp_lin(x)
260                x = self.fsdp_1(x)
261                x = self.fsdp_2(x)
262                return x
263
264        return FSDPContainer(
265            self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs),
266            self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs),
267        )
268
269    def _get_state_dict_mgr(
270        self,
271        model: nn.Module,
272        state_dict_type: str,
273        state_dict_rank0_and_offload: bool,
274    ):
275        _state_dict_type = STATE_DICT_MAPPING[state_dict_type]
276        if state_dict_type == "state_dict":
277            config = FullStateDictConfig(
278                rank0_only=state_dict_rank0_and_offload,
279                offload_to_cpu=state_dict_rank0_and_offload,
280            )
281        elif state_dict_type == "local_state_dict":
282            config = LocalStateDictConfig(
283                offload_to_cpu=state_dict_rank0_and_offload,
284            )
285        elif state_dict_type == "sharded_state_dict":
286            config = ShardedStateDictConfig(
287                offload_to_cpu=state_dict_rank0_and_offload,
288            )
289        else:
290            raise ValueError("Unsupported state_dict_type")
291        return FSDP.state_dict_type(model, _state_dict_type, config)
292
293    def _validate_state_dict_contents(
294        self, model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=None
295    ):
296        if state_dict_rank0_and_offload:
297            if self.rank == 0:
298                self.assertNotEqual(fsdp_state_dict, {})
299                for key, tensor in fsdp_state_dict.items():
300                    if ignore_keys and key in ignore_keys:
301                        continue
302                    self.assertEqual(
303                        tensor.device,
304                        torch.device("cpu"),
305                        f"{key} is unexpectedly on device {tensor.device}",
306                    )
307            else:
308                # For non-FSDP roots, the non FSDP portion can still have parameters on rank 0,
309                # so bypass the check for now.
310                if isinstance(model, FSDP):
311                    self.assertEqual(
312                        fsdp_state_dict,
313                        {},
314                        f"Expected empty state_dict but got {fsdp_state_dict} on rank {dist.get_rank()}",
315                    )
316
317    @skip_if_lt_x_gpu(2)
318    @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
319    @parametrize(
320        "checkpoint_wrap",
321        ["source", "dest", "both", "source_after_wrap", "both_after_wrap"],
322    )
323    @parametrize("rank0_only_and_offload", [False, True])
324    def test_fsdp_state_dict_with_activation_checkpoint(
325        self, state_dict_type, checkpoint_wrap, rank0_only_and_offload
326    ):
327        """Tests saving the state dict, zeroing a target model's parameters, and
328        loading the state dict, where the source and target models may have a
329        checkpoint wrapper."""
330
331        def apply_ac_to_linears(model) -> None:
332            non_reentrant_wrapper = partial(
333                checkpoint_wrapper,
334                offload_to_cpu=False,
335                checkpoint_impl=CheckpointImpl.NO_REENTRANT,
336            )
337            apply_activation_checkpointing(
338                model,
339                checkpoint_wrapper_fn=non_reentrant_wrapper,
340                check_fn=lambda submodule: isinstance(submodule, nn.Linear),
341            )
342
343        for model_call in [
344            partial(self._get_simple_model),
345            partial(self._get_simple_nested_model),
346        ]:
347            model = model_call(checkpoint_wrap=(checkpoint_wrap in ("source", "both")))
348            if checkpoint_wrap in ("source_after_wrap", "both_after_wrap"):
349                apply_ac_to_linears(model)
350            with self._get_state_dict_mgr(
351                model, state_dict_type, rank0_only_and_offload
352            ):
353                state_dict = _gather_state_dict(_get_state_dict(model, False, False))
354                # Possibly wrap new model in activation checkpoint wrapper to test save/
355                # load with this wrapper
356                model_new = model_call(
357                    checkpoint_wrap=(checkpoint_wrap in ("dest", "both"))
358                )
359                if checkpoint_wrap == "both_after_wrap":
360                    apply_ac_to_linears(model_new)
361                _zero_model(model_new)
362                self._compare_models(model, model_new, self.assertNotEqual)
363                if rank0_only_and_offload:
364                    state_dict = self._broadcast_state_dict(model, state_dict)
365                # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks
366                model_new.load_state_dict(state_dict, strict=True)
367                self._compare_models(model, model_new, self.assertEqual)
368
369    @skip_if_lt_x_gpu(2)
370    @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
371    @parametrize("rank0_only_and_offload", [False, True])
372    def test_state_dict_with_manual_ac_wrapper(
373        self,
374        state_dict_type: str,
375        rank0_only_and_offload: bool,
376    ):
377        """
378        Tests saving and loading a state dict for a model manually wrapped with
379        ``FSDP(CheckpointWrapper(module))``, where the ``CheckpointWrapper`` is
380        wrapped before FSDP.
381
382        TODO: Investigate why the test above does not cover everything in this
383        test and de-duplicate afterwards.
384        """
385        if state_dict_type == "sharded_state_dict" and rank0_only_and_offload:
386            return  # not supported
387        model_ac = TransformerWithSharedParams.init(
388            self.process_group,
389            FSDPInitMode.NO_FSDP,
390            CUDAInitMode.CUDA_BEFORE,
391        )
392        # Manually wrap FSDP without AC
393        model_no_ac = deepcopy(model_ac)
394        for i, layer in enumerate(model_no_ac.transformer.encoder.layers):
395            model_no_ac.transformer.encoder.layers[i] = FSDP(layer)
396        for i, layer in enumerate(model_no_ac.transformer.decoder.layers):
397            model_no_ac.transformer.decoder.layers[i] = FSDP(layer)
398        model_no_ac.transformer = FSDP(model_no_ac.transformer)
399
400        # Manually wrap FSDP with AC as `FSDP(CheckpointWrapper(module))`
401        for i, layer in enumerate(model_ac.transformer.encoder.layers):
402            layer = checkpoint_wrapper(layer)
403            model_ac.transformer.encoder.layers[i] = FSDP(layer)
404        for i, layer in enumerate(model_ac.transformer.decoder.layers):
405            layer = checkpoint_wrapper(layer)
406            model_ac.transformer.decoder.layers[i] = FSDP(layer)
407        model_ac.transformer = FSDP(model_ac.transformer)
408
409        # Save, load, and compare the two models
410        with self._get_state_dict_mgr(
411            model_no_ac, state_dict_type, rank0_only_and_offload
412        ):
413            state_dict_no_ac = model_no_ac.state_dict()
414        with self._get_state_dict_mgr(
415            model_ac, state_dict_type, rank0_only_and_offload
416        ):
417            state_dict_ac = model_ac.state_dict()
418        self.assertEqual(state_dict_ac.keys(), state_dict_no_ac.keys())
419        if rank0_only_and_offload:
420            state_dict_no_ac = self._broadcast_state_dict(model_no_ac, state_dict_no_ac)
421            state_dict_ac = self._broadcast_state_dict(model_ac, state_dict_ac)
422        with self._get_state_dict_mgr(
423            model_no_ac, state_dict_type, rank0_only_and_offload
424        ):
425            model_no_ac.load_state_dict(state_dict_no_ac)
426        with self._get_state_dict_mgr(
427            model_ac, state_dict_type, rank0_only_and_offload
428        ):
429            model_ac.load_state_dict(state_dict_ac)
430        self._compare_models(model_ac, model_no_ac, self.assertEqual)
431
432    @skip_if_lt_x_gpu(2)
433    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
434    def test_state_dict_with_shared_parameters(self, state_dict_type):
435        auto_wrap_policy = ModuleWrapPolicy(
436            {TransformerEncoderLayer, TransformerDecoderLayer}
437        )
438        model_creator = partial(
439            TransformerWithSharedParams.init,
440            self.process_group,
441            FSDPInitMode.RECURSIVE,
442            CUDAInitMode.CUDA_BEFORE,
443            {"auto_wrap_policy": auto_wrap_policy},
444        )
445
446        fsdp_model = model_creator()
447        with self._get_state_dict_mgr(fsdp_model, state_dict_type, False):
448            state_dict = fsdp_model.state_dict()
449
450        new_model = model_creator()
451        _zero_model(new_model, zero_buffers=True)
452        with self._get_state_dict_mgr(new_model, state_dict_type, False):
453            new_model.load_state_dict(state_dict)
454
455    @skip_if_lt_x_gpu(2)
456    @parametrize("use_orig_params", [False, True])
457    def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool):
458        """Tests saving a model checkpoint only on rank 0 and loading it only
459        on rank 0 with ``sync_module_states=True`` to emulate the workflow to
460        avoid redundant CPU memory usage."""
461        auto_wrap_policy = ModuleWrapPolicy(
462            {TransformerEncoderLayer, TransformerDecoderLayer}
463        )
464        fsdp_kwargs = {
465            "auto_wrap_policy": auto_wrap_policy,
466            "use_orig_params": use_orig_params,
467        }
468        fsdp_model = TransformerWithSharedParams.init(
469            self.process_group,
470            FSDPInitMode.RECURSIVE,
471            CUDAInitMode.CUDA_BEFORE,
472            fsdp_kwargs,
473        )
474        # Force model parameters and buffers to be nonzero
475        with FSDP.summon_full_params(fsdp_model):
476            for tensor in itertools.chain(
477                fsdp_model.parameters(), fsdp_model.buffers()
478            ):
479                if torch.count_nonzero(tensor) == 0:
480                    with torch.no_grad():
481                        tensor.add_(torch.ones_like(tensor))
482        with self._get_state_dict_mgr(fsdp_model, "state_dict", True):
483            state_dict = deepcopy(_get_state_dict(fsdp_model))
484        # Initialize a non-wrapped model on all ranks
485        new_model = TransformerWithSharedParams.init(
486            self.process_group,
487            FSDPInitMode.NO_FSDP,
488            CUDAInitMode.CUDA_BEFORE,
489        )
490        _zero_model(new_model, zero_buffers=True)
491        # Only load the checkpoint on rank 0
492        if self.rank == 0:
493            new_model.load_state_dict(state_dict, strict=True)
494        _assert_module_states(
495            new_model,
496            process_group=self.process_group,
497            assert_fn=self.assertNotEqual,
498        )
499        # Broadcast the module states from rank 0 with `sync_module_states=True`
500        new_fsdp_model = FSDP(
501            new_model,
502            device_id=torch.cuda.current_device(),
503            auto_wrap_policy=auto_wrap_policy,
504            sync_module_states=True,
505        )
506        # Check FSDP models are equal across ranks
507        with FSDP.summon_full_params(new_fsdp_model):
508            _assert_module_states(
509                new_fsdp_model,
510                process_group=self.process_group,
511                assert_fn=self.assertEqual,
512            )
513        # Check FSDP models correctly loaded the checkpoint
514        with FSDP.summon_full_params(fsdp_model):
515            with FSDP.summon_full_params(new_fsdp_model):
516                params = list(fsdp_model.parameters())
517                params_new = list(new_fsdp_model.parameters())
518                self.assertEqual(params, params_new)
519
520    @skip_if_lt_x_gpu(2)
521    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
522    @parametrize(
523        "cpu_offload",
524        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
525    )
526    @parametrize("fp16", [True, False])
527    @parametrize("state_dict_rank0_and_offload", [True, False])
528    @parametrize("use_orig_params", [True, False])
529    def test_basic_save_and_load_state_dict(
530        self,
531        state_dict_type: str,
532        cpu_offload: bool,
533        fp16: bool,
534        state_dict_rank0_and_offload: bool,
535        use_orig_params: bool,
536    ):
537        """
538        Tests that we can save a state_dict and load it into a blank model
539        with various configs such as fp16 and cpu offload and parameters
540        match as expected.
541        """
542        if (state_dict_rank0_and_offload and state_dict_type != "state_dict") or (
543            use_orig_params and state_dict_type not in _UNFLATTENED_STATE_DICT_IMPLS
544        ):
545            return  # not supported
546        device = torch.device(self.rank)
547        for model_call in [
548            partial(
549                self._get_non_fsdp_root_module,
550                cpu_offload=cpu_offload,
551                use_orig_params=use_orig_params,
552            ),
553            partial(
554                self._get_simple_nested_model,
555                cpu_offload=cpu_offload,
556                use_orig_params=use_orig_params,
557            ),
558            partial(
559                self._get_simple_model,
560                cpu_offload=cpu_offload,
561                use_orig_params=use_orig_params,
562            ),
563        ]:
564            model = model_call()
565            if fp16:
566                model.half()
567            # Run a forward/backward to compute gradients to test the case
568            # where there are gradients populated
569            inp = torch.randn((3, 10), device=device)
570            if fp16:
571                inp = inp.half()
572            model(inp).sum().backward()
573
574            ctx = self._get_state_dict_mgr(
575                model, state_dict_type, state_dict_rank0_and_offload
576            )
577            with ctx:
578                fsdp_state_dict = _get_state_dict(
579                    model, cpu_offload.offload_params, fp16
580                )
581
582            ignore_keys = [
583                k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
584            ]
585
586            self._validate_state_dict_contents(
587                model,
588                fsdp_state_dict,
589                state_dict_rank0_and_offload,
590                ignore_keys=ignore_keys,
591            )
592            if fp16:
593                # Verify fp16 is the type
594                for tensor in fsdp_state_dict.values():
595                    self.assertEqual(tensor.dtype, torch.float16)
596
597            model_new = model_call()
598            if not cpu_offload.offload_params:
599                model_new = model_new.cuda()
600            if fp16:
601                model_new.half()
602            # Run a forward/backward to compute gradients to test the case
603            # where there are gradients populated
604            inp = torch.randn((3, 10), device=device)
605            if fp16:
606                inp = inp.half()
607            model_new(inp).sum().backward()
608
609            # zero the model to ensure parameters are different.
610            _zero_model(model_new, zero_buffers=True)
611            self._compare_models(model, model_new, self.assertNotEqual)
612
613            # Verify parameters are the same in the new model.
614            if state_dict_rank0_and_offload:
615                fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict)
616            with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]):
617                model_new.load_state_dict(fsdp_state_dict, strict=True)
618
619            self._compare_models(model, model_new, self.assertEqual, check_fp16=fp16)
620
621    @skip_if_lt_x_gpu(2)
622    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
623    @parametrize(
624        "cpu_offload",
625        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
626    )
627    @parametrize("mixed_precision", [True, False])
628    @parametrize("state_dict_rank0_and_offload", [True, False])
629    @parametrize("use_orig_params", [True, False])
630    def test_buffers_save_and_load_state_dict(
631        self,
632        state_dict_type: str,
633        cpu_offload: bool,
634        mixed_precision: bool,
635        state_dict_rank0_and_offload: bool,
636        use_orig_params: bool,
637    ):
638        """
639        Tests that we can save a state_dict and load it for modules with persistent buffers, including
640        in the context of non-default mixed precision, different ``state_dict_type`` s and CPU offloading.
641        """
642        if (state_dict_rank0_and_offload and state_dict_type != "state_dict") or (
643            use_orig_params and state_dict_type not in _UNFLATTENED_STATE_DICT_IMPLS
644        ):
645            return  # not supported
646        mixed_precision = (
647            MixedPrecision(
648                param_dtype=torch.float16,
649                reduce_dtype=torch.float16,
650                buffer_dtype=torch.float16,
651            )
652            if mixed_precision
653            else None
654        )
655        model_call = partial(
656            self._get_multibuffer_nested_model,
657            cpu_offload=cpu_offload,
658            use_orig_params=use_orig_params,
659            mixed_precision=mixed_precision,
660        )
661        model = model_call()
662        ctx = self._get_state_dict_mgr(
663            model, state_dict_type, state_dict_rank0_and_offload
664        )
665        with ctx:
666            fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, False)
667
668        self._validate_state_dict_contents(
669            model, fsdp_state_dict, state_dict_rank0_and_offload
670        )
671
672        model_new = model_call()
673        if not cpu_offload.offload_params:
674            model_new = model_new.cuda()
675
676        # zero the model to ensure parameters are different.
677        _zero_model(model_new, zero_buffers=True)
678        self._compare_models(model, model_new, self.assertNotEqual)
679
680        # Verify parameters are the same in the new model.
681        if state_dict_rank0_and_offload:
682            fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict)
683        with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]):
684            model_new.load_state_dict(fsdp_state_dict, strict=True)
685
686        self._compare_models(model, model_new, self.assertEqual)
687
688    @skip_if_lt_x_gpu(2)
689    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
690    @parametrize("mixed_precision", [True, False])
691    @parametrize("state_dict_rank0_and_offload", [True, False])
692    def test_save_and_load_after_forward_state_dict(
693        self, state_dict_type, mixed_precision, state_dict_rank0_and_offload
694    ):
695        """
696        Test that saving after some training results in params being updated as
697        expected.
698        """
699        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
700            return
701        torch.cuda.set_device(self.rank)
702        mixed_precision = (
703            MixedPrecision(
704                param_dtype=torch.float16,
705                reduce_dtype=torch.float16,
706                buffer_dtype=torch.float16,
707            )
708            if mixed_precision
709            else None
710        )
711        model = self._get_simple_nested_model(mixed_precision=mixed_precision)
712        optim = torch.optim.SGD(model.parameters(), lr=0.1)
713        initial_params = get_full_params(model)
714        for _ in range(6):
715            inp = torch.randn(1, 10, device=torch.cuda.current_device())
716            output = model(*inp)
717            loss = output.sum()
718            expected_dtype = torch.float32 if mixed_precision is None else torch.float16
719            self.assertEqual(expected_dtype, loss.dtype)
720            loss.backward()
721            optim.step()
722
723        trained_params = get_full_params(model)
724        # Ensure some training occurred
725        self.assertNotEqual(initial_params, trained_params)
726        # Save a copy of the state_dict
727        fsd_mgr = self._get_state_dict_mgr(
728            model, state_dict_type, state_dict_rank0_and_offload
729        )
730        with fsd_mgr:
731            state_dict = model.state_dict()
732            if state_dict_type == "state_dict":
733                state_dict = {k: v.clone() for k, v in state_dict.items()}
734            else:
735                for sharded_tensor in state_dict.values():
736                    shard = sharded_tensor._local_shards[0]
737                    shard.tensor = shard.tensor.clone().detach_()
738        self._validate_state_dict_contents(
739            model, state_dict, state_dict_rank0_and_offload
740        )
741        _zero_model(model)
742
743        # Ensure checkpointed params have the full param dtype
744        for tensor in state_dict.values():
745            self.assertEqual(tensor.dtype, torch.float32)
746
747        # Load state_dict into zeroed model
748        if state_dict_rank0_and_offload:
749            state_dict = self._broadcast_state_dict(model, state_dict)
750
751        with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]):
752            model.load_state_dict(state_dict, strict=True)
753        loaded_params = get_full_params(model)
754        self.assertEqual(loaded_params, trained_params)
755
756    def _initialize_model(
757        self,
758        wrap_fsdp: bool,
759        wrap_ddp: bool = True,
760        register_buffers: bool = False,
761    ):
762        # keep everything deterministic for input data
763        torch.manual_seed(0)
764
765        model = Model(wrap_fsdp, register_buffers=register_buffers).cuda()
766        if wrap_fsdp:
767            model = FSDP(model)
768        elif wrap_ddp:
769            model = DistributedDataParallel(model, device_ids=[self.rank])
770        return model
771
772    @staticmethod
773    def _state_dict(model: Module, state_dict_type: str):
774        try:
775            enum_val = STATE_DICT_MAPPING[state_dict_type]
776        except KeyError as e:
777            raise ValueError(f"No state_dict type for {state_dict_type}") from e
778
779        with FSDP.state_dict_type(model, enum_val):
780            return model.state_dict()
781
782    @staticmethod
783    def _load_state_dict(
784        model: Module, state_dict_type: str, state_dict: Dict[str, Any]
785    ):
786        try:
787            enum_val = STATE_DICT_MAPPING[state_dict_type]
788        except KeyError as e:
789            raise ValueError(f"No state_dict for {state_dict_type}") from e
790
791        with FSDP.state_dict_type(model, enum_val):
792            return model.load_state_dict(state_dict, strict=True)
793
794    def _dist_train(
795        self, wrap_fsdp: bool, state_dict_type: str = "", move_to_cpu: bool = False
796    ):
797        # TODO: Move this test to common_fsdp.
798        model = self._initialize_model(wrap_fsdp)
799        optim = SGD(model.parameters(), lr=0.1)
800
801        in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
802        for _ in range(3):
803            out = model(in_data)
804            out.sum().backward()
805            optim.step()
806            optim.zero_grad()
807
808        if wrap_fsdp:
809            blank_model = FSDP(Model(True).cuda())
810            _zero_model(blank_model)
811            state_dict = self._state_dict(model, state_dict_type)
812            if move_to_cpu:
813                for key in list(state_dict.keys()):
814                    tensor = state_dict[key]
815                    if isinstance(tensor, torch.Tensor):
816                        state_dict[key] = tensor.cpu()
817                    else:
818                        shards = tensor.local_shards()
819                        if shards:
820                            shards[0].tensor = shards[0].tensor.cpu()
821
822            self._load_state_dict(blank_model, state_dict_type, state_dict)
823            return get_full_params(blank_model)
824        else:
825            return list(model.parameters())
826
827    @skip_if_lt_x_gpu(2)
828    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
829    def test_state_dict_save_load_flow(self, state_dict_type):
830        self.run_subtests(
831            {"move_to_cpu": [True, False]},
832            self._test_state_dict_save_load_flow,
833            state_dict_type=state_dict_type,
834        )
835
836    def _test_state_dict_save_load_flow(self, state_dict_type, move_to_cpu):
837        fsdp_params = self._dist_train(
838            wrap_fsdp=True,
839            state_dict_type=state_dict_type,
840            move_to_cpu=move_to_cpu,
841        )
842        ddp_params = self._dist_train(wrap_fsdp=False)
843        self.assertEqual(ddp_params, fsdp_params)
844
845    @skip_if_lt_x_gpu(2)
846    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
847    def test_fsdp_state_dict_keys(self, state_dict_type):
848        state_dict = self._state_dict(self._initialize_model(True), state_dict_type)
849        if state_dict_type == "local_state_dict":
850            self.assertEqual({FLAT_PARAM, f"inner.{FLAT_PARAM}"}, state_dict.keys())
851        elif state_dict_type in ("state_dict", "sharded_state_dict"):
852            # Keys should match local model.
853            local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False)
854            local_keys = local_model.state_dict().keys()
855            self.assertEqual(state_dict.keys(), local_keys)
856        else:
857            raise NotImplementedError(f"No test for {state_dict_type}!")
858
859    @skip_if_lt_x_gpu(2)
860    @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
861    @parametrize("state_dict_rank0_and_offload", [True, False])
862    @parametrize("fsdp_root", [True, False])
863    def test_state_dict_load_into_local_module(
864        self,
865        state_dict_type,
866        state_dict_rank0_and_offload,
867        fsdp_root,
868    ):
869        """
870        Tests that FSDP's state_dict can be loaded into a local model.
871        """
872        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
873            return
874        if not fsdp_root:
875            model = self._get_non_fsdp_root_module()
876        else:
877            model = self._initialize_model(wrap_fsdp=True, register_buffers=True)
878        optim = SGD(model.parameters(), lr=0.1)
879        if not fsdp_root:
880            in_data = torch.randn(
881                1, 10, requires_grad=True, device=torch.device("cuda")
882            )
883        else:
884            in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
885        for _ in range(3):
886            out = model(in_data)
887            out.sum().backward()
888            optim.step()
889            optim.zero_grad()
890
891        with FSDP.summon_full_params(model):
892            fsdp_params = deepcopy(list(model.parameters()))
893
894        # get FSDP state_dict. Note that by default we return full_state_dict.
895        sd_mgr = self._get_state_dict_mgr(
896            model, state_dict_type, state_dict_rank0_and_offload
897        )
898        with sd_mgr:
899            fsdp_state_dict = model.state_dict()
900
901        ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k]
902        self._validate_state_dict_contents(
903            model,
904            fsdp_state_dict,
905            state_dict_rank0_and_offload,
906            ignore_keys=ignore_keys,
907        )
908        # Create zeroed local model
909        if not fsdp_root:
910            blank_local_model = self._get_non_fsdp_root_module(wrap=False)
911        else:
912            blank_local_model = self._initialize_model(
913                wrap_fsdp=False, wrap_ddp=False, register_buffers=True
914            )
915
916        # Nothing should be FSDP
917        for mod in blank_local_model.modules():
918            self.assertFalse(isinstance(mod, FSDP))
919
920        for param in blank_local_model.parameters():
921            with torch.no_grad():
922                param.zero_()
923
924        fsdp_state_dict = _gather_state_dict(fsdp_state_dict)
925
926        # Load fsdp's full state dict into the local and verify params are as
927        # expected.
928        if state_dict_rank0_and_offload:
929            fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict)
930
931        blank_local_model.load_state_dict(fsdp_state_dict, strict=True)
932        local_params = list(blank_local_model.parameters())
933        for fsdp_param, local_param in zip(fsdp_params, local_params):
934            self.assertEqual(fsdp_param, local_param)
935
936    @skip_if_lt_x_gpu(2)
937    @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
938    @parametrize("double_nest", [True])
939    def test_state_dict_skip_module(self, state_dict_type, double_nest):
940        torch.cuda.set_device(self.rank)
941
942        def _create_module(wrap_fsdp=True):
943            LINEAR_SKIP = "linear_skip"
944            ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else nullcontext()
945            with ctx:
946                module = SkipModel(double_nest=double_nest)
947                # Full name of linear_skip param tensors in SkipModel, as would be
948                # stored in checkpoint.
949                linear_skip_tensor_names = [
950                    k
951                    for k in dict(module.named_parameters()).keys()
952                    if LINEAR_SKIP in k
953                ]
954                # skip SkipModule
955                linear_skip = getattr(module, LINEAR_SKIP)
956                delattr(module, LINEAR_SKIP)
957                # Wrap FSDP
958                fsdp = wrap(module)
959                # reattach
960                setattr(module, LINEAR_SKIP, linear_skip)
961                return fsdp, linear_skip_tensor_names
962
963        fsdp, linear_skip_tensor_names = _create_module()
964        # Run a forward pass
965        inp = torch.randn((1, 10), device=torch.cuda.current_device())
966        loss = fsdp(inp)
967        loss.sum().backward()
968
969        with FSDP.state_dict_type(fsdp, STATE_DICT_MAPPING[state_dict_type]):
970            state_dict = fsdp.state_dict()
971        if self.rank == 0 and state_dict_type != "local_state_dict":
972            sd_keys = list(state_dict.keys())
973            expected = list(SkipModel(double_nest=False).state_dict().keys())
974            self.assertEqual(sorted(sd_keys), sorted(expected))
975            # TODO: parameters in linear_skip_tensor_names should not be handled
976            # by FSDP.state_dict(). Have a check once this is implemented in
977            # FSDP.state_dict().
978
979        # Check that it can be loaded into FSDP.
980        new_fsdp, _ = _create_module()
981        _zero_model(new_fsdp)
982        for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()):
983            self.assertNotEqual(p1, p2)
984        with FSDP.state_dict_type(new_fsdp, STATE_DICT_MAPPING[state_dict_type]):
985            if state_dict_type != "local_state_dict":
986                # FlatParameter has not supported deepcopy yet.
987                state_dict = deepcopy(state_dict)
988            new_fsdp.load_state_dict(state_dict, strict=True)
989        for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()):
990            self.assertEqual(p1, p2)
991
992        # Test that the checkpoint can be loaded into a local model.
993        local, _ = _create_module(wrap_fsdp=False)
994        for param in local.parameters():
995            with torch.no_grad():
996                param.zero_()
997
998        with fsdp.summon_full_params(fsdp):
999            for p1, p2 in zip(fsdp.parameters(), local.parameters()):
1000                self.assertNotEqual(p1, p2)
1001
1002        if state_dict_type == "local_state_dict":
1003            return
1004        state_dict = _gather_state_dict(state_dict)
1005        with fsdp.summon_full_params(fsdp):
1006            if self.rank == 0:
1007                local.load_state_dict(state_dict, strict=True)
1008                for p1, p2 in zip(fsdp.parameters(), local.parameters()):
1009                    self.assertEqual(p1, p2)
1010
1011    @skip_if_lt_x_gpu(2)
1012    def test_wrong_state_dict_config(self):
1013        model = FSDP(Model(wrap_fsdp=True).cuda())
1014        with self.assertRaisesRegex(RuntimeError, "Expected state_dict_config of type"):
1015            with model.state_dict_type(
1016                model, StateDictType.FULL_STATE_DICT, LocalStateDictConfig()
1017            ):
1018                pass
1019
1020    @skip_if_lt_x_gpu(2)
1021    @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)
1022    @parametrize("prefix", [True, False])
1023    @parametrize("ignore_inner", [True, False])
1024    @parametrize("mixed_precision", [True, False])
1025    def test_state_dict_with_ignored_modules(
1026        self, state_dict_type, prefix, ignore_inner, mixed_precision
1027    ):
1028        # Initialize an FSDP-wrapped model with an ignored module that includes
1029        # both parameters and a buffer
1030        model = Model(
1031            wrap_fsdp=True,
1032            register_buffers=True,
1033            ignore_inner=ignore_inner,
1034            mixed_precision=mixed_precision,
1035        ).cuda()
1036        ignored_modules = [model.outer]
1037        ignored_tensor_to_tensor_name = {
1038            model.outer.bias: "outer.bias",
1039            model.outer.weight: "outer.weight",
1040        }
1041        if ignore_inner:
1042            ignored_tensor_to_tensor_name = {
1043                **ignored_tensor_to_tensor_name,
1044                model.inner.bias: "inner.bias",
1045                model.inner.weight: "inner.weight",
1046            }
1047        # Note that when model.inner is not ignored this test also ensures
1048        # non-ignored buffers are not cloned.
1049        buffer_to_buffer_name = {
1050            model.inner.buffer: "inner.buffer",
1051            model.outer.buffer: "outer.buffer",
1052        }
1053        # expect fp16 model.inner.buffer with mixed_precisions
1054        # expect fp32 sd.inner.buffer after restoring to original precision
1055        # so skip AssertEqual
1056        if mixed_precision and not ignore_inner:
1057            buffer_to_buffer_name.pop(model.inner.buffer)
1058
1059        fsdp_model = FSDP(
1060            model,
1061            ignored_modules=ignored_modules,
1062            mixed_precision=MixedPrecision(
1063                param_dtype=torch.float16,
1064                reduce_dtype=torch.float16,
1065                buffer_dtype=torch.float16,
1066            )
1067            if mixed_precision
1068            else None,
1069        )
1070        prefix_str = "foo." if prefix else ""
1071        with FSDP.state_dict_type(fsdp_model, STATE_DICT_MAPPING[state_dict_type]):
1072            sd1 = _gather_state_dict(fsdp_model.state_dict(prefix=prefix_str))
1073        with FSDP.summon_full_params(fsdp_model):
1074            fsdp_params = deepcopy(list(fsdp_model.parameters()))
1075        # Check that the ignored parameters and all buffers are not cloned
1076        for tensor, tensor_name in {
1077            **ignored_tensor_to_tensor_name,
1078            **buffer_to_buffer_name,
1079        }.items():
1080            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
1081            self.assertTrue(prefixed_tensor_name in sd1)
1082            self.assertEqual(
1083                tensor.data_ptr(),
1084                sd1[prefixed_tensor_name].data_ptr(),
1085                f"{prefixed_tensor_name}",
1086            )
1087        # should not apply mixed_precision to ignored buffers
1088        for buffer_name in buffer_to_buffer_name.values():
1089            prefixed_buffer_name = f"{prefix_str}{buffer_name}"
1090            self.assertTrue(prefixed_buffer_name in sd1)
1091            self.assertEqual(sd1[prefixed_buffer_name].dtype, torch.float32)
1092        # Check that the state dict can be loaded into a non-wrapped version of
1093        # the model
1094        nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()
1095        for param in nonwrapped_model.parameters():
1096            with torch.no_grad():
1097                param.zero_()
1098
1099        to_load = {k[len(prefix_str) :]: v for k, v in sd1.items()}
1100        nonwrapped_model.load_state_dict(to_load, strict=True)
1101        local_params = list(nonwrapped_model.parameters())
1102        for fsdp_param, local_param in zip(fsdp_params, local_params):
1103            self.assertEqual(fsdp_param, local_param)
1104        # Check that if we save a state dict again, the ignored parameters and
1105        # buffer still have the same data pointer
1106        with FSDP.state_dict_type(fsdp_model, STATE_DICT_MAPPING[state_dict_type]):
1107            sd2 = fsdp_model.state_dict(prefix=prefix_str)
1108        for tensor, tensor_name in {
1109            **ignored_tensor_to_tensor_name,
1110            **buffer_to_buffer_name,
1111        }.items():
1112            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
1113            self.assertTrue(prefixed_tensor_name in sd2)
1114            self.assertEqual(tensor.data_ptr(), sd2[prefixed_tensor_name].data_ptr())
1115            self.assertEqual(
1116                sd1[prefixed_tensor_name].data_ptr(),
1117                sd2[prefixed_tensor_name].data_ptr(),
1118            )
1119
1120    @skip_if_lt_x_gpu(2)
1121    def test_state_dict_type(self):
1122        module = SkipModel(double_nest=True)
1123        with enable_wrap(wrapper_cls=FSDP):
1124            fsdp = wrap(module)
1125        with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT):
1126            pass
1127        for module in FSDP.fsdp_modules(fsdp):
1128            self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
1129
1130    @skip_if_lt_x_gpu(2)
1131    def test_local_state_dict_with_empty_ranks(self):
1132        class Model(Module):
1133            def __init__(self) -> None:
1134                super().__init__()
1135                self.my_tensor = torch.full((1,), 3.1415926)
1136                self.my_parameter = nn.Parameter(self.my_tensor)
1137
1138            def forward(self, x):
1139                return self.my_parameter
1140
1141        model = FSDP(Model().cuda())
1142        with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
1143            out = model(None)
1144            out.backward()
1145
1146            state_dict = deepcopy(model.state_dict())
1147            with torch.no_grad():
1148                with FSDP.summon_full_params(model):
1149                    self.assertEqual(model.my_parameter.item(), 3.1415926)
1150                    model.my_parameter.copy_(torch.full((1,), 1.75).cuda())
1151                    self.assertEqual(model.my_parameter.item(), 1.75)
1152            model.load_state_dict(state_dict)
1153            with FSDP.summon_full_params(model):
1154                self.assertEqual(model.my_parameter.item(), 3.1415926)
1155
1156    @skip_if_lt_x_gpu(2)
1157    def test_torch_save_load(self):
1158        model = Model(wrap_fsdp=True).cuda()
1159        with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
1160            state_dict = model.state_dict()
1161            checkpoint = io.BytesIO()
1162            torch.save(state_dict, checkpoint)
1163            checkpoint.seek(0)
1164            state_dict_saved = torch.load(checkpoint)
1165            for k, v in state_dict_saved.items():
1166                if isinstance(v, ShardedTensor):
1167                    self.assertEqual(
1168                        v._local_shards[0].tensor, state_dict[k]._local_shards[0].tensor
1169                    )
1170                else:
1171                    self.assertEqual(v, state_dict[k])
1172
1173    @skip_if_lt_x_gpu(2)
1174    def test_shared_module_and_shared_parameter(self):
1175        model = FSDP(TestDummyModel().cuda())
1176        with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
1177            state_dict = model.state_dict()
1178            self.assertEqual(
1179                state_dict["random_parameter"], state_dict["shared_parameter"]
1180            )
1181            self.assertEqual(state_dict["net2.0.bias"], state_dict["net3.0.bias"])
1182            self.assertEqual(state_dict["net2.0.weight"], state_dict["net3.0.weight"])
1183
1184    @skip_if_lt_x_gpu(2)
1185    def test_full_state_dict_missing_unexpected_keys_cleaned(self):
1186        model = self._get_simple_nested_model()
1187        sd = model.state_dict()
1188        # Create a missing key
1189        sd.pop(next(iter(sd.keys())))
1190        # Create an unexpected key
1191        sd["unexpected"] = torch.ones(1)
1192        missing, unexpected = model.load_state_dict(sd, strict=False)
1193        assert len(missing) == 1
1194        assert len(unexpected) == 1
1195        self.assertTrue(FSDP_PREFIX not in missing[0])
1196        self.assertTrue(FSDP_PREFIX not in unexpected[0])
1197
1198    @skip_if_lt_x_gpu(2)
1199    def test_sharded_load_multi_backend_pg(self):
1200        auto_wrap_policy = ModuleWrapPolicy(
1201            {TransformerEncoderLayer, TransformerDecoderLayer}
1202        )
1203        fsdp_kwargs = {
1204            "auto_wrap_policy": auto_wrap_policy,
1205            "use_orig_params": True,
1206        }
1207        for load_cpu in [True, False]:
1208            with self.subTest(load_cpu=load_cpu):
1209                pg = dist.new_group(backend="cpu:gloo,cuda:nccl")
1210                fsdp_model = TransformerWithSharedParams.init(
1211                    pg,
1212                    FSDPInitMode.RECURSIVE,
1213                    CUDAInitMode.CUDA_BEFORE,
1214                    fsdp_kwargs,
1215                )
1216                FSDP.set_state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT)
1217                sharded = fsdp_model.state_dict()
1218                param_copy = [t.clone().detach_() for t in fsdp_model.parameters()]
1219                with torch.no_grad():
1220                    for p in fsdp_model.parameters():
1221                        p.zero_()
1222
1223                if load_cpu:
1224                    # Offload to CPU to simulate CPU state_dict load
1225                    for k, v in sharded.items():
1226                        sharded[k] = v.cpu()
1227
1228                fsdp_model.load_state_dict(sharded)
1229                for p1, p2 in zip(param_copy, fsdp_model.parameters()):
1230                    self.assertEqual(p1, p2, f"not equal: {p1.sum()} vs {p2.sum()}")
1231
1232    @skip_if_lt_x_gpu(2)
1233    def test_world_size_one(self):
1234        my_pg = None
1235        for i in range(self.world_size):
1236            pg = dist.new_group(ranks=[i])
1237            if i == self.rank:
1238                my_pg = pg
1239
1240        model = TransformerWithSharedParams.init(
1241            my_pg,
1242            FSDPInitMode.RECURSIVE,
1243            CUDAInitMode.CUDA_BEFORE,
1244        )
1245        with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
1246            state_dict = model.state_dict()
1247            model.load_state_dict(state_dict)
1248
1249        dist.barrier()
1250
1251
1252class TestFSDPStateDict4GPUs(FSDPTest):
1253    @property
1254    def world_size(self):
1255        return torch.cuda.device_count()
1256
1257    @skip_if_lt_x_gpu(4)
1258    def test_local_state_dict_reshard(self):
1259        """
1260        This test demonstrates the ability to do resharding when using
1261        local_state_dict. Although we do not recommend users to use
1262        local_state_dict, there are still some corner cases that
1263        using local_state_dict is a better solution.
1264        """
1265        model = FSDP(Model(wrap_fsdp=True)).cuda()
1266        optim = torch.optim.SGD(model.parameters(), lr=0.1)
1267
1268        batch = torch.randn(4, 4, device=torch.cuda.current_device())
1269        output = model(batch)
1270        loss = output.sum()
1271        loss.backward()
1272        optim.step()
1273        with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
1274            state_dict = model.state_dict()
1275
1276        rank = dist.get_rank()
1277        new_pg = dist.new_group(ranks=[0, 1])
1278        resharded_state_dict = {}
1279        # Mimic resharding from 4 GPUs to 2 GPUs
1280        for key, value in state_dict.items():
1281            if isinstance(value, ShardedTensor):
1282                full_flat_param = _all_gather_sharded_tensor(value)
1283                if rank < 2:
1284                    full_numel = full_flat_param.size()
1285                    chunks = full_flat_param.chunk(2)
1286                    flat_param = chunks[rank]
1287                    shard_offset = 0 if rank == 0 else chunks[0].numel()
1288                    local_shards = [
1289                        Shard.from_tensor_and_offsets(flat_param, [shard_offset], rank)
1290                    ]
1291                    sharded_tensor = init_from_local_shards(
1292                        local_shards, full_numel, process_group=new_pg
1293                    )
1294                    resharded_state_dict[key] = sharded_tensor
1295            else:
1296                if rank < 2:
1297                    resharded_state_dict[key] = value
1298
1299        if rank < 2:
1300            model2 = FSDP(
1301                Model(wrap_fsdp=True, process_group=new_pg), process_group=new_pg
1302            ).cuda()
1303            with FSDP.state_dict_type(model2, StateDictType.LOCAL_STATE_DICT):
1304                model2.load_state_dict(resharded_state_dict)
1305
1306        with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
1307            full_state_dict1 = model.state_dict()
1308
1309        if rank < 2:
1310            with FSDP.state_dict_type(model2, StateDictType.FULL_STATE_DICT):
1311                full_state_dict2 = model2.state_dict()
1312            self.assertEqual(full_state_dict1, full_state_dict2)
1313
1314
1315instantiate_parametrized_tests(TestFSDPStateDict)
1316
1317if __name__ == "__main__":
1318    run_tests()
1319