1# Owner(s): ["oncall: distributed"] 2 3import json 4import logging 5import os 6import re 7import sys 8import time 9from functools import partial, wraps 10 11import torch 12import torch.distributed as dist 13from torch.distributed.c10d_logger import _c10d_logger, _exception_logger, _time_logger 14 15 16if not dist.is_available(): 17 print("Distributed not available, skipping tests", file=sys.stderr) 18 sys.exit(0) 19 20from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS 21from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 22 23 24if TEST_WITH_DEV_DBG_ASAN: 25 print( 26 "Skip dev-asan as torch + multiprocessing spawn have known issues", 27 file=sys.stderr, 28 ) 29 sys.exit(0) 30 31BACKEND = dist.Backend.NCCL 32WORLD_SIZE = min(4, max(2, torch.cuda.device_count())) 33 34 35def with_comms(func=None): 36 if func is None: 37 return partial( 38 with_comms, 39 ) 40 41 @wraps(func) 42 def wrapper(self, *args, **kwargs): 43 if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: 44 sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) 45 self.dist_init() 46 func(self) 47 self.destroy_comms() 48 49 return wrapper 50 51 52class C10dErrorLoggerTest(MultiProcessTestCase): 53 def setUp(self): 54 super().setUp() 55 os.environ["WORLD_SIZE"] = str(self.world_size) 56 os.environ["BACKEND"] = BACKEND 57 self._spawn_processes() 58 59 @property 60 def device(self): 61 return ( 62 torch.device(self.rank) 63 if BACKEND == dist.Backend.NCCL 64 else torch.device("cpu") 65 ) 66 67 @property 68 def world_size(self): 69 return WORLD_SIZE 70 71 @property 72 def process_group(self): 73 return dist.group.WORLD 74 75 def destroy_comms(self): 76 # Wait for all ranks to reach here before starting shutdown. 77 dist.barrier() 78 dist.destroy_process_group() 79 80 def dist_init(self): 81 dist.init_process_group( 82 backend=BACKEND, 83 world_size=self.world_size, 84 rank=self.rank, 85 init_method=f"file://{self.file_name}", 86 ) 87 88 # set device for nccl pg for collectives 89 if BACKEND == "nccl": 90 torch.cuda.set_device(self.rank) 91 92 def test_get_or_create_logger(self): 93 self.assertIsNotNone(_c10d_logger) 94 self.assertEqual(1, len(_c10d_logger.handlers)) 95 self.assertIsInstance(_c10d_logger.handlers[0], logging.NullHandler) 96 97 @_exception_logger 98 def _failed_broadcast_raise_exception(self): 99 tensor = torch.arange(2, dtype=torch.int64) 100 dist.broadcast(tensor, self.world_size + 1) 101 102 @_exception_logger 103 def _failed_broadcast_not_raise_exception(self): 104 try: 105 tensor = torch.arange(2, dtype=torch.int64) 106 dist.broadcast(tensor, self.world_size + 1) 107 except Exception: 108 pass 109 110 @with_comms 111 def test_exception_logger(self) -> None: 112 with self.assertRaises(Exception): 113 self._failed_broadcast_raise_exception() 114 115 with self.assertLogs(_c10d_logger, level="DEBUG") as captured: 116 self._failed_broadcast_not_raise_exception() 117 error_msg_dict = json.loads( 118 re.search("({.+})", captured.output[0]).group(0).replace("'", '"') 119 ) 120 121 self.assertEqual(len(error_msg_dict), 10) 122 123 self.assertIn("pg_name", error_msg_dict.keys()) 124 self.assertEqual("None", error_msg_dict["pg_name"]) 125 126 self.assertIn("func_name", error_msg_dict.keys()) 127 self.assertEqual("broadcast", error_msg_dict["func_name"]) 128 129 self.assertIn("args", error_msg_dict.keys()) 130 131 self.assertIn("backend", error_msg_dict.keys()) 132 self.assertEqual("nccl", error_msg_dict["backend"]) 133 134 self.assertIn("nccl_version", error_msg_dict.keys()) 135 nccl_ver = torch.cuda.nccl.version() 136 self.assertEqual( 137 ".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"] 138 ) 139 140 # In this test case, group_size = world_size, since we don't have multiple processes on one node. 141 self.assertIn("group_size", error_msg_dict.keys()) 142 self.assertEqual(str(self.world_size), error_msg_dict["group_size"]) 143 144 self.assertIn("world_size", error_msg_dict.keys()) 145 self.assertEqual(str(self.world_size), error_msg_dict["world_size"]) 146 147 self.assertIn("global_rank", error_msg_dict.keys()) 148 self.assertIn(str(dist.get_rank()), error_msg_dict["global_rank"]) 149 150 # In this test case, local_rank = global_rank, since we don't have multiple processes on one node. 151 self.assertIn("local_rank", error_msg_dict.keys()) 152 self.assertIn(str(dist.get_rank()), error_msg_dict["local_rank"]) 153 154 @_time_logger 155 def _dummy_sleep(self): 156 time.sleep(5) 157 158 @with_comms 159 def test_time_logger(self) -> None: 160 with self.assertLogs(_c10d_logger, level="DEBUG") as captured: 161 self._dummy_sleep() 162 msg_dict = json.loads( 163 re.search("({.+})", captured.output[0]).group(0).replace("'", '"') 164 ) 165 self.assertEqual(len(msg_dict), 10) 166 167 self.assertIn("pg_name", msg_dict.keys()) 168 self.assertEqual("None", msg_dict["pg_name"]) 169 170 self.assertIn("func_name", msg_dict.keys()) 171 self.assertEqual("_dummy_sleep", msg_dict["func_name"]) 172 173 self.assertIn("args", msg_dict.keys()) 174 175 self.assertIn("backend", msg_dict.keys()) 176 self.assertEqual("nccl", msg_dict["backend"]) 177 178 self.assertIn("nccl_version", msg_dict.keys()) 179 nccl_ver = torch.cuda.nccl.version() 180 self.assertEqual( 181 ".".join(str(v) for v in nccl_ver), msg_dict["nccl_version"] 182 ) 183 184 # In this test case, group_size = world_size, since we don't have multiple processes on one node. 185 self.assertIn("group_size", msg_dict.keys()) 186 self.assertEqual(str(self.world_size), msg_dict["group_size"]) 187 188 self.assertIn("world_size", msg_dict.keys()) 189 self.assertEqual(str(self.world_size), msg_dict["world_size"]) 190 191 self.assertIn("global_rank", msg_dict.keys()) 192 self.assertIn(str(dist.get_rank()), msg_dict["global_rank"]) 193 194 # In this test case, local_rank = global_rank, since we don't have multiple processes on one node. 195 self.assertIn("local_rank", msg_dict.keys()) 196 self.assertIn(str(dist.get_rank()), msg_dict["local_rank"]) 197 198 self.assertIn("time_spent", msg_dict.keys()) 199 time_ns = re.findall(r"\d+", msg_dict["time_spent"])[0] 200 self.assertEqual(5, int(float(time_ns) / pow(10, 9))) 201 202 203if __name__ == "__main__": 204 run_tests() 205