xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_logger.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import json
4import logging
5import os
6import re
7import sys
8import time
9from functools import partial, wraps
10
11import torch
12import torch.distributed as dist
13from torch.distributed.c10d_logger import _c10d_logger, _exception_logger, _time_logger
14
15
16if not dist.is_available():
17    print("Distributed not available, skipping tests", file=sys.stderr)
18    sys.exit(0)
19
20from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
21from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
22
23
24if TEST_WITH_DEV_DBG_ASAN:
25    print(
26        "Skip dev-asan as torch + multiprocessing spawn have known issues",
27        file=sys.stderr,
28    )
29    sys.exit(0)
30
31BACKEND = dist.Backend.NCCL
32WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
33
34
35def with_comms(func=None):
36    if func is None:
37        return partial(
38            with_comms,
39        )
40
41    @wraps(func)
42    def wrapper(self, *args, **kwargs):
43        if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
44            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
45        self.dist_init()
46        func(self)
47        self.destroy_comms()
48
49    return wrapper
50
51
52class C10dErrorLoggerTest(MultiProcessTestCase):
53    def setUp(self):
54        super().setUp()
55        os.environ["WORLD_SIZE"] = str(self.world_size)
56        os.environ["BACKEND"] = BACKEND
57        self._spawn_processes()
58
59    @property
60    def device(self):
61        return (
62            torch.device(self.rank)
63            if BACKEND == dist.Backend.NCCL
64            else torch.device("cpu")
65        )
66
67    @property
68    def world_size(self):
69        return WORLD_SIZE
70
71    @property
72    def process_group(self):
73        return dist.group.WORLD
74
75    def destroy_comms(self):
76        # Wait for all ranks to reach here before starting shutdown.
77        dist.barrier()
78        dist.destroy_process_group()
79
80    def dist_init(self):
81        dist.init_process_group(
82            backend=BACKEND,
83            world_size=self.world_size,
84            rank=self.rank,
85            init_method=f"file://{self.file_name}",
86        )
87
88        # set device for nccl pg for collectives
89        if BACKEND == "nccl":
90            torch.cuda.set_device(self.rank)
91
92    def test_get_or_create_logger(self):
93        self.assertIsNotNone(_c10d_logger)
94        self.assertEqual(1, len(_c10d_logger.handlers))
95        self.assertIsInstance(_c10d_logger.handlers[0], logging.NullHandler)
96
97    @_exception_logger
98    def _failed_broadcast_raise_exception(self):
99        tensor = torch.arange(2, dtype=torch.int64)
100        dist.broadcast(tensor, self.world_size + 1)
101
102    @_exception_logger
103    def _failed_broadcast_not_raise_exception(self):
104        try:
105            tensor = torch.arange(2, dtype=torch.int64)
106            dist.broadcast(tensor, self.world_size + 1)
107        except Exception:
108            pass
109
110    @with_comms
111    def test_exception_logger(self) -> None:
112        with self.assertRaises(Exception):
113            self._failed_broadcast_raise_exception()
114
115        with self.assertLogs(_c10d_logger, level="DEBUG") as captured:
116            self._failed_broadcast_not_raise_exception()
117            error_msg_dict = json.loads(
118                re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
119            )
120
121            self.assertEqual(len(error_msg_dict), 10)
122
123            self.assertIn("pg_name", error_msg_dict.keys())
124            self.assertEqual("None", error_msg_dict["pg_name"])
125
126            self.assertIn("func_name", error_msg_dict.keys())
127            self.assertEqual("broadcast", error_msg_dict["func_name"])
128
129            self.assertIn("args", error_msg_dict.keys())
130
131            self.assertIn("backend", error_msg_dict.keys())
132            self.assertEqual("nccl", error_msg_dict["backend"])
133
134            self.assertIn("nccl_version", error_msg_dict.keys())
135            nccl_ver = torch.cuda.nccl.version()
136            self.assertEqual(
137                ".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]
138            )
139
140            # In this test case, group_size = world_size, since we don't have multiple processes on one node.
141            self.assertIn("group_size", error_msg_dict.keys())
142            self.assertEqual(str(self.world_size), error_msg_dict["group_size"])
143
144            self.assertIn("world_size", error_msg_dict.keys())
145            self.assertEqual(str(self.world_size), error_msg_dict["world_size"])
146
147            self.assertIn("global_rank", error_msg_dict.keys())
148            self.assertIn(str(dist.get_rank()), error_msg_dict["global_rank"])
149
150            # In this test case, local_rank = global_rank, since we don't have multiple processes on one node.
151            self.assertIn("local_rank", error_msg_dict.keys())
152            self.assertIn(str(dist.get_rank()), error_msg_dict["local_rank"])
153
154    @_time_logger
155    def _dummy_sleep(self):
156        time.sleep(5)
157
158    @with_comms
159    def test_time_logger(self) -> None:
160        with self.assertLogs(_c10d_logger, level="DEBUG") as captured:
161            self._dummy_sleep()
162            msg_dict = json.loads(
163                re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
164            )
165            self.assertEqual(len(msg_dict), 10)
166
167            self.assertIn("pg_name", msg_dict.keys())
168            self.assertEqual("None", msg_dict["pg_name"])
169
170            self.assertIn("func_name", msg_dict.keys())
171            self.assertEqual("_dummy_sleep", msg_dict["func_name"])
172
173            self.assertIn("args", msg_dict.keys())
174
175            self.assertIn("backend", msg_dict.keys())
176            self.assertEqual("nccl", msg_dict["backend"])
177
178            self.assertIn("nccl_version", msg_dict.keys())
179            nccl_ver = torch.cuda.nccl.version()
180            self.assertEqual(
181                ".".join(str(v) for v in nccl_ver), msg_dict["nccl_version"]
182            )
183
184            # In this test case, group_size = world_size, since we don't have multiple processes on one node.
185            self.assertIn("group_size", msg_dict.keys())
186            self.assertEqual(str(self.world_size), msg_dict["group_size"])
187
188            self.assertIn("world_size", msg_dict.keys())
189            self.assertEqual(str(self.world_size), msg_dict["world_size"])
190
191            self.assertIn("global_rank", msg_dict.keys())
192            self.assertIn(str(dist.get_rank()), msg_dict["global_rank"])
193
194            # In this test case, local_rank = global_rank, since we don't have multiple processes on one node.
195            self.assertIn("local_rank", msg_dict.keys())
196            self.assertIn(str(dist.get_rank()), msg_dict["local_rank"])
197
198            self.assertIn("time_spent", msg_dict.keys())
199            time_ns = re.findall(r"\d+", msg_dict["time_spent"])[0]
200            self.assertEqual(5, int(float(time_ns) / pow(10, 9)))
201
202
203if __name__ == "__main__":
204    run_tests()
205