xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/SparseTensorImpl.h>
8 #include <ATen/native/Resize.h>
9 #include <ATen/native/SparseTensorUtils.h>
10 #include <cuda_runtime.h>
11 #include <type_traits>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_sparse_sparse_matmul_native.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/empty_like_native.h>
20 #endif
21 
22 #include <thrust/device_ptr.h>
23 #include <thrust/for_each.h>
24 #include <thrust/sequence.h>
25 
26 #include <ATen/cuda/CUDAContext.h>
27 #include <ATen/cuda/CUDADataType.h>
28 #include <ATen/cuda/CUDAUtils.h>
29 #include <ATen/cuda/ThrustAllocator.h>
30 #include <cusparse.h>
31 #include <ATen/native/sparse/cuda/SparseCUDABlas.h>
32 #include <c10/cuda/CUDACachingAllocator.h>
33 
34 #include <thrust/device_vector.h>
35 #include <thrust/host_vector.h>
36 #include <thrust/iterator/counting_iterator.h>
37 #include <thrust/functional.h>
38 #include <thrust/binary_search.h>
39 #include <thrust/execution_policy.h>
40 #include <thrust/iterator/discard_iterator.h>
41 
42 
43 #if defined(__CUDACC__) && (CUSPARSE_VERSION >= 11000)
44 #define IS_CUSPARSE11_AVAILABLE() 1
45 #else
46 #define IS_CUSPARSE11_AVAILABLE() 0
47 #endif
48 
49 #if IS_CUSPARSE11_AVAILABLE()
50 #include <library_types.h>
51 #endif
52 
53 namespace at::native {
54 
55 namespace {
56 
57 using namespace at::sparse;
58 
_to_csr_int(const Tensor & rowIndices,int64_t dim,int64_t nnz)59 Tensor _to_csr_int(const Tensor& rowIndices, int64_t dim, int64_t nnz) {
60   Tensor csr = at::empty({dim + 1}, CUDA(kInt));
61   Tensor rowIndicesInt = at::empty({rowIndices.size(0)}, CUDA(kInt));
62   rowIndicesInt.copy_(rowIndices);
63   sparse::cuda::Xcoo2csr(
64       rowIndicesInt.data_ptr<int32_t>(), nnz, dim, csr.data_ptr<int32_t>());
65   return csr;
66 }
67 
68 
69 #pragma push
70 // NVCC complains that confirm_mult_size is not used,
71 // but it is used in specializations of CusparseMatrixMultiplyOp below
72 #pragma nv_diag_suppress 177   // Function was declared but never referenced
confirm_mult_size(const std::vector<int> & mat1_size,const std::vector<int> & mat2_size)73 int confirm_mult_size(const std::vector<int>& mat1_size, const std::vector<int>& mat2_size) {
74   TORCH_CHECK(
75       mat1_size[1] == mat2_size[0],
76       "mat1 and mat2 shapes cannot be multiplied (",
77       mat1_size[0],
78       "x",
79       mat1_size[1],
80       " and ",
81       mat2_size[0],
82       "x",
83       mat2_size[1],
84       ")");
85   return mat1_size[1];
86 }
87 #pragma pop
88 
create_general_description_(cusparseMatDescr_t & description_)89 void create_general_description_(cusparseMatDescr_t& description_) {
90   TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&description_));
91   TORCH_CUDASPARSE_CHECK(cusparseSetMatType(description_, CUSPARSE_MATRIX_TYPE_GENERAL));
92   TORCH_CUDASPARSE_CHECK(cusparseSetMatIndexBase(description_, CUSPARSE_INDEX_BASE_ZERO));
93 }
94 
95 // csrMatrixRef is used to have a representation of a raw CSR matrix representation
96 // comming from `sparse_sparse_matmul_cuda_kernel` function.
97 // Moreover this implements a RAII guard for a cusparse descriptor
98 template<class scalar_t>
99 struct csrMatrixRef {
100   int* csr_indices_{nullptr};
101   int* csr_pointers_{nullptr};
102   scalar_t* csr_values_{nullptr};
103   int nnz_{0};
104   std::vector<int> size_{};
105 
106   #if IS_CUSPARSE11_AVAILABLE()
107     cusparseSpMatDescr_t description_{0};
108   #else
109     cusparseMatDescr_t description_{0};
110   #endif
111 
csrMatrixRefat::native::__anondf1659a70111::csrMatrixRef112   csrMatrixRef() {
113     #if !IS_CUSPARSE11_AVAILABLE()
114       create_general_description_(description_);
115     #endif
116   }
117 
csrMatrixRefat::native::__anondf1659a70111::csrMatrixRef118   csrMatrixRef(
119       int* csr_indices,
120       int* csr_pointers,
121       scalar_t* csr_values,
122       int nnz,
123       const std::vector<int>& size)
124       : csr_indices_{csr_indices},
125         csr_pointers_{csr_pointers},
126         csr_values_{csr_values},
127         nnz_{nnz},
128         size_{size} {
129     #if IS_CUSPARSE11_AVAILABLE()
130       cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>();
131       TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
132         &description_,
133         this->size(0),
134         this->size(1),
135         this->nnz_,
136         this->csr_pointers_,
137         this->csr_indices_,
138         this->csr_values_,
139         CUSPARSE_INDEX_32I,
140         CUSPARSE_INDEX_32I,
141         CUSPARSE_INDEX_BASE_ZERO,
142         cuda_data_type));
143     #else
144       create_general_description_(description_);
145     #endif
146   }
147 
~csrMatrixRefat::native::__anondf1659a70111::csrMatrixRef148   ~csrMatrixRef() {
149     #if IS_CUSPARSE11_AVAILABLE()
150       cusparseDestroySpMat(description_);
151     #else
152       cusparseDestroyMatDescr(description_);
153     #endif
154   }
155 
sizeat::native::__anondf1659a70111::csrMatrixRef156   int size(int index) const {
157     return size_.at(index);
158   }
159 };
160 
161 // csrOutput is used to represent the output for `CusparseMatrixMultiplyOp`
162 // Note that `csrOutput` is different from `csrMatrixRef` and the purpose
163 // of this was to have a materialized  version of a CSR matrix.
164 // Moreover this implements a RAII guard for a cusparse descriptor
165 struct csrOutput {
166   Tensor csr_indices_{};
167   Tensor csr_pointers_{};
168   at::Tensor csr_values_{};
169   int nnz_{0};
170   std::vector<int> size_;
171 
172   cusparseMatDescr_t description_{0};
173 
csrOutputat::native::__anondf1659a70111::csrOutput174   csrOutput(const std::vector<int> &size) : size_{size} {
175     create_general_description_(description_);
176   }
177 
~csrOutputat::native::__anondf1659a70111::csrOutput178   ~csrOutput() {
179     cusparseDestroyMatDescr(description_);
180   }
181 
182   csrOutput(const csrOutput&) = delete;
183   csrOutput& operator=(const csrOutput&) = delete;
csrOutputat::native::__anondf1659a70111::csrOutput184   csrOutput(csrOutput&& rhs) {
185     csr_indices_ = std::move(rhs.csr_indices_);
186     csr_pointers_ = std::move(rhs.csr_pointers_);
187     csr_values_ = std::move(rhs.csr_values_);
188     nnz_ = rhs.nnz_;
189     size_ = std::move(rhs.size_);
190     description_ = rhs.description_;
191     rhs.description_ = 0;
192   }
193   csrOutput& operator=(csrOutput&&) = delete;
sizeat::native::__anondf1659a70111::csrOutput194   int size(int index) const {
195     return size_.at(index);
196   }
197 };
198 
199 #if IS_CUSPARSE11_AVAILABLE()
200 
201 // RAII guard helps to support cuSparse 11 API for `A @ B` operation
202 // This generic template exists because with cuSparse the `scalar_t` type could be a double or float
203 template <class scalar_t>
204 struct CusparseMatrixMultiplyOp {
205 
206   cusparseSpGEMMDescr_t spgemmDesc;
207 
CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp208   CusparseMatrixMultiplyOp() {
209     static_assert(
210       std::is_same<c10::Half, scalar_t>::value ||
211           std::is_same<c10::BFloat16, scalar_t>::value ||
212           std::is_same<float, scalar_t>::value ||
213           std::is_same<double, scalar_t>::value ||
214           std::is_same<c10::complex<float>, scalar_t>::value ||
215           std::is_same<c10::complex<double>, scalar_t>::value,
216       "cusparseSpGEMM only supports data type of half, bfloat16, float, double and complex float, double.");
217     // SpGEMM Computation
218     TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&spgemmDesc));
219   }
220 
~CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp221   ~CusparseMatrixMultiplyOp() {
222     // destroy matrix/vector descriptors
223     cusparseSpGEMM_destroyDescr(spgemmDesc);
224   }
225 
operator ()at::native::__anondf1659a70111::CusparseMatrixMultiplyOp226   csrOutput operator ()(
227       const csrMatrixRef<scalar_t>& A,
228       const csrMatrixRef<scalar_t>& B,
229       Tensor& output_values,
230       Tensor& output_indices) {
231     const int A_num_rows = A.size(0);
232 
233     const int B_num_cols = B.size(1);
234 
235     csrOutput out({A.size(0), B.size(1)});
236 
237     out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
238 
239     int* dC_csrOffsets = out.csr_pointers_.data_ptr<int>();
240     int* dC_columns = nullptr;
241     scalar_t* dC_values = nullptr;
242 
243     scalar_t alpha = 1.0f;
244     scalar_t beta = 0.0f;
245     cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
246     cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
247 
248     csrMatrixRef<scalar_t> C(
249       dC_columns,
250       dC_csrOffsets,
251       dC_values,
252       /*nnz*/0,
253       {A_num_rows, B_num_cols}
254     );
255 
256     //--------------------------------------------------------------------------
257     // CUSPARSE APIs
258     cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle();
259     void *dBuffer1 = NULL, *dBuffer2 = NULL;
260     size_t bufferSize1 = 0, bufferSize2 = 0;
261 
262     cusparseSpMatDescr_t matA = A.description_;
263     cusparseSpMatDescr_t matB = B.description_;
264     cusparseSpMatDescr_t matC = C.description_;
265     //--------------------------------------------------------------------------
266 
267     cudaDataType computeType = at::cuda::getCudaDataType<scalar_t>();
268 
269     // If a specific GPU model does not provide native support for a given data type,
270     // the routine returns CUSPARSE_STATUS_ARCH_MISMATCH error
271     cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
272     TORCH_CHECK(prop->major >= 5 && !((10*prop->major + prop->minor) < 53 && computeType == CUDA_R_16F),
273         "sparse_mm: CUDA Float16 requires compute capability >= 53 (current: ", prop->major, prop->minor, ")");
274     TORCH_CHECK(!(prop->major < 8 && computeType == CUDA_R_16BF),
275         "sparse_mm: CUDA BFloat16 requires compute capability >= 80 (current: ", prop->major, prop->minor, ")");
276 
277     // ask bufferSize1 bytes for external memory
278     TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
279         handle,
280         opA,
281         opB,
282         &alpha,
283         matA,
284         matB,
285         &beta,
286         matC,
287         computeType,
288         CUSPARSE_SPGEMM_DEFAULT,
289         spgemmDesc,
290         &bufferSize1,
291         NULL));
292 
293     auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
294 
295     at::DataPtr dataPtr1 = allocator.allocate(bufferSize1);
296     dBuffer1 = dataPtr1.get();
297     // inspect the matrices A and B to understand the memory requirement for
298     // the next step
299     TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
300         handle,
301         opA,
302         opB,
303         &alpha,
304         matA,
305         matB,
306         &beta,
307         matC,
308         computeType,
309         CUSPARSE_SPGEMM_DEFAULT,
310         spgemmDesc,
311         &bufferSize1,
312         dBuffer1));
313 
314     // ask bufferSize2 bytes for external memory
315     TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute(
316         handle,
317         opA,
318         opB,
319         &alpha,
320         matA,
321         matB,
322         &beta,
323         matC,
324         computeType,
325         CUSPARSE_SPGEMM_DEFAULT,
326         spgemmDesc,
327         &bufferSize2,
328         NULL));
329 
330     at::DataPtr dataPtr2 = allocator.allocate(bufferSize2);
331     dBuffer2 = dataPtr2.get();
332 
333     // compute the intermediate product of A * B
334     TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute(
335         handle,
336         opA,
337         opB,
338         &alpha,
339         matA,
340         matB,
341         &beta,
342         matC,
343         computeType,
344         CUSPARSE_SPGEMM_DEFAULT,
345         spgemmDesc,
346         &bufferSize2,
347         dBuffer2));
348     // get matrix C non-zero entries C_num_nnz1
349     int64_t C_num_rows1, C_num_cols1, C_num_nnz1;
350     TORCH_CUDASPARSE_CHECK(
351         cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_num_nnz1));
352     // allocate matrix C
353     // allocate C offsets
354     out.nnz_ = C_num_nnz1;
355 
356     out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
357     out.csr_values_ = at::empty({out.nnz_}, output_values.options());
358     dC_columns = out.csr_indices_.data_ptr<int>();
359     dC_values = out.csr_values_.data_ptr<scalar_t>();
360 
361     // update matC with the new pointers
362     TORCH_CUDASPARSE_CHECK(
363         cusparseCsrSetPointers(matC, dC_csrOffsets, dC_columns, dC_values));
364 
365     // copy the final products to the matrix C
366     TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_copy(
367         handle,
368         opA,
369         opB,
370         &alpha,
371         matA,
372         matB,
373         &beta,
374         matC,
375         computeType,
376         CUSPARSE_SPGEMM_DEFAULT,
377         spgemmDesc));
378     return out;
379   }
380 };
381 
382 
383 template struct CusparseMatrixMultiplyOp<float>;
384 
385 template struct CusparseMatrixMultiplyOp<double>;
386 
387 #else // if not IS_CUSPARSE11_AVAILABLE()
388 
389 using DcsrMatrixRef = csrMatrixRef<double>;
390 using ScsrMatrixRef = csrMatrixRef<float>;
391 
392 // RAII guard helps to support cuSparse 10 API for `A @ B` operation
393 // This generic template exists because with cuSparse the `scalar_t` type could be a double or float
394 template <class scalar_t>
395 struct CusparseMatrixMultiplyOp {
operator ()at::native::__anondf1659a70111::CusparseMatrixMultiplyOp396   csrOutput operator()(
397       const csrMatrixRef<scalar_t>& lhs,
398       const csrMatrixRef<scalar_t>& rhs,
399       Tensor &output_values,
400       Tensor &output_indices)
401   {
402     static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double.");
403   }
404 };
405 
406 // Specializacion for `A @ B` operation for double values with cuSparse
407 template<> struct CusparseMatrixMultiplyOp<double> {
408   csrgemm2Info_t gemm2Info_;
409 
CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp410   CusparseMatrixMultiplyOp() {
411     TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
412   }
~CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp413   ~CusparseMatrixMultiplyOp() {
414     cusparseDestroyCsrgemm2Info(gemm2Info_);
415   }
416 
operator ()at::native::__anondf1659a70111::CusparseMatrixMultiplyOp417   csrOutput operator ()(
418       const DcsrMatrixRef& lhs,
419       const DcsrMatrixRef& rhs,
420       Tensor &output_values,
421       Tensor &output_indices) {
422     double alpha = 1.0;
423     DcsrMatrixRef empty;
424     return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
425   }
426 
Dgemm2at::native::__anondf1659a70111::CusparseMatrixMultiplyOp427   csrOutput Dgemm2(
428       const DcsrMatrixRef& A,
429       const DcsrMatrixRef& B,
430       const DcsrMatrixRef& C,
431       const double* alpha,
432       const double* beta,
433       Tensor &output_values,
434       Tensor &output_indices) {
435     void* buffer_{nullptr};
436     cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
437     TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
438 
439     csrOutput out({A.size(0), B.size(1)});
440     int innerSize = confirm_mult_size(A.size_, B.size_);
441     out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
442 
443     // Compute needed buffer size
444     size_t new_bubber_sz;
445     TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt(
446         cusparseHandle_,
447         out.size(0),
448         out.size(1),
449         innerSize,
450         alpha,
451         A.description_,
452         A.nnz_,
453         A.csr_pointers_,
454         A.csr_indices_,
455         B.description_,
456         B.nnz_,
457         B.csr_pointers_,
458         B.csr_indices_,
459         beta,
460         C.description_,
461         C.nnz_,
462         C.csr_pointers_,
463         C.csr_indices_,
464         gemm2Info_,
465         &new_bubber_sz));
466 
467     // (Re)allocate buffer if needed
468     auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
469     at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
470     buffer_ = data_ptr.get();
471 
472     // Find the resulting non-zero pattern.
473     TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
474         cusparseHandle_,
475         out.size(0),
476         out.size(1),
477         innerSize,
478         A.description_,
479         A.nnz_,
480         A.csr_pointers_,
481         A.csr_indices_,
482         B.description_,
483         B.nnz_,
484         B.csr_pointers_,
485         B.csr_indices_,
486         C.description_,
487         C.nnz_,
488         C.csr_pointers_,
489         C.csr_indices_,
490         out.description_,
491         out.csr_pointers_.data_ptr<int>(),
492         &out.nnz_,
493         gemm2Info_,
494         buffer_));
495 
496     out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
497     out.csr_values_ = at::empty({out.nnz_}, output_values.options());
498 
499     // Perform the gemm2 operation for doubles
500     // out = alpha ∗ A ∗ B + beta ∗ C
501     TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2(
502         cusparseHandle_,
503         out.size(0),
504         out.size(1),
505         innerSize,
506         alpha,
507         A.description_,
508         A.nnz_,
509         A.csr_values_,
510         A.csr_pointers_,
511         A.csr_indices_,
512         B.description_,
513         B.nnz_,
514         B.csr_values_,
515         B.csr_pointers_,
516         B.csr_indices_,
517         beta,
518         C.description_,
519         C.nnz_,
520         C.csr_values_,
521         C.csr_pointers_,
522         C.csr_indices_,
523         out.description_,
524         out.csr_values_.data_ptr<double>(),
525         out.csr_pointers_.data_ptr<int>(),
526         out.csr_indices_.data_ptr<int>(),
527         gemm2Info_,
528         buffer_));
529     return out;
530   }
531 };
532 
533 // Specializacion for `A @ B` operation for float values with cuSparse
534 template<> struct CusparseMatrixMultiplyOp<float> {
535   csrgemm2Info_t gemm2Info_;
536 
CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp537   CusparseMatrixMultiplyOp() {
538     TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
539 
540   }
~CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp541   ~CusparseMatrixMultiplyOp() {
542     cusparseDestroyCsrgemm2Info(gemm2Info_);
543   }
operator ()at::native::__anondf1659a70111::CusparseMatrixMultiplyOp544   csrOutput operator()(
545       const ScsrMatrixRef& lhs,
546       const ScsrMatrixRef& rhs,
547       Tensor &output_values,
548       Tensor &output_indices) {
549     float alpha = 1.0;
550     ScsrMatrixRef empty;
551     return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
552   }
553 
Sgemm2at::native::__anondf1659a70111::CusparseMatrixMultiplyOp554   csrOutput Sgemm2(
555       const ScsrMatrixRef& A,
556       const ScsrMatrixRef& B,
557       const ScsrMatrixRef& C,
558       const float* alpha,
559       const float* beta,
560       Tensor &output_values,
561       Tensor &output_indices) {
562     void* buffer_{nullptr};
563     cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
564     TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
565 
566     csrOutput out({A.size(0), B.size(1)});
567 
568     int innerSize = confirm_mult_size(A.size_, B.size_);
569 
570     out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
571 
572     // Compute needed buffer size
573     size_t new_bubber_sz;
574     TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt(
575         cusparseHandle_,
576         out.size(0),
577         out.size(1),
578         innerSize,
579         alpha,
580         A.description_,
581         A.nnz_,
582         A.csr_pointers_,
583         A.csr_indices_,
584         B.description_,
585         B.nnz_,
586         B.csr_pointers_,
587         B.csr_indices_,
588         beta,
589         C.description_,
590         C.nnz_,
591         C.csr_pointers_,
592         C.csr_indices_,
593         gemm2Info_,
594         &new_bubber_sz));
595 
596     auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
597     at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
598     buffer_ = data_ptr.get();
599 
600     // Find the resulting non-zero pattern.
601     TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
602         cusparseHandle_,
603         out.size(0),
604         out.size(1),
605         innerSize,
606         A.description_,
607         A.nnz_,
608         A.csr_pointers_,
609         A.csr_indices_,
610         B.description_,
611         B.nnz_,
612         B.csr_pointers_,
613         B.csr_indices_,
614         C.description_,
615         C.nnz_,
616         C.csr_pointers_,
617         C.csr_indices_,
618         out.description_,
619         out.csr_pointers_.data_ptr<int>(),
620         &out.nnz_,
621         gemm2Info_,
622         buffer_));
623 
624     out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
625     out.csr_values_ = at::empty({out.nnz_}, output_values.options());
626 
627     // Perform the gemm2 operation for doubles
628     // out = alpha ∗ A ∗ B + beta ∗ C
629     TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2(
630         cusparseHandle_,
631         out.size(0),
632         out.size(1),
633         innerSize,
634         alpha,
635         A.description_,
636         A.nnz_,
637         A.csr_values_,
638         A.csr_pointers_,
639         A.csr_indices_,
640         B.description_,
641         B.nnz_,
642         B.csr_values_,
643         B.csr_pointers_,
644         B.csr_indices_,
645         beta,
646         C.description_,
647         C.nnz_,
648         C.csr_values_,
649         C.csr_pointers_,
650         C.csr_indices_,
651         out.description_,
652         out.csr_values_.data_ptr<float>(),
653         out.csr_pointers_.data_ptr<int>(),
654         out.csr_indices_.data_ptr<int>(),
655         gemm2Info_,
656         buffer_));
657     return out;
658   }
659 };
660 
661 
662 
663 #endif // IS_CUSPARSE11_AVAILABLE()
664 
665 template <typename scalar_t>
sparse_sparse_matmul_cuda_kernel(Tensor & result,const Tensor & mat1,const Tensor & mat2)666 void sparse_sparse_matmul_cuda_kernel(
667     Tensor& result,
668     const Tensor& mat1,
669     const Tensor& mat2) {
670 
671   static_assert(
672     std::is_same<c10::Half, scalar_t>::value ||
673         std::is_same<c10::BFloat16, scalar_t>::value ||
674         std::is_same<float, scalar_t>::value ||
675         std::is_same<double, scalar_t>::value ||
676         std::is_same<c10::complex<float>, scalar_t>::value ||
677         std::is_same<c10::complex<double>, scalar_t>::value,
678     "sparse_sparse_matmul_cuda_kernel only supports data type of half, bfloat16, float, double and complex float, double.");
679 
680   // older versions of cusparse on Windows segfault for complex128 dtype
681 #if defined(_WIN32) && defined(CUSPARSE_VERSION) && CUSPARSE_VERSION < 11400
682   TORCH_CHECK(
683       !(mat1.scalar_type() == ScalarType::ComplexDouble),
684       "Sparse multiplication with complex128 dtype inputs is not supported with current CUDA version. Please upgrade to CUDA Toolkit 11.2.1+");
685 #endif
686 
687   Tensor mat1_indices_ = mat1._indices().contiguous();
688   Tensor mat1_values = mat1._values().contiguous();
689 
690   Tensor mat1_row_indices = mat1_indices_.select(0, 0);
691   Tensor mat1_col_indices = mat1_indices_.select(0, 1);
692 
693   Tensor mat1_indptr = _to_csr_int(mat1_row_indices, mat1.size(0), mat1._nnz());
694 
695   Tensor mat1_indices = at::empty(
696       {mat1_col_indices.size(0)}, mat1_col_indices.options().dtype(kInt));
697 
698   mat1_indices.copy_(mat1_col_indices);
699 
700   Tensor mat2_indices_ = mat2._indices().contiguous();
701   Tensor mat2_values = mat2._values().contiguous();
702   Tensor mat2_row_indices = mat2_indices_.select(0, 0);
703   Tensor mat2_col_indices = mat2_indices_.select(0, 1);
704 
705   Tensor mat2_indptr = _to_csr_int(mat2_row_indices, mat2.size(0), mat2._nnz());
706   Tensor mat2_indices = at::empty({mat2_col_indices.size(0)}, mat2_col_indices.options().dtype(kInt));
707   mat2_indices.copy_(mat2_col_indices);
708 
709   auto m = mat1.size(0);
710   auto k1 = mat1.size(1);
711 
712   auto k2 = mat2.size(0);
713   auto n = mat2.size(1);
714   TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k1 <= INT_MAX),
715     "At the moment, cusparseDcsrgemm2 only supports m, n, k, nnz with the bound [val] <= ", INT_MAX, ".",
716     "If you need this, please file an issue on GitHub."
717   );
718   auto output_indices = result._indices();
719   auto output_values = result._values();
720 
721   if ((k1 == 0 && k2 == 0) || (n == 0 && m == 0)) {
722     output_indices.zero_();
723     output_values.zero_();
724     return;
725   }
726 
727   csrMatrixRef<scalar_t> csr_mat1(
728       mat1_indices.data_ptr<int>(),
729       mat1_indptr.data_ptr<int>(),
730       mat1_values.data_ptr<scalar_t>(),
731       (int)mat1._nnz(),
732       {(int)mat1.size(0), (int)mat1.size(1)});
733 
734   csrMatrixRef<scalar_t> csr_mat2(
735       mat2_indices.data_ptr<int>(),
736       mat2_indptr.data_ptr<int>(),
737       mat2_values.data_ptr<scalar_t>(),
738       (int)mat2._nnz(),
739       {(int)mat2.size(0), (int)mat2.size(1)});
740 
741   // Sparse matrix multiplication
742   CusparseMatrixMultiplyOp<scalar_t> op;
743   csrOutput csr_output = op(csr_mat1, csr_mat2, output_values, output_indices);
744   auto nnz = csr_output.nnz_;
745 
746   output_values.set_(csr_output.csr_values_);
747   output_indices.resize_({2, nnz});
748   auto output_indices_accessor = output_indices.packed_accessor64<int64_t, 2>();
749 
750   auto csr_output_pointers_accessor =
751       csr_output.csr_pointers_.packed_accessor64<int, 1>();
752 
753   auto csr_output_ind_accessor =
754       csr_output.csr_indices_.packed_accessor64<int, 1>();
755 
756   auto major_dim = result.size(0);
757   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
758   at::cuda::ThrustAllocator allocator;
759   auto policy = thrust::cuda::par(allocator).on(stream);
760 
761   // Filling the COO row indices
762   thrust::for_each(
763       policy,
764       thrust::make_counting_iterator(int64_t(0)),
765       thrust::make_counting_iterator(int64_t(major_dim)),
766       [output_indices_accessor,
767        csr_output_pointers_accessor,
768        major_dim,
769        nnz] __device__(int64_t i) {
770         auto Ap = csr_output_pointers_accessor.data();
771         int64_t* indices_row = output_indices_accessor[0].data();
772 
773         for (int jj = Ap[i];  jj < Ap[i + 1]; jj++) {
774           indices_row[jj] = i;
775         }
776       });
777 
778   // Filling the COO column indices
779   thrust::for_each(
780     policy,
781     thrust::make_counting_iterator(int64_t(0)),
782     thrust::make_counting_iterator(int64_t(csr_output.nnz_)),
783     [output_indices_accessor,
784       csr_output_pointers_accessor,
785       csr_output_ind_accessor,
786       major_dim,
787       nnz] __device__(int64_t i) {
788       int64_t* indices_col = output_indices_accessor[1].data();
789       indices_col[i] = csr_output_ind_accessor[i];
790     });
791 }
792 
793 } // end anonymous namespace
794 
sparse_sparse_matmul_cuda(const Tensor & mat1_,const Tensor & mat2_)795 Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
796   TORCH_INTERNAL_ASSERT(mat1_.is_sparse());
797   TORCH_INTERNAL_ASSERT(mat2_.is_sparse());
798   TORCH_CHECK(mat1_.dim() == 2);
799   TORCH_CHECK(mat2_.dim() == 2);
800   TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_mm: scalar values expected, mat1 got ", mat1_.dense_dim(), "D values");
801   TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_mm: scalar values expected, mat2 got ", mat2_.dense_dim(), "D values");
802 
803   TORCH_CHECK(
804       mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (",
805       mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
806 
807   TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
808            "mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type());
809 
810   auto output = at::native::empty_like(mat1_);
811   output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
812 
813 #if IS_CUSPARSE11_AVAILABLE()
814   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
815     sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
816   });
817 #else
818   AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
819     sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
820   });
821 #endif
822   return output;
823 }
824 
825 } // namespace at::native
826