1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/SparseCsrTensorImpl.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/mkl/Sparse.h>
6 #include <ATen/native/LinearAlgebraUtils.h>
7 #include <ATen/SparseCsrTensorUtils.h>
8 #include <ATen/native/mkl/SparseBlasImpl.h>
9
10 #include <c10/core/ScalarType.h>
11 #include <c10/util/MaybeOwned.h>
12
13 #if AT_USE_MKL_SPARSE()
14 #include <ATen/mkl/SparseBlas.h>
15 #include <ATen/mkl/SparseDescriptors.h>
16 #include <ATen/mkl/Utils.h>
17 #endif
18
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #include <ATen/NativeFunctions.h>
22 #else
23 #include <ATen/ops/cat.h>
24 #include <ATen/ops/sparse_coo_tensor.h>
25 #endif
26
27 namespace at {
28 namespace native {
29 namespace sparse {
30 namespace impl {
31 namespace mkl {
32
33 namespace {
34
35 #if AT_USE_MKL_SPARSE()
prepare_dense_matrix_for_mkl(const Tensor & tensor)36 c10::MaybeOwned<Tensor> prepare_dense_matrix_for_mkl(
37 const Tensor& tensor) {
38 if (tensor.is_non_overlapping_and_dense() ||
39 is_blas_compatible_row_major_order(tensor) ||
40 is_blas_compatible_column_major_order(tensor)) {
41 return at::native::expect_resolved_conj(tensor);
42 } else {
43 return c10::MaybeOwned<Tensor>::owned(
44 tensor.clone(at::MemoryFormat::Contiguous));
45 }
46 }
47
48 /*
49 Get row-major or column-major matrix.
50
51 Args:
52 * `tensor` - 2D strided Tensor.
53 * `row_major` - controls the memory layout.
54 */
prepare_dense_matrix_for_mkl(const Tensor & tensor,bool row_major)55 c10::MaybeOwned<Tensor> prepare_dense_matrix_for_mkl(
56 const Tensor& tensor,
57 bool row_major) {
58 if (is_blas_compatible_row_major_order(tensor) && row_major) {
59 return at::native::expect_resolved_conj(tensor);
60 } else {
61 if (row_major) {
62 return c10::MaybeOwned<Tensor>::owned(
63 tensor.clone(at::MemoryFormat::Contiguous));
64 } else {
65 return c10::MaybeOwned<Tensor>::owned(cloneBatchedColumnMajor(tensor));
66 }
67 }
68 }
69
prepare_dense_vector_for_mkl(const Tensor & tensor)70 c10::MaybeOwned<Tensor> inline prepare_dense_vector_for_mkl(
71 const Tensor& tensor) {
72 if (tensor.is_non_overlapping_and_dense()) {
73 return c10::MaybeOwned<Tensor>::borrowed(tensor);
74 } else {
75 return c10::MaybeOwned<Tensor>::owned(
76 tensor.clone(at::MemoryFormat::Contiguous));
77 }
78 }
79
indices_to_mkl_compatible_inplace(const Tensor & input)80 void inline indices_to_mkl_compatible_inplace(const Tensor& input) {
81 #ifdef MKL_ILP64
82 // ILP64 is a 64-bit API version of MKL
83 // Indices tensor must have ScalarType::Long type
84 static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
85 ->set_member_tensors(
86 input.crow_indices().to(kLong),
87 input.col_indices().to(kLong),
88 input.values(),
89 input.sizes());
90 #else
91 // LP64 is a 32-bit API version of MKL
92 // Indices tensor must have ScalarType::Int type
93 static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
94 ->set_member_tensors(
95 input.crow_indices().to(kInt),
96 input.col_indices().to(kInt),
97 input.values(),
98 input.sizes());
99 #endif
100 }
101
col_indices_and_values_resize_(const Tensor & input,int64_t nnz)102 void inline col_indices_and_values_resize_(const Tensor& input, int64_t nnz) {
103 static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
104 ->set_member_tensors(
105 input.crow_indices(),
106 input.col_indices().resize_({nnz}),
107 input.values().resize_({nnz}),
108 input.sizes());
109 }
110
111 /*
112 Resizes `input` tensor and fills it with the data from MKL.
113 */
114 template <typename scalar_t>
mkl_result_copy_(const Tensor & input,sparse_matrix_t mkl_desc)115 void mkl_result_copy_(const Tensor& input, sparse_matrix_t mkl_desc) {
116 sparse_index_base_t indexing = SPARSE_INDEX_BASE_ZERO;
117 MKL_INT rows, cols;
118 MKL_INT *rows_start = nullptr, *rows_end = nullptr, *columns = nullptr;
119 scalar_t* values = nullptr;
120 at::mkl::sparse::export_csr(
121 mkl_desc,
122 &indexing,
123 &rows,
124 &cols,
125 &rows_start,
126 &rows_end,
127 &columns,
128 &values);
129
130 // Resize input using nnz information from MKL
131 MKL_INT nnz = rows_end[rows - 1];
132 col_indices_and_values_resize_(input, nnz);
133
134 auto crow_indices = input.crow_indices();
135 auto col_indices = input.col_indices();
136 auto input_values = input.values();
137
138 // NB: When nnz is zero it is possible that input_values.data_ptr<scalar_t> is
139 // a nullptr, if input was created via empty. As such we need to check that
140 // nnz is not zero to avoid passing nullptr to std::memcpy. We will apply
141 // the same precautions to crow_indices.data_ptr<MKL_INT>.
142 //
143 // Otherwise ASAN will complain.
144
145 if (nnz > 0) {
146 // MKL Sparse Inspector-Executor doesn't have a way to provide external
147 // buffers So we have to copy the memory allocated by MKL
148 std::memcpy(
149 input_values.mutable_data_ptr<scalar_t>(), values, nnz * sizeof(scalar_t));
150 std::memcpy(
151 col_indices.mutable_data_ptr<MKL_INT>(), columns, nnz * sizeof(MKL_INT));
152 }
153 if (rows > 0) {
154 std::memcpy(
155 crow_indices.mutable_data_ptr<MKL_INT>(), rows_start, rows * sizeof(MKL_INT));
156 }
157 crow_indices.mutable_data_ptr<MKL_INT>()[rows] = nnz;
158 }
159 #endif
160
161 /*
162 Computes a sparse matrix-dense matrix product defined as
163 C <- alpha*(A*B) + beta*C
164
165 Args:
166 * `A` - Sparse Tensor storing m x k matrix.
167 * `B` - Dense Tensor storing k x n matrix.
168 * `C` - [in] Dense Tensor storing matrix of size m x n.
169 [out] result of the operation.
170 */
addmm_dense_result(const Tensor & A,const Tensor & B,const Scalar & beta,const Scalar & alpha,const Tensor & C)171 void addmm_dense_result(
172 const Tensor& A,
173 const Tensor& B,
174 const Scalar& beta,
175 const Scalar& alpha,
176 const Tensor& C) {
177 #if !AT_USE_MKL_SPARSE()
178 TORCH_CHECK(
179 false,
180 "Calling addmm on a sparse CPU tensor requires Linux platform. ",
181 "Please use PyTorch built with MKL on Linux.");
182 #else
183 c10::MaybeOwned<Tensor> C_ = prepare_dense_matrix_for_mkl(C);
184 IntArrayRef C_strides = C_->strides();
185 auto ndim = C_->dim();
186 bool is_C_row_major = (C_strides[ndim - 1] == 1);
187
188 // MKL requires same storage layout of matrices
189 c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_mkl(B, is_C_row_major);
190 IntArrayRef B_strides = B_->strides();
191 bool is_B_row_major = (B_strides[ndim - 1] == 1);
192
193 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!(is_C_row_major ^ is_B_row_major));
194
195 auto order =
196 is_C_row_major ? SPARSE_LAYOUT_ROW_MAJOR : SPARSE_LAYOUT_COLUMN_MAJOR;
197 auto ldc = is_C_row_major ? C_strides[ndim - 2] : C_strides[ndim - 1];
198 auto ldb = is_B_row_major ? B_strides[ndim - 2] : B_strides[ndim - 1];
199 auto columns_C = mkl_int_cast(C.size(-1), "columns_C");
200
201 matrix_descr descrA;
202 descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
203
204 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
205 C.scalar_type(), "addmm_out_sparse_csr_impl_mkl", [&] {
206 auto beta_ = beta.to<scalar_t>();
207 auto alpha_ = alpha.to<scalar_t>();
208
209 auto mkl_sparse_mat =
210 at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(A);
211 at::mkl::sparse::mm<scalar_t>(
212 SPARSE_OPERATION_NON_TRANSPOSE,
213 alpha_,
214 mkl_sparse_mat.descriptor(),
215 descrA,
216 order,
217 B_->data_ptr<scalar_t>(),
218 columns_C,
219 ldb,
220 beta_,
221 C_->data_ptr<scalar_t>(),
222 ldc);
223 });
224
225 if (!C.is_same(*C_)) {
226 C.copy_(*C_);
227 }
228 #endif
229 }
230
231 /*
232 Computes a sparse matrix-sparse matrix product with dense result defined as
233 C <- alpha*(A*B) + beta*C
234
235 Args:
236 * `A` - Sparse Tensor storing m x k matrix.
237 * `B` - Sparse Tensor storing k x n matrix.
238 * `C` - [in] Dense Tensor storing matrix of size m x n.
239 [out] result of the operation.
240 */
addmm_sparse_input_dense_result(const Tensor & A,const Tensor & B,const Scalar & beta,const Scalar & alpha,const Tensor & C)241 void addmm_sparse_input_dense_result(
242 const Tensor& A,
243 const Tensor& B,
244 const Scalar& beta,
245 const Scalar& alpha,
246 const Tensor& C) {
247 #if !AT_USE_MKL_SPARSE()
248 TORCH_CHECK(
249 false,
250 "Calling addmm on a sparse CPU tensor requires Linux platform. ",
251 "Please use PyTorch built with MKL on Linux.");
252 #else
253 // MKL function computes C <- A*B
254 // So we need a temporary matrix to store the result
255 // and then add it to C
256 auto C_ = at::empty(C.sizes(), C.options());
257 auto order = SPARSE_LAYOUT_ROW_MAJOR;
258 auto ldc = C_.stride(-2);
259
260 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
261 C.scalar_type(), "addmm_sparse_input_dense_result", [&] {
262 auto mkl_A = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(A);
263 auto mkl_B = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(B);
264 at::mkl::sparse::spmmd<scalar_t>(
265 SPARSE_OPERATION_NON_TRANSPOSE,
266 mkl_A.descriptor(),
267 mkl_B.descriptor(),
268 order,
269 C_.data_ptr<scalar_t>(),
270 ldc);
271 });
272
273 // If beta is zero NaN and Inf should not be propagated to the result
274 if (beta.toComplexDouble() == 0.) {
275 C.zero_();
276 } else {
277 C.mul_(beta);
278 }
279 C.add_(C_, alpha);
280 #endif
281 }
282
283 /*
284 Computes a sparse matrix-sparse matrix product defined as
285 C <- alpha*(A*B) + beta*C
286
287 Args:
288 * `mat1` - Sparse CSR Tensor storing m x k matrix A.
289 * `mat2` - Sparse CSR Tensor storing k x n matrix B.
290 * `result` - [in] Sparse CSR Tensor storing matrix C of size m x n.
291 [out] result of the operation.
292 */
addmm_sparse_result(const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)293 void addmm_sparse_result(
294 const Tensor& mat1,
295 const Tensor& mat2,
296 const Scalar& beta,
297 const Scalar& alpha,
298 const Tensor& result) {
299 #if !AT_USE_MKL_SPARSE()
300 TORCH_CHECK(
301 false,
302 "Calling add on a sparse CPU tensor requires Linux platform. ",
303 "Please use PyTorch built with MKL on Linux.");
304 #else
305 // Compute beta*result because MKL doesn't do it
306 // If beta is zero NaN and Inf should not be propagated to the result
307 if (beta.toComplexDouble() == 0.) {
308 result.values().zero_();
309 } else {
310 result.values().mul_(beta);
311 }
312
313 // MKL doesn't work with empty matrices
314 if (mat1._nnz() == 0 || mat2._nnz() == 0) {
315 return;
316 }
317
318 // MKL doesn't have an interface to compute alpha*(A*B) + beta*C at once
319 Tensor mat1_mat2 = at::zeros(result.sizes(), result.options());
320 indices_to_mkl_compatible_inplace(mat1_mat2);
321
322 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
323 result.scalar_type(), "addmm_out_sparse_csr_impl_mkl_sparse", [&] {
324 auto mkl_sparse_mat1 =
325 at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(mat1);
326 auto mkl_sparse_mat2 =
327 at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(mat2);
328 auto mkl_result = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>();
329 auto result_desc = mkl_result.descriptor();
330
331 TORCH_MKLSPARSE_CHECK(mkl_sparse_spmm(
332 SPARSE_OPERATION_NON_TRANSPOSE,
333 mkl_sparse_mat1.descriptor(),
334 mkl_sparse_mat2.descriptor(),
335 &result_desc));
336
337 // copy the data from MKL, otherwise computed result will be destroyed
338 // together with `mkl_result`
339 mkl_result_copy_<scalar_t>(mat1_mat2, result_desc);
340 });
341
342 result.add_(mat1_mat2, alpha);
343 #endif
344 }
345
346 } // anonymous namespace
347
348 /*
349 Computes a matrix-matrix product defined as
350 C <- alpha*(A*B) + beta*C
351
352 Args:
353 * `mat1` - Tensor storing m x k matrix A.
354 * `mat2` - Tensor storing k x n matrix B.
355 * `result` - [in] Tensor storing matrix C of size m x n.
356 [out] result of the operation.
357 */
addmm_out_sparse_csr(const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)358 void addmm_out_sparse_csr(
359 const Tensor& mat1,
360 const Tensor& mat2,
361 const Scalar& beta,
362 const Scalar& alpha,
363 const Tensor& result) {
364 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
365 mat1.dim() == 2 && mat2.dim() == 2 && result.dim() == 2);
366 TORCH_INTERNAL_ASSERT(
367 !((mat1.layout() == kStrided) && (mat2.layout() == kStrided) &&
368 (result.layout() == kStrided)),
369 "Expected at least one sparse input");
370
371 // Layout checks are nested mat1, mat2, result
372 // Conditions are ordered strided, csr, csc, bsr, bsc.
373 // Valid combinations terminate in a return
374 // Invalid combinations are omitted and will fall though to the TORCH check
375 // generating an informative error message
376 if (mat1.layout() == kStrided) {
377 if (mat2.layout() == kSparseCsr) {
378 if (result.layout() == kStrided) {
379 // TODO: Add native CSC support via cuSPARSE if supported.
380 return addmm_dense_result(
381 mat2.transpose(0, 1).to_sparse_csr(),
382 mat1.transpose(0, 1),
383 beta,
384 alpha,
385 result.transpose(0, 1));
386 }
387 }
388 if (mat2.layout() == kSparseCsc) {
389 if (result.layout() == kStrided) {
390 return addmm_dense_result(
391 mat2.transpose(-2, -1),
392 mat1.transpose(-2, -1),
393 beta,
394 alpha,
395 result.transpose(-2, -1));
396 }
397 }
398 if (mat2.layout() == kSparseBsc) {
399 if (result.layout() == kStrided) {
400 return addmm_dense_result(
401 mat2.transpose(-2, -1),
402 mat1.transpose(-2, -1),
403 beta,
404 alpha,
405 result.transpose(-2, -1));
406 }
407 }
408 }
409 if (mat1.layout() == kSparseCsr) {
410 if (mat2.layout() == kStrided) {
411 if (result.layout() == kStrided) {
412 return addmm_dense_result(mat1, mat2, beta, alpha, result);
413 }
414 }
415 if (mat2.layout() == kSparseCsr) {
416 if (result.layout() == kStrided) {
417 return addmm_sparse_input_dense_result(mat1, mat2, beta, alpha, result);
418 }
419 if (result.layout() == kSparseCsr) {
420 return addmm_sparse_result(mat1, mat2, beta, alpha, result);
421 }
422 }
423 if (mat2.layout() == kSparseCsc) {
424 if (result.layout() == kStrided) {
425 // TODO: CSR @ CSC kernel would be very fast due to format alignment
426 return addmm_sparse_input_dense_result(
427 mat1, mat2.to_sparse_csr(), beta, alpha, result);
428 }
429 if (result.layout() == kSparseCsr) {
430 // TODO: CSR @ CSC kernel would be very fast due to format alignment
431 return addmm_sparse_result(
432 mat1, mat2.to_sparse_csr(), beta, alpha, result);
433 }
434 }
435 }
436 if (mat1.layout() == kSparseCsc) {
437 if (mat2.layout() == kStrided) {
438 if (result.layout() == kStrided) {
439 // TODO: avoid csc->csr conversion with native csc support
440 return addmm_dense_result(
441 mat1.to_sparse_csr(), mat2, beta, alpha, result);
442 }
443 }
444 if (mat2.layout() == kSparseCsr) {
445 if (result.layout() == kSparseCsr) {
446 // TODO: avoid csc->csr conversion with native csc support
447 return addmm_sparse_result(
448 mat1.to_sparse_csr(), mat2, beta, alpha, result);
449 }
450 }
451 if (mat2.layout() == kSparseCsc) {
452 if (result.layout() == kStrided) {
453 return addmm_sparse_input_dense_result(
454 mat2.transpose(-2, -1),
455 mat1.transpose(-2, -1),
456 beta,
457 alpha,
458 result.transpose(-2, -1));
459 }
460 if (result.layout() == kSparseCsr) {
461 // TODO avoid csc->csr
462 return addmm_sparse_result(
463 mat1.to_sparse_csr(), mat2.to_sparse_csr(), beta, alpha, result);
464 }
465 if (result.layout() == kSparseCsc) {
466 return addmm_sparse_result(
467 mat2.transpose(-2, -1),
468 mat1.transpose(-2, -1),
469 beta,
470 alpha,
471 result.transpose(-2, -1));
472 }
473 }
474 }
475 if (mat1.layout() == kSparseBsr) {
476 if (mat2.layout() == kStrided) {
477 if (result.layout() == kStrided) {
478 return addmm_dense_result(mat1, mat2, beta, alpha, result);
479 }
480 }
481 }
482 TORCH_CHECK(
483 false,
484 "addmm: computation on CPU is not implemented for ",
485 result.layout(),
486 " + ",
487 mat1.layout(),
488 " @ ",
489 mat2.layout());
490 }
491
492 /*
493 Computes a sparse matrix-dense vector product defined as
494 y <- alpha*op(A)*x + beta*y
495
496 Args:
497 * `mat` - Tensor storing sparse m x n matrix A.
498 * `vec` - Tensor storing dense vector x of size n.
499 * `result` - [in] Tensor storing dense vector y of size m.
500 [out] result of the operation.
501 */
addmv_out_sparse_csr(const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,const Tensor & result)502 void addmv_out_sparse_csr(
503 const Tensor& mat,
504 const Tensor& vec,
505 const Scalar& beta,
506 const Scalar& alpha,
507 const Tensor& result) {
508 #if !AT_USE_MKL_SPARSE()
509 TORCH_CHECK(
510 false,
511 "Calling addmv on a sparse CPU tensor requires Linux platform. ",
512 "Please use PyTorch built with MKL on Linux.");
513 #else
514 c10::MaybeOwned<Tensor> result_ = prepare_dense_vector_for_mkl(result);
515 c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_mkl(vec);
516
517 sparse_operation_t opA = SPARSE_OPERATION_NON_TRANSPOSE;
518 matrix_descr descrA;
519 descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
520
521 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
522 result.scalar_type(), "addmv_out_sparse_csr_impl_mkl", [&] {
523 auto beta_ = beta.to<scalar_t>();
524 auto alpha_ = alpha.to<scalar_t>();
525
526 auto mkl_sparse_mat =
527 at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(mat);
528
529 at::mkl::sparse::mv<scalar_t>(
530 opA,
531 alpha_,
532 mkl_sparse_mat.descriptor(),
533 descrA,
534 vec_->data_ptr<scalar_t>(),
535 beta_,
536 result_->data_ptr<scalar_t>());
537 });
538
539 if (!result.is_same(*result_)) {
540 result.copy_(*result_);
541 }
542 #endif
543 }
544
add_out_sparse_csr(const Tensor & mat1,const Tensor & mat2,const Scalar & alpha,const Tensor & result)545 void add_out_sparse_csr(
546 const Tensor& mat1,
547 const Tensor& mat2,
548 const Scalar& alpha,
549 const Tensor& result) {
550 #if !AT_USE_MKL_SPARSE()
551 TORCH_CHECK(
552 false,
553 "Calling add on a sparse CPU tensor requires Linux platform. ",
554 "Please use PyTorch built with MKL on Linux.");
555 #else
556
557 // MKL doesn't work with empty matrices
558 if (mat2._nnz() == 0) {
559 col_indices_and_values_resize_(result, mat1._nnz());
560 result.copy_(mat1);
561 return;
562 } else if (mat1._nnz() == 0) {
563 col_indices_and_values_resize_(result, mat2._nnz());
564 result.copy_(mat2);
565 result.values().mul_(alpha);
566 return;
567 }
568
569 // Modify `result` tensor in-place to swap indices tensors with 32-bit (or
570 // 64-bit) variants
571 const auto output_indices_dtype = promoteTypes(mat1.crow_indices().scalar_type(), mat2.crow_indices().scalar_type());
572 auto result_crow_indices_backup = result.crow_indices();
573 auto result_col_indices_backup = result.col_indices();
574 indices_to_mkl_compatible_inplace(result);
575 sparse_operation_t opA = SPARSE_OPERATION_NON_TRANSPOSE;
576
577 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
578 result.scalar_type(), "add_out_sparse_csr_impl_mkl", [&] {
579 auto alpha_ = alpha.to<scalar_t>();
580
581 auto mkl_mat1 = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(mat1);
582 auto mkl_mat2 = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(mat2);
583 auto mkl_result = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>();
584
585 // Note that the order the order of mat1 and mat2 arguments is swapped
586 // because MKL computes alpha*mat1 + mat2 while PyTorch needs mat1 +
587 // alpha*mat2
588 auto result_desc = mkl_result.descriptor();
589 at::mkl::sparse::add<scalar_t>(
590 opA,
591 mkl_mat2.descriptor(),
592 alpha_,
593 mkl_mat1.descriptor(),
594 &result_desc);
595
596 // now copy data from `result_desc` to `result`
597 mkl_result_copy_<scalar_t>(result, result_desc);
598 });
599
600 if (output_indices_dtype == at::kLong) {
601 const auto res_nnz = result._nnz();
602 static_cast<SparseCsrTensorImpl*>(result.unsafeGetTensorImpl())->set_member_tensors(
603 result_crow_indices_backup.copy_(result.crow_indices()),
604 result_col_indices_backup.resize_({res_nnz}).copy_(result.col_indices()),
605 result.values(),
606 result.sizes());
607 }
608 #endif
609 }
610
triangular_solve_out_sparse_csr(const Tensor & A_,const Tensor & B,const Tensor & X,bool upper,bool transpose,bool unitriangular)611 void triangular_solve_out_sparse_csr(
612 const Tensor& A_,
613 const Tensor& B,
614 const Tensor& X,
615 bool upper,
616 bool transpose,
617 bool unitriangular) {
618 #if !AT_USE_MKL_SPARSE()
619 TORCH_CHECK(
620 false,
621 "Calling triangular_solve on a sparse CPU tensor requires Linux platform. ",
622 "Please use PyTorch built with MKL on Linux.");
623 #else
624 if (B.numel() == 0 || X.numel() == 0 || A_._nnz() == 0) {
625 // If A has no nnz, then A is singular and we can't solve.
626 X.fill_(NAN);
627 return;
628 }
629
630 const auto materialize_diagonal_indices = [](const Tensor& t) -> Tensor {
631 const auto n = t.size(-1);
632 const auto compressed_indices = std::get<0>(at::sparse_csr::getCompressedPlainIndices(t));
633 const auto diag_indices = at::arange(n, compressed_indices.options()).unsqueeze(0).expand({2, n});
634 const auto diag_values = at::zeros({1}, t.values().options()).expand({n});
635
636 const auto t_coo = t.to_sparse();
637 const auto expanded_indices = at::cat({t_coo._indices(), diag_indices}, /*dim=*/-1);
638 const auto expanded_values = at::cat({t_coo._values(), diag_values}, /*dim=*/0);
639
640 const auto t_expanded_coo = at::sparse_coo_tensor(expanded_indices, expanded_values, t_coo.sizes(), t_coo.options());
641 return t_expanded_coo.to_sparse(t.layout());
642 };
643
644 // MKL has a bug for inputs with unmaterialized diagonal indices.
645 // See https://github.com/pytorch/pytorch/issues/88890 and
646 // the comments within.
647 const auto A = unitriangular ? materialize_diagonal_indices(A_) : A_;
648
649 c10::MaybeOwned<Tensor> X_ = prepare_dense_matrix_for_mkl(X);
650 IntArrayRef X_strides = X_->strides();
651 auto ndim = X_->dim();
652 bool is_X_row_major = (ndim > 1) ? (X_strides[ndim - 1] == 1) : true;
653
654 // MKL requires same storage layout of matrices
655 c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_mkl(B, is_X_row_major);
656
657 sparse_operation_t opA = transpose ? SPARSE_OPERATION_TRANSPOSE : SPARSE_OPERATION_NON_TRANSPOSE;
658 matrix_descr descrA;
659 descrA.type = SPARSE_MATRIX_TYPE_TRIANGULAR;
660 descrA.mode = upper ? SPARSE_FILL_MODE_UPPER : SPARSE_FILL_MODE_LOWER;
661 descrA.diag = unitriangular ? SPARSE_DIAG_UNIT : SPARSE_DIAG_NON_UNIT;
662
663 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
664 X.scalar_type(), "triangular_solve_out_sparse_csr_impl_mkl", [&] {
665 auto mkl_sparse_mat =
666 at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(A);
667 scalar_t alpha = 1;
668
669 if (B.size(-1) == 1) {
670 sparse_status_t status = at::mkl::sparse::trsv<scalar_t>(
671 opA,
672 alpha,
673 mkl_sparse_mat.descriptor(),
674 descrA,
675 B_->data_ptr<scalar_t>(),
676 X_->data_ptr<scalar_t>());
677 // Emulate behavior of old MKL version that would set all elements of output array to -NaN
678 // in case of invalid input matrices.
679 if (status == SPARSE_STATUS_INVALID_VALUE) {
680 X_->fill_(-std::numeric_limits<scalar_t>::quiet_NaN());
681 }
682 } else {
683 IntArrayRef B_strides = B_->strides();
684 bool is_B_row_major = (B_strides[ndim - 1] == 1);
685 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!(is_X_row_major ^ is_B_row_major));
686
687 auto order = is_X_row_major ? SPARSE_LAYOUT_ROW_MAJOR : SPARSE_LAYOUT_COLUMN_MAJOR;
688 auto nrhs = mkl_int_cast(B.size(-1), "nrhs");
689 auto ldx = is_X_row_major ? X_strides[ndim - 2] : X_strides[ndim - 1];
690 auto ldb = is_B_row_major ? B_strides[ndim - 2] : B_strides[ndim - 1];
691 sparse_status_t status = at::mkl::sparse::trsm<scalar_t>(
692 opA,
693 alpha,
694 mkl_sparse_mat.descriptor(),
695 descrA,
696 order,
697 B_->data_ptr<scalar_t>(),
698 nrhs,
699 ldb,
700 X_->data_ptr<scalar_t>(),
701 ldx);
702 // Emulate behavior of old MKL version that would set all elements of output array to -NaN
703 // in case of invalid input matrices.
704 if (status == SPARSE_STATUS_INVALID_VALUE) {
705 X_->fill_(-std::numeric_limits<scalar_t>::quiet_NaN());
706 }
707 }
708 });
709
710 if (!X.is_same(*X_)) {
711 X.copy_(*X_);
712 }
713 #endif
714 }
715
716 } // namespace mkl
717 } // namespace impl
718 } // namespace sparse
719 } // namespace native
720 } // namespace at
721