1# mypy: allow-untyped-defs 2import traceback as tb 3from typing import Any, Dict, Tuple 4 5 6WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary] 7 8__all__ = ["CheckpointException"] 9 10 11def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION: 12 return (exc, tb.extract_tb(exc.__traceback__)) 13 14 15def _is_wrapped_exception(obj: Any) -> bool: 16 if not isinstance(obj, tuple): 17 return False 18 if len(obj) != 2: 19 return False 20 return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary) 21 22 23class CheckpointException(BaseException): 24 """Exception raised if failure was detected as part of a checkpoint load or save.""" 25 26 def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]): 27 super().__init__(msg, failures) 28 self._failures = failures 29 30 @property 31 def failures(self) -> Dict[int, WRAPPED_EXCEPTION]: 32 """Return a dictionary mapping node ranks to their associated exceptions in case of failure.""" 33 return self._failures 34 35 def __str__(self): 36 str = f"CheckpointException ranks:{self._failures.keys()}\n" 37 for rank, exc_pair in self._failures.items(): 38 exc, trace = exc_pair 39 str += f"Traceback (most recent call last): (RANK {rank})\n" 40 if trace is not None: 41 str += "".join(tb.format_list(trace)) 42 str += "".join(tb.format_exception_only(type(exc), value=exc)) 43 return str 44