1 #pragma once 2 // Light-weight version of CUDAContext.h with fewer transitive includes 3 4 #include <cstdint> 5 6 #include <cuda_runtime_api.h> 7 #include <cusparse.h> 8 #include <cublas_v2.h> 9 10 // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also 11 // added bf16 support 12 #include <cublasLt.h> 13 14 #ifdef CUDART_VERSION 15 #include <cusolverDn.h> 16 #endif 17 18 #if defined(USE_CUDSS) 19 #include <cudss.h> 20 #endif 21 22 #if defined(USE_ROCM) 23 #include <hipsolver/hipsolver.h> 24 #endif 25 26 #include <c10/core/Allocator.h> 27 #include <c10/cuda/CUDAFunctions.h> 28 29 namespace c10 { 30 struct Allocator; 31 } 32 33 namespace at::cuda { 34 35 /* 36 A common CUDA interface for ATen. 37 38 This interface is distinct from CUDAHooks, which defines an interface that links 39 to both CPU-only and CUDA builds. That interface is intended for runtime 40 dispatch and should be used from files that are included in both CPU-only and 41 CUDA builds. 42 43 CUDAContext, on the other hand, should be preferred by files only included in 44 CUDA builds. It is intended to expose CUDA functionality in a consistent 45 manner. 46 47 This means there is some overlap between the CUDAContext and CUDAHooks, but 48 the choice of which to use is simple: use CUDAContext when in a CUDA-only file, 49 use CUDAHooks otherwise. 50 51 Note that CUDAContext simply defines an interface with no associated class. 52 It is expected that the modules whose functions compose this interface will 53 manage their own state. There is only a single CUDA context/state. 54 */ 55 56 /** 57 * DEPRECATED: use device_count() instead 58 */ getNumGPUs()59inline int64_t getNumGPUs() { 60 return c10::cuda::device_count(); 61 } 62 63 /** 64 * CUDA is available if we compiled with CUDA, and there are one or more 65 * devices. If we compiled with CUDA but there is a driver problem, etc., 66 * this function will report CUDA is not available (rather than raise an error.) 67 */ is_available()68inline bool is_available() { 69 return c10::cuda::device_count() > 0; 70 } 71 72 TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties(); 73 74 TORCH_CUDA_CPP_API int warp_size(); 75 76 TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device); 77 78 TORCH_CUDA_CPP_API bool canDeviceAccessPeer( 79 c10::DeviceIndex device, 80 c10::DeviceIndex peer_device); 81 82 TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator(); 83 84 /* Handles */ 85 TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle(); 86 TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); 87 TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); 88 89 TORCH_CUDA_CPP_API void clearCublasWorkspaces(); 90 91 #if defined(CUDART_VERSION) || defined(USE_ROCM) 92 TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle(); 93 #endif 94 95 #if defined(USE_CUDSS) 96 TORCH_CUDA_CPP_API cudssHandle_t getCurrentCudssHandle(); 97 #endif 98 99 } // namespace at::cuda 100