xref: /aosp_15_r20/external/pytorch/test/distributed/_tensor/experimental/test_local_map.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3from functools import partial
4
5import torch
6import torch.distributed._functional_collectives as funcol
7from torch.distributed._tensor import (
8    distribute_tensor,
9    DTensor,
10    init_device_mesh,
11    Replicate,
12    Shard,
13)
14from torch.distributed._tensor.experimental import local_map
15from torch.distributed.tensor.debug import CommDebugMode
16from torch.testing._internal.common_utils import run_tests
17from torch.testing._internal.distributed._tensor.common_dtensor import (
18    DTensorTestBase,
19    with_comms,
20)
21
22
23funcol_py = torch.ops.c10d_functional
24
25
26row_wise = [Shard(0)]  # row-wise sharding placements on 1-d mesh
27col_wise = [Shard(1)]  # col-wise sharding placements on 1-d mesh
28replicate = [Replicate()]  # replicate placements on 1-d mesh
29
30
31def equal_allgather_forward(device_mesh, X, Y):
32    eq = torch.tensor([torch.equal(X, Y)], device=X.device)
33    eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh)
34    return torch.all(eq_gather).item()
35
36
37def mm_all_gather_forward(device_mesh, A, B):
38    local_mm_result = torch.mm(A, B)
39    return funcol.all_gather_tensor(local_mm_result, 0, device_mesh).wait()
40
41
42def mm_forward(A, B):  # no device mesh needed since we don't do collective
43    return torch.mm(A, B)
44
45
46def mm_allreduce_forward(device_mesh, A, B):
47    partial_sum_tensor = torch.mm(A, B)
48    return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
49
50
51@partial(
52    local_map,
53    out_placements=replicate,
54    in_placements=(None, col_wise, row_wise),
55)
56def mm_allreduce_forward_decorated(device_mesh, A, B):
57    partial_sum_tensor = torch.mm(A, B)
58    return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
59
60
61def mul_forward(X, scalar):  # no device mesh needed since we don't do collective
62    return torch.mul(X, scalar)
63
64
65class TestLocalMap(DTensorTestBase):
66    @property
67    def world_size(self):
68        return 2
69
70    # simple correctness check
71    @with_comms
72    def test_local_map_correctness(self):
73        device_mesh = init_device_mesh(
74            device_type=self.device_type, mesh_shape=(self.world_size,)
75        )
76        comm_mode = CommDebugMode()
77
78        # Y = X @ W
79        X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
80        W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
81        Y = torch.mm(X, W)
82
83        X_dt = distribute_tensor(
84            X, device_mesh, col_wise
85        )  # col-wisely sharded X tensor
86        W_dt = distribute_tensor(
87            W, device_mesh, row_wise
88        )  # row-wisely sharded W tensor
89
90        # Test 1: use the function returned from calling local_map
91        # get the function wrapped with DTensor/Tensor convertion
92        # mm_allreduce_forward is a function that applies to Tensors with manual collective
93        # local_mm_allreduce_forward is the function that does the same but applies to
94        # DTensors' `_local_tensor`.
95        local_mm_allreduce_forward = local_map(
96            mm_allreduce_forward,
97            out_placements=replicate,
98            in_placements=(None, col_wise, row_wise),
99            device_mesh=device_mesh,
100        )
101        with comm_mode:
102            Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
103
104        # output redistribution to Replicate
105        self.assertEqual(comm_mode.get_total_counts(), 1)
106        # check output placements
107        for placement in Y_dt.placements:
108            self.assertTrue(placement.is_replicate())
109        # check output value
110        self.assertEqual(Y_dt.to_local(), Y)
111
112        # Test 2: use the local_map decorator
113        with comm_mode:
114            Y_dt = mm_allreduce_forward_decorated(device_mesh, X_dt, W_dt)
115
116        # output redistribution to Replicate
117        self.assertEqual(comm_mode.get_total_counts(), 1)
118        # check output placements
119        for placement in Y_dt.placements:
120            self.assertTrue(placement.is_replicate())
121        # check output value
122        self.assertEqual(Y_dt.to_local(), Y)
123
124    # check for `out_placements`
125    @with_comms
126    def test_local_map_out_placements(self):
127        # Test 1: wrap out into DTensor w/ `out_placements`
128        device_mesh = init_device_mesh(
129            device_type=self.device_type, mesh_shape=(self.world_size,)
130        )
131        comm_mode = CommDebugMode()
132
133        # X.equal(Y)
134        X = torch.randn(8, 8, device=self.device_type, requires_grad=False)
135        Y = torch.randn(8, 8, device=self.device_type, requires_grad=False)
136        X_dt = distribute_tensor(X, device_mesh, row_wise)
137        Y_dt = distribute_tensor(Y, device_mesh, row_wise)
138        local_equal_allgather_forward = local_map(
139            equal_allgather_forward,
140            out_placements=None,
141        )
142        with comm_mode:
143            equal_dt = local_equal_allgather_forward(device_mesh, X_dt, Y_dt)  # a bool
144
145        self.assertEqual(comm_mode.get_total_counts(), 1)
146        self.assertTrue(not equal_dt)
147        self.assertTrue(not (X.equal(Y)))
148
149        # Test 2: directly return out if no argument is DTensor
150        # matmul in DDP
151        X = torch.randn(
152            4 // self.world_size, 4, device=self.device_type, requires_grad=False
153        )
154        W = torch.randn(4, 4, device=self.device_type, requires_grad=False)
155        local_mm_all_gather_forward = local_map(
156            mm_all_gather_forward,
157            out_placements=row_wise,
158            in_placements=(None, row_wise, replicate),
159        )
160        with comm_mode:
161            Y = local_mm_all_gather_forward(device_mesh, X, W)
162
163        self.assertEqual(comm_mode.get_total_counts(), 1)
164        self.assertEqual(
165            comm_mode.get_comm_counts()[funcol_py.all_gather_into_tensor], 1
166        )
167        X_replicate = funcol.all_gather_tensor(X, 0, device_mesh).wait()
168        Y_replicate = torch.mm(X_replicate, W)
169        self.assertEqual(Y, Y_replicate)  # Y is a torch.Tensor
170
171    # check for `in_placements` handling
172    @with_comms
173    def test_local_map_in_placements(self):
174        device_mesh = init_device_mesh(
175            device_type=self.device_type, mesh_shape=(self.world_size,)
176        )
177        comm_mode = CommDebugMode()
178
179        # Y = X @ W
180        X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
181        W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
182        Y = torch.mm(X, W)
183
184        X_dt = distribute_tensor(
185            X, device_mesh, row_wise
186        )  # row-wisely sharded X tensor
187        W_dt = distribute_tensor(W, device_mesh, replicate)  # replicate W tensor
188
189        # Test 1: explicitly pass `in_placements`
190        local_mm_forward = local_map(
191            mm_forward,
192            out_placements=row_wise,
193            in_placements=(row_wise, replicate),
194            device_mesh=device_mesh,
195        )
196        with comm_mode:
197            Y_dt = local_mm_forward(X_dt, W_dt)
198
199        # no communication should occur in this case
200        self.assertEqual(comm_mode.get_total_counts(), 0)
201        for placement in Y_dt.placements:
202            self.assertTrue(placement.is_shard(dim=0))
203        self.assertEqual(Y_dt.full_tensor(), Y)
204
205        # Test 2: `in_placements=None`
206        local_mm_forward = local_map(
207            mm_forward,
208            out_placements=row_wise,
209            device_mesh=device_mesh,
210        )
211        with comm_mode:
212            Y_dt = local_mm_forward(X_dt, W_dt)
213
214        self.assertEqual(comm_mode.get_total_counts(), 0)
215        for placement in Y_dt.placements:
216            self.assertTrue(placement.is_shard(dim=0))
217        self.assertEqual(Y_dt.full_tensor(), Y)
218
219        # Test 3: `None` placements for non-Tensor input argument
220        # Y = X * 2.0
221        local_mul_forward = local_map(
222            mul_forward,
223            in_placements=(row_wise, None),
224            out_placements=row_wise,
225            device_mesh=device_mesh,
226        )
227        Y = torch.mul(X, 2.0)
228        with comm_mode:
229            Y_dt = local_mul_forward(X_dt, 2.0)
230
231        self.assertEqual(comm_mode.get_total_counts(), 0)
232        for placement in Y_dt.placements:
233            self.assertTrue(placement.is_shard(dim=0))
234        self.assertEqual(Y_dt.full_tensor(), Y)
235
236        # Test 4: `None` placements for Tensor input argument
237        local_mm_forward = local_map(
238            mm_forward,
239            out_placements=None,
240            in_placements=(None, None),
241            device_mesh=device_mesh,
242        )
243        with comm_mode:
244            Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local())
245
246        self.assertEqual(comm_mode.get_total_counts(), 0)
247        self.assertEqual(
248            DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(),
249            torch.mm(X, W),
250        )
251
252        # Test 5: Some placements for Tensor input argument
253        local_mm_forward = local_map(
254            mm_forward,
255            out_placements=None,
256            in_placements=(replicate, row_wise),
257            device_mesh=device_mesh,
258        )
259        with comm_mode:
260            Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local())
261
262        self.assertEqual(comm_mode.get_total_counts(), 0)
263        self.assertEqual(
264            DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(),
265            torch.mm(X, W),
266        )
267
268        # Test 6: expect error - `None` placements for DTensor input argument
269        local_mm_forward = local_map(
270            mm_forward,
271            out_placements=row_wise,
272            in_placements=(row_wise, None),
273            device_mesh=device_mesh,
274        )
275        with self.assertRaisesRegex(AssertionError, "expects placements"):
276            Y_dt = local_mm_forward(X_dt, W_dt)
277
278    # check for `redistribute_inputs` handling
279    @with_comms
280    def test_local_map_redistribute(self):
281        device_mesh = init_device_mesh(
282            device_type=self.device_type, mesh_shape=(self.world_size,)
283        )
284        comm_mode = CommDebugMode()
285
286        # Y = X @ W
287        X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
288        W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
289        Y = torch.mm(X, W)
290
291        X_dt = distribute_tensor(
292            X, device_mesh, row_wise
293        )  # row-wisely sharded X tensor which will be redistributed
294        W_dt = distribute_tensor(
295            W, device_mesh, col_wise
296        )  # col-wisely sharded W tensor which will be redistributed
297
298        # Test 1: allow input redistribution
299        local_mm_allreduce_forward = local_map(
300            mm_allreduce_forward,
301            out_placements=replicate,
302            in_placements=(None, col_wise, row_wise),
303            device_mesh=device_mesh,
304            redistribute_inputs=True,
305        )
306        with comm_mode:
307            Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
308
309        # 2 for input redistribution and 1 for output
310        self.assertEqual(comm_mode.get_total_counts(), 3)
311        for placement in Y_dt.placements:
312            self.assertTrue(placement.is_replicate())
313        self.assertEqual(Y_dt.to_local(), Y)
314
315        # Test 2: no input redistribution is allowed
316        local_mm_allreduce_forward = local_map(
317            mm_allreduce_forward,
318            out_placements=replicate,
319            in_placements=(None, col_wise, row_wise),
320            device_mesh=device_mesh,
321            redistribute_inputs=False,
322        )
323        with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"):
324            Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
325
326
327if __name__ == "__main__":
328    run_tests()
329