1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10""" 11Each host in a distributed PyTorch job runs with a single TorchElastic agent, 12and multiple workers (as children processes of the TorchElastic agent). 13Since the workers are user-provided (your PyTorch script/job), TorchElastic 14has a way to propagate errors on the trainers through the agent and up to the 15scheduler, which ultimately informs the end-user about the state of the job 16and applies any retry policies. 17 18TorchElastic categorizes errors into 3 categories: 19 20+----------------+----------------+--------------------------------------------------------------+ 21| Category | Sub-Category | Description | 22+================+================+==============================================================+ 23| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) | 24| +----------------+--------------------------------------------------------------+ 25| | Worker Failure | any failures on the worker child process | 26+----------------+----------------+--------------------------------------------------------------+ 27| Platform Error | n/a | failures caused by the agent | 28+----------------+----------------+--------------------------------------------------------------+ 29| Infra Error | n/a | failures outside the domain of the agent and workers | 30| | | (e.g. host failures) | 31+----------------+----------------+--------------------------------------------------------------+ 32 33All errors other than "Worker Failure" are either raised canonically from the 34agent process or implicitly or explicitly crash the agent process. So the 35standard language (python) provided exception handling strategies apply. 36 37Worker Failures are special because the exception/failure originates on a different 38process from the agent so the error needs to be propagated inter-process 39(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process). 40 41TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes` 42to launch the workers which has a simple file based inter-process error propagation 43built-in. 44 45Any function or binary entrypoint decorated with :func:`record` 46will write uncaught exceptions (with the trace information) to a file specified by the 47environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent) 48sets this env var on each child it launches, then aggregates the error files for all 49children, and propagates the one with the **smallest** timestamp (e.g. the **first** error). 50""" 51 52import json 53import os 54import signal 55import socket 56import time 57import warnings 58from dataclasses import dataclass, field 59from datetime import datetime 60from functools import wraps 61from string import Template 62from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar 63 64from torch.distributed.elastic.utils.logging import get_logger 65 66from .error_handler import ErrorHandler # noqa: F401 67from .handlers import get_error_handler # noqa: F401 68 69 70__all__ = [ 71 "ProcessFailure", 72 "ChildFailedError", 73 "record", 74 "ErrorHandler", 75 "get_error_handler", 76] 77 78logger = get_logger(__name__) 79 80 81JSON = Dict 82 83_EMPTY_ERROR_DATA = {"message": "<NONE>"} 84_NOT_AVAILABLE = "<N/A>" 85 86T = TypeVar("T") 87 88 89@dataclass 90class ProcessFailure: 91 """ 92 Represent the failed process result. When the worker process fails, it may record failure root cause into the file. 93 94 Tries to read the failure timestamp from the provided ``error_file``, 95 if the ``error_file`` does not exist, the timestamp is the current 96 timestamp (seconds since epoch). 97 98 The ``message`` field is a concise explanation of the failure. If 99 the error file exists then the message is obtained from the error file. 100 Otherwise one is generated based on the failure signature. 101 102 .. note:: It is assumed that the ``error_file`` is written by 103 ``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``. 104 Otherwise the behavior is undefined. 105 106 """ 107 108 local_rank: int 109 pid: int 110 exitcode: int 111 error_file: str 112 error_file_data: JSON = field(init=False) 113 message: str = field(init=False) 114 timestamp: int = field(init=False) 115 116 def __post_init__(self): 117 self.error_file_data = _EMPTY_ERROR_DATA 118 if os.path.isfile(self.error_file): 119 try: 120 with open(self.error_file) as fp: 121 self.error_file_data = json.load(fp) 122 logger.debug( 123 "User process failed with error data: %s", 124 json.dumps(self.error_file_data, indent=2), 125 ) 126 self.message, self.timestamp = self._get_error_data( 127 self.error_file_data 128 ) 129 except Exception: 130 logger.exception("Failed to parse reply file: %s", self.error_file) 131 raise 132 else: 133 self._set_no_reply_file() 134 135 # make up an informative message if not already present 136 if not self.message: 137 # signals typically do not generate an error file message 138 if self.exitcode < 0: 139 self.message = ( 140 f"Signal {-self.exitcode} ({self.signal_name()})" 141 f" received by PID {self.pid}" 142 ) 143 else: 144 self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html" 145 146 def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]: 147 message = error_file_data["message"] 148 if isinstance(message, str): 149 timestamp = int(error_file_data.get("timestamp", 0)) 150 else: 151 timestamp = int(message["extraInfo"]["timestamp"]) 152 return (message, timestamp) 153 154 def _set_no_reply_file(self): 155 self.error_file = _NOT_AVAILABLE 156 self.error_file_data = _EMPTY_ERROR_DATA 157 self.message = "" 158 self.timestamp = int(time.time()) 159 160 def signal_name(self) -> str: 161 if self.exitcode < 0: 162 # We don't want to kill the parent process trying to find the signal name. 163 # if the signal doesn't map to a known name, use not available. 164 try: 165 return signal.Signals(-self.exitcode).name 166 except Exception: 167 return _NOT_AVAILABLE 168 else: 169 return _NOT_AVAILABLE 170 171 def timestamp_isoformat(self): 172 """Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS).""" 173 return datetime.fromtimestamp(self.timestamp).isoformat(sep="_") 174 175 176GlobalRank = int 177 178_FAILURE_FORMAT_TEMPLATE = """[${idx}]: 179 time : ${time} 180 host : ${hostname} 181 rank : ${rank} (local_rank: ${local_rank}) 182 exitcode : ${exitcode} (pid: ${pid}) 183 error_file: ${error_file} 184 traceback : ${message}""" 185 186# extra new lines before and after are intentional 187_MSG_FORMAT_TEMPLATE = """ 188${boarder} 189${title} 190${section} 191Failures: 192${other_failures} 193${section} 194Root Cause (first observed failure): 195${root_failure} 196${boarder}""" 197 198 199class ChildFailedError(Exception): 200 """ 201 Special exception type that can be raised from a function annotated with the 202 ``@record`` decorator to have the child process' (root exception) propagate 203 up the stack as-is (e.g. without being wrapped in the parent's traceback). 204 205 Useful in cases where the parent is a simple nanny process 206 and the child (worker) processes are actually doing meaningful compute. 207 In this case, errors typically occur on the child process as the parent 208 is not doing anything non-trivial, and child errors should be propagated 209 to the scheduler for accurate root cause diagnostics. 210 211 .. note:: The propagation relies on error files rather than exception handling to 212 support both function and binary launches. 213 214 Example: 215 :: 216 217 # process tree on a host (container) 218 0: scheduler-init-process: 219 |- 1: torchelastic_agent: 220 |- 2: trainer_0 (ok) 221 |- 3: trainer_1 (fail) -> error.json 222 |- ... 223 |- n+2: trainer_n (ok) 224 |- n+3: other processes 225 |- ... 226 227 In the example above, trainer 1's failure (written into error.json) is 228 the root cause and should be reported to the scheduler's init process. 229 The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})`` 230 upon detecting trainer 1's failure which would propagate the contents 231 of trainer 1's error file to the scheduler's init process. 232 """ 233 234 def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]): 235 self.name = name 236 self.failures = failures 237 assert ( 238 self.failures 239 ) # does not make sense to create a ChildFaileError with no failures 240 super().__init__(self.format_msg()) 241 242 def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]: 243 rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp) 244 return rank, self.failures[rank] 245 246 def format_msg(self, boarder_delim="=", section_delim="-"): 247 title = f"{self.name} FAILED" 248 root_rank, root_failure = self.get_first_failure() 249 250 root_failure_fmt: str = "" 251 other_failures_fmt: List[str] = [] 252 width = len(title) 253 for idx, (rank, failure) in enumerate(self.failures.items()): 254 fmt, w = self._format_failure(idx, rank, failure) 255 width = max(width, w) 256 if rank == root_rank: 257 root_failure_fmt = fmt 258 else: 259 other_failures_fmt.append(fmt) 260 261 # upper boundary on width 262 width = min(width, 60) 263 264 return Template(_MSG_FORMAT_TEMPLATE).substitute( 265 boarder=boarder_delim * width, 266 title=title, 267 section=section_delim * width, 268 root_failure=root_failure_fmt, 269 other_failures="\n".join(other_failures_fmt or [" <NO_OTHER_FAILURES>"]), 270 ) 271 272 def _format_failure( 273 self, idx: int, rank: int, failure: ProcessFailure 274 ) -> Tuple[str, int]: 275 # failure.message is either a str (when the failure does not generate a traceback - e.g. signals) 276 # or a dict (json) of the form 277 # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}} 278 # so the display logic is: 279 # 1. if failure.message is not a dict (it is a str) just show it as is 280 # 2. else try to get the traceback (py_callstack) 281 # 3. if the traceback is not there, use the message 282 # 4. if the message is not there show <N/A> 283 msg = failure.message 284 if isinstance(failure.message, dict): 285 msg = ( 286 failure.message.get("extraInfo", {}) 287 .get("py_callstack", failure.message.get("message", "<N/A>")) 288 .replace("\n", "\n ") # to properly indent the traceback 289 ) 290 291 fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute( 292 idx=idx, 293 time=failure.timestamp_isoformat(), 294 hostname=socket.getfqdn(), 295 rank=rank, 296 local_rank=failure.local_rank, 297 exitcode=failure.exitcode, 298 pid=failure.pid, 299 error_file=failure.error_file, 300 message=msg, 301 ) 302 width = 0 303 for line in fmt.split("\n"): 304 width = max(width, len(line)) 305 return fmt, width 306 307 308def record( 309 fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None 310) -> Callable[..., T]: 311 """ 312 Syntactic sugar to record errors/exceptions that happened in the decorated 313 function using the provided ``error_handler``. 314 315 Using this decorator is equivalent to: 316 317 :: 318 319 error_handler = get_error_handler() 320 error_handler.initialize() 321 try: 322 foobar() 323 except ChildFailedError as e: 324 _, failure = e.get_first_failure() 325 error_handler.dump_error_file(failure.error_file, failure.exitcode) 326 raise 327 except Exception as e: 328 error_handler.record(e) 329 raise 330 331 .. important:: use this decorator once per process at the top level method, 332 typically this is the main method. 333 334 Example 335 336 :: 337 338 @record 339 def main(): 340 pass 341 342 if __name__=="__main__": 343 main() 344 345 """ 346 if not error_handler: 347 error_handler = get_error_handler() 348 349 def wrap(f): 350 @wraps(f) 351 def wrapper(*args, **kwargs): 352 assert error_handler is not None # assertion for mypy type checker 353 error_handler.initialize() 354 try: 355 return f(*args, **kwargs) 356 except SystemExit as se: 357 # For run_path based entrypoints, SystemExit with code = 0 will never exit. 358 # Handling it here by returning a value: 359 if se.code == 0: 360 return None 361 else: 362 raise 363 except ChildFailedError as e: 364 rank, failure = e.get_first_failure() 365 if failure.error_file != _NOT_AVAILABLE: 366 error_handler.dump_error_file(failure.error_file, failure.exitcode) 367 else: 368 logger.info( 369 ( 370 "local_rank %s FAILED with no error file." 371 " Decorate your entrypoint fn with @record for traceback info." 372 " See: https://pytorch.org/docs/stable/elastic/errors.html", 373 rank, 374 ) 375 ) 376 raise 377 except Exception as e: 378 error_handler.record_exception(e) 379 raise 380 381 return wrapper 382 383 return wrap(fn) 384