1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <ATen/cuda/CUDAContext.h> 5 #include <ATen/cuda/CUDASparse.h> 6 7 #include <c10/core/ScalarType.h> 8 9 #if defined(USE_ROCM) 10 #include <type_traits> 11 #endif 12 13 namespace at::cuda::sparse { 14 15 template <typename T, cusparseStatus_t (*destructor)(T*)> 16 struct CuSparseDescriptorDeleter { operatorCuSparseDescriptorDeleter17 void operator()(T* x) { 18 if (x != nullptr) { 19 TORCH_CUDASPARSE_CHECK(destructor(x)); 20 } 21 } 22 }; 23 24 template <typename T, cusparseStatus_t (*destructor)(T*)> 25 class CuSparseDescriptor { 26 public: descriptor()27 T* descriptor() const { 28 return descriptor_.get(); 29 } descriptor()30 T* descriptor() { 31 return descriptor_.get(); 32 } 33 34 protected: 35 std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_; 36 }; 37 38 #if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 39 template <typename T, cusparseStatus_t (*destructor)(const T*)> 40 struct ConstCuSparseDescriptorDeleter { operatorConstCuSparseDescriptorDeleter41 void operator()(T* x) { 42 if (x != nullptr) { 43 TORCH_CUDASPARSE_CHECK(destructor(x)); 44 } 45 } 46 }; 47 48 template <typename T, cusparseStatus_t (*destructor)(const T*)> 49 class ConstCuSparseDescriptor { 50 public: descriptor()51 T* descriptor() const { 52 return descriptor_.get(); 53 } descriptor()54 T* descriptor() { 55 return descriptor_.get(); 56 } 57 58 protected: 59 std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_; 60 }; 61 #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS 62 63 #if defined(USE_ROCM) 64 using cusparseMatDescr = std::remove_pointer<hipsparseMatDescr_t>::type; 65 using cusparseDnMatDescr = std::remove_pointer<hipsparseDnMatDescr_t>::type; 66 using cusparseDnVecDescr = std::remove_pointer<hipsparseDnVecDescr_t>::type; 67 using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type; 68 using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type; 69 using cusparseSpGEMMDescr = std::remove_pointer<hipsparseSpGEMMDescr_t>::type; 70 #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 71 using bsrsv2Info = std::remove_pointer<bsrsv2Info_t>::type; 72 using bsrsm2Info = std::remove_pointer<bsrsm2Info_t>::type; 73 #endif 74 #endif 75 76 // NOTE: This is only needed for CUDA 11 and earlier, since CUDA 12 introduced 77 // API for const descriptors 78 cusparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr); 79 80 class TORCH_CUDA_CPP_API CuSparseMatDescriptor 81 : public CuSparseDescriptor<cusparseMatDescr, &cusparseDestroyMatDescr> { 82 public: CuSparseMatDescriptor()83 CuSparseMatDescriptor() { 84 cusparseMatDescr_t raw_descriptor = nullptr; 85 TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor)); 86 descriptor_.reset(raw_descriptor); 87 } 88 CuSparseMatDescriptor(bool upper,bool unit)89 CuSparseMatDescriptor(bool upper, bool unit) { 90 cusparseFillMode_t fill_mode = 91 upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER; 92 cusparseDiagType_t diag_type = 93 unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT; 94 cusparseMatDescr_t raw_descriptor = nullptr; 95 TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor)); 96 TORCH_CUDASPARSE_CHECK(cusparseSetMatFillMode(raw_descriptor, fill_mode)); 97 TORCH_CUDASPARSE_CHECK(cusparseSetMatDiagType(raw_descriptor, diag_type)); 98 descriptor_.reset(raw_descriptor); 99 } 100 }; 101 102 #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 103 104 class TORCH_CUDA_CPP_API CuSparseBsrsv2Info 105 : public CuSparseDescriptor<bsrsv2Info, &cusparseDestroyBsrsv2Info> { 106 public: CuSparseBsrsv2Info()107 CuSparseBsrsv2Info() { 108 bsrsv2Info_t raw_descriptor = nullptr; 109 TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsv2Info(&raw_descriptor)); 110 descriptor_.reset(raw_descriptor); 111 } 112 }; 113 114 class TORCH_CUDA_CPP_API CuSparseBsrsm2Info 115 : public CuSparseDescriptor<bsrsm2Info, &cusparseDestroyBsrsm2Info> { 116 public: CuSparseBsrsm2Info()117 CuSparseBsrsm2Info() { 118 bsrsm2Info_t raw_descriptor = nullptr; 119 TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsm2Info(&raw_descriptor)); 120 descriptor_.reset(raw_descriptor); 121 } 122 }; 123 124 #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE 125 126 #if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API() 127 128 cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type); 129 130 #if AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 131 class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor 132 : public CuSparseDescriptor<cusparseDnMatDescr, &cusparseDestroyDnMat> { 133 public: 134 explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1); 135 }; 136 137 class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor 138 : public CuSparseDescriptor<const cusparseDnMatDescr, &destroyConstDnMat> { 139 public: 140 explicit CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1); unsafe_mutable_descriptor()141 cusparseDnMatDescr* unsafe_mutable_descriptor() const { 142 return const_cast<cusparseDnMatDescr*>(descriptor()); 143 } unsafe_mutable_descriptor()144 cusparseDnMatDescr* unsafe_mutable_descriptor() { 145 return const_cast<cusparseDnMatDescr*>(descriptor()); 146 } 147 }; 148 149 class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor 150 : public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> { 151 public: 152 explicit CuSparseDnVecDescriptor(const Tensor& input); 153 }; 154 155 class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor 156 : public CuSparseDescriptor<cusparseSpMatDescr, &cusparseDestroySpMat> {}; 157 158 #elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 159 class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor 160 : public ConstCuSparseDescriptor< 161 cusparseDnMatDescr, 162 &cusparseDestroyDnMat> { 163 public: 164 explicit CuSparseDnMatDescriptor( 165 const Tensor& input, 166 int64_t batch_offset = -1); 167 }; 168 169 class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor 170 : public ConstCuSparseDescriptor< 171 const cusparseDnMatDescr, 172 &destroyConstDnMat> { 173 public: 174 explicit CuSparseConstDnMatDescriptor( 175 const Tensor& input, 176 int64_t batch_offset = -1); unsafe_mutable_descriptor()177 cusparseDnMatDescr* unsafe_mutable_descriptor() const { 178 return const_cast<cusparseDnMatDescr*>(descriptor()); 179 } unsafe_mutable_descriptor()180 cusparseDnMatDescr* unsafe_mutable_descriptor() { 181 return const_cast<cusparseDnMatDescr*>(descriptor()); 182 } 183 }; 184 185 class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor 186 : public ConstCuSparseDescriptor< 187 cusparseDnVecDescr, 188 &cusparseDestroyDnVec> { 189 public: 190 explicit CuSparseDnVecDescriptor(const Tensor& input); 191 }; 192 193 class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor 194 : public ConstCuSparseDescriptor< 195 cusparseSpMatDescr, 196 &cusparseDestroySpMat> {}; 197 #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 198 199 class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor 200 : public CuSparseSpMatDescriptor { 201 public: 202 explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1); 203 get_size()204 std::tuple<int64_t, int64_t, int64_t> get_size() { 205 int64_t rows = 0, cols = 0, nnz = 0; 206 TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize( 207 this->descriptor(), 208 &rows, 209 &cols, 210 &nnz)); 211 return std::make_tuple(rows, cols, nnz); 212 } 213 set_tensor(const Tensor & input)214 void set_tensor(const Tensor& input) { 215 auto crow_indices = input.crow_indices(); 216 auto col_indices = input.col_indices(); 217 auto values = input.values(); 218 219 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous()); 220 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous()); 221 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous()); 222 TORCH_CUDASPARSE_CHECK(cusparseCsrSetPointers( 223 this->descriptor(), 224 crow_indices.data_ptr(), 225 col_indices.data_ptr(), 226 values.data_ptr())); 227 } 228 229 #if AT_USE_CUSPARSE_GENERIC_SPSV() set_mat_fill_mode(bool upper)230 void set_mat_fill_mode(bool upper) { 231 cusparseFillMode_t fill_mode = 232 upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER; 233 TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute( 234 this->descriptor(), 235 CUSPARSE_SPMAT_FILL_MODE, 236 &fill_mode, 237 sizeof(fill_mode))); 238 } 239 set_mat_diag_type(bool unit)240 void set_mat_diag_type(bool unit) { 241 cusparseDiagType_t diag_type = 242 unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT; 243 TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute( 244 this->descriptor(), 245 CUSPARSE_SPMAT_DIAG_TYPE, 246 &diag_type, 247 sizeof(diag_type))); 248 } 249 #endif 250 }; 251 252 #if AT_USE_CUSPARSE_GENERIC_SPSV() 253 class TORCH_CUDA_CPP_API CuSparseSpSVDescriptor 254 : public CuSparseDescriptor<cusparseSpSVDescr, &cusparseSpSV_destroyDescr> { 255 public: CuSparseSpSVDescriptor()256 CuSparseSpSVDescriptor() { 257 cusparseSpSVDescr_t raw_descriptor = nullptr; 258 TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor)); 259 descriptor_.reset(raw_descriptor); 260 } 261 }; 262 #endif 263 264 #if AT_USE_CUSPARSE_GENERIC_SPSM() 265 class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor 266 : public CuSparseDescriptor<cusparseSpSMDescr, &cusparseSpSM_destroyDescr> { 267 public: CuSparseSpSMDescriptor()268 CuSparseSpSMDescriptor() { 269 cusparseSpSMDescr_t raw_descriptor = nullptr; 270 TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor)); 271 descriptor_.reset(raw_descriptor); 272 } 273 }; 274 #endif 275 276 class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor 277 : public CuSparseDescriptor<cusparseSpGEMMDescr, &cusparseSpGEMM_destroyDescr> { 278 public: CuSparseSpGEMMDescriptor()279 CuSparseSpGEMMDescriptor() { 280 cusparseSpGEMMDescr_t raw_descriptor = nullptr; 281 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&raw_descriptor)); 282 descriptor_.reset(raw_descriptor); 283 } 284 }; 285 286 #endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API() 287 288 } // namespace at::cuda::sparse 289