xref: /aosp_15_r20/external/pytorch/c10/cuda/driver_api.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <cuda.h>
3 #define NVML_NO_UNVERSIONED_FUNC_DEFS
4 #include <nvml.h>
5 
6 #define C10_CUDA_DRIVER_CHECK(EXPR)                                        \
7   do {                                                                     \
8     CUresult __err = EXPR;                                                 \
9     if (__err != CUDA_SUCCESS) {                                           \
10       const char* err_str;                                                 \
11       CUresult get_error_str_err C10_UNUSED =                              \
12           c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \
13       if (get_error_str_err != CUDA_SUCCESS) {                             \
14         AT_ERROR("CUDA driver error: unknown error");                      \
15       } else {                                                             \
16         AT_ERROR("CUDA driver error: ", err_str);                          \
17       }                                                                    \
18     }                                                                      \
19   } while (0)
20 
21 #define C10_LIBCUDA_DRIVER_API(_)   \
22   _(cuDeviceGetAttribute)           \
23   _(cuMemAddressReserve)            \
24   _(cuMemRelease)                   \
25   _(cuMemMap)                       \
26   _(cuMemAddressFree)               \
27   _(cuMemSetAccess)                 \
28   _(cuMemUnmap)                     \
29   _(cuMemCreate)                    \
30   _(cuMemGetAllocationGranularity)  \
31   _(cuMemExportToShareableHandle)   \
32   _(cuMemImportFromShareableHandle) \
33   _(cuGetErrorString)
34 
35 #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
36 #define C10_LIBCUDA_DRIVER_API_12030(_) \
37   _(cuMulticastAddDevice)               \
38   _(cuMulticastBindMem)                 \
39   _(cuMulticastCreate)
40 #else
41 #define C10_LIBCUDA_DRIVER_API_12030(_)
42 #endif
43 
44 #define C10_NVML_DRIVER_API(_)           \
45   _(nvmlInit_v2)                         \
46   _(nvmlDeviceGetHandleByPciBusId_v2)    \
47   _(nvmlDeviceGetNvLinkRemoteDeviceType) \
48   _(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
49   _(nvmlDeviceGetComputeRunningProcesses)
50 
51 namespace c10::cuda {
52 
53 struct DriverAPI {
54 #define CREATE_MEMBER(name) decltype(&name) name##_;
55   C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
56   C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER)
57   C10_NVML_DRIVER_API(CREATE_MEMBER)
58 #undef CREATE_MEMBER
59   static DriverAPI* get();
60   static void* get_nvml_handle();
61 };
62 
63 } // namespace c10::cuda
64