xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mkl/SparseBlas.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2   Provides the implementations of MKL Sparse BLAS function templates.
3 */
4 #define TORCH_ASSERT_NO_OPERATORS
5 #include <ATen/mkl/Exceptions.h>
6 #include <ATen/mkl/SparseBlas.h>
7 
8 namespace at::mkl::sparse {
9 
10 namespace {
11 
12 template <typename scalar_t, typename MKL_Complex>
to_mkl_complex(c10::complex<scalar_t> scalar)13 MKL_Complex to_mkl_complex(c10::complex<scalar_t> scalar) {
14   MKL_Complex mkl_scalar;
15   mkl_scalar.real = scalar.real();
16   mkl_scalar.imag = scalar.imag();
17   return mkl_scalar;
18 }
19 
20 } // namespace
21 
22 
23 template <>
create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES (float))24 void create_csr<float>(MKL_SPARSE_CREATE_CSR_ARGTYPES(float)) {
25   TORCH_MKLSPARSE_CHECK(mkl_sparse_s_create_csr(
26       A, indexing, rows, cols, rows_start, rows_end, col_indx, values));
27 }
28 template <>
create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES (double))29 void create_csr<double>(MKL_SPARSE_CREATE_CSR_ARGTYPES(double)) {
30   TORCH_MKLSPARSE_CHECK(mkl_sparse_d_create_csr(
31       A, indexing, rows, cols, rows_start, rows_end, col_indx, values));
32 }
33 template <>
create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES (c10::complex<float>))34 void create_csr<c10::complex<float>>(
35     MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<float>)) {
36   TORCH_MKLSPARSE_CHECK(mkl_sparse_c_create_csr(
37       A,
38       indexing,
39       rows,
40       cols,
41       rows_start,
42       rows_end,
43       col_indx,
44       reinterpret_cast<MKL_Complex8*>(values)));
45 }
46 template <>
create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES (c10::complex<double>))47 void create_csr<c10::complex<double>>(
48     MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<double>)) {
49   TORCH_MKLSPARSE_CHECK(mkl_sparse_z_create_csr(
50       A,
51       indexing,
52       rows,
53       cols,
54       rows_start,
55       rows_end,
56       col_indx,
57       reinterpret_cast<MKL_Complex16*>(values)));
58 }
59 
60 template <>
create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES (float))61 void create_bsr<float>(MKL_SPARSE_CREATE_BSR_ARGTYPES(float)) {
62   TORCH_MKLSPARSE_CHECK(mkl_sparse_s_create_bsr(
63       A,
64       indexing,
65       block_layout,
66       rows,
67       cols,
68       block_size,
69       rows_start,
70       rows_end,
71       col_indx,
72       values));
73 }
74 template <>
create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES (double))75 void create_bsr<double>(MKL_SPARSE_CREATE_BSR_ARGTYPES(double)) {
76   TORCH_MKLSPARSE_CHECK(mkl_sparse_d_create_bsr(
77       A,
78       indexing,
79       block_layout,
80       rows,
81       cols,
82       block_size,
83       rows_start,
84       rows_end,
85       col_indx,
86       values));
87 }
88 template <>
create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES (c10::complex<float>))89 void create_bsr<c10::complex<float>>(
90     MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<float>)) {
91   TORCH_MKLSPARSE_CHECK(mkl_sparse_c_create_bsr(
92       A,
93       indexing,
94       block_layout,
95       rows,
96       cols,
97       block_size,
98       rows_start,
99       rows_end,
100       col_indx,
101       reinterpret_cast<MKL_Complex8*>(values)));
102 }
103 template <>
create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES (c10::complex<double>))104 void create_bsr<c10::complex<double>>(
105     MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<double>)) {
106   TORCH_MKLSPARSE_CHECK(mkl_sparse_z_create_bsr(
107       A,
108       indexing,
109       block_layout,
110       rows,
111       cols,
112       block_size,
113       rows_start,
114       rows_end,
115       col_indx,
116       reinterpret_cast<MKL_Complex16*>(values)));
117 }
118 
119 template <>
mv(MKL_SPARSE_MV_ARGTYPES (float))120 void mv<float>(MKL_SPARSE_MV_ARGTYPES(float)) {
121   TORCH_MKLSPARSE_CHECK(
122       mkl_sparse_s_mv(operation, alpha, A, descr, x, beta, y));
123 }
124 template <>
mv(MKL_SPARSE_MV_ARGTYPES (double))125 void mv<double>(MKL_SPARSE_MV_ARGTYPES(double)) {
126   TORCH_MKLSPARSE_CHECK(
127       mkl_sparse_d_mv(operation, alpha, A, descr, x, beta, y));
128 }
129 template <>
mv(MKL_SPARSE_MV_ARGTYPES (c10::complex<float>))130 void mv<c10::complex<float>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<float>)) {
131   TORCH_MKLSPARSE_CHECK(mkl_sparse_c_mv(
132       operation,
133       to_mkl_complex<float, MKL_Complex8>(alpha),
134       A,
135       descr,
136       reinterpret_cast<const MKL_Complex8*>(x),
137       to_mkl_complex<float, MKL_Complex8>(beta),
138       reinterpret_cast<MKL_Complex8*>(y)));
139 }
140 template <>
mv(MKL_SPARSE_MV_ARGTYPES (c10::complex<double>))141 void mv<c10::complex<double>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<double>)) {
142   TORCH_MKLSPARSE_CHECK(mkl_sparse_z_mv(
143       operation,
144       to_mkl_complex<double, MKL_Complex16>(alpha),
145       A,
146       descr,
147       reinterpret_cast<const MKL_Complex16*>(x),
148       to_mkl_complex<double, MKL_Complex16>(beta),
149       reinterpret_cast<MKL_Complex16*>(y)));
150 }
151 
152 template <>
add(MKL_SPARSE_ADD_ARGTYPES (float))153 void add<float>(MKL_SPARSE_ADD_ARGTYPES(float)) {
154   TORCH_MKLSPARSE_CHECK(mkl_sparse_s_add(operation, A, alpha, B, C));
155 }
156 template <>
add(MKL_SPARSE_ADD_ARGTYPES (double))157 void add<double>(MKL_SPARSE_ADD_ARGTYPES(double)) {
158   TORCH_MKLSPARSE_CHECK(mkl_sparse_d_add(operation, A, alpha, B, C));
159 }
160 template <>
add(MKL_SPARSE_ADD_ARGTYPES (c10::complex<float>))161 void add<c10::complex<float>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<float>)) {
162   TORCH_MKLSPARSE_CHECK(mkl_sparse_c_add(
163       operation, A, to_mkl_complex<float, MKL_Complex8>(alpha), B, C));
164 }
165 template <>
add(MKL_SPARSE_ADD_ARGTYPES (c10::complex<double>))166 void add<c10::complex<double>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<double>)) {
167   TORCH_MKLSPARSE_CHECK(mkl_sparse_z_add(
168       operation, A, to_mkl_complex<double, MKL_Complex16>(alpha), B, C));
169 }
170 
171 template <>
export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES (float))172 void export_csr<float>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(float)) {
173   TORCH_MKLSPARSE_CHECK(mkl_sparse_s_export_csr(
174       source, indexing, rows, cols, rows_start, rows_end, col_indx, values));
175 }
176 template <>
export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES (double))177 void export_csr<double>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(double)) {
178   TORCH_MKLSPARSE_CHECK(mkl_sparse_d_export_csr(
179       source, indexing, rows, cols, rows_start, rows_end, col_indx, values));
180 }
181 template <>
export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES (c10::complex<float>))182 void export_csr<c10::complex<float>>(
183     MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<float>)) {
184   TORCH_MKLSPARSE_CHECK(mkl_sparse_c_export_csr(
185       source,
186       indexing,
187       rows,
188       cols,
189       rows_start,
190       rows_end,
191       col_indx,
192       reinterpret_cast<MKL_Complex8**>(values)));
193 }
194 template <>
export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES (c10::complex<double>))195 void export_csr<c10::complex<double>>(
196     MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<double>)) {
197   TORCH_MKLSPARSE_CHECK(mkl_sparse_z_export_csr(
198       source,
199       indexing,
200       rows,
201       cols,
202       rows_start,
203       rows_end,
204       col_indx,
205       reinterpret_cast<MKL_Complex16**>(values)));
206 }
207 
208 template <>
mm(MKL_SPARSE_MM_ARGTYPES (float))209 void mm<float>(MKL_SPARSE_MM_ARGTYPES(float)) {
210   TORCH_MKLSPARSE_CHECK(mkl_sparse_s_mm(
211       operation, alpha, A, descr, layout, B, columns, ldb, beta, C, ldc));
212 }
213 template <>
mm(MKL_SPARSE_MM_ARGTYPES (double))214 void mm<double>(MKL_SPARSE_MM_ARGTYPES(double)) {
215   TORCH_MKLSPARSE_CHECK(mkl_sparse_d_mm(
216       operation, alpha, A, descr, layout, B, columns, ldb, beta, C, ldc));
217 }
218 template <>
mm(MKL_SPARSE_MM_ARGTYPES (c10::complex<float>))219 void mm<c10::complex<float>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<float>)) {
220   TORCH_MKLSPARSE_CHECK(mkl_sparse_c_mm(
221       operation,
222       to_mkl_complex<float, MKL_Complex8>(alpha),
223       A,
224       descr,
225       layout,
226       reinterpret_cast<const MKL_Complex8*>(B),
227       columns,
228       ldb,
229       to_mkl_complex<float, MKL_Complex8>(beta),
230       reinterpret_cast<MKL_Complex8*>(C),
231       ldc));
232 }
233 template <>
mm(MKL_SPARSE_MM_ARGTYPES (c10::complex<double>))234 void mm<c10::complex<double>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<double>)) {
235   TORCH_MKLSPARSE_CHECK(mkl_sparse_z_mm(
236       operation,
237       to_mkl_complex<double, MKL_Complex16>(alpha),
238       A,
239       descr,
240       layout,
241       reinterpret_cast<const MKL_Complex16*>(B),
242       columns,
243       ldb,
244       to_mkl_complex<double, MKL_Complex16>(beta),
245       reinterpret_cast<MKL_Complex16*>(C),
246       ldc));
247 }
248 
249 template <>
spmmd(MKL_SPARSE_SPMMD_ARGTYPES (float))250 void spmmd<float>(MKL_SPARSE_SPMMD_ARGTYPES(float)) {
251   TORCH_MKLSPARSE_CHECK(mkl_sparse_s_spmmd(
252       operation, A, B, layout, C, ldc));
253 }
254 template <>
spmmd(MKL_SPARSE_SPMMD_ARGTYPES (double))255 void spmmd<double>(MKL_SPARSE_SPMMD_ARGTYPES(double)) {
256   TORCH_MKLSPARSE_CHECK(mkl_sparse_d_spmmd(
257       operation, A, B, layout, C, ldc));
258 }
259 template <>
spmmd(MKL_SPARSE_SPMMD_ARGTYPES (c10::complex<float>))260 void spmmd<c10::complex<float>>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<float>)) {
261   TORCH_MKLSPARSE_CHECK(mkl_sparse_c_spmmd(
262       operation,
263       A,
264       B,
265       layout,
266       reinterpret_cast<MKL_Complex8*>(C),
267       ldc));
268 }
269 template <>
spmmd(MKL_SPARSE_SPMMD_ARGTYPES (c10::complex<double>))270 void spmmd<c10::complex<double>>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<double>)) {
271   TORCH_MKLSPARSE_CHECK(mkl_sparse_z_spmmd(
272       operation,
273       A,
274       B,
275       layout,
276       reinterpret_cast<MKL_Complex16*>(C),
277       ldc));
278 }
279 
280 template <>
trsv(MKL_SPARSE_TRSV_ARGTYPES (float))281 sparse_status_t trsv<float>(MKL_SPARSE_TRSV_ARGTYPES(float)) {
282   sparse_status_t status = mkl_sparse_s_trsv(operation, alpha, A, descr, x, y);
283   TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_s_trsv");
284   return status;
285 }
286 template <>
trsv(MKL_SPARSE_TRSV_ARGTYPES (double))287 sparse_status_t trsv<double>(MKL_SPARSE_TRSV_ARGTYPES(double)) {
288   sparse_status_t status = mkl_sparse_d_trsv(operation, alpha, A, descr, x, y);
289   TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_d_trsv");
290   return status;
291 }
292 template <>
trsv(MKL_SPARSE_TRSV_ARGTYPES (c10::complex<float>))293 sparse_status_t trsv<c10::complex<float>>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex<float>)) {
294   sparse_status_t status = mkl_sparse_c_trsv(
295       operation,
296       to_mkl_complex<float, MKL_Complex8>(alpha),
297       A,
298       descr,
299       reinterpret_cast<const MKL_Complex8*>(x),
300       reinterpret_cast<MKL_Complex8*>(y));
301   TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_c_trsv");
302   return status;
303 }
304 template <>
trsv(MKL_SPARSE_TRSV_ARGTYPES (c10::complex<double>))305 sparse_status_t trsv<c10::complex<double>>(
306     MKL_SPARSE_TRSV_ARGTYPES(c10::complex<double>)) {
307   sparse_status_t status = mkl_sparse_z_trsv(
308       operation,
309       to_mkl_complex<double, MKL_Complex16>(alpha),
310       A,
311       descr,
312       reinterpret_cast<const MKL_Complex16*>(x),
313       reinterpret_cast<MKL_Complex16*>(y));
314   TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_z_trsv");
315   return status;
316 }
317 
318 template <>
trsm(MKL_SPARSE_TRSM_ARGTYPES (float))319 sparse_status_t trsm<float>(MKL_SPARSE_TRSM_ARGTYPES(float)) {
320   sparse_status_t status = mkl_sparse_s_trsm(
321       operation, alpha, A, descr, layout, x, columns, ldx, y, ldy);
322   TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_s_trsm");
323   return status;
324 }
325 template <>
trsm(MKL_SPARSE_TRSM_ARGTYPES (double))326 sparse_status_t trsm<double>(MKL_SPARSE_TRSM_ARGTYPES(double)) {
327   sparse_status_t status = mkl_sparse_d_trsm(
328       operation, alpha, A, descr, layout, x, columns, ldx, y, ldy);
329   TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_d_trsm");
330   return status;
331 }
332 template <>
trsm(MKL_SPARSE_TRSM_ARGTYPES (c10::complex<float>))333 sparse_status_t trsm<c10::complex<float>>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex<float>)) {
334   sparse_status_t status = mkl_sparse_c_trsm(
335       operation,
336       to_mkl_complex<float, MKL_Complex8>(alpha),
337       A,
338       descr,
339       layout,
340       reinterpret_cast<const MKL_Complex8*>(x),
341       columns,
342       ldx,
343       reinterpret_cast<MKL_Complex8*>(y),
344       ldy);
345   TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_c_trsm");
346   return status;
347 }
348 template <>
trsm(MKL_SPARSE_TRSM_ARGTYPES (c10::complex<double>))349 sparse_status_t trsm<c10::complex<double>>(
350     MKL_SPARSE_TRSM_ARGTYPES(c10::complex<double>)) {
351   sparse_status_t status = mkl_sparse_z_trsm(
352       operation,
353       to_mkl_complex<double, MKL_Complex16>(alpha),
354       A,
355       descr,
356       layout,
357       reinterpret_cast<const MKL_Complex16*>(x),
358       columns,
359       ldx,
360       reinterpret_cast<MKL_Complex16*>(y),
361       ldy);
362   TORCH_MKLSPARSE_CHECK_SUCCESS_OR_INVALID(status, "mkl_sparse_z_trsm");
363   return status;
364 }
365 
366 } // namespace at::mkl::sparse
367