xref: /aosp_15_r20/external/pytorch/torch/monitor/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from torch._C._monitor import *  # noqa: F403
2from typing import TYPE_CHECKING
3
4from torch._C._monitor import _WaitCounter  # type: ignore[attr-defined]
5
6if TYPE_CHECKING:
7    from torch.utils.tensorboard import SummaryWriter
8
9
10STAT_EVENT = "torch.monitor.Stat"
11
12
13class TensorboardEventHandler:
14    """
15    TensorboardEventHandler is an event handler that will write known events to
16    the provided SummaryWriter.
17
18    This currently only supports ``torch.monitor.Stat`` events which are logged
19    as scalars.
20
21    Example:
22        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_MONITOR)
23        >>> # xdoctest: +REQUIRES(module:tensorboard)
24        >>> from torch.utils.tensorboard import SummaryWriter
25        >>> from torch.monitor import TensorboardEventHandler, register_event_handler
26        >>> writer = SummaryWriter("log_dir")
27        >>> register_event_handler(TensorboardEventHandler(writer))
28    """
29    def __init__(self, writer: "SummaryWriter") -> None:
30        """
31        Constructs the ``TensorboardEventHandler``.
32        """
33        self._writer = writer
34
35    def __call__(self, event: Event) -> None:
36        if event.name == STAT_EVENT:
37            for k, v in event.data.items():
38                self._writer.add_scalar(k, v, walltime=event.timestamp.timestamp())
39