1# mypy: allow-untyped-defs 2# Copyright (c) Facebook, Inc. and its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import socket 9from abc import ABC, abstractmethod 10from dataclasses import dataclass 11from typing import Any, Callable, ClassVar, Dict, Optional 12 13from torch.distributed import Store 14from torch.distributed.elastic.utils.distributed import get_free_port as _get_free_port 15 16 17__all__ = [ 18 "RendezvousClosedError", 19 "RendezvousConnectionError", 20 "RendezvousError", 21 "RendezvousGracefulExitError", 22 "RendezvousHandler", 23 "RendezvousHandlerCreator", 24 "RendezvousHandlerRegistry", 25 "RendezvousInfo", 26 "RendezvousParameters", 27 "RendezvousStateError", 28 "RendezvousStoreInfo", 29 "RendezvousTimeoutError", 30 "rendezvous_handler_registry", 31] 32 33 34class RendezvousError(Exception): 35 """Represents the base type for rendezvous errors.""" 36 37 38class RendezvousClosedError(RendezvousError): 39 """Raised when a rendezvous is closed.""" 40 41 42class RendezvousTimeoutError(RendezvousError): 43 """Raised when a rendezvous did not complete on time.""" 44 45 46class RendezvousConnectionError(RendezvousError): 47 """Raised when the connection to a rendezvous backend has failed.""" 48 49 50class RendezvousStateError(RendezvousError): 51 """Raised when the state of a rendezvous is corrupt.""" 52 53 54class RendezvousGracefulExitError(RendezvousError): 55 """Raised when node wasn't not included in rendezvous and gracefully exits. 56 57 Exception is a mechanism to exit the stack, however does not mean a failure. 58 """ 59 60 61@dataclass 62class RendezvousStoreInfo: 63 """Store address and port that can be used to bootstrap trainer distributed comms""" 64 65 MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR" 66 MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT" 67 master_addr: str 68 master_port: int 69 70 @staticmethod 71 def build( 72 rank: int, store: Store, local_addr: Optional[str] 73 ) -> "RendezvousStoreInfo": 74 """Factory method, finds unused new port on rank0 host and addr/port info with all ranks. 75 76 If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor. 77 78 Args: 79 rank: rank of the current node 80 store: store to use for rendezvous 81 local_addr: address of the current node, if not provided will be resolved from hostname 82 """ 83 # TODO swap to collectives comms API 84 if rank == 0: 85 addr = local_addr or socket.getfqdn() 86 port = _get_free_port() 87 store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type] 88 store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type] 89 90 addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8") 91 port = int( 92 store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8") 93 ) 94 return RendezvousStoreInfo(master_addr=addr, master_port=port) 95 96 97class RendezvousInfo: 98 """Holds the information about the rendezvous.""" 99 100 def __init__( 101 self, 102 store: Store, 103 rank: int, 104 world_size: int, 105 bootstrap_store_info: RendezvousStoreInfo, 106 ): 107 self._store = store 108 self._rank = rank 109 self._world_size = world_size 110 self._bootstrap_store_info = bootstrap_store_info 111 112 @property 113 def store(self) -> Store: 114 """Store used by torchelastic control plane""" 115 return self._store 116 117 @property 118 def rank(self) -> int: 119 """Rank within a group""" 120 return self._rank 121 122 @property 123 def world_size(self) -> int: 124 """Global group size""" 125 return self._world_size 126 127 @property 128 def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]: 129 """Store information that can used by trainer code to bootstrap distributed comms.""" 130 return self._bootstrap_store_info 131 132 133class RendezvousHandler(ABC): 134 """Main rendezvous interface. 135 136 Note: 137 Distributed Torch users normally **do not** need to implement their own 138 ``RendezvousHandler``. An implementation based on C10d Store is already 139 provided, and is recommended for most users. 140 """ 141 142 @abstractmethod 143 def get_backend(self) -> str: 144 """Return the name of the rendezvous backend.""" 145 146 @property 147 def use_agent_store(self) -> bool: 148 """Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user 149 applications and will be available during application lifecyle. 150 151 Rendezous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`. 152 Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store. 153 """ 154 return False 155 156 @abstractmethod 157 def next_rendezvous(self) -> RendezvousInfo: 158 """Main entry-point into the rendezvous barrier. 159 160 Blocks until the rendezvous is complete and the current process is 161 included in the formed worker group, or a timeout occurs, or the 162 rendezvous was marked closed. 163 164 Returns: 165 Instance of :py:class:`RendezvousInfo`. 166 167 Raises: 168 RendezvousClosedError: 169 The rendezvous is closed. 170 RendezvousConnectionError: 171 The connection to the rendezvous backend has failed. 172 RendezvousStateError: 173 The rendezvous state is corrupt. 174 RendezvousTimeoutError: 175 The rendezvous did not complete on time. 176 """ 177 178 @abstractmethod 179 def is_closed(self) -> bool: 180 """Check whether the rendezvous has been closed. 181 182 A closed rendezvous means all future attempts to re-rendezvous within 183 same job will fail. 184 185 ``is_closed()`` and :py:meth:`set_closed` have semantics of eventual 186 propagation and should not be used for synchronization. The intention is 187 that if at least one node decides the job is finished, it will close the 188 rendezvous, and other nodes will soon observe this and stop running as 189 well. 190 """ 191 192 @abstractmethod 193 def set_closed(self): 194 """Mark the rendezvous as closed.""" 195 196 @abstractmethod 197 def num_nodes_waiting(self) -> int: 198 """Return the number of nodes who arrived late at the rendezvous 199 barrier, hence were not included in the current worker group. 200 201 Callers should periodically call this method to check whether new 202 nodes are waiting to join the job and if so admit them by calling 203 :py:meth:`next_rendezvous()` (re-rendezvous). 204 """ 205 206 @abstractmethod 207 def get_run_id(self) -> str: 208 """Return the run id of the rendezvous. 209 210 The run id is a user-defined id that uniquely identifies an instance of 211 a distributed application. It typically maps to a job id and is used to 212 allow nodes to join the correct distributed application. 213 """ 214 215 @abstractmethod 216 def shutdown(self) -> bool: 217 """Close all resources that were open for the rendezvous. 218 219 Example:: 220 221 rdzv_handler = ... 222 try: 223 store, rank, world_size = rdzv_handler.next_rendezvous() 224 finally: 225 rdzv_handler.shutdown() 226 """ 227 228 229class RendezvousParameters: 230 """Hold the parameters to construct a :py:class:`RendezvousHandler`. 231 232 Args: 233 backend: 234 The name of the backend to use to handle the rendezvous. 235 endpoint: 236 The endpoint of the rendezvous, usually in form <hostname>[:<port>]. 237 run_id: 238 The id of the rendezvous. 239 min_nodes: 240 The minimum number of nodes to admit to the rendezvous. 241 max_nodes: 242 The maximum number of nodes to admit to the rendezvous. 243 local_addr: 244 The address of the local node. 245 **kwargs: 246 Additional parameters for the specified backend. 247 """ 248 249 def __init__( 250 self, 251 backend: str, 252 endpoint: str, 253 run_id: str, 254 min_nodes: int, 255 max_nodes: int, 256 local_addr: Optional[str] = None, 257 **kwargs, 258 ): 259 if not backend: 260 raise ValueError("The rendezvous backend name must be a non-empty string.") 261 262 if min_nodes < 1: 263 raise ValueError( 264 f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero." 265 ) 266 if max_nodes < min_nodes: 267 raise ValueError( 268 f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or " 269 f"equal to the minimum number of rendezvous nodes ({min_nodes})." 270 ) 271 272 self.backend = backend 273 self.endpoint = endpoint 274 self.run_id = run_id 275 self.min_nodes = min_nodes 276 self.max_nodes = max_nodes 277 self.config = kwargs 278 self.local_addr = local_addr 279 280 def get(self, key: str, default: Any = None) -> Any: 281 """Return the value for ``key`` if ``key`` exists, else ``default``.""" 282 return self.config.get(key, default) 283 284 def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]: 285 """Return the value for ``key`` as a ``bool``.""" 286 value = self.get(key, default) 287 if value is None or isinstance(value, bool): 288 return value 289 if isinstance(value, int): 290 if value == 1: 291 return True 292 if value == 0: 293 return False 294 elif isinstance(value, str): 295 if value.lower() in ["1", "true", "t", "yes", "y"]: 296 return True 297 if value.lower() in ["0", "false", "f", "no", "n"]: 298 return False 299 raise ValueError( 300 f"The rendezvous configuration option '{key}' does not represent a valid boolean value." 301 ) 302 303 def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]: 304 """Return the value for ``key`` as an ``int``.""" 305 value = self.get(key, default) 306 if value is None: 307 return value 308 try: 309 return int(value) 310 except ValueError as e: 311 raise ValueError( 312 f"The rendezvous configuration option '{key}' does not represent a valid integer " 313 "value." 314 ) from e 315 316 317RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler] 318 319 320class RendezvousHandlerRegistry: 321 """Represent a registry of :py:class:`RendezvousHandler` backends.""" 322 323 _registry: Dict[str, RendezvousHandlerCreator] 324 325 def __init__(self) -> None: 326 self._registry = {} 327 328 def register(self, backend: str, creator: RendezvousHandlerCreator) -> None: 329 """Register a new rendezvous backend. 330 331 Args: 332 backend: 333 The name of the backend. 334 creator: 335 The callback to invoke to construct the 336 :py:class:`RendezvousHandler`. 337 """ 338 if not backend: 339 raise ValueError("The rendezvous backend name must be a non-empty string.") 340 341 current_creator: Optional[RendezvousHandlerCreator] 342 try: 343 current_creator = self._registry[backend] 344 except KeyError: 345 current_creator = None 346 347 if current_creator is not None and current_creator != creator: 348 raise ValueError( 349 f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it " 350 f"is already registered with '{current_creator}'." 351 ) 352 353 self._registry[backend] = creator 354 355 def create_handler(self, params: RendezvousParameters) -> RendezvousHandler: 356 """Create a new :py:class:`RendezvousHandler`.""" 357 try: 358 creator = self._registry[params.backend] 359 except KeyError as e: 360 raise ValueError( 361 f"The rendezvous backend '{params.backend}' is not registered. Did you forget " 362 f"to call `{self.register.__name__}`?" 363 ) from e 364 365 handler = creator(params) 366 367 # Do some sanity check. 368 if handler.get_backend() != params.backend: 369 raise RuntimeError( 370 f"The rendezvous backend '{handler.get_backend()}' does not match the requested " 371 f"backend '{params.backend}'." 372 ) 373 374 return handler 375 376 377# The default global registry instance used by launcher scripts to instantiate 378# rendezvous handlers. 379rendezvous_handler_registry = RendezvousHandlerRegistry() 380