xref: /aosp_15_r20/external/libgav1/src/dsp/arm/loop_restoration_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/loop_restoration.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 #include <arm_neon.h>
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 #include <cstring>
26 
27 #include "src/dsp/arm/common_neon.h"
28 #include "src/dsp/constants.h"
29 #include "src/dsp/dsp.h"
30 #include "src/utils/common.h"
31 #include "src/utils/compiler_attributes.h"
32 #include "src/utils/constants.h"
33 
34 namespace libgav1 {
35 namespace dsp {
36 namespace low_bitdepth {
37 namespace {
38 
39 template <int bytes>
VshrU128(const uint8x8x2_t src)40 inline uint8x8_t VshrU128(const uint8x8x2_t src) {
41   return vext_u8(src.val[0], src.val[1], bytes);
42 }
43 
44 template <int bytes>
VshrU128(const uint8x8_t src[2])45 inline uint8x8_t VshrU128(const uint8x8_t src[2]) {
46   return vext_u8(src[0], src[1], bytes);
47 }
48 
49 template <int bytes>
VshrU128(const uint8x16_t src[2])50 inline uint8x16_t VshrU128(const uint8x16_t src[2]) {
51   return vextq_u8(src[0], src[1], bytes);
52 }
53 
54 template <int bytes>
VshrU128(const uint16x8x2_t src)55 inline uint16x8_t VshrU128(const uint16x8x2_t src) {
56   return vextq_u16(src.val[0], src.val[1], bytes / 2);
57 }
58 
59 template <int bytes>
VshrU128(const uint16x8_t src[2])60 inline uint16x8_t VshrU128(const uint16x8_t src[2]) {
61   return vextq_u16(src[0], src[1], bytes / 2);
62 }
63 
64 // Wiener
65 
66 // Must make a local copy of coefficients to help compiler know that they have
67 // no overlap with other buffers. Using 'const' keyword is not enough. Actually
68 // compiler doesn't make a copy, since there is enough registers in this case.
PopulateWienerCoefficients(const RestorationUnitInfo & restoration_info,const int direction,int16_t filter[4])69 inline void PopulateWienerCoefficients(
70     const RestorationUnitInfo& restoration_info, const int direction,
71     int16_t filter[4]) {
72   // In order to keep the horizontal pass intermediate values within 16 bits we
73   // offset |filter[3]| by 128. The 128 offset will be added back in the loop.
74   for (int i = 0; i < 4; ++i) {
75     filter[i] = restoration_info.wiener_info.filter[direction][i];
76   }
77   if (direction == WienerInfo::kHorizontal) {
78     filter[3] -= 128;
79   }
80 }
81 
WienerHorizontal2(const uint8x8_t s0,const uint8x8_t s1,const int16_t filter,const int16x8_t sum)82 inline int16x8_t WienerHorizontal2(const uint8x8_t s0, const uint8x8_t s1,
83                                    const int16_t filter, const int16x8_t sum) {
84   const int16x8_t ss = vreinterpretq_s16_u16(vaddl_u8(s0, s1));
85   return vmlaq_n_s16(sum, ss, filter);
86 }
87 
WienerHorizontal2(const uint8x16_t s0,const uint8x16_t s1,const int16_t filter,const int16x8x2_t sum)88 inline int16x8x2_t WienerHorizontal2(const uint8x16_t s0, const uint8x16_t s1,
89                                      const int16_t filter,
90                                      const int16x8x2_t sum) {
91   int16x8x2_t d;
92   d.val[0] =
93       WienerHorizontal2(vget_low_u8(s0), vget_low_u8(s1), filter, sum.val[0]);
94   d.val[1] =
95       WienerHorizontal2(vget_high_u8(s0), vget_high_u8(s1), filter, sum.val[1]);
96   return d;
97 }
98 
WienerHorizontalSum(const uint8x8_t s[3],const int16_t filter[4],int16x8_t sum,int16_t * const wiener_buffer)99 inline void WienerHorizontalSum(const uint8x8_t s[3], const int16_t filter[4],
100                                 int16x8_t sum, int16_t* const wiener_buffer) {
101   constexpr int offset =
102       1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
103   constexpr int limit = (offset << 2) - 1;
104   const int16x8_t s_0_2 = vreinterpretq_s16_u16(vaddl_u8(s[0], s[2]));
105   const int16x8_t s_1 = ZeroExtend(s[1]);
106   sum = vmlaq_n_s16(sum, s_0_2, filter[2]);
107   sum = vmlaq_n_s16(sum, s_1, filter[3]);
108   // Calculate scaled down offset correction, and add to sum here to prevent
109   // signed 16 bit outranging.
110   sum = vrsraq_n_s16(vshlq_n_s16(s_1, 7 - kInterRoundBitsHorizontal), sum,
111                      kInterRoundBitsHorizontal);
112   sum = vmaxq_s16(sum, vdupq_n_s16(-offset));
113   sum = vminq_s16(sum, vdupq_n_s16(limit - offset));
114   vst1q_s16(wiener_buffer, sum);
115 }
116 
WienerHorizontalSum(const uint8x16_t src[3],const int16_t filter[4],int16x8x2_t sum,int16_t * const wiener_buffer)117 inline void WienerHorizontalSum(const uint8x16_t src[3],
118                                 const int16_t filter[4], int16x8x2_t sum,
119                                 int16_t* const wiener_buffer) {
120   uint8x8_t s[3];
121   s[0] = vget_low_u8(src[0]);
122   s[1] = vget_low_u8(src[1]);
123   s[2] = vget_low_u8(src[2]);
124   WienerHorizontalSum(s, filter, sum.val[0], wiener_buffer);
125   s[0] = vget_high_u8(src[0]);
126   s[1] = vget_high_u8(src[1]);
127   s[2] = vget_high_u8(src[2]);
128   WienerHorizontalSum(s, filter, sum.val[1], wiener_buffer + 8);
129 }
130 
WienerHorizontalTap7(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)131 inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride,
132                                  const ptrdiff_t width, const int height,
133                                  const int16_t filter[4],
134                                  int16_t** const wiener_buffer) {
135   for (int y = height; y != 0; --y) {
136     const uint8_t* src_ptr = src;
137     uint8x16_t s[8];
138     s[0] = vld1q_u8(src_ptr);
139     ptrdiff_t x = width;
140     do {
141       src_ptr += 16;
142       s[7] = vld1q_u8(src_ptr);
143       s[1] = vextq_u8(s[0], s[7], 1);
144       s[2] = vextq_u8(s[0], s[7], 2);
145       s[3] = vextq_u8(s[0], s[7], 3);
146       s[4] = vextq_u8(s[0], s[7], 4);
147       s[5] = vextq_u8(s[0], s[7], 5);
148       s[6] = vextq_u8(s[0], s[7], 6);
149       int16x8x2_t sum;
150       sum.val[0] = sum.val[1] = vdupq_n_s16(0);
151       sum = WienerHorizontal2(s[0], s[6], filter[0], sum);
152       sum = WienerHorizontal2(s[1], s[5], filter[1], sum);
153       WienerHorizontalSum(s + 2, filter, sum, *wiener_buffer);
154       s[0] = s[7];
155       *wiener_buffer += 16;
156       x -= 16;
157     } while (x != 0);
158     src += src_stride;
159   }
160 }
161 
WienerHorizontalTap5(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)162 inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride,
163                                  const ptrdiff_t width, const int height,
164                                  const int16_t filter[4],
165                                  int16_t** const wiener_buffer) {
166   for (int y = height; y != 0; --y) {
167     const uint8_t* src_ptr = src;
168     uint8x16_t s[6];
169     s[0] = vld1q_u8(src_ptr);
170     ptrdiff_t x = width;
171     do {
172       src_ptr += 16;
173       s[5] = vld1q_u8(src_ptr);
174       s[1] = vextq_u8(s[0], s[5], 1);
175       s[2] = vextq_u8(s[0], s[5], 2);
176       s[3] = vextq_u8(s[0], s[5], 3);
177       s[4] = vextq_u8(s[0], s[5], 4);
178       int16x8x2_t sum;
179       sum.val[0] = sum.val[1] = vdupq_n_s16(0);
180       sum = WienerHorizontal2(s[0], s[4], filter[1], sum);
181       WienerHorizontalSum(s + 1, filter, sum, *wiener_buffer);
182       s[0] = s[5];
183       *wiener_buffer += 16;
184       x -= 16;
185     } while (x != 0);
186     src += src_stride;
187   }
188 }
189 
WienerHorizontalTap3(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)190 inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride,
191                                  const ptrdiff_t width, const int height,
192                                  const int16_t filter[4],
193                                  int16_t** const wiener_buffer) {
194   for (int y = height; y != 0; --y) {
195     const uint8_t* src_ptr = src;
196     uint8x16_t s[3];
197     ptrdiff_t x = width;
198     do {
199       // Slightly faster than using vextq_u8().
200       s[0] = vld1q_u8(src_ptr);
201       s[1] = vld1q_u8(src_ptr + 1);
202       s[2] = vld1q_u8(src_ptr + 2);
203       int16x8x2_t sum;
204       sum.val[0] = sum.val[1] = vdupq_n_s16(0);
205       WienerHorizontalSum(s, filter, sum, *wiener_buffer);
206       src_ptr += 16;
207       *wiener_buffer += 16;
208       x -= 16;
209     } while (x != 0);
210     src += src_stride;
211   }
212 }
213 
WienerHorizontalTap1(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,int16_t ** const wiener_buffer)214 inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride,
215                                  const ptrdiff_t width, const int height,
216                                  int16_t** const wiener_buffer) {
217   for (int y = height; y != 0; --y) {
218     const uint8_t* src_ptr = src;
219     ptrdiff_t x = width;
220     do {
221       const uint8x16_t s = vld1q_u8(src_ptr);
222       const uint8x8_t s0 = vget_low_u8(s);
223       const uint8x8_t s1 = vget_high_u8(s);
224       const int16x8_t d0 = vreinterpretq_s16_u16(vshll_n_u8(s0, 4));
225       const int16x8_t d1 = vreinterpretq_s16_u16(vshll_n_u8(s1, 4));
226       vst1q_s16(*wiener_buffer + 0, d0);
227       vst1q_s16(*wiener_buffer + 8, d1);
228       src_ptr += 16;
229       *wiener_buffer += 16;
230       x -= 16;
231     } while (x != 0);
232     src += src_stride;
233   }
234 }
235 
WienerVertical2(const int16x8_t a0,const int16x8_t a1,const int16_t filter,const int32x4x2_t sum)236 inline int32x4x2_t WienerVertical2(const int16x8_t a0, const int16x8_t a1,
237                                    const int16_t filter,
238                                    const int32x4x2_t sum) {
239   const int16x8_t a = vaddq_s16(a0, a1);
240   int32x4x2_t d;
241   d.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(a), filter);
242   d.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(a), filter);
243   return d;
244 }
245 
WienerVertical(const int16x8_t a[3],const int16_t filter[4],const int32x4x2_t sum)246 inline uint8x8_t WienerVertical(const int16x8_t a[3], const int16_t filter[4],
247                                 const int32x4x2_t sum) {
248   int32x4x2_t d = WienerVertical2(a[0], a[2], filter[2], sum);
249   d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a[1]), filter[3]);
250   d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a[1]), filter[3]);
251   const uint16x4_t sum_lo_16 = vqrshrun_n_s32(d.val[0], 11);
252   const uint16x4_t sum_hi_16 = vqrshrun_n_s32(d.val[1], 11);
253   return vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16));
254 }
255 
WienerVerticalTap7Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[7])256 inline uint8x8_t WienerVerticalTap7Kernel(const int16_t* const wiener_buffer,
257                                           const ptrdiff_t wiener_stride,
258                                           const int16_t filter[4],
259                                           int16x8_t a[7]) {
260   int32x4x2_t sum;
261   a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
262   a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
263   a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
264   a[6] = vld1q_s16(wiener_buffer + 6 * wiener_stride);
265   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
266   sum = WienerVertical2(a[0], a[6], filter[0], sum);
267   sum = WienerVertical2(a[1], a[5], filter[1], sum);
268   a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
269   a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
270   a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
271   return WienerVertical(a + 2, filter, sum);
272 }
273 
WienerVerticalTap7Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])274 inline uint8x8x2_t WienerVerticalTap7Kernel2(const int16_t* const wiener_buffer,
275                                              const ptrdiff_t wiener_stride,
276                                              const int16_t filter[4]) {
277   int16x8_t a[8];
278   int32x4x2_t sum;
279   uint8x8x2_t d;
280   d.val[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
281   a[7] = vld1q_s16(wiener_buffer + 7 * wiener_stride);
282   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
283   sum = WienerVertical2(a[1], a[7], filter[0], sum);
284   sum = WienerVertical2(a[2], a[6], filter[1], sum);
285   d.val[1] = WienerVertical(a + 3, filter, sum);
286   return d;
287 }
288 
WienerVerticalTap7(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint8_t * dst,const ptrdiff_t dst_stride)289 inline void WienerVerticalTap7(const int16_t* wiener_buffer,
290                                const ptrdiff_t width, const int height,
291                                const int16_t filter[4], uint8_t* dst,
292                                const ptrdiff_t dst_stride) {
293   for (int y = height >> 1; y != 0; --y) {
294     uint8_t* dst_ptr = dst;
295     ptrdiff_t x = width;
296     do {
297       uint8x8x2_t d[2];
298       d[0] = WienerVerticalTap7Kernel2(wiener_buffer + 0, width, filter);
299       d[1] = WienerVerticalTap7Kernel2(wiener_buffer + 8, width, filter);
300       vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
301       vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
302       wiener_buffer += 16;
303       dst_ptr += 16;
304       x -= 16;
305     } while (x != 0);
306     wiener_buffer += width;
307     dst += 2 * dst_stride;
308   }
309 
310   if ((height & 1) != 0) {
311     ptrdiff_t x = width;
312     do {
313       int16x8_t a[7];
314       const uint8x8_t d0 =
315           WienerVerticalTap7Kernel(wiener_buffer + 0, width, filter, a);
316       const uint8x8_t d1 =
317           WienerVerticalTap7Kernel(wiener_buffer + 8, width, filter, a);
318       vst1q_u8(dst, vcombine_u8(d0, d1));
319       wiener_buffer += 16;
320       dst += 16;
321       x -= 16;
322     } while (x != 0);
323   }
324 }
325 
WienerVerticalTap5Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[5])326 inline uint8x8_t WienerVerticalTap5Kernel(const int16_t* const wiener_buffer,
327                                           const ptrdiff_t wiener_stride,
328                                           const int16_t filter[4],
329                                           int16x8_t a[5]) {
330   a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
331   a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
332   a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
333   a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
334   a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
335   int32x4x2_t sum;
336   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
337   sum = WienerVertical2(a[0], a[4], filter[1], sum);
338   return WienerVertical(a + 1, filter, sum);
339 }
340 
WienerVerticalTap5Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])341 inline uint8x8x2_t WienerVerticalTap5Kernel2(const int16_t* const wiener_buffer,
342                                              const ptrdiff_t wiener_stride,
343                                              const int16_t filter[4]) {
344   int16x8_t a[6];
345   int32x4x2_t sum;
346   uint8x8x2_t d;
347   d.val[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
348   a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
349   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
350   sum = WienerVertical2(a[1], a[5], filter[1], sum);
351   d.val[1] = WienerVertical(a + 2, filter, sum);
352   return d;
353 }
354 
WienerVerticalTap5(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint8_t * dst,const ptrdiff_t dst_stride)355 inline void WienerVerticalTap5(const int16_t* wiener_buffer,
356                                const ptrdiff_t width, const int height,
357                                const int16_t filter[4], uint8_t* dst,
358                                const ptrdiff_t dst_stride) {
359   for (int y = height >> 1; y != 0; --y) {
360     uint8_t* dst_ptr = dst;
361     ptrdiff_t x = width;
362     do {
363       uint8x8x2_t d[2];
364       d[0] = WienerVerticalTap5Kernel2(wiener_buffer + 0, width, filter);
365       d[1] = WienerVerticalTap5Kernel2(wiener_buffer + 8, width, filter);
366       vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
367       vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
368       wiener_buffer += 16;
369       dst_ptr += 16;
370       x -= 16;
371     } while (x != 0);
372     wiener_buffer += width;
373     dst += 2 * dst_stride;
374   }
375 
376   if ((height & 1) != 0) {
377     ptrdiff_t x = width;
378     do {
379       int16x8_t a[5];
380       const uint8x8_t d0 =
381           WienerVerticalTap5Kernel(wiener_buffer + 0, width, filter, a);
382       const uint8x8_t d1 =
383           WienerVerticalTap5Kernel(wiener_buffer + 8, width, filter, a);
384       vst1q_u8(dst, vcombine_u8(d0, d1));
385       wiener_buffer += 16;
386       dst += 16;
387       x -= 16;
388     } while (x != 0);
389   }
390 }
391 
WienerVerticalTap3Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[3])392 inline uint8x8_t WienerVerticalTap3Kernel(const int16_t* const wiener_buffer,
393                                           const ptrdiff_t wiener_stride,
394                                           const int16_t filter[4],
395                                           int16x8_t a[3]) {
396   a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
397   a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
398   a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
399   int32x4x2_t sum;
400   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
401   return WienerVertical(a, filter, sum);
402 }
403 
WienerVerticalTap3Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])404 inline uint8x8x2_t WienerVerticalTap3Kernel2(const int16_t* const wiener_buffer,
405                                              const ptrdiff_t wiener_stride,
406                                              const int16_t filter[4]) {
407   int16x8_t a[4];
408   int32x4x2_t sum;
409   uint8x8x2_t d;
410   d.val[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
411   a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
412   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
413   d.val[1] = WienerVertical(a + 1, filter, sum);
414   return d;
415 }
416 
WienerVerticalTap3(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint8_t * dst,const ptrdiff_t dst_stride)417 inline void WienerVerticalTap3(const int16_t* wiener_buffer,
418                                const ptrdiff_t width, const int height,
419                                const int16_t filter[4], uint8_t* dst,
420                                const ptrdiff_t dst_stride) {
421   for (int y = height >> 1; y != 0; --y) {
422     uint8_t* dst_ptr = dst;
423     ptrdiff_t x = width;
424     do {
425       uint8x8x2_t d[2];
426       d[0] = WienerVerticalTap3Kernel2(wiener_buffer + 0, width, filter);
427       d[1] = WienerVerticalTap3Kernel2(wiener_buffer + 8, width, filter);
428       vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
429       vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
430       wiener_buffer += 16;
431       dst_ptr += 16;
432       x -= 16;
433     } while (x != 0);
434     wiener_buffer += width;
435     dst += 2 * dst_stride;
436   }
437 
438   if ((height & 1) != 0) {
439     ptrdiff_t x = width;
440     do {
441       int16x8_t a[3];
442       const uint8x8_t d0 =
443           WienerVerticalTap3Kernel(wiener_buffer + 0, width, filter, a);
444       const uint8x8_t d1 =
445           WienerVerticalTap3Kernel(wiener_buffer + 8, width, filter, a);
446       vst1q_u8(dst, vcombine_u8(d0, d1));
447       wiener_buffer += 16;
448       dst += 16;
449       x -= 16;
450     } while (x != 0);
451   }
452 }
453 
WienerVerticalTap1Kernel(const int16_t * const wiener_buffer,uint8_t * const dst)454 inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
455                                      uint8_t* const dst) {
456   const int16x8_t a0 = vld1q_s16(wiener_buffer + 0);
457   const int16x8_t a1 = vld1q_s16(wiener_buffer + 8);
458   const uint8x8_t d0 = vqrshrun_n_s16(a0, 4);
459   const uint8x8_t d1 = vqrshrun_n_s16(a1, 4);
460   vst1q_u8(dst, vcombine_u8(d0, d1));
461 }
462 
WienerVerticalTap1(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,uint8_t * dst,const ptrdiff_t dst_stride)463 inline void WienerVerticalTap1(const int16_t* wiener_buffer,
464                                const ptrdiff_t width, const int height,
465                                uint8_t* dst, const ptrdiff_t dst_stride) {
466   for (int y = height >> 1; y != 0; --y) {
467     uint8_t* dst_ptr = dst;
468     ptrdiff_t x = width;
469     do {
470       WienerVerticalTap1Kernel(wiener_buffer, dst_ptr);
471       WienerVerticalTap1Kernel(wiener_buffer + width, dst_ptr + dst_stride);
472       wiener_buffer += 16;
473       dst_ptr += 16;
474       x -= 16;
475     } while (x != 0);
476     wiener_buffer += width;
477     dst += 2 * dst_stride;
478   }
479 
480   if ((height & 1) != 0) {
481     ptrdiff_t x = width;
482     do {
483       WienerVerticalTap1Kernel(wiener_buffer, dst);
484       wiener_buffer += 16;
485       dst += 16;
486       x -= 16;
487     } while (x != 0);
488   }
489 }
490 
491 // For width 16 and up, store the horizontal results, and then do the vertical
492 // filter row by row. This is faster than doing it column by column when
493 // considering cache issues.
WienerFilter_NEON(const RestorationUnitInfo & LIBGAV1_RESTRICT restoration_info,const void * LIBGAV1_RESTRICT const source,const ptrdiff_t stride,const void * LIBGAV1_RESTRICT const top_border,const ptrdiff_t top_border_stride,const void * LIBGAV1_RESTRICT const bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,RestorationBuffer * LIBGAV1_RESTRICT const restoration_buffer,void * LIBGAV1_RESTRICT const dest)494 void WienerFilter_NEON(
495     const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
496     const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
497     const void* LIBGAV1_RESTRICT const top_border,
498     const ptrdiff_t top_border_stride,
499     const void* LIBGAV1_RESTRICT const bottom_border,
500     const ptrdiff_t bottom_border_stride, const int width, const int height,
501     RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
502     void* LIBGAV1_RESTRICT const dest) {
503   const int16_t* const number_leading_zero_coefficients =
504       restoration_info.wiener_info.number_leading_zero_coefficients;
505   const int number_rows_to_skip = std::max(
506       static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
507       1);
508   const ptrdiff_t wiener_stride = Align(width, 16);
509   int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer;
510   // The values are saturated to 13 bits before storing.
511   int16_t* wiener_buffer_horizontal =
512       wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
513   int16_t filter_horizontal[(kWienerFilterTaps + 1) / 2];
514   int16_t filter_vertical[(kWienerFilterTaps + 1) / 2];
515   PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal,
516                              filter_horizontal);
517   PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical,
518                              filter_vertical);
519 
520   // horizontal filtering.
521   // Over-reads up to 15 - |kRestorationHorizontalBorder| values.
522   const int height_horizontal =
523       height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
524   const int height_extra = (height_horizontal - height) >> 1;
525   assert(height_extra <= 2);
526   const auto* const src = static_cast<const uint8_t*>(source);
527   const auto* const top = static_cast<const uint8_t*>(top_border);
528   const auto* const bottom = static_cast<const uint8_t*>(bottom_border);
529   if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
530     WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3,
531                          top_border_stride, wiener_stride, height_extra,
532                          filter_horizontal, &wiener_buffer_horizontal);
533     WienerHorizontalTap7(src - 3, stride, wiener_stride, height,
534                          filter_horizontal, &wiener_buffer_horizontal);
535     WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride,
536                          height_extra, filter_horizontal,
537                          &wiener_buffer_horizontal);
538   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
539     WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2,
540                          top_border_stride, wiener_stride, height_extra,
541                          filter_horizontal, &wiener_buffer_horizontal);
542     WienerHorizontalTap5(src - 2, stride, wiener_stride, height,
543                          filter_horizontal, &wiener_buffer_horizontal);
544     WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride,
545                          height_extra, filter_horizontal,
546                          &wiener_buffer_horizontal);
547   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
548     // The maximum over-reads happen here.
549     WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1,
550                          top_border_stride, wiener_stride, height_extra,
551                          filter_horizontal, &wiener_buffer_horizontal);
552     WienerHorizontalTap3(src - 1, stride, wiener_stride, height,
553                          filter_horizontal, &wiener_buffer_horizontal);
554     WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride,
555                          height_extra, filter_horizontal,
556                          &wiener_buffer_horizontal);
557   } else {
558     assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
559     WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride,
560                          top_border_stride, wiener_stride, height_extra,
561                          &wiener_buffer_horizontal);
562     WienerHorizontalTap1(src, stride, wiener_stride, height,
563                          &wiener_buffer_horizontal);
564     WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride,
565                          height_extra, &wiener_buffer_horizontal);
566   }
567 
568   // vertical filtering.
569   // Over-writes up to 15 values.
570   auto* dst = static_cast<uint8_t*>(dest);
571   if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
572     // Because the top row of |source| is a duplicate of the second row, and the
573     // bottom row of |source| is a duplicate of its above row, we can duplicate
574     // the top and bottom row of |wiener_buffer| accordingly.
575     memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
576            sizeof(*wiener_buffer_horizontal) * wiener_stride);
577     memcpy(restoration_buffer->wiener_buffer,
578            restoration_buffer->wiener_buffer + wiener_stride,
579            sizeof(*restoration_buffer->wiener_buffer) * wiener_stride);
580     WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
581                        filter_vertical, dst, stride);
582   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
583     WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
584                        height, filter_vertical, dst, stride);
585   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
586     WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
587                        wiener_stride, height, filter_vertical, dst, stride);
588   } else {
589     assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
590     WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
591                        wiener_stride, height, dst, stride);
592   }
593 }
594 
595 //------------------------------------------------------------------------------
596 // SGR
597 
598 // SIMD overreads 8 - (width % 8) - 2 * padding pixels, where padding is 3 for
599 // Pass 1 and 2 for Pass 2.
600 constexpr int kOverreadInBytesPass1 = 2;
601 constexpr int kOverreadInBytesPass2 = 4;
602 
603 // SIMD overreads 16 - (width % 16) - 2 * padding pixels, where padding is 3 for
604 // Pass 1 and 2 for Pass 2.
605 constexpr int kWideOverreadInBytesPass1 = 10;
606 constexpr int kWideOverreadInBytesPass2 = 12;
607 
LoadAligned16x2U16(const uint16_t * const src[2],const ptrdiff_t x,uint16x8_t dst[2])608 inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x,
609                                uint16x8_t dst[2]) {
610   dst[0] = vld1q_u16(src[0] + x);
611   dst[1] = vld1q_u16(src[1] + x);
612 }
613 
LoadAligned16x3U16(const uint16_t * const src[3],const ptrdiff_t x,uint16x8_t dst[3])614 inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x,
615                                uint16x8_t dst[3]) {
616   dst[0] = vld1q_u16(src[0] + x);
617   dst[1] = vld1q_u16(src[1] + x);
618   dst[2] = vld1q_u16(src[2] + x);
619 }
620 
LoadAligned32U32(const uint32_t * const src,uint32x4x2_t * dst)621 inline void LoadAligned32U32(const uint32_t* const src, uint32x4x2_t* dst) {
622   (*dst).val[0] = vld1q_u32(src + 0);
623   (*dst).val[1] = vld1q_u32(src + 4);
624 }
625 
LoadAligned32x2U32(const uint32_t * const src[2],const ptrdiff_t x,uint32x4x2_t dst[2])626 inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x,
627                                uint32x4x2_t dst[2]) {
628   LoadAligned32U32(src[0] + x, &dst[0]);
629   LoadAligned32U32(src[1] + x, &dst[1]);
630 }
631 
LoadAligned32x3U32(const uint32_t * const src[3],const ptrdiff_t x,uint32x4x2_t dst[3])632 inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x,
633                                uint32x4x2_t dst[3]) {
634   LoadAligned32U32(src[0] + x, &dst[0]);
635   LoadAligned32U32(src[1] + x, &dst[1]);
636   LoadAligned32U32(src[2] + x, &dst[2]);
637 }
638 
StoreAligned32U16(uint16_t * const dst,const uint16x8_t src[2])639 inline void StoreAligned32U16(uint16_t* const dst, const uint16x8_t src[2]) {
640   vst1q_u16(dst + 0, src[0]);
641   vst1q_u16(dst + 8, src[1]);
642 }
643 
StoreAligned32U32(uint32_t * const dst,const uint32x4x2_t src)644 inline void StoreAligned32U32(uint32_t* const dst, const uint32x4x2_t src) {
645   vst1q_u32(dst + 0, src.val[0]);
646   vst1q_u32(dst + 4, src.val[1]);
647 }
648 
StoreAligned64U32(uint32_t * const dst,const uint32x4x2_t src[2])649 inline void StoreAligned64U32(uint32_t* const dst, const uint32x4x2_t src[2]) {
650   vst1q_u32(dst + 0, src[0].val[0]);
651   vst1q_u32(dst + 4, src[0].val[1]);
652   vst1q_u32(dst + 8, src[1].val[0]);
653   vst1q_u32(dst + 12, src[1].val[1]);
654 }
655 
SquareLo8(const uint8x8_t src)656 inline uint16x8_t SquareLo8(const uint8x8_t src) { return vmull_u8(src, src); }
657 
SquareLo8(const uint8x16_t src)658 inline uint16x8_t SquareLo8(const uint8x16_t src) {
659   return vmull_u8(vget_low_u8(src), vget_low_u8(src));
660 }
661 
SquareHi8(const uint8x16_t src)662 inline uint16x8_t SquareHi8(const uint8x16_t src) {
663   return vmull_u8(vget_high_u8(src), vget_high_u8(src));
664 }
665 
Prepare3_8(const uint8x8_t src[2],uint8x8_t dst[3])666 inline void Prepare3_8(const uint8x8_t src[2], uint8x8_t dst[3]) {
667   dst[0] = VshrU128<0>(src);
668   dst[1] = VshrU128<1>(src);
669   dst[2] = VshrU128<2>(src);
670 }
671 
672 template <int offset>
Prepare3_8(const uint8x16_t src[2],uint8x16_t dst[3])673 inline void Prepare3_8(const uint8x16_t src[2], uint8x16_t dst[3]) {
674   dst[0] = VshrU128<offset + 0>(src);
675   dst[1] = VshrU128<offset + 1>(src);
676   dst[2] = VshrU128<offset + 2>(src);
677 }
678 
Prepare3_16(const uint16x8_t src[2],uint16x4_t low[3],uint16x4_t high[3])679 inline void Prepare3_16(const uint16x8_t src[2], uint16x4_t low[3],
680                         uint16x4_t high[3]) {
681   uint16x8_t s[3];
682   s[0] = VshrU128<0>(src);
683   s[1] = VshrU128<2>(src);
684   s[2] = VshrU128<4>(src);
685   low[0] = vget_low_u16(s[0]);
686   low[1] = vget_low_u16(s[1]);
687   low[2] = vget_low_u16(s[2]);
688   high[0] = vget_high_u16(s[0]);
689   high[1] = vget_high_u16(s[1]);
690   high[2] = vget_high_u16(s[2]);
691 }
692 
Prepare5_8(const uint8x8_t src[2],uint8x8_t dst[5])693 inline void Prepare5_8(const uint8x8_t src[2], uint8x8_t dst[5]) {
694   dst[0] = VshrU128<0>(src);
695   dst[1] = VshrU128<1>(src);
696   dst[2] = VshrU128<2>(src);
697   dst[3] = VshrU128<3>(src);
698   dst[4] = VshrU128<4>(src);
699 }
700 
701 template <int offset>
Prepare5_8(const uint8x16_t src[2],uint8x16_t dst[5])702 inline void Prepare5_8(const uint8x16_t src[2], uint8x16_t dst[5]) {
703   dst[0] = VshrU128<offset + 0>(src);
704   dst[1] = VshrU128<offset + 1>(src);
705   dst[2] = VshrU128<offset + 2>(src);
706   dst[3] = VshrU128<offset + 3>(src);
707   dst[4] = VshrU128<offset + 4>(src);
708 }
709 
Prepare5_16(const uint16x8_t src[2],uint16x4_t low[5],uint16x4_t high[5])710 inline void Prepare5_16(const uint16x8_t src[2], uint16x4_t low[5],
711                         uint16x4_t high[5]) {
712   Prepare3_16(src, low, high);
713   const uint16x8_t s3 = VshrU128<6>(src);
714   const uint16x8_t s4 = VshrU128<8>(src);
715   low[3] = vget_low_u16(s3);
716   low[4] = vget_low_u16(s4);
717   high[3] = vget_high_u16(s3);
718   high[4] = vget_high_u16(s4);
719 }
720 
Sum3_16(const uint16x8_t src0,const uint16x8_t src1,const uint16x8_t src2)721 inline uint16x8_t Sum3_16(const uint16x8_t src0, const uint16x8_t src1,
722                           const uint16x8_t src2) {
723   const uint16x8_t sum = vaddq_u16(src0, src1);
724   return vaddq_u16(sum, src2);
725 }
726 
Sum3_16(const uint16x8_t src[3])727 inline uint16x8_t Sum3_16(const uint16x8_t src[3]) {
728   return Sum3_16(src[0], src[1], src[2]);
729 }
730 
Sum3_32(const uint32x4_t src0,const uint32x4_t src1,const uint32x4_t src2)731 inline uint32x4_t Sum3_32(const uint32x4_t src0, const uint32x4_t src1,
732                           const uint32x4_t src2) {
733   const uint32x4_t sum = vaddq_u32(src0, src1);
734   return vaddq_u32(sum, src2);
735 }
736 
Sum3_32(const uint32x4x2_t src[3])737 inline uint32x4x2_t Sum3_32(const uint32x4x2_t src[3]) {
738   uint32x4x2_t d;
739   d.val[0] = Sum3_32(src[0].val[0], src[1].val[0], src[2].val[0]);
740   d.val[1] = Sum3_32(src[0].val[1], src[1].val[1], src[2].val[1]);
741   return d;
742 }
743 
Sum3W_16(const uint8x8_t src[3])744 inline uint16x8_t Sum3W_16(const uint8x8_t src[3]) {
745   const uint16x8_t sum = vaddl_u8(src[0], src[1]);
746   return vaddw_u8(sum, src[2]);
747 }
748 
Sum3WLo16(const uint8x16_t src[3])749 inline uint16x8_t Sum3WLo16(const uint8x16_t src[3]) {
750   const uint16x8_t sum = vaddl_u8(vget_low_u8(src[0]), vget_low_u8(src[1]));
751   return vaddw_u8(sum, vget_low_u8(src[2]));
752 }
753 
Sum3WHi16(const uint8x16_t src[3])754 inline uint16x8_t Sum3WHi16(const uint8x16_t src[3]) {
755   const uint16x8_t sum = vaddl_u8(vget_high_u8(src[0]), vget_high_u8(src[1]));
756   return vaddw_u8(sum, vget_high_u8(src[2]));
757 }
758 
Sum5WLo16(const uint8x16_t src[5])759 inline uint16x8_t Sum5WLo16(const uint8x16_t src[5]) {
760   const uint16x8_t sum01 = vaddl_u8(vget_low_u8(src[0]), vget_low_u8(src[1]));
761   const uint16x8_t sum23 = vaddl_u8(vget_low_u8(src[2]), vget_low_u8(src[3]));
762   const uint16x8_t sum = vaddq_u16(sum01, sum23);
763   return vaddw_u8(sum, vget_low_u8(src[4]));
764 }
765 
Sum5WHi16(const uint8x16_t src[5])766 inline uint16x8_t Sum5WHi16(const uint8x16_t src[5]) {
767   const uint16x8_t sum01 = vaddl_u8(vget_high_u8(src[0]), vget_high_u8(src[1]));
768   const uint16x8_t sum23 = vaddl_u8(vget_high_u8(src[2]), vget_high_u8(src[3]));
769   const uint16x8_t sum = vaddq_u16(sum01, sum23);
770   return vaddw_u8(sum, vget_high_u8(src[4]));
771 }
772 
Sum3W_32(const uint16x4_t src[3])773 inline uint32x4_t Sum3W_32(const uint16x4_t src[3]) {
774   const uint32x4_t sum = vaddl_u16(src[0], src[1]);
775   return vaddw_u16(sum, src[2]);
776 }
777 
Sum5_16(const uint16x8_t src[5])778 inline uint16x8_t Sum5_16(const uint16x8_t src[5]) {
779   const uint16x8_t sum01 = vaddq_u16(src[0], src[1]);
780   const uint16x8_t sum23 = vaddq_u16(src[2], src[3]);
781   const uint16x8_t sum = vaddq_u16(sum01, sum23);
782   return vaddq_u16(sum, src[4]);
783 }
784 
Sum5_32(const uint32x4_t src0,const uint32x4_t src1,const uint32x4_t src2,const uint32x4_t src3,const uint32x4_t src4)785 inline uint32x4_t Sum5_32(const uint32x4_t src0, const uint32x4_t src1,
786                           const uint32x4_t src2, const uint32x4_t src3,
787                           const uint32x4_t src4) {
788   const uint32x4_t sum01 = vaddq_u32(src0, src1);
789   const uint32x4_t sum23 = vaddq_u32(src2, src3);
790   const uint32x4_t sum = vaddq_u32(sum01, sum23);
791   return vaddq_u32(sum, src4);
792 }
793 
Sum5_32(const uint32x4x2_t src[5])794 inline uint32x4x2_t Sum5_32(const uint32x4x2_t src[5]) {
795   uint32x4x2_t d;
796   d.val[0] = Sum5_32(src[0].val[0], src[1].val[0], src[2].val[0], src[3].val[0],
797                      src[4].val[0]);
798   d.val[1] = Sum5_32(src[0].val[1], src[1].val[1], src[2].val[1], src[3].val[1],
799                      src[4].val[1]);
800   return d;
801 }
802 
Sum5W_32(const uint16x4_t src[5])803 inline uint32x4_t Sum5W_32(const uint16x4_t src[5]) {
804   const uint32x4_t sum01 = vaddl_u16(src[0], src[1]);
805   const uint32x4_t sum23 = vaddl_u16(src[2], src[3]);
806   const uint32x4_t sum0123 = vaddq_u32(sum01, sum23);
807   return vaddw_u16(sum0123, src[4]);
808 }
809 
Sum3Horizontal(const uint8x8_t src[2])810 inline uint16x8_t Sum3Horizontal(const uint8x8_t src[2]) {
811   uint8x8_t s[3];
812   Prepare3_8(src, s);
813   return Sum3W_16(s);
814 }
815 
Sum3Horizontal(const uint8x16_t src)816 inline uint16x8_t Sum3Horizontal(const uint8x16_t src) {
817   uint8x8_t s[2];
818   s[0] = vget_low_u8(src);
819   s[1] = vget_high_u8(src);
820   return Sum3Horizontal(s);
821 }
822 
823 template <int offset>
Sum3Horizontal(const uint8x16_t src[2],uint16x8_t dst[2])824 inline void Sum3Horizontal(const uint8x16_t src[2], uint16x8_t dst[2]) {
825   uint8x16_t s[3];
826   Prepare3_8<offset>(src, s);
827   dst[0] = Sum3WLo16(s);
828   dst[1] = Sum3WHi16(s);
829 }
830 
Sum3WHorizontal(const uint16x8_t src[2])831 inline uint32x4x2_t Sum3WHorizontal(const uint16x8_t src[2]) {
832   uint16x4_t low[3], high[3];
833   uint32x4x2_t sum;
834   Prepare3_16(src, low, high);
835   sum.val[0] = Sum3W_32(low);
836   sum.val[1] = Sum3W_32(high);
837   return sum;
838 }
839 
Sum5Horizontal(const uint8x8_t src[2])840 inline uint16x8_t Sum5Horizontal(const uint8x8_t src[2]) {
841   uint8x8_t s[5];
842   Prepare5_8(src, s);
843   const uint16x8_t sum01 = vaddl_u8(s[0], s[1]);
844   const uint16x8_t sum23 = vaddl_u8(s[2], s[3]);
845   const uint16x8_t sum0123 = vaddq_u16(sum01, sum23);
846   return vaddw_u8(sum0123, s[4]);
847 }
848 
Sum5Horizontal(const uint8x16_t src)849 inline uint16x8_t Sum5Horizontal(const uint8x16_t src) {
850   uint8x8_t s[2];
851   s[0] = vget_low_u8(src);
852   s[1] = vget_high_u8(src);
853   return Sum5Horizontal(s);
854 }
855 
856 template <int offset>
Sum5Horizontal(const uint8x16_t src[2],uint16x8_t * const dst0,uint16x8_t * const dst1)857 inline void Sum5Horizontal(const uint8x16_t src[2], uint16x8_t* const dst0,
858                            uint16x8_t* const dst1) {
859   uint8x16_t s[5];
860   Prepare5_8<offset>(src, s);
861   *dst0 = Sum5WLo16(s);
862   *dst1 = Sum5WHi16(s);
863 }
864 
Sum5WHorizontal(const uint16x8_t src[2])865 inline uint32x4x2_t Sum5WHorizontal(const uint16x8_t src[2]) {
866   uint16x4_t low[5], high[5];
867   Prepare5_16(src, low, high);
868   uint32x4x2_t sum;
869   sum.val[0] = Sum5W_32(low);
870   sum.val[1] = Sum5W_32(high);
871   return sum;
872 }
873 
874 template <int offset>
SumHorizontal(const uint8x16_t src[2],uint16x8_t * const row3_0,uint16x8_t * const row3_1,uint16x8_t * const row5_0,uint16x8_t * const row5_1)875 void SumHorizontal(const uint8x16_t src[2], uint16x8_t* const row3_0,
876                    uint16x8_t* const row3_1, uint16x8_t* const row5_0,
877                    uint16x8_t* const row5_1) {
878   uint8x16_t s[5];
879   Prepare5_8<offset>(src, s);
880   const uint16x8_t sum04_lo = vaddl_u8(vget_low_u8(s[0]), vget_low_u8(s[4]));
881   const uint16x8_t sum04_hi = vaddl_u8(vget_high_u8(s[0]), vget_high_u8(s[4]));
882   *row3_0 = Sum3WLo16(s + 1);
883   *row3_1 = Sum3WHi16(s + 1);
884   *row5_0 = vaddq_u16(sum04_lo, *row3_0);
885   *row5_1 = vaddq_u16(sum04_hi, *row3_1);
886 }
887 
SumHorizontal(const uint8x8_t src[2],uint16x8_t * const row3,uint16x8_t * const row5)888 void SumHorizontal(const uint8x8_t src[2], uint16x8_t* const row3,
889                    uint16x8_t* const row5) {
890   uint8x8_t s[5];
891   Prepare5_8(src, s);
892   const uint16x8_t sum04 = vaddl_u8(s[0], s[4]);
893   const uint16x8_t sum12 = vaddl_u8(s[1], s[2]);
894   *row3 = vaddw_u8(sum12, s[3]);
895   *row5 = vaddq_u16(sum04, *row3);
896 }
897 
SumHorizontal(const uint16x4_t src[5],uint32x4_t * const row_sq3,uint32x4_t * const row_sq5)898 void SumHorizontal(const uint16x4_t src[5], uint32x4_t* const row_sq3,
899                    uint32x4_t* const row_sq5) {
900   const uint32x4_t sum04 = vaddl_u16(src[0], src[4]);
901   const uint32x4_t sum12 = vaddl_u16(src[1], src[2]);
902   *row_sq3 = vaddw_u16(sum12, src[3]);
903   *row_sq5 = vaddq_u32(sum04, *row_sq3);
904 }
905 
SumHorizontal(const uint16x8_t sq[2],uint32x4x2_t * const row_sq3,uint32x4x2_t * const row_sq5)906 void SumHorizontal(const uint16x8_t sq[2], uint32x4x2_t* const row_sq3,
907                    uint32x4x2_t* const row_sq5) {
908   uint16x4_t low[5], high[5];
909   Prepare5_16(sq, low, high);
910   SumHorizontal(low, &row_sq3->val[0], &row_sq5->val[0]);
911   SumHorizontal(high, &row_sq3->val[1], &row_sq5->val[1]);
912 }
913 
SumHorizontal(const uint8x8_t src[2],const uint16x8_t sq[2],uint16x8_t * const row3,uint16x8_t * const row5,uint32x4x2_t * const row_sq3,uint32x4x2_t * const row_sq5)914 void SumHorizontal(const uint8x8_t src[2], const uint16x8_t sq[2],
915                    uint16x8_t* const row3, uint16x8_t* const row5,
916                    uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) {
917   SumHorizontal(src, row3, row5);
918   SumHorizontal(sq, row_sq3, row_sq5);
919 }
920 
SumHorizontal(const uint8x16_t src,const uint16x8_t sq[2],uint16x8_t * const row3,uint16x8_t * const row5,uint32x4x2_t * const row_sq3,uint32x4x2_t * const row_sq5)921 void SumHorizontal(const uint8x16_t src, const uint16x8_t sq[2],
922                    uint16x8_t* const row3, uint16x8_t* const row5,
923                    uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) {
924   uint8x8_t s[2];
925   s[0] = vget_low_u8(src);
926   s[1] = vget_high_u8(src);
927   return SumHorizontal(s, sq, row3, row5, row_sq3, row_sq5);
928 }
929 
930 template <int offset>
Sum343(const uint8x16_t ma3[2])931 inline uint16x8_t Sum343(const uint8x16_t ma3[2]) {
932   const uint16x8_t sum = (offset == 0) ? Sum3WLo16(ma3) : Sum3WHi16(ma3);
933   const uint16x8_t sum3 = Sum3_16(sum, sum, sum);
934   return vaddw_u8(sum3,
935                   (offset == 0) ? vget_low_u8(ma3[1]) : vget_high_u8(ma3[1]));
936 }
937 
Sum343W(const uint16x4_t src[3])938 inline uint32x4_t Sum343W(const uint16x4_t src[3]) {
939   const uint32x4_t sum = Sum3W_32(src);
940   const uint32x4_t sum3 = Sum3_32(sum, sum, sum);
941   return vaddw_u16(sum3, src[1]);
942 }
943 
Sum343W(const uint16x8_t src[2])944 inline uint32x4x2_t Sum343W(const uint16x8_t src[2]) {
945   uint16x4_t low[3], high[3];
946   uint32x4x2_t d;
947   Prepare3_16(src, low, high);
948   d.val[0] = Sum343W(low);
949   d.val[1] = Sum343W(high);
950   return d;
951 }
952 
953 template <int offset>
Sum565(const uint8x16_t ma5[2])954 inline uint16x8_t Sum565(const uint8x16_t ma5[2]) {
955   const uint16x8_t sum = (offset == 0) ? Sum3WLo16(ma5) : Sum3WHi16(ma5);
956   const uint16x8_t sum4 = vshlq_n_u16(sum, 2);
957   const uint16x8_t sum5 = vaddq_u16(sum4, sum);
958   return vaddw_u8(sum5,
959                   (offset == 0) ? vget_low_u8(ma5[1]) : vget_high_u8(ma5[1]));
960 }
961 
Sum565W(const uint16x4_t src[3])962 inline uint32x4_t Sum565W(const uint16x4_t src[3]) {
963   const uint32x4_t sum = Sum3W_32(src);
964   const uint32x4_t sum4 = vshlq_n_u32(sum, 2);
965   const uint32x4_t sum5 = vaddq_u32(sum4, sum);
966   return vaddw_u16(sum5, src[1]);
967 }
968 
Sum565W(const uint16x8_t src[2])969 inline uint32x4x2_t Sum565W(const uint16x8_t src[2]) {
970   uint16x4_t low[3], high[3];
971   uint32x4x2_t d;
972   Prepare3_16(src, low, high);
973   d.val[0] = Sum565W(low);
974   d.val[1] = Sum565W(high);
975   return d;
976 }
977 
BoxSum(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const ptrdiff_t sum_stride,const ptrdiff_t sum_width,uint16_t * sum3,uint16_t * sum5,uint32_t * square_sum3,uint32_t * square_sum5)978 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
979                    const ptrdiff_t width, const ptrdiff_t sum_stride,
980                    const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5,
981                    uint32_t* square_sum3, uint32_t* square_sum5) {
982   const ptrdiff_t overread_in_bytes = kOverreadInBytesPass1 - width;
983   int y = 2;
984   // Don't change loop width to 16, which is even slower.
985   do {
986     uint8x8_t s[2];
987     uint16x8_t sq[2];
988     s[0] = Load1MsanU8(src, overread_in_bytes);
989     sq[0] = SquareLo8(s[0]);
990     ptrdiff_t x = sum_width;
991     do {
992       uint16x8_t row3, row5;
993       uint32x4x2_t row_sq3, row_sq5;
994       x -= 8;
995       src += 8;
996       s[1] = Load1MsanU8(src, sum_width - x + overread_in_bytes);
997       sq[1] = SquareLo8(s[1]);
998       SumHorizontal(s, sq, &row3, &row5, &row_sq3, &row_sq5);
999       vst1q_u16(sum3, row3);
1000       vst1q_u16(sum5, row5);
1001       StoreAligned32U32(square_sum3 + 0, row_sq3);
1002       StoreAligned32U32(square_sum5 + 0, row_sq5);
1003       s[0] = s[1];
1004       sq[0] = sq[1];
1005       sum3 += 8;
1006       sum5 += 8;
1007       square_sum3 += 8;
1008       square_sum5 += 8;
1009     } while (x != 0);
1010     src += src_stride - sum_width;
1011     sum3 += sum_stride - sum_width;
1012     sum5 += sum_stride - sum_width;
1013     square_sum3 += sum_stride - sum_width;
1014     square_sum5 += sum_stride - sum_width;
1015   } while (--y != 0);
1016 }
1017 
1018 template <int size>
BoxSum(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const ptrdiff_t sum_stride,const ptrdiff_t sum_width,uint16_t * sums,uint32_t * square_sums)1019 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
1020                    const ptrdiff_t width, const ptrdiff_t sum_stride,
1021                    const ptrdiff_t sum_width, uint16_t* sums,
1022                    uint32_t* square_sums) {
1023   static_assert(size == 3 || size == 5, "");
1024   const ptrdiff_t overread_in_bytes =
1025       ((size == 5) ? kOverreadInBytesPass1 : kOverreadInBytesPass2) -
1026       sizeof(*src) * width;
1027   int y = 2;
1028   // Don't change loop width to 16, which is even slower.
1029   do {
1030     uint8x8_t s[2];
1031     uint16x8_t sq[2];
1032     s[0] = Load1MsanU8(src, overread_in_bytes);
1033     sq[0] = SquareLo8(s[0]);
1034     ptrdiff_t x = sum_width;
1035     do {
1036       uint16x8_t row;
1037       uint32x4x2_t row_sq;
1038       x -= 8;
1039       src += 8;
1040       s[1] = Load1MsanU8(src, sum_width - x + overread_in_bytes);
1041       sq[1] = SquareLo8(s[1]);
1042       if (size == 3) {
1043         row = Sum3Horizontal(s);
1044         row_sq = Sum3WHorizontal(sq);
1045       } else {
1046         row = Sum5Horizontal(s);
1047         row_sq = Sum5WHorizontal(sq);
1048       }
1049       vst1q_u16(sums, row);
1050       StoreAligned32U32(square_sums, row_sq);
1051       s[0] = s[1];
1052       sq[0] = sq[1];
1053       sums += 8;
1054       square_sums += 8;
1055     } while (x != 0);
1056     src += src_stride - sum_width;
1057     sums += sum_stride - sum_width;
1058     square_sums += sum_stride - sum_width;
1059   } while (--y != 0);
1060 }
1061 
1062 template <int n>
CalculateMa(const uint16x4_t sum,const uint32x4_t sum_sq,const uint32_t scale)1063 inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq,
1064                               const uint32_t scale) {
1065   // a = |sum_sq|
1066   // d = |sum|
1067   // p = (a * n < d * d) ? 0 : a * n - d * d;
1068   const uint32x4_t dxd = vmull_u16(sum, sum);
1069   const uint32x4_t axn = vmulq_n_u32(sum_sq, n);
1070   // Ensure |p| does not underflow by using saturating subtraction.
1071   const uint32x4_t p = vqsubq_u32(axn, dxd);
1072   const uint32x4_t pxs = vmulq_n_u32(p, scale);
1073   // vrshrn_n_u32() (narrowing shift) can only shift by 16 and kSgrProjScaleBits
1074   // is 20.
1075   const uint32x4_t shifted = vrshrq_n_u32(pxs, kSgrProjScaleBits);
1076   return vmovn_u32(shifted);
1077 }
1078 
AdjustValue(const uint8x8_t value,const uint8x8_t index,const int threshold)1079 inline uint8x8_t AdjustValue(const uint8x8_t value, const uint8x8_t index,
1080                              const int threshold) {
1081   const uint8x8_t thresholds = vdup_n_u8(threshold);
1082   const uint8x8_t offset = vcgt_u8(index, thresholds);
1083   // Adding 255 is equivalent to subtracting 1 for 8-bit data.
1084   return vadd_u8(value, offset);
1085 }
1086 
1087 template <int n, int offset>
CalculateIntermediate(const uint16x8_t sum,const uint32x4x2_t sum_sq,const uint32_t scale,uint8x16_t * const ma,uint16x8_t * const b)1088 inline void CalculateIntermediate(const uint16x8_t sum,
1089                                   const uint32x4x2_t sum_sq,
1090                                   const uint32_t scale, uint8x16_t* const ma,
1091                                   uint16x8_t* const b) {
1092   constexpr uint32_t one_over_n =
1093       ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
1094   const uint16x4_t z0 = CalculateMa<n>(vget_low_u16(sum), sum_sq.val[0], scale);
1095   const uint16x4_t z1 =
1096       CalculateMa<n>(vget_high_u16(sum), sum_sq.val[1], scale);
1097   const uint16x8_t z01 = vcombine_u16(z0, z1);
1098   const uint8x8_t idx = vqmovn_u16(z01);
1099   // Use table lookup to read elements whose indices are less than 48.
1100   // Using one uint8x8x4_t vector and one uint8x8x2_t vector is faster than
1101   // using two uint8x8x3_t vectors.
1102   uint8x8x4_t table0;
1103   uint8x8x2_t table1;
1104   table0.val[0] = vld1_u8(kSgrMaLookup + 0 * 8);
1105   table0.val[1] = vld1_u8(kSgrMaLookup + 1 * 8);
1106   table0.val[2] = vld1_u8(kSgrMaLookup + 2 * 8);
1107   table0.val[3] = vld1_u8(kSgrMaLookup + 3 * 8);
1108   table1.val[0] = vld1_u8(kSgrMaLookup + 4 * 8);
1109   table1.val[1] = vld1_u8(kSgrMaLookup + 5 * 8);
1110   // All elements whose indices are out of range [0, 47] are set to 0.
1111   uint8x8_t val = vtbl4_u8(table0, idx);  // Range [0, 31].
1112   // Subtract 8 to shuffle the next index range.
1113   const uint8x8_t index = vsub_u8(idx, vdup_n_u8(32));
1114   const uint8x8_t res = vtbl2_u8(table1, index);  // Range [32, 47].
1115   // Use OR instruction to combine shuffle results together.
1116   val = vorr_u8(val, res);
1117 
1118   // For elements whose indices are larger than 47, since they seldom change
1119   // values with the increase of the index, we use comparison and arithmetic
1120   // operations to calculate their values.
1121   // Elements whose indices are larger than 47 (with value 0) are set to 5.
1122   val = vmax_u8(val, vdup_n_u8(5));
1123   val = AdjustValue(val, idx, 55);   // 55 is the last index which value is 5.
1124   val = AdjustValue(val, idx, 72);   // 72 is the last index which value is 4.
1125   val = AdjustValue(val, idx, 101);  // 101 is the last index which value is 3.
1126   val = AdjustValue(val, idx, 169);  // 169 is the last index which value is 2.
1127   val = AdjustValue(val, idx, 254);  // 254 is the last index which value is 1.
1128   // offset == 0 is assumed to be the first call to this function. Note
1129   // vget_high_u8(*ma) is not used in this case to avoid a -Wuninitialized
1130   // warning with some versions of gcc. vdup_n_u8(0) could work as well, but in
1131   // most cases clang and gcc generated better code with this version.
1132   *ma = (offset == 0) ? vcombine_u8(val, val)
1133                       : vcombine_u8(vget_low_u8(*ma), val);
1134 
1135   // b = ma * b * one_over_n
1136   // |ma| = [0, 255]
1137   // |sum| is a box sum with radius 1 or 2.
1138   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
1139   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
1140   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
1141   // When radius is 2 |n| is 25. |one_over_n| is 164.
1142   // When radius is 1 |n| is 9. |one_over_n| is 455.
1143   // |kSgrProjReciprocalBits| is 12.
1144   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
1145   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
1146   const uint16x8_t maq =
1147       vmovl_u8((offset == 0) ? vget_low_u8(*ma) : vget_high_u8(*ma));
1148   const uint32x4_t m0 = vmull_u16(vget_low_u16(maq), vget_low_u16(sum));
1149   const uint32x4_t m1 = vmull_u16(vget_high_u16(maq), vget_high_u16(sum));
1150   const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n);
1151   const uint32x4_t m3 = vmulq_n_u32(m1, one_over_n);
1152   const uint16x4_t b_lo = vrshrn_n_u32(m2, kSgrProjReciprocalBits);
1153   const uint16x4_t b_hi = vrshrn_n_u32(m3, kSgrProjReciprocalBits);
1154   *b = vcombine_u16(b_lo, b_hi);
1155 }
1156 
1157 template <int offset>
CalculateIntermediate5(const uint16x8_t s5[5],const uint32x4x2_t sq5[5],const uint32_t scale,uint8x16_t * const ma,uint16x8_t * const b)1158 inline void CalculateIntermediate5(const uint16x8_t s5[5],
1159                                    const uint32x4x2_t sq5[5],
1160                                    const uint32_t scale, uint8x16_t* const ma,
1161                                    uint16x8_t* const b) {
1162   const uint16x8_t sum = Sum5_16(s5);
1163   const uint32x4x2_t sum_sq = Sum5_32(sq5);
1164   CalculateIntermediate<25, offset>(sum, sum_sq, scale, ma, b);
1165 }
1166 
1167 template <int offset>
CalculateIntermediate3(const uint16x8_t s3[3],const uint32x4x2_t sq3[3],const uint32_t scale,uint8x16_t * const ma,uint16x8_t * const b)1168 inline void CalculateIntermediate3(const uint16x8_t s3[3],
1169                                    const uint32x4x2_t sq3[3],
1170                                    const uint32_t scale, uint8x16_t* const ma,
1171                                    uint16x8_t* const b) {
1172   const uint16x8_t sum = Sum3_16(s3);
1173   const uint32x4x2_t sum_sq = Sum3_32(sq3);
1174   CalculateIntermediate<9, offset>(sum, sum_sq, scale, ma, b);
1175 }
1176 
1177 template <int offset>
Store343_444(const uint8x16_t ma3[3],const uint16x8_t b3[2],const ptrdiff_t x,uint16x8_t * const sum_ma343,uint16x8_t * const sum_ma444,uint32x4x2_t * const sum_b343,uint32x4x2_t * const sum_b444,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1178 inline void Store343_444(const uint8x16_t ma3[3], const uint16x8_t b3[2],
1179                          const ptrdiff_t x, uint16x8_t* const sum_ma343,
1180                          uint16x8_t* const sum_ma444,
1181                          uint32x4x2_t* const sum_b343,
1182                          uint32x4x2_t* const sum_b444, uint16_t* const ma343,
1183                          uint16_t* const ma444, uint32_t* const b343,
1184                          uint32_t* const b444) {
1185   const uint16x8_t sum_ma111 = (offset == 0) ? Sum3WLo16(ma3) : Sum3WHi16(ma3);
1186   *sum_ma444 = vshlq_n_u16(sum_ma111, 2);
1187   const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111);
1188   *sum_ma343 = vaddw_u8(
1189       sum333, (offset == 0) ? vget_low_u8(ma3[1]) : vget_high_u8(ma3[1]));
1190   uint16x4_t low[3], high[3];
1191   uint32x4x2_t sum_b111;
1192   Prepare3_16(b3, low, high);
1193   sum_b111.val[0] = Sum3W_32(low);
1194   sum_b111.val[1] = Sum3W_32(high);
1195   sum_b444->val[0] = vshlq_n_u32(sum_b111.val[0], 2);
1196   sum_b444->val[1] = vshlq_n_u32(sum_b111.val[1], 2);
1197   sum_b343->val[0] = vsubq_u32(sum_b444->val[0], sum_b111.val[0]);
1198   sum_b343->val[1] = vsubq_u32(sum_b444->val[1], sum_b111.val[1]);
1199   sum_b343->val[0] = vaddw_u16(sum_b343->val[0], low[1]);
1200   sum_b343->val[1] = vaddw_u16(sum_b343->val[1], high[1]);
1201   vst1q_u16(ma343 + x, *sum_ma343);
1202   vst1q_u16(ma444 + x, *sum_ma444);
1203   vst1q_u32(b343 + x + 0, sum_b343->val[0]);
1204   vst1q_u32(b343 + x + 4, sum_b343->val[1]);
1205   vst1q_u32(b444 + x + 0, sum_b444->val[0]);
1206   vst1q_u32(b444 + x + 4, sum_b444->val[1]);
1207 }
1208 
1209 template <int offset>
Store343_444(const uint8x16_t ma3[3],const uint16x8_t b3[2],const ptrdiff_t x,uint16x8_t * const sum_ma343,uint32x4x2_t * const sum_b343,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1210 inline void Store343_444(const uint8x16_t ma3[3], const uint16x8_t b3[2],
1211                          const ptrdiff_t x, uint16x8_t* const sum_ma343,
1212                          uint32x4x2_t* const sum_b343, uint16_t* const ma343,
1213                          uint16_t* const ma444, uint32_t* const b343,
1214                          uint32_t* const b444) {
1215   uint16x8_t sum_ma444;
1216   uint32x4x2_t sum_b444;
1217   Store343_444<offset>(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, &sum_b444,
1218                        ma343, ma444, b343, b444);
1219 }
1220 
1221 template <int offset>
Store343_444(const uint8x16_t ma3[3],const uint16x8_t b3[2],const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1222 inline void Store343_444(const uint8x16_t ma3[3], const uint16x8_t b3[2],
1223                          const ptrdiff_t x, uint16_t* const ma343,
1224                          uint16_t* const ma444, uint32_t* const b343,
1225                          uint32_t* const b444) {
1226   uint16x8_t sum_ma343;
1227   uint32x4x2_t sum_b343;
1228   Store343_444<offset>(ma3, b3, x, &sum_ma343, &sum_b343, ma343, ma444, b343,
1229                        b444);
1230 }
1231 
BoxFilterPreProcess5Lo(uint8x16_t s[2][2],const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16x8_t sq[2][4],uint8x16_t * const ma,uint16x8_t * const b)1232 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo(
1233     uint8x16_t s[2][2], const uint32_t scale, uint16_t* const sum5[5],
1234     uint32_t* const square_sum5[5], uint16x8_t sq[2][4], uint8x16_t* const ma,
1235     uint16x8_t* const b) {
1236   uint16x8_t s5[5];
1237   uint32x4x2_t sq5[5];
1238   sq[0][0] = SquareLo8(s[0][0]);
1239   sq[1][0] = SquareLo8(s[1][0]);
1240   sq[0][1] = SquareHi8(s[0][0]);
1241   sq[1][1] = SquareHi8(s[1][0]);
1242   s5[3] = Sum5Horizontal(s[0][0]);
1243   s5[4] = Sum5Horizontal(s[1][0]);
1244   sq5[3] = Sum5WHorizontal(sq[0]);
1245   sq5[4] = Sum5WHorizontal(sq[1]);
1246   vst1q_u16(sum5[3], s5[3]);
1247   vst1q_u16(sum5[4], s5[4]);
1248   StoreAligned32U32(square_sum5[3], sq5[3]);
1249   StoreAligned32U32(square_sum5[4], sq5[4]);
1250   LoadAligned16x3U16(sum5, 0, s5);
1251   LoadAligned32x3U32(square_sum5, 0, sq5);
1252   CalculateIntermediate5<0>(s5, sq5, scale, ma, b);
1253 }
1254 
BoxFilterPreProcess5(uint8x16_t s[2][2],const ptrdiff_t x,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16x8_t sq[2][4],uint8x16_t ma[2],uint16x8_t b[2])1255 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
1256     uint8x16_t s[2][2], const ptrdiff_t x, const uint32_t scale,
1257     uint16_t* const sum5[5], uint32_t* const square_sum5[5],
1258     uint16x8_t sq[2][4], uint8x16_t ma[2], uint16x8_t b[2]) {
1259   uint16x8_t s5[2][5];
1260   uint32x4x2_t sq5[5];
1261   sq[0][2] = SquareLo8(s[0][1]);
1262   sq[1][2] = SquareLo8(s[1][1]);
1263   Sum5Horizontal<8>(s[0], &s5[0][3], &s5[1][3]);
1264   Sum5Horizontal<8>(s[1], &s5[0][4], &s5[1][4]);
1265   sq5[3] = Sum5WHorizontal(sq[0] + 1);
1266   sq5[4] = Sum5WHorizontal(sq[1] + 1);
1267   vst1q_u16(sum5[3] + x, s5[0][3]);
1268   vst1q_u16(sum5[4] + x, s5[0][4]);
1269   StoreAligned32U32(square_sum5[3] + x, sq5[3]);
1270   StoreAligned32U32(square_sum5[4] + x, sq5[4]);
1271   LoadAligned16x3U16(sum5, x, s5[0]);
1272   LoadAligned32x3U32(square_sum5, x, sq5);
1273   CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], &b[0]);
1274 
1275   sq[0][3] = SquareHi8(s[0][1]);
1276   sq[1][3] = SquareHi8(s[1][1]);
1277   sq5[3] = Sum5WHorizontal(sq[0] + 2);
1278   sq5[4] = Sum5WHorizontal(sq[1] + 2);
1279   vst1q_u16(sum5[3] + x + 8, s5[1][3]);
1280   vst1q_u16(sum5[4] + x + 8, s5[1][4]);
1281   StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]);
1282   StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]);
1283   LoadAligned16x3U16(sum5, x + 8, s5[1]);
1284   LoadAligned32x3U32(square_sum5, x + 8, sq5);
1285   CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], &b[1]);
1286 }
1287 
BoxFilterPreProcess5LastRowLo(uint8x16_t * const s,const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],uint16x8_t sq[2],uint8x16_t * const ma,uint16x8_t * const b)1288 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo(
1289     uint8x16_t* const s, const uint32_t scale, const uint16_t* const sum5[5],
1290     const uint32_t* const square_sum5[5], uint16x8_t sq[2],
1291     uint8x16_t* const ma, uint16x8_t* const b) {
1292   uint16x8_t s5[5];
1293   uint32x4x2_t sq5[5];
1294   sq[0] = SquareLo8(s[0]);
1295   sq[1] = SquareHi8(s[0]);
1296   s5[3] = s5[4] = Sum5Horizontal(*s);
1297   sq5[3] = sq5[4] = Sum5WHorizontal(sq);
1298   LoadAligned16x3U16(sum5, 0, s5);
1299   LoadAligned32x3U32(square_sum5, 0, sq5);
1300   CalculateIntermediate5<0>(s5, sq5, scale, ma, b);
1301 }
1302 
BoxFilterPreProcess5LastRow(uint8x16_t s[2],const ptrdiff_t x,const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],uint16x8_t sq[3],uint8x16_t ma[2],uint16x8_t b[2])1303 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
1304     uint8x16_t s[2], const ptrdiff_t x, const uint32_t scale,
1305     const uint16_t* const sum5[5], const uint32_t* const square_sum5[5],
1306     uint16x8_t sq[3], uint8x16_t ma[2], uint16x8_t b[2]) {
1307   uint16x8_t s5[2][5];
1308   uint32x4x2_t sq5[5];
1309   sq[1] = SquareLo8(s[1]);
1310   Sum5Horizontal<8>(s, &s5[0][3], &s5[1][3]);
1311   sq5[3] = sq5[4] = Sum5WHorizontal(sq);
1312   LoadAligned16x3U16(sum5, x, s5[0]);
1313   s5[0][4] = s5[0][3];
1314   LoadAligned32x3U32(square_sum5, x, sq5);
1315   CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], &b[0]);
1316 
1317   sq[2] = SquareHi8(s[1]);
1318   sq5[3] = sq5[4] = Sum5WHorizontal(sq + 1);
1319   LoadAligned16x3U16(sum5, x + 8, s5[1]);
1320   s5[1][4] = s5[1][3];
1321   LoadAligned32x3U32(square_sum5, x + 8, sq5);
1322   CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], &b[1]);
1323 }
1324 
BoxFilterPreProcess3Lo(uint8x16_t * const s,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16x8_t sq[2],uint8x16_t * const ma,uint16x8_t * const b)1325 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo(
1326     uint8x16_t* const s, const uint32_t scale, uint16_t* const sum3[3],
1327     uint32_t* const square_sum3[3], uint16x8_t sq[2], uint8x16_t* const ma,
1328     uint16x8_t* const b) {
1329   uint16x8_t s3[3];
1330   uint32x4x2_t sq3[3];
1331   sq[0] = SquareLo8(*s);
1332   sq[1] = SquareHi8(*s);
1333   s3[2] = Sum3Horizontal(*s);
1334   sq3[2] = Sum3WHorizontal(sq);
1335   vst1q_u16(sum3[2], s3[2]);
1336   StoreAligned32U32(square_sum3[2], sq3[2]);
1337   LoadAligned16x2U16(sum3, 0, s3);
1338   LoadAligned32x2U32(square_sum3, 0, sq3);
1339   CalculateIntermediate3<0>(s3, sq3, scale, ma, b);
1340 }
1341 
BoxFilterPreProcess3(uint8x16_t s[2],const ptrdiff_t x,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16x8_t sq[3],uint8x16_t ma[2],uint16x8_t b[2])1342 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
1343     uint8x16_t s[2], const ptrdiff_t x, const uint32_t scale,
1344     uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16x8_t sq[3],
1345     uint8x16_t ma[2], uint16x8_t b[2]) {
1346   uint16x8_t s3[4];
1347   uint32x4x2_t sq3[3];
1348   sq[1] = SquareLo8(s[1]);
1349   Sum3Horizontal<8>(s, s3 + 2);
1350   sq3[2] = Sum3WHorizontal(sq);
1351   vst1q_u16(sum3[2] + x, s3[2]);
1352   StoreAligned32U32(square_sum3[2] + x, sq3[2]);
1353   LoadAligned16x2U16(sum3, x, s3);
1354   LoadAligned32x2U32(square_sum3, x, sq3);
1355   CalculateIntermediate3<8>(s3, sq3, scale, &ma[0], &b[0]);
1356 
1357   sq[2] = SquareHi8(s[1]);
1358   sq3[2] = Sum3WHorizontal(sq + 1);
1359   vst1q_u16(sum3[2] + x + 8, s3[3]);
1360   StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]);
1361   LoadAligned16x2U16(sum3, x + 8, s3 + 1);
1362   LoadAligned32x2U32(square_sum3, x + 8, sq3);
1363   CalculateIntermediate3<0>(s3 + 1, sq3, scale, &ma[1], &b[1]);
1364 }
1365 
BoxFilterPreProcessLo(uint8x16_t s[2][2],const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16x8_t sq[2][4],uint8x16_t ma3[2][2],uint16x8_t b3[2][3],uint8x16_t * const ma5,uint16x8_t * const b5)1366 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo(
1367     uint8x16_t s[2][2], const uint16_t scales[2], uint16_t* const sum3[4],
1368     uint16_t* const sum5[5], uint32_t* const square_sum3[4],
1369     uint32_t* const square_sum5[5], uint16x8_t sq[2][4], uint8x16_t ma3[2][2],
1370     uint16x8_t b3[2][3], uint8x16_t* const ma5, uint16x8_t* const b5) {
1371   uint16x8_t s3[4], s5[5];
1372   uint32x4x2_t sq3[4], sq5[5];
1373   sq[0][0] = SquareLo8(s[0][0]);
1374   sq[1][0] = SquareLo8(s[1][0]);
1375   sq[0][1] = SquareHi8(s[0][0]);
1376   sq[1][1] = SquareHi8(s[1][0]);
1377   SumHorizontal(s[0][0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]);
1378   SumHorizontal(s[1][0], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]);
1379   vst1q_u16(sum3[2], s3[2]);
1380   vst1q_u16(sum3[3], s3[3]);
1381   StoreAligned32U32(square_sum3[2], sq3[2]);
1382   StoreAligned32U32(square_sum3[3], sq3[3]);
1383   vst1q_u16(sum5[3], s5[3]);
1384   vst1q_u16(sum5[4], s5[4]);
1385   StoreAligned32U32(square_sum5[3], sq5[3]);
1386   StoreAligned32U32(square_sum5[4], sq5[4]);
1387   LoadAligned16x2U16(sum3, 0, s3);
1388   LoadAligned32x2U32(square_sum3, 0, sq3);
1389   LoadAligned16x3U16(sum5, 0, s5);
1390   LoadAligned32x3U32(square_sum5, 0, sq5);
1391   CalculateIntermediate3<0>(s3, sq3, scales[1], ma3[0], b3[0]);
1392   CalculateIntermediate3<0>(s3 + 1, sq3 + 1, scales[1], ma3[1], b3[1]);
1393   CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5);
1394 }
1395 
BoxFilterPreProcess(const uint8x16_t s[2][2],const ptrdiff_t x,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16x8_t sq[2][4],uint8x16_t ma3[2][2],uint16x8_t b3[2][3],uint8x16_t ma5[2],uint16x8_t b5[2])1396 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
1397     const uint8x16_t s[2][2], const ptrdiff_t x, const uint16_t scales[2],
1398     uint16_t* const sum3[4], uint16_t* const sum5[5],
1399     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1400     uint16x8_t sq[2][4], uint8x16_t ma3[2][2], uint16x8_t b3[2][3],
1401     uint8x16_t ma5[2], uint16x8_t b5[2]) {
1402   uint16x8_t s3[2][4], s5[2][5];
1403   uint32x4x2_t sq3[4], sq5[5];
1404   sq[0][2] = SquareLo8(s[0][1]);
1405   sq[1][2] = SquareLo8(s[1][1]);
1406   SumHorizontal<8>(s[0], &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]);
1407   SumHorizontal<8>(s[1], &s3[0][3], &s3[1][3], &s5[0][4], &s5[1][4]);
1408   SumHorizontal(sq[0] + 1, &sq3[2], &sq5[3]);
1409   SumHorizontal(sq[1] + 1, &sq3[3], &sq5[4]);
1410   vst1q_u16(sum3[2] + x, s3[0][2]);
1411   vst1q_u16(sum3[3] + x, s3[0][3]);
1412   StoreAligned32U32(square_sum3[2] + x, sq3[2]);
1413   StoreAligned32U32(square_sum3[3] + x, sq3[3]);
1414   vst1q_u16(sum5[3] + x, s5[0][3]);
1415   vst1q_u16(sum5[4] + x, s5[0][4]);
1416   StoreAligned32U32(square_sum5[3] + x, sq5[3]);
1417   StoreAligned32U32(square_sum5[4] + x, sq5[4]);
1418   LoadAligned16x2U16(sum3, x, s3[0]);
1419   LoadAligned32x2U32(square_sum3, x, sq3);
1420   LoadAligned16x3U16(sum5, x, s5[0]);
1421   LoadAligned32x3U32(square_sum5, x, sq5);
1422   CalculateIntermediate3<8>(s3[0], sq3, scales[1], &ma3[0][0], &b3[0][1]);
1423   CalculateIntermediate3<8>(s3[0] + 1, sq3 + 1, scales[1], &ma3[1][0],
1424                             &b3[1][1]);
1425   CalculateIntermediate5<8>(s5[0], sq5, scales[0], &ma5[0], &b5[0]);
1426 
1427   sq[0][3] = SquareHi8(s[0][1]);
1428   sq[1][3] = SquareHi8(s[1][1]);
1429   SumHorizontal(sq[0] + 2, &sq3[2], &sq5[3]);
1430   SumHorizontal(sq[1] + 2, &sq3[3], &sq5[4]);
1431   vst1q_u16(sum3[2] + x + 8, s3[1][2]);
1432   vst1q_u16(sum3[3] + x + 8, s3[1][3]);
1433   StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]);
1434   StoreAligned32U32(square_sum3[3] + x + 8, sq3[3]);
1435   vst1q_u16(sum5[3] + x + 8, s5[1][3]);
1436   vst1q_u16(sum5[4] + x + 8, s5[1][4]);
1437   StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]);
1438   StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]);
1439   LoadAligned16x2U16(sum3, x + 8, s3[1]);
1440   LoadAligned32x2U32(square_sum3, x + 8, sq3);
1441   LoadAligned16x3U16(sum5, x + 8, s5[1]);
1442   LoadAligned32x3U32(square_sum5, x + 8, sq5);
1443   CalculateIntermediate3<0>(s3[1], sq3, scales[1], &ma3[0][1], &b3[0][2]);
1444   CalculateIntermediate3<0>(s3[1] + 1, sq3 + 1, scales[1], &ma3[1][1],
1445                             &b3[1][2]);
1446   CalculateIntermediate5<0>(s5[1], sq5, scales[0], &ma5[1], &b5[1]);
1447 }
1448 
BoxFilterPreProcessLastRowLo(uint8x16_t * const s,const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],uint16x8_t sq[2],uint8x16_t * const ma3,uint8x16_t * const ma5,uint16x8_t * const b3,uint16x8_t * const b5)1449 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo(
1450     uint8x16_t* const s, const uint16_t scales[2],
1451     const uint16_t* const sum3[4], const uint16_t* const sum5[5],
1452     const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
1453     uint16x8_t sq[2], uint8x16_t* const ma3, uint8x16_t* const ma5,
1454     uint16x8_t* const b3, uint16x8_t* const b5) {
1455   uint16x8_t s3[3], s5[5];
1456   uint32x4x2_t sq3[3], sq5[5];
1457   sq[0] = SquareLo8(s[0]);
1458   sq[1] = SquareHi8(s[0]);
1459   SumHorizontal(*s, sq, &s3[2], &s5[3], &sq3[2], &sq5[3]);
1460   LoadAligned16x3U16(sum5, 0, s5);
1461   s5[4] = s5[3];
1462   LoadAligned32x3U32(square_sum5, 0, sq5);
1463   sq5[4] = sq5[3];
1464   CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5);
1465   LoadAligned16x2U16(sum3, 0, s3);
1466   LoadAligned32x2U32(square_sum3, 0, sq3);
1467   CalculateIntermediate3<0>(s3, sq3, scales[1], ma3, b3);
1468 }
1469 
BoxFilterPreProcessLastRow(uint8x16_t s[2],const ptrdiff_t x,const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],uint16x8_t sq[3],uint8x16_t ma3[2],uint8x16_t ma5[2],uint16x8_t b3[2],uint16x8_t b5[2])1470 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
1471     uint8x16_t s[2], const ptrdiff_t x, const uint16_t scales[2],
1472     const uint16_t* const sum3[4], const uint16_t* const sum5[5],
1473     const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
1474     uint16x8_t sq[3], uint8x16_t ma3[2], uint8x16_t ma5[2], uint16x8_t b3[2],
1475     uint16x8_t b5[2]) {
1476   uint16x8_t s3[2][3], s5[2][5];
1477   uint32x4x2_t sq3[3], sq5[5];
1478   sq[1] = SquareLo8(s[1]);
1479   SumHorizontal<8>(s, &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]);
1480   SumHorizontal(sq, &sq3[2], &sq5[3]);
1481   LoadAligned16x3U16(sum5, x, s5[0]);
1482   s5[0][4] = s5[0][3];
1483   LoadAligned32x3U32(square_sum5, x, sq5);
1484   sq5[4] = sq5[3];
1485   CalculateIntermediate5<8>(s5[0], sq5, scales[0], &ma5[0], &b5[0]);
1486   LoadAligned16x2U16(sum3, x, s3[0]);
1487   LoadAligned32x2U32(square_sum3, x, sq3);
1488   CalculateIntermediate3<8>(s3[0], sq3, scales[1], &ma3[0], &b3[0]);
1489 
1490   sq[2] = SquareHi8(s[1]);
1491   SumHorizontal(sq + 1, &sq3[2], &sq5[3]);
1492   LoadAligned16x3U16(sum5, x + 8, s5[1]);
1493   s5[1][4] = s5[1][3];
1494   LoadAligned32x3U32(square_sum5, x + 8, sq5);
1495   sq5[4] = sq5[3];
1496   CalculateIntermediate5<0>(s5[1], sq5, scales[0], &ma5[1], &b5[1]);
1497   LoadAligned16x2U16(sum3, x + 8, s3[1]);
1498   LoadAligned32x2U32(square_sum3, x + 8, sq3);
1499   CalculateIntermediate3<0>(s3[1], sq3, scales[1], &ma3[1], &b3[1]);
1500 }
1501 
BoxSumFilterPreProcess5(const uint8_t * const src0,const uint8_t * const src1,const int width,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16_t * ma565,uint32_t * b565)1502 inline void BoxSumFilterPreProcess5(const uint8_t* const src0,
1503                                     const uint8_t* const src1, const int width,
1504                                     const uint32_t scale,
1505                                     uint16_t* const sum5[5],
1506                                     uint32_t* const square_sum5[5],
1507                                     uint16_t* ma565, uint32_t* b565) {
1508   const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass1 - width;
1509   uint8x16_t s[2][2], mas[2];
1510   uint16x8_t sq[2][4], bs[3];
1511   s[0][0] = vld1q_u8(src0);
1512   s[1][0] = vld1q_u8(src1);
1513 
1514   BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], &bs[0]);
1515 
1516   int x = 0;
1517   do {
1518     uint16x8_t ma[2];
1519     uint8x16_t masx[3];
1520     uint32x4x2_t b[2];
1521     s[0][1] = Load1QMsanU8(src0 + x + 16, x + 16 + overread_in_bytes);
1522     s[1][1] = Load1QMsanU8(src1 + x + 16, x + 16 + overread_in_bytes);
1523     BoxFilterPreProcess5(s, x + 8, scale, sum5, square_sum5, sq, mas, bs + 1);
1524     Prepare3_8<0>(mas, masx);
1525     ma[0] = Sum565<0>(masx);
1526     b[0] = Sum565W(bs);
1527     vst1q_u16(ma565, ma[0]);
1528     vst1q_u32(b565 + 0, b[0].val[0]);
1529     vst1q_u32(b565 + 4, b[0].val[1]);
1530 
1531     ma[1] = Sum565<8>(masx);
1532     b[1] = Sum565W(bs + 1);
1533     vst1q_u16(ma565 + 8, ma[1]);
1534     vst1q_u32(b565 + 8, b[1].val[0]);
1535     vst1q_u32(b565 + 12, b[1].val[1]);
1536     s[0][0] = s[0][1];
1537     s[1][0] = s[1][1];
1538     sq[0][1] = sq[0][3];
1539     sq[1][1] = sq[1][3];
1540     mas[0] = mas[1];
1541     bs[0] = bs[2];
1542     ma565 += 16;
1543     b565 += 16;
1544     x += 16;
1545   } while (x < width);
1546 }
1547 
1548 template <bool calculate444>
BoxSumFilterPreProcess3(const uint8_t * const src,const int width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16_t * ma343,uint16_t * ma444,uint32_t * b343,uint32_t * b444)1549 LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
1550     const uint8_t* const src, const int width, const uint32_t scale,
1551     uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16_t* ma343,
1552     uint16_t* ma444, uint32_t* b343, uint32_t* b444) {
1553   const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass2 - width;
1554   uint8x16_t s[2], mas[2];
1555   uint16x8_t sq[4], bs[3];
1556   s[0] = Load1QMsanU8(src, overread_in_bytes);
1557   BoxFilterPreProcess3Lo(&s[0], scale, sum3, square_sum3, sq, &mas[0], &bs[0]);
1558 
1559   int x = 0;
1560   do {
1561     uint8x16_t ma3x[3];
1562     s[1] = Load1QMsanU8(src + x + 16, x + 16 + overread_in_bytes);
1563     BoxFilterPreProcess3(s, x + 8, scale, sum3, square_sum3, sq + 1, mas,
1564                          bs + 1);
1565     Prepare3_8<0>(mas, ma3x);
1566     if (calculate444) {
1567       Store343_444<0>(ma3x, bs + 0, 0, ma343, ma444, b343, b444);
1568       Store343_444<8>(ma3x, bs + 1, 0, ma343 + 8, ma444 + 8, b343 + 8,
1569                       b444 + 8);
1570       ma444 += 16;
1571       b444 += 16;
1572     } else {
1573       uint16x8_t ma[2];
1574       uint32x4x2_t b[2];
1575       ma[0] = Sum343<0>(ma3x);
1576       b[0] = Sum343W(bs);
1577       vst1q_u16(ma343, ma[0]);
1578       vst1q_u32(b343 + 0, b[0].val[0]);
1579       vst1q_u32(b343 + 4, b[0].val[1]);
1580       ma[1] = Sum343<8>(ma3x);
1581       b[1] = Sum343W(bs + 1);
1582       vst1q_u16(ma343 + 8, ma[1]);
1583       vst1q_u32(b343 + 8, b[1].val[0]);
1584       vst1q_u32(b343 + 12, b[1].val[1]);
1585     }
1586     s[0] = s[1];
1587     sq[1] = sq[3];
1588     mas[0] = mas[1];
1589     bs[0] = bs[2];
1590     ma343 += 16;
1591     b343 += 16;
1592     x += 16;
1593   } while (x < width);
1594 }
1595 
BoxSumFilterPreProcess(const uint8_t * const src0,const uint8_t * const src1,const int width,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343[4],uint16_t * const ma444,uint16_t * ma565,uint32_t * const b343[4],uint32_t * const b444,uint32_t * b565)1596 inline void BoxSumFilterPreProcess(
1597     const uint8_t* const src0, const uint8_t* const src1, const int width,
1598     const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
1599     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1600     uint16_t* const ma343[4], uint16_t* const ma444, uint16_t* ma565,
1601     uint32_t* const b343[4], uint32_t* const b444, uint32_t* b565) {
1602   const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass1 - width;
1603   uint8x16_t s[2][2], ma3[2][2], ma5[2];
1604   uint16x8_t sq[2][4], b3[2][3], b5[3];
1605   s[0][0] = vld1q_u8(src0);
1606   s[1][0] = vld1q_u8(src1);
1607 
1608   BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq,
1609                         ma3, b3, &ma5[0], &b5[0]);
1610 
1611   int x = 0;
1612   do {
1613     uint16x8_t ma[2];
1614     uint8x16_t ma3x[3], ma5x[3];
1615     uint32x4x2_t b[2];
1616 
1617     s[0][1] = Load1QMsanU8(src0 + x + 16, x + 16 + overread_in_bytes);
1618     s[1][1] = Load1QMsanU8(src1 + x + 16, x + 16 + overread_in_bytes);
1619     BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5,
1620                         sq, ma3, b3, ma5, b5 + 1);
1621     Prepare3_8<0>(ma3[0], ma3x);
1622     ma[0] = Sum343<0>(ma3x);
1623     ma[1] = Sum343<8>(ma3x);
1624     StoreAligned32U16(ma343[0] + x, ma);
1625     b[0] = Sum343W(b3[0] + 0);
1626     b[1] = Sum343W(b3[0] + 1);
1627     StoreAligned64U32(b343[0] + x, b);
1628     Prepare3_8<0>(ma3[1], ma3x);
1629     Store343_444<0>(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444);
1630     Store343_444<8>(ma3x, b3[1] + 1, x + 8, ma343[1], ma444, b343[1], b444);
1631     Prepare3_8<0>(ma5, ma5x);
1632     ma[0] = Sum565<0>(ma5x);
1633     ma[1] = Sum565<8>(ma5x);
1634     StoreAligned32U16(ma565, ma);
1635     b[0] = Sum565W(b5);
1636     b[1] = Sum565W(b5 + 1);
1637     StoreAligned64U32(b565, b);
1638     s[0][0] = s[0][1];
1639     s[1][0] = s[1][1];
1640     sq[0][1] = sq[0][3];
1641     sq[1][1] = sq[1][3];
1642     ma3[0][0] = ma3[0][1];
1643     ma3[1][0] = ma3[1][1];
1644     b3[0][0] = b3[0][2];
1645     b3[1][0] = b3[1][2];
1646     ma5[0] = ma5[1];
1647     b5[0] = b5[2];
1648     ma565 += 16;
1649     b565 += 16;
1650     x += 16;
1651   } while (x < width);
1652 }
1653 
1654 template <int shift>
FilterOutput(const uint16x4_t src,const uint16x4_t ma,const uint32x4_t b)1655 inline int16x4_t FilterOutput(const uint16x4_t src, const uint16x4_t ma,
1656                               const uint32x4_t b) {
1657   // ma: 255 * 32 = 8160 (13 bits)
1658   // b: 65088 * 32 = 2082816 (21 bits)
1659   // v: b - ma * 255 (22 bits)
1660   const int32x4_t v = vreinterpretq_s32_u32(vmlsl_u16(b, ma, src));
1661   // kSgrProjSgrBits = 8
1662   // kSgrProjRestoreBits = 4
1663   // shift = 4 or 5
1664   // v >> 8 or 9 (13 bits)
1665   return vrshrn_n_s32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
1666 }
1667 
1668 template <int shift>
CalculateFilteredOutput(const uint8x8_t src,const uint16x8_t ma,const uint32x4x2_t b)1669 inline int16x8_t CalculateFilteredOutput(const uint8x8_t src,
1670                                          const uint16x8_t ma,
1671                                          const uint32x4x2_t b) {
1672   const uint16x8_t src_u16 = vmovl_u8(src);
1673   const int16x4_t dst_lo =
1674       FilterOutput<shift>(vget_low_u16(src_u16), vget_low_u16(ma), b.val[0]);
1675   const int16x4_t dst_hi =
1676       FilterOutput<shift>(vget_high_u16(src_u16), vget_high_u16(ma), b.val[1]);
1677   return vcombine_s16(dst_lo, dst_hi);  // 13 bits
1678 }
1679 
CalculateFilteredOutputPass1(const uint8x8_t s,uint16x8_t ma[2],uint32x4x2_t b[2])1680 inline int16x8_t CalculateFilteredOutputPass1(const uint8x8_t s,
1681                                               uint16x8_t ma[2],
1682                                               uint32x4x2_t b[2]) {
1683   const uint16x8_t ma_sum = vaddq_u16(ma[0], ma[1]);
1684   uint32x4x2_t b_sum;
1685   b_sum.val[0] = vaddq_u32(b[0].val[0], b[1].val[0]);
1686   b_sum.val[1] = vaddq_u32(b[0].val[1], b[1].val[1]);
1687   return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
1688 }
1689 
CalculateFilteredOutputPass2(const uint8x8_t s,uint16x8_t ma[3],uint32x4x2_t b[3])1690 inline int16x8_t CalculateFilteredOutputPass2(const uint8x8_t s,
1691                                               uint16x8_t ma[3],
1692                                               uint32x4x2_t b[3]) {
1693   const uint16x8_t ma_sum = Sum3_16(ma);
1694   const uint32x4x2_t b_sum = Sum3_32(b);
1695   return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
1696 }
1697 
SelfGuidedFinal(const uint8x8_t src,const int32x4_t v[2])1698 inline uint8x8_t SelfGuidedFinal(const uint8x8_t src, const int32x4_t v[2]) {
1699   const int16x4_t v_lo =
1700       vrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
1701   const int16x4_t v_hi =
1702       vrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
1703   const int16x8_t vv = vcombine_s16(v_lo, v_hi);
1704   const int16x8_t d =
1705       vreinterpretq_s16_u16(vaddw_u8(vreinterpretq_u16_s16(vv), src));
1706   return vqmovun_s16(d);
1707 }
1708 
SelfGuidedDoubleMultiplier(const uint8x8_t src,const int16x8_t filter[2],const int w0,const int w2)1709 inline uint8x8_t SelfGuidedDoubleMultiplier(const uint8x8_t src,
1710                                             const int16x8_t filter[2],
1711                                             const int w0, const int w2) {
1712   int32x4_t v[2];
1713   v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0);
1714   v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0);
1715   v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2);
1716   v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2);
1717   return SelfGuidedFinal(src, v);
1718 }
1719 
SelfGuidedSingleMultiplier(const uint8x8_t src,const int16x8_t filter,const int w0)1720 inline uint8x8_t SelfGuidedSingleMultiplier(const uint8x8_t src,
1721                                             const int16x8_t filter,
1722                                             const int w0) {
1723   // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
1724   int32x4_t v[2];
1725   v[0] = vmull_n_s16(vget_low_s16(filter), w0);
1726   v[1] = vmull_n_s16(vget_high_s16(filter), w0);
1727   return SelfGuidedFinal(src, v);
1728 }
1729 
BoxFilterPass1(const uint8_t * const src,const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t stride,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const int width,const uint32_t scale,const int16_t w0,uint16_t * const ma565[2],uint32_t * const b565[2],uint8_t * const dst)1730 LIBGAV1_ALWAYS_INLINE void BoxFilterPass1(
1731     const uint8_t* const src, const uint8_t* const src0,
1732     const uint8_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5],
1733     uint32_t* const square_sum5[5], const int width, const uint32_t scale,
1734     const int16_t w0, uint16_t* const ma565[2], uint32_t* const b565[2],
1735     uint8_t* const dst) {
1736   const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass1 - width;
1737   uint8x16_t s[2][2], mas[2];
1738   uint16x8_t sq[2][4], bs[3];
1739   s[0][0] = Load1QMsanU8(src0, overread_in_bytes);
1740   s[1][0] = Load1QMsanU8(src1, overread_in_bytes);
1741 
1742   BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], &bs[0]);
1743 
1744   int x = 0;
1745   do {
1746     uint16x8_t ma[2];
1747     uint8x16_t masx[3];
1748     uint32x4x2_t b[2];
1749     int16x8_t p0, p1;
1750     s[0][1] = Load1QMsanU8(src0 + x + 16, x + 16 + overread_in_bytes);
1751     s[1][1] = Load1QMsanU8(src1 + x + 16, x + 16 + overread_in_bytes);
1752     BoxFilterPreProcess5(s, x + 8, scale, sum5, square_sum5, sq, mas, bs + 1);
1753     Prepare3_8<0>(mas, masx);
1754     ma[1] = Sum565<0>(masx);
1755     b[1] = Sum565W(bs);
1756     vst1q_u16(ma565[1] + x, ma[1]);
1757     vst1q_u32(b565[1] + x + 0, b[1].val[0]);
1758     vst1q_u32(b565[1] + x + 4, b[1].val[1]);
1759     const uint8x16_t sr0 = vld1q_u8(src + x);
1760     const uint8x16_t sr1 = vld1q_u8(src + stride + x);
1761     const uint8x8_t sr00 = vget_low_u8(sr0);
1762     const uint8x8_t sr10 = vget_low_u8(sr1);
1763     ma[0] = vld1q_u16(ma565[0] + x);
1764     b[0].val[0] = vld1q_u32(b565[0] + x + 0);
1765     b[0].val[1] = vld1q_u32(b565[0] + x + 4);
1766     p0 = CalculateFilteredOutputPass1(sr00, ma, b);
1767     p1 = CalculateFilteredOutput<4>(sr10, ma[1], b[1]);
1768     const uint8x8_t d00 = SelfGuidedSingleMultiplier(sr00, p0, w0);
1769     const uint8x8_t d10 = SelfGuidedSingleMultiplier(sr10, p1, w0);
1770 
1771     ma[1] = Sum565<8>(masx);
1772     b[1] = Sum565W(bs + 1);
1773     vst1q_u16(ma565[1] + x + 8, ma[1]);
1774     vst1q_u32(b565[1] + x + 8, b[1].val[0]);
1775     vst1q_u32(b565[1] + x + 12, b[1].val[1]);
1776     const uint8x8_t sr01 = vget_high_u8(sr0);
1777     const uint8x8_t sr11 = vget_high_u8(sr1);
1778     ma[0] = vld1q_u16(ma565[0] + x + 8);
1779     b[0].val[0] = vld1q_u32(b565[0] + x + 8);
1780     b[0].val[1] = vld1q_u32(b565[0] + x + 12);
1781     p0 = CalculateFilteredOutputPass1(sr01, ma, b);
1782     p1 = CalculateFilteredOutput<4>(sr11, ma[1], b[1]);
1783     const uint8x8_t d01 = SelfGuidedSingleMultiplier(sr01, p0, w0);
1784     const uint8x8_t d11 = SelfGuidedSingleMultiplier(sr11, p1, w0);
1785     vst1q_u8(dst + x, vcombine_u8(d00, d01));
1786     vst1q_u8(dst + stride + x, vcombine_u8(d10, d11));
1787     s[0][0] = s[0][1];
1788     s[1][0] = s[1][1];
1789     sq[0][1] = sq[0][3];
1790     sq[1][1] = sq[1][3];
1791     mas[0] = mas[1];
1792     bs[0] = bs[2];
1793     x += 16;
1794   } while (x < width);
1795 }
1796 
BoxFilterPass1LastRow(const uint8_t * const src,const uint8_t * const src0,const int width,const uint32_t scale,const int16_t w0,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16_t * ma565,uint32_t * b565,uint8_t * const dst)1797 inline void BoxFilterPass1LastRow(const uint8_t* const src,
1798                                   const uint8_t* const src0, const int width,
1799                                   const uint32_t scale, const int16_t w0,
1800                                   uint16_t* const sum5[5],
1801                                   uint32_t* const square_sum5[5],
1802                                   uint16_t* ma565, uint32_t* b565,
1803                                   uint8_t* const dst) {
1804   uint8x16_t s[2], mas[2];
1805   uint16x8_t sq[4], bs[4];
1806   s[0] = vld1q_u8(src0);
1807 
1808   BoxFilterPreProcess5LastRowLo(s, scale, sum5, square_sum5, sq, &mas[0],
1809                                 &bs[0]);
1810 
1811   int x = 0;
1812   do {
1813     uint16x8_t ma[2];
1814     uint8x16_t masx[3];
1815     uint32x4x2_t b[2];
1816     s[1] = vld1q_u8(src0 + x + 16);
1817 
1818     BoxFilterPreProcess5LastRow(s, x + 8, scale, sum5, square_sum5, sq + 1, mas,
1819                                 bs + 1);
1820     Prepare3_8<0>(mas, masx);
1821     ma[1] = Sum565<0>(masx);
1822     b[1] = Sum565W(bs);
1823     ma[0] = vld1q_u16(ma565);
1824     b[0].val[0] = vld1q_u32(b565 + 0);
1825     b[0].val[1] = vld1q_u32(b565 + 4);
1826     const uint8x16_t sr = vld1q_u8(src + x);
1827     const uint8x8_t sr0 = vget_low_u8(sr);
1828     const int16x8_t p0 = CalculateFilteredOutputPass1(sr0, ma, b);
1829     const uint8x8_t d0 = SelfGuidedSingleMultiplier(sr0, p0, w0);
1830 
1831     ma[1] = Sum565<8>(masx);
1832     b[1] = Sum565W(bs + 1);
1833     bs[0] = bs[2];
1834     const uint8x8_t sr1 = vget_high_u8(sr);
1835     ma[0] = vld1q_u16(ma565 + 8);
1836     b[0].val[0] = vld1q_u32(b565 + 8);
1837     b[0].val[1] = vld1q_u32(b565 + 12);
1838     const int16x8_t p1 = CalculateFilteredOutputPass1(sr1, ma, b);
1839     const uint8x8_t d1 = SelfGuidedSingleMultiplier(sr1, p1, w0);
1840     vst1q_u8(dst + x, vcombine_u8(d0, d1));
1841     s[0] = s[1];
1842     sq[1] = sq[3];
1843     mas[0] = mas[1];
1844     ma565 += 16;
1845     b565 += 16;
1846     x += 16;
1847   } while (x < width);
1848 }
1849 
BoxFilterPass2(const uint8_t * const src,const uint8_t * const src0,const int width,const uint32_t scale,const int16_t w0,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16_t * const ma343[3],uint16_t * const ma444[2],uint32_t * const b343[3],uint32_t * const b444[2],uint8_t * const dst)1850 LIBGAV1_ALWAYS_INLINE void BoxFilterPass2(
1851     const uint8_t* const src, const uint8_t* const src0, const int width,
1852     const uint32_t scale, const int16_t w0, uint16_t* const sum3[3],
1853     uint32_t* const square_sum3[3], uint16_t* const ma343[3],
1854     uint16_t* const ma444[2], uint32_t* const b343[3], uint32_t* const b444[2],
1855     uint8_t* const dst) {
1856   const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass2 - width;
1857   uint8x16_t s[2], mas[2];
1858   uint16x8_t sq[4], bs[3];
1859   s[0] = vld1q_u8(src0);
1860 
1861   BoxFilterPreProcess3Lo(&s[0], scale, sum3, square_sum3, sq, &mas[0], &bs[0]);
1862 
1863   int x = 0;
1864   do {
1865     uint16x8_t ma[3];
1866     uint8x16_t ma3x[3];
1867     uint32x4x2_t b[3];
1868     s[1] = Load1QMsanU8(src0 + x + 16, x + 16 + overread_in_bytes);
1869     BoxFilterPreProcess3(s, x + 8, scale, sum3, square_sum3, sq + 1, mas,
1870                          bs + 1);
1871     Prepare3_8<0>(mas, ma3x);
1872     Store343_444<0>(ma3x, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2],
1873                     b444[1]);
1874     const uint8x16_t sr = vld1q_u8(src + x);
1875     const uint8x8_t sr0 = vget_low_u8(sr);
1876     ma[0] = vld1q_u16(ma343[0] + x);
1877     ma[1] = vld1q_u16(ma444[0] + x);
1878     b[0].val[0] = vld1q_u32(b343[0] + x + 0);
1879     b[0].val[1] = vld1q_u32(b343[0] + x + 4);
1880     b[1].val[0] = vld1q_u32(b444[0] + x + 0);
1881     b[1].val[1] = vld1q_u32(b444[0] + x + 4);
1882     const int16x8_t p0 = CalculateFilteredOutputPass2(sr0, ma, b);
1883     const uint8x8_t d0 = SelfGuidedSingleMultiplier(sr0, p0, w0);
1884 
1885     Store343_444<8>(ma3x, bs + 1, x + 8, &ma[2], &b[2], ma343[2], ma444[1],
1886                     b343[2], b444[1]);
1887     const uint8x8_t sr1 = vget_high_u8(sr);
1888     ma[0] = vld1q_u16(ma343[0] + x + 8);
1889     ma[1] = vld1q_u16(ma444[0] + x + 8);
1890     b[0].val[0] = vld1q_u32(b343[0] + x + 8);
1891     b[0].val[1] = vld1q_u32(b343[0] + x + 12);
1892     b[1].val[0] = vld1q_u32(b444[0] + x + 8);
1893     b[1].val[1] = vld1q_u32(b444[0] + x + 12);
1894     const int16x8_t p1 = CalculateFilteredOutputPass2(sr1, ma, b);
1895     const uint8x8_t d1 = SelfGuidedSingleMultiplier(sr1, p1, w0);
1896     vst1q_u8(dst + x, vcombine_u8(d0, d1));
1897     s[0] = s[1];
1898     sq[1] = sq[3];
1899     mas[0] = mas[1];
1900     bs[0] = bs[2];
1901     x += 16;
1902   } while (x < width);
1903 }
1904 
BoxFilter(const uint8_t * const src,const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t stride,const int width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343[4],uint16_t * const ma444[3],uint16_t * const ma565[2],uint32_t * const b343[4],uint32_t * const b444[3],uint32_t * const b565[2],uint8_t * const dst)1905 LIBGAV1_ALWAYS_INLINE void BoxFilter(
1906     const uint8_t* const src, const uint8_t* const src0,
1907     const uint8_t* const src1, const ptrdiff_t stride, const int width,
1908     const uint16_t scales[2], const int16_t w0, const int16_t w2,
1909     uint16_t* const sum3[4], uint16_t* const sum5[5],
1910     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1911     uint16_t* const ma343[4], uint16_t* const ma444[3],
1912     uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3],
1913     uint32_t* const b565[2], uint8_t* const dst) {
1914   const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass1 - width;
1915   uint8x16_t s[2][2], ma3[2][2], ma5[2];
1916   uint16x8_t sq[2][4], b3[2][3], b5[3];
1917   s[0][0] = vld1q_u8(src0);
1918   s[1][0] = vld1q_u8(src1);
1919 
1920   BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq,
1921                         ma3, b3, &ma5[0], &b5[0]);
1922 
1923   int x = 0;
1924   do {
1925     uint16x8_t ma[3][3];
1926     uint8x16_t ma3x[2][3], ma5x[3];
1927     uint32x4x2_t b[3][3];
1928     int16x8_t p[2][2];
1929     s[0][1] = Load1QMsanU8(src0 + x + 16, x + 16 + overread_in_bytes);
1930     s[1][1] = Load1QMsanU8(src1 + x + 16, x + 16 + overread_in_bytes);
1931     BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5,
1932                         sq, ma3, b3, ma5, b5 + 1);
1933     Prepare3_8<0>(ma3[0], ma3x[0]);
1934     Prepare3_8<0>(ma3[1], ma3x[1]);
1935     Store343_444<0>(ma3x[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1],
1936                     ma343[2], ma444[1], b343[2], b444[1]);
1937     Store343_444<0>(ma3x[1], b3[1], x, &ma[2][2], &b[2][2], ma343[3], ma444[2],
1938                     b343[3], b444[2]);
1939     Prepare3_8<0>(ma5, ma5x);
1940     ma[0][1] = Sum565<0>(ma5x);
1941     b[0][1] = Sum565W(b5);
1942     vst1q_u16(ma565[1] + x, ma[0][1]);
1943     vst1q_u32(b565[1] + x, b[0][1].val[0]);
1944     vst1q_u32(b565[1] + x + 4, b[0][1].val[1]);
1945     const uint8x16_t sr0 = vld1q_u8(src + x);
1946     const uint8x16_t sr1 = vld1q_u8(src + stride + x);
1947     const uint8x8_t sr00 = vget_low_u8(sr0);
1948     const uint8x8_t sr10 = vget_low_u8(sr1);
1949     ma[0][0] = vld1q_u16(ma565[0] + x);
1950     b[0][0].val[0] = vld1q_u32(b565[0] + x);
1951     b[0][0].val[1] = vld1q_u32(b565[0] + x + 4);
1952     p[0][0] = CalculateFilteredOutputPass1(sr00, ma[0], b[0]);
1953     p[1][0] = CalculateFilteredOutput<4>(sr10, ma[0][1], b[0][1]);
1954     ma[1][0] = vld1q_u16(ma343[0] + x);
1955     ma[1][1] = vld1q_u16(ma444[0] + x);
1956     b[1][0].val[0] = vld1q_u32(b343[0] + x);
1957     b[1][0].val[1] = vld1q_u32(b343[0] + x + 4);
1958     b[1][1].val[0] = vld1q_u32(b444[0] + x);
1959     b[1][1].val[1] = vld1q_u32(b444[0] + x + 4);
1960     p[0][1] = CalculateFilteredOutputPass2(sr00, ma[1], b[1]);
1961     ma[2][0] = vld1q_u16(ma343[1] + x);
1962     b[2][0].val[0] = vld1q_u32(b343[1] + x);
1963     b[2][0].val[1] = vld1q_u32(b343[1] + x + 4);
1964     p[1][1] = CalculateFilteredOutputPass2(sr10, ma[2], b[2]);
1965     const uint8x8_t d00 = SelfGuidedDoubleMultiplier(sr00, p[0], w0, w2);
1966     const uint8x8_t d10 = SelfGuidedDoubleMultiplier(sr10, p[1], w0, w2);
1967 
1968     Store343_444<8>(ma3x[0], b3[0] + 1, x + 8, &ma[1][2], &ma[2][1], &b[1][2],
1969                     &b[2][1], ma343[2], ma444[1], b343[2], b444[1]);
1970     Store343_444<8>(ma3x[1], b3[1] + 1, x + 8, &ma[2][2], &b[2][2], ma343[3],
1971                     ma444[2], b343[3], b444[2]);
1972     ma[0][1] = Sum565<8>(ma5x);
1973     b[0][1] = Sum565W(b5 + 1);
1974     vst1q_u16(ma565[1] + x + 8, ma[0][1]);
1975     vst1q_u32(b565[1] + x + 8, b[0][1].val[0]);
1976     vst1q_u32(b565[1] + x + 12, b[0][1].val[1]);
1977     b3[0][0] = b3[0][2];
1978     b3[1][0] = b3[1][2];
1979     b5[0] = b5[2];
1980     const uint8x8_t sr01 = vget_high_u8(sr0);
1981     const uint8x8_t sr11 = vget_high_u8(sr1);
1982     ma[0][0] = vld1q_u16(ma565[0] + x + 8);
1983     b[0][0].val[0] = vld1q_u32(b565[0] + x + 8);
1984     b[0][0].val[1] = vld1q_u32(b565[0] + x + 12);
1985     p[0][0] = CalculateFilteredOutputPass1(sr01, ma[0], b[0]);
1986     p[1][0] = CalculateFilteredOutput<4>(sr11, ma[0][1], b[0][1]);
1987     ma[1][0] = vld1q_u16(ma343[0] + x + 8);
1988     ma[1][1] = vld1q_u16(ma444[0] + x + 8);
1989     b[1][0].val[0] = vld1q_u32(b343[0] + x + 8);
1990     b[1][0].val[1] = vld1q_u32(b343[0] + x + 12);
1991     b[1][1].val[0] = vld1q_u32(b444[0] + x + 8);
1992     b[1][1].val[1] = vld1q_u32(b444[0] + x + 12);
1993     p[0][1] = CalculateFilteredOutputPass2(sr01, ma[1], b[1]);
1994     ma[2][0] = vld1q_u16(ma343[1] + x + 8);
1995     b[2][0].val[0] = vld1q_u32(b343[1] + x + 8);
1996     b[2][0].val[1] = vld1q_u32(b343[1] + x + 12);
1997     p[1][1] = CalculateFilteredOutputPass2(sr11, ma[2], b[2]);
1998     const uint8x8_t d01 = SelfGuidedDoubleMultiplier(sr01, p[0], w0, w2);
1999     const uint8x8_t d11 = SelfGuidedDoubleMultiplier(sr11, p[1], w0, w2);
2000     vst1q_u8(dst + x, vcombine_u8(d00, d01));
2001     vst1q_u8(dst + stride + x, vcombine_u8(d10, d11));
2002     s[0][0] = s[0][1];
2003     s[1][0] = s[1][1];
2004     sq[0][1] = sq[0][3];
2005     sq[1][1] = sq[1][3];
2006     ma3[0][0] = ma3[0][1];
2007     ma3[1][0] = ma3[1][1];
2008     ma5[0] = ma5[1];
2009     x += 16;
2010   } while (x < width);
2011 }
2012 
BoxFilterLastRow(const uint8_t * const src,const uint8_t * const src0,const int width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343,uint16_t * const ma444,uint16_t * const ma565,uint32_t * const b343,uint32_t * const b444,uint32_t * const b565,uint8_t * const dst)2013 inline void BoxFilterLastRow(
2014     const uint8_t* const src, const uint8_t* const src0, const int width,
2015     const uint16_t scales[2], const int16_t w0, const int16_t w2,
2016     uint16_t* const sum3[4], uint16_t* const sum5[5],
2017     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2018     uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565,
2019     uint32_t* const b343, uint32_t* const b444, uint32_t* const b565,
2020     uint8_t* const dst) {
2021   uint8x16_t s[2], ma3[2], ma5[2];
2022   uint16x8_t sq[4], ma[3], b3[3], b5[3];
2023   uint32x4x2_t b[3];
2024   s[0] = vld1q_u8(src0);
2025 
2026   BoxFilterPreProcessLastRowLo(s, scales, sum3, sum5, square_sum3, square_sum5,
2027                                sq, &ma3[0], &ma5[0], &b3[0], &b5[0]);
2028 
2029   int x = 0;
2030   do {
2031     uint8x16_t ma3x[3], ma5x[3];
2032     int16x8_t p[2];
2033     s[1] = vld1q_u8(src0 + x + 16);
2034 
2035     BoxFilterPreProcessLastRow(s, x + 8, scales, sum3, sum5, square_sum3,
2036                                square_sum5, sq + 1, ma3, ma5, &b3[1], &b5[1]);
2037     Prepare3_8<0>(ma5, ma5x);
2038     ma[1] = Sum565<0>(ma5x);
2039     b[1] = Sum565W(b5);
2040     Prepare3_8<0>(ma3, ma3x);
2041     ma[2] = Sum343<0>(ma3x);
2042     b[2] = Sum343W(b3);
2043     const uint8x16_t sr = vld1q_u8(src + x);
2044     const uint8x8_t sr0 = vget_low_u8(sr);
2045     ma[0] = vld1q_u16(ma565 + x);
2046     b[0].val[0] = vld1q_u32(b565 + x + 0);
2047     b[0].val[1] = vld1q_u32(b565 + x + 4);
2048     p[0] = CalculateFilteredOutputPass1(sr0, ma, b);
2049     ma[0] = vld1q_u16(ma343 + x);
2050     ma[1] = vld1q_u16(ma444 + x);
2051     b[0].val[0] = vld1q_u32(b343 + x + 0);
2052     b[0].val[1] = vld1q_u32(b343 + x + 4);
2053     b[1].val[0] = vld1q_u32(b444 + x + 0);
2054     b[1].val[1] = vld1q_u32(b444 + x + 4);
2055     p[1] = CalculateFilteredOutputPass2(sr0, ma, b);
2056     const uint8x8_t d0 = SelfGuidedDoubleMultiplier(sr0, p, w0, w2);
2057 
2058     ma[1] = Sum565<8>(ma5x);
2059     b[1] = Sum565W(b5 + 1);
2060     b5[0] = b5[2];
2061     ma[2] = Sum343<8>(ma3x);
2062     b[2] = Sum343W(b3 + 1);
2063     b3[0] = b3[2];
2064     const uint8x8_t sr1 = vget_high_u8(sr);
2065     ma[0] = vld1q_u16(ma565 + x + 8);
2066     b[0].val[0] = vld1q_u32(b565 + x + 8);
2067     b[0].val[1] = vld1q_u32(b565 + x + 12);
2068     p[0] = CalculateFilteredOutputPass1(sr1, ma, b);
2069     ma[0] = vld1q_u16(ma343 + x + 8);
2070     ma[1] = vld1q_u16(ma444 + x + 8);
2071     b[0].val[0] = vld1q_u32(b343 + x + 8);
2072     b[0].val[1] = vld1q_u32(b343 + x + 12);
2073     b[1].val[0] = vld1q_u32(b444 + x + 8);
2074     b[1].val[1] = vld1q_u32(b444 + x + 12);
2075     p[1] = CalculateFilteredOutputPass2(sr1, ma, b);
2076     const uint8x8_t d1 = SelfGuidedDoubleMultiplier(sr1, p, w0, w2);
2077     vst1q_u8(dst + x, vcombine_u8(d0, d1));
2078     s[0] = s[1];
2079     sq[1] = sq[3];
2080     ma3[0] = ma3[1];
2081     ma5[0] = ma5[1];
2082     x += 16;
2083   } while (x < width);
2084 }
2085 
BoxFilterProcess(const RestorationUnitInfo & restoration_info,const uint8_t * src,const ptrdiff_t stride,const uint8_t * const top_border,const ptrdiff_t top_border_stride,const uint8_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)2086 LIBGAV1_ALWAYS_INLINE void BoxFilterProcess(
2087     const RestorationUnitInfo& restoration_info, const uint8_t* src,
2088     const ptrdiff_t stride, const uint8_t* const top_border,
2089     const ptrdiff_t top_border_stride, const uint8_t* bottom_border,
2090     const ptrdiff_t bottom_border_stride, const int width, const int height,
2091     SgrBuffer* const sgr_buffer, uint8_t* dst) {
2092   const auto temp_stride = Align<ptrdiff_t>(width, 16);
2093   const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
2094   const ptrdiff_t sum_stride = temp_stride + 8;
2095   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2096   const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index];  // < 2^12.
2097   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
2098   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
2099   const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
2100   uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
2101   uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
2102   sum3[0] = sgr_buffer->sum3;
2103   square_sum3[0] = sgr_buffer->square_sum3;
2104   ma343[0] = sgr_buffer->ma343;
2105   b343[0] = sgr_buffer->b343;
2106   for (int i = 1; i <= 3; ++i) {
2107     sum3[i] = sum3[i - 1] + sum_stride;
2108     square_sum3[i] = square_sum3[i - 1] + sum_stride;
2109     ma343[i] = ma343[i - 1] + temp_stride;
2110     b343[i] = b343[i - 1] + temp_stride;
2111   }
2112   sum5[0] = sgr_buffer->sum5;
2113   square_sum5[0] = sgr_buffer->square_sum5;
2114   for (int i = 1; i <= 4; ++i) {
2115     sum5[i] = sum5[i - 1] + sum_stride;
2116     square_sum5[i] = square_sum5[i - 1] + sum_stride;
2117   }
2118   ma444[0] = sgr_buffer->ma444;
2119   b444[0] = sgr_buffer->b444;
2120   for (int i = 1; i <= 2; ++i) {
2121     ma444[i] = ma444[i - 1] + temp_stride;
2122     b444[i] = b444[i - 1] + temp_stride;
2123   }
2124   ma565[0] = sgr_buffer->ma565;
2125   ma565[1] = ma565[0] + temp_stride;
2126   b565[0] = sgr_buffer->b565;
2127   b565[1] = b565[0] + temp_stride;
2128   assert(scales[0] != 0);
2129   assert(scales[1] != 0);
2130   BoxSum(top_border, top_border_stride, width, sum_stride, sum_width, sum3[0],
2131          sum5[1], square_sum3[0], square_sum5[1]);
2132   sum5[0] = sum5[1];
2133   square_sum5[0] = square_sum5[1];
2134   const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
2135   BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3,
2136                          square_sum5, ma343, ma444[0], ma565[0], b343, b444[0],
2137                          b565[0]);
2138   sum5[0] = sgr_buffer->sum5;
2139   square_sum5[0] = sgr_buffer->square_sum5;
2140 
2141   for (int y = (height >> 1) - 1; y > 0; --y) {
2142     Circulate4PointersBy2<uint16_t>(sum3);
2143     Circulate4PointersBy2<uint32_t>(square_sum3);
2144     Circulate5PointersBy2<uint16_t>(sum5);
2145     Circulate5PointersBy2<uint32_t>(square_sum5);
2146     BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width,
2147               scales, w0, w2, sum3, sum5, square_sum3, square_sum5, ma343,
2148               ma444, ma565, b343, b444, b565, dst);
2149     src += 2 * stride;
2150     dst += 2 * stride;
2151     Circulate4PointersBy2<uint16_t>(ma343);
2152     Circulate4PointersBy2<uint32_t>(b343);
2153     std::swap(ma444[0], ma444[2]);
2154     std::swap(b444[0], b444[2]);
2155     std::swap(ma565[0], ma565[1]);
2156     std::swap(b565[0], b565[1]);
2157   }
2158 
2159   Circulate4PointersBy2<uint16_t>(sum3);
2160   Circulate4PointersBy2<uint32_t>(square_sum3);
2161   Circulate5PointersBy2<uint16_t>(sum5);
2162   Circulate5PointersBy2<uint32_t>(square_sum5);
2163   if ((height & 1) == 0 || height > 1) {
2164     const uint8_t* sr[2];
2165     if ((height & 1) == 0) {
2166       sr[0] = bottom_border;
2167       sr[1] = bottom_border + bottom_border_stride;
2168     } else {
2169       sr[0] = src + 2 * stride;
2170       sr[1] = bottom_border;
2171     }
2172     BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5,
2173               square_sum3, square_sum5, ma343, ma444, ma565, b343, b444, b565,
2174               dst);
2175   }
2176   if ((height & 1) != 0) {
2177     if (height > 1) {
2178       src += 2 * stride;
2179       dst += 2 * stride;
2180       Circulate4PointersBy2<uint16_t>(sum3);
2181       Circulate4PointersBy2<uint32_t>(square_sum3);
2182       Circulate5PointersBy2<uint16_t>(sum5);
2183       Circulate5PointersBy2<uint32_t>(square_sum5);
2184       Circulate4PointersBy2<uint16_t>(ma343);
2185       Circulate4PointersBy2<uint32_t>(b343);
2186       std::swap(ma444[0], ma444[2]);
2187       std::swap(b444[0], b444[2]);
2188       std::swap(ma565[0], ma565[1]);
2189       std::swap(b565[0], b565[1]);
2190     }
2191     BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width,
2192                      scales, w0, w2, sum3, sum5, square_sum3, square_sum5,
2193                      ma343[0], ma444[0], ma565[0], b343[0], b444[0], b565[0],
2194                      dst);
2195   }
2196 }
2197 
BoxFilterProcessPass1(const RestorationUnitInfo & restoration_info,const uint8_t * src,const ptrdiff_t stride,const uint8_t * const top_border,const ptrdiff_t top_border_stride,const uint8_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)2198 inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
2199                                   const uint8_t* src, const ptrdiff_t stride,
2200                                   const uint8_t* const top_border,
2201                                   const ptrdiff_t top_border_stride,
2202                                   const uint8_t* bottom_border,
2203                                   const ptrdiff_t bottom_border_stride,
2204                                   const int width, const int height,
2205                                   SgrBuffer* const sgr_buffer, uint8_t* dst) {
2206   const auto temp_stride = Align<ptrdiff_t>(width, 16);
2207   const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
2208   const ptrdiff_t sum_stride = temp_stride + 8;
2209   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2210   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0];  // < 2^12.
2211   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
2212   uint16_t *sum5[5], *ma565[2];
2213   uint32_t *square_sum5[5], *b565[2];
2214   sum5[0] = sgr_buffer->sum5;
2215   square_sum5[0] = sgr_buffer->square_sum5;
2216   for (int i = 1; i <= 4; ++i) {
2217     sum5[i] = sum5[i - 1] + sum_stride;
2218     square_sum5[i] = square_sum5[i - 1] + sum_stride;
2219   }
2220   ma565[0] = sgr_buffer->ma565;
2221   ma565[1] = ma565[0] + temp_stride;
2222   b565[0] = sgr_buffer->b565;
2223   b565[1] = b565[0] + temp_stride;
2224   assert(scale != 0);
2225   BoxSum<5>(top_border, top_border_stride, width, sum_stride, sum_width,
2226             sum5[1], square_sum5[1]);
2227   sum5[0] = sum5[1];
2228   square_sum5[0] = square_sum5[1];
2229   const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
2230   BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, ma565[0],
2231                           b565[0]);
2232   sum5[0] = sgr_buffer->sum5;
2233   square_sum5[0] = sgr_buffer->square_sum5;
2234 
2235   for (int y = (height >> 1) - 1; y > 0; --y) {
2236     Circulate5PointersBy2<uint16_t>(sum5);
2237     Circulate5PointersBy2<uint32_t>(square_sum5);
2238     BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5,
2239                    square_sum5, width, scale, w0, ma565, b565, dst);
2240     src += 2 * stride;
2241     dst += 2 * stride;
2242     std::swap(ma565[0], ma565[1]);
2243     std::swap(b565[0], b565[1]);
2244   }
2245 
2246   Circulate5PointersBy2<uint16_t>(sum5);
2247   Circulate5PointersBy2<uint32_t>(square_sum5);
2248   if ((height & 1) == 0 || height > 1) {
2249     const uint8_t* sr[2];
2250     if ((height & 1) == 0) {
2251       sr[0] = bottom_border;
2252       sr[1] = bottom_border + bottom_border_stride;
2253     } else {
2254       sr[0] = src + 2 * stride;
2255       sr[1] = bottom_border;
2256     }
2257     BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width,
2258                    scale, w0, ma565, b565, dst);
2259   }
2260   if ((height & 1) != 0) {
2261     if (height > 1) {
2262       src += 2 * stride;
2263       dst += 2 * stride;
2264       std::swap(ma565[0], ma565[1]);
2265       std::swap(b565[0], b565[1]);
2266       Circulate5PointersBy2<uint16_t>(sum5);
2267       Circulate5PointersBy2<uint32_t>(square_sum5);
2268     }
2269     BoxFilterPass1LastRow(src + 3, bottom_border + bottom_border_stride, width,
2270                           scale, w0, sum5, square_sum5, ma565[0], b565[0], dst);
2271   }
2272 }
2273 
BoxFilterProcessPass2(const RestorationUnitInfo & restoration_info,const uint8_t * src,const ptrdiff_t stride,const uint8_t * const top_border,const ptrdiff_t top_border_stride,const uint8_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)2274 inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
2275                                   const uint8_t* src, const ptrdiff_t stride,
2276                                   const uint8_t* const top_border,
2277                                   const ptrdiff_t top_border_stride,
2278                                   const uint8_t* bottom_border,
2279                                   const ptrdiff_t bottom_border_stride,
2280                                   const int width, const int height,
2281                                   SgrBuffer* const sgr_buffer, uint8_t* dst) {
2282   assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
2283   const auto temp_stride = Align<ptrdiff_t>(width, 16);
2284   const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
2285   const ptrdiff_t sum_stride = temp_stride + 8;
2286   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
2287   const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
2288   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2289   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1];  // < 2^12.
2290   uint16_t *sum3[3], *ma343[3], *ma444[2];
2291   uint32_t *square_sum3[3], *b343[3], *b444[2];
2292   sum3[0] = sgr_buffer->sum3;
2293   square_sum3[0] = sgr_buffer->square_sum3;
2294   ma343[0] = sgr_buffer->ma343;
2295   b343[0] = sgr_buffer->b343;
2296   for (int i = 1; i <= 2; ++i) {
2297     sum3[i] = sum3[i - 1] + sum_stride;
2298     square_sum3[i] = square_sum3[i - 1] + sum_stride;
2299     ma343[i] = ma343[i - 1] + temp_stride;
2300     b343[i] = b343[i - 1] + temp_stride;
2301   }
2302   ma444[0] = sgr_buffer->ma444;
2303   ma444[1] = ma444[0] + temp_stride;
2304   b444[0] = sgr_buffer->b444;
2305   b444[1] = b444[0] + temp_stride;
2306   assert(scale != 0);
2307   BoxSum<3>(top_border, top_border_stride, width, sum_stride, sum_width,
2308             sum3[0], square_sum3[0]);
2309   BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, ma343[0],
2310                                  nullptr, b343[0], nullptr);
2311   Circulate3PointersBy1<uint16_t>(sum3);
2312   Circulate3PointersBy1<uint32_t>(square_sum3);
2313   const uint8_t* s;
2314   if (height > 1) {
2315     s = src + stride;
2316   } else {
2317     s = bottom_border;
2318     bottom_border += bottom_border_stride;
2319   }
2320   BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, ma343[1],
2321                                 ma444[0], b343[1], b444[0]);
2322 
2323   for (int y = height - 2; y > 0; --y) {
2324     Circulate3PointersBy1<uint16_t>(sum3);
2325     Circulate3PointersBy1<uint32_t>(square_sum3);
2326     BoxFilterPass2(src + 2, src + 2 * stride, width, scale, w0, sum3,
2327                    square_sum3, ma343, ma444, b343, b444, dst);
2328     src += stride;
2329     dst += stride;
2330     Circulate3PointersBy1<uint16_t>(ma343);
2331     Circulate3PointersBy1<uint32_t>(b343);
2332     std::swap(ma444[0], ma444[1]);
2333     std::swap(b444[0], b444[1]);
2334   }
2335 
2336   src += 2;
2337   int y = std::min(height, 2);
2338   do {
2339     Circulate3PointersBy1<uint16_t>(sum3);
2340     Circulate3PointersBy1<uint32_t>(square_sum3);
2341     BoxFilterPass2(src, bottom_border, width, scale, w0, sum3, square_sum3,
2342                    ma343, ma444, b343, b444, dst);
2343     src += stride;
2344     dst += stride;
2345     bottom_border += bottom_border_stride;
2346     Circulate3PointersBy1<uint16_t>(ma343);
2347     Circulate3PointersBy1<uint32_t>(b343);
2348     std::swap(ma444[0], ma444[1]);
2349     std::swap(b444[0], b444[1]);
2350   } while (--y != 0);
2351 }
2352 
2353 // If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in
2354 // the end of each row. It is safe to overwrite the output as it will not be
2355 // part of the visible frame.
SelfGuidedFilter_NEON(const RestorationUnitInfo & LIBGAV1_RESTRICT restoration_info,const void * LIBGAV1_RESTRICT const source,const ptrdiff_t stride,const void * LIBGAV1_RESTRICT const top_border,const ptrdiff_t top_border_stride,const void * LIBGAV1_RESTRICT const bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,RestorationBuffer * LIBGAV1_RESTRICT const restoration_buffer,void * LIBGAV1_RESTRICT const dest)2356 void SelfGuidedFilter_NEON(
2357     const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
2358     const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
2359     const void* LIBGAV1_RESTRICT const top_border,
2360     const ptrdiff_t top_border_stride,
2361     const void* LIBGAV1_RESTRICT const bottom_border,
2362     const ptrdiff_t bottom_border_stride, const int width, const int height,
2363     RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
2364     void* LIBGAV1_RESTRICT const dest) {
2365   const int index = restoration_info.sgr_proj_info.index;
2366   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
2367   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
2368   const auto* const src = static_cast<const uint8_t*>(source);
2369   const auto* top = static_cast<const uint8_t*>(top_border);
2370   const auto* bottom = static_cast<const uint8_t*>(bottom_border);
2371   auto* const dst = static_cast<uint8_t*>(dest);
2372   SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
2373 
2374 #if LIBGAV1_MSAN
2375   // Initialize to prevent msan warnings when intermediate overreads occur.
2376   memset(sgr_buffer, 0, sizeof(SgrBuffer));
2377 #endif
2378 
2379   if (radius_pass_1 == 0) {
2380     // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
2381     // following assertion.
2382     assert(radius_pass_0 != 0);
2383     BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3,
2384                           top_border_stride, bottom - 3, bottom_border_stride,
2385                           width, height, sgr_buffer, dst);
2386   } else if (radius_pass_0 == 0) {
2387     BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2,
2388                           top_border_stride, bottom - 2, bottom_border_stride,
2389                           width, height, sgr_buffer, dst);
2390   } else {
2391     BoxFilterProcess(restoration_info, src - 3, stride, top - 3,
2392                      top_border_stride, bottom - 3, bottom_border_stride, width,
2393                      height, sgr_buffer, dst);
2394   }
2395 }
2396 
Init8bpp()2397 void Init8bpp() {
2398   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
2399   assert(dsp != nullptr);
2400   dsp->loop_restorations[0] = WienerFilter_NEON;
2401   dsp->loop_restorations[1] = SelfGuidedFilter_NEON;
2402 }
2403 
2404 }  // namespace
2405 }  // namespace low_bitdepth
2406 
LoopRestorationInit_NEON()2407 void LoopRestorationInit_NEON() { low_bitdepth::Init8bpp(); }
2408 
2409 }  // namespace dsp
2410 }  // namespace libgav1
2411 
2412 #else   // !LIBGAV1_ENABLE_NEON
2413 namespace libgav1 {
2414 namespace dsp {
2415 
LoopRestorationInit_NEON()2416 void LoopRestorationInit_NEON() {}
2417 
2418 }  // namespace dsp
2419 }  // namespace libgav1
2420 #endif  // LIBGAV1_ENABLE_NEON
2421