1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerr"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling.""" 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workertry: 8*da0073e9SAndroid Build Coastguard Worker from torch._C import _nvtx 9*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker class _NVTXStub: 12*da0073e9SAndroid Build Coastguard Worker @staticmethod 13*da0073e9SAndroid Build Coastguard Worker def _fail(*args, **kwargs): 14*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 15*da0073e9SAndroid Build Coastguard Worker "NVTX functions not installed. Are you sure you have a CUDA build?" 16*da0073e9SAndroid Build Coastguard Worker ) 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker rangePushA = _fail 19*da0073e9SAndroid Build Coastguard Worker rangePop = _fail 20*da0073e9SAndroid Build Coastguard Worker markA = _fail 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker _nvtx = _NVTXStub() # type: ignore[assignment] 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"] 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerdef range_push(msg): 28*da0073e9SAndroid Build Coastguard Worker """ 29*da0073e9SAndroid Build Coastguard Worker Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started. 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker Args: 32*da0073e9SAndroid Build Coastguard Worker msg (str): ASCII message to associate with range 33*da0073e9SAndroid Build Coastguard Worker """ 34*da0073e9SAndroid Build Coastguard Worker return _nvtx.rangePushA(msg) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerdef range_pop(): 38*da0073e9SAndroid Build Coastguard Worker """Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended.""" 39*da0073e9SAndroid Build Coastguard Worker return _nvtx.rangePop() 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Workerdef range_start(msg) -> int: 43*da0073e9SAndroid Build Coastguard Worker """ 44*da0073e9SAndroid Build Coastguard Worker Mark the start of a range with string message. It returns an unique handle 45*da0073e9SAndroid Build Coastguard Worker for this range to pass to the corresponding call to rangeEnd(). 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker A key difference between this and range_push/range_pop is that the 48*da0073e9SAndroid Build Coastguard Worker range_start/range_end version supports range across threads (start on one 49*da0073e9SAndroid Build Coastguard Worker thread and end on another thread). 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker Returns: A range handle (uint64_t) that can be passed to range_end(). 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker Args: 54*da0073e9SAndroid Build Coastguard Worker msg (str): ASCII message to associate with the range. 55*da0073e9SAndroid Build Coastguard Worker """ 56*da0073e9SAndroid Build Coastguard Worker return _nvtx.rangeStartA(msg) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Workerdef range_end(range_id) -> None: 60*da0073e9SAndroid Build Coastguard Worker """ 61*da0073e9SAndroid Build Coastguard Worker Mark the end of a range for a given range_id. 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker Args: 64*da0073e9SAndroid Build Coastguard Worker range_id (int): an unique handle for the start range. 65*da0073e9SAndroid Build Coastguard Worker """ 66*da0073e9SAndroid Build Coastguard Worker _nvtx.rangeEnd(range_id) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerdef mark(msg): 70*da0073e9SAndroid Build Coastguard Worker """ 71*da0073e9SAndroid Build Coastguard Worker Describe an instantaneous event that occurred at some point. 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker Args: 74*da0073e9SAndroid Build Coastguard Worker msg (str): ASCII message to associate with the event. 75*da0073e9SAndroid Build Coastguard Worker """ 76*da0073e9SAndroid Build Coastguard Worker return _nvtx.markA(msg) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker@contextmanager 80*da0073e9SAndroid Build Coastguard Workerdef range(msg, *args, **kwargs): 81*da0073e9SAndroid Build Coastguard Worker """ 82*da0073e9SAndroid Build Coastguard Worker Context manager / decorator that pushes an NVTX range at the beginning 83*da0073e9SAndroid Build Coastguard Worker of its scope, and pops it at the end. If extra arguments are given, 84*da0073e9SAndroid Build Coastguard Worker they are passed as arguments to msg.format(). 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker Args: 87*da0073e9SAndroid Build Coastguard Worker msg (str): message to associate with the range 88*da0073e9SAndroid Build Coastguard Worker """ 89*da0073e9SAndroid Build Coastguard Worker range_push(msg.format(*args, **kwargs)) 90*da0073e9SAndroid Build Coastguard Worker try: 91*da0073e9SAndroid Build Coastguard Worker yield 92*da0073e9SAndroid Build Coastguard Worker finally: 93*da0073e9SAndroid Build Coastguard Worker range_pop() 94