1 /*
2 Provides the implementations of MKL Sparse BLAS function templates.
3 */
4 #define TORCH_ASSERT_NO_OPERATORS
5 #include <ATen/mkl/Exceptions.h>
6 #include <ATen/mkl/SparseBlas.h>
7
8 namespace at::mkl::sparse {
9
10 namespace {
11
12 template <typename scalar_t, typename MKL_Complex>
to_mkl_complex(c10::complex<scalar_t> scalar)13 MKL_Complex to_mkl_complex(c10::complex<scalar_t> scalar) {
14 MKL_Complex mkl_scalar;
15 mkl_scalar.real = scalar.real();
16 mkl_scalar.imag = scalar.imag();
17 return mkl_scalar;
18 }
19
20 } // namespace
21
22
23 template <>
create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES (float))24 void create_csr<float>(MKL_SPARSE_CREATE_CSR_ARGTYPES(float)) {
25 TORCH_MKLSPARSE_CHECK(mkl_sparse_s_create_csr(
26 A, indexing, rows, cols, rows_start, rows_end, col_indx, values));
27 }
28 template <>
create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES (double))29 void create_csr<double>(MKL_SPARSE_CREATE_CSR_ARGTYPES(double)) {
30 TORCH_MKLSPARSE_CHECK(mkl_sparse_d_create_csr(
31 A, indexing, rows, cols, rows_start, rows_end, col_indx, values));
32 }
33 template <>
create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES (c10::complex<float>))34 void create_csr<c10::complex<float>>(
35 MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<float>)) {
36 TORCH_MKLSPARSE_CHECK(mkl_sparse_c_create_csr(
37 A,
38 indexing,
39 rows,
40 cols,
41 rows_start,
42 rows_end,
43 col_indx,
44 reinterpret_cast<MKL_Complex8*>(values)));
45 }
46 template <>
create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES (c10::complex<double>))47 void create_csr<c10::complex<double>>(
48 MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<double>)) {
49 TORCH_MKLSPARSE_CHECK(mkl_sparse_z_create_csr(
50 A,
51 indexing,
52 rows,
53 cols,
54 rows_start,
55 rows_end,
56 col_indx,
57 reinterpret_cast<MKL_Complex16*>(values)));
58 }
59
60 template <>
create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES (float))61 void create_bsr<float>(MKL_SPARSE_CREATE_BSR_ARGTYPES(float)) {
62 TORCH_MKLSPARSE_CHECK(mkl_sparse_s_create_bsr(
63 A,
64 indexing,
65 block_layout,
66 rows,
67 cols,
68 block_size,
69 rows_start,
70 rows_end,
71 col_indx,
72 values));
73 }
74 template <>
create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES (double))75 void create_bsr<double>(MKL_SPARSE_CREATE_BSR_ARGTYPES(double)) {
76 TORCH_MKLSPARSE_CHECK(mkl_sparse_d_create_bsr(
77 A,
78 indexing,
79 block_layout,
80 rows,
81 cols,
82 block_size,
83 rows_start,
84 rows_end,
85 col_indx,
86 values));
87 }
88 template <>
create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES (c10::complex<float>))89 void create_bsr<c10::complex<float>>(
90 MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<float>)) {
91 TORCH_MKLSPARSE_CHECK(mkl_sparse_c_create_bsr(
92 A,
93 indexing,
94 block_layout,
95 rows,
96 cols,
97 block_size,
98 rows_start,
99 rows_end,
100 col_indx,
101 reinterpret_cast<MKL_Complex8*>(values)));
102 }
103 template <>
create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES (c10::complex<double>))104 void create_bsr<c10::complex<double>>(
105 MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<double>)) {
106 TORCH_MKLSPARSE_CHECK(mkl_sparse_z_create_bsr(
107 A,
108 indexing,
109 block_layout,
110 rows,
111 cols,
112 block_size,
113 rows_start,
114 rows_end,
115 col_indx,
116 reinterpret_cast<MKL_Complex16*>(values)));
117 }
118
119 template <>
mv(MKL_SPARSE_MV_ARGTYPES (float))120 void mv<float>(MKL_SPARSE_MV_ARGTYPES(float)) {
121 TORCH_MKLSPARSE_CHECK(
122 mkl_sparse_s_mv(operation, alpha, A, descr, x, beta, y));
123 }
124 template <>
mv(MKL_SPARSE_MV_ARGTYPES (double))125 void mv<double>(MKL_SPARSE_MV_ARGTYPES(double)) {
126 TORCH_MKLSPARSE_CHECK(
127 mkl_sparse_d_mv(operation, alpha, A, descr, x, beta, y));
128 }
129 template <>
mv(MKL_SPARSE_MV_ARGTYPES (c10::complex<float>))130 void mv<c10::complex<float>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<float>)) {
131 TORCH_MKLSPARSE_CHECK(mkl_sparse_c_mv(
132 operation,
133 to_mkl_complex<float, MKL_Complex8>(alpha),
134 A,
135 descr,
136 reinterpret_cast<const MKL_Complex8*>(x),
137 to_mkl_complex<float, MKL_Complex8>(beta),
138 reinterpret_cast<MKL_Complex8*>(y)));
139 }
140 template <>
mv(MKL_SPARSE_MV_ARGTYPES (c10::complex<double>))141 void mv<c10::complex<double>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<double>)) {
142 TORCH_MKLSPARSE_CHECK(mkl_sparse_z_mv(
143 operation,
144 to_mkl_complex<double, MKL_Complex16>(alpha),
145 A,
146 descr,
147 reinterpret_cast<const MKL_Complex16*>(x),
148 to_mkl_complex<double, MKL_Complex16>(beta),
149 reinterpret_cast<MKL_Complex16*>(y)));
150 }
151
152 template <>
add(MKL_SPARSE_ADD_ARGTYPES (float))153 void add<float>(MKL_SPARSE_ADD_ARGTYPES(float)) {
154 TORCH_MKLSPARSE_CHECK(mkl_sparse_s_add(operation, A, alpha, B, C));
155 }
156 template <>
add(MKL_SPARSE_ADD_ARGTYPES (double))157 void add<double>(MKL_SPARSE_ADD_ARGTYPES(double)) {
158 TORCH_MKLSPARSE_CHECK(mkl_sparse_d_add(operation, A, alpha, B, C));
159 }
160 template <>
add(MKL_SPARSE_ADD_ARGTYPES (c10::complex<float>))161 void add<c10::complex<float>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<float>)) {
162 TORCH_MKLSPARSE_CHECK(mkl_sparse_c_add(
163 operation, A, to_mkl_complex<float, MKL_Complex8>(alpha), B, C));
164 }
165 template <>
add(MKL_SPARSE_ADD_ARGTYPES (c10::complex<double>))166 void add<c10::complex<double>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<double>)) {
167 TORCH_MKLSPARSE_CHECK(mkl_sparse_z_add(
168 operation, A, to_mkl_complex<double, MKL_Complex16>(alpha), B, C));
169 }
170
171 template <>
export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES (float))172 void export_csr<float>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(float)) {
173 TORCH_MKLSPARSE_CHECK(mkl_sparse_s_export_csr(
174 source, indexing, rows, cols, rows_start, rows_end, col_indx, values));
175 }
176 template <>
export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES (double))177 void export_csr<double>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(double)) {
178 TORCH_MKLSPARSE_CHECK(mkl_sparse_d_export_csr(
179 source, indexing, rows, cols, rows_start, rows_end, col_indx, values));
180 }
181 template <>
export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES (c10::complex<float>))182 void export_csr<c10::complex<float>>(
183 MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<float>)) {
184 TORCH_MKLSPARSE_CHECK(mkl_sparse_c_export_csr(
185 source,
186 indexing,
187 rows,
188 cols,
189 rows_start,
190 rows_end,
191 col_indx,
192 reinterpret_cast<MKL_Complex8**>(values)));
193 }
194 template <>
export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES (c10::complex<double>))195 void export_csr<c10::complex<double>>(
196 MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<double>)) {
197 TORCH_MKLSPARSE_CHECK(mkl_sparse_z_export_csr(
198 source,
199 indexing,
200 rows,
201 cols,
202 rows_start,
203 rows_end,
204 col_indx,
205 reinterpret_cast<MKL_Complex16**>(values)));
206 }
207
208 template <>
mm(MKL_SPARSE_MM_ARGTYPES (float))209 void mm<float>(MKL_SPARSE_MM_ARGTYPES(float)) {
210 TORCH_MKLSPARSE_CHECK(mkl_sparse_s_mm(
211 operation, alpha, A, descr, layout, B, columns, ldb, beta, C, ldc));
212 }
213 template <>
mm(MKL_SPARSE_MM_ARGTYPES (double))214 void mm<double>(MKL_SPARSE_MM_ARGTYPES(double)) {
215 TORCH_MKLSPARSE_CHECK(mkl_sparse_d_mm(
216 operation, alpha, A, descr, layout, B, columns, ldb, beta, C, ldc));
217 }
218 template <>
mm(MKL_SPARSE_MM_ARGTYPES (c10::complex<float>))219 void mm<c10::complex<float>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<float>)) {
220 TORCH_MKLSPARSE_CHECK(mkl_sparse_c_mm(
221 operation,
222 to_mkl_complex<float, MKL_Complex8>(alpha),
223 A,
224 descr,
225 layout,
226 reinterpret_cast<const MKL_Complex8*>(B),
227 columns,
228 ldb,
229 to_mkl_complex<float, MKL_Complex8>(beta),
230 reinterpret_cast<MKL_Complex8*>(C),
231 ldc));
232 }
233 template <>
mm(MKL_SPARSE_MM_ARGTYPES (c10::complex<double>))234 void mm<c10::complex<double>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<double>)) {
235 TORCH_MKLSPARSE_CHECK(mkl_sparse_z_mm(
236 operation,
237 to_mkl_complex<double, MKL_Complex16>(alpha),
238 A,
239 descr,
240 layout,
241 reinterpret_cast<const MKL_Complex16*>(B),
242 columns,
243 ldb,
244 to_mkl_complex<double, MKL_Complex16>(beta),
245 reinterpret_cast<MKL_Complex16*>(C),
246 ldc));
247 }
248
249 template <>
spmmd(MKL_SPARSE_SPMMD_ARGTYPES (float))250 void spmmd<float>(MKL_SPARSE_SPMMD_ARGTYPES(float)) {
251 TORCH_MKLSPARSE_CHECK(mkl_sparse_s_spmmd(
252 operation, A, B, layout, C, ldc));
253 }
254 template <>
spmmd(MKL_SPARSE_SPMMD_ARGTYPES (double))255 void spmmd<double>(MKL_SPARSE_SPMMD_ARGTYPES(double)) {
256 TORCH_MKLSPARSE_CHECK(mkl_sparse_d_spmmd(
257 operation, A, B, layout, C, ldc));
258 }
259 template <>
spmmd(MKL_SPARSE_SPMMD_ARGTYPES (c10::complex<float>))260 void spmmd<c10::complex<float>>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<float>)) {
261 TORCH_MKLSPARSE_CHECK(mkl_sparse_c_spmmd(
262 operation,
263 A,
264 B,
265 layout,
266 reinterpret_cast<MKL_Complex8*>(C),
267 ldc));
268 }
269 template <>
spmmd(MKL_SPARSE_SPMMD_ARGTYPES (c10::complex<double>))270 void spmmd<c10::complex<double>>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<double>)) {
271 TORCH_MKLSPARSE_CHECK(mkl_sparse_z_spmmd(
272 operation,
273 A,
274 B,
275 layout,
276 reinterpret_cast<MKL_Complex16*>(C),
277 ldc));
278 }
279
280 template <>
trsv(MKL_SPARSE_TRSV_ARGTYPES (float))281 sparse_status_t trsv<float>(MKL_SPARSE_TRSV_ARGTYPES(float)) {
282 sparse_status_t status = mkl_sparse_s_trsv(operation, alpha, A, descr, x, y);
283 TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_s_trsv");
284 return status;
285 }
286 template <>
trsv(MKL_SPARSE_TRSV_ARGTYPES (double))287 sparse_status_t trsv<double>(MKL_SPARSE_TRSV_ARGTYPES(double)) {
288 sparse_status_t status = mkl_sparse_d_trsv(operation, alpha, A, descr, x, y);
289 TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_d_trsv");
290 return status;
291 }
292 template <>
trsv(MKL_SPARSE_TRSV_ARGTYPES (c10::complex<float>))293 sparse_status_t trsv<c10::complex<float>>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex<float>)) {
294 sparse_status_t status = mkl_sparse_c_trsv(
295 operation,
296 to_mkl_complex<float, MKL_Complex8>(alpha),
297 A,
298 descr,
299 reinterpret_cast<const MKL_Complex8*>(x),
300 reinterpret_cast<MKL_Complex8*>(y));
301 TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_c_trsv");
302 return status;
303 }
304 template <>
trsv(MKL_SPARSE_TRSV_ARGTYPES (c10::complex<double>))305 sparse_status_t trsv<c10::complex<double>>(
306 MKL_SPARSE_TRSV_ARGTYPES(c10::complex<double>)) {
307 sparse_status_t status = mkl_sparse_z_trsv(
308 operation,
309 to_mkl_complex<double, MKL_Complex16>(alpha),
310 A,
311 descr,
312 reinterpret_cast<const MKL_Complex16*>(x),
313 reinterpret_cast<MKL_Complex16*>(y));
314 TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_z_trsv");
315 return status;
316 }
317
318 template <>
trsm(MKL_SPARSE_TRSM_ARGTYPES (float))319 sparse_status_t trsm<float>(MKL_SPARSE_TRSM_ARGTYPES(float)) {
320 sparse_status_t status = mkl_sparse_s_trsm(
321 operation, alpha, A, descr, layout, x, columns, ldx, y, ldy);
322 TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_s_trsm");
323 return status;
324 }
325 template <>
trsm(MKL_SPARSE_TRSM_ARGTYPES (double))326 sparse_status_t trsm<double>(MKL_SPARSE_TRSM_ARGTYPES(double)) {
327 sparse_status_t status = mkl_sparse_d_trsm(
328 operation, alpha, A, descr, layout, x, columns, ldx, y, ldy);
329 TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_d_trsm");
330 return status;
331 }
332 template <>
trsm(MKL_SPARSE_TRSM_ARGTYPES (c10::complex<float>))333 sparse_status_t trsm<c10::complex<float>>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex<float>)) {
334 sparse_status_t status = mkl_sparse_c_trsm(
335 operation,
336 to_mkl_complex<float, MKL_Complex8>(alpha),
337 A,
338 descr,
339 layout,
340 reinterpret_cast<const MKL_Complex8*>(x),
341 columns,
342 ldx,
343 reinterpret_cast<MKL_Complex8*>(y),
344 ldy);
345 TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_c_trsm");
346 return status;
347 }
348 template <>
trsm(MKL_SPARSE_TRSM_ARGTYPES (c10::complex<double>))349 sparse_status_t trsm<c10::complex<double>>(
350 MKL_SPARSE_TRSM_ARGTYPES(c10::complex<double>)) {
351 sparse_status_t status = mkl_sparse_z_trsm(
352 operation,
353 to_mkl_complex<double, MKL_Complex16>(alpha),
354 A,
355 descr,
356 layout,
357 reinterpret_cast<const MKL_Complex16*>(x),
358 columns,
359 ldx,
360 reinterpret_cast<MKL_Complex16*>(y),
361 ldy);
362 TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_z_trsm");
363 return status;
364 }
365
366 } // namespace at::mkl::sparse
367