xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDASparseDescriptors.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/CUDASparse.h>
6 
7 #include <c10/core/ScalarType.h>
8 
9 #if defined(USE_ROCM)
10 #include <type_traits>
11 #endif
12 
13 namespace at::cuda::sparse {
14 
15 template <typename T, cusparseStatus_t (*destructor)(T*)>
16 struct CuSparseDescriptorDeleter {
operatorCuSparseDescriptorDeleter17   void operator()(T* x) {
18     if (x != nullptr) {
19       TORCH_CUDASPARSE_CHECK(destructor(x));
20     }
21   }
22 };
23 
24 template <typename T, cusparseStatus_t (*destructor)(T*)>
25 class CuSparseDescriptor {
26  public:
descriptor()27   T* descriptor() const {
28     return descriptor_.get();
29   }
descriptor()30   T* descriptor() {
31     return descriptor_.get();
32   }
33 
34  protected:
35   std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
36 };
37 
38 #if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
39 template <typename T, cusparseStatus_t (*destructor)(const T*)>
40 struct ConstCuSparseDescriptorDeleter {
operatorConstCuSparseDescriptorDeleter41   void operator()(T* x) {
42     if (x != nullptr) {
43       TORCH_CUDASPARSE_CHECK(destructor(x));
44     }
45   }
46 };
47 
48 template <typename T, cusparseStatus_t (*destructor)(const T*)>
49 class ConstCuSparseDescriptor {
50  public:
descriptor()51   T* descriptor() const {
52     return descriptor_.get();
53   }
descriptor()54   T* descriptor() {
55     return descriptor_.get();
56   }
57 
58  protected:
59   std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
60 };
61 #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS
62 
63 #if defined(USE_ROCM)
64 using cusparseMatDescr = std::remove_pointer<hipsparseMatDescr_t>::type;
65 using cusparseDnMatDescr = std::remove_pointer<hipsparseDnMatDescr_t>::type;
66 using cusparseDnVecDescr = std::remove_pointer<hipsparseDnVecDescr_t>::type;
67 using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
68 using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
69 using cusparseSpGEMMDescr = std::remove_pointer<hipsparseSpGEMMDescr_t>::type;
70 #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
71 using bsrsv2Info = std::remove_pointer<bsrsv2Info_t>::type;
72 using bsrsm2Info = std::remove_pointer<bsrsm2Info_t>::type;
73 #endif
74 #endif
75 
76 // NOTE: This is only needed for CUDA 11 and earlier, since CUDA 12 introduced
77 // API for const descriptors
78 cusparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr);
79 
80 class TORCH_CUDA_CPP_API CuSparseMatDescriptor
81     : public CuSparseDescriptor<cusparseMatDescr, &cusparseDestroyMatDescr> {
82  public:
CuSparseMatDescriptor()83   CuSparseMatDescriptor() {
84     cusparseMatDescr_t raw_descriptor = nullptr;
85     TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
86     descriptor_.reset(raw_descriptor);
87   }
88 
CuSparseMatDescriptor(bool upper,bool unit)89   CuSparseMatDescriptor(bool upper, bool unit) {
90     cusparseFillMode_t fill_mode =
91         upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
92     cusparseDiagType_t diag_type =
93         unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
94     cusparseMatDescr_t raw_descriptor = nullptr;
95     TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
96     TORCH_CUDASPARSE_CHECK(cusparseSetMatFillMode(raw_descriptor, fill_mode));
97     TORCH_CUDASPARSE_CHECK(cusparseSetMatDiagType(raw_descriptor, diag_type));
98     descriptor_.reset(raw_descriptor);
99   }
100 };
101 
102 #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
103 
104 class TORCH_CUDA_CPP_API CuSparseBsrsv2Info
105     : public CuSparseDescriptor<bsrsv2Info, &cusparseDestroyBsrsv2Info> {
106  public:
CuSparseBsrsv2Info()107   CuSparseBsrsv2Info() {
108     bsrsv2Info_t raw_descriptor = nullptr;
109     TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsv2Info(&raw_descriptor));
110     descriptor_.reset(raw_descriptor);
111   }
112 };
113 
114 class TORCH_CUDA_CPP_API CuSparseBsrsm2Info
115     : public CuSparseDescriptor<bsrsm2Info, &cusparseDestroyBsrsm2Info> {
116  public:
CuSparseBsrsm2Info()117   CuSparseBsrsm2Info() {
118     bsrsm2Info_t raw_descriptor = nullptr;
119     TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsm2Info(&raw_descriptor));
120     descriptor_.reset(raw_descriptor);
121   }
122 };
123 
124 #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
125 
126 #if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
127 
128 cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);
129 
130 #if AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS()
131 class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
132     : public CuSparseDescriptor<cusparseDnMatDescr, &cusparseDestroyDnMat> {
133  public:
134   explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
135 };
136 
137 class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor
138     : public CuSparseDescriptor<const cusparseDnMatDescr, &destroyConstDnMat> {
139  public:
140   explicit CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
unsafe_mutable_descriptor()141   cusparseDnMatDescr* unsafe_mutable_descriptor() const {
142     return const_cast<cusparseDnMatDescr*>(descriptor());
143   }
unsafe_mutable_descriptor()144   cusparseDnMatDescr* unsafe_mutable_descriptor() {
145     return const_cast<cusparseDnMatDescr*>(descriptor());
146   }
147 };
148 
149 class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
150     : public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> {
151  public:
152   explicit CuSparseDnVecDescriptor(const Tensor& input);
153 };
154 
155 class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
156     : public CuSparseDescriptor<cusparseSpMatDescr, &cusparseDestroySpMat> {};
157 
158 #elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
159   class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
160       : public ConstCuSparseDescriptor<
161             cusparseDnMatDescr,
162             &cusparseDestroyDnMat> {
163    public:
164     explicit CuSparseDnMatDescriptor(
165         const Tensor& input,
166         int64_t batch_offset = -1);
167   };
168 
169   class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor
170       : public ConstCuSparseDescriptor<
171             const cusparseDnMatDescr,
172             &destroyConstDnMat> {
173    public:
174     explicit CuSparseConstDnMatDescriptor(
175         const Tensor& input,
176         int64_t batch_offset = -1);
unsafe_mutable_descriptor()177   cusparseDnMatDescr* unsafe_mutable_descriptor() const {
178     return const_cast<cusparseDnMatDescr*>(descriptor());
179   }
unsafe_mutable_descriptor()180   cusparseDnMatDescr* unsafe_mutable_descriptor() {
181     return const_cast<cusparseDnMatDescr*>(descriptor());
182   }
183   };
184 
185   class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
186       : public ConstCuSparseDescriptor<
187             cusparseDnVecDescr,
188             &cusparseDestroyDnVec> {
189    public:
190     explicit CuSparseDnVecDescriptor(const Tensor& input);
191   };
192 
193   class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
194       : public ConstCuSparseDescriptor<
195             cusparseSpMatDescr,
196             &cusparseDestroySpMat> {};
197 #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
198 
199 class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor
200     : public CuSparseSpMatDescriptor {
201  public:
202   explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1);
203 
get_size()204   std::tuple<int64_t, int64_t, int64_t> get_size() {
205     int64_t rows = 0, cols = 0, nnz = 0;
206     TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize(
207         this->descriptor(),
208         &rows,
209         &cols,
210         &nnz));
211     return std::make_tuple(rows, cols, nnz);
212   }
213 
set_tensor(const Tensor & input)214   void set_tensor(const Tensor& input) {
215     auto crow_indices = input.crow_indices();
216     auto col_indices = input.col_indices();
217     auto values = input.values();
218 
219     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous());
220     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous());
221     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
222     TORCH_CUDASPARSE_CHECK(cusparseCsrSetPointers(
223         this->descriptor(),
224         crow_indices.data_ptr(),
225         col_indices.data_ptr(),
226         values.data_ptr()));
227   }
228 
229 #if AT_USE_CUSPARSE_GENERIC_SPSV()
set_mat_fill_mode(bool upper)230   void set_mat_fill_mode(bool upper) {
231     cusparseFillMode_t fill_mode =
232         upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
233     TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
234         this->descriptor(),
235         CUSPARSE_SPMAT_FILL_MODE,
236         &fill_mode,
237         sizeof(fill_mode)));
238   }
239 
set_mat_diag_type(bool unit)240   void set_mat_diag_type(bool unit) {
241     cusparseDiagType_t diag_type =
242         unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
243     TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
244         this->descriptor(),
245         CUSPARSE_SPMAT_DIAG_TYPE,
246         &diag_type,
247         sizeof(diag_type)));
248   }
249 #endif
250 };
251 
252 #if AT_USE_CUSPARSE_GENERIC_SPSV()
253 class TORCH_CUDA_CPP_API CuSparseSpSVDescriptor
254     : public CuSparseDescriptor<cusparseSpSVDescr, &cusparseSpSV_destroyDescr> {
255  public:
CuSparseSpSVDescriptor()256   CuSparseSpSVDescriptor() {
257     cusparseSpSVDescr_t raw_descriptor = nullptr;
258     TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor));
259     descriptor_.reset(raw_descriptor);
260   }
261 };
262 #endif
263 
264 #if AT_USE_CUSPARSE_GENERIC_SPSM()
265 class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor
266     : public CuSparseDescriptor<cusparseSpSMDescr, &cusparseSpSM_destroyDescr> {
267  public:
CuSparseSpSMDescriptor()268   CuSparseSpSMDescriptor() {
269     cusparseSpSMDescr_t raw_descriptor = nullptr;
270     TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor));
271     descriptor_.reset(raw_descriptor);
272   }
273 };
274 #endif
275 
276 class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor
277     : public CuSparseDescriptor<cusparseSpGEMMDescr, &cusparseSpGEMM_destroyDescr> {
278  public:
CuSparseSpGEMMDescriptor()279   CuSparseSpGEMMDescriptor() {
280     cusparseSpGEMMDescr_t raw_descriptor = nullptr;
281     TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&raw_descriptor));
282     descriptor_.reset(raw_descriptor);
283   }
284 };
285 
286 #endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
287 
288 } // namespace at::cuda::sparse
289