1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/kernels/optimized/utils/math_utils.h>
12 #include <executorch/kernels/optimized/utils/unroll.h>
13
14 #include <executorch/extension/parallel/thread_parallel.h>
15 #include <executorch/runtime/core/portable_type/bfloat16.h>
16
17 #include <array>
18
19 namespace executorch {
20 namespace cpublas {
21
22 template <typename scalar_t, typename opmath_t>
scale_(int64_t m,int64_t n,opmath_t alpha,scalar_t * a,int64_t lda)23 void scale_(int64_t m, int64_t n, opmath_t alpha, scalar_t* a, int64_t lda) {
24 if (alpha == opmath_t(1)) {
25 return; // identity
26 }
27
28 if (alpha == opmath_t(0)) {
29 for (size_t j = 0; j < n; ++j) {
30 for (size_t i = 0; i < m; ++i) {
31 a[j * lda + i] = scalar_t(0);
32 }
33 }
34 return;
35 }
36
37 for (size_t j = 0; j < n; ++j) {
38 for (size_t i = 0; i < m; ++i) {
39 a[j * lda + i] *= alpha;
40 }
41 }
42 }
43
44 template <typename Func>
sum(int64_t N,Func f)45 auto sum(int64_t N, Func f) {
46 constexpr int ilp_factor = 4;
47 using acc_t = decltype(f(0));
48
49 // Calculate independent partial sums then add together at the end
50 std::array<acc_t, ilp_factor> partial_sums{};
51
52 size_t i = 0;
53 for (; i + ilp_factor <= N; i += ilp_factor) {
54 utils::ForcedUnroll<ilp_factor>{}(
55 [&i, &f, &partial_sums](int k) { partial_sums[k] += f(i + k); });
56 }
57 for (; i < N; ++i) {
58 partial_sums[0] += f(i);
59 }
60 for (int k = 1; k < ilp_factor; ++k) {
61 partial_sums[0] += partial_sums[k];
62 }
63 return partial_sums[0];
64 }
65
66 template <typename scalar_t, typename opmath_t>
67 typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)68 gemm_notrans_(
69 int64_t m,
70 int64_t n,
71 int64_t k,
72 opmath_t alpha,
73 const scalar_t* a,
74 int64_t lda,
75 const scalar_t* b,
76 int64_t ldb,
77 opmath_t beta,
78 scalar_t* c,
79 int64_t ldc) {
80 // c *= beta
81 scale_(m, n, beta, c, ldc);
82
83 // c += alpha * (a @ b)
84 for (size_t l = 0; l < k; ++l) {
85 for (size_t j = 0; j < n; ++j) {
86 opmath_t val = b[l + j * ldb] * alpha;
87 int64_t i_m = m / 4;
88 for (int64_t i_i = 0; i_i < i_m; ++i_i) {
89 c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
90 c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
91 c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
92 c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
93 }
94 int64_t i = i_m * 4;
95 for (; i < m; i++) {
96 c[j * ldc + i] += a[i + l * lda] * val;
97 }
98 }
99 }
100 }
101
102 // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
103 template <typename scalar_t, typename opmath_t>
104 typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)105 gemm_notrans_(
106 int64_t m,
107 int64_t n,
108 int64_t k,
109 opmath_t alpha,
110 const scalar_t* a,
111 int64_t lda,
112 const scalar_t* b,
113 int64_t ldb,
114 opmath_t beta,
115 scalar_t* c,
116 int64_t ldc) {
117 // c += alpha * (a @ b)
118 for (size_t i = 0; i < m; ++i) {
119 for (size_t j = 0; j < n; ++j) {
120 const auto dot = sum(k, [&](int64_t l) -> opmath_t {
121 return static_cast<opmath_t>(a[l * lda + i]) *
122 static_cast<opmath_t>(b[j * ldb + l]);
123 });
124 if (beta == opmath_t(0)) {
125 c[j * ldc + i] = alpha * dot;
126 } else {
127 c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
128 }
129 }
130 }
131 }
132
133 // clang-format off
134 template <typename scalar_t, typename opmath_t>
gemm_transa_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)135 void gemm_transa_(
136 int64_t m, int64_t n, int64_t k,
137 opmath_t alpha,
138 const scalar_t *a, int64_t lda,
139 const scalar_t *b, int64_t ldb,
140 opmath_t beta,
141 scalar_t *c, int64_t ldc) {
142 // c = alpha * (a.T @ b) + beta * c
143 const scalar_t *a_ = a;
144 for (size_t i = 0; i < m; ++i) {
145 const scalar_t *b_ = b;
146 for (size_t j = 0; j < n; ++j) {
147 const auto dot = sum(k, [&](int64_t l) -> opmath_t {
148 return static_cast<opmath_t>(a_[l]) * static_cast<opmath_t>(b_[l]);
149 });
150 b_ += ldb;
151 if (beta == opmath_t(0)) {
152 c[j*ldc+i] = alpha*dot;
153 } else {
154 c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
155 }
156 }
157 a_ += lda;
158 }
159 }
160
161 #ifdef __aarch64__
162 namespace internal {
163 float bf16_dot_with_fp32_arith(const torch::executor::BFloat16* vec1, const torch::executor::BFloat16* vec2, int64_t len);
164 } // namespace internal
165
166 template <>
167 inline void gemm_transa_<torch::executor::BFloat16, torch::executor::BFloat16>(
168 int64_t m, int64_t n, int64_t k,
169 torch::executor::BFloat16 alpha,
170 const torch::executor::BFloat16 *a, int64_t lda,
171 const torch::executor::BFloat16 *b, int64_t ldb,
172 torch::executor::BFloat16 beta,
173 torch::executor::BFloat16 *c, int64_t ldc) {
174 // c = alpha * (a.T @ b) + beta * c
175 if (alpha == 1 && beta == 0) {
176 executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
177 const auto *a_ = a + begin * lda;
178 for (int i = begin; i < end; ++i) {
179 const auto *b_ = b;
180 for (int j = 0; j < n; ++j) {
181 const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
182 b_ += ldb;
183 c[j*ldc+i] = dot;
184 }
185 a_ += lda;
186 }
187 });
188 return;
189 }
190 executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
191 const auto *a_ = a + begin * lda;
192 for (int i = begin; i < end; ++i) {
193 const auto *b_ = b;
194 for (int j = 0; j < n; ++j) {
195 const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
196 b_ += ldb;
197 if (beta == 0) {
198 c[j*ldc+i] = alpha*dot;
199 } else {
200 c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
201 }
202 }
203 a_ += lda;
204 }
205 });
206 }
207 #endif
208
209 // clang-format on
210
211 template <typename scalar_t, typename opmath_t>
212 typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)213 gemm_transb_(
214 int64_t m,
215 int64_t n,
216 int64_t k,
217 opmath_t alpha,
218 const scalar_t* a,
219 int64_t lda,
220 const scalar_t* b,
221 int64_t ldb,
222 opmath_t beta,
223 scalar_t* c,
224 int64_t ldc) {
225 // c *= beta
226 scale_(m, n, beta, c, ldc);
227
228 // c += alpha * (a @ b.T)
229 for (size_t l = 0; l < k; ++l) {
230 for (size_t j = 0; j < n; ++j) {
231 opmath_t val = b[j + l * ldb] * alpha;
232 int64_t i_m = m / 4;
233 for (int64_t i_i = 0; i_i < i_m; ++i_i) {
234 c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
235 c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
236 c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
237 c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
238 }
239 int64_t i = i_m * 4;
240 for (; i < m; i++) {
241 c[j * ldc + i] += a[i + l * lda] * val;
242 }
243 }
244 }
245 }
246
247 // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
248 template <typename scalar_t, typename opmath_t>
249 typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)250 gemm_transb_(
251 int64_t m,
252 int64_t n,
253 int64_t k,
254 opmath_t alpha,
255 const scalar_t* a,
256 int64_t lda,
257 const scalar_t* b,
258 int64_t ldb,
259 opmath_t beta,
260 scalar_t* c,
261 int64_t ldc) {
262 // c += alpha * (a @ b.T)
263 for (size_t i = 0; i < m; ++i) {
264 for (size_t j = 0; j < n; ++j) {
265 const auto dot = sum(k, [&](int64_t l) -> opmath_t {
266 return static_cast<opmath_t>(a[l * lda + i]) *
267 static_cast<opmath_t>(b[l * ldb + j]);
268 });
269 if (beta == opmath_t(0)) {
270 c[j * ldc + i] = alpha * dot;
271 } else {
272 c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
273 }
274 }
275 }
276 }
277
278 // clang-format off
279 template <typename scalar_t, typename opmath_t>
gemm_transab_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)280 void gemm_transab_(
281 int64_t m, int64_t n, int64_t k,
282 opmath_t alpha,
283 const scalar_t *a, int64_t lda,
284 const scalar_t *b, int64_t ldb,
285 opmath_t beta,
286 scalar_t *c, int64_t ldc) {
287 // c = beta * c + alpha * (a.T @ b.T)
288 for (size_t i = 0; i < m; ++i) {
289 for (size_t j = 0; j < n; ++j) {
290 const auto dot = sum(k, [&](int64_t l) -> opmath_t {
291 return static_cast<opmath_t>(a[i * lda + l]) *
292 static_cast<opmath_t>(b[l * ldb + j]);
293 });
294
295 if (beta == opmath_t(0)) {
296 c[j * ldc + i] = alpha * dot;
297 } else {
298 c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
299 }
300 }
301 }
302 }
303 // clang-format on
304
305 } // namespace cpublas
306 } // namespace executorch
307