xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mkl/SparseDescriptors.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 /*
4   Provides templated descriptor wrappers of MKL Sparse BLAS sparse matrices:
5 
6     MklSparseCsrDescriptor<scalar_t>(sparse_csr_tensor)
7 
8   where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
9   The descriptors are available in at::mkl::sparse namespace.
10 */
11 
12 #include <ATen/Tensor.h>
13 #include <ATen/mkl/Exceptions.h>
14 #include <ATen/mkl/Utils.h>
15 
16 #include <c10/core/ScalarType.h>
17 #include <c10/util/MaybeOwned.h>
18 
19 #include <mkl_spblas.h>
20 
21 namespace at::mkl::sparse {
22 
23 template <typename T, sparse_status_t (*destructor)(T*)>
24 struct MklSparseDescriptorDeleter {
operatorMklSparseDescriptorDeleter25   void operator()(T* x) {
26     if (x != nullptr) {
27       TORCH_MKLSPARSE_CHECK(destructor(x));
28     }
29   }
30 };
31 
32 template <typename T, sparse_status_t (*destructor)(T*)>
33 class MklSparseDescriptor {
34  public:
descriptor()35   T* descriptor() const {
36     return descriptor_.get();
37   }
descriptor()38   T* descriptor() {
39     return descriptor_.get();
40   }
41 
42  protected:
43   std::unique_ptr<T, MklSparseDescriptorDeleter<T, destructor>> descriptor_;
44 };
45 
46 namespace {
47 
prepare_indices_for_mkl(const Tensor & indices)48 c10::MaybeOwned<Tensor> inline prepare_indices_for_mkl(const Tensor& indices) {
49   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
50       isIntegralType(indices.scalar_type(), /*includeBool=*/false));
51 #ifdef MKL_ILP64
52   // ILP64 is a 64-bit API version of MKL
53   // Indices tensor must have ScalarType::Long type
54   if (indices.scalar_type() == ScalarType::Long) {
55     return c10::MaybeOwned<Tensor>::borrowed(indices);
56   } else {
57     return c10::MaybeOwned<Tensor>::owned(indices.to(ScalarType::Long));
58   }
59 #else
60   // LP64 is a 32-bit API version of MKL
61   // Indices tensor must have ScalarType::Int type
62   if (indices.scalar_type() == ScalarType::Int) {
63     return c10::MaybeOwned<Tensor>::borrowed(indices);
64   } else {
65     return c10::MaybeOwned<Tensor>::owned(indices.to(ScalarType::Int));
66   }
67 #endif
68 }
69 
70 } // anonymous namespace
71 
72 template <typename scalar_t>
73 class MklSparseCsrDescriptor
74     : public MklSparseDescriptor<sparse_matrix, &mkl_sparse_destroy> {
75  public:
MklSparseCsrDescriptor(const Tensor & input)76   MklSparseCsrDescriptor(const Tensor& input) {
77     TORCH_INTERNAL_ASSERT_DEBUG_ONLY((input.layout() == kSparseCsr || input.layout() == kSparseBsr));
78     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() == 2);
79 
80     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
81         input._nnz() > 0, "MKL doesn't work with empty CSR matrices");
82 
83     IntArrayRef input_sizes = input.sizes();
84     auto rows = mkl_int_cast(input_sizes[0], "rows");
85     auto cols = mkl_int_cast(input_sizes[1], "cols");
86 
87     auto crow_indices = input.crow_indices();
88     auto col_indices = input.col_indices();
89     auto values = input.values();
90 
91     crow_indices_ = prepare_indices_for_mkl(crow_indices);
92     col_indices_ = prepare_indices_for_mkl(col_indices);
93     values_ = values.expect_contiguous();
94 
95     auto values_ptr = values_->data_ptr<scalar_t>();
96     auto crow_indices_ptr = crow_indices_->data_ptr<MKL_INT>();
97     auto col_indices_ptr = col_indices_->data_ptr<MKL_INT>();
98 
99     sparse_matrix_t raw_descriptor;
100 
101     if (input.layout() == kSparseBsr) {
102       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
103           values.dim() == 3 && crow_indices.dim() == 1 &&
104           col_indices.dim() == 1);
105       TORCH_CHECK(
106           values.size(-1) == values.size(-2),
107           "MKL Sparse doesn't support matrices with non-square blocks.");
108       auto block_size = mkl_int_cast(values.size(-1), "block_size");
109       create_bsr<scalar_t>(
110           &raw_descriptor,
111           SPARSE_INDEX_BASE_ZERO,
112           SPARSE_LAYOUT_ROW_MAJOR,
113           rows / block_size,
114           cols / block_size,
115           block_size,
116           crow_indices_ptr,
117           crow_indices_ptr + 1,
118           col_indices_ptr,
119           values_ptr);
120     } else {
121       create_csr<scalar_t>(
122           &raw_descriptor,
123           SPARSE_INDEX_BASE_ZERO,
124           rows,
125           cols,
126           crow_indices_ptr,
127           crow_indices_ptr + 1,
128           col_indices_ptr,
129           values_ptr);
130     }
131 
132     descriptor_.reset(raw_descriptor);
133   }
134 
MklSparseCsrDescriptor()135   MklSparseCsrDescriptor() {
136     sparse_matrix_t raw_descriptor = nullptr;
137     descriptor_.reset(raw_descriptor);
138   }
139 
140  private:
141   c10::MaybeOwned<Tensor> crow_indices_;
142   c10::MaybeOwned<Tensor> col_indices_;
143   c10::MaybeOwned<Tensor> values_;
144 };
145 
146 } // namespace at::mkl::sparse
147