xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/cuda/CUDADataType.h>
6 #include <ATen/cuda/CUDASparse.h>
7 #include <ATen/cuda/CUDASparseBlas.h>
8 #include <ATen/cuda/CUDASparseDescriptors.h>
9 #include <ATen/native/LinearAlgebraUtils.h>
10 #include <ATen/native/cuda/MiscUtils.h>
11 #include <ATen/native/sparse/SparseBlasImpl.h>
12 #include <ATen/native/sparse/cuda/SparseBlasImpl.h>
13 #include <ATen/native/sparse/cuda/SparseBlasLegacy.h>
14 
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
20 #include <ATen/ops/empty_strided.h>
21 #endif
22 
23 #include <c10/cuda/CUDACachingAllocator.h>
24 #include <c10/util/MaybeOwned.h>
25 
26 namespace at::native::sparse::impl::cuda {
27 
28 namespace {
29 
prepare_column_major_matrix_for_cusparse(const Tensor & tensor)30 c10::MaybeOwned<Tensor> prepare_column_major_matrix_for_cusparse(
31     const Tensor& tensor) {
32   if (is_blas_compatible_column_major_order(tensor)) {
33     return at::native::expect_resolved_conj(tensor);
34   } else {
35     return c10::MaybeOwned<Tensor>::owned(cloneBatchedColumnMajor(tensor));
36   }
37 }
38 
prepare_dense_matrix_for_cusparse(const Tensor & tensor)39 c10::MaybeOwned<Tensor> inline prepare_dense_matrix_for_cusparse(
40     const Tensor& tensor) {
41 #if defined(USE_ROCM)
42   // CUDA < 11.0 doesn't support row-major layout, return column-major in this case
43   return prepare_column_major_matrix_for_cusparse(tensor);
44 #else
45   if (is_blas_compatible_row_major_order(tensor) ||
46       is_blas_compatible_column_major_order(tensor)) {
47     return at::native::expect_resolved_conj(tensor);
48   } else {
49     return c10::MaybeOwned<Tensor>::owned(
50         tensor.clone(at::MemoryFormat::Contiguous));
51   }
52 #endif
53 }
54 
copy_strided(const Tensor & tensor,IntArrayRef strides)55 Tensor copy_strided(const Tensor& tensor, IntArrayRef strides) {
56   Tensor result = at::empty_strided(tensor.sizes(), strides, tensor.options());
57   result.copy_(tensor);
58   return result;
59 }
60 
prepare_dense_matrix_for_cusparse(const Tensor & tensor,IntArrayRef strides)61 c10::MaybeOwned<Tensor> prepare_dense_matrix_for_cusparse(
62     const Tensor& tensor,
63     IntArrayRef strides) {
64   if (tensor.strides().equals(strides)) {
65     return c10::MaybeOwned<Tensor>::borrowed(tensor);
66   } else {
67     return c10::MaybeOwned<Tensor>::owned(copy_strided(tensor, strides));
68   }
69 }
70 
71 // This function is used for old CUDA Toolkit versions that doesn't support new cuSPARSE Generic API
addmm_out_legacy(const at::sparse_csr::SparseCsrTensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)72 void addmm_out_legacy(
73     const at::sparse_csr::SparseCsrTensor& mat1,
74     const Tensor& mat2,
75     const Scalar& beta,
76     const Scalar& alpha,
77     const Tensor& result) {
78   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.is_sparse_csr());
79   auto nnz = mat1._nnz();
80   auto m = mat1.size(0);
81   auto k = mat1.size(1);
82   auto n = mat2.size(1);
83   auto crow_indices = mat1.crow_indices().to(kInt);
84   auto col_indices = mat1.col_indices().to(kInt);
85   auto values = mat1.values();
86   auto mat2_ = at::native::expect_resolved_conj(mat2);
87   auto result_ = at::native::expect_resolved_conj(result);
88   at::native::s_addmm_out_csr_sparse_dense_cuda_worker(nnz, m, n, k, result, beta, *result_, alpha, crow_indices, col_indices, values, *mat2_);
89   if (!result.is_same(*result_)) {
90     result.copy_(*result_);
91   }
92 }
93 
prepare_dense_vector_for_cusparse(const Tensor & tensor)94 c10::MaybeOwned<Tensor> inline prepare_dense_vector_for_cusparse(
95     const Tensor& tensor) {
96   if (tensor.is_non_overlapping_and_dense()) {
97     return c10::MaybeOwned<Tensor>::borrowed(tensor);
98   } else {
99     return c10::MaybeOwned<Tensor>::owned(
100         tensor.clone(at::MemoryFormat::Contiguous));
101   }
102 }
103 
indices_to_32_bit_inplace(const Tensor & input)104 void inline indices_to_32_bit_inplace(const Tensor& input) {
105   static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())->set_member_tensors(
106       input.crow_indices().to(kInt),
107       input.col_indices().to(kInt),
108       input.values(),
109       input.sizes());
110 }
111 
col_indices_and_values_resize_(const Tensor & input,int64_t nnz)112 void inline col_indices_and_values_resize_(const Tensor& input, int64_t nnz) {
113   static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())->set_member_tensors(
114       input.crow_indices(),
115       input.col_indices().resize_({nnz}),
116       input.values().resize_({nnz}),
117       input.sizes());
118 }
119 
bsrsv2_bsrsm2_may_need_to_sync()120 void inline bsrsv2_bsrsm2_may_need_to_sync() {
121 #if defined(CUSPARSE_VERSION) && CUSPARSE_VERSION < 11703
122   // cusparse bsrsv2 and bsrsm2 have a synchronization issue that may cause illegal memory access in cuda <= 11.6.x
123   // See https://github.com/pytorch/pytorch/issues/71297
124   ::c10::cuda::device_synchronize();
125 #endif
126   // else: do nothing!
127 }
128 
block_sparse_triangular_solve_vec(const at::sparse_csr::SparseCsrTensor & A,const Tensor & B,const Tensor & X,bool upper,bool transpose,bool unitriangular)129 void block_sparse_triangular_solve_vec(
130     const at::sparse_csr::SparseCsrTensor& A,
131     const Tensor& B,
132     const Tensor& X,
133     bool upper,
134     bool transpose,
135     bool unitriangular) {
136 #if !AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
137   TORCH_CHECK(
138       false,
139       "Calling triangular solver with block sparse GPU tensors requires compiling ",
140       "PyTorch with ROCm 4.5.0+. ",
141       "Please use PyTorch built with newer ROCm version.");
142 #else
143   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == kSparseBsr);
144   // values is expected to be a blocks of sparse matrix
145   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.values().dim() == 3);
146   // blocks are expected to be square
147   TORCH_INTERNAL_ASSERT(A.values().size(2) == A.values().size(1));
148   // only block of size > 1 is supported in cuSPARSE
149   TORCH_INTERNAL_ASSERT(A.values().size(-1) > 1);
150   // blocks are expected to be in row- or column-major order
151   TORCH_INTERNAL_ASSERT(
152       A.values().is_contiguous() ||
153       A.values().transpose(-2, -1).is_contiguous());
154 
155   // cuSPARSE can't work with empty sparse matrices
156   if (A._nnz() == 0) {
157     X.fill_(NAN);
158     return;
159   }
160 
161   const cusparseDirection_t block_layout = A.values().is_contiguous()
162       ? CUSPARSE_DIRECTION_ROW
163       : CUSPARSE_DIRECTION_COLUMN;
164 
165   c10::MaybeOwned<Tensor> X_ = prepare_dense_matrix_for_cusparse(X);
166   c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_cusparse(B);
167 
168   auto block_size = cuda_int_cast(A.values().size(2), "block_size");
169   auto nnzb = cuda_int_cast(A._nnz(), "nnzb");
170   auto mb = cuda_int_cast(A.size(0), "mb") / block_size;
171 
172   auto desc = at::cuda::sparse::CuSparseMatDescriptor(upper, unitriangular);
173   cusparseOperation_t opA = transpose ? CUSPARSE_OPERATION_TRANSPOSE
174                                       : CUSPARSE_OPERATION_NON_TRANSPOSE;
175 
176   auto info = at::cuda::sparse::CuSparseBsrsv2Info();
177 
178   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
179       X.scalar_type(), "block_sparse_triangular_solve_vec", [&] {
180         scalar_t alpha = 1;
181         auto values = A.values();
182         auto values_data_ptr = values.data_ptr<scalar_t>();
183         auto crow_indices = A.crow_indices().to(kInt);
184         auto crow_indices_data_ptr = crow_indices.data_ptr<int>();
185         auto col_indices = A.col_indices().to(kInt);
186         auto col_indices_data_ptr = col_indices.data_ptr<int>();
187         auto handle = at::cuda::getCurrentCUDASparseHandle();
188         int buffer_size = 0;
189 
190         at::cuda::sparse::bsrsv2_bufferSize(
191             handle,
192             block_layout,
193             opA,
194             mb,
195             nnzb,
196             desc.descriptor(),
197             values_data_ptr,
198             crow_indices_data_ptr,
199             col_indices_data_ptr,
200             block_size,
201             info.descriptor(),
202             &buffer_size);
203 
204         auto& allocator = *c10::cuda::CUDACachingAllocator::get();
205         auto work_data = allocator.allocate(buffer_size);
206 
207         at::cuda::sparse::bsrsv2_analysis(
208             handle,
209             block_layout,
210             opA,
211             mb,
212             nnzb,
213             desc.descriptor(),
214             values_data_ptr,
215             crow_indices_data_ptr,
216             col_indices_data_ptr,
217             block_size,
218             info.descriptor(),
219             CUSPARSE_SOLVE_POLICY_NO_LEVEL,
220             work_data.get());
221 
222         if (!unitriangular) {
223           int first_zero_diag_idx = -1;
224           cusparseStatus_t status = cusparseXbsrsv2_zeroPivot(handle, info.descriptor(), &first_zero_diag_idx);
225           if (status == CUSPARSE_STATUS_ZERO_PIVOT) {
226             X_->fill_(NAN);
227             return;
228           }
229         }
230 
231         at::cuda::sparse::bsrsv2_solve(
232             handle,
233             block_layout,
234             opA,
235             mb,
236             nnzb,
237             &alpha,
238             desc.descriptor(),
239             values_data_ptr,
240             crow_indices_data_ptr,
241             col_indices_data_ptr,
242             block_size,
243             info.descriptor(),
244             B_->data_ptr<scalar_t>(),
245             X_->data_ptr<scalar_t>(),
246             CUSPARSE_SOLVE_POLICY_NO_LEVEL,
247             work_data.get());
248 
249         bsrsv2_bsrsm2_may_need_to_sync();
250       });
251   if (!X.is_same(*X_)) {
252     X.copy_(*X_);
253   }
254 #endif
255 }
256 
block_sparse_triangular_solve_mat(const at::sparse_csr::SparseCsrTensor & A,const Tensor & B,const Tensor & X,bool upper,bool transpose,bool unitriangular)257 void block_sparse_triangular_solve_mat(
258     const at::sparse_csr::SparseCsrTensor& A,
259     const Tensor& B,
260     const Tensor& X,
261     bool upper,
262     bool transpose,
263     bool unitriangular) {
264 #if !AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
265   TORCH_CHECK(
266       false,
267       "Calling triangular solver with block sparse GPU tensors requires compiling ",
268       "PyTorch with ROCm 4.5.0+. ",
269       "Please use PyTorch built with newer ROCm version.");
270 #else
271   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == kSparseBsr);
272   // values is expected to be a blocks of sparse matrix
273   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.values().dim() == 3);
274   // blocks are expected to be square
275   TORCH_INTERNAL_ASSERT(A.values().size(2) == A.values().size(1));
276   // only block of size > 1 is supported in cuSPARSE
277   TORCH_INTERNAL_ASSERT(A.values().size(-1) > 1);
278   // blocks are expected to be in row- or column-major order
279   TORCH_INTERNAL_ASSERT(
280       A.values().is_contiguous() ||
281       A.values().transpose(-2, -1).is_contiguous());
282 
283   // cuSPARSE can't work with empty sparse matrices
284   if (A._nnz() == 0) {
285     X.fill_(NAN);
286     return;
287   }
288 
289   const cusparseDirection_t block_layout = A.values().is_contiguous()
290       ? CUSPARSE_DIRECTION_ROW
291       : CUSPARSE_DIRECTION_COLUMN;
292 
293   c10::MaybeOwned<Tensor> X_ = prepare_column_major_matrix_for_cusparse(X);
294   c10::MaybeOwned<Tensor> B_ = prepare_column_major_matrix_for_cusparse(B);
295 
296   int ldb = cuda_int_cast(B_->stride(-1), "ldb");
297   int ldx = cuda_int_cast(X_->stride(-1), "ldx");
298 
299   cusparseOperation_t opX = CUSPARSE_OPERATION_NON_TRANSPOSE;
300   cusparseOperation_t opA = transpose ? CUSPARSE_OPERATION_TRANSPOSE
301                                       : CUSPARSE_OPERATION_NON_TRANSPOSE;
302 
303   auto block_size = cuda_int_cast(A.values().size(2), "block_size");
304   auto nnzb = cuda_int_cast(A._nnz(), "nnzb");
305   auto mb = cuda_int_cast(A.size(0), "mb") / block_size;
306   auto n = cuda_int_cast(B.size(-1), "n");
307 
308   auto desc = at::cuda::sparse::CuSparseMatDescriptor(upper, unitriangular);
309   auto info = at::cuda::sparse::CuSparseBsrsm2Info();
310 
311   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
312       X.scalar_type(), "block_sparse_triangular_solve_vec", [&] {
313         scalar_t alpha = 1;
314         auto values = A.values();
315         auto values_data_ptr = values.data_ptr<scalar_t>();
316         auto crow_indices = A.crow_indices().to(kInt);
317         auto crow_indices_data_ptr = crow_indices.data_ptr<int>();
318         auto col_indices = A.col_indices().to(kInt);
319         auto col_indices_data_ptr = col_indices.data_ptr<int>();
320         auto handle = at::cuda::getCurrentCUDASparseHandle();
321         int buffer_size = 0;
322 
323         at::cuda::sparse::bsrsm2_bufferSize(
324             handle,
325             block_layout,
326             opA,
327             opX,
328             mb,
329             n,
330             nnzb,
331             desc.descriptor(),
332             values_data_ptr,
333             crow_indices_data_ptr,
334             col_indices_data_ptr,
335             block_size,
336             info.descriptor(),
337             &buffer_size);
338 
339         auto& allocator = *c10::cuda::CUDACachingAllocator::get();
340         auto work_data = allocator.allocate(buffer_size);
341 
342         at::cuda::sparse::bsrsm2_analysis(
343             handle,
344             block_layout,
345             opA,
346             opX,
347             mb,
348             n,
349             nnzb,
350             desc.descriptor(),
351             values_data_ptr,
352             crow_indices_data_ptr,
353             col_indices_data_ptr,
354             block_size,
355             info.descriptor(),
356             CUSPARSE_SOLVE_POLICY_NO_LEVEL,
357             work_data.get());
358 
359         if (!unitriangular) {
360           int first_zero_diag_idx = -1;
361           cusparseStatus_t status = cusparseXbsrsm2_zeroPivot(handle, info.descriptor(), &first_zero_diag_idx);
362           if (status == CUSPARSE_STATUS_ZERO_PIVOT) {
363             X_->fill_(NAN);
364             return;
365           }
366         }
367 
368         at::cuda::sparse::bsrsm2_solve(
369             handle,
370             block_layout,
371             opA,
372             opX,
373             mb,
374             n,
375             nnzb,
376             &alpha,
377             desc.descriptor(),
378             values_data_ptr,
379             crow_indices_data_ptr,
380             col_indices_data_ptr,
381             block_size,
382             info.descriptor(),
383             B_->data_ptr<scalar_t>(),
384             ldb,
385             X_->data_ptr<scalar_t>(),
386             ldx,
387             CUSPARSE_SOLVE_POLICY_NO_LEVEL,
388             work_data.get());
389 
390         bsrsv2_bsrsm2_may_need_to_sync();
391       });
392   if (!X.is_same(*X_)) {
393     X.copy_(*X_);
394   }
395 #endif
396 }
397 
block_sparse_mv(const at::sparse_csr::SparseCsrTensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,const Tensor & result)398 void block_sparse_mv(
399     const at::sparse_csr::SparseCsrTensor& mat,
400     const Tensor& vec,
401     const Scalar& beta,
402     const Scalar& alpha,
403     const Tensor& result) {
404   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.layout() == kSparseBsr);
405   // values is expected to be a blocks of sparse matrix
406   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.values().dim() == 3);
407   // blocks are expected to be square
408   TORCH_INTERNAL_ASSERT(mat.values().size(2) == mat.values().size(1));
409   // only block of size > 1 is supported in cuSPARSE
410   TORCH_INTERNAL_ASSERT(mat.values().size(-1) > 1);
411   // blocks are expected to be in row- or column-major order
412   TORCH_INTERNAL_ASSERT(
413       mat.values().is_contiguous() ||
414       mat.values().transpose(-2, -1).is_contiguous());
415 
416   const cusparseDirection_t block_layout = mat.values().is_contiguous()
417       ? CUSPARSE_DIRECTION_ROW
418       : CUSPARSE_DIRECTION_COLUMN;
419 
420   c10::MaybeOwned<Tensor> result_ = prepare_dense_vector_for_cusparse(result);
421   c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_cusparse(vec);
422 
423   auto block_size = cuda_int_cast(mat.values().size(2), "block_size");
424   auto nnzb = cuda_int_cast(mat._nnz(), "nnzb");
425   auto mb = cuda_int_cast(mat.size(0), "mb") / block_size;
426   auto nb = cuda_int_cast(mat.size(1), "nb") / block_size;
427 
428   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
429       result.scalar_type(), "block_sparse_mv", [&] {
430         auto beta_ = beta.to<scalar_t>();
431         auto alpha_ = alpha.to<scalar_t>();
432         auto handle = at::cuda::getCurrentCUDASparseHandle();
433         auto desc = at::cuda::sparse::CuSparseMatDescriptor();
434         auto values = mat.values();
435         auto values_data_ptr = values.data_ptr<scalar_t>();
436         auto crow_indices = mat.crow_indices().to(kInt);
437         auto crow_indices_data_ptr = crow_indices.data_ptr<int>();
438         auto col_indices = mat.col_indices().to(kInt);
439         auto col_indices_data_ptr = col_indices.data_ptr<int>();
440         at::cuda::sparse::bsrmv(
441             handle,
442             block_layout,
443             CUSPARSE_OPERATION_NON_TRANSPOSE,
444             mb,
445             nb,
446             nnzb,
447             &alpha_,
448             desc.descriptor(),
449             values_data_ptr,
450             crow_indices_data_ptr,
451             col_indices_data_ptr,
452             block_size,
453             vec_->data_ptr<scalar_t>(),
454             &beta_,
455             result_->data_ptr<scalar_t>());
456       });
457   if (!result.is_same(*result_)) {
458     result.copy_(*result_);
459   }
460 }
461 
block_sparse_mm(const Tensor & input,const at::sparse_csr::SparseCsrTensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)462 void block_sparse_mm(
463     const Tensor& input,
464     const at::sparse_csr::SparseCsrTensor& mat1,
465     const Tensor& mat2,
466     const Scalar& beta,
467     const Scalar& alpha,
468     const Tensor& result) {
469   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.layout() == kSparseBsr);
470   // values is expected to be a blocks of sparse matrix
471   TORCH_INTERNAL_ASSERT(mat1.values().dim() == 3);
472   // blocks are expected to be square
473   TORCH_INTERNAL_ASSERT(mat1.values().size(2) == mat1.values().size(1));
474   // only block of size > 1 is supported in cuSPARSE
475   TORCH_INTERNAL_ASSERT(mat1.values().size(-1) > 1);
476   // blocks are expected to be in row- or column-major order
477   TORCH_INTERNAL_ASSERT(
478       mat1.values().is_contiguous() ||
479       mat1.values().transpose(-2, -1).is_contiguous());
480 
481   // NOTE: the code below allows arbitrary block sizes
482   // and might be potentially faster than cuSPARSE implementation
483   // especially for not very sparse inputs.
484   if (mat1.scalar_type() == ScalarType::Half
485       || mat1.scalar_type() == ScalarType::BFloat16
486       || mat1.scalar_type() == ScalarType::Float) {
487     at::native::sparse::impl::_compressed_row_strided_addmm_out(
488         input,
489         mat1,
490         mat2,
491         /*beta=*/beta,
492         /*alpha=*/alpha,
493         // @nikitaved: not sure whether `const Tensor& result` makes sense,
494         // but let's keep the interface intact, hence the const cast.
495         const_cast<Tensor&>(result));
496     return;
497   }
498 
499   if (beta.toComplexDouble() != 0. && !result.is_same(input)) {
500     result.copy_(input);
501   }
502 
503   const cusparseDirection_t block_layout = mat1.values().is_contiguous()
504       ? CUSPARSE_DIRECTION_ROW
505       : CUSPARSE_DIRECTION_COLUMN;
506 
507   c10::MaybeOwned<Tensor> mat2_ = prepare_dense_matrix_for_cusparse(mat2);
508 
509   // cuSPARSE expects column-major strides for result and we can't manipulate
510   // transpose flag of mat1
511   c10::MaybeOwned<Tensor> result_ =
512       prepare_column_major_matrix_for_cusparse(result);
513 
514   IntArrayRef result_strides = result_->strides();
515   IntArrayRef mat2_strides = mat2_->strides();
516   auto ndim = result_->dim();
517 
518   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim == 2);
519   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.dim() == 2);
520   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat2.dim() == 2);
521 
522   bool is_mat2_row_major = (mat2_strides[ndim - 1] == 1);
523   int ldb = is_mat2_row_major ? cuda_int_cast(mat2_strides[ndim - 2], "ldb")
524                               : cuda_int_cast(mat2_strides[ndim - 1], "ldb");
525   int ldc = cuda_int_cast(result_strides[ndim - 1], "ldc");
526   auto block_size = cuda_int_cast(mat1.values().size(2), "block_size");
527   auto nnzb = cuda_int_cast(mat1._nnz(), "nnzb");
528   auto mb = cuda_int_cast(mat1.size(0), "mb") / block_size;
529   auto kb = cuda_int_cast(mat1.size(1), "nb") / block_size;
530   auto n = cuda_int_cast(mat2.size(1), "n");
531 
532   // according to cuSPARSE documentation, opA can only be NON_TRANSPOSE
533   cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
534   cusparseOperation_t opB = is_mat2_row_major
535       ? CUSPARSE_OPERATION_TRANSPOSE
536       : CUSPARSE_OPERATION_NON_TRANSPOSE;
537 
538   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
539       result.scalar_type(), "block_sparse_mm", [&] {
540         auto beta_ = beta.to<scalar_t>();
541         auto alpha_ = alpha.to<scalar_t>();
542         auto handle = at::cuda::getCurrentCUDASparseHandle();
543         auto desc = at::cuda::sparse::CuSparseMatDescriptor();
544 
545         auto values = mat1.values();
546         auto values_data_ptr = values.data_ptr<scalar_t>();
547         auto crow_indices = mat1.crow_indices().to(kInt);
548         auto crow_indices_data_ptr = crow_indices.data_ptr<int>();
549         auto col_indices = mat1.col_indices().to(kInt);
550         auto col_indices_data_ptr = col_indices.data_ptr<int>();
551 
552         at::cuda::sparse::bsrmm(
553             handle,
554             block_layout,
555             opA,
556             opB,
557             mb,
558             n,
559             kb,
560             nnzb,
561             &alpha_,
562             desc.descriptor(),
563             values_data_ptr,
564             crow_indices_data_ptr,
565             col_indices_data_ptr,
566             block_size,
567             mat2_->data_ptr<scalar_t>(),
568             ldb,
569             &beta_,
570             result_->data_ptr<scalar_t>(),
571             ldc);
572       });
573 
574   if (!result.is_same(*result_)) {
575     result.copy_(*result_);
576   }
577 }
578 
spmm(const at::sparse_csr::SparseCsrTensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)579 void spmm(
580     const at::sparse_csr::SparseCsrTensor& mat1,
581     const Tensor& mat2,
582     const Scalar& beta,
583     const Scalar& alpha,
584     const Tensor& result) {
585 #if !(AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API())
586   addmm_out_legacy(mat1, mat2, beta, alpha, result);
587 #else
588   c10::MaybeOwned<Tensor> result_ = prepare_dense_matrix_for_cusparse(result);
589   c10::MaybeOwned<Tensor> mat2_ = prepare_dense_matrix_for_cusparse(mat2);
590 
591   // Here subscript "c" stands for column-major, subscript "r" stands for
592   // row-major order Both orders are supported by cuSPARSE. For mixed input we
593   // need to cast 'mat2' to order of 'result'. We compute
594   // result = mat1 @ op(mat2) + result.
595   // If order of 'mat2' and 'result' matches, the op is
596   // identity; op(mat2) == mat2. If 'result' is column-major and 'mat2' is
597   // row-major we pass 'mat2' as column-major and compute
598   // result_c = mat1 @ transpose(mat2_c) + result_c; mat2_r==transpose(mat2_c)
599   // if 'result' is row-major and 'mat2' is column-major we pass 'mat2'
600   // as row-major and compute
601   // result_r = mat1 @ transpose(mat2_r) + result_r; mat2_c==transpose(mat2_r)
602   IntArrayRef result_strides = result_->strides();
603   IntArrayRef mat2_strides = mat2_->strides();
604   auto ndim = result_->dim();
605   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim == 2 || ndim == 3);
606   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.dim() == 2 || mat1.dim() == 3);
607   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat2.dim() == 2 || mat2.dim() == 3);
608   bool is_result_row_major = (result_strides[ndim - 1] == 1);
609   bool is_mat2_row_major = (mat2_strides[ndim - 1] == 1);
610   bool transpose_B = (is_result_row_major ^ is_mat2_row_major);
611 
612   cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
613   cusparseOperation_t opB = transpose_B ? CUSPARSE_OPERATION_TRANSPOSE
614                                         : CUSPARSE_OPERATION_NON_TRANSPOSE;
615 
616   // CUDA < 11.0 doesn't support 64-bit indices and doesn't raise an error about this
617   // silently returning incorrect results
618 #if defined(USE_ROCM)
619   auto mat1_32 = at::native::_sparse_csr_tensor_unsafe(
620       mat1.crow_indices().to(kInt),
621       mat1.col_indices().to(kInt),
622       mat1.values(),
623       mat1.sizes(),
624       mat1.scalar_type(),
625       mat1.layout(),
626       mat1.device());
627   auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1_32);
628   auto algorithm = CUSPARSE_MM_ALG_DEFAULT;
629 #else
630   // TODO: update this to support COO sparse layout
631   auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1);
632   auto algorithm = CUSPARSE_SPMM_CSR_ALG2;
633 #endif
634 
635   auto descB = at::cuda::sparse::CuSparseConstDnMatDescriptor(
636       transpose_B ? mat2_->mT() : *mat2_);
637   auto descC = at::cuda::sparse::CuSparseDnMatDescriptor(*result_);
638 
639   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
640       kHalf,
641       kBFloat16,
642       result.scalar_type(),
643       "spmm",
644       [&] {
645         using opmath_t = at::opmath_type<scalar_t>;
646         auto beta_ = beta.to<opmath_t>();
647         auto alpha_ = alpha.to<opmath_t>();
648         cudaDataType compute_type = at::cuda::getCudaDataType<opmath_t>();
649         auto handle = at::cuda::getCurrentCUDASparseHandle();
650 
651         size_t buffer_size;
652         TORCH_CUDASPARSE_CHECK(cusparseSpMM_bufferSize(
653             handle,
654             opA,
655             opB,
656             &alpha_,
657             descA.descriptor(),
658             descB.unsafe_mutable_descriptor(),
659             &beta_,
660             descC.descriptor(),
661             compute_type,
662             algorithm,
663             &buffer_size // output
664             ));
665 
666         auto& allocator = *c10::cuda::CUDACachingAllocator::get();
667         auto work_data = allocator.allocate(buffer_size);
668 
669         TORCH_CUDASPARSE_CHECK(cusparseSpMM(
670             handle,
671             opA,
672             opB,
673             &alpha_,
674             descA.descriptor(),
675             descB.unsafe_mutable_descriptor(),
676             &beta_,
677             descC.descriptor(),
678             compute_type,
679             algorithm,
680             work_data.get()));
681       });
682 
683   if (!result.is_same(*result_)) {
684     result.copy_(*result_);
685   }
686 #endif // !(AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API())
687 }
688 
spgemm(const at::sparse_csr::SparseCsrTensor & A,const at::sparse_csr::SparseCsrTensor & B,const Scalar & beta,const Scalar & alpha,const at::sparse_csr::SparseCsrTensor & C)689 void spgemm(
690     const at::sparse_csr::SparseCsrTensor& A,
691     const at::sparse_csr::SparseCsrTensor& B,
692     const Scalar& beta,
693     const Scalar& alpha,
694     const at::sparse_csr::SparseCsrTensor& C) {
695   // older versions of cusparse on Windows segfault for complex128 dtype
696 #if defined(_WIN32) && defined(CUSPARSE_VERSION) && CUSPARSE_VERSION < 11400
697   TORCH_CHECK(
698       !(A.scalar_type() == ScalarType::ComplexDouble),
699       "Sparse multiplication with complex128 dtype inputs is not supported with current CUDA version. Please upgrade to CUDA Toolkit 11.2.1+");
700 #endif
701 
702   IntArrayRef A_sizes = A.sizes();
703   auto ndim = A.dim();
704   auto m = A_sizes[ndim - 2];
705 
706   IntArrayRef B_sizes = B.sizes();
707   auto n = B_sizes[ndim - 1];
708 
709   // Only 32-bit indices are supported
710   auto A_32 = at::native::_sparse_csr_tensor_unsafe(A.crow_indices().to(kInt), A.col_indices().to(kInt), A.values(), A.sizes(), A.scalar_type(), A.layout(), A.device());
711   auto B_32 = at::native::_sparse_csr_tensor_unsafe(B.crow_indices().to(kInt), B.col_indices().to(kInt), B.values(), B.sizes(), B.scalar_type(), B.layout(), B.device());
712 
713   // Modify C tensor in-place to swap indices tensors with 32-bit variants
714   indices_to_32_bit_inplace(C);
715 
716   auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(A_32);
717   auto descB = at::cuda::sparse::CuSparseSpMatCsrDescriptor(B_32);
718   auto descC = at::cuda::sparse::CuSparseSpMatCsrDescriptor(C);
719 
720   auto spgemm_desc = at::cuda::sparse::CuSparseSpGEMMDescriptor();
721   cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
722   cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
723 
724   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
725       kHalf,
726       kBFloat16,
727       C.scalar_type(),
728       "spgemm",
729       [&] {
730         auto beta_ = beta.to<scalar_t>();
731         auto alpha_ = alpha.to<scalar_t>();
732         auto compute_type = at::cuda::getCudaDataType<scalar_t>();
733         auto handle = at::cuda::getCurrentCUDASparseHandle();
734 
735         // It's required to call workEstimation twice
736         size_t buffer_size1 = 0;
737         TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
738             handle,
739             opA,
740             opB,
741             &alpha_,
742             descA.descriptor(),
743             descB.descriptor(),
744             &beta_,
745             descC.descriptor(),
746             compute_type,
747             CUSPARSE_SPGEMM_DEFAULT,
748             spgemm_desc.descriptor(),
749             &buffer_size1,
750             nullptr));
751 
752         auto& allocator = *c10::cuda::CUDACachingAllocator::get();
753         auto buffer1 = allocator.allocate(buffer_size1);
754 
755         TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
756             handle,
757             opA,
758             opB,
759             &alpha_,
760             descA.descriptor(),
761             descB.descriptor(),
762             &beta_,
763             descC.descriptor(),
764             compute_type,
765             CUSPARSE_SPGEMM_DEFAULT,
766             spgemm_desc.descriptor(),
767             &buffer_size1,
768             buffer1.get()));
769 
770         // It's required to call compute twice
771         size_t buffer_size2 = 0;
772         TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute(
773             handle,
774             opA,
775             opB,
776             &alpha_,
777             descA.descriptor(),
778             descB.descriptor(),
779             &beta_,
780             descC.descriptor(),
781             compute_type,
782             CUSPARSE_SPGEMM_DEFAULT,
783             spgemm_desc.descriptor(),
784             &buffer_size2,
785             nullptr));
786 
787         auto buffer2 = allocator.allocate(buffer_size2);
788 
789         TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute(
790             handle,
791             opA,
792             opB,
793             &alpha_,
794             descA.descriptor(),
795             descB.descriptor(),
796             &beta_,
797             descC.descriptor(),
798             compute_type,
799             CUSPARSE_SPGEMM_DEFAULT,
800             spgemm_desc.descriptor(),
801             &buffer_size2,
802             buffer2.get()));
803 
804         // Get how many specified elements are there in C
805         auto [C_num_rows, C_num_cols, C_nnz] = descC.get_size();
806 
807         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(C_num_rows == m);
808         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(C_num_cols == n);
809 
810         // Resize result using nnz information from cusparse
811         col_indices_and_values_resize_(C, C_nnz);
812 
813         // Update matC with the new pointers
814         descC.set_tensor(C);
815 
816         // Copy the data into C
817         TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_copy(
818             handle,
819             opA,
820             opB,
821             &alpha_,
822             descA.descriptor(),
823             descB.descriptor(),
824             &beta_,
825             descC.descriptor(),
826             compute_type,
827             CUSPARSE_SPGEMM_DEFAULT,
828             spgemm_desc.descriptor()));
829       });
830 }
831 
832 } // anonymous namespace
833 
addmm_out_sparse_csr(const Tensor & input,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)834 void addmm_out_sparse_csr(
835     const Tensor& input,
836     const Tensor& mat1,
837     const Tensor& mat2,
838     const Scalar& beta,
839     const Scalar& alpha,
840     const Tensor& result) {
841   TORCH_INTERNAL_ASSERT(
842       !((mat1.layout() == kStrided) && (mat2.layout() == kStrided) &&
843         (result.layout() == kStrided)),
844       "Expected at least one sparse input");
845 
846   // Layout checks are nested mat1, mat2, result
847   // Conditions are ordered strided, csr, csc, bsr, bsc.
848   // Valid combinations terminate in a return
849   // Invalid combinations are omitted and will fall though to the TORCH check
850   // generating an informative error message
851 
852   // mm functions that copy input to result when needed (e.g. mm
853   // triton kernels do not require result being initialized with
854   // input):
855   if (mat1.layout() == kSparseBsr) {
856     if (mat2.layout() == kStrided) {
857       if (result.layout() == kStrided)
858         return block_sparse_mm(input, mat1, mat2, beta, alpha, result);
859     }
860   }
861 
862   if (mat1.layout() == kStrided) {
863     if (mat2.layout() == kSparseBsc) {
864       if (result.layout() == kStrided) {
865         auto result_t = result.transpose(-2, -1);
866         auto input_t = (result.is_same(input) ? result_t : input.transpose(-2, -1));
867         return block_sparse_mm(
868             input_t,
869             mat2.transpose(-2, -1),
870             mat1.transpose(-2, -1),
871             beta,
872             alpha,
873             result_t);
874       }
875     }
876   }
877 
878   // copy input to result:
879   if (beta.toComplexDouble() != 0. && !result.is_same(input)) {
880     result.copy_(input);
881   }
882 
883   // mm functions that assume that result contains input:
884   if (mat1.layout() == kStrided) {
885     if (mat2.layout() == kSparseCsr) {
886       if (result.layout() == kStrided) {
887         // TODO: Add native CSC support via cuSPARSE if supported.
888         return spmm(
889             mat2.transpose(0, 1).to_sparse_csr(),
890             mat1.transpose(0, 1),
891             beta,
892             alpha,
893             result.transpose(0, 1));
894       }
895     }
896     if (mat2.layout() == kSparseCsc) {
897       if (result.layout() == kStrided) {
898         return spmm(
899             mat2.transpose(-2, -1),
900             mat1.transpose(-2, -1),
901             beta,
902             alpha,
903             result.transpose(-2, -1));
904       }
905     }
906   }
907   if (mat1.layout() == kSparseCsr) {
908     if (mat2.layout() == kStrided) {
909       if (result.layout() == kStrided) {
910         return spmm(mat1, mat2, beta, alpha, result);
911       }
912     }
913     if (mat2.layout() == kSparseCsr) {
914       if (result.layout() == kSparseCsr) {
915         return spgemm(mat1, mat2, beta, alpha, result);
916       }
917     }
918     if (mat2.layout() == kSparseCsc) {
919       if (result.layout() == kSparseCsr) {
920         // TODO: Add native CSC support via cuSPARSE if supported.
921         // CSR @ CSC kernel would be very fast due to format alignment
922         return spgemm(mat1, mat2.to_sparse_csr(), beta, alpha, result);
923       }
924     }
925   }
926   if (mat1.layout() == kSparseCsc) {
927     if (mat2.layout() == kStrided) {
928       if (result.layout() == kStrided) {
929         // TODO: Add native CSC support via cuSPARSE if supported.
930         return spmm(mat1.to_sparse_csr(), mat2, beta, alpha, result);
931       }
932     }
933     if (mat2.layout() == kSparseCsr) {
934       if (result.layout() == kSparseCsr)
935         // TODO: Add native CSC support via cuSPARSE if supported.
936         return spgemm(mat1.to_sparse_csr(), mat2, beta, alpha, result);
937     }
938     if (mat2.layout() == kSparseCsc) {
939       if (result.layout() == kSparseCsr) {
940         // TODO: Add native CSC support via cuSPARSE if supported.
941         return spgemm(
942             mat1.to_sparse_csr(), mat2.to_sparse_csr(), beta, alpha, result);
943       }
944       if (result.layout() == kSparseCsc) {
945         return spgemm(
946             mat2.transpose(-2, -1),
947             mat1.transpose(-2, -1),
948             beta,
949             alpha,
950             result.transpose(-2, -1));
951       }
952     }
953   }
954   TORCH_CHECK(
955       false,
956       "addmm: computation on CUDA is not implemented for ",
957       result.layout(),
958       " + ",
959       mat1.layout(),
960       " @ ",
961       mat2.layout());
962 }
963 
964 /*
965   Computes a sparse matrix-dense vector product defined as
966   y <- alpha*op(A)*x + beta*y
967 
968   Args:
969   * `mat` - Tensor storing sparse m x n matrix A.
970   * `vec` - Tensor storing dense vector x of size n.
971   * `result` - [in] Tensor storing dense vector y of size m.
972                [out] result of the operation.
973 */
addmv_out_sparse_csr(const at::sparse_csr::SparseCsrTensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,const Tensor & result)974 void addmv_out_sparse_csr(
975     const at::sparse_csr::SparseCsrTensor& mat,
976     const Tensor& vec,
977     const Scalar& beta,
978     const Scalar& alpha,
979     const Tensor& result) {
980   if (mat.layout() == kSparseBsr) {
981     return block_sparse_mv(mat, vec, beta, alpha, result);
982   }
983 #if !(AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API())
984   TORCH_CHECK(
985       false,
986       "Calling addmv on a sparse GPU tensor requires compiling ",
987       "PyTorch with CUDA 10.2+ (CUDA 11+ on Windows). ",
988       "Please use PyTorch built with newer CUDA version.");
989 #else
990   cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
991 
992   c10::MaybeOwned<Tensor> result_ = prepare_dense_vector_for_cusparse(result);
993   c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_cusparse(vec);
994 
995   // TODO: update this to support COO sparse layout
996   auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat);
997   auto descX = at::cuda::sparse::CuSparseDnVecDescriptor(*vec_);
998   auto descY = at::cuda::sparse::CuSparseDnVecDescriptor(*result_);
999 
1000   // cusparseSpMVAlg_t was updated in cuda 11.2.1 (cusparse 11.4.0)
1001 #if CUSPARSE_VERSION >= 11400
1002   cusparseSpMVAlg_t alg = CUSPARSE_SPMV_ALG_DEFAULT;
1003 #else
1004   cusparseSpMVAlg_t alg = CUSPARSE_MV_ALG_DEFAULT;
1005 #endif
1006 
1007   // SpMV doesn't support uniform precision computation
1008   // For float16/bfloat16 inputs compute_type must be CUDA_R_32F
1009   // and type of alpha, beta must be float
1010   auto dispatch_scalar_type = result.scalar_type();
1011   if (dispatch_scalar_type == at::ScalarType::Half ||
1012       dispatch_scalar_type == at::ScalarType::BFloat16) {
1013     dispatch_scalar_type = at::ScalarType::Float;
1014   }
1015 
1016   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1017       dispatch_scalar_type,
1018       "addmv_out_sparse_csr_cuda_impl",
1019       [&] {
1020         auto beta_ = beta.to<scalar_t>();
1021         auto alpha_ = alpha.to<scalar_t>();
1022         cudaDataType compute_type = at::cuda::getCudaDataType<scalar_t>();
1023         auto handle = at::cuda::getCurrentCUDASparseHandle();
1024 
1025         size_t buffer_size;
1026         TORCH_CUDASPARSE_CHECK(cusparseSpMV_bufferSize(
1027             handle,
1028             opA,
1029             &alpha_,
1030             descA.descriptor(),
1031             descX.descriptor(),
1032             &beta_,
1033             descY.descriptor(),
1034             compute_type,
1035             alg,
1036             &buffer_size // output
1037             ));
1038 
1039         auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1040         auto work_data = allocator.allocate(buffer_size);
1041 
1042         TORCH_CUDASPARSE_CHECK(cusparseSpMV(
1043             handle,
1044             opA,
1045             &alpha_,
1046             descA.descriptor(),
1047             descX.descriptor(),
1048             &beta_,
1049             descY.descriptor(),
1050             compute_type,
1051             alg,
1052             work_data.get()));
1053       });
1054   if (!result.is_same(*result_)) {
1055     result.copy_(*result_);
1056   }
1057 #endif // !(AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API())
1058 }
1059 
1060 /*
1061   Computes C = alpha * A + beta * B
1062 
1063   Args:
1064   * `A` - [in] sparse Tensor of size m × n.
1065   * `B` - [in] sparse Tensor of size m × n.
1066   * `C` - [out] sparse Tensor of size m × n.
1067 */
add_out_sparse_csr(const at::sparse_csr::SparseCsrTensor & A,const at::sparse_csr::SparseCsrTensor & B,const Scalar & alpha,const Scalar & beta,const at::sparse_csr::SparseCsrTensor & C)1068 void add_out_sparse_csr(
1069     const at::sparse_csr::SparseCsrTensor& A,
1070     const at::sparse_csr::SparseCsrTensor& B,
1071     const Scalar& alpha,
1072     const Scalar& beta,
1073     const at::sparse_csr::SparseCsrTensor& C) {
1074   IntArrayRef A_sizes = A.sizes();
1075   auto ndim = A.dim();
1076   int m = at::native::cuda_int_cast(A_sizes[ndim - 2], "m");
1077   int n = at::native::cuda_int_cast(A_sizes[ndim - 1], "n");
1078 
1079   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.sizes().equals(B.sizes()) && A.sizes().equals(C.sizes()));
1080 
1081   // Only 32-bit indices are supported
1082   const auto output_indices_dtype = promoteTypes(A.crow_indices().scalar_type(), B.crow_indices().scalar_type());
1083   auto A_32 = at::native::_sparse_csr_tensor_unsafe(
1084       A.crow_indices().to(kInt),
1085       A.col_indices().to(kInt),
1086       A.values(),
1087       A.sizes(),
1088       A.scalar_type(),
1089       A.layout(),
1090       A.device());
1091   auto B_32 = at::native::_sparse_csr_tensor_unsafe(
1092       B.crow_indices().to(kInt),
1093       B.col_indices().to(kInt),
1094       B.values(),
1095       B.sizes(),
1096       B.scalar_type(),
1097       B.layout(),
1098       B.device());
1099 
1100   // Modify C tensor in-place to swap indices tensors with 32-bit variants
1101   auto C_crow_indices_backup = C.crow_indices();
1102   auto C_col_indices_backup = C.col_indices();
1103   indices_to_32_bit_inplace(C); // no-op with 32-bit indices
1104 
1105   int nnzA = at::native::cuda_int_cast(A_32._nnz(), "nnzA");
1106   int nnzB = at::native::cuda_int_cast(B_32._nnz(), "nnzB");
1107 
1108   auto desc = at::cuda::sparse::CuSparseMatDescriptor();
1109 
1110   auto A_crow_indices = A_32.crow_indices();
1111   auto B_crow_indices = B_32.crow_indices();
1112   auto C_crow_indices = C.crow_indices();
1113   auto A_crow_indices_ptr = A_crow_indices.data_ptr<int>();
1114   auto B_crow_indices_ptr = B_crow_indices.data_ptr<int>();
1115   auto C_crow_indices_ptr = C_crow_indices.data_ptr<int>();
1116 
1117   auto A_col_indices = A_32.col_indices();
1118   auto B_col_indices = B_32.col_indices();
1119   auto C_col_indices = C.col_indices();
1120   auto A_col_indices_ptr = A_col_indices.data_ptr<int>();
1121   auto B_col_indices_ptr = B_col_indices.data_ptr<int>();
1122   auto C_col_indices_ptr = C_col_indices.data_ptr<int>();
1123 
1124   // Windows compilers don't support nested macros
1125   // so we need this lambda outside of the
1126   // AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES
1127   auto fix_nnz = [
1128 #if AT_ROCM_ENABLED()
1129                      &C_crow_indices,
1130                      &m
1131 #endif
1132   ](int nnz) -> int {
1133 // For some reason POINTER_MODE_HOST is not working here
1134 // Let's extract manually the nnz from the C_crow_indices
1135 #if AT_ROCM_ENABLED()
1136     return std::max({nnz, C_crow_indices.narrow(-1, m, 1).item<int>()});
1137 #else
1138     return nnz;
1139 #endif
1140   };
1141 
1142   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1143       C.scalar_type(), "add_out_sparse_csr_cuda_impl", [&] {
1144         auto beta_ = beta.to<scalar_t>();
1145         auto alpha_ = alpha.to<scalar_t>();
1146 
1147         auto A_values = A_32.values();
1148         auto B_values = B_32.values();
1149         auto C_values = C.values();
1150         auto A_values_ptr = A_values.data_ptr<scalar_t>();
1151         auto B_values_ptr = B_values.data_ptr<scalar_t>();
1152         auto C_values_ptr = C_values.data_ptr<scalar_t>();
1153 
1154         auto handle = at::cuda::getCurrentCUDASparseHandle();
1155         TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST));
1156 
1157         size_t buffer_size;
1158         at::cuda::sparse::csrgeam2_bufferSizeExt<scalar_t>(
1159             handle,
1160             m,
1161             n,
1162             &alpha_,
1163             desc.descriptor(),
1164             nnzA,
1165             A_values_ptr,
1166             A_crow_indices_ptr,
1167             A_col_indices_ptr,
1168             &beta_,
1169             desc.descriptor(),
1170             nnzB,
1171             B_values_ptr,
1172             B_crow_indices_ptr,
1173             B_col_indices_ptr,
1174             desc.descriptor(),
1175             C_values_ptr,
1176             C_crow_indices_ptr,
1177             C_col_indices_ptr,
1178             &buffer_size // output
1179         );
1180 
1181         auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1182         auto work_data = allocator.allocate(buffer_size);
1183 
1184         int nnzC = -1;
1185         at::cuda::sparse::csrgeam2Nnz<scalar_t>(
1186             handle,
1187             m,
1188             n,
1189             desc.descriptor(),
1190             nnzA,
1191             A_crow_indices_ptr,
1192             A_col_indices_ptr,
1193             desc.descriptor(),
1194             nnzB,
1195             B_crow_indices_ptr,
1196             B_col_indices_ptr,
1197             desc.descriptor(),
1198             C_crow_indices_ptr,
1199             &nnzC,
1200             work_data.get());
1201 
1202         nnzC = fix_nnz(nnzC);
1203 
1204         // Resize result using nnz information from cusparse
1205         col_indices_and_values_resize_(C, nnzC);
1206         C_col_indices = C.col_indices();
1207         C_values = C.values();
1208 
1209         C_col_indices_ptr = C_col_indices.data_ptr<int>();
1210         C_values_ptr = C_values.data_ptr<scalar_t>();
1211 
1212         at::cuda::sparse::csrgeam2<scalar_t>(
1213             handle,
1214             m,
1215             n,
1216             &alpha_,
1217             desc.descriptor(),
1218             nnzA,
1219             A_values_ptr,
1220             A_crow_indices_ptr,
1221             A_col_indices_ptr,
1222             &beta_,
1223             desc.descriptor(),
1224             nnzB,
1225             B_values_ptr,
1226             B_crow_indices_ptr,
1227             B_col_indices_ptr,
1228             desc.descriptor(),
1229             C_values_ptr,
1230             C_crow_indices_ptr,
1231             C_col_indices_ptr,
1232             work_data.get());
1233 
1234         if (output_indices_dtype == at::kLong) {
1235           static_cast<SparseCsrTensorImpl*>(C.unsafeGetTensorImpl())->set_member_tensors(
1236               C_crow_indices_backup.copy_(C.crow_indices()),
1237               C_col_indices_backup.resize_({nnzC}).copy_(C.col_indices()),
1238               C.values(),
1239               C.sizes());
1240         }
1241       });
1242 }
1243 
1244 /*
1245   Solves a system of linear equations whose coefficients are represented in a sparse triangular matrix A:
1246   op(A) X = B.
1247 
1248   Args:
1249   * `A` - sparse Tensor of size m × m.
1250   * `B` - dense Tensor of size m × nrhs.
1251   * `X` - dense Tensor of size m × nrhs.
1252   * `upper` - controls whether upper or lower triangular part of A is considered in computations.
1253   * `transpose` - if true then op(A) = A^T.
1254   * `unitriangular` - if true then the diagonal elements of A are assumed to be one.
1255 */
triangular_solve_out_sparse_csr(const at::sparse_csr::SparseCsrTensor & A,const Tensor & B,const Tensor & X,bool upper,bool transpose,bool unitriangular)1256 void triangular_solve_out_sparse_csr(
1257     const at::sparse_csr::SparseCsrTensor& A,
1258     const Tensor& B,
1259     const Tensor& X,
1260     bool upper,
1261     bool transpose,
1262     bool unitriangular) {
1263   if (B.numel() == 0 || X.numel() == 0 || A._nnz() == 0) {
1264     // If A has no nnz, then A is singular and we can't solve.
1265     X.fill_(NAN);
1266     return;
1267   }
1268   if (A.layout() == kSparseBsr) {
1269     if (B.size(-1) == 1) {
1270       return block_sparse_triangular_solve_vec(A, B, X, upper, transpose, unitriangular);
1271     } else {
1272       return block_sparse_triangular_solve_mat(A, B, X, upper, transpose, unitriangular);
1273     }
1274   }
1275 #if !AT_USE_CUSPARSE_GENERIC_SPSV()
1276   TORCH_CHECK(
1277       false,
1278       "Calling triangular solve on a sparse GPU tensor requires compiling ",
1279       "PyTorch with at least CUDA 11.3. ",
1280       "Please use PyTorch built with newer CUDA version.");
1281 #else
1282   c10::MaybeOwned<Tensor> X_ = prepare_dense_matrix_for_cusparse(X);
1283   // It should be possible to use mixed memory format
1284   // but there is a bug in CUDA 11.3.1 version:
1285   // strides of matrix B are used to write result to matrix X.
1286   // As a workaround we need to convert matrices to have the same strides.
1287   c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_cusparse(B, X_->strides());
1288 
1289   // TODO: update this to support COO sparse layout
1290   auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(A);
1291   descA.set_mat_fill_mode(upper);
1292   descA.set_mat_diag_type(unitriangular);
1293   cusparseOperation_t opA = transpose ? CUSPARSE_OPERATION_TRANSPOSE
1294                                       : CUSPARSE_OPERATION_NON_TRANSPOSE;
1295 
1296   if (B.size(-1) == 1) {
1297     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1298         X.scalar_type(), "triangular_solve_out_sparse_csr_cuda_impl", [&] {
1299           scalar_t alpha = 1;
1300           cudaDataType compute_type = at::cuda::getCudaDataType<scalar_t>();
1301           auto handle = at::cuda::getCurrentCUDASparseHandle();
1302           size_t buffer_size;
1303 
1304           auto desc_spsv = at::cuda::sparse::CuSparseSpSVDescriptor();
1305           auto descB = at::cuda::sparse::CuSparseDnVecDescriptor(*B_);
1306           auto descX = at::cuda::sparse::CuSparseDnVecDescriptor(*X_);
1307           TORCH_CUDASPARSE_CHECK(cusparseSpSV_bufferSize(
1308               handle,
1309               opA,
1310               &alpha,
1311               descA.descriptor(),
1312               descB.descriptor(),
1313               descX.descriptor(),
1314               compute_type,
1315               CUSPARSE_SPSV_ALG_DEFAULT,
1316               desc_spsv.descriptor(),
1317               &buffer_size // output
1318               ));
1319 
1320           auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1321           auto work_data = allocator.allocate(buffer_size);
1322 
1323           TORCH_CUDASPARSE_CHECK(cusparseSpSV_analysis(
1324               handle,
1325               opA,
1326               &alpha,
1327               descA.descriptor(),
1328               descB.descriptor(),
1329               descX.descriptor(),
1330               compute_type,
1331               CUSPARSE_SPSV_ALG_DEFAULT,
1332               desc_spsv.descriptor(),
1333               work_data.get()));
1334 
1335           TORCH_CUDASPARSE_CHECK(cusparseSpSV_solve(
1336               handle,
1337               opA,
1338               &alpha,
1339               descA.descriptor(),
1340               descB.descriptor(),
1341               descX.descriptor(),
1342               compute_type,
1343               CUSPARSE_SPSV_ALG_DEFAULT,
1344               desc_spsv.descriptor()));
1345         });
1346   } else {
1347 #if !AT_USE_CUSPARSE_GENERIC_SPSM()
1348     TORCH_CHECK(
1349         false,
1350         "Calling triangular solve on a sparse GPU tensor requires compiling ",
1351         "PyTorch with at least CUDA 11.3.1. ",
1352         "Please use PyTorch built with newer CUDA version.");
1353 #else
1354     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1355         X.scalar_type(), "triangular_solve_out_sparse_csr_cuda_impl", [&] {
1356           scalar_t alpha = 1;
1357           cudaDataType compute_type = at::cuda::getCudaDataType<scalar_t>();
1358           auto handle = at::cuda::getCurrentCUDASparseHandle();
1359           size_t buffer_size;
1360 
1361           cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
1362           auto desc_spsm = at::cuda::sparse::CuSparseSpSMDescriptor();
1363           auto descB = at::cuda::sparse::CuSparseDnMatDescriptor(*B_);
1364           auto descX = at::cuda::sparse::CuSparseDnMatDescriptor(*X_);
1365           TORCH_CUDASPARSE_CHECK(cusparseSpSM_bufferSize(
1366               handle,
1367               opA,
1368               opB,
1369               &alpha,
1370               descA.descriptor(),
1371               descB.descriptor(),
1372               descX.descriptor(),
1373               compute_type,
1374               CUSPARSE_SPSM_ALG_DEFAULT,
1375               desc_spsm.descriptor(),
1376               &buffer_size // output
1377               ));
1378 
1379           auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1380           auto work_data = allocator.allocate(buffer_size);
1381 
1382           TORCH_CUDASPARSE_CHECK(cusparseSpSM_analysis(
1383               handle,
1384               opA,
1385               opB,
1386               &alpha,
1387               descA.descriptor(),
1388               descB.descriptor(),
1389               descX.descriptor(),
1390               compute_type,
1391               CUSPARSE_SPSM_ALG_DEFAULT,
1392               desc_spsm.descriptor(),
1393               work_data.get()));
1394 
1395           TORCH_CUDASPARSE_CHECK(cusparseSpSM_solve(
1396               handle,
1397               opA,
1398               opB,
1399               &alpha,
1400               descA.descriptor(),
1401               descB.descriptor(),
1402               descX.descriptor(),
1403               compute_type,
1404               CUSPARSE_SPSM_ALG_DEFAULT,
1405               desc_spsm.descriptor()));
1406         });
1407 #endif // !AT_USE_CUSPARSE_GENERIC_SPSM()
1408   }
1409   if (!X.is_same(*X_)) {
1410     X.copy_(*X_);
1411   }
1412 #endif // !AT_USE_CUSPARSE_GENERIC_SPSV()
1413 }
1414 
sampled_addmm_out_sparse_csr(const Tensor & A,const Tensor & B,const Scalar & beta,const Scalar & alpha,const at::sparse_csr::SparseCsrTensor & C)1415 void sampled_addmm_out_sparse_csr(
1416     const Tensor& A,
1417     const Tensor& B,
1418     const Scalar& beta,
1419     const Scalar& alpha,
1420     const at::sparse_csr::SparseCsrTensor& C) {
1421 #if !(AT_USE_CUSPARSE_GENERIC_SDDMM() || AT_USE_HIPSPARSE_GENERIC_API())
1422   TORCH_CHECK(
1423       false,
1424       "Calling sampled_addmm with sparse GPU tensors requires compiling ",
1425       "PyTorch with CUDA 11.2.1+. ",
1426       "Please use PyTorch built with newer CUDA version.");
1427 #else
1428   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == Layout::Strided);
1429   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(B.layout() == Layout::Strided);
1430   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(C.is_sparse_csr());
1431 
1432   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(A) == batchCount(B));
1433   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(A) == batchCount(C));
1434 
1435   cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
1436   cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
1437 
1438   c10::MaybeOwned<Tensor> A_ = prepare_dense_matrix_for_cusparse(A);
1439   c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_cusparse(B);
1440 
1441   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1442       C.scalar_type(),
1443       "sampled_addmm_out_sparse_csr",
1444       [&] {
1445         // CUDA 11.6 doesn't support batched inputs, it raises an error:
1446         // ** On entry to cusparseSDDMM_bufferSize(): batched SDDMM is not supported
1447         // So we need to resort to the for loop
1448         for (const auto i : c10::irange(batchCount(A))) {
1449           auto descA = at::cuda::sparse::CuSparseConstDnMatDescriptor(*A_, /*batch_offset=*/i);
1450           auto descB = at::cuda::sparse::CuSparseConstDnMatDescriptor(*B_, /*batch_offset=*/i);
1451           auto descC = at::cuda::sparse::CuSparseSpMatCsrDescriptor(C, /*batch_offset=*/i);
1452 
1453           auto beta_ = beta.to<scalar_t>();
1454           auto alpha_ = alpha.to<scalar_t>();
1455           auto compute_type = at::cuda::getCudaDataType<scalar_t>();
1456           auto handle = at::cuda::getCurrentCUDASparseHandle();
1457           size_t buffer_size = 0;
1458           TORCH_CUDASPARSE_CHECK(cusparseSDDMM_bufferSize(
1459               handle,
1460               opA,
1461               opB,
1462               &alpha_,
1463               descA.unsafe_mutable_descriptor(),
1464               descB.unsafe_mutable_descriptor(),
1465               &beta_,
1466               descC.descriptor(),
1467               compute_type,
1468               CUSPARSE_SDDMM_ALG_DEFAULT,
1469               &buffer_size // output
1470               ));
1471 
1472           auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1473           auto buffer = allocator.allocate(buffer_size);
1474 
1475           TORCH_CUDASPARSE_CHECK(cusparseSDDMM_preprocess(
1476               handle,
1477               opA,
1478               opB,
1479               &alpha_,
1480               descA.unsafe_mutable_descriptor(),
1481               descB.unsafe_mutable_descriptor(),
1482               &beta_,
1483               descC.descriptor(),
1484               compute_type,
1485               CUSPARSE_SDDMM_ALG_DEFAULT,
1486               buffer.get()));
1487 
1488           TORCH_CUDASPARSE_CHECK(cusparseSDDMM(
1489               handle,
1490               opA,
1491               opB,
1492               &alpha_,
1493               descA.unsafe_mutable_descriptor(),
1494               descB.unsafe_mutable_descriptor(),
1495               &beta_,
1496               descC.descriptor(),
1497               compute_type,
1498               CUSPARSE_SDDMM_ALG_DEFAULT,
1499               buffer.get()));
1500         }
1501       });
1502 #endif
1503 }
1504 
1505 } // namespace at::native::sparse::impl::cuda
1506