xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/mkl/SparseCsrLinearAlgebra.h>
3 #include <ATen/native/SparseTensorUtils.h>
4 
5 // Don't compile with MKL for macos since linking the sparse MKL routines
6 // needs some build fixes.
7 // Macros source:
8 // https://web.archive.org/web/20191012035921/http://nadeausoftware.com/articles/2012/01/c_c_tip_how_use_compiler_predefined_macros_detect_operating_system
9 #if !AT_MKL_ENABLED() || defined(__APPLE__) || \
10     defined(__MACH__)
11 
12 namespace at {
13 namespace sparse_csr {
_sparse_mm_mkl_(Tensor & self,const SparseCsrTensor & sparse_,const Tensor & dense,const Tensor & t,const Scalar & alpha,const Scalar & beta)14 Tensor& _sparse_mm_mkl_(
15     Tensor& self,
16     const SparseCsrTensor& sparse_,
17     const Tensor& dense,
18     const Tensor& t,
19     const Scalar& alpha,
20     const Scalar& beta) {
21 #if __APPLE__ || __MACH__
22   AT_ERROR("sparse_mm_mkl: MKL support is disabled on macos/iOS.");
23 #else
24   AT_ERROR("sparse_mm_mkl: ATen not compiled with MKL support");
25 #endif
26   return self; // for stopping compiler warnings.
27 }
28 } // namespace native
29 } // namespace at
30 
31 #else // AT_MKL_ENABLED
32 
33 #include <ATen/mkl/Descriptors.h>
34 #include <ATen/mkl/Exceptions.h>
35 #include <ATen/mkl/Limits.h>
36 #include <mkl.h>
37 #include <mkl_spblas.h>
38 
39 #include <ATen/Dispatch.h>
40 #include <ATen/ExpandUtils.h>
41 #include <ATen/SparseCsrTensorImpl.h>
42 
43 namespace at {
44 namespace sparse_csr {
45 
46 #ifdef MKL_ILP64
47 static constexpr ScalarType TORCH_INT_TYPE = at::kLong;
48 #else
49 static constexpr ScalarType TORCH_INT_TYPE = at::kInt;
50 #endif
51 
52 class SparseCsrMKLInterface {
53  private:
54   sparse_matrix_t A{nullptr};
55   matrix_descr desc;
56 
57  public:
SparseCsrMKLInterface(MKL_INT * col_indices,MKL_INT * crow_indices,double * values,MKL_INT nrows,MKL_INT ncols)58   SparseCsrMKLInterface(
59       MKL_INT* col_indices,
60       MKL_INT* crow_indices,
61       double* values,
62       MKL_INT nrows,
63       MKL_INT ncols) {
64     desc.type = SPARSE_MATRIX_TYPE_GENERAL;
65     int retval = mkl_sparse_d_create_csr(
66         &A,
67         SPARSE_INDEX_BASE_ZERO,
68         nrows,
69         ncols,
70         crow_indices,
71         crow_indices + 1,
72         col_indices,
73         values);
74     TORCH_CHECK(
75         retval == 0,
76         "mkl_sparse_d_create_csr failed with error code: ",
77         retval);
78   }
79 
SparseCsrMKLInterface(MKL_INT * col_indices,MKL_INT * crow_indices,float * values,MKL_INT nrows,MKL_INT ncols)80   SparseCsrMKLInterface(
81       MKL_INT* col_indices,
82       MKL_INT* crow_indices,
83       float* values,
84       MKL_INT nrows,
85       MKL_INT ncols) {
86     desc.type = SPARSE_MATRIX_TYPE_GENERAL;
87     int retval = mkl_sparse_s_create_csr(
88         &A,
89         SPARSE_INDEX_BASE_ZERO,
90         nrows,
91         ncols,
92         crow_indices,
93         crow_indices + 1,
94         col_indices,
95         values);
96     TORCH_CHECK(
97         retval == 0,
98         "mkl_sparse_s_create_csr failed with error code: ",
99         retval);
100   }
101 
102  // res(nrows, dense_ncols) = (sparse(nrows * ncols) @ dense(ncols x dense_ncols))
sparse_mm(float * res,float * dense,float alpha,float beta,MKL_INT nrows,MKL_INT ncols,MKL_INT dense_ncols)103   inline void sparse_mm(
104       float* res,
105       float* dense,
106       float alpha,
107       float beta,
108       MKL_INT nrows,
109       MKL_INT ncols,
110       MKL_INT dense_ncols) {
111     int stat;
112     if (dense_ncols == 1) {
113       stat = mkl_sparse_s_mv(
114         SPARSE_OPERATION_NON_TRANSPOSE,
115         alpha,
116         A,
117         desc,
118         dense,
119         beta,
120         res);
121       TORCH_CHECK(stat == 0, "mkl_sparse_s_mv failed with error code: ", stat);
122     } else {
123       stat = mkl_sparse_s_mm(
124         SPARSE_OPERATION_NON_TRANSPOSE,
125         alpha,
126         A,
127         desc,
128         SPARSE_LAYOUT_ROW_MAJOR,
129         dense,
130         nrows,
131         ncols,
132         beta,
133         res,
134         dense_ncols);
135       TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat);
136     }
137   }
138 
sparse_mm(double * res,double * dense,double alpha,double beta,MKL_INT nrows,MKL_INT ncols,MKL_INT dense_ncols)139   inline void sparse_mm(
140       double* res,
141       double* dense,
142       double alpha,
143       double beta,
144       MKL_INT nrows,
145       MKL_INT ncols,
146       MKL_INT dense_ncols) {
147     int stat;
148     if (dense_ncols == 1) {
149       stat = mkl_sparse_d_mv(
150         SPARSE_OPERATION_NON_TRANSPOSE,
151         alpha,
152         A,
153         desc,
154         dense,
155         beta,
156         res);
157       TORCH_CHECK(stat == 0, "mkl_sparse_d_mv failed with error code: ", stat);
158     }
159     else {
160       stat = mkl_sparse_d_mm(
161         SPARSE_OPERATION_NON_TRANSPOSE,
162         alpha,
163         A,
164         desc,
165         SPARSE_LAYOUT_ROW_MAJOR,
166         dense,
167         nrows,
168         ncols,
169         beta,
170         res,
171         dense_ncols);
172       TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat);
173     }
174   }
175 
~SparseCsrMKLInterface()176   ~SparseCsrMKLInterface() {
177     mkl_sparse_destroy(A);
178   }
179 };
180 
181 template <typename scalar_t>
sparse_mm_mkl_template(Tensor & res,const Tensor & col_indices,const Tensor & crow_indices,const Tensor & values,const Tensor & dense,const Tensor & t,const Scalar & alpha,const Scalar & beta,IntArrayRef size,IntArrayRef dense_size)182 static inline void sparse_mm_mkl_template(
183     Tensor& res,
184     const Tensor& col_indices,
185     const Tensor& crow_indices,
186     const Tensor& values,
187     const Tensor& dense,
188     const Tensor& t,
189     const Scalar& alpha,
190     const Scalar& beta,
191     IntArrayRef size,
192     IntArrayRef dense_size) {
193   SparseCsrMKLInterface mkl_impl(
194       col_indices.data_ptr<MKL_INT>(),
195       crow_indices.data_ptr<MKL_INT>(),
196       values.data_ptr<scalar_t>(),
197       size[0],
198       size[1]);
199   mkl_impl.sparse_mm(
200       res.data_ptr<scalar_t>(),
201       dense.data_ptr<scalar_t>(),
202       alpha.to<scalar_t>(),
203       beta.to<scalar_t>(),
204       size[0],
205       size[1],
206       dense_size[1]);
207 }
208 
is_mkl_int32_index()209 static bool inline constexpr is_mkl_int32_index() {
210 #ifdef MKL_ILP64
211   return false;
212 #else
213   return true;
214 #endif
215 }
216 
_sparse_mm_mkl_(Tensor & self,const SparseCsrTensor & sparse_,const Tensor & dense,const Tensor & t,const Scalar & alpha,const Scalar & beta)217 Tensor& _sparse_mm_mkl_(
218     Tensor& self,
219     const SparseCsrTensor& sparse_,
220     const Tensor& dense,
221     const Tensor& t,
222     const Scalar& alpha,
223     const Scalar& beta) {
224   if (is_mkl_int32_index()) {
225     if (sparse_.crow_indices().scalar_type() != kInt) {
226       TORCH_WARN(
227           "Pytorch is compiled with MKL LP64 and will convert crow_indices to int32.");
228     }
229     if (sparse_.col_indices().scalar_type() != kInt) {
230       TORCH_WARN(
231           "Pytorch is compiled with MKL LP64 and will convert col_indices to int32.");
232     }
233   } else { // This is for future proofing if we ever change to using MKL ILP64.
234     if (sparse_.crow_indices().scalar_type() != kLong) {
235       TORCH_WARN(
236           "Pytorch is compiled with MKL ILP64 and will convert crow_indices dtype to int64.");
237     }
238     if (sparse_.col_indices().scalar_type() != kLong) {
239       TORCH_WARN(
240           "Pytorch is compiled with MKL ILP64 and will convert col_indices dtype to int64.");
241     }
242   }
243   AT_DISPATCH_FLOATING_TYPES(
244       dense.scalar_type(), "addmm_sparse_csr_dense", [&] {
245         sparse_mm_mkl_template<scalar_t>(
246             self,
247             sparse_.col_indices().to(TORCH_INT_TYPE),
248             sparse_.crow_indices().to(TORCH_INT_TYPE),
249             sparse_.values(),
250             dense,
251             t,
252             alpha,
253             beta,
254             sparse_.sizes(),
255             dense.sizes());
256       });
257   return self;
258 }
259 
260 } // namespace native
261 } // namespace at
262 
263 #endif // AT_MKL_ENABLED
264