1# mypy: allow-untyped-defs 2 3 4import collections 5import enum 6from typing import cast, Dict, List, Set, Tuple 7 8import torch 9import torch.distributed as dist 10 11from . import api, constants as rpc_constants 12from ._utils import _group_membership_management, _update_group_membership 13 14 15__all__ = [ 16 "backend_registered", 17 "register_backend", 18 "construct_rpc_backend_options", 19 "init_backend", 20 "BackendValue", 21 "BackendType", 22] 23 24BackendValue = collections.namedtuple( 25 "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"] 26) 27 28 29def _backend_type_repr(self): 30 return "BackendType." + self.name 31 32 33_backend_type_doc = """ 34 An enum class of available backends. 35 36 PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend. 37 Additional ones can be registered using the 38 :func:`~torch.distributed.rpc.backend_registry.register_backend` function. 39""" 40 41# Create an enum type, `BackendType`, with empty members. 42# Can't handle Function Enum API (mypy bug #9079) 43BackendType = enum.Enum(value="BackendType", names={}) # type: ignore[misc] 44# Unable to assign a function a method (mypy bug #2427) 45BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] 46 47if BackendType.__doc__: 48 BackendType.__doc__ = _backend_type_doc 49 50 51def backend_registered(backend_name): 52 """ 53 Checks if backend_name is registered as an RPC backend. 54 55 Args: 56 backend_name (str): string to identify the RPC backend. 57 Returns: 58 True if the backend has been registered with ``register_backend``, else 59 False. 60 """ 61 return backend_name in BackendType.__members__.keys() 62 63 64def register_backend( 65 backend_name, construct_rpc_backend_options_handler, init_backend_handler 66): 67 """Registers a new RPC backend. 68 69 Args: 70 backend_name (str): backend string to identify the handler. 71 construct_rpc_backend_options_handler (function): 72 Handler that is invoked when 73 rpc_backend.construct_rpc_backend_options(**dict) is called. 74 init_backend_handler (function): Handler that is invoked when the 75 `_init_rpc_backend()` function is called with a backend. 76 This returns the agent. 77 """ 78 global BackendType 79 if backend_registered(backend_name): 80 raise RuntimeError(f"RPC backend {backend_name}: already registered") 81 # Create a new enum type, `BackendType`, with extended members. 82 existing_enum_dict = {member.name: member.value for member in BackendType} 83 extended_enum_dict = dict( 84 { 85 backend_name: BackendValue( 86 construct_rpc_backend_options_handler=construct_rpc_backend_options_handler, 87 init_backend_handler=init_backend_handler, 88 ) 89 }, 90 **existing_enum_dict, 91 ) 92 # Can't handle Function Enum API (mypy bug #9079) 93 BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] 94 # Unable to assign a function a method (mypy bug #2427) 95 BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] 96 if BackendType.__doc__: 97 BackendType.__doc__ = _backend_type_doc 98 return BackendType[backend_name] 99 100 101def construct_rpc_backend_options( 102 backend, 103 rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC, 104 init_method=rpc_constants.DEFAULT_INIT_METHOD, 105 **kwargs, 106): 107 return backend.value.construct_rpc_backend_options_handler( 108 rpc_timeout, init_method, **kwargs 109 ) 110 111 112def init_backend(backend, *args, **kwargs): 113 return backend.value.init_backend_handler(*args, **kwargs) 114 115 116def _init_process_group(store, rank, world_size): 117 # Initialize ProcessGroup. 118 process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT 119 120 # We're using a bunch of private APIs here since `new_group` requires the 121 # default group to be initialized. 122 group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout) 123 124 assert group is not None, "Failed to initialize default ProcessGroup." 125 126 if (rank != -1) and (rank != group.rank()): 127 raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}") 128 if (world_size != -1) and (world_size != group.size()): 129 raise RuntimeError( 130 f"world_size argument {world_size} doesn't match pg size {group.size()}" 131 ) 132 return group 133 134 135def _tensorpipe_construct_rpc_backend_options_handler( 136 rpc_timeout, 137 init_method, 138 num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS, 139 _transports=None, 140 _channels=None, 141 **kwargs, 142): 143 from . import TensorPipeRpcBackendOptions 144 145 return TensorPipeRpcBackendOptions( 146 rpc_timeout=rpc_timeout, 147 init_method=init_method, 148 num_worker_threads=num_worker_threads, 149 _transports=_transports, 150 _channels=_channels, 151 ) 152 153 154def _tensorpipe_validate_devices(devices, device_count): 155 return all( 156 d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count) 157 for d in devices 158 ) 159 160 161# detect if any worker has invalid device_map configurations, and return 162# reverse device maps 163def _tensorpipe_exchange_and_check_all_device_maps( 164 my_name, my_device_count, my_device_maps, my_devices, group 165): 166 gathered: List[ 167 Tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]] 168 ] = [("", 0, {}, []) for _ in range(group.size())] 169 dist.all_gather_object( 170 gathered, (my_name, my_device_count, my_device_maps, my_devices), group 171 ) 172 all_names = [name for name, _, _, _ in gathered] 173 all_device_counts = {name: count for name, count, _, _ in gathered} 174 all_device_maps = {name: map_ for name, _, map_, _ in gathered} 175 all_devices = {name: devices for name, _, _, devices in gathered} 176 177 _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices) 178 179 # passed all checked, construct reverse mapping and get list of devices handled by this agent 180 reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) 181 my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) 182 return reverse_device_maps, my_devices 183 184 185def _validate_device_maps( 186 all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True 187): 188 for node in all_names: 189 devices = all_devices[node] 190 if len(set(devices)) != len(devices): 191 raise ValueError( 192 f"Node {node} has duplicated devices\n" f"devices = {devices}" 193 ) 194 if not _tensorpipe_validate_devices(devices, all_device_counts[node]): 195 raise ValueError( 196 f"Node {node} has devices with invalid indices\n" 197 f"devices = {devices}\n" 198 f"device count = {all_device_counts[node]}" 199 ) 200 201 for source_node in all_names: 202 # For dynamic group (non-static) do not check the target node name since it may not have joined yet 203 if is_static_group and not set(all_device_maps[source_node].keys()).issubset( 204 all_names 205 ): 206 raise ValueError( 207 f"Node {source_node} has invalid target node names in its device maps\n" 208 f"device maps = {all_device_maps[source_node].keys()}\n" 209 f"node names = {all_names}" 210 ) 211 for target_node, map_ in all_device_maps[source_node].items(): 212 if len(set(map_.values())) != len(map_): 213 raise ValueError( 214 f"Node {source_node} has duplicated target devices " 215 f"in its device map for {target_node}\n" 216 f"device map = {map_}" 217 ) 218 if all_devices[source_node]: 219 if not set(map_.keys()).issubset(all_devices[source_node]): 220 raise ValueError( 221 f"Node {source_node} has unexpected source devices " 222 f"in its device map for {target_node}\n" 223 f"device map = {map_}\n" 224 f"devices = {all_devices[source_node]}" 225 ) 226 elif not _tensorpipe_validate_devices( 227 map_.keys(), all_device_counts[source_node] 228 ): 229 raise ValueError( 230 f"Node {source_node} has source devices with invalid indices " 231 f"in its device map for {target_node}\n" 232 f"device map = {map_}\n" 233 f"device count = {all_device_counts[source_node]}" 234 ) 235 if all_devices.get(target_node, []): 236 if not set(map_.values()).issubset(all_devices[target_node]): 237 raise ValueError( 238 f"Node {source_node} has unexpected target devices " 239 f"in its device map for {target_node}\n" 240 f"device map = {map_}\n" 241 f"devices = {all_devices[target_node]}" 242 ) 243 elif target_node in all_device_counts and not _tensorpipe_validate_devices( 244 map_.values(), all_device_counts[target_node] 245 ): 246 raise ValueError( 247 f"Node {source_node} has target devices with invalid indices " 248 f"in its device map for {target_node}\n" 249 f"device map = {map_}\n" 250 f"device count = {all_device_counts[target_node]}" 251 ) 252 253 254def _create_device_list(my_devices, my_device_maps, reverse_device_maps): 255 if not my_devices: 256 devices_set: Set[torch.device] = set() 257 for map_ in my_device_maps.values(): 258 devices_set.update(map_.keys()) 259 for map_ in reverse_device_maps.values(): 260 devices_set.update(map_.keys()) 261 devices_set.discard(torch.device("cpu")) 262 my_devices = list(devices_set) 263 my_devices = sorted(my_devices, key=lambda d: d.index) 264 return my_devices 265 266 267def _create_reverse_mapping(my_name, all_names, all_device_maps): 268 reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {} 269 for node in all_names: 270 if my_name in all_device_maps[node]: 271 reverse_device_maps[node] = { 272 v: k for k, v in all_device_maps[node][my_name].items() 273 } 274 return reverse_device_maps 275 276 277def _get_device_infos(): 278 from . import TensorPipeAgent 279 280 agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) 281 opts = agent._get_backend_options() 282 device_count = torch.cuda.device_count() 283 if torch.cuda.is_available() and opts.devices: 284 torch.cuda.init() 285 return device_count, opts.device_maps, opts.devices 286 287 288def _set_devices_and_reverse_device_map(agent): 289 from . import TensorPipeAgent 290 291 agent = cast(TensorPipeAgent, agent) 292 # Group state is retrieved from local agent 293 # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid 294 my_worker_info = agent.get_worker_info() 295 my_name = my_worker_info.name 296 all_worker_infos = agent.get_worker_infos() 297 # One round to get device_maps of all workers and construct reverse device maps 298 all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, [] 299 for worker_info in all_worker_infos: 300 worker_name = worker_info.name 301 if worker_name != my_name: 302 # TODO: make async? 303 device_count, device_map, devices = api.rpc_sync( 304 worker_name, _get_device_infos 305 ) 306 else: 307 opts = agent._get_backend_options() 308 device_count, device_map, devices = ( 309 torch.cuda.device_count(), 310 opts.device_maps, 311 opts.devices, 312 ) 313 all_device_counts[worker_name] = device_count 314 all_device_maps[worker_name] = device_map 315 all_devices[worker_name] = devices 316 all_names.append(worker_name) 317 318 _validate_device_maps( 319 all_names, 320 all_device_counts, 321 all_device_maps, 322 all_devices, 323 is_static_group=False, 324 ) 325 reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) 326 327 # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps 328 for worker_name in all_names: 329 # Set device list for each worker 330 all_devices[worker_name] = _create_device_list( 331 all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps 332 ) 333 api.rpc_sync( 334 worker_name, 335 _update_group_membership, 336 args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True), 337 ) 338 339 340def _tensorpipe_init_backend_handler( 341 store, name, rank, world_size, rpc_backend_options 342): 343 from . import TensorPipeAgent, TensorPipeRpcBackendOptions 344 345 if not isinstance(store, dist.Store): 346 raise TypeError(f"`store` must be a c10d::Store. {store}") 347 348 if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions): 349 raise TypeError( 350 f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}" 351 ) 352 353 device_count = torch.cuda.device_count() 354 355 is_static_group = True if world_size else False 356 # world_size is specified so this is a static group (ranks cannot join and leave) 357 if is_static_group: 358 # The agent's join method is required to behave like a barrier and perform 359 # collective operations, for which it relies on a process group, instead of 360 # re-implementing this on top of RPCs. 361 group = _init_process_group(store, rank, world_size) 362 363 reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps( 364 name, 365 device_count, 366 rpc_backend_options.device_maps, 367 rpc_backend_options.devices, 368 group, 369 ) 370 371 if torch.cuda.is_available() and devices: 372 # It's necessary to initialize PyTorch CUDA states here (e.g., 373 # CUDACachingAllocator). If this is missing, we could hit errors like 374 # "allocator not initialized", because other processes might send 375 # CUDA-related RPC request to this process before user code in this 376 # process initializes its PyTorch CUDA states. 377 torch.cuda.init() 378 379 # TODO: add try-except and destroy _agent in all processes if any fails. 380 agent = TensorPipeAgent( 381 store, 382 name, 383 rank, 384 world_size, 385 rpc_backend_options, 386 reverse_device_maps, 387 devices, 388 ) 389 390 api._init_rpc_states(agent) 391 392 # Run one dummy round of RPC to initialize channels/transports. Without 393 # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC 394 # on that process before rpc.shutdown(), as the agent initialization can 395 # take longer than 5s. 396 api._all_gather(None, timeout=rpc_backend_options.rpc_timeout) 397 # Need a barrier here to make sure no peers leave before the rank0 finishes 398 # _all_gather 399 group.barrier().wait() 400 401 return agent 402 # initialization for dynamic rpc (ranks can join and leave) 403 else: 404 with _group_membership_management(store, name, True): 405 # Construct TPAgent with empty reverse_device_map and devices 406 # these properties will be updated after initialization 407 agent = TensorPipeAgent( 408 store, 409 name, 410 rank, 411 world_size, 412 rpc_backend_options, 413 {}, 414 [], 415 ) 416 api._init_rpc_states(agent) 417 418 try: 419 # Notify all workers in group this rank has joined and set devices and reverse_device_map 420 # This is a synchronous operation that completes once all existing ranks are updated 421 _set_devices_and_reverse_device_map(agent) 422 except Exception: 423 api.shutdown() 424 raise 425 return agent 426 427 428register_backend( 429 "TENSORPIPE", 430 _tensorpipe_construct_rpc_backend_options_handler, 431 _tensorpipe_init_backend_handler, 432) 433