xref: /aosp_15_r20/external/pytorch/test/distributed/test_device_mesh.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3import os
4
5import torch
6import torch.distributed._functional_collectives as funcol
7from torch.distributed._tensor import DTensor
8from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
9from torch.distributed.distributed_c10d import (
10    _get_default_group,
11    _world,
12    get_global_rank,
13    get_world_size,
14    init_process_group,
15    is_initialized,
16    is_nccl_available,
17    ProcessGroup,
18)
19from torch.distributed.tensor._collective_utils import (
20    mesh_broadcast,
21    mesh_scatter,
22    unpad_tensor,
23)
24from torch.distributed.tensor.placement_types import _Partial, Shard
25from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
26from torch.testing._internal.common_utils import run_tests
27from torch.testing._internal.distributed._tensor.common_dtensor import (
28    DTensorTestBase,
29    with_comms,
30)
31from torch.testing._internal.distributed.fake_pg import FakeStore
32
33
34def _get_device_type(world_size):
35    if (
36        torch.cuda.is_available()
37        and torch.cuda.device_count() >= world_size
38        and is_nccl_available()
39    ):
40        device_type = "cuda"
41    else:
42        device_type = "cpu"
43    return device_type
44
45
46def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0):
47    os.environ["MASTER_ADDR"] = addr
48    os.environ["MASTER_PORT"] = port
49    os.environ["WORLD_SIZE"] = f"{world_size}"
50    os.environ["RANK"] = f"{rank}"
51
52
53class DeviceMeshTestGlooBackend(DTensorTestBase):
54    @property
55    def backend(self):
56        return "gloo"
57
58    @with_comms
59    def test_device_mesh_reuse_default_group(self):
60        mesh = init_device_mesh(self.device_type, (self.world_size,))
61        mesh_group = mesh.get_group()
62        default_group = _get_default_group()
63        if torch.cuda.is_available():
64            self.assertNotEqual(mesh_group, default_group)
65            self.assertEqual(get_world_size(mesh_group), get_world_size(default_group))
66        else:
67            self.assertEqual(mesh_group, default_group)
68
69
70class DeviceMeshTest(DTensorTestBase):
71    @property
72    def world_size(self):
73        return 4
74
75    def test_init_process_group(self):
76        device_type = _get_device_type(self.world_size)
77        mesh_tensor = torch.arange(4).reshape(2, 2)
78        self.assertTrue(not is_initialized())
79        _set_env_var(world_size=self.world_size, rank=self.rank)
80        DeviceMesh(device_type, mesh_tensor)
81        self.assertTrue(is_initialized())
82        self.destroy_pg()
83
84    @with_comms
85    @skip_if_lt_x_gpu(4)
86    def test_assert_invalid_mesh_tensor(self):
87        mesh = torch.arange(self.world_size).to(self.rank)
88        with self.assertRaises(ValueError):
89            device_mesh = DeviceMesh(self.device_type, mesh)
90
91    @with_comms
92    def test_get_group_and_get_all_groups(self):
93        mesh_shape = (2, self.world_size // 2)
94        mesh_2d = init_device_mesh(
95            self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
96        )
97
98        tp_mesh = mesh_2d["tp"]
99        dp_mesh = mesh_2d["dp"]
100
101        self.assertEqual(mesh_2d.get_group(0), mesh_2d.get_group("dp"))
102        self.assertEqual(mesh_2d.get_group(1), mesh_2d.get_group("tp"))
103
104        self.assertEqual(mesh_2d.get_group("dp"), dp_mesh.get_group())
105        self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group())
106
107        groups = mesh_2d.get_all_groups()
108        self.assertEqual(len(groups), 2)
109        self.assertTrue(tp_mesh.get_group() in groups)
110        self.assertTrue(dp_mesh.get_group() in groups)
111
112    @with_comms
113    def test_get_local_rank_raises_exception(self):
114        mesh_shape = (2, self.world_size // 2)
115        mesh_2d = init_device_mesh(
116            self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
117        )
118
119        with self.assertRaisesRegex(
120            RuntimeError,
121            "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
122        ):
123            local_rank = mesh_2d.get_local_rank()
124
125    @with_comms
126    def test_get_local_rank(self):
127        mesh_shape = (2, self.world_size // 2)
128        mesh_2d = init_device_mesh(
129            self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
130        )
131        self.assertEqual(mesh_2d.get_local_rank("dp"), mesh_2d.get_local_rank(0))
132        self.assertEqual(mesh_2d.get_local_rank("tp"), mesh_2d.get_local_rank(1))
133
134        dp_mesh = mesh_2d["dp"]
135        tp_mesh = mesh_2d["tp"]
136        self.assertEqual(dp_mesh.get_local_rank(), mesh_2d.get_local_rank("dp"))
137        self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp"))
138
139        # Verify flattened mesh local rank correctness.
140        flattened_mesh = mesh_2d["dp", "tp"]._flatten()
141        self.assertEqual(flattened_mesh.get_local_rank(), self.rank)
142
143    @with_comms
144    def test_device_mesh_2d(self):
145        mesh_tensor = torch.arange(4).reshape(2, 2)
146        # construct a cuda device mesh
147        mesh = DeviceMesh(self.device_type, mesh_tensor)
148
149        # check all dim groups
150        dim_to_subgroups = mesh.get_all_groups()
151
152        expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]
153        for dim, dim_group in enumerate(dim_to_subgroups):
154            self.assertTrue(dim < 2)
155            dim_ranks = expected_ranks_by_dim[dim]
156
157            dim_group_size = get_world_size(dim_group)
158            self.assertIsInstance(dim_group, ProcessGroup)
159            self.assertEqual(dim_group_size, 2)
160            global_ranks = [
161                get_global_rank(dim_group, i) for i in range(dim_group_size)
162            ]
163            current_rank_expected_group_ranks = (
164                dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1]
165            )
166            self.assertEqual(global_ranks, current_rank_expected_group_ranks)
167
168    @with_comms
169    def test_device_mesh_init_backend(self):
170        mesh = DeviceMesh(self.device_type, [1], _init_backend=False)
171
172        with self.assertRaisesRegex(RuntimeError, "process groups not initialized!"):
173            mesh.get_group()
174
175        # coordinates should always been populated when init_backend is False, as whenever
176        # we call init_backend we should make sure the default pg already created
177        mesh.get_coordinate()
178
179    def test_fake_pg_device_mesh(self):
180        fake_store = FakeStore()
181        init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)
182        device_type = "cuda" if torch.cuda.is_available() else "cpu"
183        mesh = DeviceMesh(device_type, torch.arange(self.world_size))
184
185        local_tensor = torch.randn(2, 8)
186        global_tensor = funcol.all_gather_tensor(
187            local_tensor, gather_dim=0, group=(mesh, 0)
188        )
189        self.assertEqual(global_tensor.shape, (self.world_size * 2, 8))
190
191    @with_comms
192    def test_from_group_with_global_pg(self):
193        # Simple test: check `from_group` from a mesh pg vs. directly
194        # initializing via `init_device_mesh`
195        ref_global_mesh = init_device_mesh(self.device_type, (self.world_size,))
196        mesh_pg = ref_global_mesh.get_group()
197        global_mesh = DeviceMesh.from_group(mesh_pg, self.device_type)
198        self.assertEqual(ref_global_mesh, global_mesh)
199        self.assertEqual(ref_global_mesh._dim_group_infos, global_mesh._dim_group_infos)
200        self.assertEqual(
201            ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim
202        )
203
204    @with_comms
205    def test_from_group_with_invalid_mesh(self):
206        global_pg = _get_default_group()
207        global_pg_size = global_pg.size()
208        assert global_pg_size == 4, "Test assumes global world size of 4"
209        invalid_mesh = [[0, 1], [2, 3]]  # 2D mesh when we need 1D
210        regex = r"Invalid mesh \[\[0, 1\], \[2, 3\]\] for ProcessGroup with ranks \[0, 1, 2, 3\]"
211        with self.assertRaisesRegex(ValueError, regex):
212            DeviceMesh.from_group(global_pg, "cuda", invalid_mesh)
213
214        device_mesh = init_device_mesh(self.device_type, (2, 2))
215        groups = device_mesh.get_all_groups()
216        invalid_mesh = (0, 1, 2, 3)  # 1D mesh when we need 2D
217        regex = r"Expects mesh with ndim equal to number of ProcessGroups but got mesh \[0, 1, 2, 3\] and 2 ProcessGroups"
218        with self.assertRaisesRegex(ValueError, regex):
219            DeviceMesh.from_group(groups, self.device_type, invalid_mesh)
220
221    def test_raises_invalid_device_type(self):
222        with self.assertRaisesRegex(
223            RuntimeError,
224            "Device type with GPU index is not supported",
225        ):
226            # test init_device_mesh with an invalid device type that contains a GPU index
227            mesh_shape = (2, self.world_size // 2)
228            mesh_2d = init_device_mesh(
229                "cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
230            )
231
232    @with_comms
233    def test_set_mesh_dim_group_options(self):
234        device_type = "cuda" if torch.cuda.is_available() else "cpu"
235        _mesh_resources._set_mesh_dim_group_options(1, "fake", None)
236
237        mesh_tensor = torch.arange(4).reshape(2, 2)
238        mesh = DeviceMesh(device_type, mesh_tensor)
239        self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake")
240
241
242class DeviceMeshTestNDim(DTensorTestBase):
243    @property
244    def world_size(self):
245        return 8
246
247    @with_comms
248    def test_device_mesh_nd(self):
249        # construct a cuda device mesh
250        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
251        mesh = DeviceMesh(self.device_type, mesh_tensor)
252
253        # check all dim groups
254        dim_to_subgroups = mesh.get_all_groups()
255
256        for dim, dim_group in enumerate(dim_to_subgroups):
257            self.assertTrue(dim < mesh_tensor.ndim)
258            dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2)
259
260            dim_group_size = get_world_size(dim_group)
261            self.assertIsInstance(dim_group, ProcessGroup)
262            self.assertEqual(dim_group_size, 2)
263            global_ranks = [
264                get_global_rank(dim_group, i) for i in range(dim_group_size)
265            ]
266            for ranks in dim_ranks:
267                if self.rank in ranks:
268                    self.assertEqual(global_ranks, ranks.tolist())
269
270    @with_comms
271    def test_device_mesh_hash(self):
272        mesh_tensor_2d = torch.arange(8).reshape(4, 2)
273        mesh = DeviceMesh(self.device_type, mesh_tensor_2d)
274        mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d)
275        self.assertEqual(hash(mesh), hash(mesh2))
276        mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2)
277        mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d)
278        self.assertNotEqual(hash(mesh), hash(mesh3))
279        self.assertNotEqual(hash(mesh2), hash(mesh3))
280
281    @with_comms
282    def test_get_local_rank_3d(self):
283        """
284        If we have a 3D mesh and we want to apply dp, pp, tp to it,
285        mesh_dim_names = ["dp", "pp", "tp"], and the mesh tensor would be:
286        mesh_3d_tensor = [
287            [
288                [0, 1],
289                [2, 3],
290            ],
291            [
292                [4, 5],
293                [6, 7],
294            ]
295
296        ]
297        """
298        mesh_shape = (2, 2, 2)
299        mesh_3d = init_device_mesh(
300            self.device_type, mesh_shape, mesh_dim_names=("dp", "pp", "tp")
301        )
302
303        # tp_rank_0: [0, 2, 4, 6], tp_rank_1: [1, 3, 5, 7]
304        tp_rank = mesh_3d.get_local_rank("tp")
305        expected_tp_rank = self.rank % 2
306        self.assertEqual(tp_rank, expected_tp_rank)
307
308        # pp_rank_0: [0, 1, 4, 5], pp_rank_1: [2, 3, 6, 7]
309        pp_rank = mesh_3d.get_local_rank("pp")
310        expected_pp_rank = 0 if self.rank % 4 <= 1 else 1
311        self.assertEqual(pp_rank, expected_pp_rank)
312
313        # dp_rank_0: [0, 1, 2, 3], dp_rank_1: [4, 5, 6, 7]
314        dp_rank = mesh_3d.get_local_rank("dp")
315        expected_dp_rank = self.rank // 4
316        self.assertEqual(dp_rank, expected_dp_rank)
317
318    @with_comms
319    def test_device_mesh_parent_child_hash(self):
320        mesh_2d = init_device_mesh(
321            self.device_type, (2, self.world_size // 2), mesh_dim_names=("DP", "TP")
322        )
323
324        mesh_group_1 = torch.arange(0, self.world_size // 2)
325        mesh_group_2 = torch.arange(self.world_size // 2, self.world_size)
326        ep_mesh_1 = DeviceMesh(self.device_type, mesh_group_1)
327        ep_mesh_2 = DeviceMesh(self.device_type, mesh_group_2)
328        ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2
329        # ep_mesh is considered different from mesh_2d["TP"]
330        self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list)
331        self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape)
332        self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type)
333        self.assertNotEqual(mesh_2d["TP"].mesh_dim_names, ep_mesh.mesh_dim_names)
334        self.assertEqual(mesh_2d["TP"]._thread_id, ep_mesh._thread_id)
335        self.assertNotEqual(hash(mesh_2d["TP"]), hash(ep_mesh))
336        self.assertNotEqual(mesh_2d["TP"], ep_mesh)
337
338        another_mesh_1 = DeviceMesh(self.device_type, mesh_group_1)
339        another_mesh_2 = DeviceMesh(self.device_type, mesh_group_2)
340        another_mesh = (
341            another_mesh_1 if self.rank < self.world_size // 2 else another_mesh_2
342        )
343        # another_mesh is considered the same as ep_mesh
344        self.assertEqual(ep_mesh._flatten_mesh_list, another_mesh._flatten_mesh_list)
345        self.assertEqual(ep_mesh.mesh.shape, another_mesh.mesh.shape)
346        self.assertEqual(ep_mesh.device_type, another_mesh.device_type)
347        self.assertEqual(ep_mesh.mesh_dim_names, another_mesh.mesh_dim_names)
348        self.assertEqual(ep_mesh._thread_id, another_mesh._thread_id)
349        self.assertEqual(hash(ep_mesh), hash(another_mesh))
350        self.assertEqual(ep_mesh, another_mesh)
351
352    @with_comms
353    def test_from_group_with_mesh_shape(self):
354        """Tests ``from_group`` when passing ``mesh_shape`` as 2D."""
355        # Consider two different logical views of the same mesh:
356        # - (4, 2) ("dp", "tp") mesh
357        # - (2, 2, 2) ("dp_replicate", "dp_shard", "tp") mesh
358        mesh_shape = (2, 2, 2)
359        mesh_dim_names = ("dp_replicate", "dp_shard", "tp")
360        ref_mesh = init_device_mesh(
361            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
362        )
363
364        dp_shard_group = ref_mesh["dp_shard"].get_group()
365        dp_replicate_group = ref_mesh["dp_replicate"].get_group()
366
367        dp_mesh = DeviceMesh.from_group(
368            [dp_replicate_group, dp_shard_group],
369            self.device_type,
370            mesh=ref_mesh.mesh[:, :, ref_mesh.get_local_rank(2)],
371            mesh_dim_names=mesh_dim_names[:2],
372        )
373
374        ref_mesh_dp_dim_group_infos = ref_mesh._dim_group_infos[:2]
375        for (_, ref_ranks, _), (_, ranks, _) in zip(
376            ref_mesh_dp_dim_group_infos, dp_mesh._dim_group_infos
377        ):
378            self.assertEqual(ref_ranks, ranks)
379        # Cannot check directly for mesh equality since parent meshes are not
380        # the same since the ref's parent mesh is 3D
381        self.assertEqual(dp_mesh["dp_replicate"].mesh, ref_mesh["dp_replicate"].mesh)
382        for (_, ref_ranks, _), (_, ranks, _) in zip(
383            dp_mesh["dp_replicate"]._dim_group_infos,
384            ref_mesh["dp_replicate"]._dim_group_infos,
385        ):
386            self.assertEqual(ref_ranks, ranks)
387        self.assertEqual(dp_mesh["dp_shard"].mesh, ref_mesh["dp_shard"].mesh)
388        for (_, ref_ranks, _), (_, ranks, _) in zip(
389            dp_mesh["dp_shard"]._dim_group_infos, ref_mesh["dp_shard"]._dim_group_infos
390        ):
391            self.assertEqual(ref_ranks, ranks)
392
393
394class InitDeviceMeshTest(DTensorTestBase):
395    @property
396    def world_size(self):
397        return 8
398
399    @with_comms
400    def test_init_device_mesh(self):
401        mesh_shape = (2, 4)
402        mesh_dim_names = ("DP", "TP")
403        ref_mesh = DeviceMesh(
404            self.device_type,
405            torch.arange(8).view(mesh_shape),
406            mesh_dim_names=mesh_dim_names,
407        )
408
409        # test init_device_mesh with mesh_dim_names
410        mesh_2d = init_device_mesh(
411            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
412        )
413        self.assertEqual(mesh_2d, ref_mesh)
414        self.assertEqual(mesh_2d.mesh_dim_names, mesh_dim_names)
415
416    @with_comms
417    def test_raises_duplicate_mesh_dim_names(self):
418        with self.assertRaisesRegex(
419            RuntimeError,
420            "Each mesh_dim_name must be unique.",
421        ):
422            mesh = init_device_mesh(
423                self.device_type,
424                (2, 4),
425                mesh_dim_names=["dp", "dp"],
426            )
427
428    @with_comms
429    def test_raises_mesh_shape_mesh_dim_names_mismatch(self):
430        with self.assertRaisesRegex(
431            RuntimeError,
432            "mesh_shape and mesh_dim_names should have same length!",
433        ):
434            mesh = init_device_mesh(
435                self.device_type,
436                (8,),
437                mesh_dim_names=["dp", "tp"],
438            )
439
440
441class TestDeviceMeshGetItem(DTensorTestBase):
442    @property
443    def world_size(self):
444        return 8
445
446    @with_comms
447    def test_raises_no_mesh_dim_found(self):
448        with self.assertRaisesRegex(
449            RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!"
450        ):
451            mesh = init_device_mesh(self.device_type, (2, 4))
452            child_mesh = mesh["DP"]
453
454    @with_comms
455    def test_raises_invalid_mesh_dim_name(self):
456        child_mesh_dim_name = ("PP",)
457        with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"):
458            mesh_dim_names = ("DP", "TP")
459            mesh = init_device_mesh(
460                self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
461            )
462            child_mesh = mesh[child_mesh_dim_name]
463
464    @with_comms
465    def test_get_item_2d(self):
466        mesh_shape = (2, 4)
467        mesh_dim_names = ("DP", "TP")
468        mesh_2d = init_device_mesh(
469            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
470        )
471
472        pg_ranks_by_dim_name = {}
473        for mesh_dim_name in mesh_dim_names:
474            mesh_dim = mesh_dim_names.index(mesh_dim_name)
475            pg_ranks_by_dim_name[mesh_dim_name] = mesh_2d.mesh.swapdims(
476                -1, mesh_dim
477            ).reshape(-1, mesh_2d.mesh.size(mesh_dim))
478
479        tp_mesh = mesh_2d["TP"]
480        tp_group_idx = self.rank // 4
481        self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx])
482
483        dp_mesh = mesh_2d["DP"]
484        dp_group_idx = self.rank % 4
485        self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx])
486
487    @with_comms
488    def test_get_item_1d(self):
489        mesh = init_device_mesh(self.device_type, (8,), mesh_dim_names=("dp",))
490        # Make sure slicing out 1D mesh from a 1D mesh works.
491        dp_mesh = mesh["dp"]
492        self.assertEqual(dp_mesh, mesh)
493
494        with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"):
495            dp_mesh = mesh["dim0"]
496
497    @with_comms
498    def test_get_item_3d(self):
499        mesh_shape = (2, 2, 2)
500        mesh_dim_names = ("Replicate", "Shard", "TP")
501        mesh_3d = init_device_mesh(
502            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
503        )
504
505        tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]]
506        tp_group_idx = int(self.rank / 2)
507        self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx])
508
509        shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]]
510        shard_group_idx = self.rank % 2 + self.rank // 4 * 2
511        self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx])
512
513        replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]]
514        replicate_group_idx = self.rank % 4
515        self.assertEqual(
516            mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx]
517        )
518
519        # We support both UX for nD slicing.
520        # mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"]
521        hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]]
522        hsdp_mesh_2 = mesh_3d["Replicate", "Shard"]
523        hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]]
524        hsdp_group_idx = self.rank % 2
525        self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx])
526        self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx])
527        self.assertEqual(hsdp_mesh_1, hsdp_mesh_2)
528
529    @with_comms
530    def test_cache_and_reuse_submesh_slice_result(self):
531        mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp"))
532
533        dp_mesh = mesh["dp"]
534        ref_pg_count = _world.group_count
535
536        # When we call the "dp" slice second time, it should not create any new pg.
537        # As we are just using the cached result so the pg count should be the same.
538        dp_mesh_2 = mesh["dp"]
539        self.assertEqual(ref_pg_count, _world.group_count)
540
541        # When we call the "tp" slice, it should not create a new pg, as the "tp" slice would
542        # just reuse the parent mesh pg.
543        tp_mesh = mesh["tp"]
544        self.assertEqual(_world.group_count, ref_pg_count)
545
546    @with_comms
547    def test_get_item_3d_noncontiguous_slicing(self):
548        mesh_shape = (2, 2, 2)
549        mesh_dim_names = ("dp", "pp", "cp")
550        mesh_3d = init_device_mesh(
551            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
552        )
553
554        # Slice order simply decides which mesh_dim sits on which mesh_dim.
555        # For dp_cp_mesh, cp mesh is the innermost dimension.
556        dp_cp_mesh = mesh_3d["dp", "cp"]
557        expected_mesh_tensor = (
558            torch.tensor([[0, 1], [4, 5]], dtype=torch.int)
559            if self.rank in (0, 1, 4, 5)
560            else torch.tensor([[2, 3], [6, 7]], dtype=torch.int)
561        )
562        dp_local_rank = dp_cp_mesh.get_local_rank("dp")
563        self.assertEqual(dp_cp_mesh.mesh, expected_mesh_tensor)
564        cp_mesh = mesh_3d["cp"]
565        # Check on the current dp_local_rank, whether the cp mesh tensor is the same.
566        self.assertEqual(dp_cp_mesh.mesh[dp_local_rank], cp_mesh.mesh)
567
568        with self.assertRaisesRegex(
569            KeyError,
570            "Invalid mesh_dim_names",
571        ):
572            cp_dp_mesh = mesh_3d["cp", "dp"]
573
574    @with_comms
575    def test_flatten_mesh(self):
576        mesh_shape = (2, 2, 2)
577        mesh_dim_names = ("dp", "cp", "tp")
578        mesh_3d = init_device_mesh(
579            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
580        )
581
582        # Test flatten contiguous dims
583        dp_cp_mesh = mesh_3d["dp", "cp"]
584        flattened_dp_cp_mesh = dp_cp_mesh._flatten()
585        self.assertEqual(dp_cp_mesh.mesh.flatten(), flattened_dp_cp_mesh.mesh)
586        self.assertEqual(flattened_dp_cp_mesh.mesh_dim_names[0], "dp_cp")
587        root_mesh = _mesh_resources.get_root_mesh(dp_cp_mesh)
588        self.assertEqual(root_mesh, mesh_3d)
589        flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][
590            "dp_cp"
591        ]
592        self.assertEqual(flatten_mesh_root_dims, (0, 1))
593
594        ref_pg_count = _world.group_count
595        # Calling flatten again should not create a new pg.
596        flattened_dp_cp_mesh_2 = dp_cp_mesh._flatten()
597        self.assertEqual(flattened_dp_cp_mesh, flattened_dp_cp_mesh_2)
598        self.assertEqual(ref_pg_count, _world.group_count)
599
600        # Test flatten non-contiguous dims
601        dp_tp_mesh = mesh_3d["dp", "tp"]
602        flattened_dp_tp_mesh = dp_tp_mesh._flatten()
603        self.assertEqual(dp_tp_mesh.mesh.flatten(), flattened_dp_tp_mesh.mesh)
604        self.assertEqual(flattened_dp_tp_mesh.mesh_dim_names[0], "dp_tp")
605        root_mesh = _mesh_resources.get_root_mesh(dp_tp_mesh)
606        self.assertEqual(root_mesh, mesh_3d)
607        flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][
608            "dp_tp"
609        ]
610        self.assertEqual(flatten_mesh_root_dims, (0, 2))
611
612        # Test flatten with a flattened mesh_dim_name
613        cp_tp_mesh = mesh_3d["cp", "tp"]
614        cp_tp_mesh._flatten("dummy")
615        self.assertEqual(mesh_3d["dummy"].mesh_dim_names[0], "dummy")
616
617    @with_comms
618    def test_reconstruct_mesh_with_flatten_dim(self):
619        mesh_3d = init_device_mesh(
620            self.device_type, (2, 2, 2), mesh_dim_names=("replicate", "shard", "cp")
621        )
622        shard_cp_mesh = mesh_3d["shard", "cp"]._flatten()
623        hsdp_mesh = mesh_3d["replicate", "shard_cp"]
624        expected_mesh_tensor = torch.tensor(
625            [[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int
626        )
627        self.assertEqual(hsdp_mesh.mesh, expected_mesh_tensor)
628        self.assertEqual(shard_cp_mesh.get_group(), mesh_3d["shard_cp"].get_group())
629        self.assertEqual(
630            shard_cp_mesh.get_group(), mesh_3d.get_group(mesh_dim="shard_cp")
631        )
632
633        mesh_3d = init_device_mesh(
634            self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp")
635        )
636        dp_cp_mesh = mesh_3d["dp", "cp"]._flatten()
637        spmd_mesh = mesh_3d["dp_cp", "tp"]
638        expected_mesh_tensor = torch.tensor(
639            [[0, 1], [2, 3], [4, 5], [6, 7]], dtype=torch.int
640        )
641        self.assertEqual(spmd_mesh.mesh, expected_mesh_tensor)
642        self.assertEqual(dp_cp_mesh.get_group(), mesh_3d["dp_cp"].get_group())
643        self.assertEqual(dp_cp_mesh.get_group(), mesh_3d.get_group(mesh_dim="dp_cp"))
644
645
646class TestMeshEnv(DTensorTestBase):
647    @property
648    def world_size(self):
649        return 8
650
651    @with_comms
652    def test_get_root_mesh(self):
653        mesh_3d = init_device_mesh(
654            self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp")
655        )
656
657        dp_cp_mesh = mesh_3d["dp", "cp"]
658        dp_tp_mesh = mesh_3d["dp", "tp"]
659        cp_tp_mesh = mesh_3d["cp", "tp"]
660        dp_mesh = mesh_3d["dp"]
661        cp_mesh = mesh_3d["cp"]
662        tp_mesh = mesh_3d["tp"]
663        self.assertEqual(_mesh_resources.get_root_mesh(dp_cp_mesh), mesh_3d)
664        self.assertEqual(_mesh_resources.get_root_mesh(dp_tp_mesh), mesh_3d)
665        self.assertEqual(_mesh_resources.get_root_mesh(cp_tp_mesh), mesh_3d)
666        self.assertEqual(_mesh_resources.get_root_mesh(dp_mesh), mesh_3d)
667        self.assertEqual(_mesh_resources.get_root_mesh(cp_mesh), mesh_3d)
668        self.assertEqual(_mesh_resources.get_root_mesh(tp_mesh), mesh_3d)
669
670    @with_comms
671    def test_get_root_mesh_dim_exist(self):
672        mesh_shape = (2, self.world_size // 2)
673        mesh_dim_names = ("DP", "TP")
674        mesh_2d = init_device_mesh(
675            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
676        )
677
678        self.assertEqual(_mesh_resources.get_root_mesh_dim(mesh_2d["DP"]), 0)
679        self.assertEqual(_mesh_resources.get_root_mesh_dim(mesh_2d["TP"]), 1)
680
681    @with_comms
682    def test_get_root_mesh_dim_not_exist(self):
683        mesh_shape = (self.world_size,)
684        mesh = init_device_mesh(self.device_type, mesh_shape)
685
686        self.assertEqual(_mesh_resources.get_root_mesh_dim(mesh), None)
687
688    @with_comms
689    def test_get_mesh_dim_by_name(self):
690        mesh_shape = (2, self.world_size // 2)
691        mesh_dim_names = ("DP", "TP")
692        mesh_2d = init_device_mesh(
693            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
694        )
695
696        self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "DP"), 0)
697        self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "TP"), 1)
698
699    @with_comms
700    def test_get_all_submeshes(self):
701        mesh_2d = init_device_mesh(
702            self.device_type, (2, 4), mesh_dim_names=("replicate", "shard")
703        )
704        all_submeshes = _mesh_resources._get_all_submeshes(mesh_2d, "replicate")
705        self.assertEqual(len(all_submeshes), 4)
706        self.assertEqual(
707            all(submesh.mesh.numel() == 2 for submesh in all_submeshes), True
708        )
709
710
711class DeviceMeshCollectiveTest(DTensorTestBase):
712    @property
713    def world_size(self):
714        return 8
715
716    @with_comms
717    def test_broadcast_1d(self):
718        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
719        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
720        mesh_broadcast(local_tensor, mesh, mesh_dim=0)
721        self.assertEqual(local_tensor, torch.zeros(3, 3))
722
723    @with_comms
724    def test_scatter_1d(self):
725        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
726        scatter_tensor_shape = [3, 3, 3]
727        for scatter_dim in range(len(scatter_tensor_shape)):
728            shard_placement = Shard(scatter_dim)
729            scatter_tensor_shape[scatter_dim] *= self.world_size
730            # make the random seed same across rank
731            torch.manual_seed(0)
732            global_tensor = torch.randn(scatter_tensor_shape, device=self.device_type)
733            splitted_list, _ = shard_placement._split_tensor(
734                global_tensor, mesh.size(), with_padding=True, contiguous=True
735            )
736            recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()])
737            # scatter on dim > 0 would generate non-contiguous tensor, verify that works
738            mesh_scatter(recv_tensor, splitted_list, mesh, mesh_dim=0)
739            self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()])
740
741    @with_comms
742    def test_scatter_uneven(self):
743        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
744        my_rank = device_mesh.get_rank()
745        tensor_to_split = torch.randn(
746            device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type
747        )
748
749        for shard_dim in range(tensor_to_split.ndim):
750            shard_placement = Shard(shard_dim)
751
752            tensor_to_scatter = tensor_to_split.clone()
753            tensor_splitted_list = list(
754                torch.chunk(tensor_to_split, self.world_size, dim=shard_dim)
755            )
756            for _ in range(self.world_size - len(tensor_splitted_list)):
757                tensor_splitted_list.append(torch.tensor([], device=self.device_type))
758
759            padded_tensor_list, pad_sizes = shard_placement._split_tensor(
760                tensor_to_scatter,
761                device_mesh.size(),
762                with_padding=True,
763                contiguous=True,
764            )
765
766            scattered_tensor = torch.empty_like(padded_tensor_list[my_rank])
767            mesh_scatter(scattered_tensor, padded_tensor_list, device_mesh, mesh_dim=0)
768
769            if pad_sizes[my_rank] != 0:
770                scattered_tensor = unpad_tensor(
771                    scattered_tensor, shard_dim, pad_sizes[my_rank]
772                )
773
774            if scattered_tensor.numel() == 0:
775                # We need to check numel() instead of size if a tensor is ([]) after unpadding,
776                # since the size could be ([0, 8]) after unpadding.
777                self.assertEqual(
778                    scattered_tensor.numel(), tensor_splitted_list[my_rank].numel()
779                )
780            else:
781                self.assertEqual(
782                    scattered_tensor.size(), tensor_splitted_list[my_rank].size()
783                )
784                self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank])
785
786    @with_comms
787    def test_all_gather_uneven(self):
788        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
789        my_rank = device_mesh.get_rank()
790        tensor_to_split = torch.ones(
791            device_mesh.size() + 3,
792            device_mesh.size() + 1,
793            device=self.device_type,
794        )
795
796        for shard_dim in range(tensor_to_split.ndim):
797            shard_placement = Shard(shard_dim)
798            tensor_padded_list, pad_sizes = shard_placement._split_tensor(
799                tensor_to_split,
800                device_mesh.size(),
801                with_padding=True,
802                contiguous=True,
803            )
804            local_tensor = tensor_padded_list[my_rank]
805            big_tensor = funcol.all_gather_tensor(
806                local_tensor, gather_dim=shard_dim, group=(device_mesh, 0)
807            )
808            big_tensor_chunks = list(
809                torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim)
810            )
811            unpadded_list = [
812                (
813                    unpad_tensor(big_tensor, shard_dim, pad_sizes[i])
814                    if pad_sizes[i] > 0
815                    else big_tensor
816                )
817                for i, big_tensor in enumerate(big_tensor_chunks)
818            ]
819            all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim)
820
821            self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size())
822            self.assertEqual(all_gathered_tensor, tensor_to_split)
823
824    @with_comms
825    def test_reduce_scatter_contiguous(self):
826        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
827        my_rank = device_mesh.get_rank()
828
829        # Init the tensor
830        step = self.world_size * 2
831        total_elem = step**2
832        tensor = torch.arange(0, total_elem).view(step, -1).to(device=self.device_type)
833        tensor = tensor * (my_rank + 1)
834
835        # Get non-contiguous tensor by slicing
836        tensor_to_reduce = tensor[::2, :2]
837        tensor_contiguous = tensor_to_reduce.clone().contiguous()
838
839        # Partial to Shard to trigger reduce_scatter
840        tensor_to_reduce = DTensor.from_local(
841            tensor_to_reduce, device_mesh, [_Partial()]
842        )
843        tensor_contiguous = DTensor.from_local(
844            tensor_contiguous, device_mesh, [_Partial()]
845        )
846        new_tensor = tensor_to_reduce.redistribute(device_mesh, [Shard(0)])
847        new_tensor_contiguous = tensor_contiguous.redistribute(device_mesh, [Shard(0)])
848
849        # The output for contiguous and non-contiguous tensors of the same value
850        # should return the same reducescatter value.
851        new_tensor_local = new_tensor._local_tensor
852        new_tensor_contiguous_local = new_tensor_contiguous._local_tensor
853        self.assertEqual(new_tensor_local, new_tensor_contiguous_local)
854        self.assertEqual(list(new_tensor_local.size()), [1, 2])
855
856        # Check the reduce numerical value
857        sum_base = (1 + self.world_size) * self.world_size / 2
858        first_elem = my_rank * sum_base * step * 2
859        expected_tensor = torch.tensor(
860            [[first_elem, first_elem + sum_base]],
861            dtype=new_tensor_local.dtype,
862            device=self.device_type,
863        )
864        self.assertEqual(new_tensor_local, expected_tensor)
865
866    @with_comms
867    def test_reduce_scatter_uneven(self):
868        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
869        my_rank = device_mesh.get_rank()
870        tensor_to_split = (
871            torch.ones(
872                device_mesh.size() + 3,
873                device_mesh.size() + 1,
874                device=self.device_type,
875            )
876            * self.rank
877        )
878
879        for shard_dim in range(tensor_to_split.ndim):
880            shard_placement = Shard(shard_dim)
881            tensor_to_scatter = tensor_to_split.clone()
882
883            tensor_splitted_list = list(
884                torch.chunk(tensor_to_split, self.world_size, dim=shard_dim)
885            )
886            for _ in range(self.world_size - len(tensor_splitted_list)):
887                tensor_splitted_list.append(torch.tensor([], device=self.device_type))
888
889            padded_tensor_list, pad_sizes = shard_placement._split_tensor(
890                tensor_to_scatter,
891                device_mesh.size(),
892                with_padding=True,
893                contiguous=True,
894            )
895
896            tensor_to_reduce = torch.cat(padded_tensor_list, shard_dim)
897
898            res_num = ((0 + self.world_size - 1) * self.world_size) / 2
899
900            scattered_tensor = funcol.reduce_scatter_tensor(
901                tensor_to_reduce,
902                reduceOp="sum",
903                scatter_dim=shard_dim,
904                group=(device_mesh, 0),
905            )
906
907            # unpad scattered_tensor
908            if pad_sizes[my_rank] > 0:
909                scattered_tensor = unpad_tensor(
910                    scattered_tensor, shard_dim, pad_sizes[my_rank]
911                )
912
913            if scattered_tensor.numel() == 0:
914                # We need to check numel() instead of size if a tensor is ([]) after unpadding,
915                # since the size could be ([0, 8]) after unpadding.
916                self.assertEqual(
917                    scattered_tensor.numel(), tensor_splitted_list[my_rank].numel()
918                )
919            else:
920                self.assertEqual(
921                    scattered_tensor.size(), tensor_splitted_list[my_rank].size()
922                )
923                self.assertEqual(
924                    scattered_tensor,
925                    torch.ones_like(tensor_splitted_list[my_rank]) * res_num,
926                )
927
928    @with_comms
929    def test_broadcast_nd(self):
930        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
931        mesh = DeviceMesh(self.device_type, mesh_tensor)
932        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
933
934        # check all dim groups
935        dim_to_subgroups = mesh.get_all_groups()
936        for dim, dim_group in enumerate(dim_to_subgroups):
937            dim_group_size = get_world_size(dim_group)
938            global_ranks = [
939                get_global_rank(dim_group, i) for i in range(dim_group_size)
940            ]
941            cloned_local_tensor = local_tensor.clone()
942            mesh_broadcast(cloned_local_tensor, mesh, mesh_dim=dim)
943            res_num = global_ranks[0]
944            self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num)
945
946    @with_comms
947    def test_scatter_nd(self):
948        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
949        mesh = DeviceMesh(self.device_type, mesh_tensor)
950
951        # check all dim groups
952        dim_to_subgroups = mesh.get_all_groups()
953        for dim, dim_group in enumerate(dim_to_subgroups):
954            dim_group_size = get_world_size(dim_group)
955            global_ranks = [
956                get_global_rank(dim_group, i) for i in range(dim_group_size)
957            ]
958            scattered_tensors = [
959                torch.ones(3, 3, device=self.device_type) * global_rank
960                for global_rank in global_ranks
961            ]
962            received_tensor = torch.empty_like(
963                scattered_tensors[mesh.get_coordinate()[dim]]
964            )
965            mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim)
966            self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)
967
968
969if __name__ == "__main__":
970    run_tests()
971