xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10"""
11Each host in a distributed PyTorch job runs with a single TorchElastic agent,
12and multiple workers (as children processes of the TorchElastic agent).
13Since the workers are user-provided (your PyTorch script/job), TorchElastic
14has a way to propagate errors on the trainers through the agent and up to the
15scheduler, which ultimately informs the end-user about the state of the job
16and applies any retry policies.
17
18TorchElastic categorizes errors into 3 categories:
19
20+----------------+----------------+--------------------------------------------------------------+
21| Category       | Sub-Category   |  Description                                                 |
22+================+================+==============================================================+
23| User Error     | Input Error    | invalid inputs to TorchElastic APIs (e.g. min > max nodes)   |
24|                +----------------+--------------------------------------------------------------+
25|                | Worker Failure | any failures on the worker child process                     |
26+----------------+----------------+--------------------------------------------------------------+
27| Platform Error |      n/a       | failures caused by the agent                                 |
28+----------------+----------------+--------------------------------------------------------------+
29| Infra Error    |      n/a       | failures outside the domain of the agent and workers         |
30|                |                | (e.g. host failures)                                         |
31+----------------+----------------+--------------------------------------------------------------+
32
33All errors other than "Worker Failure" are either raised canonically from the
34agent process or implicitly or explicitly crash the agent process. So the
35standard language (python) provided exception handling strategies apply.
36
37Worker Failures are special because the exception/failure originates on a different
38process from the agent so the error needs to be propagated inter-process
39(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process).
40
41TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes`
42to launch the workers which has a simple file based inter-process error propagation
43built-in.
44
45Any function or binary entrypoint decorated with :func:`record`
46will write uncaught exceptions (with the trace information) to a file specified by the
47environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent)
48sets this env var on each child it launches, then aggregates the error files for all
49children, and propagates the one with the **smallest** timestamp (e.g. the **first** error).
50"""
51
52import json
53import os
54import signal
55import socket
56import time
57import warnings
58from dataclasses import dataclass, field
59from datetime import datetime
60from functools import wraps
61from string import Template
62from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
63
64from torch.distributed.elastic.utils.logging import get_logger
65
66from .error_handler import ErrorHandler  # noqa: F401
67from .handlers import get_error_handler  # noqa: F401
68
69
70__all__ = [
71    "ProcessFailure",
72    "ChildFailedError",
73    "record",
74    "ErrorHandler",
75    "get_error_handler",
76]
77
78logger = get_logger(__name__)
79
80
81JSON = Dict
82
83_EMPTY_ERROR_DATA = {"message": "<NONE>"}
84_NOT_AVAILABLE = "<N/A>"
85
86T = TypeVar("T")
87
88
89@dataclass
90class ProcessFailure:
91    """
92    Represent the failed process result. When the worker process fails, it may record failure root cause into the file.
93
94    Tries to read the failure timestamp from the provided ``error_file``,
95    if the ``error_file`` does not exist, the timestamp is the current
96    timestamp (seconds since epoch).
97
98    The ``message`` field is a concise explanation of the failure. If
99    the error file exists then the message is obtained from the error file.
100    Otherwise one is generated based on the failure signature.
101
102    .. note:: It is assumed that the ``error_file`` is written by
103              ``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``.
104              Otherwise the behavior is undefined.
105
106    """
107
108    local_rank: int
109    pid: int
110    exitcode: int
111    error_file: str
112    error_file_data: JSON = field(init=False)
113    message: str = field(init=False)
114    timestamp: int = field(init=False)
115
116    def __post_init__(self):
117        self.error_file_data = _EMPTY_ERROR_DATA
118        if os.path.isfile(self.error_file):
119            try:
120                with open(self.error_file) as fp:
121                    self.error_file_data = json.load(fp)
122                    logger.debug(
123                        "User process failed with error data: %s",
124                        json.dumps(self.error_file_data, indent=2),
125                    )
126                    self.message, self.timestamp = self._get_error_data(
127                        self.error_file_data
128                    )
129            except Exception:
130                logger.exception("Failed to parse reply file: %s", self.error_file)
131                raise
132        else:
133            self._set_no_reply_file()
134
135        # make up an informative message if not already present
136        if not self.message:
137            # signals typically do not generate an error file message
138            if self.exitcode < 0:
139                self.message = (
140                    f"Signal {-self.exitcode} ({self.signal_name()})"
141                    f" received by PID {self.pid}"
142                )
143            else:
144                self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"
145
146    def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]:
147        message = error_file_data["message"]
148        if isinstance(message, str):
149            timestamp = int(error_file_data.get("timestamp", 0))
150        else:
151            timestamp = int(message["extraInfo"]["timestamp"])
152        return (message, timestamp)
153
154    def _set_no_reply_file(self):
155        self.error_file = _NOT_AVAILABLE
156        self.error_file_data = _EMPTY_ERROR_DATA
157        self.message = ""
158        self.timestamp = int(time.time())
159
160    def signal_name(self) -> str:
161        if self.exitcode < 0:
162            # We don't want to kill the parent process trying to find the signal name.
163            # if the signal doesn't map to a known name, use not available.
164            try:
165                return signal.Signals(-self.exitcode).name
166            except Exception:
167                return _NOT_AVAILABLE
168        else:
169            return _NOT_AVAILABLE
170
171    def timestamp_isoformat(self):
172        """Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS)."""
173        return datetime.fromtimestamp(self.timestamp).isoformat(sep="_")
174
175
176GlobalRank = int
177
178_FAILURE_FORMAT_TEMPLATE = """[${idx}]:
179  time      : ${time}
180  host      : ${hostname}
181  rank      : ${rank} (local_rank: ${local_rank})
182  exitcode  : ${exitcode} (pid: ${pid})
183  error_file: ${error_file}
184  traceback : ${message}"""
185
186# extra new lines before and after are intentional
187_MSG_FORMAT_TEMPLATE = """
188${boarder}
189${title}
190${section}
191Failures:
192${other_failures}
193${section}
194Root Cause (first observed failure):
195${root_failure}
196${boarder}"""
197
198
199class ChildFailedError(Exception):
200    """
201    Special exception type that can be raised from a function annotated with the
202    ``@record`` decorator to have the child process' (root exception) propagate
203    up the stack as-is (e.g. without being wrapped in the parent's traceback).
204
205    Useful in cases where the parent is a simple nanny process
206    and the child (worker) processes are actually doing meaningful compute.
207    In this case, errors typically occur on the child process as the parent
208    is not doing anything non-trivial, and child errors should be propagated
209    to the scheduler for accurate root cause diagnostics.
210
211    .. note:: The propagation relies on error files rather than exception handling to
212              support both function and binary launches.
213
214    Example:
215    ::
216
217     # process tree on a host (container)
218     0: scheduler-init-process:
219                |- 1: torchelastic_agent:
220                         |- 2: trainer_0 (ok)
221                         |- 3: trainer_1 (fail) -> error.json
222                         |- ...
223                         |- n+2: trainer_n (ok)
224                |- n+3: other processes
225                |- ...
226
227    In the example above, trainer 1's failure (written into error.json) is
228    the root cause and should be reported to the scheduler's init process.
229    The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})``
230    upon detecting trainer 1's failure which would propagate the contents
231    of trainer 1's error file to the scheduler's init process.
232    """
233
234    def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]):
235        self.name = name
236        self.failures = failures
237        assert (
238            self.failures
239        )  # does not make sense to create a ChildFaileError with no failures
240        super().__init__(self.format_msg())
241
242    def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]:
243        rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp)
244        return rank, self.failures[rank]
245
246    def format_msg(self, boarder_delim="=", section_delim="-"):
247        title = f"{self.name} FAILED"
248        root_rank, root_failure = self.get_first_failure()
249
250        root_failure_fmt: str = ""
251        other_failures_fmt: List[str] = []
252        width = len(title)
253        for idx, (rank, failure) in enumerate(self.failures.items()):
254            fmt, w = self._format_failure(idx, rank, failure)
255            width = max(width, w)
256            if rank == root_rank:
257                root_failure_fmt = fmt
258            else:
259                other_failures_fmt.append(fmt)
260
261        # upper boundary on width
262        width = min(width, 60)
263
264        return Template(_MSG_FORMAT_TEMPLATE).substitute(
265            boarder=boarder_delim * width,
266            title=title,
267            section=section_delim * width,
268            root_failure=root_failure_fmt,
269            other_failures="\n".join(other_failures_fmt or ["  <NO_OTHER_FAILURES>"]),
270        )
271
272    def _format_failure(
273        self, idx: int, rank: int, failure: ProcessFailure
274    ) -> Tuple[str, int]:
275        # failure.message is either a str (when the failure does not generate a traceback - e.g. signals)
276        # or a dict (json) of the form
277        # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}}
278        # so the display logic is:
279        # 1. if failure.message is not a dict (it is a str) just show it as is
280        # 2. else try to get the traceback (py_callstack)
281        # 3.      if the traceback is not there, use the message
282        # 4.      if the message  is not there show <N/A>
283        msg = failure.message
284        if isinstance(failure.message, dict):
285            msg = (
286                failure.message.get("extraInfo", {})
287                .get("py_callstack", failure.message.get("message", "<N/A>"))
288                .replace("\n", "\n  ")  # to properly indent the traceback
289            )
290
291        fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute(
292            idx=idx,
293            time=failure.timestamp_isoformat(),
294            hostname=socket.getfqdn(),
295            rank=rank,
296            local_rank=failure.local_rank,
297            exitcode=failure.exitcode,
298            pid=failure.pid,
299            error_file=failure.error_file,
300            message=msg,
301        )
302        width = 0
303        for line in fmt.split("\n"):
304            width = max(width, len(line))
305        return fmt, width
306
307
308def record(
309    fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None
310) -> Callable[..., T]:
311    """
312    Syntactic sugar to record errors/exceptions that happened in the decorated
313    function using the provided ``error_handler``.
314
315    Using this decorator is equivalent to:
316
317    ::
318
319     error_handler = get_error_handler()
320     error_handler.initialize()
321     try:
322        foobar()
323     except ChildFailedError as e:
324        _, failure = e.get_first_failure()
325        error_handler.dump_error_file(failure.error_file, failure.exitcode)
326        raise
327     except Exception as e:
328        error_handler.record(e)
329        raise
330
331    .. important:: use this decorator once per process at the top level method,
332                   typically this is the main method.
333
334    Example
335
336    ::
337
338     @record
339     def main():
340         pass
341
342     if __name__=="__main__":
343        main()
344
345    """
346    if not error_handler:
347        error_handler = get_error_handler()
348
349    def wrap(f):
350        @wraps(f)
351        def wrapper(*args, **kwargs):
352            assert error_handler is not None  # assertion for mypy type checker
353            error_handler.initialize()
354            try:
355                return f(*args, **kwargs)
356            except SystemExit as se:
357                # For run_path based entrypoints, SystemExit with code = 0 will never exit.
358                # Handling it here by returning a value:
359                if se.code == 0:
360                    return None
361                else:
362                    raise
363            except ChildFailedError as e:
364                rank, failure = e.get_first_failure()
365                if failure.error_file != _NOT_AVAILABLE:
366                    error_handler.dump_error_file(failure.error_file, failure.exitcode)
367                else:
368                    logger.info(
369                        (
370                            "local_rank %s FAILED with no error file."
371                            " Decorate your entrypoint fn with @record for traceback info."
372                            " See: https://pytorch.org/docs/stable/elastic/errors.html",
373                            rank,
374                        )
375                    )
376                raise
377            except Exception as e:
378                error_handler.record_exception(e)
379                raise
380
381        return wrapper
382
383    return wrap(fn)
384