xref: /aosp_15_r20/external/pytorch/torch/distributed/c10d_logger.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import functools
11import logging
12import time
13from typing import Any, Callable, Dict, List, Tuple, TypeVar
14from typing_extensions import ParamSpec
15
16import torch
17import torch.distributed as dist
18from torch.distributed.logging_handlers import _log_handlers
19
20
21__all__: List[str] = []
22
23_DEFAULT_DESTINATION = "default"
24
25
26def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Logger:
27    logging_handler, log_handler_name = _get_logging_handler(destination)
28    logger = logging.getLogger(f"c10d-{log_handler_name}")
29    logger.setLevel(logging.DEBUG)
30    formatter = logging.Formatter(
31        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
32    )
33    logging_handler.setFormatter(formatter)
34    logger.propagate = False
35    logger.addHandler(logging_handler)
36    return logger
37
38
39def _get_logging_handler(
40    destination: str = _DEFAULT_DESTINATION,
41) -> Tuple[logging.Handler, str]:
42    log_handler = _log_handlers[destination]
43    log_handler_name = f"{type(log_handler).__name__}-{destination}"
44    return (log_handler, log_handler_name)
45
46
47global _c10d_logger
48_c10d_logger = _get_or_create_logger()
49
50
51def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
52    if dist.is_initialized():
53        group = kwargs.get("group") or kwargs.get("process_group")
54        msg_dict = {
55            "func_name": f"{func_name}",
56            "args": f"{args}, {kwargs}",
57            "pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}",  # type: ignore[arg-type]
58            "backend": f"{dist.get_backend(group)}",
59            "world_size": f"{dist.get_world_size()}",
60            "group_size": f"{dist.get_world_size(group)}",
61            "global_rank": f"{dist.get_rank()}",
62            "local_rank": f"{dist.get_rank(group)}",
63        }
64        if msg_dict["backend"] == "nccl":
65            nccl_version = torch.cuda.nccl.version()
66            msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version)
67    else:
68        msg_dict = {
69            "func_name": f"{func_name}",
70            "args": f"{args}, {kwargs}",
71        }
72    return msg_dict
73
74
75_T = TypeVar("_T")
76_P = ParamSpec("_P")
77
78
79def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
80    @functools.wraps(func)
81    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
82        try:
83            return func(*args, **kwargs)
84        except Exception as error:
85            msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
86            msg_dict["error"] = f"{error}"
87            _c10d_logger.debug(msg_dict)
88            raise
89
90    return wrapper
91
92
93def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
94    @functools.wraps(func)
95    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
96        t1 = time.time_ns()
97        func_return = func(*args, **kwargs)
98        time_spent = time.time_ns() - t1
99
100        msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
101        msg_dict["time_spent"] = f"{time_spent}ns"
102        _c10d_logger.debug(msg_dict)
103
104        return func_return
105
106    return wrapper
107