xref: /aosp_15_r20/external/pytorch/test/distributed/_tensor/experimental/test_tp_transform.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2from collections import defaultdict
3from typing import Dict
4
5import torch
6from torch.distributed._tensor.experimental._tp_transform import (
7    tensor_parallel_transformation,
8)
9from torch.distributed.tensor.parallel.style import (
10    ColwiseParallel,
11    ParallelStyle,
12    RowwiseParallel,
13)
14from torch.testing._internal.common_utils import run_tests
15from torch.testing._internal.distributed._tensor.common_dtensor import (
16    DTensorTestBase,
17    with_comms,
18)
19
20
21class MLPListModule(torch.nn.Module):
22    """
23    A dummy model with list of MLPs.
24    """
25
26    def __init__(self, num_mlps=3, bias=True):
27        super().__init__()
28        self.mlps = torch.nn.ModuleList()
29        for _ in range(num_mlps):
30            self.mlps.append(
31                torch.nn.Sequential(
32                    torch.nn.Linear(6, 18),
33                    torch.nn.ReLU(),
34                    torch.nn.Linear(18, 6, bias=bias),
35                )
36            )
37
38    def forward(self, x: torch.Tensor) -> torch.Tensor:
39        x = torch.chunk(x, 2, dim=1)[0]
40        for mlp in self.mlps:
41            x = mlp(x)
42        return x + torch.ones_like(x)
43
44
45class DummyModel(torch.nn.Module):
46    def __init__(self) -> None:
47        super().__init__()
48        self.fc = torch.nn.Linear(3, 5)
49        self.bn = torch.nn.BatchNorm1d(5)
50
51    def forward(self, x):
52        return self.bn(self.fc(x))
53
54
55class TensorParallelTest(DTensorTestBase):
56    def setUp(self) -> None:
57        super().setUp()
58
59    def assert_has_c10d_ops(
60        self, gm: torch.fx.GraphModule, expected_ops_count: Dict[str, int]
61    ) -> None:
62        actual_ops_count: Dict[str, int] = defaultdict(int)
63        for node in gm.graph.nodes:
64            if node.op == "call_function":
65                if "c10d_functional" in str(node.target):
66                    actual_ops_count[str(node.target)] += 1
67        self.assertDictEqual(expected_ops_count, actual_ops_count)
68
69    @with_comms
70    def test_tp_transform_with_uncovered_op(self):
71        model = DummyModel().to(device=self.device_type)
72        inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),)
73        with torch.no_grad():
74            res = model(*inputs)
75            exported_program = torch.export.export(
76                model,
77                inputs,
78            ).run_decompositions()
79        tp_exported_program = tensor_parallel_transformation(
80            exported_program,
81            self.rank,
82            self.world_size,
83            self.device_type,
84            {"fc": ColwiseParallel},
85        )
86        tp_model = tp_exported_program.module()
87        with torch.no_grad():
88            tp_res = tp_model(*inputs)
89        self.assertEqual(res, tp_res)
90        # Expect all_gather to be inserted to distributed sharded fc resutls
91        self.assert_has_c10d_ops(
92            tp_exported_program.graph_module,
93            {
94                "_c10d_functional.all_gather_into_tensor.default": 1,
95                "_c10d_functional.wait_tensor.default": 1,
96            },
97        )
98
99    @with_comms
100    def test_tp_transform_e2e(self):
101        torch.manual_seed(0)
102        model = MLPListModule(2).to(device=self.device_type)
103        inputs = (torch.randn((10, 12)).to(device=self.device_type),)
104        parallel_strategies: Dict[str, ParallelStyle] = {
105            "mlps.0.0": ColwiseParallel,
106            "mlps.0.2": RowwiseParallel,
107            "mlps.1.0": ColwiseParallel,
108            "mlps.1.2": RowwiseParallel,
109        }
110
111        with torch.inference_mode():
112            res = model(*inputs)
113            exported_program = torch.export.export(
114                model,
115                inputs,
116            ).run_decompositions()
117        tp_exported_program = tensor_parallel_transformation(
118            exported_program,
119            self.rank,
120            self.world_size,
121            self.device_type,
122            parallel_strategies,
123        )
124        tp_model = tp_exported_program.module()
125        with torch.inference_mode():
126            tp_res = tp_model(*inputs)
127        self.assertEqual(res, tp_res)
128        # Expect all_reduce to be inserted at the end of each MLP
129        self.assert_has_c10d_ops(
130            tp_exported_program.graph_module,
131            {
132                "_c10d_functional.all_reduce.default": 2,
133                "_c10d_functional.wait_tensor.default": 2,
134            },
135        )
136
137    @with_comms
138    def test_tp_transform_no_bias(self):
139        torch.manual_seed(0)
140        model = MLPListModule(1, bias=False).to(device=self.device_type)
141        inputs = (torch.randn((10, 12)).to(device=self.device_type),)
142        parallel_strategies: Dict[str, ParallelStyle] = {
143            "mlps.0.0": ColwiseParallel,
144            "mlps.0.2": RowwiseParallel,
145        }
146
147        with torch.inference_mode():
148            res = model(*inputs)
149            exported_program = torch.export.export(
150                model,
151                inputs,
152            ).run_decompositions()
153        tp_exported_program = tensor_parallel_transformation(
154            exported_program,
155            self.rank,
156            self.world_size,
157            self.device_type,
158            parallel_strategies,
159        )
160        tp_model = tp_exported_program.module()
161        with torch.inference_mode():
162            tp_res = tp_model(*inputs)
163        self.assertEqual(res, tp_res)
164        self.assert_has_c10d_ops(
165            tp_exported_program.graph_module,
166            {
167                "_c10d_functional.all_reduce.default": 1,
168                "_c10d_functional.wait_tensor.default": 1,
169            },
170        )
171
172
173if __name__ == "__main__":
174    run_tests()
175