xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDAException.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDADeviceAssertionHost.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMacros.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMiscFunctions.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
9*da0073e9SAndroid Build Coastguard Worker #include <cuda.h>
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker // Note [CHECK macro]
12*da0073e9SAndroid Build Coastguard Worker // ~~~~~~~~~~~~~~~~~~
13*da0073e9SAndroid Build Coastguard Worker // This is a macro so that AT_ERROR can get accurate __LINE__
14*da0073e9SAndroid Build Coastguard Worker // and __FILE__ information.  We could split this into a short
15*da0073e9SAndroid Build Coastguard Worker // macro and a function implementation if we pass along __LINE__
16*da0073e9SAndroid Build Coastguard Worker // and __FILE__, but no one has found this worth doing.
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker // Used to denote errors from CUDA framework.
19*da0073e9SAndroid Build Coastguard Worker // This needs to be declared here instead util/Exception.h for proper conversion
20*da0073e9SAndroid Build Coastguard Worker // during hipify.
21*da0073e9SAndroid Build Coastguard Worker namespace c10 {
22*da0073e9SAndroid Build Coastguard Worker class C10_CUDA_API CUDAError : public c10::Error {
23*da0073e9SAndroid Build Coastguard Worker   using Error::Error;
24*da0073e9SAndroid Build Coastguard Worker };
25*da0073e9SAndroid Build Coastguard Worker } // namespace c10
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_CHECK(EXPR)                                        \
28*da0073e9SAndroid Build Coastguard Worker   do {                                                              \
29*da0073e9SAndroid Build Coastguard Worker     const cudaError_t __err = EXPR;                                 \
30*da0073e9SAndroid Build Coastguard Worker     c10::cuda::c10_cuda_check_implementation(                       \
31*da0073e9SAndroid Build Coastguard Worker         static_cast<int32_t>(__err),                                \
32*da0073e9SAndroid Build Coastguard Worker         __FILE__,                                                   \
33*da0073e9SAndroid Build Coastguard Worker         __func__, /* Line number data type not well-defined between \
34*da0073e9SAndroid Build Coastguard Worker                       compilers, so we perform an explicit cast */  \
35*da0073e9SAndroid Build Coastguard Worker         static_cast<uint32_t>(__LINE__),                            \
36*da0073e9SAndroid Build Coastguard Worker         true);                                                      \
37*da0073e9SAndroid Build Coastguard Worker   } while (0)
38*da0073e9SAndroid Build Coastguard Worker 
39*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_CHECK_WARN(EXPR)                              \
40*da0073e9SAndroid Build Coastguard Worker   do {                                                         \
41*da0073e9SAndroid Build Coastguard Worker     const cudaError_t __err = EXPR;                            \
42*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(__err != cudaSuccess)) {                  \
43*da0073e9SAndroid Build Coastguard Worker       auto error_unused C10_UNUSED = cudaGetLastError();       \
44*da0073e9SAndroid Build Coastguard Worker       (void)error_unused;                                      \
45*da0073e9SAndroid Build Coastguard Worker       TORCH_WARN("CUDA warning: ", cudaGetErrorString(__err)); \
46*da0073e9SAndroid Build Coastguard Worker     }                                                          \
47*da0073e9SAndroid Build Coastguard Worker   } while (0)
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker // Indicates that a CUDA error is handled in a non-standard way
50*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_ERROR_HANDLED(EXPR) EXPR
51*da0073e9SAndroid Build Coastguard Worker 
52*da0073e9SAndroid Build Coastguard Worker // Intentionally ignore a CUDA error
53*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_IGNORE_ERROR(EXPR)                             \
54*da0073e9SAndroid Build Coastguard Worker   do {                                                          \
55*da0073e9SAndroid Build Coastguard Worker     const cudaError_t __err = EXPR;                             \
56*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(__err != cudaSuccess)) {                   \
57*da0073e9SAndroid Build Coastguard Worker       cudaError_t error_unused C10_UNUSED = cudaGetLastError(); \
58*da0073e9SAndroid Build Coastguard Worker       (void)error_unused;                                       \
59*da0073e9SAndroid Build Coastguard Worker     }                                                           \
60*da0073e9SAndroid Build Coastguard Worker   } while (0)
61*da0073e9SAndroid Build Coastguard Worker 
62*da0073e9SAndroid Build Coastguard Worker // Clear the last CUDA error
63*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_CLEAR_ERROR()                                \
64*da0073e9SAndroid Build Coastguard Worker   do {                                                        \
65*da0073e9SAndroid Build Coastguard Worker     cudaError_t error_unused C10_UNUSED = cudaGetLastError(); \
66*da0073e9SAndroid Build Coastguard Worker     (void)error_unused;                                       \
67*da0073e9SAndroid Build Coastguard Worker   } while (0)
68*da0073e9SAndroid Build Coastguard Worker 
69*da0073e9SAndroid Build Coastguard Worker // This should be used directly after every kernel launch to ensure
70*da0073e9SAndroid Build Coastguard Worker // the launch happened correctly and provide an early, close-to-source
71*da0073e9SAndroid Build Coastguard Worker // diagnostic if it didn't.
72*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError())
73*da0073e9SAndroid Build Coastguard Worker 
74*da0073e9SAndroid Build Coastguard Worker /// Launches a CUDA kernel appending to it all the information need to handle
75*da0073e9SAndroid Build Coastguard Worker /// device-side assertion failures. Checks that the launch was successful.
76*da0073e9SAndroid Build Coastguard Worker #define TORCH_DSA_KERNEL_LAUNCH(                                      \
77*da0073e9SAndroid Build Coastguard Worker     kernel, blocks, threads, shared_mem, stream, ...)                 \
78*da0073e9SAndroid Build Coastguard Worker   do {                                                                \
79*da0073e9SAndroid Build Coastguard Worker     auto& launch_registry =                                           \
80*da0073e9SAndroid Build Coastguard Worker         c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref();     \
81*da0073e9SAndroid Build Coastguard Worker     kernel<<<blocks, threads, shared_mem, stream>>>(                  \
82*da0073e9SAndroid Build Coastguard Worker         __VA_ARGS__,                                                  \
83*da0073e9SAndroid Build Coastguard Worker         launch_registry.get_uvm_assertions_ptr_for_current_device(),  \
84*da0073e9SAndroid Build Coastguard Worker         launch_registry.insert(                                       \
85*da0073e9SAndroid Build Coastguard Worker             __FILE__, __FUNCTION__, __LINE__, #kernel, stream.id())); \
86*da0073e9SAndroid Build Coastguard Worker     C10_CUDA_KERNEL_LAUNCH_CHECK();                                   \
87*da0073e9SAndroid Build Coastguard Worker   } while (0)
88*da0073e9SAndroid Build Coastguard Worker 
89*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
90*da0073e9SAndroid Build Coastguard Worker 
91*da0073e9SAndroid Build Coastguard Worker /// In the event of a CUDA failure, formats a nice error message about that
92*da0073e9SAndroid Build Coastguard Worker /// failure and also checks for device-side assertion failures
93*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void c10_cuda_check_implementation(
94*da0073e9SAndroid Build Coastguard Worker     const int32_t err,
95*da0073e9SAndroid Build Coastguard Worker     const char* filename,
96*da0073e9SAndroid Build Coastguard Worker     const char* function_name,
97*da0073e9SAndroid Build Coastguard Worker     const int line_number,
98*da0073e9SAndroid Build Coastguard Worker     const bool include_device_assertions);
99*da0073e9SAndroid Build Coastguard Worker 
100*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
101