xref: /aosp_15_r20/external/pytorch/torch/mps/profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3
4import torch
5
6
7__all__ = ["start", "stop", "profile"]
8
9
10def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
11    r"""Start OS Signpost tracing from MPS backend.
12
13    The generated OS Signposts could be recorded and viewed in
14    XCode Instruments Logging tool.
15
16    Args:
17        mode(str): OS Signpost tracing mode could be "interval", "event",
18            or both "interval,event".
19            The interval mode traces the duration of execution of the operations,
20            whereas event mode marks the completion of executions.
21            See document `Recording Performance Data`_ for more info.
22        wait_until_completed(bool): Waits until the MPS Stream complete
23            executing each encoded GPU operation. This helps generating single
24            dispatches on the trace's timeline.
25            Note that enabling this option would affect the performance negatively.
26
27    .. _Recording Performance Data:
28       https://developer.apple.com/documentation/os/logging/recording_performance_data
29    """
30    mode_normalized = mode.lower().replace(" ", "")
31    torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed)
32
33
34def stop():
35    r"""Stops generating OS Signpost tracing from MPS backend."""
36    torch._C._mps_profilerStopTrace()
37
38
39@contextlib.contextmanager
40def profile(mode: str = "interval", wait_until_completed: bool = False):
41    r"""Context Manager to enabling generating OS Signpost tracing from MPS backend.
42
43    Args:
44        mode(str): OS Signpost tracing mode could be "interval", "event",
45            or both "interval,event".
46            The interval mode traces the duration of execution of the operations,
47            whereas event mode marks the completion of executions.
48            See document `Recording Performance Data`_ for more info.
49        wait_until_completed(bool): Waits until the MPS Stream complete
50            executing each encoded GPU operation. This helps generating single
51            dispatches on the trace's timeline.
52            Note that enabling this option would affect the performance negatively.
53
54    .. _Recording Performance Data:
55       https://developer.apple.com/documentation/os/logging/recording_performance_data
56    """
57    try:
58        start(mode, wait_until_completed)
59        yield
60    finally:
61        stop()
62