xref: /aosp_15_r20/external/libaom/av1/common/arm/selfguided_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17 
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/txfm_common.h"
20 #include "aom_dsp/arm/mem_neon.h"
21 #include "aom_dsp/arm/transpose_neon.h"
22 #include "aom_mem/aom_mem.h"
23 #include "aom_ports/mem.h"
24 #include "av1/common/av1_common_int.h"
25 #include "av1/common/common.h"
26 #include "av1/common/resize.h"
27 #include "av1/common/restoration.h"
28 
29 // Constants used for right shift in final_filter calculation.
30 #define NB_EVEN 5
31 #define NB_ODD 4
32 
calc_ab_fast_internal_common(uint32x4_t s0,uint32x4_t s1,uint32x4_t s2,uint32x4_t s3,uint32x4_t s4,uint32x4_t s5,uint32x4_t s6,uint32x4_t s7,int32x4_t sr4,int32x4_t sr5,int32x4_t sr6,int32x4_t sr7,uint32x4_t const_n_val,uint32x4_t s_vec,uint32x4_t const_val,uint32x4_t one_by_n_minus_1_vec,uint16x4_t sgrproj_sgr,int32_t * src1,uint16_t * dst_A16,int32_t * src2,const int buf_stride)33 static inline void calc_ab_fast_internal_common(
34     uint32x4_t s0, uint32x4_t s1, uint32x4_t s2, uint32x4_t s3, uint32x4_t s4,
35     uint32x4_t s5, uint32x4_t s6, uint32x4_t s7, int32x4_t sr4, int32x4_t sr5,
36     int32x4_t sr6, int32x4_t sr7, uint32x4_t const_n_val, uint32x4_t s_vec,
37     uint32x4_t const_val, uint32x4_t one_by_n_minus_1_vec,
38     uint16x4_t sgrproj_sgr, int32_t *src1, uint16_t *dst_A16, int32_t *src2,
39     const int buf_stride) {
40   uint32x4_t q0, q1, q2, q3;
41   uint32x4_t p0, p1, p2, p3;
42   uint16x4_t d0, d1, d2, d3;
43 
44   s0 = vmulq_u32(s0, const_n_val);
45   s1 = vmulq_u32(s1, const_n_val);
46   s2 = vmulq_u32(s2, const_n_val);
47   s3 = vmulq_u32(s3, const_n_val);
48 
49   q0 = vmulq_u32(s4, s4);
50   q1 = vmulq_u32(s5, s5);
51   q2 = vmulq_u32(s6, s6);
52   q3 = vmulq_u32(s7, s7);
53 
54   p0 = vcleq_u32(q0, s0);
55   p1 = vcleq_u32(q1, s1);
56   p2 = vcleq_u32(q2, s2);
57   p3 = vcleq_u32(q3, s3);
58 
59   q0 = vsubq_u32(s0, q0);
60   q1 = vsubq_u32(s1, q1);
61   q2 = vsubq_u32(s2, q2);
62   q3 = vsubq_u32(s3, q3);
63 
64   p0 = vandq_u32(p0, q0);
65   p1 = vandq_u32(p1, q1);
66   p2 = vandq_u32(p2, q2);
67   p3 = vandq_u32(p3, q3);
68 
69   p0 = vmulq_u32(p0, s_vec);
70   p1 = vmulq_u32(p1, s_vec);
71   p2 = vmulq_u32(p2, s_vec);
72   p3 = vmulq_u32(p3, s_vec);
73 
74   p0 = vrshrq_n_u32(p0, SGRPROJ_MTABLE_BITS);
75   p1 = vrshrq_n_u32(p1, SGRPROJ_MTABLE_BITS);
76   p2 = vrshrq_n_u32(p2, SGRPROJ_MTABLE_BITS);
77   p3 = vrshrq_n_u32(p3, SGRPROJ_MTABLE_BITS);
78 
79   p0 = vminq_u32(p0, const_val);
80   p1 = vminq_u32(p1, const_val);
81   p2 = vminq_u32(p2, const_val);
82   p3 = vminq_u32(p3, const_val);
83 
84   {
85     store_u32_4x4((uint32_t *)src1, buf_stride, p0, p1, p2, p3);
86 
87     for (int x = 0; x < 4; x++) {
88       for (int y = 0; y < 4; y++) {
89         dst_A16[x * buf_stride + y] = av1_x_by_xplus1[src1[x * buf_stride + y]];
90       }
91     }
92     load_u16_4x4(dst_A16, buf_stride, &d0, &d1, &d2, &d3);
93   }
94   p0 = vsubl_u16(sgrproj_sgr, d0);
95   p1 = vsubl_u16(sgrproj_sgr, d1);
96   p2 = vsubl_u16(sgrproj_sgr, d2);
97   p3 = vsubl_u16(sgrproj_sgr, d3);
98 
99   s4 = vmulq_u32(vreinterpretq_u32_s32(sr4), one_by_n_minus_1_vec);
100   s5 = vmulq_u32(vreinterpretq_u32_s32(sr5), one_by_n_minus_1_vec);
101   s6 = vmulq_u32(vreinterpretq_u32_s32(sr6), one_by_n_minus_1_vec);
102   s7 = vmulq_u32(vreinterpretq_u32_s32(sr7), one_by_n_minus_1_vec);
103 
104   s4 = vmulq_u32(s4, p0);
105   s5 = vmulq_u32(s5, p1);
106   s6 = vmulq_u32(s6, p2);
107   s7 = vmulq_u32(s7, p3);
108 
109   p0 = vrshrq_n_u32(s4, SGRPROJ_RECIP_BITS);
110   p1 = vrshrq_n_u32(s5, SGRPROJ_RECIP_BITS);
111   p2 = vrshrq_n_u32(s6, SGRPROJ_RECIP_BITS);
112   p3 = vrshrq_n_u32(s7, SGRPROJ_RECIP_BITS);
113 
114   store_s32_4x4(src2, buf_stride, vreinterpretq_s32_u32(p0),
115                 vreinterpretq_s32_u32(p1), vreinterpretq_s32_u32(p2),
116                 vreinterpretq_s32_u32(p3));
117 }
calc_ab_internal_common(uint32x4_t s0,uint32x4_t s1,uint32x4_t s2,uint32x4_t s3,uint32x4_t s4,uint32x4_t s5,uint32x4_t s6,uint32x4_t s7,uint16x8_t s16_0,uint16x8_t s16_1,uint16x8_t s16_2,uint16x8_t s16_3,uint16x8_t s16_4,uint16x8_t s16_5,uint16x8_t s16_6,uint16x8_t s16_7,uint32x4_t const_n_val,uint32x4_t s_vec,uint32x4_t const_val,uint16x4_t one_by_n_minus_1_vec,uint16x8_t sgrproj_sgr,int32_t * src1,uint16_t * dst_A16,int32_t * dst2,const int buf_stride)118 static inline void calc_ab_internal_common(
119     uint32x4_t s0, uint32x4_t s1, uint32x4_t s2, uint32x4_t s3, uint32x4_t s4,
120     uint32x4_t s5, uint32x4_t s6, uint32x4_t s7, uint16x8_t s16_0,
121     uint16x8_t s16_1, uint16x8_t s16_2, uint16x8_t s16_3, uint16x8_t s16_4,
122     uint16x8_t s16_5, uint16x8_t s16_6, uint16x8_t s16_7,
123     uint32x4_t const_n_val, uint32x4_t s_vec, uint32x4_t const_val,
124     uint16x4_t one_by_n_minus_1_vec, uint16x8_t sgrproj_sgr, int32_t *src1,
125     uint16_t *dst_A16, int32_t *dst2, const int buf_stride) {
126   uint16x4_t d0, d1, d2, d3, d4, d5, d6, d7;
127   uint32x4_t q0, q1, q2, q3, q4, q5, q6, q7;
128   uint32x4_t p0, p1, p2, p3, p4, p5, p6, p7;
129 
130   s0 = vmulq_u32(s0, const_n_val);
131   s1 = vmulq_u32(s1, const_n_val);
132   s2 = vmulq_u32(s2, const_n_val);
133   s3 = vmulq_u32(s3, const_n_val);
134   s4 = vmulq_u32(s4, const_n_val);
135   s5 = vmulq_u32(s5, const_n_val);
136   s6 = vmulq_u32(s6, const_n_val);
137   s7 = vmulq_u32(s7, const_n_val);
138 
139   d0 = vget_low_u16(s16_4);
140   d1 = vget_low_u16(s16_5);
141   d2 = vget_low_u16(s16_6);
142   d3 = vget_low_u16(s16_7);
143   d4 = vget_high_u16(s16_4);
144   d5 = vget_high_u16(s16_5);
145   d6 = vget_high_u16(s16_6);
146   d7 = vget_high_u16(s16_7);
147 
148   q0 = vmull_u16(d0, d0);
149   q1 = vmull_u16(d1, d1);
150   q2 = vmull_u16(d2, d2);
151   q3 = vmull_u16(d3, d3);
152   q4 = vmull_u16(d4, d4);
153   q5 = vmull_u16(d5, d5);
154   q6 = vmull_u16(d6, d6);
155   q7 = vmull_u16(d7, d7);
156 
157   p0 = vcleq_u32(q0, s0);
158   p1 = vcleq_u32(q1, s1);
159   p2 = vcleq_u32(q2, s2);
160   p3 = vcleq_u32(q3, s3);
161   p4 = vcleq_u32(q4, s4);
162   p5 = vcleq_u32(q5, s5);
163   p6 = vcleq_u32(q6, s6);
164   p7 = vcleq_u32(q7, s7);
165 
166   q0 = vsubq_u32(s0, q0);
167   q1 = vsubq_u32(s1, q1);
168   q2 = vsubq_u32(s2, q2);
169   q3 = vsubq_u32(s3, q3);
170   q4 = vsubq_u32(s4, q4);
171   q5 = vsubq_u32(s5, q5);
172   q6 = vsubq_u32(s6, q6);
173   q7 = vsubq_u32(s7, q7);
174 
175   p0 = vandq_u32(p0, q0);
176   p1 = vandq_u32(p1, q1);
177   p2 = vandq_u32(p2, q2);
178   p3 = vandq_u32(p3, q3);
179   p4 = vandq_u32(p4, q4);
180   p5 = vandq_u32(p5, q5);
181   p6 = vandq_u32(p6, q6);
182   p7 = vandq_u32(p7, q7);
183 
184   p0 = vmulq_u32(p0, s_vec);
185   p1 = vmulq_u32(p1, s_vec);
186   p2 = vmulq_u32(p2, s_vec);
187   p3 = vmulq_u32(p3, s_vec);
188   p4 = vmulq_u32(p4, s_vec);
189   p5 = vmulq_u32(p5, s_vec);
190   p6 = vmulq_u32(p6, s_vec);
191   p7 = vmulq_u32(p7, s_vec);
192 
193   p0 = vrshrq_n_u32(p0, SGRPROJ_MTABLE_BITS);
194   p1 = vrshrq_n_u32(p1, SGRPROJ_MTABLE_BITS);
195   p2 = vrshrq_n_u32(p2, SGRPROJ_MTABLE_BITS);
196   p3 = vrshrq_n_u32(p3, SGRPROJ_MTABLE_BITS);
197   p4 = vrshrq_n_u32(p4, SGRPROJ_MTABLE_BITS);
198   p5 = vrshrq_n_u32(p5, SGRPROJ_MTABLE_BITS);
199   p6 = vrshrq_n_u32(p6, SGRPROJ_MTABLE_BITS);
200   p7 = vrshrq_n_u32(p7, SGRPROJ_MTABLE_BITS);
201 
202   p0 = vminq_u32(p0, const_val);
203   p1 = vminq_u32(p1, const_val);
204   p2 = vminq_u32(p2, const_val);
205   p3 = vminq_u32(p3, const_val);
206   p4 = vminq_u32(p4, const_val);
207   p5 = vminq_u32(p5, const_val);
208   p6 = vminq_u32(p6, const_val);
209   p7 = vminq_u32(p7, const_val);
210 
211   {
212     store_u32_4x4((uint32_t *)src1, buf_stride, p0, p1, p2, p3);
213     store_u32_4x4((uint32_t *)src1 + 4, buf_stride, p4, p5, p6, p7);
214 
215     for (int x = 0; x < 4; x++) {
216       for (int y = 0; y < 8; y++) {
217         dst_A16[x * buf_stride + y] = av1_x_by_xplus1[src1[x * buf_stride + y]];
218       }
219     }
220     load_u16_8x4(dst_A16, buf_stride, &s16_4, &s16_5, &s16_6, &s16_7);
221   }
222 
223   s16_4 = vsubq_u16(sgrproj_sgr, s16_4);
224   s16_5 = vsubq_u16(sgrproj_sgr, s16_5);
225   s16_6 = vsubq_u16(sgrproj_sgr, s16_6);
226   s16_7 = vsubq_u16(sgrproj_sgr, s16_7);
227 
228   s0 = vmull_u16(vget_low_u16(s16_0), one_by_n_minus_1_vec);
229   s1 = vmull_u16(vget_low_u16(s16_1), one_by_n_minus_1_vec);
230   s2 = vmull_u16(vget_low_u16(s16_2), one_by_n_minus_1_vec);
231   s3 = vmull_u16(vget_low_u16(s16_3), one_by_n_minus_1_vec);
232   s4 = vmull_u16(vget_high_u16(s16_0), one_by_n_minus_1_vec);
233   s5 = vmull_u16(vget_high_u16(s16_1), one_by_n_minus_1_vec);
234   s6 = vmull_u16(vget_high_u16(s16_2), one_by_n_minus_1_vec);
235   s7 = vmull_u16(vget_high_u16(s16_3), one_by_n_minus_1_vec);
236 
237   s0 = vmulq_u32(s0, vmovl_u16(vget_low_u16(s16_4)));
238   s1 = vmulq_u32(s1, vmovl_u16(vget_low_u16(s16_5)));
239   s2 = vmulq_u32(s2, vmovl_u16(vget_low_u16(s16_6)));
240   s3 = vmulq_u32(s3, vmovl_u16(vget_low_u16(s16_7)));
241   s4 = vmulq_u32(s4, vmovl_u16(vget_high_u16(s16_4)));
242   s5 = vmulq_u32(s5, vmovl_u16(vget_high_u16(s16_5)));
243   s6 = vmulq_u32(s6, vmovl_u16(vget_high_u16(s16_6)));
244   s7 = vmulq_u32(s7, vmovl_u16(vget_high_u16(s16_7)));
245 
246   p0 = vrshrq_n_u32(s0, SGRPROJ_RECIP_BITS);
247   p1 = vrshrq_n_u32(s1, SGRPROJ_RECIP_BITS);
248   p2 = vrshrq_n_u32(s2, SGRPROJ_RECIP_BITS);
249   p3 = vrshrq_n_u32(s3, SGRPROJ_RECIP_BITS);
250   p4 = vrshrq_n_u32(s4, SGRPROJ_RECIP_BITS);
251   p5 = vrshrq_n_u32(s5, SGRPROJ_RECIP_BITS);
252   p6 = vrshrq_n_u32(s6, SGRPROJ_RECIP_BITS);
253   p7 = vrshrq_n_u32(s7, SGRPROJ_RECIP_BITS);
254 
255   store_s32_4x4(dst2, buf_stride, vreinterpretq_s32_u32(p0),
256                 vreinterpretq_s32_u32(p1), vreinterpretq_s32_u32(p2),
257                 vreinterpretq_s32_u32(p3));
258   store_s32_4x4(dst2 + 4, buf_stride, vreinterpretq_s32_u32(p4),
259                 vreinterpretq_s32_u32(p5), vreinterpretq_s32_u32(p6),
260                 vreinterpretq_s32_u32(p7));
261 }
262 
boxsum2_square_sum_calc(int16x4_t t1,int16x4_t t2,int16x4_t t3,int16x4_t t4,int16x4_t t5,int16x4_t t6,int16x4_t t7,int16x4_t t8,int16x4_t t9,int16x4_t t10,int16x4_t t11,int32x4_t * r0,int32x4_t * r1,int32x4_t * r2,int32x4_t * r3)263 static inline void boxsum2_square_sum_calc(
264     int16x4_t t1, int16x4_t t2, int16x4_t t3, int16x4_t t4, int16x4_t t5,
265     int16x4_t t6, int16x4_t t7, int16x4_t t8, int16x4_t t9, int16x4_t t10,
266     int16x4_t t11, int32x4_t *r0, int32x4_t *r1, int32x4_t *r2, int32x4_t *r3) {
267   int32x4_t d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11;
268   int32x4_t r12, r34, r67, r89, r1011;
269   int32x4_t r345, r6789, r789;
270 
271   d1 = vmull_s16(t1, t1);
272   d2 = vmull_s16(t2, t2);
273   d3 = vmull_s16(t3, t3);
274   d4 = vmull_s16(t4, t4);
275   d5 = vmull_s16(t5, t5);
276   d6 = vmull_s16(t6, t6);
277   d7 = vmull_s16(t7, t7);
278   d8 = vmull_s16(t8, t8);
279   d9 = vmull_s16(t9, t9);
280   d10 = vmull_s16(t10, t10);
281   d11 = vmull_s16(t11, t11);
282 
283   r12 = vaddq_s32(d1, d2);
284   r34 = vaddq_s32(d3, d4);
285   r67 = vaddq_s32(d6, d7);
286   r89 = vaddq_s32(d8, d9);
287   r1011 = vaddq_s32(d10, d11);
288   r345 = vaddq_s32(r34, d5);
289   r6789 = vaddq_s32(r67, r89);
290   r789 = vsubq_s32(r6789, d6);
291   *r0 = vaddq_s32(r12, r345);
292   *r1 = vaddq_s32(r67, r345);
293   *r2 = vaddq_s32(d5, r6789);
294   *r3 = vaddq_s32(r789, r1011);
295 }
296 
boxsum2(int16_t * src,const int src_stride,int16_t * dst16,int32_t * dst32,int32_t * dst2,const int dst_stride,const int width,const int height)297 static inline void boxsum2(int16_t *src, const int src_stride, int16_t *dst16,
298                            int32_t *dst32, int32_t *dst2, const int dst_stride,
299                            const int width, const int height) {
300   assert(width > 2 * SGRPROJ_BORDER_HORZ);
301   assert(height > 2 * SGRPROJ_BORDER_VERT);
302 
303   int16_t *dst1_16_ptr, *src_ptr;
304   int32_t *dst2_ptr;
305   int h, w, count = 0;
306   const int dst_stride_2 = (dst_stride << 1);
307   const int dst_stride_8 = (dst_stride << 3);
308 
309   dst1_16_ptr = dst16;
310   dst2_ptr = dst2;
311   src_ptr = src;
312   w = width;
313   {
314     int16x8_t t1, t2, t3, t4, t5, t6, t7;
315     int16x8_t t8, t9, t10, t11, t12;
316 
317     int16x8_t q12345, q56789, q34567, q7891011;
318     int16x8_t q12, q34, q67, q89, q1011;
319     int16x8_t q345, q6789, q789;
320 
321     int32x4_t r12345, r56789, r34567, r7891011;
322 
323     do {
324       h = height;
325       dst1_16_ptr = dst16 + (count << 3);
326       dst2_ptr = dst2 + (count << 3);
327       src_ptr = src + (count << 3);
328 
329       dst1_16_ptr += dst_stride_2;
330       dst2_ptr += dst_stride_2;
331       do {
332         load_s16_8x4(src_ptr, src_stride, &t1, &t2, &t3, &t4);
333         src_ptr += 4 * src_stride;
334         load_s16_8x4(src_ptr, src_stride, &t5, &t6, &t7, &t8);
335         src_ptr += 4 * src_stride;
336         load_s16_8x4(src_ptr, src_stride, &t9, &t10, &t11, &t12);
337 
338         q12 = vaddq_s16(t1, t2);
339         q34 = vaddq_s16(t3, t4);
340         q67 = vaddq_s16(t6, t7);
341         q89 = vaddq_s16(t8, t9);
342         q1011 = vaddq_s16(t10, t11);
343         q345 = vaddq_s16(q34, t5);
344         q6789 = vaddq_s16(q67, q89);
345         q789 = vaddq_s16(q89, t7);
346         q12345 = vaddq_s16(q12, q345);
347         q34567 = vaddq_s16(q67, q345);
348         q56789 = vaddq_s16(t5, q6789);
349         q7891011 = vaddq_s16(q789, q1011);
350 
351         store_s16_8x4(dst1_16_ptr, dst_stride_2, q12345, q34567, q56789,
352                       q7891011);
353         dst1_16_ptr += dst_stride_8;
354 
355         boxsum2_square_sum_calc(
356             vget_low_s16(t1), vget_low_s16(t2), vget_low_s16(t3),
357             vget_low_s16(t4), vget_low_s16(t5), vget_low_s16(t6),
358             vget_low_s16(t7), vget_low_s16(t8), vget_low_s16(t9),
359             vget_low_s16(t10), vget_low_s16(t11), &r12345, &r34567, &r56789,
360             &r7891011);
361 
362         store_s32_4x4(dst2_ptr, dst_stride_2, r12345, r34567, r56789, r7891011);
363 
364         boxsum2_square_sum_calc(
365             vget_high_s16(t1), vget_high_s16(t2), vget_high_s16(t3),
366             vget_high_s16(t4), vget_high_s16(t5), vget_high_s16(t6),
367             vget_high_s16(t7), vget_high_s16(t8), vget_high_s16(t9),
368             vget_high_s16(t10), vget_high_s16(t11), &r12345, &r34567, &r56789,
369             &r7891011);
370 
371         store_s32_4x4(dst2_ptr + 4, dst_stride_2, r12345, r34567, r56789,
372                       r7891011);
373         dst2_ptr += (dst_stride_8);
374         h -= 8;
375       } while (h > 0);
376       w -= 8;
377       count++;
378     } while (w > 0);
379 
380     // memset needed for row pixels as 2nd stage of boxsum filter uses
381     // first 2 rows of dst16, dst2 buffer which is not filled in first stage.
382     for (int x = 0; x < 2; x++) {
383       memset(dst16 + x * dst_stride, 0, (width + 4) * sizeof(*dst16));
384       memset(dst2 + x * dst_stride, 0, (width + 4) * sizeof(*dst2));
385     }
386 
387     // memset needed for extra columns as 2nd stage of boxsum filter uses
388     // last 2 columns of dst16, dst2 buffer which is not filled in first stage.
389     for (int x = 2; x < height + 2; x++) {
390       int dst_offset = x * dst_stride + width + 2;
391       memset(dst16 + dst_offset, 0, 3 * sizeof(*dst16));
392       memset(dst2 + dst_offset, 0, 3 * sizeof(*dst2));
393     }
394   }
395 
396   {
397     int16x4_t s1, s2, s3, s4, s5, s6, s7, s8;
398     int32x4_t d1, d2, d3, d4, d5, d6, d7, d8;
399     int32x4_t q12345, q34567, q23456, q45678;
400     int32x4_t q23, q45, q67;
401     int32x4_t q2345, q4567;
402 
403     int32x4_t r12345, r34567, r23456, r45678;
404     int32x4_t r23, r45, r67;
405     int32x4_t r2345, r4567;
406 
407     int32_t *src2_ptr, *dst1_32_ptr;
408     int16_t *src1_ptr;
409     count = 0;
410     h = height;
411     do {
412       dst1_32_ptr = dst32 + count * dst_stride_8 + (dst_stride_2);
413       dst2_ptr = dst2 + count * dst_stride_8 + (dst_stride_2);
414       src1_ptr = dst16 + count * dst_stride_8 + (dst_stride_2);
415       src2_ptr = dst2 + count * dst_stride_8 + (dst_stride_2);
416       w = width;
417 
418       dst1_32_ptr += 2;
419       dst2_ptr += 2;
420       load_s16_4x4(src1_ptr, dst_stride_2, &s1, &s2, &s3, &s4);
421       transpose_elems_inplace_s16_4x4(&s1, &s2, &s3, &s4);
422       load_s32_4x4(src2_ptr, dst_stride_2, &d1, &d2, &d3, &d4);
423       transpose_elems_inplace_s32_4x4(&d1, &d2, &d3, &d4);
424       do {
425         src1_ptr += 4;
426         src2_ptr += 4;
427         load_s16_4x4(src1_ptr, dst_stride_2, &s5, &s6, &s7, &s8);
428         transpose_elems_inplace_s16_4x4(&s5, &s6, &s7, &s8);
429         load_s32_4x4(src2_ptr, dst_stride_2, &d5, &d6, &d7, &d8);
430         transpose_elems_inplace_s32_4x4(&d5, &d6, &d7, &d8);
431         q23 = vaddl_s16(s2, s3);
432         q45 = vaddl_s16(s4, s5);
433         q67 = vaddl_s16(s6, s7);
434         q2345 = vaddq_s32(q23, q45);
435         q4567 = vaddq_s32(q45, q67);
436         q12345 = vaddq_s32(vmovl_s16(s1), q2345);
437         q23456 = vaddq_s32(q2345, vmovl_s16(s6));
438         q34567 = vaddq_s32(q4567, vmovl_s16(s3));
439         q45678 = vaddq_s32(q4567, vmovl_s16(s8));
440 
441         transpose_elems_inplace_s32_4x4(&q12345, &q23456, &q34567, &q45678);
442         store_s32_4x4(dst1_32_ptr, dst_stride_2, q12345, q23456, q34567,
443                       q45678);
444         dst1_32_ptr += 4;
445         s1 = s5;
446         s2 = s6;
447         s3 = s7;
448         s4 = s8;
449 
450         r23 = vaddq_s32(d2, d3);
451         r45 = vaddq_s32(d4, d5);
452         r67 = vaddq_s32(d6, d7);
453         r2345 = vaddq_s32(r23, r45);
454         r4567 = vaddq_s32(r45, r67);
455         r12345 = vaddq_s32(d1, r2345);
456         r23456 = vaddq_s32(r2345, d6);
457         r34567 = vaddq_s32(r4567, d3);
458         r45678 = vaddq_s32(r4567, d8);
459 
460         transpose_elems_inplace_s32_4x4(&r12345, &r23456, &r34567, &r45678);
461         store_s32_4x4(dst2_ptr, dst_stride_2, r12345, r23456, r34567, r45678);
462         dst2_ptr += 4;
463         d1 = d5;
464         d2 = d6;
465         d3 = d7;
466         d4 = d8;
467         w -= 4;
468       } while (w > 0);
469       h -= 8;
470       count++;
471     } while (h > 0);
472   }
473 }
474 
calc_ab_internal_lbd(int32_t * A,uint16_t * A16,uint16_t * B16,int32_t * B,const int buf_stride,const int width,const int height,const int r,const int s,const int ht_inc)475 static inline void calc_ab_internal_lbd(int32_t *A, uint16_t *A16,
476                                         uint16_t *B16, int32_t *B,
477                                         const int buf_stride, const int width,
478                                         const int height, const int r,
479                                         const int s, const int ht_inc) {
480   int32_t *src1, *dst2, count = 0;
481   uint16_t *dst_A16, *src2;
482   const uint32_t n = (2 * r + 1) * (2 * r + 1);
483   const uint32x4_t const_n_val = vdupq_n_u32(n);
484   const uint16x8_t sgrproj_sgr = vdupq_n_u16(SGRPROJ_SGR);
485   const uint16x4_t one_by_n_minus_1_vec = vdup_n_u16(av1_one_by_x[n - 1]);
486   const uint32x4_t const_val = vdupq_n_u32(255);
487 
488   uint16x8_t s16_0, s16_1, s16_2, s16_3, s16_4, s16_5, s16_6, s16_7;
489 
490   uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
491 
492   const uint32x4_t s_vec = vdupq_n_u32(s);
493   int w, h = height;
494 
495   do {
496     dst_A16 = A16 + (count << 2) * buf_stride;
497     src1 = A + (count << 2) * buf_stride;
498     src2 = B16 + (count << 2) * buf_stride;
499     dst2 = B + (count << 2) * buf_stride;
500     w = width;
501     do {
502       load_u32_4x4((uint32_t *)src1, buf_stride, &s0, &s1, &s2, &s3);
503       load_u32_4x4((uint32_t *)src1 + 4, buf_stride, &s4, &s5, &s6, &s7);
504       load_u16_8x4(src2, buf_stride, &s16_0, &s16_1, &s16_2, &s16_3);
505 
506       s16_4 = s16_0;
507       s16_5 = s16_1;
508       s16_6 = s16_2;
509       s16_7 = s16_3;
510 
511       calc_ab_internal_common(
512           s0, s1, s2, s3, s4, s5, s6, s7, s16_0, s16_1, s16_2, s16_3, s16_4,
513           s16_5, s16_6, s16_7, const_n_val, s_vec, const_val,
514           one_by_n_minus_1_vec, sgrproj_sgr, src1, dst_A16, dst2, buf_stride);
515 
516       w -= 8;
517       dst2 += 8;
518       src1 += 8;
519       src2 += 8;
520       dst_A16 += 8;
521     } while (w > 0);
522     count++;
523     h -= (ht_inc * 4);
524   } while (h > 0);
525 }
526 
527 #if CONFIG_AV1_HIGHBITDEPTH
calc_ab_internal_hbd(int32_t * A,uint16_t * A16,uint16_t * B16,int32_t * B,const int buf_stride,const int width,const int height,const int bit_depth,const int r,const int s,const int ht_inc)528 static inline void calc_ab_internal_hbd(int32_t *A, uint16_t *A16,
529                                         uint16_t *B16, int32_t *B,
530                                         const int buf_stride, const int width,
531                                         const int height, const int bit_depth,
532                                         const int r, const int s,
533                                         const int ht_inc) {
534   int32_t *src1, *dst2, count = 0;
535   uint16_t *dst_A16, *src2;
536   const uint32_t n = (2 * r + 1) * (2 * r + 1);
537   const int16x8_t bd_min_2_vec = vdupq_n_s16(-(bit_depth - 8));
538   const int32x4_t bd_min_1_vec = vdupq_n_s32(-((bit_depth - 8) << 1));
539   const uint32x4_t const_n_val = vdupq_n_u32(n);
540   const uint16x8_t sgrproj_sgr = vdupq_n_u16(SGRPROJ_SGR);
541   const uint16x4_t one_by_n_minus_1_vec = vdup_n_u16(av1_one_by_x[n - 1]);
542   const uint32x4_t const_val = vdupq_n_u32(255);
543 
544   int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7;
545   uint16x8_t s16_0, s16_1, s16_2, s16_3;
546   uint16x8_t s16_4, s16_5, s16_6, s16_7;
547   uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
548 
549   const uint32x4_t s_vec = vdupq_n_u32(s);
550   int w, h = height;
551 
552   do {
553     src1 = A + (count << 2) * buf_stride;
554     src2 = B16 + (count << 2) * buf_stride;
555     dst2 = B + (count << 2) * buf_stride;
556     dst_A16 = A16 + (count << 2) * buf_stride;
557     w = width;
558     do {
559       load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3);
560       load_s32_4x4(src1 + 4, buf_stride, &sr4, &sr5, &sr6, &sr7);
561       load_u16_8x4(src2, buf_stride, &s16_0, &s16_1, &s16_2, &s16_3);
562 
563       s0 = vrshlq_u32(vreinterpretq_u32_s32(sr0), bd_min_1_vec);
564       s1 = vrshlq_u32(vreinterpretq_u32_s32(sr1), bd_min_1_vec);
565       s2 = vrshlq_u32(vreinterpretq_u32_s32(sr2), bd_min_1_vec);
566       s3 = vrshlq_u32(vreinterpretq_u32_s32(sr3), bd_min_1_vec);
567       s4 = vrshlq_u32(vreinterpretq_u32_s32(sr4), bd_min_1_vec);
568       s5 = vrshlq_u32(vreinterpretq_u32_s32(sr5), bd_min_1_vec);
569       s6 = vrshlq_u32(vreinterpretq_u32_s32(sr6), bd_min_1_vec);
570       s7 = vrshlq_u32(vreinterpretq_u32_s32(sr7), bd_min_1_vec);
571 
572       s16_4 = vrshlq_u16(s16_0, bd_min_2_vec);
573       s16_5 = vrshlq_u16(s16_1, bd_min_2_vec);
574       s16_6 = vrshlq_u16(s16_2, bd_min_2_vec);
575       s16_7 = vrshlq_u16(s16_3, bd_min_2_vec);
576 
577       calc_ab_internal_common(
578           s0, s1, s2, s3, s4, s5, s6, s7, s16_0, s16_1, s16_2, s16_3, s16_4,
579           s16_5, s16_6, s16_7, const_n_val, s_vec, const_val,
580           one_by_n_minus_1_vec, sgrproj_sgr, src1, dst_A16, dst2, buf_stride);
581 
582       w -= 8;
583       dst2 += 8;
584       src1 += 8;
585       src2 += 8;
586       dst_A16 += 8;
587     } while (w > 0);
588     count++;
589     h -= (ht_inc * 4);
590   } while (h > 0);
591 }
592 #endif  // CONFIG_AV1_HIGHBITDEPTH
593 
calc_ab_fast_internal_lbd(int32_t * A,uint16_t * A16,int32_t * B,const int buf_stride,const int width,const int height,const int r,const int s,const int ht_inc)594 static inline void calc_ab_fast_internal_lbd(int32_t *A, uint16_t *A16,
595                                              int32_t *B, const int buf_stride,
596                                              const int width, const int height,
597                                              const int r, const int s,
598                                              const int ht_inc) {
599   int32_t *src1, *src2, count = 0;
600   uint16_t *dst_A16;
601   const uint32_t n = (2 * r + 1) * (2 * r + 1);
602   const uint32x4_t const_n_val = vdupq_n_u32(n);
603   const uint16x4_t sgrproj_sgr = vdup_n_u16(SGRPROJ_SGR);
604   const uint32x4_t one_by_n_minus_1_vec = vdupq_n_u32(av1_one_by_x[n - 1]);
605   const uint32x4_t const_val = vdupq_n_u32(255);
606 
607   int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7;
608   uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
609 
610   const uint32x4_t s_vec = vdupq_n_u32(s);
611   int w, h = height;
612 
613   do {
614     src1 = A + (count << 2) * buf_stride;
615     src2 = B + (count << 2) * buf_stride;
616     dst_A16 = A16 + (count << 2) * buf_stride;
617     w = width;
618     do {
619       load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3);
620       load_s32_4x4(src2, buf_stride, &sr4, &sr5, &sr6, &sr7);
621 
622       s0 = vreinterpretq_u32_s32(sr0);
623       s1 = vreinterpretq_u32_s32(sr1);
624       s2 = vreinterpretq_u32_s32(sr2);
625       s3 = vreinterpretq_u32_s32(sr3);
626       s4 = vreinterpretq_u32_s32(sr4);
627       s5 = vreinterpretq_u32_s32(sr5);
628       s6 = vreinterpretq_u32_s32(sr6);
629       s7 = vreinterpretq_u32_s32(sr7);
630 
631       calc_ab_fast_internal_common(s0, s1, s2, s3, s4, s5, s6, s7, sr4, sr5,
632                                    sr6, sr7, const_n_val, s_vec, const_val,
633                                    one_by_n_minus_1_vec, sgrproj_sgr, src1,
634                                    dst_A16, src2, buf_stride);
635 
636       w -= 4;
637       src1 += 4;
638       src2 += 4;
639       dst_A16 += 4;
640     } while (w > 0);
641     count++;
642     h -= (ht_inc * 4);
643   } while (h > 0);
644 }
645 
646 #if CONFIG_AV1_HIGHBITDEPTH
calc_ab_fast_internal_hbd(int32_t * A,uint16_t * A16,int32_t * B,const int buf_stride,const int width,const int height,const int bit_depth,const int r,const int s,const int ht_inc)647 static inline void calc_ab_fast_internal_hbd(int32_t *A, uint16_t *A16,
648                                              int32_t *B, const int buf_stride,
649                                              const int width, const int height,
650                                              const int bit_depth, const int r,
651                                              const int s, const int ht_inc) {
652   int32_t *src1, *src2, count = 0;
653   uint16_t *dst_A16;
654   const uint32_t n = (2 * r + 1) * (2 * r + 1);
655   const int32x4_t bd_min_2_vec = vdupq_n_s32(-(bit_depth - 8));
656   const int32x4_t bd_min_1_vec = vdupq_n_s32(-((bit_depth - 8) << 1));
657   const uint32x4_t const_n_val = vdupq_n_u32(n);
658   const uint16x4_t sgrproj_sgr = vdup_n_u16(SGRPROJ_SGR);
659   const uint32x4_t one_by_n_minus_1_vec = vdupq_n_u32(av1_one_by_x[n - 1]);
660   const uint32x4_t const_val = vdupq_n_u32(255);
661 
662   int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7;
663   uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
664 
665   const uint32x4_t s_vec = vdupq_n_u32(s);
666   int w, h = height;
667 
668   do {
669     src1 = A + (count << 2) * buf_stride;
670     src2 = B + (count << 2) * buf_stride;
671     dst_A16 = A16 + (count << 2) * buf_stride;
672     w = width;
673     do {
674       load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3);
675       load_s32_4x4(src2, buf_stride, &sr4, &sr5, &sr6, &sr7);
676 
677       s0 = vrshlq_u32(vreinterpretq_u32_s32(sr0), bd_min_1_vec);
678       s1 = vrshlq_u32(vreinterpretq_u32_s32(sr1), bd_min_1_vec);
679       s2 = vrshlq_u32(vreinterpretq_u32_s32(sr2), bd_min_1_vec);
680       s3 = vrshlq_u32(vreinterpretq_u32_s32(sr3), bd_min_1_vec);
681       s4 = vrshlq_u32(vreinterpretq_u32_s32(sr4), bd_min_2_vec);
682       s5 = vrshlq_u32(vreinterpretq_u32_s32(sr5), bd_min_2_vec);
683       s6 = vrshlq_u32(vreinterpretq_u32_s32(sr6), bd_min_2_vec);
684       s7 = vrshlq_u32(vreinterpretq_u32_s32(sr7), bd_min_2_vec);
685 
686       calc_ab_fast_internal_common(s0, s1, s2, s3, s4, s5, s6, s7, sr4, sr5,
687                                    sr6, sr7, const_n_val, s_vec, const_val,
688                                    one_by_n_minus_1_vec, sgrproj_sgr, src1,
689                                    dst_A16, src2, buf_stride);
690 
691       w -= 4;
692       src1 += 4;
693       src2 += 4;
694       dst_A16 += 4;
695     } while (w > 0);
696     count++;
697     h -= (ht_inc * 4);
698   } while (h > 0);
699 }
700 #endif  // CONFIG_AV1_HIGHBITDEPTH
701 
boxsum1(int16_t * src,const int src_stride,uint16_t * dst1,int32_t * dst2,const int dst_stride,const int width,const int height)702 static inline void boxsum1(int16_t *src, const int src_stride, uint16_t *dst1,
703                            int32_t *dst2, const int dst_stride, const int width,
704                            const int height) {
705   assert(width > 2 * SGRPROJ_BORDER_HORZ);
706   assert(height > 2 * SGRPROJ_BORDER_VERT);
707 
708   int16_t *src_ptr;
709   int32_t *dst2_ptr;
710   uint16_t *dst1_ptr;
711   int h, w, count = 0;
712 
713   w = width;
714   {
715     int16x8_t s1, s2, s3, s4, s5, s6, s7, s8;
716     int16x8_t q23, q34, q56, q234, q345, q456, q567;
717     int32x4_t r23, r56, r345, r456, r567, r78, r678;
718     int32x4_t r4_low, r4_high, r34_low, r34_high, r234_low, r234_high;
719     int32x4_t r2, r3, r5, r6, r7, r8;
720     int16x8_t q678, q78;
721 
722     do {
723       dst1_ptr = dst1 + (count << 3);
724       dst2_ptr = dst2 + (count << 3);
725       src_ptr = src + (count << 3);
726       h = height;
727 
728       load_s16_8x4(src_ptr, src_stride, &s1, &s2, &s3, &s4);
729       src_ptr += 4 * src_stride;
730 
731       q23 = vaddq_s16(s2, s3);
732       q234 = vaddq_s16(q23, s4);
733       q34 = vaddq_s16(s3, s4);
734       dst1_ptr += (dst_stride << 1);
735 
736       r2 = vmull_s16(vget_low_s16(s2), vget_low_s16(s2));
737       r3 = vmull_s16(vget_low_s16(s3), vget_low_s16(s3));
738       r4_low = vmull_s16(vget_low_s16(s4), vget_low_s16(s4));
739       r23 = vaddq_s32(r2, r3);
740       r234_low = vaddq_s32(r23, r4_low);
741       r34_low = vaddq_s32(r3, r4_low);
742 
743       r2 = vmull_s16(vget_high_s16(s2), vget_high_s16(s2));
744       r3 = vmull_s16(vget_high_s16(s3), vget_high_s16(s3));
745       r4_high = vmull_s16(vget_high_s16(s4), vget_high_s16(s4));
746       r23 = vaddq_s32(r2, r3);
747       r234_high = vaddq_s32(r23, r4_high);
748       r34_high = vaddq_s32(r3, r4_high);
749 
750       dst2_ptr += (dst_stride << 1);
751 
752       do {
753         load_s16_8x4(src_ptr, src_stride, &s5, &s6, &s7, &s8);
754         src_ptr += 4 * src_stride;
755 
756         q345 = vaddq_s16(s5, q34);
757         q56 = vaddq_s16(s5, s6);
758         q456 = vaddq_s16(s4, q56);
759         q567 = vaddq_s16(s7, q56);
760         q78 = vaddq_s16(s7, s8);
761         q678 = vaddq_s16(s6, q78);
762 
763         store_s16_8x4((int16_t *)dst1_ptr, dst_stride, q234, q345, q456, q567);
764         dst1_ptr += (dst_stride << 2);
765 
766         s4 = s8;
767         q34 = q78;
768         q234 = q678;
769 
770         r5 = vmull_s16(vget_low_s16(s5), vget_low_s16(s5));
771         r6 = vmull_s16(vget_low_s16(s6), vget_low_s16(s6));
772         r7 = vmull_s16(vget_low_s16(s7), vget_low_s16(s7));
773         r8 = vmull_s16(vget_low_s16(s8), vget_low_s16(s8));
774 
775         r345 = vaddq_s32(r5, r34_low);
776         r56 = vaddq_s32(r5, r6);
777         r456 = vaddq_s32(r4_low, r56);
778         r567 = vaddq_s32(r7, r56);
779         r78 = vaddq_s32(r7, r8);
780         r678 = vaddq_s32(r6, r78);
781         store_s32_4x4(dst2_ptr, dst_stride, r234_low, r345, r456, r567);
782 
783         r4_low = r8;
784         r34_low = r78;
785         r234_low = r678;
786 
787         r5 = vmull_s16(vget_high_s16(s5), vget_high_s16(s5));
788         r6 = vmull_s16(vget_high_s16(s6), vget_high_s16(s6));
789         r7 = vmull_s16(vget_high_s16(s7), vget_high_s16(s7));
790         r8 = vmull_s16(vget_high_s16(s8), vget_high_s16(s8));
791 
792         r345 = vaddq_s32(r5, r34_high);
793         r56 = vaddq_s32(r5, r6);
794         r456 = vaddq_s32(r4_high, r56);
795         r567 = vaddq_s32(r7, r56);
796         r78 = vaddq_s32(r7, r8);
797         r678 = vaddq_s32(r6, r78);
798         store_s32_4x4((dst2_ptr + 4), dst_stride, r234_high, r345, r456, r567);
799         dst2_ptr += (dst_stride << 2);
800 
801         r4_high = r8;
802         r34_high = r78;
803         r234_high = r678;
804 
805         h -= 4;
806       } while (h > 0);
807       w -= 8;
808       count++;
809     } while (w > 0);
810 
811     // memset needed for row pixels as 2nd stage of boxsum filter uses
812     // first 2 rows of dst1, dst2 buffer which is not filled in first stage.
813     for (int x = 0; x < 2; x++) {
814       memset(dst1 + x * dst_stride, 0, (width + 4) * sizeof(*dst1));
815       memset(dst2 + x * dst_stride, 0, (width + 4) * sizeof(*dst2));
816     }
817 
818     // memset needed for extra columns as 2nd stage of boxsum filter uses
819     // last 2 columns of dst1, dst2 buffer which is not filled in first stage.
820     for (int x = 2; x < height + 2; x++) {
821       int dst_offset = x * dst_stride + width + 2;
822       memset(dst1 + dst_offset, 0, 3 * sizeof(*dst1));
823       memset(dst2 + dst_offset, 0, 3 * sizeof(*dst2));
824     }
825   }
826 
827   {
828     int16x4_t d1, d2, d3, d4, d5, d6, d7, d8;
829     int16x4_t q23, q34, q56, q234, q345, q456, q567;
830     int32x4_t r23, r56, r234, r345, r456, r567, r34, r78, r678;
831     int32x4_t r1, r2, r3, r4, r5, r6, r7, r8;
832     int16x4_t q678, q78;
833 
834     int32_t *src2_ptr;
835     uint16_t *src1_ptr;
836     count = 0;
837     h = height;
838     w = width;
839     do {
840       dst1_ptr = dst1 + (count << 2) * dst_stride;
841       dst2_ptr = dst2 + (count << 2) * dst_stride;
842       src1_ptr = dst1 + (count << 2) * dst_stride;
843       src2_ptr = dst2 + (count << 2) * dst_stride;
844       w = width;
845 
846       load_s16_4x4((int16_t *)src1_ptr, dst_stride, &d1, &d2, &d3, &d4);
847       transpose_elems_inplace_s16_4x4(&d1, &d2, &d3, &d4);
848       load_s32_4x4(src2_ptr, dst_stride, &r1, &r2, &r3, &r4);
849       transpose_elems_inplace_s32_4x4(&r1, &r2, &r3, &r4);
850       src1_ptr += 4;
851       src2_ptr += 4;
852 
853       q23 = vadd_s16(d2, d3);
854       q234 = vadd_s16(q23, d4);
855       q34 = vadd_s16(d3, d4);
856       dst1_ptr += 2;
857       r23 = vaddq_s32(r2, r3);
858       r234 = vaddq_s32(r23, r4);
859       r34 = vaddq_s32(r3, r4);
860       dst2_ptr += 2;
861 
862       do {
863         load_s16_4x4((int16_t *)src1_ptr, dst_stride, &d5, &d6, &d7, &d8);
864         transpose_elems_inplace_s16_4x4(&d5, &d6, &d7, &d8);
865         load_s32_4x4(src2_ptr, dst_stride, &r5, &r6, &r7, &r8);
866         transpose_elems_inplace_s32_4x4(&r5, &r6, &r7, &r8);
867         src1_ptr += 4;
868         src2_ptr += 4;
869 
870         q345 = vadd_s16(d5, q34);
871         q56 = vadd_s16(d5, d6);
872         q456 = vadd_s16(d4, q56);
873         q567 = vadd_s16(d7, q56);
874         q78 = vadd_s16(d7, d8);
875         q678 = vadd_s16(d6, q78);
876         transpose_elems_inplace_s16_4x4(&q234, &q345, &q456, &q567);
877         store_s16_4x4((int16_t *)dst1_ptr, dst_stride, q234, q345, q456, q567);
878         dst1_ptr += 4;
879 
880         d4 = d8;
881         q34 = q78;
882         q234 = q678;
883 
884         r345 = vaddq_s32(r5, r34);
885         r56 = vaddq_s32(r5, r6);
886         r456 = vaddq_s32(r4, r56);
887         r567 = vaddq_s32(r7, r56);
888         r78 = vaddq_s32(r7, r8);
889         r678 = vaddq_s32(r6, r78);
890         transpose_elems_inplace_s32_4x4(&r234, &r345, &r456, &r567);
891         store_s32_4x4(dst2_ptr, dst_stride, r234, r345, r456, r567);
892         dst2_ptr += 4;
893 
894         r4 = r8;
895         r34 = r78;
896         r234 = r678;
897         w -= 4;
898       } while (w > 0);
899       h -= 4;
900       count++;
901     } while (h > 0);
902   }
903 }
904 
cross_sum_inp_s32(int32_t * buf,int buf_stride)905 static inline int32x4_t cross_sum_inp_s32(int32_t *buf, int buf_stride) {
906   int32x4_t xtr, xt, xtl, xl, x, xr, xbr, xb, xbl;
907   int32x4_t fours, threes, res;
908 
909   xtl = vld1q_s32(buf - buf_stride - 1);
910   xt = vld1q_s32(buf - buf_stride);
911   xtr = vld1q_s32(buf - buf_stride + 1);
912   xl = vld1q_s32(buf - 1);
913   x = vld1q_s32(buf);
914   xr = vld1q_s32(buf + 1);
915   xbl = vld1q_s32(buf + buf_stride - 1);
916   xb = vld1q_s32(buf + buf_stride);
917   xbr = vld1q_s32(buf + buf_stride + 1);
918 
919   fours = vaddq_s32(xl, vaddq_s32(xt, vaddq_s32(xr, vaddq_s32(xb, x))));
920   threes = vaddq_s32(xtl, vaddq_s32(xtr, vaddq_s32(xbr, xbl)));
921   res = vsubq_s32(vshlq_n_s32(vaddq_s32(fours, threes), 2), threes);
922   return res;
923 }
924 
cross_sum_inp_u16(uint16_t * buf,int buf_stride,int32x4_t * a0,int32x4_t * a1)925 static inline void cross_sum_inp_u16(uint16_t *buf, int buf_stride,
926                                      int32x4_t *a0, int32x4_t *a1) {
927   uint16x8_t xtr, xt, xtl, xl, x, xr, xbr, xb, xbl;
928   uint16x8_t r0, r1;
929 
930   xtl = vld1q_u16(buf - buf_stride - 1);
931   xt = vld1q_u16(buf - buf_stride);
932   xtr = vld1q_u16(buf - buf_stride + 1);
933   xl = vld1q_u16(buf - 1);
934   x = vld1q_u16(buf);
935   xr = vld1q_u16(buf + 1);
936   xbl = vld1q_u16(buf + buf_stride - 1);
937   xb = vld1q_u16(buf + buf_stride);
938   xbr = vld1q_u16(buf + buf_stride + 1);
939 
940   xb = vaddq_u16(xb, x);
941   xt = vaddq_u16(xt, xr);
942   xl = vaddq_u16(xl, xb);
943   xl = vaddq_u16(xl, xt);
944 
945   r0 = vshlq_n_u16(xl, 2);
946 
947   xbl = vaddq_u16(xbl, xbr);
948   xtl = vaddq_u16(xtl, xtr);
949   xtl = vaddq_u16(xtl, xbl);
950 
951   r1 = vshlq_n_u16(xtl, 2);
952   r1 = vsubq_u16(r1, xtl);
953 
954   *a0 = vreinterpretq_s32_u32(
955       vaddq_u32(vmovl_u16(vget_low_u16(r0)), vmovl_u16(vget_low_u16(r1))));
956   *a1 = vreinterpretq_s32_u32(
957       vaddq_u32(vmovl_u16(vget_high_u16(r0)), vmovl_u16(vget_high_u16(r1))));
958 }
959 
cross_sum_fast_even_row(int32_t * buf,int buf_stride)960 static inline int32x4_t cross_sum_fast_even_row(int32_t *buf, int buf_stride) {
961   int32x4_t xtr, xt, xtl, xbr, xb, xbl;
962   int32x4_t fives, sixes, fives_plus_sixes;
963 
964   xtl = vld1q_s32(buf - buf_stride - 1);
965   xt = vld1q_s32(buf - buf_stride);
966   xtr = vld1q_s32(buf - buf_stride + 1);
967   xbl = vld1q_s32(buf + buf_stride - 1);
968   xb = vld1q_s32(buf + buf_stride);
969   xbr = vld1q_s32(buf + buf_stride + 1);
970 
971   fives = vaddq_s32(xtl, vaddq_s32(xtr, vaddq_s32(xbr, xbl)));
972   sixes = vaddq_s32(xt, xb);
973   fives_plus_sixes = vaddq_s32(fives, sixes);
974 
975   return vaddq_s32(
976       vaddq_s32(vshlq_n_s32(fives_plus_sixes, 2), fives_plus_sixes), sixes);
977 }
978 
cross_sum_fast_even_row_inp16(uint16_t * buf,int buf_stride,int32x4_t * a0,int32x4_t * a1)979 static inline void cross_sum_fast_even_row_inp16(uint16_t *buf, int buf_stride,
980                                                  int32x4_t *a0, int32x4_t *a1) {
981   uint16x8_t xtr, xt, xtl, xbr, xb, xbl, xb0;
982 
983   xtl = vld1q_u16(buf - buf_stride - 1);
984   xt = vld1q_u16(buf - buf_stride);
985   xtr = vld1q_u16(buf - buf_stride + 1);
986   xbl = vld1q_u16(buf + buf_stride - 1);
987   xb = vld1q_u16(buf + buf_stride);
988   xbr = vld1q_u16(buf + buf_stride + 1);
989 
990   xbr = vaddq_u16(xbr, xbl);
991   xtr = vaddq_u16(xtr, xtl);
992   xbr = vaddq_u16(xbr, xtr);
993   xtl = vshlq_n_u16(xbr, 2);
994   xbr = vaddq_u16(xtl, xbr);
995 
996   xb = vaddq_u16(xb, xt);
997   xb0 = vshlq_n_u16(xb, 1);
998   xb = vshlq_n_u16(xb, 2);
999   xb = vaddq_u16(xb, xb0);
1000 
1001   *a0 = vreinterpretq_s32_u32(
1002       vaddq_u32(vmovl_u16(vget_low_u16(xbr)), vmovl_u16(vget_low_u16(xb))));
1003   *a1 = vreinterpretq_s32_u32(
1004       vaddq_u32(vmovl_u16(vget_high_u16(xbr)), vmovl_u16(vget_high_u16(xb))));
1005 }
1006 
cross_sum_fast_odd_row(int32_t * buf)1007 static inline int32x4_t cross_sum_fast_odd_row(int32_t *buf) {
1008   int32x4_t xl, x, xr;
1009   int32x4_t fives, sixes, fives_plus_sixes;
1010 
1011   xl = vld1q_s32(buf - 1);
1012   x = vld1q_s32(buf);
1013   xr = vld1q_s32(buf + 1);
1014   fives = vaddq_s32(xl, xr);
1015   sixes = x;
1016   fives_plus_sixes = vaddq_s32(fives, sixes);
1017 
1018   return vaddq_s32(
1019       vaddq_s32(vshlq_n_s32(fives_plus_sixes, 2), fives_plus_sixes), sixes);
1020 }
1021 
cross_sum_fast_odd_row_inp16(uint16_t * buf,int32x4_t * a0,int32x4_t * a1)1022 static inline void cross_sum_fast_odd_row_inp16(uint16_t *buf, int32x4_t *a0,
1023                                                 int32x4_t *a1) {
1024   uint16x8_t xl, x, xr;
1025   uint16x8_t x0;
1026 
1027   xl = vld1q_u16(buf - 1);
1028   x = vld1q_u16(buf);
1029   xr = vld1q_u16(buf + 1);
1030   xl = vaddq_u16(xl, xr);
1031   x0 = vshlq_n_u16(xl, 2);
1032   xl = vaddq_u16(xl, x0);
1033 
1034   x0 = vshlq_n_u16(x, 1);
1035   x = vshlq_n_u16(x, 2);
1036   x = vaddq_u16(x, x0);
1037 
1038   *a0 = vreinterpretq_s32_u32(
1039       vaddq_u32(vmovl_u16(vget_low_u16(xl)), vmovl_u16(vget_low_u16(x))));
1040   *a1 = vreinterpretq_s32_u32(
1041       vaddq_u32(vmovl_u16(vget_high_u16(xl)), vmovl_u16(vget_high_u16(x))));
1042 }
1043 
final_filter_fast_internal(uint16_t * A,int32_t * B,const int buf_stride,int16_t * src,const int src_stride,int32_t * dst,const int dst_stride,const int width,const int height)1044 static void final_filter_fast_internal(uint16_t *A, int32_t *B,
1045                                        const int buf_stride, int16_t *src,
1046                                        const int src_stride, int32_t *dst,
1047                                        const int dst_stride, const int width,
1048                                        const int height) {
1049   int16x8_t s0;
1050   int32_t *B_tmp, *dst_ptr;
1051   uint16_t *A_tmp;
1052   int16_t *src_ptr;
1053   int32x4_t a_res0, a_res1, b_res0, b_res1;
1054   int w, h, count = 0;
1055   assert(SGRPROJ_SGR_BITS == 8);
1056   assert(SGRPROJ_RST_BITS == 4);
1057 
1058   A_tmp = A;
1059   B_tmp = B;
1060   src_ptr = src;
1061   dst_ptr = dst;
1062   h = height;
1063   do {
1064     A_tmp = (A + count * buf_stride);
1065     B_tmp = (B + count * buf_stride);
1066     src_ptr = (src + count * src_stride);
1067     dst_ptr = (dst + count * dst_stride);
1068     w = width;
1069     if (!(count & 1)) {
1070       do {
1071         s0 = vld1q_s16(src_ptr);
1072         cross_sum_fast_even_row_inp16(A_tmp, buf_stride, &a_res0, &a_res1);
1073         a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0);
1074         a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1);
1075 
1076         b_res0 = cross_sum_fast_even_row(B_tmp, buf_stride);
1077         b_res1 = cross_sum_fast_even_row(B_tmp + 4, buf_stride);
1078         a_res0 = vaddq_s32(a_res0, b_res0);
1079         a_res1 = vaddq_s32(a_res1, b_res1);
1080 
1081         a_res0 =
1082             vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
1083         a_res1 =
1084             vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
1085 
1086         vst1q_s32(dst_ptr, a_res0);
1087         vst1q_s32(dst_ptr + 4, a_res1);
1088 
1089         A_tmp += 8;
1090         B_tmp += 8;
1091         src_ptr += 8;
1092         dst_ptr += 8;
1093         w -= 8;
1094       } while (w > 0);
1095     } else {
1096       do {
1097         s0 = vld1q_s16(src_ptr);
1098         cross_sum_fast_odd_row_inp16(A_tmp, &a_res0, &a_res1);
1099         a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0);
1100         a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1);
1101 
1102         b_res0 = cross_sum_fast_odd_row(B_tmp);
1103         b_res1 = cross_sum_fast_odd_row(B_tmp + 4);
1104         a_res0 = vaddq_s32(a_res0, b_res0);
1105         a_res1 = vaddq_s32(a_res1, b_res1);
1106 
1107         a_res0 =
1108             vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_ODD - SGRPROJ_RST_BITS);
1109         a_res1 =
1110             vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_ODD - SGRPROJ_RST_BITS);
1111 
1112         vst1q_s32(dst_ptr, a_res0);
1113         vst1q_s32(dst_ptr + 4, a_res1);
1114 
1115         A_tmp += 8;
1116         B_tmp += 8;
1117         src_ptr += 8;
1118         dst_ptr += 8;
1119         w -= 8;
1120       } while (w > 0);
1121     }
1122     count++;
1123     h -= 1;
1124   } while (h > 0);
1125 }
1126 
final_filter_internal(uint16_t * A,int32_t * B,const int buf_stride,int16_t * src,const int src_stride,int32_t * dst,const int dst_stride,const int width,const int height)1127 static void final_filter_internal(uint16_t *A, int32_t *B, const int buf_stride,
1128                                   int16_t *src, const int src_stride,
1129                                   int32_t *dst, const int dst_stride,
1130                                   const int width, const int height) {
1131   int16x8_t s0;
1132   int32_t *B_tmp, *dst_ptr;
1133   uint16_t *A_tmp;
1134   int16_t *src_ptr;
1135   int32x4_t a_res0, a_res1, b_res0, b_res1;
1136   int w, h, count = 0;
1137 
1138   assert(SGRPROJ_SGR_BITS == 8);
1139   assert(SGRPROJ_RST_BITS == 4);
1140   h = height;
1141 
1142   do {
1143     A_tmp = (A + count * buf_stride);
1144     B_tmp = (B + count * buf_stride);
1145     src_ptr = (src + count * src_stride);
1146     dst_ptr = (dst + count * dst_stride);
1147     w = width;
1148     do {
1149       s0 = vld1q_s16(src_ptr);
1150       cross_sum_inp_u16(A_tmp, buf_stride, &a_res0, &a_res1);
1151       a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0);
1152       a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1);
1153 
1154       b_res0 = cross_sum_inp_s32(B_tmp, buf_stride);
1155       b_res1 = cross_sum_inp_s32(B_tmp + 4, buf_stride);
1156       a_res0 = vaddq_s32(a_res0, b_res0);
1157       a_res1 = vaddq_s32(a_res1, b_res1);
1158 
1159       a_res0 =
1160           vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
1161       a_res1 =
1162           vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
1163       vst1q_s32(dst_ptr, a_res0);
1164       vst1q_s32(dst_ptr + 4, a_res1);
1165 
1166       A_tmp += 8;
1167       B_tmp += 8;
1168       src_ptr += 8;
1169       dst_ptr += 8;
1170       w -= 8;
1171     } while (w > 0);
1172     count++;
1173     h -= 1;
1174   } while (h > 0);
1175 }
1176 
restoration_fast_internal(uint16_t * dgd16,int width,int height,int dgd_stride,int32_t * dst,int dst_stride,int bit_depth,int sgr_params_idx,int radius_idx)1177 static inline int restoration_fast_internal(uint16_t *dgd16, int width,
1178                                             int height, int dgd_stride,
1179                                             int32_t *dst, int dst_stride,
1180                                             int bit_depth, int sgr_params_idx,
1181                                             int radius_idx) {
1182   const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx];
1183   const int r = params->r[radius_idx];
1184   const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
1185   const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
1186   const int buf_stride = ((width_ext + 3) & ~3) + 16;
1187 
1188   const size_t buf_size = 3 * sizeof(int32_t) * RESTORATION_PROC_UNIT_PELS;
1189   int32_t *buf = aom_memalign(8, buf_size);
1190   if (!buf) return -1;
1191 
1192   int32_t *square_sum_buf = buf;
1193   int32_t *sum_buf = square_sum_buf + RESTORATION_PROC_UNIT_PELS;
1194   uint16_t *tmp16_buf = (uint16_t *)(sum_buf + RESTORATION_PROC_UNIT_PELS);
1195   assert((char *)(sum_buf + RESTORATION_PROC_UNIT_PELS) <=
1196              (char *)buf + buf_size &&
1197          "Allocated buffer is too small. Resize the buffer.");
1198 
1199   assert(r <= MAX_RADIUS && "Need MAX_RADIUS >= r");
1200   assert(r <= SGRPROJ_BORDER_VERT - 1 && r <= SGRPROJ_BORDER_HORZ - 1 &&
1201          "Need SGRPROJ_BORDER_* >= r+1");
1202 
1203   assert(radius_idx == 0);
1204   assert(r == 2);
1205 
1206   // input(dgd16) is 16bit.
1207   // sum of pixels 1st stage output will be in 16bit(tmp16_buf). End output is
1208   // kept in 32bit [sum_buf]. sum of squares output is kept in 32bit
1209   // buffer(square_sum_buf).
1210   boxsum2((int16_t *)(dgd16 - dgd_stride * SGRPROJ_BORDER_VERT -
1211                       SGRPROJ_BORDER_HORZ),
1212           dgd_stride, (int16_t *)tmp16_buf, sum_buf, square_sum_buf, buf_stride,
1213           width_ext, height_ext);
1214 
1215   square_sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1216   sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1217   tmp16_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1218 
1219   // Calculation of a, b. a output is in 16bit tmp_buf which is in range of
1220   // [1, 256] for all bit depths. b output is kept in 32bit buffer.
1221 
1222 #if CONFIG_AV1_HIGHBITDEPTH
1223   if (bit_depth > 8) {
1224     calc_ab_fast_internal_hbd(
1225         (square_sum_buf - buf_stride - 1), (tmp16_buf - buf_stride - 1),
1226         (sum_buf - buf_stride - 1), buf_stride * 2, width + 2, height + 2,
1227         bit_depth, r, params->s[radius_idx], 2);
1228   } else {
1229     calc_ab_fast_internal_lbd(
1230         (square_sum_buf - buf_stride - 1), (tmp16_buf - buf_stride - 1),
1231         (sum_buf - buf_stride - 1), buf_stride * 2, width + 2, height + 2, r,
1232         params->s[radius_idx], 2);
1233   }
1234 #else
1235   (void)bit_depth;
1236   calc_ab_fast_internal_lbd((square_sum_buf - buf_stride - 1),
1237                             (tmp16_buf - buf_stride - 1),
1238                             (sum_buf - buf_stride - 1), buf_stride * 2,
1239                             width + 2, height + 2, r, params->s[radius_idx], 2);
1240 #endif
1241   final_filter_fast_internal(tmp16_buf, sum_buf, buf_stride, (int16_t *)dgd16,
1242                              dgd_stride, dst, dst_stride, width, height);
1243   aom_free(buf);
1244   return 0;
1245 }
1246 
restoration_internal(uint16_t * dgd16,int width,int height,int dgd_stride,int32_t * dst,int dst_stride,int bit_depth,int sgr_params_idx,int radius_idx)1247 static inline int restoration_internal(uint16_t *dgd16, int width, int height,
1248                                        int dgd_stride, int32_t *dst,
1249                                        int dst_stride, int bit_depth,
1250                                        int sgr_params_idx, int radius_idx) {
1251   const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx];
1252   const int r = params->r[radius_idx];
1253   const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
1254   const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
1255   const int buf_stride = ((width_ext + 3) & ~3) + 16;
1256 
1257   const size_t buf_size = 3 * sizeof(int32_t) * RESTORATION_PROC_UNIT_PELS;
1258   int32_t *buf = aom_memalign(8, buf_size);
1259   if (!buf) return -1;
1260 
1261   int32_t *square_sum_buf = buf;
1262   int32_t *B = square_sum_buf + RESTORATION_PROC_UNIT_PELS;
1263   uint16_t *A16 = (uint16_t *)(B + RESTORATION_PROC_UNIT_PELS);
1264   uint16_t *sum_buf = A16 + RESTORATION_PROC_UNIT_PELS;
1265 
1266   assert((char *)(sum_buf + RESTORATION_PROC_UNIT_PELS) <=
1267              (char *)buf + buf_size &&
1268          "Allocated buffer is too small. Resize the buffer.");
1269 
1270   assert(r <= MAX_RADIUS && "Need MAX_RADIUS >= r");
1271   assert(r <= SGRPROJ_BORDER_VERT - 1 && r <= SGRPROJ_BORDER_HORZ - 1 &&
1272          "Need SGRPROJ_BORDER_* >= r+1");
1273 
1274   assert(radius_idx == 1);
1275   assert(r == 1);
1276 
1277   // input(dgd16) is 16bit.
1278   // sum of pixels output will be in 16bit(sum_buf).
1279   // sum of squares output is kept in 32bit buffer(square_sum_buf).
1280   boxsum1((int16_t *)(dgd16 - dgd_stride * SGRPROJ_BORDER_VERT -
1281                       SGRPROJ_BORDER_HORZ),
1282           dgd_stride, sum_buf, square_sum_buf, buf_stride, width_ext,
1283           height_ext);
1284 
1285   square_sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1286   B += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1287   A16 += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1288   sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
1289 
1290 #if CONFIG_AV1_HIGHBITDEPTH
1291   // Calculation of a, b. a output is in 16bit tmp_buf which is in range of
1292   // [1, 256] for all bit depths. b output is kept in 32bit buffer.
1293   if (bit_depth > 8) {
1294     calc_ab_internal_hbd((square_sum_buf - buf_stride - 1),
1295                          (A16 - buf_stride - 1), (sum_buf - buf_stride - 1),
1296                          (B - buf_stride - 1), buf_stride, width + 2,
1297                          height + 2, bit_depth, r, params->s[radius_idx], 1);
1298   } else {
1299     calc_ab_internal_lbd((square_sum_buf - buf_stride - 1),
1300                          (A16 - buf_stride - 1), (sum_buf - buf_stride - 1),
1301                          (B - buf_stride - 1), buf_stride, width + 2,
1302                          height + 2, r, params->s[radius_idx], 1);
1303   }
1304 #else
1305   (void)bit_depth;
1306   calc_ab_internal_lbd((square_sum_buf - buf_stride - 1),
1307                        (A16 - buf_stride - 1), (sum_buf - buf_stride - 1),
1308                        (B - buf_stride - 1), buf_stride, width + 2, height + 2,
1309                        r, params->s[radius_idx], 1);
1310 #endif
1311   final_filter_internal(A16, B, buf_stride, (int16_t *)dgd16, dgd_stride, dst,
1312                         dst_stride, width, height);
1313   aom_free(buf);
1314   return 0;
1315 }
1316 
src_convert_u8_to_u16(const uint8_t * src,const int src_stride,uint16_t * dst,const int dst_stride,const int width,const int height)1317 static inline void src_convert_u8_to_u16(const uint8_t *src,
1318                                          const int src_stride, uint16_t *dst,
1319                                          const int dst_stride, const int width,
1320                                          const int height) {
1321   const uint8_t *src_ptr;
1322   uint16_t *dst_ptr;
1323   int h, w, count = 0;
1324 
1325   uint8x8_t t1, t2, t3, t4;
1326   uint16x8_t s1, s2, s3, s4;
1327   h = height;
1328   do {
1329     src_ptr = src + (count << 2) * src_stride;
1330     dst_ptr = dst + (count << 2) * dst_stride;
1331     w = width;
1332     if (w >= 7) {
1333       do {
1334         load_u8_8x4(src_ptr, src_stride, &t1, &t2, &t3, &t4);
1335         s1 = vmovl_u8(t1);
1336         s2 = vmovl_u8(t2);
1337         s3 = vmovl_u8(t3);
1338         s4 = vmovl_u8(t4);
1339         store_u16_8x4(dst_ptr, dst_stride, s1, s2, s3, s4);
1340 
1341         src_ptr += 8;
1342         dst_ptr += 8;
1343         w -= 8;
1344       } while (w > 7);
1345     }
1346 
1347     for (int y = 0; y < w; y++) {
1348       dst_ptr[y] = src_ptr[y];
1349       dst_ptr[y + 1 * dst_stride] = src_ptr[y + 1 * src_stride];
1350       dst_ptr[y + 2 * dst_stride] = src_ptr[y + 2 * src_stride];
1351       dst_ptr[y + 3 * dst_stride] = src_ptr[y + 3 * src_stride];
1352     }
1353     count++;
1354     h -= 4;
1355   } while (h > 3);
1356 
1357   src_ptr = src + (count << 2) * src_stride;
1358   dst_ptr = dst + (count << 2) * dst_stride;
1359   for (int x = 0; x < h; x++) {
1360     for (int y = 0; y < width; y++) {
1361       dst_ptr[y + x * dst_stride] = src_ptr[y + x * src_stride];
1362     }
1363   }
1364 
1365   // memset uninitialized rows of src buffer as they are needed for the
1366   // boxsum filter calculation.
1367   for (int x = height; x < height + 5; x++)
1368     memset(dst + x * dst_stride, 0, (width + 2) * sizeof(*dst));
1369 }
1370 
1371 #if CONFIG_AV1_HIGHBITDEPTH
src_convert_hbd_copy(const uint16_t * src,int src_stride,uint16_t * dst,const int dst_stride,int width,int height)1372 static inline void src_convert_hbd_copy(const uint16_t *src, int src_stride,
1373                                         uint16_t *dst, const int dst_stride,
1374                                         int width, int height) {
1375   const uint16_t *src_ptr;
1376   uint16_t *dst_ptr;
1377   int h, w, count = 0;
1378   uint16x8_t s1, s2, s3, s4;
1379 
1380   h = height;
1381   do {
1382     src_ptr = src + (count << 2) * src_stride;
1383     dst_ptr = dst + (count << 2) * dst_stride;
1384     w = width;
1385     do {
1386       load_u16_8x4(src_ptr, src_stride, &s1, &s2, &s3, &s4);
1387       store_u16_8x4(dst_ptr, dst_stride, s1, s2, s3, s4);
1388       src_ptr += 8;
1389       dst_ptr += 8;
1390       w -= 8;
1391     } while (w > 7);
1392 
1393     for (int y = 0; y < w; y++) {
1394       dst_ptr[y] = src_ptr[y];
1395       dst_ptr[y + 1 * dst_stride] = src_ptr[y + 1 * src_stride];
1396       dst_ptr[y + 2 * dst_stride] = src_ptr[y + 2 * src_stride];
1397       dst_ptr[y + 3 * dst_stride] = src_ptr[y + 3 * src_stride];
1398     }
1399     count++;
1400     h -= 4;
1401   } while (h > 3);
1402 
1403   src_ptr = src + (count << 2) * src_stride;
1404   dst_ptr = dst + (count << 2) * dst_stride;
1405 
1406   for (int x = 0; x < h; x++) {
1407     memcpy((dst_ptr + x * dst_stride), (src_ptr + x * src_stride),
1408            sizeof(uint16_t) * width);
1409   }
1410   // memset uninitialized rows of src buffer as they are needed for the
1411   // boxsum filter calculation.
1412   for (int x = height; x < height + 5; x++)
1413     memset(dst + x * dst_stride, 0, (width + 2) * sizeof(*dst));
1414 }
1415 #endif  // CONFIG_AV1_HIGHBITDEPTH
1416 
av1_selfguided_restoration_neon(const uint8_t * dat8,int width,int height,int stride,int32_t * flt0,int32_t * flt1,int flt_stride,int sgr_params_idx,int bit_depth,int highbd)1417 int av1_selfguided_restoration_neon(const uint8_t *dat8, int width, int height,
1418                                     int stride, int32_t *flt0, int32_t *flt1,
1419                                     int flt_stride, int sgr_params_idx,
1420                                     int bit_depth, int highbd) {
1421   const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx];
1422   assert(!(params->r[0] == 0 && params->r[1] == 0));
1423 
1424   uint16_t dgd16_[RESTORATION_PROC_UNIT_PELS];
1425   const int dgd16_stride = width + 2 * SGRPROJ_BORDER_HORZ;
1426   uint16_t *dgd16 =
1427       dgd16_ + dgd16_stride * SGRPROJ_BORDER_VERT + SGRPROJ_BORDER_HORZ;
1428   const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
1429   const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
1430   const int dgd_stride = stride;
1431 
1432 #if CONFIG_AV1_HIGHBITDEPTH
1433   if (highbd) {
1434     const uint16_t *dgd16_tmp = CONVERT_TO_SHORTPTR(dat8);
1435     src_convert_hbd_copy(
1436         dgd16_tmp - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
1437         dgd_stride,
1438         dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1439         dgd16_stride, width_ext, height_ext);
1440   } else {
1441     src_convert_u8_to_u16(
1442         dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
1443         dgd_stride,
1444         dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1445         dgd16_stride, width_ext, height_ext);
1446   }
1447 #else
1448   (void)highbd;
1449   src_convert_u8_to_u16(
1450       dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, dgd_stride,
1451       dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1452       dgd16_stride, width_ext, height_ext);
1453 #endif
1454 
1455   if (params->r[0] > 0) {
1456     int ret =
1457         restoration_fast_internal(dgd16, width, height, dgd16_stride, flt0,
1458                                   flt_stride, bit_depth, sgr_params_idx, 0);
1459     if (ret != 0) return ret;
1460   }
1461   if (params->r[1] > 0) {
1462     int ret = restoration_internal(dgd16, width, height, dgd16_stride, flt1,
1463                                    flt_stride, bit_depth, sgr_params_idx, 1);
1464     if (ret != 0) return ret;
1465   }
1466   return 0;
1467 }
1468 
av1_apply_selfguided_restoration_neon(const uint8_t * dat8,int width,int height,int stride,int eps,const int * xqd,uint8_t * dst8,int dst_stride,int32_t * tmpbuf,int bit_depth,int highbd)1469 int av1_apply_selfguided_restoration_neon(const uint8_t *dat8, int width,
1470                                           int height, int stride, int eps,
1471                                           const int *xqd, uint8_t *dst8,
1472                                           int dst_stride, int32_t *tmpbuf,
1473                                           int bit_depth, int highbd) {
1474   int32_t *flt0 = tmpbuf;
1475   int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX;
1476   assert(width * height <= RESTORATION_UNITPELS_MAX);
1477   uint16_t dgd16_[RESTORATION_PROC_UNIT_PELS];
1478   const int dgd16_stride = width + 2 * SGRPROJ_BORDER_HORZ;
1479   uint16_t *dgd16 =
1480       dgd16_ + dgd16_stride * SGRPROJ_BORDER_VERT + SGRPROJ_BORDER_HORZ;
1481   const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
1482   const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
1483   const int dgd_stride = stride;
1484   const sgr_params_type *const params = &av1_sgr_params[eps];
1485   int xq[2];
1486 
1487   assert(!(params->r[0] == 0 && params->r[1] == 0));
1488 
1489 #if CONFIG_AV1_HIGHBITDEPTH
1490   if (highbd) {
1491     const uint16_t *dgd16_tmp = CONVERT_TO_SHORTPTR(dat8);
1492     src_convert_hbd_copy(
1493         dgd16_tmp - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
1494         dgd_stride,
1495         dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1496         dgd16_stride, width_ext, height_ext);
1497   } else {
1498     src_convert_u8_to_u16(
1499         dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
1500         dgd_stride,
1501         dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1502         dgd16_stride, width_ext, height_ext);
1503   }
1504 #else
1505   (void)highbd;
1506   src_convert_u8_to_u16(
1507       dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, dgd_stride,
1508       dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
1509       dgd16_stride, width_ext, height_ext);
1510 #endif
1511   if (params->r[0] > 0) {
1512     int ret = restoration_fast_internal(dgd16, width, height, dgd16_stride,
1513                                         flt0, width, bit_depth, eps, 0);
1514     if (ret != 0) return ret;
1515   }
1516   if (params->r[1] > 0) {
1517     int ret = restoration_internal(dgd16, width, height, dgd16_stride, flt1,
1518                                    width, bit_depth, eps, 1);
1519     if (ret != 0) return ret;
1520   }
1521 
1522   av1_decode_xq(xqd, xq, params);
1523 
1524   {
1525     int16_t *src_ptr;
1526     uint8_t *dst_ptr;
1527 #if CONFIG_AV1_HIGHBITDEPTH
1528     uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst8);
1529     uint16_t *dst16_ptr;
1530 #endif
1531     int16x4_t d0, d4;
1532     int16x8_t r0, s0;
1533     uint16x8_t r4;
1534     int32x4_t u0, u4, v0, v4, f00, f10;
1535     uint8x8_t t0;
1536     int count = 0, w = width, h = height, rc = 0;
1537 
1538     const int32x4_t xq0_vec = vdupq_n_s32(xq[0]);
1539     const int32x4_t xq1_vec = vdupq_n_s32(xq[1]);
1540     const int16x8_t zero = vdupq_n_s16(0);
1541     const uint16x8_t max = vdupq_n_u16((1 << bit_depth) - 1);
1542     src_ptr = (int16_t *)dgd16;
1543     do {
1544       w = width;
1545       count = 0;
1546       dst_ptr = dst8 + rc * dst_stride;
1547 #if CONFIG_AV1_HIGHBITDEPTH
1548       dst16_ptr = dst16 + rc * dst_stride;
1549 #endif
1550       do {
1551         s0 = vld1q_s16(src_ptr + count);
1552 
1553         u0 = vshll_n_s16(vget_low_s16(s0), SGRPROJ_RST_BITS);
1554         u4 = vshll_n_s16(vget_high_s16(s0), SGRPROJ_RST_BITS);
1555 
1556         v0 = vshlq_n_s32(u0, SGRPROJ_PRJ_BITS);
1557         v4 = vshlq_n_s32(u4, SGRPROJ_PRJ_BITS);
1558 
1559         if (params->r[0] > 0) {
1560           f00 = vld1q_s32(flt0 + count);
1561           f10 = vld1q_s32(flt0 + count + 4);
1562 
1563           f00 = vsubq_s32(f00, u0);
1564           f10 = vsubq_s32(f10, u4);
1565 
1566           v0 = vmlaq_s32(v0, xq0_vec, f00);
1567           v4 = vmlaq_s32(v4, xq0_vec, f10);
1568         }
1569 
1570         if (params->r[1] > 0) {
1571           f00 = vld1q_s32(flt1 + count);
1572           f10 = vld1q_s32(flt1 + count + 4);
1573 
1574           f00 = vsubq_s32(f00, u0);
1575           f10 = vsubq_s32(f10, u4);
1576 
1577           v0 = vmlaq_s32(v0, xq1_vec, f00);
1578           v4 = vmlaq_s32(v4, xq1_vec, f10);
1579         }
1580 
1581         d0 = vqrshrn_n_s32(v0, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
1582         d4 = vqrshrn_n_s32(v4, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
1583 
1584         r0 = vcombine_s16(d0, d4);
1585 
1586         r4 = vreinterpretq_u16_s16(vmaxq_s16(r0, zero));
1587 
1588 #if CONFIG_AV1_HIGHBITDEPTH
1589         if (highbd) {
1590           r4 = vminq_u16(r4, max);
1591           vst1q_u16(dst16_ptr, r4);
1592           dst16_ptr += 8;
1593         } else {
1594           t0 = vqmovn_u16(r4);
1595           vst1_u8(dst_ptr, t0);
1596           dst_ptr += 8;
1597         }
1598 #else
1599         (void)max;
1600         t0 = vqmovn_u16(r4);
1601         vst1_u8(dst_ptr, t0);
1602         dst_ptr += 8;
1603 #endif
1604         w -= 8;
1605         count += 8;
1606       } while (w > 0);
1607 
1608       src_ptr += dgd16_stride;
1609       flt1 += width;
1610       flt0 += width;
1611       rc++;
1612       h--;
1613     } while (h > 0);
1614   }
1615   return 0;
1616 }
1617