xref: /aosp_15_r20/external/libaom/av1/common/arm/av1_convolve_scale_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 #include <stdint.h>
15 
16 #include "config/aom_config.h"
17 #include "config/av1_rtcd.h"
18 
19 #include "aom_dsp/aom_dsp_common.h"
20 #include "aom_dsp/aom_filter.h"
21 #include "aom_dsp/arm/mem_neon.h"
22 #include "aom_dsp/arm/transpose_neon.h"
23 #include "av1/common/arm/convolve_scale_neon.h"
24 #include "av1/common/convolve.h"
25 #include "av1/common/filter.h"
26 
convolve8_4_h(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t s4,const int16x4_t s5,const int16x4_t s6,const int16x4_t s7,const int16x8_t filter,const int32x4_t horiz_const)27 static inline int16x4_t convolve8_4_h(const int16x4_t s0, const int16x4_t s1,
28                                       const int16x4_t s2, const int16x4_t s3,
29                                       const int16x4_t s4, const int16x4_t s5,
30                                       const int16x4_t s6, const int16x4_t s7,
31                                       const int16x8_t filter,
32                                       const int32x4_t horiz_const) {
33   int16x4_t filter_lo = vget_low_s16(filter);
34   int16x4_t filter_hi = vget_high_s16(filter);
35 
36   int32x4_t sum = horiz_const;
37   sum = vmlal_lane_s16(sum, s0, filter_lo, 0);
38   sum = vmlal_lane_s16(sum, s1, filter_lo, 1);
39   sum = vmlal_lane_s16(sum, s2, filter_lo, 2);
40   sum = vmlal_lane_s16(sum, s3, filter_lo, 3);
41   sum = vmlal_lane_s16(sum, s4, filter_hi, 0);
42   sum = vmlal_lane_s16(sum, s5, filter_hi, 1);
43   sum = vmlal_lane_s16(sum, s6, filter_hi, 2);
44   sum = vmlal_lane_s16(sum, s7, filter_hi, 3);
45 
46   return vshrn_n_s32(sum, ROUND0_BITS);
47 }
48 
convolve8_8_h(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x8_t s5,const int16x8_t s6,const int16x8_t s7,const int16x8_t filter,const int16x8_t horiz_const)49 static inline int16x8_t convolve8_8_h(const int16x8_t s0, const int16x8_t s1,
50                                       const int16x8_t s2, const int16x8_t s3,
51                                       const int16x8_t s4, const int16x8_t s5,
52                                       const int16x8_t s6, const int16x8_t s7,
53                                       const int16x8_t filter,
54                                       const int16x8_t horiz_const) {
55   int16x4_t filter_lo = vget_low_s16(filter);
56   int16x4_t filter_hi = vget_high_s16(filter);
57 
58   int16x8_t sum = horiz_const;
59   sum = vmlaq_lane_s16(sum, s0, filter_lo, 0);
60   sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
61   sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
62   sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
63   sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
64   sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
65   sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
66   sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
67 
68   return vshrq_n_s16(sum, ROUND0_BITS - 1);
69 }
70 
convolve_horiz_scale_8tap_neon(const uint8_t * src,int src_stride,int16_t * dst,int dst_stride,int w,int h,const int16_t * x_filter,const int subpel_x_qn,const int x_step_qn)71 static inline void convolve_horiz_scale_8tap_neon(const uint8_t *src,
72                                                   int src_stride, int16_t *dst,
73                                                   int dst_stride, int w, int h,
74                                                   const int16_t *x_filter,
75                                                   const int subpel_x_qn,
76                                                   const int x_step_qn) {
77   DECLARE_ALIGNED(16, int16_t, temp[8 * 8]);
78   const int bd = 8;
79 
80   if (w == 4) {
81     // The shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts.
82     const int32x4_t horiz_offset =
83         vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
84 
85     do {
86       int x_qn = subpel_x_qn;
87 
88       // Process a 4x4 tile.
89       for (int r = 0; r < 4; ++r) {
90         const uint8_t *const s = &src[x_qn >> SCALE_SUBPEL_BITS];
91 
92         const ptrdiff_t filter_offset =
93             SUBPEL_TAPS * ((x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS);
94         const int16x8_t filter = vld1q_s16(x_filter + filter_offset);
95 
96         uint8x8_t t0, t1, t2, t3;
97         load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3);
98 
99         transpose_elems_inplace_u8_8x4(&t0, &t1, &t2, &t3);
100 
101         int16x4_t s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
102         int16x4_t s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
103         int16x4_t s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
104         int16x4_t s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
105         int16x4_t s4 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
106         int16x4_t s5 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
107         int16x4_t s6 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
108         int16x4_t s7 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
109 
110         int16x4_t d0 =
111             convolve8_4_h(s0, s1, s2, s3, s4, s5, s6, s7, filter, horiz_offset);
112 
113         vst1_s16(&temp[r * 4], d0);
114         x_qn += x_step_qn;
115       }
116 
117       // Transpose the 4x4 result tile and store.
118       int16x4_t d0, d1, d2, d3;
119       load_s16_4x4(temp, 4, &d0, &d1, &d2, &d3);
120 
121       transpose_elems_inplace_s16_4x4(&d0, &d1, &d2, &d3);
122 
123       store_s16_4x4(dst, dst_stride, d0, d1, d2, d3);
124 
125       dst += 4 * dst_stride;
126       src += 4 * src_stride;
127       h -= 4;
128     } while (h > 0);
129   } else {
130     // The shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts.
131     // The additional -1 is needed because we are halving the filter values.
132     const int16x8_t horiz_offset =
133         vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) + (1 << (ROUND0_BITS - 2)));
134 
135     do {
136       int x_qn = subpel_x_qn;
137       int16_t *d = dst;
138       int width = w;
139 
140       do {
141         // Process an 8x8 tile.
142         for (int r = 0; r < 8; ++r) {
143           const uint8_t *const s = &src[(x_qn >> SCALE_SUBPEL_BITS)];
144 
145           const ptrdiff_t filter_offset =
146               SUBPEL_TAPS * ((x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS);
147           int16x8_t filter = vld1q_s16(x_filter + filter_offset);
148           // Filter values are all even so halve them to allow convolution
149           // kernel computations to stay in 16-bit element types.
150           filter = vshrq_n_s16(filter, 1);
151 
152           uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
153           load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
154 
155           transpose_elems_u8_8x8(t0, t1, t2, t3, t4, t5, t6, t7, &t0, &t1, &t2,
156                                  &t3, &t4, &t5, &t6, &t7);
157 
158           int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
159           int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
160           int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
161           int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
162           int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
163           int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
164           int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
165           int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t7));
166 
167           int16x8_t d0 = convolve8_8_h(s0, s1, s2, s3, s4, s5, s6, s7, filter,
168                                        horiz_offset);
169 
170           vst1q_s16(&temp[r * 8], d0);
171 
172           x_qn += x_step_qn;
173         }
174 
175         // Transpose the 8x8 result tile and store.
176         int16x8_t d0, d1, d2, d3, d4, d5, d6, d7;
177         load_s16_8x8(temp, 8, &d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
178 
179         transpose_elems_inplace_s16_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
180 
181         store_s16_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
182 
183         d += 8;
184         width -= 8;
185       } while (width != 0);
186 
187       dst += 8 * dst_stride;
188       src += 8 * src_stride;
189       h -= 8;
190     } while (h > 0);
191   }
192 }
193 
convolve6_4_h(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t s4,const int16x4_t s5,const int16x8_t filter,const int32x4_t horiz_const)194 static inline int16x4_t convolve6_4_h(const int16x4_t s0, const int16x4_t s1,
195                                       const int16x4_t s2, const int16x4_t s3,
196                                       const int16x4_t s4, const int16x4_t s5,
197                                       const int16x8_t filter,
198                                       const int32x4_t horiz_const) {
199   int16x4_t filter_lo = vget_low_s16(filter);
200   int16x4_t filter_hi = vget_high_s16(filter);
201 
202   int32x4_t sum = horiz_const;
203   // Filter values at indices 0 and 7 are 0.
204   sum = vmlal_lane_s16(sum, s0, filter_lo, 1);
205   sum = vmlal_lane_s16(sum, s1, filter_lo, 2);
206   sum = vmlal_lane_s16(sum, s2, filter_lo, 3);
207   sum = vmlal_lane_s16(sum, s3, filter_hi, 0);
208   sum = vmlal_lane_s16(sum, s4, filter_hi, 1);
209   sum = vmlal_lane_s16(sum, s5, filter_hi, 2);
210 
211   return vshrn_n_s32(sum, ROUND0_BITS);
212 }
213 
convolve6_8_h(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x8_t s5,const int16x8_t filter,const int16x8_t horiz_const)214 static inline int16x8_t convolve6_8_h(const int16x8_t s0, const int16x8_t s1,
215                                       const int16x8_t s2, const int16x8_t s3,
216                                       const int16x8_t s4, const int16x8_t s5,
217                                       const int16x8_t filter,
218                                       const int16x8_t horiz_const) {
219   int16x4_t filter_lo = vget_low_s16(filter);
220   int16x4_t filter_hi = vget_high_s16(filter);
221 
222   int16x8_t sum = horiz_const;
223   // Filter values at indices 0 and 7 are 0.
224   sum = vmlaq_lane_s16(sum, s0, filter_lo, 1);
225   sum = vmlaq_lane_s16(sum, s1, filter_lo, 2);
226   sum = vmlaq_lane_s16(sum, s2, filter_lo, 3);
227   sum = vmlaq_lane_s16(sum, s3, filter_hi, 0);
228   sum = vmlaq_lane_s16(sum, s4, filter_hi, 1);
229   sum = vmlaq_lane_s16(sum, s5, filter_hi, 2);
230 
231   // We halved the filter values so -1 from right shift.
232   return vshrq_n_s16(sum, ROUND0_BITS - 1);
233 }
234 
convolve_horiz_scale_6tap_neon(const uint8_t * src,int src_stride,int16_t * dst,int dst_stride,int w,int h,const int16_t * x_filter,const int subpel_x_qn,const int x_step_qn)235 static inline void convolve_horiz_scale_6tap_neon(const uint8_t *src,
236                                                   int src_stride, int16_t *dst,
237                                                   int dst_stride, int w, int h,
238                                                   const int16_t *x_filter,
239                                                   const int subpel_x_qn,
240                                                   const int x_step_qn) {
241   DECLARE_ALIGNED(16, int16_t, temp[8 * 8]);
242   const int bd = 8;
243 
244   if (w == 4) {
245     // The shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts.
246     const int32x4_t horiz_offset =
247         vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
248 
249     do {
250       int x_qn = subpel_x_qn;
251 
252       // Process a 4x4 tile.
253       for (int r = 0; r < 4; ++r) {
254         const uint8_t *const s = &src[x_qn >> SCALE_SUBPEL_BITS];
255 
256         const ptrdiff_t filter_offset =
257             SUBPEL_TAPS * ((x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS);
258         const int16x8_t filter = vld1q_s16(x_filter + filter_offset);
259 
260         uint8x8_t t0, t1, t2, t3;
261         load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3);
262 
263         transpose_elems_inplace_u8_8x4(&t0, &t1, &t2, &t3);
264 
265         int16x4_t s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
266         int16x4_t s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
267         int16x4_t s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
268         int16x4_t s3 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
269         int16x4_t s4 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
270         int16x4_t s5 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
271 
272         int16x4_t d0 =
273             convolve6_4_h(s0, s1, s2, s3, s4, s5, filter, horiz_offset);
274 
275         vst1_s16(&temp[r * 4], d0);
276         x_qn += x_step_qn;
277       }
278 
279       // Transpose the 4x4 result tile and store.
280       int16x4_t d0, d1, d2, d3;
281       load_s16_4x4(temp, 4, &d0, &d1, &d2, &d3);
282 
283       transpose_elems_inplace_s16_4x4(&d0, &d1, &d2, &d3);
284 
285       store_s16_4x4(dst, dst_stride, d0, d1, d2, d3);
286 
287       dst += 4 * dst_stride;
288       src += 4 * src_stride;
289       h -= 4;
290     } while (h > 0);
291   } else {
292     // The shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts.
293     // The additional -1 is needed because we are halving the filter values.
294     const int16x8_t horiz_offset =
295         vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) + (1 << (ROUND0_BITS - 2)));
296 
297     do {
298       int x_qn = subpel_x_qn;
299       int16_t *d = dst;
300       int width = w;
301 
302       do {
303         // Process an 8x8 tile.
304         for (int r = 0; r < 8; ++r) {
305           const uint8_t *const s = &src[(x_qn >> SCALE_SUBPEL_BITS)];
306 
307           const ptrdiff_t filter_offset =
308               SUBPEL_TAPS * ((x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS);
309           int16x8_t filter = vld1q_s16(x_filter + filter_offset);
310           // Filter values are all even so halve them to allow convolution
311           // kernel computations to stay in 16-bit element types.
312           filter = vshrq_n_s16(filter, 1);
313 
314           uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
315           load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
316 
317           transpose_elems_u8_8x8(t0, t1, t2, t3, t4, t5, t6, t7, &t0, &t1, &t2,
318                                  &t3, &t4, &t5, &t6, &t7);
319 
320           int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t1));
321           int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t2));
322           int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t3));
323           int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t4));
324           int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t5));
325           int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t6));
326 
327           int16x8_t d0 =
328               convolve6_8_h(s0, s1, s2, s3, s4, s5, filter, horiz_offset);
329 
330           vst1q_s16(&temp[r * 8], d0);
331 
332           x_qn += x_step_qn;
333         }
334 
335         // Transpose the 8x8 result tile and store.
336         int16x8_t d0, d1, d2, d3, d4, d5, d6, d7;
337         load_s16_8x8(temp, 8, &d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
338 
339         transpose_elems_inplace_s16_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
340 
341         store_s16_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
342 
343         d += 8;
344         width -= 8;
345       } while (width != 0);
346 
347       dst += 8 * dst_stride;
348       src += 8 * src_stride;
349       h -= 8;
350     } while (h > 0);
351   }
352 }
353 
convolve_horiz_scale_2_8tap_neon(const uint8_t * src,int src_stride,int16_t * dst,int dst_stride,int w,int h,const int16_t * x_filter)354 static inline void convolve_horiz_scale_2_8tap_neon(
355     const uint8_t *src, int src_stride, int16_t *dst, int dst_stride, int w,
356     int h, const int16_t *x_filter) {
357   const int bd = 8;
358 
359   if (w == 4) {
360     // A shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
361     // shifts - which are generally faster than rounding shifts on modern CPUs.
362     const int32x4_t horiz_offset =
363         vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
364     const int16x8_t filter = vld1q_s16(x_filter);
365 
366     do {
367       uint8x16_t t0, t1, t2, t3;
368       load_u8_16x4(src, src_stride, &t0, &t1, &t2, &t3);
369       transpose_elems_inplace_u8_16x4(&t0, &t1, &t2, &t3);
370 
371       int16x8_t tt0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t0)));
372       int16x8_t tt1 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t1)));
373       int16x8_t tt2 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t2)));
374       int16x8_t tt3 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t3)));
375       int16x8_t tt4 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t0)));
376       int16x8_t tt5 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t1)));
377       int16x8_t tt6 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t2)));
378       int16x8_t tt7 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t3)));
379 
380       int16x4_t s0 = vget_low_s16(tt0);
381       int16x4_t s1 = vget_low_s16(tt1);
382       int16x4_t s2 = vget_low_s16(tt2);
383       int16x4_t s3 = vget_low_s16(tt3);
384       int16x4_t s4 = vget_high_s16(tt0);
385       int16x4_t s5 = vget_high_s16(tt1);
386       int16x4_t s6 = vget_high_s16(tt2);
387       int16x4_t s7 = vget_high_s16(tt3);
388       int16x4_t s8 = vget_low_s16(tt4);
389       int16x4_t s9 = vget_low_s16(tt5);
390       int16x4_t s10 = vget_low_s16(tt6);
391       int16x4_t s11 = vget_low_s16(tt7);
392       int16x4_t s12 = vget_high_s16(tt4);
393       int16x4_t s13 = vget_high_s16(tt5);
394 
395       int16x4_t d0 =
396           convolve8_4_h(s0, s1, s2, s3, s4, s5, s6, s7, filter, horiz_offset);
397       int16x4_t d1 =
398           convolve8_4_h(s2, s3, s4, s5, s6, s7, s8, s9, filter, horiz_offset);
399       int16x4_t d2 =
400           convolve8_4_h(s4, s5, s6, s7, s8, s9, s10, s11, filter, horiz_offset);
401       int16x4_t d3 = convolve8_4_h(s6, s7, s8, s9, s10, s11, s12, s13, filter,
402                                    horiz_offset);
403 
404       transpose_elems_inplace_s16_4x4(&d0, &d1, &d2, &d3);
405 
406       store_s16_4x4(dst, dst_stride, d0, d1, d2, d3);
407 
408       dst += 4 * dst_stride;
409       src += 4 * src_stride;
410       h -= 4;
411     } while (h > 0);
412   } else {
413     // A shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
414     // shifts - which are generally faster than rounding shifts on modern CPUs.
415     // The additional -1 is needed because we are halving the filter values.
416     const int16x8_t horiz_offset =
417         vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) + (1 << (ROUND0_BITS - 2)));
418     // Filter values are all even so halve them to allow convolution
419     // kernel computations to stay in 16-bit element types.
420     const int16x8_t filter = vshrq_n_s16(vld1q_s16(x_filter), 1);
421 
422     do {
423       const uint8_t *s = src;
424       int16_t *d = dst;
425       int width = w;
426 
427       uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
428       load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
429       transpose_elems_u8_8x8(t0, t1, t2, t3, t4, t5, t6, t7, &t0, &t1, &t2, &t3,
430                              &t4, &t5, &t6, &t7);
431 
432       s += 8;
433 
434       int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
435       int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
436       int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
437       int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
438       int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
439       int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
440       int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
441       int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t7));
442 
443       do {
444         uint8x8_t t8, t9, t10, t11, t12, t13, t14, t15;
445         load_u8_8x8(s, src_stride, &t8, &t9, &t10, &t11, &t12, &t13, &t14,
446                     &t15);
447         transpose_elems_u8_8x8(t8, t9, t10, t11, t12, t13, t14, t15, &t8, &t9,
448                                &t10, &t11, &t12, &t13, &t14, &t15);
449 
450         int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t8));
451         int16x8_t s9 = vreinterpretq_s16_u16(vmovl_u8(t9));
452         int16x8_t s10 = vreinterpretq_s16_u16(vmovl_u8(t10));
453         int16x8_t s11 = vreinterpretq_s16_u16(vmovl_u8(t11));
454         int16x8_t s12 = vreinterpretq_s16_u16(vmovl_u8(t12));
455         int16x8_t s13 = vreinterpretq_s16_u16(vmovl_u8(t13));
456         int16x8_t s14 = vreinterpretq_s16_u16(vmovl_u8(t14));
457         int16x8_t s15 = vreinterpretq_s16_u16(vmovl_u8(t15));
458 
459         int16x8_t d0 =
460             convolve8_8_h(s0, s1, s2, s3, s4, s5, s6, s7, filter, horiz_offset);
461         int16x8_t d1 =
462             convolve8_8_h(s2, s3, s4, s5, s6, s7, s8, s9, filter, horiz_offset);
463         int16x8_t d2 = convolve8_8_h(s4, s5, s6, s7, s8, s9, s10, s11, filter,
464                                      horiz_offset);
465         int16x8_t d3 = convolve8_8_h(s6, s7, s8, s9, s10, s11, s12, s13, filter,
466                                      horiz_offset);
467 
468         transpose_elems_inplace_s16_8x4(&d0, &d1, &d2, &d3);
469 
470         store_s16_4x8(d, dst_stride, vget_low_s16(d0), vget_low_s16(d1),
471                       vget_low_s16(d2), vget_low_s16(d3), vget_high_s16(d0),
472                       vget_high_s16(d1), vget_high_s16(d2), vget_high_s16(d3));
473 
474         s0 = s8;
475         s1 = s9;
476         s2 = s10;
477         s3 = s11;
478         s4 = s12;
479         s5 = s13;
480         s6 = s14;
481         s7 = s15;
482 
483         s += 8;
484         d += 4;
485         width -= 4;
486       } while (width != 0);
487 
488       dst += 8 * dst_stride;
489       src += 8 * src_stride;
490       h -= 8;
491     } while (h > 0);
492   }
493 }
494 
convolve_horiz_scale_2_6tap_neon(const uint8_t * src,int src_stride,int16_t * dst,int dst_stride,int w,int h,const int16_t * x_filter)495 static inline void convolve_horiz_scale_2_6tap_neon(
496     const uint8_t *src, int src_stride, int16_t *dst, int dst_stride, int w,
497     int h, const int16_t *x_filter) {
498   const int bd = 8;
499 
500   if (w == 4) {
501     // A shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
502     // shifts - which are generally faster than rounding shifts on modern CPUs.
503     const int32x4_t horiz_offset =
504         vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
505     const int16x8_t filter = vld1q_s16(x_filter);
506 
507     do {
508       uint8x16_t t0, t1, t2, t3;
509       load_u8_16x4(src, src_stride, &t0, &t1, &t2, &t3);
510       transpose_elems_inplace_u8_16x4(&t0, &t1, &t2, &t3);
511 
512       int16x8_t tt0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t1)));
513       int16x8_t tt1 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t2)));
514       int16x8_t tt2 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t3)));
515       int16x8_t tt3 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t0)));
516       int16x8_t tt4 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t0)));
517       int16x8_t tt5 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t1)));
518       int16x8_t tt6 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t2)));
519       int16x8_t tt7 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t3)));
520 
521       int16x4_t s0 = vget_low_s16(tt0);
522       int16x4_t s1 = vget_low_s16(tt1);
523       int16x4_t s2 = vget_low_s16(tt2);
524       int16x4_t s3 = vget_high_s16(tt3);
525       int16x4_t s4 = vget_high_s16(tt0);
526       int16x4_t s5 = vget_high_s16(tt1);
527       int16x4_t s6 = vget_high_s16(tt2);
528       int16x4_t s7 = vget_low_s16(tt4);
529       int16x4_t s8 = vget_low_s16(tt5);
530       int16x4_t s9 = vget_low_s16(tt6);
531       int16x4_t s10 = vget_low_s16(tt7);
532       int16x4_t s11 = vget_high_s16(tt4);
533 
534       int16x4_t d0 =
535           convolve6_4_h(s0, s1, s2, s3, s4, s5, filter, horiz_offset);
536       int16x4_t d1 =
537           convolve6_4_h(s2, s3, s4, s5, s6, s7, filter, horiz_offset);
538       int16x4_t d2 =
539           convolve6_4_h(s4, s5, s6, s7, s8, s9, filter, horiz_offset);
540       int16x4_t d3 =
541           convolve6_4_h(s6, s7, s8, s9, s10, s11, filter, horiz_offset);
542 
543       transpose_elems_inplace_s16_4x4(&d0, &d1, &d2, &d3);
544 
545       store_s16_4x4(dst, dst_stride, d0, d1, d2, d3);
546 
547       dst += 4 * dst_stride;
548       src += 4 * src_stride;
549       h -= 4;
550     } while (h > 0);
551   } else {
552     // A shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
553     // shifts - which are generally faster than rounding shifts on modern CPUs.
554     // The additional -1 is needed because we are halving the filter values.
555     const int16x8_t horiz_offset =
556         vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) + (1 << (ROUND0_BITS - 2)));
557     // Filter values are all even so halve them to allow convolution
558     // kernel computations to stay in 16-bit element types.
559     const int16x8_t filter = vshrq_n_s16(vld1q_s16(x_filter), 1);
560 
561     do {
562       const uint8_t *s = src;
563       int16_t *d = dst;
564       int width = w;
565 
566       uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
567       load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
568       transpose_elems_u8_8x8(t0, t1, t2, t3, t4, t5, t6, t7, &t0, &t1, &t2, &t3,
569                              &t4, &t5, &t6, &t7);
570 
571       s += 8;
572 
573       int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t1));
574       int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t2));
575       int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t3));
576       int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t4));
577       int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t5));
578       int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t6));
579       int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t7));
580 
581       do {
582         uint8x8_t t8, t9, t10, t11, t12, t13, t14, t15;
583         load_u8_8x8(s, src_stride, &t8, &t9, &t10, &t11, &t12, &t13, &t14,
584                     &t15);
585         transpose_elems_u8_8x8(t8, t9, t10, t11, t12, t13, t14, t15, &t8, &t9,
586                                &t10, &t11, &t12, &t13, &t14, &t15);
587 
588         int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t8));
589         int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t9));
590         int16x8_t s9 = vreinterpretq_s16_u16(vmovl_u8(t10));
591         int16x8_t s10 = vreinterpretq_s16_u16(vmovl_u8(t11));
592         int16x8_t s11 = vreinterpretq_s16_u16(vmovl_u8(t12));
593         int16x8_t s12 = vreinterpretq_s16_u16(vmovl_u8(t13));
594         int16x8_t s13 = vreinterpretq_s16_u16(vmovl_u8(t14));
595         int16x8_t s14 = vreinterpretq_s16_u16(vmovl_u8(t15));
596 
597         int16x8_t d0 =
598             convolve6_8_h(s0, s1, s2, s3, s4, s5, filter, horiz_offset);
599         int16x8_t d1 =
600             convolve6_8_h(s2, s3, s4, s5, s6, s7, filter, horiz_offset);
601         int16x8_t d2 =
602             convolve6_8_h(s4, s5, s6, s7, s8, s9, filter, horiz_offset);
603         int16x8_t d3 =
604             convolve6_8_h(s6, s7, s8, s9, s10, s11, filter, horiz_offset);
605 
606         transpose_elems_inplace_s16_8x4(&d0, &d1, &d2, &d3);
607 
608         store_s16_4x8(d, dst_stride, vget_low_s16(d0), vget_low_s16(d1),
609                       vget_low_s16(d2), vget_low_s16(d3), vget_high_s16(d0),
610                       vget_high_s16(d1), vget_high_s16(d2), vget_high_s16(d3));
611 
612         s0 = s8;
613         s1 = s9;
614         s2 = s10;
615         s3 = s11;
616         s4 = s12;
617         s5 = s13;
618         s6 = s14;
619 
620         s += 8;
621         d += 4;
622         width -= 4;
623       } while (width != 0);
624 
625       dst += 8 * dst_stride;
626       src += 8 * src_stride;
627       h -= 8;
628     } while (h > 0);
629   }
630 }
631 
av1_convolve_2d_scale_neon(const uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int x_step_qn,const int subpel_y_qn,const int y_step_qn,ConvolveParams * conv_params)632 void av1_convolve_2d_scale_neon(const uint8_t *src, int src_stride,
633                                 uint8_t *dst, int dst_stride, int w, int h,
634                                 const InterpFilterParams *filter_params_x,
635                                 const InterpFilterParams *filter_params_y,
636                                 const int subpel_x_qn, const int x_step_qn,
637                                 const int subpel_y_qn, const int y_step_qn,
638                                 ConvolveParams *conv_params) {
639   if (w < 4 || h < 4) {
640     av1_convolve_2d_scale_c(src, src_stride, dst, dst_stride, w, h,
641                             filter_params_x, filter_params_y, subpel_x_qn,
642                             x_step_qn, subpel_y_qn, y_step_qn, conv_params);
643     return;
644   }
645 
646   // For the interpolation 8-tap filters are used.
647   assert(filter_params_y->taps <= 8 && filter_params_x->taps <= 8);
648 
649   DECLARE_ALIGNED(32, int16_t,
650                   im_block[(2 * MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE]);
651   int im_h = (((h - 1) * y_step_qn + subpel_y_qn) >> SCALE_SUBPEL_BITS) +
652              filter_params_y->taps;
653   int im_stride = MAX_SB_SIZE;
654   CONV_BUF_TYPE *dst16 = conv_params->dst;
655   const int dst16_stride = conv_params->dst_stride;
656 
657   // Account for needing filter_taps / 2 - 1 lines prior and filter_taps / 2
658   // lines post both horizontally and vertically.
659   const ptrdiff_t horiz_offset = filter_params_x->taps / 2 - 1;
660   const ptrdiff_t vert_offset = (filter_params_y->taps / 2 - 1) * src_stride;
661 
662   // Horizontal filter
663 
664   if (x_step_qn != 2 * (1 << SCALE_SUBPEL_BITS)) {
665     if (filter_params_x->interp_filter == MULTITAP_SHARP) {
666       convolve_horiz_scale_8tap_neon(
667           src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
668           im_h, filter_params_x->filter_ptr, subpel_x_qn, x_step_qn);
669     } else {
670       convolve_horiz_scale_6tap_neon(
671           src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
672           im_h, filter_params_x->filter_ptr, subpel_x_qn, x_step_qn);
673     }
674   } else {
675     assert(subpel_x_qn < (1 << SCALE_SUBPEL_BITS));
676     // The filter index is calculated using the
677     // ((subpel_x_qn + x * x_step_qn) & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS
678     // equation, where the values of x are from 0 to w. If x_step_qn is a
679     // multiple of SCALE_SUBPEL_MASK we can leave it out of the equation.
680     const ptrdiff_t filter_offset =
681         SUBPEL_TAPS * ((subpel_x_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS);
682     const int16_t *x_filter = filter_params_x->filter_ptr + filter_offset;
683 
684     // The source index is calculated using the (subpel_x_qn + x * x_step_qn)
685     // >> SCALE_SUBPEL_BITS, where the values of x are from 0 to w. If
686     // subpel_x_qn < (1 << SCALE_SUBPEL_BITS) and x_step_qn % (1 <<
687     // SCALE_SUBPEL_BITS) == 0, the source index can be determined using the
688     // value x * (x_step_qn / (1 << SCALE_SUBPEL_BITS)).
689     if (filter_params_x->interp_filter == MULTITAP_SHARP) {
690       convolve_horiz_scale_2_8tap_neon(src - horiz_offset - vert_offset,
691                                        src_stride, im_block, im_stride, w, im_h,
692                                        x_filter);
693     } else {
694       convolve_horiz_scale_2_6tap_neon(src - horiz_offset - vert_offset,
695                                        src_stride, im_block, im_stride, w, im_h,
696                                        x_filter);
697     }
698   }
699 
700   // Vertical filter
701   if (filter_params_y->interp_filter == MULTITAP_SHARP) {
702     if (UNLIKELY(conv_params->is_compound)) {
703       if (conv_params->do_average) {
704         if (conv_params->use_dist_wtd_comp_avg) {
705           compound_dist_wtd_convolve_vert_scale_8tap_neon(
706               im_block, im_stride, dst, dst_stride, dst16, dst16_stride, w, h,
707               filter_params_y->filter_ptr, conv_params, subpel_y_qn, y_step_qn);
708         } else {
709           compound_avg_convolve_vert_scale_8tap_neon(
710               im_block, im_stride, dst, dst_stride, dst16, dst16_stride, w, h,
711               filter_params_y->filter_ptr, subpel_y_qn, y_step_qn);
712         }
713       } else {
714         compound_convolve_vert_scale_8tap_neon(
715             im_block, im_stride, dst16, dst16_stride, w, h,
716             filter_params_y->filter_ptr, subpel_y_qn, y_step_qn);
717       }
718     } else {
719       convolve_vert_scale_8tap_neon(im_block, im_stride, dst, dst_stride, w, h,
720                                     filter_params_y->filter_ptr, subpel_y_qn,
721                                     y_step_qn);
722     }
723   } else {
724     if (UNLIKELY(conv_params->is_compound)) {
725       if (conv_params->do_average) {
726         if (conv_params->use_dist_wtd_comp_avg) {
727           compound_dist_wtd_convolve_vert_scale_6tap_neon(
728               im_block + im_stride, im_stride, dst, dst_stride, dst16,
729               dst16_stride, w, h, filter_params_y->filter_ptr, conv_params,
730               subpel_y_qn, y_step_qn);
731         } else {
732           compound_avg_convolve_vert_scale_6tap_neon(
733               im_block + im_stride, im_stride, dst, dst_stride, dst16,
734               dst16_stride, w, h, filter_params_y->filter_ptr, subpel_y_qn,
735               y_step_qn);
736         }
737       } else {
738         compound_convolve_vert_scale_6tap_neon(
739             im_block + im_stride, im_stride, dst16, dst16_stride, w, h,
740             filter_params_y->filter_ptr, subpel_y_qn, y_step_qn);
741       }
742     } else {
743       convolve_vert_scale_6tap_neon(
744           im_block + im_stride, im_stride, dst, dst_stride, w, h,
745           filter_params_y->filter_ptr, subpel_y_qn, y_step_qn);
746     }
747   }
748 }
749