1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10import functools 11import logging 12import time 13from typing import Any, Callable, Dict, List, Tuple, TypeVar 14from typing_extensions import ParamSpec 15 16import torch 17import torch.distributed as dist 18from torch.distributed.logging_handlers import _log_handlers 19 20 21__all__: List[str] = [] 22 23_DEFAULT_DESTINATION = "default" 24 25 26def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Logger: 27 logging_handler, log_handler_name = _get_logging_handler(destination) 28 logger = logging.getLogger(f"c10d-{log_handler_name}") 29 logger.setLevel(logging.DEBUG) 30 formatter = logging.Formatter( 31 "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" 32 ) 33 logging_handler.setFormatter(formatter) 34 logger.propagate = False 35 logger.addHandler(logging_handler) 36 return logger 37 38 39def _get_logging_handler( 40 destination: str = _DEFAULT_DESTINATION, 41) -> Tuple[logging.Handler, str]: 42 log_handler = _log_handlers[destination] 43 log_handler_name = f"{type(log_handler).__name__}-{destination}" 44 return (log_handler, log_handler_name) 45 46 47global _c10d_logger 48_c10d_logger = _get_or_create_logger() 49 50 51def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: 52 if dist.is_initialized(): 53 group = kwargs.get("group") or kwargs.get("process_group") 54 msg_dict = { 55 "func_name": f"{func_name}", 56 "args": f"{args}, {kwargs}", 57 "pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}", # type: ignore[arg-type] 58 "backend": f"{dist.get_backend(group)}", 59 "world_size": f"{dist.get_world_size()}", 60 "group_size": f"{dist.get_world_size(group)}", 61 "global_rank": f"{dist.get_rank()}", 62 "local_rank": f"{dist.get_rank(group)}", 63 } 64 if msg_dict["backend"] == "nccl": 65 nccl_version = torch.cuda.nccl.version() 66 msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version) 67 else: 68 msg_dict = { 69 "func_name": f"{func_name}", 70 "args": f"{args}, {kwargs}", 71 } 72 return msg_dict 73 74 75_T = TypeVar("_T") 76_P = ParamSpec("_P") 77 78 79def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: 80 @functools.wraps(func) 81 def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: 82 try: 83 return func(*args, **kwargs) 84 except Exception as error: 85 msg_dict = _get_msg_dict(func.__name__, *args, **kwargs) 86 msg_dict["error"] = f"{error}" 87 _c10d_logger.debug(msg_dict) 88 raise 89 90 return wrapper 91 92 93def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: 94 @functools.wraps(func) 95 def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: 96 t1 = time.time_ns() 97 func_return = func(*args, **kwargs) 98 time_spent = time.time_ns() - t1 99 100 msg_dict = _get_msg_dict(func.__name__, *args, **kwargs) 101 msg_dict["time_spent"] = f"{time_spent}ns" 102 _c10d_logger.debug(msg_dict) 103 104 return func_return 105 106 return wrapper 107