xref: /aosp_15_r20/external/executorch/kernels/optimized/blas/CPUBlas.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/optimized/blas/CPUBlas.h>
10 
11 #include <limits.h>
12 
13 #ifdef ET_BUILD_WITH_BLAS
14 #ifdef ET_BUILD_FOR_APPLE
15 #include <Accelerate/Accelerate.h>
16 #else
17 // clang-format off
18 extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc);
19 extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc);
20 // clang-format on
21 #endif // ET_BUILD_FOR_APPLE
22 #endif // ET_BUILD_WITH_BLAS
23 
24 namespace executorch {
25 namespace cpublas {
26 
27 using exec_aten::BFloat16;
28 using exec_aten::Half;
29 
30 #ifdef ET_BUILD_WITH_BLAS
31 #ifdef ET_BUILD_FOR_APPLE
to_cblas_transpose(TransposeType trans)32 inline CBLAS_TRANSPOSE to_cblas_transpose(TransposeType trans) {
33   switch (trans) {
34     case TransposeType::Transpose:
35       return CblasTrans;
36     case TransposeType::NoTranspose:
37       return CblasNoTrans;
38     case TransposeType::ConjTranspose:
39       return CblasConjTrans;
40   }
41   // Assume no transpose by default
42   return CblasNoTrans;
43 }
44 #endif // ET_BUILD_FOR_APPLE
45 #endif // ET_BUILD_WITH_BLAS
46 
47 // clang-format off
normalize_last_dims(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,int64_t * lda,int64_t * ldb,int64_t * ldc)48 void normalize_last_dims(
49     TransposeType transa, TransposeType transb,
50     int64_t m, int64_t n, int64_t k,
51     int64_t *lda, int64_t *ldb, int64_t *ldc) {
52   if (n == 1) {
53     *ldc = m;
54   }
55 
56   if(transa != TransposeType::NoTranspose) {
57     if (m == 1) {
58       *lda = k;
59     }
60   } else if(k == 1) {
61     *lda = m;
62   }
63 
64   if(transb != TransposeType::NoTranspose) {
65     if (k == 1) {
66       *ldb = n;
67     }
68   } else if (n == 1) {
69     *ldb = k;
70   }
71 }
72 // clang-format on
73 
74 // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const double alpha,const double * a,int64_t lda,const double * b,int64_t ldb,const double beta,double * c,int64_t ldc)75 void gemm(
76     TransposeType transa, TransposeType transb,
77     int64_t m, int64_t n, int64_t k,
78     const double alpha,
79     const double *a, int64_t lda,
80     const double *b, int64_t ldb,
81     const double beta,
82     double *c, int64_t ldc) {
83   normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
84 #ifdef ET_BUILD_WITH_BLAS
85 #ifdef ET_BUILD_FOR_APPLE
86   cblas_dgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
87 #else
88   int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
89   double alpha_ = alpha, beta_ = beta;
90   char transa_ = to_blas(transa), transb_ = to_blas(transb);
91   dgemm_(
92       &transa_, &transb_,
93       &m_, &n_, &k_,
94       &alpha_,
95       a, &lda_,
96       b, &ldb_,
97       &beta_,
98       c, &ldc_);
99 #endif // ET_BUILD_FOR_APPLE
100 #else
101   using acc_type = utils::compute_dtype<float>;
102   gemm_impl(
103       transa, transb,
104       m, n, k,
105       static_cast<const acc_type>(alpha),
106       a, lda,
107       b, ldb,
108       static_cast<const acc_type>(beta),
109       c, ldc);
110 #endif
111 }
112 // clang-format on
113 
114 // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const float * a,int64_t lda,const float * b,int64_t ldb,const float beta,float * c,int64_t ldc)115 void gemm(
116     TransposeType transa, TransposeType transb,
117     int64_t m, int64_t n, int64_t k,
118     const float alpha,
119     const float *a, int64_t lda,
120     const float *b, int64_t ldb,
121     const float beta,
122     float *c, int64_t ldc) {
123   normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
124 #ifdef ET_BUILD_WITH_BLAS
125 #ifdef ET_BUILD_FOR_APPLE
126   cblas_sgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
127 #else
128   int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
129   float alpha_ = alpha, beta_ = beta;
130   char transa_ = to_blas(transa), transb_ = to_blas(transb);
131   sgemm_(
132       &transa_, &transb_,
133       &m_, &n_, &k_,
134       &alpha_,
135       a, &lda_,
136       b, &ldb_,
137       &beta_,
138       c, &ldc_);
139 #endif // ET_BUILD_FOR_APPLE
140 
141 #else
142   using acc_type = utils::compute_dtype<float>;
143   gemm_impl(
144       transa, transb,
145       m, n, k,
146       static_cast<const acc_type>(alpha),
147       a, lda,
148       b, ldb,
149       static_cast<const acc_type>(beta),
150       c, ldc);
151 #endif
152 }
153 
154 // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const Half alpha,const Half * a,int64_t lda,const Half * b,int64_t ldb,const Half beta,Half * c,int64_t ldc)155 void gemm(
156     TransposeType transa, TransposeType transb,
157     int64_t m, int64_t n, int64_t k,
158     const Half alpha,
159     const Half *a, int64_t lda,
160     const Half *b, int64_t ldb,
161     const Half beta,
162     Half *c, int64_t ldc) {
163   normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
164 
165   using acc_type = utils::compute_dtype<Half>;
166   gemm_impl(
167       transa, transb,
168       m, n, k,
169       static_cast<const acc_type>(alpha),
170       a, lda,
171       b, ldb,
172       static_cast<const acc_type>(beta),
173       c, ldc);
174 }
175 // clang-format on
176 
177 // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const BFloat16 alpha,const BFloat16 * a,int64_t lda,const BFloat16 * b,int64_t ldb,const BFloat16 beta,BFloat16 * c,int64_t ldc)178 void gemm(
179     TransposeType transa, TransposeType transb,
180     int64_t m, int64_t n, int64_t k,
181     const BFloat16 alpha,
182     const BFloat16 *a, int64_t lda,
183     const BFloat16 *b, int64_t ldb,
184     const BFloat16 beta,
185     BFloat16 *c, int64_t ldc) {
186   normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
187 
188   using acc_type = utils::compute_dtype<BFloat16>;
189   gemm_impl(
190       transa, transb,
191       m, n, k,
192       static_cast<const acc_type>(alpha),
193       a, lda,
194       b, ldb,
195       static_cast<const acc_type>(beta),
196       c, ldc);
197 }
198 // clang-format on
199 
200 } // namespace cpublas
201 } // namespace executorch
202