1 /*
2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/txfm_common.h"
20 #include "aom_dsp/arm/mem_neon.h"
21 #include "aom_dsp/arm/transpose_neon.h"
22 #include "aom_mem/aom_mem.h"
23 #include "aom_ports/mem.h"
24 #include "av1/common/av1_common_int.h"
25 #include "av1/common/common.h"
26 #include "av1/common/resize.h"
27 #include "av1/common/restoration.h"
28
29 // Constants used for right shift in final_filter calculation.
30 #define NB_EVEN 5
31 #define NB_ODD 4
32
calc_ab_fast_internal_common(uint32x4_t s0,uint32x4_t s1,uint32x4_t s2,uint32x4_t s3,uint32x4_t s4,uint32x4_t s5,uint32x4_t s6,uint32x4_t s7,int32x4_t sr4,int32x4_t sr5,int32x4_t sr6,int32x4_t sr7,uint32x4_t const_n_val,uint32x4_t s_vec,uint32x4_t const_val,uint32x4_t one_by_n_minus_1_vec,uint16x4_t sgrproj_sgr,int32_t * src1,uint16_t * dst_A16,int32_t * src2,const int buf_stride)33 static inline void calc_ab_fast_internal_common(
34 uint32x4_t s0, uint32x4_t s1, uint32x4_t s2, uint32x4_t s3, uint32x4_t s4,
35 uint32x4_t s5, uint32x4_t s6, uint32x4_t s7, int32x4_t sr4, int32x4_t sr5,
36 int32x4_t sr6, int32x4_t sr7, uint32x4_t const_n_val, uint32x4_t s_vec,
37 uint32x4_t const_val, uint32x4_t one_by_n_minus_1_vec,
38 uint16x4_t sgrproj_sgr, int32_t *src1, uint16_t *dst_A16, int32_t *src2,
39 const int buf_stride) {
40 uint32x4_t q0, q1, q2, q3;
41 uint32x4_t p0, p1, p2, p3;
42 uint16x4_t d0, d1, d2, d3;
43
44 s0 = vmulq_u32(s0, const_n_val);
45 s1 = vmulq_u32(s1, const_n_val);
46 s2 = vmulq_u32(s2, const_n_val);
47 s3 = vmulq_u32(s3, const_n_val);
48
49 q0 = vmulq_u32(s4, s4);
50 q1 = vmulq_u32(s5, s5);
51 q2 = vmulq_u32(s6, s6);
52 q3 = vmulq_u32(s7, s7);
53
54 p0 = vcleq_u32(q0, s0);
55 p1 = vcleq_u32(q1, s1);
56 p2 = vcleq_u32(q2, s2);
57 p3 = vcleq_u32(q3, s3);
58
59 q0 = vsubq_u32(s0, q0);
60 q1 = vsubq_u32(s1, q1);
61 q2 = vsubq_u32(s2, q2);
62 q3 = vsubq_u32(s3, q3);
63
64 p0 = vandq_u32(p0, q0);
65 p1 = vandq_u32(p1, q1);
66 p2 = vandq_u32(p2, q2);
67 p3 = vandq_u32(p3, q3);
68
69 p0 = vmulq_u32(p0, s_vec);
70 p1 = vmulq_u32(p1, s_vec);
71 p2 = vmulq_u32(p2, s_vec);
72 p3 = vmulq_u32(p3, s_vec);
73
74 p0 = vrshrq_n_u32(p0, SGRPROJ_MTABLE_BITS);
75 p1 = vrshrq_n_u32(p1, SGRPROJ_MTABLE_BITS);
76 p2 = vrshrq_n_u32(p2, SGRPROJ_MTABLE_BITS);
77 p3 = vrshrq_n_u32(p3, SGRPROJ_MTABLE_BITS);
78
79 p0 = vminq_u32(p0, const_val);
80 p1 = vminq_u32(p1, const_val);
81 p2 = vminq_u32(p2, const_val);
82 p3 = vminq_u32(p3, const_val);
83
84 {
85 store_u32_4x4((uint32_t *)src1, buf_stride, p0, p1, p2, p3);
86
87 for (int x = 0; x < 4; x++) {
88 for (int y = 0; y < 4; y++) {
89 dst_A16[x * buf_stride + y] = av1_x_by_xplus1[src1[x * buf_stride + y]];
90 }
91 }
92 load_u16_4x4(dst_A16, buf_stride, &d0, &d1, &d2, &d3);
93 }
94 p0 = vsubl_u16(sgrproj_sgr, d0);
95 p1 = vsubl_u16(sgrproj_sgr, d1);
96 p2 = vsubl_u16(sgrproj_sgr, d2);
97 p3 = vsubl_u16(sgrproj_sgr, d3);
98
99 s4 = vmulq_u32(vreinterpretq_u32_s32(sr4), one_by_n_minus_1_vec);
100 s5 = vmulq_u32(vreinterpretq_u32_s32(sr5), one_by_n_minus_1_vec);
101 s6 = vmulq_u32(vreinterpretq_u32_s32(sr6), one_by_n_minus_1_vec);
102 s7 = vmulq_u32(vreinterpretq_u32_s32(sr7), one_by_n_minus_1_vec);
103
104 s4 = vmulq_u32(s4, p0);
105 s5 = vmulq_u32(s5, p1);
106 s6 = vmulq_u32(s6, p2);
107 s7 = vmulq_u32(s7, p3);
108
109 p0 = vrshrq_n_u32(s4, SGRPROJ_RECIP_BITS);
110 p1 = vrshrq_n_u32(s5, SGRPROJ_RECIP_BITS);
111 p2 = vrshrq_n_u32(s6, SGRPROJ_RECIP_BITS);
112 p3 = vrshrq_n_u32(s7, SGRPROJ_RECIP_BITS);
113
114 store_s32_4x4(src2, buf_stride, vreinterpretq_s32_u32(p0),
115 vreinterpretq_s32_u32(p1), vreinterpretq_s32_u32(p2),
116 vreinterpretq_s32_u32(p3));
117 }
calc_ab_internal_common(uint32x4_t s0,uint32x4_t s1,uint32x4_t s2,uint32x4_t s3,uint32x4_t s4,uint32x4_t s5,uint32x4_t s6,uint32x4_t s7,uint16x8_t s16_0,uint16x8_t s16_1,uint16x8_t s16_2,uint16x8_t s16_3,uint16x8_t s16_4,uint16x8_t s16_5,uint16x8_t s16_6,uint16x8_t s16_7,uint32x4_t const_n_val,uint32x4_t s_vec,uint32x4_t const_val,uint16x4_t one_by_n_minus_1_vec,uint16x8_t sgrproj_sgr,int32_t * src1,uint16_t * dst_A16,int32_t * dst2,const int buf_stride)118 static inline void calc_ab_internal_common(
119 uint32x4_t s0, uint32x4_t s1, uint32x4_t s2, uint32x4_t s3, uint32x4_t s4,
120 uint32x4_t s5, uint32x4_t s6, uint32x4_t s7, uint16x8_t s16_0,
121 uint16x8_t s16_1, uint16x8_t s16_2, uint16x8_t s16_3, uint16x8_t s16_4,
122 uint16x8_t s16_5, uint16x8_t s16_6, uint16x8_t s16_7,
123 uint32x4_t const_n_val, uint32x4_t s_vec, uint32x4_t const_val,
124 uint16x4_t one_by_n_minus_1_vec, uint16x8_t sgrproj_sgr, int32_t *src1,
125 uint16_t *dst_A16, int32_t *dst2, const int buf_stride) {
126 uint16x4_t d0, d1, d2, d3, d4, d5, d6, d7;
127 uint32x4_t q0, q1, q2, q3, q4, q5, q6, q7;
128 uint32x4_t p0, p1, p2, p3, p4, p5, p6, p7;
129
130 s0 = vmulq_u32(s0, const_n_val);
131 s1 = vmulq_u32(s1, const_n_val);
132 s2 = vmulq_u32(s2, const_n_val);
133 s3 = vmulq_u32(s3, const_n_val);
134 s4 = vmulq_u32(s4, const_n_val);
135 s5 = vmulq_u32(s5, const_n_val);
136 s6 = vmulq_u32(s6, const_n_val);
137 s7 = vmulq_u32(s7, const_n_val);
138
139 d0 = vget_low_u16(s16_4);
140 d1 = vget_low_u16(s16_5);
141 d2 = vget_low_u16(s16_6);
142 d3 = vget_low_u16(s16_7);
143 d4 = vget_high_u16(s16_4);
144 d5 = vget_high_u16(s16_5);
145 d6 = vget_high_u16(s16_6);
146 d7 = vget_high_u16(s16_7);
147
148 q0 = vmull_u16(d0, d0);
149 q1 = vmull_u16(d1, d1);
150 q2 = vmull_u16(d2, d2);
151 q3 = vmull_u16(d3, d3);
152 q4 = vmull_u16(d4, d4);
153 q5 = vmull_u16(d5, d5);
154 q6 = vmull_u16(d6, d6);
155 q7 = vmull_u16(d7, d7);
156
157 p0 = vcleq_u32(q0, s0);
158 p1 = vcleq_u32(q1, s1);
159 p2 = vcleq_u32(q2, s2);
160 p3 = vcleq_u32(q3, s3);
161 p4 = vcleq_u32(q4, s4);
162 p5 = vcleq_u32(q5, s5);
163 p6 = vcleq_u32(q6, s6);
164 p7 = vcleq_u32(q7, s7);
165
166 q0 = vsubq_u32(s0, q0);
167 q1 = vsubq_u32(s1, q1);
168 q2 = vsubq_u32(s2, q2);
169 q3 = vsubq_u32(s3, q3);
170 q4 = vsubq_u32(s4, q4);
171 q5 = vsubq_u32(s5, q5);
172 q6 = vsubq_u32(s6, q6);
173 q7 = vsubq_u32(s7, q7);
174
175 p0 = vandq_u32(p0, q0);
176 p1 = vandq_u32(p1, q1);
177 p2 = vandq_u32(p2, q2);
178 p3 = vandq_u32(p3, q3);
179 p4 = vandq_u32(p4, q4);
180 p5 = vandq_u32(p5, q5);
181 p6 = vandq_u32(p6, q6);
182 p7 = vandq_u32(p7, q7);
183
184 p0 = vmulq_u32(p0, s_vec);
185 p1 = vmulq_u32(p1, s_vec);
186 p2 = vmulq_u32(p2, s_vec);
187 p3 = vmulq_u32(p3, s_vec);
188 p4 = vmulq_u32(p4, s_vec);
189 p5 = vmulq_u32(p5, s_vec);
190 p6 = vmulq_u32(p6, s_vec);
191 p7 = vmulq_u32(p7, s_vec);
192
193 p0 = vrshrq_n_u32(p0, SGRPROJ_MTABLE_BITS);
194 p1 = vrshrq_n_u32(p1, SGRPROJ_MTABLE_BITS);
195 p2 = vrshrq_n_u32(p2, SGRPROJ_MTABLE_BITS);
196 p3 = vrshrq_n_u32(p3, SGRPROJ_MTABLE_BITS);
197 p4 = vrshrq_n_u32(p4, SGRPROJ_MTABLE_BITS);
198 p5 = vrshrq_n_u32(p5, SGRPROJ_MTABLE_BITS);
199 p6 = vrshrq_n_u32(p6, SGRPROJ_MTABLE_BITS);
200 p7 = vrshrq_n_u32(p7, SGRPROJ_MTABLE_BITS);
201
202 p0 = vminq_u32(p0, const_val);
203 p1 = vminq_u32(p1, const_val);
204 p2 = vminq_u32(p2, const_val);
205 p3 = vminq_u32(p3, const_val);
206 p4 = vminq_u32(p4, const_val);
207 p5 = vminq_u32(p5, const_val);
208 p6 = vminq_u32(p6, const_val);
209 p7 = vminq_u32(p7, const_val);
210
211 {
212 store_u32_4x4((uint32_t *)src1, buf_stride, p0, p1, p2, p3);
213 store_u32_4x4((uint32_t *)src1 + 4, buf_stride, p4, p5, p6, p7);
214
215 for (int x = 0; x < 4; x++) {
216 for (int y = 0; y < 8; y++) {
217 dst_A16[x * buf_stride + y] = av1_x_by_xplus1[src1[x * buf_stride + y]];
218 }
219 }
220 load_u16_8x4(dst_A16, buf_stride, &s16_4, &s16_5, &s16_6, &s16_7);
221 }
222
223 s16_4 = vsubq_u16(sgrproj_sgr, s16_4);
224 s16_5 = vsubq_u16(sgrproj_sgr, s16_5);
225 s16_6 = vsubq_u16(sgrproj_sgr, s16_6);
226 s16_7 = vsubq_u16(sgrproj_sgr, s16_7);
227
228 s0 = vmull_u16(vget_low_u16(s16_0), one_by_n_minus_1_vec);
229 s1 = vmull_u16(vget_low_u16(s16_1), one_by_n_minus_1_vec);
230 s2 = vmull_u16(vget_low_u16(s16_2), one_by_n_minus_1_vec);
231 s3 = vmull_u16(vget_low_u16(s16_3), one_by_n_minus_1_vec);
232 s4 = vmull_u16(vget_high_u16(s16_0), one_by_n_minus_1_vec);
233 s5 = vmull_u16(vget_high_u16(s16_1), one_by_n_minus_1_vec);
234 s6 = vmull_u16(vget_high_u16(s16_2), one_by_n_minus_1_vec);
235 s7 = vmull_u16(vget_high_u16(s16_3), one_by_n_minus_1_vec);
236
237 s0 = vmulq_u32(s0, vmovl_u16(vget_low_u16(s16_4)));
238 s1 = vmulq_u32(s1, vmovl_u16(vget_low_u16(s16_5)));
239 s2 = vmulq_u32(s2, vmovl_u16(vget_low_u16(s16_6)));
240 s3 = vmulq_u32(s3, vmovl_u16(vget_low_u16(s16_7)));
241 s4 = vmulq_u32(s4, vmovl_u16(vget_high_u16(s16_4)));
242 s5 = vmulq_u32(s5, vmovl_u16(vget_high_u16(s16_5)));
243 s6 = vmulq_u32(s6, vmovl_u16(vget_high_u16(s16_6)));
244 s7 = vmulq_u32(s7, vmovl_u16(vget_high_u16(s16_7)));
245
246 p0 = vrshrq_n_u32(s0, SGRPROJ_RECIP_BITS);
247 p1 = vrshrq_n_u32(s1, SGRPROJ_RECIP_BITS);
248 p2 = vrshrq_n_u32(s2, SGRPROJ_RECIP_BITS);
249 p3 = vrshrq_n_u32(s3, SGRPROJ_RECIP_BITS);
250 p4 = vrshrq_n_u32(s4, SGRPROJ_RECIP_BITS);
251 p5 = vrshrq_n_u32(s5, SGRPROJ_RECIP_BITS);
252 p6 = vrshrq_n_u32(s6, SGRPROJ_RECIP_BITS);
253 p7 = vrshrq_n_u32(s7, SGRPROJ_RECIP_BITS);
254
255 store_s32_4x4(dst2, buf_stride, vreinterpretq_s32_u32(p0),
256 vreinterpretq_s32_u32(p1), vreinterpretq_s32_u32(p2),
257 vreinterpretq_s32_u32(p3));
258 store_s32_4x4(dst2 + 4, buf_stride, vreinterpretq_s32_u32(p4),
259 vreinterpretq_s32_u32(p5), vreinterpretq_s32_u32(p6),
260 vreinterpretq_s32_u32(p7));
261 }
262
boxsum2_square_sum_calc(int16x4_t t1,int16x4_t t2,int16x4_t t3,int16x4_t t4,int16x4_t t5,int16x4_t t6,int16x4_t t7,int16x4_t t8,int16x4_t t9,int16x4_t t10,int16x4_t t11,int32x4_t * r0,int32x4_t * r1,int32x4_t * r2,int32x4_t * r3)263 static inline void boxsum2_square_sum_calc(
264 int16x4_t t1, int16x4_t t2, int16x4_t t3, int16x4_t t4, int16x4_t t5,
265 int16x4_t t6, int16x4_t t7, int16x4_t t8, int16x4_t t9, int16x4_t t10,
266 int16x4_t t11, int32x4_t *r0, int32x4_t *r1, int32x4_t *r2, int32x4_t *r3) {
267 int32x4_t d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11;
268 int32x4_t r12, r34, r67, r89, r1011;
269 int32x4_t r345, r6789, r789;
270
271 d1 = vmull_s16(t1, t1);
272 d2 = vmull_s16(t2, t2);
273 d3 = vmull_s16(t3, t3);
274 d4 = vmull_s16(t4, t4);
275 d5 = vmull_s16(t5, t5);
276 d6 = vmull_s16(t6, t6);
277 d7 = vmull_s16(t7, t7);
278 d8 = vmull_s16(t8, t8);
279 d9 = vmull_s16(t9, t9);
280 d10 = vmull_s16(t10, t10);
281 d11 = vmull_s16(t11, t11);
282
283 r12 = vaddq_s32(d1, d2);
284 r34 = vaddq_s32(d3, d4);
285 r67 = vaddq_s32(d6, d7);
286 r89 = vaddq_s32(d8, d9);
287 r1011 = vaddq_s32(d10, d11);
288 r345 = vaddq_s32(r34, d5);
289 r6789 = vaddq_s32(r67, r89);
290 r789 = vsubq_s32(r6789, d6);
291 *r0 = vaddq_s32(r12, r345);
292 *r1 = vaddq_s32(r67, r345);
293 *r2 = vaddq_s32(d5, r6789);
294 *r3 = vaddq_s32(r789, r1011);
295 }
296
boxsum2(int16_t * src,const int src_stride,int16_t * dst16,int32_t * dst32,int32_t * dst2,const int dst_stride,const int width,const int height)297 static inline void boxsum2(int16_t *src, const int src_stride, int16_t *dst16,
298 int32_t *dst32, int32_t *dst2, const int dst_stride,
299 const int width, const int height) {
300 assert(width > 2 * SGRPROJ_BORDER_HORZ);
301 assert(height > 2 * SGRPROJ_BORDER_VERT);
302
303 int16_t *dst1_16_ptr, *src_ptr;
304 int32_t *dst2_ptr;
305 int h, w, count = 0;
306 const int dst_stride_2 = (dst_stride << 1);
307 const int dst_stride_8 = (dst_stride << 3);
308
309 dst1_16_ptr = dst16;
310 dst2_ptr = dst2;
311 src_ptr = src;
312 w = width;
313 {
314 int16x8_t t1, t2, t3, t4, t5, t6, t7;
315 int16x8_t t8, t9, t10, t11, t12;
316
317 int16x8_t q12345, q56789, q34567, q7891011;
318 int16x8_t q12, q34, q67, q89, q1011;
319 int16x8_t q345, q6789, q789;
320
321 int32x4_t r12345, r56789, r34567, r7891011;
322
323 do {
324 h = height;
325 dst1_16_ptr = dst16 + (count << 3);
326 dst2_ptr = dst2 + (count << 3);
327 src_ptr = src + (count << 3);
328
329 dst1_16_ptr += dst_stride_2;
330 dst2_ptr += dst_stride_2;
331 do {
332 load_s16_8x4(src_ptr, src_stride, &t1, &t2, &t3, &t4);
333 src_ptr += 4 * src_stride;
334 load_s16_8x4(src_ptr, src_stride, &t5, &t6, &t7, &t8);
335 src_ptr += 4 * src_stride;
336 load_s16_8x4(src_ptr, src_stride, &t9, &t10, &t11, &t12);
337
338 q12 = vaddq_s16(t1, t2);
339 q34 = vaddq_s16(t3, t4);
340 q67 = vaddq_s16(t6, t7);
341 q89 = vaddq_s16(t8, t9);
342 q1011 = vaddq_s16(t10, t11);
343 q345 = vaddq_s16(q34, t5);
344 q6789 = vaddq_s16(q67, q89);
345 q789 = vaddq_s16(q89, t7);
346 q12345 = vaddq_s16(q12, q345);
347 q34567 = vaddq_s16(q67, q345);
348 q56789 = vaddq_s16(t5, q6789);
349 q7891011 = vaddq_s16(q789, q1011);
350
351 store_s16_8x4(dst1_16_ptr, dst_stride_2, q12345, q34567, q56789,
352 q7891011);
353 dst1_16_ptr += dst_stride_8;
354
355 boxsum2_square_sum_calc(
356 vget_low_s16(t1), vget_low_s16(t2), vget_low_s16(t3),
357 vget_low_s16(t4), vget_low_s16(t5), vget_low_s16(t6),
358 vget_low_s16(t7), vget_low_s16(t8), vget_low_s16(t9),
359 vget_low_s16(t10), vget_low_s16(t11), &r12345, &r34567, &r56789,
360 &r7891011);
361
362 store_s32_4x4(dst2_ptr, dst_stride_2, r12345, r34567, r56789, r7891011);
363
364 boxsum2_square_sum_calc(
365 vget_high_s16(t1), vget_high_s16(t2), vget_high_s16(t3),
366 vget_high_s16(t4), vget_high_s16(t5), vget_high_s16(t6),
367 vget_high_s16(t7), vget_high_s16(t8), vget_high_s16(t9),
368 vget_high_s16(t10), vget_high_s16(t11), &r12345, &r34567, &r56789,
369 &r7891011);
370
371 store_s32_4x4(dst2_ptr + 4, dst_stride_2, r12345, r34567, r56789,
372 r7891011);
373 dst2_ptr += (dst_stride_8);
374 h -= 8;
375 } while (h > 0);
376 w -= 8;
377 count++;
378 } while (w > 0);
379
380 // memset needed for row pixels as 2nd stage of boxsum filter uses
381 // first 2 rows of dst16, dst2 buffer which is not filled in first stage.
382 for (int x = 0; x < 2; x++) {
383 memset(dst16 + x * dst_stride, 0, (width + 4) * sizeof(*dst16));
384 memset(dst2 + x * dst_stride, 0, (width + 4) * sizeof(*dst2));
385 }
386
387 // memset needed for extra columns as 2nd stage of boxsum filter uses
388 // last 2 columns of dst16, dst2 buffer which is not filled in first stage.
389 for (int x = 2; x < height + 2; x++) {
390 int dst_offset = x * dst_stride + width + 2;
391 memset(dst16 + dst_offset, 0, 3 * sizeof(*dst16));
392 memset(dst2 + dst_offset, 0, 3 * sizeof(*dst2));
393 }
394 }
395
396 {
397 int16x4_t s1, s2, s3, s4, s5, s6, s7, s8;
398 int32x4_t d1, d2, d3, d4, d5, d6, d7, d8;
399 int32x4_t q12345, q34567, q23456, q45678;
400 int32x4_t q23, q45, q67;
401 int32x4_t q2345, q4567;
402
403 int32x4_t r12345, r34567, r23456, r45678;
404 int32x4_t r23, r45, r67;
405 int32x4_t r2345, r4567;
406
407 int32_t *src2_ptr, *dst1_32_ptr;
408 int16_t *src1_ptr;
409 count = 0;
410 h = height;
411 do {
412 dst1_32_ptr = dst32 + count * dst_stride_8 + (dst_stride_2);
413 dst2_ptr = dst2 + count * dst_stride_8 + (dst_stride_2);
414 src1_ptr = dst16 + count * dst_stride_8 + (dst_stride_2);
415 src2_ptr = dst2 + count * dst_stride_8 + (dst_stride_2);
416 w = width;
417
418 dst1_32_ptr += 2;
419 dst2_ptr += 2;
420 load_s16_4x4(src1_ptr, dst_stride_2, &s1, &s2, &s3, &s4);
421 transpose_elems_inplace_s16_4x4(&s1, &s2, &s3, &s4);
422 load_s32_4x4(src2_ptr, dst_stride_2, &d1, &d2, &d3, &d4);
423 transpose_elems_inplace_s32_4x4(&d1, &d2, &d3, &d4);
424 do {
425 src1_ptr += 4;
426 src2_ptr += 4;
427 load_s16_4x4(src1_ptr, dst_stride_2, &s5, &s6, &s7, &s8);
428 transpose_elems_inplace_s16_4x4(&s5, &s6, &s7, &s8);
429 load_s32_4x4(src2_ptr, dst_stride_2, &d5, &d6, &d7, &d8);
430 transpose_elems_inplace_s32_4x4(&d5, &d6, &d7, &d8);
431 q23 = vaddl_s16(s2, s3);
432 q45 = vaddl_s16(s4, s5);
433 q67 = vaddl_s16(s6, s7);
434 q2345 = vaddq_s32(q23, q45);
435 q4567 = vaddq_s32(q45, q67);
436 q12345 = vaddq_s32(vmovl_s16(s1), q2345);
437 q23456 = vaddq_s32(q2345, vmovl_s16(s6));
438 q34567 = vaddq_s32(q4567, vmovl_s16(s3));
439 q45678 = vaddq_s32(q4567, vmovl_s16(s8));
440
441 transpose_elems_inplace_s32_4x4(&q12345, &q23456, &q34567, &q45678);
442 store_s32_4x4(dst1_32_ptr, dst_stride_2, q12345, q23456, q34567,
443 q45678);
444 dst1_32_ptr += 4;
445 s1 = s5;
446 s2 = s6;
447 s3 = s7;
448 s4 = s8;
449
450 r23 = vaddq_s32(d2, d3);
451 r45 = vaddq_s32(d4, d5);
452 r67 = vaddq_s32(d6, d7);
453 r2345 = vaddq_s32(r23, r45);
454 r4567 = vaddq_s32(r45, r67);
455 r12345 = vaddq_s32(d1, r2345);
456 r23456 = vaddq_s32(r2345, d6);
457 r34567 = vaddq_s32(r4567, d3);
458 r45678 = vaddq_s32(r4567, d8);
459
460 transpose_elems_inplace_s32_4x4(&r12345, &r23456, &r34567, &r45678);
461 store_s32_4x4(dst2_ptr, dst_stride_2, r12345, r23456, r34567, r45678);
462 dst2_ptr += 4;
463 d1 = d5;
464 d2 = d6;
465 d3 = d7;
466 d4 = d8;
467 w -= 4;
468 } while (w > 0);
469 h -= 8;
470 count++;
471 } while (h > 0);
472 }
473 }
474
calc_ab_internal_lbd(int32_t * A,uint16_t * A16,uint16_t * B16,int32_t * B,const int buf_stride,const int width,const int height,const int r,const int s,const int ht_inc)475 static inline void calc_ab_internal_lbd(int32_t *A, uint16_t *A16,
476 uint16_t *B16, int32_t *B,
477 const int buf_stride, const int width,
478 const int height, const int r,
479 const int s, const int ht_inc) {
480 int32_t *src1, *dst2, count = 0;
481 uint16_t *dst_A16, *src2;
482 const uint32_t n = (2 * r + 1) * (2 * r + 1);
483 const uint32x4_t const_n_val = vdupq_n_u32(n);
484 const uint16x8_t sgrproj_sgr = vdupq_n_u16(SGRPROJ_SGR);
485 const uint16x4_t one_by_n_minus_1_vec = vdup_n_u16(av1_one_by_x[n - 1]);
486 const uint32x4_t const_val = vdupq_n_u32(255);
487
488 uint16x8_t s16_0, s16_1, s16_2, s16_3, s16_4, s16_5, s16_6, s16_7;
489
490 uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
491
492 const uint32x4_t s_vec = vdupq_n_u32(s);
493 int w, h = height;
494
495 do {
496 dst_A16 = A16 + (count << 2) * buf_stride;
497 src1 = A + (count << 2) * buf_stride;
498 src2 = B16 + (count << 2) * buf_stride;
499 dst2 = B + (count << 2) * buf_stride;
500 w = width;
501 do {
502 load_u32_4x4((uint32_t *)src1, buf_stride, &s0, &s1, &s2, &s3);
503 load_u32_4x4((uint32_t *)src1 + 4, buf_stride, &s4, &s5, &s6, &s7);
504 load_u16_8x4(src2, buf_stride, &s16_0, &s16_1, &s16_2, &s16_3);
505
506 s16_4 = s16_0;
507 s16_5 = s16_1;
508 s16_6 = s16_2;
509 s16_7 = s16_3;
510
511 calc_ab_internal_common(
512 s0, s1, s2, s3, s4, s5, s6, s7, s16_0, s16_1, s16_2, s16_3, s16_4,
513 s16_5, s16_6, s16_7, const_n_val, s_vec, const_val,
514 one_by_n_minus_1_vec, sgrproj_sgr, src1, dst_A16, dst2, buf_stride);
515
516 w -= 8;
517 dst2 += 8;
518 src1 += 8;
519 src2 += 8;
520 dst_A16 += 8;
521 } while (w > 0);
522 count++;
523 h -= (ht_inc * 4);
524 } while (h > 0);
525 }
526
527 #if CONFIG_AV1_HIGHBITDEPTH
calc_ab_internal_hbd(int32_t * A,uint16_t * A16,uint16_t * B16,int32_t * B,const int buf_stride,const int width,const int height,const int bit_depth,const int r,const int s,const int ht_inc)528 static inline void calc_ab_internal_hbd(int32_t *A, uint16_t *A16,
529 uint16_t *B16, int32_t *B,
530 const int buf_stride, const int width,
531 const int height, const int bit_depth,
532 const int r, const int s,
533 const int ht_inc) {
534 int32_t *src1, *dst2, count = 0;
535 uint16_t *dst_A16, *src2;
536 const uint32_t n = (2 * r + 1) * (2 * r + 1);
537 const int16x8_t bd_min_2_vec = vdupq_n_s16(-(bit_depth - 8));
538 const int32x4_t bd_min_1_vec = vdupq_n_s32(-((bit_depth - 8) << 1));
539 const uint32x4_t const_n_val = vdupq_n_u32(n);
540 const uint16x8_t sgrproj_sgr = vdupq_n_u16(SGRPROJ_SGR);
541 const uint16x4_t one_by_n_minus_1_vec = vdup_n_u16(av1_one_by_x[n - 1]);
542 const uint32x4_t const_val = vdupq_n_u32(255);
543
544 int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7;
545 uint16x8_t s16_0, s16_1, s16_2, s16_3;
546 uint16x8_t s16_4, s16_5, s16_6, s16_7;
547 uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
548
549 const uint32x4_t s_vec = vdupq_n_u32(s);
550 int w, h = height;
551
552 do {
553 src1 = A + (count << 2) * buf_stride;
554 src2 = B16 + (count << 2) * buf_stride;
555 dst2 = B + (count << 2) * buf_stride;
556 dst_A16 = A16 + (count << 2) * buf_stride;
557 w = width;
558 do {
559 load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3);
560 load_s32_4x4(src1 + 4, buf_stride, &sr4, &sr5, &sr6, &sr7);
561 load_u16_8x4(src2, buf_stride, &s16_0, &s16_1, &s16_2, &s16_3);
562
563 s0 = vrshlq_u32(vreinterpretq_u32_s32(sr0), bd_min_1_vec);
564 s1 = vrshlq_u32(vreinterpretq_u32_s32(sr1), bd_min_1_vec);
565 s2 = vrshlq_u32(vreinterpretq_u32_s32(sr2), bd_min_1_vec);
566 s3 = vrshlq_u32(vreinterpretq_u32_s32(sr3), bd_min_1_vec);
567 s4 = vrshlq_u32(vreinterpretq_u32_s32(sr4), bd_min_1_vec);
568 s5 = vrshlq_u32(vreinterpretq_u32_s32(sr5), bd_min_1_vec);
569 s6 = vrshlq_u32(vreinterpretq_u32_s32(sr6), bd_min_1_vec);
570 s7 = vrshlq_u32(vreinterpretq_u32_s32(sr7), bd_min_1_vec);
571
572 s16_4 = vrshlq_u16(s16_0, bd_min_2_vec);
573 s16_5 = vrshlq_u16(s16_1, bd_min_2_vec);
574 s16_6 = vrshlq_u16(s16_2, bd_min_2_vec);
575 s16_7 = vrshlq_u16(s16_3, bd_min_2_vec);
576
577 calc_ab_internal_common(
578 s0, s1, s2, s3, s4, s5, s6, s7, s16_0, s16_1, s16_2, s16_3, s16_4,
579 s16_5, s16_6, s16_7, const_n_val, s_vec, const_val,
580 one_by_n_minus_1_vec, sgrproj_sgr, src1, dst_A16, dst2, buf_stride);
581
582 w -= 8;
583 dst2 += 8;
584 src1 += 8;
585 src2 += 8;
586 dst_A16 += 8;
587 } while (w > 0);
588 count++;
589 h -= (ht_inc * 4);
590 } while (h > 0);
591 }
592 #endif // CONFIG_AV1_HIGHBITDEPTH
593
calc_ab_fast_internal_lbd(int32_t * A,uint16_t * A16,int32_t * B,const int buf_stride,const int width,const int height,const int r,const int s,const int ht_inc)594 static inline void calc_ab_fast_internal_lbd(int32_t *A, uint16_t *A16,
595 int32_t *B, const int buf_stride,
596 const int width, const int height,
597 const int r, const int s,
598 const int ht_inc) {
599 int32_t *src1, *src2, count = 0;
600 uint16_t *dst_A16;
601 const uint32_t n = (2 * r + 1) * (2 * r + 1);
602 const uint32x4_t const_n_val = vdupq_n_u32(n);
603 const uint16x4_t sgrproj_sgr = vdup_n_u16(SGRPROJ_SGR);
604 const uint32x4_t one_by_n_minus_1_vec = vdupq_n_u32(av1_one_by_x[n - 1]);
605 const uint32x4_t const_val = vdupq_n_u32(255);
606
607 int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7;
608 uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
609
610 const uint32x4_t s_vec = vdupq_n_u32(s);
611 int w, h = height;
612
613 do {
614 src1 = A + (count << 2) * buf_stride;
615 src2 = B + (count << 2) * buf_stride;
616 dst_A16 = A16 + (count << 2) * buf_stride;
617 w = width;
618 do {
619 load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3);
620 load_s32_4x4(src2, buf_stride, &sr4, &sr5, &sr6, &sr7);
621
622 s0 = vreinterpretq_u32_s32(sr0);
623 s1 = vreinterpretq_u32_s32(sr1);
624 s2 = vreinterpretq_u32_s32(sr2);
625 s3 = vreinterpretq_u32_s32(sr3);
626 s4 = vreinterpretq_u32_s32(sr4);
627 s5 = vreinterpretq_u32_s32(sr5);
628 s6 = vreinterpretq_u32_s32(sr6);
629 s7 = vreinterpretq_u32_s32(sr7);
630
631 calc_ab_fast_internal_common(s0, s1, s2, s3, s4, s5, s6, s7, sr4, sr5,
632 sr6, sr7, const_n_val, s_vec, const_val,
633 one_by_n_minus_1_vec, sgrproj_sgr, src1,
634 dst_A16, src2, buf_stride);
635
636 w -= 4;
637 src1 += 4;
638 src2 += 4;
639 dst_A16 += 4;
640 } while (w > 0);
641 count++;
642 h -= (ht_inc * 4);
643 } while (h > 0);
644 }
645
646 #if CONFIG_AV1_HIGHBITDEPTH
calc_ab_fast_internal_hbd(int32_t * A,uint16_t * A16,int32_t * B,const int buf_stride,const int width,const int height,const int bit_depth,const int r,const int s,const int ht_inc)647 static inline void calc_ab_fast_internal_hbd(int32_t *A, uint16_t *A16,
648 int32_t *B, const int buf_stride,
649 const int width, const int height,
650 const int bit_depth, const int r,
651 const int s, const int ht_inc) {
652 int32_t *src1, *src2, count = 0;
653 uint16_t *dst_A16;
654 const uint32_t n = (2 * r + 1) * (2 * r + 1);
655 const int32x4_t bd_min_2_vec = vdupq_n_s32(-(bit_depth - 8));
656 const int32x4_t bd_min_1_vec = vdupq_n_s32(-((bit_depth - 8) << 1));
657 const uint32x4_t const_n_val = vdupq_n_u32(n);
658 const uint16x4_t sgrproj_sgr = vdup_n_u16(SGRPROJ_SGR);
659 const uint32x4_t one_by_n_minus_1_vec = vdupq_n_u32(av1_one_by_x[n - 1]);
660 const uint32x4_t const_val = vdupq_n_u32(255);
661
662 int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7;
663 uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
664
665 const uint32x4_t s_vec = vdupq_n_u32(s);
666 int w, h = height;
667
668 do {
669 src1 = A + (count << 2) * buf_stride;
670 src2 = B + (count << 2) * buf_stride;
671 dst_A16 = A16 + (count << 2) * buf_stride;
672 w = width;
673 do {
674 load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3);
675 load_s32_4x4(src2, buf_stride, &sr4, &sr5, &sr6, &sr7);
676
677 s0 = vrshlq_u32(vreinterpretq_u32_s32(sr0), bd_min_1_vec);
678 s1 = vrshlq_u32(vreinterpretq_u32_s32(sr1), bd_min_1_vec);
679 s2 = vrshlq_u32(vreinterpretq_u32_s32(sr2), bd_min_1_vec);
680 s3 = vrshlq_u32(vreinterpretq_u32_s32(sr3), bd_min_1_vec);
681 s4 = vrshlq_u32(vreinterpretq_u32_s32(sr4), bd_min_2_vec);
682 s5 = vrshlq_u32(vreinterpretq_u32_s32(sr5), bd_min_2_vec);
683 s6 = vrshlq_u32(vreinterpretq_u32_s32(sr6), bd_min_2_vec);
684 s7 = vrshlq_u32(vreinterpretq_u32_s32(sr7), bd_min_2_vec);
685
686 calc_ab_fast_internal_common(s0, s1, s2, s3, s4, s5, s6, s7, sr4, sr5,
687 sr6, sr7, const_n_val, s_vec, const_val,
688 one_by_n_minus_1_vec, sgrproj_sgr, src1,
689 dst_A16, src2, buf_stride);
690
691 w -= 4;
692 src1 += 4;
693 src2 += 4;
694 dst_A16 += 4;
695 } while (w > 0);
696 count++;
697 h -= (ht_inc * 4);
698 } while (h > 0);
699 }
700 #endif // CONFIG_AV1_HIGHBITDEPTH
701
boxsum1(int16_t * src,const int src_stride,uint16_t * dst1,int32_t * dst2,const int dst_stride,const int width,const int height)702 static inline void boxsum1(int16_t *src, const int src_stride, uint16_t *dst1,
703 int32_t *dst2, const int dst_stride, const int width,
704 const int height) {
705 assert(width > 2 * SGRPROJ_BORDER_HORZ);
706 assert(height > 2 * SGRPROJ_BORDER_VERT);
707
708 int16_t *src_ptr;
709 int32_t *dst2_ptr;
710 uint16_t *dst1_ptr;
711 int h, w, count = 0;
712
713 w = width;
714 {
715 int16x8_t s1, s2, s3, s4, s5, s6, s7, s8;
716 int16x8_t q23, q34, q56, q234, q345, q456, q567;
717 int32x4_t r23, r56, r345, r456, r567, r78, r678;
718 int32x4_t r4_low, r4_high, r34_low, r34_high, r234_low, r234_high;
719 int32x4_t r2, r3, r5, r6, r7, r8;
720 int16x8_t q678, q78;
721
722 do {
723 dst1_ptr = dst1 + (count << 3);
724 dst2_ptr = dst2 + (count << 3);
725 src_ptr = src + (count << 3);
726 h = height;
727
728 load_s16_8x4(src_ptr, src_stride, &s1, &s2, &s3, &s4);
729 src_ptr += 4 * src_stride;
730
731 q23 = vaddq_s16(s2, s3);
732 q234 = vaddq_s16(q23, s4);
733 q34 = vaddq_s16(s3, s4);
734 dst1_ptr += (dst_stride << 1);
735
736 r2 = vmull_s16(vget_low_s16(s2), vget_low_s16(s2));
737 r3 = vmull_s16(vget_low_s16(s3), vget_low_s16(s3));
738 r4_low = vmull_s16(vget_low_s16(s4), vget_low_s16(s4));
739 r23 = vaddq_s32(r2, r3);
740 r234_low = vaddq_s32(r23, r4_low);
741 r34_low = vaddq_s32(r3, r4_low);
742
743 r2 = vmull_s16(vget_high_s16(s2), vget_high_s16(s2));
744 r3 = vmull_s16(vget_high_s16(s3), vget_high_s16(s3));
745 r4_high = vmull_s16(vget_high_s16(s4), vget_high_s16(s4));
746 r23 = vaddq_s32(r2, r3);
747 r234_high = vaddq_s32(r23, r4_high);
748 r34_high = vaddq_s32(r3, r4_high);
749
750 dst2_ptr += (dst_stride << 1);
751
752 do {
753 load_s16_8x4(src_ptr, src_stride, &s5, &s6, &s7, &s8);
754 src_ptr += 4 * src_stride;
755
756 q345 = vaddq_s16(s5, q34);
757 q56 = vaddq_s16(s5, s6);
758 q456 = vaddq_s16(s4, q56);
759 q567 = vaddq_s16(s7, q56);
760 q78 = vaddq_s16(s7, s8);
761 q678 = vaddq_s16(s6, q78);
762
763 store_s16_8x4((int16_t *)dst1_ptr, dst_stride, q234, q345, q456, q567);
764 dst1_ptr += (dst_stride << 2);
765
766 s4 = s8;
767 q34 = q78;
768 q234 = q678;
769
770 r5 = vmull_s16(vget_low_s16(s5), vget_low_s16(s5));
771 r6 = vmull_s16(vget_low_s16(s6), vget_low_s16(s6));
772 r7 = vmull_s16(vget_low_s16(s7), vget_low_s16(s7));
773 r8 = vmull_s16(vget_low_s16(s8), vget_low_s16(s8));
774
775 r345 = vaddq_s32(r5, r34_low);
776 r56 = vaddq_s32(r5, r6);
777 r456 = vaddq_s32(r4_low, r56);
778 r567 = vaddq_s32(r7, r56);
779 r78 = vaddq_s32(r7, r8);
780 r678 = vaddq_s32(r6, r78);
781 store_s32_4x4(dst2_ptr, dst_stride, r234_low, r345, r456, r567);
782
783 r4_low = r8;
784 r34_low = r78;
785 r234_low = r678;
786
787 r5 = vmull_s16(vget_high_s16(s5), vget_high_s16(s5));
788 r6 = vmull_s16(vget_high_s16(s6), vget_high_s16(s6));
789 r7 = vmull_s16(vget_high_s16(s7), vget_high_s16(s7));
790 r8 = vmull_s16(vget_high_s16(s8), vget_high_s16(s8));
791
792 r345 = vaddq_s32(r5, r34_high);
793 r56 = vaddq_s32(r5, r6);
794 r456 = vaddq_s32(r4_high, r56);
795 r567 = vaddq_s32(r7, r56);
796 r78 = vaddq_s32(r7, r8);
797 r678 = vaddq_s32(r6, r78);
798 store_s32_4x4((dst2_ptr + 4), dst_stride, r234_high, r345, r456, r567);
799 dst2_ptr += (dst_stride << 2);
800
801 r4_high = r8;
802 r34_high = r78;
803 r234_high = r678;
804
805 h -= 4;
806 } while (h > 0);
807 w -= 8;
808 count++;
809 } while (w > 0);
810
811 // memset needed for row pixels as 2nd stage of boxsum filter uses
812 // first 2 rows of dst1, dst2 buffer which is not filled in first stage.
813 for (int x = 0; x < 2; x++) {
814 memset(dst1 + x * dst_stride, 0, (width + 4) * sizeof(*dst1));
815 memset(dst2 + x * dst_stride, 0, (width + 4) * sizeof(*dst2));
816 }
817
818 // memset needed for extra columns as 2nd stage of boxsum filter uses
819 // last 2 columns of dst1, dst2 buffer which is not filled in first stage.
820 for (int x = 2; x < height + 2; x++) {
821 int dst_offset = x * dst_stride + width + 2;
822 memset(dst1 + dst_offset, 0, 3 * sizeof(*dst1));
823 memset(dst2 + dst_offset, 0, 3 * sizeof(*dst2));
824 }
825 }
826
827 {
828 int16x4_t d1, d2, d3, d4, d5, d6, d7, d8;
829 int16x4_t q23, q34, q56, q234, q345, q456, q567;
830 int32x4_t r23, r56, r234, r345, r456, r567, r34, r78, r678;
831 int32x4_t r1, r2, r3, r4, r5, r6, r7, r8;
832 int16x4_t q678, q78;
833
834 int32_t *src2_ptr;
835 uint16_t *src1_ptr;
836 count = 0;
837 h = height;
838 w = width;
839 do {
840 dst1_ptr = dst1 + (count << 2) * dst_stride;
841 dst2_ptr = dst2 + (count << 2) * dst_stride;
842 src1_ptr = dst1 + (count << 2) * dst_stride;
843 src2_ptr = dst2 + (count << 2) * dst_stride;
844 w = width;
845
846 load_s16_4x4((int16_t *)src1_ptr, dst_stride, &d1, &d2, &d3, &d4);
847 transpose_elems_inplace_s16_4x4(&d1, &d2, &d3, &d4);
848 load_s32_4x4(src2_ptr, dst_stride, &r1, &r2, &r3, &r4);
849 transpose_elems_inplace_s32_4x4(&r1, &r2, &r3, &r4);
850 src1_ptr += 4;
851 src2_ptr += 4;
852
853 q23 = vadd_s16(d2, d3);
854 q234 = vadd_s16(q23, d4);
855 q34 = vadd_s16(d3, d4);
856 dst1_ptr += 2;
857 r23 = vaddq_s32(r2, r3);
858 r234 = vaddq_s32(r23, r4);
859 r34 = vaddq_s32(r3, r4);
860 dst2_ptr += 2;
861
862 do {
863 load_s16_4x4((int16_t *)src1_ptr, dst_stride, &d5, &d6, &d7, &d8);
864 transpose_elems_inplace_s16_4x4(&d5, &d6, &d7, &d8);
865 load_s32_4x4(src2_ptr, dst_stride, &r5, &r6, &r7, &r8);
866 transpose_elems_inplace_s32_4x4(&r5, &r6, &r7, &r8);
867 src1_ptr += 4;
868 src2_ptr += 4;
869
870 q345 = vadd_s16(d5, q34);
871 q56 = vadd_s16(d5, d6);
872 q456 = vadd_s16(d4, q56);
873 q567 = vadd_s16(d7, q56);
874 q78 = vadd_s16(d7, d8);
875 q678 = vadd_s16(d6, q78);
876 transpose_elems_inplace_s16_4x4(&q234, &q345, &q456, &q567);
877 store_s16_4x4((int16_t *)dst1_ptr, dst_stride, q234, q345, q456, q567);
878 dst1_ptr += 4;
879
880 d4 = d8;
881 q34 = q78;
882 q234 = q678;
883
884 r345 = vaddq_s32(r5, r34);
885 r56 = vaddq_s32(r5, r6);
886 r456 = vaddq_s32(r4, r56);
887 r567 = vaddq_s32(r7, r56);
888 r78 = vaddq_s32(r7, r8);
889 r678 = vaddq_s32(r6, r78);
890 transpose_elems_inplace_s32_4x4(&r234, &r345, &r456, &r567);
891 store_s32_4x4(dst2_ptr, dst_stride, r234, r345, r456, r567);
892 dst2_ptr += 4;
893
894 r4 = r8;
895 r34 = r78;
896 r234 = r678;
897 w -= 4;
898 } while (w > 0);
899 h -= 4;
900 count++;
901 } while (h > 0);
902 }
903 }
904
cross_sum_inp_s32(int32_t * buf,int buf_stride)905 static inline int32x4_t cross_sum_inp_s32(int32_t *buf, int buf_stride) {
906 int32x4_t xtr, xt, xtl, xl, x, xr, xbr, xb, xbl;
907 int32x4_t fours, threes, res;
908
909 xtl = vld1q_s32(buf - buf_stride - 1);
910 xt = vld1q_s32(buf - buf_stride);
911 xtr = vld1q_s32(buf - buf_stride + 1);
912 xl = vld1q_s32(buf - 1);
913 x = vld1q_s32(buf);
914 xr = vld1q_s32(buf + 1);
915 xbl = vld1q_s32(buf + buf_stride - 1);
916 xb = vld1q_s32(buf + buf_stride);
917 xbr = vld1q_s32(buf + buf_stride + 1);
918
919 fours = vaddq_s32(xl, vaddq_s32(xt, vaddq_s32(xr, vaddq_s32(xb, x))));
920 threes = vaddq_s32(xtl, vaddq_s32(xtr, vaddq_s32(xbr, xbl)));
921 res = vsubq_s32(vshlq_n_s32(vaddq_s32(fours, threes), 2), threes);
922 return res;
923 }
924
cross_sum_inp_u16(uint16_t * buf,int buf_stride,int32x4_t * a0,int32x4_t * a1)925 static inline void cross_sum_inp_u16(uint16_t *buf, int buf_stride,
926 int32x4_t *a0, int32x4_t *a1) {
927 uint16x8_t xtr, xt, xtl, xl, x, xr, xbr, xb, xbl;
928 uint16x8_t r0, r1;
929
930 xtl = vld1q_u16(buf - buf_stride - 1);
931 xt = vld1q_u16(buf - buf_stride);
932 xtr = vld1q_u16(buf - buf_stride + 1);
933 xl = vld1q_u16(buf - 1);
934 x = vld1q_u16(buf);
935 xr = vld1q_u16(buf + 1);
936 xbl = vld1q_u16(buf + buf_stride - 1);
937 xb = vld1q_u16(buf + buf_stride);
938 xbr = vld1q_u16(buf + buf_stride + 1);
939
940 xb = vaddq_u16(xb, x);
941 xt = vaddq_u16(xt, xr);
942 xl = vaddq_u16(xl, xb);
943 xl = vaddq_u16(xl, xt);
944
945 r0 = vshlq_n_u16(xl, 2);
946
947 xbl = vaddq_u16(xbl, xbr);
948 xtl = vaddq_u16(xtl, xtr);
949 xtl = vaddq_u16(xtl, xbl);
950
951 r1 = vshlq_n_u16(xtl, 2);
952 r1 = vsubq_u16(r1, xtl);
953
954 *a0 = vreinterpretq_s32_u32(
955 vaddq_u32(vmovl_u16(vget_low_u16(r0)), vmovl_u16(vget_low_u16(r1))));
956 *a1 = vreinterpretq_s32_u32(
957 vaddq_u32(vmovl_u16(vget_high_u16(r0)), vmovl_u16(vget_high_u16(r1))));
958 }
959
cross_sum_fast_even_row(int32_t * buf,int buf_stride)960 static inline int32x4_t cross_sum_fast_even_row(int32_t *buf, int buf_stride) {
961 int32x4_t xtr, xt, xtl, xbr, xb, xbl;
962 int32x4_t fives, sixes, fives_plus_sixes;
963
964 xtl = vld1q_s32(buf - buf_stride - 1);
965 xt = vld1q_s32(buf - buf_stride);
966 xtr = vld1q_s32(buf - buf_stride + 1);
967 xbl = vld1q_s32(buf + buf_stride - 1);
968 xb = vld1q_s32(buf + buf_stride);
969 xbr = vld1q_s32(buf + buf_stride + 1);
970
971 fives = vaddq_s32(xtl, vaddq_s32(xtr, vaddq_s32(xbr, xbl)));
972 sixes = vaddq_s32(xt, xb);
973 fives_plus_sixes = vaddq_s32(fives, sixes);
974
975 return vaddq_s32(
976 vaddq_s32(vshlq_n_s32(fives_plus_sixes, 2), fives_plus_sixes), sixes);
977 }
978
cross_sum_fast_even_row_inp16(uint16_t * buf,int buf_stride,int32x4_t * a0,int32x4_t * a1)979 static inline void cross_sum_fast_even_row_inp16(uint16_t *buf, int buf_stride,
980 int32x4_t *a0, int32x4_t *a1) {
981 uint16x8_t xtr, xt, xtl, xbr, xb, xbl, xb0;
982
983 xtl = vld1q_u16(buf - buf_stride - 1);
984 xt = vld1q_u16(buf - buf_stride);
985 xtr = vld1q_u16(buf - buf_stride + 1);
986 xbl = vld1q_u16(buf + buf_stride - 1);
987 xb = vld1q_u16(buf + buf_stride);
988 xbr = vld1q_u16(buf + buf_stride + 1);
989
990 xbr = vaddq_u16(xbr, xbl);
991 xtr = vaddq_u16(xtr, xtl);
992 xbr = vaddq_u16(xbr, xtr);
993 xtl = vshlq_n_u16(xbr, 2);
994 xbr = vaddq_u16(xtl, xbr);
995
996 xb = vaddq_u16(xb, xt);
997 xb0 = vshlq_n_u16(xb, 1);
998 xb = vshlq_n_u16(xb, 2);
999 xb = vaddq_u16(xb, xb0);
1000
1001 *a0 = vreinterpretq_s32_u32(
1002 vaddq_u32(vmovl_u16(vget_low_u16(xbr)), vmovl_u16(vget_low_u16(xb))));
1003 *a1 = vreinterpretq_s32_u32(
1004 vaddq_u32(vmovl_u16(vget_high_u16(xbr)), vmovl_u16(vget_high_u16(xb))));
1005 }
1006
cross_sum_fast_odd_row(int32_t * buf)1007 static inline int32x4_t cross_sum_fast_odd_row(int32_t *buf) {
1008 int32x4_t xl, x, xr;
1009 int32x4_t fives, sixes, fives_plus_sixes;
1010
1011 xl = vld1q_s32(buf - 1);
1012 x = vld1q_s32(buf);
1013 xr = vld1q_s32(buf + 1);
1014 fives = vaddq_s32(xl, xr);
1015 sixes = x;
1016 fives_plus_sixes = vaddq_s32(fives, sixes);
1017
1018 return vaddq_s32(
1019 vaddq_s32(vshlq_n_s32(fives_plus_sixes, 2), fives_plus_sixes), sixes);
1020 }
1021
cross_sum_fast_odd_row_inp16(uint16_t * buf,int32x4_t * a0,int32x4_t * a1)1022 static inline void cross_sum_fast_odd_row_inp16(uint16_t *buf, int32x4_t *a0,
1023 int32x4_t *a1) {
1024 uint16x8_t xl, x, xr;
1025 uint16x8_t x0;
1026
1027 xl = vld1q_u16(buf - 1);
1028 x = vld1q_u16(buf);
1029 xr = vld1q_u16(buf + 1);
1030 xl = vaddq_u16(xl, xr);
1031 x0 = vshlq_n_u16(xl, 2);
1032 xl = vaddq_u16(xl, x0);
1033
1034 x0 = vshlq_n_u16(x, 1);
1035 x = vshlq_n_u16(x, 2);
1036 x = vaddq_u16(x, x0);
1037
1038 *a0 = vreinterpretq_s32_u32(
1039 vaddq_u32(vmovl_u16(vget_low_u16(xl)), vmovl_u16(vget_low_u16(x))));
1040 *a1 = vreinterpretq_s32_u32(
1041 vaddq_u32(vmovl_u16(vget_high_u16(xl)), vmovl_u16(vget_high_u16(x))));
1042 }
1043
final_filter_fast_internal(uint16_t * A,int32_t * B,const int buf_stride,int16_t * src,const int src_stride,int32_t * dst,const int dst_stride,const int width,const int height)1044 static void final_filter_fast_internal(uint16_t *A, int32_t *B,
1045 const int buf_stride, int16_t *src,
1046 const int src_stride, int32_t *dst,
1047 const int dst_stride, const int width,
1048 const int height) {
1049 int16x8_t s0;
1050 int32_t *B_tmp, *dst_ptr;
1051 uint16_t *A_tmp;
1052 int16_t *src_ptr;
1053 int32x4_t a_res0, a_res1, b_res0, b_res1;
1054 int w, h, count = 0;
1055 assert(SGRPROJ_SGR_BITS == 8);
1056 assert(SGRPROJ_RST_BITS == 4);
1057
1058 A_tmp = A;
1059 B_tmp = B;
1060 src_ptr = src;
1061 dst_ptr = dst;
1062 h = height;
1063 do {
1064 A_tmp = (A + count * buf_stride);
1065 B_tmp = (B + count * buf_stride);
1066 src_ptr = (src + count * src_stride);
1067 dst_ptr = (dst + count * dst_stride);
1068 w = width;
1069 if (!(count & 1)) {
1070 do {
1071 s0 = vld1q_s16(src_ptr);
1072 cross_sum_fast_even_row_inp16(A_tmp, buf_stride, &a_res0, &a_res1);
1073 a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0);
1074 a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1);
1075
1076 b_res0 = cross_sum_fast_even_row(B_tmp, buf_stride);
1077 b_res1 = cross_sum_fast_even_row(B_tmp + 4, buf_stride);
1078 a_res0 = vaddq_s32(a_res0, b_res0);
1079 a_res1 = vaddq_s32(a_res1, b_res1);
1080
1081 a_res0 =
1082 vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
1083 a_res1 =
1084 vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
1085
1086 vst1q_s32(dst_ptr, a_res0);
1087 vst1q_s32(dst_ptr + 4, a_res1);
1088
1089 A_tmp += 8;
1090 B_tmp += 8;
1091 src_ptr += 8;
1092 dst_ptr += 8;
1093 w -= 8;
1094 } while (w > 0);
1095 } else {
1096 do {
1097 s0 = vld1q_s16(src_ptr);
1098 cross_sum_fast_odd_row_inp16(A_tmp, &a_res0, &a_res1);
1099 a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0);
1100 a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1);
1101
1102 b_res0 = cross_sum_fast_odd_row(B_tmp);
1103 b_res1 = cross_sum_fast_odd_row(B_tmp + 4);
1104 a_res0 = vaddq_s32(a_res0, b_res0);
1105 a_res1 = vaddq_s32(a_res1, b_res1);
1106
1107 a_res0 =
1108 vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_ODD - SGRPROJ_RST_BITS);
1109 a_res1 =
1110 vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_ODD - SGRPROJ_RST_BITS);
1111
1112 vst1q_s32(dst_ptr, a_res0);
1113 vst1q_s32(dst_ptr + 4, a_res1);
1114
1115 A_tmp += 8;
1116 B_tmp += 8;
1117 src_ptr += 8;
1118 dst_ptr += 8;
1119 w -= 8;
1120 } while (w > 0);
1121 }
1122 count++;
1123 h -= 1;
1124 } while (h > 0);
1125 }
1126
final_filter_internal(uint16_t * A,int32_t * B,const int buf_stride,int16_t * src,const int src_stride,int32_t * dst,const int dst_stride,const int width,const int height)1127 static void final_filter_internal(uint16_t *A, int32_t *B, const int buf_stride,
1128 int16_t *src, const int src_stride,
1129 int32_t *dst, const int dst_stride,
1130 const int width, const int height) {
1131 int16x8_t s0;
1132 int32_t *B_tmp, *dst_ptr;
1133 uint16_t *A_tmp;
1134 int16_t *src_ptr;
1135 int32x4_t a_res0, a_res1, b_res0, b_res1;
1136 int w, h, count = 0;
1137
1138 assert(SGRPROJ_SGR_BITS == 8);
1139 assert(SGRPROJ_RST_BITS == 4);
1140 h = height;
1141
1142 do {
1143 A_tmp = (A + count * buf_stride);
1144 B_tmp = (B + count * buf_stride);
1145 src_ptr = (src + count * src_stride);
1146 dst_ptr = (dst + count * dst_stride);
1147 w = width;
1148 do {
1149 s0 = vld1q_s16(src_ptr);
1150 cross_sum_inp_u16(A_tmp, buf_stride, &a_res0, &a_res1);
1151 a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0);
1152 a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1);
1153
1154 b_res0 = cross_sum_inp_s32(B_tmp, buf_stride);
1155 b_res1 = cross_sum_inp_s32(B_tmp + 4, buf_stride);
1156 a_res0 = vaddq_s32(a_res0, b_res0);
1157 a_res1 = vaddq_s32(a_res1, b_res1);
1158
1159 a_res0 =
1160 vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
1161 a_res1 =
1162 vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
1163 vst1q_s32(dst_ptr, a_res0);
1164 vst1q_s32(dst_ptr + 4, a_res1);
1165
1166 A_tmp += 8;
1167 B_tmp += 8;
1168 src_ptr += 8;
1169 dst_ptr += 8;
1170 w -= 8;
1171 } while (w > 0);
1172 count++;
1173 h -= 1;
1174 } while (h > 0);
1175 }
1176
restoration_fast_internal(uint16_t * dgd16,int width,int height,int dgd_stride,int32_t * dst,int dst_stride,int bit_depth,int sgr_params_idx,int radius_idx)1177 static inline int restoration_fast_internal(uint16_t *dgd16, int width,
1178 int height, int dgd_stride,
1179 int32_t *dst, int dst_stride,
1180 int bit_depth, int sgr_params_idx,
1181 int radius_idx) {
1182 const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx];
1183 const int r = params->r[radius_idx];
1184 const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
1185 const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
1186 const int buf_stride = ((width_ext + 3) & ~3) + 16;
1187
1188 const size_t buf_size = 3 * sizeof(int32_t) * RESTORATION_PROC_UNIT_PELS;
1189 int32_t *buf = aom_memalign(8, buf_size);
1190 if (!buf) return -1;
1191
1192 int32_t *square_sum_buf = buf;
1193 int32_t *sum_buf = square_sum_buf + RESTORATION_PROC_UNIT_PELS;
1194 uint16_t *tmp16_buf = (uint16_t *)(sum_buf + RESTORATION_PROC_UNIT_PELS);
1195 assert((char *)(sum_buf + RESTORATION_PROC_UNIT_PELS) <=
1196 (char *)buf + buf_size &&
1197 "Allocated buffer is too small. Resize the buffer.");
1198
1199 assert(r <= MAX_RADIUS && "Need MAX_RADIUS >= r");
1200 assert(r <= SGRPROJ_BORDER_VERT - 1 && r <= SGRPROJ_BORDER_HORZ - 1 &&
1201 "Need SGRPROJ_BORDER_* >= r+1");
1202
1203 assert(radius_idx == 0);
1204 assert(r == 2);
1205
1206 // input(dgd16) is 16bit.
1207 // sum of pixels 1st stage output will be in 16bit(tmp16_buf). End output is
1208 // kept in 32bit [sum_buf]. sum of squares output is kept in 32bit
1209 // buffer(square_sum_buf).
1210 boxsum2((int16_t *)(dgd16 - dgd_stride * SGRPROJ_BORDER_VERT -
1211 SGRPROJ_BORDER_HORZ),
1212 dgd_stride, (int16_t *)tmp16_buf, sum_buf, square_sum_buf, buf_stride,
1213 width_ext, height_ext);
1214
1215 square_sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1216 sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1217 tmp16_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1218
1219 // Calculation of a, b. a output is in 16bit tmp_buf which is in range of
1220 // [1, 256] for all bit depths. b output is kept in 32bit buffer.
1221
1222 #if CONFIG_AV1_HIGHBITDEPTH
1223 if (bit_depth > 8) {
1224 calc_ab_fast_internal_hbd(
1225 (square_sum_buf - buf_stride - 1), (tmp16_buf - buf_stride - 1),
1226 (sum_buf - buf_stride - 1), buf_stride * 2, width + 2, height + 2,
1227 bit_depth, r, params->s[radius_idx], 2);
1228 } else {
1229 calc_ab_fast_internal_lbd(
1230 (square_sum_buf - buf_stride - 1), (tmp16_buf - buf_stride - 1),
1231 (sum_buf - buf_stride - 1), buf_stride * 2, width + 2, height + 2, r,
1232 params->s[radius_idx], 2);
1233 }
1234 #else
1235 (void)bit_depth;
1236 calc_ab_fast_internal_lbd((square_sum_buf - buf_stride - 1),
1237 (tmp16_buf - buf_stride - 1),
1238 (sum_buf - buf_stride - 1), buf_stride * 2,
1239 width + 2, height + 2, r, params->s[radius_idx], 2);
1240 #endif
1241 final_filter_fast_internal(tmp16_buf, sum_buf, buf_stride, (int16_t *)dgd16,
1242 dgd_stride, dst, dst_stride, width, height);
1243 aom_free(buf);
1244 return 0;
1245 }
1246
restoration_internal(uint16_t * dgd16,int width,int height,int dgd_stride,int32_t * dst,int dst_stride,int bit_depth,int sgr_params_idx,int radius_idx)1247 static inline int restoration_internal(uint16_t *dgd16, int width, int height,
1248 int dgd_stride, int32_t *dst,
1249 int dst_stride, int bit_depth,
1250 int sgr_params_idx, int radius_idx) {
1251 const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx];
1252 const int r = params->r[radius_idx];
1253 const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
1254 const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
1255 const int buf_stride = ((width_ext + 3) & ~3) + 16;
1256
1257 const size_t buf_size = 3 * sizeof(int32_t) * RESTORATION_PROC_UNIT_PELS;
1258 int32_t *buf = aom_memalign(8, buf_size);
1259 if (!buf) return -1;
1260
1261 int32_t *square_sum_buf = buf;
1262 int32_t *B = square_sum_buf + RESTORATION_PROC_UNIT_PELS;
1263 uint16_t *A16 = (uint16_t *)(B + RESTORATION_PROC_UNIT_PELS);
1264 uint16_t *sum_buf = A16 + RESTORATION_PROC_UNIT_PELS;
1265
1266 assert((char *)(sum_buf + RESTORATION_PROC_UNIT_PELS) <=
1267 (char *)buf + buf_size &&
1268 "Allocated buffer is too small. Resize the buffer.");
1269
1270 assert(r <= MAX_RADIUS && "Need MAX_RADIUS >= r");
1271 assert(r <= SGRPROJ_BORDER_VERT - 1 && r <= SGRPROJ_BORDER_HORZ - 1 &&
1272 "Need SGRPROJ_BORDER_* >= r+1");
1273
1274 assert(radius_idx == 1);
1275 assert(r == 1);
1276
1277 // input(dgd16) is 16bit.
1278 // sum of pixels output will be in 16bit(sum_buf).
1279 // sum of squares output is kept in 32bit buffer(square_sum_buf).
1280 boxsum1((int16_t *)(dgd16 - dgd_stride * SGRPROJ_BORDER_VERT -
1281 SGRPROJ_BORDER_HORZ),
1282 dgd_stride, sum_buf, square_sum_buf, buf_stride, width_ext,
1283 height_ext);
1284
1285 square_sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1286 B += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1287 A16 += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1288 sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1289
1290 #if CONFIG_AV1_HIGHBITDEPTH
1291 // Calculation of a, b. a output is in 16bit tmp_buf which is in range of
1292 // [1, 256] for all bit depths. b output is kept in 32bit buffer.
1293 if (bit_depth > 8) {
1294 calc_ab_internal_hbd((square_sum_buf - buf_stride - 1),
1295 (A16 - buf_stride - 1), (sum_buf - buf_stride - 1),
1296 (B - buf_stride - 1), buf_stride, width + 2,
1297 height + 2, bit_depth, r, params->s[radius_idx], 1);
1298 } else {
1299 calc_ab_internal_lbd((square_sum_buf - buf_stride - 1),
1300 (A16 - buf_stride - 1), (sum_buf - buf_stride - 1),
1301 (B - buf_stride - 1), buf_stride, width + 2,
1302 height + 2, r, params->s[radius_idx], 1);
1303 }
1304 #else
1305 (void)bit_depth;
1306 calc_ab_internal_lbd((square_sum_buf - buf_stride - 1),
1307 (A16 - buf_stride - 1), (sum_buf - buf_stride - 1),
1308 (B - buf_stride - 1), buf_stride, width + 2, height + 2,
1309 r, params->s[radius_idx], 1);
1310 #endif
1311 final_filter_internal(A16, B, buf_stride, (int16_t *)dgd16, dgd_stride, dst,
1312 dst_stride, width, height);
1313 aom_free(buf);
1314 return 0;
1315 }
1316
src_convert_u8_to_u16(const uint8_t * src,const int src_stride,uint16_t * dst,const int dst_stride,const int width,const int height)1317 static inline void src_convert_u8_to_u16(const uint8_t *src,
1318 const int src_stride, uint16_t *dst,
1319 const int dst_stride, const int width,
1320 const int height) {
1321 const uint8_t *src_ptr;
1322 uint16_t *dst_ptr;
1323 int h, w, count = 0;
1324
1325 uint8x8_t t1, t2, t3, t4;
1326 uint16x8_t s1, s2, s3, s4;
1327 h = height;
1328 do {
1329 src_ptr = src + (count << 2) * src_stride;
1330 dst_ptr = dst + (count << 2) * dst_stride;
1331 w = width;
1332 if (w >= 7) {
1333 do {
1334 load_u8_8x4(src_ptr, src_stride, &t1, &t2, &t3, &t4);
1335 s1 = vmovl_u8(t1);
1336 s2 = vmovl_u8(t2);
1337 s3 = vmovl_u8(t3);
1338 s4 = vmovl_u8(t4);
1339 store_u16_8x4(dst_ptr, dst_stride, s1, s2, s3, s4);
1340
1341 src_ptr += 8;
1342 dst_ptr += 8;
1343 w -= 8;
1344 } while (w > 7);
1345 }
1346
1347 for (int y = 0; y < w; y++) {
1348 dst_ptr[y] = src_ptr[y];
1349 dst_ptr[y + 1 * dst_stride] = src_ptr[y + 1 * src_stride];
1350 dst_ptr[y + 2 * dst_stride] = src_ptr[y + 2 * src_stride];
1351 dst_ptr[y + 3 * dst_stride] = src_ptr[y + 3 * src_stride];
1352 }
1353 count++;
1354 h -= 4;
1355 } while (h > 3);
1356
1357 src_ptr = src + (count << 2) * src_stride;
1358 dst_ptr = dst + (count << 2) * dst_stride;
1359 for (int x = 0; x < h; x++) {
1360 for (int y = 0; y < width; y++) {
1361 dst_ptr[y + x * dst_stride] = src_ptr[y + x * src_stride];
1362 }
1363 }
1364
1365 // memset uninitialized rows of src buffer as they are needed for the
1366 // boxsum filter calculation.
1367 for (int x = height; x < height + 5; x++)
1368 memset(dst + x * dst_stride, 0, (width + 2) * sizeof(*dst));
1369 }
1370
1371 #if CONFIG_AV1_HIGHBITDEPTH
src_convert_hbd_copy(const uint16_t * src,int src_stride,uint16_t * dst,const int dst_stride,int width,int height)1372 static inline void src_convert_hbd_copy(const uint16_t *src, int src_stride,
1373 uint16_t *dst, const int dst_stride,
1374 int width, int height) {
1375 const uint16_t *src_ptr;
1376 uint16_t *dst_ptr;
1377 int h, w, count = 0;
1378 uint16x8_t s1, s2, s3, s4;
1379
1380 h = height;
1381 do {
1382 src_ptr = src + (count << 2) * src_stride;
1383 dst_ptr = dst + (count << 2) * dst_stride;
1384 w = width;
1385 do {
1386 load_u16_8x4(src_ptr, src_stride, &s1, &s2, &s3, &s4);
1387 store_u16_8x4(dst_ptr, dst_stride, s1, s2, s3, s4);
1388 src_ptr += 8;
1389 dst_ptr += 8;
1390 w -= 8;
1391 } while (w > 7);
1392
1393 for (int y = 0; y < w; y++) {
1394 dst_ptr[y] = src_ptr[y];
1395 dst_ptr[y + 1 * dst_stride] = src_ptr[y + 1 * src_stride];
1396 dst_ptr[y + 2 * dst_stride] = src_ptr[y + 2 * src_stride];
1397 dst_ptr[y + 3 * dst_stride] = src_ptr[y + 3 * src_stride];
1398 }
1399 count++;
1400 h -= 4;
1401 } while (h > 3);
1402
1403 src_ptr = src + (count << 2) * src_stride;
1404 dst_ptr = dst + (count << 2) * dst_stride;
1405
1406 for (int x = 0; x < h; x++) {
1407 memcpy((dst_ptr + x * dst_stride), (src_ptr + x * src_stride),
1408 sizeof(uint16_t) * width);
1409 }
1410 // memset uninitialized rows of src buffer as they are needed for the
1411 // boxsum filter calculation.
1412 for (int x = height; x < height + 5; x++)
1413 memset(dst + x * dst_stride, 0, (width + 2) * sizeof(*dst));
1414 }
1415 #endif // CONFIG_AV1_HIGHBITDEPTH
1416
av1_selfguided_restoration_neon(const uint8_t * dat8,int width,int height,int stride,int32_t * flt0,int32_t * flt1,int flt_stride,int sgr_params_idx,int bit_depth,int highbd)1417 int av1_selfguided_restoration_neon(const uint8_t *dat8, int width, int height,
1418 int stride, int32_t *flt0, int32_t *flt1,
1419 int flt_stride, int sgr_params_idx,
1420 int bit_depth, int highbd) {
1421 const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx];
1422 assert(!(params->r[0] == 0 && params->r[1] == 0));
1423
1424 uint16_t dgd16_[RESTORATION_PROC_UNIT_PELS];
1425 const int dgd16_stride = width + 2 * SGRPROJ_BORDER_HORZ;
1426 uint16_t *dgd16 =
1427 dgd16_ + dgd16_stride * SGRPROJ_BORDER_VERT + SGRPROJ_BORDER_HORZ;
1428 const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
1429 const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
1430 const int dgd_stride = stride;
1431
1432 #if CONFIG_AV1_HIGHBITDEPTH
1433 if (highbd) {
1434 const uint16_t *dgd16_tmp = CONVERT_TO_SHORTPTR(dat8);
1435 src_convert_hbd_copy(
1436 dgd16_tmp - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
1437 dgd_stride,
1438 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1439 dgd16_stride, width_ext, height_ext);
1440 } else {
1441 src_convert_u8_to_u16(
1442 dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
1443 dgd_stride,
1444 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1445 dgd16_stride, width_ext, height_ext);
1446 }
1447 #else
1448 (void)highbd;
1449 src_convert_u8_to_u16(
1450 dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, dgd_stride,
1451 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1452 dgd16_stride, width_ext, height_ext);
1453 #endif
1454
1455 if (params->r[0] > 0) {
1456 int ret =
1457 restoration_fast_internal(dgd16, width, height, dgd16_stride, flt0,
1458 flt_stride, bit_depth, sgr_params_idx, 0);
1459 if (ret != 0) return ret;
1460 }
1461 if (params->r[1] > 0) {
1462 int ret = restoration_internal(dgd16, width, height, dgd16_stride, flt1,
1463 flt_stride, bit_depth, sgr_params_idx, 1);
1464 if (ret != 0) return ret;
1465 }
1466 return 0;
1467 }
1468
av1_apply_selfguided_restoration_neon(const uint8_t * dat8,int width,int height,int stride,int eps,const int * xqd,uint8_t * dst8,int dst_stride,int32_t * tmpbuf,int bit_depth,int highbd)1469 int av1_apply_selfguided_restoration_neon(const uint8_t *dat8, int width,
1470 int height, int stride, int eps,
1471 const int *xqd, uint8_t *dst8,
1472 int dst_stride, int32_t *tmpbuf,
1473 int bit_depth, int highbd) {
1474 int32_t *flt0 = tmpbuf;
1475 int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX;
1476 assert(width * height <= RESTORATION_UNITPELS_MAX);
1477 uint16_t dgd16_[RESTORATION_PROC_UNIT_PELS];
1478 const int dgd16_stride = width + 2 * SGRPROJ_BORDER_HORZ;
1479 uint16_t *dgd16 =
1480 dgd16_ + dgd16_stride * SGRPROJ_BORDER_VERT + SGRPROJ_BORDER_HORZ;
1481 const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
1482 const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
1483 const int dgd_stride = stride;
1484 const sgr_params_type *const params = &av1_sgr_params[eps];
1485 int xq[2];
1486
1487 assert(!(params->r[0] == 0 && params->r[1] == 0));
1488
1489 #if CONFIG_AV1_HIGHBITDEPTH
1490 if (highbd) {
1491 const uint16_t *dgd16_tmp = CONVERT_TO_SHORTPTR(dat8);
1492 src_convert_hbd_copy(
1493 dgd16_tmp - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
1494 dgd_stride,
1495 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1496 dgd16_stride, width_ext, height_ext);
1497 } else {
1498 src_convert_u8_to_u16(
1499 dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
1500 dgd_stride,
1501 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1502 dgd16_stride, width_ext, height_ext);
1503 }
1504 #else
1505 (void)highbd;
1506 src_convert_u8_to_u16(
1507 dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, dgd_stride,
1508 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1509 dgd16_stride, width_ext, height_ext);
1510 #endif
1511 if (params->r[0] > 0) {
1512 int ret = restoration_fast_internal(dgd16, width, height, dgd16_stride,
1513 flt0, width, bit_depth, eps, 0);
1514 if (ret != 0) return ret;
1515 }
1516 if (params->r[1] > 0) {
1517 int ret = restoration_internal(dgd16, width, height, dgd16_stride, flt1,
1518 width, bit_depth, eps, 1);
1519 if (ret != 0) return ret;
1520 }
1521
1522 av1_decode_xq(xqd, xq, params);
1523
1524 {
1525 int16_t *src_ptr;
1526 uint8_t *dst_ptr;
1527 #if CONFIG_AV1_HIGHBITDEPTH
1528 uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst8);
1529 uint16_t *dst16_ptr;
1530 #endif
1531 int16x4_t d0, d4;
1532 int16x8_t r0, s0;
1533 uint16x8_t r4;
1534 int32x4_t u0, u4, v0, v4, f00, f10;
1535 uint8x8_t t0;
1536 int count = 0, w = width, h = height, rc = 0;
1537
1538 const int32x4_t xq0_vec = vdupq_n_s32(xq[0]);
1539 const int32x4_t xq1_vec = vdupq_n_s32(xq[1]);
1540 const int16x8_t zero = vdupq_n_s16(0);
1541 const uint16x8_t max = vdupq_n_u16((1 << bit_depth) - 1);
1542 src_ptr = (int16_t *)dgd16;
1543 do {
1544 w = width;
1545 count = 0;
1546 dst_ptr = dst8 + rc * dst_stride;
1547 #if CONFIG_AV1_HIGHBITDEPTH
1548 dst16_ptr = dst16 + rc * dst_stride;
1549 #endif
1550 do {
1551 s0 = vld1q_s16(src_ptr + count);
1552
1553 u0 = vshll_n_s16(vget_low_s16(s0), SGRPROJ_RST_BITS);
1554 u4 = vshll_n_s16(vget_high_s16(s0), SGRPROJ_RST_BITS);
1555
1556 v0 = vshlq_n_s32(u0, SGRPROJ_PRJ_BITS);
1557 v4 = vshlq_n_s32(u4, SGRPROJ_PRJ_BITS);
1558
1559 if (params->r[0] > 0) {
1560 f00 = vld1q_s32(flt0 + count);
1561 f10 = vld1q_s32(flt0 + count + 4);
1562
1563 f00 = vsubq_s32(f00, u0);
1564 f10 = vsubq_s32(f10, u4);
1565
1566 v0 = vmlaq_s32(v0, xq0_vec, f00);
1567 v4 = vmlaq_s32(v4, xq0_vec, f10);
1568 }
1569
1570 if (params->r[1] > 0) {
1571 f00 = vld1q_s32(flt1 + count);
1572 f10 = vld1q_s32(flt1 + count + 4);
1573
1574 f00 = vsubq_s32(f00, u0);
1575 f10 = vsubq_s32(f10, u4);
1576
1577 v0 = vmlaq_s32(v0, xq1_vec, f00);
1578 v4 = vmlaq_s32(v4, xq1_vec, f10);
1579 }
1580
1581 d0 = vqrshrn_n_s32(v0, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
1582 d4 = vqrshrn_n_s32(v4, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
1583
1584 r0 = vcombine_s16(d0, d4);
1585
1586 r4 = vreinterpretq_u16_s16(vmaxq_s16(r0, zero));
1587
1588 #if CONFIG_AV1_HIGHBITDEPTH
1589 if (highbd) {
1590 r4 = vminq_u16(r4, max);
1591 vst1q_u16(dst16_ptr, r4);
1592 dst16_ptr += 8;
1593 } else {
1594 t0 = vqmovn_u16(r4);
1595 vst1_u8(dst_ptr, t0);
1596 dst_ptr += 8;
1597 }
1598 #else
1599 (void)max;
1600 t0 = vqmovn_u16(r4);
1601 vst1_u8(dst_ptr, t0);
1602 dst_ptr += 8;
1603 #endif
1604 w -= 8;
1605 count += 8;
1606 } while (w > 0);
1607
1608 src_ptr += dgd16_stride;
1609 flt1 += width;
1610 flt0 += width;
1611 rc++;
1612 h--;
1613 } while (h > 0);
1614 }
1615 return 0;
1616 }
1617