xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Fast Gemv (i.e. matrix*vector multiplication) paths.
17 // TODO(b/132094390): remove when GEMM performance is good enough on GEMV cases.
18 
19 // TFLite's runtime ops concentrate as much as possible the matrix*vector
20 // use cases on the (matrix) * (column-vector) case, as opposed to
21 // (row-vector) * (matrix).  So that is what we focus on optimizing here.
22 // Accordingly, the public cpu_backend_gemm::Gemm() entry point checks
23 // if we are in this (matrix) * (column-vector) case, and if so calls
24 // CustomGemv.
25 //
26 // cpu_backend_gemm::Gemm is also currently restricted (as enforced in
27 // ValidateParams) to the case where the left-hand side matrix is row-major.
28 //
29 // So the current scope of this CustomGemv function really is:
30 // (row-major matrix) * (column-vector).
31 
32 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_
33 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_
34 
35 #include <stdint.h>
36 
37 #include <algorithm>
38 #include <type_traits>
39 #include <vector>
40 
41 #include "ruy/profiler/instrumentation.h"  // from @ruy
42 #include "tensorflow/lite/kernels/cpu_backend_context.h"
43 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
44 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
45 #include "tensorflow/lite/kernels/internal/common.h"
46 #include "tensorflow/lite/kernels/internal/compatibility.h"
47 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
48 
49 namespace tflite {
50 namespace cpu_backend_gemm {
51 namespace detail {
52 
53 // CustomGemvImpl is what needs to be specialized for each custom GEMV path.
54 //
55 // It does not deal with any multi-threaded implementation detail. Rather,
56 // it provides the single-thread implementation to be run by each thread.
57 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
58           typename DstScalar, QuantizationFlavor quantization_flavor>
59 struct CustomGemvImpl {
60   // The number of rows of the left-hand-side matrix (and equivalently of the
61   // destination column-vector) that the kernel processes at a time.
62   // This will also be the minimum required number of rows for a Gemv shape
63   // to be supported by this path.
64   //
65   // Gemv implementations are expected to be able to deal with numbers of
66   // rows that aren't multiples of kKernelRows by possibly running the kernel
67   // again at an odd row_start, e.g. if kKernelRows==4, Run() should still
68   // support running on 7 rows by running twice: once with row_start=0 and then
69   // another time with row_start=3.
70   //
71   // On the other hand, gemv implementations are not expected to support
72   // running on fewer than kKernelRows rows. There is no interest in
73   // optimizing such narrow Gemv's that they are just a few dot-products.
74   // Supporting that would require custom kernel code only for that case.
75   static constexpr int kKernelRows = 1;
76 
77   // Returns true if the Gemv shape is supported by Run(), provided that
78   // (row_end - row_start) > kKernelRows.
IsSupportedGivenSufficientlyManyRowsCustomGemvImpl79   static bool IsSupportedGivenSufficientlyManyRows(
80       const MatrixParams<LhsScalar>& lhs_params,
81       const MatrixParams<RhsScalar>& rhs_params,
82       const MatrixParams<DstScalar>& dst_params,
83       const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
84     return false;
85   }
86 
87   // Performs the Gemv.
RunCustomGemvImpl88   static void Run(
89       const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
90       const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
91       const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
92       const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
93       int row_start, int row_end) {}
94 };
95 
96 // Wraps CustomGemvImpl for multi-threaded operation.
97 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
98           typename DstScalar, QuantizationFlavor quantization_flavor>
99 class CustomGemvTask : public cpu_backend_threadpool::Task {
100  public:
CustomGemvTask(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<DstScalar> & dst_params,DstScalar * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,int row_start,int row_end)101   CustomGemvTask(
102       const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
103       const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
104       const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
105       const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
106       int row_start, int row_end)
107       : lhs_params_(lhs_params),
108         lhs_data_(lhs_data),
109         rhs_params_(rhs_params),
110         rhs_data_(rhs_data),
111         dst_params_(dst_params),
112         dst_data_(dst_data),
113         params_(params),
114         row_start_(row_start),
115         row_end_(row_end) {}
116 
Run()117   void Run() override {
118     using Impl = CustomGemvImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
119                                 quantization_flavor>;
120     Impl::Run(lhs_params_, lhs_data_, rhs_params_, rhs_data_, dst_params_,
121               dst_data_, params_, row_start_, row_end_);
122   }
123 
124  private:
125   const MatrixParams<LhsScalar>& lhs_params_;
126   const LhsScalar* lhs_data_;
127   const MatrixParams<RhsScalar>& rhs_params_;
128   const RhsScalar* rhs_data_;
129   const MatrixParams<DstScalar>& dst_params_;
130   DstScalar* dst_data_;
131   const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params_;
132   int row_start_;
133   int row_end_;
134 };
135 
136 // Either performs the requested Gemv operation and returns true,
137 // or immediately returns false.
138 //
139 // See the comment at the top of the file for the scope of what this handles.
140 // In summary: (row-major matrix) * (column-vector).
141 //
142 // Here is only high-level logic.
143 // The actual implementation details are in specializations of
144 // CustomGemvImpl.
145 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
146           typename DstScalar, QuantizationFlavor quantization_flavor>
CustomGemv(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<DstScalar> & dst_params,DstScalar * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,CpuBackendContext * context)147 bool CustomGemv(
148     const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
149     const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
150     const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
151     const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
152     CpuBackendContext* context) {
153   ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm: CustomGemv");
154   using Impl = CustomGemvImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
155                               quantization_flavor>;
156   if (lhs_params.rows < Impl::kKernelRows) {
157     return false;
158   }
159   if (!Impl::IsSupportedGivenSufficientlyManyRows(lhs_params, rhs_params,
160                                                   dst_params, params)) {
161     return false;
162   }
163   TFLITE_DCHECK_GE(lhs_params.rows, Impl::kKernelRows);
164   int thread_count = LegacyHowManyThreads<Impl::kKernelRows>(
165       context->max_num_threads(), dst_params.rows, dst_params.cols,
166       lhs_params.cols);
167   if (thread_count == 1) {
168     Impl::Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data,
169               params, 0, lhs_params.rows);
170   } else {
171     using Task = CustomGemvTask<LhsScalar, RhsScalar, AccumScalar, DstScalar,
172                                 quantization_flavor>;
173     std::vector<Task> tasks;
174     tasks.reserve(thread_count);
175     const int kRowsPerThread =
176         RoundUp<Impl::kKernelRows>(CeilQuotient(dst_params.rows, thread_count));
177     int row_start = 0;
178     for (int i = 0; i < thread_count; i++) {
179       int row_end = std::min(dst_params.rows, row_start + kRowsPerThread);
180       tasks.emplace_back(lhs_params, lhs_data, rhs_params, rhs_data, dst_params,
181                          dst_data, params, row_start, row_end);
182       row_start = row_end;
183     }
184     cpu_backend_threadpool::Execute(tasks.size(), tasks.data(), context);
185   }
186   return true;
187 }
188 
189 // USE_NEON still allows for x86 where we may be using the arm_neon_sse.h
190 // wrapper implementing NEON intrinsics on top of SSE4 intrinsics.
191 #ifdef USE_NEON
192 
193 // Some NEON helper functions used by CustomGemvImpl specializations below,
194 // allowing for some type genericity in them.
195 
Load16AndSubtractZeroPoint(const std::uint8_t * src,std::uint8_t zero_point)196 inline int16x8x2_t Load16AndSubtractZeroPoint(const std::uint8_t* src,
197                                               std::uint8_t zero_point) {
198   uint8x16_t src_u8 = vld1q_u8(src);
199   int16x8_t src_s16_0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(src_u8)));
200   int16x8_t src_s16_1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(src_u8)));
201   int16x8x2_t result;
202   int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
203   result.val[0] = vsubq_s16(src_s16_0, zero_point_vec);
204   result.val[1] = vsubq_s16(src_s16_1, zero_point_vec);
205   return result;
206 }
207 
Load16AndSubtractZeroPoint(const std::int8_t * src,std::int8_t zero_point)208 inline int16x8x2_t Load16AndSubtractZeroPoint(const std::int8_t* src,
209                                               std::int8_t zero_point) {
210   int8x16_t src_s8 = vld1q_s8(src);
211   int16x8_t src_s16_0 = vmovl_s8(vget_low_s8(src_s8));
212   int16x8_t src_s16_1 = vmovl_s8(vget_high_s8(src_s8));
213   int16x8x2_t result;
214   int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
215   result.val[0] = vsubq_s16(src_s16_0, zero_point_vec);
216   result.val[1] = vsubq_s16(src_s16_1, zero_point_vec);
217   return result;
218 }
219 
Load8AndSubtractZeroPoint(const std::uint8_t * src,std::uint8_t zero_point)220 inline int16x8_t Load8AndSubtractZeroPoint(const std::uint8_t* src,
221                                            std::uint8_t zero_point) {
222   uint8x8_t src_u8 = vld1_u8(src);
223   int16x8_t src_s16 = vreinterpretq_s16_u16(vmovl_u8(src_u8));
224   int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
225   return vsubq_s16(src_s16, zero_point_vec);
226 }
227 
Load8AndSubtractZeroPoint(const std::int8_t * src,std::int8_t zero_point)228 inline int16x8_t Load8AndSubtractZeroPoint(const std::int8_t* src,
229                                            std::int8_t zero_point) {
230   int8x8_t src_s8 = vld1_s8(src);
231   int16x8_t src_s16 = vmovl_s8(src_s8);
232   int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
233   return vsubq_s16(src_s16, zero_point_vec);
234 }
235 
ClampAndStore(int32x4_t src,std::uint8_t clamp_min,std::uint8_t clamp_max,std::uint8_t * dst)236 inline void ClampAndStore(int32x4_t src, std::uint8_t clamp_min,
237                           std::uint8_t clamp_max, std::uint8_t* dst) {
238   // Narrow values down to 16 bit signed.
239   const int16x4_t res16 = vqmovn_s32(src);
240   // Narrow values down to 8 bit unsigned, saturating.
241   uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
242   // Apply the clamping from the activation function
243   res8 = vmax_u8(res8, vdup_n_u8(clamp_min));
244   res8 = vmin_u8(res8, vdup_n_u8(clamp_max));
245   // Store results to destination.
246   vst1_lane_u8(dst + 0, res8, 0);
247   vst1_lane_u8(dst + 1, res8, 1);
248   vst1_lane_u8(dst + 2, res8, 2);
249   vst1_lane_u8(dst + 3, res8, 3);
250 }
251 
ClampAndStore(int32x4_t src,std::int8_t clamp_min,std::int8_t clamp_max,std::int8_t * dst)252 inline void ClampAndStore(int32x4_t src, std::int8_t clamp_min,
253                           std::int8_t clamp_max, std::int8_t* dst) {
254   // Narrow values down to 16 bit signed.
255   const int16x4_t res16 = vqmovn_s32(src);
256   // Narrow values down to 8 bit unsigned, saturating.
257   int8x8_t res8 = vqmovn_s16(vcombine_s16(res16, res16));
258   // Apply the clamping from the activation function
259   res8 = vmax_s8(res8, vdup_n_s8(clamp_min));
260   res8 = vmin_s8(res8, vdup_n_s8(clamp_max));
261   // Store results to destination.
262   vst1_lane_s8(dst + 0, res8, 0);
263   vst1_lane_s8(dst + 1, res8, 1);
264   vst1_lane_s8(dst + 2, res8, 2);
265   vst1_lane_s8(dst + 3, res8, 3);
266 }
267 
ClampAndStore(int32x4_t src,std::int16_t clamp_min,std::int16_t clamp_max,std::int16_t * dst)268 inline void ClampAndStore(int32x4_t src, std::int16_t clamp_min,
269                           std::int16_t clamp_max, std::int16_t* dst) {
270   // Narrow values down to 16 bit signed.
271   int16x4_t res16 = vqmovn_s32(src);
272   // Apply the clamping from the activation function
273   res16 = vmax_s16(res16, vdup_n_s16(clamp_min));
274   res16 = vmin_s16(res16, vdup_n_s16(clamp_max));
275   // Store results to destination.
276   vst1_lane_s16(dst + 0, res16, 0);
277   vst1_lane_s16(dst + 1, res16, 1);
278   vst1_lane_s16(dst + 2, res16, 2);
279   vst1_lane_s16(dst + 3, res16, 3);
280 }
281 
282 template <typename LhsScalar, typename RhsScalar, typename DstScalar,
283           QuantizationFlavor quantization_flavor>
284 struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
285                       quantization_flavor> {
286   // This partial template specialization is less generic than its declaration
287   // implies: it assumes the following constraints on its free template
288   // parameters. We guard these assumptions in the following static_assert's.
289   static_assert(std::is_same<LhsScalar, std::uint8_t>::value ||
290                     std::is_same<LhsScalar, std::int8_t>::value,
291                 "");
292   static_assert(std::is_same<RhsScalar, std::uint8_t>::value ||
293                     std::is_same<RhsScalar, std::int8_t>::value,
294                 "");
295   static_assert(std::is_same<DstScalar, std::uint8_t>::value ||
296                     std::is_same<DstScalar, std::int8_t>::value ||
297                     std::is_same<DstScalar, std::int16_t>::value,
298                 "");
299   static_assert(quantization_flavor ==
300                         QuantizationFlavor::kIntegerWithUniformMultiplier ||
301                     quantization_flavor ==
302                         QuantizationFlavor::kIntegerWithPerRowMultiplier,
303                 "");
304 
305   // This implementation's inner loop processes 4 rows of the left-hand side
306   // matrix at a time.
307   static constexpr int kKernelRows = 4;
308 
309   static bool IsSupportedGivenSufficientlyManyRows(
310       const MatrixParams<LhsScalar>& lhs_params,
311       const MatrixParams<RhsScalar>& rhs_params,
312       const MatrixParams<DstScalar>& dst_params,
313       const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params) {
314     // The kernel processes at least 8 LHS columns at once to fill NEON
315     // registers. The leftovers-handling code at the end works by loading a
316     // partially overlapping final register by walking back by a few (<8) values
317     // to avoid running past the row's end. This relies on there being
318     // at least 8 LHS columns.
319     return lhs_params.cols >= 8;
320   }
321 
322   static void Run(
323       const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
324       const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
325       const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
326       const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params,
327       int row_start, int row_end) {
328     // Handle kKernelRows ( == 4) rows of the left-hand side matrix at each
329     // iteration of this for loop.
330     TFLITE_DCHECK_GE(row_end - row_start, kKernelRows);
331     for (int row = row_start; row < row_end; row += kKernelRows) {
332       // Here is the magic where we allow this kernel to handle any odd number
333       // of rows as long as it's >= kKernelRows: the last group of `kKernelRows`
334       // rows will be nudged to fit, possibly by starting at an odd value of
335       // `row`.
336       row = std::min(row, row_end - kKernelRows);
337       const LhsScalar* filter_ptr = lhs_data + row * lhs_params.cols;
338 
339       static constexpr int kCacheLineSize = 64;
340       for (int k = 0; k < rhs_params.rows;
341            k += kCacheLineSize / sizeof(RhsScalar)) {
342         optimized_ops_preload_l1_keep(rhs_data + k);
343       }
344 
345       // kPreloadAhead is empirically determined.
346       // End-to-end latency (ms) on mobilenet_v2_0.35_96_8bit, 1 thread,
347       // on Qualcomm S855:
348       //
349       // kPreloadAhead | big core | little core
350       // --------------+----------+------------
351       // 64            | 1.26     | 5.45
352       // 128           | 1.23     | 5.01
353       // 256           | 1.18     | 4.9
354       // 512           | 1.18     | 5.45
355       // 1024          | 1.18     | 6.5
356       // no prefetch   | 1.25     | 8.1
357       static constexpr int kPreloadAhead = 256;
358 
359       // 4 accumulator registers, one for each row being processed.
360       // Each has 4 int32 lanes that corresponds to columns modulo 4, and
361       // will need to be horizontally reduced at the end.
362       int32x4_t acc0 = vdupq_n_s32(0);
363       int32x4_t acc1 = acc0;
364       int32x4_t acc2 = acc0;
365       int32x4_t acc3 = acc0;
366       int in = 0;
367       // As much as possible, handle 16 columns of the left-hand side matrix
368       // at a time. This allows for decent NEON implementation.
369       for (; in <= lhs_params.cols - 16; in += 16) {
370         const LhsScalar* local_filter_ptr = filter_ptr;
371         int16x8x2_t input_val =
372             Load16AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
373         int16x8x2_t filter_val_0 =
374             Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
375         optimized_ops_preload_l1_stream(local_filter_ptr +
376                                         kPreloadAhead / sizeof(LhsScalar));
377         local_filter_ptr += lhs_params.cols;
378         int16x8x2_t filter_val_1 =
379             Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
380         optimized_ops_preload_l1_stream(local_filter_ptr +
381                                         kPreloadAhead / sizeof(LhsScalar));
382         local_filter_ptr += lhs_params.cols;
383         int16x8x2_t filter_val_2 =
384             Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
385         optimized_ops_preload_l1_stream(local_filter_ptr +
386                                         kPreloadAhead / sizeof(LhsScalar));
387         local_filter_ptr += lhs_params.cols;
388         int16x8x2_t filter_val_3 =
389             Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
390         optimized_ops_preload_l1_stream(local_filter_ptr +
391                                         kPreloadAhead / sizeof(LhsScalar));
392         filter_ptr += 16;
393         acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[0]),
394                          vget_low_s16(input_val.val[0]));
395         acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1.val[0]),
396                          vget_low_s16(input_val.val[0]));
397         acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2.val[0]),
398                          vget_low_s16(input_val.val[0]));
399         acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3.val[0]),
400                          vget_low_s16(input_val.val[0]));
401         acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[1]),
402                          vget_low_s16(input_val.val[1]));
403         acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1.val[1]),
404                          vget_low_s16(input_val.val[1]));
405         acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2.val[1]),
406                          vget_low_s16(input_val.val[1]));
407         acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3.val[1]),
408                          vget_low_s16(input_val.val[1]));
409         acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0.val[0]),
410                          vget_high_s16(input_val.val[0]));
411         acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1.val[0]),
412                          vget_high_s16(input_val.val[0]));
413         acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2.val[0]),
414                          vget_high_s16(input_val.val[0]));
415         acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[0]),
416                          vget_high_s16(input_val.val[0]));
417         acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0.val[1]),
418                          vget_high_s16(input_val.val[1]));
419         acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1.val[1]),
420                          vget_high_s16(input_val.val[1]));
421         acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2.val[1]),
422                          vget_high_s16(input_val.val[1]));
423         acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[1]),
424                          vget_high_s16(input_val.val[1]));
425       }
426       // Less that 16 values remain. Try to handle 8 more.
427       if (in <= lhs_params.cols - 8) {
428         int16x8_t input_val =
429             Load8AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
430         int16x8_t filter_val_0 = Load8AndSubtractZeroPoint(
431             filter_ptr + 0 * lhs_params.cols, lhs_params.zero_point);
432         int16x8_t filter_val_1 = Load8AndSubtractZeroPoint(
433             filter_ptr + 1 * lhs_params.cols, lhs_params.zero_point);
434         int16x8_t filter_val_2 = Load8AndSubtractZeroPoint(
435             filter_ptr + 2 * lhs_params.cols, lhs_params.zero_point);
436         int16x8_t filter_val_3 = Load8AndSubtractZeroPoint(
437             filter_ptr + 3 * lhs_params.cols, lhs_params.zero_point);
438         filter_ptr += 8;
439         acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0),
440                          vget_low_s16(input_val));
441         acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1),
442                          vget_low_s16(input_val));
443         acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2),
444                          vget_low_s16(input_val));
445         acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3),
446                          vget_low_s16(input_val));
447         acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
448                          vget_high_s16(input_val));
449         acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
450                          vget_high_s16(input_val));
451         acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
452                          vget_high_s16(input_val));
453         acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
454                          vget_high_s16(input_val));
455         in += 8;
456       }
457       // Less than 8 values remain. Handle the remaining values
458       // in one more copy of the above code handling 8, where we
459       // walk back a few values to be able to load 8 values without
460       // overrunning the buffer. This is where we make use of the requirement
461       // (see IsSupportedGivenSufficientlyManyRows) that there at least
462       // 8 LHS columns.
463       if (in < lhs_params.cols) {
464         // `back` is how many entries to walk back by.
465         // Its value is necessarily between 1 and 7.
466         const int back = in + 8 - lhs_params.cols;
467         TFLITE_DCHECK_GE(back, 1);
468         TFLITE_DCHECK_LE(back, 7);
469         // Load 8 values as usual.
470         int16x8_t input_val = Load8AndSubtractZeroPoint(
471             rhs_data + lhs_params.cols - 8, rhs_params.zero_point);
472         const LhsScalar* local_filter_ptr = filter_ptr - back;
473         filter_ptr += lhs_params.cols - in;
474         int16x8_t filter_val_0 =
475             Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
476         local_filter_ptr += lhs_params.cols;
477         int16x8_t filter_val_1 =
478             Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
479         local_filter_ptr += lhs_params.cols;
480         int16x8_t filter_val_2 =
481             Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
482         local_filter_ptr += lhs_params.cols;
483         int16x8_t filter_val_3 =
484             Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
485         // Now zero out the `back` first entries of input_val.
486         // vsetq_lane_s16 takes a literal index, so we need unrolled code.
487         switch (back) {
488           case 7:
489             input_val = vsetq_lane_s16(0, input_val, 6);
490             [[clang::fallthrough]];
491           case 6:
492             input_val = vsetq_lane_s16(0, input_val, 5);
493             [[clang::fallthrough]];
494           case 5:
495             input_val = vsetq_lane_s16(0, input_val, 4);
496             [[clang::fallthrough]];
497           case 4:
498             input_val = vsetq_lane_s16(0, input_val, 3);
499             [[clang::fallthrough]];
500           case 3:
501             input_val = vsetq_lane_s16(0, input_val, 2);
502             [[clang::fallthrough]];
503           case 2:
504             input_val = vsetq_lane_s16(0, input_val, 1);
505             [[clang::fallthrough]];
506           default:
507             input_val = vsetq_lane_s16(0, input_val, 0);
508         }
509         // Multiply-accumulate 8 values as usual. The `back` first lanes
510         // of filter_val_* are junk, but it doesn't matter since they get
511         // multiplied by the zeros that we just wrote in the corresponding
512         // lanes of input_val.
513         acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0),
514                          vget_low_s16(input_val));
515         acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1),
516                          vget_low_s16(input_val));
517         acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2),
518                          vget_low_s16(input_val));
519         acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3),
520                          vget_low_s16(input_val));
521         acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
522                          vget_high_s16(input_val));
523         acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
524                          vget_high_s16(input_val));
525         acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
526                          vget_high_s16(input_val));
527         acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
528                          vget_high_s16(input_val));
529       }
530 
531       // Horizontally reduce accumulators
532       int32x2_t pairwise_reduced_acc_0 =
533           vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
534       int32x2_t pairwise_reduced_acc_1 =
535           vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
536       int32x2_t pairwise_reduced_acc_2 =
537           vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
538       int32x2_t pairwise_reduced_acc_3 =
539           vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
540       const int32x2_t reduced_lo =
541           vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
542       const int32x2_t reduced_hi =
543           vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
544       int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
545       // End of horizontal reduction: now `reduced` is a single int32x4
546       // containing the 4 int32 accumulators corresponding to the 4 rows
547       // being processed.
548 
549       // Add bias values.
550       if (params.bias) {
551         int32x4_t bias_vec = vld1q_s32(params.bias + row);
552         reduced = vaddq_s32(reduced, bias_vec);
553       }
554 
555       // Get multiplier parameters.
556       int32x4_t multiplier_fixedpoint;
557       int32x4_t multiplier_exponent;
558       if (quantization_flavor ==
559           QuantizationFlavor::kIntegerWithPerRowMultiplier) {
560         multiplier_exponent =
561             vld1q_s32(params.multiplier_exponent_perchannel + row);
562         multiplier_fixedpoint =
563             vld1q_s32(params.multiplier_fixedpoint_perchannel + row);
564       } else {
565         multiplier_exponent = vdupq_n_s32(params.multiplier_exponent);
566         multiplier_fixedpoint = vdupq_n_s32(params.multiplier_fixedpoint);
567       }
568 
569       // If positive exponent, shift left.
570       int32x4_t exponent_positive_part =
571           vmaxq_s32(multiplier_exponent, vdupq_n_s32(0));
572       reduced = vshlq_s32(reduced, exponent_positive_part);
573       // Multiply by the fixed-point multiplier.
574       reduced = vqrdmulhq_s32(reduced, multiplier_fixedpoint);
575       // If negative exponent, rounding-shift-right.
576       int32x4_t exponent_negative_part =
577           vminq_s32(multiplier_exponent, vdupq_n_s32(0));
578       reduced = vrshlq_s32(reduced, exponent_negative_part);
579 
580       // Add the output offset.
581       const int32x4_t output_offset_vec = vdupq_n_s32(dst_params.zero_point);
582       reduced = vaddq_s32(reduced, output_offset_vec);
583 
584       // Finally, clamp and store to the destination.
585       ClampAndStore(reduced, params.clamp_min, params.clamp_max,
586                     dst_data + row);
587     }
588   }
589 };
590 
591 // The float specialization below is unconditionally faster than ruy
592 // because ruy does not currently have any Gemv path.
593 // But it is not unconditionally faster than Eigen, which is what is used
594 // unless TFLITE_WITH_RUY is defined. Indeed, Eigen has decently efficient
595 // Gemv paths, and they may use AVX instructions, while the present
596 // NEON intrinsics code maps at best to SSE4 on x86.
597 #ifdef TFLITE_WITH_RUY
598 
599 // We want to use fused multiply-add when it's available (that is, on A64
600 // unconditionally and on A32 with VFPv4) because it's often faster, and
601 // because non-fused seems not to be available in A64 so a conscientious
602 // compiler might emit slow code (separate mul and add instructions) in order to
603 // implement the vmlaq_f32 intrinsic with strict bit-for-bit exactness on A64.
604 // (Compilers seem to be generating a fused fmla instruction at the moment,
605 // but that could change).
606 //
607 // We still want to support building for A32 without VFPv4.
608 inline float32x4_t mul_add(float32x4_t acc, float32x4_t lhs, float32x4_t rhs) {
609 #ifdef __ARM_FEATURE_FMA
610   return vfmaq_f32(acc, lhs, rhs);
611 #else
612   return vmlaq_f32(acc, lhs, rhs);
613 #endif
614 }
615 
616 template <>
617 struct CustomGemvImpl<float, float, float, float,
618                       QuantizationFlavor::kFloatingPoint> {
619   // This implementation's inner loop processes 4 rows of the left-hand side
620   // matrix at a time.
621   static constexpr int kKernelRows = 4;
622 
623   static bool IsSupportedGivenSufficientlyManyRows(
624       const MatrixParams<float>& lhs_params,
625       const MatrixParams<float>& rhs_params,
626       const MatrixParams<float>& dst_params,
627       const GemmParams<float, float>& params) {
628     // The kernel processes 4 LHS columns at once to fill float32x4 registers.
629     // The leftovers-handling code at the end works by loading a partially
630     // overlapping final register by walking back by a few (<4) floats
631     // to avoid running past the row's end. This relies on there being
632     // at least 4 LHS columns.
633     return lhs_params.cols >= 4;
634   }
635   static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
636                   const MatrixParams<float>& rhs_params, const float* rhs_data,
637                   const MatrixParams<float>& dst_params, float* dst_data,
638                   const GemmParams<float, float>& params, int row_start,
639                   int row_end) {
640     // Handle kKernelRows ( == 4) rows of the left-hand side matrix at each
641     // iteration of this for loop.
642     TFLITE_DCHECK_GE(row_end - row_start, kKernelRows);
643     for (int row = row_start; row < row_end; row += kKernelRows) {
644       // Here is the magic where we allow this kernel to handle any odd number
645       // of rows as long as it's >= kKernelRows: the last group of `kKernelRows`
646       // rows will be nudged to fit, possibly by starting at an odd value of
647       // `row`.
648       row = std::min(row, row_end - kKernelRows);
649       const float* filter_ptr = lhs_data + row * lhs_params.cols;
650 
651       static constexpr int kCacheLineSize = 64;
652       for (int k = 0; k < rhs_params.rows;
653            k += kCacheLineSize / sizeof(float)) {
654         optimized_ops_preload_l1_keep(rhs_data + k);
655       }
656 
657       // kPreloadAhead is empirically determined.
658       // End-to-end latency (ms) on mobilenet_v2_0.35_96_float, 1 thread,
659       // on Qualcomm S855:
660       //
661       // kPreloadAhead | big core | little core
662       // --------------+----------+------------
663       // 64            | 2.4      | 15.2
664       // 128           | 2.15     | 12.9
665       // 256           | 2        | 12.9
666       // 512           | 2.08     | 13.3
667       // 1024          | 2.05     | 14.7
668       // no prefetch   | 2.1      | 28
669       static constexpr int kPreloadAhead = 256;
670 
671       // 4 accumulator registers, one for each row being processed.
672       // Each has 4 float32 lanes that corresponds to columns modulo 4, and
673       // will need to be horizontally reduced at the end.
674       float32x4_t acc0 = vdupq_n_f32(0);
675       float32x4_t acc1 = acc0;
676       float32x4_t acc2 = acc0;
677       float32x4_t acc3 = acc0;
678       int in = 0;
679       // As much as possible, handle 4 columns of the left-hand side matrix
680       // at a time. This allows for decent NEON implementation.
681       for (; in <= lhs_params.cols - 4; in += 4) {
682         float32x4_t input_val = vld1q_f32(rhs_data + in);
683         const float* local_filter_ptr = filter_ptr;
684         float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr);
685         optimized_ops_preload_l1_stream(local_filter_ptr +
686                                         kPreloadAhead / sizeof(float));
687         local_filter_ptr += lhs_params.cols;
688         float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr);
689         optimized_ops_preload_l1_stream(local_filter_ptr +
690                                         kPreloadAhead / sizeof(float));
691         local_filter_ptr += lhs_params.cols;
692         float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr);
693         optimized_ops_preload_l1_stream(local_filter_ptr +
694                                         kPreloadAhead / sizeof(float));
695         local_filter_ptr += lhs_params.cols;
696         float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr);
697         optimized_ops_preload_l1_stream(local_filter_ptr +
698                                         kPreloadAhead / sizeof(float));
699         filter_ptr += 4;
700         acc0 = mul_add(acc0, filter_val_0, input_val);
701         acc1 = mul_add(acc1, filter_val_1, input_val);
702         acc2 = mul_add(acc2, filter_val_2, input_val);
703         acc3 = mul_add(acc3, filter_val_3, input_val);
704       }
705       // Less than 4 values remain. Handle the remaining values
706       // in one more copy of the above code handling 4, where we
707       // walk back a few values to be able to load 4 values without
708       // overrunning the buffer. This is where we make use of the requirement
709       // (see IsSupportedGivenSufficientlyManyRows) that there at least
710       // 4 LHS columns.
711       if (in < lhs_params.cols) {
712         // `back` is how many entries to walk back by.
713         // Its value is necessarily between 1 and 3.
714         const int back = in + 4 - lhs_params.cols;
715         TFLITE_DCHECK_GE(back, 1);
716         TFLITE_DCHECK_LE(back, 3);
717         // Load 4 values as usual.
718         float32x4_t input_val = vld1q_f32(rhs_data + lhs_params.cols - 4);
719         const float* local_filter_ptr = filter_ptr - back;
720         filter_ptr += lhs_params.cols - in;
721         float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr);
722         local_filter_ptr += lhs_params.cols;
723         float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr);
724         local_filter_ptr += lhs_params.cols;
725         float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr);
726         local_filter_ptr += lhs_params.cols;
727         float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr);
728         // Now zero out the `back` first entries of input_val.
729         // vsetq_lane_f32 takes a literal index, so we need unrolled code.
730         switch (back) {
731           case 3:
732             input_val = vsetq_lane_f32(0, input_val, 2);
733             [[clang::fallthrough]];
734           case 2:
735             input_val = vsetq_lane_f32(0, input_val, 1);
736             [[clang::fallthrough]];
737           default:
738             input_val = vsetq_lane_f32(0, input_val, 0);
739         }
740         // Multiply-accumulate 4 values as usual. The `back` first lanes
741         // of filter_val_* are junk, but it doesn't matter since they get
742         // multiplied by the zeros that we just wrote in the corresponding
743         // lanes of input_val.
744         acc0 = mul_add(acc0, filter_val_0, input_val);
745         acc1 = mul_add(acc1, filter_val_1, input_val);
746         acc2 = mul_add(acc2, filter_val_2, input_val);
747         acc3 = mul_add(acc3, filter_val_3, input_val);
748       }
749 
750       // Horizontally reduce accumulators
751       float32x2_t pairwise_reduced_acc_0 =
752           vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
753       float32x2_t pairwise_reduced_acc_1 =
754           vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
755       float32x2_t pairwise_reduced_acc_2 =
756           vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2));
757       float32x2_t pairwise_reduced_acc_3 =
758           vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3));
759       float32x2_t reduced_lo =
760           vpadd_f32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
761       float32x2_t reduced_hi =
762           vpadd_f32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
763       float32x4_t reduced = vcombine_f32(reduced_lo, reduced_hi);
764       // End of horizontal reduction: now `reduced` is a single float32x4
765       // containing the 4 float32 accumulators corresponding to the 4 rows
766       // being processed.
767 
768       if (params.bias) {
769         // Add bias values.
770         reduced = vaddq_f32(reduced, vld1q_f32(params.bias + row));
771       }
772 
773       // Clamp and store to destination.
774       reduced = vminq_f32(reduced, vdupq_n_f32(params.clamp_max));
775       reduced = vmaxq_f32(reduced, vdupq_n_f32(params.clamp_min));
776       vst1q_f32(dst_data + row, reduced);
777     }
778   }
779 };
780 
781 #endif  // TFLITE_WITH_RUY
782 
783 #endif  // USE_NEON
784 
785 }  // namespace detail
786 }  // namespace cpu_backend_gemm
787 }  // namespace tflite
788 
789 #endif  // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_
790