1 #include <ATen/cuda/detail/LazyNVRTC.h>
2
3 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
4 #include <ATen/DynamicLibrary.h>
5 #include <stdexcept>
6
7 namespace at {
8 namespace cuda {
9 namespace detail {
10 namespace _stubs {
11
getCUDALibrary()12 at::DynamicLibrary& getCUDALibrary() {
13 #if defined(_WIN32)
14 static at::DynamicLibrary lib("nvcuda.dll");
15 #else
16 static at::DynamicLibrary lib("libcuda.so.1");
17 #endif
18 return lib;
19 }
20
getLibVersion()21 static std::string getLibVersion() {
22 // [NVRTC versioning]
23 // Quote of https://docs.nvidia.com/cuda/nvrtc/index.html Section 8.1. NVRTC library versioning
24 //
25 // In the following, MAJOR and MINOR denote the major and minor versions of the CUDA Toolkit.
26 // e.g. for CUDA 11.2, MAJOR is "11" and MINOR is "2".
27 //
28 // Linux:
29 // - In CUDA toolkits prior to CUDA 11.3, the soname was set to "MAJOR.MINOR".
30 // - In CUDA 11.3 and later 11.x toolkits, the soname field is set to "11.2".
31 // - In CUDA toolkits with major version > 11 (e.g. CUDA 12.x), the soname field is set to "MAJOR".
32 //
33 // Windows:
34 // - In CUDA toolkits prior to cuda 11.3, the DLL name was of the form "nvrtc64_XY_0.dll", where X = MAJOR, Y = MINOR.
35 // - In CUDA 11.3 and later 11.x toolkits, the DLL name is "nvrtc64_112_0.dll".
36 // - In CUDA toolkits with major version > 11 (e.g. CUDA 12.x), the DLL name is of the form "nvrtc64_X0_0.dll" where X = MAJOR.
37 //
38 // Consider a CUDA toolkit with major version > 11. The NVRTC library in this CUDA toolkit will have the same soname (Linux)
39 // or DLL name (Windows) as an NVRTC library in a previous minor version of the same CUDA toolkit. Similarly, the NVRTC
40 // library in CUDA 11.3 and later 11.x releases will have the same soname (Linux) or DLL name (Windows) as the NVRTC library in CUDA 11.2.
41 constexpr auto major = CUDA_VERSION / 1000;
42 constexpr auto minor = ( CUDA_VERSION / 10 ) % 10;
43 #if defined(_WIN32)
44 if (major < 11 || (major == 11 && minor < 3)) {
45 return std::to_string(major) + std::to_string(minor);
46 } else if (major == 11) {
47 return "112";
48 } else {
49 return std::to_string(major) + "0";
50 }
51 #else
52 if (major < 11 || (major == 11 && minor < 3)) {
53 return std::to_string(major) + "." + std::to_string(minor);
54 } else if (major == 11) {
55 return "11.2";
56 } else {
57 return std::to_string(major);
58 }
59 #endif
60 }
61
getLibName()62 static std::string getLibName() {
63 #if defined(_WIN32)
64 return std::string("nvrtc64_") + getLibVersion() + "_0.dll";
65 #else
66 return std::string("libnvrtc.so.") + getLibVersion();
67 #endif
68 }
69
getAltLibName()70 static std::string getAltLibName() {
71 #if !defined(_WIN32) && defined(NVRTC_SHORTHASH)
72 return std::string("libnvrtc-") + C10_STRINGIZE(NVRTC_SHORTHASH) + ".so." + getLibVersion();
73 #else
74 return {};
75 #endif
76 }
77
getNVRTCLibrary()78 at::DynamicLibrary& getNVRTCLibrary() {
79 static std::string libname = getLibName();
80 static std::string alt_libname = getAltLibName();
81 static at::DynamicLibrary lib(libname.c_str(), alt_libname.empty() ? nullptr : alt_libname.c_str());
82 return lib;
83 }
84
85 #define _STUB_1(LIB, NAME, RETTYPE, ARG1) \
86 RETTYPE NAME(ARG1 a1) { \
87 auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
88 if (!fn) \
89 throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \
90 lazyNVRTC.NAME = fn; \
91 return fn(a1); \
92 }
93
94 #define _STUB_2(LIB, NAME, RETTYPE, ARG1, ARG2) \
95 RETTYPE NAME(ARG1 a1, ARG2 a2) { \
96 auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
97 if (!fn) \
98 throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \
99 lazyNVRTC.NAME = fn; \
100 return fn(a1, a2); \
101 }
102
103 #define _STUB_3(LIB, NAME, RETTYPE, ARG1, ARG2, ARG3) \
104 RETTYPE NAME(ARG1 a1, ARG2 a2, ARG3 a3) { \
105 auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
106 if (!fn) \
107 throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \
108 lazyNVRTC.NAME = fn; \
109 return fn(a1, a2, a3); \
110 }
111
112 #define _STUB_4(LIB, NAME, RETTYPE, ARG1, ARG2, ARG3, ARG4) \
113 RETTYPE NAME(ARG1 a1, ARG2 a2, ARG3 a3, ARG4 a4) { \
114 auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
115 if (!fn) \
116 throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \
117 lazyNVRTC.NAME = fn; \
118 return fn(a1, a2, a3, a4); \
119 }
120
121 #define CUDA_STUB1(NAME, A1) _STUB_1(CUDA, NAME, CUresult CUDAAPI, A1)
122 #define CUDA_STUB2(NAME, A1, A2) _STUB_2(CUDA, NAME, CUresult CUDAAPI, A1, A2)
123 #define CUDA_STUB3(NAME, A1, A2, A3) _STUB_3(CUDA, NAME, CUresult CUDAAPI, A1, A2, A3)
124 #define CUDA_STUB4(NAME, A1, A2, A3, A4) _STUB_4(CUDA, NAME, CUresult CUDAAPI, A1, A2, A3, A4)
125
126 #define NVRTC_STUB1(NAME, A1) _STUB_1(NVRTC, NAME, nvrtcResult, A1)
127 #define NVRTC_STUB2(NAME, A1, A2) _STUB_2(NVRTC, NAME, nvrtcResult, A1, A2)
128 #define NVRTC_STUB3(NAME, A1, A2, A3) _STUB_3(NVRTC, NAME, nvrtcResult, A1, A2, A3)
129
130 NVRTC_STUB2(nvrtcVersion, int*, int*);
131 NVRTC_STUB2(nvrtcAddNameExpression, nvrtcProgram, const char * const);
132
nvrtcCreateProgram(nvrtcProgram * prog,const char * src,const char * name,int numHeaders,const char * const * headers,const char * const * includeNames)133 nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
134 const char *src,
135 const char *name,
136 int numHeaders,
137 const char * const *headers,
138 const char * const *includeNames) {
139 auto fn = reinterpret_cast<decltype(&nvrtcCreateProgram)>(getNVRTCLibrary().sym(__func__));
140 if (!fn)
141 throw std::runtime_error("Can't get nvrtcCreateProgram");
142 lazyNVRTC.nvrtcCreateProgram = fn;
143 return fn(prog, src, name, numHeaders, headers, includeNames);
144 }
145
146 NVRTC_STUB1(nvrtcDestroyProgram, nvrtcProgram *);
147 NVRTC_STUB2(nvrtcGetPTXSize, nvrtcProgram, size_t *);
148 NVRTC_STUB2(nvrtcGetPTX, nvrtcProgram, char *);
149 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
150 NVRTC_STUB2(nvrtcGetCUBINSize, nvrtcProgram, size_t *);
151 NVRTC_STUB2(nvrtcGetCUBIN, nvrtcProgram, char *);
152 #endif
153 NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char * const *);
154 _STUB_1(NVRTC, nvrtcGetErrorString, const char *, nvrtcResult);
155 NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*);
156 NVRTC_STUB2(nvrtcGetProgramLog, nvrtcProgram, char *);
157 NVRTC_STUB3(nvrtcGetLoweredName, nvrtcProgram, const char *, const char **);
158
159 CUDA_STUB2(cuModuleLoadData, CUmodule *, const void *);
160 CUDA_STUB3(cuModuleGetFunction, CUfunction *, CUmodule, const char *);
161 CUDA_STUB4(cuOccupancyMaxActiveBlocksPerMultiprocessor, int *, CUfunction, int, size_t);
162 CUDA_STUB2(cuGetErrorString, CUresult, const char **);
163 CUDA_STUB1(cuCtxGetCurrent, CUcontext *);
164 CUDA_STUB1(cuCtxSetCurrent, CUcontext);
165 CUDA_STUB1(cuModuleUnload, CUmodule);
166 CUDA_STUB3(cuDevicePrimaryCtxGetState, CUdevice, unsigned int *, int *);
167 CUDA_STUB2(cuDevicePrimaryCtxRetain, CUcontext *, CUdevice);
168 CUDA_STUB4(cuLinkCreate, unsigned int, CUjit_option *, void **, CUlinkState *);
169 CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *);
170 CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int);
171 CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction);
172
173 #if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
174 CUresult CUDAAPI
cuTensorMapEncodeTiled(CUtensorMap * tensorMap,CUtensorMapDataType tensorDataType,cuuint32_t tensorRank,void * globalAddress,const cuuint64_t * globalDim,const cuuint64_t * globalStrides,const cuuint32_t * boxDim,const cuuint32_t * elementStrides,CUtensorMapInterleave interleave,CUtensorMapSwizzle swizzle,CUtensorMapL2promotion l2Promotion,CUtensorMapFloatOOBfill oobFill)175 cuTensorMapEncodeTiled(
176 CUtensorMap* tensorMap,
177 CUtensorMapDataType tensorDataType,
178 cuuint32_t tensorRank,
179 void* globalAddress,
180 const cuuint64_t* globalDim,
181 const cuuint64_t* globalStrides,
182 const cuuint32_t* boxDim,
183 const cuuint32_t* elementStrides,
184 CUtensorMapInterleave interleave,
185 CUtensorMapSwizzle swizzle,
186 CUtensorMapL2promotion l2Promotion,
187 CUtensorMapFloatOOBfill oobFill) {
188 auto fn = reinterpret_cast<decltype(&cuTensorMapEncodeTiled)>(
189 getCUDALibrary().sym(__func__));
190 if (!fn)
191 throw std::runtime_error("Can't get cuTensorMapEncodeTiled");
192 lazyNVRTC.cuTensorMapEncodeTiled = fn;
193 return fn(
194 tensorMap,
195 tensorDataType,
196 tensorRank,
197 globalAddress,
198 globalDim,
199 globalStrides,
200 boxDim,
201 elementStrides,
202 interleave,
203 swizzle,
204 l2Promotion,
205 oobFill);
206 }
207
208 #endif
209
210 // Irregularly shaped functions
cuLaunchKernel(CUfunction f,unsigned int gridDimX,unsigned int gridDimY,unsigned int gridDimZ,unsigned int blockDimX,unsigned int blockDimY,unsigned int blockDimZ,unsigned int sharedMemBytes,CUstream hStream,void ** kernelParams,void ** extra)211 CUresult CUDAAPI cuLaunchKernel(CUfunction f,
212 unsigned int gridDimX,
213 unsigned int gridDimY,
214 unsigned int gridDimZ,
215 unsigned int blockDimX,
216 unsigned int blockDimY,
217 unsigned int blockDimZ,
218 unsigned int sharedMemBytes,
219 CUstream hStream,
220 void **kernelParams,
221 void **extra) {
222 auto fn = reinterpret_cast<decltype(&cuLaunchKernel)>(getCUDALibrary().sym(__func__));
223 if (!fn)
224 throw std::runtime_error("Can't get cuLaunchKernel");
225 lazyNVRTC.cuLaunchKernel = fn;
226 return fn(f,
227 gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ,
228 sharedMemBytes, hStream, kernelParams, extra);
229 }
230
231 // Irregularly shaped functions
cuLaunchCooperativeKernel(CUfunction f,unsigned int gridDimX,unsigned int gridDimY,unsigned int gridDimZ,unsigned int blockDimX,unsigned int blockDimY,unsigned int blockDimZ,unsigned int sharedMemBytes,CUstream hStream,void ** kernelParams)232 CUresult CUDAAPI cuLaunchCooperativeKernel(
233 CUfunction f,
234 unsigned int gridDimX,
235 unsigned int gridDimY,
236 unsigned int gridDimZ,
237 unsigned int blockDimX,
238 unsigned int blockDimY,
239 unsigned int blockDimZ,
240 unsigned int sharedMemBytes,
241 CUstream hStream,
242 void** kernelParams) {
243 auto fn = reinterpret_cast<decltype(&cuLaunchCooperativeKernel)>(
244 getCUDALibrary().sym(__func__));
245 if (!fn)
246 throw std::runtime_error("Can't get cuLaunchCooperativeKernel");
247 lazyNVRTC.cuLaunchCooperativeKernel = fn;
248 return fn(
249 f,
250 gridDimX,
251 gridDimY,
252 gridDimZ,
253 blockDimX,
254 blockDimY,
255 blockDimZ,
256 sharedMemBytes,
257 hStream,
258 kernelParams);
259 }
260
cuModuleLoadDataEx(CUmodule * module,const void * image,unsigned int numOptions,CUjit_option * options,void ** optionValues)261 CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module,
262 const void *image,
263 unsigned int numOptions,
264 CUjit_option *options,
265 void **optionValues) {
266 auto fn = reinterpret_cast<decltype(&cuModuleLoadDataEx)>(getCUDALibrary().sym(__func__));
267 if (!fn)
268 throw std::runtime_error("Can't get cuModuleLoadDataEx");
269 lazyNVRTC.cuModuleLoadDataEx = fn;
270 return fn(module, image, numOptions, options, optionValues);
271 }
272
273 CUresult CUDAAPI
cuLinkAddData(CUlinkState state,CUjitInputType type,void * data,size_t size,const char * name,unsigned int numOptions,CUjit_option * options,void ** optionValues)274 cuLinkAddData(CUlinkState state,
275 CUjitInputType type,
276 void *data,
277 size_t size,
278 const char *name,
279 unsigned int numOptions,
280 CUjit_option *options,
281 void **optionValues) {
282 auto fn = reinterpret_cast<decltype(&cuLinkAddData)>(getCUDALibrary().sym(__func__));
283 if (!fn)
284 throw std::runtime_error("Can't get cuLinkAddData");
285 lazyNVRTC.cuLinkAddData = fn;
286 return fn(state, type, data, size, name, numOptions, options, optionValues);
287 }
288
289 } // namespace _stubs
290
291 NVRTC lazyNVRTC = {
292 #define _REFERENCE_MEMBER(name) _stubs::name,
293 AT_FORALL_NVRTC(_REFERENCE_MEMBER)
294 #undef _REFERENCE_MEMBER
295 };
296 } // namespace detail
297 } // namespace cuda
298 } // namespace at
299