xref: /aosp_15_r20/external/libaom/av1/common/arm/compound_convolve_neon_i8mm.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "av1/common/arm/compound_convolve_neon.h"
17 #include "config/aom_config.h"
18 #include "config/av1_rtcd.h"
19 
20 DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
21   0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
22   4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
23   8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
24 };
25 
26 DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
27   // clang-format off
28   0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9,
29   4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13
30   // clang-format on
31 };
32 
convolve6_4_2d_h(uint8x16_t samples,const int8x16_t x_filter,const uint8x16_t permute_tbl,const int32x4_t horiz_const)33 static inline int16x4_t convolve6_4_2d_h(uint8x16_t samples,
34                                          const int8x16_t x_filter,
35                                          const uint8x16_t permute_tbl,
36                                          const int32x4_t horiz_const) {
37   // Permute samples ready for matrix multiply.
38   // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
39   uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
40 
41   // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
42   // (filter), destructively accumulating into the destination register.
43   int32x4_t sum = vusmmlaq_s32(horiz_const, permuted_samples, x_filter);
44 
45   // We halved the convolution filter values so -1 from the right shift.
46   return vshrn_n_s32(sum, ROUND0_BITS - 1);
47 }
48 
convolve6_8_2d_h(uint8x16_t samples,const int8x16_t x_filter,const uint8x16x2_t permute_tbl,const int32x4_t horiz_const)49 static inline int16x8_t convolve6_8_2d_h(uint8x16_t samples,
50                                          const int8x16_t x_filter,
51                                          const uint8x16x2_t permute_tbl,
52                                          const int32x4_t horiz_const) {
53   // Permute samples ready for matrix multiply.
54   // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
55   // { 4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13 }
56   uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
57                                      vqtbl1q_u8(samples, permute_tbl.val[1]) };
58 
59   // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
60   // (filter), destructively accumulating into the destination register.
61   int32x4_t sum0123 = vusmmlaq_s32(horiz_const, permuted_samples[0], x_filter);
62   int32x4_t sum4567 = vusmmlaq_s32(horiz_const, permuted_samples[1], x_filter);
63 
64   // Narrow and re-pack.
65   // We halved the convolution filter values so -1 from the right shift.
66   return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
67                       vshrn_n_s32(sum4567, ROUND0_BITS - 1));
68 }
69 
dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(const uint8_t * src,int src_stride,int16_t * im_block,const int im_stride,const int16_t * x_filter_ptr,const int im_h,int w)70 static inline void dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(
71     const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
72     const int16_t *x_filter_ptr, const int im_h, int w) {
73   const int bd = 8;
74   // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
75   // shifts - which are generally faster than rounding shifts on modern CPUs.
76   // (The extra -1 is needed because we halved the filter values.)
77   const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
78                                             (1 << ((ROUND0_BITS - 1) - 1)));
79 
80   // Filter values are even, so halve to reduce intermediate precision reqs.
81   const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
82   // Stagger the filter for use with the matrix multiply instructions.
83   // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
84   const int8x16_t x_filter =
85       vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
86 
87   const uint8_t *src_ptr = src;
88   int16_t *dst_ptr = im_block;
89   int dst_stride = im_stride;
90   int height = im_h;
91 
92   if (w == 4) {
93     const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
94     do {
95       uint8x16_t s0, s1, s2, s3;
96       load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
97 
98       int16x4_t d0 = convolve6_4_2d_h(s0, x_filter, permute_tbl, horiz_const);
99       int16x4_t d1 = convolve6_4_2d_h(s1, x_filter, permute_tbl, horiz_const);
100       int16x4_t d2 = convolve6_4_2d_h(s2, x_filter, permute_tbl, horiz_const);
101       int16x4_t d3 = convolve6_4_2d_h(s3, x_filter, permute_tbl, horiz_const);
102 
103       store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
104 
105       src_ptr += 4 * src_stride;
106       dst_ptr += 4 * dst_stride;
107       height -= 4;
108     } while (height > 4);
109 
110     do {
111       uint8x16_t s0 = vld1q_u8(src_ptr);
112 
113       int16x4_t d0 = convolve6_4_2d_h(s0, x_filter, permute_tbl, horiz_const);
114 
115       vst1_s16(dst_ptr, d0);
116 
117       src_ptr += src_stride;
118       dst_ptr += dst_stride;
119     } while (--height != 0);
120   } else {
121     const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
122     do {
123       const uint8_t *s = src_ptr;
124       int16_t *d = dst_ptr;
125       int width = w;
126 
127       do {
128         uint8x16_t s0, s1, s2, s3;
129         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
130 
131         int16x8_t d0 = convolve6_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
132         int16x8_t d1 = convolve6_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
133         int16x8_t d2 = convolve6_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
134         int16x8_t d3 = convolve6_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
135 
136         store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
137 
138         s += 8;
139         d += 8;
140         width -= 8;
141       } while (width > 0);
142       src_ptr += 4 * src_stride;
143       dst_ptr += 4 * dst_stride;
144       height -= 4;
145     } while (height > 4);
146 
147     do {
148       const uint8_t *s = src_ptr;
149       int16_t *d = dst_ptr;
150       int width = w;
151 
152       do {
153         uint8x16_t s0 = vld1q_u8(s);
154 
155         int16x8_t d0 = convolve6_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
156 
157         vst1q_s16(d, d0);
158 
159         s += 8;
160         d += 8;
161         width -= 8;
162       } while (width > 0);
163       src_ptr += src_stride;
164       dst_ptr += dst_stride;
165     } while (--height != 0);
166   }
167 }
168 
convolve8_8_2d_h(uint8x16_t samples,const int8x8_t x_filter,const uint8x16x3_t permute_tbl,const int32x4_t horiz_const)169 static inline int16x8_t convolve8_8_2d_h(uint8x16_t samples,
170                                          const int8x8_t x_filter,
171                                          const uint8x16x3_t permute_tbl,
172                                          const int32x4_t horiz_const) {
173   uint8x16_t permuted_samples[3];
174   int32x4_t sum[2];
175 
176   // Permute samples ready for dot product.
177   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
178   permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
179   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
180   permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
181   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
182   permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
183 
184   // First 4 output values.
185   sum[0] = vusdotq_lane_s32(horiz_const, permuted_samples[0], x_filter, 0);
186   sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
187   // Second 4 output values.
188   sum[1] = vusdotq_lane_s32(horiz_const, permuted_samples[1], x_filter, 0);
189   sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
190 
191   // Narrow and re-pack.
192   // We halved the convolution filter values so -1 from the right shift.
193   return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
194                       vshrn_n_s32(sum[1], ROUND0_BITS - 1));
195 }
196 
dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(const uint8_t * src,int src_stride,int16_t * im_block,const int im_stride,const int16_t * x_filter_ptr,const int im_h,int w)197 static inline void dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(
198     const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
199     const int16_t *x_filter_ptr, const int im_h, int w) {
200   const int bd = 8;
201   // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
202   // shifts - which are generally faster than rounding shifts on modern CPUs.
203   // (The extra -1 is needed because we halved the filter values.)
204   const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
205                                             (1 << ((ROUND0_BITS - 1) - 1)));
206 
207   const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
208   // Filter values are even, so halve to reduce intermediate precision reqs.
209   const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
210 
211   const uint8_t *src_ptr = src;
212   int16_t *dst_ptr = im_block;
213   int dst_stride = im_stride;
214   int height = im_h;
215 
216   do {
217     const uint8_t *s = src_ptr;
218     int16_t *d = dst_ptr;
219     int width = w;
220 
221     do {
222       uint8x16_t s0, s1, s2, s3;
223       load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
224 
225       int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
226       int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
227       int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
228       int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
229 
230       store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
231 
232       s += 8;
233       d += 8;
234       width -= 8;
235     } while (width > 0);
236     src_ptr += 4 * src_stride;
237     dst_ptr += 4 * dst_stride;
238     height -= 4;
239   } while (height > 4);
240 
241   do {
242     const uint8_t *s = src_ptr;
243     int16_t *d = dst_ptr;
244     int width = w;
245 
246     do {
247       uint8x16_t s0 = vld1q_u8(s);
248 
249       int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
250 
251       vst1q_s16(d, d0);
252 
253       s += 8;
254       d += 8;
255       width -= 8;
256     } while (width > 0);
257     src_ptr += src_stride;
258     dst_ptr += dst_stride;
259   } while (--height != 0);
260 }
261 
av1_dist_wtd_convolve_2d_neon_i8mm(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params)262 void av1_dist_wtd_convolve_2d_neon_i8mm(
263     const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
264     int h, const InterpFilterParams *filter_params_x,
265     const InterpFilterParams *filter_params_y, const int subpel_x_qn,
266     const int subpel_y_qn, ConvolveParams *conv_params) {
267   assert(w % 4 == 0);
268   assert(h % 4 == 0);
269 
270   DECLARE_ALIGNED(16, int16_t,
271                   im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
272 
273   const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
274   const int clamped_x_taps = x_filter_taps < 6 ? 6 : x_filter_taps;
275   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
276   const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
277 
278   const int im_h = h + clamped_y_taps - 1;
279   const int im_stride = MAX_SB_SIZE;
280   const int vert_offset = clamped_y_taps / 2 - 1;
281   const int horiz_offset = clamped_x_taps / 2 - 1;
282   const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
283   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
284       filter_params_x, subpel_x_qn & SUBPEL_MASK);
285   const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
286       filter_params_y, subpel_y_qn & SUBPEL_MASK);
287 
288   const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
289 
290   if (clamped_x_taps == 6) {
291     dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(src_ptr, src_stride, im_block,
292                                               im_stride, x_filter_ptr, im_h, w);
293   } else {
294     dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(src_ptr, src_stride, im_block,
295                                               im_stride, x_filter_ptr, im_h, w);
296   }
297 
298   if (clamped_y_taps == 6) {
299     if (conv_params->do_average) {
300       if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
301         dist_wtd_convolve_2d_vert_6tap_dist_wtd_avg_neon(
302             im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
303             w);
304       } else {
305         dist_wtd_convolve_2d_vert_6tap_avg_neon(im_block, im_stride, dst8,
306                                                 dst8_stride, conv_params,
307                                                 y_filter, h, w);
308       }
309     } else {
310       dist_wtd_convolve_2d_vert_6tap_neon(im_block, im_stride, conv_params,
311                                           y_filter, h, w);
312     }
313   } else {
314     if (conv_params->do_average) {
315       if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
316         dist_wtd_convolve_2d_vert_8tap_dist_wtd_avg_neon(
317             im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
318             w);
319       } else {
320         dist_wtd_convolve_2d_vert_8tap_avg_neon(im_block, im_stride, dst8,
321                                                 dst8_stride, conv_params,
322                                                 y_filter, h, w);
323       }
324     } else {
325       dist_wtd_convolve_2d_vert_8tap_neon(im_block, im_stride, conv_params,
326                                           y_filter, h, w);
327     }
328   }
329 }
330 
convolve6_4_x(uint8x16_t samples,const int8x16_t x_filter,const uint8x16_t permute_tbl,const int32x4_t round_offset)331 static inline uint16x4_t convolve6_4_x(uint8x16_t samples,
332                                        const int8x16_t x_filter,
333                                        const uint8x16_t permute_tbl,
334                                        const int32x4_t round_offset) {
335   // Permute samples ready for matrix multiply.
336   // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
337   uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
338 
339   // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
340   // (filter), destructively accumulating into the destination register.
341   int32x4_t sum = vusmmlaq_s32(round_offset, permuted_samples, x_filter);
342 
343   // We halved the convolution filter values so -1 from the right shift.
344   return vreinterpret_u16_s16(vshrn_n_s32(sum, ROUND0_BITS - 1));
345 }
346 
convolve6_8_x(uint8x16_t samples,const int8x16_t x_filter,const uint8x16x2_t permute_tbl,const int32x4_t round_offset)347 static inline uint16x8_t convolve6_8_x(uint8x16_t samples,
348                                        const int8x16_t x_filter,
349                                        const uint8x16x2_t permute_tbl,
350                                        const int32x4_t round_offset) {
351   // Permute samples ready for matrix multiply.
352   // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
353   // { 4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13 }
354   uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
355                                      vqtbl1q_u8(samples, permute_tbl.val[1]) };
356 
357   // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
358   // (filter), destructively accumulating into the destination register.
359   int32x4_t sum0123 = vusmmlaq_s32(round_offset, permuted_samples[0], x_filter);
360   int32x4_t sum4567 = vusmmlaq_s32(round_offset, permuted_samples[1], x_filter);
361 
362   // Narrow and re-pack.
363   // We halved the convolution filter values so -1 from the right shift.
364   int16x8_t res = vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
365                                vshrn_n_s32(sum4567, ROUND0_BITS - 1));
366   return vreinterpretq_u16_s16(res);
367 }
368 
convolve8_8_x(uint8x16_t samples,const int8x8_t x_filter,const uint8x16x3_t permute_tbl,const int32x4_t round_offset)369 static inline uint16x8_t convolve8_8_x(uint8x16_t samples,
370                                        const int8x8_t x_filter,
371                                        const uint8x16x3_t permute_tbl,
372                                        const int32x4_t round_offset) {
373   uint8x16_t permuted_samples[3];
374   int32x4_t sum[2];
375 
376   // Permute samples ready for dot product.
377   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
378   permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
379   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
380   permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
381   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
382   permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
383 
384   // First 4 output values.
385   sum[0] = vusdotq_lane_s32(round_offset, permuted_samples[0], x_filter, 0);
386   sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
387   // Second 4 output values.
388   sum[1] = vusdotq_lane_s32(round_offset, permuted_samples[1], x_filter, 0);
389   sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
390 
391   // Narrow and re-pack.
392   // We halved the convolution filter values so -1 from the right shift.
393   int16x8_t res = vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
394                                vshrn_n_s32(sum[1], ROUND0_BITS - 1));
395   return vreinterpretq_u16_s16(res);
396 }
397 
dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,uint8_t * dst8,int dst8_stride,int w,int h,const int16_t * x_filter_ptr,const uint16_t fwd_offset,const uint16_t bck_offset)398 static inline void dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(
399     const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
400     uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr,
401     const uint16_t fwd_offset, const uint16_t bck_offset) {
402   assert(w % 4 == 0);
403   assert(h % 4 == 0);
404 
405   const int bd = 8;
406   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
407   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
408                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
409   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
410   // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
411   // shifts - which are generally faster than rounding shifts on modern CPUs.
412   // (The extra -1 is needed because we halved the filter values.)
413   const int32x4_t round_offset_shim = vdupq_n_s32(
414       (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
415 
416   // Filter values are even, so halve to reduce intermediate precision reqs.
417   const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
418   // Stagger the filter for use with the matrix multiply instructions.
419   // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
420   const int8x16_t x_filter =
421       vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
422 
423   if (w == 4) {
424     const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
425     do {
426       uint8x16_t s0, s1, s2, s3;
427       load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
428 
429       uint16x4_t d0 =
430           convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim);
431       uint16x4_t d1 =
432           convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim);
433       uint16x4_t d2 =
434           convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim);
435       uint16x4_t d3 =
436           convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim);
437 
438       uint16x4_t dd0, dd1, dd2, dd3;
439       load_u16_4x4(dst, dst_stride, &dd0, &dd1, &dd2, &dd3);
440 
441       uint8x8_t d01_u8, d23_u8;
442       compute_dist_wtd_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
443                                bck_offset, round_offset_vec, &d01_u8, &d23_u8);
444 
445       store_u8x4_strided_x2(dst8 + 0 * dst8_stride, dst8_stride, d01_u8);
446       store_u8x4_strided_x2(dst8 + 2 * dst8_stride, dst8_stride, d23_u8);
447 
448       src += 4 * src_stride;
449       dst += 4 * dst_stride;
450       dst8 += 4 * dst8_stride;
451       h -= 4;
452     } while (h != 0);
453   } else {
454     const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
455     do {
456       const uint8_t *s = src;
457       uint16_t *d = dst;
458       uint8_t *d_u8 = dst8;
459       int width = w;
460 
461       do {
462         uint8x16_t s0, s1, s2, s3;
463         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
464 
465         uint16x8_t d0 =
466             convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim);
467         uint16x8_t d1 =
468             convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim);
469         uint16x8_t d2 =
470             convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim);
471         uint16x8_t d3 =
472             convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim);
473 
474         uint16x8_t dd0, dd1, dd2, dd3;
475         load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
476 
477         uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
478         compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
479                                  bck_offset, round_offset_vec, &d0_u8, &d1_u8,
480                                  &d2_u8, &d3_u8);
481 
482         store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
483 
484         s += 8;
485         d += 8;
486         d_u8 += 8;
487         width -= 8;
488       } while (width != 0);
489       src += 4 * src_stride;
490       dst += 4 * dst_stride;
491       dst8 += 4 * dst8_stride;
492       h -= 4;
493     } while (h != 0);
494   }
495 }
496 
dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,uint8_t * dst8,int dst8_stride,int w,int h,const int16_t * x_filter_ptr,const uint16_t fwd_offset,const uint16_t bck_offset)497 static inline void dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm(
498     const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
499     uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr,
500     const uint16_t fwd_offset, const uint16_t bck_offset) {
501   assert(w % 4 == 0);
502   assert(h % 4 == 0);
503 
504   const int bd = 8;
505   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
506   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
507                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
508   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
509   // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
510   // shifts - which are generally faster than rounding shifts on modern CPUs.
511   // (The extra -1 is needed because we halved the filter values.)
512   const int32x4_t round_offset_shim = vdupq_n_s32(
513       (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
514 
515   const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
516   // Filter values are even, so halve to reduce intermediate precision reqs.
517   const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
518 
519   do {
520     const uint8_t *s = src;
521     uint16_t *d = dst;
522     uint8_t *d_u8 = dst8;
523     int width = w;
524 
525     do {
526       uint8x16_t s0, s1, s2, s3;
527       load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
528 
529       uint16x8_t d0 =
530           convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
531       uint16x8_t d1 =
532           convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
533       uint16x8_t d2 =
534           convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
535       uint16x8_t d3 =
536           convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
537 
538       uint16x8_t dd0, dd1, dd2, dd3;
539       load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
540 
541       uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
542       compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
543                                bck_offset, round_offset_vec, &d0_u8, &d1_u8,
544                                &d2_u8, &d3_u8);
545 
546       store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
547 
548       s += 8;
549       d += 8;
550       d_u8 += 8;
551       width -= 8;
552     } while (width != 0);
553     src += 4 * src_stride;
554     dst += 4 * dst_stride;
555     dst8 += 4 * dst8_stride;
556     h -= 4;
557   } while (h != 0);
558 }
559 
dist_wtd_convolve_x_avg_6tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,uint8_t * dst8,int dst8_stride,int w,int h,const int16_t * x_filter_ptr)560 static inline void dist_wtd_convolve_x_avg_6tap_neon_i8mm(
561     const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
562     uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr) {
563   assert(w % 4 == 0);
564   assert(h % 4 == 0);
565 
566   const int bd = 8;
567   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
568   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
569                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
570   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
571   // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
572   // shifts - which are generally faster than rounding shifts on modern CPUs.
573   // (The extra -1 is needed because we halved the filter values.)
574   const int32x4_t round_offset_shim = vdupq_n_s32(
575       (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
576 
577   // Filter values are even, so halve to reduce intermediate precision reqs.
578   const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
579   // Stagger the filter for use with the matrix multiply instructions.
580   // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
581   const int8x16_t x_filter =
582       vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
583 
584   if (w == 4) {
585     const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
586     do {
587       uint8x16_t s0, s1, s2, s3;
588       load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
589 
590       uint16x4_t d0 =
591           convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim);
592       uint16x4_t d1 =
593           convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim);
594       uint16x4_t d2 =
595           convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim);
596       uint16x4_t d3 =
597           convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim);
598 
599       uint16x4_t dd0, dd1, dd2, dd3;
600       load_u16_4x4(dst, dst_stride, &dd0, &dd1, &dd2, &dd3);
601 
602       uint8x8_t d01_u8, d23_u8;
603       compute_basic_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
604                             round_offset_vec, &d01_u8, &d23_u8);
605 
606       store_u8x4_strided_x2(dst8 + 0 * dst8_stride, dst8_stride, d01_u8);
607       store_u8x4_strided_x2(dst8 + 2 * dst8_stride, dst8_stride, d23_u8);
608 
609       src += 4 * src_stride;
610       dst += 4 * dst_stride;
611       dst8 += 4 * dst8_stride;
612       h -= 4;
613     } while (h != 0);
614   } else {
615     const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
616     do {
617       const uint8_t *s = src;
618       uint16_t *d = dst;
619       uint8_t *d_u8 = dst8;
620       int width = w;
621 
622       do {
623         uint8x16_t s0, s1, s2, s3;
624         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
625 
626         uint16x8_t d0 =
627             convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim);
628         uint16x8_t d1 =
629             convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim);
630         uint16x8_t d2 =
631             convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim);
632         uint16x8_t d3 =
633             convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim);
634 
635         uint16x8_t dd0, dd1, dd2, dd3;
636         load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
637 
638         uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
639         compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
640                               round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
641 
642         store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
643 
644         s += 8;
645         d += 8;
646         d_u8 += 8;
647         width -= 8;
648       } while (width != 0);
649       src += 4 * src_stride;
650       dst += 4 * dst_stride;
651       dst8 += 4 * dst8_stride;
652       h -= 4;
653     } while (h != 0);
654   }
655 }
656 
dist_wtd_convolve_x_avg_8tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,uint8_t * dst8,int dst8_stride,int w,int h,const int16_t * x_filter_ptr)657 static inline void dist_wtd_convolve_x_avg_8tap_neon_i8mm(
658     const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
659     uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr) {
660   assert(w % 4 == 0);
661   assert(h % 4 == 0);
662 
663   const int bd = 8;
664   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
665   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
666                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
667   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
668   // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
669   // shifts - which are generally faster than rounding shifts on modern CPUs.
670   // (The extra -1 is needed because we halved the filter values.)
671   const int32x4_t round_offset_shim = vdupq_n_s32(
672       (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
673 
674   const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
675   // Filter values are even, so halve to reduce intermediate precision reqs.
676   const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
677 
678   do {
679     const uint8_t *s = src;
680     uint16_t *d = dst;
681     uint8_t *d_u8 = dst8;
682     int width = w;
683 
684     do {
685       uint8x16_t s0, s1, s2, s3;
686       load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
687 
688       uint16x8_t d0 =
689           convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
690       uint16x8_t d1 =
691           convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
692       uint16x8_t d2 =
693           convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
694       uint16x8_t d3 =
695           convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
696 
697       uint16x8_t dd0, dd1, dd2, dd3;
698       load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
699 
700       uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
701       compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
702                             round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
703 
704       store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
705 
706       s += 8;
707       d += 8;
708       d_u8 += 8;
709       width -= 8;
710     } while (width != 0);
711     src += 4 * src_stride;
712     dst += 4 * dst_stride;
713     dst8 += 4 * dst8_stride;
714     h -= 4;
715   } while (h != 0);
716 }
717 
dist_wtd_convolve_x_6tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,const int16_t * x_filter_ptr)718 static inline void dist_wtd_convolve_x_6tap_neon_i8mm(
719     const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
720     int h, const int16_t *x_filter_ptr) {
721   assert(w % 4 == 0);
722   assert(h % 4 == 0);
723 
724   const int bd = 8;
725   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
726   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
727                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
728   // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
729   // shifts - which are generally faster than rounding shifts on modern CPUs.
730   // (The extra -1 is needed because we halved the filter values.)
731   const int32x4_t round_offset_shim = vdupq_n_s32(
732       (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
733 
734   // Filter values are even, so halve to reduce intermediate precision reqs.
735   const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
736   // Stagger the filter for use with the matrix multiply instructions.
737   // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
738   const int8x16_t x_filter =
739       vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
740 
741   if (w == 4) {
742     const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
743     do {
744       uint8x16_t s0, s1, s2, s3;
745       load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
746 
747       uint16x4_t d0 =
748           convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim);
749       uint16x4_t d1 =
750           convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim);
751       uint16x4_t d2 =
752           convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim);
753       uint16x4_t d3 =
754           convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim);
755 
756       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
757 
758       src += 4 * src_stride;
759       dst += 4 * dst_stride;
760       h -= 4;
761     } while (h != 0);
762   } else {
763     const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
764     do {
765       const uint8_t *s = src;
766       uint16_t *d = dst;
767       int width = w;
768 
769       do {
770         uint8x16_t s0, s1, s2, s3;
771         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
772 
773         uint16x8_t d0 =
774             convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim);
775         uint16x8_t d1 =
776             convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim);
777         uint16x8_t d2 =
778             convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim);
779         uint16x8_t d3 =
780             convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim);
781 
782         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
783 
784         s += 8;
785         d += 8;
786         width -= 8;
787       } while (width != 0);
788       src += 4 * src_stride;
789       dst += 4 * dst_stride;
790       h -= 4;
791     } while (h != 0);
792   }
793 }
794 
dist_wtd_convolve_x_8tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,const int16_t * x_filter_ptr)795 static inline void dist_wtd_convolve_x_8tap_neon_i8mm(
796     const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
797     int h, const int16_t *x_filter_ptr) {
798   assert(w % 4 == 0);
799   assert(h % 4 == 0);
800 
801   const int bd = 8;
802   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
803   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
804                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
805   // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
806   // shifts - which are generally faster than rounding shifts on modern CPUs.
807   // (The extra -1 is needed because we halved the filter values.)
808   const int32x4_t round_offset_shim = vdupq_n_s32(
809       (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
810 
811   const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
812   // Filter values are even, so halve to reduce intermediate precision reqs.
813   const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
814 
815   do {
816     const uint8_t *s = src;
817     uint16_t *d = dst;
818     int width = w;
819 
820     do {
821       uint8x16_t s0, s1, s2, s3;
822       load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
823 
824       uint16x8_t d0 =
825           convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
826       uint16x8_t d1 =
827           convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
828       uint16x8_t d2 =
829           convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
830       uint16x8_t d3 =
831           convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
832 
833       store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
834 
835       s += 8;
836       d += 8;
837       width -= 8;
838     } while (width != 0);
839     src += 4 * src_stride;
840     dst += 4 * dst_stride;
841     h -= 4;
842   } while (h != 0);
843 }
844 
av1_dist_wtd_convolve_x_neon_i8mm(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)845 void av1_dist_wtd_convolve_x_neon_i8mm(
846     const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
847     int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
848     ConvolveParams *conv_params) {
849   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
850       filter_params_x, subpel_x_qn & SUBPEL_MASK);
851   const int filter_taps =
852       get_filter_tap(filter_params_x, subpel_x_qn & SUBPEL_MASK);
853 
854   src -= (SUBPEL_TAPS / 2 - 1);
855 
856   if (conv_params->do_average) {
857     if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
858       if (filter_taps < 8) {
859         dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(
860             src + 1, src_stride, conv_params->dst, conv_params->dst_stride,
861             dst8, dst8_stride, w, h, x_filter_ptr, conv_params->fwd_offset,
862             conv_params->bck_offset);
863         return;
864       }
865 
866       dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm(
867           src, src_stride, conv_params->dst, conv_params->dst_stride, dst8,
868           dst8_stride, w, h, x_filter_ptr, conv_params->fwd_offset,
869           conv_params->bck_offset);
870     } else {
871       if (filter_taps < 8) {
872         dist_wtd_convolve_x_avg_6tap_neon_i8mm(
873             src + 1, src_stride, conv_params->dst, conv_params->dst_stride,
874             dst8, dst8_stride, w, h, x_filter_ptr);
875         return;
876       }
877 
878       dist_wtd_convolve_x_avg_8tap_neon_i8mm(src, src_stride, conv_params->dst,
879                                              conv_params->dst_stride, dst8,
880                                              dst8_stride, w, h, x_filter_ptr);
881     }
882   } else {
883     if (filter_taps < 8) {
884       dist_wtd_convolve_x_6tap_neon_i8mm(src + 1, src_stride, conv_params->dst,
885                                          conv_params->dst_stride, w, h,
886                                          x_filter_ptr);
887       return;
888     }
889 
890     dist_wtd_convolve_x_8tap_neon_i8mm(src, src_stride, conv_params->dst,
891                                        conv_params->dst_stride, w, h,
892                                        x_filter_ptr);
893   }
894 }
895