1# Owner(s): ["oncall: distributed"]
2
3import copy
4import functools
5import io
6from copy import deepcopy
7from typing import List, Type
8
9import torch
10import torch.distributed as dist
11import torch.distributed.checkpoint as dcp
12import torch.nn as nn
13import torch.nn.functional as F
14from torch.distributed._composable import replicate
15from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
16from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shard
17from torch.distributed.checkpoint.state_dict import (
18    get_model_state_dict,
19    get_optimizer_state_dict,
20    set_model_state_dict,
21    set_optimizer_state_dict,
22    StateDictOptions,
23)
24from torch.distributed.device_mesh import DeviceMesh
25from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
26from torch.distributed.fsdp._common_utils import (
27    _get_module_fsdp_state,
28    clean_tensor_name,
29)
30from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
31from torch.distributed.tensor.debug import CommDebugMode
32from torch.distributed.tensor.parallel import (
33    ColwiseParallel,
34    parallelize_module,
35    RowwiseParallel,
36)
37from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform
38from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
39from torch.distributed.tensor.parallel.input_reshard import input_reshard
40from torch.nn.parallel import DistributedDataParallel as DDP
41from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
42from torch.testing._internal.common_fsdp import FSDPTest, MLP, MLPStack
43from torch.testing._internal.common_utils import (
44    instantiate_parametrized_tests,
45    parametrize,
46    run_tests,
47    skipIfRocm,
48)
49from torch.testing._internal.distributed._tensor.common_dtensor import (
50    DTensorTestBase,
51    MLPModule,
52    ModelArgs,
53    Transformer,
54    with_comms,
55)
56from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
57
58
59class SimpleModel(nn.Module):
60    def __init__(self):
61        super().__init__()
62        self.net1 = nn.Linear(5, 8)
63        self.relu = nn.ReLU()
64        self.net2 = nn.Linear(8, 4)
65        self.net3 = nn.Linear(4, 12)
66
67    def forward(self, x):
68        x = F.relu(self.net1(x))
69        x = F.relu(self.net2(x))
70        x = F.relu(self.net3(x))
71        return x
72
73    def get_input(self):
74        return torch.rand(4, 5, device="cuda")
75
76
77class SimpleModelUneven(nn.Module):
78    def __init__(self):
79        super().__init__()
80        torch.manual_seed(0)
81        self.net1 = nn.Linear(5, 10)
82        self.relu = nn.ReLU()
83        self.net2 = nn.Linear(10, 15)
84        self.net3 = nn.Linear(15, 30)
85        self.net4 = nn.Linear(30, 5)
86
87    def forward(self, x):
88        x = F.relu(self.net1(x))
89        x = F.relu(self.net2(x))
90        x = F.relu(self.net3(x))
91        x = self.net4(x)
92        return x
93
94    def get_input(self):
95        return torch.rand(4, 5, device="cuda")
96
97
98class TestFullyShard2DTraining(FSDPTest):
99    global c10d_ops
100    global funcol
101    c10d_ops = torch.ops.c10d
102    funcol = torch.ops.c10d_functional
103
104    @property
105    def world_size(self) -> int:
106        return min(4, torch.cuda.device_count())
107
108    def init_global_mesh(self) -> DeviceMesh:
109        # Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP
110        dp_size = 2 if self.world_size > 2 else 1
111        return init_device_mesh(
112            "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
113        )
114
115    # TODO: remove this test when uneven sharding is supported for FSDP+TP
116    @skip_if_lt_x_gpu(2)
117    def test_2d_uneven_shard_raise_error(self):
118        global_mesh = self.init_global_mesh()
119        dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
120        model = MLPStack(3)
121        with self.assertRaisesRegex(NotImplementedError, "uneven sharding"):
122            model.parallelize(tp_mesh, dp_mesh, False)
123
124    @skip_if_lt_x_gpu(2)
125    @skipIfRocm
126    def test_train_parity_2d_mlp(self):
127        global_mesh = self.init_global_mesh()
128        self.run_subtests(
129            {
130                "reshard_after_forward": [False, True],
131                "use_activation_checkpointing": [False, True],
132                # TODO: change "mlp_dim" back to [3, 16, 17] when uneven sharding
133                # is supported for FSDP+TP
134                "mlp_dim": [4, 16, 20],
135            },
136            functools.partial(self._test_train_parity_2d_mlp, global_mesh),
137        )
138
139    def _test_train_parity_2d_mlp(
140        self,
141        global_mesh: DeviceMesh,
142        reshard_after_forward: bool,
143        use_activation_checkpointing: bool,
144        mlp_dim: int,
145    ):
146        dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
147        dp_pg = dp_mesh.get_group()  # used for `replicate()`
148
149        torch.manual_seed(42)
150        model = MLPStack(mlp_dim)
151        ref_model = copy.deepcopy(model).cuda()
152        replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
153        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
154        model.parallelize(
155            tp_mesh,
156            dp_mesh,
157            use_activation_checkpointing,
158            reshard_after_forward=reshard_after_forward,
159        )
160        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
161
162        torch.manual_seed(42 + dp_pg.rank() + 1)
163        device = torch.device("cuda")
164        for iter_idx in range(10):
165            inp = torch.randn((8, mlp_dim), device=device)
166            losses: List[torch.Tensor] = []
167            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
168                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
169                losses.append(_model(inp).sum())
170                losses[-1].backward()
171                _optim.step()
172            self.assertEqual(losses[0], losses[1])
173
174    @skip_if_lt_x_gpu(2)
175    @skipIfRocm
176    def test_tp_with_fsdp_offloading(self):
177        global_mesh = init_device_mesh(
178            "cuda", (1, self.world_size), mesh_dim_names=("dp", "tp")
179        )
180        dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
181        torch.manual_seed(42)
182        mlp_dim = 16
183        model = MLPStack(mlp_dim)
184        ref_model = copy.deepcopy(model).cuda()
185        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
186        # Parallelize with N-way TP and 1-way FSDP
187        model.parallelize(
188            tp_mesh,
189            dp_mesh,
190            use_activation_checkpointing=False,
191            reshard_after_forward=True,
192            offload_policy=CPUOffloadPolicy(),
193        )
194        for param in model.parameters():
195            self.assertEqual(param.device.type, "cpu")
196        num_mlps = sum(isinstance(module, MLP) for module in model.modules())
197        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
198
199        # NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
200        # called, but they will just be no-ops without issuing any kernels.
201        # We prefer to keep the no-op check at the c10d level, not in FSDP.
202        inp = torch.randn((4, mlp_dim), device="cuda")  # same on all ranks
203        for iter_idx in range(10):
204            ref_optim.zero_grad()
205            optim.zero_grad()
206
207            with CommDebugMode() as fwd_comm_mode:
208                loss = model(inp).sum()
209
210            fwd_comm_counts = fwd_comm_mode.get_comm_counts()
211            self.assertEqual(len(fwd_comm_counts), 2)
212            self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps)
213            self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
214            ref_loss = ref_model(inp).sum()
215            self.assertEqual(loss, ref_loss)
216
217            with CommDebugMode() as bwd_comm_mode:
218                loss.backward()
219            bwd_comm_counts = bwd_comm_mode.get_comm_counts()
220            self.assertEqual(len(bwd_comm_counts), 3)
221            # First MLP's input gradient does not need to be all-reduced
222            self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1)
223            self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
224            self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps)
225            ref_loss.backward()
226
227            optim.step()
228            ref_optim.step()
229
230    @skip_if_lt_x_gpu(2)
231    @with_temp_dir
232    def test_train_parity_2d_transformer_checkpoint_resume(self):
233        """
234        Tests train parity of a 2D transformer without checkpointing against a
235        2D transformer with a checkpoint save/load.
236        """
237        self.run_subtests(
238            {
239                "use_seq_parallel": [False, True],
240                # If reusing, then load into the same model/optimizer instance
241                # else construct new ones (requiring eager optim state init)
242                "reuse_model_optim": [False, True],
243                "optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
244                # TODO: need to update `parallelize` before including foreach=True for testing
245                "foreach": [False],
246            },
247            self._test_train_parity_2d_transformer_checkpoint_resume,
248        )
249
250    def _test_train_parity_2d_transformer_checkpoint_resume(
251        self,
252        use_seq_parallel: bool,
253        reuse_model_optim: bool,
254        optimizer_class: Type[torch.optim.Optimizer],
255        foreach: bool,
256    ):
257        def train_step(
258            _model: nn.Module, _optim: torch.optim.Optimizer, _inp: torch.Tensor
259        ) -> torch.Tensor:
260            loss = _model(_inp).sum()
261            loss.backward()
262            _optim.step()
263            _optim.zero_grad()
264            return loss
265
266        def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool):
267            _model = Transformer.parallelize(_model, mesh["tp"], use_seq_parallel)
268            for layer in _model.layers:
269                fully_shard(layer, mesh=mesh["dp"])
270            fully_shard(_model, mesh=mesh["dp"])
271            return _model
272
273        global_mesh = self.init_global_mesh()
274        # Baseline: run two iterations without checkpointing
275        seed = 42
276        torch.manual_seed(seed)
277        model_args = ModelArgs(dropout_p=0.0)
278        model_no_cp = parallelize(
279            Transformer(model_args), global_mesh, use_seq_parallel
280        )
281        optim_no_cp = optimizer_class(
282            model_no_cp.parameters(), lr=1e-2, foreach=foreach
283        )
284
285        torch.manual_seed(42 + global_mesh["dp"].get_local_rank() + 1)
286        inp = torch.randint(0, model_args.vocab_size, (3, 16), device="cuda")
287        loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp)
288        loss_no_cp2 = train_step(model_no_cp, optim_no_cp, inp)
289
290        # Test: run one iteration, save checkpoint, zero states or init new
291        # model/optimizer, load checkpoint, and run another iteration
292        torch.manual_seed(seed)
293        model_cp = parallelize(Transformer(model_args), global_mesh, use_seq_parallel)
294        optim_cp = optimizer_class(model_cp.parameters(), lr=1e-2, foreach=foreach)
295
296        loss_cp1 = train_step(model_cp, optim_cp, inp)
297        self.assertEqual(loss_no_cp1, loss_cp1)
298
299        sharded_sd = {
300            "model": get_model_state_dict(model_cp),
301            # Use `get_optimizer_state_dict` to handle eager optim state init
302            # when constructing a new optimizer instance
303            "optim": get_optimizer_state_dict(model_cp, optim_cp),
304        }
305        dcp.save(
306            state_dict=sharded_sd,
307            storage_writer=dcp.FileSystemWriter(self.temp_dir),
308        )
309        if reuse_model_optim:
310            with torch.no_grad():
311                for param in model_cp.parameters():
312                    param.zero_()
313                optim_sd = optim_cp.state_dict()
314                for param_states in optim_sd["state"].values():
315                    for state_value in param_states.values():
316                        if torch.is_tensor(state_value):
317                            state_value.zero_()
318        else:
319            torch.manual_seed(seed + 1)  # different seed
320            model_cp = parallelize(
321                Transformer(model_args), global_mesh, use_seq_parallel
322            )
323            optim_cp = optimizer_class(model_cp.parameters(), lr=1e-2, foreach=foreach)
324        self.assertNotEqual(loss_no_cp2, train_step(model_cp, optim_cp, inp))
325
326        sharded_sd = {
327            "model": get_model_state_dict(model_cp),
328            "optim": get_optimizer_state_dict(model_cp, optim_cp),
329        }
330        dcp.load(
331            state_dict=sharded_sd,
332            storage_reader=dcp.FileSystemReader(self.temp_dir),
333        )
334        self.assertGreater(len(optim_cp.state_dict()["state"]), 0)
335
336        loss_cp2 = train_step(model_cp, optim_cp, inp)
337        self.assertEqual(loss_no_cp2, loss_cp2)
338
339
340class TestFullyShard2DStateDict(DTensorTestBase):
341    @property
342    def backend(self):
343        # need to specify gloo backend for testing cpu offload
344        return "cpu:gloo,cuda:nccl"
345
346    @with_comms
347    @skip_if_lt_x_gpu(4)
348    def test_fully_shard_tp_2d_set_full_state_dict(self):
349        dummy_model = SimpleModel().cuda()
350        mesh_2d = init_device_mesh(
351            "cuda",
352            (2, self.world_size // 2),
353            mesh_dim_names=("dp", "tp"),
354        )
355        tp_mesh = mesh_2d["tp"]
356        dp_mesh = mesh_2d["dp"]
357        parallelize_plan = {
358            "net1": ColwiseParallel(),
359            "net2": RowwiseParallel(),
360            "net3": ColwiseParallel(),
361        }
362        model = parallelize_module(dummy_model, tp_mesh, parallelize_plan)
363        fully_shard(model, mesh=dp_mesh)
364        optim = torch.optim.Adam(model.parameters(), lr=0.01)
365        model(model.get_input()).sum().backward()
366        optim.step()
367        # ref_msd, ref_osd are both the default sharded state dict
368        ref_msd = copy.deepcopy(get_model_state_dict(model))
369        ref_osd = copy.deepcopy(get_optimizer_state_dict(model, optimizers=optim))
370
371        options = StateDictOptions(
372            full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True
373        )
374        full_msd = get_model_state_dict(model, options=options)
375        full_osd = get_optimizer_state_dict(model, optimizers=optim, options=options)
376        # load full_msd and full_osd into model and optim.
377        # this loads the slice of full tensor into each rank's local DTensor.
378        set_model_state_dict(model, full_msd, options=options)
379        set_optimizer_state_dict(
380            model, optimizers=optim, optim_state_dict=full_osd, options=options
381        )
382
383        # check after setting full state dict, the model and optim default sharded state dict
384        # are the same as the initial default sharded state dict.
385        new_msd = get_model_state_dict(model)
386        new_osd = get_optimizer_state_dict(model, optimizers=optim)
387        self.assertEqual(ref_msd, new_msd)
388        self.assertEqual(ref_osd, new_osd)
389
390
391class Test2dFSDP1ParallelIntegration(DTensorTestBase):
392    def init_model(self, device_type, model_parallel_size=2):
393        torch.manual_seed(0)
394        model = MLPModule(device_type)
395        torch.manual_seed(0)
396        twod_model = MLPModule(device_type)
397        model = DDP(model)
398
399        # 2-D mesh is [dp, tp]
400        world_size = dist.get_world_size()
401        mesh_2d = init_device_mesh(
402            device_type,
403            (world_size // model_parallel_size, model_parallel_size),
404            mesh_dim_names=("dp", "tp"),
405        )
406
407        dp_pg = mesh_2d.get_group(mesh_dim=0)
408
409        parallelize_plan = {
410            "net1": ColwiseParallel(),
411            "net2": RowwiseParallel(),
412        }
413        twod_model = parallelize_module(twod_model, mesh_2d["tp"], parallelize_plan)
414        _pre_dp_module_transform(twod_model)
415        # TODO: Add tests when using gradient_as_bucket_view and static_graph for DDP.
416        twod_model = DDP(twod_model, process_group=dp_pg)
417        return model, twod_model, dp_pg
418
419    def _check_module(self, m1, m2, check_grad=False):
420        named_parameters = dict(m1.named_parameters())
421        for name, param_m2 in m2.named_parameters():
422            if name not in named_parameters:
423                print(name, named_parameters.keys())
424            self.assertTrue(name in named_parameters)
425            param_m1 = named_parameters[name]
426            if check_grad:
427                param_m2 = param_m2.grad
428                param_m1 = param_m1.grad
429            if isinstance(param_m2, DTensor):
430                replicate = [Replicate()]
431                param_m2 = param_m2.redistribute(
432                    device_mesh=param_m2.device_mesh, placements=replicate
433                ).to_local()
434            self.assertEqual(param_m2, param_m1)
435
436    @with_comms
437    @skip_if_lt_x_gpu(4)
438    def test_2d_ddp_integration_functionality(self) -> None:
439        model, twod_model, dp_pg = self.init_model(self.device_type)
440        optim = torch.optim.Adam(model.parameters(), lr=3e-5)
441        twod_optim = torch.optim.Adam(twod_model.parameters(), lr=3e-5)
442
443        # Create Input
444        input_seed = dist.get_rank(dp_pg)
445        torch.manual_seed(input_seed + 1)
446        input = torch.rand(4, 10, device=self.device_type)
447
448        output = model(input)
449        twod_output = twod_model(input)
450        self.assertEqual(output, twod_output)
451
452        output.sum().backward()
453        twod_output.sum().backward()
454        self._check_module(model, twod_model, check_grad=True)
455
456        optim.step()
457        twod_optim.step()
458        self._check_module(model, twod_model)
459
460        torch.manual_seed(input_seed + 1004)
461        input = torch.rand(16, 10, device=self.device_type)
462
463        output = model(input)
464        twod_output = twod_model(input)
465        self.assertEqual(output, twod_output)
466
467        # TODO: Add save/load of 2D verification.
468
469
470# TODO: add additional tests for multi_param_group, optim_in_backward,
471# and fsdp_nested.
472class TestNew2dParallelTraining(DTensorTestBase):
473    def _compare_params(self, m1, m2):
474        with FSDP.summon_full_params(m1):
475            with FSDP.summon_full_params(m2):
476                for n_p1, n_p2 in zip(m1.named_parameters(), m2.named_parameters()):
477                    p1 = n_p1[1]
478                    p2 = n_p2[1]
479                    if n_p1[0] != n_p2[0]:
480                        self.assertTrue(n_p1[0] in n_p2[0])
481                    name = n_p1[0]
482                    if name == "net2.bias" and self.rank != 0:
483                        continue
484                    if type(p2) is DTensor:
485                        p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local()
486                    self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
487
488    @with_comms
489    @skip_if_lt_x_gpu(4)
490    def test_raise_invalid_tp_composition(self):
491        with self.assertRaisesRegex(
492            RuntimeError, r"Found TP device_mesh on the \d dimension of its parent mesh"
493        ):
494            mesh_2d = init_device_mesh(
495                self.device_type, (2, self.world_size // 2), mesh_dim_names=("tp", "dp")
496            )
497            parallelize_plan = {
498                "net1": ColwiseParallel(),
499                "net2": RowwiseParallel(),
500            }
501            model_2d = parallelize_module(
502                SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan
503            )
504
505    @with_comms
506    @skip_if_lt_x_gpu(4)
507    def test_2d_fsdp_state_enable_extension(self):
508        mesh_2d = init_device_mesh(
509            self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
510        )
511        model = FSDP(
512            SimpleModel().cuda(),
513            device_mesh=mesh_2d["dp"],
514        )
515        fsdp_state = _get_module_fsdp_state(model)
516        self.assertTrue(isinstance(fsdp_state._fsdp_extension, DTensorExtensions))
517
518    def _test_2d_e2e_training(
519        self,
520        use_orig_params=False,
521        recompute_activation=False,
522    ) -> None:
523        torch.manual_seed(0)
524        model = SimpleModel().cuda(self.rank)
525        model = FSDP(model, use_orig_params=use_orig_params)
526        optim = torch.optim.Adam(model.parameters(), lr=0.01)
527
528        torch.manual_seed(0)
529        mesh_2d = init_device_mesh(
530            self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
531        )
532        tp_mesh = mesh_2d["tp"]
533        dp_mesh = mesh_2d["dp"]
534        parallelize_plan = {
535            "net1": ColwiseParallel(),
536            "net2": RowwiseParallel(),
537        }
538        model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, parallelize_plan)
539        model_2d = FSDP(
540            model_2d,
541            device_mesh=dp_mesh,
542            use_orig_params=use_orig_params,
543        )
544        optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
545
546        if recompute_activation:
547            model_2d = input_reshard(model_2d, mesh_2d["tp"], 0)
548
549        # Check named parameters are returning the same name at least.
550        param_names_2d = [
551            clean_tensor_name(name) for name, _ in model_2d.named_parameters()
552        ]
553        for name, _ in model.named_parameters():
554            name = clean_tensor_name(name)
555            if name not in param_names_2d:
556                print(name, param_names_2d)
557            self.assertTrue(name in param_names_2d)
558        self._compare_params(model, model_2d)
559
560        # TODO: add additional tests for multi_param_group and optim_in_backward.
561
562        for i in range(5):
563            # Ensure all input across TP ranks are same.
564            # TODO: add a get_group_rank() to DeviceMesh.
565            torch.manual_seed(i + dist.get_rank(dp_mesh.get_group(mesh_dim=0)))
566            input = torch.rand(4, 5).cuda(self.rank)
567            output = model(input)
568            output_2d = model_2d(input)
569            self.assertEqual(output, output_2d)
570            output.sum().backward()
571            output_2d.sum().backward()
572            optim.step()
573            optim_2d.step()
574            self.assertEqual(model(input), model_2d(input))
575
576        # Ensure all params are still the same after optimizer update.
577        self._compare_params(model, model_2d)
578
579    @with_comms
580    @skip_if_lt_x_gpu(4)
581    def test_2d_e2e_training_default(self):
582        self._test_2d_e2e_training()
583
584    @with_comms
585    @skip_if_lt_x_gpu(4)
586    def test_2d_e2e_training_use_orig_params(self):
587        self._test_2d_e2e_training(use_orig_params=True)
588
589    @with_comms
590    @skip_if_lt_x_gpu(4)
591    def test_2d_e2e_training_not_use_orig_params(self):
592        # TODO: need to revisit input_reshard API about why it failed multi-gpu tests.
593        # self._test_2d_e2e_training(recompute_activation=True)
594        self._test_2d_e2e_training(recompute_activation=False)
595
596
597# TODO: update all state dict unit tests to use distributed.checkpoint.state_dict,
598# and consolidate all the state_dict test in test.distributed.checkpoint.
599class TestNew2dParallelStateDict(DTensorTestBase):
600    @property
601    def backend(self):
602        # need to specify gloo backend for testing cpu offload
603        return "cpu:gloo,cuda:nccl"
604
605    @with_comms
606    @skip_if_lt_x_gpu(4)
607    def test_fsdp_2d_extension(self):
608        """
609        Test whether _fsdp_extension from FSDPstate has been set correctly.
610        """
611        mesh_2d = init_device_mesh(
612            self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
613        )
614        parallelize_plan = {
615            "net1": ColwiseParallel(),
616            "net2": RowwiseParallel(),
617            "net3": ColwiseParallel(),
618        }
619        model_2d = parallelize_module(
620            SimpleModel().cuda(),
621            mesh_2d["tp"],
622            parallelize_plan=parallelize_plan,
623        )
624        model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True)
625        model_2d_fsdp_state = _get_module_fsdp_state(model_2d)
626        self.assertTrue(
627            isinstance(model_2d_fsdp_state._fsdp_extension, DTensorExtensions)
628        )
629
630        mesh_1d = init_device_mesh("cuda", (self.world_size,))
631        model_1d = FSDP(SimpleModel().cuda(), device_mesh=mesh_1d, use_orig_params=True)
632        model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
633        self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
634
635    @with_comms
636    @skip_if_lt_x_gpu(4)
637    @parametrize("is_even_sharded_model", [True, False])
638    def test_2d_state_dict(self, is_even_sharded_model):
639        simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
640
641        # Create a model without wrapper
642        torch.manual_seed(0)
643        no_wrap_model = simple_model().cuda(self.rank)
644        no_wrap_state_dict = no_wrap_model.state_dict()
645
646        # Create a model and sharded it with 2D FSDP + TP
647        torch.manual_seed(0)
648        mesh_2d = init_device_mesh(
649            self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
650        )
651        tp_mesh = mesh_2d["tp"]
652        dp_mesh = mesh_2d["dp"]
653        parallelize_plan = {
654            "net1": ColwiseParallel(),
655            "net2": RowwiseParallel(),
656        }
657        model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan)
658        model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
659
660        FSDP.set_state_dict_type(
661            model_2d,
662            StateDictType.SHARDED_STATE_DICT,
663        )
664        state_dict_2d = model_2d.state_dict()
665
666        for no_wrap_items, two_d_items in zip(
667            no_wrap_state_dict.items(), state_dict_2d.items()
668        ):
669            no_wrap_k, no_wrap_v = no_wrap_items
670            two_d_k, two_d_v = two_d_items
671
672            self.assertEqual(no_wrap_k, two_d_k)
673
674            # check if all value in 2D state_dict are DTensor
675            self.assertTrue(isinstance(two_d_v, DTensor))
676            self.assertEqual(len(two_d_v.placements), 2)
677            # the outer dimension is the FSDP dimension and the placement is always Shard(0)
678            self.assertEqual(two_d_v.placements[0], Shard(0))
679            self.assertEqual(two_d_v.device_mesh, mesh_2d)
680
681            # check if the parameter value is the same between 2D model and the model without wrapper
682            all_gather_two_d_v = two_d_v.redistribute(
683                mesh_2d, (Replicate(), Replicate())
684            )
685            self.assertEqual(
686                torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True
687            )
688
689    @with_comms
690    @skip_if_lt_x_gpu(4)
691    @parametrize("is_even_sharded_model", [True, False])
692    def test_2d_load_state_dict(self, is_even_sharded_model):
693        simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
694
695        torch.manual_seed(0)
696        mesh_2d = init_device_mesh(
697            self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
698        )
699        tp_mesh = mesh_2d["tp"]
700        dp_mesh = mesh_2d["dp"]
701        parallelize_plan = {
702            "net1": ColwiseParallel(),
703            "net2": RowwiseParallel(),
704        }
705        model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan)
706        model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True)
707        optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
708
709        FSDP.set_state_dict_type(
710            model_2d,
711            StateDictType.SHARDED_STATE_DICT,
712        )
713        checkpoint = io.BytesIO()
714        torch.save(model_2d.state_dict(), checkpoint)
715        # Deepcopy to save current state_dict to compare with the state_dict loaded back below.
716        ref_state_dict = deepcopy(model_2d.state_dict())
717
718        # Update the parameters so model.state_dict() will be different from ref_dtensor_sd.
719        model_2d(model_2d.get_input().cuda(self.rank)).sum().backward()
720        optim_2d.step()
721
722        # Load ref_state_dict back.
723        checkpoint.seek(0)
724        load_ref_state_dict = torch.load(checkpoint)
725        model_2d.load_state_dict(load_ref_state_dict)
726        new_state_dict = model_2d.state_dict()
727
728        # Check whether new_state_dict is the same as ref_state_dict.
729        for (k1, v1), (k2, v2) in zip(ref_state_dict.items(), new_state_dict.items()):
730            # check whether fqn are the same
731            self.assertEqual(k1, k2)
732
733            self.assertEqual(type(v1), DTensor)
734            self.assertEqual(type(v2), DTensor)
735            # check whether DTensor are the same
736            # TODO: 2D DTensor comparison is not supported at the time, so we are comparing the spec and the local tensor for now.
737            # TODO: Update it to compare the two DTensors once 2D DTensor comparison is supported.
738            self.assertEqual(v1.to_local(), v2.to_local())
739            self.assertEqual(v1.device_mesh, v2.device_mesh)
740            self.assertEqual(v1.placements, v2.placements)
741
742    @with_comms
743    @skip_if_lt_x_gpu(4)
744    @parametrize("is_even_sharded_model", [True, False])
745    def test_2d_optim_state_dict(self, is_even_sharded_model):
746        simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
747
748        # Create a model without wrapper
749        torch.manual_seed(0)
750        no_wrap_model = simple_model().cuda(self.rank)
751        no_wrap_state_dict = no_wrap_model.state_dict()
752        no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01)
753        no_wrap_model(no_wrap_model.get_input().cuda(self.rank)).sum().backward()
754        no_wrap_optim.step()
755        no_wrap_osd = get_optimizer_state_dict(no_wrap_model, optimizers=no_wrap_optim)
756
757        # Create a model and sharded it with 2D FSDP + TP
758        torch.manual_seed(0)
759        mesh_2d = init_device_mesh(
760            self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
761        )
762        parallelize_plan = {
763            "net1": ColwiseParallel(),
764            "net2": RowwiseParallel(),
765        }
766        model_2d = parallelize_module(
767            simple_model().cuda(), mesh_2d["tp"], parallelize_plan
768        )
769        model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True)
770        FSDP.set_state_dict_type(
771            model_2d,
772            StateDictType.SHARDED_STATE_DICT,
773        )
774        optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01)
775        model_2d(model_2d.get_input().cuda(self.rank)).sum().backward()
776        optim_2d.step()
777        optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d)
778        ref_optim_2d_osd = deepcopy(optim_2d_osd)
779
780        no_wrap_osd_states = no_wrap_osd["state"]
781        optim_2d_osd_states = optim_2d_osd["state"]
782
783        self.assertEqual(len(no_wrap_osd_states), len(optim_2d_osd_states))
784        self.assertEqual(no_wrap_osd_states.keys(), optim_2d_osd_states.keys())
785        for fqn, states in no_wrap_osd_states.items():
786            dist_states = optim_2d_osd_states.get(fqn)
787
788            for state_name, state in states.items():
789                dist_state = dist_states.get(state_name)
790                # If a state  is DTensor, we all gather it in both DP and TP dimension to
791                # compare with no_wrap state.
792                if isinstance(dist_state, DTensor):
793                    dist_state = (
794                        dist_state.cuda()
795                        .redistribute(placements=(Replicate(), Replicate()))
796                        .to_local()
797                    )
798                self.assertTrue(isinstance(dist_state, torch.Tensor))
799                self.assertTrue(torch.allclose(state, dist_state))
800
801        # Update the parameters 2d optim states will be different from ref_optim_state_dict.
802        model_2d(model_2d.get_input().cuda(self.rank)).sum().backward()
803        optim_2d.step()
804
805        set_optimizer_state_dict(
806            model_2d, optimizers=optim_2d, optim_state_dict=ref_optim_2d_osd
807        )
808        new_optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d)
809
810        ref_optim_2d_osd_states = ref_optim_2d_osd["state"]
811        new_optim_2d_osd_states = optim_2d_osd["state"]
812
813        # Compare the new optim state dict after load with the reference one
814        self.assertEqual(len(ref_optim_2d_osd_states), len(new_optim_2d_osd_states))
815        self.assertEqual(ref_optim_2d_osd_states.keys(), new_optim_2d_osd_states.keys())
816        for fqn, states in ref_optim_2d_osd_states.items():
817            new_states = new_optim_2d_osd_states.get(fqn)
818
819            for state_name, state in states.items():
820                new_state = new_states.get(state_name)
821
822                if isinstance(new_state, DTensor):
823                    self.assertEqual(new_state.placements, state.placements)
824                    self.assertEqual(new_state.device_mesh, state.device_mesh)
825                    self.assertTrue(
826                        torch.allclose(new_state.to_local(), state.to_local())
827                    )
828                else:
829                    self.assertEqual(new_state, state)
830
831
832instantiate_parametrized_tests(TestNew2dParallelStateDict)
833
834if __name__ == "__main__":
835    run_tests()
836