1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import sys 10import uuid 11from dataclasses import dataclass, field 12from typing import Any, Callable, Dict, List, Optional, Tuple, Union 13 14import torch.distributed.elastic.rendezvous.registry as rdzv_registry 15from torch.distributed.elastic import events, metrics 16from torch.distributed.elastic.agent.server.api import WorkerSpec 17from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent 18from torch.distributed.elastic.multiprocessing import ( 19 DefaultLogsSpecs, 20 LogsSpecs, 21 SignalException, 22) 23from torch.distributed.elastic.multiprocessing.errors import ChildFailedError 24from torch.distributed.elastic.rendezvous import RendezvousParameters 25from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint 26from torch.distributed.elastic.utils.logging import get_logger 27 28 29__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] 30 31logger = get_logger(__name__) 32 33 34@dataclass 35class LaunchConfig: 36 """ 37 Creates a rendezvous config. 38 39 Args: 40 min_nodes: Minimum amount of nodes that the user function will 41 be launched on. Elastic agent ensures that the user 42 function start only when the min_nodes amount enters 43 the rendezvous. 44 max_nodes: Maximum amount of nodes that the user function 45 will be launched on. 46 nproc_per_node: On each node the elastic agent will launch 47 this amount of workers that will execute user 48 defined function. 49 rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). 50 rdzv_endpoint: The endpoint of the rdzv sync. storage. 51 rdzv_configs: Key, value pair that specifies rendezvous specific configuration. 52 rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going 53 to be removed in future versions, see the note below. The default timeout is 900 seconds. 54 run_id: The unique run id of the job (if not passed a unique one will be 55 deduced from run environment - flow workflow id in flow - or auto generated). 56 role: User defined role of the worker (defaults to "trainer"). 57 max_restarts: The maximum amount of restarts that elastic agent will conduct 58 on workers before failure. 59 monitor_interval: The interval in seconds that is used by the elastic_agent 60 as a period of monitoring workers. 61 start_method: The method is used by the elastic agent to start the 62 workers (spawn, fork, forkserver). 63 metrics_cfg: configuration to initialize metrics. 64 local_addr: address of the local node if any. If not set, a lookup on the local 65 machine's FQDN will be performed. 66 local_ranks_filter: ranks for which to show logs in console. If not set, show from all. 67 ..note: 68 `rdzv_timeout` is a legacy argument that will be removed in future. 69 Set the timeout via `rdzv_configs['timeout']` 70 71 """ 72 73 min_nodes: int 74 max_nodes: int 75 nproc_per_node: int 76 logs_specs: Optional[LogsSpecs] = None 77 run_id: str = "" 78 role: str = "default_role" 79 rdzv_endpoint: str = "" 80 rdzv_backend: str = "etcd" 81 rdzv_configs: Dict[str, Any] = field(default_factory=dict) 82 rdzv_timeout: int = -1 83 max_restarts: int = 3 84 monitor_interval: float = 0.1 85 start_method: str = "spawn" 86 log_line_prefix_template: Optional[str] = None 87 metrics_cfg: Dict[str, str] = field(default_factory=dict) 88 local_addr: Optional[str] = None 89 90 def __post_init__(self): 91 default_timeout = 900 92 if self.rdzv_timeout != -1: 93 self.rdzv_configs["timeout"] = self.rdzv_timeout 94 elif "timeout" not in self.rdzv_configs: 95 self.rdzv_configs["timeout"] = default_timeout 96 97 # Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage 98 if self.logs_specs is None: 99 self.logs_specs = DefaultLogsSpecs() 100 101 102class elastic_launch: 103 """ 104 Launches an torchelastic agent on the container that invoked the entrypoint. 105 106 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ 107 ``entrypoint`` can be a function or a command. 108 2. The return value is a map of each worker's output mapped 109 by their respective global rank. 110 111 Usage 112 113 :: 114 115 def worker_fn(foo): 116 # ... 117 118 def main(): 119 # entrypoint is a function. 120 outputs = elastic_launch(LaunchConfig, worker_fn)(foo) 121 # return rank 0's output 122 return outputs[0] 123 124 # entrypoint is a command and ``script.py`` is the python module. 125 outputs = elastic_launch(LaunchConfig, "script.py")(args) 126 outputs = elastic_launch(LaunchConfig, "python")("script.py") 127 """ 128 129 def __init__( 130 self, 131 config: LaunchConfig, 132 entrypoint: Union[Callable, str, None], 133 ): 134 self._config = config 135 self._entrypoint = entrypoint 136 137 def __call__(self, *args): 138 return launch_agent(self._config, self._entrypoint, list(args)) 139 140 141def _get_entrypoint_name( 142 entrypoint: Union[Callable, str, None], args: List[Any] 143) -> str: 144 """Retrieve entrypoint name with the rule: 145 1. If entrypoint is a function, use ``entrypoint.__qualname__``. 146 2. If entrypoint is a string, check its value: 147 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` 148 which does not start with hifen letter (for example, "-u" will be skipped). 149 2.2 otherwise, use ``entrypoint`` value. 150 3. Otherwise, return empty string. 151 """ 152 if isinstance(entrypoint, Callable): # type: ignore[arg-type] 153 return entrypoint.__name__ # type: ignore[union-attr] 154 elif isinstance(entrypoint, str): 155 if entrypoint == sys.executable: 156 return next((arg for arg in args if arg[0] != "-"), "") 157 else: 158 return entrypoint 159 else: 160 return "" 161 162 163def _get_addr_and_port( 164 rdzv_parameters: RendezvousParameters, 165) -> Tuple[Optional[str], Optional[int]]: 166 if rdzv_parameters.backend != "static": 167 return (None, None) 168 endpoint = rdzv_parameters.endpoint 169 endpoint = endpoint.strip() 170 if not endpoint: 171 raise ValueError( 172 "Endpoint is missing in endpoint. Try to add --master-addr and --master-port" 173 ) 174 master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) 175 if master_port == -1: 176 raise ValueError( 177 f"port is missing in endpoint: {endpoint}. Try to specify --master-port" 178 ) 179 return (master_addr, master_port) 180 181 182def launch_agent( 183 config: LaunchConfig, 184 entrypoint: Union[Callable, str, None], 185 args: List[Any], 186) -> Dict[int, Any]: 187 if not config.run_id: 188 run_id = str(uuid.uuid4().int) 189 logger.warning("config has no run_id, generated a random run_id: %s", run_id) 190 config.run_id = run_id 191 192 entrypoint_name = _get_entrypoint_name(entrypoint, args) 193 194 logger.info( 195 "Starting elastic_operator with launch configs:\n" 196 " entrypoint : %(entrypoint)s\n" 197 " min_nodes : %(min_nodes)s\n" 198 " max_nodes : %(max_nodes)s\n" 199 " nproc_per_node : %(nproc_per_node)s\n" 200 " run_id : %(run_id)s\n" 201 " rdzv_backend : %(rdzv_backend)s\n" 202 " rdzv_endpoint : %(rdzv_endpoint)s\n" 203 " rdzv_configs : %(rdzv_configs)s\n" 204 " max_restarts : %(max_restarts)s\n" 205 " monitor_interval : %(monitor_interval)s\n" 206 " log_dir : %(log_dir)s\n" 207 " metrics_cfg : %(metrics_cfg)s\n", 208 { 209 "entrypoint": entrypoint_name, 210 "min_nodes": config.min_nodes, 211 "max_nodes": config.max_nodes, 212 "nproc_per_node": config.nproc_per_node, 213 "run_id": config.run_id, 214 "rdzv_backend": config.rdzv_backend, 215 "rdzv_endpoint": config.rdzv_endpoint, 216 "rdzv_configs": config.rdzv_configs, 217 "max_restarts": config.max_restarts, 218 "monitor_interval": config.monitor_interval, 219 "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] 220 "metrics_cfg": config.metrics_cfg, 221 }, 222 ) 223 224 rdzv_parameters = RendezvousParameters( 225 backend=config.rdzv_backend, 226 endpoint=config.rdzv_endpoint, 227 run_id=config.run_id, 228 min_nodes=config.min_nodes, 229 max_nodes=config.max_nodes, 230 local_addr=config.local_addr, 231 **config.rdzv_configs, 232 ) 233 234 master_addr, master_port = _get_addr_and_port(rdzv_parameters) 235 236 spec = WorkerSpec( 237 role=config.role, 238 local_world_size=config.nproc_per_node, 239 entrypoint=entrypoint, 240 args=tuple(args), 241 rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), 242 max_restarts=config.max_restarts, 243 monitor_interval=config.monitor_interval, 244 master_addr=master_addr, 245 master_port=master_port, 246 local_addr=config.local_addr, 247 ) 248 249 agent = LocalElasticAgent( 250 spec=spec, 251 logs_specs=config.logs_specs, # type: ignore[arg-type] 252 start_method=config.start_method, 253 log_line_prefix_template=config.log_line_prefix_template, 254 ) 255 256 shutdown_rdzv = True 257 try: 258 metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) 259 260 result = agent.run() 261 # records that agent.run() has succeeded NOT that workers have succeeded 262 events.record(agent.get_event_succeeded()) 263 264 if result.is_failed(): 265 # ChildFailedError is treated specially by @record 266 # if the error files for the failed children exist 267 # @record will copy the first error (root cause) 268 # to the error file of the launcher process. 269 raise ChildFailedError( 270 name=entrypoint_name, 271 failures=result.failures, 272 ) 273 274 return result.return_values 275 except ChildFailedError: 276 raise 277 except SignalException: 278 # when the agent dies with a signal do NOT shutdown the rdzv_handler 279 # since this closes the rendezvous on this rdzv_id permanently and 280 # prevents any additional scaling events 281 shutdown_rdzv = False 282 events.record(agent.get_event_failed()) 283 raise 284 except Exception: 285 events.record(agent.get_event_failed()) 286 raise 287 finally: 288 if shutdown_rdzv: 289 spec.rdzv_handler.shutdown() 290