xref: /aosp_15_r20/external/pytorch/test/distributed/test_pg_wrapper.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5from datetime import timedelta
6from unittest.mock import patch
7
8import torch
9import torch.distributed as c10d
10from torch._C._distributed_c10d import _ProcessGroupWrapper
11
12
13if not c10d.is_available():
14    print("c10d not available, skipping tests", file=sys.stderr)
15    sys.exit(0)
16
17from test_c10d_common import LOOPBACK
18
19from torch.testing._internal.common_distributed import (
20    create_device,
21    MultiProcessTestCase,
22    requires_gloo,
23    requires_nccl,
24    skip_if_lt_x_gpu,
25    with_dist_debug_levels,
26)
27from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
28
29
30class AbstractProcessGroupWrapperTest(MultiProcessTestCase):
31    def setUp(self):
32        super().setUp()
33        self._spawn_processes()
34
35    def _validate_error(self, exception, op_type, rank, tensor, verify_diff=True):
36        err = str(exception)
37        self.assertTrue(
38            op_type in err, f"Got {err} but expected {op_type} to be in error."
39        )
40        # User doesn't call barrier with tensor.
41        if op_type != "BARRIER":
42            self.assertTrue(
43                f"{list(tensor.shape)}" in err,
44                f"Did not find shapes {list(tensor.shape)} in error {err}",
45            )
46            # For CUDA, only assert on device type, not index
47            if "cuda" in str(tensor.device):
48                self.assertTrue(
49                    "cuda" in err, f"Did not find cuda device in error {err}"
50                )
51            else:
52                self.assertTrue(
53                    str(tensor.device) in err,
54                    f"Did not find tensor device {str(tensor.device)} in error {err}",
55                )
56            # C++ and python type strings are not exactly the same.
57            if "float" in str(tensor.dtype):
58                self.assertTrue("Float" in err, "Expected Float type")
59            elif "int" in str(tensor.dtype):
60                self.assertTrue("Long" in err, "Expected Long type")
61            else:
62                self.fail(f"Unexpected dtype {str(tensor.dtype)} for error {err}")
63
64            # Ensure sequence number is logged in error
65            self.assertTrue("SequenceNumber" in err)
66            # Ensure info about how collectives diff is in the error.
67            if verify_diff:
68                self.assertTrue(
69                    "Collectives differ in the following" in err, f"Got error {err}"
70                )
71
72    def _test_collective_hang(self, wrapper_pg, use_cuda=False):
73        # All ranks besides 1 call allreduce and wrapper_pg should detect a hang
74        # and report an issue with rank 1.
75        faulty_rank = 1
76        if self.rank != faulty_rank:
77            tensor = torch.randn(20, 10)
78            if use_cuda:
79                tensor = tensor.to(self.rank)
80
81            if self.rank == 0:
82                # Rank 0 reports faulty ranks
83                err = f"Ranks {faulty_rank} failed to pass monitoredBarrier"
84            else:
85                err = "Please check rank 0 logs for faulty rank"
86
87            # Gloo can sometimes throw the following error if a rank exits early
88            # before rank 0 calls into the allreduce.
89            err += "|Connection closed by peer|Connection reset by peer"
90            with self.assertRaisesRegex(RuntimeError, err):
91                wrapper_pg.allreduce([tensor])
92
93    def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False):
94        tensor = torch.randn(20, 10)
95        if use_cuda:
96            tensor = tensor.to(self.rank)
97        works = []
98        # Run a few successful collectives
99        for _ in range(500):
100            work = wrapper_pg.allreduce([tensor])
101            works.append(work)
102
103        for w in works:
104            w.wait()
105
106        # Simulate mismatch: allreduce vs reduce.
107        # Error including info about inconsistent collective, rank, tensor
108        # shape, device, and dtype should be raised.
109        with self.assertRaisesRegex(RuntimeError, ".*") as cm:
110            if self.rank == 0:
111                wrapper_pg.allreduce([tensor])
112            else:
113                wrapper_pg.reduce([tensor])
114        self._validate_error(
115            exception=cm.exception,
116            op_type="ALLREDUCE" if self.rank == 0 else "REDUCE",
117            rank=self.rank,
118            tensor=tensor,
119        )
120
121        with self.assertRaisesRegex(RuntimeError, ".*") as cm:
122            if self.rank == 0:
123                wrapper_pg.reduce([tensor])
124            else:
125                wrapper_pg.barrier()
126        self._validate_error(
127            exception=cm.exception,
128            op_type="REDUCE" if self.rank == 0 else "BARRIER",
129            rank=self.rank,
130            tensor=tensor,
131        )
132
133        with self.assertRaisesRegex(RuntimeError, ".*") as cm:
134            if self.rank == 0:
135                wrapper_pg.broadcast(tensor, 0)
136            else:
137                output_tensors = [
138                    torch.zeros_like(tensor) for _ in range(self.world_size)
139                ]
140                wrapper_pg.allgather([output_tensors], [tensor])
141        self._validate_error(
142            exception=cm.exception,
143            op_type="BROADCAST" if self.rank == 0 else "ALLGATHER",
144            rank=self.rank,
145            tensor=tensor,
146        )
147
148    def _test_collective_shape_mismatch(self, wrapper_pg, use_cuda=False):
149        wrapper_pg.barrier()
150        dim = 2 if self.rank == 0 else 10
151        tensor = torch.randn(20, dim)
152        if use_cuda:
153            tensor = tensor.to(self.rank)
154        with self.assertRaisesRegex(RuntimeError, ".*") as cm:
155            wrapper_pg.allreduce([tensor])
156        self._validate_error(
157            exception=cm.exception,
158            op_type="ALLREDUCE",
159            rank=self.rank,
160            tensor=tensor,
161        )
162
163        # Check errors are raised when dimensionality of shapes is different
164        tensor = torch.randn(20, 10, 2) if self.rank == 0 else torch.randn(20, 10)
165        if use_cuda:
166            tensor = tensor.to(self.rank)
167        with self.assertRaisesRegex(RuntimeError, ".*") as cm:
168            wrapper_pg.allreduce([tensor])
169        self._validate_error(
170            exception=cm.exception,
171            op_type="ALLREDUCE",
172            rank=self.rank,
173            tensor=tensor,
174        )
175
176        # Check shape errors with scatter
177        input = [
178            torch.tensor(
179                [self.rank] if self.rank == 0 else [self.rank, self.rank],
180                device=self.rank if use_cuda else "cpu",
181            )
182            for _ in range(self.world_size)
183        ]
184        outputs = [
185            torch.tensor(
186                [-1] if self.rank == 0 else [-1, -1],
187                device=self.rank if use_cuda else "cpu",
188            )
189            for _ in range(self.world_size)
190        ]
191        root_rank = 0
192        opts = c10d.ScatterOptions()
193        opts.rootRank = root_rank
194        with self.assertRaisesRegex(RuntimeError, ".*") as cm:
195            if self.rank == root_rank:
196                wrapper_pg.scatter([outputs[self.rank]], [input], opts).wait()
197            else:
198                wrapper_pg.scatter([outputs[self.rank]], [], opts).wait()
199        self._validate_error(
200            exception=cm.exception,
201            op_type="SCATTER",
202            rank=self.rank,
203            tensor=outputs[self.rank],
204        )
205
206
207# ASAN is not safe since we are spawning processes.
208if not TEST_WITH_DEV_DBG_ASAN:
209
210    @requires_gloo()
211    @requires_nccl()
212    class ProcessGroupNCCLWrapperTest(AbstractProcessGroupWrapperTest):
213        def setUp(self):
214            super(AbstractProcessGroupWrapperTest, self).setUp()
215            self._spawn_processes()
216            # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
217            # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
218            os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
219
220        @property
221        def world_size(self) -> int:
222            return 2
223
224        def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
225            store = c10d.FileStore(self.file_name, self.world_size)
226            c10d.init_process_group(
227                backend="nccl",
228                rank=self.rank,
229                world_size=self.world_size,
230                store=store,
231                timeout=timedelta(seconds=timeout),
232            )
233            if with_new_group:
234                pg = c10d.new_group(backend="nccl", timeout=timedelta(seconds=timeout))
235            else:
236                _pg = c10d.ProcessGroupNCCL(
237                    store,
238                    self.rank,
239                    self.world_size,
240                    timeout=timedelta(seconds=timeout),
241                )
242                pg = c10d._create_process_group_wrapper(
243                    _pg,
244                    "unused",
245                    store,
246                    self.rank,
247                    self.world_size,
248                    timeout=timeout,
249                )
250            return pg
251
252        @requires_nccl()
253        @skip_if_lt_x_gpu(2)
254        def test_collective_hang(self):
255            pg = self._create_wrapper_pg(timeout=2.0)
256            self._test_collective_hang(pg)
257
258        # NOTE: these tests are separated by debug level instead of combined into
259        # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
260        # combined after that is resolved.
261        @requires_nccl()
262        @skip_if_lt_x_gpu(2)
263        @with_dist_debug_levels(levels=["DETAIL"])
264        def test_collectives_op_mismatch_debug_mode(self):
265            pg = self._create_wrapper_pg(with_new_group=True)
266            self._test_collectives_op_mismatch(pg, use_cuda=True)
267            self._test_nccl_only_op_mismatch(pg)
268
269        @requires_nccl()
270        @skip_if_lt_x_gpu(2)
271        @with_dist_debug_levels(levels=["OFF"])
272        def test_collectives_op_mismatch(self):
273            pg = self._create_wrapper_pg(with_new_group=False)
274            self._test_collectives_op_mismatch(pg, use_cuda=True)
275            self._test_nccl_only_op_mismatch(pg)
276
277        @requires_nccl()
278        @skip_if_lt_x_gpu(2)
279        @with_dist_debug_levels(levels=["DETAIL"])
280        def test_collective_shape_mismatch_debug_mode_detail(self):
281            pg = self._create_wrapper_pg(with_new_group=True)
282            self._test_collective_shape_mismatch(pg, use_cuda=True)
283            self._test_nccl_only_shape_mismatch(pg)
284
285        @requires_nccl()
286        @skip_if_lt_x_gpu(2)
287        @with_dist_debug_levels(levels=["OFF"])
288        def test_collective_shape_mismatch_debug_mode_off(self):
289            pg = self._create_wrapper_pg(with_new_group=False)
290            self._test_collective_shape_mismatch(pg, use_cuda=True)
291            self._test_nccl_only_shape_mismatch(pg)
292
293        def _test_nccl_only_op_mismatch(self, wrapper_pg):
294            device = f"cuda:{self.rank}"
295            with self.assertRaisesRegex(RuntimeError, ".*") as cm:
296                output = torch.zeros(4 + self.rank, device=device)
297                input = torch.ones(4 * self.world_size, device=device)
298                if self.rank == 0:
299                    wrapper_pg._allgather_base(output, input).wait()
300                else:
301                    wrapper_pg._reduce_scatter_base(output, input).wait()
302
303            op_type = "ALLGATHER_BASE" if self.rank == 0 else "REDUCE_SCATTER_BASE"
304            self._validate_error(
305                exception=cm.exception,
306                op_type=op_type,
307                rank=self.rank,
308                tensor=input,
309            )
310
311        def _test_nccl_only_shape_mismatch(self, wrapper_pg):
312            device = f"cuda:{self.rank}"
313            with self.assertRaisesRegex(RuntimeError, ".*") as cm:
314                output = torch.zeros(4 + self.rank, device=device)
315                input = torch.ones(4 * (self.world_size + 1), device=device)
316
317                wrapper_pg._reduce_scatter_base(output, input).wait()
318            self._validate_error(
319                exception=cm.exception,
320                op_type="REDUCE_SCATTER_BASE",
321                rank=self.rank,
322                tensor=input,
323                verify_diff=False,
324            )
325            with self.assertRaisesRegex(RuntimeError, ".*") as cm:
326                output = torch.zeros(4, device=device)
327                input = torch.ones((4 + self.rank) * self.world_size, device=device)
328
329                wrapper_pg._reduce_scatter_base(output, input).wait()
330            self._validate_error(
331                exception=cm.exception,
332                op_type="REDUCE_SCATTER_BASE",
333                rank=self.rank,
334                tensor=input,
335                verify_diff=False,
336            )
337
338        @requires_nccl()
339        @skip_if_lt_x_gpu(2)
340        @with_dist_debug_levels(levels=["DETAIL"])
341        def test_coalescing_manager_debug_mode_detail(self):
342            """
343            Tests that coalescing manager w/TORCH_DISTRIBUTED_DEBUG
344            does not crash: https://github.com/pytorch/pytorch/issues/109520
345            """
346            torch.cuda.set_device(self.rank)
347            pg = self._create_wrapper_pg(with_new_group=True)
348            dev = torch.cuda.current_device()
349            pg._start_coalescing(torch.device(dev))
350            pg.allreduce([torch.ones(1, device=dev)])
351            pg._end_coalescing(torch.device(dev))
352
353        @requires_nccl()
354        @skip_if_lt_x_gpu(2)
355        @with_dist_debug_levels(levels=["DETAIL"])
356        @patch("torch.distributed.distributed_c10d._GLOO_AVAILABLE", False)
357        def test_debug_level_detail_no_gloo(self):
358            with self.assertRaisesRegex(
359                AssertionError, "ProcessGroupWrapper unsupported without GLOO backend"
360            ):
361                self._create_wrapper_pg()
362
363        @requires_nccl()
364        @skip_if_lt_x_gpu(2)
365        @patch("torch.distributed.distributed_c10d._GLOO_AVAILABLE", False)
366        def test_new_group_no_gloo(self):
367            def patched_isinstance(obj, clazz):
368                if clazz is _ProcessGroupWrapper:
369                    raise NameError
370                else:
371                    return isinstance(obj, clazz)
372
373            with patch(
374                "torch.distributed.distributed_c10d.isinstance",
375                side_effect=patched_isinstance,
376            ):
377                self._create_wrapper_pg(with_new_group=True)
378                # nothing to assert, isinstance(pg, _ProcessGroupWrapper)
379                # should never be invoked since it is preceeded by
380                # _GLOO_AVAILABLE check, this test will fail on
381                # an unexpected NameError if not.
382
383
384@requires_gloo()
385class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
386    def opts(self, threads=2, timeout=10.0):
387        opts = c10d.ProcessGroupGloo._Options()
388        opts._timeout = timeout
389        opts._devices = [create_device(interface=LOOPBACK)]
390        opts._threads = threads
391        return opts
392
393    def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
394        store = c10d.FileStore(self.file_name, self.world_size)
395        c10d.init_process_group(
396            backend="gloo", rank=self.rank, world_size=self.world_size, store=store
397        )
398        if with_new_group:
399            pg = c10d.new_group(backend="gloo")
400        else:
401            _pg = c10d.ProcessGroupGloo(
402                store, self.rank, self.world_size, self.opts(timeout=timeout)
403            )
404            pg = c10d._create_process_group_wrapper(
405                _pg,
406                "unused",
407                store,
408                self.rank,
409                self.world_size,
410                timeout=timeout,
411            )
412        return pg
413
414    def test_collective_hang(self):
415        pg = self._create_wrapper_pg(timeout=2.0)
416        self._test_collective_hang(pg)
417
418    # NOTE: these tests are separated by debug level instead of combined into
419    # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
420    # combined after that is resolved.
421    @with_dist_debug_levels(levels=["DETAIL"])
422    def test_collectives_op_mismatch_debug_mode(self):
423        pg = self._create_wrapper_pg(with_new_group=True)
424        self._test_collectives_op_mismatch(pg)
425
426    @with_dist_debug_levels(levels=["OFF"])
427    def test_collectives_op_mismatch(self):
428        pg = self._create_wrapper_pg(with_new_group=False)
429        self._test_collectives_op_mismatch(pg)
430
431    @with_dist_debug_levels(levels=["DETAIL"])
432    def test_collective_shape_mismatch_debug_mode(self):
433        pg = self._create_wrapper_pg(with_new_group=True)
434        self._test_collective_shape_mismatch(pg)
435
436    @with_dist_debug_levels(levels=["OFF"])
437    def test_collective_shape_mismatch_debug_mode_off(self):
438        pg = self._create_wrapper_pg(with_new_group=False)
439        self._test_collective_shape_mismatch(pg)
440
441    @skip_if_lt_x_gpu(4)
442    @with_dist_debug_levels(levels=["DETAIL"])
443    def test_collectives_op_mismatch_cuda_debug_mode(self):
444        pg = self._create_wrapper_pg(with_new_group=True)
445        self._test_collectives_op_mismatch(pg, use_cuda=True)
446
447    @skip_if_lt_x_gpu(4)
448    @with_dist_debug_levels(levels=["OFF"])
449    def test_collectives_op_mismatch_cuda(self):
450        pg = self._create_wrapper_pg(with_new_group=False)
451        self._test_collectives_op_mismatch(pg, use_cuda=True)
452
453    @skip_if_lt_x_gpu(4)
454    @with_dist_debug_levels(levels=["DETAIL"])
455    def test_collective_shape_mismatch_cuda_debug_mode(self):
456        pg = self._create_wrapper_pg(with_new_group=True)
457        self._test_collective_shape_mismatch(pg, use_cuda=True)
458
459    @skip_if_lt_x_gpu(4)
460    @with_dist_debug_levels(levels=["OFF"])
461    def test_collective_shape_mismatch_cuda(self):
462        pg = self._create_wrapper_pg(with_new_group=False)
463        self._test_collective_shape_mismatch(pg, use_cuda=True)
464
465
466if __name__ == "__main__":
467    assert (
468        not torch.cuda._initialized
469    ), "test_pg_wrapper must not have initialized CUDA context on main process"
470
471    run_tests()
472