xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/multiprocessing/errors/error_handler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import faulthandler
10import json
11import logging
12import os
13import time
14import traceback
15import warnings
16from typing import Any, Dict, Optional
17
18
19__all__ = ["ErrorHandler"]
20
21logger = logging.getLogger(__name__)
22
23
24class ErrorHandler:
25    """
26    Write the provided exception object along with some other metadata about
27    the error in a structured way in JSON format to an error file specified by the
28    environment variable: ``TORCHELASTIC_ERROR_FILE``. If this environment
29    variable is not set, then simply logs the contents of what would have been
30    written to the error file.
31
32    This handler may be subclassed to customize the handling of the error.
33    Subclasses should override ``initialize()`` and ``record_exception()``.
34    """
35
36    def _get_error_file_path(self) -> Optional[str]:
37        """
38        Return the error file path.
39
40        May return ``None`` to have the structured error be logged only.
41        """
42        return os.environ.get("TORCHELASTIC_ERROR_FILE", None)
43
44    def initialize(self) -> None:
45        """
46        Call prior to running code that we wish to capture errors/exceptions.
47
48        Typically registers signal/fault handlers. Users can override this
49        function to add custom initialization/registrations that aid in
50        propagation/information of errors/signals/exceptions/faults.
51        """
52        try:
53            faulthandler.enable(all_threads=True)
54        except Exception as e:
55            warnings.warn(f"Unable to enable fault handler. {type(e).__name__}: {e}")
56
57    def _write_error_file(self, file_path: str, error_msg: str) -> None:
58        """Write error message to the file."""
59        try:
60            with open(file_path, "w") as fp:
61                fp.write(error_msg)
62        except Exception as e:
63            warnings.warn(f"Unable to write error to file. {type(e).__name__}: {e}")
64
65    def record_exception(self, e: BaseException) -> None:
66        """
67        Write a structured information about the exception into an error file in JSON format.
68
69        If the error file cannot be determined, then logs the content
70        that would have been written to the error file.
71        """
72        file = self._get_error_file_path()
73        if file:
74            data = {
75                "message": {
76                    "message": f"{type(e).__name__}: {e}",
77                    "extraInfo": {
78                        "py_callstack": traceback.format_exc(),
79                        "timestamp": str(int(time.time())),
80                    },
81                }
82            }
83            with open(file, "w") as fp:
84                json.dump(data, fp)
85
86    def override_error_code_in_rootcause_data(
87        self,
88        rootcause_error_file: str,
89        rootcause_error: Dict[str, Any],
90        error_code: int = 0,
91    ):
92        """Modify the rootcause_error read from the file, to correctly set the exit code."""
93        if "message" not in rootcause_error:
94            logger.warning(
95                "child error file (%s) does not have field `message`. \n"
96                "cannot override error code: %s",
97                rootcause_error_file,
98                error_code,
99            )
100        elif isinstance(rootcause_error["message"], str):
101            logger.warning(
102                "child error file (%s) has a new message format. \n"
103                "skipping error code override",
104                rootcause_error_file,
105            )
106        else:
107            rootcause_error["message"]["errorCode"] = error_code
108
109    def dump_error_file(self, rootcause_error_file: str, error_code: int = 0):
110        """Dump parent error file from child process's root cause error and error code."""
111        with open(rootcause_error_file) as fp:
112            rootcause_error = json.load(fp)
113            # Override error code since the child process cannot capture the error code if it
114            # is terminated by signals like SIGSEGV.
115            if error_code:
116                self.override_error_code_in_rootcause_data(
117                    rootcause_error_file, rootcause_error, error_code
118                )
119            logger.debug(
120                "child error file (%s) contents:\n" "%s",
121                rootcause_error_file,
122                json.dumps(rootcause_error, indent=2),
123            )
124
125        my_error_file = self._get_error_file_path()
126        if my_error_file:
127            # Guard against existing error files
128            # This can happen when the child is created using multiprocessing
129            # and the same env var (TORCHELASTIC_ERROR_FILE) is used on the
130            # parent and child to specify the error files (respectively)
131            # because the env vars on the child is set in the wrapper function
132            # and by default the child inherits the parent's env vars, if the child
133            # process receives a signal before the wrapper function kicks in
134            # and the signal handler writes to the error file, then the child
135            # will write to the parent's error file. In this case just log the
136            # original error file contents and overwrite the error file.
137            self._rm(my_error_file)
138            self._write_error_file(my_error_file, json.dumps(rootcause_error))
139            logger.info("dumped error file to parent's %s", my_error_file)
140        else:
141            logger.error(
142                "no error file defined for parent, to copy child error file (%s)",
143                rootcause_error_file,
144            )
145
146    def _rm(self, my_error_file):
147        if os.path.isfile(my_error_file):
148            # Log the contents of the original file.
149            with open(my_error_file) as fp:
150                try:
151                    original = json.dumps(json.load(fp), indent=2)
152                    logger.warning(
153                        "%s already exists"
154                        " and will be overwritten."
155                        " Original contents:\n%s",
156                        my_error_file,
157                        original,
158                    )
159                except json.decoder.JSONDecodeError:
160                    logger.warning(
161                        "%s already exists"
162                        " and will be overwritten."
163                        " Unable to load original contents:\n",
164                        my_error_file,
165                    )
166            os.remove(my_error_file)
167