1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDADeviceAssertionHost.h> 4*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMacros.h> 5*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMiscFunctions.h> 6*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h> 7*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h> 8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h> 9*da0073e9SAndroid Build Coastguard Worker #include <cuda.h> 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker // Note [CHECK macro] 12*da0073e9SAndroid Build Coastguard Worker // ~~~~~~~~~~~~~~~~~~ 13*da0073e9SAndroid Build Coastguard Worker // This is a macro so that AT_ERROR can get accurate __LINE__ 14*da0073e9SAndroid Build Coastguard Worker // and __FILE__ information. We could split this into a short 15*da0073e9SAndroid Build Coastguard Worker // macro and a function implementation if we pass along __LINE__ 16*da0073e9SAndroid Build Coastguard Worker // and __FILE__, but no one has found this worth doing. 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker // Used to denote errors from CUDA framework. 19*da0073e9SAndroid Build Coastguard Worker // This needs to be declared here instead util/Exception.h for proper conversion 20*da0073e9SAndroid Build Coastguard Worker // during hipify. 21*da0073e9SAndroid Build Coastguard Worker namespace c10 { 22*da0073e9SAndroid Build Coastguard Worker class C10_CUDA_API CUDAError : public c10::Error { 23*da0073e9SAndroid Build Coastguard Worker using Error::Error; 24*da0073e9SAndroid Build Coastguard Worker }; 25*da0073e9SAndroid Build Coastguard Worker } // namespace c10 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_CHECK(EXPR) \ 28*da0073e9SAndroid Build Coastguard Worker do { \ 29*da0073e9SAndroid Build Coastguard Worker const cudaError_t __err = EXPR; \ 30*da0073e9SAndroid Build Coastguard Worker c10::cuda::c10_cuda_check_implementation( \ 31*da0073e9SAndroid Build Coastguard Worker static_cast<int32_t>(__err), \ 32*da0073e9SAndroid Build Coastguard Worker __FILE__, \ 33*da0073e9SAndroid Build Coastguard Worker __func__, /* Line number data type not well-defined between \ 34*da0073e9SAndroid Build Coastguard Worker compilers, so we perform an explicit cast */ \ 35*da0073e9SAndroid Build Coastguard Worker static_cast<uint32_t>(__LINE__), \ 36*da0073e9SAndroid Build Coastguard Worker true); \ 37*da0073e9SAndroid Build Coastguard Worker } while (0) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_CHECK_WARN(EXPR) \ 40*da0073e9SAndroid Build Coastguard Worker do { \ 41*da0073e9SAndroid Build Coastguard Worker const cudaError_t __err = EXPR; \ 42*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(__err != cudaSuccess)) { \ 43*da0073e9SAndroid Build Coastguard Worker auto error_unused C10_UNUSED = cudaGetLastError(); \ 44*da0073e9SAndroid Build Coastguard Worker (void)error_unused; \ 45*da0073e9SAndroid Build Coastguard Worker TORCH_WARN("CUDA warning: ", cudaGetErrorString(__err)); \ 46*da0073e9SAndroid Build Coastguard Worker } \ 47*da0073e9SAndroid Build Coastguard Worker } while (0) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker // Indicates that a CUDA error is handled in a non-standard way 50*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_ERROR_HANDLED(EXPR) EXPR 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker // Intentionally ignore a CUDA error 53*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_IGNORE_ERROR(EXPR) \ 54*da0073e9SAndroid Build Coastguard Worker do { \ 55*da0073e9SAndroid Build Coastguard Worker const cudaError_t __err = EXPR; \ 56*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(__err != cudaSuccess)) { \ 57*da0073e9SAndroid Build Coastguard Worker cudaError_t error_unused C10_UNUSED = cudaGetLastError(); \ 58*da0073e9SAndroid Build Coastguard Worker (void)error_unused; \ 59*da0073e9SAndroid Build Coastguard Worker } \ 60*da0073e9SAndroid Build Coastguard Worker } while (0) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker // Clear the last CUDA error 63*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_CLEAR_ERROR() \ 64*da0073e9SAndroid Build Coastguard Worker do { \ 65*da0073e9SAndroid Build Coastguard Worker cudaError_t error_unused C10_UNUSED = cudaGetLastError(); \ 66*da0073e9SAndroid Build Coastguard Worker (void)error_unused; \ 67*da0073e9SAndroid Build Coastguard Worker } while (0) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker // This should be used directly after every kernel launch to ensure 70*da0073e9SAndroid Build Coastguard Worker // the launch happened correctly and provide an early, close-to-source 71*da0073e9SAndroid Build Coastguard Worker // diagnostic if it didn't. 72*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError()) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker /// Launches a CUDA kernel appending to it all the information need to handle 75*da0073e9SAndroid Build Coastguard Worker /// device-side assertion failures. Checks that the launch was successful. 76*da0073e9SAndroid Build Coastguard Worker #define TORCH_DSA_KERNEL_LAUNCH( \ 77*da0073e9SAndroid Build Coastguard Worker kernel, blocks, threads, shared_mem, stream, ...) \ 78*da0073e9SAndroid Build Coastguard Worker do { \ 79*da0073e9SAndroid Build Coastguard Worker auto& launch_registry = \ 80*da0073e9SAndroid Build Coastguard Worker c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref(); \ 81*da0073e9SAndroid Build Coastguard Worker kernel<<<blocks, threads, shared_mem, stream>>>( \ 82*da0073e9SAndroid Build Coastguard Worker __VA_ARGS__, \ 83*da0073e9SAndroid Build Coastguard Worker launch_registry.get_uvm_assertions_ptr_for_current_device(), \ 84*da0073e9SAndroid Build Coastguard Worker launch_registry.insert( \ 85*da0073e9SAndroid Build Coastguard Worker __FILE__, __FUNCTION__, __LINE__, #kernel, stream.id())); \ 86*da0073e9SAndroid Build Coastguard Worker C10_CUDA_KERNEL_LAUNCH_CHECK(); \ 87*da0073e9SAndroid Build Coastguard Worker } while (0) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda { 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker /// In the event of a CUDA failure, formats a nice error message about that 92*da0073e9SAndroid Build Coastguard Worker /// failure and also checks for device-side assertion failures 93*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void c10_cuda_check_implementation( 94*da0073e9SAndroid Build Coastguard Worker const int32_t err, 95*da0073e9SAndroid Build Coastguard Worker const char* filename, 96*da0073e9SAndroid Build Coastguard Worker const char* function_name, 97*da0073e9SAndroid Build Coastguard Worker const int line_number, 98*da0073e9SAndroid Build Coastguard Worker const bool include_device_assertions); 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda 101