xref: /aosp_15_r20/external/pytorch/test/distributed/test_functional_api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5import unittest
6from functools import partial, wraps
7
8import torch
9import torch.distributed as dist
10import torch.distributed._functional_collectives as ft_c
11import torch.distributed._tensor as dt
12import torch.distributed.distributed_c10d as c10d
13from functorch import make_fx
14from torch._inductor.utils import run_and_get_code
15from torch.testing import FileCheck
16from torch.testing._internal.distributed.fake_pg import FakeStore
17from torch.utils._triton import has_triton
18
19
20if not dist.is_available():
21    print("Distributed not available, skipping tests", file=sys.stderr)
22    sys.exit(0)
23
24from torch.testing._internal.common_distributed import (
25    MultiProcessTestCase,
26    MultiThreadedTestCase,
27    requires_nccl,
28    TEST_SKIPS,
29)
30from torch.testing._internal.common_utils import (
31    instantiate_parametrized_tests,
32    parametrize,
33    run_tests,
34    TestCase,
35)
36
37
38def new_subgroups(group_size: int, pg_tag=None):
39    world_size = dist.get_world_size()
40    subgroups = []
41    cur_subgroup = None
42
43    for subgroup_id in range(world_size // group_size):
44        start_rank = subgroup_id * group_size
45        end_rank = start_rank + group_size
46        ranks_in_subgroup = list(range(start_rank, end_rank))
47        subgroup = c10d._new_group_with_tag(
48            ranks=ranks_in_subgroup,
49            pg_tag=pg_tag,
50        )
51        subgroups.append(subgroup)
52
53        rank = dist.get_rank()
54        if rank in ranks_in_subgroup:
55            cur_subgroup = subgroup
56
57    return cur_subgroup, subgroups
58
59
60class TestExpand(MultiThreadedTestCase):
61    @property
62    def world_size(self):
63        return 4
64
65    def setUp(self):
66        super().setUp()
67        self._spawn_threads()
68
69    def test_expand_1d_rank_list(self):
70        tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])
71        self.assertEqual("", tag)
72        self.assertEqual([0, 1, 2, 3], rankset)
73        self.assertEqual(4, group_size)
74
75        tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")
76        self.assertEqual("bla", tag)
77
78    def test_expand_2d_rank_list(self):
79        tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])
80        self.assertEqual("", tag)
81        self.assertEqual([0, 1, 2, 3], rankset)
82        self.assertEqual(2, group_size)
83
84        tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")
85        self.assertEqual("blu", tag)
86
87        with self.assertRaisesRegex(ValueError, "group sizes must be identical"):
88            ft_c._expand_group([[0], [1, 2, 3]])
89
90    def test_expand_process_group(self):
91        tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)
92        self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)
93        self.assertEqual([0, 1, 2, 3], rankset)
94        self.assertEqual(4, group_size)
95
96        tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
97        self.assertEqual("bla", tag)
98
99        my_pg, others = new_subgroups(group_size=2)
100        tag, rankset, group_size = ft_c._expand_group(my_pg)
101        self.assertEqual(c10d._get_group_tag(my_pg), tag)
102        self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
103        self.assertEqual(2, group_size)
104
105        my_pg = None
106        for i in range(dist.get_world_size()):
107            group = c10d._new_group_with_tag([i], pg_tag="my_pg")
108            if i == dist.get_rank():
109                my_pg = group
110        tag, rankset, group_size = ft_c._expand_group(my_pg)
111        self.assertEqual("my_pg", tag)
112        self.assertEqual([dist.get_rank()], rankset)
113        self.assertEqual(1, group_size)
114
115        tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")
116        self.assertEqual("bla", tag)
117
118    def test_expand_device_mesh(self):
119        mesh = dt.DeviceMesh("cpu", torch.arange(4))
120        tag, rankset, group_size = ft_c._expand_group(mesh)
121        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
122        self.assertEqual([0, 1, 2, 3], rankset)
123        self.assertEqual(4, group_size)
124
125        mesh = dt.DeviceMesh("cpu", torch.arange(4))
126        tag, rankset, group_size = ft_c._expand_group(mesh)
127        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
128        self.assertEqual([0, 1, 2, 3], rankset)
129        self.assertEqual(4, group_size)
130
131    def test_expand_device_mesh_tuple(self):
132        mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
133        with self.assertRaisesRegex(AssertionError, "Only 1D mesh"):
134            tag, rankset, group_size = ft_c._expand_group(mesh)
135
136        tag, rankset, group_size = ft_c._expand_group((mesh, 0))
137        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
138        expected_rankset = [0, 2] if dist.get_rank() in [0, 2] else [1, 3]
139        self.assertEqual(expected_rankset, rankset)
140        self.assertEqual(2, group_size)
141
142        tag, rankset, group_size = ft_c._expand_group((mesh, 1))
143        expected_rankset = [0, 1] if dist.get_rank() in [0, 1] else [2, 3]
144        self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=1)), tag)
145        self.assertEqual(expected_rankset, rankset)
146        self.assertEqual(2, group_size)
147
148
149class TestPgTag(MultiThreadedTestCase):
150    @property
151    def world_size(self):
152        return 4
153
154    def setUp(self):
155        super().setUp()
156        self._spawn_threads()
157
158    """
159    The behavior we want is as follow:
160
161    - rankset+tag will always result in the same PG.
162    Do we enforce this by failing creation of new PGs or returning existing ones?
163        Return existing one.
164
165    - default tag gives existing behavior.
166        This means we should create duplicates.
167    - _expand_group on _default-tagged pg should always resolve to it
168        This mean we can't depend on empty tag + rankset.
169    """
170
171    def test_pg_creation_with_tag(self):
172        my_group, _ = new_subgroups(group_size=2, pg_tag="blu")
173        my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")
174        self.assertEqual(my_group, my_group2)
175
176        my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")
177        self.assertNotEqual(my_group, my_group3)
178
179        my_group4, _ = new_subgroups(group_size=2)
180        self.assertNotEqual(my_group, my_group4)
181
182        my_group5, _ = new_subgroups(group_size=2)
183        self.assertNotEqual(my_group4, my_group5)
184
185    def test_pg_lookup_roundtrip(self):
186        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
187        pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")
188        pg_notag0, _ = new_subgroups(group_size=2)
189        pg_notag1, _ = new_subgroups(group_size=2)
190
191        def roundtrip(pg):
192            tag, rankset, _ = ft_c._expand_group(pg)
193            return c10d._find_pg_by_ranks_and_tag(tag, rankset)
194
195        self.assertEqual(pg_tag0, roundtrip(pg_tag0))
196        self.assertEqual(pg_tag1, roundtrip(pg_tag1))
197        self.assertEqual(pg_notag0, roundtrip(pg_notag0))
198        self.assertEqual(pg_notag1, roundtrip(pg_notag1))
199
200    def test_pg_lookup_with_tag(self):
201        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
202        pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")
203        pg_notag0, _ = new_subgroups(group_size=2)
204
205        def roundtrip(pg, pg_tag):
206            tag, rankset, _ = ft_c._expand_group(pg, pg_tag)
207            return c10d._find_pg_by_ranks_and_tag(tag, rankset)
208
209        self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))
210        self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))
211        # Cannot erase the tag of a PG
212        self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))
213
214    def test_find_or_create_pg(self):
215        pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)
216        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
217        self.assertEqual(pg, pg_tag0)
218
219    def test_find_root_pg(self):
220        pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])
221        self.assertEqual(dist.group.WORLD, pg)
222
223
224@instantiate_parametrized_tests
225class TestTraceableCollectives(MultiThreadedTestCase):
226    @property
227    def world_size(self):
228        return 4
229
230    def setUp(self):
231        super().setUp()
232        self._spawn_threads()
233
234    @parametrize("device", ["cpu", "cuda"])
235    def test_broadcast(self, device):
236        if device == "cuda":
237            if torch.cuda.device_count() < self.world_size:
238                self.skipTest("Not enough CUDA devices")
239            torch.cuda.set_device(dist.get_rank())
240
241        if dist.get_rank() == 0:
242            tensor = torch.ones([4], device=device)
243        else:
244            tensor = torch.zeros([4], device=device)
245
246        mesh = dt.DeviceMesh(device, torch.arange(4))
247        res = ft_c.broadcast(tensor, 0, mesh)
248        self.assertEqual(res, torch.ones([4], device=device))
249
250    @parametrize("device", ["cpu", "cuda"])
251    def test_all_reduce_eager(self, device):
252        if device == "cuda":
253            if torch.cuda.device_count() < self.world_size:
254                self.skipTest("Not enough CUDA devices")
255            torch.cuda.set_device(dist.get_rank())
256
257        tensor = torch.ones([4], device=device)
258        mesh = dt.DeviceMesh(device, torch.arange(4))
259
260        res = ft_c.all_reduce(tensor, "sum", mesh)
261        self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))
262
263        mesh = dt.DeviceMesh(device, torch.arange(4).view(2, 2))
264        res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
265        self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))
266
267    @parametrize("device", ["cpu", "cuda"])
268    def test_all_reduce_coalesced_eager(self, device):
269        if device == "cuda":
270            if torch.cuda.device_count() < self.world_size:
271                self.skipTest("Not enough CUDA devices")
272            torch.cuda.set_device(dist.get_rank())
273
274        t0 = torch.ones([4], device=device)
275        t1 = torch.ones([6], device=device) + 2
276        mesh = dt.DeviceMesh(device, torch.arange(4))
277
278        res = ft_c.all_reduce_coalesced([t0, t1], "sum", mesh)
279        self.assertEqual(res[0], t0 * 4)
280        self.assertEqual(res[1], t1 * 4)
281
282    @parametrize("device", ["cpu", "cuda"])
283    def test_all_gather_tensor(self, device):
284        if device == "cuda":
285            if torch.cuda.device_count() < self.world_size:
286                self.skipTest("Not enough CUDA devices")
287            torch.cuda.set_device(dist.get_rank())
288
289        # testing 1d/2d mesh
290        mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
291        mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
292        for mesh in [mesh_1d, mesh_2d]:
293            dims_to_gather = [0, 1, 2]
294            for dim in dims_to_gather:
295                output_size = [3, 3, 3]
296                output_size[dim] *= mesh.size(0)
297                # each rank have its own tensor, all_gather gives a bigger tensor
298                local_tensor = torch.ones([3, 3, 3], device=device)
299                gathered_tensor = ft_c.all_gather_tensor(
300                    local_tensor, gather_dim=dim, group=(mesh, 0)
301                )
302                self.assertEqual(gathered_tensor, torch.ones(output_size))
303
304    @parametrize("device", ["cpu", "cuda"])
305    def test_all_gather_into_tensor_coalesced(self, device):
306        if device == "cuda":
307            if torch.cuda.device_count() < self.world_size:
308                self.skipTest("Not enough CUDA devices")
309            torch.cuda.set_device(dist.get_rank())
310
311        tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
312        mesh = dt.DeviceMesh(device, torch.arange(4))
313
314        res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
315        self.assertEqual(2, len(res))
316        self.assertEqual(torch.ones([4 * dist.get_world_size()], device=device), res[0])
317        self.assertEqual(
318            torch.ones([4 * dist.get_world_size()], device=device) + 1, res[1]
319        )
320
321    @parametrize("device", ["cpu", "cuda"])
322    def test_reduce_scatter_tensor(self, device):
323        if device == "cuda":
324            if torch.cuda.device_count() < self.world_size:
325                self.skipTest("Not enough CUDA devices")
326            torch.cuda.set_device(dist.get_rank())
327
328        # testing 1d/2d mesh
329        mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
330        mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
331        for mesh in [mesh_1d, mesh_2d]:
332            dims_to_scatter = [0, 1]
333            for dim in dims_to_scatter:
334                group_size = mesh.size(0)
335                input_size = [3, 3]
336                output_size = [3, 3]
337                output_size[dim] *= group_size
338                input_tensor = torch.ones(output_size, device=device)
339                res_num = 1 * group_size
340                rs_tensor = ft_c.reduce_scatter_tensor(
341                    input_tensor, "sum", scatter_dim=dim, group=(mesh, 0)
342                )
343                self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
344
345    @parametrize("device", ["cpu", "cuda"])
346    def test_reduce_scatter_into_tensor_coalesced(self, device):
347        if device == "cuda":
348            if torch.cuda.device_count() < self.world_size:
349                self.skipTest("Not enough CUDA devices")
350            torch.cuda.set_device(dist.get_rank())
351        tensors = [
352            torch.ones([4], dtype=torch.int64, device=device),
353            torch.ones([4], dtype=torch.int64, device=device) + 1,
354        ]
355        mesh = dt.DeviceMesh(device, torch.arange(4))
356
357        res = ft_c.reduce_scatter_tensor_coalesced(tensors, "sum", [0, 0], mesh)
358        self.assertEqual(2, len(res))
359        self.assertEqual(torch.tensor([4], device=device), res[0])
360        self.assertEqual(torch.tensor([8], device=device), res[1])
361
362
363class TestMetaCollectives(TestCase):
364    def test_all_reduce(self):
365        x = torch.rand((2, 3, 4), device="meta")
366        out = ft_c.all_reduce(x, "sum", "0")
367        self.assertEqual(x.size(), out.size())
368
369
370class TestGradCollectives(MultiThreadedTestCase):
371    @property
372    def world_size(self):
373        return 2
374
375    def setUp(self):
376        super().setUp()
377        self._spawn_threads()
378
379    def test_all_reduce(self):
380        x = torch.rand([4], requires_grad=True)
381        y = torch.rand([4], requires_grad=True)
382        out = ft_c.all_reduce(x, "sum", dist.group.WORLD)
383        (out + y).sum().backward()
384        self.assertIsNone(x.grad)
385
386
387class TestMakeFx(TestCase):
388    def setUp(self):
389        # make_fx is not thread-safe due to patching nd mutating global states
390        # so create a fake_pg.
391        self.rank = 0
392        self.world_size = 2
393        store = FakeStore()
394        dist.init_process_group(
395            backend="fake",
396            world_size=self.world_size,
397            rank=self.rank,
398            store=store,
399        )
400
401    def tearDown(self):
402        super().tearDown()
403
404        self.assertFalse(torch.fx._symbolic_trace.is_fx_tracing())
405
406    def test_all_reduce_tracing(self):
407        def allred(input):
408            return ft_c.all_reduce(input, "sum", group=dist.group.WORLD) + 1
409
410        graph = make_fx(allred)(torch.rand(4))
411        FileCheck().check("all_reduce").check("wait_tensor").run(str(graph.graph))
412
413        mesh = dt.DeviceMesh("cpu", torch.arange(self.world_size))
414
415        def allred_mesh(input):
416            return ft_c.all_reduce(input, "sum", mesh) + 1
417
418        mesh_graph = make_fx(allred_mesh)(torch.rand(4))
419        FileCheck().check_not("get_attr").check("wait_tensor").run(
420            str(mesh_graph.graph)
421        )
422
423        def allred_mesh_dim(input):
424            return ft_c.all_reduce(input, "sum", (mesh, 0)) + 1
425
426        mesh_dim_graph = make_fx(allred_mesh_dim)(torch.rand(4))
427        FileCheck().check_not("get_attr").check("wait_tensor").run(
428            str(mesh_dim_graph.graph)
429        )
430
431
432BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
433WORLD_SIZE = 2
434
435
436def exit_if_lt_x_gpu(x):
437    if torch.cuda.device_count() < x:
438        sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
439
440
441def with_comms(func=None):
442    if func is None:
443        return partial(
444            with_comms,
445        )
446
447    @wraps(func)
448    def wrapper(self, *args, **kwargs):
449        global BACKEND
450
451        if "BACKEND" in os.environ:
452            BACKEND = os.environ["BACKEND"]
453        if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
454            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
455        self.dist_init()
456        func(self)
457        self.destroy_comms()
458
459    return wrapper
460
461
462class TestCollectivesWithNCCL(MultiProcessTestCase):
463    def setUp(self):
464        super().setUp()
465        os.environ["WORLD_SIZE"] = str(self.world_size)
466        os.environ["BACKEND"] = dist.Backend.NCCL
467        BACKEND = dist.Backend.NCCL
468        self._spawn_processes()
469
470    @property
471    def device(self):
472        return torch.device(self.rank)
473
474    @property
475    def world_size(self):
476        return WORLD_SIZE
477
478    @property
479    def process_group(self):
480        return dist.group.WORLD
481
482    def dist_init(self):
483        dist.init_process_group(
484            backend=BACKEND,
485            world_size=self.world_size,
486            rank=self.rank,
487            init_method=f"file://{self.file_name}",
488        )
489
490        # set device for nccl pg for collectives
491        if BACKEND == "nccl":
492            torch.cuda.set_device(self.rank)
493
494    def destroy_comms(self):
495        # Wait for all ranks to reach here before starting shutdown.
496        dist.barrier()
497        dist.destroy_process_group()
498
499    @requires_nccl()
500    @with_comms()
501    def test_all_gather_into_tensor_coalesced(self):
502        exit_if_lt_x_gpu(self.world_size)
503
504        tensors = [
505            torch.ones([4], device=f"cuda:{self.rank}"),
506            torch.ones([4], device=f"cuda:{self.rank}") + 1,
507        ]
508        mesh = dt.DeviceMesh(f"cuda:{self.rank}", torch.arange(self.world_size))
509
510        res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
511        self.assertEqual(2, len(res))
512        self.assertEqual(torch.ones([4 * dist.get_world_size()]), res[0])
513        self.assertEqual(torch.ones([4 * dist.get_world_size()]) + 1, res[1])
514
515    @with_comms()
516    def test_all_to_all_single(self):
517        device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
518        mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
519        rank = dist.get_rank()
520
521        row = self.world_size * (rank + 1) * (self.world_size + 1) / 2
522        x = torch.ones(int(row), 5, device=device) * (rank + 1)
523        split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]
524        y = ft_c.all_to_all_single(
525            x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh
526        )
527        expected = []
528        for idx, tensor in enumerate(torch.split(x, split_sizes)):
529            expected.append(torch.full_like(tensor, (idx + 1)))
530        expected = torch.cat(expected)
531        self.assertEqual(y, expected)
532
533    @with_comms()
534    def test_all_to_all_single_1d_input(self):
535        device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
536        mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
537        rank = dist.get_rank()
538
539        row = self.world_size * (rank + 1) * (self.world_size + 1) / 2
540        x = torch.ones(int(row), device=device) * (rank + 1)
541        split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]
542        y = ft_c.all_to_all_single(
543            x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh
544        )
545        expected = []
546        for idx, tensor in enumerate(torch.split(x, split_sizes)):
547            expected.append(torch.full_like(tensor, (idx + 1)))
548        expected = torch.cat(expected)
549        self.assertEqual(y, expected)
550
551    @with_comms()
552    def test_all_to_all_single_split_sizes_none(self):
553        device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
554        mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
555        rank = dist.get_rank()
556
557        x = torch.ones(self.world_size, self.world_size, device=device) * (rank + 1)
558        y = ft_c.all_to_all_single(
559            x, output_split_sizes=None, input_split_sizes=None, group=mesh
560        )
561        expected = []
562        for idx, tensor in enumerate(torch.chunk(x, self.world_size)):
563            expected.append(torch.full_like(tensor, (idx + 1)))
564        expected = torch.cat(expected)
565        self.assertEqual(y, expected)
566
567    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
568    @requires_nccl()
569    @with_comms()
570    def test_tracing(self):
571        def allreduce(t, pg):
572            return ft_c.all_reduce(t, "sum", pg)
573
574        compiled_allreduce = torch.compile(allreduce, fullgraph=True)
575        compiled_allreduce(torch.randn(8, device=self.device), self.process_group)
576
577    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
578    def test_tracing_with_fakepg(self):
579        exit_if_lt_x_gpu(self.world_size)
580
581        def allreduce(t, pg):
582            return ft_c.all_reduce(t, "sum", pg)
583
584        compiled_allreduce = torch.compile(allreduce, fullgraph=True)
585        dist.init_process_group(
586            backend="fake",
587            rank=0,
588            world_size=8,
589            store=FakeStore(),
590        )
591        allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD)
592
593    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
594    @requires_nccl()
595    @with_comms()
596    def test_tracing_with_dce_code(self):
597        if self.world_size > 2:
598            return
599
600        def func(batch, group, rank):
601            ret = ft_c.permute_tensor(batch, [1, 0], group)
602            if hasattr(ret, "wait"):
603                ret = ret.wait()
604            if rank == 0:
605                return ret
606            else:
607                return batch * 5
608
609        compiled_func = torch.compile(func)
610        ret = compiled_func(
611            torch.ones((100,), device="cuda"), self.process_group, self.rank
612        )
613        dist.barrier()
614
615
616class TestNCCLCollectivesWithWorldSize4(TestCollectivesWithNCCL):
617    @property
618    def world_size(self):
619        return 4
620
621    @requires_nccl()
622    @with_comms()
623    def test_permute_tensor_with_sub_group(self):
624        exit_if_lt_x_gpu(self.world_size)
625
626        device = "cuda"
627        mesh_dim_names = ["dp", "tp"]
628
629        mesh_2d = dt.init_device_mesh(
630            device, (2, self.world_size // 2), mesh_dim_names=mesh_dim_names
631        )
632
633        for mesh_name in mesh_dim_names:
634            mesh = mesh_2d[mesh_name]
635            rank = mesh.get_local_rank()
636
637            # rank0: [0., 1.], rank1: [2., 3.]
638            send_tensor = torch.arange(2, dtype=torch.float32, device=device) + 2 * rank
639            recvd_tensor = ft_c.permute_tensor(send_tensor, [1, 0], group=mesh)
640
641            # rank0: [2., 3.], rank1: [0., 1.]
642            expected = torch.arange(2, dtype=torch.float32, device=device) + 2 * (
643                (rank - 1 + 2) % 2
644            )
645            self.assertEqual(
646                recvd_tensor,
647                expected,
648                msg=f"Expected {expected} on {self.rank=} (local_rank={rank}), "
649                f"but received {recvd_tensor} instead.",
650            )
651
652
653@instantiate_parametrized_tests
654class TestFunctionalAutograd(MultiThreadedTestCase):
655    def setUp(self):
656        super().setUp()
657        self._spawn_threads()
658
659    @property
660    def world_size(self):
661        return 2
662
663    @parametrize("compile", [True, False])
664    def test_all_to_all_single(self, compile: bool = True) -> None:
665        group = dist.group.WORLD.group_name
666
667        t = torch.ones((self.world_size, 2), requires_grad=True)
668
669        def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
670            sizes = [1] * world_size
671            t = t * 2
672            assert t.requires_grad
673            out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
674            out = out + 0
675            return out
676
677        if compile:
678            compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
679        else:
680            compiled = my_func
681
682        out = compiled(t, self.world_size)
683        self.assertEqual(out.shape, t.shape)
684        self.assertEqual(out, torch.full_like(t, 2.0))
685        self.assertIsNotNone(out.grad_fn)
686        self.assertTrue(out.requires_grad)
687        loss = out.sum()
688        loss.backward()
689        self.assertEqual(t.grad, torch.full_like(t, 2.0))
690
691    def test_all_to_all_single_inductor(self) -> None:
692        group = dist.group.WORLD.group_name
693
694        t = torch.rand((self.world_size, 2), requires_grad=True)
695
696        def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
697            sizes = [1] * world_size
698            t = t * 10
699            assert t.requires_grad
700            out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
701            out = out + 2
702            return out.sum()
703
704        compiled = torch.compile(my_func, fullgraph=True)
705
706        def run_with_backward():
707            out = compiled(t, self.world_size)
708            out.backward()
709
710        res, codes = run_and_get_code(run_with_backward)
711        for code in codes:
712            FileCheck().check_count(
713                "_c10d_functional.all_to_all_single.default", 1, exactly=True
714            ).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(
715                code
716            )
717
718        self.assertIsNotNone(t.grad)
719
720    @parametrize("compile", [True, False])
721    def test_all_gather_tensor(self, compile: bool) -> None:
722        group = dist.group.WORLD.group_name
723
724        def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
725            assert t.requires_grad
726            out = ft_c.all_gather_tensor_autograd(
727                t * 1.0,
728                gather_dim=dim,
729                group=group,
730            )
731            out = out * 1.0
732            return out
733
734        if compile:
735            compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
736        else:
737            compiled = my_func
738
739        dims_to_gather = [0, 1, 2]
740        for dim in dims_to_gather:
741            output_size = [3, 3, 3]
742            output_size[dim] *= self.world_size
743            # each rank have its own tensor, all_gather gives a bigger tensor
744            local_tensor = torch.ones([3, 3, 3], requires_grad=True)
745            gathered_tensor = compiled(local_tensor, dim)
746            self.assertEqual(gathered_tensor, torch.ones(output_size))
747
748            gathered_tensor.sum().backward()
749            self.assertEqual(
750                local_tensor.grad,
751                torch.full((3, 3, 3), fill_value=float(self.world_size)),
752            )
753
754    @parametrize("compile", [True, False])
755    def test_reduce_scatter_tensor(self, compile: bool) -> None:
756        group = dist.group.WORLD.group_name
757
758        def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
759            assert t.requires_grad
760            rs_tensor = (
761                ft_c.reduce_scatter_tensor_autograd(
762                    input_tensor * 1.0, "sum", scatter_dim=dim, group=group
763                )
764                * 1.0
765            )
766            return rs_tensor
767
768        if compile:
769            compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
770        else:
771            compiled = my_func
772
773        dims_to_scatter = [0, 1]
774        for dim in dims_to_scatter:
775            group_size = self.world_size
776            input_size = [3, 3]
777            output_size = [3, 3]
778            output_size[dim] *= group_size
779            input_tensor = torch.ones(output_size, requires_grad=True)
780            rs_tensor = compiled(input_tensor, dim)
781            res_num = 1 * group_size
782            self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
783            rs_tensor.sum().backward()
784            self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))
785
786
787class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):
788    def setUp(self):
789        super().setUp()
790        os.environ["WORLD_SIZE"] = str(self.world_size)
791        os.environ["BACKEND"] = dist.Backend.NCCL
792        self._spawn_processes()
793
794    @property
795    def device(self):
796        return torch.device(self.rank)
797
798    @property
799    def world_size(self):
800        return 2
801
802    @property
803    def process_group(self):
804        return dist.group.WORLD
805
806    def dist_init(self):
807        dist.init_process_group(
808            backend=BACKEND,
809            world_size=self.world_size,
810            rank=self.rank,
811            init_method=f"file://{self.file_name}",
812        )
813
814        # set device for nccl pg for collectives
815        if BACKEND == "nccl":
816            torch.cuda.set_device(self.rank)
817
818    def destroy_comms(self):
819        # Wait for all ranks to reach here before starting shutdown.
820        dist.barrier()
821        dist.destroy_process_group()
822
823    @requires_nccl()
824    @with_comms()
825    def test_all_to_all_single(self) -> None:
826        group = self.process_group.group_name
827
828        t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device)
829
830        sizes = [1] * self.world_size
831        assert t.requires_grad
832        out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0
833
834        self.assertEqual(out.shape, t.shape)
835        self.assertEqual(out, torch.full_like(t, 2.0))
836        self.assertIsNotNone(out.grad_fn)
837        self.assertTrue(out.requires_grad)
838        loss = out.sum()
839        loss.backward()
840        self.assertEqual(t.grad, torch.full_like(t, 2.0))
841
842
843if __name__ == "__main__":
844    run_tests()
845