1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/optimized/blas/CPUBlas.h>
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #include <limits.h>
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_WITH_BLAS
14*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_FOR_APPLE
15*523fa7a6SAndroid Build Coastguard Worker #include <Accelerate/Accelerate.h>
16*523fa7a6SAndroid Build Coastguard Worker #else
17*523fa7a6SAndroid Build Coastguard Worker // clang-format off
18*523fa7a6SAndroid Build Coastguard Worker extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc);
19*523fa7a6SAndroid Build Coastguard Worker extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc);
20*523fa7a6SAndroid Build Coastguard Worker // clang-format on
21*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_FOR_APPLE
22*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_WITH_BLAS
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
25*523fa7a6SAndroid Build Coastguard Worker namespace cpublas {
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker using exec_aten::BFloat16;
28*523fa7a6SAndroid Build Coastguard Worker using exec_aten::Half;
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_WITH_BLAS
31*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_FOR_APPLE
to_cblas_transpose(TransposeType trans)32*523fa7a6SAndroid Build Coastguard Worker inline CBLAS_TRANSPOSE to_cblas_transpose(TransposeType trans) {
33*523fa7a6SAndroid Build Coastguard Worker switch (trans) {
34*523fa7a6SAndroid Build Coastguard Worker case TransposeType::Transpose:
35*523fa7a6SAndroid Build Coastguard Worker return CblasTrans;
36*523fa7a6SAndroid Build Coastguard Worker case TransposeType::NoTranspose:
37*523fa7a6SAndroid Build Coastguard Worker return CblasNoTrans;
38*523fa7a6SAndroid Build Coastguard Worker case TransposeType::ConjTranspose:
39*523fa7a6SAndroid Build Coastguard Worker return CblasConjTrans;
40*523fa7a6SAndroid Build Coastguard Worker }
41*523fa7a6SAndroid Build Coastguard Worker // Assume no transpose by default
42*523fa7a6SAndroid Build Coastguard Worker return CblasNoTrans;
43*523fa7a6SAndroid Build Coastguard Worker }
44*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_FOR_APPLE
45*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_WITH_BLAS
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker // clang-format off
normalize_last_dims(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,int64_t * lda,int64_t * ldb,int64_t * ldc)48*523fa7a6SAndroid Build Coastguard Worker void normalize_last_dims(
49*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
50*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
51*523fa7a6SAndroid Build Coastguard Worker int64_t *lda, int64_t *ldb, int64_t *ldc) {
52*523fa7a6SAndroid Build Coastguard Worker if (n == 1) {
53*523fa7a6SAndroid Build Coastguard Worker *ldc = m;
54*523fa7a6SAndroid Build Coastguard Worker }
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker if(transa != TransposeType::NoTranspose) {
57*523fa7a6SAndroid Build Coastguard Worker if (m == 1) {
58*523fa7a6SAndroid Build Coastguard Worker *lda = k;
59*523fa7a6SAndroid Build Coastguard Worker }
60*523fa7a6SAndroid Build Coastguard Worker } else if(k == 1) {
61*523fa7a6SAndroid Build Coastguard Worker *lda = m;
62*523fa7a6SAndroid Build Coastguard Worker }
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker if(transb != TransposeType::NoTranspose) {
65*523fa7a6SAndroid Build Coastguard Worker if (k == 1) {
66*523fa7a6SAndroid Build Coastguard Worker *ldb = n;
67*523fa7a6SAndroid Build Coastguard Worker }
68*523fa7a6SAndroid Build Coastguard Worker } else if (n == 1) {
69*523fa7a6SAndroid Build Coastguard Worker *ldb = k;
70*523fa7a6SAndroid Build Coastguard Worker }
71*523fa7a6SAndroid Build Coastguard Worker }
72*523fa7a6SAndroid Build Coastguard Worker // clang-format on
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const double alpha,const double * a,int64_t lda,const double * b,int64_t ldb,const double beta,double * c,int64_t ldc)75*523fa7a6SAndroid Build Coastguard Worker void gemm(
76*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
77*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
78*523fa7a6SAndroid Build Coastguard Worker const double alpha,
79*523fa7a6SAndroid Build Coastguard Worker const double *a, int64_t lda,
80*523fa7a6SAndroid Build Coastguard Worker const double *b, int64_t ldb,
81*523fa7a6SAndroid Build Coastguard Worker const double beta,
82*523fa7a6SAndroid Build Coastguard Worker double *c, int64_t ldc) {
83*523fa7a6SAndroid Build Coastguard Worker normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
84*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_WITH_BLAS
85*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_FOR_APPLE
86*523fa7a6SAndroid Build Coastguard Worker cblas_dgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
87*523fa7a6SAndroid Build Coastguard Worker #else
88*523fa7a6SAndroid Build Coastguard Worker int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
89*523fa7a6SAndroid Build Coastguard Worker double alpha_ = alpha, beta_ = beta;
90*523fa7a6SAndroid Build Coastguard Worker char transa_ = to_blas(transa), transb_ = to_blas(transb);
91*523fa7a6SAndroid Build Coastguard Worker dgemm_(
92*523fa7a6SAndroid Build Coastguard Worker &transa_, &transb_,
93*523fa7a6SAndroid Build Coastguard Worker &m_, &n_, &k_,
94*523fa7a6SAndroid Build Coastguard Worker &alpha_,
95*523fa7a6SAndroid Build Coastguard Worker a, &lda_,
96*523fa7a6SAndroid Build Coastguard Worker b, &ldb_,
97*523fa7a6SAndroid Build Coastguard Worker &beta_,
98*523fa7a6SAndroid Build Coastguard Worker c, &ldc_);
99*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_FOR_APPLE
100*523fa7a6SAndroid Build Coastguard Worker #else
101*523fa7a6SAndroid Build Coastguard Worker using acc_type = utils::compute_dtype<float>;
102*523fa7a6SAndroid Build Coastguard Worker gemm_impl(
103*523fa7a6SAndroid Build Coastguard Worker transa, transb,
104*523fa7a6SAndroid Build Coastguard Worker m, n, k,
105*523fa7a6SAndroid Build Coastguard Worker static_cast<const acc_type>(alpha),
106*523fa7a6SAndroid Build Coastguard Worker a, lda,
107*523fa7a6SAndroid Build Coastguard Worker b, ldb,
108*523fa7a6SAndroid Build Coastguard Worker static_cast<const acc_type>(beta),
109*523fa7a6SAndroid Build Coastguard Worker c, ldc);
110*523fa7a6SAndroid Build Coastguard Worker #endif
111*523fa7a6SAndroid Build Coastguard Worker }
112*523fa7a6SAndroid Build Coastguard Worker // clang-format on
113*523fa7a6SAndroid Build Coastguard Worker
114*523fa7a6SAndroid Build Coastguard Worker // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const float * a,int64_t lda,const float * b,int64_t ldb,const float beta,float * c,int64_t ldc)115*523fa7a6SAndroid Build Coastguard Worker void gemm(
116*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
117*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
118*523fa7a6SAndroid Build Coastguard Worker const float alpha,
119*523fa7a6SAndroid Build Coastguard Worker const float *a, int64_t lda,
120*523fa7a6SAndroid Build Coastguard Worker const float *b, int64_t ldb,
121*523fa7a6SAndroid Build Coastguard Worker const float beta,
122*523fa7a6SAndroid Build Coastguard Worker float *c, int64_t ldc) {
123*523fa7a6SAndroid Build Coastguard Worker normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
124*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_WITH_BLAS
125*523fa7a6SAndroid Build Coastguard Worker #ifdef ET_BUILD_FOR_APPLE
126*523fa7a6SAndroid Build Coastguard Worker cblas_sgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
127*523fa7a6SAndroid Build Coastguard Worker #else
128*523fa7a6SAndroid Build Coastguard Worker int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
129*523fa7a6SAndroid Build Coastguard Worker float alpha_ = alpha, beta_ = beta;
130*523fa7a6SAndroid Build Coastguard Worker char transa_ = to_blas(transa), transb_ = to_blas(transb);
131*523fa7a6SAndroid Build Coastguard Worker sgemm_(
132*523fa7a6SAndroid Build Coastguard Worker &transa_, &transb_,
133*523fa7a6SAndroid Build Coastguard Worker &m_, &n_, &k_,
134*523fa7a6SAndroid Build Coastguard Worker &alpha_,
135*523fa7a6SAndroid Build Coastguard Worker a, &lda_,
136*523fa7a6SAndroid Build Coastguard Worker b, &ldb_,
137*523fa7a6SAndroid Build Coastguard Worker &beta_,
138*523fa7a6SAndroid Build Coastguard Worker c, &ldc_);
139*523fa7a6SAndroid Build Coastguard Worker #endif // ET_BUILD_FOR_APPLE
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Worker #else
142*523fa7a6SAndroid Build Coastguard Worker using acc_type = utils::compute_dtype<float>;
143*523fa7a6SAndroid Build Coastguard Worker gemm_impl(
144*523fa7a6SAndroid Build Coastguard Worker transa, transb,
145*523fa7a6SAndroid Build Coastguard Worker m, n, k,
146*523fa7a6SAndroid Build Coastguard Worker static_cast<const acc_type>(alpha),
147*523fa7a6SAndroid Build Coastguard Worker a, lda,
148*523fa7a6SAndroid Build Coastguard Worker b, ldb,
149*523fa7a6SAndroid Build Coastguard Worker static_cast<const acc_type>(beta),
150*523fa7a6SAndroid Build Coastguard Worker c, ldc);
151*523fa7a6SAndroid Build Coastguard Worker #endif
152*523fa7a6SAndroid Build Coastguard Worker }
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Worker // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const Half alpha,const Half * a,int64_t lda,const Half * b,int64_t ldb,const Half beta,Half * c,int64_t ldc)155*523fa7a6SAndroid Build Coastguard Worker void gemm(
156*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
157*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
158*523fa7a6SAndroid Build Coastguard Worker const Half alpha,
159*523fa7a6SAndroid Build Coastguard Worker const Half *a, int64_t lda,
160*523fa7a6SAndroid Build Coastguard Worker const Half *b, int64_t ldb,
161*523fa7a6SAndroid Build Coastguard Worker const Half beta,
162*523fa7a6SAndroid Build Coastguard Worker Half *c, int64_t ldc) {
163*523fa7a6SAndroid Build Coastguard Worker normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker using acc_type = utils::compute_dtype<Half>;
166*523fa7a6SAndroid Build Coastguard Worker gemm_impl(
167*523fa7a6SAndroid Build Coastguard Worker transa, transb,
168*523fa7a6SAndroid Build Coastguard Worker m, n, k,
169*523fa7a6SAndroid Build Coastguard Worker static_cast<const acc_type>(alpha),
170*523fa7a6SAndroid Build Coastguard Worker a, lda,
171*523fa7a6SAndroid Build Coastguard Worker b, ldb,
172*523fa7a6SAndroid Build Coastguard Worker static_cast<const acc_type>(beta),
173*523fa7a6SAndroid Build Coastguard Worker c, ldc);
174*523fa7a6SAndroid Build Coastguard Worker }
175*523fa7a6SAndroid Build Coastguard Worker // clang-format on
176*523fa7a6SAndroid Build Coastguard Worker
177*523fa7a6SAndroid Build Coastguard Worker // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const BFloat16 alpha,const BFloat16 * a,int64_t lda,const BFloat16 * b,int64_t ldb,const BFloat16 beta,BFloat16 * c,int64_t ldc)178*523fa7a6SAndroid Build Coastguard Worker void gemm(
179*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
180*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
181*523fa7a6SAndroid Build Coastguard Worker const BFloat16 alpha,
182*523fa7a6SAndroid Build Coastguard Worker const BFloat16 *a, int64_t lda,
183*523fa7a6SAndroid Build Coastguard Worker const BFloat16 *b, int64_t ldb,
184*523fa7a6SAndroid Build Coastguard Worker const BFloat16 beta,
185*523fa7a6SAndroid Build Coastguard Worker BFloat16 *c, int64_t ldc) {
186*523fa7a6SAndroid Build Coastguard Worker normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
187*523fa7a6SAndroid Build Coastguard Worker
188*523fa7a6SAndroid Build Coastguard Worker using acc_type = utils::compute_dtype<BFloat16>;
189*523fa7a6SAndroid Build Coastguard Worker gemm_impl(
190*523fa7a6SAndroid Build Coastguard Worker transa, transb,
191*523fa7a6SAndroid Build Coastguard Worker m, n, k,
192*523fa7a6SAndroid Build Coastguard Worker static_cast<const acc_type>(alpha),
193*523fa7a6SAndroid Build Coastguard Worker a, lda,
194*523fa7a6SAndroid Build Coastguard Worker b, ldb,
195*523fa7a6SAndroid Build Coastguard Worker static_cast<const acc_type>(beta),
196*523fa7a6SAndroid Build Coastguard Worker c, ldc);
197*523fa7a6SAndroid Build Coastguard Worker }
198*523fa7a6SAndroid Build Coastguard Worker // clang-format on
199*523fa7a6SAndroid Build Coastguard Worker
200*523fa7a6SAndroid Build Coastguard Worker } // namespace cpublas
201*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
202