xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/detail/LazyNVRTC.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cuda/detail/LazyNVRTC.h>
2 
3 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
4 #include <ATen/DynamicLibrary.h>
5 #include <stdexcept>
6 
7 namespace at {
8 namespace cuda {
9 namespace detail {
10 namespace _stubs {
11 
getCUDALibrary()12 at::DynamicLibrary& getCUDALibrary() {
13 #if defined(_WIN32)
14   static at::DynamicLibrary lib("nvcuda.dll");
15 #else
16   static at::DynamicLibrary lib("libcuda.so.1");
17 #endif
18   return lib;
19 }
20 
getLibVersion()21 static std::string getLibVersion() {
22   // [NVRTC versioning]
23   // Quote of https://docs.nvidia.com/cuda/nvrtc/index.html Section 8.1. NVRTC library versioning
24   //
25   // In the following, MAJOR and MINOR denote the major and minor versions of the CUDA Toolkit.
26   // e.g. for CUDA 11.2, MAJOR is "11" and MINOR is "2".
27   //
28   // Linux:
29   //   - In CUDA toolkits prior to CUDA 11.3, the soname was set to "MAJOR.MINOR".
30   //   - In CUDA 11.3 and later 11.x toolkits, the soname field is set to "11.2".
31   //   - In CUDA toolkits with major version > 11 (e.g. CUDA 12.x), the soname field is set to "MAJOR".
32   //
33   // Windows:
34   //   - In CUDA toolkits prior to cuda 11.3, the DLL name was of the form "nvrtc64_XY_0.dll", where X = MAJOR, Y = MINOR.
35   //   - In CUDA 11.3 and later 11.x toolkits, the DLL name is "nvrtc64_112_0.dll".
36   //   - In CUDA toolkits with major version > 11 (e.g. CUDA 12.x), the DLL name is of the form "nvrtc64_X0_0.dll" where X = MAJOR.
37   //
38   // Consider a CUDA toolkit with major version > 11. The NVRTC library in this CUDA toolkit will have the same soname (Linux)
39   // or DLL name (Windows) as an NVRTC library in a previous minor version of the same CUDA toolkit. Similarly, the NVRTC
40   // library in CUDA 11.3 and later 11.x releases will have the same soname (Linux) or DLL name (Windows) as the NVRTC library in CUDA 11.2.
41   constexpr auto major = CUDA_VERSION / 1000;
42   constexpr auto minor = ( CUDA_VERSION / 10 ) % 10;
43 #if defined(_WIN32)
44   if (major < 11 || (major == 11 && minor < 3)) {
45     return std::to_string(major) + std::to_string(minor);
46   } else if (major == 11) {
47     return "112";
48   } else {
49     return std::to_string(major) + "0";
50   }
51 #else
52   if (major < 11 || (major == 11 && minor < 3)) {
53     return std::to_string(major) + "." + std::to_string(minor);
54   } else if (major == 11) {
55     return "11.2";
56   } else {
57     return std::to_string(major);
58   }
59 #endif
60 }
61 
getLibName()62 static std::string getLibName() {
63 #if defined(_WIN32)
64   return std::string("nvrtc64_") + getLibVersion() + "_0.dll";
65 #else
66   return std::string("libnvrtc.so.") + getLibVersion();
67 #endif
68 }
69 
getAltLibName()70 static std::string getAltLibName() {
71 #if !defined(_WIN32) && defined(NVRTC_SHORTHASH)
72   return std::string("libnvrtc-") + C10_STRINGIZE(NVRTC_SHORTHASH) + ".so." + getLibVersion();
73 #else
74   return {};
75 #endif
76 }
77 
getNVRTCLibrary()78 at::DynamicLibrary& getNVRTCLibrary() {
79   static std::string libname = getLibName();
80   static std::string alt_libname = getAltLibName();
81   static at::DynamicLibrary lib(libname.c_str(), alt_libname.empty() ? nullptr : alt_libname.c_str());
82   return lib;
83 }
84 
85 #define _STUB_1(LIB, NAME, RETTYPE, ARG1)                                            \
86 RETTYPE NAME(ARG1 a1) {                                                              \
87   auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
88   if (!fn)                                                                           \
89     throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) );                     \
90   lazyNVRTC.NAME = fn;                                                               \
91   return fn(a1);                                                                     \
92 }
93 
94 #define _STUB_2(LIB, NAME, RETTYPE, ARG1, ARG2)                                      \
95 RETTYPE NAME(ARG1 a1, ARG2 a2) {                                                     \
96   auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
97   if (!fn)                                                                           \
98     throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) );                     \
99   lazyNVRTC.NAME = fn;                                                               \
100   return fn(a1, a2);                                                                 \
101 }
102 
103 #define _STUB_3(LIB, NAME, RETTYPE, ARG1, ARG2, ARG3)                                \
104 RETTYPE NAME(ARG1 a1, ARG2 a2, ARG3 a3) {                                            \
105   auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
106   if (!fn)                                                                           \
107     throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) );                     \
108   lazyNVRTC.NAME = fn;                                                               \
109   return fn(a1, a2, a3);                                                             \
110 }
111 
112 #define _STUB_4(LIB, NAME, RETTYPE, ARG1, ARG2, ARG3, ARG4)                          \
113 RETTYPE NAME(ARG1 a1, ARG2 a2, ARG3 a3, ARG4 a4) {                                   \
114   auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
115   if (!fn)                                                                           \
116     throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) );                     \
117   lazyNVRTC.NAME = fn;                                                               \
118   return fn(a1, a2, a3, a4);                                                         \
119 }
120 
121 #define CUDA_STUB1(NAME, A1) _STUB_1(CUDA, NAME, CUresult CUDAAPI, A1)
122 #define CUDA_STUB2(NAME, A1, A2) _STUB_2(CUDA, NAME, CUresult CUDAAPI, A1, A2)
123 #define CUDA_STUB3(NAME, A1, A2, A3) _STUB_3(CUDA, NAME, CUresult CUDAAPI, A1, A2, A3)
124 #define CUDA_STUB4(NAME, A1, A2, A3, A4) _STUB_4(CUDA, NAME, CUresult CUDAAPI, A1, A2, A3, A4)
125 
126 #define NVRTC_STUB1(NAME, A1) _STUB_1(NVRTC, NAME, nvrtcResult, A1)
127 #define NVRTC_STUB2(NAME, A1, A2) _STUB_2(NVRTC, NAME, nvrtcResult, A1, A2)
128 #define NVRTC_STUB3(NAME, A1, A2, A3) _STUB_3(NVRTC, NAME, nvrtcResult, A1, A2, A3)
129 
130 NVRTC_STUB2(nvrtcVersion, int*, int*);
131 NVRTC_STUB2(nvrtcAddNameExpression, nvrtcProgram, const char * const);
132 
nvrtcCreateProgram(nvrtcProgram * prog,const char * src,const char * name,int numHeaders,const char * const * headers,const char * const * includeNames)133 nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
134                                const char *src,
135                                const char *name,
136                                int numHeaders,
137                                const char * const *headers,
138                                const char * const *includeNames) {
139   auto fn = reinterpret_cast<decltype(&nvrtcCreateProgram)>(getNVRTCLibrary().sym(__func__));
140   if (!fn)
141     throw std::runtime_error("Can't get nvrtcCreateProgram");
142   lazyNVRTC.nvrtcCreateProgram = fn;
143   return fn(prog, src, name, numHeaders, headers, includeNames);
144 }
145 
146 NVRTC_STUB1(nvrtcDestroyProgram, nvrtcProgram *);
147 NVRTC_STUB2(nvrtcGetPTXSize, nvrtcProgram, size_t *);
148 NVRTC_STUB2(nvrtcGetPTX, nvrtcProgram, char *);
149 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
150 NVRTC_STUB2(nvrtcGetCUBINSize, nvrtcProgram, size_t *);
151 NVRTC_STUB2(nvrtcGetCUBIN, nvrtcProgram, char *);
152 #endif
153 NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char * const *);
154 _STUB_1(NVRTC, nvrtcGetErrorString, const char *, nvrtcResult);
155 NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*);
156 NVRTC_STUB2(nvrtcGetProgramLog, nvrtcProgram, char *);
157 NVRTC_STUB3(nvrtcGetLoweredName, nvrtcProgram, const char *, const char **);
158 
159 CUDA_STUB2(cuModuleLoadData, CUmodule *, const void *);
160 CUDA_STUB3(cuModuleGetFunction, CUfunction *, CUmodule, const char *);
161 CUDA_STUB4(cuOccupancyMaxActiveBlocksPerMultiprocessor, int *, CUfunction, int, size_t);
162 CUDA_STUB2(cuGetErrorString, CUresult, const char **);
163 CUDA_STUB1(cuCtxGetCurrent, CUcontext *);
164 CUDA_STUB1(cuCtxSetCurrent, CUcontext);
165 CUDA_STUB1(cuModuleUnload, CUmodule);
166 CUDA_STUB3(cuDevicePrimaryCtxGetState, CUdevice, unsigned int *, int *);
167 CUDA_STUB2(cuDevicePrimaryCtxRetain, CUcontext *, CUdevice);
168 CUDA_STUB4(cuLinkCreate, unsigned int, CUjit_option *, void **, CUlinkState *);
169 CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *);
170 CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int);
171 CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction);
172 
173 #if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
174 CUresult CUDAAPI
cuTensorMapEncodeTiled(CUtensorMap * tensorMap,CUtensorMapDataType tensorDataType,cuuint32_t tensorRank,void * globalAddress,const cuuint64_t * globalDim,const cuuint64_t * globalStrides,const cuuint32_t * boxDim,const cuuint32_t * elementStrides,CUtensorMapInterleave interleave,CUtensorMapSwizzle swizzle,CUtensorMapL2promotion l2Promotion,CUtensorMapFloatOOBfill oobFill)175 cuTensorMapEncodeTiled(
176     CUtensorMap* tensorMap,
177     CUtensorMapDataType tensorDataType,
178     cuuint32_t tensorRank,
179     void* globalAddress,
180     const cuuint64_t* globalDim,
181     const cuuint64_t* globalStrides,
182     const cuuint32_t* boxDim,
183     const cuuint32_t* elementStrides,
184     CUtensorMapInterleave interleave,
185     CUtensorMapSwizzle swizzle,
186     CUtensorMapL2promotion l2Promotion,
187     CUtensorMapFloatOOBfill oobFill) {
188   auto fn = reinterpret_cast<decltype(&cuTensorMapEncodeTiled)>(
189       getCUDALibrary().sym(__func__));
190   if (!fn)
191     throw std::runtime_error("Can't get cuTensorMapEncodeTiled");
192   lazyNVRTC.cuTensorMapEncodeTiled = fn;
193   return fn(
194       tensorMap,
195       tensorDataType,
196       tensorRank,
197       globalAddress,
198       globalDim,
199       globalStrides,
200       boxDim,
201       elementStrides,
202       interleave,
203       swizzle,
204       l2Promotion,
205       oobFill);
206 }
207 
208 #endif
209 
210 // Irregularly shaped functions
cuLaunchKernel(CUfunction f,unsigned int gridDimX,unsigned int gridDimY,unsigned int gridDimZ,unsigned int blockDimX,unsigned int blockDimY,unsigned int blockDimZ,unsigned int sharedMemBytes,CUstream hStream,void ** kernelParams,void ** extra)211 CUresult CUDAAPI cuLaunchKernel(CUfunction f,
212                                 unsigned int gridDimX,
213                                 unsigned int gridDimY,
214                                 unsigned int gridDimZ,
215                                 unsigned int blockDimX,
216                                 unsigned int blockDimY,
217                                 unsigned int blockDimZ,
218                                 unsigned int sharedMemBytes,
219                                 CUstream hStream,
220                                 void **kernelParams,
221                                 void **extra) {
222   auto fn = reinterpret_cast<decltype(&cuLaunchKernel)>(getCUDALibrary().sym(__func__));
223   if (!fn)
224     throw std::runtime_error("Can't get cuLaunchKernel");
225   lazyNVRTC.cuLaunchKernel = fn;
226   return fn(f,
227             gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ,
228             sharedMemBytes, hStream, kernelParams, extra);
229 }
230 
231 // Irregularly shaped functions
cuLaunchCooperativeKernel(CUfunction f,unsigned int gridDimX,unsigned int gridDimY,unsigned int gridDimZ,unsigned int blockDimX,unsigned int blockDimY,unsigned int blockDimZ,unsigned int sharedMemBytes,CUstream hStream,void ** kernelParams)232 CUresult CUDAAPI cuLaunchCooperativeKernel(
233     CUfunction f,
234     unsigned int gridDimX,
235     unsigned int gridDimY,
236     unsigned int gridDimZ,
237     unsigned int blockDimX,
238     unsigned int blockDimY,
239     unsigned int blockDimZ,
240     unsigned int sharedMemBytes,
241     CUstream hStream,
242     void** kernelParams) {
243   auto fn = reinterpret_cast<decltype(&cuLaunchCooperativeKernel)>(
244       getCUDALibrary().sym(__func__));
245   if (!fn)
246     throw std::runtime_error("Can't get cuLaunchCooperativeKernel");
247   lazyNVRTC.cuLaunchCooperativeKernel = fn;
248   return fn(
249       f,
250       gridDimX,
251       gridDimY,
252       gridDimZ,
253       blockDimX,
254       blockDimY,
255       blockDimZ,
256       sharedMemBytes,
257       hStream,
258       kernelParams);
259 }
260 
cuModuleLoadDataEx(CUmodule * module,const void * image,unsigned int numOptions,CUjit_option * options,void ** optionValues)261 CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module,
262                                     const void *image,
263                                     unsigned int numOptions,
264                                     CUjit_option *options,
265                                     void **optionValues) {
266   auto fn = reinterpret_cast<decltype(&cuModuleLoadDataEx)>(getCUDALibrary().sym(__func__));
267   if (!fn)
268     throw std::runtime_error("Can't get cuModuleLoadDataEx");
269   lazyNVRTC.cuModuleLoadDataEx = fn;
270   return fn(module, image, numOptions, options, optionValues);
271 }
272 
273 CUresult CUDAAPI
cuLinkAddData(CUlinkState state,CUjitInputType type,void * data,size_t size,const char * name,unsigned int numOptions,CUjit_option * options,void ** optionValues)274 cuLinkAddData(CUlinkState state,
275               CUjitInputType type,
276               void *data,
277               size_t size,
278               const char *name,
279               unsigned int numOptions,
280               CUjit_option *options,
281               void **optionValues) {
282   auto fn = reinterpret_cast<decltype(&cuLinkAddData)>(getCUDALibrary().sym(__func__));
283   if (!fn)
284     throw std::runtime_error("Can't get cuLinkAddData");
285   lazyNVRTC.cuLinkAddData = fn;
286   return fn(state, type, data, size, name, numOptions, options, optionValues);
287 }
288 
289 } // namespace _stubs
290 
291 NVRTC lazyNVRTC = {
292 #define _REFERENCE_MEMBER(name) _stubs::name,
293   AT_FORALL_NVRTC(_REFERENCE_MEMBER)
294 #undef _REFERENCE_MEMBER
295 };
296 } // namespace detail
297 } // namespace cuda
298 } // namespace at
299