xref: /aosp_15_r20/external/pytorch/torch/distributed/launcher/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import sys
10import uuid
11from dataclasses import dataclass, field
12from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
14import torch.distributed.elastic.rendezvous.registry as rdzv_registry
15from torch.distributed.elastic import events, metrics
16from torch.distributed.elastic.agent.server.api import WorkerSpec
17from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
18from torch.distributed.elastic.multiprocessing import (
19    DefaultLogsSpecs,
20    LogsSpecs,
21    SignalException,
22)
23from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
24from torch.distributed.elastic.rendezvous import RendezvousParameters
25from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
26from torch.distributed.elastic.utils.logging import get_logger
27
28
29__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"]
30
31logger = get_logger(__name__)
32
33
34@dataclass
35class LaunchConfig:
36    """
37    Creates a rendezvous config.
38
39    Args:
40        min_nodes: Minimum amount of nodes that the user function will
41                        be launched on. Elastic agent ensures that the user
42                        function start only when the min_nodes amount enters
43                        the rendezvous.
44        max_nodes: Maximum amount of nodes that the user function
45                        will be launched on.
46        nproc_per_node: On each node the elastic agent will launch
47                            this amount of workers that will execute user
48                            defined function.
49        rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
50        rdzv_endpoint: The endpoint of the rdzv sync. storage.
51        rdzv_configs: Key, value pair that specifies rendezvous specific configuration.
52        rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going
53            to be removed in future versions, see the note below. The default timeout is 900 seconds.
54        run_id: The unique run id of the job (if not passed a unique one will be
55                deduced from run environment - flow workflow id in flow - or auto generated).
56        role: User defined role of the worker (defaults to "trainer").
57        max_restarts: The maximum amount of restarts that elastic agent will conduct
58                    on workers before failure.
59        monitor_interval: The interval in seconds that is used by the elastic_agent
60                        as a period of monitoring workers.
61        start_method: The method is used by the elastic agent to start the
62                    workers (spawn, fork, forkserver).
63        metrics_cfg: configuration to initialize metrics.
64        local_addr: address of the local node if any. If not set, a lookup on the local
65                machine's FQDN will be performed.
66        local_ranks_filter: ranks for which to show logs in console. If not set, show from all.
67    ..note:
68        `rdzv_timeout` is a legacy argument that will be removed in future.
69        Set the timeout via `rdzv_configs['timeout']`
70
71    """
72
73    min_nodes: int
74    max_nodes: int
75    nproc_per_node: int
76    logs_specs: Optional[LogsSpecs] = None
77    run_id: str = ""
78    role: str = "default_role"
79    rdzv_endpoint: str = ""
80    rdzv_backend: str = "etcd"
81    rdzv_configs: Dict[str, Any] = field(default_factory=dict)
82    rdzv_timeout: int = -1
83    max_restarts: int = 3
84    monitor_interval: float = 0.1
85    start_method: str = "spawn"
86    log_line_prefix_template: Optional[str] = None
87    metrics_cfg: Dict[str, str] = field(default_factory=dict)
88    local_addr: Optional[str] = None
89
90    def __post_init__(self):
91        default_timeout = 900
92        if self.rdzv_timeout != -1:
93            self.rdzv_configs["timeout"] = self.rdzv_timeout
94        elif "timeout" not in self.rdzv_configs:
95            self.rdzv_configs["timeout"] = default_timeout
96
97        # Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage
98        if self.logs_specs is None:
99            self.logs_specs = DefaultLogsSpecs()
100
101
102class elastic_launch:
103    """
104    Launches an torchelastic agent on the container that invoked the entrypoint.
105
106        1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
107           ``entrypoint`` can be a function or a command.
108        2. The return value is a map of each worker's output mapped
109           by their respective global rank.
110
111    Usage
112
113    ::
114
115    def worker_fn(foo):
116        # ...
117
118    def main():
119        # entrypoint is a function.
120        outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
121        # return rank 0's output
122        return outputs[0]
123
124        # entrypoint is a command and ``script.py`` is the python module.
125        outputs = elastic_launch(LaunchConfig, "script.py")(args)
126        outputs = elastic_launch(LaunchConfig, "python")("script.py")
127    """
128
129    def __init__(
130        self,
131        config: LaunchConfig,
132        entrypoint: Union[Callable, str, None],
133    ):
134        self._config = config
135        self._entrypoint = entrypoint
136
137    def __call__(self, *args):
138        return launch_agent(self._config, self._entrypoint, list(args))
139
140
141def _get_entrypoint_name(
142    entrypoint: Union[Callable, str, None], args: List[Any]
143) -> str:
144    """Retrieve entrypoint name with the rule:
145    1. If entrypoint is a function, use ``entrypoint.__qualname__``.
146    2. If entrypoint is a string, check its value:
147        2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
148            which does not start with hifen letter (for example, "-u" will be skipped).
149        2.2 otherwise, use ``entrypoint`` value.
150    3. Otherwise, return empty string.
151    """
152    if isinstance(entrypoint, Callable):  # type: ignore[arg-type]
153        return entrypoint.__name__  # type: ignore[union-attr]
154    elif isinstance(entrypoint, str):
155        if entrypoint == sys.executable:
156            return next((arg for arg in args if arg[0] != "-"), "")
157        else:
158            return entrypoint
159    else:
160        return ""
161
162
163def _get_addr_and_port(
164    rdzv_parameters: RendezvousParameters,
165) -> Tuple[Optional[str], Optional[int]]:
166    if rdzv_parameters.backend != "static":
167        return (None, None)
168    endpoint = rdzv_parameters.endpoint
169    endpoint = endpoint.strip()
170    if not endpoint:
171        raise ValueError(
172            "Endpoint is missing in endpoint. Try to add --master-addr and --master-port"
173        )
174    master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
175    if master_port == -1:
176        raise ValueError(
177            f"port is missing in endpoint: {endpoint}. Try to specify --master-port"
178        )
179    return (master_addr, master_port)
180
181
182def launch_agent(
183    config: LaunchConfig,
184    entrypoint: Union[Callable, str, None],
185    args: List[Any],
186) -> Dict[int, Any]:
187    if not config.run_id:
188        run_id = str(uuid.uuid4().int)
189        logger.warning("config has no run_id, generated a random run_id: %s", run_id)
190        config.run_id = run_id
191
192    entrypoint_name = _get_entrypoint_name(entrypoint, args)
193
194    logger.info(
195        "Starting elastic_operator with launch configs:\n"
196        "  entrypoint       : %(entrypoint)s\n"
197        "  min_nodes        : %(min_nodes)s\n"
198        "  max_nodes        : %(max_nodes)s\n"
199        "  nproc_per_node   : %(nproc_per_node)s\n"
200        "  run_id           : %(run_id)s\n"
201        "  rdzv_backend     : %(rdzv_backend)s\n"
202        "  rdzv_endpoint    : %(rdzv_endpoint)s\n"
203        "  rdzv_configs     : %(rdzv_configs)s\n"
204        "  max_restarts     : %(max_restarts)s\n"
205        "  monitor_interval : %(monitor_interval)s\n"
206        "  log_dir          : %(log_dir)s\n"
207        "  metrics_cfg      : %(metrics_cfg)s\n",
208        {
209            "entrypoint": entrypoint_name,
210            "min_nodes": config.min_nodes,
211            "max_nodes": config.max_nodes,
212            "nproc_per_node": config.nproc_per_node,
213            "run_id": config.run_id,
214            "rdzv_backend": config.rdzv_backend,
215            "rdzv_endpoint": config.rdzv_endpoint,
216            "rdzv_configs": config.rdzv_configs,
217            "max_restarts": config.max_restarts,
218            "monitor_interval": config.monitor_interval,
219            "log_dir": config.logs_specs.root_log_dir,  # type: ignore[union-attr]
220            "metrics_cfg": config.metrics_cfg,
221        },
222    )
223
224    rdzv_parameters = RendezvousParameters(
225        backend=config.rdzv_backend,
226        endpoint=config.rdzv_endpoint,
227        run_id=config.run_id,
228        min_nodes=config.min_nodes,
229        max_nodes=config.max_nodes,
230        local_addr=config.local_addr,
231        **config.rdzv_configs,
232    )
233
234    master_addr, master_port = _get_addr_and_port(rdzv_parameters)
235
236    spec = WorkerSpec(
237        role=config.role,
238        local_world_size=config.nproc_per_node,
239        entrypoint=entrypoint,
240        args=tuple(args),
241        rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
242        max_restarts=config.max_restarts,
243        monitor_interval=config.monitor_interval,
244        master_addr=master_addr,
245        master_port=master_port,
246        local_addr=config.local_addr,
247    )
248
249    agent = LocalElasticAgent(
250        spec=spec,
251        logs_specs=config.logs_specs,  # type: ignore[arg-type]
252        start_method=config.start_method,
253        log_line_prefix_template=config.log_line_prefix_template,
254    )
255
256    shutdown_rdzv = True
257    try:
258        metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
259
260        result = agent.run()
261        # records that agent.run() has succeeded NOT that workers have succeeded
262        events.record(agent.get_event_succeeded())
263
264        if result.is_failed():
265            # ChildFailedError is treated specially by @record
266            # if the error files for the failed children exist
267            # @record will copy the first error (root cause)
268            # to the error file of the launcher process.
269            raise ChildFailedError(
270                name=entrypoint_name,
271                failures=result.failures,
272            )
273
274        return result.return_values
275    except ChildFailedError:
276        raise
277    except SignalException:
278        # when the agent dies with a signal do NOT shutdown the rdzv_handler
279        # since this closes the rendezvous on this rdzv_id permanently and
280        # prevents any additional scaling events
281        shutdown_rdzv = False
282        events.record(agent.get_event_failed())
283        raise
284    except Exception:
285        events.record(agent.get_event_failed())
286        raise
287    finally:
288        if shutdown_rdzv:
289            spec.rdzv_handler.shutdown()
290