1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/optimized/blas/BlasKernel.h>
10
11 #ifdef __aarch64__
12 #include <arm_neon.h>
13 #include <cpuinfo.h>
14 #endif
15
16 using torch::executor::BFloat16;
17
18 namespace executorch {
19 namespace cpublas {
20 namespace internal {
21 #ifdef __aarch64__
f32_fma(float32x4_t a,float32x4_t b,float32x4_t c)22 static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
23 #ifdef __ARM_FEATURE_FMA
24 return vfmaq_f32(a, b, c);
25 #else
26 return vaddq_f32(a, vmulq_f32(b, c));
27 #endif // __ARM_FEATURE_FMA
28 }
29
30 // The below reduce overload and fp16_dot_with_fp32_arith are adapted
31 // from llama.cpp's ggml_vec_dot_f32 and surrounding utility
32 // functions. See NOTE [ GGML Copyright Notice ] above for the
33 // required notice.
34
35 // We need the shift for reduce(), hence the extra constants.
36 static constexpr auto kF32ElementsPerIterationShift = 5;
37 static constexpr auto kF32ElementsPerIteration = 1
38 << kF32ElementsPerIterationShift;
39 static_assert(kF32ElementsPerIteration == 32);
40
41 static constexpr auto kF32ElementsPerRegisterShift = 2;
42 static constexpr auto kF32ElementsPerRegister = 1
43 << kF32ElementsPerRegisterShift;
44 static_assert(kF32ElementsPerRegister == 4);
45
46 static constexpr auto kF32RegisterPairsPerIteration = 4;
47 static constexpr auto kF32RegistersPerIteration =
48 kF32RegisterPairsPerIteration * 2;
49 static constexpr auto kF32RegistersPerIterationShift = 3;
50 static_assert(
51 kF32RegistersPerIteration ==
52 kF32ElementsPerIteration / kF32ElementsPerRegister);
53 static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);
54
reduce(float32x4_t x[kF32RegistersPerIteration])55 static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
56 int offset = kF32RegistersPerIteration;
57 utils::ForcedUnroll<kF32RegistersPerIterationShift>{}(
58 [&offset, &x](auto idx) ET_INLINE_ATTRIBUTE {
59 offset /= 2;
60 for (int i = 0; i < offset; ++i) {
61 x[i] = vaddq_f32(x[i], x[offset + i]);
62 }
63 });
64 return vaddvq_f32(x[0]);
65 }
66
to_bfloat16(uint16x4_t u16)67 static ET_INLINE float32x4_t to_bfloat16(uint16x4_t u16) {
68 int32x4_t shift = vdupq_n_s32(16);
69 return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift));
70 }
71
72 static ET_INLINE float32x4_t
f32_fma_bf16(float32x4_t a,uint16x4_t b,uint16x4_t c)73 f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
74 return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
75 }
76
77 #define ET_TARGET_ARM_BF16_ATTRIBUTE \
78 __attribute__((target("arch=armv8.2-a+bf16")))
79 ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t
f32_dot_bf16(float32x4_t a,bfloat16x8_t b,bfloat16x8_t c)80 f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
81 return vbfdotq_f32(a, b, c);
82 }
83
84 ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
dot_with_fp32_arith_main_inner_loop_bfdot(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)85 dot_with_fp32_arith_main_inner_loop_bfdot(
86 const BFloat16* vec1,
87 const BFloat16* vec2,
88 float32x4_t sum[kF32RegistersPerIteration],
89 int registerPairIndex) {
90 const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
91 &vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
92 const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
93 &vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
94 sum[registerPairIndex] =
95 f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
96 }
97
dot_with_fp32_arith_main_inner_loop_no_bfdot(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)98 static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
99 const BFloat16* vec1,
100 const BFloat16* vec2,
101 float32x4_t sum[kF32RegistersPerIteration],
102 int registerPairIndex) {
103 const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
104 &vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
105 const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
106 &vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
107
108 sum[2 * registerPairIndex] = f32_fma_bf16(
109 sum[2 * registerPairIndex],
110 vget_low_u16(temp_vec1),
111 vget_low_u16(temp_vec2));
112 sum[2 * registerPairIndex + 1] = f32_fma_bf16(
113 sum[2 * registerPairIndex + 1],
114 vget_high_u16(temp_vec1),
115 vget_high_u16(temp_vec2));
116 }
117
118 template <bool useBfdot>
119 ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
dot_with_fp32_arith_main_inner_loop(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)120 dot_with_fp32_arith_main_inner_loop(
121 const BFloat16* vec1,
122 const BFloat16* vec2,
123 float32x4_t sum[kF32RegistersPerIteration],
124 int registerPairIndex) {
125 if constexpr (useBfdot) {
126 dot_with_fp32_arith_main_inner_loop_bfdot(
127 vec1, vec2, sum, registerPairIndex);
128 } else {
129 dot_with_fp32_arith_main_inner_loop_no_bfdot(
130 vec1, vec2, sum, registerPairIndex);
131 }
132 }
133
dot_with_fp32_arith_vectorized_tail_inner_loop(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t * tailSum,int idx)134 static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
135 const BFloat16* vec1,
136 const BFloat16* vec2,
137 float32x4_t* tailSum,
138 int idx) {
139 const auto temp_vec1 =
140 vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
141 const auto temp_vec2 =
142 vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
143 *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
144 }
145
146 namespace {
147 template <int n>
148 struct ForcedUnrollTargetBFloat16 {
149 template <typename Func>
operator ()executorch::cpublas::internal::__anon6f41bdbc0211::ForcedUnrollTargetBFloat16150 ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
151 ForcedUnrollTargetBFloat16<n - 1>{}(f);
152 f(n - 1);
153 }
154 };
155
156 template <>
157 struct ForcedUnrollTargetBFloat16<1> {
158 template <typename Func>
operator ()executorch::cpublas::internal::__anon6f41bdbc0211::ForcedUnrollTargetBFloat16159 ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
160 f(0);
161 }
162 };
163
164 } // namespace
165
166 template <typename T, bool useBFloat16Dot>
167 ET_TARGET_ARM_BF16_ATTRIBUTE float
dot_with_fp32_arith(const T * vec1,const T * vec2,int64_t len)168 dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
169 float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
170 const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
171 for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) {
172 const auto* vec1_ = vec1 + j;
173 const auto* vec2_ = vec2 + j;
174 ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}(
175 [vec1_, vec2_, &sum](auto k)
176 ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE {
177 dot_with_fp32_arith_main_inner_loop<useBFloat16Dot>(
178 vec1_, vec2_, sum, k);
179 });
180 }
181 auto reducedSum = reduce(sum);
182
183 // First-tier tail fixup: make sure we handle workloads that can
184 // benefit from vectorization, but don't fit into our fully unrolled
185 // loop above.
186 float32x4_t tailSum = vdupq_n_f32(0);
187 const auto len_aligned_4 = len & ~3;
188 for (int j = len_aligned; j < len_aligned_4; j += 4) {
189 dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j);
190 }
191 auto reducedTail = vpaddq_f32(tailSum, tailSum);
192 reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);
193
194 // Second-tier tail fixup: handle all workloads.
195 for (int j = len_aligned_4; j < len; ++j) {
196 reducedSum += vec1[j] * vec2[j];
197 }
198 return reducedSum;
199 }
200
bf16_dot_with_fp32_arith(const BFloat16 * vec1,const BFloat16 * vec2,int64_t len)201 float bf16_dot_with_fp32_arith(
202 const BFloat16* vec1,
203 const BFloat16* vec2,
204 int64_t len) {
205 if (cpuinfo_has_arm_bf16()) {
206 return dot_with_fp32_arith<BFloat16, true>(vec1, vec2, len);
207 } else {
208 return dot_with_fp32_arith<BFloat16, false>(vec1, vec2, len);
209 }
210 }
211 #endif // __aarch64__
212 } // namespace internal
213 } // namespace cpublas
214 } // namespace executorch
215