1 #pragma once 2 #include <cuda.h> 3 #define NVML_NO_UNVERSIONED_FUNC_DEFS 4 #include <nvml.h> 5 6 #define C10_CUDA_DRIVER_CHECK(EXPR) \ 7 do { \ 8 CUresult __err = EXPR; \ 9 if (__err != CUDA_SUCCESS) { \ 10 const char* err_str; \ 11 CUresult get_error_str_err C10_UNUSED = \ 12 c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \ 13 if (get_error_str_err != CUDA_SUCCESS) { \ 14 AT_ERROR("CUDA driver error: unknown error"); \ 15 } else { \ 16 AT_ERROR("CUDA driver error: ", err_str); \ 17 } \ 18 } \ 19 } while (0) 20 21 #define C10_LIBCUDA_DRIVER_API(_) \ 22 _(cuDeviceGetAttribute) \ 23 _(cuMemAddressReserve) \ 24 _(cuMemRelease) \ 25 _(cuMemMap) \ 26 _(cuMemAddressFree) \ 27 _(cuMemSetAccess) \ 28 _(cuMemUnmap) \ 29 _(cuMemCreate) \ 30 _(cuMemGetAllocationGranularity) \ 31 _(cuMemExportToShareableHandle) \ 32 _(cuMemImportFromShareableHandle) \ 33 _(cuGetErrorString) 34 35 #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) 36 #define C10_LIBCUDA_DRIVER_API_12030(_) \ 37 _(cuMulticastAddDevice) \ 38 _(cuMulticastBindMem) \ 39 _(cuMulticastCreate) 40 #else 41 #define C10_LIBCUDA_DRIVER_API_12030(_) 42 #endif 43 44 #define C10_NVML_DRIVER_API(_) \ 45 _(nvmlInit_v2) \ 46 _(nvmlDeviceGetHandleByPciBusId_v2) \ 47 _(nvmlDeviceGetNvLinkRemoteDeviceType) \ 48 _(nvmlDeviceGetNvLinkRemotePciInfo_v2) \ 49 _(nvmlDeviceGetComputeRunningProcesses) 50 51 namespace c10::cuda { 52 53 struct DriverAPI { 54 #define CREATE_MEMBER(name) decltype(&name) name##_; 55 C10_LIBCUDA_DRIVER_API(CREATE_MEMBER) 56 C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER) 57 C10_NVML_DRIVER_API(CREATE_MEMBER) 58 #undef CREATE_MEMBER 59 static DriverAPI* get(); 60 static void* get_nvml_handle(); 61 }; 62 63 } // namespace c10::cuda 64