1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/cuda/CUDADataType.h>
6 #include <ATen/cuda/CUDASparse.h>
7 #include <ATen/cuda/CUDASparseBlas.h>
8 #include <ATen/cuda/CUDASparseDescriptors.h>
9 #include <ATen/native/LinearAlgebraUtils.h>
10 #include <ATen/native/cuda/MiscUtils.h>
11 #include <ATen/native/sparse/SparseBlasImpl.h>
12 #include <ATen/native/sparse/cuda/SparseBlasImpl.h>
13 #include <ATen/native/sparse/cuda/SparseBlasLegacy.h>
14
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
20 #include <ATen/ops/empty_strided.h>
21 #endif
22
23 #include <c10/cuda/CUDACachingAllocator.h>
24 #include <c10/util/MaybeOwned.h>
25
26 namespace at::native::sparse::impl::cuda {
27
28 namespace {
29
prepare_column_major_matrix_for_cusparse(const Tensor & tensor)30 c10::MaybeOwned<Tensor> prepare_column_major_matrix_for_cusparse(
31 const Tensor& tensor) {
32 if (is_blas_compatible_column_major_order(tensor)) {
33 return at::native::expect_resolved_conj(tensor);
34 } else {
35 return c10::MaybeOwned<Tensor>::owned(cloneBatchedColumnMajor(tensor));
36 }
37 }
38
prepare_dense_matrix_for_cusparse(const Tensor & tensor)39 c10::MaybeOwned<Tensor> inline prepare_dense_matrix_for_cusparse(
40 const Tensor& tensor) {
41 #if defined(USE_ROCM)
42 // CUDA < 11.0 doesn't support row-major layout, return column-major in this case
43 return prepare_column_major_matrix_for_cusparse(tensor);
44 #else
45 if (is_blas_compatible_row_major_order(tensor) ||
46 is_blas_compatible_column_major_order(tensor)) {
47 return at::native::expect_resolved_conj(tensor);
48 } else {
49 return c10::MaybeOwned<Tensor>::owned(
50 tensor.clone(at::MemoryFormat::Contiguous));
51 }
52 #endif
53 }
54
copy_strided(const Tensor & tensor,IntArrayRef strides)55 Tensor copy_strided(const Tensor& tensor, IntArrayRef strides) {
56 Tensor result = at::empty_strided(tensor.sizes(), strides, tensor.options());
57 result.copy_(tensor);
58 return result;
59 }
60
prepare_dense_matrix_for_cusparse(const Tensor & tensor,IntArrayRef strides)61 c10::MaybeOwned<Tensor> prepare_dense_matrix_for_cusparse(
62 const Tensor& tensor,
63 IntArrayRef strides) {
64 if (tensor.strides().equals(strides)) {
65 return c10::MaybeOwned<Tensor>::borrowed(tensor);
66 } else {
67 return c10::MaybeOwned<Tensor>::owned(copy_strided(tensor, strides));
68 }
69 }
70
71 // This function is used for old CUDA Toolkit versions that doesn't support new cuSPARSE Generic API
addmm_out_legacy(const at::sparse_csr::SparseCsrTensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)72 void addmm_out_legacy(
73 const at::sparse_csr::SparseCsrTensor& mat1,
74 const Tensor& mat2,
75 const Scalar& beta,
76 const Scalar& alpha,
77 const Tensor& result) {
78 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.is_sparse_csr());
79 auto nnz = mat1._nnz();
80 auto m = mat1.size(0);
81 auto k = mat1.size(1);
82 auto n = mat2.size(1);
83 auto crow_indices = mat1.crow_indices().to(kInt);
84 auto col_indices = mat1.col_indices().to(kInt);
85 auto values = mat1.values();
86 auto mat2_ = at::native::expect_resolved_conj(mat2);
87 auto result_ = at::native::expect_resolved_conj(result);
88 at::native::s_addmm_out_csr_sparse_dense_cuda_worker(nnz, m, n, k, result, beta, *result_, alpha, crow_indices, col_indices, values, *mat2_);
89 if (!result.is_same(*result_)) {
90 result.copy_(*result_);
91 }
92 }
93
prepare_dense_vector_for_cusparse(const Tensor & tensor)94 c10::MaybeOwned<Tensor> inline prepare_dense_vector_for_cusparse(
95 const Tensor& tensor) {
96 if (tensor.is_non_overlapping_and_dense()) {
97 return c10::MaybeOwned<Tensor>::borrowed(tensor);
98 } else {
99 return c10::MaybeOwned<Tensor>::owned(
100 tensor.clone(at::MemoryFormat::Contiguous));
101 }
102 }
103
indices_to_32_bit_inplace(const Tensor & input)104 void inline indices_to_32_bit_inplace(const Tensor& input) {
105 static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())->set_member_tensors(
106 input.crow_indices().to(kInt),
107 input.col_indices().to(kInt),
108 input.values(),
109 input.sizes());
110 }
111
col_indices_and_values_resize_(const Tensor & input,int64_t nnz)112 void inline col_indices_and_values_resize_(const Tensor& input, int64_t nnz) {
113 static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())->set_member_tensors(
114 input.crow_indices(),
115 input.col_indices().resize_({nnz}),
116 input.values().resize_({nnz}),
117 input.sizes());
118 }
119
bsrsv2_bsrsm2_may_need_to_sync()120 void inline bsrsv2_bsrsm2_may_need_to_sync() {
121 #if defined(CUSPARSE_VERSION) && CUSPARSE_VERSION < 11703
122 // cusparse bsrsv2 and bsrsm2 have a synchronization issue that may cause illegal memory access in cuda <= 11.6.x
123 // See https://github.com/pytorch/pytorch/issues/71297
124 ::c10::cuda::device_synchronize();
125 #endif
126 // else: do nothing!
127 }
128
block_sparse_triangular_solve_vec(const at::sparse_csr::SparseCsrTensor & A,const Tensor & B,const Tensor & X,bool upper,bool transpose,bool unitriangular)129 void block_sparse_triangular_solve_vec(
130 const at::sparse_csr::SparseCsrTensor& A,
131 const Tensor& B,
132 const Tensor& X,
133 bool upper,
134 bool transpose,
135 bool unitriangular) {
136 #if !AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
137 TORCH_CHECK(
138 false,
139 "Calling triangular solver with block sparse GPU tensors requires compiling ",
140 "PyTorch with ROCm 4.5.0+. ",
141 "Please use PyTorch built with newer ROCm version.");
142 #else
143 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == kSparseBsr);
144 // values is expected to be a blocks of sparse matrix
145 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.values().dim() == 3);
146 // blocks are expected to be square
147 TORCH_INTERNAL_ASSERT(A.values().size(2) == A.values().size(1));
148 // only block of size > 1 is supported in cuSPARSE
149 TORCH_INTERNAL_ASSERT(A.values().size(-1) > 1);
150 // blocks are expected to be in row- or column-major order
151 TORCH_INTERNAL_ASSERT(
152 A.values().is_contiguous() ||
153 A.values().transpose(-2, -1).is_contiguous());
154
155 // cuSPARSE can't work with empty sparse matrices
156 if (A._nnz() == 0) {
157 X.fill_(NAN);
158 return;
159 }
160
161 const cusparseDirection_t block_layout = A.values().is_contiguous()
162 ? CUSPARSE_DIRECTION_ROW
163 : CUSPARSE_DIRECTION_COLUMN;
164
165 c10::MaybeOwned<Tensor> X_ = prepare_dense_matrix_for_cusparse(X);
166 c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_cusparse(B);
167
168 auto block_size = cuda_int_cast(A.values().size(2), "block_size");
169 auto nnzb = cuda_int_cast(A._nnz(), "nnzb");
170 auto mb = cuda_int_cast(A.size(0), "mb") / block_size;
171
172 auto desc = at::cuda::sparse::CuSparseMatDescriptor(upper, unitriangular);
173 cusparseOperation_t opA = transpose ? CUSPARSE_OPERATION_TRANSPOSE
174 : CUSPARSE_OPERATION_NON_TRANSPOSE;
175
176 auto info = at::cuda::sparse::CuSparseBsrsv2Info();
177
178 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
179 X.scalar_type(), "block_sparse_triangular_solve_vec", [&] {
180 scalar_t alpha = 1;
181 auto values = A.values();
182 auto values_data_ptr = values.data_ptr<scalar_t>();
183 auto crow_indices = A.crow_indices().to(kInt);
184 auto crow_indices_data_ptr = crow_indices.data_ptr<int>();
185 auto col_indices = A.col_indices().to(kInt);
186 auto col_indices_data_ptr = col_indices.data_ptr<int>();
187 auto handle = at::cuda::getCurrentCUDASparseHandle();
188 int buffer_size = 0;
189
190 at::cuda::sparse::bsrsv2_bufferSize(
191 handle,
192 block_layout,
193 opA,
194 mb,
195 nnzb,
196 desc.descriptor(),
197 values_data_ptr,
198 crow_indices_data_ptr,
199 col_indices_data_ptr,
200 block_size,
201 info.descriptor(),
202 &buffer_size);
203
204 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
205 auto work_data = allocator.allocate(buffer_size);
206
207 at::cuda::sparse::bsrsv2_analysis(
208 handle,
209 block_layout,
210 opA,
211 mb,
212 nnzb,
213 desc.descriptor(),
214 values_data_ptr,
215 crow_indices_data_ptr,
216 col_indices_data_ptr,
217 block_size,
218 info.descriptor(),
219 CUSPARSE_SOLVE_POLICY_NO_LEVEL,
220 work_data.get());
221
222 if (!unitriangular) {
223 int first_zero_diag_idx = -1;
224 cusparseStatus_t status = cusparseXbsrsv2_zeroPivot(handle, info.descriptor(), &first_zero_diag_idx);
225 if (status == CUSPARSE_STATUS_ZERO_PIVOT) {
226 X_->fill_(NAN);
227 return;
228 }
229 }
230
231 at::cuda::sparse::bsrsv2_solve(
232 handle,
233 block_layout,
234 opA,
235 mb,
236 nnzb,
237 &alpha,
238 desc.descriptor(),
239 values_data_ptr,
240 crow_indices_data_ptr,
241 col_indices_data_ptr,
242 block_size,
243 info.descriptor(),
244 B_->data_ptr<scalar_t>(),
245 X_->data_ptr<scalar_t>(),
246 CUSPARSE_SOLVE_POLICY_NO_LEVEL,
247 work_data.get());
248
249 bsrsv2_bsrsm2_may_need_to_sync();
250 });
251 if (!X.is_same(*X_)) {
252 X.copy_(*X_);
253 }
254 #endif
255 }
256
block_sparse_triangular_solve_mat(const at::sparse_csr::SparseCsrTensor & A,const Tensor & B,const Tensor & X,bool upper,bool transpose,bool unitriangular)257 void block_sparse_triangular_solve_mat(
258 const at::sparse_csr::SparseCsrTensor& A,
259 const Tensor& B,
260 const Tensor& X,
261 bool upper,
262 bool transpose,
263 bool unitriangular) {
264 #if !AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
265 TORCH_CHECK(
266 false,
267 "Calling triangular solver with block sparse GPU tensors requires compiling ",
268 "PyTorch with ROCm 4.5.0+. ",
269 "Please use PyTorch built with newer ROCm version.");
270 #else
271 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == kSparseBsr);
272 // values is expected to be a blocks of sparse matrix
273 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.values().dim() == 3);
274 // blocks are expected to be square
275 TORCH_INTERNAL_ASSERT(A.values().size(2) == A.values().size(1));
276 // only block of size > 1 is supported in cuSPARSE
277 TORCH_INTERNAL_ASSERT(A.values().size(-1) > 1);
278 // blocks are expected to be in row- or column-major order
279 TORCH_INTERNAL_ASSERT(
280 A.values().is_contiguous() ||
281 A.values().transpose(-2, -1).is_contiguous());
282
283 // cuSPARSE can't work with empty sparse matrices
284 if (A._nnz() == 0) {
285 X.fill_(NAN);
286 return;
287 }
288
289 const cusparseDirection_t block_layout = A.values().is_contiguous()
290 ? CUSPARSE_DIRECTION_ROW
291 : CUSPARSE_DIRECTION_COLUMN;
292
293 c10::MaybeOwned<Tensor> X_ = prepare_column_major_matrix_for_cusparse(X);
294 c10::MaybeOwned<Tensor> B_ = prepare_column_major_matrix_for_cusparse(B);
295
296 int ldb = cuda_int_cast(B_->stride(-1), "ldb");
297 int ldx = cuda_int_cast(X_->stride(-1), "ldx");
298
299 cusparseOperation_t opX = CUSPARSE_OPERATION_NON_TRANSPOSE;
300 cusparseOperation_t opA = transpose ? CUSPARSE_OPERATION_TRANSPOSE
301 : CUSPARSE_OPERATION_NON_TRANSPOSE;
302
303 auto block_size = cuda_int_cast(A.values().size(2), "block_size");
304 auto nnzb = cuda_int_cast(A._nnz(), "nnzb");
305 auto mb = cuda_int_cast(A.size(0), "mb") / block_size;
306 auto n = cuda_int_cast(B.size(-1), "n");
307
308 auto desc = at::cuda::sparse::CuSparseMatDescriptor(upper, unitriangular);
309 auto info = at::cuda::sparse::CuSparseBsrsm2Info();
310
311 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
312 X.scalar_type(), "block_sparse_triangular_solve_vec", [&] {
313 scalar_t alpha = 1;
314 auto values = A.values();
315 auto values_data_ptr = values.data_ptr<scalar_t>();
316 auto crow_indices = A.crow_indices().to(kInt);
317 auto crow_indices_data_ptr = crow_indices.data_ptr<int>();
318 auto col_indices = A.col_indices().to(kInt);
319 auto col_indices_data_ptr = col_indices.data_ptr<int>();
320 auto handle = at::cuda::getCurrentCUDASparseHandle();
321 int buffer_size = 0;
322
323 at::cuda::sparse::bsrsm2_bufferSize(
324 handle,
325 block_layout,
326 opA,
327 opX,
328 mb,
329 n,
330 nnzb,
331 desc.descriptor(),
332 values_data_ptr,
333 crow_indices_data_ptr,
334 col_indices_data_ptr,
335 block_size,
336 info.descriptor(),
337 &buffer_size);
338
339 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
340 auto work_data = allocator.allocate(buffer_size);
341
342 at::cuda::sparse::bsrsm2_analysis(
343 handle,
344 block_layout,
345 opA,
346 opX,
347 mb,
348 n,
349 nnzb,
350 desc.descriptor(),
351 values_data_ptr,
352 crow_indices_data_ptr,
353 col_indices_data_ptr,
354 block_size,
355 info.descriptor(),
356 CUSPARSE_SOLVE_POLICY_NO_LEVEL,
357 work_data.get());
358
359 if (!unitriangular) {
360 int first_zero_diag_idx = -1;
361 cusparseStatus_t status = cusparseXbsrsm2_zeroPivot(handle, info.descriptor(), &first_zero_diag_idx);
362 if (status == CUSPARSE_STATUS_ZERO_PIVOT) {
363 X_->fill_(NAN);
364 return;
365 }
366 }
367
368 at::cuda::sparse::bsrsm2_solve(
369 handle,
370 block_layout,
371 opA,
372 opX,
373 mb,
374 n,
375 nnzb,
376 &alpha,
377 desc.descriptor(),
378 values_data_ptr,
379 crow_indices_data_ptr,
380 col_indices_data_ptr,
381 block_size,
382 info.descriptor(),
383 B_->data_ptr<scalar_t>(),
384 ldb,
385 X_->data_ptr<scalar_t>(),
386 ldx,
387 CUSPARSE_SOLVE_POLICY_NO_LEVEL,
388 work_data.get());
389
390 bsrsv2_bsrsm2_may_need_to_sync();
391 });
392 if (!X.is_same(*X_)) {
393 X.copy_(*X_);
394 }
395 #endif
396 }
397
block_sparse_mv(const at::sparse_csr::SparseCsrTensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,const Tensor & result)398 void block_sparse_mv(
399 const at::sparse_csr::SparseCsrTensor& mat,
400 const Tensor& vec,
401 const Scalar& beta,
402 const Scalar& alpha,
403 const Tensor& result) {
404 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.layout() == kSparseBsr);
405 // values is expected to be a blocks of sparse matrix
406 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.values().dim() == 3);
407 // blocks are expected to be square
408 TORCH_INTERNAL_ASSERT(mat.values().size(2) == mat.values().size(1));
409 // only block of size > 1 is supported in cuSPARSE
410 TORCH_INTERNAL_ASSERT(mat.values().size(-1) > 1);
411 // blocks are expected to be in row- or column-major order
412 TORCH_INTERNAL_ASSERT(
413 mat.values().is_contiguous() ||
414 mat.values().transpose(-2, -1).is_contiguous());
415
416 const cusparseDirection_t block_layout = mat.values().is_contiguous()
417 ? CUSPARSE_DIRECTION_ROW
418 : CUSPARSE_DIRECTION_COLUMN;
419
420 c10::MaybeOwned<Tensor> result_ = prepare_dense_vector_for_cusparse(result);
421 c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_cusparse(vec);
422
423 auto block_size = cuda_int_cast(mat.values().size(2), "block_size");
424 auto nnzb = cuda_int_cast(mat._nnz(), "nnzb");
425 auto mb = cuda_int_cast(mat.size(0), "mb") / block_size;
426 auto nb = cuda_int_cast(mat.size(1), "nb") / block_size;
427
428 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
429 result.scalar_type(), "block_sparse_mv", [&] {
430 auto beta_ = beta.to<scalar_t>();
431 auto alpha_ = alpha.to<scalar_t>();
432 auto handle = at::cuda::getCurrentCUDASparseHandle();
433 auto desc = at::cuda::sparse::CuSparseMatDescriptor();
434 auto values = mat.values();
435 auto values_data_ptr = values.data_ptr<scalar_t>();
436 auto crow_indices = mat.crow_indices().to(kInt);
437 auto crow_indices_data_ptr = crow_indices.data_ptr<int>();
438 auto col_indices = mat.col_indices().to(kInt);
439 auto col_indices_data_ptr = col_indices.data_ptr<int>();
440 at::cuda::sparse::bsrmv(
441 handle,
442 block_layout,
443 CUSPARSE_OPERATION_NON_TRANSPOSE,
444 mb,
445 nb,
446 nnzb,
447 &alpha_,
448 desc.descriptor(),
449 values_data_ptr,
450 crow_indices_data_ptr,
451 col_indices_data_ptr,
452 block_size,
453 vec_->data_ptr<scalar_t>(),
454 &beta_,
455 result_->data_ptr<scalar_t>());
456 });
457 if (!result.is_same(*result_)) {
458 result.copy_(*result_);
459 }
460 }
461
block_sparse_mm(const Tensor & input,const at::sparse_csr::SparseCsrTensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)462 void block_sparse_mm(
463 const Tensor& input,
464 const at::sparse_csr::SparseCsrTensor& mat1,
465 const Tensor& mat2,
466 const Scalar& beta,
467 const Scalar& alpha,
468 const Tensor& result) {
469 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.layout() == kSparseBsr);
470 // values is expected to be a blocks of sparse matrix
471 TORCH_INTERNAL_ASSERT(mat1.values().dim() == 3);
472 // blocks are expected to be square
473 TORCH_INTERNAL_ASSERT(mat1.values().size(2) == mat1.values().size(1));
474 // only block of size > 1 is supported in cuSPARSE
475 TORCH_INTERNAL_ASSERT(mat1.values().size(-1) > 1);
476 // blocks are expected to be in row- or column-major order
477 TORCH_INTERNAL_ASSERT(
478 mat1.values().is_contiguous() ||
479 mat1.values().transpose(-2, -1).is_contiguous());
480
481 // NOTE: the code below allows arbitrary block sizes
482 // and might be potentially faster than cuSPARSE implementation
483 // especially for not very sparse inputs.
484 if (mat1.scalar_type() == ScalarType::Half
485 || mat1.scalar_type() == ScalarType::BFloat16
486 || mat1.scalar_type() == ScalarType::Float) {
487 at::native::sparse::impl::_compressed_row_strided_addmm_out(
488 input,
489 mat1,
490 mat2,
491 /*beta=*/beta,
492 /*alpha=*/alpha,
493 // @nikitaved: not sure whether `const Tensor& result` makes sense,
494 // but let's keep the interface intact, hence the const cast.
495 const_cast<Tensor&>(result));
496 return;
497 }
498
499 if (beta.toComplexDouble() != 0. && !result.is_same(input)) {
500 result.copy_(input);
501 }
502
503 const cusparseDirection_t block_layout = mat1.values().is_contiguous()
504 ? CUSPARSE_DIRECTION_ROW
505 : CUSPARSE_DIRECTION_COLUMN;
506
507 c10::MaybeOwned<Tensor> mat2_ = prepare_dense_matrix_for_cusparse(mat2);
508
509 // cuSPARSE expects column-major strides for result and we can't manipulate
510 // transpose flag of mat1
511 c10::MaybeOwned<Tensor> result_ =
512 prepare_column_major_matrix_for_cusparse(result);
513
514 IntArrayRef result_strides = result_->strides();
515 IntArrayRef mat2_strides = mat2_->strides();
516 auto ndim = result_->dim();
517
518 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim == 2);
519 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.dim() == 2);
520 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat2.dim() == 2);
521
522 bool is_mat2_row_major = (mat2_strides[ndim - 1] == 1);
523 int ldb = is_mat2_row_major ? cuda_int_cast(mat2_strides[ndim - 2], "ldb")
524 : cuda_int_cast(mat2_strides[ndim - 1], "ldb");
525 int ldc = cuda_int_cast(result_strides[ndim - 1], "ldc");
526 auto block_size = cuda_int_cast(mat1.values().size(2), "block_size");
527 auto nnzb = cuda_int_cast(mat1._nnz(), "nnzb");
528 auto mb = cuda_int_cast(mat1.size(0), "mb") / block_size;
529 auto kb = cuda_int_cast(mat1.size(1), "nb") / block_size;
530 auto n = cuda_int_cast(mat2.size(1), "n");
531
532 // according to cuSPARSE documentation, opA can only be NON_TRANSPOSE
533 cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
534 cusparseOperation_t opB = is_mat2_row_major
535 ? CUSPARSE_OPERATION_TRANSPOSE
536 : CUSPARSE_OPERATION_NON_TRANSPOSE;
537
538 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
539 result.scalar_type(), "block_sparse_mm", [&] {
540 auto beta_ = beta.to<scalar_t>();
541 auto alpha_ = alpha.to<scalar_t>();
542 auto handle = at::cuda::getCurrentCUDASparseHandle();
543 auto desc = at::cuda::sparse::CuSparseMatDescriptor();
544
545 auto values = mat1.values();
546 auto values_data_ptr = values.data_ptr<scalar_t>();
547 auto crow_indices = mat1.crow_indices().to(kInt);
548 auto crow_indices_data_ptr = crow_indices.data_ptr<int>();
549 auto col_indices = mat1.col_indices().to(kInt);
550 auto col_indices_data_ptr = col_indices.data_ptr<int>();
551
552 at::cuda::sparse::bsrmm(
553 handle,
554 block_layout,
555 opA,
556 opB,
557 mb,
558 n,
559 kb,
560 nnzb,
561 &alpha_,
562 desc.descriptor(),
563 values_data_ptr,
564 crow_indices_data_ptr,
565 col_indices_data_ptr,
566 block_size,
567 mat2_->data_ptr<scalar_t>(),
568 ldb,
569 &beta_,
570 result_->data_ptr<scalar_t>(),
571 ldc);
572 });
573
574 if (!result.is_same(*result_)) {
575 result.copy_(*result_);
576 }
577 }
578
spmm(const at::sparse_csr::SparseCsrTensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)579 void spmm(
580 const at::sparse_csr::SparseCsrTensor& mat1,
581 const Tensor& mat2,
582 const Scalar& beta,
583 const Scalar& alpha,
584 const Tensor& result) {
585 #if !(AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API())
586 addmm_out_legacy(mat1, mat2, beta, alpha, result);
587 #else
588 c10::MaybeOwned<Tensor> result_ = prepare_dense_matrix_for_cusparse(result);
589 c10::MaybeOwned<Tensor> mat2_ = prepare_dense_matrix_for_cusparse(mat2);
590
591 // Here subscript "c" stands for column-major, subscript "r" stands for
592 // row-major order Both orders are supported by cuSPARSE. For mixed input we
593 // need to cast 'mat2' to order of 'result'. We compute
594 // result = mat1 @ op(mat2) + result.
595 // If order of 'mat2' and 'result' matches, the op is
596 // identity; op(mat2) == mat2. If 'result' is column-major and 'mat2' is
597 // row-major we pass 'mat2' as column-major and compute
598 // result_c = mat1 @ transpose(mat2_c) + result_c; mat2_r==transpose(mat2_c)
599 // if 'result' is row-major and 'mat2' is column-major we pass 'mat2'
600 // as row-major and compute
601 // result_r = mat1 @ transpose(mat2_r) + result_r; mat2_c==transpose(mat2_r)
602 IntArrayRef result_strides = result_->strides();
603 IntArrayRef mat2_strides = mat2_->strides();
604 auto ndim = result_->dim();
605 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim == 2 || ndim == 3);
606 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.dim() == 2 || mat1.dim() == 3);
607 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat2.dim() == 2 || mat2.dim() == 3);
608 bool is_result_row_major = (result_strides[ndim - 1] == 1);
609 bool is_mat2_row_major = (mat2_strides[ndim - 1] == 1);
610 bool transpose_B = (is_result_row_major ^ is_mat2_row_major);
611
612 cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
613 cusparseOperation_t opB = transpose_B ? CUSPARSE_OPERATION_TRANSPOSE
614 : CUSPARSE_OPERATION_NON_TRANSPOSE;
615
616 // CUDA < 11.0 doesn't support 64-bit indices and doesn't raise an error about this
617 // silently returning incorrect results
618 #if defined(USE_ROCM)
619 auto mat1_32 = at::native::_sparse_csr_tensor_unsafe(
620 mat1.crow_indices().to(kInt),
621 mat1.col_indices().to(kInt),
622 mat1.values(),
623 mat1.sizes(),
624 mat1.scalar_type(),
625 mat1.layout(),
626 mat1.device());
627 auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1_32);
628 auto algorithm = CUSPARSE_MM_ALG_DEFAULT;
629 #else
630 // TODO: update this to support COO sparse layout
631 auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat1);
632 auto algorithm = CUSPARSE_SPMM_CSR_ALG2;
633 #endif
634
635 auto descB = at::cuda::sparse::CuSparseConstDnMatDescriptor(
636 transpose_B ? mat2_->mT() : *mat2_);
637 auto descC = at::cuda::sparse::CuSparseDnMatDescriptor(*result_);
638
639 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
640 kHalf,
641 kBFloat16,
642 result.scalar_type(),
643 "spmm",
644 [&] {
645 using opmath_t = at::opmath_type<scalar_t>;
646 auto beta_ = beta.to<opmath_t>();
647 auto alpha_ = alpha.to<opmath_t>();
648 cudaDataType compute_type = at::cuda::getCudaDataType<opmath_t>();
649 auto handle = at::cuda::getCurrentCUDASparseHandle();
650
651 size_t buffer_size;
652 TORCH_CUDASPARSE_CHECK(cusparseSpMM_bufferSize(
653 handle,
654 opA,
655 opB,
656 &alpha_,
657 descA.descriptor(),
658 descB.unsafe_mutable_descriptor(),
659 &beta_,
660 descC.descriptor(),
661 compute_type,
662 algorithm,
663 &buffer_size // output
664 ));
665
666 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
667 auto work_data = allocator.allocate(buffer_size);
668
669 TORCH_CUDASPARSE_CHECK(cusparseSpMM(
670 handle,
671 opA,
672 opB,
673 &alpha_,
674 descA.descriptor(),
675 descB.unsafe_mutable_descriptor(),
676 &beta_,
677 descC.descriptor(),
678 compute_type,
679 algorithm,
680 work_data.get()));
681 });
682
683 if (!result.is_same(*result_)) {
684 result.copy_(*result_);
685 }
686 #endif // !(AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API())
687 }
688
spgemm(const at::sparse_csr::SparseCsrTensor & A,const at::sparse_csr::SparseCsrTensor & B,const Scalar & beta,const Scalar & alpha,const at::sparse_csr::SparseCsrTensor & C)689 void spgemm(
690 const at::sparse_csr::SparseCsrTensor& A,
691 const at::sparse_csr::SparseCsrTensor& B,
692 const Scalar& beta,
693 const Scalar& alpha,
694 const at::sparse_csr::SparseCsrTensor& C) {
695 // older versions of cusparse on Windows segfault for complex128 dtype
696 #if defined(_WIN32) && defined(CUSPARSE_VERSION) && CUSPARSE_VERSION < 11400
697 TORCH_CHECK(
698 !(A.scalar_type() == ScalarType::ComplexDouble),
699 "Sparse multiplication with complex128 dtype inputs is not supported with current CUDA version. Please upgrade to CUDA Toolkit 11.2.1+");
700 #endif
701
702 IntArrayRef A_sizes = A.sizes();
703 auto ndim = A.dim();
704 auto m = A_sizes[ndim - 2];
705
706 IntArrayRef B_sizes = B.sizes();
707 auto n = B_sizes[ndim - 1];
708
709 // Only 32-bit indices are supported
710 auto A_32 = at::native::_sparse_csr_tensor_unsafe(A.crow_indices().to(kInt), A.col_indices().to(kInt), A.values(), A.sizes(), A.scalar_type(), A.layout(), A.device());
711 auto B_32 = at::native::_sparse_csr_tensor_unsafe(B.crow_indices().to(kInt), B.col_indices().to(kInt), B.values(), B.sizes(), B.scalar_type(), B.layout(), B.device());
712
713 // Modify C tensor in-place to swap indices tensors with 32-bit variants
714 indices_to_32_bit_inplace(C);
715
716 auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(A_32);
717 auto descB = at::cuda::sparse::CuSparseSpMatCsrDescriptor(B_32);
718 auto descC = at::cuda::sparse::CuSparseSpMatCsrDescriptor(C);
719
720 auto spgemm_desc = at::cuda::sparse::CuSparseSpGEMMDescriptor();
721 cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
722 cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
723
724 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
725 kHalf,
726 kBFloat16,
727 C.scalar_type(),
728 "spgemm",
729 [&] {
730 auto beta_ = beta.to<scalar_t>();
731 auto alpha_ = alpha.to<scalar_t>();
732 auto compute_type = at::cuda::getCudaDataType<scalar_t>();
733 auto handle = at::cuda::getCurrentCUDASparseHandle();
734
735 // It's required to call workEstimation twice
736 size_t buffer_size1 = 0;
737 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
738 handle,
739 opA,
740 opB,
741 &alpha_,
742 descA.descriptor(),
743 descB.descriptor(),
744 &beta_,
745 descC.descriptor(),
746 compute_type,
747 CUSPARSE_SPGEMM_DEFAULT,
748 spgemm_desc.descriptor(),
749 &buffer_size1,
750 nullptr));
751
752 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
753 auto buffer1 = allocator.allocate(buffer_size1);
754
755 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
756 handle,
757 opA,
758 opB,
759 &alpha_,
760 descA.descriptor(),
761 descB.descriptor(),
762 &beta_,
763 descC.descriptor(),
764 compute_type,
765 CUSPARSE_SPGEMM_DEFAULT,
766 spgemm_desc.descriptor(),
767 &buffer_size1,
768 buffer1.get()));
769
770 // It's required to call compute twice
771 size_t buffer_size2 = 0;
772 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute(
773 handle,
774 opA,
775 opB,
776 &alpha_,
777 descA.descriptor(),
778 descB.descriptor(),
779 &beta_,
780 descC.descriptor(),
781 compute_type,
782 CUSPARSE_SPGEMM_DEFAULT,
783 spgemm_desc.descriptor(),
784 &buffer_size2,
785 nullptr));
786
787 auto buffer2 = allocator.allocate(buffer_size2);
788
789 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute(
790 handle,
791 opA,
792 opB,
793 &alpha_,
794 descA.descriptor(),
795 descB.descriptor(),
796 &beta_,
797 descC.descriptor(),
798 compute_type,
799 CUSPARSE_SPGEMM_DEFAULT,
800 spgemm_desc.descriptor(),
801 &buffer_size2,
802 buffer2.get()));
803
804 // Get how many specified elements are there in C
805 auto [C_num_rows, C_num_cols, C_nnz] = descC.get_size();
806
807 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(C_num_rows == m);
808 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(C_num_cols == n);
809
810 // Resize result using nnz information from cusparse
811 col_indices_and_values_resize_(C, C_nnz);
812
813 // Update matC with the new pointers
814 descC.set_tensor(C);
815
816 // Copy the data into C
817 TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_copy(
818 handle,
819 opA,
820 opB,
821 &alpha_,
822 descA.descriptor(),
823 descB.descriptor(),
824 &beta_,
825 descC.descriptor(),
826 compute_type,
827 CUSPARSE_SPGEMM_DEFAULT,
828 spgemm_desc.descriptor()));
829 });
830 }
831
832 } // anonymous namespace
833
addmm_out_sparse_csr(const Tensor & input,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)834 void addmm_out_sparse_csr(
835 const Tensor& input,
836 const Tensor& mat1,
837 const Tensor& mat2,
838 const Scalar& beta,
839 const Scalar& alpha,
840 const Tensor& result) {
841 TORCH_INTERNAL_ASSERT(
842 !((mat1.layout() == kStrided) && (mat2.layout() == kStrided) &&
843 (result.layout() == kStrided)),
844 "Expected at least one sparse input");
845
846 // Layout checks are nested mat1, mat2, result
847 // Conditions are ordered strided, csr, csc, bsr, bsc.
848 // Valid combinations terminate in a return
849 // Invalid combinations are omitted and will fall though to the TORCH check
850 // generating an informative error message
851
852 // mm functions that copy input to result when needed (e.g. mm
853 // triton kernels do not require result being initialized with
854 // input):
855 if (mat1.layout() == kSparseBsr) {
856 if (mat2.layout() == kStrided) {
857 if (result.layout() == kStrided)
858 return block_sparse_mm(input, mat1, mat2, beta, alpha, result);
859 }
860 }
861
862 if (mat1.layout() == kStrided) {
863 if (mat2.layout() == kSparseBsc) {
864 if (result.layout() == kStrided) {
865 auto result_t = result.transpose(-2, -1);
866 auto input_t = (result.is_same(input) ? result_t : input.transpose(-2, -1));
867 return block_sparse_mm(
868 input_t,
869 mat2.transpose(-2, -1),
870 mat1.transpose(-2, -1),
871 beta,
872 alpha,
873 result_t);
874 }
875 }
876 }
877
878 // copy input to result:
879 if (beta.toComplexDouble() != 0. && !result.is_same(input)) {
880 result.copy_(input);
881 }
882
883 // mm functions that assume that result contains input:
884 if (mat1.layout() == kStrided) {
885 if (mat2.layout() == kSparseCsr) {
886 if (result.layout() == kStrided) {
887 // TODO: Add native CSC support via cuSPARSE if supported.
888 return spmm(
889 mat2.transpose(0, 1).to_sparse_csr(),
890 mat1.transpose(0, 1),
891 beta,
892 alpha,
893 result.transpose(0, 1));
894 }
895 }
896 if (mat2.layout() == kSparseCsc) {
897 if (result.layout() == kStrided) {
898 return spmm(
899 mat2.transpose(-2, -1),
900 mat1.transpose(-2, -1),
901 beta,
902 alpha,
903 result.transpose(-2, -1));
904 }
905 }
906 }
907 if (mat1.layout() == kSparseCsr) {
908 if (mat2.layout() == kStrided) {
909 if (result.layout() == kStrided) {
910 return spmm(mat1, mat2, beta, alpha, result);
911 }
912 }
913 if (mat2.layout() == kSparseCsr) {
914 if (result.layout() == kSparseCsr) {
915 return spgemm(mat1, mat2, beta, alpha, result);
916 }
917 }
918 if (mat2.layout() == kSparseCsc) {
919 if (result.layout() == kSparseCsr) {
920 // TODO: Add native CSC support via cuSPARSE if supported.
921 // CSR @ CSC kernel would be very fast due to format alignment
922 return spgemm(mat1, mat2.to_sparse_csr(), beta, alpha, result);
923 }
924 }
925 }
926 if (mat1.layout() == kSparseCsc) {
927 if (mat2.layout() == kStrided) {
928 if (result.layout() == kStrided) {
929 // TODO: Add native CSC support via cuSPARSE if supported.
930 return spmm(mat1.to_sparse_csr(), mat2, beta, alpha, result);
931 }
932 }
933 if (mat2.layout() == kSparseCsr) {
934 if (result.layout() == kSparseCsr)
935 // TODO: Add native CSC support via cuSPARSE if supported.
936 return spgemm(mat1.to_sparse_csr(), mat2, beta, alpha, result);
937 }
938 if (mat2.layout() == kSparseCsc) {
939 if (result.layout() == kSparseCsr) {
940 // TODO: Add native CSC support via cuSPARSE if supported.
941 return spgemm(
942 mat1.to_sparse_csr(), mat2.to_sparse_csr(), beta, alpha, result);
943 }
944 if (result.layout() == kSparseCsc) {
945 return spgemm(
946 mat2.transpose(-2, -1),
947 mat1.transpose(-2, -1),
948 beta,
949 alpha,
950 result.transpose(-2, -1));
951 }
952 }
953 }
954 TORCH_CHECK(
955 false,
956 "addmm: computation on CUDA is not implemented for ",
957 result.layout(),
958 " + ",
959 mat1.layout(),
960 " @ ",
961 mat2.layout());
962 }
963
964 /*
965 Computes a sparse matrix-dense vector product defined as
966 y <- alpha*op(A)*x + beta*y
967
968 Args:
969 * `mat` - Tensor storing sparse m x n matrix A.
970 * `vec` - Tensor storing dense vector x of size n.
971 * `result` - [in] Tensor storing dense vector y of size m.
972 [out] result of the operation.
973 */
addmv_out_sparse_csr(const at::sparse_csr::SparseCsrTensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,const Tensor & result)974 void addmv_out_sparse_csr(
975 const at::sparse_csr::SparseCsrTensor& mat,
976 const Tensor& vec,
977 const Scalar& beta,
978 const Scalar& alpha,
979 const Tensor& result) {
980 if (mat.layout() == kSparseBsr) {
981 return block_sparse_mv(mat, vec, beta, alpha, result);
982 }
983 #if !(AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API())
984 TORCH_CHECK(
985 false,
986 "Calling addmv on a sparse GPU tensor requires compiling ",
987 "PyTorch with CUDA 10.2+ (CUDA 11+ on Windows). ",
988 "Please use PyTorch built with newer CUDA version.");
989 #else
990 cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
991
992 c10::MaybeOwned<Tensor> result_ = prepare_dense_vector_for_cusparse(result);
993 c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_cusparse(vec);
994
995 // TODO: update this to support COO sparse layout
996 auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(mat);
997 auto descX = at::cuda::sparse::CuSparseDnVecDescriptor(*vec_);
998 auto descY = at::cuda::sparse::CuSparseDnVecDescriptor(*result_);
999
1000 // cusparseSpMVAlg_t was updated in cuda 11.2.1 (cusparse 11.4.0)
1001 #if CUSPARSE_VERSION >= 11400
1002 cusparseSpMVAlg_t alg = CUSPARSE_SPMV_ALG_DEFAULT;
1003 #else
1004 cusparseSpMVAlg_t alg = CUSPARSE_MV_ALG_DEFAULT;
1005 #endif
1006
1007 // SpMV doesn't support uniform precision computation
1008 // For float16/bfloat16 inputs compute_type must be CUDA_R_32F
1009 // and type of alpha, beta must be float
1010 auto dispatch_scalar_type = result.scalar_type();
1011 if (dispatch_scalar_type == at::ScalarType::Half ||
1012 dispatch_scalar_type == at::ScalarType::BFloat16) {
1013 dispatch_scalar_type = at::ScalarType::Float;
1014 }
1015
1016 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1017 dispatch_scalar_type,
1018 "addmv_out_sparse_csr_cuda_impl",
1019 [&] {
1020 auto beta_ = beta.to<scalar_t>();
1021 auto alpha_ = alpha.to<scalar_t>();
1022 cudaDataType compute_type = at::cuda::getCudaDataType<scalar_t>();
1023 auto handle = at::cuda::getCurrentCUDASparseHandle();
1024
1025 size_t buffer_size;
1026 TORCH_CUDASPARSE_CHECK(cusparseSpMV_bufferSize(
1027 handle,
1028 opA,
1029 &alpha_,
1030 descA.descriptor(),
1031 descX.descriptor(),
1032 &beta_,
1033 descY.descriptor(),
1034 compute_type,
1035 alg,
1036 &buffer_size // output
1037 ));
1038
1039 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1040 auto work_data = allocator.allocate(buffer_size);
1041
1042 TORCH_CUDASPARSE_CHECK(cusparseSpMV(
1043 handle,
1044 opA,
1045 &alpha_,
1046 descA.descriptor(),
1047 descX.descriptor(),
1048 &beta_,
1049 descY.descriptor(),
1050 compute_type,
1051 alg,
1052 work_data.get()));
1053 });
1054 if (!result.is_same(*result_)) {
1055 result.copy_(*result_);
1056 }
1057 #endif // !(AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API())
1058 }
1059
1060 /*
1061 Computes C = alpha * A + beta * B
1062
1063 Args:
1064 * `A` - [in] sparse Tensor of size m × n.
1065 * `B` - [in] sparse Tensor of size m × n.
1066 * `C` - [out] sparse Tensor of size m × n.
1067 */
add_out_sparse_csr(const at::sparse_csr::SparseCsrTensor & A,const at::sparse_csr::SparseCsrTensor & B,const Scalar & alpha,const Scalar & beta,const at::sparse_csr::SparseCsrTensor & C)1068 void add_out_sparse_csr(
1069 const at::sparse_csr::SparseCsrTensor& A,
1070 const at::sparse_csr::SparseCsrTensor& B,
1071 const Scalar& alpha,
1072 const Scalar& beta,
1073 const at::sparse_csr::SparseCsrTensor& C) {
1074 IntArrayRef A_sizes = A.sizes();
1075 auto ndim = A.dim();
1076 int m = at::native::cuda_int_cast(A_sizes[ndim - 2], "m");
1077 int n = at::native::cuda_int_cast(A_sizes[ndim - 1], "n");
1078
1079 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.sizes().equals(B.sizes()) && A.sizes().equals(C.sizes()));
1080
1081 // Only 32-bit indices are supported
1082 const auto output_indices_dtype = promoteTypes(A.crow_indices().scalar_type(), B.crow_indices().scalar_type());
1083 auto A_32 = at::native::_sparse_csr_tensor_unsafe(
1084 A.crow_indices().to(kInt),
1085 A.col_indices().to(kInt),
1086 A.values(),
1087 A.sizes(),
1088 A.scalar_type(),
1089 A.layout(),
1090 A.device());
1091 auto B_32 = at::native::_sparse_csr_tensor_unsafe(
1092 B.crow_indices().to(kInt),
1093 B.col_indices().to(kInt),
1094 B.values(),
1095 B.sizes(),
1096 B.scalar_type(),
1097 B.layout(),
1098 B.device());
1099
1100 // Modify C tensor in-place to swap indices tensors with 32-bit variants
1101 auto C_crow_indices_backup = C.crow_indices();
1102 auto C_col_indices_backup = C.col_indices();
1103 indices_to_32_bit_inplace(C); // no-op with 32-bit indices
1104
1105 int nnzA = at::native::cuda_int_cast(A_32._nnz(), "nnzA");
1106 int nnzB = at::native::cuda_int_cast(B_32._nnz(), "nnzB");
1107
1108 auto desc = at::cuda::sparse::CuSparseMatDescriptor();
1109
1110 auto A_crow_indices = A_32.crow_indices();
1111 auto B_crow_indices = B_32.crow_indices();
1112 auto C_crow_indices = C.crow_indices();
1113 auto A_crow_indices_ptr = A_crow_indices.data_ptr<int>();
1114 auto B_crow_indices_ptr = B_crow_indices.data_ptr<int>();
1115 auto C_crow_indices_ptr = C_crow_indices.data_ptr<int>();
1116
1117 auto A_col_indices = A_32.col_indices();
1118 auto B_col_indices = B_32.col_indices();
1119 auto C_col_indices = C.col_indices();
1120 auto A_col_indices_ptr = A_col_indices.data_ptr<int>();
1121 auto B_col_indices_ptr = B_col_indices.data_ptr<int>();
1122 auto C_col_indices_ptr = C_col_indices.data_ptr<int>();
1123
1124 // Windows compilers don't support nested macros
1125 // so we need this lambda outside of the
1126 // AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES
1127 auto fix_nnz = [
1128 #if AT_ROCM_ENABLED()
1129 &C_crow_indices,
1130 &m
1131 #endif
1132 ](int nnz) -> int {
1133 // For some reason POINTER_MODE_HOST is not working here
1134 // Let's extract manually the nnz from the C_crow_indices
1135 #if AT_ROCM_ENABLED()
1136 return std::max({nnz, C_crow_indices.narrow(-1, m, 1).item<int>()});
1137 #else
1138 return nnz;
1139 #endif
1140 };
1141
1142 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1143 C.scalar_type(), "add_out_sparse_csr_cuda_impl", [&] {
1144 auto beta_ = beta.to<scalar_t>();
1145 auto alpha_ = alpha.to<scalar_t>();
1146
1147 auto A_values = A_32.values();
1148 auto B_values = B_32.values();
1149 auto C_values = C.values();
1150 auto A_values_ptr = A_values.data_ptr<scalar_t>();
1151 auto B_values_ptr = B_values.data_ptr<scalar_t>();
1152 auto C_values_ptr = C_values.data_ptr<scalar_t>();
1153
1154 auto handle = at::cuda::getCurrentCUDASparseHandle();
1155 TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST));
1156
1157 size_t buffer_size;
1158 at::cuda::sparse::csrgeam2_bufferSizeExt<scalar_t>(
1159 handle,
1160 m,
1161 n,
1162 &alpha_,
1163 desc.descriptor(),
1164 nnzA,
1165 A_values_ptr,
1166 A_crow_indices_ptr,
1167 A_col_indices_ptr,
1168 &beta_,
1169 desc.descriptor(),
1170 nnzB,
1171 B_values_ptr,
1172 B_crow_indices_ptr,
1173 B_col_indices_ptr,
1174 desc.descriptor(),
1175 C_values_ptr,
1176 C_crow_indices_ptr,
1177 C_col_indices_ptr,
1178 &buffer_size // output
1179 );
1180
1181 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1182 auto work_data = allocator.allocate(buffer_size);
1183
1184 int nnzC = -1;
1185 at::cuda::sparse::csrgeam2Nnz<scalar_t>(
1186 handle,
1187 m,
1188 n,
1189 desc.descriptor(),
1190 nnzA,
1191 A_crow_indices_ptr,
1192 A_col_indices_ptr,
1193 desc.descriptor(),
1194 nnzB,
1195 B_crow_indices_ptr,
1196 B_col_indices_ptr,
1197 desc.descriptor(),
1198 C_crow_indices_ptr,
1199 &nnzC,
1200 work_data.get());
1201
1202 nnzC = fix_nnz(nnzC);
1203
1204 // Resize result using nnz information from cusparse
1205 col_indices_and_values_resize_(C, nnzC);
1206 C_col_indices = C.col_indices();
1207 C_values = C.values();
1208
1209 C_col_indices_ptr = C_col_indices.data_ptr<int>();
1210 C_values_ptr = C_values.data_ptr<scalar_t>();
1211
1212 at::cuda::sparse::csrgeam2<scalar_t>(
1213 handle,
1214 m,
1215 n,
1216 &alpha_,
1217 desc.descriptor(),
1218 nnzA,
1219 A_values_ptr,
1220 A_crow_indices_ptr,
1221 A_col_indices_ptr,
1222 &beta_,
1223 desc.descriptor(),
1224 nnzB,
1225 B_values_ptr,
1226 B_crow_indices_ptr,
1227 B_col_indices_ptr,
1228 desc.descriptor(),
1229 C_values_ptr,
1230 C_crow_indices_ptr,
1231 C_col_indices_ptr,
1232 work_data.get());
1233
1234 if (output_indices_dtype == at::kLong) {
1235 static_cast<SparseCsrTensorImpl*>(C.unsafeGetTensorImpl())->set_member_tensors(
1236 C_crow_indices_backup.copy_(C.crow_indices()),
1237 C_col_indices_backup.resize_({nnzC}).copy_(C.col_indices()),
1238 C.values(),
1239 C.sizes());
1240 }
1241 });
1242 }
1243
1244 /*
1245 Solves a system of linear equations whose coefficients are represented in a sparse triangular matrix A:
1246 op(A) X = B.
1247
1248 Args:
1249 * `A` - sparse Tensor of size m × m.
1250 * `B` - dense Tensor of size m × nrhs.
1251 * `X` - dense Tensor of size m × nrhs.
1252 * `upper` - controls whether upper or lower triangular part of A is considered in computations.
1253 * `transpose` - if true then op(A) = A^T.
1254 * `unitriangular` - if true then the diagonal elements of A are assumed to be one.
1255 */
triangular_solve_out_sparse_csr(const at::sparse_csr::SparseCsrTensor & A,const Tensor & B,const Tensor & X,bool upper,bool transpose,bool unitriangular)1256 void triangular_solve_out_sparse_csr(
1257 const at::sparse_csr::SparseCsrTensor& A,
1258 const Tensor& B,
1259 const Tensor& X,
1260 bool upper,
1261 bool transpose,
1262 bool unitriangular) {
1263 if (B.numel() == 0 || X.numel() == 0 || A._nnz() == 0) {
1264 // If A has no nnz, then A is singular and we can't solve.
1265 X.fill_(NAN);
1266 return;
1267 }
1268 if (A.layout() == kSparseBsr) {
1269 if (B.size(-1) == 1) {
1270 return block_sparse_triangular_solve_vec(A, B, X, upper, transpose, unitriangular);
1271 } else {
1272 return block_sparse_triangular_solve_mat(A, B, X, upper, transpose, unitriangular);
1273 }
1274 }
1275 #if !AT_USE_CUSPARSE_GENERIC_SPSV()
1276 TORCH_CHECK(
1277 false,
1278 "Calling triangular solve on a sparse GPU tensor requires compiling ",
1279 "PyTorch with at least CUDA 11.3. ",
1280 "Please use PyTorch built with newer CUDA version.");
1281 #else
1282 c10::MaybeOwned<Tensor> X_ = prepare_dense_matrix_for_cusparse(X);
1283 // It should be possible to use mixed memory format
1284 // but there is a bug in CUDA 11.3.1 version:
1285 // strides of matrix B are used to write result to matrix X.
1286 // As a workaround we need to convert matrices to have the same strides.
1287 c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_cusparse(B, X_->strides());
1288
1289 // TODO: update this to support COO sparse layout
1290 auto descA = at::cuda::sparse::CuSparseSpMatCsrDescriptor(A);
1291 descA.set_mat_fill_mode(upper);
1292 descA.set_mat_diag_type(unitriangular);
1293 cusparseOperation_t opA = transpose ? CUSPARSE_OPERATION_TRANSPOSE
1294 : CUSPARSE_OPERATION_NON_TRANSPOSE;
1295
1296 if (B.size(-1) == 1) {
1297 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1298 X.scalar_type(), "triangular_solve_out_sparse_csr_cuda_impl", [&] {
1299 scalar_t alpha = 1;
1300 cudaDataType compute_type = at::cuda::getCudaDataType<scalar_t>();
1301 auto handle = at::cuda::getCurrentCUDASparseHandle();
1302 size_t buffer_size;
1303
1304 auto desc_spsv = at::cuda::sparse::CuSparseSpSVDescriptor();
1305 auto descB = at::cuda::sparse::CuSparseDnVecDescriptor(*B_);
1306 auto descX = at::cuda::sparse::CuSparseDnVecDescriptor(*X_);
1307 TORCH_CUDASPARSE_CHECK(cusparseSpSV_bufferSize(
1308 handle,
1309 opA,
1310 &alpha,
1311 descA.descriptor(),
1312 descB.descriptor(),
1313 descX.descriptor(),
1314 compute_type,
1315 CUSPARSE_SPSV_ALG_DEFAULT,
1316 desc_spsv.descriptor(),
1317 &buffer_size // output
1318 ));
1319
1320 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1321 auto work_data = allocator.allocate(buffer_size);
1322
1323 TORCH_CUDASPARSE_CHECK(cusparseSpSV_analysis(
1324 handle,
1325 opA,
1326 &alpha,
1327 descA.descriptor(),
1328 descB.descriptor(),
1329 descX.descriptor(),
1330 compute_type,
1331 CUSPARSE_SPSV_ALG_DEFAULT,
1332 desc_spsv.descriptor(),
1333 work_data.get()));
1334
1335 TORCH_CUDASPARSE_CHECK(cusparseSpSV_solve(
1336 handle,
1337 opA,
1338 &alpha,
1339 descA.descriptor(),
1340 descB.descriptor(),
1341 descX.descriptor(),
1342 compute_type,
1343 CUSPARSE_SPSV_ALG_DEFAULT,
1344 desc_spsv.descriptor()));
1345 });
1346 } else {
1347 #if !AT_USE_CUSPARSE_GENERIC_SPSM()
1348 TORCH_CHECK(
1349 false,
1350 "Calling triangular solve on a sparse GPU tensor requires compiling ",
1351 "PyTorch with at least CUDA 11.3.1. ",
1352 "Please use PyTorch built with newer CUDA version.");
1353 #else
1354 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1355 X.scalar_type(), "triangular_solve_out_sparse_csr_cuda_impl", [&] {
1356 scalar_t alpha = 1;
1357 cudaDataType compute_type = at::cuda::getCudaDataType<scalar_t>();
1358 auto handle = at::cuda::getCurrentCUDASparseHandle();
1359 size_t buffer_size;
1360
1361 cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
1362 auto desc_spsm = at::cuda::sparse::CuSparseSpSMDescriptor();
1363 auto descB = at::cuda::sparse::CuSparseDnMatDescriptor(*B_);
1364 auto descX = at::cuda::sparse::CuSparseDnMatDescriptor(*X_);
1365 TORCH_CUDASPARSE_CHECK(cusparseSpSM_bufferSize(
1366 handle,
1367 opA,
1368 opB,
1369 &alpha,
1370 descA.descriptor(),
1371 descB.descriptor(),
1372 descX.descriptor(),
1373 compute_type,
1374 CUSPARSE_SPSM_ALG_DEFAULT,
1375 desc_spsm.descriptor(),
1376 &buffer_size // output
1377 ));
1378
1379 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1380 auto work_data = allocator.allocate(buffer_size);
1381
1382 TORCH_CUDASPARSE_CHECK(cusparseSpSM_analysis(
1383 handle,
1384 opA,
1385 opB,
1386 &alpha,
1387 descA.descriptor(),
1388 descB.descriptor(),
1389 descX.descriptor(),
1390 compute_type,
1391 CUSPARSE_SPSM_ALG_DEFAULT,
1392 desc_spsm.descriptor(),
1393 work_data.get()));
1394
1395 TORCH_CUDASPARSE_CHECK(cusparseSpSM_solve(
1396 handle,
1397 opA,
1398 opB,
1399 &alpha,
1400 descA.descriptor(),
1401 descB.descriptor(),
1402 descX.descriptor(),
1403 compute_type,
1404 CUSPARSE_SPSM_ALG_DEFAULT,
1405 desc_spsm.descriptor()));
1406 });
1407 #endif // !AT_USE_CUSPARSE_GENERIC_SPSM()
1408 }
1409 if (!X.is_same(*X_)) {
1410 X.copy_(*X_);
1411 }
1412 #endif // !AT_USE_CUSPARSE_GENERIC_SPSV()
1413 }
1414
sampled_addmm_out_sparse_csr(const Tensor & A,const Tensor & B,const Scalar & beta,const Scalar & alpha,const at::sparse_csr::SparseCsrTensor & C)1415 void sampled_addmm_out_sparse_csr(
1416 const Tensor& A,
1417 const Tensor& B,
1418 const Scalar& beta,
1419 const Scalar& alpha,
1420 const at::sparse_csr::SparseCsrTensor& C) {
1421 #if !(AT_USE_CUSPARSE_GENERIC_SDDMM() || AT_USE_HIPSPARSE_GENERIC_API())
1422 TORCH_CHECK(
1423 false,
1424 "Calling sampled_addmm with sparse GPU tensors requires compiling ",
1425 "PyTorch with CUDA 11.2.1+. ",
1426 "Please use PyTorch built with newer CUDA version.");
1427 #else
1428 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == Layout::Strided);
1429 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(B.layout() == Layout::Strided);
1430 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(C.is_sparse_csr());
1431
1432 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(A) == batchCount(B));
1433 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(A) == batchCount(C));
1434
1435 cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
1436 cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
1437
1438 c10::MaybeOwned<Tensor> A_ = prepare_dense_matrix_for_cusparse(A);
1439 c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_cusparse(B);
1440
1441 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1442 C.scalar_type(),
1443 "sampled_addmm_out_sparse_csr",
1444 [&] {
1445 // CUDA 11.6 doesn't support batched inputs, it raises an error:
1446 // ** On entry to cusparseSDDMM_bufferSize(): batched SDDMM is not supported
1447 // So we need to resort to the for loop
1448 for (const auto i : c10::irange(batchCount(A))) {
1449 auto descA = at::cuda::sparse::CuSparseConstDnMatDescriptor(*A_, /*batch_offset=*/i);
1450 auto descB = at::cuda::sparse::CuSparseConstDnMatDescriptor(*B_, /*batch_offset=*/i);
1451 auto descC = at::cuda::sparse::CuSparseSpMatCsrDescriptor(C, /*batch_offset=*/i);
1452
1453 auto beta_ = beta.to<scalar_t>();
1454 auto alpha_ = alpha.to<scalar_t>();
1455 auto compute_type = at::cuda::getCudaDataType<scalar_t>();
1456 auto handle = at::cuda::getCurrentCUDASparseHandle();
1457 size_t buffer_size = 0;
1458 TORCH_CUDASPARSE_CHECK(cusparseSDDMM_bufferSize(
1459 handle,
1460 opA,
1461 opB,
1462 &alpha_,
1463 descA.unsafe_mutable_descriptor(),
1464 descB.unsafe_mutable_descriptor(),
1465 &beta_,
1466 descC.descriptor(),
1467 compute_type,
1468 CUSPARSE_SDDMM_ALG_DEFAULT,
1469 &buffer_size // output
1470 ));
1471
1472 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
1473 auto buffer = allocator.allocate(buffer_size);
1474
1475 TORCH_CUDASPARSE_CHECK(cusparseSDDMM_preprocess(
1476 handle,
1477 opA,
1478 opB,
1479 &alpha_,
1480 descA.unsafe_mutable_descriptor(),
1481 descB.unsafe_mutable_descriptor(),
1482 &beta_,
1483 descC.descriptor(),
1484 compute_type,
1485 CUSPARSE_SDDMM_ALG_DEFAULT,
1486 buffer.get()));
1487
1488 TORCH_CUDASPARSE_CHECK(cusparseSDDMM(
1489 handle,
1490 opA,
1491 opB,
1492 &alpha_,
1493 descA.unsafe_mutable_descriptor(),
1494 descB.unsafe_mutable_descriptor(),
1495 &beta_,
1496 descC.descriptor(),
1497 compute_type,
1498 CUSPARSE_SDDMM_ALG_DEFAULT,
1499 buffer.get()));
1500 }
1501 });
1502 #endif
1503 }
1504
1505 } // namespace at::native::sparse::impl::cuda
1506