1 #pragma once
2
3 /*
4 Provides a subset of MKL Sparse BLAS functions as templates:
5
6 mv<scalar_t>(operation, alpha, A, descr, x, beta, y)
7
8 where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
9 The functions are available in at::mkl::sparse namespace.
10 */
11
12 #include <c10/util/Exception.h>
13 #include <c10/util/complex.h>
14
15 #include <mkl_spblas.h>
16
17 namespace at::mkl::sparse {
18
19 #define MKL_SPARSE_CREATE_CSR_ARGTYPES(scalar_t) \
20 sparse_matrix_t *A, const sparse_index_base_t indexing, const MKL_INT rows, \
21 const MKL_INT cols, MKL_INT *rows_start, MKL_INT *rows_end, \
22 MKL_INT *col_indx, scalar_t *values
23
24 template <typename scalar_t>
create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES (scalar_t))25 inline void create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES(scalar_t)) {
26 TORCH_INTERNAL_ASSERT(
27 false,
28 "at::mkl::sparse::create_csr: not implemented for ",
29 typeid(scalar_t).name());
30 }
31
32 template <>
33 void create_csr<float>(MKL_SPARSE_CREATE_CSR_ARGTYPES(float));
34 template <>
35 void create_csr<double>(MKL_SPARSE_CREATE_CSR_ARGTYPES(double));
36 template <>
37 void create_csr<c10::complex<float>>(
38 MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<float>));
39 template <>
40 void create_csr<c10::complex<double>>(
41 MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<double>));
42
43 #define MKL_SPARSE_CREATE_BSR_ARGTYPES(scalar_t) \
44 sparse_matrix_t *A, const sparse_index_base_t indexing, \
45 const sparse_layout_t block_layout, const MKL_INT rows, \
46 const MKL_INT cols, MKL_INT block_size, MKL_INT *rows_start, \
47 MKL_INT *rows_end, MKL_INT *col_indx, scalar_t *values
48
49 template <typename scalar_t>
create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES (scalar_t))50 inline void create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES(scalar_t)) {
51 TORCH_INTERNAL_ASSERT(
52 false,
53 "at::mkl::sparse::create_bsr: not implemented for ",
54 typeid(scalar_t).name());
55 }
56
57 template <>
58 void create_bsr<float>(MKL_SPARSE_CREATE_BSR_ARGTYPES(float));
59 template <>
60 void create_bsr<double>(MKL_SPARSE_CREATE_BSR_ARGTYPES(double));
61 template <>
62 void create_bsr<c10::complex<float>>(
63 MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<float>));
64 template <>
65 void create_bsr<c10::complex<double>>(
66 MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<double>));
67
68 #define MKL_SPARSE_MV_ARGTYPES(scalar_t) \
69 const sparse_operation_t operation, const scalar_t alpha, \
70 const sparse_matrix_t A, const struct matrix_descr descr, \
71 const scalar_t *x, const scalar_t beta, scalar_t *y
72
73 template <typename scalar_t>
mv(MKL_SPARSE_MV_ARGTYPES (scalar_t))74 inline void mv(MKL_SPARSE_MV_ARGTYPES(scalar_t)) {
75 TORCH_INTERNAL_ASSERT(
76 false,
77 "at::mkl::sparse::mv: not implemented for ",
78 typeid(scalar_t).name());
79 }
80
81 template <>
82 void mv<float>(MKL_SPARSE_MV_ARGTYPES(float));
83 template <>
84 void mv<double>(MKL_SPARSE_MV_ARGTYPES(double));
85 template <>
86 void mv<c10::complex<float>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<float>));
87 template <>
88 void mv<c10::complex<double>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<double>));
89
90 #define MKL_SPARSE_ADD_ARGTYPES(scalar_t) \
91 const sparse_operation_t operation, const sparse_matrix_t A, \
92 const scalar_t alpha, const sparse_matrix_t B, sparse_matrix_t *C
93
94 template <typename scalar_t>
add(MKL_SPARSE_ADD_ARGTYPES (scalar_t))95 inline void add(MKL_SPARSE_ADD_ARGTYPES(scalar_t)) {
96 TORCH_INTERNAL_ASSERT(
97 false,
98 "at::mkl::sparse::add: not implemented for ",
99 typeid(scalar_t).name());
100 }
101
102 template <>
103 void add<float>(MKL_SPARSE_ADD_ARGTYPES(float));
104 template <>
105 void add<double>(MKL_SPARSE_ADD_ARGTYPES(double));
106 template <>
107 void add<c10::complex<float>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<float>));
108 template <>
109 void add<c10::complex<double>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<double>));
110
111 #define MKL_SPARSE_EXPORT_CSR_ARGTYPES(scalar_t) \
112 const sparse_matrix_t source, sparse_index_base_t *indexing, MKL_INT *rows, \
113 MKL_INT *cols, MKL_INT **rows_start, MKL_INT **rows_end, \
114 MKL_INT **col_indx, scalar_t **values
115
116 template <typename scalar_t>
export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES (scalar_t))117 inline void export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES(scalar_t)) {
118 TORCH_INTERNAL_ASSERT(
119 false,
120 "at::mkl::sparse::export_csr: not implemented for ",
121 typeid(scalar_t).name());
122 }
123
124 template <>
125 void export_csr<float>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(float));
126 template <>
127 void export_csr<double>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(double));
128 template <>
129 void export_csr<c10::complex<float>>(
130 MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<float>));
131 template <>
132 void export_csr<c10::complex<double>>(
133 MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<double>));
134
135 #define MKL_SPARSE_MM_ARGTYPES(scalar_t) \
136 const sparse_operation_t operation, const scalar_t alpha, \
137 const sparse_matrix_t A, const struct matrix_descr descr, \
138 const sparse_layout_t layout, const scalar_t *B, const MKL_INT columns, \
139 const MKL_INT ldb, const scalar_t beta, scalar_t *C, const MKL_INT ldc
140
141 template <typename scalar_t>
mm(MKL_SPARSE_MM_ARGTYPES (scalar_t))142 inline void mm(MKL_SPARSE_MM_ARGTYPES(scalar_t)) {
143 TORCH_INTERNAL_ASSERT(
144 false,
145 "at::mkl::sparse::mm: not implemented for ",
146 typeid(scalar_t).name());
147 }
148
149 template <>
150 void mm<float>(MKL_SPARSE_MM_ARGTYPES(float));
151 template <>
152 void mm<double>(MKL_SPARSE_MM_ARGTYPES(double));
153 template <>
154 void mm<c10::complex<float>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<float>));
155 template <>
156 void mm<c10::complex<double>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<double>));
157
158 #define MKL_SPARSE_SPMMD_ARGTYPES(scalar_t) \
159 const sparse_operation_t operation, const sparse_matrix_t A, \
160 const sparse_matrix_t B, const sparse_layout_t layout, scalar_t *C, \
161 const MKL_INT ldc
162
163 template <typename scalar_t>
spmmd(MKL_SPARSE_SPMMD_ARGTYPES (scalar_t))164 inline void spmmd(MKL_SPARSE_SPMMD_ARGTYPES(scalar_t)) {
165 TORCH_INTERNAL_ASSERT(
166 false,
167 "at::mkl::sparse::spmmd: not implemented for ",
168 typeid(scalar_t).name());
169 }
170
171 template <>
172 void spmmd<float>(MKL_SPARSE_SPMMD_ARGTYPES(float));
173 template <>
174 void spmmd<double>(MKL_SPARSE_SPMMD_ARGTYPES(double));
175 template <>
176 void spmmd<c10::complex<float>>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<float>));
177 template <>
178 void spmmd<c10::complex<double>>(
179 MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<double>));
180
181 #define MKL_SPARSE_TRSV_ARGTYPES(scalar_t) \
182 const sparse_operation_t operation, const scalar_t alpha, \
183 const sparse_matrix_t A, const struct matrix_descr descr, \
184 const scalar_t *x, scalar_t *y
185
186 template <typename scalar_t>
trsv(MKL_SPARSE_TRSV_ARGTYPES (scalar_t))187 inline sparse_status_t trsv(MKL_SPARSE_TRSV_ARGTYPES(scalar_t)) {
188 TORCH_INTERNAL_ASSERT(
189 false,
190 "at::mkl::sparse::trsv: not implemented for ",
191 typeid(scalar_t).name());
192 }
193
194 template <>
195 sparse_status_t trsv<float>(MKL_SPARSE_TRSV_ARGTYPES(float));
196 template <>
197 sparse_status_t trsv<double>(MKL_SPARSE_TRSV_ARGTYPES(double));
198 template <>
199 sparse_status_t trsv<c10::complex<float>>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex<float>));
200 template <>
201 sparse_status_t trsv<c10::complex<double>>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex<double>));
202
203 #define MKL_SPARSE_TRSM_ARGTYPES(scalar_t) \
204 const sparse_operation_t operation, const scalar_t alpha, \
205 const sparse_matrix_t A, const struct matrix_descr descr, \
206 const sparse_layout_t layout, const scalar_t *x, const MKL_INT columns, \
207 const MKL_INT ldx, scalar_t *y, const MKL_INT ldy
208
209 template <typename scalar_t>
trsm(MKL_SPARSE_TRSM_ARGTYPES (scalar_t))210 inline sparse_status_t trsm(MKL_SPARSE_TRSM_ARGTYPES(scalar_t)) {
211 TORCH_INTERNAL_ASSERT(
212 false,
213 "at::mkl::sparse::trsm: not implemented for ",
214 typeid(scalar_t).name());
215 }
216
217 template <>
218 sparse_status_t trsm<float>(MKL_SPARSE_TRSM_ARGTYPES(float));
219 template <>
220 sparse_status_t trsm<double>(MKL_SPARSE_TRSM_ARGTYPES(double));
221 template <>
222 sparse_status_t trsm<c10::complex<float>>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex<float>));
223 template <>
224 sparse_status_t trsm<c10::complex<double>>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex<double>));
225
226 } // namespace at::mkl::sparse
227