xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkl/SparseBlasImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/SparseCsrTensorImpl.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/mkl/Sparse.h>
6 #include <ATen/native/LinearAlgebraUtils.h>
7 #include <ATen/SparseCsrTensorUtils.h>
8 #include <ATen/native/mkl/SparseBlasImpl.h>
9 
10 #include <c10/core/ScalarType.h>
11 #include <c10/util/MaybeOwned.h>
12 
13 #if AT_USE_MKL_SPARSE()
14 #include <ATen/mkl/SparseBlas.h>
15 #include <ATen/mkl/SparseDescriptors.h>
16 #include <ATen/mkl/Utils.h>
17 #endif
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #include <ATen/NativeFunctions.h>
22 #else
23 #include <ATen/ops/cat.h>
24 #include <ATen/ops/sparse_coo_tensor.h>
25 #endif
26 
27 namespace at {
28 namespace native {
29 namespace sparse {
30 namespace impl {
31 namespace mkl {
32 
33 namespace {
34 
35 #if AT_USE_MKL_SPARSE()
prepare_dense_matrix_for_mkl(const Tensor & tensor)36 c10::MaybeOwned<Tensor> prepare_dense_matrix_for_mkl(
37     const Tensor& tensor) {
38   if (tensor.is_non_overlapping_and_dense() ||
39       is_blas_compatible_row_major_order(tensor) ||
40       is_blas_compatible_column_major_order(tensor)) {
41     return at::native::expect_resolved_conj(tensor);
42   } else {
43     return c10::MaybeOwned<Tensor>::owned(
44         tensor.clone(at::MemoryFormat::Contiguous));
45   }
46 }
47 
48 /*
49   Get row-major or column-major matrix.
50 
51   Args:
52   * `tensor` - 2D strided Tensor.
53   * `row_major` - controls the memory layout.
54 */
prepare_dense_matrix_for_mkl(const Tensor & tensor,bool row_major)55 c10::MaybeOwned<Tensor> prepare_dense_matrix_for_mkl(
56     const Tensor& tensor,
57     bool row_major) {
58   if (is_blas_compatible_row_major_order(tensor) && row_major) {
59     return at::native::expect_resolved_conj(tensor);
60   } else {
61     if (row_major) {
62       return c10::MaybeOwned<Tensor>::owned(
63           tensor.clone(at::MemoryFormat::Contiguous));
64     } else {
65       return c10::MaybeOwned<Tensor>::owned(cloneBatchedColumnMajor(tensor));
66     }
67   }
68 }
69 
prepare_dense_vector_for_mkl(const Tensor & tensor)70 c10::MaybeOwned<Tensor> inline prepare_dense_vector_for_mkl(
71     const Tensor& tensor) {
72   if (tensor.is_non_overlapping_and_dense()) {
73     return c10::MaybeOwned<Tensor>::borrowed(tensor);
74   } else {
75     return c10::MaybeOwned<Tensor>::owned(
76         tensor.clone(at::MemoryFormat::Contiguous));
77   }
78 }
79 
indices_to_mkl_compatible_inplace(const Tensor & input)80 void inline indices_to_mkl_compatible_inplace(const Tensor& input) {
81 #ifdef MKL_ILP64
82   // ILP64 is a 64-bit API version of MKL
83   // Indices tensor must have ScalarType::Long type
84   static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
85       ->set_member_tensors(
86           input.crow_indices().to(kLong),
87           input.col_indices().to(kLong),
88           input.values(),
89           input.sizes());
90 #else
91   // LP64 is a 32-bit API version of MKL
92   // Indices tensor must have ScalarType::Int type
93   static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
94       ->set_member_tensors(
95           input.crow_indices().to(kInt),
96           input.col_indices().to(kInt),
97           input.values(),
98           input.sizes());
99 #endif
100 }
101 
col_indices_and_values_resize_(const Tensor & input,int64_t nnz)102 void inline col_indices_and_values_resize_(const Tensor& input, int64_t nnz) {
103   static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
104       ->set_member_tensors(
105           input.crow_indices(),
106           input.col_indices().resize_({nnz}),
107           input.values().resize_({nnz}),
108           input.sizes());
109 }
110 
111 /*
112   Resizes `input` tensor and fills it with the data from MKL.
113 */
114 template <typename scalar_t>
mkl_result_copy_(const Tensor & input,sparse_matrix_t mkl_desc)115 void mkl_result_copy_(const Tensor& input, sparse_matrix_t mkl_desc) {
116   sparse_index_base_t indexing = SPARSE_INDEX_BASE_ZERO;
117   MKL_INT rows, cols;
118   MKL_INT *rows_start = nullptr, *rows_end = nullptr, *columns = nullptr;
119   scalar_t* values = nullptr;
120   at::mkl::sparse::export_csr(
121       mkl_desc,
122       &indexing,
123       &rows,
124       &cols,
125       &rows_start,
126       &rows_end,
127       &columns,
128       &values);
129 
130   // Resize input using nnz information from MKL
131   MKL_INT nnz = rows_end[rows - 1];
132   col_indices_and_values_resize_(input, nnz);
133 
134   auto crow_indices = input.crow_indices();
135   auto col_indices = input.col_indices();
136   auto input_values = input.values();
137 
138   // NB: When nnz is zero it is possible that input_values.data_ptr<scalar_t> is
139   // a nullptr, if input was created via empty. As such we need to check that
140   // nnz is not zero to avoid passing nullptr to std::memcpy. We will apply
141   // the same precautions to crow_indices.data_ptr<MKL_INT>.
142   //
143   // Otherwise ASAN will complain.
144 
145   if (nnz > 0) {
146     // MKL Sparse Inspector-Executor doesn't have a way to provide external
147     // buffers So we have to copy the memory allocated by MKL
148     std::memcpy(
149         input_values.mutable_data_ptr<scalar_t>(), values, nnz * sizeof(scalar_t));
150     std::memcpy(
151         col_indices.mutable_data_ptr<MKL_INT>(), columns, nnz * sizeof(MKL_INT));
152   }
153   if (rows > 0) {
154     std::memcpy(
155         crow_indices.mutable_data_ptr<MKL_INT>(), rows_start, rows * sizeof(MKL_INT));
156   }
157   crow_indices.mutable_data_ptr<MKL_INT>()[rows] = nnz;
158 }
159 #endif
160 
161 /*
162   Computes a sparse matrix-dense matrix product defined as
163   C <- alpha*(A*B) + beta*C
164 
165   Args:
166   * `A` - Sparse Tensor storing m x k matrix.
167   * `B` - Dense Tensor storing k x n matrix.
168   * `C` - [in] Dense Tensor storing matrix of size m x n.
169           [out] result of the operation.
170 */
addmm_dense_result(const Tensor & A,const Tensor & B,const Scalar & beta,const Scalar & alpha,const Tensor & C)171 void addmm_dense_result(
172     const Tensor& A,
173     const Tensor& B,
174     const Scalar& beta,
175     const Scalar& alpha,
176     const Tensor& C) {
177 #if !AT_USE_MKL_SPARSE()
178   TORCH_CHECK(
179       false,
180       "Calling addmm on a sparse CPU tensor requires Linux platform. ",
181       "Please use PyTorch built with MKL on Linux.");
182 #else
183   c10::MaybeOwned<Tensor> C_ = prepare_dense_matrix_for_mkl(C);
184   IntArrayRef C_strides = C_->strides();
185   auto ndim = C_->dim();
186   bool is_C_row_major = (C_strides[ndim - 1] == 1);
187 
188   // MKL requires same storage layout of matrices
189   c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_mkl(B, is_C_row_major);
190   IntArrayRef B_strides = B_->strides();
191   bool is_B_row_major = (B_strides[ndim - 1] == 1);
192 
193   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!(is_C_row_major ^ is_B_row_major));
194 
195   auto order =
196       is_C_row_major ? SPARSE_LAYOUT_ROW_MAJOR : SPARSE_LAYOUT_COLUMN_MAJOR;
197   auto ldc = is_C_row_major ? C_strides[ndim - 2] : C_strides[ndim - 1];
198   auto ldb = is_B_row_major ? B_strides[ndim - 2] : B_strides[ndim - 1];
199   auto columns_C = mkl_int_cast(C.size(-1), "columns_C");
200 
201   matrix_descr descrA;
202   descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
203 
204   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
205       C.scalar_type(), "addmm_out_sparse_csr_impl_mkl", [&] {
206         auto beta_ = beta.to<scalar_t>();
207         auto alpha_ = alpha.to<scalar_t>();
208 
209         auto mkl_sparse_mat =
210             at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(A);
211         at::mkl::sparse::mm<scalar_t>(
212             SPARSE_OPERATION_NON_TRANSPOSE,
213             alpha_,
214             mkl_sparse_mat.descriptor(),
215             descrA,
216             order,
217             B_->data_ptr<scalar_t>(),
218             columns_C,
219             ldb,
220             beta_,
221             C_->data_ptr<scalar_t>(),
222             ldc);
223       });
224 
225   if (!C.is_same(*C_)) {
226     C.copy_(*C_);
227   }
228 #endif
229 }
230 
231 /*
232   Computes a sparse matrix-sparse matrix product with dense result defined as
233   C <- alpha*(A*B) + beta*C
234 
235   Args:
236   * `A` - Sparse Tensor storing m x k matrix.
237   * `B` - Sparse Tensor storing k x n matrix.
238   * `C` - [in] Dense Tensor storing matrix of size m x n.
239           [out] result of the operation.
240 */
addmm_sparse_input_dense_result(const Tensor & A,const Tensor & B,const Scalar & beta,const Scalar & alpha,const Tensor & C)241 void addmm_sparse_input_dense_result(
242     const Tensor& A,
243     const Tensor& B,
244     const Scalar& beta,
245     const Scalar& alpha,
246     const Tensor& C) {
247 #if !AT_USE_MKL_SPARSE()
248   TORCH_CHECK(
249       false,
250       "Calling addmm on a sparse CPU tensor requires Linux platform. ",
251       "Please use PyTorch built with MKL on Linux.");
252 #else
253   // MKL function computes C <- A*B
254   // So we need a temporary matrix to store the result
255   // and then add it to C
256   auto C_ = at::empty(C.sizes(), C.options());
257   auto order = SPARSE_LAYOUT_ROW_MAJOR;
258   auto ldc = C_.stride(-2);
259 
260   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
261       C.scalar_type(), "addmm_sparse_input_dense_result", [&] {
262         auto mkl_A = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(A);
263         auto mkl_B = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(B);
264         at::mkl::sparse::spmmd<scalar_t>(
265             SPARSE_OPERATION_NON_TRANSPOSE,
266             mkl_A.descriptor(),
267             mkl_B.descriptor(),
268             order,
269             C_.data_ptr<scalar_t>(),
270             ldc);
271       });
272 
273   // If beta is zero NaN and Inf should not be propagated to the result
274   if (beta.toComplexDouble() == 0.) {
275     C.zero_();
276   } else {
277     C.mul_(beta);
278   }
279   C.add_(C_, alpha);
280 #endif
281 }
282 
283 /*
284   Computes a sparse matrix-sparse matrix product defined as
285   C <- alpha*(A*B) + beta*C
286 
287   Args:
288   * `mat1` - Sparse CSR Tensor storing m x k matrix A.
289   * `mat2` - Sparse CSR Tensor storing k x n matrix B.
290   * `result` - [in] Sparse CSR Tensor storing matrix C of size m x n.
291                [out] result of the operation.
292 */
addmm_sparse_result(const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)293 void addmm_sparse_result(
294     const Tensor& mat1,
295     const Tensor& mat2,
296     const Scalar& beta,
297     const Scalar& alpha,
298     const Tensor& result) {
299 #if !AT_USE_MKL_SPARSE()
300   TORCH_CHECK(
301       false,
302       "Calling add on a sparse CPU tensor requires Linux platform. ",
303       "Please use PyTorch built with MKL on Linux.");
304 #else
305   // Compute beta*result because MKL doesn't do it
306   // If beta is zero NaN and Inf should not be propagated to the result
307   if (beta.toComplexDouble() == 0.) {
308     result.values().zero_();
309   } else {
310     result.values().mul_(beta);
311   }
312 
313   // MKL doesn't work with empty matrices
314   if (mat1._nnz() == 0 || mat2._nnz() == 0) {
315     return;
316   }
317 
318   // MKL doesn't have an interface to compute alpha*(A*B) + beta*C at once
319   Tensor mat1_mat2 = at::zeros(result.sizes(), result.options());
320   indices_to_mkl_compatible_inplace(mat1_mat2);
321 
322   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
323       result.scalar_type(), "addmm_out_sparse_csr_impl_mkl_sparse", [&] {
324         auto mkl_sparse_mat1 =
325             at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(mat1);
326         auto mkl_sparse_mat2 =
327             at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(mat2);
328         auto mkl_result = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>();
329         auto result_desc = mkl_result.descriptor();
330 
331         TORCH_MKLSPARSE_CHECK(mkl_sparse_spmm(
332             SPARSE_OPERATION_NON_TRANSPOSE,
333             mkl_sparse_mat1.descriptor(),
334             mkl_sparse_mat2.descriptor(),
335             &result_desc));
336 
337         // copy the data from MKL, otherwise computed result will be destroyed
338         // together with `mkl_result`
339         mkl_result_copy_<scalar_t>(mat1_mat2, result_desc);
340       });
341 
342   result.add_(mat1_mat2, alpha);
343 #endif
344 }
345 
346 } // anonymous namespace
347 
348 /*
349   Computes a matrix-matrix product defined as
350   C <- alpha*(A*B) + beta*C
351 
352   Args:
353   * `mat1` - Tensor storing m x k matrix A.
354   * `mat2` - Tensor storing k x n matrix B.
355   * `result` - [in] Tensor storing matrix C of size m x n.
356                [out] result of the operation.
357 */
addmm_out_sparse_csr(const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)358 void addmm_out_sparse_csr(
359     const Tensor& mat1,
360     const Tensor& mat2,
361     const Scalar& beta,
362     const Scalar& alpha,
363     const Tensor& result) {
364   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
365       mat1.dim() == 2 && mat2.dim() == 2 && result.dim() == 2);
366   TORCH_INTERNAL_ASSERT(
367       !((mat1.layout() == kStrided) && (mat2.layout() == kStrided) &&
368         (result.layout() == kStrided)),
369       "Expected at least one sparse input");
370 
371   // Layout checks are nested mat1, mat2, result
372   // Conditions are ordered strided, csr, csc, bsr, bsc.
373   // Valid combinations terminate in a return
374   // Invalid combinations are omitted and will fall though to the TORCH check
375   // generating an informative error message
376   if (mat1.layout() == kStrided) {
377     if (mat2.layout() == kSparseCsr) {
378       if (result.layout() == kStrided) {
379         // TODO: Add native CSC support via cuSPARSE if supported.
380         return addmm_dense_result(
381             mat2.transpose(0, 1).to_sparse_csr(),
382             mat1.transpose(0, 1),
383             beta,
384             alpha,
385             result.transpose(0, 1));
386       }
387     }
388     if (mat2.layout() == kSparseCsc) {
389       if (result.layout() == kStrided) {
390         return addmm_dense_result(
391             mat2.transpose(-2, -1),
392             mat1.transpose(-2, -1),
393             beta,
394             alpha,
395             result.transpose(-2, -1));
396       }
397     }
398     if (mat2.layout() == kSparseBsc) {
399       if (result.layout() == kStrided) {
400         return addmm_dense_result(
401             mat2.transpose(-2, -1),
402             mat1.transpose(-2, -1),
403             beta,
404             alpha,
405             result.transpose(-2, -1));
406       }
407     }
408   }
409   if (mat1.layout() == kSparseCsr) {
410     if (mat2.layout() == kStrided) {
411       if (result.layout() == kStrided) {
412         return addmm_dense_result(mat1, mat2, beta, alpha, result);
413       }
414     }
415     if (mat2.layout() == kSparseCsr) {
416       if (result.layout() == kStrided) {
417         return addmm_sparse_input_dense_result(mat1, mat2, beta, alpha, result);
418       }
419       if (result.layout() == kSparseCsr) {
420         return addmm_sparse_result(mat1, mat2, beta, alpha, result);
421       }
422     }
423     if (mat2.layout() == kSparseCsc) {
424       if (result.layout() == kStrided) {
425         // TODO: CSR @ CSC kernel would be very fast due to format alignment
426         return addmm_sparse_input_dense_result(
427             mat1, mat2.to_sparse_csr(), beta, alpha, result);
428       }
429       if (result.layout() == kSparseCsr) {
430         // TODO: CSR @ CSC kernel would be very fast due to format alignment
431         return addmm_sparse_result(
432             mat1, mat2.to_sparse_csr(), beta, alpha, result);
433       }
434     }
435   }
436   if (mat1.layout() == kSparseCsc) {
437     if (mat2.layout() == kStrided) {
438       if (result.layout() == kStrided) {
439         // TODO: avoid csc->csr conversion with native csc support
440         return addmm_dense_result(
441             mat1.to_sparse_csr(), mat2, beta, alpha, result);
442       }
443     }
444     if (mat2.layout() == kSparseCsr) {
445       if (result.layout() == kSparseCsr) {
446         // TODO: avoid csc->csr conversion with native csc support
447         return addmm_sparse_result(
448             mat1.to_sparse_csr(), mat2, beta, alpha, result);
449       }
450     }
451     if (mat2.layout() == kSparseCsc) {
452       if (result.layout() == kStrided) {
453         return addmm_sparse_input_dense_result(
454             mat2.transpose(-2, -1),
455             mat1.transpose(-2, -1),
456             beta,
457             alpha,
458             result.transpose(-2, -1));
459       }
460       if (result.layout() == kSparseCsr) {
461         // TODO avoid csc->csr
462         return addmm_sparse_result(
463             mat1.to_sparse_csr(), mat2.to_sparse_csr(), beta, alpha, result);
464       }
465       if (result.layout() == kSparseCsc) {
466         return addmm_sparse_result(
467             mat2.transpose(-2, -1),
468             mat1.transpose(-2, -1),
469             beta,
470             alpha,
471             result.transpose(-2, -1));
472       }
473     }
474   }
475   if (mat1.layout() == kSparseBsr) {
476     if (mat2.layout() == kStrided) {
477       if (result.layout() == kStrided) {
478         return addmm_dense_result(mat1, mat2, beta, alpha, result);
479       }
480     }
481   }
482   TORCH_CHECK(
483       false,
484       "addmm: computation on CPU is not implemented for ",
485       result.layout(),
486       " + ",
487       mat1.layout(),
488       " @ ",
489       mat2.layout());
490 }
491 
492 /*
493   Computes a sparse matrix-dense vector product defined as
494   y <- alpha*op(A)*x + beta*y
495 
496   Args:
497   * `mat` - Tensor storing sparse m x n matrix A.
498   * `vec` - Tensor storing dense vector x of size n.
499   * `result` - [in] Tensor storing dense vector y of size m.
500                [out] result of the operation.
501 */
addmv_out_sparse_csr(const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha,const Tensor & result)502 void addmv_out_sparse_csr(
503     const Tensor& mat,
504     const Tensor& vec,
505     const Scalar& beta,
506     const Scalar& alpha,
507     const Tensor& result) {
508 #if !AT_USE_MKL_SPARSE()
509   TORCH_CHECK(
510       false,
511       "Calling addmv on a sparse CPU tensor requires Linux platform. ",
512       "Please use PyTorch built with MKL on Linux.");
513 #else
514   c10::MaybeOwned<Tensor> result_ = prepare_dense_vector_for_mkl(result);
515   c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_mkl(vec);
516 
517   sparse_operation_t opA = SPARSE_OPERATION_NON_TRANSPOSE;
518   matrix_descr descrA;
519   descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
520 
521   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
522       result.scalar_type(), "addmv_out_sparse_csr_impl_mkl", [&] {
523         auto beta_ = beta.to<scalar_t>();
524         auto alpha_ = alpha.to<scalar_t>();
525 
526         auto mkl_sparse_mat =
527             at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(mat);
528 
529         at::mkl::sparse::mv<scalar_t>(
530             opA,
531             alpha_,
532             mkl_sparse_mat.descriptor(),
533             descrA,
534             vec_->data_ptr<scalar_t>(),
535             beta_,
536             result_->data_ptr<scalar_t>());
537       });
538 
539   if (!result.is_same(*result_)) {
540     result.copy_(*result_);
541   }
542 #endif
543 }
544 
add_out_sparse_csr(const Tensor & mat1,const Tensor & mat2,const Scalar & alpha,const Tensor & result)545 void add_out_sparse_csr(
546     const Tensor& mat1,
547     const Tensor& mat2,
548     const Scalar& alpha,
549     const Tensor& result) {
550 #if !AT_USE_MKL_SPARSE()
551   TORCH_CHECK(
552       false,
553       "Calling add on a sparse CPU tensor requires Linux platform. ",
554       "Please use PyTorch built with MKL on Linux.");
555 #else
556 
557   // MKL doesn't work with empty matrices
558   if (mat2._nnz() == 0) {
559     col_indices_and_values_resize_(result, mat1._nnz());
560     result.copy_(mat1);
561     return;
562   } else if (mat1._nnz() == 0) {
563     col_indices_and_values_resize_(result, mat2._nnz());
564     result.copy_(mat2);
565     result.values().mul_(alpha);
566     return;
567   }
568 
569   // Modify `result` tensor in-place to swap indices tensors with 32-bit (or
570   // 64-bit) variants
571   const auto output_indices_dtype = promoteTypes(mat1.crow_indices().scalar_type(), mat2.crow_indices().scalar_type());
572   auto result_crow_indices_backup = result.crow_indices();
573   auto result_col_indices_backup = result.col_indices();
574   indices_to_mkl_compatible_inplace(result);
575   sparse_operation_t opA = SPARSE_OPERATION_NON_TRANSPOSE;
576 
577   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
578       result.scalar_type(), "add_out_sparse_csr_impl_mkl", [&] {
579         auto alpha_ = alpha.to<scalar_t>();
580 
581         auto mkl_mat1 = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(mat1);
582         auto mkl_mat2 = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(mat2);
583         auto mkl_result = at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>();
584 
585         // Note that the order the order of mat1 and mat2 arguments is swapped
586         // because MKL computes alpha*mat1 + mat2 while PyTorch needs mat1 +
587         // alpha*mat2
588         auto result_desc = mkl_result.descriptor();
589         at::mkl::sparse::add<scalar_t>(
590             opA,
591             mkl_mat2.descriptor(),
592             alpha_,
593             mkl_mat1.descriptor(),
594             &result_desc);
595 
596         // now copy data from `result_desc` to `result`
597         mkl_result_copy_<scalar_t>(result, result_desc);
598       });
599 
600   if (output_indices_dtype == at::kLong) {
601     const auto res_nnz = result._nnz();
602     static_cast<SparseCsrTensorImpl*>(result.unsafeGetTensorImpl())->set_member_tensors(
603         result_crow_indices_backup.copy_(result.crow_indices()),
604         result_col_indices_backup.resize_({res_nnz}).copy_(result.col_indices()),
605         result.values(),
606         result.sizes());
607   }
608 #endif
609 }
610 
triangular_solve_out_sparse_csr(const Tensor & A_,const Tensor & B,const Tensor & X,bool upper,bool transpose,bool unitriangular)611 void triangular_solve_out_sparse_csr(
612     const Tensor& A_,
613     const Tensor& B,
614     const Tensor& X,
615     bool upper,
616     bool transpose,
617     bool unitriangular) {
618 #if !AT_USE_MKL_SPARSE()
619   TORCH_CHECK(
620       false,
621       "Calling triangular_solve on a sparse CPU tensor requires Linux platform. ",
622       "Please use PyTorch built with MKL on Linux.");
623 #else
624   if (B.numel() == 0 || X.numel() == 0 || A_._nnz() == 0) {
625     // If A has no nnz, then A is singular and we can't solve.
626     X.fill_(NAN);
627     return;
628   }
629 
630   const auto materialize_diagonal_indices = [](const Tensor& t) -> Tensor {
631     const auto n = t.size(-1);
632     const auto compressed_indices = std::get<0>(at::sparse_csr::getCompressedPlainIndices(t));
633     const auto diag_indices = at::arange(n, compressed_indices.options()).unsqueeze(0).expand({2, n});
634     const auto diag_values = at::zeros({1}, t.values().options()).expand({n});
635 
636     const auto t_coo = t.to_sparse();
637     const auto expanded_indices = at::cat({t_coo._indices(), diag_indices}, /*dim=*/-1);
638     const auto expanded_values = at::cat({t_coo._values(), diag_values}, /*dim=*/0);
639 
640     const auto t_expanded_coo = at::sparse_coo_tensor(expanded_indices, expanded_values, t_coo.sizes(), t_coo.options());
641     return t_expanded_coo.to_sparse(t.layout());
642   };
643 
644   // MKL has a bug for inputs with unmaterialized diagonal indices.
645   // See https://github.com/pytorch/pytorch/issues/88890 and
646   // the comments within.
647   const auto A = unitriangular ? materialize_diagonal_indices(A_) : A_;
648 
649   c10::MaybeOwned<Tensor> X_ = prepare_dense_matrix_for_mkl(X);
650   IntArrayRef X_strides = X_->strides();
651   auto ndim = X_->dim();
652   bool is_X_row_major = (ndim > 1) ? (X_strides[ndim - 1] == 1) : true;
653 
654   // MKL requires same storage layout of matrices
655   c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_mkl(B, is_X_row_major);
656 
657   sparse_operation_t opA = transpose ? SPARSE_OPERATION_TRANSPOSE : SPARSE_OPERATION_NON_TRANSPOSE;
658   matrix_descr descrA;
659   descrA.type = SPARSE_MATRIX_TYPE_TRIANGULAR;
660   descrA.mode = upper ? SPARSE_FILL_MODE_UPPER : SPARSE_FILL_MODE_LOWER;
661   descrA.diag = unitriangular ? SPARSE_DIAG_UNIT : SPARSE_DIAG_NON_UNIT;
662 
663   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
664       X.scalar_type(), "triangular_solve_out_sparse_csr_impl_mkl", [&] {
665         auto mkl_sparse_mat =
666             at::mkl::sparse::MklSparseCsrDescriptor<scalar_t>(A);
667         scalar_t alpha = 1;
668 
669         if (B.size(-1) == 1) {
670           sparse_status_t status = at::mkl::sparse::trsv<scalar_t>(
671               opA,
672               alpha,
673               mkl_sparse_mat.descriptor(),
674               descrA,
675               B_->data_ptr<scalar_t>(),
676               X_->data_ptr<scalar_t>());
677           // Emulate behavior of old MKL version that would set all elements of output array to -NaN
678           // in case of invalid input matrices.
679           if (status == SPARSE_STATUS_INVALID_VALUE) {
680             X_->fill_(-std::numeric_limits<scalar_t>::quiet_NaN());
681           }
682         } else {
683           IntArrayRef B_strides = B_->strides();
684           bool is_B_row_major = (B_strides[ndim - 1] == 1);
685           TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!(is_X_row_major ^ is_B_row_major));
686 
687           auto order = is_X_row_major ? SPARSE_LAYOUT_ROW_MAJOR : SPARSE_LAYOUT_COLUMN_MAJOR;
688           auto nrhs = mkl_int_cast(B.size(-1), "nrhs");
689           auto ldx = is_X_row_major ? X_strides[ndim - 2] : X_strides[ndim - 1];
690           auto ldb = is_B_row_major ? B_strides[ndim - 2] : B_strides[ndim - 1];
691           sparse_status_t status = at::mkl::sparse::trsm<scalar_t>(
692               opA,
693               alpha,
694               mkl_sparse_mat.descriptor(),
695               descrA,
696               order,
697               B_->data_ptr<scalar_t>(),
698               nrhs,
699               ldb,
700               X_->data_ptr<scalar_t>(),
701               ldx);
702           // Emulate behavior of old MKL version that would set all elements of output array to -NaN
703           // in case of invalid input matrices.
704           if (status == SPARSE_STATUS_INVALID_VALUE) {
705             X_->fill_(-std::numeric_limits<scalar_t>::quiet_NaN());
706           }
707         }
708       });
709 
710   if (!X.is_same(*X_)) {
711     X.copy_(*X_);
712   }
713 #endif
714 }
715 
716 } // namespace mkl
717 } // namespace impl
718 } // namespace sparse
719 } // namespace native
720 } // namespace at
721