xref: /aosp_15_r20/external/executorch/kernels/optimized/blas/CPUBlas.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/optimized/blas/CPUBlas.h>
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <limits.h>
12*523fa7a6SAndroid Build Coastguard Worker 
13*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_WITH_BLAS
14*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_FOR_APPLE
15*523fa7a6SAndroid Build Coastguard Worker #include <Accelerate/Accelerate.h>
16*523fa7a6SAndroid Build Coastguard Worker #else
17*523fa7a6SAndroid Build Coastguard Worker // clang-format off
18*523fa7a6SAndroid Build Coastguard Worker extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc);
19*523fa7a6SAndroid Build Coastguard Worker extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc);
20*523fa7a6SAndroid Build Coastguard Worker // clang-format on
21*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_FOR_APPLE
22*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_WITH_BLAS
23*523fa7a6SAndroid Build Coastguard Worker 
24*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
25*523fa7a6SAndroid Build Coastguard Worker namespace cpublas {
26*523fa7a6SAndroid Build Coastguard Worker 
27*523fa7a6SAndroid Build Coastguard Worker using exec_aten::BFloat16;
28*523fa7a6SAndroid Build Coastguard Worker using exec_aten::Half;
29*523fa7a6SAndroid Build Coastguard Worker 
30*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_WITH_BLAS
31*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_FOR_APPLE
to_cblas_transpose(TransposeType trans)32*523fa7a6SAndroid Build Coastguard Worker inline CBLAS_TRANSPOSE to_cblas_transpose(TransposeType trans) {
33*523fa7a6SAndroid Build Coastguard Worker   switch (trans) {
34*523fa7a6SAndroid Build Coastguard Worker     case TransposeType::Transpose:
35*523fa7a6SAndroid Build Coastguard Worker       return CblasTrans;
36*523fa7a6SAndroid Build Coastguard Worker     case TransposeType::NoTranspose:
37*523fa7a6SAndroid Build Coastguard Worker       return CblasNoTrans;
38*523fa7a6SAndroid Build Coastguard Worker     case TransposeType::ConjTranspose:
39*523fa7a6SAndroid Build Coastguard Worker       return CblasConjTrans;
40*523fa7a6SAndroid Build Coastguard Worker   }
41*523fa7a6SAndroid Build Coastguard Worker   // Assume no transpose by default
42*523fa7a6SAndroid Build Coastguard Worker   return CblasNoTrans;
43*523fa7a6SAndroid Build Coastguard Worker }
44*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_FOR_APPLE
45*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_WITH_BLAS
46*523fa7a6SAndroid Build Coastguard Worker 
47*523fa7a6SAndroid Build Coastguard Worker // clang-format off
normalize_last_dims(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,int64_t * lda,int64_t * ldb,int64_t * ldc)48*523fa7a6SAndroid Build Coastguard Worker void normalize_last_dims(
49*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
50*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
51*523fa7a6SAndroid Build Coastguard Worker     int64_t *lda, int64_t *ldb, int64_t *ldc) {
52*523fa7a6SAndroid Build Coastguard Worker   if (n == 1) {
53*523fa7a6SAndroid Build Coastguard Worker     *ldc = m;
54*523fa7a6SAndroid Build Coastguard Worker   }
55*523fa7a6SAndroid Build Coastguard Worker 
56*523fa7a6SAndroid Build Coastguard Worker   if(transa != TransposeType::NoTranspose) {
57*523fa7a6SAndroid Build Coastguard Worker     if (m == 1) {
58*523fa7a6SAndroid Build Coastguard Worker       *lda = k;
59*523fa7a6SAndroid Build Coastguard Worker     }
60*523fa7a6SAndroid Build Coastguard Worker   } else if(k == 1) {
61*523fa7a6SAndroid Build Coastguard Worker     *lda = m;
62*523fa7a6SAndroid Build Coastguard Worker   }
63*523fa7a6SAndroid Build Coastguard Worker 
64*523fa7a6SAndroid Build Coastguard Worker   if(transb != TransposeType::NoTranspose) {
65*523fa7a6SAndroid Build Coastguard Worker     if (k == 1) {
66*523fa7a6SAndroid Build Coastguard Worker       *ldb = n;
67*523fa7a6SAndroid Build Coastguard Worker     }
68*523fa7a6SAndroid Build Coastguard Worker   } else if (n == 1) {
69*523fa7a6SAndroid Build Coastguard Worker     *ldb = k;
70*523fa7a6SAndroid Build Coastguard Worker   }
71*523fa7a6SAndroid Build Coastguard Worker }
72*523fa7a6SAndroid Build Coastguard Worker // clang-format on
73*523fa7a6SAndroid Build Coastguard Worker 
74*523fa7a6SAndroid Build Coastguard Worker // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const double alpha,const double * a,int64_t lda,const double * b,int64_t ldb,const double beta,double * c,int64_t ldc)75*523fa7a6SAndroid Build Coastguard Worker void gemm(
76*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
77*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
78*523fa7a6SAndroid Build Coastguard Worker     const double alpha,
79*523fa7a6SAndroid Build Coastguard Worker     const double *a, int64_t lda,
80*523fa7a6SAndroid Build Coastguard Worker     const double *b, int64_t ldb,
81*523fa7a6SAndroid Build Coastguard Worker     const double beta,
82*523fa7a6SAndroid Build Coastguard Worker     double *c, int64_t ldc) {
83*523fa7a6SAndroid Build Coastguard Worker   normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
84*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_WITH_BLAS
85*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_FOR_APPLE
86*523fa7a6SAndroid Build Coastguard Worker   cblas_dgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
87*523fa7a6SAndroid Build Coastguard Worker #else
88*523fa7a6SAndroid Build Coastguard Worker   int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
89*523fa7a6SAndroid Build Coastguard Worker   double alpha_ = alpha, beta_ = beta;
90*523fa7a6SAndroid Build Coastguard Worker   char transa_ = to_blas(transa), transb_ = to_blas(transb);
91*523fa7a6SAndroid Build Coastguard Worker   dgemm_(
92*523fa7a6SAndroid Build Coastguard Worker       &transa_, &transb_,
93*523fa7a6SAndroid Build Coastguard Worker       &m_, &n_, &k_,
94*523fa7a6SAndroid Build Coastguard Worker       &alpha_,
95*523fa7a6SAndroid Build Coastguard Worker       a, &lda_,
96*523fa7a6SAndroid Build Coastguard Worker       b, &ldb_,
97*523fa7a6SAndroid Build Coastguard Worker       &beta_,
98*523fa7a6SAndroid Build Coastguard Worker       c, &ldc_);
99*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_FOR_APPLE
100*523fa7a6SAndroid Build Coastguard Worker #else
101*523fa7a6SAndroid Build Coastguard Worker   using acc_type = utils::compute_dtype<float>;
102*523fa7a6SAndroid Build Coastguard Worker   gemm_impl(
103*523fa7a6SAndroid Build Coastguard Worker       transa, transb,
104*523fa7a6SAndroid Build Coastguard Worker       m, n, k,
105*523fa7a6SAndroid Build Coastguard Worker       static_cast<const acc_type>(alpha),
106*523fa7a6SAndroid Build Coastguard Worker       a, lda,
107*523fa7a6SAndroid Build Coastguard Worker       b, ldb,
108*523fa7a6SAndroid Build Coastguard Worker       static_cast<const acc_type>(beta),
109*523fa7a6SAndroid Build Coastguard Worker       c, ldc);
110*523fa7a6SAndroid Build Coastguard Worker #endif
111*523fa7a6SAndroid Build Coastguard Worker }
112*523fa7a6SAndroid Build Coastguard Worker // clang-format on
113*523fa7a6SAndroid Build Coastguard Worker 
114*523fa7a6SAndroid Build Coastguard Worker // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const float * a,int64_t lda,const float * b,int64_t ldb,const float beta,float * c,int64_t ldc)115*523fa7a6SAndroid Build Coastguard Worker void gemm(
116*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
117*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
118*523fa7a6SAndroid Build Coastguard Worker     const float alpha,
119*523fa7a6SAndroid Build Coastguard Worker     const float *a, int64_t lda,
120*523fa7a6SAndroid Build Coastguard Worker     const float *b, int64_t ldb,
121*523fa7a6SAndroid Build Coastguard Worker     const float beta,
122*523fa7a6SAndroid Build Coastguard Worker     float *c, int64_t ldc) {
123*523fa7a6SAndroid Build Coastguard Worker   normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
124*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_WITH_BLAS
125*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_FOR_APPLE
126*523fa7a6SAndroid Build Coastguard Worker   cblas_sgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
127*523fa7a6SAndroid Build Coastguard Worker #else
128*523fa7a6SAndroid Build Coastguard Worker   int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
129*523fa7a6SAndroid Build Coastguard Worker   float alpha_ = alpha, beta_ = beta;
130*523fa7a6SAndroid Build Coastguard Worker   char transa_ = to_blas(transa), transb_ = to_blas(transb);
131*523fa7a6SAndroid Build Coastguard Worker   sgemm_(
132*523fa7a6SAndroid Build Coastguard Worker       &transa_, &transb_,
133*523fa7a6SAndroid Build Coastguard Worker       &m_, &n_, &k_,
134*523fa7a6SAndroid Build Coastguard Worker       &alpha_,
135*523fa7a6SAndroid Build Coastguard Worker       a, &lda_,
136*523fa7a6SAndroid Build Coastguard Worker       b, &ldb_,
137*523fa7a6SAndroid Build Coastguard Worker       &beta_,
138*523fa7a6SAndroid Build Coastguard Worker       c, &ldc_);
139*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_FOR_APPLE
140*523fa7a6SAndroid Build Coastguard Worker 
141*523fa7a6SAndroid Build Coastguard Worker #else
142*523fa7a6SAndroid Build Coastguard Worker   using acc_type = utils::compute_dtype<float>;
143*523fa7a6SAndroid Build Coastguard Worker   gemm_impl(
144*523fa7a6SAndroid Build Coastguard Worker       transa, transb,
145*523fa7a6SAndroid Build Coastguard Worker       m, n, k,
146*523fa7a6SAndroid Build Coastguard Worker       static_cast<const acc_type>(alpha),
147*523fa7a6SAndroid Build Coastguard Worker       a, lda,
148*523fa7a6SAndroid Build Coastguard Worker       b, ldb,
149*523fa7a6SAndroid Build Coastguard Worker       static_cast<const acc_type>(beta),
150*523fa7a6SAndroid Build Coastguard Worker       c, ldc);
151*523fa7a6SAndroid Build Coastguard Worker #endif
152*523fa7a6SAndroid Build Coastguard Worker }
153*523fa7a6SAndroid Build Coastguard Worker 
154*523fa7a6SAndroid Build Coastguard Worker // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const Half alpha,const Half * a,int64_t lda,const Half * b,int64_t ldb,const Half beta,Half * c,int64_t ldc)155*523fa7a6SAndroid Build Coastguard Worker void gemm(
156*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
157*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
158*523fa7a6SAndroid Build Coastguard Worker     const Half alpha,
159*523fa7a6SAndroid Build Coastguard Worker     const Half *a, int64_t lda,
160*523fa7a6SAndroid Build Coastguard Worker     const Half *b, int64_t ldb,
161*523fa7a6SAndroid Build Coastguard Worker     const Half beta,
162*523fa7a6SAndroid Build Coastguard Worker     Half *c, int64_t ldc) {
163*523fa7a6SAndroid Build Coastguard Worker   normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
164*523fa7a6SAndroid Build Coastguard Worker 
165*523fa7a6SAndroid Build Coastguard Worker   using acc_type = utils::compute_dtype<Half>;
166*523fa7a6SAndroid Build Coastguard Worker   gemm_impl(
167*523fa7a6SAndroid Build Coastguard Worker       transa, transb,
168*523fa7a6SAndroid Build Coastguard Worker       m, n, k,
169*523fa7a6SAndroid Build Coastguard Worker       static_cast<const acc_type>(alpha),
170*523fa7a6SAndroid Build Coastguard Worker       a, lda,
171*523fa7a6SAndroid Build Coastguard Worker       b, ldb,
172*523fa7a6SAndroid Build Coastguard Worker       static_cast<const acc_type>(beta),
173*523fa7a6SAndroid Build Coastguard Worker       c, ldc);
174*523fa7a6SAndroid Build Coastguard Worker }
175*523fa7a6SAndroid Build Coastguard Worker // clang-format on
176*523fa7a6SAndroid Build Coastguard Worker 
177*523fa7a6SAndroid Build Coastguard Worker // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const BFloat16 alpha,const BFloat16 * a,int64_t lda,const BFloat16 * b,int64_t ldb,const BFloat16 beta,BFloat16 * c,int64_t ldc)178*523fa7a6SAndroid Build Coastguard Worker void gemm(
179*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
180*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
181*523fa7a6SAndroid Build Coastguard Worker     const BFloat16 alpha,
182*523fa7a6SAndroid Build Coastguard Worker     const BFloat16 *a, int64_t lda,
183*523fa7a6SAndroid Build Coastguard Worker     const BFloat16 *b, int64_t ldb,
184*523fa7a6SAndroid Build Coastguard Worker     const BFloat16 beta,
185*523fa7a6SAndroid Build Coastguard Worker     BFloat16 *c, int64_t ldc) {
186*523fa7a6SAndroid Build Coastguard Worker   normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
187*523fa7a6SAndroid Build Coastguard Worker 
188*523fa7a6SAndroid Build Coastguard Worker   using acc_type = utils::compute_dtype<BFloat16>;
189*523fa7a6SAndroid Build Coastguard Worker   gemm_impl(
190*523fa7a6SAndroid Build Coastguard Worker       transa, transb,
191*523fa7a6SAndroid Build Coastguard Worker       m, n, k,
192*523fa7a6SAndroid Build Coastguard Worker       static_cast<const acc_type>(alpha),
193*523fa7a6SAndroid Build Coastguard Worker       a, lda,
194*523fa7a6SAndroid Build Coastguard Worker       b, ldb,
195*523fa7a6SAndroid Build Coastguard Worker       static_cast<const acc_type>(beta),
196*523fa7a6SAndroid Build Coastguard Worker       c, ldc);
197*523fa7a6SAndroid Build Coastguard Worker }
198*523fa7a6SAndroid Build Coastguard Worker // clang-format on
199*523fa7a6SAndroid Build Coastguard Worker 
200*523fa7a6SAndroid Build Coastguard Worker } // namespace cpublas
201*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
202