xref: /aosp_15_r20/external/pytorch/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import copy
5
6import torch
7import torch.distributed.checkpoint as dcp
8import torch.nn as nn
9from torch.distributed._composable.fsdp import fully_shard
10from torch.distributed._tensor import DTensor, init_device_mesh
11from torch.distributed._tensor.experimental import implicit_replication
12from torch.distributed.checkpoint.state_dict import (
13    get_model_state_dict,
14    get_optimizer_state_dict,
15    StateDictOptions,
16)
17from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
18from torch.distributed.fsdp.wrap import always_wrap_policy
19from torch.distributed.tensor.parallel import (
20    ColwiseParallel,
21    parallelize_module,
22    RowwiseParallel,
23)
24from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
25from torch.testing._internal.common_fsdp import FSDPTest, MLP
26from torch.testing._internal.common_utils import run_tests
27from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
28from torch.utils._pytree import tree_all_only
29
30
31class TestFullyShardWithDistributedStateDict(FSDPTest):
32    @property
33    def world_size(self) -> int:
34        return min(4, torch.cuda.device_count())
35
36    def _get_base_model(self, mlp_dim: int = 2):
37        base_model = nn.Sequential(
38            MLP(mlp_dim),
39            nn.Sequential(MLP(mlp_dim), nn.Linear(mlp_dim, mlp_dim)),
40            MLP(mlp_dim),
41        )
42        return base_model
43
44    @skip_if_lt_x_gpu(2)
45    def test_1d_fsdp_get_model_state_dict(self):
46        self.run_subtests(
47            {"mlp_dim": [2, 3, 4, 5]},
48            self._test_1d_fsdp_get_model_state_dict,
49        )
50
51    def _test_1d_fsdp_get_model_state_dict(self, mlp_dim: int):
52        """
53        Test model.state_dict() and distributed_state_dict parity.
54        """
55        base_model = self._get_base_model(mlp_dim)
56        # Default is `reshard_after_forward=True`
57        model1 = copy.deepcopy(base_model)
58        for module in model1:
59            fully_shard(module)
60        fully_shard(model1)
61
62        # osd: original state dict, dsd: distributed state dict
63        osd = model1.state_dict()
64        dsd = get_model_state_dict(model1)
65        self.assertEqual(osd, dsd)
66
67        # Check `reshard_after_forward=False` after a forward
68        model2 = copy.deepcopy(base_model)
69        for module in model2:
70            fully_shard(module, reshard_after_forward=False)
71        fully_shard(model2, reshard_after_forward=False)
72        inp = torch.randn((2, mlp_dim), device="cuda")
73        model2(inp)  # parameters are not resharded after this forward
74        # Check that state dict hooks reshard
75        osd_2 = model2.state_dict()
76        dsd_2 = get_model_state_dict(model2)
77        self.assertEqual(osd_2, dsd_2)
78
79    @skip_if_lt_x_gpu(2)
80    def test_1d_fsdp_cpu_offload_full_model_state_dict(self):
81        """
82        Test full_state_dict and cpu_offload works for FSDP2 state_dict.
83        """
84        orig_model = self._get_base_model()
85        fsdp_model = copy.deepcopy(orig_model)
86        for module in fsdp_model:
87            fully_shard(module)
88        fully_shard(fsdp_model)
89
90        osd = orig_model.state_dict()
91        dsd = get_model_state_dict(
92            fsdp_model, options=StateDictOptions(full_state_dict=True, cpu_offload=True)
93        )
94
95        cpu_device = torch.device("cpu")
96
97        def is_cpu(v):
98            if isinstance(v, DTensor):
99                return v.device == torch.device("cpu")
100            else:
101                return v.device == cpu_device
102
103        if self.rank == 0:
104            self.assertEqual(osd, dsd)
105            self.assertTrue(tree_all_only((torch.Tensor, DTensor), is_cpu, osd))
106        else:
107            self.assertEqual(dsd, {})
108
109    @skip_if_lt_x_gpu(2)
110    def test_save_with_fsdp1_and_load_with_fsdp2(self):
111        self.run_subtests(
112            {
113                "state_dict_type": [
114                    StateDictType.FULL_STATE_DICT,
115                    StateDictType.SHARDED_STATE_DICT,
116                ]
117            },
118            self._test_save_with_fsdp1_and_load_with_fsdp2,
119        )
120
121    @skip_if_lt_x_gpu(2)
122    @with_temp_dir
123    def _test_save_with_fsdp1_and_load_with_fsdp2(self, state_dict_type: StateDictType):
124        """
125        Test that we can save a model with FSDP1 and load it with FSDP2.
126        """
127
128        # Save state dict with model wrapped with FSDP1
129        fsdp1_model = FSDP(
130            self._get_base_model().cuda(),
131            use_orig_params=True,
132            auto_wrap_policy=always_wrap_policy,
133        )
134
135        fsdp1_optim = torch.optim.AdamW(fsdp1_model.parameters(), lr=0.1)
136
137        fsdp1_model(torch.randn((2,), device=self.rank)).sum().backward()
138        fsdp1_optim.step()
139
140        with FSDP.state_dict_type(fsdp1_model, state_dict_type):
141            fsdp1_state_dict = {
142                "model": fsdp1_model.state_dict(),
143                "optim": FSDP.sharded_optim_state_dict(fsdp1_model, fsdp1_optim),
144            }
145            dcp.save(
146                fsdp1_state_dict,
147                checkpoint_id=self.temp_dir,
148            )
149
150        fsdp1_full_msd = get_model_state_dict(
151            fsdp1_model,
152            options=StateDictOptions(full_state_dict=True, cpu_offload=True),
153        )
154        fsdp1_full_osd = get_optimizer_state_dict(
155            fsdp1_model,
156            fsdp1_optim,
157            options=StateDictOptions(full_state_dict=True, cpu_offload=True),
158        )
159
160        # Load state dict into model with FSDP2 applied
161        fsdp2_model = self._get_base_model()
162        for module in fsdp2_model:
163            fully_shard(module)
164        fully_shard(fsdp2_model)
165        fsdp2_optim = torch.optim.AdamW(fsdp2_model.parameters(), lr=0.1)
166
167        fsdp2_state_dict = {
168            "model": get_model_state_dict(fsdp2_model),
169            "optim": get_optimizer_state_dict(fsdp2_model, fsdp2_optim),
170        }
171        dcp.load(
172            fsdp2_state_dict,
173            checkpoint_id=self.temp_dir,
174        )
175        fsdp2_model.load_state_dict(fsdp2_state_dict["model"])
176        fsdp2_optim.load_state_dict(fsdp2_state_dict["optim"])
177
178        fsdp2_full_msd = get_model_state_dict(
179            fsdp2_model,
180            options=StateDictOptions(full_state_dict=True, cpu_offload=True),
181        )
182        fsdp2_full_osd = get_optimizer_state_dict(
183            fsdp2_model,
184            fsdp2_optim,
185            options=StateDictOptions(full_state_dict=True, cpu_offload=True),
186        )
187
188        # Compare full state dict to make sure they are the same.
189        self.assertEqual(fsdp2_full_msd, fsdp1_full_msd)
190        self.assertEqual(fsdp1_full_osd, fsdp2_full_osd)
191
192    @skip_if_lt_x_gpu(4)
193    @with_temp_dir
194    def test_save_with_fsdp1_and_load_with_fsdp2_tp(self):
195        """
196        Test that we can save a model with FSDP1 and load it with FSDP2 + TP on 2d mesh.
197        """
198
199        def _get_base_model(mlp_dim: int = 2):
200            base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim))
201            return base_model
202
203        # init device mesh
204        dp_size = 2
205        global_mesh = init_device_mesh(
206            "cuda",
207            (dp_size, self.world_size // dp_size),
208            mesh_dim_names=("dp", "tp"),
209        )
210        dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
211
212        # Save state dict with original model
213        base_model = _get_base_model().cuda()
214        base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
215
216        # Save state dict with model wrapped with FSDP1
217        fsdp1_model = FSDP(
218            copy.deepcopy(base_model),
219            device_mesh=global_mesh,
220            use_orig_params=True,
221            auto_wrap_policy=always_wrap_policy,
222        )
223
224        fsdp1_optim = torch.optim.AdamW(fsdp1_model.parameters(), lr=0.1)
225
226        # one-step training to modify state dict
227        inp = torch.randn((2,), device=self.rank)
228        base_model(inp).sum().backward()
229        base_optim.step()
230        fsdp1_model(inp).sum().backward()
231        fsdp1_optim.step()
232
233        # obtain the full state dict
234        base_msd = get_model_state_dict(
235            base_model,
236            options=StateDictOptions(full_state_dict=True, cpu_offload=True),
237        )
238        base_osd = get_optimizer_state_dict(
239            base_model,
240            base_optim,
241            options=StateDictOptions(full_state_dict=True, cpu_offload=True),
242        )
243
244        # obtain the sharded state dict
245        fsdp1_msd = get_model_state_dict(
246            fsdp1_model,
247            options=StateDictOptions(full_state_dict=False),
248        )
249        fsdp1_osd = get_optimizer_state_dict(
250            fsdp1_model,
251            fsdp1_optim,
252            options=StateDictOptions(full_state_dict=False),
253        )
254
255        # save state dict to temp dir
256        source_state_dict = {
257            "model_full": base_msd,
258            "optim_full": base_osd,
259            "model_sharded": fsdp1_msd,
260            "optim_sharded": fsdp1_osd,
261        }
262        dcp.save(
263            source_state_dict,
264            checkpoint_id=self.temp_dir,
265        )
266
267        # FSDP + TP
268        fsdp2_tp_model = _get_base_model()
269        fsdp2_tp_model = parallelize_module(
270            fsdp2_tp_model,
271            device_mesh=tp_mesh,
272            parallelize_plan={
273                "0.in_proj": ColwiseParallel(),
274                "0.out_proj": RowwiseParallel(),
275                "1.in_proj": ColwiseParallel(),
276                "1.out_proj": RowwiseParallel(),
277                "2.in_proj": ColwiseParallel(),
278                "2.out_proj": RowwiseParallel(),
279            },
280        )
281        for module in fsdp2_tp_model:
282            fully_shard(module, mesh=dp_mesh)
283        fully_shard(fsdp2_tp_model, mesh=dp_mesh)
284
285        fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1)
286
287        # Load state dict into model with FSDP2 + TP applied
288        for src_state_dict_type in ["full", "sharded"]:
289            msd_name = f"model_{src_state_dict_type}"
290            osd_name = f"optim_{src_state_dict_type}"
291            fsdp2_tp_state_dict = {
292                msd_name: get_model_state_dict(fsdp2_tp_model),
293                osd_name: get_optimizer_state_dict(fsdp2_tp_model, fsdp2_tp_optim),
294            }
295            # load state dict from temp dir
296            dcp.load(
297                fsdp2_tp_state_dict,
298                checkpoint_id=self.temp_dir,
299            )
300            fsdp2_tp_model.load_state_dict(fsdp2_tp_state_dict[msd_name])
301            fsdp2_tp_optim.load_state_dict(fsdp2_tp_state_dict[osd_name])
302
303            fsdp2_tp_full_msd = get_model_state_dict(
304                fsdp2_tp_model,
305                options=StateDictOptions(full_state_dict=True, cpu_offload=True),
306            )
307            fsdp2_tp_full_osd = get_optimizer_state_dict(
308                fsdp2_tp_model,
309                fsdp2_tp_optim,
310                options=StateDictOptions(full_state_dict=True, cpu_offload=True),
311            )
312
313            # Compare full state dict to make sure they are the same.
314            self.assertEqual(base_msd, fsdp2_tp_full_msd)
315            self.assertEqual(base_osd, fsdp2_tp_full_osd)
316
317    @skip_if_lt_x_gpu(4)
318    @with_temp_dir
319    def test_save_with_tp_and_load_with_fsdp2_tp(self):
320        """
321        Test that we can save a model with TP and load it with FSDP2 + TP on 2d mesh.
322        """
323
324        def _get_base_model(mlp_dim: int = 2):
325            base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim))
326            return base_model
327
328        tp_parallelize_plan = {
329            "0.in_proj": ColwiseParallel(),
330            "0.out_proj": RowwiseParallel(),
331            "1.in_proj": ColwiseParallel(),
332            "1.out_proj": RowwiseParallel(),
333            "2.in_proj": ColwiseParallel(),
334            "2.out_proj": RowwiseParallel(),
335        }
336
337        # init device mesh
338        dp_size = 2
339        global_mesh_1d = init_device_mesh(
340            "cuda", (self.world_size,), mesh_dim_names=("tp",)
341        )
342        global_mesh_2d = init_device_mesh(
343            "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
344        )
345        dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"]
346
347        # Save state dict with original model
348        base_model = _get_base_model().cuda()
349        base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
350
351        # Save state dict with TP model
352        tp_model = copy.deepcopy(base_model)
353        tp_model = parallelize_module(
354            tp_model,
355            device_mesh=global_mesh_1d,
356            parallelize_plan=tp_parallelize_plan,
357        )
358        tp_model_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1)
359
360        # one-step training to modify state dict
361        inp = torch.randn((2,), device=self.rank)
362        base_model(inp).sum().backward()
363        base_optim.step()
364        tp_model(inp).sum().backward()
365        tp_model_optim.step()
366
367        # obtain the full state dict
368        base_msd = get_model_state_dict(
369            base_model,
370            options=StateDictOptions(full_state_dict=True, cpu_offload=True),
371        )
372        base_osd = get_optimizer_state_dict(
373            base_model,
374            base_optim,
375            options=StateDictOptions(full_state_dict=True, cpu_offload=True),
376        )
377
378        # obtain sharded state dict
379        tp_msd = get_model_state_dict(
380            tp_model,
381            options=StateDictOptions(full_state_dict=False),
382        )
383        tp_osd = get_optimizer_state_dict(
384            tp_model,
385            tp_model_optim,
386            options=StateDictOptions(full_state_dict=False),
387        )
388
389        # save state dict to temp dir
390        source_state_dict = {
391            "model_full": base_msd,
392            "optim_full": base_osd,
393            "model_sharded": tp_msd,
394            "optim_sharded": tp_osd,
395        }
396        dcp.save(
397            source_state_dict,
398            checkpoint_id=self.temp_dir,
399        )
400
401        # FSDP + TP
402        fsdp2_tp_model = _get_base_model()
403        fsdp2_tp_model = parallelize_module(
404            fsdp2_tp_model,
405            device_mesh=tp_mesh,
406            parallelize_plan=tp_parallelize_plan,
407        )
408        for module in fsdp2_tp_model:
409            fully_shard(module, mesh=dp_mesh)
410        fully_shard(fsdp2_tp_model, mesh=dp_mesh)
411        fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1)
412
413        # Load state dict into model with FSDP2 + TP applied
414        for src_state_dict_type in ["full", "sharded"]:
415            msd_name = f"model_{src_state_dict_type}"
416            osd_name = f"optim_{src_state_dict_type}"
417            fsdp2_tp_state_dict = {
418                msd_name: get_model_state_dict(fsdp2_tp_model),
419                osd_name: get_optimizer_state_dict(fsdp2_tp_model, fsdp2_tp_optim),
420            }
421            # load state dict from temp dir
422            dcp.load(
423                fsdp2_tp_state_dict,
424                checkpoint_id=self.temp_dir,
425            )
426            fsdp2_tp_model.load_state_dict(fsdp2_tp_state_dict[msd_name])
427            fsdp2_tp_optim.load_state_dict(fsdp2_tp_state_dict[osd_name])
428
429            fsdp2_tp_full_msd = get_model_state_dict(
430                fsdp2_tp_model,
431                options=StateDictOptions(full_state_dict=True, cpu_offload=True),
432            )
433            fsdp2_tp_full_osd = get_optimizer_state_dict(
434                fsdp2_tp_model,
435                fsdp2_tp_optim,
436                options=StateDictOptions(full_state_dict=True, cpu_offload=True),
437            )
438
439            # Compare full state dict to make sure they are the same.
440            self.assertEqual(base_msd, fsdp2_tp_full_msd)
441            self.assertEqual(base_osd, fsdp2_tp_full_osd)
442
443    @skip_if_lt_x_gpu(4)
444    def test_save_with_fsdp2_tp_and_load_with_tp(self):
445        self.run_subtests(
446            {"allow_implicit_replication": [True, False]},
447            self._test_save_with_fsdp2_tp_and_load_with_tp,
448        )
449
450    @skip_if_lt_x_gpu(4)
451    @with_temp_dir
452    def _test_save_with_fsdp2_tp_and_load_with_tp(
453        self, allow_implicit_replication: bool
454    ):
455        """
456        Test that we can save a model with FSDP2 + TP on 2d mesh and load it with TP.
457        """
458
459        def _get_base_model(mlp_dim: int = 2):
460            base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim))
461            return base_model
462
463        cm = (
464            implicit_replication()
465            if allow_implicit_replication
466            else contextlib.nullcontext()
467        )
468        tp_parallelize_plan = {
469            "0.in_proj": ColwiseParallel(),
470            "0.out_proj": RowwiseParallel(),
471            "1.in_proj": ColwiseParallel(),
472            "1.out_proj": RowwiseParallel(),
473            "2.in_proj": ColwiseParallel(),
474            "2.out_proj": RowwiseParallel(),
475        }
476        if allow_implicit_replication:
477            # intentionally pop the plans for some tp layers so that the model is not fully tensor parallelized
478            tp_parallelize_plan.pop("0.in_proj")
479            tp_parallelize_plan.pop("0.out_proj")
480
481        with cm:
482            tp_parallelize_plan = {
483                "0.in_proj": ColwiseParallel(),
484                "0.out_proj": RowwiseParallel(),
485                "1.in_proj": ColwiseParallel(),
486                "1.out_proj": RowwiseParallel(),
487                "2.in_proj": ColwiseParallel(),
488                "2.out_proj": RowwiseParallel(),
489            }
490
491            # init device mesh
492            dp_size = 2
493            global_mesh_1d = init_device_mesh(
494                "cuda", (self.world_size,), mesh_dim_names=("tp",)
495            )
496            global_mesh_2d = init_device_mesh(
497                "cuda",
498                (dp_size, self.world_size // dp_size),
499                mesh_dim_names=("dp", "tp"),
500            )
501            dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"]
502
503            for save_full_state_dict in [True, False]:
504                # Save state dict with original model
505                base_model = _get_base_model().cuda()
506                base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
507
508                # Save state dict with FSDP2 + TP model
509                fsdp2_tp_model = copy.deepcopy(base_model)
510                fsdp2_tp_model = parallelize_module(
511                    fsdp2_tp_model,
512                    device_mesh=tp_mesh,
513                    parallelize_plan=tp_parallelize_plan,
514                )
515                for module in fsdp2_tp_model:
516                    fully_shard(module, mesh=dp_mesh)
517                fully_shard(fsdp2_tp_model, mesh=dp_mesh)
518                fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1)
519
520                # one-step training to modify state dict
521                inp = torch.randn((2,), device=self.rank)
522                base_model(inp).sum().backward()
523                base_optim.step()
524                fsdp2_tp_model(inp).sum().backward()
525                fsdp2_tp_optim.step()
526
527                # obtain the unsharded state dict
528                base_msd = get_model_state_dict(
529                    base_model,
530                    options=StateDictOptions(full_state_dict=True, cpu_offload=True),
531                )
532                base_osd = get_optimizer_state_dict(
533                    base_model,
534                    base_optim,
535                    options=StateDictOptions(full_state_dict=True, cpu_offload=True),
536                )
537
538                # obtain FSDP2 + TP state dict
539                fsdp2_tp_msd = get_model_state_dict(
540                    fsdp2_tp_model,
541                    options=StateDictOptions(full_state_dict=save_full_state_dict),
542                )
543                fsdp2_tp_osd = get_optimizer_state_dict(
544                    fsdp2_tp_model,
545                    fsdp2_tp_optim,
546                    options=StateDictOptions(full_state_dict=save_full_state_dict),
547                )
548
549                fsdp2_tp_state_dict = {"model": fsdp2_tp_msd, "optim": fsdp2_tp_osd}
550                dcp.save(fsdp2_tp_state_dict, checkpoint_id=self.temp_dir)
551
552                fsdp2_tp_full_msd = get_model_state_dict(
553                    fsdp2_tp_model,
554                    options=StateDictOptions(full_state_dict=True, cpu_offload=True),
555                )
556                fsdp2_tp_full_osd = get_optimizer_state_dict(
557                    fsdp2_tp_model,
558                    fsdp2_tp_optim,
559                    options=StateDictOptions(full_state_dict=True, cpu_offload=True),
560                )
561
562                # Load state dict into model with TP applied
563                tp_model = _get_base_model()
564                tp_model = parallelize_module(
565                    tp_model,
566                    device_mesh=global_mesh_1d,
567                    parallelize_plan=tp_parallelize_plan,
568                )
569                tp_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1)
570
571                tp_state_dict = {
572                    "model": get_model_state_dict(tp_model),
573                    "optim": get_optimizer_state_dict(tp_model, tp_optim),
574                }
575                dcp.load(tp_state_dict, checkpoint_id=self.temp_dir)
576                tp_model.load_state_dict(tp_state_dict["model"])
577                tp_optim.load_state_dict(tp_state_dict["optim"])
578
579                tp_full_msd = get_model_state_dict(
580                    tp_model,
581                    options=StateDictOptions(full_state_dict=True, cpu_offload=True),
582                )
583                tp_full_osd = get_optimizer_state_dict(
584                    tp_model,
585                    tp_optim,
586                    options=StateDictOptions(full_state_dict=True, cpu_offload=True),
587                )
588
589                # Compare full state dict to make sure they are the same.
590                self.assertEqual(base_msd, tp_full_msd)
591                self.assertEqual(base_osd, tp_full_osd)
592                self.assertEqual(fsdp2_tp_full_msd, tp_full_msd)
593                self.assertEqual(fsdp2_tp_full_osd, tp_full_osd)
594
595
596if __name__ == "__main__":
597    run_tests()
598