1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport contextlib 3*da0073e9SAndroid Build Coastguard Workerimport tempfile 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom . import check_error, cudart 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker__all__ = ["init", "start", "stop", "profile"] 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard WorkerDEFAULT_FLAGS = [ 13*da0073e9SAndroid Build Coastguard Worker "gpustarttimestamp", 14*da0073e9SAndroid Build Coastguard Worker "gpuendtimestamp", 15*da0073e9SAndroid Build Coastguard Worker "gridsize3d", 16*da0073e9SAndroid Build Coastguard Worker "threadblocksize", 17*da0073e9SAndroid Build Coastguard Worker "streamid", 18*da0073e9SAndroid Build Coastguard Worker "enableonstart 0", 19*da0073e9SAndroid Build Coastguard Worker "conckerneltrace", 20*da0073e9SAndroid Build Coastguard Worker] 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerdef init(output_file, flags=None, output_mode="key_value"): 24*da0073e9SAndroid Build Coastguard Worker rt = cudart() 25*da0073e9SAndroid Build Coastguard Worker if not hasattr(rt, "cudaOutputMode"): 26*da0073e9SAndroid Build Coastguard Worker raise AssertionError("HIP does not support profiler initialization!") 27*da0073e9SAndroid Build Coastguard Worker if ( 28*da0073e9SAndroid Build Coastguard Worker hasattr(torch.version, "cuda") 29*da0073e9SAndroid Build Coastguard Worker and torch.version.cuda is not None 30*da0073e9SAndroid Build Coastguard Worker and int(torch.version.cuda.split(".")[0]) >= 12 31*da0073e9SAndroid Build Coastguard Worker ): 32*da0073e9SAndroid Build Coastguard Worker # Check https://github.com/pytorch/pytorch/pull/91118 33*da0073e9SAndroid Build Coastguard Worker # cudaProfilerInitialize is no longer needed after CUDA 12 34*da0073e9SAndroid Build Coastguard Worker raise AssertionError("CUDA12+ does not need profiler initialization!") 35*da0073e9SAndroid Build Coastguard Worker flags = DEFAULT_FLAGS if flags is None else flags 36*da0073e9SAndroid Build Coastguard Worker if output_mode == "key_value": 37*da0073e9SAndroid Build Coastguard Worker output_mode_enum = rt.cudaOutputMode.KeyValuePair 38*da0073e9SAndroid Build Coastguard Worker elif output_mode == "csv": 39*da0073e9SAndroid Build Coastguard Worker output_mode_enum = rt.cudaOutputMode.CSV 40*da0073e9SAndroid Build Coastguard Worker else: 41*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 42*da0073e9SAndroid Build Coastguard Worker "supported CUDA profiler output modes are: key_value and csv" 43*da0073e9SAndroid Build Coastguard Worker ) 44*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile(delete=True) as f: 45*da0073e9SAndroid Build Coastguard Worker f.write(b"\n".join(f.encode("ascii") for f in flags)) 46*da0073e9SAndroid Build Coastguard Worker f.flush() 47*da0073e9SAndroid Build Coastguard Worker check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum)) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workerdef start(): 51*da0073e9SAndroid Build Coastguard Worker r"""Starts cuda profiler data collection. 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker .. warning:: 54*da0073e9SAndroid Build Coastguard Worker Raises CudaError in case of it is unable to start the profiler. 55*da0073e9SAndroid Build Coastguard Worker """ 56*da0073e9SAndroid Build Coastguard Worker check_error(cudart().cudaProfilerStart()) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Workerdef stop(): 60*da0073e9SAndroid Build Coastguard Worker r"""Stops cuda profiler data collection. 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker .. warning:: 63*da0073e9SAndroid Build Coastguard Worker Raises CudaError in case of it is unable to stop the profiler. 64*da0073e9SAndroid Build Coastguard Worker """ 65*da0073e9SAndroid Build Coastguard Worker check_error(cudart().cudaProfilerStop()) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 69*da0073e9SAndroid Build Coastguard Workerdef profile(): 70*da0073e9SAndroid Build Coastguard Worker """ 71*da0073e9SAndroid Build Coastguard Worker Enable profiling. 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker Context Manager to enabling profile collection by the active profiling tool from CUDA backend. 74*da0073e9SAndroid Build Coastguard Worker Example: 75*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 76*da0073e9SAndroid Build Coastguard Worker >>> import torch 77*da0073e9SAndroid Build Coastguard Worker >>> model = torch.nn.Linear(20, 30).cuda() 78*da0073e9SAndroid Build Coastguard Worker >>> inputs = torch.randn(128, 20).cuda() 79*da0073e9SAndroid Build Coastguard Worker >>> with torch.cuda.profiler.profile() as prof: 80*da0073e9SAndroid Build Coastguard Worker ... model(inputs) 81*da0073e9SAndroid Build Coastguard Worker """ 82*da0073e9SAndroid Build Coastguard Worker try: 83*da0073e9SAndroid Build Coastguard Worker start() 84*da0073e9SAndroid Build Coastguard Worker yield 85*da0073e9SAndroid Build Coastguard Worker finally: 86*da0073e9SAndroid Build Coastguard Worker stop() 87