1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workertry: 3*da0073e9SAndroid Build Coastguard Worker from urllib.parse import urlparse, urlunparse 4*da0073e9SAndroid Build Coastguard Workerexcept ImportError as e: 5*da0073e9SAndroid Build Coastguard Worker raise ImportError( 6*da0073e9SAndroid Build Coastguard Worker "urllib cannot be found, urlparse from python2 is no longer supported." 7*da0073e9SAndroid Build Coastguard Worker ) from e 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport numbers 10*da0073e9SAndroid Build Coastguard Workerimport os 11*da0073e9SAndroid Build Coastguard Workerimport sys 12*da0073e9SAndroid Build Coastguard Workerfrom datetime import timedelta 13*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Dict, Iterator, Optional, Tuple 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed import FileStore, PrefixStore, Store, TCPStore 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerfrom .constants import default_pg_timeout 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker_rendezvous_handlers: Dict[str, Callable[..., Iterator[Tuple[Store, int, int]]]] = {} 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker__all__ = ["register_rendezvous_handler", "rendezvous"] 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerdef register_rendezvous_handler(scheme, handler): 26*da0073e9SAndroid Build Coastguard Worker """ 27*da0073e9SAndroid Build Coastguard Worker Register a new rendezvous handler. 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker Before we can run collective algorithms, participating processes 30*da0073e9SAndroid Build Coastguard Worker need to find each other and exchange information to be able to 31*da0073e9SAndroid Build Coastguard Worker communicate. We call this process rendezvous. 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker The outcome of the rendezvous process is a triplet containing a 34*da0073e9SAndroid Build Coastguard Worker shared key/value store, the rank of the process, and the total 35*da0073e9SAndroid Build Coastguard Worker number of participating processes. 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker If none of the bundled rendezvous methods apply to your execution 38*da0073e9SAndroid Build Coastguard Worker environment you can opt to register your own rendezvous handler. 39*da0073e9SAndroid Build Coastguard Worker Pick a unique name and use the URL scheme to identify it when 40*da0073e9SAndroid Build Coastguard Worker calling the `rendezvous()` function. 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker Args: 43*da0073e9SAndroid Build Coastguard Worker scheme (str): URL scheme to identify your rendezvous handler. 44*da0073e9SAndroid Build Coastguard Worker handler (function): Handler that is invoked when the 45*da0073e9SAndroid Build Coastguard Worker `rendezvous()` function is called with a URL that uses 46*da0073e9SAndroid Build Coastguard Worker the corresponding scheme. It must be a generator function 47*da0073e9SAndroid Build Coastguard Worker that yields the triplet. 48*da0073e9SAndroid Build Coastguard Worker """ 49*da0073e9SAndroid Build Coastguard Worker global _rendezvous_handlers 50*da0073e9SAndroid Build Coastguard Worker if scheme in _rendezvous_handlers: 51*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered") 52*da0073e9SAndroid Build Coastguard Worker _rendezvous_handlers[scheme] = handler 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker# Query will have format "rank=0&world_size=1" and is 56*da0073e9SAndroid Build Coastguard Worker# converted into {"rank": 0, "world_size": 1} 57*da0073e9SAndroid Build Coastguard Workerdef _query_to_dict(query: str) -> Dict[str, str]: 58*da0073e9SAndroid Build Coastguard Worker return { 59*da0073e9SAndroid Build Coastguard Worker pair[0]: pair[1] 60*da0073e9SAndroid Build Coastguard Worker for pair in (pair.split("=") for pair in filter(None, query.split("&"))) 61*da0073e9SAndroid Build Coastguard Worker } 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Workerdef _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool: 65*da0073e9SAndroid Build Coastguard Worker # libuv is the default backend for TCPStore. To enable the non-libuv backend, 66*da0073e9SAndroid Build Coastguard Worker # user can explicitly specify ``use_libuv=0`` in the URL parameter. 67*da0073e9SAndroid Build Coastguard Worker return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1" 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Workerdef _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs): 71*da0073e9SAndroid Build Coastguard Worker result = urlparse(url) 72*da0073e9SAndroid Build Coastguard Worker if world_size_opt is None: 73*da0073e9SAndroid Build Coastguard Worker world_size = -1 74*da0073e9SAndroid Build Coastguard Worker if result.scheme == "env": 75*da0073e9SAndroid Build Coastguard Worker rank = int(os.environ.get("RANK", rank)) 76*da0073e9SAndroid Build Coastguard Worker # If the world_size env variable is not present then it is a dynamic group 77*da0073e9SAndroid Build Coastguard Worker world_size = int(os.environ.get("WORLD_SIZE", world_size)) 78*da0073e9SAndroid Build Coastguard Worker else: 79*da0073e9SAndroid Build Coastguard Worker world_size = world_size_opt 80*da0073e9SAndroid Build Coastguard Worker if rank != -1 or world_size != -1 or world_size_opt is None: 81*da0073e9SAndroid Build Coastguard Worker query_dict = _query_to_dict(result.query) 82*da0073e9SAndroid Build Coastguard Worker assert ( 83*da0073e9SAndroid Build Coastguard Worker "rank" not in query_dict and "world_size" not in query_dict 84*da0073e9SAndroid Build Coastguard Worker ), f"The url: {url} has node-specific arguments(rank, world_size) already." 85*da0073e9SAndroid Build Coastguard Worker if rank != -1: 86*da0073e9SAndroid Build Coastguard Worker query_dict["rank"] = str(rank) 87*da0073e9SAndroid Build Coastguard Worker if world_size != -1 or world_size_opt is None: 88*da0073e9SAndroid Build Coastguard Worker query_dict["world_size"] = str(world_size) 89*da0073e9SAndroid Build Coastguard Worker result = result._replace( 90*da0073e9SAndroid Build Coastguard Worker query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}" 91*da0073e9SAndroid Build Coastguard Worker ) 92*da0073e9SAndroid Build Coastguard Worker url = urlunparse(result) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker if result.scheme not in _rendezvous_handlers: 95*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"No rendezvous handler for {result.scheme}://") 96*da0073e9SAndroid Build Coastguard Worker return _rendezvous_handlers[result.scheme](url, **kwargs) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Workerdef rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): 100*da0073e9SAndroid Build Coastguard Worker if not isinstance(url, (str, bytes)): 101*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"`url` must be a string. {type(url)}: {url}") 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker if not isinstance(rank, numbers.Integral): 104*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"`rank` must be an integer. {rank}") 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker if not isinstance(world_size, numbers.Integral): 107*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"`world_size` must be an integer. {world_size}") 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker return _rendezvous_helper(url, rank, world_size, **kwargs) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Workerdef _create_store_from_options(backend_options, rank): 113*da0073e9SAndroid Build Coastguard Worker store, _, _ = next(_rendezvous_helper(backend_options.init_method, rank, None)) 114*da0073e9SAndroid Build Coastguard Worker return store 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Workerdef _rendezvous_error(msg): 118*da0073e9SAndroid Build Coastguard Worker return ValueError("Error initializing torch.distributed using " + msg) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Workerdef _file_rendezvous_handler(url: str, **kwargs): 122*da0073e9SAndroid Build Coastguard Worker def _error(msg): 123*da0073e9SAndroid Build Coastguard Worker return _rendezvous_error("file:// rendezvous: " + msg) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker result = urlparse(url) 126*da0073e9SAndroid Build Coastguard Worker path = result.path 127*da0073e9SAndroid Build Coastguard Worker if sys.platform == "win32": 128*da0073e9SAndroid Build Coastguard Worker import urllib.request 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker full_path = result.netloc + result.path 131*da0073e9SAndroid Build Coastguard Worker path = urllib.request.url2pathname(full_path) 132*da0073e9SAndroid Build Coastguard Worker if path: 133*da0073e9SAndroid Build Coastguard Worker # Normalizing an empty string produces ".", which is not expected. 134*da0073e9SAndroid Build Coastguard Worker path = os.path.normpath(path) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker if not path: 137*da0073e9SAndroid Build Coastguard Worker raise _error("path missing") 138*da0073e9SAndroid Build Coastguard Worker query_dict = _query_to_dict(result.query) 139*da0073e9SAndroid Build Coastguard Worker if "rank" not in query_dict: 140*da0073e9SAndroid Build Coastguard Worker raise _error("rank parameter missing") 141*da0073e9SAndroid Build Coastguard Worker if "world_size" not in query_dict: 142*da0073e9SAndroid Build Coastguard Worker raise _error("world size parameter missing") 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker rank = int(query_dict["rank"]) 145*da0073e9SAndroid Build Coastguard Worker world_size = int(query_dict["world_size"]) 146*da0073e9SAndroid Build Coastguard Worker store = FileStore(path, world_size) 147*da0073e9SAndroid Build Coastguard Worker yield (store, rank, world_size) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker # If this configuration is invalidated, there is nothing we can do about it 150*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Unable to perform rerendezvous using file:// method") 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Workerdef _torchelastic_use_agent_store() -> bool: 154*da0073e9SAndroid Build Coastguard Worker return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Workerdef _create_c10d_store( 158*da0073e9SAndroid Build Coastguard Worker hostname, port, rank, world_size, timeout, use_libuv=True 159*da0073e9SAndroid Build Coastguard Worker) -> Store: 160*da0073e9SAndroid Build Coastguard Worker """ 161*da0073e9SAndroid Build Coastguard Worker Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker The TCPStore server is assumed to be hosted 164*da0073e9SAndroid Build Coastguard Worker on ``hostname:port``. 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker By default, the TCPStore server uses the asynchronous implementation 167*da0073e9SAndroid Build Coastguard Worker ``LibUVStoreDaemon`` which utilizes libuv. 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that 170*da0073e9SAndroid Build Coastguard Worker the agent leader (node rank 0) hosts the TCPStore server (for which the 171*da0073e9SAndroid Build Coastguard Worker endpoint is specified by the given ``hostname:port``). Hence 172*da0073e9SAndroid Build Coastguard Worker ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``). 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host 175*da0073e9SAndroid Build Coastguard Worker the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname 176*da0073e9SAndroid Build Coastguard Worker and port are correctly passed via ``hostname`` and ``port``. All 177*da0073e9SAndroid Build Coastguard Worker non-zero ranks will create and return a TCPStore client. 178*da0073e9SAndroid Build Coastguard Worker """ 179*da0073e9SAndroid Build Coastguard Worker # check if port is uint16_t 180*da0073e9SAndroid Build Coastguard Worker if not 0 <= port < 2**16: 181*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"port must have value from 0 to 65535 but was {port}.") 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker if _torchelastic_use_agent_store(): 184*da0073e9SAndroid Build Coastguard Worker attempt = os.environ["TORCHELASTIC_RESTART_COUNT"] 185*da0073e9SAndroid Build Coastguard Worker tcp_store = TCPStore(hostname, port, world_size, False, timeout) 186*da0073e9SAndroid Build Coastguard Worker return PrefixStore(f"/worker/attempt_{attempt}", tcp_store) 187*da0073e9SAndroid Build Coastguard Worker else: 188*da0073e9SAndroid Build Coastguard Worker start_daemon = rank == 0 189*da0073e9SAndroid Build Coastguard Worker return TCPStore( 190*da0073e9SAndroid Build Coastguard Worker hostname, 191*da0073e9SAndroid Build Coastguard Worker port, 192*da0073e9SAndroid Build Coastguard Worker world_size, 193*da0073e9SAndroid Build Coastguard Worker start_daemon, 194*da0073e9SAndroid Build Coastguard Worker timeout, 195*da0073e9SAndroid Build Coastguard Worker multi_tenant=True, 196*da0073e9SAndroid Build Coastguard Worker use_libuv=use_libuv, 197*da0073e9SAndroid Build Coastguard Worker ) 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Workerdef _tcp_rendezvous_handler( 201*da0073e9SAndroid Build Coastguard Worker url: str, timeout: timedelta = default_pg_timeout, **kwargs 202*da0073e9SAndroid Build Coastguard Worker): 203*da0073e9SAndroid Build Coastguard Worker def _error(msg): 204*da0073e9SAndroid Build Coastguard Worker return _rendezvous_error("tcp:// rendezvous: " + msg) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker result = urlparse(url) 207*da0073e9SAndroid Build Coastguard Worker if not result.port: 208*da0073e9SAndroid Build Coastguard Worker raise _error("port number missing") 209*da0073e9SAndroid Build Coastguard Worker query_dict = _query_to_dict(result.query) 210*da0073e9SAndroid Build Coastguard Worker if "rank" not in query_dict: 211*da0073e9SAndroid Build Coastguard Worker raise _error("rank parameter missing") 212*da0073e9SAndroid Build Coastguard Worker if "world_size" not in query_dict: 213*da0073e9SAndroid Build Coastguard Worker raise _error("world size parameter missing") 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker rank = int(query_dict["rank"]) 216*da0073e9SAndroid Build Coastguard Worker world_size = int(query_dict["world_size"]) 217*da0073e9SAndroid Build Coastguard Worker use_libuv = _get_use_libuv_from_query_dict(query_dict) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker assert result.hostname is not None 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker store = _create_c10d_store( 222*da0073e9SAndroid Build Coastguard Worker result.hostname, result.port, rank, world_size, timeout, use_libuv 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker yield (store, rank, world_size) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker # If this configuration is invalidated, there is nothing we can do about it 228*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Unable to perform re-rendezvous using tcp:// method") 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Workerdef _env_rendezvous_handler( 232*da0073e9SAndroid Build Coastguard Worker url: str, timeout: timedelta = default_pg_timeout, **kwargs 233*da0073e9SAndroid Build Coastguard Worker): 234*da0073e9SAndroid Build Coastguard Worker def _error(msg): 235*da0073e9SAndroid Build Coastguard Worker return _rendezvous_error("env:// rendezvous: " + msg) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker def _env_error(var): 238*da0073e9SAndroid Build Coastguard Worker return _error(f"environment variable {var} expected, but not set") 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker def _get_env_or_raise(env_var: str) -> str: 241*da0073e9SAndroid Build Coastguard Worker env_val = os.environ.get(env_var, None) 242*da0073e9SAndroid Build Coastguard Worker if not env_val: 243*da0073e9SAndroid Build Coastguard Worker raise _env_error(env_var) 244*da0073e9SAndroid Build Coastguard Worker else: 245*da0073e9SAndroid Build Coastguard Worker return env_val 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker result = urlparse(url) 248*da0073e9SAndroid Build Coastguard Worker query_dict = _query_to_dict(result.query) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker rank: int 251*da0073e9SAndroid Build Coastguard Worker world_size: int 252*da0073e9SAndroid Build Coastguard Worker master_port: int 253*da0073e9SAndroid Build Coastguard Worker master_addr: str 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker if "rank" in query_dict: 256*da0073e9SAndroid Build Coastguard Worker rank = int(query_dict["rank"]) 257*da0073e9SAndroid Build Coastguard Worker else: 258*da0073e9SAndroid Build Coastguard Worker rank = int(_get_env_or_raise("RANK")) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker if "world_size" in query_dict: 261*da0073e9SAndroid Build Coastguard Worker world_size = int(query_dict["world_size"]) 262*da0073e9SAndroid Build Coastguard Worker else: 263*da0073e9SAndroid Build Coastguard Worker world_size = int(_get_env_or_raise("WORLD_SIZE")) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker master_addr = _get_env_or_raise("MASTER_ADDR") 266*da0073e9SAndroid Build Coastguard Worker master_port = int(_get_env_or_raise("MASTER_PORT")) 267*da0073e9SAndroid Build Coastguard Worker use_libuv = _get_use_libuv_from_query_dict(query_dict) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker store = _create_c10d_store( 270*da0073e9SAndroid Build Coastguard Worker master_addr, master_port, rank, world_size, timeout, use_libuv 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker yield (store, rank, world_size) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker # If this configuration is invalidated, there is nothing we can do about it 276*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Unable to perform re-rendezvous using env:// method") 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Workerregister_rendezvous_handler("tcp", _tcp_rendezvous_handler) 280*da0073e9SAndroid Build Coastguard Workerregister_rendezvous_handler("env", _env_rendezvous_handler) 281*da0073e9SAndroid Build Coastguard Workerregister_rendezvous_handler("file", _file_rendezvous_handler) 282