xref: /aosp_15_r20/external/pytorch/torch/profiler/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""
3PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference.
4Profiler's context manager API can be used to better understand what model operators are the most expensive,
5examine their input shapes and stack traces, study device kernel activity and visualize the execution trace.
6
7.. note::
8    An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated.
9
10"""
11import os
12
13from torch._C._autograd import _supported_activities, DeviceType, kineto_available
14from torch._C._profiler import _ExperimentalConfig, ProfilerActivity, RecordScope
15from torch.autograd.profiler import KinetoStepTracker, record_function
16from torch.optim.optimizer import register_optimizer_step_post_hook
17
18from .profiler import (
19    _KinetoProfile,
20    ExecutionTraceObserver,
21    profile,
22    ProfilerAction,
23    schedule,
24    supported_activities,
25    tensorboard_trace_handler,
26)
27
28
29__all__ = [
30    "profile",
31    "schedule",
32    "supported_activities",
33    "tensorboard_trace_handler",
34    "ProfilerAction",
35    "ProfilerActivity",
36    "kineto_available",
37    "DeviceType",
38    "record_function",
39    "ExecutionTraceObserver",
40]
41
42from . import itt
43
44
45def _optimizer_post_hook(optimizer, args, kwargs):
46    KinetoStepTracker.increment_step("Optimizer")
47
48
49if os.environ.get("KINETO_USE_DAEMON", None):
50    _ = register_optimizer_step_post_hook(_optimizer_post_hook)
51