1 #pragma once 2 3 #include <c10/core/Allocator.h> 4 #include <c10/util/Exception.h> 5 #include <c10/util/Registry.h> 6 7 #include <ATen/detail/AcceleratorHooksInterface.h> 8 9 // Forward-declares at::Generator and at::cuda::NVRTC 10 namespace at { 11 struct Generator; 12 namespace cuda { 13 struct NVRTC; 14 } // namespace cuda 15 } // namespace at 16 17 // NB: Class must live in `at` due to limitations of Registry.h. 18 namespace at { 19 20 #ifdef _MSC_VER 21 constexpr const char* CUDA_HELP = 22 "PyTorch splits its backend into two shared libraries: a CPU library " 23 "and a CUDA library; this error has occurred because you are trying " 24 "to use some CUDA functionality, but the CUDA library has not been " 25 "loaded by the dynamic linker for some reason. The CUDA library MUST " 26 "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! " 27 "One common culprit is a lack of -INCLUDE:?warp_size@cuda@at@@YAHXZ " 28 "in your link arguments; many dynamic linkers will delete dynamic library " 29 "dependencies if you don't depend on any of their symbols. You can check " 30 "if this has occurred by using link on your binary to see if there is a " 31 "dependency on *_cuda.dll library."; 32 #else 33 constexpr const char* CUDA_HELP = 34 "PyTorch splits its backend into two shared libraries: a CPU library " 35 "and a CUDA library; this error has occurred because you are trying " 36 "to use some CUDA functionality, but the CUDA library has not been " 37 "loaded by the dynamic linker for some reason. The CUDA library MUST " 38 "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! " 39 "One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many " 40 "dynamic linkers will delete dynamic library dependencies if you don't " 41 "depend on any of their symbols. You can check if this has occurred by " 42 "using ldd on your binary to see if there is a dependency on *_cuda.so " 43 "library."; 44 #endif 45 46 // The CUDAHooksInterface is an omnibus interface for any CUDA functionality 47 // which we may want to call into from CPU code (and thus must be dynamically 48 // dispatched, to allow for separate compilation of CUDA code). How do I 49 // decide if a function should live in this class? There are two tests: 50 // 51 // 1. Does the *implementation* of this function require linking against 52 // CUDA libraries? 53 // 54 // 2. Is this function *called* from non-CUDA ATen code? 55 // 56 // (2) should filter out many ostensible use-cases, since many times a CUDA 57 // function provided by ATen is only really ever used by actual CUDA code. 58 // 59 // TODO: Consider putting the stub definitions in another class, so that one 60 // never forgets to implement each virtual function in the real implementation 61 // in CUDAHooks. This probably doesn't buy us much though. 62 struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { 63 // This should never actually be implemented, but it is used to 64 // squelch -Werror=non-virtual-dtor 65 ~CUDAHooksInterface() override = default; 66 67 // Initialize THCState and, transitively, the CUDA state initCUDACUDAHooksInterface68 virtual void initCUDA() const { 69 TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP); 70 } 71 72 virtual const Generator& getDefaultCUDAGenerator(C10_UNUSED DeviceIndex device_index = -1) const { 73 TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP); 74 } 75 getDeviceFromPtrCUDAHooksInterface76 virtual Device getDeviceFromPtr(void* /*data*/) const { 77 TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP); 78 } 79 isPinnedPtrCUDAHooksInterface80 bool isPinnedPtr(const void* data) const override { 81 return false; 82 } 83 hasCUDACUDAHooksInterface84 virtual bool hasCUDA() const { 85 return false; 86 } 87 hasCUDARTCUDAHooksInterface88 virtual bool hasCUDART() const { 89 return false; 90 } 91 hasMAGMACUDAHooksInterface92 virtual bool hasMAGMA() const { 93 return false; 94 } 95 hasCuDNNCUDAHooksInterface96 virtual bool hasCuDNN() const { 97 return false; 98 } 99 hasCuSOLVERCUDAHooksInterface100 virtual bool hasCuSOLVER() const { 101 return false; 102 } 103 hasCuBLASLtCUDAHooksInterface104 virtual bool hasCuBLASLt() const { 105 return false; 106 } 107 hasROCMCUDAHooksInterface108 virtual bool hasROCM() const { 109 return false; 110 } 111 nvrtcCUDAHooksInterface112 virtual const at::cuda::NVRTC& nvrtc() const { 113 TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); 114 } 115 hasPrimaryContextCUDAHooksInterface116 bool hasPrimaryContext(DeviceIndex device_index) const override { 117 TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP); 118 } 119 current_deviceCUDAHooksInterface120 virtual DeviceIndex current_device() const { 121 return -1; 122 } 123 getPinnedMemoryAllocatorCUDAHooksInterface124 Allocator* getPinnedMemoryAllocator() const override { 125 TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP); 126 } 127 getCUDADeviceAllocatorCUDAHooksInterface128 virtual Allocator* getCUDADeviceAllocator() const { 129 TORCH_CHECK(false, "CUDADeviceAllocator requires CUDA. ", CUDA_HELP); 130 } 131 compiledWithCuDNNCUDAHooksInterface132 virtual bool compiledWithCuDNN() const { 133 return false; 134 } 135 compiledWithMIOpenCUDAHooksInterface136 virtual bool compiledWithMIOpen() const { 137 return false; 138 } 139 supportsDilatedConvolutionWithCuDNNCUDAHooksInterface140 virtual bool supportsDilatedConvolutionWithCuDNN() const { 141 return false; 142 } 143 supportsDepthwiseConvolutionWithCuDNNCUDAHooksInterface144 virtual bool supportsDepthwiseConvolutionWithCuDNN() const { 145 return false; 146 } 147 supportsBFloat16ConvolutionWithCuDNNv8CUDAHooksInterface148 virtual bool supportsBFloat16ConvolutionWithCuDNNv8() const { 149 return false; 150 } 151 versionCuDNNCUDAHooksInterface152 virtual long versionCuDNN() const { 153 TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP); 154 } 155 versionCUDARTCUDAHooksInterface156 virtual long versionCUDART() const { 157 TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP); 158 } 159 showConfigCUDAHooksInterface160 virtual std::string showConfig() const { 161 TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP); 162 } 163 batchnormMinEpsilonCuDNNCUDAHooksInterface164 virtual double batchnormMinEpsilonCuDNN() const { 165 TORCH_CHECK(false, 166 "Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP); 167 } 168 cuFFTGetPlanCacheMaxSizeCUDAHooksInterface169 virtual int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex /*device_index*/) const { 170 TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); 171 } 172 cuFFTSetPlanCacheMaxSizeCUDAHooksInterface173 virtual void cuFFTSetPlanCacheMaxSize(DeviceIndex /*device_index*/, int64_t /*max_size*/) const { 174 TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); 175 } 176 cuFFTGetPlanCacheSizeCUDAHooksInterface177 virtual int64_t cuFFTGetPlanCacheSize(DeviceIndex /*device_index*/) const { 178 TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); 179 } 180 cuFFTClearPlanCacheCUDAHooksInterface181 virtual void cuFFTClearPlanCache(DeviceIndex /*device_index*/) const { 182 TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); 183 } 184 getNumGPUsCUDAHooksInterface185 virtual int getNumGPUs() const { 186 return 0; 187 } 188 189 #ifdef USE_ROCM isGPUArchCUDAHooksInterface190 virtual bool isGPUArch(DeviceIndex /*device_index*/, const std::vector<std::string>& /*archs*/) const { 191 TORCH_CHECK(false, "Cannot check GPU arch without ATen_cuda library. ", CUDA_HELP); 192 } 193 #endif 194 deviceSynchronizeCUDAHooksInterface195 virtual void deviceSynchronize(DeviceIndex /*device_index*/) const { 196 TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP); 197 } 198 }; 199 200 // NB: dummy argument to suppress "ISO C++11 requires at least one argument 201 // for the "..." in a variadic macro" 202 struct TORCH_API CUDAHooksArgs {}; 203 204 TORCH_DECLARE_REGISTRY(CUDAHooksRegistry, CUDAHooksInterface, CUDAHooksArgs); 205 #define REGISTER_CUDA_HOOKS(clsname) \ 206 C10_REGISTER_CLASS(CUDAHooksRegistry, clsname, clsname) 207 208 namespace detail { 209 TORCH_API const CUDAHooksInterface& getCUDAHooks(); 210 } // namespace detail 211 } // namespace at 212