xref: /aosp_15_r20/external/pytorch/test/distributed/tensor/parallel/test_parallelize_api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2from collections import OrderedDict
3from copy import deepcopy
4
5import torch
6from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
7from torch.distributed.tensor.parallel.api import parallelize_module
8from torch.distributed.tensor.parallel.style import (
9    ColwiseParallel,
10    PrepareModuleInput,
11    PrepareModuleOutput,
12    RowwiseParallel,
13)
14from torch.testing._internal.common_utils import run_tests
15from torch.testing._internal.distributed._tensor.common_dtensor import (
16    DTensorTestBase,
17    MLPModule,
18    MLPStacked,
19    with_comms,
20)
21
22
23class DummyModule(torch.nn.Module):
24    def __init__(self) -> None:
25        super().__init__()
26
27    def forward(self, x):
28        return x
29
30
31class TensorParallelAPITests(DTensorTestBase):
32    @property
33    def world_size(self):
34        gpu_num = torch.cuda.device_count()
35        return gpu_num if gpu_num % 2 == 0 and gpu_num > 4 else 4
36
37    def _compare_params(
38        self,
39        local_module,
40        dist_module,
41        rank0_only,
42        skip_rowwise_bias=False,
43        compare_grad=False,
44    ):
45        replicate = [Replicate()]
46        for name, param in local_module.named_parameters():
47            dist_param = dist_module.get_parameter(name)
48            param = param.grad if compare_grad else param
49            dist_param = dist_param.grad if compare_grad else dist_param
50            if (
51                (not rank0_only)
52                or (self.rank == 0)
53                or (
54                    name not in ["net2.bias"]
55                    and not skip_rowwise_bias
56                    or name not in ["bias", "net2.bias"]
57                )
58            ):
59                self.assertEqual(
60                    param,
61                    dist_param.redistribute(
62                        device_mesh=dist_param.device_mesh, placements=replicate
63                    ).to_local(),
64                    f"{name} not equal between dist and non-dist",
65                )
66
67    def _compare_module(
68        self, local_module, dist_module, inp_size, rank0_only=True, rowwise=False
69    ):
70        LR = 0.25  # the learning rate we use for testing
71        local_optim = torch.optim.SGD(local_module.parameters(), lr=LR)
72        dist_optim = torch.optim.SGD(dist_module.parameters(), lr=LR)
73        torch.manual_seed(0)
74        inp = torch.rand(*inp_size, device=self.device_type)
75        self._compare_params(local_module, dist_module, rank0_only)
76
77        # check forward correctness
78        local_output = local_module(inp)
79        inp = inp.chunk(self.world_size, dim=-1)[self.rank] if rowwise else inp
80        dist_output = dist_module(inp)
81        dist_output = (
82            dist_output.redistribute(dist_output.device_mesh, [Replicate()]).to_local()
83            if isinstance(dist_output, DTensor)
84            else dist_output
85        )
86        self.assertEqual(local_output, dist_output)
87
88        local_output.sum().backward()
89        dist_output.sum().backward()
90
91        # check backward and ensure gradients are same
92        self._compare_params(local_module, dist_module, rank0_only, rowwise, True)
93
94        local_optim.step()
95        dist_optim.step()
96        self._compare_params(local_module, dist_module, rank0_only, rowwise)
97
98    @with_comms
99    def test_parallelize_mlp_with_module_api(self):
100        inp_size = [12, 10]
101        model = MLPModule(self.device_type)
102        model_tp = deepcopy(model)
103
104        # Parallelize module.
105        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
106        model_tp = parallelize_module(
107            model_tp,
108            device_mesh,
109            {
110                "net1": ColwiseParallel(output_layouts=Replicate()),
111                "net2": ColwiseParallel(output_layouts=Replicate()),
112            },
113        )
114        self._compare_module(model, model_tp, inp_size, rank0_only=False)
115
116    @with_comms
117    def test_parallelize_mlp_with_module_api_nested(self):
118        inp_size = [12, 10]
119        model = torch.nn.Sequential(
120            OrderedDict([("dummy_encoder", MLPModule(self.device_type))])
121        )
122        model_tp = deepcopy(model)
123
124        # Parallelize module.
125        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
126        model_tp = parallelize_module(
127            model_tp,
128            device_mesh,
129            {
130                "dummy_encoder.net1": ColwiseParallel(output_layouts=Replicate()),
131                "dummy_encoder.net2": ColwiseParallel(output_layouts=Replicate()),
132            },
133        )
134        self._compare_module(model, model_tp, inp_size, rank0_only=False)
135
136    @with_comms
137    def test_linear_row_wise_parallel(self):
138        # test RowwiseParallel
139        inp_size = [9, 16]
140        rowwise = RowwiseParallel()
141
142        torch.manual_seed(5)
143        model = torch.nn.Linear(16, 10, device=self.device_type)
144        model_tp = deepcopy(model)
145
146        # parallelize model_tp
147        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
148        model_tp = parallelize_module(model_tp, device_mesh, rowwise)
149
150        # let each rank generate unique local input
151        torch.manual_seed(self.rank)
152        self._compare_module(model, model_tp, inp_size, rowwise=True)
153
154    @with_comms
155    def test_linear_col_wise_parallel(self):
156        # test ColwiseParallel
157        inp_size = [8, 10]
158        colwise = ColwiseParallel(output_layouts=Replicate())
159
160        torch.manual_seed(5)
161        model = torch.nn.Linear(10, 16, device=self.device_type)
162        model_tp = deepcopy(model)
163
164        # parallelize model_tp
165        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
166        model_tp = parallelize_module(model_tp, device_mesh, colwise)
167
168        self._compare_module(model, model_tp, inp_size)
169
170    @with_comms
171    def test_prepare_module_input(self):
172        module = DummyModule()
173        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
174        parallelize_module(
175            module,
176            device_mesh,
177            PrepareModuleInput(
178                input_layouts=Shard(0), desired_input_layouts=Replicate()
179            ),
180        )
181        inp = torch.rand(5, 7, device=self.device_type)
182        output = module(inp).redistribute(device_mesh, [Shard(0)]).to_local()
183        self.assertEqual(inp, output)
184
185    @with_comms
186    def test_prepare_module_output(self):
187        module = DummyModule()
188        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
189        parallelize_module(
190            module,
191            device_mesh,
192            PrepareModuleOutput(
193                output_layouts=Replicate(), desired_output_layouts=Shard(0)
194            ),
195        )
196        torch.manual_seed(15)
197        inp = torch.rand(16, 7, device=self.device_type)
198        dtensor = DTensor.from_local(inp, device_mesh, [Replicate()], run_check=False)
199        output = module(dtensor)
200        inp = dtensor.redistribute(device_mesh, [Shard(0)]).to_local()
201        self.assertEqual(inp, output)
202
203    @with_comms
204    def test_parallelize_module_with_star(self):
205        inp_size = [12, 10]
206        model = MLPModule(self.device_type)
207        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
208
209        model_tp = deepcopy(model)
210        model_tp = parallelize_module(
211            model_tp,
212            device_mesh,
213            {
214                "net*": ColwiseParallel(output_layouts=Replicate()),
215            },
216        )
217        self._compare_module(model, model_tp, inp_size, rank0_only=False)
218
219    @with_comms
220    def test_parallelize_module_with_question(self):
221        inp_size = [12, 10]
222        model = MLPModule(self.device_type)
223        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
224
225        model_tp = deepcopy(model)
226        model_tp = parallelize_module(
227            model_tp,
228            device_mesh,
229            {
230                "net?": ColwiseParallel(output_layouts=Replicate()),
231            },
232        )
233        self._compare_module(model, model_tp, inp_size, rank0_only=False)
234
235    @with_comms
236    def test_parallelize_module_with_digit(self):
237        inp_size = [12, 10]
238        model = MLPModule(self.device_type)
239        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
240
241        model_tp = deepcopy(model)
242        model_tp = parallelize_module(
243            model_tp,
244            device_mesh,
245            {
246                "net[1-2]": ColwiseParallel(output_layouts=Replicate()),
247            },
248        )
249        self._compare_module(model, model_tp, inp_size, rank0_only=False)
250
251    @with_comms
252    def test_parallelize_module_multi_wildcard(self):
253        inp_size = [12, 10]
254        model = MLPStacked(self.device_type, n_layers=2)
255        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
256
257        model_tp = deepcopy(model)
258        model_tp = parallelize_module(
259            model_tp,
260            device_mesh,
261            {
262                "layers.*.net[1]": ColwiseParallel(),
263                "layers.*.net[2]": RowwiseParallel(),
264            },
265        )
266        self._compare_module(model, model_tp, inp_size, rank0_only=False)
267
268
269if __name__ == "__main__":
270    run_tests()
271