xref: /aosp_15_r20/external/libaom/av1/common/arm/highbd_wiener_convolve_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "av1/common/convolve.h"
17 #include "config/aom_config.h"
18 #include "config/av1_rtcd.h"
19 
20 #define HBD_WIENER_5TAP_HORIZ(name, shift)                              \
21   static inline uint16x8_t name##_wiener_convolve5_8_2d_h(              \
22       const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,       \
23       const int16x8_t s3, const int16x8_t s4, const int16x4_t x_filter, \
24       const int32x4_t round_vec, const uint16x8_t im_max_val) {         \
25     /* Wiener filter is symmetric so add mirrored source elements. */   \
26     int16x8_t s04 = vaddq_s16(s0, s4);                                  \
27     int16x8_t s13 = vaddq_s16(s1, s3);                                  \
28                                                                         \
29     /* x_filter[0] = 0. (5-tap filters are 0-padded to 7 taps.) */      \
30     int32x4_t sum_lo =                                                  \
31         vmlal_lane_s16(round_vec, vget_low_s16(s04), x_filter, 1);      \
32     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s13), x_filter, 2);    \
33     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s2), x_filter, 3);     \
34                                                                         \
35     int32x4_t sum_hi =                                                  \
36         vmlal_lane_s16(round_vec, vget_high_s16(s04), x_filter, 1);     \
37     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s13), x_filter, 2);   \
38     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s2), x_filter, 3);    \
39                                                                         \
40     uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift);                  \
41     uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift);                  \
42                                                                         \
43     return vminq_u16(vcombine_u16(res_lo, res_hi), im_max_val);         \
44   }                                                                     \
45                                                                         \
46   static inline void name##_convolve_add_src_5tap_horiz(                \
47       const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, \
48       ptrdiff_t dst_stride, int w, int h, const int16x4_t x_filter,     \
49       const int32x4_t round_vec, const uint16x8_t im_max_val) {         \
50     do {                                                                \
51       const int16_t *s = (int16_t *)src_ptr;                            \
52       uint16_t *d = dst_ptr;                                            \
53       int width = w;                                                    \
54                                                                         \
55       do {                                                              \
56         int16x8_t s0, s1, s2, s3, s4;                                   \
57         load_s16_8x5(s, 1, &s0, &s1, &s2, &s3, &s4);                    \
58                                                                         \
59         uint16x8_t d0 = name##_wiener_convolve5_8_2d_h(                 \
60             s0, s1, s2, s3, s4, x_filter, round_vec, im_max_val);       \
61                                                                         \
62         vst1q_u16(d, d0);                                               \
63                                                                         \
64         s += 8;                                                         \
65         d += 8;                                                         \
66         width -= 8;                                                     \
67       } while (width != 0);                                             \
68       src_ptr += src_stride;                                            \
69       dst_ptr += dst_stride;                                            \
70     } while (--h != 0);                                                 \
71   }
72 
HBD_WIENER_5TAP_HORIZ(highbd,WIENER_ROUND0_BITS)73 HBD_WIENER_5TAP_HORIZ(highbd, WIENER_ROUND0_BITS)
74 HBD_WIENER_5TAP_HORIZ(highbd_12, WIENER_ROUND0_BITS + 2)
75 
76 #undef HBD_WIENER_5TAP_HORIZ
77 
78 #define HBD_WIENER_7TAP_HORIZ(name, shift)                                     \
79   static inline uint16x8_t name##_wiener_convolve7_8_2d_h(                     \
80       const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,              \
81       const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,              \
82       const int16x8_t s6, const int16x4_t x_filter, const int32x4_t round_vec, \
83       const uint16x8_t im_max_val) {                                           \
84     /* Wiener filter is symmetric so add mirrored source elements. */          \
85     int16x8_t s06 = vaddq_s16(s0, s6);                                         \
86     int16x8_t s15 = vaddq_s16(s1, s5);                                         \
87     int16x8_t s24 = vaddq_s16(s2, s4);                                         \
88                                                                                \
89     int32x4_t sum_lo =                                                         \
90         vmlal_lane_s16(round_vec, vget_low_s16(s06), x_filter, 0);             \
91     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s15), x_filter, 1);           \
92     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s24), x_filter, 2);           \
93     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s3), x_filter, 3);            \
94                                                                                \
95     int32x4_t sum_hi =                                                         \
96         vmlal_lane_s16(round_vec, vget_high_s16(s06), x_filter, 0);            \
97     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s15), x_filter, 1);          \
98     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s24), x_filter, 2);          \
99     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s3), x_filter, 3);           \
100                                                                                \
101     uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift);                         \
102     uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift);                         \
103                                                                                \
104     return vminq_u16(vcombine_u16(res_lo, res_hi), im_max_val);                \
105   }                                                                            \
106                                                                                \
107   static inline void name##_convolve_add_src_7tap_horiz(                       \
108       const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,        \
109       ptrdiff_t dst_stride, int w, int h, const int16x4_t x_filter,            \
110       const int32x4_t round_vec, const uint16x8_t im_max_val) {                \
111     do {                                                                       \
112       const int16_t *s = (int16_t *)src_ptr;                                   \
113       uint16_t *d = dst_ptr;                                                   \
114       int width = w;                                                           \
115                                                                                \
116       do {                                                                     \
117         int16x8_t s0, s1, s2, s3, s4, s5, s6;                                  \
118         load_s16_8x7(s, 1, &s0, &s1, &s2, &s3, &s4, &s5, &s6);                 \
119                                                                                \
120         uint16x8_t d0 = name##_wiener_convolve7_8_2d_h(                        \
121             s0, s1, s2, s3, s4, s5, s6, x_filter, round_vec, im_max_val);      \
122                                                                                \
123         vst1q_u16(d, d0);                                                      \
124                                                                                \
125         s += 8;                                                                \
126         d += 8;                                                                \
127         width -= 8;                                                            \
128       } while (width != 0);                                                    \
129       src_ptr += src_stride;                                                   \
130       dst_ptr += dst_stride;                                                   \
131     } while (--h != 0);                                                        \
132   }
133 
134 HBD_WIENER_7TAP_HORIZ(highbd, WIENER_ROUND0_BITS)
135 HBD_WIENER_7TAP_HORIZ(highbd_12, WIENER_ROUND0_BITS + 2)
136 
137 #undef HBD_WIENER_7TAP_HORIZ
138 
139 #define HBD_WIENER_5TAP_VERT(name, shift)                                     \
140   static inline uint16x8_t name##_wiener_convolve5_8_2d_v(                    \
141       const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,             \
142       const int16x8_t s3, const int16x8_t s4, const int16x4_t y_filter,       \
143       const int32x4_t round_vec, const uint16x8_t res_max_val) {              \
144     const int32x2_t y_filter_lo = vget_low_s32(vmovl_s16(y_filter));          \
145     const int32x2_t y_filter_hi = vget_high_s32(vmovl_s16(y_filter));         \
146     /* Wiener filter is symmetric so add mirrored source elements. */         \
147     int32x4_t s04_lo = vaddl_s16(vget_low_s16(s0), vget_low_s16(s4));         \
148     int32x4_t s13_lo = vaddl_s16(vget_low_s16(s1), vget_low_s16(s3));         \
149                                                                               \
150     /* y_filter[0] = 0. (5-tap filters are 0-padded to 7 taps.) */            \
151     int32x4_t sum_lo = vmlaq_lane_s32(round_vec, s04_lo, y_filter_lo, 1);     \
152     sum_lo = vmlaq_lane_s32(sum_lo, s13_lo, y_filter_hi, 0);                  \
153     sum_lo =                                                                  \
154         vmlaq_lane_s32(sum_lo, vmovl_s16(vget_low_s16(s2)), y_filter_hi, 1);  \
155                                                                               \
156     int32x4_t s04_hi = vaddl_s16(vget_high_s16(s0), vget_high_s16(s4));       \
157     int32x4_t s13_hi = vaddl_s16(vget_high_s16(s1), vget_high_s16(s3));       \
158                                                                               \
159     int32x4_t sum_hi = vmlaq_lane_s32(round_vec, s04_hi, y_filter_lo, 1);     \
160     sum_hi = vmlaq_lane_s32(sum_hi, s13_hi, y_filter_hi, 0);                  \
161     sum_hi =                                                                  \
162         vmlaq_lane_s32(sum_hi, vmovl_s16(vget_high_s16(s2)), y_filter_hi, 1); \
163                                                                               \
164     uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift);                        \
165     uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift);                        \
166                                                                               \
167     return vminq_u16(vcombine_u16(res_lo, res_hi), res_max_val);              \
168   }                                                                           \
169                                                                               \
170   static inline void name##_convolve_add_src_5tap_vert(                       \
171       const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,       \
172       ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter,           \
173       const int32x4_t round_vec, const uint16x8_t res_max_val) {              \
174     do {                                                                      \
175       const int16_t *s = (int16_t *)src_ptr;                                  \
176       uint16_t *d = dst_ptr;                                                  \
177       int height = h;                                                         \
178                                                                               \
179       while (height > 3) {                                                    \
180         int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;                             \
181         load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);  \
182                                                                               \
183         uint16x8_t d0 = name##_wiener_convolve5_8_2d_v(                       \
184             s0, s1, s2, s3, s4, y_filter, round_vec, res_max_val);            \
185         uint16x8_t d1 = name##_wiener_convolve5_8_2d_v(                       \
186             s1, s2, s3, s4, s5, y_filter, round_vec, res_max_val);            \
187         uint16x8_t d2 = name##_wiener_convolve5_8_2d_v(                       \
188             s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val);            \
189         uint16x8_t d3 = name##_wiener_convolve5_8_2d_v(                       \
190             s3, s4, s5, s6, s7, y_filter, round_vec, res_max_val);            \
191                                                                               \
192         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);                         \
193                                                                               \
194         s += 4 * src_stride;                                                  \
195         d += 4 * dst_stride;                                                  \
196         height -= 4;                                                          \
197       }                                                                       \
198                                                                               \
199       while (height-- != 0) {                                                 \
200         int16x8_t s0, s1, s2, s3, s4;                                         \
201         load_s16_8x5(s, src_stride, &s0, &s1, &s2, &s3, &s4);                 \
202                                                                               \
203         uint16x8_t d0 = name##_wiener_convolve5_8_2d_v(                       \
204             s0, s1, s2, s3, s4, y_filter, round_vec, res_max_val);            \
205                                                                               \
206         vst1q_u16(d, d0);                                                     \
207                                                                               \
208         s += src_stride;                                                      \
209         d += dst_stride;                                                      \
210       }                                                                       \
211                                                                               \
212       src_ptr += 8;                                                           \
213       dst_ptr += 8;                                                           \
214       w -= 8;                                                                 \
215     } while (w != 0);                                                         \
216   }
217 
218 HBD_WIENER_5TAP_VERT(highbd, 2 * FILTER_BITS - WIENER_ROUND0_BITS)
219 HBD_WIENER_5TAP_VERT(highbd_12, 2 * FILTER_BITS - WIENER_ROUND0_BITS - 2)
220 
221 #undef HBD_WIENER_5TAP_VERT
222 
223 #define HBD_WIENER_7TAP_VERT(name, shift)                                      \
224   static inline uint16x8_t name##_wiener_convolve7_8_2d_v(                     \
225       const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,              \
226       const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,              \
227       const int16x8_t s6, const int16x4_t y_filter, const int32x4_t round_vec, \
228       const uint16x8_t res_max_val) {                                          \
229     const int32x2_t y_filter_lo = vget_low_s32(vmovl_s16(y_filter));           \
230     const int32x2_t y_filter_hi = vget_high_s32(vmovl_s16(y_filter));          \
231     /* Wiener filter is symmetric so add mirrored source elements. */          \
232     int32x4_t s06_lo = vaddl_s16(vget_low_s16(s0), vget_low_s16(s6));          \
233     int32x4_t s15_lo = vaddl_s16(vget_low_s16(s1), vget_low_s16(s5));          \
234     int32x4_t s24_lo = vaddl_s16(vget_low_s16(s2), vget_low_s16(s4));          \
235                                                                                \
236     int32x4_t sum_lo = vmlaq_lane_s32(round_vec, s06_lo, y_filter_lo, 0);      \
237     sum_lo = vmlaq_lane_s32(sum_lo, s15_lo, y_filter_lo, 1);                   \
238     sum_lo = vmlaq_lane_s32(sum_lo, s24_lo, y_filter_hi, 0);                   \
239     sum_lo =                                                                   \
240         vmlaq_lane_s32(sum_lo, vmovl_s16(vget_low_s16(s3)), y_filter_hi, 1);   \
241                                                                                \
242     int32x4_t s06_hi = vaddl_s16(vget_high_s16(s0), vget_high_s16(s6));        \
243     int32x4_t s15_hi = vaddl_s16(vget_high_s16(s1), vget_high_s16(s5));        \
244     int32x4_t s24_hi = vaddl_s16(vget_high_s16(s2), vget_high_s16(s4));        \
245                                                                                \
246     int32x4_t sum_hi = vmlaq_lane_s32(round_vec, s06_hi, y_filter_lo, 0);      \
247     sum_hi = vmlaq_lane_s32(sum_hi, s15_hi, y_filter_lo, 1);                   \
248     sum_hi = vmlaq_lane_s32(sum_hi, s24_hi, y_filter_hi, 0);                   \
249     sum_hi =                                                                   \
250         vmlaq_lane_s32(sum_hi, vmovl_s16(vget_high_s16(s3)), y_filter_hi, 1);  \
251                                                                                \
252     uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift);                         \
253     uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift);                         \
254                                                                                \
255     return vminq_u16(vcombine_u16(res_lo, res_hi), res_max_val);               \
256   }                                                                            \
257                                                                                \
258   static inline void name##_convolve_add_src_7tap_vert(                        \
259       const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,        \
260       ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter,            \
261       const int32x4_t round_vec, const uint16x8_t res_max_val) {               \
262     do {                                                                       \
263       const int16_t *s = (int16_t *)src_ptr;                                   \
264       uint16_t *d = dst_ptr;                                                   \
265       int height = h;                                                          \
266                                                                                \
267       while (height > 3) {                                                     \
268         int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9;                      \
269         load_s16_8x10(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7,   \
270                       &s8, &s9);                                               \
271                                                                                \
272         uint16x8_t d0 = name##_wiener_convolve7_8_2d_v(                        \
273             s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val);     \
274         uint16x8_t d1 = name##_wiener_convolve7_8_2d_v(                        \
275             s1, s2, s3, s4, s5, s6, s7, y_filter, round_vec, res_max_val);     \
276         uint16x8_t d2 = name##_wiener_convolve7_8_2d_v(                        \
277             s2, s3, s4, s5, s6, s7, s8, y_filter, round_vec, res_max_val);     \
278         uint16x8_t d3 = name##_wiener_convolve7_8_2d_v(                        \
279             s3, s4, s5, s6, s7, s8, s9, y_filter, round_vec, res_max_val);     \
280                                                                                \
281         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);                          \
282                                                                                \
283         s += 4 * src_stride;                                                   \
284         d += 4 * dst_stride;                                                   \
285         height -= 4;                                                           \
286       }                                                                        \
287                                                                                \
288       while (height-- != 0) {                                                  \
289         int16x8_t s0, s1, s2, s3, s4, s5, s6;                                  \
290         load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);        \
291                                                                                \
292         uint16x8_t d0 = name##_wiener_convolve7_8_2d_v(                        \
293             s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val);     \
294                                                                                \
295         vst1q_u16(d, d0);                                                      \
296                                                                                \
297         s += src_stride;                                                       \
298         d += dst_stride;                                                       \
299       }                                                                        \
300                                                                                \
301       src_ptr += 8;                                                            \
302       dst_ptr += 8;                                                            \
303       w -= 8;                                                                  \
304     } while (w != 0);                                                          \
305   }
306 
307 HBD_WIENER_7TAP_VERT(highbd, 2 * FILTER_BITS - WIENER_ROUND0_BITS)
308 HBD_WIENER_7TAP_VERT(highbd_12, 2 * FILTER_BITS - WIENER_ROUND0_BITS - 2)
309 
310 #undef HBD_WIENER_7TAP_VERT
311 
312 static inline int get_wiener_filter_taps(const int16_t *filter) {
313   assert(filter[7] == 0);
314   if (filter[0] == 0 && filter[6] == 0) {
315     return WIENER_WIN_REDUCED;
316   }
317   return WIENER_WIN;
318 }
319 
av1_highbd_wiener_convolve_add_src_neon(const uint8_t * src8,ptrdiff_t src_stride,uint8_t * dst8,ptrdiff_t dst_stride,const int16_t * x_filter,int x_step_q4,const int16_t * y_filter,int y_step_q4,int w,int h,const WienerConvolveParams * conv_params,int bd)320 void av1_highbd_wiener_convolve_add_src_neon(
321     const uint8_t *src8, ptrdiff_t src_stride, uint8_t *dst8,
322     ptrdiff_t dst_stride, const int16_t *x_filter, int x_step_q4,
323     const int16_t *y_filter, int y_step_q4, int w, int h,
324     const WienerConvolveParams *conv_params, int bd) {
325   (void)x_step_q4;
326   (void)y_step_q4;
327 
328   assert(w % 8 == 0);
329   assert(w <= MAX_SB_SIZE && h <= MAX_SB_SIZE);
330   assert(x_step_q4 == 16 && y_step_q4 == 16);
331   assert(x_filter[7] == 0 && y_filter[7] == 0);
332 
333   DECLARE_ALIGNED(16, uint16_t,
334                   im_block[(MAX_SB_SIZE + WIENER_WIN - 1) * MAX_SB_SIZE]);
335 
336   const int x_filter_taps = get_wiener_filter_taps(x_filter);
337   const int y_filter_taps = get_wiener_filter_taps(y_filter);
338   int16x4_t x_filter_s16 = vld1_s16(x_filter);
339   int16x4_t y_filter_s16 = vld1_s16(y_filter);
340   // Add 128 to tap 3. (Needed for rounding.)
341   x_filter_s16 = vadd_s16(x_filter_s16, vcreate_s16(128ULL << 48));
342   y_filter_s16 = vadd_s16(y_filter_s16, vcreate_s16(128ULL << 48));
343 
344   const int im_stride = MAX_SB_SIZE;
345   const int im_h = h + y_filter_taps - 1;
346   const int horiz_offset = x_filter_taps / 2;
347   const int vert_offset = (y_filter_taps / 2) * (int)src_stride;
348 
349   const int extraprec_clamp_limit =
350       WIENER_CLAMP_LIMIT(conv_params->round_0, bd);
351   const uint16x8_t im_max_val = vdupq_n_u16(extraprec_clamp_limit - 1);
352   const int32x4_t horiz_round_vec = vdupq_n_s32(1 << (bd + FILTER_BITS - 1));
353 
354   const uint16x8_t res_max_val = vdupq_n_u16((1 << bd) - 1);
355   const int32x4_t vert_round_vec =
356       vdupq_n_s32(-(1 << (bd + conv_params->round_1 - 1)));
357 
358   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
359   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
360 
361   if (bd == 12) {
362     if (x_filter_taps == WIENER_WIN_REDUCED) {
363       highbd_12_convolve_add_src_5tap_horiz(
364           src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
365           im_h, x_filter_s16, horiz_round_vec, im_max_val);
366     } else {
367       highbd_12_convolve_add_src_7tap_horiz(
368           src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
369           im_h, x_filter_s16, horiz_round_vec, im_max_val);
370     }
371 
372     if (y_filter_taps == WIENER_WIN_REDUCED) {
373       highbd_12_convolve_add_src_5tap_vert(im_block, im_stride, dst, dst_stride,
374                                            w, h, y_filter_s16, vert_round_vec,
375                                            res_max_val);
376     } else {
377       highbd_12_convolve_add_src_7tap_vert(im_block, im_stride, dst, dst_stride,
378                                            w, h, y_filter_s16, vert_round_vec,
379                                            res_max_val);
380     }
381 
382   } else {
383     if (x_filter_taps == WIENER_WIN_REDUCED) {
384       highbd_convolve_add_src_5tap_horiz(
385           src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
386           im_h, x_filter_s16, horiz_round_vec, im_max_val);
387     } else {
388       highbd_convolve_add_src_7tap_horiz(
389           src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
390           im_h, x_filter_s16, horiz_round_vec, im_max_val);
391     }
392 
393     if (y_filter_taps == WIENER_WIN_REDUCED) {
394       highbd_convolve_add_src_5tap_vert(im_block, im_stride, dst, dst_stride, w,
395                                         h, y_filter_s16, vert_round_vec,
396                                         res_max_val);
397     } else {
398       highbd_convolve_add_src_7tap_vert(im_block, im_stride, dst, dst_stride, w,
399                                         h, y_filter_s16, vert_round_vec,
400                                         res_max_val);
401     }
402   }
403 }
404