xref: /aosp_15_r20/external/pytorch/torch/cuda/nvtx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerr"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling."""
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workertry:
8*da0073e9SAndroid Build Coastguard Worker    from torch._C import _nvtx
9*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker    class _NVTXStub:
12*da0073e9SAndroid Build Coastguard Worker        @staticmethod
13*da0073e9SAndroid Build Coastguard Worker        def _fail(*args, **kwargs):
14*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
15*da0073e9SAndroid Build Coastguard Worker                "NVTX functions not installed. Are you sure you have a CUDA build?"
16*da0073e9SAndroid Build Coastguard Worker            )
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker        rangePushA = _fail
19*da0073e9SAndroid Build Coastguard Worker        rangePop = _fail
20*da0073e9SAndroid Build Coastguard Worker        markA = _fail
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    _nvtx = _NVTXStub()  # type: ignore[assignment]
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerdef range_push(msg):
28*da0073e9SAndroid Build Coastguard Worker    """
29*da0073e9SAndroid Build Coastguard Worker    Push a range onto a stack of nested range span.  Returns zero-based depth of the range that is started.
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    Args:
32*da0073e9SAndroid Build Coastguard Worker        msg (str): ASCII message to associate with range
33*da0073e9SAndroid Build Coastguard Worker    """
34*da0073e9SAndroid Build Coastguard Worker    return _nvtx.rangePushA(msg)
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerdef range_pop():
38*da0073e9SAndroid Build Coastguard Worker    """Pop a range off of a stack of nested range spans.  Returns the  zero-based depth of the range that is ended."""
39*da0073e9SAndroid Build Coastguard Worker    return _nvtx.rangePop()
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerdef range_start(msg) -> int:
43*da0073e9SAndroid Build Coastguard Worker    """
44*da0073e9SAndroid Build Coastguard Worker    Mark the start of a range with string message. It returns an unique handle
45*da0073e9SAndroid Build Coastguard Worker    for this range to pass to the corresponding call to rangeEnd().
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    A key difference between this and range_push/range_pop is that the
48*da0073e9SAndroid Build Coastguard Worker    range_start/range_end version supports range across threads (start on one
49*da0073e9SAndroid Build Coastguard Worker    thread and end on another thread).
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    Returns: A range handle (uint64_t) that can be passed to range_end().
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    Args:
54*da0073e9SAndroid Build Coastguard Worker        msg (str): ASCII message to associate with the range.
55*da0073e9SAndroid Build Coastguard Worker    """
56*da0073e9SAndroid Build Coastguard Worker    return _nvtx.rangeStartA(msg)
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Workerdef range_end(range_id) -> None:
60*da0073e9SAndroid Build Coastguard Worker    """
61*da0073e9SAndroid Build Coastguard Worker    Mark the end of a range for a given range_id.
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    Args:
64*da0073e9SAndroid Build Coastguard Worker        range_id (int): an unique handle for the start range.
65*da0073e9SAndroid Build Coastguard Worker    """
66*da0073e9SAndroid Build Coastguard Worker    _nvtx.rangeEnd(range_id)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerdef mark(msg):
70*da0073e9SAndroid Build Coastguard Worker    """
71*da0073e9SAndroid Build Coastguard Worker    Describe an instantaneous event that occurred at some point.
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    Args:
74*da0073e9SAndroid Build Coastguard Worker        msg (str): ASCII message to associate with the event.
75*da0073e9SAndroid Build Coastguard Worker    """
76*da0073e9SAndroid Build Coastguard Worker    return _nvtx.markA(msg)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker@contextmanager
80*da0073e9SAndroid Build Coastguard Workerdef range(msg, *args, **kwargs):
81*da0073e9SAndroid Build Coastguard Worker    """
82*da0073e9SAndroid Build Coastguard Worker    Context manager / decorator that pushes an NVTX range at the beginning
83*da0073e9SAndroid Build Coastguard Worker    of its scope, and pops it at the end. If extra arguments are given,
84*da0073e9SAndroid Build Coastguard Worker    they are passed as arguments to msg.format().
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    Args:
87*da0073e9SAndroid Build Coastguard Worker        msg (str): message to associate with the range
88*da0073e9SAndroid Build Coastguard Worker    """
89*da0073e9SAndroid Build Coastguard Worker    range_push(msg.format(*args, **kwargs))
90*da0073e9SAndroid Build Coastguard Worker    try:
91*da0073e9SAndroid Build Coastguard Worker        yield
92*da0073e9SAndroid Build Coastguard Worker    finally:
93*da0073e9SAndroid Build Coastguard Worker        range_pop()
94