xref: /aosp_15_r20/external/pytorch/test/distributed/_tensor/debug/test_comm_mode.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import torch
4import torch.distributed as dist
5import torch.distributed._functional_collectives as funcol
6import torch.nn as nn
7from torch.distributed._tensor import DeviceMesh, DTensor
8from torch.distributed._tensor.placement_types import Shard
9from torch.distributed.tensor.debug import CommDebugMode
10from torch.testing._internal.common_distributed import requires_nccl
11from torch.testing._internal.common_utils import run_tests, TestCase
12from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
13from torch.testing._internal.distributed.fake_pg import FakeStore
14
15
16c10d_functional = torch.ops.c10d_functional
17c10d_ops = torch.ops.c10d
18
19
20class TestCommMode(TestCase):
21    def tearDown(self):
22        super().tearDown()
23        dist.destroy_process_group()
24
25    def setUp(self):
26        super().setUp()
27        self.world_size = 2
28        store = FakeStore()
29        dist.init_process_group(
30            backend="fake", rank=1, world_size=self.world_size, store=store
31        )
32        self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
33        self.world_pg = dist.distributed_c10d._get_default_group()
34
35    def checksAssert(self, comm_mode, key, expected_value, expected_total_value):
36        comm_counts = comm_mode.get_comm_counts()
37        self.assertEqual(comm_mode.get_total_counts(), expected_total_value)
38        self.assertEqual(comm_counts[key], expected_value)
39
40        return
41
42    def test_comm_mode(self):
43        world_pg = self.world_pg
44
45        class WrapperModel(nn.Module):
46            def __init__(self, device):
47                super().__init__()
48                self.model = MLPModule(device=device)
49
50            def forward(self, x):
51                x = funcol.all_gather_tensor(x, 0, world_pg)
52                x = funcol.reduce_scatter_tensor(x, "sum", 0, world_pg)
53                out = self.model(x)
54                return funcol.all_reduce(out, "sum", world_pg)
55
56        model = WrapperModel(self.device_type)
57
58        comm_mode = CommDebugMode()
59        with comm_mode:
60            model(torch.randn(20, 10, device=self.device_type))
61
62        comm_counts = comm_mode.get_comm_counts()
63        self.assertEqual(comm_mode.get_total_counts(), 3)
64        self.assertEqual(comm_counts[c10d_functional.all_reduce], 1)
65        self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
66        self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1)
67
68    def test_comm_mode_coalesced(self):
69        world_pg = self.world_pg
70
71        class WrapperModelCoalesced(nn.Module):
72            def __init__(self, device):
73                super().__init__()
74                self.model = MLPModule(device=device)
75
76            def forward(self, x):
77                x = funcol.all_gather_tensor(x, 0, world_pg)
78                x = funcol.reduce_scatter_tensor(x, "sum", 0, world_pg)
79                out = self.model(x)
80                return funcol.all_reduce_coalesced([out], "sum", world_pg)
81
82        model = WrapperModelCoalesced(self.device_type)
83
84        comm_mode = CommDebugMode()
85        with comm_mode:
86            model(torch.randn(20, 10, device=self.device_type))
87
88        comm_counts = comm_mode.get_comm_counts()
89        self.assertEqual(comm_mode.get_total_counts(), 3)
90        self.assertEqual(comm_counts[c10d_functional.all_reduce_coalesced], 1)
91        self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
92        self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1)
93
94    def test_comm_mode_with_dtensor(self):
95        world_pg = self.world_pg
96        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
97
98        def f(x, y):
99            return torch.mm(x, y)
100
101        comm_mode = CommDebugMode()
102        x = torch.randn(4, 8, requires_grad=True)
103        y = torch.randn(4, 32, requires_grad=True)
104        x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
105        y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
106
107        with comm_mode:
108            f(x_dtensor, y_dtensor)
109
110        comm_counts = comm_mode.get_comm_counts()
111        self.assertEqual(comm_mode.get_total_counts(), 1)
112        self.assertEqual(comm_counts[c10d_functional.all_reduce], 0)
113        self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
114        self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0)
115
116    @requires_nccl()
117    def test_comm_mode_with_c10d(self):
118        if not torch.cuda.is_available():
119            return
120
121        world_pg = self.world_pg
122
123        inp = torch.rand(2, 8, 16).cuda()
124        all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
125
126        comm_mode = CommDebugMode()
127
128        # tests c10d all_reduce tracing
129        with comm_mode:
130            dist.all_reduce(inp)
131
132        self.checksAssert(comm_mode, c10d_ops.allreduce_, 1, 1)
133
134        # tests c10d all_gather_into_tensor tracing
135        with comm_mode:
136            dist.all_gather_into_tensor(all_gather_out, inp)
137
138        self.checksAssert(comm_mode, c10d_ops._allgather_base_, 1, 1)
139
140        # tests c10d reduce_scatter tracing
141        with comm_mode:
142            dist.reduce_scatter_tensor(inp, all_gather_out)
143
144        self.checksAssert(comm_mode, c10d_ops._reduce_scatter_base_, 1, 1)
145
146        # tests c10d broadcast tracing
147        with comm_mode:
148            dist.broadcast(inp, 0)
149
150        self.checksAssert(comm_mode, c10d_ops.broadcast_, 1, 1)
151
152        # tests c10d gather tracing
153        with comm_mode:
154            dist.gather(inp, None, 0)
155
156        self.checksAssert(comm_mode, c10d_ops.gather_, 1, 1)
157
158        # tests c10d reduce tracing
159        with comm_mode:
160            dist.reduce(inp, 0)
161
162        self.checksAssert(comm_mode, c10d_ops.reduce_, 1, 1)
163
164        # tests c10d scatter tracing
165        with comm_mode:
166            dist.scatter(inp, None, 0)
167
168        self.checksAssert(comm_mode, c10d_ops.scatter_, 1, 1)
169
170        # tests c10d all_gather tracing
171        output_list = []
172
173        with comm_mode:
174            dist.all_gather(output_list, inp, None)
175
176        self.checksAssert(comm_mode, c10d_ops.allgather_, 1, 1)
177
178        # tests c10d allgather_coalesced_ tracing
179        output_list = []
180
181        with comm_mode:
182            dist.all_gather_coalesced(output_list, [inp], None)
183
184        self.checksAssert(comm_mode, c10d_ops.allgather_coalesced_, 1, 1)
185
186        # tests c10d allgather_into_tensor_coalesced_ tracing
187        with comm_mode, dist._coalescing_manager():
188            dist.all_gather_into_tensor(all_gather_out, inp)
189
190        self.checksAssert(comm_mode, c10d_ops.allgather_into_tensor_coalesced_, 1, 1)
191
192        # tests c10d allreduce_coalesced
193        with comm_mode:
194            dist.all_reduce_coalesced(inp)
195
196        self.checksAssert(comm_mode, c10d_ops.allreduce_coalesced_, 1, 1)
197
198        # tests c10d reduce_scatter_
199        with comm_mode:
200            dist.reduce_scatter(all_gather_out, [inp])
201
202        self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1)
203
204        # tests c10d reduce_scatter_tensor_coalesced
205        with comm_mode as A, dist._coalescing_manager() as B:
206            dist.reduce_scatter_tensor(all_gather_out, inp)
207
208        self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1)
209
210        # tests c10d alltoall_
211        with comm_mode:
212            dist.all_to_all([inp], [inp])
213
214        self.checksAssert(comm_mode, c10d_ops.alltoall_, 1, 1)
215
216        # tests c10d alltoall_base_
217        with comm_mode:
218            dist.all_to_all_single(inp, inp)
219
220        self.checksAssert(comm_mode, c10d_ops.alltoall_base_, 1, 1)
221
222
223if __name__ == "__main__":
224    run_tests()
225