1# mypy: allow-untyped-defs 2import functools 3import time 4from typing import Any, Callable, Dict, List, TypeVar 5from typing_extensions import ParamSpec 6from uuid import uuid4 7 8import torch.distributed.c10d_logger as c10d_logger 9from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME 10 11 12__all__: List[str] = [] 13 14global _dcp_logger 15_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME) 16 17_T = TypeVar("_T") 18_P = ParamSpec("_P") 19 20 21def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]: 22 """ 23 Extracts log data from dcp method args 24 """ 25 msg_dict = {} 26 27 # checkpoint ID can be passed in through the serializer or through the checkpoint id directly 28 storage_writer = kwargs.get("storage_writer", None) 29 storage_reader = kwargs.get("storage_reader", None) 30 planner = kwargs.get("planner", None) 31 32 checkpoint_id = kwargs.get("checkpoint_id", None) 33 if not checkpoint_id and (serializer := storage_writer or storage_reader): 34 checkpoint_id = getattr(serializer, "checkpoint_id", None) 35 36 msg_dict["checkpoint_id"] = ( 37 str(checkpoint_id) if checkpoint_id is not None else checkpoint_id 38 ) 39 40 # Uniquely identify a _dcp_method_logger wrapped function call. 41 msg_dict["uuid"] = str(uuid4().int) 42 43 if storage_writer: 44 msg_dict["storage_writer"] = storage_writer.__class__.__name__ 45 46 if storage_reader: 47 msg_dict["storage_reader"] = storage_reader.__class__.__name__ 48 49 if planner: 50 msg_dict["planner"] = planner.__class__.__name__ 51 52 return msg_dict 53 54 55def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: 56 msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs) 57 msg_dict.update(c10d_logger._get_msg_dict(func_name, **msg_dict)) 58 59 return msg_dict 60 61 62def _dcp_method_logger( 63 log_exceptions: bool = False, **wrapper_kwargs: Any 64) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore 65 """This method decorator logs the start, end, and exception of wrapped events.""" 66 67 def decorator(func: Callable[_P, _T]): 68 @functools.wraps(func) 69 def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: 70 msg_dict = _get_msg_dict( 71 func.__name__, *args, **{**wrapper_kwargs, **kwargs} 72 ) 73 74 # log start event 75 msg_dict["event"] = "start" 76 t0 = time.time_ns() 77 msg_dict["time"] = t0 78 msg_dict["log_exceptions"] = log_exceptions 79 _dcp_logger.debug(msg_dict) 80 81 # exceptions 82 try: 83 result = func(*args, **kwargs) 84 except BaseException as error: 85 if log_exceptions: 86 msg_dict["event"] = "exception" 87 msg_dict["error"] = f"{error}" 88 msg_dict["time"] = time.time_ns() 89 _dcp_logger.error(msg_dict) 90 raise 91 92 # end event 93 msg_dict["event"] = "end" 94 t1 = time.time_ns() 95 msg_dict["time"] = time.time_ns() 96 msg_dict["times_spent"] = t1 - t0 97 _dcp_logger.debug(msg_dict) 98 99 return result 100 101 return wrapper 102 103 return decorator 104