xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_tp_integration.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2import copy
3import sys
4from collections import OrderedDict
5from typing import Dict, List, Optional, Tuple
6
7import torch
8from torch import distributed as dist
9from torch.distributed._tensor import (
10    DeviceMesh,
11    distribute_module,
12    DTensor,
13    init_device_mesh,
14    Replicate,
15    Shard,
16)
17from torch.distributed.fsdp.fully_sharded_data_parallel import (
18    CPUOffload,
19    FullyShardedDataParallel as FSDP,
20    ShardingStrategy,
21)
22from torch.distributed.tensor.debug import CommDebugMode
23from torch.distributed.tensor.parallel import (
24    ColwiseParallel,
25    parallelize_module,
26    RowwiseParallel,
27)
28from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
29from torch.testing._internal.common_fsdp import FSDPTest
30from torch.testing._internal.common_utils import (
31    instantiate_parametrized_tests,
32    run_tests,
33    TEST_WITH_DEV_DBG_ASAN,
34)
35from torch.testing._internal.distributed._tensor.common_dtensor import (
36    MLPModule,
37    RMSNormPython,
38)
39
40
41if not dist.is_available():
42    print("Distributed not available, skipping tests", file=sys.stderr)
43    sys.exit(0)
44
45if TEST_WITH_DEV_DBG_ASAN:
46    print(
47        "Skip dev-asan as torch + multiprocessing spawn have known issues",
48        file=sys.stderr,
49    )
50    sys.exit(0)
51
52
53class SimpleModel(torch.nn.Module):
54    def __init__(self) -> None:
55        super().__init__()
56        self.net1 = torch.nn.Linear(5, 8)
57        self.relu = torch.nn.ReLU()
58        self.net2 = torch.nn.Linear(8, 4)
59        self.net3 = torch.nn.Linear(4, 12)
60
61    def forward(self, x):
62        return self.net3(self.net2(self.relu(self.net1(x))))
63
64    @staticmethod
65    def get_sharded_param_names() -> List[str]:
66        return ["net1.weight", "net1.bias", "net2.weight"]
67
68    @staticmethod
69    def get_non_sharded_param_names() -> List[str]:
70        return ["net3.weight", "net3.bias"]
71
72
73def distribute_rmsnorm(module, device_mesh):
74    def prepare_input_fn(mod, inputs, device_mesh):
75        shard_tensor = DTensor.from_local(inputs[0], device_mesh, [Shard(0)])
76        return shard_tensor
77
78    def prepare_output_fn(mod, outputs, device_mesh):
79        return outputs.to_local()
80
81    return distribute_module(
82        module, device_mesh, input_fn=prepare_input_fn, output_fn=prepare_output_fn
83    )
84
85
86class TestTPFSDPIntegration(FSDPTest):
87    def _get_params_and_sharding_info(
88        self,
89        model: SimpleModel,
90        sharded_param_names: List[str],
91        tensor_parallel_size: int,
92    ) -> Tuple[Dict[str, int], Dict[str, Tuple[torch.Size, int]]]:
93        """ """
94        assert (
95            type(model) is SimpleModel
96        ), "Expects a `SimpleModel` since the sharding cases on the model definition"
97        param_name_to_numel = OrderedDict()
98        param_name_to_sharding_info = OrderedDict()
99        for param_name, param in model.named_parameters():
100            if param_name not in sharded_param_names:
101                param_name_to_numel[param_name] = param.numel()
102            else:
103                param_name_to_numel[param_name] = param.numel() // tensor_parallel_size
104                param_name_to_sharding_info[param_name] = (
105                    param.size(),
106                    0 if "net1" in param_name else 1,
107                )
108        return param_name_to_numel, param_name_to_sharding_info
109
110    def _get_sub_pgs(self, tensor_parallel_size: int):
111        """
112        Generates TP and FSDP subprocess groups. ``tensor_parallel_size`` gives
113        the TP process group size.
114
115        For example, if the global world size is 8 and the tensor parallel size
116        is 2, then this creates:
117        - 4 TP subprocess groups: [0, 1], [2, 3], [4, 5], [6, 7]
118        - 2 FSDP subprocess groups: [0, 2, 4, 6], [1, 3, 5, 7]
119        """
120        # 2-D mesh is [dp, tp]
121        twod_mesh = DeviceMesh(
122            device_type="cuda",
123            mesh=torch.arange(0, self.world_size).view(-1, tensor_parallel_size),
124        )
125
126        fsdp_pg = twod_mesh.get_group(mesh_dim=0)
127        tp_pg = twod_mesh.get_group(mesh_dim=1)
128        return twod_mesh, fsdp_pg, tp_pg
129
130    def _sync_tp_grads(
131        self,
132        tp_fsdp_model: FSDP,
133        tp_pg: dist.ProcessGroup,
134        param_name_to_numel: Dict[str, int],
135        non_sharded_param_names: List[str],
136    ) -> None:
137        """
138        Syncs the tensor parallel parameters' gradients following the data
139        parallel paradigm where gradients are averaged over ranks (in this
140        case, the ones in the tensor parallel process group).
141        """
142        tp_world_size = tp_pg.size()
143        fsdp_world_size = self.world_size // tp_world_size
144        assert (
145            type(tp_fsdp_model) is FSDP
146            and len([m for m in tp_fsdp_model.modules() if type(m) is FSDP]) == 1
147        ), (
148            "The following logic assumes a single top-level-only FSDP wrapping "
149            "the model with TP already applied"
150        )
151        for flat_param in tp_fsdp_model.params:
152            splits = tuple(param_name_to_numel.values())
153            # Create a mask over the gradient elements to manually reduce
154            unsharded_size = torch.Size([flat_param.numel() * fsdp_world_size])
155            unsharded_zeros = torch.zeros(unsharded_size, device=flat_param.device)
156            per_param_masks = unsharded_zeros.split(splits)
157            for param_idx, param_name in enumerate(
158                param_name_to_numel.keys()
159            ):  # assumes fixed order
160                if param_name not in non_sharded_param_names:
161                    per_param_masks[param_idx][:] = 1
162            unsharded_mask = (
163                torch.cat(per_param_masks).contiguous().type(torch.BoolTensor)
164            )
165            sharded_mask = unsharded_mask.chunk(fsdp_world_size)[
166                self.rank // tp_world_size
167            ]
168            grad_device = flat_param.grad.device
169            grad = flat_param.grad.detach().clone().cuda(self.rank)
170            dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=tp_pg)
171            grad = grad.to(grad_device)
172            flat_param.grad[~sharded_mask] = grad[~sharded_mask]
173            # Average *all* gradient elements to match the FSDP only semantics
174            flat_param.grad /= tp_world_size
175
176    def _get_grads_as_flattened(
177        self,
178        model: FSDP,
179        uses_tp: bool,
180        param_name_to_numel: Dict[str, int],
181        param_name_to_sharding_info: Dict[str, Tuple[torch.Size, int]],
182        tp_pg: Optional[dist.ProcessGroup],
183        fsdp_pg: Optional[dist.ProcessGroup],
184        sharded_param_names: Optional[List[str]],
185    ) -> torch.Tensor:
186        """
187        Returns all unsharded gradients as a single flattened tensor. This
188        returns the same value on all ranks.
189        """
190        local_grads_as_flattened = (
191            torch.cat(
192                [
193                    torch.flatten(param.grad)
194                    if param.grad is not None
195                    else torch.zeros_like(torch.flatten(param))
196                    for param in model.parameters()
197                ]
198            )
199            .contiguous()
200            .cuda(self.rank)
201        )
202        all_grads_as_flattened = torch.cat(
203            [torch.empty_like(local_grads_as_flattened) for _ in range(fsdp_pg.size())]
204        ).contiguous()
205        dist.all_gather_into_tensor(
206            all_grads_as_flattened, local_grads_as_flattened, group=fsdp_pg
207        )
208        if not uses_tp:
209            return all_grads_as_flattened
210        splits = tuple(param_name_to_numel.values())
211        all_grads_per_param = list(all_grads_as_flattened.split(splits))
212        for param_idx, param_name in enumerate(
213            param_name_to_numel.keys()
214        ):  # assumes fixed order
215            if param_name in sharded_param_names:
216                local_tensor_size = list(param_name_to_sharding_info[param_name][0])
217                sharding_dim = param_name_to_sharding_info[param_name][1]
218                local_tensor_size[sharding_dim] //= tp_pg.size()
219                local_tensor = all_grads_per_param[param_idx].view(*local_tensor_size)
220                local_tensors = [
221                    torch.empty_like(local_tensor) for _ in range(tp_pg.size())
222                ]
223                dist.all_gather(local_tensors, local_tensor, group=tp_pg)
224                all_grads_per_param[param_idx] = torch.cat(
225                    local_tensors, dim=sharding_dim
226                ).reshape(-1)
227        return torch.cat(all_grads_per_param).contiguous()
228
229    @skip_if_lt_x_gpu(4)
230    def test_fsdp_tp_integration(self):
231        self.run_subtests(
232            {
233                "cpu_offload": [
234                    CPUOffload(offload_params=False),
235                    CPUOffload(offload_params=True),
236                ],
237                "sharding_strategy": [None, ShardingStrategy.SHARD_GRAD_OP],
238                "use_orig_params": [False, True],
239            },
240            self._test_fsdp_tp_integration,
241        )
242
243    def _test_fsdp_tp_integration(
244        self, cpu_offload, sharding_strategy, use_orig_params
245    ):
246        """
247        Tests training for TP + FSDP integration by comparing an FSDP-only
248        model with a TP + FSDP model.
249        """
250        tensor_parallel_size = 2
251        LR = 3e-5
252        torch.manual_seed(0)
253        model = SimpleModel().cuda(self.rank)
254        tp_fsdp_model = copy.deepcopy(model)
255        sharded_param_names = SimpleModel.get_sharded_param_names()
256        non_sharded_param_names = SimpleModel.get_non_sharded_param_names()
257        (
258            param_name_to_numel,
259            param_name_to_sharding_info,
260        ) = self._get_params_and_sharding_info(
261            model,
262            sharded_param_names,
263            tensor_parallel_size,
264        )
265
266        input_seed = self.rank
267        torch.manual_seed(input_seed + 1)
268        inp_size = [2, 3, 5]
269        inp = torch.rand(*inp_size).cuda(self.rank)
270        self.assertEqual(model(inp), tp_fsdp_model(inp))  # sanity check
271
272        mesh_1d = init_device_mesh("cuda", (self.world_size,))
273        fsdp_model = FSDP(
274            model,
275            cpu_offload=cpu_offload,
276            device_mesh=mesh_1d,
277            sharding_strategy=sharding_strategy,
278            use_orig_params=use_orig_params,
279        )
280        mesh_2d = init_device_mesh(
281            "cuda",
282            (self.world_size // tensor_parallel_size, tensor_parallel_size),
283            mesh_dim_names=["dp", "tp"],
284        )
285        # Shard with TP and then wrap with FSDP
286        sequence_parallelize_plan = {
287            "net1": ColwiseParallel(input_layouts=Shard(0)),
288            "net2": RowwiseParallel(output_layouts=Shard(0)),
289        }
290        tp_fsdp_model = parallelize_module(
291            tp_fsdp_model,
292            mesh_2d["tp"],
293            sequence_parallelize_plan,
294        )
295        tp_pg = mesh_2d["tp"].get_group(mesh_dim=0)
296        assert isinstance(tp_fsdp_model.net1.weight, DTensor)
297        assert isinstance(tp_fsdp_model.net2.weight, DTensor)
298        tp_fsdp_model = FSDP(
299            tp_fsdp_model,
300            cpu_offload=cpu_offload,
301            device_mesh=mesh_2d["dp"],
302            sharding_strategy=sharding_strategy,
303            use_orig_params=use_orig_params,
304        )
305        fsdp_pg = mesh_2d["dp"].get_group(mesh_dim=0)
306
307        # Check the forward by checking output equality
308        fsdp_out = fsdp_model(inp)
309        tp_fsdp_out = tp_fsdp_model(inp)
310        self.assertEqual(fsdp_out, tp_fsdp_out)
311
312        # Check the backward by checking gradient equality
313        fsdp_out.sum().backward()
314        tp_fsdp_out.sum().backward()
315        self._sync_tp_grads(
316            tp_fsdp_model,
317            tp_pg,
318            param_name_to_numel,
319            non_sharded_param_names,
320        )
321        model_grads = self._get_grads_as_flattened(
322            fsdp_model,
323            False,
324            param_name_to_numel,
325            param_name_to_sharding_info,
326            None,
327            self.process_group,
328            None,
329        )
330        model_tp_grads = self._get_grads_as_flattened(
331            tp_fsdp_model,
332            True,
333            param_name_to_numel,
334            param_name_to_sharding_info,
335            tp_pg,
336            fsdp_pg,
337            sharded_param_names,
338        )
339        self.assertEqual(model_grads, model_tp_grads)
340
341        # Check the optimizer step by performing a second forward pass
342        fsdp_optim = torch.optim.SGD(fsdp_model.parameters(), lr=LR)
343        tp_fsdp_optim = torch.optim.SGD(tp_fsdp_model.parameters(), lr=LR)
344        fsdp_optim.step()
345        tp_fsdp_optim.step()
346        torch.manual_seed(input_seed + 16)
347        inp = torch.rand(*inp_size).cuda(self.rank)
348        fsdp_out = fsdp_model(inp)
349        tp_fsdp_out = tp_fsdp_model(inp)
350        self.assertEqual(fsdp_out, tp_fsdp_out)
351
352    @skip_if_lt_x_gpu(4)
353    def test_fsdp_tp_extension_grad(self):
354        """
355        Tests TP + FSDP extension with correct gradient (i.e. no ACT)
356        """
357        mesh_2d = init_device_mesh(
358            "cuda", (self.world_size // 2, 2), mesh_dim_names=["dp", "tp"]
359        )
360
361        class TestModel(torch.nn.Module):
362            def __init__(self) -> None:
363                super().__init__()
364                self.mlp = MLPModule("cuda")
365                self.mlp_norm = RMSNormPython(10)
366
367            def forward(self, x):
368                return self.mlp(self.mlp_norm(x))
369
370        model = TestModel().cuda(self.rank)
371
372        # Shard with TP and test gradient
373        tp_mesh = mesh_2d["tp"]
374        tp_model = parallelize_module(
375            model,
376            tp_mesh,
377            {
378                "mlp.net1": ColwiseParallel(input_layouts=Shard(0)),
379                "mlp.net2": RowwiseParallel(output_layouts=Shard(0)),
380            },
381        )
382        distribute_rmsnorm(tp_model.mlp_norm, tp_mesh)
383
384        fsdp_2d_model = FSDP(tp_model, device_mesh=mesh_2d["dp"])
385        comm_mode = CommDebugMode()
386
387        with comm_mode:
388            fsdp_2d_model(torch.rand(2, 10).cuda(self.rank)).sum().backward()
389
390        funcol = torch.ops.c10d_functional
391        c10d_ops = torch.ops.c10d
392        comm_counts = comm_mode.get_comm_counts()
393        self.assertEqual(comm_mode.get_total_counts(), 7)
394        # TP comms
395        self.assertEqual(comm_counts[funcol.reduce_scatter_tensor], 2)
396        self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 2)
397        self.assertEqual(comm_counts[funcol.all_reduce], 1)
398        # FSDP comms
399        self.assertEqual(comm_counts[c10d_ops._allgather_base_], 1)
400        self.assertEqual(comm_counts[c10d_ops._reduce_scatter_base_], 1)
401
402        grads = [p.grad for p in fsdp_2d_model.parameters() if p.grad is not None]
403
404        for grad in grads:
405            self.assertFalse(grad.isnan().any().item())
406
407    @skip_if_lt_x_gpu(4)
408    def test_fsdp_tp_sync_module_state(self):
409        mesh_2d = init_device_mesh(
410            "cuda", (self.world_size // 2, 2), mesh_dim_names=["dp", "tp"]
411        )
412        tp_mesh = mesh_2d["tp"]
413        dp_mesh = mesh_2d["dp"]
414
415        # set random seed for each rank
416        torch.manual_seed(mesh_2d.get_rank())
417
418        class TestModel(torch.nn.Module):
419            def __init__(self) -> None:
420                super().__init__()
421                replicated_dt = DTensor.from_local(
422                    torch.randn(8, 8), tp_mesh, [Replicate()], run_check=False
423                )
424                replicated_buffer_dt = DTensor.from_local(
425                    torch.randn(8, 8), tp_mesh, [Replicate()], run_check=False
426                )
427                self.param = torch.nn.Parameter(replicated_dt)
428                self.buf = torch.nn.Buffer(replicated_buffer_dt)
429
430            def forward(self, x):
431                return self.param + self.buffer + 1
432
433        model = TestModel()
434
435        def assert_local_shard_across_ranks(local_tensor, group, check_equal=True):
436            gathered_tensors = [
437                torch.empty_like(local_tensor) for _ in range(group.size())
438            ]
439            dist.all_gather(gathered_tensors, local_tensor, group=group)
440            # on dp mesh dim local tensor does not equal
441            tensor_to_compare = gathered_tensors[0]
442            for tensor in gathered_tensors[1:]:
443                if check_equal:
444                    self.assertTrue(torch.equal(tensor, tensor_to_compare))
445                else:
446                    self.assertFalse(torch.equal(tensor, tensor_to_compare))
447
448        dp_group = dp_mesh.get_group()
449
450        # check on dp mesh dim param local tensor does not equal
451        local_param = model.param.to_local()
452        assert_local_shard_across_ranks(local_param, dp_group, check_equal=False)
453        # check on dp mesh dim buffer local tensor does not equal
454        local_buf = model.buf.to_local()
455        assert_local_shard_across_ranks(local_buf, dp_group, check_equal=False)
456
457        # wrap with fsdp sync param should sync dp mesh dim
458        fsdp_mod = FSDP(model, device_mesh=dp_mesh, sync_module_states=True)
459        with fsdp_mod.summon_full_params(fsdp_mod):
460            # on dp mesh dim local param does equal after sync_module_states
461            local_param = fsdp_mod.param.to_local()
462            assert_local_shard_across_ranks(local_param, dp_group, check_equal=True)
463
464            # on dp mesh dim local buf does equal after sync_module_states
465            local_buf = fsdp_mod.buf.to_local()
466            assert_local_shard_across_ranks(local_buf, dp_group, check_equal=True)
467
468
469instantiate_parametrized_tests(TestTPFSDPIntegration)
470
471if __name__ == "__main__":
472    run_tests()
473