1 #pragma once
2
3 /*
4 Provides templated descriptor wrappers of MKL Sparse BLAS sparse matrices:
5
6 MklSparseCsrDescriptor<scalar_t>(sparse_csr_tensor)
7
8 where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
9 The descriptors are available in at::mkl::sparse namespace.
10 */
11
12 #include <ATen/Tensor.h>
13 #include <ATen/mkl/Exceptions.h>
14 #include <ATen/mkl/Utils.h>
15
16 #include <c10/core/ScalarType.h>
17 #include <c10/util/MaybeOwned.h>
18
19 #include <mkl_spblas.h>
20
21 namespace at::mkl::sparse {
22
23 template <typename T, sparse_status_t (*destructor)(T*)>
24 struct MklSparseDescriptorDeleter {
operatorMklSparseDescriptorDeleter25 void operator()(T* x) {
26 if (x != nullptr) {
27 TORCH_MKLSPARSE_CHECK(destructor(x));
28 }
29 }
30 };
31
32 template <typename T, sparse_status_t (*destructor)(T*)>
33 class MklSparseDescriptor {
34 public:
descriptor()35 T* descriptor() const {
36 return descriptor_.get();
37 }
descriptor()38 T* descriptor() {
39 return descriptor_.get();
40 }
41
42 protected:
43 std::unique_ptr<T, MklSparseDescriptorDeleter<T, destructor>> descriptor_;
44 };
45
46 namespace {
47
prepare_indices_for_mkl(const Tensor & indices)48 c10::MaybeOwned<Tensor> inline prepare_indices_for_mkl(const Tensor& indices) {
49 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
50 isIntegralType(indices.scalar_type(), /*includeBool=*/false));
51 #ifdef MKL_ILP64
52 // ILP64 is a 64-bit API version of MKL
53 // Indices tensor must have ScalarType::Long type
54 if (indices.scalar_type() == ScalarType::Long) {
55 return c10::MaybeOwned<Tensor>::borrowed(indices);
56 } else {
57 return c10::MaybeOwned<Tensor>::owned(indices.to(ScalarType::Long));
58 }
59 #else
60 // LP64 is a 32-bit API version of MKL
61 // Indices tensor must have ScalarType::Int type
62 if (indices.scalar_type() == ScalarType::Int) {
63 return c10::MaybeOwned<Tensor>::borrowed(indices);
64 } else {
65 return c10::MaybeOwned<Tensor>::owned(indices.to(ScalarType::Int));
66 }
67 #endif
68 }
69
70 } // anonymous namespace
71
72 template <typename scalar_t>
73 class MklSparseCsrDescriptor
74 : public MklSparseDescriptor<sparse_matrix, &mkl_sparse_destroy> {
75 public:
MklSparseCsrDescriptor(const Tensor & input)76 MklSparseCsrDescriptor(const Tensor& input) {
77 TORCH_INTERNAL_ASSERT_DEBUG_ONLY((input.layout() == kSparseCsr || input.layout() == kSparseBsr));
78 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() == 2);
79
80 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
81 input._nnz() > 0, "MKL doesn't work with empty CSR matrices");
82
83 IntArrayRef input_sizes = input.sizes();
84 auto rows = mkl_int_cast(input_sizes[0], "rows");
85 auto cols = mkl_int_cast(input_sizes[1], "cols");
86
87 auto crow_indices = input.crow_indices();
88 auto col_indices = input.col_indices();
89 auto values = input.values();
90
91 crow_indices_ = prepare_indices_for_mkl(crow_indices);
92 col_indices_ = prepare_indices_for_mkl(col_indices);
93 values_ = values.expect_contiguous();
94
95 auto values_ptr = values_->data_ptr<scalar_t>();
96 auto crow_indices_ptr = crow_indices_->data_ptr<MKL_INT>();
97 auto col_indices_ptr = col_indices_->data_ptr<MKL_INT>();
98
99 sparse_matrix_t raw_descriptor;
100
101 if (input.layout() == kSparseBsr) {
102 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
103 values.dim() == 3 && crow_indices.dim() == 1 &&
104 col_indices.dim() == 1);
105 TORCH_CHECK(
106 values.size(-1) == values.size(-2),
107 "MKL Sparse doesn't support matrices with non-square blocks.");
108 auto block_size = mkl_int_cast(values.size(-1), "block_size");
109 create_bsr<scalar_t>(
110 &raw_descriptor,
111 SPARSE_INDEX_BASE_ZERO,
112 SPARSE_LAYOUT_ROW_MAJOR,
113 rows / block_size,
114 cols / block_size,
115 block_size,
116 crow_indices_ptr,
117 crow_indices_ptr + 1,
118 col_indices_ptr,
119 values_ptr);
120 } else {
121 create_csr<scalar_t>(
122 &raw_descriptor,
123 SPARSE_INDEX_BASE_ZERO,
124 rows,
125 cols,
126 crow_indices_ptr,
127 crow_indices_ptr + 1,
128 col_indices_ptr,
129 values_ptr);
130 }
131
132 descriptor_.reset(raw_descriptor);
133 }
134
MklSparseCsrDescriptor()135 MklSparseCsrDescriptor() {
136 sparse_matrix_t raw_descriptor = nullptr;
137 descriptor_.reset(raw_descriptor);
138 }
139
140 private:
141 c10::MaybeOwned<Tensor> crow_indices_;
142 c10::MaybeOwned<Tensor> col_indices_;
143 c10::MaybeOwned<Tensor> values_;
144 };
145
146 } // namespace at::mkl::sparse
147