1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/optimized/blas/CPUBlas.h>
10
11 #include <limits.h>
12
13 #ifdef ET_BUILD_WITH_BLAS
14 #ifdef ET_BUILD_FOR_APPLE
15 #include <Accelerate/Accelerate.h>
16 #else
17 // clang-format off
18 extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc);
19 extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc);
20 // clang-format on
21 #endif // ET_BUILD_FOR_APPLE
22 #endif // ET_BUILD_WITH_BLAS
23
24 namespace executorch {
25 namespace cpublas {
26
27 using exec_aten::BFloat16;
28 using exec_aten::Half;
29
30 #ifdef ET_BUILD_WITH_BLAS
31 #ifdef ET_BUILD_FOR_APPLE
to_cblas_transpose(TransposeType trans)32 inline CBLAS_TRANSPOSE to_cblas_transpose(TransposeType trans) {
33 switch (trans) {
34 case TransposeType::Transpose:
35 return CblasTrans;
36 case TransposeType::NoTranspose:
37 return CblasNoTrans;
38 case TransposeType::ConjTranspose:
39 return CblasConjTrans;
40 }
41 // Assume no transpose by default
42 return CblasNoTrans;
43 }
44 #endif // ET_BUILD_FOR_APPLE
45 #endif // ET_BUILD_WITH_BLAS
46
47 // clang-format off
normalize_last_dims(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,int64_t * lda,int64_t * ldb,int64_t * ldc)48 void normalize_last_dims(
49 TransposeType transa, TransposeType transb,
50 int64_t m, int64_t n, int64_t k,
51 int64_t *lda, int64_t *ldb, int64_t *ldc) {
52 if (n == 1) {
53 *ldc = m;
54 }
55
56 if(transa != TransposeType::NoTranspose) {
57 if (m == 1) {
58 *lda = k;
59 }
60 } else if(k == 1) {
61 *lda = m;
62 }
63
64 if(transb != TransposeType::NoTranspose) {
65 if (k == 1) {
66 *ldb = n;
67 }
68 } else if (n == 1) {
69 *ldb = k;
70 }
71 }
72 // clang-format on
73
74 // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const double alpha,const double * a,int64_t lda,const double * b,int64_t ldb,const double beta,double * c,int64_t ldc)75 void gemm(
76 TransposeType transa, TransposeType transb,
77 int64_t m, int64_t n, int64_t k,
78 const double alpha,
79 const double *a, int64_t lda,
80 const double *b, int64_t ldb,
81 const double beta,
82 double *c, int64_t ldc) {
83 normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
84 #ifdef ET_BUILD_WITH_BLAS
85 #ifdef ET_BUILD_FOR_APPLE
86 cblas_dgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
87 #else
88 int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
89 double alpha_ = alpha, beta_ = beta;
90 char transa_ = to_blas(transa), transb_ = to_blas(transb);
91 dgemm_(
92 &transa_, &transb_,
93 &m_, &n_, &k_,
94 &alpha_,
95 a, &lda_,
96 b, &ldb_,
97 &beta_,
98 c, &ldc_);
99 #endif // ET_BUILD_FOR_APPLE
100 #else
101 using acc_type = utils::compute_dtype<float>;
102 gemm_impl(
103 transa, transb,
104 m, n, k,
105 static_cast<const acc_type>(alpha),
106 a, lda,
107 b, ldb,
108 static_cast<const acc_type>(beta),
109 c, ldc);
110 #endif
111 }
112 // clang-format on
113
114 // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const float * a,int64_t lda,const float * b,int64_t ldb,const float beta,float * c,int64_t ldc)115 void gemm(
116 TransposeType transa, TransposeType transb,
117 int64_t m, int64_t n, int64_t k,
118 const float alpha,
119 const float *a, int64_t lda,
120 const float *b, int64_t ldb,
121 const float beta,
122 float *c, int64_t ldc) {
123 normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
124 #ifdef ET_BUILD_WITH_BLAS
125 #ifdef ET_BUILD_FOR_APPLE
126 cblas_sgemm(CblasColMajor, to_cblas_transpose(transa), to_cblas_transpose(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
127 #else
128 int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
129 float alpha_ = alpha, beta_ = beta;
130 char transa_ = to_blas(transa), transb_ = to_blas(transb);
131 sgemm_(
132 &transa_, &transb_,
133 &m_, &n_, &k_,
134 &alpha_,
135 a, &lda_,
136 b, &ldb_,
137 &beta_,
138 c, &ldc_);
139 #endif // ET_BUILD_FOR_APPLE
140
141 #else
142 using acc_type = utils::compute_dtype<float>;
143 gemm_impl(
144 transa, transb,
145 m, n, k,
146 static_cast<const acc_type>(alpha),
147 a, lda,
148 b, ldb,
149 static_cast<const acc_type>(beta),
150 c, ldc);
151 #endif
152 }
153
154 // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const Half alpha,const Half * a,int64_t lda,const Half * b,int64_t ldb,const Half beta,Half * c,int64_t ldc)155 void gemm(
156 TransposeType transa, TransposeType transb,
157 int64_t m, int64_t n, int64_t k,
158 const Half alpha,
159 const Half *a, int64_t lda,
160 const Half *b, int64_t ldb,
161 const Half beta,
162 Half *c, int64_t ldc) {
163 normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
164
165 using acc_type = utils::compute_dtype<Half>;
166 gemm_impl(
167 transa, transb,
168 m, n, k,
169 static_cast<const acc_type>(alpha),
170 a, lda,
171 b, ldb,
172 static_cast<const acc_type>(beta),
173 c, ldc);
174 }
175 // clang-format on
176
177 // clang-format off
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const BFloat16 alpha,const BFloat16 * a,int64_t lda,const BFloat16 * b,int64_t ldb,const BFloat16 beta,BFloat16 * c,int64_t ldc)178 void gemm(
179 TransposeType transa, TransposeType transb,
180 int64_t m, int64_t n, int64_t k,
181 const BFloat16 alpha,
182 const BFloat16 *a, int64_t lda,
183 const BFloat16 *b, int64_t ldb,
184 const BFloat16 beta,
185 BFloat16 *c, int64_t ldc) {
186 normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
187
188 using acc_type = utils::compute_dtype<BFloat16>;
189 gemm_impl(
190 transa, transb,
191 m, n, k,
192 static_cast<const acc_type>(alpha),
193 a, lda,
194 b, ldb,
195 static_cast<const acc_type>(beta),
196 c, ldc);
197 }
198 // clang-format on
199
200 } // namespace cpublas
201 } // namespace executorch
202