1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/mkl/SparseCsrLinearAlgebra.h>
3 #include <ATen/native/SparseTensorUtils.h>
4
5 // Don't compile with MKL for macos since linking the sparse MKL routines
6 // needs some build fixes.
7 // Macros source:
8 // https://web.archive.org/web/20191012035921/http://nadeausoftware.com/articles/2012/01/c_c_tip_how_use_compiler_predefined_macros_detect_operating_system
9 #if !AT_MKL_ENABLED() || defined(__APPLE__) || \
10 defined(__MACH__)
11
12 namespace at {
13 namespace sparse_csr {
_sparse_mm_mkl_(Tensor & self,const SparseCsrTensor & sparse_,const Tensor & dense,const Tensor & t,const Scalar & alpha,const Scalar & beta)14 Tensor& _sparse_mm_mkl_(
15 Tensor& self,
16 const SparseCsrTensor& sparse_,
17 const Tensor& dense,
18 const Tensor& t,
19 const Scalar& alpha,
20 const Scalar& beta) {
21 #if __APPLE__ || __MACH__
22 AT_ERROR("sparse_mm_mkl: MKL support is disabled on macos/iOS.");
23 #else
24 AT_ERROR("sparse_mm_mkl: ATen not compiled with MKL support");
25 #endif
26 return self; // for stopping compiler warnings.
27 }
28 } // namespace native
29 } // namespace at
30
31 #else // AT_MKL_ENABLED
32
33 #include <ATen/mkl/Descriptors.h>
34 #include <ATen/mkl/Exceptions.h>
35 #include <ATen/mkl/Limits.h>
36 #include <mkl.h>
37 #include <mkl_spblas.h>
38
39 #include <ATen/Dispatch.h>
40 #include <ATen/ExpandUtils.h>
41 #include <ATen/SparseCsrTensorImpl.h>
42
43 namespace at {
44 namespace sparse_csr {
45
46 #ifdef MKL_ILP64
47 static constexpr ScalarType TORCH_INT_TYPE = at::kLong;
48 #else
49 static constexpr ScalarType TORCH_INT_TYPE = at::kInt;
50 #endif
51
52 class SparseCsrMKLInterface {
53 private:
54 sparse_matrix_t A{nullptr};
55 matrix_descr desc;
56
57 public:
SparseCsrMKLInterface(MKL_INT * col_indices,MKL_INT * crow_indices,double * values,MKL_INT nrows,MKL_INT ncols)58 SparseCsrMKLInterface(
59 MKL_INT* col_indices,
60 MKL_INT* crow_indices,
61 double* values,
62 MKL_INT nrows,
63 MKL_INT ncols) {
64 desc.type = SPARSE_MATRIX_TYPE_GENERAL;
65 int retval = mkl_sparse_d_create_csr(
66 &A,
67 SPARSE_INDEX_BASE_ZERO,
68 nrows,
69 ncols,
70 crow_indices,
71 crow_indices + 1,
72 col_indices,
73 values);
74 TORCH_CHECK(
75 retval == 0,
76 "mkl_sparse_d_create_csr failed with error code: ",
77 retval);
78 }
79
SparseCsrMKLInterface(MKL_INT * col_indices,MKL_INT * crow_indices,float * values,MKL_INT nrows,MKL_INT ncols)80 SparseCsrMKLInterface(
81 MKL_INT* col_indices,
82 MKL_INT* crow_indices,
83 float* values,
84 MKL_INT nrows,
85 MKL_INT ncols) {
86 desc.type = SPARSE_MATRIX_TYPE_GENERAL;
87 int retval = mkl_sparse_s_create_csr(
88 &A,
89 SPARSE_INDEX_BASE_ZERO,
90 nrows,
91 ncols,
92 crow_indices,
93 crow_indices + 1,
94 col_indices,
95 values);
96 TORCH_CHECK(
97 retval == 0,
98 "mkl_sparse_s_create_csr failed with error code: ",
99 retval);
100 }
101
102 // res(nrows, dense_ncols) = (sparse(nrows * ncols) @ dense(ncols x dense_ncols))
sparse_mm(float * res,float * dense,float alpha,float beta,MKL_INT nrows,MKL_INT ncols,MKL_INT dense_ncols)103 inline void sparse_mm(
104 float* res,
105 float* dense,
106 float alpha,
107 float beta,
108 MKL_INT nrows,
109 MKL_INT ncols,
110 MKL_INT dense_ncols) {
111 int stat;
112 if (dense_ncols == 1) {
113 stat = mkl_sparse_s_mv(
114 SPARSE_OPERATION_NON_TRANSPOSE,
115 alpha,
116 A,
117 desc,
118 dense,
119 beta,
120 res);
121 TORCH_CHECK(stat == 0, "mkl_sparse_s_mv failed with error code: ", stat);
122 } else {
123 stat = mkl_sparse_s_mm(
124 SPARSE_OPERATION_NON_TRANSPOSE,
125 alpha,
126 A,
127 desc,
128 SPARSE_LAYOUT_ROW_MAJOR,
129 dense,
130 nrows,
131 ncols,
132 beta,
133 res,
134 dense_ncols);
135 TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat);
136 }
137 }
138
sparse_mm(double * res,double * dense,double alpha,double beta,MKL_INT nrows,MKL_INT ncols,MKL_INT dense_ncols)139 inline void sparse_mm(
140 double* res,
141 double* dense,
142 double alpha,
143 double beta,
144 MKL_INT nrows,
145 MKL_INT ncols,
146 MKL_INT dense_ncols) {
147 int stat;
148 if (dense_ncols == 1) {
149 stat = mkl_sparse_d_mv(
150 SPARSE_OPERATION_NON_TRANSPOSE,
151 alpha,
152 A,
153 desc,
154 dense,
155 beta,
156 res);
157 TORCH_CHECK(stat == 0, "mkl_sparse_d_mv failed with error code: ", stat);
158 }
159 else {
160 stat = mkl_sparse_d_mm(
161 SPARSE_OPERATION_NON_TRANSPOSE,
162 alpha,
163 A,
164 desc,
165 SPARSE_LAYOUT_ROW_MAJOR,
166 dense,
167 nrows,
168 ncols,
169 beta,
170 res,
171 dense_ncols);
172 TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat);
173 }
174 }
175
~SparseCsrMKLInterface()176 ~SparseCsrMKLInterface() {
177 mkl_sparse_destroy(A);
178 }
179 };
180
181 template <typename scalar_t>
sparse_mm_mkl_template(Tensor & res,const Tensor & col_indices,const Tensor & crow_indices,const Tensor & values,const Tensor & dense,const Tensor & t,const Scalar & alpha,const Scalar & beta,IntArrayRef size,IntArrayRef dense_size)182 static inline void sparse_mm_mkl_template(
183 Tensor& res,
184 const Tensor& col_indices,
185 const Tensor& crow_indices,
186 const Tensor& values,
187 const Tensor& dense,
188 const Tensor& t,
189 const Scalar& alpha,
190 const Scalar& beta,
191 IntArrayRef size,
192 IntArrayRef dense_size) {
193 SparseCsrMKLInterface mkl_impl(
194 col_indices.data_ptr<MKL_INT>(),
195 crow_indices.data_ptr<MKL_INT>(),
196 values.data_ptr<scalar_t>(),
197 size[0],
198 size[1]);
199 mkl_impl.sparse_mm(
200 res.data_ptr<scalar_t>(),
201 dense.data_ptr<scalar_t>(),
202 alpha.to<scalar_t>(),
203 beta.to<scalar_t>(),
204 size[0],
205 size[1],
206 dense_size[1]);
207 }
208
is_mkl_int32_index()209 static bool inline constexpr is_mkl_int32_index() {
210 #ifdef MKL_ILP64
211 return false;
212 #else
213 return true;
214 #endif
215 }
216
_sparse_mm_mkl_(Tensor & self,const SparseCsrTensor & sparse_,const Tensor & dense,const Tensor & t,const Scalar & alpha,const Scalar & beta)217 Tensor& _sparse_mm_mkl_(
218 Tensor& self,
219 const SparseCsrTensor& sparse_,
220 const Tensor& dense,
221 const Tensor& t,
222 const Scalar& alpha,
223 const Scalar& beta) {
224 if (is_mkl_int32_index()) {
225 if (sparse_.crow_indices().scalar_type() != kInt) {
226 TORCH_WARN(
227 "Pytorch is compiled with MKL LP64 and will convert crow_indices to int32.");
228 }
229 if (sparse_.col_indices().scalar_type() != kInt) {
230 TORCH_WARN(
231 "Pytorch is compiled with MKL LP64 and will convert col_indices to int32.");
232 }
233 } else { // This is for future proofing if we ever change to using MKL ILP64.
234 if (sparse_.crow_indices().scalar_type() != kLong) {
235 TORCH_WARN(
236 "Pytorch is compiled with MKL ILP64 and will convert crow_indices dtype to int64.");
237 }
238 if (sparse_.col_indices().scalar_type() != kLong) {
239 TORCH_WARN(
240 "Pytorch is compiled with MKL ILP64 and will convert col_indices dtype to int64.");
241 }
242 }
243 AT_DISPATCH_FLOATING_TYPES(
244 dense.scalar_type(), "addmm_sparse_csr_dense", [&] {
245 sparse_mm_mkl_template<scalar_t>(
246 self,
247 sparse_.col_indices().to(TORCH_INT_TYPE),
248 sparse_.crow_indices().to(TORCH_INT_TYPE),
249 sparse_.values(),
250 dense,
251 t,
252 alpha,
253 beta,
254 sparse_.sizes(),
255 dense.sizes());
256 });
257 return self;
258 }
259
260 } // namespace native
261 } // namespace at
262
263 #endif // AT_MKL_ENABLED
264