xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/logger.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import time
4from typing import Any, Callable, Dict, List, TypeVar
5from typing_extensions import ParamSpec
6from uuid import uuid4
7
8import torch.distributed.c10d_logger as c10d_logger
9from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
10
11
12__all__: List[str] = []
13
14global _dcp_logger
15_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
16
17_T = TypeVar("_T")
18_P = ParamSpec("_P")
19
20
21def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
22    """
23    Extracts log data from dcp method args
24    """
25    msg_dict = {}
26
27    # checkpoint ID can be passed in through the serializer or through the checkpoint id directly
28    storage_writer = kwargs.get("storage_writer", None)
29    storage_reader = kwargs.get("storage_reader", None)
30    planner = kwargs.get("planner", None)
31
32    checkpoint_id = kwargs.get("checkpoint_id", None)
33    if not checkpoint_id and (serializer := storage_writer or storage_reader):
34        checkpoint_id = getattr(serializer, "checkpoint_id", None)
35
36    msg_dict["checkpoint_id"] = (
37        str(checkpoint_id) if checkpoint_id is not None else checkpoint_id
38    )
39
40    # Uniquely identify a _dcp_method_logger wrapped function call.
41    msg_dict["uuid"] = str(uuid4().int)
42
43    if storage_writer:
44        msg_dict["storage_writer"] = storage_writer.__class__.__name__
45
46    if storage_reader:
47        msg_dict["storage_reader"] = storage_reader.__class__.__name__
48
49    if planner:
50        msg_dict["planner"] = planner.__class__.__name__
51
52    return msg_dict
53
54
55def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
56    msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs)
57    msg_dict.update(c10d_logger._get_msg_dict(func_name, **msg_dict))
58
59    return msg_dict
60
61
62def _dcp_method_logger(
63    log_exceptions: bool = False, **wrapper_kwargs: Any
64) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:  # pyre-ignore
65    """This method decorator logs the start, end, and exception of wrapped events."""
66
67    def decorator(func: Callable[_P, _T]):
68        @functools.wraps(func)
69        def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
70            msg_dict = _get_msg_dict(
71                func.__name__, *args, **{**wrapper_kwargs, **kwargs}
72            )
73
74            # log start event
75            msg_dict["event"] = "start"
76            t0 = time.time_ns()
77            msg_dict["time"] = t0
78            msg_dict["log_exceptions"] = log_exceptions
79            _dcp_logger.debug(msg_dict)
80
81            # exceptions
82            try:
83                result = func(*args, **kwargs)
84            except BaseException as error:
85                if log_exceptions:
86                    msg_dict["event"] = "exception"
87                    msg_dict["error"] = f"{error}"
88                    msg_dict["time"] = time.time_ns()
89                    _dcp_logger.error(msg_dict)
90                raise
91
92            # end event
93            msg_dict["event"] = "end"
94            t1 = time.time_ns()
95            msg_dict["time"] = time.time_ns()
96            msg_dict["times_spent"] = t1 - t0
97            _dcp_logger.debug(msg_dict)
98
99            return result
100
101        return wrapper
102
103    return decorator
104