xref: /aosp_15_r20/external/pytorch/test/distributed/checkpoint/test_state_dict.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import functools
5import sys
6from itertools import chain
7from typing import Callable, Tuple, Type, Union
8
9import torch
10import torch.distributed as dist
11import torch.nn as nn
12from torch.distributed._composable import fully_shard, replicate
13
14# importing fully_shard as FSDP2 since the original fully_shard is used in this test.
15# TODO: remove old composable fully_shard so that we don't have to import new fully_shard as FSDP2
16from torch.distributed._composable.fsdp import fully_shard as FSDP2
17from torch.distributed._shard.sharded_tensor import ShardedTensor
18from torch.distributed._tensor import DTensor, init_device_mesh
19from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
20    apply_activation_checkpointing,
21)
22from torch.distributed.checkpoint import state_dict as ptd_state_dict
23from torch.distributed.checkpoint.state_dict import (
24    _patch_model_state_dict,
25    _patch_optimizer_state_dict,
26    get_model_state_dict,
27    get_optimizer_state_dict,
28    get_state_dict,
29    set_model_state_dict,
30    set_optimizer_state_dict,
31    StateDictOptions,
32)
33from torch.distributed.fsdp import (
34    FullyShardedDataParallel as FSDP,
35    ShardingStrategy,
36    StateDictType,
37)
38from torch.distributed.fsdp.wrap import ModuleWrapPolicy
39from torch.distributed.optim import _apply_optimizer_in_backward
40from torch.nn.parallel import DistributedDataParallel as DDP
41from torch.optim import Optimizer
42from torch.testing._internal.common_dist_composable import (
43    CompositeParamModel,
44    UnitModule,
45)
46from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
47from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
48from torch.testing._internal.distributed._tensor.common_dtensor import (
49    DTensorTestBase,
50    MultiProcessTestCase,
51    with_comms,
52)
53from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
54from torch.utils._pytree import tree_all, tree_all_only
55
56
57if not dist.is_available():
58    print("Distributed not available, skipping tests", file=sys.stderr)
59    sys.exit(0)
60
61if TEST_WITH_DEV_DBG_ASAN:
62    print(
63        "Skip dev-asan as torch + multiprocessing spawn have known issues",
64        file=sys.stderr,
65    )
66    sys.exit(0)
67
68
69class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
70    """Tests state_dict and load_state_dict"""
71
72    @property
73    def world_size(self) -> int:
74        return min(4, torch.cuda.device_count())
75
76    def _test_save_load(
77        self,
78        init_model_optim: Callable,
79        test_frozen: bool = False,
80    ) -> None:
81        options = StateDictOptions(ignore_frozen_params=test_frozen)
82        # Initialize original model and distributed model.
83        model, optim, copy_optim, dist_model, dist_optim = init_model_optim()
84
85        # Train 10 steps.
86        for i in range(10):
87            batch = torch.rand(8, 100, device="cuda")
88            model(batch).sum().backward()
89            optim.step()
90            dist_model(batch).sum().backward()
91            if not isinstance(dist_optim, list):
92                dist_optim.step()
93                dist_optim.zero_grad()
94            else:
95                for _dist_optim in dist_optim:
96                    _dist_optim.zero_grad()
97            optim.zero_grad()
98
99        # Get the state_dict, and compare the result
100        msd = model.state_dict()
101        osd = optim.state_dict()
102        dist_msd, dist_osd = get_state_dict(
103            dist_model, optimizers=dist_optim, options=options
104        )
105        self._verify_msd(msd, dist_msd, options)
106        self._verify_osd_by_load(model, optim, copy_optim, dist_osd)
107        self._verify_osd(model, optim, osd, dist_osd)
108
109        # Initialize a completely new model to simulate checkpoint load.
110        _, _, _, dist_model, dist_optim = init_model_optim()
111
112        # Simulate DCP distributed load. We need to first get the state_dict and
113        # pass them to DCP to load the saved state_dict from the storage.
114        # Then finally we can call set_state_dict().
115        if not isinstance(dist_optim, list):
116            dist_optim = [dist_optim]
117        if test_frozen:
118            # We won't be able to load the partial state_dict back.
119            return
120        # Since we already have the state_dict saved before, no need to call DCP.
121        # We can directly load them back. This asser is to ensure that optimizer
122        # state storage are initialized.
123        # self.assertEqual(len(curr_dist_osd[STATE]), len(dist_osd[STATE]))
124        set_model_state_dict(
125            dist_model,
126            model_state_dict=dist_msd,
127            options=options,
128        )
129        set_optimizer_state_dict(
130            dist_model,
131            optimizers=dist_optim,
132            optim_state_dict=dist_osd,
133            options=options,
134        )
135
136        # Check if the new state_dict are the same
137        dist_msd, dist_osd = get_state_dict(
138            dist_model, optimizers=dist_optim, options=options
139        )
140        self._verify_msd(msd, dist_msd, options)
141        # TODO: Ditto
142        # self._verify_osd_by_load(model, optim, copy_optim, dist_osd)
143        self._verify_osd(model, optim, osd, dist_osd)
144
145        # Test _patch_model_state_dict, and _patch_optimizer_state_dict
146        _patch_model_state_dict(dist_model, options=options)
147        _patch_optimizer_state_dict(dist_model, optimizers=dist_optim, options=options)
148        dist_msd = dist_model.state_dict()
149        dist_osd = dist_optim[0].state_dict()
150        self._verify_msd(msd, dist_msd, options)
151        self._verify_osd_by_load(model, optim, copy_optim, dist_osd)
152        self._verify_osd(model, optim, osd, dist_osd)
153
154    def _test_fsdp(
155        self,
156        *,
157        use_orig_params: bool,
158        use_composable: bool,
159        use_dtensor: bool,
160        wrapping: Tuple[nn.Module] = (),
161        compile_model: bool = False,
162        optimizer_class: Type[Optimizer],
163    ) -> None:
164        if not use_orig_params and use_composable:
165            return
166
167        # TODO: remove this return after we complete the composable API side change for device_mesh
168        if use_composable and use_dtensor:
169            return
170
171        def init_model_optim():
172            if use_dtensor:
173                device_mesh = init_device_mesh("cuda", (self.world_size,))
174
175            orig_model = CompositeParamModel(device=torch.device("cuda"))
176            orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
177            copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
178            if wrapping:
179                strategy = set(wrapping)
180            else:
181                strategy = {UnitModule}
182            if use_composable:
183                dist_model = fully_shard(
184                    copy.deepcopy(orig_model), policy=ModuleWrapPolicy(strategy)
185                )
186            else:
187                if use_dtensor:
188                    device_mesh = init_device_mesh("cuda", (self.world_size,))
189                    dist_model = FSDP(
190                        copy.deepcopy(orig_model),
191                        auto_wrap_policy=ModuleWrapPolicy(strategy),
192                        use_orig_params=use_orig_params,
193                        device_mesh=device_mesh,
194                    )
195                else:
196                    dist_model = FSDP(
197                        copy.deepcopy(orig_model),
198                        auto_wrap_policy=ModuleWrapPolicy(strategy),
199                        use_orig_params=use_orig_params,
200                    )
201
202            if compile_model:
203                dist_model = torch.compile(dist_model)
204            dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
205            return orig_model, orig_optim, copy_optim, dist_model, dist_optim
206
207        self._test_save_load(init_model_optim)
208
209    @with_comms
210    @skip_if_lt_x_gpu(2)
211    def test_fsdp(self) -> None:
212        self.run_subtests(
213            {
214                "use_orig_params": [True, False],
215                "use_composable": [True, False],
216                "use_dtensor": [True, False],
217                "wrapping": [tuple(), (nn.Linear, UnitModule)],
218                "optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
219            },
220            self._test_fsdp,
221        )
222
223    @with_comms
224    @skip_if_lt_x_gpu(2)
225    def test_compiled_fsdp(self) -> None:
226        self.run_subtests(
227            {
228                "use_orig_params": [True],
229                "use_composable": [False],
230                "use_dtensor": [False],
231                "wrapping": [tuple()],
232                "optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
233            },
234            self._test_fsdp,
235        )
236
237    def _test_fsdp2(
238        self,
239        *,
240        reshard_after_forward: Union[bool, int],
241        optimizer_class: Type[Optimizer],
242        compile_model: bool,
243        foreach: bool = True,
244    ):
245        def init_model_optim():
246            orig_model = CompositeParamModel(device=torch.device("cuda"))
247            orig_optim = optimizer_class(
248                orig_model.parameters(), lr=1e-3, foreach=foreach
249            )
250            copy_optim = optimizer_class(
251                orig_model.parameters(), lr=1e-3, foreach=foreach
252            )
253
254            dist_model = FSDP2(
255                copy.deepcopy(orig_model),
256                reshard_after_forward=reshard_after_forward,
257            )
258
259            if compile_model:
260                dist_model = torch.compile(dist_model)
261            dist_optim = optimizer_class(
262                dist_model.parameters(), lr=1e-3, foreach=foreach
263            )
264
265            return orig_model, orig_optim, copy_optim, dist_model, dist_optim
266
267        self._test_save_load(init_model_optim)
268
269    @with_comms
270    @skip_if_lt_x_gpu(2)
271    def test_fsdp2(self) -> None:
272        self.run_subtests(
273            {
274                "reshard_after_forward": [True, False],
275                "optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
276                "compile_model": [True, False],
277            },
278            self._test_fsdp2,
279        )
280
281    def _test_ddp(self, use_composable: bool, optimizer_class: Type[Optimizer]) -> None:
282        def init_model_optim():
283            orig_model = CompositeParamModel(device=torch.device("cuda"))
284            orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
285            copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
286            if use_composable:
287                dist_model = replicate(copy.deepcopy(orig_model))
288            else:
289                dist_model = DDP(copy.deepcopy(orig_model))
290            dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
291            return orig_model, orig_optim, copy_optim, dist_model, dist_optim
292
293        self._test_save_load(init_model_optim)
294
295    @with_comms
296    @skip_if_lt_x_gpu(2)
297    def test_ddp(self) -> None:
298        self.run_subtests(
299            {
300                "use_composable": [True, False],
301                "optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
302            },
303            self._test_ddp,
304        )
305
306    def _test_fsdp_ddp(
307        self,
308        use_composable: bool,
309        optimizer_class: Type[Optimizer],
310        optim_in_backward: bool = False,
311        test_frozen: bool = False,
312    ) -> None:
313        def init_model_optim():
314            orig_model = CompositeParamModel(device=torch.device("cuda"))
315            if test_frozen:
316                for param in chain(
317                    orig_model.u1.parameters(), orig_model.u2.parameters()
318                ):
319                    param.requires_grad = False
320            orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
321            copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
322            dist_model = copy.deepcopy(orig_model)
323            if use_composable:
324                replicate(dist_model.l)
325                fully_shard(dist_model, policy=ModuleWrapPolicy({UnitModule}))
326            else:
327                dist_model.l = DDP(dist_model.l)
328                dist_model = FSDP(
329                    copy.deepcopy(orig_model),
330                    auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
331                    use_orig_params=optim_in_backward,
332                    ignored_modules=[dist_model.l],
333                )
334            if optim_in_backward:
335                _apply_optimizer_in_backward(
336                    optimizer_class, dist_model.parameters(), {"lr": 1e-3}
337                )
338                dist_optim = [
339                    p._in_backward_optimizers[0] for p in dist_model.parameters()
340                ]
341            else:
342                dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
343            return orig_model, orig_optim, copy_optim, dist_model, dist_optim
344
345        self._test_save_load(init_model_optim, test_frozen)
346
347    @with_comms
348    @skip_if_lt_x_gpu(2)
349    def test_fsdp_ddp(self) -> None:
350        self.run_subtests(
351            {
352                "use_composable": [True, False],
353                "optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
354            },
355            self._test_fsdp_ddp,
356        )
357
358    @with_comms
359    @skip_if_lt_x_gpu(2)
360    def test_frozen_parameters(self) -> None:
361        self.run_subtests(
362            {
363                "use_composable": [True],
364                "optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
365                "test_frozen": [True],
366            },
367            self._test_fsdp_ddp,
368        )
369
370    # TODO: enable use_dtensor once 2D device_mesh support is fully landed.
371    """
372    @with_comms
373    @skip_if_lt_x_gpu(2)
374    def test_use_dtensor(self) -> None:
375        self._test_fsdp_ddp(use_composable=False, use_dtensor=True)
376    """
377
378    # TODO: enable the test after FSDP + apply_optimizer_in_backward works.
379    # Disable this test as it is broken after
380    # https://github.com/pytorch/pytorch/pull/108298.
381    """
382    @with_comms
383    @skip_if_lt_x_gpu(2)
384    def test_apply_optimizer_in_backward(self) -> None:
385        self.run_subtests(
386            {"use_composable": [True, False]},
387            self._test_fsdp_ddp,
388            optim_in_backward=True,
389        )
390    """
391
392    def _test_single_gpu(self, optimizer_class: Type[Optimizer]) -> None:
393        def init_model_optim():
394            orig_model = CompositeParamModel(device=torch.device("cuda"))
395            orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
396            copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
397            model_copy = copy.deepcopy(orig_model)
398            optim_copy = optimizer_class(model_copy.parameters(), lr=1e-3)
399            return orig_model, orig_optim, copy_optim, model_copy, optim_copy
400
401        self._test_save_load(init_model_optim)
402
403    @with_comms
404    @skip_if_lt_x_gpu(1)
405    def test_single_gpu(self) -> None:
406        self.run_subtests(
407            {"optimizer_class": [torch.optim.Adam, torch.optim.AdamW]},
408            self._test_single_gpu,
409        )
410
411    @with_comms
412    @skip_if_lt_x_gpu(1)
413    def test_strict(self) -> None:
414        model = CompositeParamModel(device=torch.device("cuda"))
415
416        model_state_dict = get_model_state_dict(model)
417        key = next(iter(model_state_dict.keys()))
418        model_state_dict["abc"] = torch.zeros(10)
419        with self.assertRaisesRegex(RuntimeError, "Unexpected key"):
420            set_model_state_dict(model, model_state_dict=model_state_dict)
421        model_state_dict.pop(key)
422        incompatible_keys = set_model_state_dict(
423            model,
424            model_state_dict=model_state_dict,
425            options=StateDictOptions(strict=False),
426        )
427        self.assertEqual(incompatible_keys.missing_keys, [key])
428        self.assertEqual(incompatible_keys.unexpected_keys, ["abc"])
429        model_state_dict.pop("abc")
430        with self.assertRaisesRegex(RuntimeError, "Missing key"):
431            set_model_state_dict(model, model_state_dict=model_state_dict)
432
433    def _test_cpu_offload_full_state_dict(
434        self, optimizer_class: Type[Optimizer]
435    ) -> None:
436        orig_model = CompositeParamModel(device=torch.device("cuda"))
437        device_mesh = init_device_mesh("cuda", (self.world_size,))
438        dist_model = FSDP(
439            copy.deepcopy(orig_model),
440            auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
441            use_orig_params=True,
442            device_mesh=device_mesh,
443        )
444
445        dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
446
447        mst, ost = get_state_dict(
448            dist_model,
449            dist_optim,
450            options=StateDictOptions(cpu_offload=True),
451        )
452
453        cpu_device = torch.device("cpu")
454
455        def is_cpu(v):
456            if isinstance(v, DTensor):
457                return v.device == cpu_device
458            elif isinstance(v, ShardedTensor):
459                shards = v.local_shards()
460                if not shards:
461                    return True
462                return shards[0].tensor.device == cpu_device
463            else:
464                return v.device == cpu_device
465
466        self.assertTrue(
467            tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, mst)
468        )
469        self.assertTrue(
470            tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, ost)
471        )
472
473        mst, ost = get_state_dict(
474            dist_model, dist_optim, options=StateDictOptions(full_state_dict=True)
475        )
476
477        self.assertTrue(
478            tree_all(lambda v: not isinstance(v, (DTensor, ShardedTensor)), mst)
479        )
480        self.assertTrue(
481            tree_all(lambda v: not isinstance(v, (DTensor, ShardedTensor)), ost)
482        )
483
484        mst, ost = get_state_dict(
485            dist_model,
486            dist_optim,
487            options=StateDictOptions(full_state_dict=True, cpu_offload=True),
488        )
489
490        if self.rank == 0:
491            self.assertTrue(
492                tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, mst)
493            )
494            self.assertTrue(
495                tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, ost)
496            )
497        else:
498            self.assertEqual(mst, {})
499            self.assertEqual(ost, {})
500
501    @with_comms
502    @skip_if_lt_x_gpu(2)
503    def test_cpu_offload_full_state_dict(self) -> None:
504        self.run_subtests(
505            {"optimizer_class": [torch.optim.Adam, torch.optim.AdamW]},
506            self._test_cpu_offload_full_state_dict,
507        )
508
509    @with_comms
510    @skip_if_lt_x_gpu(1)
511    def test_activation_ckpt_fqns_ddp(self) -> None:
512        """Tests that activation checkpointing prefixes are removed from module names"""
513        model = CompositeParamModel(device=torch.device("cuda"))
514        original_keys = get_model_state_dict(model).keys()
515
516        apply_activation_checkpointing(model)
517        model = DDP(model)
518        new_keys = get_model_state_dict(model).keys()
519
520        self.assertEqual(original_keys, new_keys)
521
522    @with_comms
523    @skip_if_lt_x_gpu(1)
524    def test_activation_ckpt_fqns_fsdp1(self) -> None:
525        self.run_subtests(
526            {"use_orig_params": [True, False]},
527            self._test_activation_ckpt_fqns_fsdp1,
528        )
529
530    def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None:
531        """Tests that activation checkpointing prefixes are removed from module names"""
532        model = CompositeParamModel(device=torch.device("cuda"))
533        original_keys = get_model_state_dict(model).keys()
534
535        apply_activation_checkpointing(model)
536        model = FSDP(model, use_orig_params=use_orig_params)
537        new_keys = get_model_state_dict(model).keys()
538
539        self.assertEqual(original_keys, new_keys)
540
541    @with_comms
542    @skip_if_lt_x_gpu(1)
543    def test_extra_state(self) -> None:
544        model = CompositeParamModel(device=torch.device("cuda"))
545
546        def get_extra_state(self):
547            return "MyState"
548
549        def set_extra_state(self, state):
550            return
551
552        UnitModule.get_extra_state = get_extra_state
553        UnitModule.set_extra_state = set_extra_state
554
555        ddp_model = DDP(copy.deepcopy(model))
556        set_model_state_dict(ddp_model, get_model_state_dict(ddp_model))
557        self.assertEqual(model.state_dict()["u1._extra_state"], "MyState")
558        self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model))
559
560    @with_comms
561    @skip_if_lt_x_gpu(1)
562    def test_non_persistent_buffers(self) -> None:
563        model = CompositeParamModel(device=torch.device("cuda"))
564        model.register_buffer(
565            "dont_save_me", torch.rand(100, device="cuda"), persistent=False
566        )
567        ddp_model = DDP(copy.deepcopy(model))
568        set_model_state_dict(ddp_model, get_model_state_dict(ddp_model))
569        self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model))
570
571    def _test_broadcast_from_rank0(self, wrapper) -> None:
572        model = CompositeParamModel(device=torch.device("cuda"))
573        optim = torch.optim.Adam(model.parameters())
574        fsdp_model = wrapper(copy.deepcopy(model))
575        fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
576
577        batch = torch.rand(8, 100, device="cuda")
578        model(batch).sum().backward()
579        optim.step()
580        states, optim_states = get_state_dict(model, optim)
581
582        fsdp_model(batch).sum().backward()
583        fsdp_optim.step()
584
585        def check(equal):
586            fsdp_states = get_model_state_dict(
587                fsdp_model,
588                options=StateDictOptions(full_state_dict=True),
589            )
590            fsdp_optim_states = get_optimizer_state_dict(
591                fsdp_model,
592                fsdp_optim,
593                options=StateDictOptions(full_state_dict=True),
594            )
595            if equal:
596                self.assertEqual(states, fsdp_states)
597                self.assertEqual(optim_states, fsdp_optim_states)
598            else:
599                self.assertNotEqual(states, fsdp_states)
600                self.assertNotEqual(optim_states, fsdp_optim_states)
601
602        check(equal=True)
603        fsdp_model(batch).sum().backward()
604        fsdp_optim.step()
605        check(equal=False)
606
607        # Drop the states to simulate loading from rank0
608        if dist.get_rank() > 0:
609            load_states = {}
610            load_states2 = {}
611            load_optim_states = {}
612        else:
613            load_states = copy.deepcopy(states)
614            load_states2 = copy.deepcopy(states)
615            load_optim_states = copy.deepcopy(optim_states)
616
617        set_model_state_dict(
618            fsdp_model,
619            model_state_dict=load_states,
620            options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
621        )
622        set_optimizer_state_dict(
623            fsdp_model,
624            fsdp_optim,
625            optim_state_dict=load_optim_states,
626            options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
627        )
628
629        check(equal=True)
630        # Verify the `strict` flag.
631        load_states = load_states2
632        if load_states:
633            key = next(iter(load_states.keys()))
634            load_states.pop(key)
635        with self.assertRaisesRegex(RuntimeError, "Missing key"):
636            set_model_state_dict(
637                fsdp_model,
638                model_state_dict=load_states,
639                options=StateDictOptions(
640                    broadcast_from_rank0=True, full_state_dict=True
641                ),
642            )
643
644    @with_comms
645    @skip_if_lt_x_gpu(2)
646    def test_broadcast_from_rank0(self) -> None:
647        device_mesh = init_device_mesh("cuda", (self.world_size,))
648        self.run_subtests(
649            {
650                "wrapper": [
651                    functools.partial(FSDP2, mesh=device_mesh),
652                    functools.partial(FSDP, device_mesh=device_mesh),
653                ]
654            },
655            self._test_broadcast_from_rank0,
656        )
657
658    @with_comms
659    @skip_if_lt_x_gpu(4)
660    def test_broadcast_from_rank0_hsdp(self) -> None:
661        device_mesh = init_device_mesh("cuda", (2, self.world_size // 2))
662        self.run_subtests(
663            {
664                "wrapper": [
665                    functools.partial(
666                        FSDP,
667                        device_mesh=device_mesh,
668                        sharding_strategy=ShardingStrategy.HYBRID_SHARD,
669                    ),
670                ]
671            },
672            self._test_broadcast_from_rank0,
673        )
674
675    @with_comms
676    @skip_if_lt_x_gpu(2)
677    def test_fsdp_root_not_initialized(self) -> None:
678        # This test verifies that FSDP root is not initialized but we should
679        # still be able to  get the state_dict without errors because
680        # fsdp_model.state_dict() will trigger the FSDP initialization.
681        device_mesh = init_device_mesh("cuda", (self.world_size,))
682        model = CompositeParamModel(device=torch.device("cuda"))
683        fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh)
684        fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
685        get_model_state_dict(fsdp_model)
686        get_optimizer_state_dict(fsdp_model, fsdp_optim)
687
688    @with_comms
689    @skip_if_lt_x_gpu(2)
690    def test_optim_state_dict_param_matching(self) -> None:
691        # This test verifies parameters between optim and optim_state_dict
692        # "initial_lr" is added to optim_state_dict, but not to the new optim
693        # We test whether "initial_lr" appear in optim after
694        # set_optimizer_state_dict.
695        device = "cuda"
696        torch.manual_seed(0)
697        model = nn.Sequential(
698            *[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)]
699        )
700        for layer in model:
701            fully_shard(layer)
702        fully_shard(model)
703        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
704        torch.optim.lr_scheduler.LambdaLR(
705            optim, lr_lambda=[lambda epoch: 0.95**epoch]
706        )
707        opt_state_dict = ptd_state_dict.get_optimizer_state_dict(
708            model,
709            optim,
710            options=ptd_state_dict.StateDictOptions(
711                full_state_dict=True, cpu_offload=True
712            ),
713        )
714        if dist.get_rank() == 0:
715            self.assertTrue("initial_lr" in opt_state_dict["param_groups"][0])
716
717        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
718        self.assertTrue("initial_lr" not in optim.param_groups[0])
719
720        ptd_state_dict.set_optimizer_state_dict(
721            model,
722            optim,
723            optim_state_dict=opt_state_dict,
724            options=ptd_state_dict.StateDictOptions(
725                broadcast_from_rank0=True, full_state_dict=True
726            ),
727        )
728        if dist.get_rank() == 0:
729            self.assertTrue("initial_lr" in optim.param_groups[0])
730
731    @with_comms
732    @skip_if_lt_x_gpu(2)
733    def test_flattened_osd(self) -> None:
734        device_mesh = init_device_mesh("cuda", (self.world_size,))
735        model = CompositeParamModel(device=torch.device("cuda"))
736        fsdp_model = FSDP2(copy.deepcopy(model), mesh=device_mesh)
737        fsdp_optim = torch.optim.AdamW(fsdp_model.parameters())
738        batch = torch.rand(8, 100, device="cuda")
739        fsdp_model(batch).sum().backward()
740        fsdp_optim.step()
741        fsdp_optim.zero_grad()
742        osd1 = get_optimizer_state_dict(fsdp_model, fsdp_optim)
743        osd2 = get_optimizer_state_dict(
744            fsdp_model,
745            fsdp_optim,
746            options=StateDictOptions(flatten_optimizer_state_dict=True),
747        )
748        fsdp_optim2 = torch.optim.AdamW(fsdp_model.parameters())
749        set_optimizer_state_dict(
750            fsdp_model, optimizers=fsdp_optim2, optim_state_dict=osd2
751        )
752        self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict())
753        set_optimizer_state_dict(
754            fsdp_model, optimizers=fsdp_optim2, optim_state_dict=osd1
755        )
756        self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict())
757
758    @with_comms
759    @skip_if_lt_x_gpu(1)
760    def test_deprecate_partial(self) -> None:
761        model = CompositeParamModel(device=torch.device("cuda"))
762
763        model_state_dict1 = get_model_state_dict(model)
764        model_state_dict1 = copy.deepcopy(model_state_dict1)
765        with self.assertWarnsRegex(
766            FutureWarning,
767            "Getting submodules only model/optim state_dict is deprecated",
768        ):
769            model_state_dict2 = get_model_state_dict(model, submodules={model.l})
770        model_state_dict2 = copy.deepcopy(model_state_dict2)
771        with self.assertWarnsRegex(
772            FutureWarning,
773            "Getting submodules only model/optim state_dict is deprecated",
774        ):
775            model_state_dict3 = get_model_state_dict(
776                model,
777                submodules={model.l},
778                options=StateDictOptions(keep_submodule_prefixes=False),
779            )
780        model_state_dict3 = copy.deepcopy(model_state_dict3)
781        self.assertEqual(len(model_state_dict2), 2)
782        self.assertEqual(len(model_state_dict3), 2)
783        for key in model_state_dict3.keys():
784            full_fqn = f"l.{key}"
785            value1 = model_state_dict1[full_fqn]
786            value2 = model_state_dict2[full_fqn]
787            value3 = model_state_dict3[key]
788            self.assertEqual(value1, value2)
789            self.assertEqual(value2, value3)
790
791        zeros_state_dict = {
792            k: torch.zeros_like(v) for k, v in model_state_dict1.items()
793        }
794        model.load_state_dict(zeros_state_dict)
795        set_model_state_dict(
796            model,
797            model_state_dict=model_state_dict2,
798            options=StateDictOptions(strict=False),
799        )
800        self.assertEqual(model.l.weight, model_state_dict1["l.weight"])
801        self.assertEqual(model.l.bias, model_state_dict1["l.bias"])
802
803        model.load_state_dict(zeros_state_dict)
804        with self.assertWarnsRegex(FutureWarning, "Passing model_state_dict as a "):
805            set_model_state_dict(
806                model,
807                model_state_dict={model.l: model_state_dict3},
808                options=StateDictOptions(strict=False),
809            )
810        self.assertEqual(model.l.weight, model_state_dict1["l.weight"])
811        self.assertEqual(model.l.bias, model_state_dict1["l.bias"])
812
813    @with_comms
814    @skip_if_lt_x_gpu(1)
815    def test_deprecate_fsdp_api(self) -> None:
816        device_mesh = init_device_mesh("cuda", (self.world_size,))
817        model = CompositeParamModel(device=torch.device("cuda"))
818        fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh)
819        with self.assertWarnsRegex(
820            FutureWarning,
821            r"FSDP.state_dict_type\(\) and FSDP.set_state_dict_type\(\) are being deprecated",
822        ):
823            with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
824                fsdp_model.state_dict()
825
826        with self.assertRaisesRegex(AssertionError, "FutureWarning not triggered"):
827            with self.assertWarnsRegex(
828                FutureWarning,
829                r"FSDP.state_dict_type\(\) and FSDP.set_state_dict_type\(\) are being deprecated",
830            ):
831                get_model_state_dict(model)
832
833    @with_comms
834    @skip_if_lt_x_gpu(2)
835    def test_shared_weight(self):
836        class TiedEmbeddingModel(nn.Module):
837            def __init__(self, vocab_size, embedding_dim):
838                super().__init__()
839                self.embedding = nn.Embedding(vocab_size, embedding_dim)
840                self.decoder = nn.Linear(embedding_dim, vocab_size)
841                self.decoder.weight = self.embedding.weight  # Tying weights
842
843            def forward(self, input):
844                input = (input * 10).to(torch.int)
845                embedded = self.embedding(input)
846                output = self.decoder(embedded)
847                return output
848
849        def init_model_optim():
850            device_mesh = init_device_mesh("cuda", (self.world_size,))
851            orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
852            orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
853            copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
854            dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
855            dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3)
856            return orig_model, orig_optim, copy_optim, dist_model, dist_optim
857
858        self._test_save_load(init_model_optim)
859
860
861class TestNoComm(MultiProcessTestCase):
862    def setUp(self) -> None:
863        super().setUp()
864        self._spawn_processes()
865
866    @skip_if_lt_x_gpu(1)
867    def test_no_dist(self) -> None:
868        model = CompositeParamModel(device=torch.device("cuda"))
869        optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
870
871        self.assertFalse(dist.is_initialized())
872        msd = get_model_state_dict(
873            model, options=StateDictOptions(full_state_dict=True, cpu_offload=True)
874        )
875        for v in msd.values():
876            self.assertFalse(v.is_cuda)
877        self.assertEqual(model.state_dict(), msd)
878        set_model_state_dict(model, model.state_dict())
879        osd = get_optimizer_state_dict(
880            model,
881            optim,
882            options=StateDictOptions(full_state_dict=True, cpu_offload=True),
883        )
884        set_optimizer_state_dict(model, optim, osd)
885        set_optimizer_state_dict(model, optim, optim.state_dict())
886
887
888if __name__ == "__main__":
889    run_tests()
890