xref: /aosp_15_r20/external/executorch/kernels/optimized/blas/BlasKernel.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/optimized/blas/BlasKernel.h>
10 
11 #ifdef __aarch64__
12 #include <arm_neon.h>
13 #include <cpuinfo.h>
14 #endif
15 
16 using torch::executor::BFloat16;
17 
18 namespace executorch {
19 namespace cpublas {
20 namespace internal {
21 #ifdef __aarch64__
f32_fma(float32x4_t a,float32x4_t b,float32x4_t c)22 static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
23 #ifdef __ARM_FEATURE_FMA
24   return vfmaq_f32(a, b, c);
25 #else
26   return vaddq_f32(a, vmulq_f32(b, c));
27 #endif // __ARM_FEATURE_FMA
28 }
29 
30 // The below reduce overload and fp16_dot_with_fp32_arith are adapted
31 // from llama.cpp's ggml_vec_dot_f32 and surrounding utility
32 // functions. See NOTE [ GGML Copyright Notice ] above for the
33 // required notice.
34 
35 // We need the shift for reduce(), hence the extra constants.
36 static constexpr auto kF32ElementsPerIterationShift = 5;
37 static constexpr auto kF32ElementsPerIteration = 1
38     << kF32ElementsPerIterationShift;
39 static_assert(kF32ElementsPerIteration == 32);
40 
41 static constexpr auto kF32ElementsPerRegisterShift = 2;
42 static constexpr auto kF32ElementsPerRegister = 1
43     << kF32ElementsPerRegisterShift;
44 static_assert(kF32ElementsPerRegister == 4);
45 
46 static constexpr auto kF32RegisterPairsPerIteration = 4;
47 static constexpr auto kF32RegistersPerIteration =
48     kF32RegisterPairsPerIteration * 2;
49 static constexpr auto kF32RegistersPerIterationShift = 3;
50 static_assert(
51     kF32RegistersPerIteration ==
52     kF32ElementsPerIteration / kF32ElementsPerRegister);
53 static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);
54 
reduce(float32x4_t x[kF32RegistersPerIteration])55 static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
56   int offset = kF32RegistersPerIteration;
57   utils::ForcedUnroll<kF32RegistersPerIterationShift>{}(
58       [&offset, &x](auto idx) ET_INLINE_ATTRIBUTE {
59         offset /= 2;
60         for (int i = 0; i < offset; ++i) {
61           x[i] = vaddq_f32(x[i], x[offset + i]);
62         }
63       });
64   return vaddvq_f32(x[0]);
65 }
66 
to_bfloat16(uint16x4_t u16)67 static ET_INLINE float32x4_t to_bfloat16(uint16x4_t u16) {
68   int32x4_t shift = vdupq_n_s32(16);
69   return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift));
70 }
71 
72 static ET_INLINE float32x4_t
f32_fma_bf16(float32x4_t a,uint16x4_t b,uint16x4_t c)73 f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
74   return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
75 }
76 
77 #define ET_TARGET_ARM_BF16_ATTRIBUTE \
78   __attribute__((target("arch=armv8.2-a+bf16")))
79 ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t
f32_dot_bf16(float32x4_t a,bfloat16x8_t b,bfloat16x8_t c)80 f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
81   return vbfdotq_f32(a, b, c);
82 }
83 
84 ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
dot_with_fp32_arith_main_inner_loop_bfdot(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)85 dot_with_fp32_arith_main_inner_loop_bfdot(
86     const BFloat16* vec1,
87     const BFloat16* vec2,
88     float32x4_t sum[kF32RegistersPerIteration],
89     int registerPairIndex) {
90   const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
91       &vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
92   const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
93       &vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
94   sum[registerPairIndex] =
95       f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
96 }
97 
dot_with_fp32_arith_main_inner_loop_no_bfdot(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)98 static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
99     const BFloat16* vec1,
100     const BFloat16* vec2,
101     float32x4_t sum[kF32RegistersPerIteration],
102     int registerPairIndex) {
103   const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
104       &vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
105   const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
106       &vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
107 
108   sum[2 * registerPairIndex] = f32_fma_bf16(
109       sum[2 * registerPairIndex],
110       vget_low_u16(temp_vec1),
111       vget_low_u16(temp_vec2));
112   sum[2 * registerPairIndex + 1] = f32_fma_bf16(
113       sum[2 * registerPairIndex + 1],
114       vget_high_u16(temp_vec1),
115       vget_high_u16(temp_vec2));
116 }
117 
118 template <bool useBfdot>
119 ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
dot_with_fp32_arith_main_inner_loop(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)120 dot_with_fp32_arith_main_inner_loop(
121     const BFloat16* vec1,
122     const BFloat16* vec2,
123     float32x4_t sum[kF32RegistersPerIteration],
124     int registerPairIndex) {
125   if constexpr (useBfdot) {
126     dot_with_fp32_arith_main_inner_loop_bfdot(
127         vec1, vec2, sum, registerPairIndex);
128   } else {
129     dot_with_fp32_arith_main_inner_loop_no_bfdot(
130         vec1, vec2, sum, registerPairIndex);
131   }
132 }
133 
dot_with_fp32_arith_vectorized_tail_inner_loop(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t * tailSum,int idx)134 static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
135     const BFloat16* vec1,
136     const BFloat16* vec2,
137     float32x4_t* tailSum,
138     int idx) {
139   const auto temp_vec1 =
140       vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
141   const auto temp_vec2 =
142       vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
143   *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
144 }
145 
146 namespace {
147 template <int n>
148 struct ForcedUnrollTargetBFloat16 {
149   template <typename Func>
operator ()executorch::cpublas::internal::__anon6f41bdbc0211::ForcedUnrollTargetBFloat16150   ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
151     ForcedUnrollTargetBFloat16<n - 1>{}(f);
152     f(n - 1);
153   }
154 };
155 
156 template <>
157 struct ForcedUnrollTargetBFloat16<1> {
158   template <typename Func>
operator ()executorch::cpublas::internal::__anon6f41bdbc0211::ForcedUnrollTargetBFloat16159   ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
160     f(0);
161   }
162 };
163 
164 } // namespace
165 
166 template <typename T, bool useBFloat16Dot>
167 ET_TARGET_ARM_BF16_ATTRIBUTE float
dot_with_fp32_arith(const T * vec1,const T * vec2,int64_t len)168 dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
169   float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
170   const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
171   for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) {
172     const auto* vec1_ = vec1 + j;
173     const auto* vec2_ = vec2 + j;
174     ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}(
175         [vec1_, vec2_, &sum](auto k)
176             ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE {
177               dot_with_fp32_arith_main_inner_loop<useBFloat16Dot>(
178                   vec1_, vec2_, sum, k);
179             });
180   }
181   auto reducedSum = reduce(sum);
182 
183   // First-tier tail fixup: make sure we handle workloads that can
184   // benefit from vectorization, but don't fit into our fully unrolled
185   // loop above.
186   float32x4_t tailSum = vdupq_n_f32(0);
187   const auto len_aligned_4 = len & ~3;
188   for (int j = len_aligned; j < len_aligned_4; j += 4) {
189     dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j);
190   }
191   auto reducedTail = vpaddq_f32(tailSum, tailSum);
192   reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);
193 
194   // Second-tier tail fixup: handle all workloads.
195   for (int j = len_aligned_4; j < len; ++j) {
196     reducedSum += vec1[j] * vec2[j];
197   }
198   return reducedSum;
199 }
200 
bf16_dot_with_fp32_arith(const BFloat16 * vec1,const BFloat16 * vec2,int64_t len)201 float bf16_dot_with_fp32_arith(
202     const BFloat16* vec1,
203     const BFloat16* vec2,
204     int64_t len) {
205   if (cpuinfo_has_arm_bf16()) {
206     return dot_with_fp32_arith<BFloat16, true>(vec1, vec2, len);
207   } else {
208     return dot_with_fp32_arith<BFloat16, false>(vec1, vec2, len);
209   }
210 }
211 #endif // __aarch64__
212 } // namespace internal
213 } // namespace cpublas
214 } // namespace executorch
215