xref: /aosp_15_r20/external/libaom/av1/common/arm/compound_convolve_neon.h (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 #ifndef AOM_AV1_COMMON_ARM_COMPOUND_CONVOLVE_NEON_H_
12 #define AOM_AV1_COMMON_ARM_COMPOUND_CONVOLVE_NEON_H_
13 
14 #include <arm_neon.h>
15 
16 #include "av1/common/convolve.h"
17 #include "av1/common/enums.h"
18 #include "av1/common/filter.h"
19 
compute_dist_wtd_avg_4x1(uint16x4_t dd0,uint16x4_t d0,const uint16_t fwd_offset,const uint16_t bck_offset,const int16x4_t round_offset,uint8x8_t * d0_u8)20 static inline void compute_dist_wtd_avg_4x1(uint16x4_t dd0, uint16x4_t d0,
21                                             const uint16_t fwd_offset,
22                                             const uint16_t bck_offset,
23                                             const int16x4_t round_offset,
24                                             uint8x8_t *d0_u8) {
25   uint32x4_t blend0 = vmull_n_u16(dd0, fwd_offset);
26   blend0 = vmlal_n_u16(blend0, d0, bck_offset);
27 
28   uint16x4_t avg0 = vshrn_n_u32(blend0, DIST_PRECISION_BITS);
29 
30   int16x4_t dst0 = vsub_s16(vreinterpret_s16_u16(avg0), round_offset);
31 
32   int16x8_t dst0q = vcombine_s16(dst0, vdup_n_s16(0));
33 
34   *d0_u8 = vqrshrun_n_s16(dst0q, FILTER_BITS - ROUND0_BITS);
35 }
36 
compute_basic_avg_4x1(uint16x4_t dd0,uint16x4_t d0,const int16x4_t round_offset,uint8x8_t * d0_u8)37 static inline void compute_basic_avg_4x1(uint16x4_t dd0, uint16x4_t d0,
38                                          const int16x4_t round_offset,
39                                          uint8x8_t *d0_u8) {
40   uint16x4_t avg0 = vhadd_u16(dd0, d0);
41 
42   int16x4_t dst0 = vsub_s16(vreinterpret_s16_u16(avg0), round_offset);
43 
44   int16x8_t dst0q = vcombine_s16(dst0, vdup_n_s16(0));
45 
46   *d0_u8 = vqrshrun_n_s16(dst0q, FILTER_BITS - ROUND0_BITS);
47 }
48 
compute_dist_wtd_avg_8x1(uint16x8_t dd0,uint16x8_t d0,const uint16_t fwd_offset,const uint16_t bck_offset,const int16x8_t round_offset,uint8x8_t * d0_u8)49 static inline void compute_dist_wtd_avg_8x1(uint16x8_t dd0, uint16x8_t d0,
50                                             const uint16_t fwd_offset,
51                                             const uint16_t bck_offset,
52                                             const int16x8_t round_offset,
53                                             uint8x8_t *d0_u8) {
54   uint32x4_t blend0_lo = vmull_n_u16(vget_low_u16(dd0), fwd_offset);
55   blend0_lo = vmlal_n_u16(blend0_lo, vget_low_u16(d0), bck_offset);
56   uint32x4_t blend0_hi = vmull_n_u16(vget_high_u16(dd0), fwd_offset);
57   blend0_hi = vmlal_n_u16(blend0_hi, vget_high_u16(d0), bck_offset);
58 
59   uint16x8_t avg0 = vcombine_u16(vshrn_n_u32(blend0_lo, DIST_PRECISION_BITS),
60                                  vshrn_n_u32(blend0_hi, DIST_PRECISION_BITS));
61 
62   int16x8_t dst0 = vsubq_s16(vreinterpretq_s16_u16(avg0), round_offset);
63 
64   *d0_u8 = vqrshrun_n_s16(dst0, FILTER_BITS - ROUND0_BITS);
65 }
66 
compute_basic_avg_8x1(uint16x8_t dd0,uint16x8_t d0,const int16x8_t round_offset,uint8x8_t * d0_u8)67 static inline void compute_basic_avg_8x1(uint16x8_t dd0, uint16x8_t d0,
68                                          const int16x8_t round_offset,
69                                          uint8x8_t *d0_u8) {
70   uint16x8_t avg0 = vhaddq_u16(dd0, d0);
71 
72   int16x8_t dst0 = vsubq_s16(vreinterpretq_s16_u16(avg0), round_offset);
73 
74   *d0_u8 = vqrshrun_n_s16(dst0, FILTER_BITS - ROUND0_BITS);
75 }
76 
compute_dist_wtd_avg_4x4(uint16x4_t dd0,uint16x4_t dd1,uint16x4_t dd2,uint16x4_t dd3,uint16x4_t d0,uint16x4_t d1,uint16x4_t d2,uint16x4_t d3,const uint16_t fwd_offset,const uint16_t bck_offset,const int16x8_t round_offset,uint8x8_t * d01_u8,uint8x8_t * d23_u8)77 static inline void compute_dist_wtd_avg_4x4(
78     uint16x4_t dd0, uint16x4_t dd1, uint16x4_t dd2, uint16x4_t dd3,
79     uint16x4_t d0, uint16x4_t d1, uint16x4_t d2, uint16x4_t d3,
80     const uint16_t fwd_offset, const uint16_t bck_offset,
81     const int16x8_t round_offset, uint8x8_t *d01_u8, uint8x8_t *d23_u8) {
82   uint32x4_t blend0 = vmull_n_u16(dd0, fwd_offset);
83   blend0 = vmlal_n_u16(blend0, d0, bck_offset);
84   uint32x4_t blend1 = vmull_n_u16(dd1, fwd_offset);
85   blend1 = vmlal_n_u16(blend1, d1, bck_offset);
86   uint32x4_t blend2 = vmull_n_u16(dd2, fwd_offset);
87   blend2 = vmlal_n_u16(blend2, d2, bck_offset);
88   uint32x4_t blend3 = vmull_n_u16(dd3, fwd_offset);
89   blend3 = vmlal_n_u16(blend3, d3, bck_offset);
90 
91   uint16x4_t avg0 = vshrn_n_u32(blend0, DIST_PRECISION_BITS);
92   uint16x4_t avg1 = vshrn_n_u32(blend1, DIST_PRECISION_BITS);
93   uint16x4_t avg2 = vshrn_n_u32(blend2, DIST_PRECISION_BITS);
94   uint16x4_t avg3 = vshrn_n_u32(blend3, DIST_PRECISION_BITS);
95 
96   int16x8_t dst_01 = vreinterpretq_s16_u16(vcombine_u16(avg0, avg1));
97   int16x8_t dst_23 = vreinterpretq_s16_u16(vcombine_u16(avg2, avg3));
98 
99   dst_01 = vsubq_s16(dst_01, round_offset);
100   dst_23 = vsubq_s16(dst_23, round_offset);
101 
102   *d01_u8 = vqrshrun_n_s16(dst_01, FILTER_BITS - ROUND0_BITS);
103   *d23_u8 = vqrshrun_n_s16(dst_23, FILTER_BITS - ROUND0_BITS);
104 }
105 
compute_basic_avg_4x4(uint16x4_t dd0,uint16x4_t dd1,uint16x4_t dd2,uint16x4_t dd3,uint16x4_t d0,uint16x4_t d1,uint16x4_t d2,uint16x4_t d3,const int16x8_t round_offset,uint8x8_t * d01_u8,uint8x8_t * d23_u8)106 static inline void compute_basic_avg_4x4(uint16x4_t dd0, uint16x4_t dd1,
107                                          uint16x4_t dd2, uint16x4_t dd3,
108                                          uint16x4_t d0, uint16x4_t d1,
109                                          uint16x4_t d2, uint16x4_t d3,
110                                          const int16x8_t round_offset,
111                                          uint8x8_t *d01_u8, uint8x8_t *d23_u8) {
112   uint16x4_t avg0 = vhadd_u16(dd0, d0);
113   uint16x4_t avg1 = vhadd_u16(dd1, d1);
114   uint16x4_t avg2 = vhadd_u16(dd2, d2);
115   uint16x4_t avg3 = vhadd_u16(dd3, d3);
116 
117   int16x8_t dst_01 = vreinterpretq_s16_u16(vcombine_u16(avg0, avg1));
118   int16x8_t dst_23 = vreinterpretq_s16_u16(vcombine_u16(avg2, avg3));
119 
120   dst_01 = vsubq_s16(dst_01, round_offset);
121   dst_23 = vsubq_s16(dst_23, round_offset);
122 
123   *d01_u8 = vqrshrun_n_s16(dst_01, FILTER_BITS - ROUND0_BITS);
124   *d23_u8 = vqrshrun_n_s16(dst_23, FILTER_BITS - ROUND0_BITS);
125 }
126 
compute_dist_wtd_avg_8x4(uint16x8_t dd0,uint16x8_t dd1,uint16x8_t dd2,uint16x8_t dd3,uint16x8_t d0,uint16x8_t d1,uint16x8_t d2,uint16x8_t d3,const uint16_t fwd_offset,const uint16_t bck_offset,const int16x8_t round_offset,uint8x8_t * d0_u8,uint8x8_t * d1_u8,uint8x8_t * d2_u8,uint8x8_t * d3_u8)127 static inline void compute_dist_wtd_avg_8x4(
128     uint16x8_t dd0, uint16x8_t dd1, uint16x8_t dd2, uint16x8_t dd3,
129     uint16x8_t d0, uint16x8_t d1, uint16x8_t d2, uint16x8_t d3,
130     const uint16_t fwd_offset, const uint16_t bck_offset,
131     const int16x8_t round_offset, uint8x8_t *d0_u8, uint8x8_t *d1_u8,
132     uint8x8_t *d2_u8, uint8x8_t *d3_u8) {
133   uint32x4_t blend0_lo = vmull_n_u16(vget_low_u16(dd0), fwd_offset);
134   blend0_lo = vmlal_n_u16(blend0_lo, vget_low_u16(d0), bck_offset);
135   uint32x4_t blend0_hi = vmull_n_u16(vget_high_u16(dd0), fwd_offset);
136   blend0_hi = vmlal_n_u16(blend0_hi, vget_high_u16(d0), bck_offset);
137 
138   uint32x4_t blend1_lo = vmull_n_u16(vget_low_u16(dd1), fwd_offset);
139   blend1_lo = vmlal_n_u16(blend1_lo, vget_low_u16(d1), bck_offset);
140   uint32x4_t blend1_hi = vmull_n_u16(vget_high_u16(dd1), fwd_offset);
141   blend1_hi = vmlal_n_u16(blend1_hi, vget_high_u16(d1), bck_offset);
142 
143   uint32x4_t blend2_lo = vmull_n_u16(vget_low_u16(dd2), fwd_offset);
144   blend2_lo = vmlal_n_u16(blend2_lo, vget_low_u16(d2), bck_offset);
145   uint32x4_t blend2_hi = vmull_n_u16(vget_high_u16(dd2), fwd_offset);
146   blend2_hi = vmlal_n_u16(blend2_hi, vget_high_u16(d2), bck_offset);
147 
148   uint32x4_t blend3_lo = vmull_n_u16(vget_low_u16(dd3), fwd_offset);
149   blend3_lo = vmlal_n_u16(blend3_lo, vget_low_u16(d3), bck_offset);
150   uint32x4_t blend3_hi = vmull_n_u16(vget_high_u16(dd3), fwd_offset);
151   blend3_hi = vmlal_n_u16(blend3_hi, vget_high_u16(d3), bck_offset);
152 
153   uint16x8_t avg0 = vcombine_u16(vshrn_n_u32(blend0_lo, DIST_PRECISION_BITS),
154                                  vshrn_n_u32(blend0_hi, DIST_PRECISION_BITS));
155   uint16x8_t avg1 = vcombine_u16(vshrn_n_u32(blend1_lo, DIST_PRECISION_BITS),
156                                  vshrn_n_u32(blend1_hi, DIST_PRECISION_BITS));
157   uint16x8_t avg2 = vcombine_u16(vshrn_n_u32(blend2_lo, DIST_PRECISION_BITS),
158                                  vshrn_n_u32(blend2_hi, DIST_PRECISION_BITS));
159   uint16x8_t avg3 = vcombine_u16(vshrn_n_u32(blend3_lo, DIST_PRECISION_BITS),
160                                  vshrn_n_u32(blend3_hi, DIST_PRECISION_BITS));
161 
162   int16x8_t dst0 = vsubq_s16(vreinterpretq_s16_u16(avg0), round_offset);
163   int16x8_t dst1 = vsubq_s16(vreinterpretq_s16_u16(avg1), round_offset);
164   int16x8_t dst2 = vsubq_s16(vreinterpretq_s16_u16(avg2), round_offset);
165   int16x8_t dst3 = vsubq_s16(vreinterpretq_s16_u16(avg3), round_offset);
166 
167   *d0_u8 = vqrshrun_n_s16(dst0, FILTER_BITS - ROUND0_BITS);
168   *d1_u8 = vqrshrun_n_s16(dst1, FILTER_BITS - ROUND0_BITS);
169   *d2_u8 = vqrshrun_n_s16(dst2, FILTER_BITS - ROUND0_BITS);
170   *d3_u8 = vqrshrun_n_s16(dst3, FILTER_BITS - ROUND0_BITS);
171 }
172 
compute_basic_avg_8x4(uint16x8_t dd0,uint16x8_t dd1,uint16x8_t dd2,uint16x8_t dd3,uint16x8_t d0,uint16x8_t d1,uint16x8_t d2,uint16x8_t d3,const int16x8_t round_offset,uint8x8_t * d0_u8,uint8x8_t * d1_u8,uint8x8_t * d2_u8,uint8x8_t * d3_u8)173 static inline void compute_basic_avg_8x4(uint16x8_t dd0, uint16x8_t dd1,
174                                          uint16x8_t dd2, uint16x8_t dd3,
175                                          uint16x8_t d0, uint16x8_t d1,
176                                          uint16x8_t d2, uint16x8_t d3,
177                                          const int16x8_t round_offset,
178                                          uint8x8_t *d0_u8, uint8x8_t *d1_u8,
179                                          uint8x8_t *d2_u8, uint8x8_t *d3_u8) {
180   uint16x8_t avg0 = vhaddq_u16(dd0, d0);
181   uint16x8_t avg1 = vhaddq_u16(dd1, d1);
182   uint16x8_t avg2 = vhaddq_u16(dd2, d2);
183   uint16x8_t avg3 = vhaddq_u16(dd3, d3);
184 
185   int16x8_t dst0 = vsubq_s16(vreinterpretq_s16_u16(avg0), round_offset);
186   int16x8_t dst1 = vsubq_s16(vreinterpretq_s16_u16(avg1), round_offset);
187   int16x8_t dst2 = vsubq_s16(vreinterpretq_s16_u16(avg2), round_offset);
188   int16x8_t dst3 = vsubq_s16(vreinterpretq_s16_u16(avg3), round_offset);
189 
190   *d0_u8 = vqrshrun_n_s16(dst0, FILTER_BITS - ROUND0_BITS);
191   *d1_u8 = vqrshrun_n_s16(dst1, FILTER_BITS - ROUND0_BITS);
192   *d2_u8 = vqrshrun_n_s16(dst2, FILTER_BITS - ROUND0_BITS);
193   *d3_u8 = vqrshrun_n_s16(dst3, FILTER_BITS - ROUND0_BITS);
194 }
195 
convolve6_4_2d_v(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t s4,const int16x4_t s5,const int16x8_t y_filter,const int32x4_t offset_const)196 static inline uint16x4_t convolve6_4_2d_v(
197     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
198     const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
199     const int16x8_t y_filter, const int32x4_t offset_const) {
200   const int16x4_t y_filter_0_3 = vget_low_s16(y_filter);
201   const int16x4_t y_filter_4_7 = vget_high_s16(y_filter);
202 
203   int32x4_t sum = offset_const;
204   // Filter values at indices 0 and 7 are 0.
205   sum = vmlal_lane_s16(sum, s0, y_filter_0_3, 1);
206   sum = vmlal_lane_s16(sum, s1, y_filter_0_3, 2);
207   sum = vmlal_lane_s16(sum, s2, y_filter_0_3, 3);
208   sum = vmlal_lane_s16(sum, s3, y_filter_4_7, 0);
209   sum = vmlal_lane_s16(sum, s4, y_filter_4_7, 1);
210   sum = vmlal_lane_s16(sum, s5, y_filter_4_7, 2);
211 
212   return vqrshrun_n_s32(sum, COMPOUND_ROUND1_BITS);
213 }
214 
convolve6_8_2d_v(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x8_t s5,const int16x8_t y_filter,const int32x4_t offset_const)215 static inline uint16x8_t convolve6_8_2d_v(
216     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
217     const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
218     const int16x8_t y_filter, const int32x4_t offset_const) {
219   const int16x4_t y_filter_0_3 = vget_low_s16(y_filter);
220   const int16x4_t y_filter_4_7 = vget_high_s16(y_filter);
221 
222   int32x4_t sum0 = offset_const;
223   // Filter values at indices 0 and 7 are 0.
224   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s0), y_filter_0_3, 1);
225   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), y_filter_0_3, 2);
226   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), y_filter_0_3, 3);
227   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), y_filter_4_7, 0);
228   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s4), y_filter_4_7, 1);
229   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s5), y_filter_4_7, 2);
230 
231   int32x4_t sum1 = offset_const;
232   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s0), y_filter_0_3, 1);
233   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), y_filter_0_3, 2);
234   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), y_filter_0_3, 3);
235   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), y_filter_4_7, 0);
236   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s4), y_filter_4_7, 1);
237   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s5), y_filter_4_7, 2);
238 
239   return vcombine_u16(vqrshrun_n_s32(sum0, COMPOUND_ROUND1_BITS),
240                       vqrshrun_n_s32(sum1, COMPOUND_ROUND1_BITS));
241 }
242 
dist_wtd_convolve_2d_vert_6tap_dist_wtd_avg_neon(int16_t * src_ptr,const int src_stride,uint8_t * dst8_ptr,int dst8_stride,ConvolveParams * conv_params,const int16x8_t y_filter,int h,int w)243 static inline void dist_wtd_convolve_2d_vert_6tap_dist_wtd_avg_neon(
244     int16_t *src_ptr, const int src_stride, uint8_t *dst8_ptr, int dst8_stride,
245     ConvolveParams *conv_params, const int16x8_t y_filter, int h, int w) {
246   const int bd = 8;
247   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
248   const int32x4_t offset_const = vdupq_n_s32(1 << offset_bits);
249   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
250                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
251   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
252 
253   const uint16_t fwd_offset = conv_params->fwd_offset;
254   const uint16_t bck_offset = conv_params->bck_offset;
255 
256   CONV_BUF_TYPE *dst_ptr = conv_params->dst;
257   const int dst_stride = conv_params->dst_stride;
258 
259   if (w == 4) {
260     int16x4_t s0, s1, s2, s3, s4;
261     load_s16_4x5(src_ptr, src_stride, &s0, &s1, &s2, &s3, &s4);
262     src_ptr += 5 * src_stride;
263 
264     do {
265 #if AOM_ARCH_AARCH64
266       int16x4_t s5, s6, s7, s8;
267       load_s16_4x4(src_ptr, src_stride, &s5, &s6, &s7, &s8);
268 
269       uint16x4_t d0 =
270           convolve6_4_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
271       uint16x4_t d1 =
272           convolve6_4_2d_v(s1, s2, s3, s4, s5, s6, y_filter, offset_const);
273       uint16x4_t d2 =
274           convolve6_4_2d_v(s2, s3, s4, s5, s6, s7, y_filter, offset_const);
275       uint16x4_t d3 =
276           convolve6_4_2d_v(s3, s4, s5, s6, s7, s8, y_filter, offset_const);
277 
278       uint16x4_t dd0, dd1, dd2, dd3;
279       load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
280 
281       uint8x8_t d01_u8, d23_u8;
282       compute_dist_wtd_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
283                                bck_offset, round_offset_vec, &d01_u8, &d23_u8);
284 
285       store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
286       store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
287       dst8_ptr += 4 * dst8_stride;
288 
289       s0 = s4;
290       s1 = s5;
291       s2 = s6;
292       s3 = s7;
293       s4 = s8;
294       src_ptr += 4 * src_stride;
295       dst_ptr += 4 * dst_stride;
296       h -= 4;
297 #else   // !AOM_ARCH_AARCH64
298       int16x4_t s5 = vld1_s16(src_ptr);
299 
300       uint16x4_t d0 =
301           convolve6_4_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
302 
303       uint16x4_t dd0 = vld1_u16(dst_ptr);
304 
305       uint8x8_t d01_u8;
306       compute_dist_wtd_avg_4x1(dd0, d0, fwd_offset, bck_offset,
307                                vget_low_s16(round_offset_vec), &d01_u8);
308 
309       store_u8_4x1(dst8_ptr, d01_u8);
310       dst8_ptr += dst8_stride;
311 
312       s0 = s1;
313       s1 = s2;
314       s2 = s3;
315       s3 = s4;
316       s4 = s5;
317       src_ptr += src_stride;
318       dst_ptr += dst_stride;
319       h--;
320 #endif  // AOM_ARCH_AARCH64
321     } while (h != 0);
322   } else {
323     do {
324       int16_t *s = src_ptr;
325       CONV_BUF_TYPE *d = dst_ptr;
326       uint8_t *d_u8 = dst8_ptr;
327       int height = h;
328 
329       int16x8_t s0, s1, s2, s3, s4;
330       load_s16_8x5(s, src_stride, &s0, &s1, &s2, &s3, &s4);
331       s += 5 * src_stride;
332 
333       do {
334 #if AOM_ARCH_AARCH64
335         int16x8_t s5, s6, s7, s8;
336         load_s16_8x4(s, src_stride, &s5, &s6, &s7, &s8);
337 
338         uint16x8_t d0 =
339             convolve6_8_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
340         uint16x8_t d1 =
341             convolve6_8_2d_v(s1, s2, s3, s4, s5, s6, y_filter, offset_const);
342         uint16x8_t d2 =
343             convolve6_8_2d_v(s2, s3, s4, s5, s6, s7, y_filter, offset_const);
344         uint16x8_t d3 =
345             convolve6_8_2d_v(s3, s4, s5, s6, s7, s8, y_filter, offset_const);
346 
347         uint16x8_t dd0, dd1, dd2, dd3;
348         load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
349 
350         uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
351         compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
352                                  bck_offset, round_offset_vec, &d0_u8, &d1_u8,
353                                  &d2_u8, &d3_u8);
354 
355         store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
356         d_u8 += 4 * dst8_stride;
357 
358         s0 = s4;
359         s1 = s5;
360         s2 = s6;
361         s3 = s7;
362         s4 = s8;
363         s += 4 * src_stride;
364         d += 4 * dst_stride;
365         height -= 4;
366 #else   // !AOM_ARCH_AARCH64
367         int16x8_t s5 = vld1q_s16(s);
368 
369         uint16x8_t d0 =
370             convolve6_8_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
371 
372         uint16x8_t dd0 = vld1q_u16(d);
373 
374         uint8x8_t d0_u8;
375         compute_dist_wtd_avg_8x1(dd0, d0, fwd_offset, bck_offset,
376                                  round_offset_vec, &d0_u8);
377 
378         vst1_u8(d_u8, d0_u8);
379         d_u8 += dst8_stride;
380 
381         s0 = s1;
382         s1 = s2;
383         s2 = s3;
384         s3 = s4;
385         s4 = s5;
386         s += src_stride;
387         d += dst_stride;
388         height--;
389 #endif  // AOM_ARCH_AARCH64
390       } while (height != 0);
391       src_ptr += 8;
392       dst_ptr += 8;
393       dst8_ptr += 8;
394       w -= 8;
395     } while (w != 0);
396   }
397 }
398 
dist_wtd_convolve_2d_vert_6tap_avg_neon(int16_t * src_ptr,const int src_stride,uint8_t * dst8_ptr,int dst8_stride,ConvolveParams * conv_params,const int16x8_t y_filter,int h,int w)399 static inline void dist_wtd_convolve_2d_vert_6tap_avg_neon(
400     int16_t *src_ptr, const int src_stride, uint8_t *dst8_ptr, int dst8_stride,
401     ConvolveParams *conv_params, const int16x8_t y_filter, int h, int w) {
402   const int bd = 8;
403   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
404   const int32x4_t offset_const = vdupq_n_s32(1 << offset_bits);
405   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
406                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
407   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
408 
409   CONV_BUF_TYPE *dst_ptr = conv_params->dst;
410   const int dst_stride = conv_params->dst_stride;
411 
412   if (w == 4) {
413     int16x4_t s0, s1, s2, s3, s4;
414     load_s16_4x5(src_ptr, src_stride, &s0, &s1, &s2, &s3, &s4);
415     src_ptr += 5 * src_stride;
416 
417     do {
418 #if AOM_ARCH_AARCH64
419       int16x4_t s5, s6, s7, s8;
420       load_s16_4x4(src_ptr, src_stride, &s5, &s6, &s7, &s8);
421 
422       uint16x4_t d0 =
423           convolve6_4_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
424       uint16x4_t d1 =
425           convolve6_4_2d_v(s1, s2, s3, s4, s5, s6, y_filter, offset_const);
426       uint16x4_t d2 =
427           convolve6_4_2d_v(s2, s3, s4, s5, s6, s7, y_filter, offset_const);
428       uint16x4_t d3 =
429           convolve6_4_2d_v(s3, s4, s5, s6, s7, s8, y_filter, offset_const);
430 
431       uint16x4_t dd0, dd1, dd2, dd3;
432       load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
433 
434       uint8x8_t d01_u8, d23_u8;
435       compute_basic_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
436                             round_offset_vec, &d01_u8, &d23_u8);
437 
438       store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
439       store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
440       dst8_ptr += 4 * dst8_stride;
441 
442       s0 = s4;
443       s1 = s5;
444       s2 = s6;
445       s3 = s7;
446       s4 = s8;
447       src_ptr += 4 * src_stride;
448       dst_ptr += 4 * dst_stride;
449       h -= 4;
450 #else   // !AOM_ARCH_AARCH64
451       int16x4_t s5 = vld1_s16(src_ptr);
452 
453       uint16x4_t d0 =
454           convolve6_4_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
455 
456       uint16x4_t dd0 = vld1_u16(dst_ptr);
457 
458       uint8x8_t d01_u8;
459       compute_basic_avg_4x1(dd0, d0, vget_low_s16(round_offset_vec), &d01_u8);
460 
461       store_u8_4x1(dst8_ptr, d01_u8);
462       dst8_ptr += dst8_stride;
463 
464       s0 = s1;
465       s1 = s2;
466       s2 = s3;
467       s3 = s4;
468       s4 = s5;
469       src_ptr += src_stride;
470       dst_ptr += dst_stride;
471       h--;
472 #endif  // AOM_ARCH_AARCH64
473     } while (h != 0);
474   } else {
475     do {
476       int16_t *s = src_ptr;
477       CONV_BUF_TYPE *d = dst_ptr;
478       uint8_t *d_u8 = dst8_ptr;
479       int height = h;
480 
481       int16x8_t s0, s1, s2, s3, s4;
482       load_s16_8x5(s, src_stride, &s0, &s1, &s2, &s3, &s4);
483       s += 5 * src_stride;
484 
485       do {
486 #if AOM_ARCH_AARCH64
487         int16x8_t s5, s6, s7, s8;
488         load_s16_8x4(s, src_stride, &s5, &s6, &s7, &s8);
489 
490         uint16x8_t d0 =
491             convolve6_8_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
492         uint16x8_t d1 =
493             convolve6_8_2d_v(s1, s2, s3, s4, s5, s6, y_filter, offset_const);
494         uint16x8_t d2 =
495             convolve6_8_2d_v(s2, s3, s4, s5, s6, s7, y_filter, offset_const);
496         uint16x8_t d3 =
497             convolve6_8_2d_v(s3, s4, s5, s6, s7, s8, y_filter, offset_const);
498 
499         uint16x8_t dd0, dd1, dd2, dd3;
500         load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
501 
502         uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
503         compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
504                               round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
505 
506         store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
507         d_u8 += 4 * dst8_stride;
508 
509         s0 = s4;
510         s1 = s5;
511         s2 = s6;
512         s3 = s7;
513         s4 = s8;
514         s += 4 * src_stride;
515         d += 4 * dst_stride;
516         height -= 4;
517 #else   // !AOM_ARCH_AARCH64
518         int16x8_t s5 = vld1q_s16(s);
519 
520         uint16x8_t d0 =
521             convolve6_8_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
522 
523         uint16x8_t dd0 = vld1q_u16(d);
524 
525         uint8x8_t d0_u8;
526         compute_basic_avg_8x1(dd0, d0, round_offset_vec, &d0_u8);
527 
528         vst1_u8(d_u8, d0_u8);
529         d_u8 += dst8_stride;
530 
531         s0 = s1;
532         s1 = s2;
533         s2 = s3;
534         s3 = s4;
535         s4 = s5;
536         s += src_stride;
537         d += dst_stride;
538         height--;
539 #endif  // AOM_ARCH_AARCH64
540       } while (height != 0);
541       src_ptr += 8;
542       dst_ptr += 8;
543       dst8_ptr += 8;
544       w -= 8;
545     } while (w != 0);
546   }
547 }
548 
dist_wtd_convolve_2d_vert_6tap_neon(int16_t * src_ptr,const int src_stride,ConvolveParams * conv_params,const int16x8_t y_filter,int h,int w)549 static inline void dist_wtd_convolve_2d_vert_6tap_neon(
550     int16_t *src_ptr, const int src_stride, ConvolveParams *conv_params,
551     const int16x8_t y_filter, int h, int w) {
552   const int bd = 8;
553   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
554   const int32x4_t offset_const = vdupq_n_s32(1 << offset_bits);
555 
556   CONV_BUF_TYPE *dst_ptr = conv_params->dst;
557   const int dst_stride = conv_params->dst_stride;
558 
559   if (w == 4) {
560     int16x4_t s0, s1, s2, s3, s4;
561     load_s16_4x5(src_ptr, src_stride, &s0, &s1, &s2, &s3, &s4);
562     src_ptr += 5 * src_stride;
563 
564     do {
565 #if AOM_ARCH_AARCH64
566       int16x4_t s5, s6, s7, s8;
567       load_s16_4x4(src_ptr, src_stride, &s5, &s6, &s7, &s8);
568 
569       uint16x4_t d0 =
570           convolve6_4_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
571       uint16x4_t d1 =
572           convolve6_4_2d_v(s1, s2, s3, s4, s5, s6, y_filter, offset_const);
573       uint16x4_t d2 =
574           convolve6_4_2d_v(s2, s3, s4, s5, s6, s7, y_filter, offset_const);
575       uint16x4_t d3 =
576           convolve6_4_2d_v(s3, s4, s5, s6, s7, s8, y_filter, offset_const);
577 
578       store_u16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
579 
580       s0 = s4;
581       s1 = s5;
582       s2 = s6;
583       s3 = s7;
584       s4 = s8;
585       src_ptr += 4 * src_stride;
586       dst_ptr += 4 * dst_stride;
587       h -= 4;
588 #else   // !AOM_ARCH_AARCH64
589       int16x4_t s5 = vld1_s16(src_ptr);
590 
591       uint16x4_t d0 =
592           convolve6_4_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
593 
594       vst1_u16(dst_ptr, d0);
595 
596       s0 = s1;
597       s1 = s2;
598       s2 = s3;
599       s3 = s4;
600       s4 = s5;
601       src_ptr += src_stride;
602       dst_ptr += dst_stride;
603       h--;
604 #endif  // AOM_ARCH_AARCH64
605     } while (h != 0);
606   } else {
607     do {
608       int16_t *s = src_ptr;
609       CONV_BUF_TYPE *d = dst_ptr;
610       int height = h;
611 
612       int16x8_t s0, s1, s2, s3, s4;
613       load_s16_8x5(s, src_stride, &s0, &s1, &s2, &s3, &s4);
614       s += 5 * src_stride;
615 
616       do {
617 #if AOM_ARCH_AARCH64
618         int16x8_t s5, s6, s7, s8;
619         load_s16_8x4(s, src_stride, &s5, &s6, &s7, &s8);
620 
621         uint16x8_t d0 =
622             convolve6_8_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
623         uint16x8_t d1 =
624             convolve6_8_2d_v(s1, s2, s3, s4, s5, s6, y_filter, offset_const);
625         uint16x8_t d2 =
626             convolve6_8_2d_v(s2, s3, s4, s5, s6, s7, y_filter, offset_const);
627         uint16x8_t d3 =
628             convolve6_8_2d_v(s3, s4, s5, s6, s7, s8, y_filter, offset_const);
629 
630         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
631 
632         s0 = s4;
633         s1 = s5;
634         s2 = s6;
635         s3 = s7;
636         s4 = s8;
637         s += 4 * src_stride;
638         d += 4 * dst_stride;
639         height -= 4;
640 #else   // !AOM_ARCH_AARCH64
641         int16x8_t s5 = vld1q_s16(s);
642 
643         uint16x8_t d0 =
644             convolve6_8_2d_v(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
645 
646         vst1q_u16(d, d0);
647 
648         s0 = s1;
649         s1 = s2;
650         s2 = s3;
651         s3 = s4;
652         s4 = s5;
653         s += src_stride;
654         d += dst_stride;
655         height--;
656 #endif  // AOM_ARCH_AARCH64
657       } while (height != 0);
658       src_ptr += 8;
659       dst_ptr += 8;
660       w -= 8;
661     } while (w != 0);
662   }
663 }
664 
convolve8_4_2d_v(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t s4,const int16x4_t s5,const int16x4_t s6,const int16x4_t s7,const int16x8_t y_filter,const int32x4_t offset_const)665 static inline uint16x4_t convolve8_4_2d_v(
666     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
667     const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
668     const int16x4_t s6, const int16x4_t s7, const int16x8_t y_filter,
669     const int32x4_t offset_const) {
670   const int16x4_t y_filter_0_3 = vget_low_s16(y_filter);
671   const int16x4_t y_filter_4_7 = vget_high_s16(y_filter);
672 
673   int32x4_t sum = offset_const;
674   sum = vmlal_lane_s16(sum, s0, y_filter_0_3, 0);
675   sum = vmlal_lane_s16(sum, s1, y_filter_0_3, 1);
676   sum = vmlal_lane_s16(sum, s2, y_filter_0_3, 2);
677   sum = vmlal_lane_s16(sum, s3, y_filter_0_3, 3);
678   sum = vmlal_lane_s16(sum, s4, y_filter_4_7, 0);
679   sum = vmlal_lane_s16(sum, s5, y_filter_4_7, 1);
680   sum = vmlal_lane_s16(sum, s6, y_filter_4_7, 2);
681   sum = vmlal_lane_s16(sum, s7, y_filter_4_7, 3);
682 
683   return vqrshrun_n_s32(sum, COMPOUND_ROUND1_BITS);
684 }
685 
convolve8_8_2d_v(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x8_t s5,const int16x8_t s6,const int16x8_t s7,const int16x8_t y_filter,const int32x4_t offset_const)686 static inline uint16x8_t convolve8_8_2d_v(
687     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
688     const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
689     const int16x8_t s6, const int16x8_t s7, const int16x8_t y_filter,
690     const int32x4_t offset_const) {
691   const int16x4_t y_filter_0_3 = vget_low_s16(y_filter);
692   const int16x4_t y_filter_4_7 = vget_high_s16(y_filter);
693 
694   int32x4_t sum0 = offset_const;
695   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s0), y_filter_0_3, 0);
696   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), y_filter_0_3, 1);
697   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), y_filter_0_3, 2);
698   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), y_filter_0_3, 3);
699   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s4), y_filter_4_7, 0);
700   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s5), y_filter_4_7, 1);
701   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s6), y_filter_4_7, 2);
702   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s7), y_filter_4_7, 3);
703 
704   int32x4_t sum1 = offset_const;
705   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s0), y_filter_0_3, 0);
706   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), y_filter_0_3, 1);
707   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), y_filter_0_3, 2);
708   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), y_filter_0_3, 3);
709   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s4), y_filter_4_7, 0);
710   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s5), y_filter_4_7, 1);
711   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s6), y_filter_4_7, 2);
712   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s7), y_filter_4_7, 3);
713 
714   return vcombine_u16(vqrshrun_n_s32(sum0, COMPOUND_ROUND1_BITS),
715                       vqrshrun_n_s32(sum1, COMPOUND_ROUND1_BITS));
716 }
717 
dist_wtd_convolve_2d_vert_8tap_dist_wtd_avg_neon(int16_t * src_ptr,const int src_stride,uint8_t * dst8_ptr,int dst8_stride,ConvolveParams * conv_params,const int16x8_t y_filter,int h,int w)718 static inline void dist_wtd_convolve_2d_vert_8tap_dist_wtd_avg_neon(
719     int16_t *src_ptr, const int src_stride, uint8_t *dst8_ptr, int dst8_stride,
720     ConvolveParams *conv_params, const int16x8_t y_filter, int h, int w) {
721   const int bd = 8;
722   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
723   const int32x4_t offset_const = vdupq_n_s32(1 << offset_bits);
724   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
725                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
726   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
727 
728   const uint16_t fwd_offset = conv_params->fwd_offset;
729   const uint16_t bck_offset = conv_params->bck_offset;
730 
731   CONV_BUF_TYPE *dst_ptr = conv_params->dst;
732   const int dst_stride = conv_params->dst_stride;
733 
734   if (w == 4) {
735     int16x4_t s0, s1, s2, s3, s4, s5, s6;
736     load_s16_4x7(src_ptr, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
737     src_ptr += 7 * src_stride;
738 
739     do {
740 #if AOM_ARCH_AARCH64
741       int16x4_t s7, s8, s9, s10;
742       load_s16_4x4(src_ptr, src_stride, &s7, &s8, &s9, &s10);
743 
744       uint16x4_t d0 = convolve8_4_2d_v(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
745                                        offset_const);
746       uint16x4_t d1 = convolve8_4_2d_v(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
747                                        offset_const);
748       uint16x4_t d2 = convolve8_4_2d_v(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
749                                        offset_const);
750       uint16x4_t d3 = convolve8_4_2d_v(s3, s4, s5, s6, s7, s8, s9, s10,
751                                        y_filter, offset_const);
752 
753       uint16x4_t dd0, dd1, dd2, dd3;
754       load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
755 
756       uint8x8_t d01_u8, d23_u8;
757       compute_dist_wtd_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
758                                bck_offset, round_offset_vec, &d01_u8, &d23_u8);
759 
760       store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
761       store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
762       dst8_ptr += 4 * dst8_stride;
763 
764       s0 = s4;
765       s1 = s5;
766       s2 = s6;
767       s3 = s7;
768       s4 = s8;
769       s5 = s9;
770       s6 = s10;
771       src_ptr += 4 * src_stride;
772       dst_ptr += 4 * dst_stride;
773       h -= 4;
774 #else   // !AOM_ARCH_AARCH64
775       int16x4_t s7 = vld1_s16(src_ptr);
776 
777       uint16x4_t d0 = convolve8_4_2d_v(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
778                                        offset_const);
779 
780       uint16x4_t dd0 = vld1_u16(dst_ptr);
781 
782       uint8x8_t d01_u8;
783       compute_dist_wtd_avg_4x1(dd0, d0, fwd_offset, bck_offset,
784                                vget_low_s16(round_offset_vec), &d01_u8);
785 
786       store_u8_4x1(dst8_ptr, d01_u8);
787       dst8_ptr += dst8_stride;
788 
789       s0 = s1;
790       s1 = s2;
791       s2 = s3;
792       s3 = s4;
793       s4 = s5;
794       s5 = s6;
795       s6 = s7;
796       src_ptr += src_stride;
797       dst_ptr += dst_stride;
798       h--;
799 #endif  // AOM_ARCH_AARCH64
800     } while (h != 0);
801   } else {
802     do {
803       int16_t *s = src_ptr;
804       CONV_BUF_TYPE *d = dst_ptr;
805       uint8_t *d_u8 = dst8_ptr;
806       int height = h;
807 
808       int16x8_t s0, s1, s2, s3, s4, s5, s6;
809       load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
810       s += 7 * src_stride;
811 
812       do {
813 #if AOM_ARCH_AARCH64
814         int16x8_t s7, s8, s9, s10;
815         load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
816 
817         uint16x8_t d0 = convolve8_8_2d_v(s0, s1, s2, s3, s4, s5, s6, s7,
818                                          y_filter, offset_const);
819         uint16x8_t d1 = convolve8_8_2d_v(s1, s2, s3, s4, s5, s6, s7, s8,
820                                          y_filter, offset_const);
821         uint16x8_t d2 = convolve8_8_2d_v(s2, s3, s4, s5, s6, s7, s8, s9,
822                                          y_filter, offset_const);
823         uint16x8_t d3 = convolve8_8_2d_v(s3, s4, s5, s6, s7, s8, s9, s10,
824                                          y_filter, offset_const);
825 
826         uint16x8_t dd0, dd1, dd2, dd3;
827         load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
828 
829         uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
830         compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
831                                  bck_offset, round_offset_vec, &d0_u8, &d1_u8,
832                                  &d2_u8, &d3_u8);
833 
834         store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
835         d_u8 += 4 * dst8_stride;
836 
837         s0 = s4;
838         s1 = s5;
839         s2 = s6;
840         s3 = s7;
841         s4 = s8;
842         s5 = s9;
843         s6 = s10;
844         s += 4 * src_stride;
845         d += 4 * dst_stride;
846         height -= 4;
847 #else   // !AOM_ARCH_AARCH64
848         int16x8_t s7 = vld1q_s16(s);
849 
850         uint16x8_t d0 = convolve8_8_2d_v(s0, s1, s2, s3, s4, s5, s6, s7,
851                                          y_filter, offset_const);
852 
853         uint16x8_t dd0 = vld1q_u16(d);
854 
855         uint8x8_t d0_u8;
856         compute_dist_wtd_avg_8x1(dd0, d0, fwd_offset, bck_offset,
857                                  round_offset_vec, &d0_u8);
858 
859         vst1_u8(d_u8, d0_u8);
860         d_u8 += dst8_stride;
861 
862         s0 = s1;
863         s1 = s2;
864         s2 = s3;
865         s3 = s4;
866         s4 = s5;
867         s5 = s6;
868         s6 = s7;
869         s += src_stride;
870         d += dst_stride;
871         height--;
872 #endif  // AOM_ARCH_AARCH64
873       } while (height != 0);
874       src_ptr += 8;
875       dst_ptr += 8;
876       dst8_ptr += 8;
877       w -= 8;
878     } while (w != 0);
879   }
880 }
881 
dist_wtd_convolve_2d_vert_8tap_avg_neon(int16_t * src_ptr,const int src_stride,uint8_t * dst8_ptr,int dst8_stride,ConvolveParams * conv_params,const int16x8_t y_filter,int h,int w)882 static inline void dist_wtd_convolve_2d_vert_8tap_avg_neon(
883     int16_t *src_ptr, const int src_stride, uint8_t *dst8_ptr, int dst8_stride,
884     ConvolveParams *conv_params, const int16x8_t y_filter, int h, int w) {
885   const int bd = 8;
886   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
887   const int32x4_t offset_const = vdupq_n_s32(1 << offset_bits);
888   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
889                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
890   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
891 
892   CONV_BUF_TYPE *dst_ptr = conv_params->dst;
893   const int dst_stride = conv_params->dst_stride;
894 
895   if (w == 4) {
896     int16x4_t s0, s1, s2, s3, s4, s5, s6;
897     load_s16_4x7(src_ptr, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
898     src_ptr += 7 * src_stride;
899 
900     do {
901 #if AOM_ARCH_AARCH64
902       int16x4_t s7, s8, s9, s10;
903       load_s16_4x4(src_ptr, src_stride, &s7, &s8, &s9, &s10);
904 
905       uint16x4_t d0 = convolve8_4_2d_v(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
906                                        offset_const);
907       uint16x4_t d1 = convolve8_4_2d_v(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
908                                        offset_const);
909       uint16x4_t d2 = convolve8_4_2d_v(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
910                                        offset_const);
911       uint16x4_t d3 = convolve8_4_2d_v(s3, s4, s5, s6, s7, s8, s9, s10,
912                                        y_filter, offset_const);
913 
914       uint16x4_t dd0, dd1, dd2, dd3;
915       load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
916 
917       uint8x8_t d01_u8, d23_u8;
918       compute_basic_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
919                             round_offset_vec, &d01_u8, &d23_u8);
920 
921       store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
922       store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
923       dst8_ptr += 4 * dst8_stride;
924 
925       s0 = s4;
926       s1 = s5;
927       s2 = s6;
928       s3 = s7;
929       s4 = s8;
930       s5 = s9;
931       s6 = s10;
932       src_ptr += 4 * src_stride;
933       dst_ptr += 4 * dst_stride;
934       h -= 4;
935 #else   // !AOM_ARCH_AARCH64
936       int16x4_t s7 = vld1_s16(src_ptr);
937 
938       uint16x4_t d0 = convolve8_4_2d_v(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
939                                        offset_const);
940 
941       uint16x4_t dd0 = vld1_u16(dst_ptr);
942 
943       uint8x8_t d01_u8;
944       compute_basic_avg_4x1(dd0, d0, vget_low_s16(round_offset_vec), &d01_u8);
945 
946       store_u8_4x1(dst8_ptr, d01_u8);
947       dst8_ptr += dst8_stride;
948 
949       s0 = s1;
950       s1 = s2;
951       s2 = s3;
952       s3 = s4;
953       s4 = s5;
954       s5 = s6;
955       s6 = s7;
956       src_ptr += src_stride;
957       dst_ptr += dst_stride;
958       h--;
959 #endif  // AOM_ARCH_AARCH64
960     } while (h != 0);
961   } else {
962     do {
963       int16_t *s = src_ptr;
964       CONV_BUF_TYPE *d = dst_ptr;
965       uint8_t *d_u8 = dst8_ptr;
966       int height = h;
967 
968       int16x8_t s0, s1, s2, s3, s4, s5, s6;
969       load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
970       s += 7 * src_stride;
971 
972       do {
973 #if AOM_ARCH_AARCH64
974         int16x8_t s7, s8, s9, s10;
975         load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
976 
977         uint16x8_t d0 = convolve8_8_2d_v(s0, s1, s2, s3, s4, s5, s6, s7,
978                                          y_filter, offset_const);
979         uint16x8_t d1 = convolve8_8_2d_v(s1, s2, s3, s4, s5, s6, s7, s8,
980                                          y_filter, offset_const);
981         uint16x8_t d2 = convolve8_8_2d_v(s2, s3, s4, s5, s6, s7, s8, s9,
982                                          y_filter, offset_const);
983         uint16x8_t d3 = convolve8_8_2d_v(s3, s4, s5, s6, s7, s8, s9, s10,
984                                          y_filter, offset_const);
985 
986         uint16x8_t dd0, dd1, dd2, dd3;
987         load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
988 
989         uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
990         compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
991                               round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
992 
993         store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
994         d_u8 += 4 * dst8_stride;
995 
996         s0 = s4;
997         s1 = s5;
998         s2 = s6;
999         s3 = s7;
1000         s4 = s8;
1001         s5 = s9;
1002         s6 = s10;
1003         s += 4 * src_stride;
1004         d += 4 * dst_stride;
1005         height -= 4;
1006 #else   // !AOM_ARCH_AARCH64
1007         int16x8_t s7 = vld1q_s16(s);
1008 
1009         uint16x8_t d0 = convolve8_8_2d_v(s0, s1, s2, s3, s4, s5, s6, s7,
1010                                          y_filter, offset_const);
1011 
1012         uint16x8_t dd0 = vld1q_u16(d);
1013 
1014         uint8x8_t d0_u8;
1015         compute_basic_avg_8x1(dd0, d0, round_offset_vec, &d0_u8);
1016 
1017         vst1_u8(d_u8, d0_u8);
1018         d_u8 += dst8_stride;
1019 
1020         s0 = s1;
1021         s1 = s2;
1022         s2 = s3;
1023         s3 = s4;
1024         s4 = s5;
1025         s5 = s6;
1026         s6 = s7;
1027         s += src_stride;
1028         d += dst_stride;
1029         height--;
1030 #endif  // AOM_ARCH_AARCH64
1031       } while (height != 0);
1032       src_ptr += 8;
1033       dst_ptr += 8;
1034       dst8_ptr += 8;
1035       w -= 8;
1036     } while (w != 0);
1037   }
1038 }
1039 
dist_wtd_convolve_2d_vert_8tap_neon(int16_t * src_ptr,const int src_stride,ConvolveParams * conv_params,const int16x8_t y_filter,int h,int w)1040 static inline void dist_wtd_convolve_2d_vert_8tap_neon(
1041     int16_t *src_ptr, const int src_stride, ConvolveParams *conv_params,
1042     const int16x8_t y_filter, int h, int w) {
1043   const int bd = 8;
1044   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
1045   const int32x4_t offset_const = vdupq_n_s32(1 << offset_bits);
1046 
1047   CONV_BUF_TYPE *dst_ptr = conv_params->dst;
1048   const int dst_stride = conv_params->dst_stride;
1049 
1050   if (w == 4) {
1051     int16x4_t s0, s1, s2, s3, s4, s5, s6;
1052     load_s16_4x7(src_ptr, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
1053     src_ptr += 7 * src_stride;
1054 
1055     do {
1056 #if AOM_ARCH_AARCH64
1057       int16x4_t s7, s8, s9, s10;
1058       load_s16_4x4(src_ptr, src_stride, &s7, &s8, &s9, &s10);
1059 
1060       uint16x4_t d0 = convolve8_4_2d_v(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
1061                                        offset_const);
1062       uint16x4_t d1 = convolve8_4_2d_v(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
1063                                        offset_const);
1064       uint16x4_t d2 = convolve8_4_2d_v(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
1065                                        offset_const);
1066       uint16x4_t d3 = convolve8_4_2d_v(s3, s4, s5, s6, s7, s8, s9, s10,
1067                                        y_filter, offset_const);
1068 
1069       store_u16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
1070 
1071       s0 = s4;
1072       s1 = s5;
1073       s2 = s6;
1074       s3 = s7;
1075       s4 = s8;
1076       s5 = s9;
1077       s6 = s10;
1078       src_ptr += 4 * src_stride;
1079       dst_ptr += 4 * dst_stride;
1080       h -= 4;
1081 #else   // !AOM_ARCH_AARCH64
1082       int16x4_t s7 = vld1_s16(src_ptr);
1083 
1084       uint16x4_t d0 = convolve8_4_2d_v(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
1085                                        offset_const);
1086 
1087       vst1_u16(dst_ptr, d0);
1088 
1089       s0 = s1;
1090       s1 = s2;
1091       s2 = s3;
1092       s3 = s4;
1093       s4 = s5;
1094       s5 = s6;
1095       s6 = s7;
1096       src_ptr += src_stride;
1097       dst_ptr += dst_stride;
1098       h--;
1099 #endif  // AOM_ARCH_AARCH64
1100     } while (h != 0);
1101   } else {
1102     do {
1103       int16_t *s = src_ptr;
1104       CONV_BUF_TYPE *d = dst_ptr;
1105       int height = h;
1106 
1107       int16x8_t s0, s1, s2, s3, s4, s5, s6;
1108       load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
1109       s += 7 * src_stride;
1110 
1111       do {
1112 #if AOM_ARCH_AARCH64
1113         int16x8_t s7, s8, s9, s10;
1114         load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
1115 
1116         uint16x8_t d0 = convolve8_8_2d_v(s0, s1, s2, s3, s4, s5, s6, s7,
1117                                          y_filter, offset_const);
1118         uint16x8_t d1 = convolve8_8_2d_v(s1, s2, s3, s4, s5, s6, s7, s8,
1119                                          y_filter, offset_const);
1120         uint16x8_t d2 = convolve8_8_2d_v(s2, s3, s4, s5, s6, s7, s8, s9,
1121                                          y_filter, offset_const);
1122         uint16x8_t d3 = convolve8_8_2d_v(s3, s4, s5, s6, s7, s8, s9, s10,
1123                                          y_filter, offset_const);
1124 
1125         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
1126 
1127         s0 = s4;
1128         s1 = s5;
1129         s2 = s6;
1130         s3 = s7;
1131         s4 = s8;
1132         s5 = s9;
1133         s6 = s10;
1134         s += 4 * src_stride;
1135         d += 4 * dst_stride;
1136         height -= 4;
1137 #else   // !AOM_ARCH_AARCH64
1138         int16x8_t s7 = vld1q_s16(s);
1139 
1140         uint16x8_t d0 = convolve8_8_2d_v(s0, s1, s2, s3, s4, s5, s6, s7,
1141                                          y_filter, offset_const);
1142 
1143         vst1q_u16(d, d0);
1144 
1145         s0 = s1;
1146         s1 = s2;
1147         s2 = s3;
1148         s3 = s4;
1149         s4 = s5;
1150         s5 = s6;
1151         s6 = s7;
1152         s += src_stride;
1153         d += dst_stride;
1154         height--;
1155 #endif  // AOM_ARCH_AARCH64
1156       } while (height != 0);
1157       src_ptr += 8;
1158       dst_ptr += 8;
1159       w -= 8;
1160     } while (w != 0);
1161   }
1162 }
1163 
1164 #endif  // AOM_AV1_COMMON_ARM_COMPOUND_CONVOLVE_NEON_H_
1165