xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDAGraphsUtils.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cuda/CUDAGeneratorImpl.h>
4 #include <ATen/cuda/CUDAEvent.h>
5 #include <ATen/cuda/PhiloxUtils.cuh>
6 #include <ATen/cuda/detail/CUDAHooks.h>
7 #include <ATen/detail/CUDAHooksInterface.h>
8 #include <c10/core/StreamGuard.h>
9 #include <c10/cuda/CUDAGraphsC10Utils.h>
10 #include <c10/cuda/CUDAGuard.h>
11 
12 // c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten.
13 // This file adds utils used by aten only.
14 
15 namespace at::cuda {
16 
17 using CaptureId_t = c10::cuda::CaptureId_t;
18 using CaptureStatus = c10::cuda::CaptureStatus;
19 
20 // Use this version where you don't want to create a CUDA context if none exists.
currentStreamCaptureStatus()21 inline CaptureStatus currentStreamCaptureStatus() {
22   // don't create a context if we don't have to
23   if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) {
24     return c10::cuda::currentStreamCaptureStatusMayInitCtx();
25   } else {
26     return CaptureStatus::None;
27   }
28 }
29 
assertNotCapturing(const std::string & attempt)30 inline void assertNotCapturing(const std::string& attempt) {
31   auto status = currentStreamCaptureStatus();
32   TORCH_CHECK(status == CaptureStatus::None,
33               attempt,
34               " during CUDA graph capture. If you need this call to be captured, "
35               "please file an issue. "
36               "Current cudaStreamCaptureStatus: ",
37               status);
38 }
39 
errorIfCapturingCudnnBenchmark(const std::string & version_specific)40 inline void errorIfCapturingCudnnBenchmark(const std::string& version_specific) {
41   auto status = currentStreamCaptureStatus();
42   TORCH_CHECK(status == CaptureStatus::None,
43               "Current cudaStreamCaptureStatus: ",
44               status,
45               "\nCapturing ",
46               version_specific,
47               "is prohibited. Possible causes of this error:\n"
48               "1. No warmup iterations occurred before capture.\n"
49               "2. The convolutions you're trying to capture use dynamic shapes, "
50               "in which case capturing them is generally prohibited.");
51 }
52 
53 } // namespace at::cuda
54