xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <sys/types.h>
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <cstddef>
20 #include <cstdint>
21 #include <cstdlib>
22 #include <cstring>
23 #include <limits>
24 #include <utility>
25 
26 #include "ruy/ruy.h"  // from @ruy
27 #include "tensorflow/lite/kernels/cpu_backend_context.h"
28 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
29 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
30 #include "tensorflow/lite/kernels/internal/common.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 #include "tensorflow/lite/kernels/internal/cppmath.h"
33 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
34 #include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
35 
36 #ifdef USE_NEON
37 
38 // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11.
39 #if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L
40 #if !defined(__ANDROID__) || __ANDROID_API__ >= 28
41 // Neither Apple nor Windows provide aligned_alloc.
42 #if !defined(__APPLE__) && !defined(_WIN32)
43 // TODO(miaowang): Re-enable std::aligned_alloc when it is avalaible in Android.
44 // #define TFLITE_USE_STD_ALIGNED_ALLOC
45 #endif
46 #endif
47 #endif
48 
49 // Note: This is the same as ABSL_HAVE_BUILTIN, but can't include the header.
50 #ifdef __has_builtin
51 #define TFLITE_HAS_BUILTIN(x) __has_builtin(x)
52 #else
53 #define TFLITE_HAS_BUILTIN(x) 0
54 #endif
55 
56 // Note: This is the same as ABSL_PREDICT_FALSE, but can't include the header.
57 #if TFLITE_HAS_BUILTIN(__builtin_expect) || \
58     (defined(__GNUC__) && !defined(__clang__))
59 #define TFLITE_UNLIKELY(x) (__builtin_expect(false || (x), false))
60 #else
61 #define TFLITE_UNLIKELY(x) (x)
62 #endif
63 
64 namespace tflite {
65 namespace tensor_utils {
66 namespace {
67 
68 constexpr int kFloatValuesPerNeonVector = 4;
69 constexpr int kInt16ValuesPerNeonVector = 8;
70 constexpr int kInt8ValuesPerNeonVector = 16;
71 constexpr int kNeonVectorAlignment = 4;
72 template <int PerNeonSize>
RoundDownVectors(int size)73 inline int RoundDownVectors(int size) {
74   return size & ~(PerNeonSize - 1);
75 }
76 
77 // Allocates, at least, size bytes of uninitialized storage whose alignment is
78 // specified by alignment. The size parameter must be an integral multiple of
79 // alignment.
80 // Caller is responsible by freeing the allocated memory by calling free on
81 // the passed freeing_buffer pointer.
aligned_alloc(size_t alignment,size_t size,void ** freeing_buffer)82 inline void* aligned_alloc(size_t alignment, size_t size,
83                            void** freeing_buffer) {
84 #ifdef TFLITE_USE_STD_ALIGNED_ALLOC
85   *freeing_buffer = std::aligned_alloc(
86       alignment, (size + alignment - 1) / alignment * alignment);
87   return *freeing_buffer;
88 #else
89   *freeing_buffer = malloc(size + alignment);
90   const size_t offset = ((uintptr_t)*freeing_buffer) % alignment;  // NOLINT
91   return offset == 0
92              ? *freeing_buffer
93              : ((char*)*freeing_buffer + (alignment - offset));  // NOLINT
94 #endif
95 }
96 
HasSdotInstruction()97 bool HasSdotInstruction() {
98   static const bool has_dotprod = DetectArmNeonDotprod();
99   return has_dotprod;
100 }
101 
AccumulateNeonLane(const float32x4_t lane)102 inline float AccumulateNeonLane(const float32x4_t lane) {
103 #ifdef __aarch64__
104   return vaddvq_f32(lane);
105 #else
106   return vgetq_lane_f32(lane, 0) + vgetq_lane_f32(lane, 1) +
107          vgetq_lane_f32(lane, 2) + vgetq_lane_f32(lane, 3);
108 #endif
109 }
110 
111 // Empirically determined breakpoints on when to use CpuBackendGemm vs.
112 // standard MatrixBatchVectorMultiplyAccumulate. Briefly, if the batch size
113 // is above 8 and the device does not have sdot, use CpuBackendGemm. Otherwise,
114 // for large batch sizes, it makes sense to use CpuBackendGemm if the matrix
115 // is not extremely rectangular.
UseCpuBackendGemm(int rows,int cols,int batch)116 bool UseCpuBackendGemm(int rows, int cols, int batch) {
117   if (!HasSdotInstruction()) {
118     return batch >= 8;
119   }
120   if (batch < 16) {
121     return false;
122   }
123   constexpr int kCpuBackendGemmThreshold = 2;
124   // Calculate "rectangularness" as a measure of how far from square the
125   // the LHS matrix is.
126   int row_rect = rows / cols;
127   int col_rect = cols / rows;
128   int rectangularness_lg2 =
129       row_rect > 0 ? FloorLog2(row_rect) : FloorLog2(col_rect);
130   int batch_lg2 = FloorLog2(batch);
131   // Large batch sizes move us above the threshold, but can be offset
132   // by significant rectangularness.
133   int batch_lg2_minus_rect_lg2 = batch_lg2 - rectangularness_lg2;
134   return batch_lg2_minus_rect_lg2 > kCpuBackendGemmThreshold;
135 }
136 
AccumulateNeonLane(const int32x4_t lane)137 inline int32_t AccumulateNeonLane(const int32x4_t lane) {
138 #ifdef __aarch64__
139   return vaddvq_s32(lane);
140 #else
141   int64x2_t pairwiseAdded = vpaddlq_s32(lane);
142   return vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
143 #endif
144 }
145 
146 // Single-rounding MultiplyByQuantizedMultiplier
147 #if TFLITE_SINGLE_ROUNDING
MultiplyByQuantizedMultiplier2Rows(int32x4x2_t input_val,int32_t quantized_multiplier,int shift)148 inline int32x4x2_t MultiplyByQuantizedMultiplier2Rows(
149     int32x4x2_t input_val, int32_t quantized_multiplier, int shift) {
150   TFLITE_DCHECK(quantized_multiplier >= 0);
151   const int right_shift = std::min(-1, shift);
152   const int left_shift = shift - right_shift;
153 
154   const int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
155   const int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
156   const int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
157 
158   int32x4x2_t result;
159   result.val[0] = vrshlq_s32(
160       vqdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup), multiplier_dup),
161       right_shift_dup);
162 
163   result.val[1] = vrshlq_s32(
164       vqdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup), multiplier_dup),
165       right_shift_dup);
166 
167   return result;
168 }
169 // Double-rounding MultiplyByQuantizedMultiplier
170 #else
MultiplyByQuantizedMultiplier2Rows(int32x4x2_t input_val,int32 quantized_multiplier,int shift)171 inline int32x4x2_t MultiplyByQuantizedMultiplier2Rows(
172     int32x4x2_t input_val, int32 quantized_multiplier, int shift) {
173   using gemmlowp::RoundingDivideByPOT;
174   using gemmlowp::SaturatingRoundingDoublingHighMul;
175   const int left_shift = shift > 0 ? shift : 0;
176   const int right_shift = shift > 0 ? 0 : -shift;
177   int32x4x2_t result;
178   // The vector type support for SaturatingRoundingDoublingHighMulth in gemmlowp
179   // is limited to NEON.
180 #ifdef GEMMLOWP_NEON
181   const int32x4_t left_shifted_one_dup = vdupq_n_s32(1 << left_shift);
182   result.val[0] =
183       RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
184                               vmulq_s32(input_val.val[0], left_shifted_one_dup),
185                               quantized_multiplier),
186                           right_shift);
187   result.val[1] =
188       RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
189                               vmulq_s32(input_val.val[1], left_shifted_one_dup),
190                               quantized_multiplier),
191                           right_shift);
192 #else
193   for (int i = 0; i < 2; ++i) {
194     int32_t vals[4];
195     vals[0] = RoundingDivideByPOT(
196         SaturatingRoundingDoublingHighMul(
197             vgetq_lane_s32(input_val.val[i], 0) * (1 << left_shift),
198             quantized_multiplier),
199         right_shift);
200     vals[1] = RoundingDivideByPOT(
201         SaturatingRoundingDoublingHighMul(
202             vgetq_lane_s32(input_val.val[i], 1) * (1 << left_shift),
203             quantized_multiplier),
204         right_shift);
205     vals[2] = RoundingDivideByPOT(
206         SaturatingRoundingDoublingHighMul(
207             vgetq_lane_s32(input_val.val[i], 2) * (1 << left_shift),
208             quantized_multiplier),
209         right_shift);
210     vals[3] = RoundingDivideByPOT(
211         SaturatingRoundingDoublingHighMul(
212             vgetq_lane_s32(input_val.val[i], 3) * (1 << left_shift),
213             quantized_multiplier),
214         right_shift);
215 
216     result.val[i] = vld1q_s32(reinterpret_cast<int32_t*>(&vals));
217   }
218 #endif
219   return result;
220 }
221 #endif  // TFLITE_SINGLE_ROUNDING
222 
223 }  // namespace
224 
NeonMatrixBatchVectorMultiplyAccumulate(const float * matrix,int m_rows,int m_cols,const float * vector,int n_batch,float * result)225 void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
226                                              int m_cols, const float* vector,
227                                              int n_batch, float* result) {
228   // If v_size is not divisible by the vector size, then we need to process the
229   // final few elements sequentially. postamble_start shows the start index
230   // where this should happen.
231   const int postamble_start =
232       RoundDownVectors<kFloatValuesPerNeonVector>(m_cols);
233 
234   for (int b = 0; b < n_batch; b++) {
235     float* result_in_batch = result + b * m_rows;
236     const float* vector_in_batch = vector + b * m_cols;
237     const float* matrix_row = matrix;
238 
239     // Main matrix by vector multiplication loop
240     for (int r = 0; r < m_rows; r++) {
241       float32x4_t acc_32x4 = vmovq_n_f32(0.0);
242       int c = 0;
243       for (; c < postamble_start; c += kFloatValuesPerNeonVector) {
244         // Load 4 float values from vector and matrix row.
245         float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c);
246         float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c);
247         // Multiply the vector and matrix row and add to accumulator.
248         acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
249       }
250       // Add the 4 intermediate sum values to get the final dot-prod value for
251       // this column.
252       *result_in_batch += AccumulateNeonLane(acc_32x4);
253       for (; TFLITE_UNLIKELY(c < m_cols); c++) {
254         *result_in_batch += matrix_row[c] * vector_in_batch[c];
255       }
256       matrix_row += m_cols;
257       ++result_in_batch;
258     }
259   }
260 }
261 
262 #ifdef __aarch64__
263 
264 // We interleave vector data to make the dot product logic more efficient.
265 // Suppose that vectors is:
266 //     a0 a1 a2 a3 a4 a5 ...
267 //     b0 b1 b2 b3 b4 b5 ...
268 //     c0 c1 c2 c3 c4 c5 ...
269 //     d0 d1 d2 d3 d4 d5 ...
270 //     e0 e1 e2 e3 e4 e5 ...
271 // This code interleaves them like this:
272 //     a0 a1 a2 a3 b0 b1 b2 b3 c0 c1 c2 c3 d0 d1 d2 d3 a4 a5 a6 a7 b4 ...
273 //     e0 e1 e2 e3 f0 f1 f2 f3 ...
274 // Once the data is interleaved, each 16-byte read from the vectors pointer
275 // contains 4 bytes from each of 4 vectors.
ShuffleVectors(const int8_t * vectors,const int n_batch,const int m_cols,void ** shuffled_vectors_free)276 const int8_t* ShuffleVectors(const int8_t* vectors, const int n_batch,
277                              const int m_cols, void** shuffled_vectors_free) {
278   int8* shuffled_vectors = reinterpret_cast<int8*>(aligned_alloc(
279       kNeonVectorAlignment, n_batch * m_cols, shuffled_vectors_free));
280 
281   for (int i = 0; i < n_batch; i += 4) {
282     int8* shuffled_vectors_ptr = shuffled_vectors + (i * m_cols);
283     const int8* unshuffled_vec0_ptr =
284         reinterpret_cast<const int8*>(vectors) + (i * m_cols);
285     const int8* unshuffled_vec1_ptr =
286         reinterpret_cast<const int8*>(vectors) + ((i + 1) * m_cols);
287     const int8* unshuffled_vec2_ptr =
288         reinterpret_cast<const int8*>(vectors) + ((i + 2) * m_cols);
289     const int8* unshuffled_vec3_ptr =
290         reinterpret_cast<const int8*>(vectors) + ((i + 3) * m_cols);
291     const int8* const end_vec0_ptr = unshuffled_vec1_ptr;
292 
293     while (unshuffled_vec0_ptr != end_vec0_ptr) {
294       asm volatile(
295           // This code path requires that (n_cols % 16) == 0 so we can safely
296           // read in 16-byte chunks from each row.
297           "ld1 {v0.16b}, [%[unshuffled_vec0_ptr]], #16\n"
298           "ld1 {v1.16b}, [%[unshuffled_vec1_ptr]], #16\n"
299           "ld1 {v2.16b}, [%[unshuffled_vec2_ptr]], #16\n"
300           "ld1 {v3.16b}, [%[unshuffled_vec3_ptr]], #16\n"
301 
302           "st4 {v0.s, v1.s, v2.s, v3.s}[0], [%[shuffled_vectors_ptr]], #16\n"
303           "st4 {v0.s, v1.s, v2.s, v3.s}[1], [%[shuffled_vectors_ptr]], #16\n"
304           "st4 {v0.s, v1.s, v2.s, v3.s}[2], [%[shuffled_vectors_ptr]], #16\n"
305           "st4 {v0.s, v1.s, v2.s, v3.s}[3], [%[shuffled_vectors_ptr]], #16\n"
306 
307           : [unshuffled_vec0_ptr] "+r"(unshuffled_vec0_ptr),
308             [unshuffled_vec1_ptr] "+r"(unshuffled_vec1_ptr),
309             [unshuffled_vec2_ptr] "+r"(unshuffled_vec2_ptr),
310             [unshuffled_vec3_ptr] "+r"(unshuffled_vec3_ptr),
311             [shuffled_vectors_ptr] "+r"(shuffled_vectors_ptr)
312           :
313           : "v0", "v1", "v2", "v3", "cc", "memory");
314     }
315   }
316 
317   return reinterpret_cast<const int8_t*>(shuffled_vectors);
318 }
319 
320 // Notes about the speed of this version vs. the baseline (from memory):
321 // - With 256K of L1, we can keep a lot of vectors in cache.
322 //   I recall a reasonable speedup just by rearranging the loop to have
323 //   row on the outside and batch on the inside.
324 // - I also recall getting a nice speedup from sdot.
325 // - I tried many times to do better than the current implementation, using
326 //   loop unrolling and instruction reordering to avoid stalls, etc.
327 //   but I was not able to do significantly better. This code is, however,
328 //   much worse than what the processor spec sheet suggests is possible.
DotprodMatrixBatchFourVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)329 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
330     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
331     const int8_t* vectors, const float* scaling_factors, int n_batch,
332     float* __restrict__ result) {
333   void* shuffled_vectors_free;
334 
335   const int8_t* shuffled_vectors =
336       ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
337 
338   for (int row = 0; row < m_rows; row += 2) {
339     for (int batch = 0; batch < n_batch; batch += 4) {
340       float* result_ptr = result + (batch * m_rows) + row;
341       const int8* mat_ptr0 = matrix + (row * m_cols);
342       const int8* mat_ptr1 = matrix + ((row + 1) * m_cols);
343       const int8* mat_ptr0_end = mat_ptr1;
344       const int8* vec_ptr = shuffled_vectors + (batch * m_cols);
345       const float* scaling_factors_ptr = scaling_factors + batch;
346       const uint64_t wide_rows = m_rows * sizeof(float);
347       const int8* mat_ptr2 = matrix + ((row + 2) * m_cols);
348       const int8* mat_ptr3 = matrix + ((row + 3) * m_cols);
349 
350       asm volatile(
351           // Zero out the accumulator registers.
352           "movi v0.4s, #0\n"
353           "movi v1.4s, #0\n"
354           "movi v2.4s, #0\n"
355           "movi v3.4s, #0\n"
356 
357           "1:\n"  // batch_cols_loop
358 
359           // Read 16 more bytes from a pair of matrix rows.
360           "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
361 
362           // Prefetch two rows ahead.
363           "prfm pldl1strm, [%[mat_ptr2]]\n"
364           "prfm pldl1strm, [%[mat_ptr3]]\n"
365 
366           // Read from input vectors 4 times; 64 bytes total.
367           // Each 16-byte register contains parts of 4 vectors; see the
368           // shuffle logic above.
369 
370           // From Benoit, places to look in the future:
371           // - Move load instructions further from sdot
372           // - Switch loop use-then-reload
373           // - Do partial unrolling to use register space better
374           "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
375           ".word 0x4f8ce100  // sdot v0.4s, v8.16b, v12.4b[0]\n"
376           "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
377           ".word 0x4face121  // sdot v1.4s, v9.16b, v12.4b[1]\n"
378           "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
379           ".word 0x4f8ce940  // sdot v0.4s, v10.16b, v12.4b[2]\n"
380           "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
381           ".word 0x4face961  // sdot v1.4s, v11.16b, v12.4b[3]\n"
382 
383           // Update prefetch pointers.
384           "add %[mat_ptr2], %[mat_ptr2], #16\n"
385           "add %[mat_ptr3], %[mat_ptr3], #16\n"
386 
387           // Re-use those vectors for the next row as well.
388           "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
389           ".word 0x4f8de102  // sdot v2.4s, v8.16b, v13.4b[0]\n"
390           ".word 0x4fade123  // sdot v3.4s, v9.16b, v13.4b[1]\n"
391           ".word 0x4f8de942  // sdot v2.4s, v10.16b, v13.4b[2]\n"
392           ".word 0x4fade963  // sdot v3.4s, v11.16b, v13.4b[3]\n"
393 
394           // If we're not done with these rows, continue.
395           "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
396           "bne 1b\n"  // batch_cols_loop
397 
398           // Done with the rows, sum the results.
399           "add v0.4s, v0.4s, v1.4s\n"
400           "add v2.4s, v2.4s, v3.4s\n"
401 
402           // Convert the per-vector sums to floating point.
403           "scvtf v0.4s, v0.4s\n"
404           "scvtf v1.4s, v2.4s\n"
405 
406           // Fetch scale factors.
407           "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
408 
409           // Multiply scale factors times sums.
410           "fmul v0.4s, v4.4s, v0.4s\n"
411           "fmul v1.4s, v4.4s, v1.4s\n"
412 
413           // Load previous result values.
414           // The result position is:
415           //   result[batch * m_rows + row]
416           // Here that is factored into:
417           //   result_ptr = result + row
418           //   *result_ptr = res[0]
419           //   (uint8*)result_ptr += (m_rows * sizeof(float))
420           //   *result_ptr = res[1]
421           //   ...
422           // Since we're reading two rows at a time, though, we read both
423           //   result[batch * m_rows + row]
424           // and
425           //   result[batch * m_rows + row + 1]
426           "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
427           "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
428           "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
429           "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
430 
431           // Go back to the starting position (subtract wide_rows * 4).
432           "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
433 
434           // Add previous result values.
435           "fadd v9.4s, v9.4s, v0.4s\n"
436           "fadd v10.4s, v10.4s, v1.4s\n"
437 
438           // Store results.
439           "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
440           "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
441           "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
442           "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
443           : [mat_ptr0] "+r"(mat_ptr0), [mat_ptr1] "+r"(mat_ptr1),
444             [vec_ptr] "+r"(vec_ptr), [result_ptr] "+r"(result_ptr),
445             [mat_ptr2] "+r"(mat_ptr2), [mat_ptr3] "+r"(mat_ptr3)
446           : [mat_ptr0_end] "r"(mat_ptr0_end),
447             [scaling_factors_ptr] "r"(scaling_factors_ptr),
448             [wide_rows] "r"(wide_rows)
449           : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
450             "v10", "v11", "v12", "v13", "cc", "memory");
451     }
452   }
453 
454   free(shuffled_vectors_free);
455 }
456 
DotprodMatrixBatchFourVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * row_sums)457 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
458     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
459     const int8_t* vectors, const float* scaling_factors, int n_batch,
460     float* __restrict__ result, const float* per_channel_scale,
461     const int32_t* input_offset, int32_t* row_sums) {
462   void* shuffled_vectors_free;
463   const int8_t* shuffled_vectors =
464       ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
465 
466   for (int row = 0; row < m_rows; row += 2) {
467     for (int batch = 0; batch < n_batch; batch += 4) {
468       const float* channel_scales_ptr = per_channel_scale + row;
469       int32_t* row_sums_ptr = row_sums ? row_sums + row : nullptr;
470 
471       float* result_ptr = result + (batch * m_rows) + row;
472       const int8* mat_ptr0 = matrix + (row * m_cols);
473       const int8* mat_ptr1 = matrix + ((row + 1) * m_cols);
474       const int8* mat_ptr0_end = mat_ptr1;
475       const int8* vec_ptr = shuffled_vectors + (batch * m_cols);
476       const float* scaling_factors_ptr = scaling_factors + batch;
477       const uint64_t wide_rows = m_rows * sizeof(float);
478       const int32_t* batch_offsets_ptr = input_offset + batch;
479       const int32_t is_channel_scale_nullptr = per_channel_scale == nullptr;
480       const int32_t is_row_sums_nullptr = row_sums_ptr == nullptr;
481       asm volatile(
482           "movi v0.4s, #0\n"
483           "movi v1.4s, #0\n"
484           "movi v2.4s, #0\n"
485           "movi v3.4s, #0\n"
486           // Load zero points.
487           "ld1 {v7.4s}, [%[batch_offsets_ptr]]\n"
488           "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
489           // Zero out zero point accumulators.
490           "movi v14.4s, #0\n"
491           "movi v15.4s, #0\n"
492 
493           // Load per channel scales if not null.
494           "cmp %w[is_channel_scale_nullptr], #0\n"
495           "bne 1f\n"
496           "ld1r {v16.4s}, [%[channel_scales_ptr]], #4\n"
497           "ld1r {v17.4s}, [%[channel_scales_ptr]]\n"
498           "fmul v16.4s, v16.4s, v4.4s\n"
499           "fmul v17.4s, v17.4s, v4.4s\n"
500           "b 2f\n"
501           "1:\n"
502           "mov v16.16b, v4.16b\n"
503           "mov v17.16b, v4.16b\n"
504           "2:\n"
505           "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
506           "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
507           ".word 0x4f8ce100  // sdot v0.4s, v8.16b, v12.4b[0]\n"
508           "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
509           ".word 0x4face121  // sdot v1.4s, v9.16b, v12.4b[1]\n"
510           "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
511           ".word 0x4f8ce940  // sdot v0.4s, v10.16b, v12.4b[2]\n"
512           "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
513           ".word 0x4face961  // sdot v1.4s, v11.16b, v12.4b[3]\n"
514           "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
515           ".word 0x4f8de102  // sdot v2.4s, v8.16b, v13.4b[0]\n"
516           ".word 0x4fade123  // sdot v3.4s, v9.16b, v13.4b[1]\n"
517           ".word 0x4f8de942  // sdot v2.4s, v10.16b, v13.4b[2]\n"
518           ".word 0x4fade963  // sdot v3.4s, v11.16b, v13.4b[3]\n"
519           "cmp %w[is_row_sums_nullptr], #1\n"
520           "bne 3f\n"
521           // Accumulate row_sums for zero point calculations.
522           "saddlp v12.8h, v12.16b\n"
523           "saddlp v13.8h, v13.16b\n"
524           "sadalp v14.4s, v12.8h\n"
525           "sadalp v15.4s, v13.8h\n"
526           "3:\n"
527           "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
528           "bne 2b\n"
529           "add v0.4s, v0.4s, v1.4s\n"
530           "add v2.4s, v2.4s, v3.4s\n"
531 
532           "cmp %w[is_row_sums_nullptr], #1\n"
533           "bne 4f\n"
534           // Calculate zero point offsets.
535           "addv s14, v14.4s\n"
536           "addv s15, v15.4s\n"
537           "dup v14.4s, v14.s[0]\n"
538           "dup v15.4s, v15.s[0]\n"
539           "b 5f\n"
540           "4:\n"
541           "ld1r {v14.4s}, [%[row_sums_ptr]], #4\n"
542           "ld1r {v15.4s}, [%[row_sums_ptr]]\n"
543           "5:\n"
544 
545           "mul v14.4s, v14.4s, v7.4s\n"
546           "mul v15.4s, v15.4s, v7.4s\n"
547           "sub v0.4s, v0.4s, v14.4s\n"
548           "sub v2.4s, v2.4s, v15.4s\n"
549 
550           "scvtf v0.4s, v0.4s\n"
551           "scvtf v1.4s, v2.4s\n"
552 
553           // Multiply scale.
554           "fmul v0.4s, v16.4s, v0.4s\n"
555           "fmul v1.4s, v17.4s, v1.4s\n"
556 
557           "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
558           "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
559           "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
560           "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
561           "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
562           "fadd v9.4s, v9.4s, v0.4s\n"
563           "fadd v10.4s, v10.4s, v1.4s\n"
564           "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
565           "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
566           "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
567           "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
568           : [mat_ptr0] "+r"(mat_ptr0), [mat_ptr1] "+r"(mat_ptr1),
569             [vec_ptr] "+r"(vec_ptr), [result_ptr] "+r"(result_ptr),
570             [row_sums_ptr] "+r"(row_sums_ptr),
571             [channel_scales_ptr] "+r"(channel_scales_ptr)
572           : [mat_ptr0_end] "r"(mat_ptr0_end),
573             [scaling_factors_ptr] "r"(scaling_factors_ptr),
574             [wide_rows] "r"(wide_rows),
575             [batch_offsets_ptr] "r"(batch_offsets_ptr),
576             [is_channel_scale_nullptr] "r"(is_channel_scale_nullptr),
577             [is_row_sums_nullptr] "r"(is_row_sums_nullptr)
578           : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
579             "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "w0", "w1",
580             "cc", "memory");
581     }
582   }
583 
584   free(shuffled_vectors_free);
585 }
586 
DotprodMatrixBatchFourVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset)587 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
588     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
589     const int8_t* vectors, const float* scaling_factors, int n_batch,
590     float* __restrict__ result, const float* per_channel_scale,
591     const int32_t* input_offset) {
592   DotprodMatrixBatchFourVectorMultiplyAccumulate(
593       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
594       per_channel_scale, input_offset, nullptr);
595 }
596 
597 // The DotprodMatrixBatchFourVectorMultiplyAccumulate kernel processes 4
598 // vectors in the same time as the baseline processes 1 vector. However, it
599 // requires 4 vectors of input.
600 //
601 // To take advantage of this speed difference, we add some zero-valued
602 // vectors to the batch so that n_batch is a multiple of 4. Then we execute
603 // DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate on that padded batch,
604 // then extract just the results we want at the end (ignoring the extra padding
605 // outputs).
606 //
607 // The relative cost of the padding is large when the matrix is smaller than
608 // 128x128, so we don't use this code path on small matrices. On larger
609 // matrices, the computation cost dwarfs the padding cost, making this code
610 // viable.
611 //
612 // If we ignore the cost of padding, this kernel is:
613 //    1x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 1
614 //    2x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 2
615 //    3x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 3
616 //    ...
617 //
618 // We don't use this kernel when n_batch = 1 because the baseline kernel
619 // is fine for that case.
DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * row_sums)620 void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
621     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
622     const int8_t* vectors, const float* scaling_factors, int n_batch,
623     float* __restrict__ result, const float* per_channel_scale,
624     const int32_t* input_offset, int32_t* row_sums) {
625   // Round to the nearest multiple of 4.
626   int batch_round_up = n_batch;
627   if (n_batch % 4 != 0) {
628     batch_round_up += (4 - n_batch % 4);
629   }
630   TFLITE_CHECK_LE(n_batch, batch_round_up);
631 
632   void* padded_vectors_free;
633   const int padded_vectors_size = batch_round_up * m_cols;
634   int8_t* padded_vectors = reinterpret_cast<int8_t*>(aligned_alloc(
635       kNeonVectorAlignment, padded_vectors_size, &padded_vectors_free));
636   memset(padded_vectors, 0, padded_vectors_size);
637 
638   void* padded_result_free;
639   const int result_size = n_batch * m_rows * sizeof(float);
640   const int padded_result_size = batch_round_up * m_rows * sizeof(float);
641   float* padded_result = reinterpret_cast<float*>(aligned_alloc(
642       kNeonVectorAlignment, padded_result_size, &padded_result_free));
643   memcpy(padded_result, result, result_size);
644   memset(reinterpret_cast<char*>(padded_result) + result_size, 0,
645          padded_result_size - result_size);
646 
647   // Copy the input into the padded data structure.
648   TFLITE_CHECK_LE(n_batch * m_cols, padded_vectors_size);
649   memcpy(padded_vectors, vectors, n_batch * m_cols);
650 
651   void* padded_scaling_factors_free;
652   const int padded_scaling_factors_size = batch_round_up * sizeof(float);
653   float* padded_scaling_factors = reinterpret_cast<float*>(
654       aligned_alloc(kNeonVectorAlignment, padded_scaling_factors_size,
655                     &padded_scaling_factors_free));
656   TFLITE_CHECK_LE(n_batch * sizeof(float), padded_scaling_factors_size);
657   TFLITE_CHECK_LE(batch_round_up * sizeof(float), padded_scaling_factors_size);
658   memset(padded_scaling_factors, 0, batch_round_up * sizeof(float));
659   memcpy(padded_scaling_factors, scaling_factors, n_batch * sizeof(float));
660 
661   if (input_offset != nullptr) {
662     void* padded_input_offset_free;
663     const int padded_input_offset_size = batch_round_up * sizeof(int32_t);
664     int32_t* padded_input_offset = reinterpret_cast<int32_t*>(
665         aligned_alloc(kNeonVectorAlignment, padded_input_offset_size,
666                       &padded_input_offset_free));
667     TFLITE_CHECK_LE(n_batch * sizeof(int32_t), padded_input_offset_size);
668     TFLITE_CHECK_LE(batch_round_up * sizeof(int32_t), padded_input_offset_size);
669     memset(padded_input_offset, 0, batch_round_up * sizeof(int32_t));
670     memcpy(padded_input_offset, input_offset, n_batch * sizeof(int32_t));
671 
672     // Call the main kernel.
673     DotprodMatrixBatchFourVectorMultiplyAccumulate(
674         matrix, m_rows, m_cols, padded_vectors, padded_scaling_factors,
675         batch_round_up, padded_result, per_channel_scale, padded_input_offset,
676         row_sums);
677 
678     free(padded_input_offset_free);
679   } else {
680     // Call the main kernel.
681     DotprodMatrixBatchFourVectorMultiplyAccumulate(
682         matrix, m_rows, m_cols, padded_vectors, padded_scaling_factors,
683         batch_round_up, padded_result);
684   }
685   memcpy(result, padded_result, result_size);
686 
687   free(padded_result_free);
688   free(padded_vectors_free);
689   free(padded_scaling_factors_free);
690 }
691 
DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)692 void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
693     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
694     const int8_t* vectors, const float* scaling_factors, int n_batch,
695     float* __restrict__ result) {
696   DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
697       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
698       /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
699       /*row_sums=*/nullptr);
700 }
701 
DotprodSparseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)702 static void DotprodSparseMatrixBatchVectorMultiplyAccumulate(
703     const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
704     const int m_cols, const int8_t* __restrict__ vectors,
705     const float* scaling_factors, int n_batch, float* __restrict__ result) {
706   const uint8_t* ledger_ptr = ledger;
707   const int8* mat_ptr = matrix;
708 
709   for (int row = 0; row < m_rows; row++) {
710     int num_nonzero_chunks = *ledger_ptr;
711     ledger_ptr++;
712     const uint8* ledger_start = ledger_ptr;
713     const uint8* ledger_end = ledger_ptr + num_nonzero_chunks;
714     const int8* mat_start = mat_ptr;
715 
716     for (int batch = 0; batch < n_batch; batch++) {
717       const int8* vec_ptr = vectors + (batch * m_cols);
718       int64_t row_sum = 0;
719 
720       mat_ptr = mat_start;
721       ledger_ptr = ledger_start;
722 
723       if (ledger_ptr != ledger_end) {
724         asm volatile(
725             "movi v0.4s, #0\n"
726             "movi v1.4s, #0\n"
727             "movi v8.4s, #0\n"
728             "mov x7, 0\n"
729 
730             "1:\n"  // chunks_loop
731 
732             // Single matrix chunk, 16 bytes
733             "ld1 {v8.16b}, [%[mat_ptr]], #16\n"
734 
735             // Read the next ledger index and increment.
736             "ldrb w7, [%[ledger_ptr]], #1\n"
737 
738             // Read 16 bytes of vector data from (vec_ptr + (ledger_index * 16))
739             "add x8, %[vec_ptr], x7, lsl #4\n"
740             "ld1 {v9.16b}, [x8]\n"
741 
742             // Dot product of matrix row and vector.
743             ".word 0x4e889520  // sdot v0.4s, v9.16b, v8.16b\n"
744 
745             "cmp %[ledger_ptr], %[ledger_end]\n"
746             "blt 1b\n"  // chunks_loop
747 
748             // Sum the 4 vector components into a 32-bit value.
749             "addv s1, v0.4s\n"
750             // row_sum is 64-bit, so we copy 64 bits of v1 into it.
751             // We have to be careful to cast this value to 32 bits in order
752             // to interpret the sign bit properly.
753             "mov %[row_sum], v1.d[0]\n"
754             : [row_sum] "=r"(row_sum), [ledger_ptr] "+r"(ledger_ptr),
755               [mat_ptr] "+r"(mat_ptr), [vec_ptr] "+r"(vec_ptr)
756             : [ledger_end] "r"(ledger_end)
757             : "x0", "x1", "x7", "x8", "v0", "v1", "v8", "v9", "cc", "memory");
758       }
759       result[batch * m_rows + row] +=
760           static_cast<int32>(row_sum) * scaling_factors[batch];
761     }
762   }
763 }
764 
765 #endif  // __aarch64__
766 
NeonMatrixBatchVectorMultiplyImpl(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch)767 void NeonMatrixBatchVectorMultiplyImpl(const int8_t* input, const int32_t* bias,
768                                        const int8_t* input_to_gate_weights,
769                                        int32_t n_batch, int32_t n_input,
770                                        int32_t n_output, int32_t output_zp,
771                                        int32_t* scratch) {
772   // Assuming *matrix is kNeonVectorAlignment-byte aligned, every row of the
773   // matrix is also kNeonVectorAlignment-byte aligned as long as cols is a
774   // multiple of kNeonVectorAlignment. The assumption is currently satisfied by
775   // TFLite's 16-byte memory alignment scheme.
776   //
777   // Otherwise, we allocate an aligned memory block and set
778   // a flag to later copy rows from matrix to the block
779   // for aligned multiplication.
780   bool unaligned = false;
781   int8_t* aligned_row = nullptr;
782   void* aligned_row_free = nullptr;
783   if ((n_input & (kNeonVectorAlignment - 1)) != 0) {
784     unaligned = true;
785     aligned_row =
786         (int8_t*)aligned_alloc(kNeonVectorAlignment, n_input,  // NOLINT
787                                &aligned_row_free);
788   }
789   void* aligned_vec_free = nullptr;
790   int8_t* aligned_vec =
791       (int8_t*)aligned_alloc(kNeonVectorAlignment, n_input,  // NOLINT
792                              &aligned_vec_free);
793 
794   // If m_cols is not at least kInt8ValuesPerNeonVector, we cannot use the main
795   // vectorized loop, and we need to process sequentially. postamble_half_start
796   // shows the start index where this should happen. Between postamble_start and
797   // postamble_half_start we can still process kInt8ValuesPerNeonVector/2 in a
798   // vectorized form.
799   const int postamble_half_start =
800       RoundDownVectors<kInt8ValuesPerNeonVector>(n_input);
801   const int postamble_start =
802       RoundDownVectors<(kInt8ValuesPerNeonVector / 2)>(n_input);
803 
804   for (int batch = 0; batch < n_batch; ++batch) {
805     // Copy the vector data to an aligned vector.
806     memcpy(aligned_vec, input + batch * n_input, sizeof(int8_t) * n_input);
807     // Compute dot-product for every column.
808     for (int row = 0; row < n_output; ++row) {
809       // Get the address of the first element of the row.
810       int8_t* row_ptr =
811           (int8_t*)input_to_gate_weights + row * n_input;  // NOLINT
812       if (unaligned) {
813         memcpy(aligned_row, row_ptr, sizeof(int8_t) * n_input);
814         row_ptr = aligned_row;
815       }
816 
817       // Initialize the dot product sum for the row to 0.
818       int32x4_t dotprod_32x4 = vmovq_n_s32(0);
819 
820       // For every block of 16 8-bit elements.
821       int col = 0;
822       for (; col < postamble_half_start; col += kInt8ValuesPerNeonVector) {
823         // Load 16 8-bit values from the row and vector, each, to operate on.
824         // Here the assumption is that each buffer is 4-byte aligned. Otherwise,
825         // performance may suffer significantly.
826         TFLITE_DCHECK_EQ(  // NOLINT
827             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
828         const int8x16_t s1_8x16 = vld1q_s8((const int8_t*)(aligned_vec + col));
829         const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr + col));
830         // Multiply the low bits (i.e. the lower 8 8bit numbers in the
831         // registers).
832         int16x8_t prod_16x8 =
833             vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
834         // Multiply the high bits (i.e. the higher 8 8bit numbers in the
835         // registers), and accumulate with the result of the low bits product.
836         // The assumption here is that overflow will not happen as we quantize
837         // our values to be in the range [-127, 127]. As such the sum of the 2
838         // products is always strictly smaller than 15-bits (32767 in absolute
839         // value).
840         prod_16x8 =
841             vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
842 
843         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
844       }  // for col
845 
846       // Half iteration dealing only 8 elements
847       if (TFLITE_UNLIKELY(col < postamble_start)) {
848         // Load 8 8-bit values from the row and column each to operate on.
849         // Here the assumption is that each buffer is 4-bytes aligned.
850         // Otherwise, performance may suffer significantly.
851         TFLITE_DCHECK_EQ(  // NOLINT
852             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
853         const int8x8_t s1_8x8 = vld1_s8((const int8_t*)(aligned_vec + col));
854         const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
855         const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
856         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
857         col += (kInt8ValuesPerNeonVector >> 1);
858       }
859       // Add the 4 intermediate sum values to get the final dot-prod value for
860       // this row.
861       int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
862       // Postamble loop.
863       for (; TFLITE_UNLIKELY(col < n_input); ++col) {
864         dotprod += row_ptr[col] * aligned_vec[col];
865       }  // for col
866 
867       dotprod += bias[row];
868       scratch[batch * n_output + row] = dotprod;
869     }  // for row
870   }    // for batch
871 
872   if (unaligned) {
873     free(aligned_row_free);
874   }
875   free(aligned_vec_free);
876 }
877 
NeonMatrixBatchVectorAccumulateImpl(int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_output,int32_t output_zp,int32_t * scratch,int16_t * output)878 inline void NeonMatrixBatchVectorAccumulateImpl(
879     int32_t multiplier, int32_t shift, int32_t n_batch, int32_t n_output,
880     int32_t output_zp, int32_t* scratch, int16_t* output) {
881   int i = 0;
882   const int total_size = n_batch * n_output;
883 
884   const int32_t output_min = std::numeric_limits<int16_t>::min();
885   const int32_t output_max = std::numeric_limits<int16_t>::max();
886 
887   const int32x4_t output_zp_dup = vdupq_n_s32(output_zp);
888   const int32x4_t max_val_dup = vdupq_n_s32(output_max);
889   const int32x4_t min_val_dup = vdupq_n_s32(output_min);
890 
891   using gemmlowp::RoundingDivideByPOT;
892   using gemmlowp::SaturatingRoundingDoublingHighMul;
893 
894   for (; i <= total_size - 8; i += 8) {
895     int32x4x2_t scratch_val;
896     scratch_val.val[0] = vld1q_s32(scratch + i);
897     scratch_val.val[1] = vld1q_s32(scratch + i + 4);
898     const int16x8_t output_val = vld1q_s16(output + i);
899     const int32x4_t first_half = vmovl_s16(vget_low_s16(output_val));
900     const int32x4_t second_half = vmovl_s16(vget_high_s16(output_val));
901 
902     int32x4x2_t temp_val =
903         MultiplyByQuantizedMultiplier2Rows(scratch_val, multiplier, shift);
904 
905     temp_val.val[0] =
906         vaddq_s32(vaddq_s32(temp_val.val[0], first_half), output_zp_dup);
907     temp_val.val[1] =
908         vaddq_s32(vaddq_s32(temp_val.val[1], second_half), output_zp_dup);
909     temp_val.val[0] =
910         vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
911     temp_val.val[1] =
912         vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
913     const int16x8_t result =
914         vcombine_s16(vqmovn_s32(temp_val.val[0]), vqmovn_s32(temp_val.val[1]));
915     vst1q_s16(output + i, result);
916   }
917   for (; TFLITE_UNLIKELY(i < total_size); ++i) {
918     int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
919     temp += output_zp;
920     temp += output[i];
921     if (temp > output_max) {
922       temp = output_max;
923     }
924     if (temp < output_min) {
925       temp = output_min;
926     }
927     output[i] = static_cast<int16_t>(temp);
928   }
929 }
930 
NeonMatrixBatchVectorAccumulateImpl(int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_output,int32_t output_zp,int32_t * scratch,int8_t * output)931 inline void NeonMatrixBatchVectorAccumulateImpl(
932     int32_t multiplier, int32_t shift, int32_t n_batch, int32_t n_output,
933     int32_t output_zp, int32_t* scratch, int8_t* output) {
934   int i = 0;
935   const int total_size = n_batch * n_output;
936 
937   const int32_t output_min = std::numeric_limits<int8_t>::min();
938   const int32_t output_max = std::numeric_limits<int8_t>::max();
939 
940   const int32x4_t output_zp_dup = vdupq_n_s32(output_zp);
941   const int32x4_t max_val_dup = vdupq_n_s32(output_max);
942   const int32x4_t min_val_dup = vdupq_n_s32(output_min);
943 
944   using gemmlowp::RoundingDivideByPOT;
945   using gemmlowp::SaturatingRoundingDoublingHighMul;
946 
947   for (; i <= total_size - 16; i += 16) {
948     int32x4x4_t scratch_val;
949     scratch_val.val[0] = vld1q_s32(scratch + i);
950     scratch_val.val[1] = vld1q_s32(scratch + i + 4);
951     scratch_val.val[2] = vld1q_s32(scratch + i + 8);
952     scratch_val.val[3] = vld1q_s32(scratch + i + 12);
953 
954     const int8x16_t output_val = vld1q_s8(output + i);
955     const int16x8_t first_half = vmovl_s8(vget_low_s8(output_val));
956     const int16x8_t second_half = vmovl_s8(vget_high_s8(output_val));
957     const int32x4_t output_val_1 = vmovl_s16(vget_low_s16(first_half));
958     const int32x4_t output_val_2 = vmovl_s16(vget_high_s16(first_half));
959     const int32x4_t output_val_3 = vmovl_s16(vget_low_s16(second_half));
960     const int32x4_t output_val_4 = vmovl_s16(vget_high_s16(second_half));
961 
962     int32x4x4_t temp_val =
963         MultiplyByQuantizedMultiplier4Rows(scratch_val, multiplier, shift);
964 
965     temp_val.val[0] =
966         vaddq_s32(vaddq_s32(temp_val.val[0], output_val_1), output_zp_dup);
967     temp_val.val[1] =
968         vaddq_s32(vaddq_s32(temp_val.val[1], output_val_2), output_zp_dup);
969     temp_val.val[2] =
970         vaddq_s32(vaddq_s32(temp_val.val[2], output_val_3), output_zp_dup);
971     temp_val.val[3] =
972         vaddq_s32(vaddq_s32(temp_val.val[3], output_val_4), output_zp_dup);
973 
974     temp_val.val[0] =
975         vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
976     temp_val.val[1] =
977         vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
978     temp_val.val[2] =
979         vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
980     temp_val.val[3] =
981         vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
982 
983     const int16x8_t result_1 =
984         vcombine_s16(vqmovn_s32(temp_val.val[0]), vqmovn_s32(temp_val.val[1]));
985     const int16x8_t result_2 =
986         vcombine_s16(vqmovn_s32(temp_val.val[2]), vqmovn_s32(temp_val.val[3]));
987     const int8x16_t result =
988         vcombine_s8(vqmovn_s16(result_1), vqmovn_s16(result_2));
989     vst1q_s8(output + i, result);
990   }
991   for (; TFLITE_UNLIKELY(i < total_size); ++i) {
992     int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
993     temp += output_zp;
994     temp += output[i];
995     if (temp > output_max) {
996       temp = output_max;
997     }
998     if (temp < output_min) {
999       temp = output_min;
1000     }
1001     output[i] = static_cast<int8_t>(temp);
1002   }
1003 }
1004 
NeonCpuBackendGemm(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,CpuBackendContext * context)1005 void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias,
1006                         const int8_t* input_to_gate_weights, int32_t n_batch,
1007                         int32_t n_input, int32_t n_output, int32_t output_zp,
1008                         int32_t* scratch, CpuBackendContext* context) {
1009   using ::tflite::cpu_backend_gemm::Gemm;
1010   using ::tflite::cpu_backend_gemm::GemmParams;
1011   using ::tflite::cpu_backend_gemm::MatrixParams;
1012 
1013   MatrixParams<int8_t> lhs_params;
1014   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1015   lhs_params.rows = n_output;
1016   lhs_params.cols = n_input;
1017   lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
1018 
1019   MatrixParams<int8_t> rhs_params;
1020   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1021   rhs_params.rows = n_input;
1022   rhs_params.cols = n_batch;
1023 
1024   MatrixParams<int32_t> dst_params;
1025   dst_params.order = cpu_backend_gemm::Order::kColMajor;
1026   dst_params.rows = n_output;
1027   dst_params.cols = n_batch;
1028 
1029   GemmParams<int32, int32> gemm_params;
1030   if (bias) {
1031     gemm_params.bias = bias;
1032   }
1033   cpu_backend_gemm::Gemm(lhs_params, input_to_gate_weights, rhs_params, input,
1034                          dst_params, scratch, gemm_params, context);
1035 }
1036 
NeonMatrixBatchVectorMultiplyAccumulate(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,int16_t * output,CpuBackendContext * context)1037 void NeonMatrixBatchVectorMultiplyAccumulate(
1038     const int8_t* input, const int32_t* bias,
1039     const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
1040     int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
1041     int32_t* scratch, int16_t* output, CpuBackendContext* context) {
1042 #ifdef TFLITE_WITH_RUY_GEMV
1043   NeonCpuBackendGemm(input, bias, input_to_gate_weights, n_batch, n_input,
1044                      n_output, output_zp, scratch, context);
1045 #else
1046   NeonMatrixBatchVectorMultiplyImpl(input, bias, input_to_gate_weights, n_batch,
1047                                     n_input, n_output, output_zp, scratch);
1048 #endif
1049   NeonMatrixBatchVectorAccumulateImpl(multiplier, shift, n_batch, n_output,
1050                                       output_zp, scratch, output);
1051 }
1052 
NeonMatrixBatchVectorMultiplyAccumulate(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,int8_t * output,CpuBackendContext * context)1053 void NeonMatrixBatchVectorMultiplyAccumulate(
1054     const int8_t* input, const int32_t* bias,
1055     const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
1056     int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
1057     int32_t* scratch, int8_t* output, CpuBackendContext* context) {
1058 #ifdef TFLITE_WITH_RUY_GEMV
1059   NeonCpuBackendGemm(input, bias, input_to_gate_weights, n_batch, n_input,
1060                      n_output, output_zp, scratch, context);
1061 #else
1062   NeonMatrixBatchVectorMultiplyImpl(input, bias, input_to_gate_weights, n_batch,
1063                                     n_input, n_output, output_zp, scratch);
1064 #endif
1065   NeonMatrixBatchVectorAccumulateImpl(multiplier, shift, n_batch, n_output,
1066                                       output_zp, scratch, output);
1067 }
1068 
NeonMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)1069 void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
1070                                              const int m_rows, const int m_cols,
1071                                              const int8_t* __restrict__ vectors,
1072                                              const float* scaling_factors,
1073                                              int n_batch,
1074                                              float* __restrict__ result) {
1075 #ifdef __aarch64__
1076   if (HasSdotInstruction() && m_cols % 16 == 0 && m_rows % 2 == 0 &&
1077       m_rows >= n_batch) {
1078     if (n_batch % 4 == 0) {
1079       // Benchmarks suggest that it's always better to use the batch code
1080       // when we can, even on small matrices.
1081       DotprodMatrixBatchFourVectorMultiplyAccumulate(
1082           matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
1083       return;
1084     } else if (n_batch >= 2 && m_rows * m_cols >= 128 * 128) {
1085       DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
1086           matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
1087       return;
1088     }
1089   }
1090 #endif  // __aarch64__
1091 
1092   // Assuming *matrix is kNeonVectorAlignment-byte aligned, every row of the
1093   // matrix is also kNeonVectorAlignment-byte aligned as long as cols is a
1094   // multiple of kNeonVectorAlignment. The assumption is currently satisfied by
1095   // TFLite's 16-byte memory alignment scheme.
1096   //
1097   // Otherwise, we allocate an aligned memory block and set
1098   // a flag to later copy rows from matrix to the block
1099   // for aligned multiplication.
1100   bool unaligned = false;
1101   int8_t* aligned_row = nullptr;
1102   void* aligned_row_free = nullptr;
1103   if ((m_cols & (kNeonVectorAlignment - 1)) != 0) {
1104     unaligned = true;
1105     aligned_row =
1106         (int8_t*)aligned_alloc(kNeonVectorAlignment, m_cols,  // NOLINT
1107                                &aligned_row_free);
1108   }
1109   void* aligned_vec_free = nullptr;
1110   int8_t* aligned_vec =
1111       (int8_t*)aligned_alloc(kNeonVectorAlignment, m_cols,  // NOLINT
1112                              &aligned_vec_free);
1113 
1114   // If m_cols is not at least kInt8ValuesPerNeonVector, we cannot use the main
1115   // vectorized loop, and we need to process sequentially. postamble_half_start
1116   // shows the start index where this should happen. Between postamble_start and
1117   // postamble_half_start we can still process kInt8ValuesPerNeonVector/2 in a
1118   // vectorized form.
1119   const int postamble_half_start =
1120       RoundDownVectors<kInt8ValuesPerNeonVector>(m_cols);
1121   const int postamble_start =
1122       RoundDownVectors<(kInt8ValuesPerNeonVector / 2)>(m_cols);
1123 
1124   for (int batch = 0; batch < n_batch; ++batch) {
1125     const float batch_scaling_factor = scaling_factors[batch];
1126     // Copy the vector data to an aligned vector.
1127     memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8_t) * m_cols);
1128     // Compute dot-product for every column.
1129     for (int row = 0; row < m_rows; ++row) {
1130       // Get the address of the first element of the row.
1131       int8_t* row_ptr = (int8_t*)matrix + row * m_cols;  // NOLINT
1132       if (unaligned) {
1133         memcpy(aligned_row, row_ptr, sizeof(int8_t) * m_cols);
1134         row_ptr = aligned_row;
1135       }
1136 
1137       // Initialize the dot product sum for the row to 0.
1138       int32x4_t dotprod_32x4 = vmovq_n_s32(0);
1139 
1140       // Prefetch the row to cache.
1141       __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
1142                          3 /* temporal locality */);
1143 
1144       // For every block of 16 8-bit elements.
1145       int col = 0;
1146       for (; col < postamble_half_start; col += kInt8ValuesPerNeonVector) {
1147         // Load 16 8-bit values from the row and vector, each, to operate on.
1148         // Here the assumption is that each buffer is 4-byte aligned. Otherwise,
1149         // performance may suffer significantly.
1150         TFLITE_DCHECK_EQ(  // NOLINT
1151             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
1152         const int8x16_t s1_8x16 = vld1q_s8((const int8_t*)(aligned_vec + col));
1153         const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr + col));
1154         // Multiply the low bits (i.e. the lower 8 8bit numbers in the
1155         // registers).
1156         int16x8_t prod_16x8 =
1157             vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
1158         // Multiply the high bits (i.e. the higher 8 8bit numbers in the
1159         // registers), and accumulate with the result of the low bits product.
1160         // The assumption here is that overflow will not happen as we quantize
1161         // our values to be in the range [-127, 127]. As such the sum of the 2
1162         // products is always strictly smaller than 15-bits (32767 in absolute
1163         // value).
1164         prod_16x8 =
1165             vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
1166 
1167         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
1168       }  // for col
1169 
1170       // Half iteration dealing only 8 elements
1171       if (TFLITE_UNLIKELY(col < postamble_start)) {
1172         // Load 8 8-bit values from the row and column each to operate on.
1173         // Here the assumption is that each buffer is 4-bytes aligned.
1174         // Otherwise, performance may suffer significantly.
1175         TFLITE_DCHECK_EQ(  // NOLINT
1176             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
1177         const int8x8_t s1_8x8 = vld1_s8((const int8_t*)(aligned_vec + col));
1178         const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
1179         const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
1180         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
1181         col += (kInt8ValuesPerNeonVector >> 1);
1182       }
1183       // Add the 4 intermediate sum values to get the final dot-prod value for
1184       // this row.
1185       int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
1186       // Postamble loop.
1187       for (; TFLITE_UNLIKELY(col < m_cols); ++col) {
1188         dotprod += row_ptr[col] * aligned_vec[col];
1189       }  // for col
1190 
1191       *result += dotprod * batch_scaling_factor;
1192       ++result;
1193     }  // for row
1194   }    // for batch
1195 
1196   if (unaligned) {
1197     free(aligned_row_free);
1198   }
1199   free(aligned_vec_free);
1200 }
1201 
NeonMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,int32_t * scratch,float * __restrict__ result,CpuBackendContext * context)1202 void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
1203                                              const int m_rows, const int m_cols,
1204                                              const int8_t* __restrict__ vectors,
1205                                              const float* scaling_factors,
1206                                              int n_batch, int32_t* scratch,
1207                                              float* __restrict__ result,
1208                                              CpuBackendContext* context) {
1209   if (m_rows % 4 == 0) {
1210     const int32_t* bias = static_cast<const int32_t*>(nullptr);
1211     NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
1212                        /*output_zp =*/0, scratch, context);
1213 
1214     // Multiply by float scaling factors and write to result
1215     const int total_size = n_batch * m_rows;
1216     int i = 0;
1217     for (; i <= total_size - 8; i += 8, result += 8) {
1218       const float batch_scaling_factor0 = scaling_factors[i / m_rows];
1219       const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
1220       const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
1221       const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
1222       const int32x4_t scratch_val0 = vld1q_s32(scratch + i);
1223       const int32x4_t scratch_val1 = vld1q_s32(scratch + i + 4);
1224       const float32x4_t float_val0 = vcvtq_f32_s32(scratch_val0);
1225       const float32x4_t float_val1 = vcvtq_f32_s32(scratch_val1);
1226       const float32x4_t result0 =
1227           vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0);
1228       const float32x4_t result1 =
1229           vmlaq_f32(vld1q_f32(result + 4), float_val1, scaling_factor1);
1230       vst1q_f32(result, result0);
1231       vst1q_f32(result + 4, result1);
1232     }
1233     scratch += i;
1234     for (; TFLITE_UNLIKELY(i < total_size); i++) {
1235       const float batch_scaling_factor = scaling_factors[i / m_rows];
1236       int32_t x = *(scratch++);
1237       *result += x * batch_scaling_factor;
1238       ++result;
1239     }
1240     return;
1241   }
1242   NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
1243                                           scaling_factors, n_batch, result);
1244 }
1245 
NeonMatrixScalarMultiplyAccumulate(const int8_t * matrix,int32_t scalar,int32_t n_row,int32_t n_col,int32_t * output)1246 void NeonMatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
1247                                         int32_t n_row, int32_t n_col,
1248                                         int32_t* output) {
1249   // Processing multiple rows at the same time actually makes it slower. :(
1250   for (int i = 0; i < n_row; ++i) {
1251     int32x4_t row_sum = vdupq_n_s32(0);
1252     int j = 0;
1253     const int8_t* row_ptr = matrix + i * n_col;
1254     for (; j <= n_col - kInt8ValuesPerNeonVector;
1255          j += kInt8ValuesPerNeonVector) {
1256       const int8x16_t input_value = vld1q_s8(row_ptr + j);
1257       int16x8_t temp = vmovl_s8(vget_low_s8(input_value));
1258       temp = vaddw_s8(temp, vget_high_s8(input_value));
1259       row_sum = vpadalq_s16(row_sum, temp);
1260     }
1261     int32_t sum = AccumulateNeonLane(row_sum);
1262     for (; TFLITE_UNLIKELY(j < n_col); ++j) {
1263       sum += *(row_ptr + j);
1264     }
1265     output[i] += sum * scalar;
1266   }
1267 }
1268 
NeonMatrixBatchVectorMultiplyAccumulateImpl(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * row_sums)1269 void NeonMatrixBatchVectorMultiplyAccumulateImpl(
1270     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
1271     const int8_t* __restrict__ vectors, const float* scaling_factors,
1272     int n_batch, float* __restrict__ result, const float* per_channel_scale,
1273     const int32_t* input_offset, int32_t* row_sums) {
1274 #ifdef __aarch64__
1275   if (HasSdotInstruction() && m_cols % 16 == 0 && m_rows % 2 == 0 &&
1276       m_rows >= n_batch) {
1277     if (n_batch % 4 == 0) {
1278       DotprodMatrixBatchFourVectorMultiplyAccumulate(
1279           matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
1280           per_channel_scale, input_offset, row_sums);
1281       return;
1282     } else if (n_batch >= 2 && m_rows * m_cols >= 128 * 128) {
1283       DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
1284           matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
1285           per_channel_scale, input_offset, row_sums);
1286       return;
1287     }
1288   }
1289 #endif  // __aarch64__
1290 
1291   bool unaligned = false;
1292   int8_t* aligned_row = nullptr;
1293   void* aligned_row_free = nullptr;
1294   if ((m_cols & (kNeonVectorAlignment - 1)) != 0) {
1295     unaligned = true;
1296     aligned_row =
1297         (int8_t*)aligned_alloc(kNeonVectorAlignment, m_cols,  // NOLINT
1298                                &aligned_row_free);
1299   }
1300   void* aligned_vec_free = nullptr;
1301   int8_t* aligned_vec =
1302       (int8_t*)aligned_alloc(kNeonVectorAlignment, m_cols,  // NOLINT
1303                              &aligned_vec_free);
1304 
1305   const int postamble_half_start =
1306       RoundDownVectors<kInt8ValuesPerNeonVector>(m_cols);
1307   const int postamble_start =
1308       RoundDownVectors<(kInt8ValuesPerNeonVector / 2)>(m_cols);
1309 
1310   int32_t* row_sums_ptr = row_sums;
1311   if (row_sums == nullptr) {
1312     row_sums_ptr = static_cast<int32_t*>(malloc(sizeof(int32_t) * m_rows));
1313     NeonReductionSumVector(matrix, row_sums_ptr, m_rows, m_cols);
1314   }
1315 
1316   for (int batch = 0; batch < n_batch; ++batch) {
1317     const float batch_scaling_factor = scaling_factors[batch];
1318     const int batch_input_offset = input_offset[batch];
1319     memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8_t) * m_cols);
1320     for (int row = 0; row < m_rows; ++row) {
1321       int8_t* row_ptr = (int8_t*)matrix + row * m_cols;  // NOLINT
1322       if (unaligned) {
1323         memcpy(aligned_row, row_ptr, sizeof(int8_t) * m_cols);
1324         row_ptr = aligned_row;
1325       }
1326       float scale = batch_scaling_factor;
1327       if (per_channel_scale) {
1328         scale *= per_channel_scale[row];
1329       }
1330       // Initialize the dot product sum for the row to 0.
1331       int32x4_t dotprod_32x4 = vmovq_n_s32(0);
1332 
1333       // Prefetch the row to cache.
1334       __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
1335                          3 /* temporal locality */);
1336 
1337       // For every block of 16 8-bit elements.
1338       int col = 0;
1339       for (; col < postamble_half_start; col += kInt8ValuesPerNeonVector) {
1340         // Load 16 8-bit values from the row and vector, each, to operate on.
1341         // Here the assumption is that each buffer is 4-byte aligned. Otherwise,
1342         // performance may suffer significantly.
1343         TFLITE_DCHECK_EQ(  // NOLINT
1344             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
1345         const int8x16_t s1_8x16 = vld1q_s8((const int8_t*)(aligned_vec + col));
1346         const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr + col));
1347         // Multiply the low bits (i.e. the lower 8 8bit numbers in the
1348         // registers).
1349         int16x8_t prod_16x8 =
1350             vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
1351         // Multiply the high bits (i.e. the higher 8 8bit numbers in the
1352         // registers), and accumulate with the result of the low bits product.
1353         // The assumption here is that overflow will not happen as we quantize
1354         // our values to be in the range [-127, 127]. As such the sum of the 2
1355         // products is always strictly smaller than 15-bits (32767 in absolute
1356         // value).
1357         prod_16x8 =
1358             vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
1359         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
1360       }  // for col
1361 
1362       // Half iteration dealing only 8 elements
1363       if (TFLITE_UNLIKELY(col < postamble_start)) {
1364         // Load 8 8-bit values from the row and column each to operate on.
1365         // Here the assumption is that each buffer is 4-bytes aligned.
1366         // Otherwise, performance may suffer significantly.
1367         TFLITE_DCHECK_EQ(  // NOLINT
1368             (uintptr_t)(&row_ptr[col]) & (kNeonVectorAlignment - 1), 0);
1369         const int8x8_t s1_8x8 = vld1_s8((const int8_t*)(aligned_vec + col));
1370         const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
1371         const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
1372         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
1373         col += (kInt8ValuesPerNeonVector >> 1);
1374       }
1375 
1376       int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
1377 
1378       // Postamble loop.
1379       for (; TFLITE_UNLIKELY(col < m_cols); ++col) {
1380         dotprod += row_ptr[col] * aligned_vec[col];
1381       }  // for col
1382       dotprod -= row_sums_ptr[row] * batch_input_offset;
1383       *result += dotprod * scale;
1384       ++result;
1385     }  // for row
1386   }    // for batch
1387 
1388   if (row_sums == nullptr) {
1389     free(row_sums_ptr);
1390   }
1391   if (unaligned) {
1392     free(aligned_row_free);
1393   }
1394   free(aligned_vec_free);
1395 }
1396 
NeonMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * context)1397 void NeonMatrixBatchVectorMultiplyAccumulate(
1398     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
1399     const int8_t* __restrict__ vectors, const float* scaling_factors,
1400     int n_batch, float* __restrict__ result, const float* per_channel_scale,
1401     const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
1402     bool* compute_row_sums, CpuBackendContext* context) {
1403   const bool use_cpu_backend_gemm = (context && context->use_caching()) ||
1404                                     UseCpuBackendGemm(m_rows, m_cols, n_batch);
1405   if (input_offset == nullptr) {
1406     if (use_cpu_backend_gemm && context) {
1407       NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
1408                                               scaling_factors, n_batch, scratch,
1409                                               result, context);
1410       return;
1411     }
1412     NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
1413                                             scaling_factors, n_batch, result);
1414     return;
1415   }
1416 
1417   if (compute_row_sums == nullptr || *compute_row_sums) {
1418     NeonReductionSumVector(matrix, row_sums, m_rows, m_cols);
1419     if (compute_row_sums) {
1420       *compute_row_sums = false;
1421     }
1422   }
1423 
1424   if (use_cpu_backend_gemm) {
1425     if (context != nullptr && m_rows % 4 == 0) {
1426       const int32_t* bias = static_cast<const int32_t*>(nullptr);
1427       NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0,
1428                          scratch, context);
1429 
1430       // Multiply by float scaling factors and write to result
1431       const int total_size = n_batch * m_rows;
1432       int i = 0;
1433       int32_t* scratch_ptr = scratch;
1434       for (; i <= total_size - 8; i += 8, result += 8) {
1435         const float batch_scaling_factor0 = scaling_factors[i / m_rows];
1436         const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
1437         const int batch_input_offset0 = -input_offset[i / m_rows];
1438         const int batch_input_offset1 = -input_offset[(i + 4) / m_rows];
1439         float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
1440         float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
1441         if (per_channel_scale) {
1442           const float32x4_t per_channel_scale0 =
1443               vld1q_f32(&per_channel_scale[i % m_rows]);
1444           const float32x4_t per_channel_scale1 =
1445               vld1q_f32(&per_channel_scale[(i + 4) % m_rows]);
1446           scaling_factor0 = vmulq_f32(scaling_factor0, per_channel_scale0);
1447           scaling_factor1 = vmulq_f32(scaling_factor1, per_channel_scale1);
1448         }
1449         const int32x4_t input_offset0 = vdupq_n_s32(batch_input_offset0);
1450         const int32x4_t input_offset1 = vdupq_n_s32(batch_input_offset1);
1451         const int32x4_t row_sum0 = vld1q_s32(row_sums + (i % m_rows));
1452         const int32x4_t row_sum1 = vld1q_s32(row_sums + ((i + 4) % m_rows));
1453         const int32x4_t scratch_val0 = vld1q_s32(scratch_ptr + i);
1454         const int32x4_t scratch_val1 = vld1q_s32(scratch_ptr + i + 4);
1455         const int32x4_t dotprod0 =
1456             vmlaq_s32(scratch_val0, row_sum0, input_offset0);
1457         const int32x4_t dotprod1 =
1458             vmlaq_s32(scratch_val1, row_sum1, input_offset1);
1459         const float32x4_t float_val0 = vcvtq_f32_s32(dotprod0);
1460         const float32x4_t float_val1 = vcvtq_f32_s32(dotprod1);
1461         const float32x4_t result0 =
1462             vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0);
1463         const float32x4_t result1 =
1464             vmlaq_f32(vld1q_f32(result + 4), float_val1, scaling_factor1);
1465         vst1q_f32(result, result0);
1466         vst1q_f32(result + 4, result1);
1467       }
1468 
1469       scratch_ptr += i;
1470       for (; TFLITE_UNLIKELY(i < total_size); i++) {
1471         float batch_scaling_factor = scaling_factors[i / m_rows];
1472         if (per_channel_scale) {
1473           batch_scaling_factor *= per_channel_scale[i % m_rows];
1474         }
1475         const int32_t zero_point = input_offset[i / m_rows];
1476         int32_t dotprod = *(scratch_ptr++);
1477         dotprod -= row_sums[i % m_rows] * zero_point;
1478         *result += dotprod * batch_scaling_factor;
1479         ++result;
1480       }
1481       return;
1482     }
1483   }
1484 
1485   NeonMatrixBatchVectorMultiplyAccumulateImpl(
1486       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
1487       per_channel_scale, input_offset, row_sums);
1488 }
1489 
MulAdd(int32x4_t acc,int32x4_t lhs,int32x4_t rhs)1490 inline int64x2x2_t MulAdd(int32x4_t acc, int32x4_t lhs, int32x4_t rhs) {
1491   int64x2x2_t result;
1492   const int64x2_t lhs_low = vmovl_s32(vget_low_s32(lhs));
1493   const int64x2_t lhs_high = vmovl_s32(vget_high_s32(lhs));
1494   const int64_t lhs_0 = vgetq_lane_s64(lhs_low, 0);
1495   const int64_t lhs_1 = vgetq_lane_s64(lhs_low, 1);
1496   const int64_t lhs_2 = vgetq_lane_s64(lhs_high, 0);
1497   const int64_t lhs_3 = vgetq_lane_s64(lhs_high, 1);
1498 
1499   const int64x2_t rhs_low = vmovl_s32(vget_low_s32(rhs));
1500   const int64x2_t rhs_high = vmovl_s32(vget_high_s32(rhs));
1501   const int64_t rhs_0 = vgetq_lane_s64(rhs_low, 0);
1502   const int64_t rhs_1 = vgetq_lane_s64(rhs_low, 1);
1503   const int64_t rhs_2 = vgetq_lane_s64(rhs_high, 0);
1504   const int64_t rhs_3 = vgetq_lane_s64(rhs_high, 1);
1505 
1506   const int64x2_t mul_0 = {lhs_0 * rhs_0, lhs_1 * rhs_1};
1507   const int64x2_t mul_1 = {lhs_2 * rhs_2, lhs_3 * rhs_3};
1508 
1509   result.val[0] = vaddq_s64(vmovl_s32(vget_low_s32(acc)), mul_0);
1510   result.val[1] = vaddq_s64(vmovl_s32(vget_high_s32(acc)), mul_1);
1511   return result;
1512 }
1513 
NeonApplyLayerNorm(const int16_t * input,const int16_t * layer_norm_weights,const int32_t * bias,int32_t layer_norm_scale_a,int32_t layer_norm_scale_b,int32_t variance_limit,int n_batch,int n_input,int16_t * output)1514 void NeonApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
1515                         const int32_t* bias, int32_t layer_norm_scale_a,
1516                         int32_t layer_norm_scale_b, int32_t variance_limit,
1517                         int n_batch, int n_input, int16_t* output) {
1518   const int32 int16_max = std::numeric_limits<int16>::max();
1519   const int32 int16_min = std::numeric_limits<int16>::min();
1520   const int32 temp = 1048576 / n_input;
1521 
1522   for (int i = 0; i < n_batch; ++i) {
1523     int64_t sum = 0;
1524     int64_t sum_sq = 0;
1525 
1526     int j = 0;
1527     for (; j <= n_input - 8; j += 8) {
1528       const int32 index = i * n_input + j;
1529       const int16x8_t val_s16 = vld1q_s16(input + index);
1530       const int32x4_t val_s32_0 = vmovl_s16(vget_low_s16(val_s16));
1531       const int32x4_t val_s32_1 = vmovl_s16(vget_high_s16(val_s16));
1532 
1533       sum += static_cast<int64_t>(AccumulateNeonLane(val_s32_0));
1534       sum += static_cast<int64_t>(AccumulateNeonLane(val_s32_1));
1535 
1536       sum_sq += static_cast<int64_t>(
1537           AccumulateNeonLane(vmulq_s32(val_s32_0, val_s32_0)));
1538       sum_sq += static_cast<int64_t>(
1539           AccumulateNeonLane(vmulq_s32(val_s32_1, val_s32_1)));
1540     }
1541     for (; TFLITE_UNLIKELY(j < n_input); ++j) {
1542       const int32 index = i * n_input + j;
1543       int32 val = static_cast<int32_t>(input[index]);
1544       sum += val;
1545       sum_sq += val * val;
1546     }
1547 
1548     // Divide by `n_input` first to avoid overflow but only works for POT
1549     // `n_input`.
1550     int32_t mean =
1551         static_cast<int32_t>(static_cast<int64_t>(sum) * 1024 / n_input);
1552     int64_t variance =
1553         sum_sq * temp - static_cast<int64_t>(mean) * static_cast<int64_t>(mean);
1554     int32_t variance2 = static_cast<int32>(variance / 1048576);
1555     if (variance2 < 1) {
1556       variance2 = variance_limit;
1557     }
1558     int32_t stddev_inverse_a;
1559     int stddev_inverse_b;
1560     GetInvSqrtQuantizedMultiplierExp(variance2, /*reverse_shift*/ -1,
1561                                      &stddev_inverse_a, &stddev_inverse_b);
1562 
1563     j = 0;
1564     const int32x4_t mean_dup = vdupq_n_s32(mean);
1565     for (; j <= n_input - 16; j += 16) {
1566       // Load 16 items at once.
1567       const int32 index = i * n_input + j;
1568       const int16x8_t val_s16_0 = vld1q_s16(input + index);
1569       const int16x8_t val_s16_1 = vld1q_s16(input + index + 8);
1570 
1571       int32x4x4_t shifted;
1572       shifted.val[0] = vsubq_s32(
1573           vshlq_n_s32(vmovl_s16(vget_low_s16(val_s16_0)), 10), mean_dup);
1574       shifted.val[1] = vsubq_s32(
1575           vshlq_n_s32(vmovl_s16(vget_high_s16(val_s16_0)), 10), mean_dup);
1576       shifted.val[2] = vsubq_s32(
1577           vshlq_n_s32(vmovl_s16(vget_low_s16(val_s16_1)), 10), mean_dup);
1578       shifted.val[3] = vsubq_s32(
1579           vshlq_n_s32(vmovl_s16(vget_high_s16(val_s16_1)), 10), mean_dup);
1580 
1581       int32x4x4_t rescaled = MultiplyByQuantizedMultiplier4Rows(
1582           shifted, stddev_inverse_a, stddev_inverse_b);
1583 
1584       const int32x4_t bias_0 = vld1q_s32(bias + j);
1585       const int32x4_t bias_1 = vld1q_s32(bias + j + 4);
1586       const int32x4_t bias_2 = vld1q_s32(bias + j + 8);
1587       const int32x4_t bias_3 = vld1q_s32(bias + j + 12);
1588 
1589       const int16x8_t layer_norm_weights_s16_0 =
1590           vld1q_s16(layer_norm_weights + j);
1591       const int16x8_t layer_norm_weights_s16_1 =
1592           vld1q_s16(layer_norm_weights + j + 8);
1593       const int32x4_t layer_norm_weights_s32_0 =
1594           vmovl_s16(vget_low_s16(layer_norm_weights_s16_0));
1595       const int32x4_t layer_norm_weights_s32_1 =
1596           vmovl_s16(vget_high_s16(layer_norm_weights_s16_0));
1597       const int32x4_t layer_norm_weights_s32_2 =
1598           vmovl_s16(vget_low_s16(layer_norm_weights_s16_1));
1599       const int32x4_t layer_norm_weights_s32_3 =
1600           vmovl_s16(vget_high_s16(layer_norm_weights_s16_1));
1601 
1602       int64x2x2_t val3_0 =
1603           MulAdd(bias_0, rescaled.val[0], layer_norm_weights_s32_0);
1604       int64x2x2_t val3_1 =
1605           MulAdd(bias_1, rescaled.val[1], layer_norm_weights_s32_1);
1606       int64x2x2_t val3_2 =
1607           MulAdd(bias_2, rescaled.val[2], layer_norm_weights_s32_2);
1608       int64x2x2_t val3_3 =
1609           MulAdd(bias_3, rescaled.val[3], layer_norm_weights_s32_3);
1610 
1611       int32x4x4_t val4;
1612       val4.val[0] = vcombine_s32(vmovn_s64(vrshrq_n_s64(val3_0.val[0], 10)),
1613                                  vmovn_s64(vrshrq_n_s64(val3_0.val[1], 10)));
1614       val4.val[1] = vcombine_s32(vmovn_s64(vrshrq_n_s64(val3_1.val[0], 10)),
1615                                  vmovn_s64(vrshrq_n_s64(val3_1.val[1], 10)));
1616       val4.val[2] = vcombine_s32(vmovn_s64(vrshrq_n_s64(val3_2.val[0], 10)),
1617                                  vmovn_s64(vrshrq_n_s64(val3_2.val[1], 10)));
1618       val4.val[3] = vcombine_s32(vmovn_s64(vrshrq_n_s64(val3_3.val[0], 10)),
1619                                  vmovn_s64(vrshrq_n_s64(val3_3.val[1], 10)));
1620 
1621       int32x4x4_t val5_s32 = MultiplyByQuantizedMultiplier4Rows(
1622           val4, layer_norm_scale_a, layer_norm_scale_b + 12);
1623       vst1_s16(output + index, vqmovn_s32(val5_s32.val[0]));
1624       vst1_s16(output + index + 4, vqmovn_s32(val5_s32.val[1]));
1625       vst1_s16(output + index + 8, vqmovn_s32(val5_s32.val[2]));
1626       vst1_s16(output + index + 12, vqmovn_s32(val5_s32.val[3]));
1627     }
1628     for (; TFLITE_UNLIKELY(j < n_input); ++j) {
1629       const int32 index = i * n_input + j;
1630       int32 val = static_cast<int32_t>(input[index]);
1631       int32 shifted = 1024 * val - mean;
1632       int32 rescaled = MultiplyByQuantizedMultiplier(shifted, stddev_inverse_a,
1633                                                      stddev_inverse_b);
1634       // TODO(jianlijianli): Saturate this.
1635       int64_t val3 = rescaled * layer_norm_weights[j] + bias[j];
1636       int32 val4 =
1637           static_cast<int32>((val3 > 0 ? val3 + 512 : val3 - 512) / 1024);
1638       int32 val5 = MultiplyByQuantizedMultiplier(val4, layer_norm_scale_a,
1639                                                  layer_norm_scale_b + 12);
1640       val5 = std::min(std::max(int16_min, val5), int16_max);
1641       output[index] = static_cast<int16_t>(val5);
1642     }
1643   }
1644 }
1645 
NeonApplySigmoid(const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)1646 void NeonApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
1647                       int16_t* output) {
1648   for (int batch = 0; batch < n_batch; ++batch) {
1649     int i = 0;
1650 #ifdef GEMMLOWP_NEON
1651     // F0 uses 0 integer bits, range [-1, 1].
1652     // This is the return type of math functions such as tanh, logistic,
1653     // whose range is in [-1, 1].
1654     using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
1655     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
1656     using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
1657 
1658     for (; i <= n_input - 32; i += 32) {
1659       const int index = batch * n_input + i;
1660       F3 input0 = F3::FromRaw(vld1q_s16(input + index));
1661       F3 input1 = F3::FromRaw(vld1q_s16(input + index + 8));
1662       F3 input2 = F3::FromRaw(vld1q_s16(input + index + 16));
1663       F3 input3 = F3::FromRaw(vld1q_s16(input + index + 24));
1664       F0 output0 = gemmlowp::logistic(input0);
1665       F0 output1 = gemmlowp::logistic(input1);
1666       F0 output2 = gemmlowp::logistic(input2);
1667       F0 output3 = gemmlowp::logistic(input3);
1668       vst1q_s16(output + index, output0.raw());
1669       vst1q_s16(output + index + 8, output1.raw());
1670       vst1q_s16(output + index + 16, output2.raw());
1671       vst1q_s16(output + index + 24, output3.raw());
1672     }
1673 #endif  // GEMMLOWP_NEON
1674     using F0_Scalar = gemmlowp::FixedPoint<int16_t, 0>;
1675     using F3_Scalar = gemmlowp::FixedPoint<int16_t, 3>;
1676     for (; i < n_input; ++i) {
1677       const int index = batch * n_input + i;
1678       F3_Scalar input_f3 = F3_Scalar::FromRaw(input[index]);
1679       F0_Scalar output_f0 = gemmlowp::logistic(input_f3);
1680       output[index] = output_f0.raw();
1681     }
1682   }
1683 }
1684 
1685 template <int IntegerBits>
NeonApplyTanhImpl(const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)1686 void NeonApplyTanhImpl(const int16_t* input, int32_t n_batch, int32_t n_input,
1687                        int16_t* output) {
1688   for (int batch = 0; batch < n_batch; ++batch) {
1689     int i = 0;
1690 #ifdef GEMMLOWP_NEON
1691     // F0 uses 0 integer bits, range [-1, 1].
1692     // This is the return type of math functions such as tanh, logistic,
1693     // whose range is in [-1, 1].
1694     using F_In = gemmlowp::FixedPoint<int16x8_t, IntegerBits>;
1695     using F_Out = gemmlowp::FixedPoint<int16x8_t, 0>;
1696 
1697     for (; i <= n_input - 32; i += 32) {
1698       const int index = batch * n_input + i;
1699       F_In input0 = F_In::FromRaw(vld1q_s16(input + index));
1700       F_In input1 = F_In::FromRaw(vld1q_s16(input + index + 8));
1701       F_In input2 = F_In::FromRaw(vld1q_s16(input + index + 16));
1702       F_In input3 = F_In::FromRaw(vld1q_s16(input + index + 24));
1703       F_Out output0 = gemmlowp::tanh(input0);
1704       F_Out output1 = gemmlowp::tanh(input1);
1705       F_Out output2 = gemmlowp::tanh(input2);
1706       F_Out output3 = gemmlowp::tanh(input3);
1707       vst1q_s16(output + index, output0.raw());
1708       vst1q_s16(output + index + 8, output1.raw());
1709       vst1q_s16(output + index + 16, output2.raw());
1710       vst1q_s16(output + index + 24, output3.raw());
1711     }
1712 #endif  // GEMMLOWP_NEON
1713     using F_In_Scalar = gemmlowp::FixedPoint<int16_t, IntegerBits>;
1714     using F_Out_Scalar = gemmlowp::FixedPoint<int16_t, 0>;
1715     for (; i < n_input; ++i) {
1716       const int index = batch * n_input + i;
1717       F_In_Scalar input_in = F_In_Scalar::FromRaw(input[index]);
1718       F_Out_Scalar output_out = gemmlowp::tanh(input_in);
1719       output[index] = output_out.raw();
1720     }
1721   }
1722 }
1723 
NeonApplyTanh(int32_t integer_bits,const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)1724 void NeonApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
1725                    int32_t n_input, int16_t* output) {
1726   assert(integer_bits <= 6);
1727 #define DISPATCH_TANH(i)                                   \
1728   case i:                                                  \
1729     NeonApplyTanhImpl<i>(input, n_batch, n_input, output); \
1730     break;
1731   switch (integer_bits) {
1732     DISPATCH_TANH(0);
1733     DISPATCH_TANH(1);
1734     DISPATCH_TANH(2);
1735     DISPATCH_TANH(3);
1736     DISPATCH_TANH(4);
1737     DISPATCH_TANH(5);
1738     DISPATCH_TANH(6);
1739     default:
1740       return;
1741   }
1742 #undef DISPATCH_TANH
1743 }
1744 
NeonCwiseMul(const int16_t * input_1,const int16_t * input_2,int n_batch,int n_input,int shift,int16_t * output)1745 void NeonCwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
1746                   int n_input, int shift, int16_t* output) {
1747   for (int batch = 0; batch < n_batch; ++batch) {
1748     int i = 0;
1749     for (; i <= n_input - 8; i += 8) {
1750       const int index = batch * n_input + i;
1751       const int16x8_t a = vld1q_s16(input_1 + index);
1752       const int16x8_t b = vld1q_s16(input_2 + index);
1753       const int32x4_t a_s32_0 = vmovl_s16(vget_low_s16(a));
1754       const int32x4_t a_s32_1 = vmovl_s16(vget_high_s16(a));
1755       const int32x4_t b_s32_0 = vmovl_s16(vget_low_s16(b));
1756       const int32x4_t b_s32_1 = vmovl_s16(vget_high_s16(b));
1757 
1758       int32x4_t x_0 = vmulq_s32(a_s32_0, b_s32_0);
1759       int32x4_t x_1 = vmulq_s32(a_s32_1, b_s32_1);
1760       x_0 = gemmlowp::RoundingDivideByPOT(x_0, shift);
1761       x_1 = gemmlowp::RoundingDivideByPOT(x_1, shift);
1762 
1763       const int16x8_t result = vcombine_s16(vmovn_s32(x_0), vmovn_s32(x_1));
1764       vst1q_s16(output + index, result);
1765     }
1766     for (; TFLITE_UNLIKELY(i < n_input); ++i) {
1767       const int index = batch * n_input + i;
1768       const int16_t a = input_1[index];
1769       const int16_t b = input_2[index];
1770       int64_t x = a * b;
1771       if (x > std::numeric_limits<std::int32_t>::max()) {
1772         x = std::numeric_limits<std::int32_t>::max();
1773       }
1774       const int32_t value = static_cast<int32_t>(x);
1775       output[index] =
1776           static_cast<int16_t>(gemmlowp::RoundingDivideByPOT(value, shift));
1777     }
1778   }
1779 }
1780 
NeonCwiseMul(const int16_t * input_1,const int16_t * input_2,int32_t multiplier,int shift,int n_batch,int n_input,int32_t output_zp,int8_t * output)1781 void NeonCwiseMul(const int16_t* input_1, const int16_t* input_2,
1782                   int32_t multiplier, int shift, int n_batch, int n_input,
1783                   int32_t output_zp, int8_t* output) {
1784   const int32_t output_min = std::numeric_limits<int8_t>::min();
1785   const int32_t output_max = std::numeric_limits<int8_t>::max();
1786 
1787   const int32x4_t output_zp_dup = vdupq_n_s32(-output_zp);
1788   const int32x4_t max_val_dup = vdupq_n_s32(output_max);
1789   const int32x4_t min_val_dup = vdupq_n_s32(output_min);
1790 
1791   for (int batch = 0; batch < n_batch; ++batch) {
1792     int i = 0;
1793     for (; i <= n_input - 8; i += 8) {
1794       const int index = batch * n_input + i;
1795       const int16x8_t a = vld1q_s16(input_1 + index);
1796       const int16x8_t b = vld1q_s16(input_2 + index);
1797       const int32x4_t a_s32_0 = vmovl_s16(vget_low_s16(a));
1798       const int32x4_t a_s32_1 = vmovl_s16(vget_high_s16(a));
1799       const int32x4_t b_s32_0 = vmovl_s16(vget_low_s16(b));
1800       const int32x4_t b_s32_1 = vmovl_s16(vget_high_s16(b));
1801 
1802       int32x4x2_t temp_val;
1803       temp_val.val[0] = vmulq_s32(a_s32_0, b_s32_0);
1804       temp_val.val[1] = vmulq_s32(a_s32_1, b_s32_1);
1805       temp_val =
1806           MultiplyByQuantizedMultiplier2Rows(temp_val, multiplier, shift);
1807 
1808       temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
1809       temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
1810       temp_val.val[0] =
1811           vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
1812       temp_val.val[1] =
1813           vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
1814 
1815       const int16x8_t result =
1816           vcombine_s16(vmovn_s32(temp_val.val[0]), vmovn_s32(temp_val.val[1]));
1817       vst1_s8(output + index, vmovn_s16(result));
1818     }
1819     for (; TFLITE_UNLIKELY(i < n_input); ++i) {
1820       const int index = batch * n_input + i;
1821       const int16_t a = input_1[index];
1822       const int16_t b = input_2[index];
1823       int32_t value = static_cast<int32_t>(a) * static_cast<int32_t>(b);
1824       value = MultiplyByQuantizedMultiplier(value, multiplier, shift);
1825       value -= output_zp;
1826       value = std::min(std::max(-128, value), 127);
1827 
1828       output[index] = static_cast<int8>(value);
1829     }
1830   }
1831 }
1832 
NeonCwiseAdd(const int16_t * input_1,const int16_t * input_2,int n_batch,int n_input,int16_t * output)1833 void NeonCwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
1834                   int n_input, int16_t* output) {
1835   const int32 int16_max = std::numeric_limits<int16>::max();
1836   const int32 int16_min = std::numeric_limits<int16>::min();
1837   for (int batch = 0; batch < n_batch; ++batch) {
1838     int i = 0;
1839     for (; i <= n_input - 8; i += 8) {
1840       const int index = batch * n_input + i;
1841       const int16x8_t a = vld1q_s16(input_1 + index);
1842       const int16x8_t b = vld1q_s16(input_2 + index);
1843       const int32x4_t a_s32_0 = vmovl_s16(vget_low_s16(a));
1844       const int32x4_t a_s32_1 = vmovl_s16(vget_high_s16(a));
1845       const int32x4_t b_s32_0 = vmovl_s16(vget_low_s16(b));
1846       const int32x4_t b_s32_1 = vmovl_s16(vget_high_s16(b));
1847 
1848       const int32x4_t sum_0 = vaddq_s32(a_s32_0, b_s32_0);
1849       const int32x4_t sum_1 = vaddq_s32(a_s32_1, b_s32_1);
1850       vst1_s16(output + index, vqmovn_s32(sum_0));
1851       vst1_s16(output + index + 4, vqmovn_s32(sum_1));
1852     }
1853     for (; TFLITE_UNLIKELY(i < n_input); ++i) {
1854       const int index = batch * n_input + i;
1855       int32_t sum = input_1[index] + input_2[index];
1856       const int32 sum_clamped = std::min(int16_max, std::max(int16_min, sum));
1857       output[index] = static_cast<int16_t>(sum_clamped);
1858     }
1859   }
1860 }
1861 
NeonCwiseClipping(float * vector,const int v_size,const float clipping_value)1862 void NeonCwiseClipping(float* vector, const int v_size,
1863                        const float clipping_value) {
1864   const float32x4_t clipping_value_f32x4 = vmovq_n_f32(clipping_value);
1865   const float32x4_t neg_clipping_value_f32x4 = vmovq_n_f32(-clipping_value);
1866 
1867   int i = 0;
1868   for (; i <= v_size - kFloatValuesPerNeonVector;
1869        i += kFloatValuesPerNeonVector) {
1870     // Load from memory to vector.
1871     float32x4_t v_f32x4 = vld1q_f32(vector + i);
1872     // Clip between clipping_value and -clipping_value.
1873     v_f32x4 = vminq_f32(clipping_value_f32x4, v_f32x4);
1874     v_f32x4 = vmaxq_f32(neg_clipping_value_f32x4, v_f32x4);
1875     // Save to output.
1876     vst1q_f32(vector + i, v_f32x4);
1877   }
1878   for (; TFLITE_UNLIKELY(i < v_size); i++) {
1879     vector[i] = std::max(std::min(clipping_value, vector[i]), -clipping_value);
1880   }
1881 }
1882 
NeonCwiseClipping(int16_t * vector,const int v_size,const int16_t clipping_value)1883 void NeonCwiseClipping(int16_t* vector, const int v_size,
1884                        const int16_t clipping_value) {
1885   const int16x8_t max_dup = vdupq_n_s16(clipping_value);
1886   const int16x8_t min_dup = vdupq_n_s16(-clipping_value);
1887 
1888   int i = 0;
1889   for (; i <= v_size - kInt16ValuesPerNeonVector * 2;
1890        i += kInt16ValuesPerNeonVector * 2) {
1891     int16x8_t val_0 = vld1q_s16(vector + i);
1892     int16x8_t val_1 = vld1q_s16(vector + i + kInt16ValuesPerNeonVector);
1893     val_0 = vminq_s16(val_0, max_dup);
1894     val_1 = vminq_s16(val_1, max_dup);
1895     val_0 = vmaxq_s16(val_0, min_dup);
1896     val_1 = vmaxq_s16(val_1, min_dup);
1897     vst1q_s16(vector + i, val_0);
1898     vst1q_s16(vector + i + kInt16ValuesPerNeonVector, val_1);
1899   }
1900   for (; TFLITE_UNLIKELY(i < v_size); i++) {
1901     vector[i] = std::max(std::min(clipping_value, vector[i]),
1902                          static_cast<int16_t>(-clipping_value));
1903   }
1904 }
1905 
NeonCwiseClipping(int8_t * vector,const int v_size,const int8_t clipping_value)1906 void NeonCwiseClipping(int8_t* vector, const int v_size,
1907                        const int8_t clipping_value) {
1908   const int8x16_t max_dup = vdupq_n_s8(clipping_value);
1909   const int8x16_t min_dup = vdupq_n_s8(-clipping_value);
1910 
1911   int i = 0;
1912   for (; i < v_size - kInt8ValuesPerNeonVector * 2;
1913        i += kInt8ValuesPerNeonVector * 2) {
1914     int8x16_t val_0 = vld1q_s8(vector + i);
1915     int8x16_t val_1 = vld1q_s8(vector + i + kInt8ValuesPerNeonVector);
1916     val_0 = vminq_s8(val_0, max_dup);
1917     val_1 = vminq_s8(val_1, max_dup);
1918     val_0 = vmaxq_s8(val_0, min_dup);
1919     val_1 = vmaxq_s8(val_1, min_dup);
1920     vst1q_s8(vector + i, val_0);
1921     vst1q_s8(vector + i + kInt8ValuesPerNeonVector, val_1);
1922   }
1923   for (; TFLITE_UNLIKELY(i < v_size); i++) {
1924     vector[i] = std::max(std::min(clipping_value, vector[i]),
1925                          static_cast<int8_t>(-clipping_value));
1926   }
1927 }
1928 
NeonSparseMatrixBatchVectorMultiplyAccumulate1x4(const float * __restrict__ matrix,const int32_t * __restrict__ segments,const int32_t * __restrict__ indices,int m_rows,int m_cols,const float * __restrict__ vector,int n_batch,float * __restrict__ result)1929 void NeonSparseMatrixBatchVectorMultiplyAccumulate1x4(
1930     const float* __restrict__ matrix, const int32_t* __restrict__ segments,
1931     const int32_t* __restrict__ indices, int m_rows, int m_cols,
1932     const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
1933   constexpr int kBlockSize = kFloatValuesPerNeonVector;
1934   TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
1935 
1936   for (int batch = 0; batch < n_batch; batch++) {
1937     const float* matrix_ptr = matrix;
1938     for (int row = 0; row < m_rows; row++) {
1939       float32x4_t acc_32x4 = vmovq_n_f32(0.0);
1940       const float* vector_in_batch = vector + batch * m_cols;
1941 
1942       for (int i = segments[row]; i < segments[row + 1]; i++) {
1943         const int block_start_index = indices[i] * kBlockSize;
1944         const float* vector_block_in_batch_ptr =
1945             vector_in_batch + block_start_index;
1946 
1947         // Load 4 float values from the vector and matrix row.
1948         float32x4_t vector_f32x4 = vld1q_f32(vector_block_in_batch_ptr);
1949         float32x4_t matrix_f32x4 = vld1q_f32(matrix_ptr);
1950         // Multiply the vector and matrix row and add to accumulator.
1951         acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
1952         matrix_ptr += kBlockSize;
1953       }
1954       result[batch * m_rows + row] += AccumulateNeonLane(acc_32x4);
1955     }
1956   }
1957 }
1958 
NeonSparseMatrixBatchVectorMultiplyAccumulate1x16(const int8_t * __restrict__ matrix,const int32_t * __restrict__ segments,const int32_t * __restrict__ indices,int m_rows,int m_cols,const int8_t * __restrict__ vector,const int32_t * __restrict__ bias_vector,int n_batch,const int32_t input_offset,const int32_t output_multiplier,const int32_t output_shift,const int32_t output_offset,const int32_t output_activation_min,const int32_t output_activation_max,int8_t * __restrict__ result)1959 void NeonSparseMatrixBatchVectorMultiplyAccumulate1x16(
1960     const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
1961     const int32_t* __restrict__ indices, int m_rows, int m_cols,
1962     const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
1963     int n_batch, const int32_t input_offset, const int32_t output_multiplier,
1964     const int32_t output_shift, const int32_t output_offset,
1965     const int32_t output_activation_min, const int32_t output_activation_max,
1966     int8_t* __restrict__ result) {
1967   constexpr int kBlockSize = kInt8ValuesPerNeonVector;
1968   TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
1969 
1970   for (int batch = 0; batch < n_batch; ++batch) {
1971     const int8_t* matrix_ptr = matrix;
1972     for (int row = 0; row < m_rows; ++row) {
1973       // Accumulation loop.
1974       int32x4_t acc_i32x4 = vmovq_n_s32(0);
1975       int32_t matrix_row_sum = 0;
1976       const int8_t* vector_in_batch = vector + batch * m_cols;
1977 
1978       for (int i = segments[row]; i < segments[row + 1]; ++i) {
1979         const int block_start_index = indices[i] * kBlockSize;
1980         const int8_t* vector_block_in_batch_ptr =
1981             vector_in_batch + block_start_index;
1982 
1983         // Load 16 int8 values from the vector and matrix row.
1984         int8x16_t vector_i8x16 = vld1q_s8(vector_block_in_batch_ptr);
1985         int8x16_t matrix_i8x16 = vld1q_s8(matrix_ptr);
1986 #ifdef __aarch64__
1987         int16_t matrix_block_sum = vaddlvq_s8(matrix_i8x16);
1988 #else
1989         int16_t matrix_block_sum =
1990             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 0)) +
1991             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 1)) +
1992             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 2)) +
1993             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 3)) +
1994             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 4)) +
1995             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 5)) +
1996             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 6)) +
1997             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 7)) +
1998             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 8)) +
1999             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 9)) +
2000             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 10)) +
2001             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 11)) +
2002             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 12)) +
2003             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 13)) +
2004             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 14)) +
2005             static_cast<int8_t>(vgetq_lane_s8(matrix_i8x16, 15));
2006 #endif
2007 
2008         // Multiply the vector and matrix row and add to accumulator.
2009         int16x8_t acc_i16x8 =
2010             vmull_s8(vget_low_s8(vector_i8x16), vget_low_s8(matrix_i8x16));
2011         acc_i16x8 = vmlal_s8(acc_i16x8, vget_high_s8(vector_i8x16),
2012                              vget_high_s8(matrix_i8x16));
2013         acc_i32x4 = vpadalq_s16(acc_i32x4, acc_i16x8);
2014         matrix_row_sum += matrix_block_sum;
2015         matrix_ptr += kBlockSize;
2016       }
2017 #ifdef __aarch64__
2018       int32_t acc = vaddvq_s32(acc_i32x4);
2019 #else
2020       int32_t acc = vgetq_lane_s32(acc_i32x4, 0) +
2021                     vgetq_lane_s32(acc_i32x4, 1) +
2022                     vgetq_lane_s32(acc_i32x4, 2) + vgetq_lane_s32(acc_i32x4, 3);
2023 #endif
2024       const int32_t bias_value = bias_vector != nullptr ? bias_vector[row] : 0;
2025       acc = acc + bias_value + input_offset * matrix_row_sum;
2026       acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
2027       acc += output_offset;
2028       result[batch * m_rows + row] =
2029           static_cast<int8_t>(ActivationFunctionWithMinMax(
2030               acc, output_activation_min, output_activation_max));
2031     }
2032   }
2033 }
2034 
NeonSparseMatrixBatchVectorMultiplyAccumulate(const float * __restrict__ matrix,const uint8_t * __restrict__ ledger,int m_rows,int m_cols,const float * __restrict__ vector,int n_batch,float * __restrict__ result)2035 void NeonSparseMatrixBatchVectorMultiplyAccumulate(
2036     const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
2037     int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
2038     float* __restrict__ result) {
2039   constexpr int kNeonVectorsPerBlock = 4;
2040   constexpr int kBlockSize = kNeonVectorsPerBlock * kFloatValuesPerNeonVector;
2041   TFLITE_DCHECK_EQ(  // NOLINT
2042       m_cols % kBlockSize, 0);
2043 
2044   for (int batch = 0; batch < n_batch; batch++) {
2045     const float* matrix_ptr = matrix;
2046     const uint8_t* ledger_ptr = ledger;
2047     for (int row = 0; row < m_rows; row++) {
2048       int num_nonzero_blocks = *ledger_ptr++;
2049       if (num_nonzero_blocks > 0) {
2050         float32x4_t acc_32x4 = vmovq_n_f32(0.0);
2051         const float* vector_in_batch = vector + batch * m_cols;
2052 
2053         for (int i = 0; i < num_nonzero_blocks; i++) {
2054           const int block_start_index = *ledger_ptr++ * kBlockSize;
2055           const float* vector_block_in_batch_ptr =
2056               vector_in_batch + block_start_index;
2057 
2058           for (int c = 0; c < kNeonVectorsPerBlock; c++) {
2059             // Load 4 float values from the vector and matrix row.
2060             float32x4_t vector_f32x4 = vld1q_f32(vector_block_in_batch_ptr +
2061                                                  c * kFloatValuesPerNeonVector);
2062             float32x4_t matrix_f32x4 =
2063                 vld1q_f32(matrix_ptr + c * kFloatValuesPerNeonVector);
2064             // Multiply the vector and matrix row and add to accumulator.
2065             acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
2066           }
2067           matrix_ptr += kBlockSize;
2068         }
2069         result[batch * m_rows + row] += AccumulateNeonLane(acc_32x4);
2070       }
2071     }
2072   }
2073 }
2074 
NeonSparseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)2075 void NeonSparseMatrixBatchVectorMultiplyAccumulate(
2076     const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
2077     const int m_cols, const int8_t* __restrict__ vectors,
2078     const float* scaling_factors, int n_batch, float* __restrict__ result) {
2079 #ifdef __aarch64__
2080   if (HasSdotInstruction() && m_cols % 16 == 0) {
2081     DotprodSparseMatrixBatchVectorMultiplyAccumulate(
2082         matrix, ledger, m_rows, m_cols, vectors, scaling_factors, n_batch,
2083         result);
2084     return;
2085   }
2086 #endif  // __aarch64__
2087 
2088   constexpr int kBlockSize = kInt8ValuesPerNeonVector;
2089   TFLITE_DCHECK_EQ(  // NOLINT
2090       m_cols % kBlockSize, 0);
2091   void* aligned_vec_free = nullptr;
2092   int8_t* aligned_vec =
2093       (int8_t*)aligned_alloc(kNeonVectorAlignment, m_cols,  // NOLINT
2094                              &aligned_vec_free);
2095 
2096   for (int batch = 0; batch < n_batch; ++batch) {
2097     const float batch_scaling_factor = scaling_factors[batch];
2098     // Copy the vector data to an aligned vector.
2099     memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols);
2100 
2101     const uint8_t* ledger_ptr = ledger;
2102     const int8_t* row_ptr = matrix;
2103     for (int row = 0; row < m_rows; ++row) {
2104       // Initialize the dot product sum for the row to 0.
2105       int32x4_t dotprod_32x4 = vmovq_n_s32(0);
2106       int num_nonzero_blocks = *ledger_ptr++;
2107       if (num_nonzero_blocks > 0) {
2108         // Prefetch the row to cache.
2109         __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
2110                            3 /* temporal locality */);
2111         for (int i = 0; i < num_nonzero_blocks; i++) {
2112           const int col_index = *ledger_ptr++ * kBlockSize;
2113           // Load 16 8-bit values from the row and vector, each, to operate on.
2114           // Here the assumption is that each buffer is 4-byte aligned.
2115           // Otherwise, performance may suffer significantly.
2116           TFLITE_DCHECK_EQ(  // NOLINT
2117               (uintptr_t)(&row_ptr) & (kNeonVectorAlignment - 1), 0);
2118           const int8x16_t s1_8x16 =
2119               vld1q_s8((const int8_t*)(aligned_vec + col_index));
2120           const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr));
2121           // Multiply the low bits (i.e. the lower 8 8bit numbers in the
2122           // registers).
2123           int16x8_t prod_16x8 =
2124               vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
2125           // Multiply the high bits (i.e. the lower 8 8bit numbers in the
2126           // registers), and accumulate with the result of the low bits product.
2127           // The assumption here is that overflow will not happen as we quantize
2128           // our values to be in the range [-127, 127]. As such the sum of the 2
2129           // products is always strictly smaller than 15-bits (32767 in absolute
2130           // value).
2131           prod_16x8 =
2132               vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
2133 
2134           dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
2135           row_ptr += kBlockSize;
2136         }
2137         // Add the 4 intermediate sum values to get the final dot-prod value for
2138         // this row.
2139         int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
2140         result[batch * m_rows + row] += dotprod * batch_scaling_factor;
2141       }
2142     }  // for row
2143   }    // for batch
2144   free(aligned_vec_free);
2145 }
2146 
NeonSub1Vector(const float * vector,int v_size,float * result)2147 void NeonSub1Vector(const float* vector, int v_size, float* result) {
2148   // If v_size is not divisible by the vector size, then we need to process the
2149   // final few elements sequentially. postamble_start shows the start index
2150   // where this should happen.
2151   const int postamble_start =
2152       RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
2153 
2154   float32x4_t one_f32x4 = vmovq_n_f32(1.0);
2155   int v = 0;
2156   for (; v < postamble_start; v += kFloatValuesPerNeonVector) {
2157     // Load 4 float values from the current pointers of the input column and
2158     // subtract from 1.
2159     float32x4_t v_f32x4 = vld1q_f32(vector + v);
2160     float32x4_t result_f32x4 = vsubq_f32(one_f32x4, v_f32x4);
2161     // Save to output.
2162     vst1q_f32(result + v, result_f32x4);
2163   }
2164   for (; TFLITE_UNLIKELY(v < v_size); v++) {
2165     result[v] = 1.0f - vector[v];
2166   }
2167 }
2168 
NeonSub1Vector(const int16_t * vector,int v_size,int16_t * result)2169 void NeonSub1Vector(const int16_t* vector, int v_size, int16_t* result) {
2170   int postamble_start = RoundDownVectors<kInt16ValuesPerNeonVector>(v_size);
2171   static const int16_t kOne = 32767;
2172   // Use xor to replace substract from 1 << 15 - 1.
2173   // Local benchmark shows it's slightly faster than pure substract.
2174   const int16x8_t one_dup = vdupq_n_s16(kOne);
2175   int i = 0;
2176   for (; i < postamble_start; i += kInt16ValuesPerNeonVector) {
2177     const int16x8_t input = vld1q_s16(vector + i);
2178     const int16x8_t sub1_result = veorq_s16(one_dup, input);
2179     vst1q_s16(result + i, sub1_result);
2180   }
2181   for (; TFLITE_UNLIKELY(i < v_size); i++) {
2182     result[i] = kOne ^ vector[i];
2183   }
2184 }
2185 
2186 namespace {
2187 
2188 #ifdef __aarch64__
IsAllZero(const int8x16_t v_s8x16)2189 inline bool IsAllZero(const int8x16_t v_s8x16) {
2190   const uint32_t u32 = vmaxvq_u32(vreinterpretq_u32_s8(v_s8x16));
2191   return !u32;
2192 }
2193 
IsAllZero(const float32x4_t v_f32x4)2194 inline bool IsAllZero(const float32x4_t v_f32x4) {
2195   const uint32x4_t cmp_result = vceqzq_f32(v_f32x4);
2196   const uint32_t u32 = vminvq_u32(cmp_result);
2197   return u32;
2198 }
2199 #else
2200 inline bool IsAllZero(const uint32x4_t u32x4) {
2201   const uint32x2_t u32x2 = vqadd_u32(vget_high_u32(u32x4), vget_low_u32(u32x4));
2202   const uint64x1_t u64 = vreinterpret_u64_u32(u32x2);
2203   return !vget_lane_u64(u64, 0);
2204 }
2205 
2206 #ifndef __SSE__
2207 // With Intel NEON-2-SSE translator library, this is a redefinition..
2208 inline bool IsAllZero(const int8x16_t v) {
2209   return IsAllZero(vreinterpretq_u32_s8(v));
2210 }
2211 #endif
2212 
2213 inline bool IsAllZero(const float32x4_t v_f32x4) {
2214   const float32x4_t zero_f32x4 = vmovq_n_f32(0.0f);
2215   // Compare-absolute greater-than, |v| > |0|, equivalently v != 0
2216   const uint32x4_t cmp_result = vcagtq_f32(v_f32x4, zero_f32x4);
2217   return IsAllZero(cmp_result);
2218 }
2219 #endif
2220 
2221 }  // namespace
2222 
NeonIsZeroVector(const float * vector,int v_size)2223 bool NeonIsZeroVector(const float* vector, int v_size) {
2224   // If v_size is not divisible by the vector size, then we need to process the
2225   // final few elements sequentially. postamble_start shows the start index
2226   // where this should happen.
2227   const int postamble_start =
2228       RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
2229 
2230   int v = 0;
2231   for (; v < postamble_start; v += kFloatValuesPerNeonVector) {
2232     const float32x4_t v_f32x4 = vld1q_f32(vector + v);
2233     if (!IsAllZero(v_f32x4)) return false;
2234   }
2235   // Postamble loop
2236   for (; TFLITE_UNLIKELY(v < v_size); ++v) {
2237     if (vector[v] != 0.0) return false;
2238   }
2239   return true;
2240 }
2241 
NeonIsZeroVector(const int8_t * vector,int v_size)2242 bool NeonIsZeroVector(const int8_t* vector, int v_size) {
2243   // If v_size is not divisible by the vector size, then we need to process the
2244   // final few elements sequentially. postamble_start shows the start index
2245   // where this should happen.
2246   const int postamble_start =
2247       RoundDownVectors<kInt8ValuesPerNeonVector>(v_size);
2248 
2249   int v = 0;
2250   for (; v < postamble_start; v += kInt8ValuesPerNeonVector) {
2251     const int8x16_t v_s8x16 = vld1q_s8(vector + v);
2252     if (!IsAllZero(v_s8x16)) return false;
2253   }
2254   // Postamble loop
2255   for (; TFLITE_UNLIKELY(v < v_size); ++v) {
2256     if (vector[v] != 0) return false;
2257   }
2258   return true;
2259 }
2260 
NeonVectorScalarMultiply(const int8_t * vector,const int v_size,const float scale,float * result)2261 void NeonVectorScalarMultiply(const int8_t* vector, const int v_size,
2262                               const float scale, float* result) {
2263   // Here the assumption is that each buffer is 4-byte aligned.
2264   TFLITE_CHECK_EQ((intptr_t)(&vector[0]) & (kNeonVectorAlignment - 1), 0);
2265   // If v_size is not divisible by kInt8ValuesPerNeonVector, we cannot use the
2266   // main vectorized loop, and we need to process sequentially. postamble_start
2267   // shows the start index where this should happen.
2268   const int postamble_start =
2269       RoundDownVectors<kInt8ValuesPerNeonVector>(v_size);
2270 
2271   // Create a vector of 4 floats with the scale value.
2272   const float32x4_t scale_f32x4 = vdupq_n_f32(scale);
2273   int v = 0;
2274   for (; v < postamble_start; v += kInt8ValuesPerNeonVector) {
2275     // Load int8 values, sixteen at a time.
2276     const int8x16_t v_i8x16 = vld1q_s8(vector + v);
2277     // Split it into two components of size eight.
2278     const int8x8_t v0_i8x8 = vget_low_s8(v_i8x16);
2279     const int8x8_t v1_i8x8 = vget_high_s8(v_i8x16);
2280     // Convert both components to int16 first.
2281     const int16x8_t v0_i16x8 = vmovl_s8(v0_i8x8);
2282     const int16x8_t v1_i16x8 = vmovl_s8(v1_i8x8);
2283     // Split each of them into two components each.
2284     const int16x4_t v0_i16x4 = vget_low_s16(v0_i16x8);
2285     const int16x4_t v1_i16x4 = vget_high_s16(v0_i16x8);
2286     const int16x4_t v2_i16x4 = vget_low_s16(v1_i16x8);
2287     const int16x4_t v3_i16x4 = vget_high_s16(v1_i16x8);
2288     // Convert these to int32 and then to float.
2289     float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4));
2290     float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4));
2291     float32x4_t v2_f32x4 = vcvtq_f32_s32(vmovl_s16(v2_i16x4));
2292     float32x4_t v3_f32x4 = vcvtq_f32_s32(vmovl_s16(v3_i16x4));
2293     // Vector multiply four floats at a time.
2294     v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4);
2295     v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4);
2296     v2_f32x4 = vmulq_f32(v2_f32x4, scale_f32x4);
2297     v3_f32x4 = vmulq_f32(v3_f32x4, scale_f32x4);
2298     // Store the results.
2299     vst1q_f32(result + v, v0_f32x4);
2300     vst1q_f32(result + v + 4, v1_f32x4);
2301     vst1q_f32(result + v + 8, v2_f32x4);
2302     vst1q_f32(result + v + 12, v3_f32x4);
2303   }
2304 
2305   if (TFLITE_UNLIKELY(v_size - postamble_start >=
2306                       (kInt8ValuesPerNeonVector >> 1))) {
2307     // Load eight int8 values, if there is at least eight remaining.
2308     const int8x8_t v_i8x8 = vld1_s8(vector + v);
2309     // Convert them to int16 first.
2310     const int16x8_t v_i16x8 = vmovl_s8(v_i8x8);
2311     // Split it into two components.
2312     const int16x4_t v0_i16x4 = vget_low_s16(v_i16x8);
2313     const int16x4_t v1_i16x4 = vget_high_s16(v_i16x8);
2314     // Convert the components two floats.
2315     float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4));
2316     float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4));
2317     // Vector multiply four floats at a time.
2318     v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4);
2319     v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4);
2320     // Store the results.
2321     vst1q_f32(result + v, v0_f32x4);
2322     vst1q_f32(result + v + 4, v1_f32x4);
2323     v += (kInt8ValuesPerNeonVector >> 1);
2324   }
2325 
2326   // Postamble loop.
2327   for (; TFLITE_UNLIKELY(v < v_size); v++) {
2328     result[v] = scale * vector[v];
2329   }
2330 }
2331 
2332 // TODO(b/185850916): Consider changing the rounding stragey from "ties to away"
2333 // to "ties to even" since vcvtnq_s32_f32 is generally more available.
RoundToNearest(const float32x4_t input)2334 inline int32x4_t RoundToNearest(const float32x4_t input) {
2335 #if __ARM_ARCH >= 8
2336   return vcvtaq_s32_f32(input);
2337 #else
2338   static const float32x4_t zero_val_dup = vdupq_n_f32(0.0f);
2339   static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
2340 
2341   const int32x4_t mask = vreinterpretq_s32_u32(vcltq_f32(input, zero_val_dup));
2342   const float32x4_t casted_mask = vcvtq_f32_s32(mask);
2343   const float32x4_t round = vaddq_f32(casted_mask, point5_val_dup);
2344   return vcvtq_s32_f32(vaddq_f32(input, round));
2345 #endif
2346 }
2347 
2348 // Note: this function caps minimum and maximum at zero, unlike the true
2349 // minmax_element. This is intentional.
NeonMinMax(const float * values,const int size,float * min,float * max)2350 inline void NeonMinMax(const float* values, const int size, float* min,
2351                        float* max) {
2352   const int postamble_start = RoundDownVectors<kFloatValuesPerNeonVector>(size);
2353   float rmin = 0.0f, rmax = 0.0f;
2354   int i = 0;
2355   if (postamble_start) {
2356     float32x4_t min_f32x4 = vld1q_f32(values);
2357     float32x4_t max_f32x4 = min_f32x4;
2358     for (i = kFloatValuesPerNeonVector; i < postamble_start;
2359          i += kFloatValuesPerNeonVector) {
2360       const float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
2361       min_f32x4 = vminq_f32(min_f32x4, value0_f32x4);
2362       max_f32x4 = vmaxq_f32(max_f32x4, value0_f32x4);
2363     }
2364 #ifdef __aarch64__
2365     rmin = std::min(rmin, vminvq_f32(min_f32x4));
2366     rmax = std::max(rmax, vmaxvq_f32(max_f32x4));
2367 #else
2368     float32x2_t min_f32x2 =
2369         vmin_f32(vget_low_f32(min_f32x4), vget_high_f32(min_f32x4));
2370     float32x2_t max_f32x2 =
2371         vmax_f32(vget_low_f32(max_f32x4), vget_high_f32(max_f32x4));
2372     min_f32x2 = vpmin_f32(min_f32x2, min_f32x2);
2373     max_f32x2 = vpmax_f32(max_f32x2, max_f32x2);
2374     rmin = std::min(rmin, vget_lane_f32(min_f32x2, 0));
2375     rmax = std::max(rmax, vget_lane_f32(max_f32x2, 0));
2376 #endif  // __aarch64__
2377   }
2378   if (TFLITE_UNLIKELY(i < size)) {
2379     const auto minmax =
2380         std::minmax_element(values + postamble_start, values + size);
2381     rmin = std::min(rmin, *minmax.first);
2382     rmax = std::max(rmax, *minmax.second);
2383   }
2384   *min = rmin;
2385   *max = rmax;
2386 }
2387 
NeonSymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float * min,float * max,float * scaling_factor)2388 void NeonSymmetricQuantizeFloats(const float* values, const int size,
2389                                  int8_t* quantized_values, float* min,
2390                                  float* max, float* scaling_factor) {
2391   // TODO(raziel): vectorize min/max calculation.
2392   auto minmax = std::minmax_element(values, values + size);
2393   *min = *minmax.first;
2394   *max = *minmax.second;
2395   NeonSymmetricQuantizeFloats(values, size, quantized_values, *min, *max,
2396                               scaling_factor);
2397 }
2398 
NeonSymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float min,float max,float * scaling_factor)2399 void NeonSymmetricQuantizeFloats(const float* values, const int size,
2400                                  int8_t* quantized_values, float min, float max,
2401                                  float* scaling_factor) {
2402   constexpr int kScale = 127;
2403   const float range = std::max(std::abs(min), std::abs(max));
2404   if (range == 0) {
2405     memset(quantized_values, 0, size * sizeof(int8_t));
2406     *scaling_factor = 1;
2407     return;
2408   }
2409   *scaling_factor = range / kScale;
2410   const float scaling_factor_inv = kScale / range;
2411 
2412   const int postamble_start =
2413       RoundDownVectors<(2 * kFloatValuesPerNeonVector)>(size);
2414 
2415   // Vectorized constants.
2416   const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
2417   const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
2418   const int32x4_t neg_scale_i32x4 = vmovq_n_s32(-kScale);
2419 
2420   int i = 0;
2421   for (; i < postamble_start; i += 2 * kFloatValuesPerNeonVector) {
2422     // Implements the vectorized version of the following:
2423     // const int32 quantized_value = static_cast<int32>(
2424     //    std::round(*scaling_factor * values[i]));
2425     float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
2426     float32x4_t value1_f32x4 =
2427         vld1q_f32(&values[i + kFloatValuesPerNeonVector]);
2428     float32x4_t mul0_f32x4 = vmulq_f32(value0_f32x4, q_factor_f32x4);
2429     float32x4_t mul1_f32x4 = vmulq_f32(value1_f32x4, q_factor_f32x4);
2430 
2431     const int32x4_t f2i0_i32x4 = RoundToNearest(mul0_f32x4);
2432     const int32x4_t f2i1_i32x4 = RoundToNearest(mul1_f32x4);
2433 
2434     // Implements the vectorized version of the following block:
2435     //  quantized_values[i] = std::min(kScale, std::max(-kScale,
2436     //  quantized_value));
2437     int32x4_t max0_i32x4 = vmaxq_s32(f2i0_i32x4, neg_scale_i32x4);
2438     int32x4_t max1_i32x4 = vmaxq_s32(f2i1_i32x4, neg_scale_i32x4);
2439     int32x4_t min0_i32x4 = vminq_s32(max0_i32x4, scale_i32x4);
2440     int32x4_t min1_i32x4 = vminq_s32(max1_i32x4, scale_i32x4);
2441 
2442     int16x4_t min0_16x4 = vmovn_s32(min0_i32x4);
2443     int16x4_t min1_16x4 = vmovn_s32(min1_i32x4);
2444 
2445     int16x8_t min_16x8 = vcombine_s16(min0_16x4, min1_16x4);
2446     int8x8_t min_s8x8 = vqmovn_s16(min_16x8);
2447     vst1_s8(&quantized_values[i], min_s8x8);
2448   }
2449 
2450   for (; TFLITE_UNLIKELY(i < size); ++i) {
2451     const int32 quantized_value =
2452         static_cast<int32>(TfLiteRound(scaling_factor_inv * values[i]));
2453     quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
2454   }
2455 }
2456 
NeonAsymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float * scaling_factor,int32_t * offset)2457 void NeonAsymmetricQuantizeFloats(const float* values, const int size,
2458                                   int8_t* quantized_values,
2459                                   float* scaling_factor, int32_t* offset) {
2460   float rmin, rmax;
2461   NeonMinMax(values, size, &rmin, &rmax);
2462 
2463   const int32_t kMinScale = -128;
2464   const int32_t kMaxScale = 127;
2465   const double qmin_double = kMinScale;
2466   const double qmax_double = kMaxScale;
2467   if (rmin == rmax) {
2468     memset(quantized_values, 0, size * sizeof(int8_t));
2469     *scaling_factor = 1;
2470     *offset = 0;
2471     return;
2472   } else {
2473     const double scale = (rmax - rmin) / (qmax_double - qmin_double);
2474     const double zero_point_from_min = qmin_double - rmin / scale;
2475     const double zero_point_from_max = qmax_double - rmax / scale;
2476     const double zero_point_from_min_error =
2477         std::abs(qmin_double) + std::abs(rmin / scale);
2478     const double zero_point_from_max_error =
2479         std::abs(qmax_double) + std::abs(rmax / scale);
2480     const double zero_point_double =
2481         zero_point_from_min_error < zero_point_from_max_error
2482             ? zero_point_from_min
2483             : zero_point_from_max;
2484     int8 nudged_zero_point = 0;
2485     if (zero_point_double <= qmin_double) {
2486       nudged_zero_point = kMinScale;
2487     } else if (zero_point_double >= qmax_double) {
2488       nudged_zero_point = kMaxScale;
2489     } else {
2490       nudged_zero_point = static_cast<int8>(round(zero_point_double));
2491     }
2492     *scaling_factor = scale;
2493     *offset = nudged_zero_point;
2494   }
2495 
2496   const int postamble_start =
2497       RoundDownVectors<(2 * kFloatValuesPerNeonVector)>(size);
2498   const float scaling_factor_inv =
2499       *scaling_factor == 0 ? 0 : 1.0 / *scaling_factor;
2500   const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
2501   const int32x4_t scale_i32x4 = vmovq_n_s32(kMaxScale);
2502   const int32x4_t neg_scale_i32x4 = vmovq_n_s32(kMinScale);
2503   const int32x4_t offset_i32x4 = vmovq_n_s32(*offset);
2504 
2505   int i = 0;
2506   for (; i < postamble_start; i += 2 * kFloatValuesPerNeonVector) {
2507     float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
2508     float32x4_t value1_f32x4 =
2509         vld1q_f32(&values[i + kFloatValuesPerNeonVector]);
2510     float32x4_t mul0_f32x4 = vmulq_f32(value0_f32x4, q_factor_f32x4);
2511     float32x4_t mul1_f32x4 = vmulq_f32(value1_f32x4, q_factor_f32x4);
2512 
2513     const int32x4_t f2i0_i32x4 = RoundToNearest(mul0_f32x4);
2514     const int32x4_t f2i1_i32x4 = RoundToNearest(mul1_f32x4);
2515 
2516     // Add offset
2517     int32x4_t q0_i32x4 = vaddq_s32(f2i0_i32x4, offset_i32x4);
2518     int32x4_t q1_i32x4 = vaddq_s32(f2i1_i32x4, offset_i32x4);
2519 
2520     int32x4_t max0_i32x4 = vmaxq_s32(q0_i32x4, neg_scale_i32x4);
2521     int32x4_t max1_i32x4 = vmaxq_s32(q1_i32x4, neg_scale_i32x4);
2522     int32x4_t min0_i32x4 = vminq_s32(max0_i32x4, scale_i32x4);
2523     int32x4_t min1_i32x4 = vminq_s32(max1_i32x4, scale_i32x4);
2524 
2525     int16x4_t min0_16x4 = vmovn_s32(min0_i32x4);
2526     int16x4_t min1_16x4 = vmovn_s32(min1_i32x4);
2527 
2528     int16x8_t min_16x8 = vcombine_s16(min0_16x4, min1_16x4);
2529     int8x8_t min_s8x8 = vqmovn_s16(min_16x8);
2530     vst1_s8(&quantized_values[i], min_s8x8);
2531   }
2532 
2533   for (; TFLITE_UNLIKELY(i < size); ++i) {
2534     const int32 quantized_value = static_cast<int32>(
2535         *offset + TfLiteRound(scaling_factor_inv * values[i]));
2536     quantized_values[i] =
2537         std::min(kMaxScale, std::max(kMinScale, quantized_value));
2538   }
2539 }
2540 
NeonVectorVectorDotProduct(const float * vector1,const float * vector2,int v_size)2541 float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
2542                                  int v_size) {
2543   // If v_size is not divisible by the vector size, then we need to process the
2544   // final few elements sequentially. postamble_start shows the start index
2545   // where this should happen.
2546   const int postamble_start =
2547       RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
2548   float32x4_t acc_32x4 = vmovq_n_f32(0.0);
2549   int v = 0;
2550   for (; v < postamble_start; v += kFloatValuesPerNeonVector) {
2551     // Load 4 float values from vector1 and vector2 and accumulator.
2552     float32x4_t v1_f32x4 = vld1q_f32(vector1 + v);
2553     float32x4_t v2_f32x4 = vld1q_f32(vector2 + v);
2554     // Vector multiply-accumulate 4 float
2555     acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4);
2556   }
2557   float result = AccumulateNeonLane(acc_32x4);
2558   // Postamble loop.
2559   for (; TFLITE_UNLIKELY(v < v_size); v++) {
2560     result += vector1[v] * vector2[v];
2561   }
2562   return result;
2563 }
2564 
NeonReductionSumVector(const float * input_vector,float * output_vector,int output_size,int reduction_size)2565 void NeonReductionSumVector(const float* input_vector, float* output_vector,
2566                             int output_size, int reduction_size) {
2567   for (int o = 0; o < output_size; o++) {
2568     // If v_size is not divisible by the vector size, then we need to process
2569     // the final few elements sequentially. postamble_start shows the start
2570     // index where this should happen.
2571     const int postamble_start =
2572         RoundDownVectors<kFloatValuesPerNeonVector>(reduction_size);
2573     float32x4_t sum_f32x4 = vmovq_n_f32(0.0);
2574     int r = 0;
2575     for (; r < postamble_start; r += kFloatValuesPerNeonVector) {
2576       float32x4_t v1_f32x4 = vld1q_f32(input_vector + r);
2577       sum_f32x4 = vaddq_f32(sum_f32x4, v1_f32x4);
2578     }
2579     float sum = AccumulateNeonLane(sum_f32x4);
2580     // Postamble loop.
2581     for (; TFLITE_UNLIKELY(r < reduction_size); r++) {
2582       sum += input_vector[r];
2583     }
2584     output_vector[o] = sum;
2585     input_vector += reduction_size;
2586   }
2587 }
2588 
NeonReductionSumVector(const int8_t * input_vector,int32_t * output_vector,const int output_size,const int reduction_size)2589 void NeonReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
2590                             const int output_size, const int reduction_size) {
2591   const int postamble_half_start =
2592       RoundDownVectors<kInt8ValuesPerNeonVector>(reduction_size);
2593   const int postamble_start =
2594       RoundDownVectors<(kInt8ValuesPerNeonVector / 2)>(reduction_size);
2595   for (int o = 0; o < output_size; ++o) {
2596     int32x4_t sum_32x4 = vmovq_n_s32(0);
2597     int r = 0;
2598     for (; r < postamble_half_start; r += kInt8ValuesPerNeonVector) {
2599       const int8x16_t s2_8x16 = vld1q_s8(input_vector + r);
2600       sum_32x4 = vpadalq_s16(sum_32x4, vpaddlq_s8(s2_8x16));
2601     }
2602     if (TFLITE_UNLIKELY(r < postamble_start)) {
2603       const int8x8_t s2_8x8 = vld1_s8(input_vector + r);
2604       sum_32x4 = vpadalq_s16(sum_32x4, vmovl_s8(s2_8x8));
2605       r += (kInt8ValuesPerNeonVector >> 1);
2606     }
2607     int32_t sum = AccumulateNeonLane(sum_32x4);
2608     for (; TFLITE_UNLIKELY(r < reduction_size); ++r) {
2609       sum += input_vector[r];
2610     }
2611     output_vector[o] = sum;
2612     input_vector += reduction_size;
2613   }
2614 }
2615 
NeonVectorBatchVectorCwiseProductAccumulate(const int16_t * vector,int v_size,const int16_t * batch_vector,int n_batch,int32_t multiplier,int shift,int16_t * result)2616 void NeonVectorBatchVectorCwiseProductAccumulate(
2617     const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch,
2618     int32_t multiplier, int shift, int16_t* result) {
2619   int32x4_t min_value_vector = vdupq_n_s32(-32768);
2620   int32x4_t max_value_vector = vdupq_n_s32(32767);
2621 
2622   for (int b = 0; b < n_batch; b++) {
2623     int v = 0;
2624     for (; v <= v_size - 16; v += 16) {
2625       int32x4x4_t prod;
2626       prod.val[0] = vmull_s16(vld1_s16(vector + v), vld1_s16(batch_vector));
2627       prod.val[1] =
2628           vmull_s16(vld1_s16(vector + v + 4), vld1_s16(batch_vector + 4));
2629       prod.val[2] =
2630           vmull_s16(vld1_s16(vector + v + 8), vld1_s16(batch_vector + 8));
2631       prod.val[3] =
2632           vmull_s16(vld1_s16(vector + v + 12), vld1_s16(batch_vector + 12));
2633       batch_vector += 16;
2634 
2635       prod = MultiplyByQuantizedMultiplier4Rows(prod, multiplier, shift);
2636 
2637       int16x4x4_t results;
2638       results.val[0] = vld1_s16(result);
2639       results.val[1] = vld1_s16(result + 4);
2640       results.val[2] = vld1_s16(result + 8);
2641       results.val[3] = vld1_s16(result + 12);
2642 
2643       prod.val[0] = vaddq_s32(prod.val[0], vmovl_s16(results.val[0]));
2644       prod.val[1] = vaddq_s32(prod.val[1], vmovl_s16(results.val[1]));
2645       prod.val[2] = vaddq_s32(prod.val[2], vmovl_s16(results.val[2]));
2646       prod.val[3] = vaddq_s32(prod.val[3], vmovl_s16(results.val[3]));
2647 
2648       prod.val[0] = vmaxq_s32(prod.val[0], min_value_vector);
2649       prod.val[1] = vmaxq_s32(prod.val[1], min_value_vector);
2650       prod.val[2] = vmaxq_s32(prod.val[2], min_value_vector);
2651       prod.val[3] = vmaxq_s32(prod.val[3], min_value_vector);
2652 
2653       prod.val[0] = vminq_s32(prod.val[0], max_value_vector);
2654       prod.val[1] = vminq_s32(prod.val[1], max_value_vector);
2655       prod.val[2] = vminq_s32(prod.val[2], max_value_vector);
2656       prod.val[3] = vminq_s32(prod.val[3], max_value_vector);
2657 
2658       vst1_s16(result, vmovn_s32(prod.val[0]));
2659       vst1_s16(result + 4, vmovn_s32(prod.val[1]));
2660       vst1_s16(result + 8, vmovn_s32(prod.val[2]));
2661       vst1_s16(result + 12, vmovn_s32(prod.val[3]));
2662 
2663       result += 16;
2664     }
2665 
2666     for (; TFLITE_UNLIKELY(v < v_size); v++) {
2667       int32_t prod = vector[v] * *batch_vector++;
2668       prod = MultiplyByQuantizedMultiplier(prod, multiplier, shift);
2669       int32_t output = prod + *result;
2670       output = std::max(std::min(32767, output), -32768);
2671       *result++ = output;
2672     }
2673   }
2674 }
2675 
NeonMeanStddevNormalization(const float * __restrict__ input_vector,float * __restrict__ output_vector,int v_size,int n_batch)2676 void NeonMeanStddevNormalization(const float* __restrict__ input_vector,
2677                                  float* __restrict__ output_vector, int v_size,
2678                                  int n_batch) {
2679   constexpr int kBlockSize = kFloatValuesPerNeonVector * 4;
2680 
2681   for (int batch = 0; batch < n_batch; ++batch) {
2682     // Calculate sum
2683     float32x4_t sum_f32x4_0 = vdupq_n_f32(0.0f);
2684     float32x4_t sum_f32x4_1 = vdupq_n_f32(0.0f);
2685     float32x4_t sum_f32x4_2 = vdupq_n_f32(0.0f);
2686     float32x4_t sum_f32x4_3 = vdupq_n_f32(0.0f);
2687     int i = 0;
2688     for (; i <= v_size - kBlockSize; i += kBlockSize) {
2689       const float32x4_t input_f32x4_0 =
2690           vld1q_f32(input_vector + i + 0 * kFloatValuesPerNeonVector);
2691       const float32x4_t input_f32x4_1 =
2692           vld1q_f32(input_vector + i + 1 * kFloatValuesPerNeonVector);
2693       const float32x4_t input_f32x4_2 =
2694           vld1q_f32(input_vector + i + 2 * kFloatValuesPerNeonVector);
2695       const float32x4_t input_f32x4_3 =
2696           vld1q_f32(input_vector + i + 3 * kFloatValuesPerNeonVector);
2697       sum_f32x4_0 = vaddq_f32(sum_f32x4_0, input_f32x4_0);
2698       sum_f32x4_1 = vaddq_f32(sum_f32x4_1, input_f32x4_1);
2699       sum_f32x4_2 = vaddq_f32(sum_f32x4_2, input_f32x4_2);
2700       sum_f32x4_3 = vaddq_f32(sum_f32x4_3, input_f32x4_3);
2701     }
2702     sum_f32x4_0 = vaddq_f32(sum_f32x4_0, sum_f32x4_2);
2703     sum_f32x4_1 = vaddq_f32(sum_f32x4_1, sum_f32x4_3);
2704     sum_f32x4_0 = vaddq_f32(sum_f32x4_0, sum_f32x4_1);
2705     float sum = AccumulateNeonLane(sum_f32x4_0);
2706     for (; TFLITE_UNLIKELY(i < v_size); ++i) {
2707       sum += input_vector[i];
2708     }
2709     // Calculate mean
2710     const float mean = sum / v_size;
2711     const float32x4_t mean_f32x4 = vdupq_n_f32(mean);
2712     // Calculate sum of squared differences
2713     float32x4_t sum_diff_sq_f32x4_0 = vdupq_n_f32(0.0f);
2714     float32x4_t sum_diff_sq_f32x4_1 = vdupq_n_f32(0.0f);
2715     float32x4_t sum_diff_sq_f32x4_2 = vdupq_n_f32(0.0f);
2716     float32x4_t sum_diff_sq_f32x4_3 = vdupq_n_f32(0.0f);
2717     i = 0;
2718     for (; i <= v_size - kBlockSize; i += kBlockSize) {
2719       const float32x4_t input_f32x4_0 =
2720           vld1q_f32(input_vector + i + 0 * kFloatValuesPerNeonVector);
2721       const float32x4_t input_f32x4_1 =
2722           vld1q_f32(input_vector + i + 1 * kFloatValuesPerNeonVector);
2723       const float32x4_t input_f32x4_2 =
2724           vld1q_f32(input_vector + i + 2 * kFloatValuesPerNeonVector);
2725       const float32x4_t input_f32x4_3 =
2726           vld1q_f32(input_vector + i + 3 * kFloatValuesPerNeonVector);
2727       const float32x4_t diff_f32x4_0 = vsubq_f32(input_f32x4_0, mean_f32x4);
2728       const float32x4_t diff_f32x4_1 = vsubq_f32(input_f32x4_1, mean_f32x4);
2729       const float32x4_t diff_f32x4_2 = vsubq_f32(input_f32x4_2, mean_f32x4);
2730       const float32x4_t diff_f32x4_3 = vsubq_f32(input_f32x4_3, mean_f32x4);
2731       sum_diff_sq_f32x4_0 =
2732           vmlaq_f32(sum_diff_sq_f32x4_0, diff_f32x4_0, diff_f32x4_0);
2733       sum_diff_sq_f32x4_1 =
2734           vmlaq_f32(sum_diff_sq_f32x4_1, diff_f32x4_1, diff_f32x4_1);
2735       sum_diff_sq_f32x4_2 =
2736           vmlaq_f32(sum_diff_sq_f32x4_2, diff_f32x4_2, diff_f32x4_2);
2737       sum_diff_sq_f32x4_3 =
2738           vmlaq_f32(sum_diff_sq_f32x4_3, diff_f32x4_3, diff_f32x4_3);
2739     }
2740     sum_diff_sq_f32x4_0 = vaddq_f32(sum_diff_sq_f32x4_0, sum_diff_sq_f32x4_2);
2741     sum_diff_sq_f32x4_1 = vaddq_f32(sum_diff_sq_f32x4_1, sum_diff_sq_f32x4_3);
2742     sum_diff_sq_f32x4_0 = vaddq_f32(sum_diff_sq_f32x4_0, sum_diff_sq_f32x4_1);
2743     float sum_diff_sq = AccumulateNeonLane(sum_diff_sq_f32x4_0);
2744     for (; TFLITE_UNLIKELY(i < v_size); ++i) {
2745       const float diff = input_vector[i] - mean;
2746       sum_diff_sq += diff * diff;
2747     }
2748     // Calculate 1/stddev
2749     const float variance = sum_diff_sq / v_size;
2750     constexpr float kNormalizationConstant = 1e-8f;
2751     const float stddev_inv =
2752         1.0f / std::sqrt(variance + kNormalizationConstant);
2753     // Do the normalization
2754     i = 0;
2755     for (; i <= v_size - kBlockSize; i += kBlockSize) {
2756       const float32x4_t input_f32x4_0 =
2757           vld1q_f32(input_vector + i + 0 * kFloatValuesPerNeonVector);
2758       const float32x4_t input_f32x4_1 =
2759           vld1q_f32(input_vector + i + 1 * kFloatValuesPerNeonVector);
2760       const float32x4_t input_f32x4_2 =
2761           vld1q_f32(input_vector + i + 2 * kFloatValuesPerNeonVector);
2762       const float32x4_t input_f32x4_3 =
2763           vld1q_f32(input_vector + i + 3 * kFloatValuesPerNeonVector);
2764       const float32x4_t tmp_0 = vsubq_f32(input_f32x4_0, mean_f32x4);
2765       const float32x4_t tmp_1 = vsubq_f32(input_f32x4_1, mean_f32x4);
2766       const float32x4_t tmp_2 = vsubq_f32(input_f32x4_2, mean_f32x4);
2767       const float32x4_t tmp_3 = vsubq_f32(input_f32x4_3, mean_f32x4);
2768       const float32x4_t output_f32x4_0 = vmulq_n_f32(tmp_0, stddev_inv);
2769       const float32x4_t output_f32x4_1 = vmulq_n_f32(tmp_1, stddev_inv);
2770       const float32x4_t output_f32x4_2 = vmulq_n_f32(tmp_2, stddev_inv);
2771       const float32x4_t output_f32x4_3 = vmulq_n_f32(tmp_3, stddev_inv);
2772       vst1q_f32(output_vector + i + 0 * kFloatValuesPerNeonVector,
2773                 output_f32x4_0);
2774       vst1q_f32(output_vector + i + 1 * kFloatValuesPerNeonVector,
2775                 output_f32x4_1);
2776       vst1q_f32(output_vector + i + 2 * kFloatValuesPerNeonVector,
2777                 output_f32x4_2);
2778       vst1q_f32(output_vector + i + 3 * kFloatValuesPerNeonVector,
2779                 output_f32x4_3);
2780     }
2781     for (; TFLITE_UNLIKELY(i < v_size); ++i) {
2782       output_vector[i] = (input_vector[i] - mean) * stddev_inv;
2783     }
2784     // Advance to next batch
2785     input_vector += v_size;
2786     output_vector += v_size;
2787   }
2788 }
2789 
2790 }  // namespace tensor_utils
2791 }  // namespace tflite
2792 
2793 #endif  // USE_NEON
2794