1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/SparseTensorImpl.h>
8 #include <ATen/native/Resize.h>
9 #include <ATen/native/SparseTensorUtils.h>
10 #include <cuda_runtime.h>
11 #include <type_traits>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_sparse_sparse_matmul_native.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/empty_like_native.h>
20 #endif
21
22 #include <thrust/device_ptr.h>
23 #include <thrust/for_each.h>
24 #include <thrust/sequence.h>
25
26 #include <ATen/cuda/CUDAContext.h>
27 #include <ATen/cuda/CUDADataType.h>
28 #include <ATen/cuda/CUDAUtils.h>
29 #include <ATen/cuda/ThrustAllocator.h>
30 #include <cusparse.h>
31 #include <ATen/native/sparse/cuda/SparseCUDABlas.h>
32 #include <c10/cuda/CUDACachingAllocator.h>
33
34 #include <thrust/device_vector.h>
35 #include <thrust/host_vector.h>
36 #include <thrust/iterator/counting_iterator.h>
37 #include <thrust/functional.h>
38 #include <thrust/binary_search.h>
39 #include <thrust/execution_policy.h>
40 #include <thrust/iterator/discard_iterator.h>
41
42
43 #if defined(__CUDACC__) && (CUSPARSE_VERSION >= 11000)
44 #define IS_CUSPARSE11_AVAILABLE() 1
45 #else
46 #define IS_CUSPARSE11_AVAILABLE() 0
47 #endif
48
49 #if IS_CUSPARSE11_AVAILABLE()
50 #include <library_types.h>
51 #endif
52
53 namespace at::native {
54
55 namespace {
56
57 using namespace at::sparse;
58
_to_csr_int(const Tensor & rowIndices,int64_t dim,int64_t nnz)59 Tensor _to_csr_int(const Tensor& rowIndices, int64_t dim, int64_t nnz) {
60 Tensor csr = at::empty({dim + 1}, CUDA(kInt));
61 Tensor rowIndicesInt = at::empty({rowIndices.size(0)}, CUDA(kInt));
62 rowIndicesInt.copy_(rowIndices);
63 sparse::cuda::Xcoo2csr(
64 rowIndicesInt.data_ptr<int32_t>(), nnz, dim, csr.data_ptr<int32_t>());
65 return csr;
66 }
67
68
69 #pragma push
70 // NVCC complains that confirm_mult_size is not used,
71 // but it is used in specializations of CusparseMatrixMultiplyOp below
72 #pragma nv_diag_suppress 177 // Function was declared but never referenced
confirm_mult_size(const std::vector<int> & mat1_size,const std::vector<int> & mat2_size)73 int confirm_mult_size(const std::vector<int>& mat1_size, const std::vector<int>& mat2_size) {
74 TORCH_CHECK(
75 mat1_size[1] == mat2_size[0],
76 "mat1 and mat2 shapes cannot be multiplied (",
77 mat1_size[0],
78 "x",
79 mat1_size[1],
80 " and ",
81 mat2_size[0],
82 "x",
83 mat2_size[1],
84 ")");
85 return mat1_size[1];
86 }
87 #pragma pop
88
create_general_description_(cusparseMatDescr_t & description_)89 void create_general_description_(cusparseMatDescr_t& description_) {
90 TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&description_));
91 TORCH_CUDASPARSE_CHECK(cusparseSetMatType(description_, CUSPARSE_MATRIX_TYPE_GENERAL));
92 TORCH_CUDASPARSE_CHECK(cusparseSetMatIndexBase(description_, CUSPARSE_INDEX_BASE_ZERO));
93 }
94
95 // csrMatrixRef is used to have a representation of a raw CSR matrix representation
96 // comming from `sparse_sparse_matmul_cuda_kernel` function.
97 // Moreover this implements a RAII guard for a cusparse descriptor
98 template<class scalar_t>
99 struct csrMatrixRef {
100 int* csr_indices_{nullptr};
101 int* csr_pointers_{nullptr};
102 scalar_t* csr_values_{nullptr};
103 int nnz_{0};
104 std::vector<int> size_{};
105
106 #if IS_CUSPARSE11_AVAILABLE()
107 cusparseSpMatDescr_t description_{0};
108 #else
109 cusparseMatDescr_t description_{0};
110 #endif
111
csrMatrixRefat::native::__anondf1659a70111::csrMatrixRef112 csrMatrixRef() {
113 #if !IS_CUSPARSE11_AVAILABLE()
114 create_general_description_(description_);
115 #endif
116 }
117
csrMatrixRefat::native::__anondf1659a70111::csrMatrixRef118 csrMatrixRef(
119 int* csr_indices,
120 int* csr_pointers,
121 scalar_t* csr_values,
122 int nnz,
123 const std::vector<int>& size)
124 : csr_indices_{csr_indices},
125 csr_pointers_{csr_pointers},
126 csr_values_{csr_values},
127 nnz_{nnz},
128 size_{size} {
129 #if IS_CUSPARSE11_AVAILABLE()
130 cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>();
131 TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
132 &description_,
133 this->size(0),
134 this->size(1),
135 this->nnz_,
136 this->csr_pointers_,
137 this->csr_indices_,
138 this->csr_values_,
139 CUSPARSE_INDEX_32I,
140 CUSPARSE_INDEX_32I,
141 CUSPARSE_INDEX_BASE_ZERO,
142 cuda_data_type));
143 #else
144 create_general_description_(description_);
145 #endif
146 }
147
~csrMatrixRefat::native::__anondf1659a70111::csrMatrixRef148 ~csrMatrixRef() {
149 #if IS_CUSPARSE11_AVAILABLE()
150 cusparseDestroySpMat(description_);
151 #else
152 cusparseDestroyMatDescr(description_);
153 #endif
154 }
155
sizeat::native::__anondf1659a70111::csrMatrixRef156 int size(int index) const {
157 return size_.at(index);
158 }
159 };
160
161 // csrOutput is used to represent the output for `CusparseMatrixMultiplyOp`
162 // Note that `csrOutput` is different from `csrMatrixRef` and the purpose
163 // of this was to have a materialized version of a CSR matrix.
164 // Moreover this implements a RAII guard for a cusparse descriptor
165 struct csrOutput {
166 Tensor csr_indices_{};
167 Tensor csr_pointers_{};
168 at::Tensor csr_values_{};
169 int nnz_{0};
170 std::vector<int> size_;
171
172 cusparseMatDescr_t description_{0};
173
csrOutputat::native::__anondf1659a70111::csrOutput174 csrOutput(const std::vector<int> &size) : size_{size} {
175 create_general_description_(description_);
176 }
177
~csrOutputat::native::__anondf1659a70111::csrOutput178 ~csrOutput() {
179 cusparseDestroyMatDescr(description_);
180 }
181
182 csrOutput(const csrOutput&) = delete;
183 csrOutput& operator=(const csrOutput&) = delete;
csrOutputat::native::__anondf1659a70111::csrOutput184 csrOutput(csrOutput&& rhs) {
185 csr_indices_ = std::move(rhs.csr_indices_);
186 csr_pointers_ = std::move(rhs.csr_pointers_);
187 csr_values_ = std::move(rhs.csr_values_);
188 nnz_ = rhs.nnz_;
189 size_ = std::move(rhs.size_);
190 description_ = rhs.description_;
191 rhs.description_ = 0;
192 }
193 csrOutput& operator=(csrOutput&&) = delete;
sizeat::native::__anondf1659a70111::csrOutput194 int size(int index) const {
195 return size_.at(index);
196 }
197 };
198
199 #if IS_CUSPARSE11_AVAILABLE()
200
201 // RAII guard helps to support cuSparse 11 API for `A @ B` operation
202 // This generic template exists because with cuSparse the `scalar_t` type could be a double or float
203 template <class scalar_t>
204 struct CusparseMatrixMultiplyOp {
205
206 cusparseSpGEMMDescr_t spgemmDesc;
207
CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp208 CusparseMatrixMultiplyOp() {
209 static_assert(
210 std::is_same<c10::Half, scalar_t>::value ||
211 std::is_same<c10::BFloat16, scalar_t>::value ||
212 std::is_same<float, scalar_t>::value ||
213 std::is_same<double, scalar_t>::value ||
214 std::is_same<c10::complex<float>, scalar_t>::value ||
215 std::is_same<c10::complex<double>, scalar_t>::value,
216 "cusparseSpGEMM only supports data type of half, bfloat16, float, double and complex float, double.");
217 // SpGEMM Computation
218 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&spgemmDesc));
219 }
220
~CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp221 ~CusparseMatrixMultiplyOp() {
222 // destroy matrix/vector descriptors
223 cusparseSpGEMM_destroyDescr(spgemmDesc);
224 }
225
operator ()at::native::__anondf1659a70111::CusparseMatrixMultiplyOp226 csrOutput operator ()(
227 const csrMatrixRef<scalar_t>& A,
228 const csrMatrixRef<scalar_t>& B,
229 Tensor& output_values,
230 Tensor& output_indices) {
231 const int A_num_rows = A.size(0);
232
233 const int B_num_cols = B.size(1);
234
235 csrOutput out({A.size(0), B.size(1)});
236
237 out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
238
239 int* dC_csrOffsets = out.csr_pointers_.data_ptr<int>();
240 int* dC_columns = nullptr;
241 scalar_t* dC_values = nullptr;
242
243 scalar_t alpha = 1.0f;
244 scalar_t beta = 0.0f;
245 cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
246 cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
247
248 csrMatrixRef<scalar_t> C(
249 dC_columns,
250 dC_csrOffsets,
251 dC_values,
252 /*nnz*/0,
253 {A_num_rows, B_num_cols}
254 );
255
256 //--------------------------------------------------------------------------
257 // CUSPARSE APIs
258 cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle();
259 void *dBuffer1 = NULL, *dBuffer2 = NULL;
260 size_t bufferSize1 = 0, bufferSize2 = 0;
261
262 cusparseSpMatDescr_t matA = A.description_;
263 cusparseSpMatDescr_t matB = B.description_;
264 cusparseSpMatDescr_t matC = C.description_;
265 //--------------------------------------------------------------------------
266
267 cudaDataType computeType = at::cuda::getCudaDataType<scalar_t>();
268
269 // If a specific GPU model does not provide native support for a given data type,
270 // the routine returns CUSPARSE_STATUS_ARCH_MISMATCH error
271 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
272 TORCH_CHECK(prop->major >= 5 && !((10*prop->major + prop->minor) < 53 && computeType == CUDA_R_16F),
273 "sparse_mm: CUDA Float16 requires compute capability >= 53 (current: ", prop->major, prop->minor, ")");
274 TORCH_CHECK(!(prop->major < 8 && computeType == CUDA_R_16BF),
275 "sparse_mm: CUDA BFloat16 requires compute capability >= 80 (current: ", prop->major, prop->minor, ")");
276
277 // ask bufferSize1 bytes for external memory
278 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
279 handle,
280 opA,
281 opB,
282 &alpha,
283 matA,
284 matB,
285 &beta,
286 matC,
287 computeType,
288 CUSPARSE_SPGEMM_DEFAULT,
289 spgemmDesc,
290 &bufferSize1,
291 NULL));
292
293 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
294
295 at::DataPtr dataPtr1 = allocator.allocate(bufferSize1);
296 dBuffer1 = dataPtr1.get();
297 // inspect the matrices A and B to understand the memory requirement for
298 // the next step
299 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
300 handle,
301 opA,
302 opB,
303 &alpha,
304 matA,
305 matB,
306 &beta,
307 matC,
308 computeType,
309 CUSPARSE_SPGEMM_DEFAULT,
310 spgemmDesc,
311 &bufferSize1,
312 dBuffer1));
313
314 // ask bufferSize2 bytes for external memory
315 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute(
316 handle,
317 opA,
318 opB,
319 &alpha,
320 matA,
321 matB,
322 &beta,
323 matC,
324 computeType,
325 CUSPARSE_SPGEMM_DEFAULT,
326 spgemmDesc,
327 &bufferSize2,
328 NULL));
329
330 at::DataPtr dataPtr2 = allocator.allocate(bufferSize2);
331 dBuffer2 = dataPtr2.get();
332
333 // compute the intermediate product of A * B
334 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute(
335 handle,
336 opA,
337 opB,
338 &alpha,
339 matA,
340 matB,
341 &beta,
342 matC,
343 computeType,
344 CUSPARSE_SPGEMM_DEFAULT,
345 spgemmDesc,
346 &bufferSize2,
347 dBuffer2));
348 // get matrix C non-zero entries C_num_nnz1
349 int64_t C_num_rows1, C_num_cols1, C_num_nnz1;
350 TORCH_CUDASPARSE_CHECK(
351 cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_num_nnz1));
352 // allocate matrix C
353 // allocate C offsets
354 out.nnz_ = C_num_nnz1;
355
356 out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
357 out.csr_values_ = at::empty({out.nnz_}, output_values.options());
358 dC_columns = out.csr_indices_.data_ptr<int>();
359 dC_values = out.csr_values_.data_ptr<scalar_t>();
360
361 // update matC with the new pointers
362 TORCH_CUDASPARSE_CHECK(
363 cusparseCsrSetPointers(matC, dC_csrOffsets, dC_columns, dC_values));
364
365 // copy the final products to the matrix C
366 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_copy(
367 handle,
368 opA,
369 opB,
370 &alpha,
371 matA,
372 matB,
373 &beta,
374 matC,
375 computeType,
376 CUSPARSE_SPGEMM_DEFAULT,
377 spgemmDesc));
378 return out;
379 }
380 };
381
382
383 template struct CusparseMatrixMultiplyOp<float>;
384
385 template struct CusparseMatrixMultiplyOp<double>;
386
387 #else // if not IS_CUSPARSE11_AVAILABLE()
388
389 using DcsrMatrixRef = csrMatrixRef<double>;
390 using ScsrMatrixRef = csrMatrixRef<float>;
391
392 // RAII guard helps to support cuSparse 10 API for `A @ B` operation
393 // This generic template exists because with cuSparse the `scalar_t` type could be a double or float
394 template <class scalar_t>
395 struct CusparseMatrixMultiplyOp {
operator ()at::native::__anondf1659a70111::CusparseMatrixMultiplyOp396 csrOutput operator()(
397 const csrMatrixRef<scalar_t>& lhs,
398 const csrMatrixRef<scalar_t>& rhs,
399 Tensor &output_values,
400 Tensor &output_indices)
401 {
402 static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double.");
403 }
404 };
405
406 // Specializacion for `A @ B` operation for double values with cuSparse
407 template<> struct CusparseMatrixMultiplyOp<double> {
408 csrgemm2Info_t gemm2Info_;
409
CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp410 CusparseMatrixMultiplyOp() {
411 TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
412 }
~CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp413 ~CusparseMatrixMultiplyOp() {
414 cusparseDestroyCsrgemm2Info(gemm2Info_);
415 }
416
operator ()at::native::__anondf1659a70111::CusparseMatrixMultiplyOp417 csrOutput operator ()(
418 const DcsrMatrixRef& lhs,
419 const DcsrMatrixRef& rhs,
420 Tensor &output_values,
421 Tensor &output_indices) {
422 double alpha = 1.0;
423 DcsrMatrixRef empty;
424 return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
425 }
426
Dgemm2at::native::__anondf1659a70111::CusparseMatrixMultiplyOp427 csrOutput Dgemm2(
428 const DcsrMatrixRef& A,
429 const DcsrMatrixRef& B,
430 const DcsrMatrixRef& C,
431 const double* alpha,
432 const double* beta,
433 Tensor &output_values,
434 Tensor &output_indices) {
435 void* buffer_{nullptr};
436 cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
437 TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
438
439 csrOutput out({A.size(0), B.size(1)});
440 int innerSize = confirm_mult_size(A.size_, B.size_);
441 out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
442
443 // Compute needed buffer size
444 size_t new_bubber_sz;
445 TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt(
446 cusparseHandle_,
447 out.size(0),
448 out.size(1),
449 innerSize,
450 alpha,
451 A.description_,
452 A.nnz_,
453 A.csr_pointers_,
454 A.csr_indices_,
455 B.description_,
456 B.nnz_,
457 B.csr_pointers_,
458 B.csr_indices_,
459 beta,
460 C.description_,
461 C.nnz_,
462 C.csr_pointers_,
463 C.csr_indices_,
464 gemm2Info_,
465 &new_bubber_sz));
466
467 // (Re)allocate buffer if needed
468 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
469 at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
470 buffer_ = data_ptr.get();
471
472 // Find the resulting non-zero pattern.
473 TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
474 cusparseHandle_,
475 out.size(0),
476 out.size(1),
477 innerSize,
478 A.description_,
479 A.nnz_,
480 A.csr_pointers_,
481 A.csr_indices_,
482 B.description_,
483 B.nnz_,
484 B.csr_pointers_,
485 B.csr_indices_,
486 C.description_,
487 C.nnz_,
488 C.csr_pointers_,
489 C.csr_indices_,
490 out.description_,
491 out.csr_pointers_.data_ptr<int>(),
492 &out.nnz_,
493 gemm2Info_,
494 buffer_));
495
496 out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
497 out.csr_values_ = at::empty({out.nnz_}, output_values.options());
498
499 // Perform the gemm2 operation for doubles
500 // out = alpha ∗ A ∗ B + beta ∗ C
501 TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2(
502 cusparseHandle_,
503 out.size(0),
504 out.size(1),
505 innerSize,
506 alpha,
507 A.description_,
508 A.nnz_,
509 A.csr_values_,
510 A.csr_pointers_,
511 A.csr_indices_,
512 B.description_,
513 B.nnz_,
514 B.csr_values_,
515 B.csr_pointers_,
516 B.csr_indices_,
517 beta,
518 C.description_,
519 C.nnz_,
520 C.csr_values_,
521 C.csr_pointers_,
522 C.csr_indices_,
523 out.description_,
524 out.csr_values_.data_ptr<double>(),
525 out.csr_pointers_.data_ptr<int>(),
526 out.csr_indices_.data_ptr<int>(),
527 gemm2Info_,
528 buffer_));
529 return out;
530 }
531 };
532
533 // Specializacion for `A @ B` operation for float values with cuSparse
534 template<> struct CusparseMatrixMultiplyOp<float> {
535 csrgemm2Info_t gemm2Info_;
536
CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp537 CusparseMatrixMultiplyOp() {
538 TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
539
540 }
~CusparseMatrixMultiplyOpat::native::__anondf1659a70111::CusparseMatrixMultiplyOp541 ~CusparseMatrixMultiplyOp() {
542 cusparseDestroyCsrgemm2Info(gemm2Info_);
543 }
operator ()at::native::__anondf1659a70111::CusparseMatrixMultiplyOp544 csrOutput operator()(
545 const ScsrMatrixRef& lhs,
546 const ScsrMatrixRef& rhs,
547 Tensor &output_values,
548 Tensor &output_indices) {
549 float alpha = 1.0;
550 ScsrMatrixRef empty;
551 return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
552 }
553
Sgemm2at::native::__anondf1659a70111::CusparseMatrixMultiplyOp554 csrOutput Sgemm2(
555 const ScsrMatrixRef& A,
556 const ScsrMatrixRef& B,
557 const ScsrMatrixRef& C,
558 const float* alpha,
559 const float* beta,
560 Tensor &output_values,
561 Tensor &output_indices) {
562 void* buffer_{nullptr};
563 cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
564 TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
565
566 csrOutput out({A.size(0), B.size(1)});
567
568 int innerSize = confirm_mult_size(A.size_, B.size_);
569
570 out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
571
572 // Compute needed buffer size
573 size_t new_bubber_sz;
574 TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt(
575 cusparseHandle_,
576 out.size(0),
577 out.size(1),
578 innerSize,
579 alpha,
580 A.description_,
581 A.nnz_,
582 A.csr_pointers_,
583 A.csr_indices_,
584 B.description_,
585 B.nnz_,
586 B.csr_pointers_,
587 B.csr_indices_,
588 beta,
589 C.description_,
590 C.nnz_,
591 C.csr_pointers_,
592 C.csr_indices_,
593 gemm2Info_,
594 &new_bubber_sz));
595
596 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
597 at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
598 buffer_ = data_ptr.get();
599
600 // Find the resulting non-zero pattern.
601 TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
602 cusparseHandle_,
603 out.size(0),
604 out.size(1),
605 innerSize,
606 A.description_,
607 A.nnz_,
608 A.csr_pointers_,
609 A.csr_indices_,
610 B.description_,
611 B.nnz_,
612 B.csr_pointers_,
613 B.csr_indices_,
614 C.description_,
615 C.nnz_,
616 C.csr_pointers_,
617 C.csr_indices_,
618 out.description_,
619 out.csr_pointers_.data_ptr<int>(),
620 &out.nnz_,
621 gemm2Info_,
622 buffer_));
623
624 out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
625 out.csr_values_ = at::empty({out.nnz_}, output_values.options());
626
627 // Perform the gemm2 operation for doubles
628 // out = alpha ∗ A ∗ B + beta ∗ C
629 TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2(
630 cusparseHandle_,
631 out.size(0),
632 out.size(1),
633 innerSize,
634 alpha,
635 A.description_,
636 A.nnz_,
637 A.csr_values_,
638 A.csr_pointers_,
639 A.csr_indices_,
640 B.description_,
641 B.nnz_,
642 B.csr_values_,
643 B.csr_pointers_,
644 B.csr_indices_,
645 beta,
646 C.description_,
647 C.nnz_,
648 C.csr_values_,
649 C.csr_pointers_,
650 C.csr_indices_,
651 out.description_,
652 out.csr_values_.data_ptr<float>(),
653 out.csr_pointers_.data_ptr<int>(),
654 out.csr_indices_.data_ptr<int>(),
655 gemm2Info_,
656 buffer_));
657 return out;
658 }
659 };
660
661
662
663 #endif // IS_CUSPARSE11_AVAILABLE()
664
665 template <typename scalar_t>
sparse_sparse_matmul_cuda_kernel(Tensor & result,const Tensor & mat1,const Tensor & mat2)666 void sparse_sparse_matmul_cuda_kernel(
667 Tensor& result,
668 const Tensor& mat1,
669 const Tensor& mat2) {
670
671 static_assert(
672 std::is_same<c10::Half, scalar_t>::value ||
673 std::is_same<c10::BFloat16, scalar_t>::value ||
674 std::is_same<float, scalar_t>::value ||
675 std::is_same<double, scalar_t>::value ||
676 std::is_same<c10::complex<float>, scalar_t>::value ||
677 std::is_same<c10::complex<double>, scalar_t>::value,
678 "sparse_sparse_matmul_cuda_kernel only supports data type of half, bfloat16, float, double and complex float, double.");
679
680 // older versions of cusparse on Windows segfault for complex128 dtype
681 #if defined(_WIN32) && defined(CUSPARSE_VERSION) && CUSPARSE_VERSION < 11400
682 TORCH_CHECK(
683 !(mat1.scalar_type() == ScalarType::ComplexDouble),
684 "Sparse multiplication with complex128 dtype inputs is not supported with current CUDA version. Please upgrade to CUDA Toolkit 11.2.1+");
685 #endif
686
687 Tensor mat1_indices_ = mat1._indices().contiguous();
688 Tensor mat1_values = mat1._values().contiguous();
689
690 Tensor mat1_row_indices = mat1_indices_.select(0, 0);
691 Tensor mat1_col_indices = mat1_indices_.select(0, 1);
692
693 Tensor mat1_indptr = _to_csr_int(mat1_row_indices, mat1.size(0), mat1._nnz());
694
695 Tensor mat1_indices = at::empty(
696 {mat1_col_indices.size(0)}, mat1_col_indices.options().dtype(kInt));
697
698 mat1_indices.copy_(mat1_col_indices);
699
700 Tensor mat2_indices_ = mat2._indices().contiguous();
701 Tensor mat2_values = mat2._values().contiguous();
702 Tensor mat2_row_indices = mat2_indices_.select(0, 0);
703 Tensor mat2_col_indices = mat2_indices_.select(0, 1);
704
705 Tensor mat2_indptr = _to_csr_int(mat2_row_indices, mat2.size(0), mat2._nnz());
706 Tensor mat2_indices = at::empty({mat2_col_indices.size(0)}, mat2_col_indices.options().dtype(kInt));
707 mat2_indices.copy_(mat2_col_indices);
708
709 auto m = mat1.size(0);
710 auto k1 = mat1.size(1);
711
712 auto k2 = mat2.size(0);
713 auto n = mat2.size(1);
714 TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k1 <= INT_MAX),
715 "At the moment, cusparseDcsrgemm2 only supports m, n, k, nnz with the bound [val] <= ", INT_MAX, ".",
716 "If you need this, please file an issue on GitHub."
717 );
718 auto output_indices = result._indices();
719 auto output_values = result._values();
720
721 if ((k1 == 0 && k2 == 0) || (n == 0 && m == 0)) {
722 output_indices.zero_();
723 output_values.zero_();
724 return;
725 }
726
727 csrMatrixRef<scalar_t> csr_mat1(
728 mat1_indices.data_ptr<int>(),
729 mat1_indptr.data_ptr<int>(),
730 mat1_values.data_ptr<scalar_t>(),
731 (int)mat1._nnz(),
732 {(int)mat1.size(0), (int)mat1.size(1)});
733
734 csrMatrixRef<scalar_t> csr_mat2(
735 mat2_indices.data_ptr<int>(),
736 mat2_indptr.data_ptr<int>(),
737 mat2_values.data_ptr<scalar_t>(),
738 (int)mat2._nnz(),
739 {(int)mat2.size(0), (int)mat2.size(1)});
740
741 // Sparse matrix multiplication
742 CusparseMatrixMultiplyOp<scalar_t> op;
743 csrOutput csr_output = op(csr_mat1, csr_mat2, output_values, output_indices);
744 auto nnz = csr_output.nnz_;
745
746 output_values.set_(csr_output.csr_values_);
747 output_indices.resize_({2, nnz});
748 auto output_indices_accessor = output_indices.packed_accessor64<int64_t, 2>();
749
750 auto csr_output_pointers_accessor =
751 csr_output.csr_pointers_.packed_accessor64<int, 1>();
752
753 auto csr_output_ind_accessor =
754 csr_output.csr_indices_.packed_accessor64<int, 1>();
755
756 auto major_dim = result.size(0);
757 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
758 at::cuda::ThrustAllocator allocator;
759 auto policy = thrust::cuda::par(allocator).on(stream);
760
761 // Filling the COO row indices
762 thrust::for_each(
763 policy,
764 thrust::make_counting_iterator(int64_t(0)),
765 thrust::make_counting_iterator(int64_t(major_dim)),
766 [output_indices_accessor,
767 csr_output_pointers_accessor,
768 major_dim,
769 nnz] __device__(int64_t i) {
770 auto Ap = csr_output_pointers_accessor.data();
771 int64_t* indices_row = output_indices_accessor[0].data();
772
773 for (int jj = Ap[i]; jj < Ap[i + 1]; jj++) {
774 indices_row[jj] = i;
775 }
776 });
777
778 // Filling the COO column indices
779 thrust::for_each(
780 policy,
781 thrust::make_counting_iterator(int64_t(0)),
782 thrust::make_counting_iterator(int64_t(csr_output.nnz_)),
783 [output_indices_accessor,
784 csr_output_pointers_accessor,
785 csr_output_ind_accessor,
786 major_dim,
787 nnz] __device__(int64_t i) {
788 int64_t* indices_col = output_indices_accessor[1].data();
789 indices_col[i] = csr_output_ind_accessor[i];
790 });
791 }
792
793 } // end anonymous namespace
794
sparse_sparse_matmul_cuda(const Tensor & mat1_,const Tensor & mat2_)795 Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
796 TORCH_INTERNAL_ASSERT(mat1_.is_sparse());
797 TORCH_INTERNAL_ASSERT(mat2_.is_sparse());
798 TORCH_CHECK(mat1_.dim() == 2);
799 TORCH_CHECK(mat2_.dim() == 2);
800 TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_mm: scalar values expected, mat1 got ", mat1_.dense_dim(), "D values");
801 TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_mm: scalar values expected, mat2 got ", mat2_.dense_dim(), "D values");
802
803 TORCH_CHECK(
804 mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (",
805 mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
806
807 TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
808 "mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type());
809
810 auto output = at::native::empty_like(mat1_);
811 output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
812
813 #if IS_CUSPARSE11_AVAILABLE()
814 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
815 sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
816 });
817 #else
818 AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
819 sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
820 });
821 #endif
822 return output;
823 }
824
825 } // namespace at::native
826