1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "av1/common/convolve.h"
17 #include "config/aom_config.h"
18 #include "config/av1_rtcd.h"
19
20 #define HBD_WIENER_5TAP_HORIZ(name, shift) \
21 static inline uint16x8_t name##_wiener_convolve5_8_2d_h( \
22 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, \
23 const int16x8_t s3, const int16x8_t s4, const int16x4_t x_filter, \
24 const int32x4_t round_vec, const uint16x8_t im_max_val) { \
25 /* Wiener filter is symmetric so add mirrored source elements. */ \
26 int16x8_t s04 = vaddq_s16(s0, s4); \
27 int16x8_t s13 = vaddq_s16(s1, s3); \
28 \
29 /* x_filter[0] = 0. (5-tap filters are 0-padded to 7 taps.) */ \
30 int32x4_t sum_lo = \
31 vmlal_lane_s16(round_vec, vget_low_s16(s04), x_filter, 1); \
32 sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s13), x_filter, 2); \
33 sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s2), x_filter, 3); \
34 \
35 int32x4_t sum_hi = \
36 vmlal_lane_s16(round_vec, vget_high_s16(s04), x_filter, 1); \
37 sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s13), x_filter, 2); \
38 sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s2), x_filter, 3); \
39 \
40 uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift); \
41 uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift); \
42 \
43 return vminq_u16(vcombine_u16(res_lo, res_hi), im_max_val); \
44 } \
45 \
46 static inline void name##_convolve_add_src_5tap_horiz( \
47 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, \
48 ptrdiff_t dst_stride, int w, int h, const int16x4_t x_filter, \
49 const int32x4_t round_vec, const uint16x8_t im_max_val) { \
50 do { \
51 const int16_t *s = (int16_t *)src_ptr; \
52 uint16_t *d = dst_ptr; \
53 int width = w; \
54 \
55 do { \
56 int16x8_t s0, s1, s2, s3, s4; \
57 load_s16_8x5(s, 1, &s0, &s1, &s2, &s3, &s4); \
58 \
59 uint16x8_t d0 = name##_wiener_convolve5_8_2d_h( \
60 s0, s1, s2, s3, s4, x_filter, round_vec, im_max_val); \
61 \
62 vst1q_u16(d, d0); \
63 \
64 s += 8; \
65 d += 8; \
66 width -= 8; \
67 } while (width != 0); \
68 src_ptr += src_stride; \
69 dst_ptr += dst_stride; \
70 } while (--h != 0); \
71 }
72
HBD_WIENER_5TAP_HORIZ(highbd,WIENER_ROUND0_BITS)73 HBD_WIENER_5TAP_HORIZ(highbd, WIENER_ROUND0_BITS)
74 HBD_WIENER_5TAP_HORIZ(highbd_12, WIENER_ROUND0_BITS + 2)
75
76 #undef HBD_WIENER_5TAP_HORIZ
77
78 #define HBD_WIENER_7TAP_HORIZ(name, shift) \
79 static inline uint16x8_t name##_wiener_convolve7_8_2d_h( \
80 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, \
81 const int16x8_t s3, const int16x8_t s4, const int16x8_t s5, \
82 const int16x8_t s6, const int16x4_t x_filter, const int32x4_t round_vec, \
83 const uint16x8_t im_max_val) { \
84 /* Wiener filter is symmetric so add mirrored source elements. */ \
85 int16x8_t s06 = vaddq_s16(s0, s6); \
86 int16x8_t s15 = vaddq_s16(s1, s5); \
87 int16x8_t s24 = vaddq_s16(s2, s4); \
88 \
89 int32x4_t sum_lo = \
90 vmlal_lane_s16(round_vec, vget_low_s16(s06), x_filter, 0); \
91 sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s15), x_filter, 1); \
92 sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s24), x_filter, 2); \
93 sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s3), x_filter, 3); \
94 \
95 int32x4_t sum_hi = \
96 vmlal_lane_s16(round_vec, vget_high_s16(s06), x_filter, 0); \
97 sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s15), x_filter, 1); \
98 sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s24), x_filter, 2); \
99 sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s3), x_filter, 3); \
100 \
101 uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift); \
102 uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift); \
103 \
104 return vminq_u16(vcombine_u16(res_lo, res_hi), im_max_val); \
105 } \
106 \
107 static inline void name##_convolve_add_src_7tap_horiz( \
108 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, \
109 ptrdiff_t dst_stride, int w, int h, const int16x4_t x_filter, \
110 const int32x4_t round_vec, const uint16x8_t im_max_val) { \
111 do { \
112 const int16_t *s = (int16_t *)src_ptr; \
113 uint16_t *d = dst_ptr; \
114 int width = w; \
115 \
116 do { \
117 int16x8_t s0, s1, s2, s3, s4, s5, s6; \
118 load_s16_8x7(s, 1, &s0, &s1, &s2, &s3, &s4, &s5, &s6); \
119 \
120 uint16x8_t d0 = name##_wiener_convolve7_8_2d_h( \
121 s0, s1, s2, s3, s4, s5, s6, x_filter, round_vec, im_max_val); \
122 \
123 vst1q_u16(d, d0); \
124 \
125 s += 8; \
126 d += 8; \
127 width -= 8; \
128 } while (width != 0); \
129 src_ptr += src_stride; \
130 dst_ptr += dst_stride; \
131 } while (--h != 0); \
132 }
133
134 HBD_WIENER_7TAP_HORIZ(highbd, WIENER_ROUND0_BITS)
135 HBD_WIENER_7TAP_HORIZ(highbd_12, WIENER_ROUND0_BITS + 2)
136
137 #undef HBD_WIENER_7TAP_HORIZ
138
139 #define HBD_WIENER_5TAP_VERT(name, shift) \
140 static inline uint16x8_t name##_wiener_convolve5_8_2d_v( \
141 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, \
142 const int16x8_t s3, const int16x8_t s4, const int16x4_t y_filter, \
143 const int32x4_t round_vec, const uint16x8_t res_max_val) { \
144 const int32x2_t y_filter_lo = vget_low_s32(vmovl_s16(y_filter)); \
145 const int32x2_t y_filter_hi = vget_high_s32(vmovl_s16(y_filter)); \
146 /* Wiener filter is symmetric so add mirrored source elements. */ \
147 int32x4_t s04_lo = vaddl_s16(vget_low_s16(s0), vget_low_s16(s4)); \
148 int32x4_t s13_lo = vaddl_s16(vget_low_s16(s1), vget_low_s16(s3)); \
149 \
150 /* y_filter[0] = 0. (5-tap filters are 0-padded to 7 taps.) */ \
151 int32x4_t sum_lo = vmlaq_lane_s32(round_vec, s04_lo, y_filter_lo, 1); \
152 sum_lo = vmlaq_lane_s32(sum_lo, s13_lo, y_filter_hi, 0); \
153 sum_lo = \
154 vmlaq_lane_s32(sum_lo, vmovl_s16(vget_low_s16(s2)), y_filter_hi, 1); \
155 \
156 int32x4_t s04_hi = vaddl_s16(vget_high_s16(s0), vget_high_s16(s4)); \
157 int32x4_t s13_hi = vaddl_s16(vget_high_s16(s1), vget_high_s16(s3)); \
158 \
159 int32x4_t sum_hi = vmlaq_lane_s32(round_vec, s04_hi, y_filter_lo, 1); \
160 sum_hi = vmlaq_lane_s32(sum_hi, s13_hi, y_filter_hi, 0); \
161 sum_hi = \
162 vmlaq_lane_s32(sum_hi, vmovl_s16(vget_high_s16(s2)), y_filter_hi, 1); \
163 \
164 uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift); \
165 uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift); \
166 \
167 return vminq_u16(vcombine_u16(res_lo, res_hi), res_max_val); \
168 } \
169 \
170 static inline void name##_convolve_add_src_5tap_vert( \
171 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, \
172 ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter, \
173 const int32x4_t round_vec, const uint16x8_t res_max_val) { \
174 do { \
175 const int16_t *s = (int16_t *)src_ptr; \
176 uint16_t *d = dst_ptr; \
177 int height = h; \
178 \
179 while (height > 3) { \
180 int16x8_t s0, s1, s2, s3, s4, s5, s6, s7; \
181 load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); \
182 \
183 uint16x8_t d0 = name##_wiener_convolve5_8_2d_v( \
184 s0, s1, s2, s3, s4, y_filter, round_vec, res_max_val); \
185 uint16x8_t d1 = name##_wiener_convolve5_8_2d_v( \
186 s1, s2, s3, s4, s5, y_filter, round_vec, res_max_val); \
187 uint16x8_t d2 = name##_wiener_convolve5_8_2d_v( \
188 s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val); \
189 uint16x8_t d3 = name##_wiener_convolve5_8_2d_v( \
190 s3, s4, s5, s6, s7, y_filter, round_vec, res_max_val); \
191 \
192 store_u16_8x4(d, dst_stride, d0, d1, d2, d3); \
193 \
194 s += 4 * src_stride; \
195 d += 4 * dst_stride; \
196 height -= 4; \
197 } \
198 \
199 while (height-- != 0) { \
200 int16x8_t s0, s1, s2, s3, s4; \
201 load_s16_8x5(s, src_stride, &s0, &s1, &s2, &s3, &s4); \
202 \
203 uint16x8_t d0 = name##_wiener_convolve5_8_2d_v( \
204 s0, s1, s2, s3, s4, y_filter, round_vec, res_max_val); \
205 \
206 vst1q_u16(d, d0); \
207 \
208 s += src_stride; \
209 d += dst_stride; \
210 } \
211 \
212 src_ptr += 8; \
213 dst_ptr += 8; \
214 w -= 8; \
215 } while (w != 0); \
216 }
217
218 HBD_WIENER_5TAP_VERT(highbd, 2 * FILTER_BITS - WIENER_ROUND0_BITS)
219 HBD_WIENER_5TAP_VERT(highbd_12, 2 * FILTER_BITS - WIENER_ROUND0_BITS - 2)
220
221 #undef HBD_WIENER_5TAP_VERT
222
223 #define HBD_WIENER_7TAP_VERT(name, shift) \
224 static inline uint16x8_t name##_wiener_convolve7_8_2d_v( \
225 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, \
226 const int16x8_t s3, const int16x8_t s4, const int16x8_t s5, \
227 const int16x8_t s6, const int16x4_t y_filter, const int32x4_t round_vec, \
228 const uint16x8_t res_max_val) { \
229 const int32x2_t y_filter_lo = vget_low_s32(vmovl_s16(y_filter)); \
230 const int32x2_t y_filter_hi = vget_high_s32(vmovl_s16(y_filter)); \
231 /* Wiener filter is symmetric so add mirrored source elements. */ \
232 int32x4_t s06_lo = vaddl_s16(vget_low_s16(s0), vget_low_s16(s6)); \
233 int32x4_t s15_lo = vaddl_s16(vget_low_s16(s1), vget_low_s16(s5)); \
234 int32x4_t s24_lo = vaddl_s16(vget_low_s16(s2), vget_low_s16(s4)); \
235 \
236 int32x4_t sum_lo = vmlaq_lane_s32(round_vec, s06_lo, y_filter_lo, 0); \
237 sum_lo = vmlaq_lane_s32(sum_lo, s15_lo, y_filter_lo, 1); \
238 sum_lo = vmlaq_lane_s32(sum_lo, s24_lo, y_filter_hi, 0); \
239 sum_lo = \
240 vmlaq_lane_s32(sum_lo, vmovl_s16(vget_low_s16(s3)), y_filter_hi, 1); \
241 \
242 int32x4_t s06_hi = vaddl_s16(vget_high_s16(s0), vget_high_s16(s6)); \
243 int32x4_t s15_hi = vaddl_s16(vget_high_s16(s1), vget_high_s16(s5)); \
244 int32x4_t s24_hi = vaddl_s16(vget_high_s16(s2), vget_high_s16(s4)); \
245 \
246 int32x4_t sum_hi = vmlaq_lane_s32(round_vec, s06_hi, y_filter_lo, 0); \
247 sum_hi = vmlaq_lane_s32(sum_hi, s15_hi, y_filter_lo, 1); \
248 sum_hi = vmlaq_lane_s32(sum_hi, s24_hi, y_filter_hi, 0); \
249 sum_hi = \
250 vmlaq_lane_s32(sum_hi, vmovl_s16(vget_high_s16(s3)), y_filter_hi, 1); \
251 \
252 uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift); \
253 uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift); \
254 \
255 return vminq_u16(vcombine_u16(res_lo, res_hi), res_max_val); \
256 } \
257 \
258 static inline void name##_convolve_add_src_7tap_vert( \
259 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr, \
260 ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter, \
261 const int32x4_t round_vec, const uint16x8_t res_max_val) { \
262 do { \
263 const int16_t *s = (int16_t *)src_ptr; \
264 uint16_t *d = dst_ptr; \
265 int height = h; \
266 \
267 while (height > 3) { \
268 int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9; \
269 load_s16_8x10(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, \
270 &s8, &s9); \
271 \
272 uint16x8_t d0 = name##_wiener_convolve7_8_2d_v( \
273 s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val); \
274 uint16x8_t d1 = name##_wiener_convolve7_8_2d_v( \
275 s1, s2, s3, s4, s5, s6, s7, y_filter, round_vec, res_max_val); \
276 uint16x8_t d2 = name##_wiener_convolve7_8_2d_v( \
277 s2, s3, s4, s5, s6, s7, s8, y_filter, round_vec, res_max_val); \
278 uint16x8_t d3 = name##_wiener_convolve7_8_2d_v( \
279 s3, s4, s5, s6, s7, s8, s9, y_filter, round_vec, res_max_val); \
280 \
281 store_u16_8x4(d, dst_stride, d0, d1, d2, d3); \
282 \
283 s += 4 * src_stride; \
284 d += 4 * dst_stride; \
285 height -= 4; \
286 } \
287 \
288 while (height-- != 0) { \
289 int16x8_t s0, s1, s2, s3, s4, s5, s6; \
290 load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6); \
291 \
292 uint16x8_t d0 = name##_wiener_convolve7_8_2d_v( \
293 s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val); \
294 \
295 vst1q_u16(d, d0); \
296 \
297 s += src_stride; \
298 d += dst_stride; \
299 } \
300 \
301 src_ptr += 8; \
302 dst_ptr += 8; \
303 w -= 8; \
304 } while (w != 0); \
305 }
306
307 HBD_WIENER_7TAP_VERT(highbd, 2 * FILTER_BITS - WIENER_ROUND0_BITS)
308 HBD_WIENER_7TAP_VERT(highbd_12, 2 * FILTER_BITS - WIENER_ROUND0_BITS - 2)
309
310 #undef HBD_WIENER_7TAP_VERT
311
312 static inline int get_wiener_filter_taps(const int16_t *filter) {
313 assert(filter[7] == 0);
314 if (filter[0] == 0 && filter[6] == 0) {
315 return WIENER_WIN_REDUCED;
316 }
317 return WIENER_WIN;
318 }
319
av1_highbd_wiener_convolve_add_src_neon(const uint8_t * src8,ptrdiff_t src_stride,uint8_t * dst8,ptrdiff_t dst_stride,const int16_t * x_filter,int x_step_q4,const int16_t * y_filter,int y_step_q4,int w,int h,const WienerConvolveParams * conv_params,int bd)320 void av1_highbd_wiener_convolve_add_src_neon(
321 const uint8_t *src8, ptrdiff_t src_stride, uint8_t *dst8,
322 ptrdiff_t dst_stride, const int16_t *x_filter, int x_step_q4,
323 const int16_t *y_filter, int y_step_q4, int w, int h,
324 const WienerConvolveParams *conv_params, int bd) {
325 (void)x_step_q4;
326 (void)y_step_q4;
327
328 assert(w % 8 == 0);
329 assert(w <= MAX_SB_SIZE && h <= MAX_SB_SIZE);
330 assert(x_step_q4 == 16 && y_step_q4 == 16);
331 assert(x_filter[7] == 0 && y_filter[7] == 0);
332
333 DECLARE_ALIGNED(16, uint16_t,
334 im_block[(MAX_SB_SIZE + WIENER_WIN - 1) * MAX_SB_SIZE]);
335
336 const int x_filter_taps = get_wiener_filter_taps(x_filter);
337 const int y_filter_taps = get_wiener_filter_taps(y_filter);
338 int16x4_t x_filter_s16 = vld1_s16(x_filter);
339 int16x4_t y_filter_s16 = vld1_s16(y_filter);
340 // Add 128 to tap 3. (Needed for rounding.)
341 x_filter_s16 = vadd_s16(x_filter_s16, vcreate_s16(128ULL << 48));
342 y_filter_s16 = vadd_s16(y_filter_s16, vcreate_s16(128ULL << 48));
343
344 const int im_stride = MAX_SB_SIZE;
345 const int im_h = h + y_filter_taps - 1;
346 const int horiz_offset = x_filter_taps / 2;
347 const int vert_offset = (y_filter_taps / 2) * (int)src_stride;
348
349 const int extraprec_clamp_limit =
350 WIENER_CLAMP_LIMIT(conv_params->round_0, bd);
351 const uint16x8_t im_max_val = vdupq_n_u16(extraprec_clamp_limit - 1);
352 const int32x4_t horiz_round_vec = vdupq_n_s32(1 << (bd + FILTER_BITS - 1));
353
354 const uint16x8_t res_max_val = vdupq_n_u16((1 << bd) - 1);
355 const int32x4_t vert_round_vec =
356 vdupq_n_s32(-(1 << (bd + conv_params->round_1 - 1)));
357
358 uint16_t *src = CONVERT_TO_SHORTPTR(src8);
359 uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
360
361 if (bd == 12) {
362 if (x_filter_taps == WIENER_WIN_REDUCED) {
363 highbd_12_convolve_add_src_5tap_horiz(
364 src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
365 im_h, x_filter_s16, horiz_round_vec, im_max_val);
366 } else {
367 highbd_12_convolve_add_src_7tap_horiz(
368 src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
369 im_h, x_filter_s16, horiz_round_vec, im_max_val);
370 }
371
372 if (y_filter_taps == WIENER_WIN_REDUCED) {
373 highbd_12_convolve_add_src_5tap_vert(im_block, im_stride, dst, dst_stride,
374 w, h, y_filter_s16, vert_round_vec,
375 res_max_val);
376 } else {
377 highbd_12_convolve_add_src_7tap_vert(im_block, im_stride, dst, dst_stride,
378 w, h, y_filter_s16, vert_round_vec,
379 res_max_val);
380 }
381
382 } else {
383 if (x_filter_taps == WIENER_WIN_REDUCED) {
384 highbd_convolve_add_src_5tap_horiz(
385 src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
386 im_h, x_filter_s16, horiz_round_vec, im_max_val);
387 } else {
388 highbd_convolve_add_src_7tap_horiz(
389 src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
390 im_h, x_filter_s16, horiz_round_vec, im_max_val);
391 }
392
393 if (y_filter_taps == WIENER_WIN_REDUCED) {
394 highbd_convolve_add_src_5tap_vert(im_block, im_stride, dst, dst_stride, w,
395 h, y_filter_s16, vert_round_vec,
396 res_max_val);
397 } else {
398 highbd_convolve_add_src_7tap_vert(im_block, im_stride, dst, dst_stride, w,
399 h, y_filter_s16, vert_round_vec,
400 res_max_val);
401 }
402 }
403 }
404