xref: /aosp_15_r20/external/executorch/kernels/optimized/blas/BlasKernel.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/kernels/optimized/utils/math_utils.h>
12 #include <executorch/kernels/optimized/utils/unroll.h>
13 
14 #include <executorch/extension/parallel/thread_parallel.h>
15 #include <executorch/runtime/core/portable_type/bfloat16.h>
16 
17 #include <array>
18 
19 namespace executorch {
20 namespace cpublas {
21 
22 template <typename scalar_t, typename opmath_t>
scale_(int64_t m,int64_t n,opmath_t alpha,scalar_t * a,int64_t lda)23 void scale_(int64_t m, int64_t n, opmath_t alpha, scalar_t* a, int64_t lda) {
24   if (alpha == opmath_t(1)) {
25     return; // identity
26   }
27 
28   if (alpha == opmath_t(0)) {
29     for (size_t j = 0; j < n; ++j) {
30       for (size_t i = 0; i < m; ++i) {
31         a[j * lda + i] = scalar_t(0);
32       }
33     }
34     return;
35   }
36 
37   for (size_t j = 0; j < n; ++j) {
38     for (size_t i = 0; i < m; ++i) {
39       a[j * lda + i] *= alpha;
40     }
41   }
42 }
43 
44 template <typename Func>
sum(int64_t N,Func f)45 auto sum(int64_t N, Func f) {
46   constexpr int ilp_factor = 4;
47   using acc_t = decltype(f(0));
48 
49   // Calculate independent partial sums then add together at the end
50   std::array<acc_t, ilp_factor> partial_sums{};
51 
52   size_t i = 0;
53   for (; i + ilp_factor <= N; i += ilp_factor) {
54     utils::ForcedUnroll<ilp_factor>{}(
55         [&i, &f, &partial_sums](int k) { partial_sums[k] += f(i + k); });
56   }
57   for (; i < N; ++i) {
58     partial_sums[0] += f(i);
59   }
60   for (int k = 1; k < ilp_factor; ++k) {
61     partial_sums[0] += partial_sums[k];
62   }
63   return partial_sums[0];
64 }
65 
66 template <typename scalar_t, typename opmath_t>
67 typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)68 gemm_notrans_(
69     int64_t m,
70     int64_t n,
71     int64_t k,
72     opmath_t alpha,
73     const scalar_t* a,
74     int64_t lda,
75     const scalar_t* b,
76     int64_t ldb,
77     opmath_t beta,
78     scalar_t* c,
79     int64_t ldc) {
80   // c *= beta
81   scale_(m, n, beta, c, ldc);
82 
83   // c += alpha * (a @ b)
84   for (size_t l = 0; l < k; ++l) {
85     for (size_t j = 0; j < n; ++j) {
86       opmath_t val = b[l + j * ldb] * alpha;
87       int64_t i_m = m / 4;
88       for (int64_t i_i = 0; i_i < i_m; ++i_i) {
89         c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
90         c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
91         c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
92         c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
93       }
94       int64_t i = i_m * 4;
95       for (; i < m; i++) {
96         c[j * ldc + i] += a[i + l * lda] * val;
97       }
98     }
99   }
100 }
101 
102 // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
103 template <typename scalar_t, typename opmath_t>
104 typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)105 gemm_notrans_(
106     int64_t m,
107     int64_t n,
108     int64_t k,
109     opmath_t alpha,
110     const scalar_t* a,
111     int64_t lda,
112     const scalar_t* b,
113     int64_t ldb,
114     opmath_t beta,
115     scalar_t* c,
116     int64_t ldc) {
117   // c += alpha * (a @ b)
118   for (size_t i = 0; i < m; ++i) {
119     for (size_t j = 0; j < n; ++j) {
120       const auto dot = sum(k, [&](int64_t l) -> opmath_t {
121         return static_cast<opmath_t>(a[l * lda + i]) *
122             static_cast<opmath_t>(b[j * ldb + l]);
123       });
124       if (beta == opmath_t(0)) {
125         c[j * ldc + i] = alpha * dot;
126       } else {
127         c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
128       }
129     }
130   }
131 }
132 
133 // clang-format off
134 template <typename scalar_t, typename opmath_t>
gemm_transa_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)135 void gemm_transa_(
136     int64_t m, int64_t n, int64_t k,
137     opmath_t alpha,
138     const scalar_t *a, int64_t lda,
139     const scalar_t *b, int64_t ldb,
140     opmath_t beta,
141     scalar_t *c, int64_t ldc) {
142   // c = alpha * (a.T @ b) + beta * c
143   const scalar_t *a_ = a;
144   for (size_t i = 0; i < m; ++i) {
145     const scalar_t *b_ = b;
146     for (size_t j = 0; j < n; ++j) {
147       const auto dot = sum(k, [&](int64_t l) -> opmath_t {
148         return static_cast<opmath_t>(a_[l]) * static_cast<opmath_t>(b_[l]);
149       });
150       b_ += ldb;
151       if (beta == opmath_t(0)) {
152         c[j*ldc+i] = alpha*dot;
153       } else {
154         c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
155       }
156     }
157     a_ += lda;
158   }
159 }
160 
161 #ifdef __aarch64__
162 namespace internal {
163 float bf16_dot_with_fp32_arith(const torch::executor::BFloat16* vec1, const torch::executor::BFloat16* vec2, int64_t len);
164 } // namespace internal
165 
166 template <>
167 inline void gemm_transa_<torch::executor::BFloat16, torch::executor::BFloat16>(
168     int64_t m, int64_t n, int64_t k,
169     torch::executor::BFloat16 alpha,
170     const torch::executor::BFloat16 *a, int64_t lda,
171     const torch::executor::BFloat16 *b, int64_t ldb,
172     torch::executor::BFloat16 beta,
173     torch::executor::BFloat16 *c, int64_t ldc) {
174   // c = alpha * (a.T @ b) + beta * c
175   if (alpha == 1 && beta == 0) {
176     executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
177       const auto *a_ = a + begin * lda;
178       for (int i = begin; i < end; ++i) {
179         const auto *b_ = b;
180         for (int j = 0; j < n; ++j) {
181           const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
182           b_ += ldb;
183           c[j*ldc+i] = dot;
184         }
185         a_ += lda;
186       }
187     });
188     return;
189   }
190   executorch::extension::parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
191     const auto *a_ = a + begin * lda;
192     for (int i = begin; i < end; ++i) {
193       const auto *b_ = b;
194       for (int j = 0; j < n; ++j) {
195         const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
196         b_ += ldb;
197         if (beta == 0) {
198           c[j*ldc+i] = alpha*dot;
199         } else {
200           c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
201         }
202       }
203       a_ += lda;
204     }
205   });
206 }
207 #endif
208 
209 // clang-format on
210 
211 template <typename scalar_t, typename opmath_t>
212 typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)213 gemm_transb_(
214     int64_t m,
215     int64_t n,
216     int64_t k,
217     opmath_t alpha,
218     const scalar_t* a,
219     int64_t lda,
220     const scalar_t* b,
221     int64_t ldb,
222     opmath_t beta,
223     scalar_t* c,
224     int64_t ldc) {
225   // c *= beta
226   scale_(m, n, beta, c, ldc);
227 
228   // c += alpha * (a @ b.T)
229   for (size_t l = 0; l < k; ++l) {
230     for (size_t j = 0; j < n; ++j) {
231       opmath_t val = b[j + l * ldb] * alpha;
232       int64_t i_m = m / 4;
233       for (int64_t i_i = 0; i_i < i_m; ++i_i) {
234         c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
235         c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
236         c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
237         c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
238       }
239       int64_t i = i_m * 4;
240       for (; i < m; i++) {
241         c[j * ldc + i] += a[i + l * lda] * val;
242       }
243     }
244   }
245 }
246 
247 // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
248 template <typename scalar_t, typename opmath_t>
249 typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)250 gemm_transb_(
251     int64_t m,
252     int64_t n,
253     int64_t k,
254     opmath_t alpha,
255     const scalar_t* a,
256     int64_t lda,
257     const scalar_t* b,
258     int64_t ldb,
259     opmath_t beta,
260     scalar_t* c,
261     int64_t ldc) {
262   // c += alpha * (a @ b.T)
263   for (size_t i = 0; i < m; ++i) {
264     for (size_t j = 0; j < n; ++j) {
265       const auto dot = sum(k, [&](int64_t l) -> opmath_t {
266         return static_cast<opmath_t>(a[l * lda + i]) *
267             static_cast<opmath_t>(b[l * ldb + j]);
268       });
269       if (beta == opmath_t(0)) {
270         c[j * ldc + i] = alpha * dot;
271       } else {
272         c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
273       }
274     }
275   }
276 }
277 
278 // clang-format off
279 template <typename scalar_t, typename opmath_t>
gemm_transab_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)280 void gemm_transab_(
281     int64_t m, int64_t n, int64_t k,
282     opmath_t alpha,
283     const scalar_t *a, int64_t lda,
284     const scalar_t *b, int64_t ldb,
285     opmath_t beta,
286     scalar_t *c, int64_t ldc) {
287   // c = beta * c + alpha * (a.T @ b.T)
288   for (size_t i = 0; i < m; ++i) {
289     for (size_t j = 0; j < n; ++j) {
290       const auto dot = sum(k, [&](int64_t l) -> opmath_t {
291         return static_cast<opmath_t>(a[i * lda + l]) *
292             static_cast<opmath_t>(b[l * ldb + j]);
293       });
294 
295       if (beta == opmath_t(0)) {
296         c[j * ldc + i] = alpha * dot;
297       } else {
298         c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
299       }
300     }
301   }
302 }
303 // clang-format on
304 
305 } // namespace cpublas
306 } // namespace executorch
307