xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mkl/SparseBlas.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 /*
4   Provides a subset of MKL Sparse BLAS functions as templates:
5 
6     mv<scalar_t>(operation, alpha, A, descr, x, beta, y)
7 
8   where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
9   The functions are available in at::mkl::sparse namespace.
10 */
11 
12 #include <c10/util/Exception.h>
13 #include <c10/util/complex.h>
14 
15 #include <mkl_spblas.h>
16 
17 namespace at::mkl::sparse {
18 
19 #define MKL_SPARSE_CREATE_CSR_ARGTYPES(scalar_t)                              \
20   sparse_matrix_t *A, const sparse_index_base_t indexing, const MKL_INT rows, \
21       const MKL_INT cols, MKL_INT *rows_start, MKL_INT *rows_end,             \
22       MKL_INT *col_indx, scalar_t *values
23 
24 template <typename scalar_t>
create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES (scalar_t))25 inline void create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES(scalar_t)) {
26   TORCH_INTERNAL_ASSERT(
27       false,
28       "at::mkl::sparse::create_csr: not implemented for ",
29       typeid(scalar_t).name());
30 }
31 
32 template <>
33 void create_csr<float>(MKL_SPARSE_CREATE_CSR_ARGTYPES(float));
34 template <>
35 void create_csr<double>(MKL_SPARSE_CREATE_CSR_ARGTYPES(double));
36 template <>
37 void create_csr<c10::complex<float>>(
38     MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<float>));
39 template <>
40 void create_csr<c10::complex<double>>(
41     MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<double>));
42 
43 #define MKL_SPARSE_CREATE_BSR_ARGTYPES(scalar_t)                   \
44   sparse_matrix_t *A, const sparse_index_base_t indexing,          \
45       const sparse_layout_t block_layout, const MKL_INT rows,      \
46       const MKL_INT cols, MKL_INT block_size, MKL_INT *rows_start, \
47       MKL_INT *rows_end, MKL_INT *col_indx, scalar_t *values
48 
49 template <typename scalar_t>
create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES (scalar_t))50 inline void create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES(scalar_t)) {
51   TORCH_INTERNAL_ASSERT(
52       false,
53       "at::mkl::sparse::create_bsr: not implemented for ",
54       typeid(scalar_t).name());
55 }
56 
57 template <>
58 void create_bsr<float>(MKL_SPARSE_CREATE_BSR_ARGTYPES(float));
59 template <>
60 void create_bsr<double>(MKL_SPARSE_CREATE_BSR_ARGTYPES(double));
61 template <>
62 void create_bsr<c10::complex<float>>(
63     MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<float>));
64 template <>
65 void create_bsr<c10::complex<double>>(
66     MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<double>));
67 
68 #define MKL_SPARSE_MV_ARGTYPES(scalar_t)                        \
69   const sparse_operation_t operation, const scalar_t alpha,     \
70       const sparse_matrix_t A, const struct matrix_descr descr, \
71       const scalar_t *x, const scalar_t beta, scalar_t *y
72 
73 template <typename scalar_t>
mv(MKL_SPARSE_MV_ARGTYPES (scalar_t))74 inline void mv(MKL_SPARSE_MV_ARGTYPES(scalar_t)) {
75   TORCH_INTERNAL_ASSERT(
76       false,
77       "at::mkl::sparse::mv: not implemented for ",
78       typeid(scalar_t).name());
79 }
80 
81 template <>
82 void mv<float>(MKL_SPARSE_MV_ARGTYPES(float));
83 template <>
84 void mv<double>(MKL_SPARSE_MV_ARGTYPES(double));
85 template <>
86 void mv<c10::complex<float>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<float>));
87 template <>
88 void mv<c10::complex<double>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<double>));
89 
90 #define MKL_SPARSE_ADD_ARGTYPES(scalar_t)                      \
91   const sparse_operation_t operation, const sparse_matrix_t A, \
92       const scalar_t alpha, const sparse_matrix_t B, sparse_matrix_t *C
93 
94 template <typename scalar_t>
add(MKL_SPARSE_ADD_ARGTYPES (scalar_t))95 inline void add(MKL_SPARSE_ADD_ARGTYPES(scalar_t)) {
96   TORCH_INTERNAL_ASSERT(
97       false,
98       "at::mkl::sparse::add: not implemented for ",
99       typeid(scalar_t).name());
100 }
101 
102 template <>
103 void add<float>(MKL_SPARSE_ADD_ARGTYPES(float));
104 template <>
105 void add<double>(MKL_SPARSE_ADD_ARGTYPES(double));
106 template <>
107 void add<c10::complex<float>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<float>));
108 template <>
109 void add<c10::complex<double>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<double>));
110 
111 #define MKL_SPARSE_EXPORT_CSR_ARGTYPES(scalar_t)                              \
112   const sparse_matrix_t source, sparse_index_base_t *indexing, MKL_INT *rows, \
113       MKL_INT *cols, MKL_INT **rows_start, MKL_INT **rows_end,                \
114       MKL_INT **col_indx, scalar_t **values
115 
116 template <typename scalar_t>
export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES (scalar_t))117 inline void export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES(scalar_t)) {
118   TORCH_INTERNAL_ASSERT(
119       false,
120       "at::mkl::sparse::export_csr: not implemented for ",
121       typeid(scalar_t).name());
122 }
123 
124 template <>
125 void export_csr<float>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(float));
126 template <>
127 void export_csr<double>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(double));
128 template <>
129 void export_csr<c10::complex<float>>(
130     MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<float>));
131 template <>
132 void export_csr<c10::complex<double>>(
133     MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<double>));
134 
135 #define MKL_SPARSE_MM_ARGTYPES(scalar_t)                                      \
136   const sparse_operation_t operation, const scalar_t alpha,                   \
137       const sparse_matrix_t A, const struct matrix_descr descr,               \
138       const sparse_layout_t layout, const scalar_t *B, const MKL_INT columns, \
139       const MKL_INT ldb, const scalar_t beta, scalar_t *C, const MKL_INT ldc
140 
141 template <typename scalar_t>
mm(MKL_SPARSE_MM_ARGTYPES (scalar_t))142 inline void mm(MKL_SPARSE_MM_ARGTYPES(scalar_t)) {
143   TORCH_INTERNAL_ASSERT(
144       false,
145       "at::mkl::sparse::mm: not implemented for ",
146       typeid(scalar_t).name());
147 }
148 
149 template <>
150 void mm<float>(MKL_SPARSE_MM_ARGTYPES(float));
151 template <>
152 void mm<double>(MKL_SPARSE_MM_ARGTYPES(double));
153 template <>
154 void mm<c10::complex<float>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<float>));
155 template <>
156 void mm<c10::complex<double>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<double>));
157 
158 #define MKL_SPARSE_SPMMD_ARGTYPES(scalar_t)                               \
159   const sparse_operation_t operation, const sparse_matrix_t A,            \
160       const sparse_matrix_t B, const sparse_layout_t layout, scalar_t *C, \
161       const MKL_INT ldc
162 
163 template <typename scalar_t>
spmmd(MKL_SPARSE_SPMMD_ARGTYPES (scalar_t))164 inline void spmmd(MKL_SPARSE_SPMMD_ARGTYPES(scalar_t)) {
165   TORCH_INTERNAL_ASSERT(
166       false,
167       "at::mkl::sparse::spmmd: not implemented for ",
168       typeid(scalar_t).name());
169 }
170 
171 template <>
172 void spmmd<float>(MKL_SPARSE_SPMMD_ARGTYPES(float));
173 template <>
174 void spmmd<double>(MKL_SPARSE_SPMMD_ARGTYPES(double));
175 template <>
176 void spmmd<c10::complex<float>>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<float>));
177 template <>
178 void spmmd<c10::complex<double>>(
179     MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<double>));
180 
181 #define MKL_SPARSE_TRSV_ARGTYPES(scalar_t)                      \
182   const sparse_operation_t operation, const scalar_t alpha,     \
183       const sparse_matrix_t A, const struct matrix_descr descr, \
184       const scalar_t *x, scalar_t *y
185 
186 template <typename scalar_t>
trsv(MKL_SPARSE_TRSV_ARGTYPES (scalar_t))187 inline sparse_status_t trsv(MKL_SPARSE_TRSV_ARGTYPES(scalar_t)) {
188   TORCH_INTERNAL_ASSERT(
189       false,
190       "at::mkl::sparse::trsv: not implemented for ",
191       typeid(scalar_t).name());
192 }
193 
194 template <>
195 sparse_status_t trsv<float>(MKL_SPARSE_TRSV_ARGTYPES(float));
196 template <>
197 sparse_status_t trsv<double>(MKL_SPARSE_TRSV_ARGTYPES(double));
198 template <>
199 sparse_status_t trsv<c10::complex<float>>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex<float>));
200 template <>
201 sparse_status_t trsv<c10::complex<double>>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex<double>));
202 
203 #define MKL_SPARSE_TRSM_ARGTYPES(scalar_t)                                    \
204   const sparse_operation_t operation, const scalar_t alpha,                   \
205       const sparse_matrix_t A, const struct matrix_descr descr,               \
206       const sparse_layout_t layout, const scalar_t *x, const MKL_INT columns, \
207       const MKL_INT ldx, scalar_t *y, const MKL_INT ldy
208 
209 template <typename scalar_t>
trsm(MKL_SPARSE_TRSM_ARGTYPES (scalar_t))210 inline sparse_status_t trsm(MKL_SPARSE_TRSM_ARGTYPES(scalar_t)) {
211   TORCH_INTERNAL_ASSERT(
212       false,
213       "at::mkl::sparse::trsm: not implemented for ",
214       typeid(scalar_t).name());
215 }
216 
217 template <>
218 sparse_status_t trsm<float>(MKL_SPARSE_TRSM_ARGTYPES(float));
219 template <>
220 sparse_status_t trsm<double>(MKL_SPARSE_TRSM_ARGTYPES(double));
221 template <>
222 sparse_status_t trsm<c10::complex<float>>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex<float>));
223 template <>
224 sparse_status_t trsm<c10::complex<double>>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex<double>));
225 
226 } // namespace at::mkl::sparse
227