xref: /aosp_15_r20/external/libgav1/src/dsp/arm/loop_restoration_10bit_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2021 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/loop_restoration.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
19 #include <arm_neon.h>
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cstdint>
24 
25 #include "src/dsp/arm/common_neon.h"
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/utils/common.h"
29 #include "src/utils/compiler_attributes.h"
30 #include "src/utils/constants.h"
31 
32 namespace libgav1 {
33 namespace dsp {
34 namespace {
35 
36 //------------------------------------------------------------------------------
37 // Wiener
38 
39 // Must make a local copy of coefficients to help compiler know that they have
40 // no overlap with other buffers. Using 'const' keyword is not enough. Actually
41 // compiler doesn't make a copy, since there is enough registers in this case.
PopulateWienerCoefficients(const RestorationUnitInfo & restoration_info,const int direction,int16_t filter[4])42 inline void PopulateWienerCoefficients(
43     const RestorationUnitInfo& restoration_info, const int direction,
44     int16_t filter[4]) {
45   for (int i = 0; i < 4; ++i) {
46     filter[i] = restoration_info.wiener_info.filter[direction][i];
47   }
48 }
49 
WienerHorizontal2(const uint16x8_t s0,const uint16x8_t s1,const int16_t filter,const int32x4x2_t sum)50 inline int32x4x2_t WienerHorizontal2(const uint16x8_t s0, const uint16x8_t s1,
51                                      const int16_t filter,
52                                      const int32x4x2_t sum) {
53   const int16x8_t ss = vreinterpretq_s16_u16(vaddq_u16(s0, s1));
54   int32x4x2_t res;
55   res.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(ss), filter);
56   res.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(ss), filter);
57   return res;
58 }
59 
WienerHorizontalSum(const uint16x8_t s[3],const int16_t filter[4],int32x4x2_t sum,int16_t * const wiener_buffer)60 inline void WienerHorizontalSum(const uint16x8_t s[3], const int16_t filter[4],
61                                 int32x4x2_t sum, int16_t* const wiener_buffer) {
62   constexpr int offset =
63       1 << (kBitdepth10 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
64   constexpr int limit = (offset << 2) - 1;
65   const int16x8_t s_0_2 = vreinterpretq_s16_u16(vaddq_u16(s[0], s[2]));
66   const int16x8_t s_1 = vreinterpretq_s16_u16(s[1]);
67   int16x4x2_t sum16;
68   sum.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(s_0_2), filter[2]);
69   sum.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(s_1), filter[3]);
70   sum16.val[0] = vqshrn_n_s32(sum.val[0], kInterRoundBitsHorizontal);
71   sum16.val[0] = vmax_s16(sum16.val[0], vdup_n_s16(-offset));
72   sum16.val[0] = vmin_s16(sum16.val[0], vdup_n_s16(limit - offset));
73   vst1_s16(wiener_buffer, sum16.val[0]);
74   sum.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(s_0_2), filter[2]);
75   sum.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(s_1), filter[3]);
76   sum16.val[1] = vqshrn_n_s32(sum.val[1], kInterRoundBitsHorizontal);
77   sum16.val[1] = vmax_s16(sum16.val[1], vdup_n_s16(-offset));
78   sum16.val[1] = vmin_s16(sum16.val[1], vdup_n_s16(limit - offset));
79   vst1_s16(wiener_buffer + 4, sum16.val[1]);
80 }
81 
WienerHorizontalTap7(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t wiener_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)82 inline void WienerHorizontalTap7(const uint16_t* src,
83                                  const ptrdiff_t src_stride,
84                                  const ptrdiff_t wiener_stride,
85                                  const ptrdiff_t width, const int height,
86                                  const int16_t filter[4],
87                                  int16_t** const wiener_buffer) {
88   const ptrdiff_t src_width =
89       width + ((kRestorationHorizontalBorder - 1) * sizeof(*src));
90   for (int y = height; y != 0; --y) {
91     const uint16_t* src_ptr = src;
92     uint16x8_t s[8];
93     s[0] = vld1q_u16(src_ptr);
94     ptrdiff_t x = wiener_stride;
95     ptrdiff_t valid_bytes = src_width * 2;
96     do {
97       src_ptr += 8;
98       valid_bytes -= 16;
99       s[7] = Load1QMsanU16(src_ptr, 16 - valid_bytes);
100       s[1] = vextq_u16(s[0], s[7], 1);
101       s[2] = vextq_u16(s[0], s[7], 2);
102       s[3] = vextq_u16(s[0], s[7], 3);
103       s[4] = vextq_u16(s[0], s[7], 4);
104       s[5] = vextq_u16(s[0], s[7], 5);
105       s[6] = vextq_u16(s[0], s[7], 6);
106       int32x4x2_t sum;
107       sum.val[0] = sum.val[1] =
108           vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 1));
109       sum = WienerHorizontal2(s[0], s[6], filter[0], sum);
110       sum = WienerHorizontal2(s[1], s[5], filter[1], sum);
111       WienerHorizontalSum(s + 2, filter, sum, *wiener_buffer);
112       s[0] = s[7];
113       *wiener_buffer += 8;
114       x -= 8;
115     } while (x != 0);
116     src += src_stride;
117   }
118 }
119 
WienerHorizontalTap5(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t wiener_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)120 inline void WienerHorizontalTap5(const uint16_t* src,
121                                  const ptrdiff_t src_stride,
122                                  const ptrdiff_t wiener_stride,
123                                  const ptrdiff_t width, const int height,
124                                  const int16_t filter[4],
125                                  int16_t** const wiener_buffer) {
126   const ptrdiff_t src_width =
127       width + ((kRestorationHorizontalBorder - 1) * sizeof(*src));
128   for (int y = height; y != 0; --y) {
129     const uint16_t* src_ptr = src;
130     uint16x8_t s[6];
131     s[0] = vld1q_u16(src_ptr);
132     ptrdiff_t x = wiener_stride;
133     ptrdiff_t valid_bytes = src_width * 2;
134     do {
135       src_ptr += 8;
136       valid_bytes -= 16;
137       s[5] = Load1QMsanU16(src_ptr, 16 - valid_bytes);
138       s[1] = vextq_u16(s[0], s[5], 1);
139       s[2] = vextq_u16(s[0], s[5], 2);
140       s[3] = vextq_u16(s[0], s[5], 3);
141       s[4] = vextq_u16(s[0], s[5], 4);
142 
143       int32x4x2_t sum;
144       sum.val[0] = sum.val[1] =
145           vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 1));
146       sum = WienerHorizontal2(s[0], s[4], filter[1], sum);
147       WienerHorizontalSum(s + 1, filter, sum, *wiener_buffer);
148       s[0] = s[5];
149       *wiener_buffer += 8;
150       x -= 8;
151     } while (x != 0);
152     src += src_stride;
153   }
154 }
155 
WienerHorizontalTap3(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)156 inline void WienerHorizontalTap3(const uint16_t* src,
157                                  const ptrdiff_t src_stride,
158                                  const ptrdiff_t width, const int height,
159                                  const int16_t filter[4],
160                                  int16_t** const wiener_buffer) {
161   for (int y = height; y != 0; --y) {
162     const uint16_t* src_ptr = src;
163     uint16x8_t s[3];
164     ptrdiff_t x = width;
165     do {
166       s[0] = vld1q_u16(src_ptr);
167       s[1] = vld1q_u16(src_ptr + 1);
168       s[2] = vld1q_u16(src_ptr + 2);
169 
170       int32x4x2_t sum;
171       sum.val[0] = sum.val[1] =
172           vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 1));
173       WienerHorizontalSum(s, filter, sum, *wiener_buffer);
174       src_ptr += 8;
175       *wiener_buffer += 8;
176       x -= 8;
177     } while (x != 0);
178     src += src_stride;
179   }
180 }
181 
WienerHorizontalTap1(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,int16_t ** const wiener_buffer)182 inline void WienerHorizontalTap1(const uint16_t* src,
183                                  const ptrdiff_t src_stride,
184                                  const ptrdiff_t width, const int height,
185                                  int16_t** const wiener_buffer) {
186   for (int y = height; y != 0; --y) {
187     ptrdiff_t x = 0;
188     do {
189       const uint16x8_t s = vld1q_u16(src + x);
190       const int16x8_t d = vreinterpretq_s16_u16(vshlq_n_u16(s, 4));
191       vst1q_s16(*wiener_buffer + x, d);
192       x += 8;
193     } while (x < width);
194     src += src_stride;
195     *wiener_buffer += width;
196   }
197 }
198 
WienerVertical2(const int16x8_t a0,const int16x8_t a1,const int16_t filter,const int32x4x2_t sum)199 inline int32x4x2_t WienerVertical2(const int16x8_t a0, const int16x8_t a1,
200                                    const int16_t filter,
201                                    const int32x4x2_t sum) {
202   int32x4x2_t d;
203   d.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(a0), filter);
204   d.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(a0), filter);
205   d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a1), filter);
206   d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a1), filter);
207   return d;
208 }
209 
WienerVertical(const int16x8_t a[3],const int16_t filter[4],const int32x4x2_t sum)210 inline uint16x8_t WienerVertical(const int16x8_t a[3], const int16_t filter[4],
211                                  const int32x4x2_t sum) {
212   int32x4x2_t d = WienerVertical2(a[0], a[2], filter[2], sum);
213   d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a[1]), filter[3]);
214   d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a[1]), filter[3]);
215   const uint16x4_t sum_lo_16 = vqrshrun_n_s32(d.val[0], 11);
216   const uint16x4_t sum_hi_16 = vqrshrun_n_s32(d.val[1], 11);
217   return vcombine_u16(sum_lo_16, sum_hi_16);
218 }
219 
WienerVerticalTap7Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[7])220 inline uint16x8_t WienerVerticalTap7Kernel(const int16_t* const wiener_buffer,
221                                            const ptrdiff_t wiener_stride,
222                                            const int16_t filter[4],
223                                            int16x8_t a[7]) {
224   int32x4x2_t sum;
225   a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
226   a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
227   a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
228   a[6] = vld1q_s16(wiener_buffer + 6 * wiener_stride);
229   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
230   sum = WienerVertical2(a[0], a[6], filter[0], sum);
231   sum = WienerVertical2(a[1], a[5], filter[1], sum);
232   a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
233   a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
234   a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
235   return WienerVertical(a + 2, filter, sum);
236 }
237 
WienerVerticalTap7Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])238 inline uint16x8x2_t WienerVerticalTap7Kernel2(
239     const int16_t* const wiener_buffer, const ptrdiff_t wiener_stride,
240     const int16_t filter[4]) {
241   int16x8_t a[8];
242   int32x4x2_t sum;
243   uint16x8x2_t d;
244   d.val[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
245   a[7] = vld1q_s16(wiener_buffer + 7 * wiener_stride);
246   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
247   sum = WienerVertical2(a[1], a[7], filter[0], sum);
248   sum = WienerVertical2(a[2], a[6], filter[1], sum);
249   d.val[1] = WienerVertical(a + 3, filter, sum);
250   return d;
251 }
252 
WienerVerticalTap7(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint16_t * dst,const ptrdiff_t dst_stride)253 inline void WienerVerticalTap7(const int16_t* wiener_buffer,
254                                const ptrdiff_t width, const int height,
255                                const int16_t filter[4], uint16_t* dst,
256                                const ptrdiff_t dst_stride) {
257   const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
258   for (int y = height >> 1; y != 0; --y) {
259     uint16_t* dst_ptr = dst;
260     ptrdiff_t x = width;
261     do {
262       uint16x8x2_t d[2];
263       d[0] = WienerVerticalTap7Kernel2(wiener_buffer + 0, width, filter);
264       d[1] = WienerVerticalTap7Kernel2(wiener_buffer + 8, width, filter);
265       vst1q_u16(dst_ptr, vminq_u16(d[0].val[0], v_max_bitdepth));
266       vst1q_u16(dst_ptr + 8, vminq_u16(d[1].val[0], v_max_bitdepth));
267       vst1q_u16(dst_ptr + dst_stride, vminq_u16(d[0].val[1], v_max_bitdepth));
268       vst1q_u16(dst_ptr + 8 + dst_stride,
269                 vminq_u16(d[1].val[1], v_max_bitdepth));
270       wiener_buffer += 16;
271       dst_ptr += 16;
272       x -= 16;
273     } while (x != 0);
274     wiener_buffer += width;
275     dst += 2 * dst_stride;
276   }
277 
278   if ((height & 1) != 0) {
279     ptrdiff_t x = width;
280     do {
281       int16x8_t a[7];
282       const uint16x8_t d0 =
283           WienerVerticalTap7Kernel(wiener_buffer + 0, width, filter, a);
284       const uint16x8_t d1 =
285           WienerVerticalTap7Kernel(wiener_buffer + 8, width, filter, a);
286       vst1q_u16(dst, vminq_u16(d0, v_max_bitdepth));
287       vst1q_u16(dst + 8, vminq_u16(d1, v_max_bitdepth));
288       wiener_buffer += 16;
289       dst += 16;
290       x -= 16;
291     } while (x != 0);
292   }
293 }
294 
WienerVerticalTap5Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[5])295 inline uint16x8_t WienerVerticalTap5Kernel(const int16_t* const wiener_buffer,
296                                            const ptrdiff_t wiener_stride,
297                                            const int16_t filter[4],
298                                            int16x8_t a[5]) {
299   a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
300   a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
301   a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
302   a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
303   a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
304   int32x4x2_t sum;
305   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
306   sum = WienerVertical2(a[0], a[4], filter[1], sum);
307   return WienerVertical(a + 1, filter, sum);
308 }
309 
WienerVerticalTap5Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])310 inline uint16x8x2_t WienerVerticalTap5Kernel2(
311     const int16_t* const wiener_buffer, const ptrdiff_t wiener_stride,
312     const int16_t filter[4]) {
313   int16x8_t a[6];
314   int32x4x2_t sum;
315   uint16x8x2_t d;
316   d.val[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
317   a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
318   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
319   sum = WienerVertical2(a[1], a[5], filter[1], sum);
320   d.val[1] = WienerVertical(a + 2, filter, sum);
321   return d;
322 }
323 
WienerVerticalTap5(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint16_t * dst,const ptrdiff_t dst_stride)324 inline void WienerVerticalTap5(const int16_t* wiener_buffer,
325                                const ptrdiff_t width, const int height,
326                                const int16_t filter[4], uint16_t* dst,
327                                const ptrdiff_t dst_stride) {
328   const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
329   for (int y = height >> 1; y != 0; --y) {
330     uint16_t* dst_ptr = dst;
331     ptrdiff_t x = width;
332     do {
333       uint16x8x2_t d[2];
334       d[0] = WienerVerticalTap5Kernel2(wiener_buffer + 0, width, filter);
335       d[1] = WienerVerticalTap5Kernel2(wiener_buffer + 8, width, filter);
336       vst1q_u16(dst_ptr, vminq_u16(d[0].val[0], v_max_bitdepth));
337       vst1q_u16(dst_ptr + 8, vminq_u16(d[1].val[0], v_max_bitdepth));
338       vst1q_u16(dst_ptr + dst_stride, vminq_u16(d[0].val[1], v_max_bitdepth));
339       vst1q_u16(dst_ptr + 8 + dst_stride,
340                 vminq_u16(d[1].val[1], v_max_bitdepth));
341       wiener_buffer += 16;
342       dst_ptr += 16;
343       x -= 16;
344     } while (x != 0);
345     wiener_buffer += width;
346     dst += 2 * dst_stride;
347   }
348 
349   if ((height & 1) != 0) {
350     ptrdiff_t x = width;
351     do {
352       int16x8_t a[5];
353       const uint16x8_t d0 =
354           WienerVerticalTap5Kernel(wiener_buffer + 0, width, filter, a);
355       const uint16x8_t d1 =
356           WienerVerticalTap5Kernel(wiener_buffer + 8, width, filter, a);
357       vst1q_u16(dst, vminq_u16(d0, v_max_bitdepth));
358       vst1q_u16(dst + 8, vminq_u16(d1, v_max_bitdepth));
359       wiener_buffer += 16;
360       dst += 16;
361       x -= 16;
362     } while (x != 0);
363   }
364 }
365 
WienerVerticalTap3Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[3])366 inline uint16x8_t WienerVerticalTap3Kernel(const int16_t* const wiener_buffer,
367                                            const ptrdiff_t wiener_stride,
368                                            const int16_t filter[4],
369                                            int16x8_t a[3]) {
370   a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
371   a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
372   a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
373   int32x4x2_t sum;
374   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
375   return WienerVertical(a, filter, sum);
376 }
377 
WienerVerticalTap3Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])378 inline uint16x8x2_t WienerVerticalTap3Kernel2(
379     const int16_t* const wiener_buffer, const ptrdiff_t wiener_stride,
380     const int16_t filter[4]) {
381   int16x8_t a[4];
382   int32x4x2_t sum;
383   uint16x8x2_t d;
384   d.val[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
385   a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
386   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
387   d.val[1] = WienerVertical(a + 1, filter, sum);
388   return d;
389 }
390 
WienerVerticalTap3(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint16_t * dst,const ptrdiff_t dst_stride)391 inline void WienerVerticalTap3(const int16_t* wiener_buffer,
392                                const ptrdiff_t width, const int height,
393                                const int16_t filter[4], uint16_t* dst,
394                                const ptrdiff_t dst_stride) {
395   const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
396 
397   for (int y = height >> 1; y != 0; --y) {
398     uint16_t* dst_ptr = dst;
399     ptrdiff_t x = width;
400     do {
401       uint16x8x2_t d[2];
402       d[0] = WienerVerticalTap3Kernel2(wiener_buffer + 0, width, filter);
403       d[1] = WienerVerticalTap3Kernel2(wiener_buffer + 8, width, filter);
404 
405       vst1q_u16(dst_ptr, vminq_u16(d[0].val[0], v_max_bitdepth));
406       vst1q_u16(dst_ptr + 8, vminq_u16(d[1].val[0], v_max_bitdepth));
407       vst1q_u16(dst_ptr + dst_stride, vminq_u16(d[0].val[1], v_max_bitdepth));
408       vst1q_u16(dst_ptr + 8 + dst_stride,
409                 vminq_u16(d[1].val[1], v_max_bitdepth));
410 
411       wiener_buffer += 16;
412       dst_ptr += 16;
413       x -= 16;
414     } while (x != 0);
415     wiener_buffer += width;
416     dst += 2 * dst_stride;
417   }
418 
419   if ((height & 1) != 0) {
420     ptrdiff_t x = width;
421     do {
422       int16x8_t a[3];
423       const uint16x8_t d0 =
424           WienerVerticalTap3Kernel(wiener_buffer + 0, width, filter, a);
425       const uint16x8_t d1 =
426           WienerVerticalTap3Kernel(wiener_buffer + 8, width, filter, a);
427       vst1q_u16(dst, vminq_u16(d0, v_max_bitdepth));
428       vst1q_u16(dst + 8, vminq_u16(d1, v_max_bitdepth));
429       wiener_buffer += 16;
430       dst += 16;
431       x -= 16;
432     } while (x != 0);
433   }
434 }
435 
WienerVerticalTap1Kernel(const int16_t * const wiener_buffer,uint16_t * const dst)436 inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
437                                      uint16_t* const dst) {
438   const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
439   const int16x8_t a0 = vld1q_s16(wiener_buffer + 0);
440   const int16x8_t a1 = vld1q_s16(wiener_buffer + 8);
441   const int16x8_t d0 = vrshrq_n_s16(a0, 4);
442   const int16x8_t d1 = vrshrq_n_s16(a1, 4);
443   vst1q_u16(dst, vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(d0, vdupq_n_s16(0))),
444                            v_max_bitdepth));
445   vst1q_u16(dst + 8,
446             vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(d1, vdupq_n_s16(0))),
447                       v_max_bitdepth));
448 }
449 
WienerVerticalTap1(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,uint16_t * dst,const ptrdiff_t dst_stride)450 inline void WienerVerticalTap1(const int16_t* wiener_buffer,
451                                const ptrdiff_t width, const int height,
452                                uint16_t* dst, const ptrdiff_t dst_stride) {
453   for (int y = height >> 1; y != 0; --y) {
454     uint16_t* dst_ptr = dst;
455     ptrdiff_t x = width;
456     do {
457       WienerVerticalTap1Kernel(wiener_buffer, dst_ptr);
458       WienerVerticalTap1Kernel(wiener_buffer + width, dst_ptr + dst_stride);
459       wiener_buffer += 16;
460       dst_ptr += 16;
461       x -= 16;
462     } while (x != 0);
463     wiener_buffer += width;
464     dst += 2 * dst_stride;
465   }
466 
467   if ((height & 1) != 0) {
468     ptrdiff_t x = width;
469     do {
470       WienerVerticalTap1Kernel(wiener_buffer, dst);
471       wiener_buffer += 16;
472       dst += 16;
473       x -= 16;
474     } while (x != 0);
475   }
476 }
477 
478 // For width 16 and up, store the horizontal results, and then do the vertical
479 // filter row by row. This is faster than doing it column by column when
480 // considering cache issues.
WienerFilter_NEON(const RestorationUnitInfo & LIBGAV1_RESTRICT restoration_info,const void * LIBGAV1_RESTRICT const source,const ptrdiff_t stride,const void * LIBGAV1_RESTRICT const top_border,const ptrdiff_t top_border_stride,const void * LIBGAV1_RESTRICT const bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,RestorationBuffer * LIBGAV1_RESTRICT const restoration_buffer,void * LIBGAV1_RESTRICT const dest)481 void WienerFilter_NEON(
482     const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
483     const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
484     const void* LIBGAV1_RESTRICT const top_border,
485     const ptrdiff_t top_border_stride,
486     const void* LIBGAV1_RESTRICT const bottom_border,
487     const ptrdiff_t bottom_border_stride, const int width, const int height,
488     RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
489     void* LIBGAV1_RESTRICT const dest) {
490   const int16_t* const number_leading_zero_coefficients =
491       restoration_info.wiener_info.number_leading_zero_coefficients;
492   const int number_rows_to_skip = std::max(
493       static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
494       1);
495   const ptrdiff_t wiener_stride = Align(width, 16);
496   int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer;
497   // The values are saturated to 13 bits before storing.
498   int16_t* wiener_buffer_horizontal =
499       wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
500   int16_t filter_horizontal[(kWienerFilterTaps + 1) / 2];
501   int16_t filter_vertical[(kWienerFilterTaps + 1) / 2];
502   PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal,
503                              filter_horizontal);
504   PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical,
505                              filter_vertical);
506   // horizontal filtering.
507   const int height_horizontal =
508       height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
509   const int height_extra = (height_horizontal - height) >> 1;
510   assert(height_extra <= 2);
511   const auto* const src = static_cast<const uint16_t*>(source);
512   const auto* const top = static_cast<const uint16_t*>(top_border);
513   const auto* const bottom = static_cast<const uint16_t*>(bottom_border);
514   if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
515     WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3,
516                          top_border_stride, wiener_stride, width, height_extra,
517                          filter_horizontal, &wiener_buffer_horizontal);
518     WienerHorizontalTap7(src - 3, stride, wiener_stride, width, height,
519                          filter_horizontal, &wiener_buffer_horizontal);
520     WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride, width,
521                          height_extra, filter_horizontal,
522                          &wiener_buffer_horizontal);
523   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
524     WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2,
525                          top_border_stride, wiener_stride, width, height_extra,
526                          filter_horizontal, &wiener_buffer_horizontal);
527     WienerHorizontalTap5(src - 2, stride, wiener_stride, width, height,
528                          filter_horizontal, &wiener_buffer_horizontal);
529     WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride, width,
530                          height_extra, filter_horizontal,
531                          &wiener_buffer_horizontal);
532   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
533     WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1,
534                          top_border_stride, wiener_stride, height_extra,
535                          filter_horizontal, &wiener_buffer_horizontal);
536     WienerHorizontalTap3(src - 1, stride, wiener_stride, height,
537                          filter_horizontal, &wiener_buffer_horizontal);
538     WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride,
539                          height_extra, filter_horizontal,
540                          &wiener_buffer_horizontal);
541   } else {
542     assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
543     WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride,
544                          top_border_stride, wiener_stride, height_extra,
545                          &wiener_buffer_horizontal);
546     WienerHorizontalTap1(src, stride, wiener_stride, height,
547                          &wiener_buffer_horizontal);
548     WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride,
549                          height_extra, &wiener_buffer_horizontal);
550   }
551 
552   // vertical filtering.
553   auto* dst = static_cast<uint16_t*>(dest);
554   if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
555     // Because the top row of |source| is a duplicate of the second row, and the
556     // bottom row of |source| is a duplicate of its above row, we can duplicate
557     // the top and bottom row of |wiener_buffer| accordingly.
558     memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
559            sizeof(*wiener_buffer_horizontal) * wiener_stride);
560     memcpy(restoration_buffer->wiener_buffer,
561            restoration_buffer->wiener_buffer + wiener_stride,
562            sizeof(*restoration_buffer->wiener_buffer) * wiener_stride);
563     WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
564                        filter_vertical, dst, stride);
565   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
566     WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
567                        height, filter_vertical, dst, stride);
568   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
569     WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
570                        wiener_stride, height, filter_vertical, dst, stride);
571   } else {
572     assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
573     WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
574                        wiener_stride, height, dst, stride);
575   }
576 }
577 
578 //------------------------------------------------------------------------------
579 // SGR
580 
581 // SIMD overreads 8 - (width % 8) - 2 * padding pixels, where padding is 3 for
582 // Pass 1 and 2 for Pass 2.
583 constexpr int kOverreadInBytesPass1 = 4;
584 constexpr int kOverreadInBytesPass2 = 8;
585 
LoadAligned16x2U16(const uint16_t * const src[2],const ptrdiff_t x,uint16x8_t dst[2])586 inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x,
587                                uint16x8_t dst[2]) {
588   dst[0] = vld1q_u16(src[0] + x);
589   dst[1] = vld1q_u16(src[1] + x);
590 }
591 
LoadAligned16x2U16Msan(const uint16_t * const src[2],const ptrdiff_t x,const ptrdiff_t border,uint16x8_t dst[2])592 inline void LoadAligned16x2U16Msan(const uint16_t* const src[2],
593                                    const ptrdiff_t x, const ptrdiff_t border,
594                                    uint16x8_t dst[2]) {
595   dst[0] = Load1QMsanU16(src[0] + x, sizeof(**src) * (x + 8 - border));
596   dst[1] = Load1QMsanU16(src[1] + x, sizeof(**src) * (x + 8 - border));
597 }
598 
LoadAligned16x3U16(const uint16_t * const src[3],const ptrdiff_t x,uint16x8_t dst[3])599 inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x,
600                                uint16x8_t dst[3]) {
601   dst[0] = vld1q_u16(src[0] + x);
602   dst[1] = vld1q_u16(src[1] + x);
603   dst[2] = vld1q_u16(src[2] + x);
604 }
605 
LoadAligned16x3U16Msan(const uint16_t * const src[3],const ptrdiff_t x,const ptrdiff_t border,uint16x8_t dst[3])606 inline void LoadAligned16x3U16Msan(const uint16_t* const src[3],
607                                    const ptrdiff_t x, const ptrdiff_t border,
608                                    uint16x8_t dst[3]) {
609   dst[0] = Load1QMsanU16(src[0] + x, sizeof(**src) * (x + 8 - border));
610   dst[1] = Load1QMsanU16(src[1] + x, sizeof(**src) * (x + 8 - border));
611   dst[2] = Load1QMsanU16(src[2] + x, sizeof(**src) * (x + 8 - border));
612 }
613 
LoadAligned32U32(const uint32_t * const src,uint32x4_t dst[2])614 inline void LoadAligned32U32(const uint32_t* const src, uint32x4_t dst[2]) {
615   dst[0] = vld1q_u32(src + 0);
616   dst[1] = vld1q_u32(src + 4);
617 }
618 
LoadAligned32U32Msan(const uint32_t * const src,const ptrdiff_t x,const ptrdiff_t border,uint32x4_t dst[2])619 inline void LoadAligned32U32Msan(const uint32_t* const src, const ptrdiff_t x,
620                                  const ptrdiff_t border, uint32x4_t dst[2]) {
621   dst[0] = Load1QMsanU32(src + x + 0, sizeof(*src) * (x + 4 - border));
622   dst[1] = Load1QMsanU32(src + x + 4, sizeof(*src) * (x + 8 - border));
623 }
624 
LoadAligned32x2U32(const uint32_t * const src[2],const ptrdiff_t x,uint32x4_t dst[2][2])625 inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x,
626                                uint32x4_t dst[2][2]) {
627   LoadAligned32U32(src[0] + x, dst[0]);
628   LoadAligned32U32(src[1] + x, dst[1]);
629 }
630 
LoadAligned32x2U32Msan(const uint32_t * const src[2],const ptrdiff_t x,const ptrdiff_t border,uint32x4_t dst[2][2])631 inline void LoadAligned32x2U32Msan(const uint32_t* const src[2],
632                                    const ptrdiff_t x, const ptrdiff_t border,
633                                    uint32x4_t dst[2][2]) {
634   LoadAligned32U32Msan(src[0], x, border, dst[0]);
635   LoadAligned32U32Msan(src[1], x, border, dst[1]);
636 }
637 
LoadAligned32x3U32(const uint32_t * const src[3],const ptrdiff_t x,uint32x4_t dst[3][2])638 inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x,
639                                uint32x4_t dst[3][2]) {
640   LoadAligned32U32(src[0] + x, dst[0]);
641   LoadAligned32U32(src[1] + x, dst[1]);
642   LoadAligned32U32(src[2] + x, dst[2]);
643 }
644 
LoadAligned32x3U32Msan(const uint32_t * const src[3],const ptrdiff_t x,const ptrdiff_t border,uint32x4_t dst[3][2])645 inline void LoadAligned32x3U32Msan(const uint32_t* const src[3],
646                                    const ptrdiff_t x, const ptrdiff_t border,
647                                    uint32x4_t dst[3][2]) {
648   LoadAligned32U32Msan(src[0], x, border, dst[0]);
649   LoadAligned32U32Msan(src[1], x, border, dst[1]);
650   LoadAligned32U32Msan(src[2], x, border, dst[2]);
651 }
652 
StoreAligned32U16(uint16_t * const dst,const uint16x8_t src[2])653 inline void StoreAligned32U16(uint16_t* const dst, const uint16x8_t src[2]) {
654   vst1q_u16(dst + 0, src[0]);
655   vst1q_u16(dst + 8, src[1]);
656 }
657 
StoreAligned32U32(uint32_t * const dst,const uint32x4_t src[2])658 inline void StoreAligned32U32(uint32_t* const dst, const uint32x4_t src[2]) {
659   vst1q_u32(dst + 0, src[0]);
660   vst1q_u32(dst + 4, src[1]);
661 }
662 
StoreAligned64U32(uint32_t * const dst,const uint32x4_t src[4])663 inline void StoreAligned64U32(uint32_t* const dst, const uint32x4_t src[4]) {
664   StoreAligned32U32(dst + 0, src + 0);
665   StoreAligned32U32(dst + 8, src + 2);
666 }
667 
VaddwLo8(const uint16x8_t src0,const uint8x16_t src1)668 inline uint16x8_t VaddwLo8(const uint16x8_t src0, const uint8x16_t src1) {
669   const uint8x8_t s1 = vget_low_u8(src1);
670   return vaddw_u8(src0, s1);
671 }
672 
VaddwHi8(const uint16x8_t src0,const uint8x16_t src1)673 inline uint16x8_t VaddwHi8(const uint16x8_t src0, const uint8x16_t src1) {
674   const uint8x8_t s1 = vget_high_u8(src1);
675   return vaddw_u8(src0, s1);
676 }
677 
VmullLo16(const uint16x8_t src0,const uint16x8_t src1)678 inline uint32x4_t VmullLo16(const uint16x8_t src0, const uint16x8_t src1) {
679   return vmull_u16(vget_low_u16(src0), vget_low_u16(src1));
680 }
681 
VmullHi16(const uint16x8_t src0,const uint16x8_t src1)682 inline uint32x4_t VmullHi16(const uint16x8_t src0, const uint16x8_t src1) {
683   return vmull_u16(vget_high_u16(src0), vget_high_u16(src1));
684 }
685 
686 template <int bytes>
VshrU128(const uint8x8x2_t src)687 inline uint8x8_t VshrU128(const uint8x8x2_t src) {
688   return vext_u8(src.val[0], src.val[1], bytes);
689 }
690 
691 template <int bytes>
VshrU128(const uint8x8_t src[2])692 inline uint8x8_t VshrU128(const uint8x8_t src[2]) {
693   return vext_u8(src[0], src[1], bytes);
694 }
695 
696 template <int bytes>
VshrU128(const uint8x16_t src[2])697 inline uint8x16_t VshrU128(const uint8x16_t src[2]) {
698   return vextq_u8(src[0], src[1], bytes);
699 }
700 
701 template <int bytes>
VshrU128(const uint16x8x2_t src)702 inline uint16x8_t VshrU128(const uint16x8x2_t src) {
703   return vextq_u16(src.val[0], src.val[1], bytes / 2);
704 }
705 
706 template <int bytes>
VshrU128(const uint16x8_t src[2])707 inline uint16x8_t VshrU128(const uint16x8_t src[2]) {
708   return vextq_u16(src[0], src[1], bytes / 2);
709 }
710 
Square(uint16x4_t s)711 inline uint32x4_t Square(uint16x4_t s) { return vmull_u16(s, s); }
712 
Square(const uint16x8_t src,uint32x4_t dst[2])713 inline void Square(const uint16x8_t src, uint32x4_t dst[2]) {
714   const uint16x4_t s_lo = vget_low_u16(src);
715   const uint16x4_t s_hi = vget_high_u16(src);
716   dst[0] = Square(s_lo);
717   dst[1] = Square(s_hi);
718 }
719 
720 template <int offset>
Prepare3_8(const uint8x16_t src[2],uint8x16_t dst[3])721 inline void Prepare3_8(const uint8x16_t src[2], uint8x16_t dst[3]) {
722   dst[0] = VshrU128<offset + 0>(src);
723   dst[1] = VshrU128<offset + 1>(src);
724   dst[2] = VshrU128<offset + 2>(src);
725 }
726 
Prepare3_16(const uint16x8_t src[2],uint16x8_t dst[3])727 inline void Prepare3_16(const uint16x8_t src[2], uint16x8_t dst[3]) {
728   dst[0] = src[0];
729   dst[1] = vextq_u16(src[0], src[1], 1);
730   dst[2] = vextq_u16(src[0], src[1], 2);
731 }
732 
733 template <int offset>
Prepare5_8(const uint8x16_t src[2],uint8x16_t dst[5])734 inline void Prepare5_8(const uint8x16_t src[2], uint8x16_t dst[5]) {
735   dst[0] = VshrU128<offset + 0>(src);
736   dst[1] = VshrU128<offset + 1>(src);
737   dst[2] = VshrU128<offset + 2>(src);
738   dst[3] = VshrU128<offset + 3>(src);
739   dst[4] = VshrU128<offset + 4>(src);
740 }
741 
Prepare5_16(const uint16x8_t src[2],uint16x8_t dst[5])742 inline void Prepare5_16(const uint16x8_t src[2], uint16x8_t dst[5]) {
743   dst[0] = src[0];
744   dst[1] = vextq_u16(src[0], src[1], 1);
745   dst[2] = vextq_u16(src[0], src[1], 2);
746   dst[3] = vextq_u16(src[0], src[1], 3);
747   dst[4] = vextq_u16(src[0], src[1], 4);
748 }
749 
Prepare3_32(const uint32x4_t src[2],uint32x4_t dst[3])750 inline void Prepare3_32(const uint32x4_t src[2], uint32x4_t dst[3]) {
751   dst[0] = src[0];
752   dst[1] = vextq_u32(src[0], src[1], 1);
753   dst[2] = vextq_u32(src[0], src[1], 2);
754 }
755 
Prepare5_32(const uint32x4_t src[2],uint32x4_t dst[5])756 inline void Prepare5_32(const uint32x4_t src[2], uint32x4_t dst[5]) {
757   Prepare3_32(src, dst);
758   dst[3] = vextq_u32(src[0], src[1], 3);
759   dst[4] = src[1];
760 }
761 
Sum3WLo16(const uint8x16_t src[3])762 inline uint16x8_t Sum3WLo16(const uint8x16_t src[3]) {
763   const uint16x8_t sum = vaddl_u8(vget_low_u8(src[0]), vget_low_u8(src[1]));
764   return vaddw_u8(sum, vget_low_u8(src[2]));
765 }
766 
Sum3WHi16(const uint8x16_t src[3])767 inline uint16x8_t Sum3WHi16(const uint8x16_t src[3]) {
768   const uint16x8_t sum = vaddl_u8(vget_high_u8(src[0]), vget_high_u8(src[1]));
769   return vaddw_u8(sum, vget_high_u8(src[2]));
770 }
771 
Sum3_16(const uint16x8_t src0,const uint16x8_t src1,const uint16x8_t src2)772 inline uint16x8_t Sum3_16(const uint16x8_t src0, const uint16x8_t src1,
773                           const uint16x8_t src2) {
774   const uint16x8_t sum = vaddq_u16(src0, src1);
775   return vaddq_u16(sum, src2);
776 }
777 
Sum3_16(const uint16x8_t src[3])778 inline uint16x8_t Sum3_16(const uint16x8_t src[3]) {
779   return Sum3_16(src[0], src[1], src[2]);
780 }
781 
Sum3_32(const uint32x4_t src0,const uint32x4_t src1,const uint32x4_t src2)782 inline uint32x4_t Sum3_32(const uint32x4_t src0, const uint32x4_t src1,
783                           const uint32x4_t src2) {
784   const uint32x4_t sum = vaddq_u32(src0, src1);
785   return vaddq_u32(sum, src2);
786 }
787 
Sum3_32(const uint32x4_t src[3])788 inline uint32x4_t Sum3_32(const uint32x4_t src[3]) {
789   return Sum3_32(src[0], src[1], src[2]);
790 }
791 
Sum3_32(const uint32x4_t src[3][2],uint32x4_t dst[2])792 inline void Sum3_32(const uint32x4_t src[3][2], uint32x4_t dst[2]) {
793   dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]);
794   dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]);
795 }
796 
Sum5_16(const uint16x8_t src[5])797 inline uint16x8_t Sum5_16(const uint16x8_t src[5]) {
798   const uint16x8_t sum01 = vaddq_u16(src[0], src[1]);
799   const uint16x8_t sum23 = vaddq_u16(src[2], src[3]);
800   const uint16x8_t sum = vaddq_u16(sum01, sum23);
801   return vaddq_u16(sum, src[4]);
802 }
803 
Sum5_32(const uint32x4_t * src0,const uint32x4_t * src1,const uint32x4_t * src2,const uint32x4_t * src3,const uint32x4_t * src4)804 inline uint32x4_t Sum5_32(const uint32x4_t* src0, const uint32x4_t* src1,
805                           const uint32x4_t* src2, const uint32x4_t* src3,
806                           const uint32x4_t* src4) {
807   const uint32x4_t sum01 = vaddq_u32(*src0, *src1);
808   const uint32x4_t sum23 = vaddq_u32(*src2, *src3);
809   const uint32x4_t sum = vaddq_u32(sum01, sum23);
810   return vaddq_u32(sum, *src4);
811 }
812 
Sum5_32(const uint32x4_t src[5])813 inline uint32x4_t Sum5_32(const uint32x4_t src[5]) {
814   return Sum5_32(&src[0], &src[1], &src[2], &src[3], &src[4]);
815 }
816 
Sum5_32(const uint32x4_t src[5][2],uint32x4_t dst[2])817 inline void Sum5_32(const uint32x4_t src[5][2], uint32x4_t dst[2]) {
818   dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]);
819   dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]);
820 }
821 
Sum3Horizontal16(const uint16x8_t src[2])822 inline uint16x8_t Sum3Horizontal16(const uint16x8_t src[2]) {
823   uint16x8_t s[3];
824   Prepare3_16(src, s);
825   return Sum3_16(s);
826 }
827 
Sum3Horizontal32(const uint32x4_t src[3],uint32x4_t dst[2])828 inline void Sum3Horizontal32(const uint32x4_t src[3], uint32x4_t dst[2]) {
829   uint32x4_t s[3];
830   Prepare3_32(src + 0, s);
831   dst[0] = Sum3_32(s);
832   Prepare3_32(src + 1, s);
833   dst[1] = Sum3_32(s);
834 }
835 
Sum5Horizontal16(const uint16x8_t src[2])836 inline uint16x8_t Sum5Horizontal16(const uint16x8_t src[2]) {
837   uint16x8_t s[5];
838   Prepare5_16(src, s);
839   return Sum5_16(s);
840 }
841 
Sum5Horizontal32(const uint32x4_t src[3],uint32x4_t dst[2])842 inline void Sum5Horizontal32(const uint32x4_t src[3], uint32x4_t dst[2]) {
843   uint32x4_t s[5];
844   Prepare5_32(src + 0, s);
845   dst[0] = Sum5_32(s);
846   Prepare5_32(src + 1, s);
847   dst[1] = Sum5_32(s);
848 }
849 
SumHorizontal16(const uint16x8_t src[2],uint16x8_t * const row3,uint16x8_t * const row5)850 void SumHorizontal16(const uint16x8_t src[2], uint16x8_t* const row3,
851                      uint16x8_t* const row5) {
852   uint16x8_t s[5];
853   Prepare5_16(src, s);
854   const uint16x8_t sum04 = vaddq_u16(s[0], s[4]);
855   *row3 = Sum3_16(s + 1);
856   *row5 = vaddq_u16(sum04, *row3);
857 }
858 
SumHorizontal16(const uint16x8_t src[3],uint16x8_t * const row3_0,uint16x8_t * const row3_1,uint16x8_t * const row5_0,uint16x8_t * const row5_1)859 inline void SumHorizontal16(const uint16x8_t src[3], uint16x8_t* const row3_0,
860                             uint16x8_t* const row3_1, uint16x8_t* const row5_0,
861                             uint16x8_t* const row5_1) {
862   SumHorizontal16(src + 0, row3_0, row5_0);
863   SumHorizontal16(src + 1, row3_1, row5_1);
864 }
865 
SumHorizontal32(const uint32x4_t src[5],uint32x4_t * const row_sq3,uint32x4_t * const row_sq5)866 void SumHorizontal32(const uint32x4_t src[5], uint32x4_t* const row_sq3,
867                      uint32x4_t* const row_sq5) {
868   const uint32x4_t sum04 = vaddq_u32(src[0], src[4]);
869   *row_sq3 = Sum3_32(src + 1);
870   *row_sq5 = vaddq_u32(sum04, *row_sq3);
871 }
872 
SumHorizontal32(const uint32x4_t src[3],uint32x4_t * const row_sq3_0,uint32x4_t * const row_sq3_1,uint32x4_t * const row_sq5_0,uint32x4_t * const row_sq5_1)873 inline void SumHorizontal32(const uint32x4_t src[3],
874                             uint32x4_t* const row_sq3_0,
875                             uint32x4_t* const row_sq3_1,
876                             uint32x4_t* const row_sq5_0,
877                             uint32x4_t* const row_sq5_1) {
878   uint32x4_t s[5];
879   Prepare5_32(src + 0, s);
880   SumHorizontal32(s, row_sq3_0, row_sq5_0);
881   Prepare5_32(src + 1, s);
882   SumHorizontal32(s, row_sq3_1, row_sq5_1);
883 }
884 
Sum343Lo(const uint8x16_t ma3[3])885 inline uint16x8_t Sum343Lo(const uint8x16_t ma3[3]) {
886   const uint16x8_t sum = Sum3WLo16(ma3);
887   const uint16x8_t sum3 = Sum3_16(sum, sum, sum);
888   return VaddwLo8(sum3, ma3[1]);
889 }
890 
Sum343Hi(const uint8x16_t ma3[3])891 inline uint16x8_t Sum343Hi(const uint8x16_t ma3[3]) {
892   const uint16x8_t sum = Sum3WHi16(ma3);
893   const uint16x8_t sum3 = Sum3_16(sum, sum, sum);
894   return VaddwHi8(sum3, ma3[1]);
895 }
896 
Sum343(const uint32x4_t src[3])897 inline uint32x4_t Sum343(const uint32x4_t src[3]) {
898   const uint32x4_t sum = Sum3_32(src);
899   const uint32x4_t sum3 = Sum3_32(sum, sum, sum);
900   return vaddq_u32(sum3, src[1]);
901 }
902 
Sum343(const uint32x4_t src[3],uint32x4_t dst[2])903 inline void Sum343(const uint32x4_t src[3], uint32x4_t dst[2]) {
904   uint32x4_t s[3];
905   Prepare3_32(src + 0, s);
906   dst[0] = Sum343(s);
907   Prepare3_32(src + 1, s);
908   dst[1] = Sum343(s);
909 }
910 
Sum565Lo(const uint8x16_t src[3])911 inline uint16x8_t Sum565Lo(const uint8x16_t src[3]) {
912   const uint16x8_t sum = Sum3WLo16(src);
913   const uint16x8_t sum4 = vshlq_n_u16(sum, 2);
914   const uint16x8_t sum5 = vaddq_u16(sum4, sum);
915   return VaddwLo8(sum5, src[1]);
916 }
917 
Sum565Hi(const uint8x16_t src[3])918 inline uint16x8_t Sum565Hi(const uint8x16_t src[3]) {
919   const uint16x8_t sum = Sum3WHi16(src);
920   const uint16x8_t sum4 = vshlq_n_u16(sum, 2);
921   const uint16x8_t sum5 = vaddq_u16(sum4, sum);
922   return VaddwHi8(sum5, src[1]);
923 }
924 
Sum565(const uint32x4_t src[3])925 inline uint32x4_t Sum565(const uint32x4_t src[3]) {
926   const uint32x4_t sum = Sum3_32(src);
927   const uint32x4_t sum4 = vshlq_n_u32(sum, 2);
928   const uint32x4_t sum5 = vaddq_u32(sum4, sum);
929   return vaddq_u32(sum5, src[1]);
930 }
931 
Sum565(const uint32x4_t src[3],uint32x4_t dst[2])932 inline void Sum565(const uint32x4_t src[3], uint32x4_t dst[2]) {
933   uint32x4_t s[3];
934   Prepare3_32(src + 0, s);
935   dst[0] = Sum565(s);
936   Prepare3_32(src + 1, s);
937   dst[1] = Sum565(s);
938 }
939 
BoxSum(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const ptrdiff_t sum_stride,const ptrdiff_t sum_width,uint16_t * sum3,uint16_t * sum5,uint32_t * square_sum3,uint32_t * square_sum5)940 inline void BoxSum(const uint16_t* src, const ptrdiff_t src_stride,
941                    const ptrdiff_t width, const ptrdiff_t sum_stride,
942                    const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5,
943                    uint32_t* square_sum3, uint32_t* square_sum5) {
944   const ptrdiff_t overread_in_bytes =
945       kOverreadInBytesPass1 - sizeof(*src) * width;
946   int y = 2;
947   do {
948     uint16x8_t s[3];
949     uint32x4_t sq[6];
950     s[0] = Load1QMsanU16(src, overread_in_bytes);
951     Square(s[0], sq);
952     ptrdiff_t x = sum_width;
953     do {
954       uint16x8_t row3[2], row5[2];
955       uint32x4_t row_sq3[2], row_sq5[2];
956       s[1] = Load1QMsanU16(
957           src + 8, overread_in_bytes + sizeof(*src) * (sum_width - x + 8));
958       x -= 16;
959       src += 16;
960       s[2] = Load1QMsanU16(src,
961                            overread_in_bytes + sizeof(*src) * (sum_width - x));
962       Square(s[1], sq + 2);
963       Square(s[2], sq + 4);
964       SumHorizontal16(s, &row3[0], &row3[1], &row5[0], &row5[1]);
965       StoreAligned32U16(sum3, row3);
966       StoreAligned32U16(sum5, row5);
967       SumHorizontal32(sq + 0, &row_sq3[0], &row_sq3[1], &row_sq5[0],
968                       &row_sq5[1]);
969       StoreAligned32U32(square_sum3 + 0, row_sq3);
970       StoreAligned32U32(square_sum5 + 0, row_sq5);
971       SumHorizontal32(sq + 2, &row_sq3[0], &row_sq3[1], &row_sq5[0],
972                       &row_sq5[1]);
973       StoreAligned32U32(square_sum3 + 8, row_sq3);
974       StoreAligned32U32(square_sum5 + 8, row_sq5);
975       s[0] = s[2];
976       sq[0] = sq[4];
977       sq[1] = sq[5];
978       sum3 += 16;
979       sum5 += 16;
980       square_sum3 += 16;
981       square_sum5 += 16;
982     } while (x != 0);
983     src += src_stride - sum_width;
984     sum3 += sum_stride - sum_width;
985     sum5 += sum_stride - sum_width;
986     square_sum3 += sum_stride - sum_width;
987     square_sum5 += sum_stride - sum_width;
988   } while (--y != 0);
989 }
990 
991 template <int size>
BoxSum(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const ptrdiff_t sum_stride,const ptrdiff_t sum_width,uint16_t * sums,uint32_t * square_sums)992 inline void BoxSum(const uint16_t* src, const ptrdiff_t src_stride,
993                    const ptrdiff_t width, const ptrdiff_t sum_stride,
994                    const ptrdiff_t sum_width, uint16_t* sums,
995                    uint32_t* square_sums) {
996   static_assert(size == 3 || size == 5, "");
997   const ptrdiff_t overread_in_bytes =
998       ((size == 5) ? kOverreadInBytesPass1 : kOverreadInBytesPass2) -
999       sizeof(*src) * width;
1000   int y = 2;
1001   do {
1002     uint16x8_t s[3];
1003     uint32x4_t sq[6];
1004     s[0] = Load1QMsanU16(src, overread_in_bytes);
1005     Square(s[0], sq);
1006     ptrdiff_t x = sum_width;
1007     do {
1008       uint16x8_t row[2];
1009       uint32x4_t row_sq[4];
1010       s[1] = Load1QMsanU16(
1011           src + 8, overread_in_bytes + sizeof(*src) * (sum_width - x + 8));
1012       x -= 16;
1013       src += 16;
1014       s[2] = Load1QMsanU16(src,
1015                            overread_in_bytes + sizeof(*src) * (sum_width - x));
1016       Square(s[1], sq + 2);
1017       Square(s[2], sq + 4);
1018       if (size == 3) {
1019         row[0] = Sum3Horizontal16(s + 0);
1020         row[1] = Sum3Horizontal16(s + 1);
1021         Sum3Horizontal32(sq + 0, row_sq + 0);
1022         Sum3Horizontal32(sq + 2, row_sq + 2);
1023       } else {
1024         row[0] = Sum5Horizontal16(s + 0);
1025         row[1] = Sum5Horizontal16(s + 1);
1026         Sum5Horizontal32(sq + 0, row_sq + 0);
1027         Sum5Horizontal32(sq + 2, row_sq + 2);
1028       }
1029       StoreAligned32U16(sums, row);
1030       StoreAligned64U32(square_sums, row_sq);
1031       s[0] = s[2];
1032       sq[0] = sq[4];
1033       sq[1] = sq[5];
1034       sums += 16;
1035       square_sums += 16;
1036     } while (x != 0);
1037     src += src_stride - sum_width;
1038     sums += sum_stride - sum_width;
1039     square_sums += sum_stride - sum_width;
1040   } while (--y != 0);
1041 }
1042 
1043 template <int n>
CalculateMa(const uint16x4_t sum,const uint32x4_t sum_sq,const uint32_t scale)1044 inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq,
1045                               const uint32_t scale) {
1046   // a = |sum_sq|
1047   // d = |sum|
1048   // p = (a * n < d * d) ? 0 : a * n - d * d;
1049   const uint32x4_t dxd = vmull_u16(sum, sum);
1050   const uint32x4_t axn = vmulq_n_u32(sum_sq, n);
1051   // Ensure |p| does not underflow by using saturating subtraction.
1052   const uint32x4_t p = vqsubq_u32(axn, dxd);
1053   const uint32x4_t pxs = vmulq_n_u32(p, scale);
1054   // vrshrn_n_u32() (narrowing shift) can only shift by 16 and kSgrProjScaleBits
1055   // is 20.
1056   const uint32x4_t shifted = vrshrq_n_u32(pxs, kSgrProjScaleBits);
1057   return vmovn_u32(shifted);
1058 }
1059 
1060 template <int n>
CalculateMa(const uint16x8_t sum,const uint32x4_t sum_sq[2],const uint32_t scale)1061 inline uint16x8_t CalculateMa(const uint16x8_t sum, const uint32x4_t sum_sq[2],
1062                               const uint32_t scale) {
1063   static_assert(n == 9 || n == 25, "");
1064   const uint16x8_t b = vrshrq_n_u16(sum, 2);
1065   const uint16x4_t sum_lo = vget_low_u16(b);
1066   const uint16x4_t sum_hi = vget_high_u16(b);
1067   const uint16x4_t z0 =
1068       CalculateMa<n>(sum_lo, vrshrq_n_u32(sum_sq[0], 4), scale);
1069   const uint16x4_t z1 =
1070       CalculateMa<n>(sum_hi, vrshrq_n_u32(sum_sq[1], 4), scale);
1071   return vcombine_u16(z0, z1);
1072 }
1073 
CalculateB5(const uint16x8_t sum,const uint16x8_t ma,uint32x4_t b[2])1074 inline void CalculateB5(const uint16x8_t sum, const uint16x8_t ma,
1075                         uint32x4_t b[2]) {
1076   // one_over_n == 164.
1077   constexpr uint32_t one_over_n =
1078       ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25;
1079   // one_over_n_quarter == 41.
1080   constexpr uint32_t one_over_n_quarter = one_over_n >> 2;
1081   static_assert(one_over_n == one_over_n_quarter << 2, "");
1082   // |ma| is in range [0, 255].
1083   const uint32x4_t m2 = VmullLo16(ma, sum);
1084   const uint32x4_t m3 = VmullHi16(ma, sum);
1085   const uint32x4_t m0 = vmulq_n_u32(m2, one_over_n_quarter);
1086   const uint32x4_t m1 = vmulq_n_u32(m3, one_over_n_quarter);
1087   b[0] = vrshrq_n_u32(m0, kSgrProjReciprocalBits - 2);
1088   b[1] = vrshrq_n_u32(m1, kSgrProjReciprocalBits - 2);
1089 }
1090 
CalculateB3(const uint16x8_t sum,const uint16x8_t ma,uint32x4_t b[2])1091 inline void CalculateB3(const uint16x8_t sum, const uint16x8_t ma,
1092                         uint32x4_t b[2]) {
1093   // one_over_n == 455.
1094   constexpr uint32_t one_over_n =
1095       ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9;
1096   const uint32x4_t m0 = VmullLo16(ma, sum);
1097   const uint32x4_t m1 = VmullHi16(ma, sum);
1098   const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n);
1099   const uint32x4_t m3 = vmulq_n_u32(m1, one_over_n);
1100   b[0] = vrshrq_n_u32(m2, kSgrProjReciprocalBits);
1101   b[1] = vrshrq_n_u32(m3, kSgrProjReciprocalBits);
1102 }
1103 
CalculateSumAndIndex3(const uint16x8_t s3[3],const uint32x4_t sq3[3][2],const uint32_t scale,uint16x8_t * const sum,uint16x8_t * const index)1104 inline void CalculateSumAndIndex3(const uint16x8_t s3[3],
1105                                   const uint32x4_t sq3[3][2],
1106                                   const uint32_t scale, uint16x8_t* const sum,
1107                                   uint16x8_t* const index) {
1108   uint32x4_t sum_sq[2];
1109   *sum = Sum3_16(s3);
1110   Sum3_32(sq3, sum_sq);
1111   *index = CalculateMa<9>(*sum, sum_sq, scale);
1112 }
1113 
CalculateSumAndIndex5(const uint16x8_t s5[5],const uint32x4_t sq5[5][2],const uint32_t scale,uint16x8_t * const sum,uint16x8_t * const index)1114 inline void CalculateSumAndIndex5(const uint16x8_t s5[5],
1115                                   const uint32x4_t sq5[5][2],
1116                                   const uint32_t scale, uint16x8_t* const sum,
1117                                   uint16x8_t* const index) {
1118   uint32x4_t sum_sq[2];
1119   *sum = Sum5_16(s5);
1120   Sum5_32(sq5, sum_sq);
1121   *index = CalculateMa<25>(*sum, sum_sq, scale);
1122 }
1123 
1124 template <int n, int offset>
LookupIntermediate(const uint16x8_t sum,const uint16x8_t index,uint8x16_t * const ma,uint32x4_t b[2])1125 inline void LookupIntermediate(const uint16x8_t sum, const uint16x8_t index,
1126                                uint8x16_t* const ma, uint32x4_t b[2]) {
1127   static_assert(n == 9 || n == 25, "");
1128   static_assert(offset == 0 || offset == 8, "");
1129 
1130   const uint8x8_t idx = vqmovn_u16(index);
1131   uint8_t temp[8];
1132   vst1_u8(temp, idx);
1133   // offset == 0 is assumed to be the first call to this function. The value is
1134   // duplicated to avoid -Wuninitialized warnings under gcc.
1135   if (offset == 0) {
1136     *ma = vdupq_n_u8(kSgrMaLookup[temp[0]]);
1137   } else {
1138     *ma = vsetq_lane_u8(kSgrMaLookup[temp[0]], *ma, offset + 0);
1139   }
1140   *ma = vsetq_lane_u8(kSgrMaLookup[temp[1]], *ma, offset + 1);
1141   *ma = vsetq_lane_u8(kSgrMaLookup[temp[2]], *ma, offset + 2);
1142   *ma = vsetq_lane_u8(kSgrMaLookup[temp[3]], *ma, offset + 3);
1143   *ma = vsetq_lane_u8(kSgrMaLookup[temp[4]], *ma, offset + 4);
1144   *ma = vsetq_lane_u8(kSgrMaLookup[temp[5]], *ma, offset + 5);
1145   *ma = vsetq_lane_u8(kSgrMaLookup[temp[6]], *ma, offset + 6);
1146   *ma = vsetq_lane_u8(kSgrMaLookup[temp[7]], *ma, offset + 7);
1147   // b = ma * b * one_over_n
1148   // |ma| = [0, 255]
1149   // |sum| is a box sum with radius 1 or 2.
1150   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
1151   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
1152   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
1153   // When radius is 2 |n| is 25. |one_over_n| is 164.
1154   // When radius is 1 |n| is 9. |one_over_n| is 455.
1155   // |kSgrProjReciprocalBits| is 12.
1156   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
1157   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
1158   const uint16x8_t maq =
1159       vmovl_u8((offset == 0) ? vget_low_u8(*ma) : vget_high_u8(*ma));
1160   if (n == 9) {
1161     CalculateB3(sum, maq, b);
1162   } else {
1163     CalculateB5(sum, maq, b);
1164   }
1165 }
1166 
AdjustValue(const uint8x8_t value,const uint8x8_t index,const int threshold)1167 inline uint8x8_t AdjustValue(const uint8x8_t value, const uint8x8_t index,
1168                              const int threshold) {
1169   const uint8x8_t thresholds = vdup_n_u8(threshold);
1170   const uint8x8_t offset = vcgt_u8(index, thresholds);
1171   // Adding 255 is equivalent to subtracting 1 for 8-bit data.
1172   return vadd_u8(value, offset);
1173 }
1174 
MaLookupAndAdjust(const uint8x8x4_t table0,const uint8x8x2_t table1,const uint16x8_t index)1175 inline uint8x8_t MaLookupAndAdjust(const uint8x8x4_t table0,
1176                                    const uint8x8x2_t table1,
1177                                    const uint16x8_t index) {
1178   const uint8x8_t idx = vqmovn_u16(index);
1179   // All elements whose indices are out of range [0, 47] are set to 0.
1180   uint8x8_t val = vtbl4_u8(table0, idx);  // Range [0, 31].
1181   // Subtract 8 to shuffle the next index range.
1182   const uint8x8_t sub_idx = vsub_u8(idx, vdup_n_u8(32));
1183   const uint8x8_t res = vtbl2_u8(table1, sub_idx);  // Range [32, 47].
1184   // Use OR instruction to combine shuffle results together.
1185   val = vorr_u8(val, res);
1186 
1187   // For elements whose indices are larger than 47, since they seldom change
1188   // values with the increase of the index, we use comparison and arithmetic
1189   // operations to calculate their values.
1190   // Elements whose indices are larger than 47 (with value 0) are set to 5.
1191   val = vmax_u8(val, vdup_n_u8(5));
1192   val = AdjustValue(val, idx, 55);   // 55 is the last index which value is 5.
1193   val = AdjustValue(val, idx, 72);   // 72 is the last index which value is 4.
1194   val = AdjustValue(val, idx, 101);  // 101 is the last index which value is 3.
1195   val = AdjustValue(val, idx, 169);  // 169 is the last index which value is 2.
1196   val = AdjustValue(val, idx, 254);  // 254 is the last index which value is 1.
1197   return val;
1198 }
1199 
CalculateIntermediate(const uint16x8_t sum[2],const uint16x8_t index[2],uint8x16_t * const ma,uint32x4_t b0[2],uint32x4_t b1[2])1200 inline void CalculateIntermediate(const uint16x8_t sum[2],
1201                                   const uint16x8_t index[2],
1202                                   uint8x16_t* const ma, uint32x4_t b0[2],
1203                                   uint32x4_t b1[2]) {
1204   // Use table lookup to read elements whose indices are less than 48.
1205   // Using one uint8x8x4_t vector and one uint8x8x2_t vector is faster than
1206   // using two uint8x8x3_t vectors.
1207   uint8x8x4_t table0;
1208   uint8x8x2_t table1;
1209   table0.val[0] = vld1_u8(kSgrMaLookup + 0 * 8);
1210   table0.val[1] = vld1_u8(kSgrMaLookup + 1 * 8);
1211   table0.val[2] = vld1_u8(kSgrMaLookup + 2 * 8);
1212   table0.val[3] = vld1_u8(kSgrMaLookup + 3 * 8);
1213   table1.val[0] = vld1_u8(kSgrMaLookup + 4 * 8);
1214   table1.val[1] = vld1_u8(kSgrMaLookup + 5 * 8);
1215   const uint8x8_t ma_lo = MaLookupAndAdjust(table0, table1, index[0]);
1216   const uint8x8_t ma_hi = MaLookupAndAdjust(table0, table1, index[1]);
1217   *ma = vcombine_u8(ma_lo, ma_hi);
1218   // b = ma * b * one_over_n
1219   // |ma| = [0, 255]
1220   // |sum| is a box sum with radius 1 or 2.
1221   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
1222   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
1223   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
1224   // When radius is 2 |n| is 25. |one_over_n| is 164.
1225   // When radius is 1 |n| is 9. |one_over_n| is 455.
1226   // |kSgrProjReciprocalBits| is 12.
1227   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
1228   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
1229   const uint16x8_t maq0 = vmovl_u8(vget_low_u8(*ma));
1230   CalculateB3(sum[0], maq0, b0);
1231   const uint16x8_t maq1 = vmovl_u8(vget_high_u8(*ma));
1232   CalculateB3(sum[1], maq1, b1);
1233 }
1234 
CalculateIntermediate(const uint16x8_t sum[2],const uint16x8_t index[2],uint8x16_t ma[2],uint32x4_t b[4])1235 inline void CalculateIntermediate(const uint16x8_t sum[2],
1236                                   const uint16x8_t index[2], uint8x16_t ma[2],
1237                                   uint32x4_t b[4]) {
1238   uint8x16_t mas;
1239   CalculateIntermediate(sum, index, &mas, b + 0, b + 2);
1240   ma[0] = vcombine_u8(vget_low_u8(ma[0]), vget_low_u8(mas));
1241   ma[1] = vextq_u8(mas, vdupq_n_u8(0), 8);
1242 }
1243 
1244 template <int offset>
CalculateIntermediate5(const uint16x8_t s5[5],const uint32x4_t sq5[5][2],const uint32_t scale,uint8x16_t * const ma,uint32x4_t b[2])1245 inline void CalculateIntermediate5(const uint16x8_t s5[5],
1246                                    const uint32x4_t sq5[5][2],
1247                                    const uint32_t scale, uint8x16_t* const ma,
1248                                    uint32x4_t b[2]) {
1249   static_assert(offset == 0 || offset == 8, "");
1250   uint16x8_t sum, index;
1251   CalculateSumAndIndex5(s5, sq5, scale, &sum, &index);
1252   LookupIntermediate<25, offset>(sum, index, ma, b);
1253 }
1254 
CalculateIntermediate3(const uint16x8_t s3[3],const uint32x4_t sq3[3][2],const uint32_t scale,uint8x16_t * const ma,uint32x4_t b[2])1255 inline void CalculateIntermediate3(const uint16x8_t s3[3],
1256                                    const uint32x4_t sq3[3][2],
1257                                    const uint32_t scale, uint8x16_t* const ma,
1258                                    uint32x4_t b[2]) {
1259   uint16x8_t sum, index;
1260   CalculateSumAndIndex3(s3, sq3, scale, &sum, &index);
1261   LookupIntermediate<9, 0>(sum, index, ma, b);
1262 }
1263 
Store343_444(const uint32x4_t b3[3],const ptrdiff_t x,uint32x4_t sum_b343[2],uint32x4_t sum_b444[2],uint32_t * const b343,uint32_t * const b444)1264 inline void Store343_444(const uint32x4_t b3[3], const ptrdiff_t x,
1265                          uint32x4_t sum_b343[2], uint32x4_t sum_b444[2],
1266                          uint32_t* const b343, uint32_t* const b444) {
1267   uint32x4_t b[3], sum_b111[2];
1268   Prepare3_32(b3 + 0, b);
1269   sum_b111[0] = Sum3_32(b);
1270   sum_b444[0] = vshlq_n_u32(sum_b111[0], 2);
1271   sum_b343[0] = vsubq_u32(sum_b444[0], sum_b111[0]);
1272   sum_b343[0] = vaddq_u32(sum_b343[0], b[1]);
1273   Prepare3_32(b3 + 1, b);
1274   sum_b111[1] = Sum3_32(b);
1275   sum_b444[1] = vshlq_n_u32(sum_b111[1], 2);
1276   sum_b343[1] = vsubq_u32(sum_b444[1], sum_b111[1]);
1277   sum_b343[1] = vaddq_u32(sum_b343[1], b[1]);
1278   StoreAligned32U32(b444 + x, sum_b444);
1279   StoreAligned32U32(b343 + x, sum_b343);
1280 }
1281 
Store343_444Lo(const uint8x16_t ma3[3],const uint32x4_t b3[3],const ptrdiff_t x,uint16x8_t * const sum_ma343,uint16x8_t * const sum_ma444,uint32x4_t sum_b343[2],uint32x4_t sum_b444[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1282 inline void Store343_444Lo(const uint8x16_t ma3[3], const uint32x4_t b3[3],
1283                            const ptrdiff_t x, uint16x8_t* const sum_ma343,
1284                            uint16x8_t* const sum_ma444, uint32x4_t sum_b343[2],
1285                            uint32x4_t sum_b444[2], uint16_t* const ma343,
1286                            uint16_t* const ma444, uint32_t* const b343,
1287                            uint32_t* const b444) {
1288   const uint16x8_t sum_ma111 = Sum3WLo16(ma3);
1289   *sum_ma444 = vshlq_n_u16(sum_ma111, 2);
1290   vst1q_u16(ma444 + x, *sum_ma444);
1291   const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111);
1292   *sum_ma343 = VaddwLo8(sum333, ma3[1]);
1293   vst1q_u16(ma343 + x, *sum_ma343);
1294   Store343_444(b3, x, sum_b343, sum_b444, b343, b444);
1295 }
1296 
Store343_444Hi(const uint8x16_t ma3[3],const uint32x4_t b3[2],const ptrdiff_t x,uint16x8_t * const sum_ma343,uint16x8_t * const sum_ma444,uint32x4_t sum_b343[2],uint32x4_t sum_b444[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1297 inline void Store343_444Hi(const uint8x16_t ma3[3], const uint32x4_t b3[2],
1298                            const ptrdiff_t x, uint16x8_t* const sum_ma343,
1299                            uint16x8_t* const sum_ma444, uint32x4_t sum_b343[2],
1300                            uint32x4_t sum_b444[2], uint16_t* const ma343,
1301                            uint16_t* const ma444, uint32_t* const b343,
1302                            uint32_t* const b444) {
1303   const uint16x8_t sum_ma111 = Sum3WHi16(ma3);
1304   *sum_ma444 = vshlq_n_u16(sum_ma111, 2);
1305   vst1q_u16(ma444 + x, *sum_ma444);
1306   const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111);
1307   *sum_ma343 = VaddwHi8(sum333, ma3[1]);
1308   vst1q_u16(ma343 + x, *sum_ma343);
1309   Store343_444(b3, x, sum_b343, sum_b444, b343, b444);
1310 }
1311 
Store343_444Lo(const uint8x16_t ma3[3],const uint32x4_t b3[2],const ptrdiff_t x,uint16x8_t * const sum_ma343,uint32x4_t sum_b343[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1312 inline void Store343_444Lo(const uint8x16_t ma3[3], const uint32x4_t b3[2],
1313                            const ptrdiff_t x, uint16x8_t* const sum_ma343,
1314                            uint32x4_t sum_b343[2], uint16_t* const ma343,
1315                            uint16_t* const ma444, uint32_t* const b343,
1316                            uint32_t* const b444) {
1317   uint16x8_t sum_ma444;
1318   uint32x4_t sum_b444[2];
1319   Store343_444Lo(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343,
1320                  ma444, b343, b444);
1321 }
1322 
Store343_444Hi(const uint8x16_t ma3[3],const uint32x4_t b3[2],const ptrdiff_t x,uint16x8_t * const sum_ma343,uint32x4_t sum_b343[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1323 inline void Store343_444Hi(const uint8x16_t ma3[3], const uint32x4_t b3[2],
1324                            const ptrdiff_t x, uint16x8_t* const sum_ma343,
1325                            uint32x4_t sum_b343[2], uint16_t* const ma343,
1326                            uint16_t* const ma444, uint32_t* const b343,
1327                            uint32_t* const b444) {
1328   uint16x8_t sum_ma444;
1329   uint32x4_t sum_b444[2];
1330   Store343_444Hi(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343,
1331                  ma444, b343, b444);
1332 }
1333 
Store343_444Lo(const uint8x16_t ma3[3],const uint32x4_t b3[2],const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1334 inline void Store343_444Lo(const uint8x16_t ma3[3], const uint32x4_t b3[2],
1335                            const ptrdiff_t x, uint16_t* const ma343,
1336                            uint16_t* const ma444, uint32_t* const b343,
1337                            uint32_t* const b444) {
1338   uint16x8_t sum_ma343;
1339   uint32x4_t sum_b343[2];
1340   Store343_444Lo(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444);
1341 }
1342 
Store343_444Hi(const uint8x16_t ma3[3],const uint32x4_t b3[2],const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1343 inline void Store343_444Hi(const uint8x16_t ma3[3], const uint32x4_t b3[2],
1344                            const ptrdiff_t x, uint16_t* const ma343,
1345                            uint16_t* const ma444, uint32_t* const b343,
1346                            uint32_t* const b444) {
1347   uint16x8_t sum_ma343;
1348   uint32x4_t sum_b343[2];
1349   Store343_444Hi(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444);
1350 }
1351 
BoxFilterPreProcess5Lo(const uint16x8_t s[2][4],const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint32x4_t sq[2][8],uint8x16_t * const ma,uint32x4_t b[2])1352 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo(
1353     const uint16x8_t s[2][4], const uint32_t scale, uint16_t* const sum5[5],
1354     uint32_t* const square_sum5[5], uint32x4_t sq[2][8], uint8x16_t* const ma,
1355     uint32x4_t b[2]) {
1356   uint16x8_t s5[2][5];
1357   uint32x4_t sq5[5][2];
1358   Square(s[0][1], sq[0] + 2);
1359   Square(s[1][1], sq[1] + 2);
1360   s5[0][3] = Sum5Horizontal16(s[0]);
1361   vst1q_u16(sum5[3], s5[0][3]);
1362   s5[0][4] = Sum5Horizontal16(s[1]);
1363   vst1q_u16(sum5[4], s5[0][4]);
1364   Sum5Horizontal32(sq[0], sq5[3]);
1365   StoreAligned32U32(square_sum5[3], sq5[3]);
1366   Sum5Horizontal32(sq[1], sq5[4]);
1367   StoreAligned32U32(square_sum5[4], sq5[4]);
1368   LoadAligned16x3U16(sum5, 0, s5[0]);
1369   LoadAligned32x3U32(square_sum5, 0, sq5);
1370   CalculateIntermediate5<0>(s5[0], sq5, scale, ma, b);
1371 }
1372 
BoxFilterPreProcess5(const uint16x8_t s[2][4],const ptrdiff_t sum_width,const ptrdiff_t x,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint32x4_t sq[2][8],uint8x16_t ma[2],uint32x4_t b[6])1373 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
1374     const uint16x8_t s[2][4], const ptrdiff_t sum_width, const ptrdiff_t x,
1375     const uint32_t scale, uint16_t* const sum5[5],
1376     uint32_t* const square_sum5[5], uint32x4_t sq[2][8], uint8x16_t ma[2],
1377     uint32x4_t b[6]) {
1378   uint16x8_t s5[2][5];
1379   uint32x4_t sq5[5][2];
1380   Square(s[0][2], sq[0] + 4);
1381   Square(s[1][2], sq[1] + 4);
1382   s5[0][3] = Sum5Horizontal16(s[0] + 1);
1383   s5[1][3] = Sum5Horizontal16(s[0] + 2);
1384   vst1q_u16(sum5[3] + x + 0, s5[0][3]);
1385   vst1q_u16(sum5[3] + x + 8, s5[1][3]);
1386   s5[0][4] = Sum5Horizontal16(s[1] + 1);
1387   s5[1][4] = Sum5Horizontal16(s[1] + 2);
1388   vst1q_u16(sum5[4] + x + 0, s5[0][4]);
1389   vst1q_u16(sum5[4] + x + 8, s5[1][4]);
1390   Sum5Horizontal32(sq[0] + 2, sq5[3]);
1391   StoreAligned32U32(square_sum5[3] + x, sq5[3]);
1392   Sum5Horizontal32(sq[1] + 2, sq5[4]);
1393   StoreAligned32U32(square_sum5[4] + x, sq5[4]);
1394   LoadAligned16x3U16(sum5, x, s5[0]);
1395   LoadAligned32x3U32(square_sum5, x, sq5);
1396   CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], b + 2);
1397 
1398   Square(s[0][3], sq[0] + 6);
1399   Square(s[1][3], sq[1] + 6);
1400   Sum5Horizontal32(sq[0] + 4, sq5[3]);
1401   StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]);
1402   Sum5Horizontal32(sq[1] + 4, sq5[4]);
1403   StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]);
1404   LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]);
1405   LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5);
1406   CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], b + 4);
1407 }
1408 
BoxFilterPreProcess5LastRowLo(const uint16x8_t s[2],const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],uint32x4_t sq[4],uint8x16_t * const ma,uint32x4_t b[2])1409 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo(
1410     const uint16x8_t s[2], const uint32_t scale, const uint16_t* const sum5[5],
1411     const uint32_t* const square_sum5[5], uint32x4_t sq[4],
1412     uint8x16_t* const ma, uint32x4_t b[2]) {
1413   uint16x8_t s5[5];
1414   uint32x4_t sq5[5][2];
1415   Square(s[1], sq + 2);
1416   s5[3] = s5[4] = Sum5Horizontal16(s);
1417   Sum5Horizontal32(sq, sq5[3]);
1418   sq5[4][0] = sq5[3][0];
1419   sq5[4][1] = sq5[3][1];
1420   LoadAligned16x3U16(sum5, 0, s5);
1421   LoadAligned32x3U32(square_sum5, 0, sq5);
1422   CalculateIntermediate5<0>(s5, sq5, scale, ma, b);
1423 }
1424 
BoxFilterPreProcess5LastRow(const uint16x8_t s[4],const ptrdiff_t sum_width,const ptrdiff_t x,const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],uint32x4_t sq[8],uint8x16_t ma[2],uint32x4_t b[6])1425 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
1426     const uint16x8_t s[4], const ptrdiff_t sum_width, const ptrdiff_t x,
1427     const uint32_t scale, const uint16_t* const sum5[5],
1428     const uint32_t* const square_sum5[5], uint32x4_t sq[8], uint8x16_t ma[2],
1429     uint32x4_t b[6]) {
1430   uint16x8_t s5[2][5];
1431   uint32x4_t sq5[5][2];
1432   Square(s[2], sq + 4);
1433   s5[0][3] = Sum5Horizontal16(s + 1);
1434   s5[1][3] = Sum5Horizontal16(s + 2);
1435   s5[0][4] = s5[0][3];
1436   s5[1][4] = s5[1][3];
1437   Sum5Horizontal32(sq + 2, sq5[3]);
1438   sq5[4][0] = sq5[3][0];
1439   sq5[4][1] = sq5[3][1];
1440   LoadAligned16x3U16(sum5, x, s5[0]);
1441   LoadAligned32x3U32(square_sum5, x, sq5);
1442   CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], b + 2);
1443 
1444   Square(s[3], sq + 6);
1445   Sum5Horizontal32(sq + 4, sq5[3]);
1446   sq5[4][0] = sq5[3][0];
1447   sq5[4][1] = sq5[3][1];
1448   LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]);
1449   LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5);
1450   CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], b + 4);
1451 }
1452 
BoxFilterPreProcess3Lo(const uint16x8_t s[2],const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint32x4_t sq[4],uint8x16_t * const ma,uint32x4_t b[2])1453 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo(
1454     const uint16x8_t s[2], const uint32_t scale, uint16_t* const sum3[3],
1455     uint32_t* const square_sum3[3], uint32x4_t sq[4], uint8x16_t* const ma,
1456     uint32x4_t b[2]) {
1457   uint16x8_t s3[3];
1458   uint32x4_t sq3[3][2];
1459   Square(s[1], sq + 2);
1460   s3[2] = Sum3Horizontal16(s);
1461   vst1q_u16(sum3[2], s3[2]);
1462   Sum3Horizontal32(sq, sq3[2]);
1463   StoreAligned32U32(square_sum3[2], sq3[2]);
1464   LoadAligned16x2U16(sum3, 0, s3);
1465   LoadAligned32x2U32(square_sum3, 0, sq3);
1466   CalculateIntermediate3(s3, sq3, scale, ma, b);
1467 }
1468 
BoxFilterPreProcess3(const uint16x8_t s[4],const ptrdiff_t x,const ptrdiff_t sum_width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint32x4_t sq[8],uint8x16_t ma[2],uint32x4_t b[6])1469 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
1470     const uint16x8_t s[4], const ptrdiff_t x, const ptrdiff_t sum_width,
1471     const uint32_t scale, uint16_t* const sum3[3],
1472     uint32_t* const square_sum3[3], uint32x4_t sq[8], uint8x16_t ma[2],
1473     uint32x4_t b[6]) {
1474   uint16x8_t s3[4], sum[2], index[2];
1475   uint32x4_t sq3[3][2];
1476 
1477   Square(s[2], sq + 4);
1478   s3[2] = Sum3Horizontal16(s + 1);
1479   s3[3] = Sum3Horizontal16(s + 2);
1480   StoreAligned32U16(sum3[2] + x, s3 + 2);
1481   Sum3Horizontal32(sq + 2, sq3[2]);
1482   StoreAligned32U32(square_sum3[2] + x + 0, sq3[2]);
1483   LoadAligned16x2U16(sum3, x, s3);
1484   LoadAligned32x2U32(square_sum3, x, sq3);
1485   CalculateSumAndIndex3(s3, sq3, scale, &sum[0], &index[0]);
1486 
1487   Square(s[3], sq + 6);
1488   Sum3Horizontal32(sq + 4, sq3[2]);
1489   StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]);
1490   LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3 + 1);
1491   LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3);
1492   CalculateSumAndIndex3(s3 + 1, sq3, scale, &sum[1], &index[1]);
1493   CalculateIntermediate(sum, index, ma, b + 2);
1494 }
1495 
BoxFilterPreProcessLo(const uint16x8_t s[2][4],const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint32x4_t sq[2][8],uint8x16_t ma3[2][2],uint32x4_t b3[2][6],uint8x16_t * const ma5,uint32x4_t b5[2])1496 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo(
1497     const uint16x8_t s[2][4], const uint16_t scales[2], uint16_t* const sum3[4],
1498     uint16_t* const sum5[5], uint32_t* const square_sum3[4],
1499     uint32_t* const square_sum5[5], uint32x4_t sq[2][8], uint8x16_t ma3[2][2],
1500     uint32x4_t b3[2][6], uint8x16_t* const ma5, uint32x4_t b5[2]) {
1501   uint16x8_t s3[4], s5[5], sum[2], index[2];
1502   uint32x4_t sq3[4][2], sq5[5][2];
1503 
1504   Square(s[0][1], sq[0] + 2);
1505   Square(s[1][1], sq[1] + 2);
1506   SumHorizontal16(s[0], &s3[2], &s5[3]);
1507   SumHorizontal16(s[1], &s3[3], &s5[4]);
1508   vst1q_u16(sum3[2], s3[2]);
1509   vst1q_u16(sum3[3], s3[3]);
1510   vst1q_u16(sum5[3], s5[3]);
1511   vst1q_u16(sum5[4], s5[4]);
1512   SumHorizontal32(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1513   StoreAligned32U32(square_sum3[2], sq3[2]);
1514   StoreAligned32U32(square_sum5[3], sq5[3]);
1515   SumHorizontal32(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
1516   StoreAligned32U32(square_sum3[3], sq3[3]);
1517   StoreAligned32U32(square_sum5[4], sq5[4]);
1518   LoadAligned16x2U16(sum3, 0, s3);
1519   LoadAligned32x2U32(square_sum3, 0, sq3);
1520   LoadAligned16x3U16(sum5, 0, s5);
1521   LoadAligned32x3U32(square_sum5, 0, sq5);
1522   CalculateSumAndIndex3(s3 + 0, sq3 + 0, scales[1], &sum[0], &index[0]);
1523   CalculateSumAndIndex3(s3 + 1, sq3 + 1, scales[1], &sum[1], &index[1]);
1524   CalculateIntermediate(sum, index, &ma3[0][0], b3[0], b3[1]);
1525   ma3[1][0] = vextq_u8(ma3[0][0], vdupq_n_u8(0), 8);
1526   CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5);
1527 }
1528 
BoxFilterPreProcess(const uint16x8_t s[2][4],const ptrdiff_t x,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint32x4_t sq[2][8],uint8x16_t ma3[2][2],uint32x4_t b3[2][6],uint8x16_t ma5[2],uint32x4_t b5[6])1529 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
1530     const uint16x8_t s[2][4], const ptrdiff_t x, const uint16_t scales[2],
1531     uint16_t* const sum3[4], uint16_t* const sum5[5],
1532     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1533     const ptrdiff_t sum_width, uint32x4_t sq[2][8], uint8x16_t ma3[2][2],
1534     uint32x4_t b3[2][6], uint8x16_t ma5[2], uint32x4_t b5[6]) {
1535   uint16x8_t s3[2][4], s5[2][5], sum[2][2], index[2][2];
1536   uint32x4_t sq3[4][2], sq5[5][2];
1537 
1538   SumHorizontal16(s[0] + 1, &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]);
1539   vst1q_u16(sum3[2] + x + 0, s3[0][2]);
1540   vst1q_u16(sum3[2] + x + 8, s3[1][2]);
1541   vst1q_u16(sum5[3] + x + 0, s5[0][3]);
1542   vst1q_u16(sum5[3] + x + 8, s5[1][3]);
1543   SumHorizontal16(s[1] + 1, &s3[0][3], &s3[1][3], &s5[0][4], &s5[1][4]);
1544   vst1q_u16(sum3[3] + x + 0, s3[0][3]);
1545   vst1q_u16(sum3[3] + x + 8, s3[1][3]);
1546   vst1q_u16(sum5[4] + x + 0, s5[0][4]);
1547   vst1q_u16(sum5[4] + x + 8, s5[1][4]);
1548   Square(s[0][2], sq[0] + 4);
1549   Square(s[1][2], sq[1] + 4);
1550   SumHorizontal32(sq[0] + 2, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1551   StoreAligned32U32(square_sum3[2] + x, sq3[2]);
1552   StoreAligned32U32(square_sum5[3] + x, sq5[3]);
1553   SumHorizontal32(sq[1] + 2, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
1554   StoreAligned32U32(square_sum3[3] + x, sq3[3]);
1555   StoreAligned32U32(square_sum5[4] + x, sq5[4]);
1556   LoadAligned16x2U16(sum3, x, s3[0]);
1557   LoadAligned32x2U32(square_sum3, x, sq3);
1558   CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum[0][0], &index[0][0]);
1559   CalculateSumAndIndex3(s3[0] + 1, sq3 + 1, scales[1], &sum[1][0],
1560                         &index[1][0]);
1561   LoadAligned16x3U16(sum5, x, s5[0]);
1562   LoadAligned32x3U32(square_sum5, x, sq5);
1563   CalculateIntermediate5<8>(s5[0], sq5, scales[0], &ma5[0], b5 + 2);
1564 
1565   Square(s[0][3], sq[0] + 6);
1566   Square(s[1][3], sq[1] + 6);
1567   SumHorizontal32(sq[0] + 4, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1568   StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]);
1569   StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]);
1570   SumHorizontal32(sq[1] + 4, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
1571   StoreAligned32U32(square_sum3[3] + x + 8, sq3[3]);
1572   StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]);
1573   LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3[1]);
1574   LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3);
1575   CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum[0][1], &index[0][1]);
1576   CalculateSumAndIndex3(s3[1] + 1, sq3 + 1, scales[1], &sum[1][1],
1577                         &index[1][1]);
1578   CalculateIntermediate(sum[0], index[0], ma3[0], b3[0] + 2);
1579   CalculateIntermediate(sum[1], index[1], ma3[1], b3[1] + 2);
1580   LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]);
1581   LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5);
1582   CalculateIntermediate5<0>(s5[1], sq5, scales[0], &ma5[1], b5 + 4);
1583 }
1584 
BoxFilterPreProcessLastRowLo(const uint16x8_t s[2],const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],uint32x4_t sq[4],uint8x16_t * const ma3,uint8x16_t * const ma5,uint32x4_t b3[2],uint32x4_t b5[2])1585 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo(
1586     const uint16x8_t s[2], const uint16_t scales[2],
1587     const uint16_t* const sum3[4], const uint16_t* const sum5[5],
1588     const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
1589     uint32x4_t sq[4], uint8x16_t* const ma3, uint8x16_t* const ma5,
1590     uint32x4_t b3[2], uint32x4_t b5[2]) {
1591   uint16x8_t s3[3], s5[5];
1592   uint32x4_t sq3[3][2], sq5[5][2];
1593 
1594   Square(s[1], sq + 2);
1595   SumHorizontal16(s, &s3[2], &s5[3]);
1596   SumHorizontal32(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1597   LoadAligned16x3U16(sum5, 0, s5);
1598   s5[4] = s5[3];
1599   LoadAligned32x3U32(square_sum5, 0, sq5);
1600   sq5[4][0] = sq5[3][0];
1601   sq5[4][1] = sq5[3][1];
1602   CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5);
1603   LoadAligned16x2U16(sum3, 0, s3);
1604   LoadAligned32x2U32(square_sum3, 0, sq3);
1605   CalculateIntermediate3(s3, sq3, scales[1], ma3, b3);
1606 }
1607 
BoxFilterPreProcessLastRow(const uint16x8_t s[4],const ptrdiff_t sum_width,const ptrdiff_t x,const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],uint32x4_t sq[8],uint8x16_t ma3[2],uint8x16_t ma5[2],uint32x4_t b3[6],uint32x4_t b5[6])1608 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
1609     const uint16x8_t s[4], const ptrdiff_t sum_width, const ptrdiff_t x,
1610     const uint16_t scales[2], const uint16_t* const sum3[4],
1611     const uint16_t* const sum5[5], const uint32_t* const square_sum3[4],
1612     const uint32_t* const square_sum5[5], uint32x4_t sq[8], uint8x16_t ma3[2],
1613     uint8x16_t ma5[2], uint32x4_t b3[6], uint32x4_t b5[6]) {
1614   uint16x8_t s3[2][3], s5[2][5], sum[2], index[2];
1615   uint32x4_t sq3[3][2], sq5[5][2];
1616 
1617   Square(s[2], sq + 4);
1618   SumHorizontal16(s + 1, &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]);
1619   SumHorizontal32(sq + 2, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1620   LoadAligned16x3U16(sum5, x, s5[0]);
1621   s5[0][4] = s5[0][3];
1622   LoadAligned32x3U32(square_sum5, x, sq5);
1623   sq5[4][0] = sq5[3][0];
1624   sq5[4][1] = sq5[3][1];
1625   CalculateIntermediate5<8>(s5[0], sq5, scales[0], ma5, b5 + 2);
1626   LoadAligned16x2U16(sum3, x, s3[0]);
1627   LoadAligned32x2U32(square_sum3, x, sq3);
1628   CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum[0], &index[0]);
1629 
1630   Square(s[3], sq + 6);
1631   SumHorizontal32(sq + 4, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1632   LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]);
1633   s5[1][4] = s5[1][3];
1634   LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5);
1635   sq5[4][0] = sq5[3][0];
1636   sq5[4][1] = sq5[3][1];
1637   CalculateIntermediate5<0>(s5[1], sq5, scales[0], ma5 + 1, b5 + 4);
1638   LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3[1]);
1639   LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3);
1640   CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum[1], &index[1]);
1641   CalculateIntermediate(sum, index, ma3, b3 + 2);
1642 }
1643 
BoxSumFilterPreProcess5(const uint16_t * const src0,const uint16_t * const src1,const int width,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * ma565,uint32_t * b565)1644 inline void BoxSumFilterPreProcess5(const uint16_t* const src0,
1645                                     const uint16_t* const src1, const int width,
1646                                     const uint32_t scale,
1647                                     uint16_t* const sum5[5],
1648                                     uint32_t* const square_sum5[5],
1649                                     const ptrdiff_t sum_width, uint16_t* ma565,
1650                                     uint32_t* b565) {
1651   const ptrdiff_t overread_in_bytes =
1652       kOverreadInBytesPass1 - sizeof(*src0) * width;
1653   uint16x8_t s[2][4];
1654   uint8x16_t mas[2];
1655   uint32x4_t sq[2][8], bs[6];
1656 
1657   s[0][0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
1658   s[0][1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
1659   s[1][0] = Load1QMsanU16(src1 + 0, overread_in_bytes + 0);
1660   s[1][1] = Load1QMsanU16(src1 + 8, overread_in_bytes + 16);
1661   Square(s[0][0], sq[0]);
1662   Square(s[1][0], sq[1]);
1663   BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], bs);
1664 
1665   int x = 0;
1666   do {
1667     uint8x16_t ma5[3];
1668     uint16x8_t ma[2];
1669     uint32x4_t b[4];
1670 
1671     s[0][2] = Load1QMsanU16(src0 + x + 16,
1672                             overread_in_bytes + sizeof(*src0) * (x + 16));
1673     s[0][3] = Load1QMsanU16(src0 + x + 24,
1674                             overread_in_bytes + sizeof(*src0) * (x + 24));
1675     s[1][2] = Load1QMsanU16(src1 + x + 16,
1676                             overread_in_bytes + sizeof(*src1) * (x + 16));
1677     s[1][3] = Load1QMsanU16(src1 + x + 24,
1678                             overread_in_bytes + sizeof(*src1) * (x + 24));
1679 
1680     BoxFilterPreProcess5(s, sum_width, x + 8, scale, sum5, square_sum5, sq, mas,
1681                          bs);
1682     Prepare3_8<0>(mas, ma5);
1683     ma[0] = Sum565Lo(ma5);
1684     ma[1] = Sum565Hi(ma5);
1685     StoreAligned32U16(ma565, ma);
1686     Sum565(bs + 0, b + 0);
1687     Sum565(bs + 2, b + 2);
1688     StoreAligned64U32(b565, b);
1689     s[0][0] = s[0][2];
1690     s[0][1] = s[0][3];
1691     s[1][0] = s[1][2];
1692     s[1][1] = s[1][3];
1693     sq[0][2] = sq[0][6];
1694     sq[0][3] = sq[0][7];
1695     sq[1][2] = sq[1][6];
1696     sq[1][3] = sq[1][7];
1697     mas[0] = mas[1];
1698     bs[0] = bs[4];
1699     bs[1] = bs[5];
1700     ma565 += 16;
1701     b565 += 16;
1702     x += 16;
1703   } while (x < width);
1704 }
1705 
1706 template <bool calculate444>
BoxSumFilterPreProcess3(const uint16_t * const src,const int width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],const ptrdiff_t sum_width,uint16_t * ma343,uint16_t * ma444,uint32_t * b343,uint32_t * b444)1707 LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
1708     const uint16_t* const src, const int width, const uint32_t scale,
1709     uint16_t* const sum3[3], uint32_t* const square_sum3[3],
1710     const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343,
1711     uint32_t* b444) {
1712   const ptrdiff_t overread_in_bytes =
1713       kOverreadInBytesPass2 - sizeof(*src) * width;
1714   uint16x8_t s[4];
1715   uint8x16_t mas[2];
1716   uint32x4_t sq[8], bs[6];
1717 
1718   s[0] = Load1QMsanU16(src + 0, overread_in_bytes + 0);
1719   s[1] = Load1QMsanU16(src + 8, overread_in_bytes + 16);
1720   Square(s[0], sq);
1721   BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq, &mas[0], bs);
1722 
1723   int x = 0;
1724   do {
1725     s[2] = Load1QMsanU16(src + x + 16,
1726                          overread_in_bytes + sizeof(*src) * (x + 16));
1727     s[3] = Load1QMsanU16(src + x + 24,
1728                          overread_in_bytes + sizeof(*src) * (x + 24));
1729     BoxFilterPreProcess3(s, x + 8, sum_width, scale, sum3, square_sum3, sq, mas,
1730                          bs);
1731     uint8x16_t ma3[3];
1732     Prepare3_8<0>(mas, ma3);
1733     if (calculate444) {  // NOLINT(readability-simplify-boolean-expr)
1734       Store343_444Lo(ma3, bs + 0, 0, ma343, ma444, b343, b444);
1735       Store343_444Hi(ma3, bs + 2, 8, ma343, ma444, b343, b444);
1736       ma444 += 16;
1737       b444 += 16;
1738     } else {
1739       uint16x8_t ma[2];
1740       uint32x4_t b[4];
1741       ma[0] = Sum343Lo(ma3);
1742       ma[1] = Sum343Hi(ma3);
1743       StoreAligned32U16(ma343, ma);
1744       Sum343(bs + 0, b + 0);
1745       Sum343(bs + 2, b + 2);
1746       StoreAligned64U32(b343, b);
1747     }
1748     s[1] = s[3];
1749     sq[2] = sq[6];
1750     sq[3] = sq[7];
1751     mas[0] = mas[1];
1752     bs[0] = bs[4];
1753     bs[1] = bs[5];
1754     ma343 += 16;
1755     b343 += 16;
1756     x += 16;
1757   } while (x < width);
1758 }
1759 
BoxSumFilterPreProcess(const uint16_t * const src0,const uint16_t * const src1,const int width,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * const ma343[4],uint16_t * const ma444,uint16_t * ma565,uint32_t * const b343[4],uint32_t * const b444,uint32_t * b565)1760 inline void BoxSumFilterPreProcess(
1761     const uint16_t* const src0, const uint16_t* const src1, const int width,
1762     const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
1763     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1764     const ptrdiff_t sum_width, uint16_t* const ma343[4], uint16_t* const ma444,
1765     uint16_t* ma565, uint32_t* const b343[4], uint32_t* const b444,
1766     uint32_t* b565) {
1767   const ptrdiff_t overread_in_bytes =
1768       kOverreadInBytesPass1 - sizeof(*src0) * width;
1769   uint16x8_t s[2][4];
1770   uint8x16_t ma3[2][2], ma5[2];
1771   uint32x4_t sq[2][8], b3[2][6], b5[6];
1772 
1773   s[0][0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
1774   s[0][1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
1775   s[1][0] = Load1QMsanU16(src1 + 0, overread_in_bytes + 0);
1776   s[1][1] = Load1QMsanU16(src1 + 8, overread_in_bytes + 16);
1777   Square(s[0][0], sq[0]);
1778   Square(s[1][0], sq[1]);
1779   BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq,
1780                         ma3, b3, &ma5[0], b5);
1781 
1782   int x = 0;
1783   do {
1784     uint16x8_t ma[2];
1785     uint32x4_t b[4];
1786     uint8x16_t ma3x[3], ma5x[3];
1787 
1788     s[0][2] = Load1QMsanU16(src0 + x + 16,
1789                             overread_in_bytes + sizeof(*src0) * (x + 16));
1790     s[0][3] = Load1QMsanU16(src0 + x + 24,
1791                             overread_in_bytes + sizeof(*src0) * (x + 24));
1792     s[1][2] = Load1QMsanU16(src1 + x + 16,
1793                             overread_in_bytes + sizeof(*src1) * (x + 16));
1794     s[1][3] = Load1QMsanU16(src1 + x + 24,
1795                             overread_in_bytes + sizeof(*src1) * (x + 24));
1796     BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5,
1797                         sum_width, sq, ma3, b3, ma5, b5);
1798 
1799     Prepare3_8<0>(ma3[0], ma3x);
1800     ma[0] = Sum343Lo(ma3x);
1801     ma[1] = Sum343Hi(ma3x);
1802     StoreAligned32U16(ma343[0] + x, ma);
1803     Sum343(b3[0] + 0, b + 0);
1804     Sum343(b3[0] + 2, b + 2);
1805     StoreAligned64U32(b343[0] + x, b);
1806     Sum565(b5 + 0, b + 0);
1807     Sum565(b5 + 2, b + 2);
1808     StoreAligned64U32(b565, b);
1809     Prepare3_8<0>(ma3[1], ma3x);
1810     Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444);
1811     Store343_444Hi(ma3x, b3[1] + 2, x + 8, ma343[1], ma444, b343[1], b444);
1812     Prepare3_8<0>(ma5, ma5x);
1813     ma[0] = Sum565Lo(ma5x);
1814     ma[1] = Sum565Hi(ma5x);
1815     StoreAligned32U16(ma565, ma);
1816     s[0][0] = s[0][2];
1817     s[0][1] = s[0][3];
1818     s[1][0] = s[1][2];
1819     s[1][1] = s[1][3];
1820     sq[0][2] = sq[0][6];
1821     sq[0][3] = sq[0][7];
1822     sq[1][2] = sq[1][6];
1823     sq[1][3] = sq[1][7];
1824     ma3[0][0] = ma3[0][1];
1825     ma3[1][0] = ma3[1][1];
1826     ma5[0] = ma5[1];
1827     b3[0][0] = b3[0][4];
1828     b3[0][1] = b3[0][5];
1829     b3[1][0] = b3[1][4];
1830     b3[1][1] = b3[1][5];
1831     b5[0] = b5[4];
1832     b5[1] = b5[5];
1833     ma565 += 16;
1834     b565 += 16;
1835     x += 16;
1836   } while (x < width);
1837 }
1838 
1839 template <int shift>
FilterOutput(const uint32x4_t ma_x_src,const uint32x4_t b)1840 inline int16x4_t FilterOutput(const uint32x4_t ma_x_src, const uint32x4_t b) {
1841   // ma: 255 * 32 = 8160 (13 bits)
1842   // b: 65088 * 32 = 2082816 (21 bits)
1843   // v: b - ma * 255 (22 bits)
1844   const int32x4_t v = vreinterpretq_s32_u32(vsubq_u32(b, ma_x_src));
1845   // kSgrProjSgrBits = 8
1846   // kSgrProjRestoreBits = 4
1847   // shift = 4 or 5
1848   // v >> 8 or 9 (13 bits)
1849   return vqrshrn_n_s32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
1850 }
1851 
1852 template <int shift>
CalculateFilteredOutput(const uint16x8_t src,const uint16x8_t ma,const uint32x4_t b[2])1853 inline int16x8_t CalculateFilteredOutput(const uint16x8_t src,
1854                                          const uint16x8_t ma,
1855                                          const uint32x4_t b[2]) {
1856   const uint32x4_t ma_x_src_lo = VmullLo16(ma, src);
1857   const uint32x4_t ma_x_src_hi = VmullHi16(ma, src);
1858   const int16x4_t dst_lo = FilterOutput<shift>(ma_x_src_lo, b[0]);
1859   const int16x4_t dst_hi = FilterOutput<shift>(ma_x_src_hi, b[1]);
1860   return vcombine_s16(dst_lo, dst_hi);  // 13 bits
1861 }
1862 
CalculateFilteredOutputPass1(const uint16x8_t src,const uint16x8_t ma[2],const uint32x4_t b[2][2])1863 inline int16x8_t CalculateFilteredOutputPass1(const uint16x8_t src,
1864                                               const uint16x8_t ma[2],
1865                                               const uint32x4_t b[2][2]) {
1866   const uint16x8_t ma_sum = vaddq_u16(ma[0], ma[1]);
1867   uint32x4_t b_sum[2];
1868   b_sum[0] = vaddq_u32(b[0][0], b[1][0]);
1869   b_sum[1] = vaddq_u32(b[0][1], b[1][1]);
1870   return CalculateFilteredOutput<5>(src, ma_sum, b_sum);
1871 }
1872 
CalculateFilteredOutputPass2(const uint16x8_t src,const uint16x8_t ma[3],const uint32x4_t b[3][2])1873 inline int16x8_t CalculateFilteredOutputPass2(const uint16x8_t src,
1874                                               const uint16x8_t ma[3],
1875                                               const uint32x4_t b[3][2]) {
1876   const uint16x8_t ma_sum = Sum3_16(ma);
1877   uint32x4_t b_sum[2];
1878   Sum3_32(b, b_sum);
1879   return CalculateFilteredOutput<5>(src, ma_sum, b_sum);
1880 }
1881 
SelfGuidedFinal(const uint16x8_t src,const int32x4_t v[2])1882 inline int16x8_t SelfGuidedFinal(const uint16x8_t src, const int32x4_t v[2]) {
1883   const int16x4_t v_lo =
1884       vqrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
1885   const int16x4_t v_hi =
1886       vqrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
1887   const int16x8_t vv = vcombine_s16(v_lo, v_hi);
1888   return vaddq_s16(vreinterpretq_s16_u16(src), vv);
1889 }
1890 
SelfGuidedDoubleMultiplier(const uint16x8_t src,const int16x8_t filter[2],const int w0,const int w2)1891 inline int16x8_t SelfGuidedDoubleMultiplier(const uint16x8_t src,
1892                                             const int16x8_t filter[2],
1893                                             const int w0, const int w2) {
1894   int32x4_t v[2];
1895   v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0);
1896   v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0);
1897   v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2);
1898   v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2);
1899   return SelfGuidedFinal(src, v);
1900 }
1901 
SelfGuidedSingleMultiplier(const uint16x8_t src,const int16x8_t filter,const int w0)1902 inline int16x8_t SelfGuidedSingleMultiplier(const uint16x8_t src,
1903                                             const int16x8_t filter,
1904                                             const int w0) {
1905   // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
1906   int32x4_t v[2];
1907   v[0] = vmull_n_s16(vget_low_s16(filter), w0);
1908   v[1] = vmull_n_s16(vget_high_s16(filter), w0);
1909   return SelfGuidedFinal(src, v);
1910 }
1911 
ClipAndStore(uint16_t * const dst,const int16x8_t val)1912 inline void ClipAndStore(uint16_t* const dst, const int16x8_t val) {
1913   const uint16x8_t val0 = vreinterpretq_u16_s16(vmaxq_s16(val, vdupq_n_s16(0)));
1914   const uint16x8_t val1 = vminq_u16(val0, vdupq_n_u16((1 << kBitdepth10) - 1));
1915   vst1q_u16(dst, val1);
1916 }
1917 
BoxFilterPass1(const uint16_t * const src,const uint16_t * const src0,const uint16_t * const src1,const ptrdiff_t stride,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const ma565[2],uint32_t * const b565[2],uint16_t * const dst)1918 LIBGAV1_ALWAYS_INLINE void BoxFilterPass1(
1919     const uint16_t* const src, const uint16_t* const src0,
1920     const uint16_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5],
1921     uint32_t* const square_sum5[5], const int width, const ptrdiff_t sum_width,
1922     const uint32_t scale, const int16_t w0, uint16_t* const ma565[2],
1923     uint32_t* const b565[2], uint16_t* const dst) {
1924   const ptrdiff_t overread_in_bytes =
1925       kOverreadInBytesPass1 - sizeof(*src0) * width;
1926   uint16x8_t s[2][4];
1927   uint8x16_t mas[2];
1928   uint32x4_t sq[2][8], bs[6];
1929 
1930   s[0][0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
1931   s[0][1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
1932   s[1][0] = Load1QMsanU16(src1 + 0, overread_in_bytes + 0);
1933   s[1][1] = Load1QMsanU16(src1 + 8, overread_in_bytes + 16);
1934 
1935   Square(s[0][0], sq[0]);
1936   Square(s[1][0], sq[1]);
1937   BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], bs);
1938 
1939   int x = 0;
1940   do {
1941     uint16x8_t ma[2];
1942     uint32x4_t b[2][2];
1943     uint8x16_t ma5[3];
1944     int16x8_t p[2];
1945 
1946     s[0][2] = Load1QMsanU16(src0 + x + 16,
1947                             overread_in_bytes + sizeof(*src0) * (x + 16));
1948     s[0][3] = Load1QMsanU16(src0 + x + 24,
1949                             overread_in_bytes + sizeof(*src0) * (x + 24));
1950     s[1][2] = Load1QMsanU16(src1 + x + 16,
1951                             overread_in_bytes + sizeof(*src1) * (x + 16));
1952     s[1][3] = Load1QMsanU16(src1 + x + 24,
1953                             overread_in_bytes + sizeof(*src1) * (x + 24));
1954     BoxFilterPreProcess5(s, sum_width, x + 8, scale, sum5, square_sum5, sq, mas,
1955                          bs);
1956     Prepare3_8<0>(mas, ma5);
1957     ma[1] = Sum565Lo(ma5);
1958     vst1q_u16(ma565[1] + x, ma[1]);
1959     Sum565(bs, b[1]);
1960     StoreAligned32U32(b565[1] + x, b[1]);
1961     const uint16x8_t sr0_lo = vld1q_u16(src + x + 0);
1962     const uint16x8_t sr1_lo = vld1q_u16(src + stride + x + 0);
1963     ma[0] = vld1q_u16(ma565[0] + x);
1964     LoadAligned32U32(b565[0] + x, b[0]);
1965     p[0] = CalculateFilteredOutputPass1(sr0_lo, ma, b);
1966     p[1] = CalculateFilteredOutput<4>(sr1_lo, ma[1], b[1]);
1967     const int16x8_t d00 = SelfGuidedSingleMultiplier(sr0_lo, p[0], w0);
1968     const int16x8_t d10 = SelfGuidedSingleMultiplier(sr1_lo, p[1], w0);
1969 
1970     ma[1] = Sum565Hi(ma5);
1971     vst1q_u16(ma565[1] + x + 8, ma[1]);
1972     Sum565(bs + 2, b[1]);
1973     StoreAligned32U32(b565[1] + x + 8, b[1]);
1974     const uint16x8_t sr0_hi = vld1q_u16(src + x + 8);
1975     const uint16x8_t sr1_hi = vld1q_u16(src + stride + x + 8);
1976     ma[0] = vld1q_u16(ma565[0] + x + 8);
1977     LoadAligned32U32(b565[0] + x + 8, b[0]);
1978     p[0] = CalculateFilteredOutputPass1(sr0_hi, ma, b);
1979     p[1] = CalculateFilteredOutput<4>(sr1_hi, ma[1], b[1]);
1980     const int16x8_t d01 = SelfGuidedSingleMultiplier(sr0_hi, p[0], w0);
1981     ClipAndStore(dst + x + 0, d00);
1982     ClipAndStore(dst + x + 8, d01);
1983     const int16x8_t d11 = SelfGuidedSingleMultiplier(sr1_hi, p[1], w0);
1984     ClipAndStore(dst + stride + x + 0, d10);
1985     ClipAndStore(dst + stride + x + 8, d11);
1986     s[0][0] = s[0][2];
1987     s[0][1] = s[0][3];
1988     s[1][0] = s[1][2];
1989     s[1][1] = s[1][3];
1990     sq[0][2] = sq[0][6];
1991     sq[0][3] = sq[0][7];
1992     sq[1][2] = sq[1][6];
1993     sq[1][3] = sq[1][7];
1994     mas[0] = mas[1];
1995     bs[0] = bs[4];
1996     bs[1] = bs[5];
1997     x += 16;
1998   } while (x < width);
1999 }
2000 
BoxFilterPass1LastRow(const uint16_t * const src,const uint16_t * const src0,const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16_t * ma565,uint32_t * b565,uint16_t * const dst)2001 inline void BoxFilterPass1LastRow(
2002     const uint16_t* const src, const uint16_t* const src0, const int width,
2003     const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0,
2004     uint16_t* const sum5[5], uint32_t* const square_sum5[5], uint16_t* ma565,
2005     uint32_t* b565, uint16_t* const dst) {
2006   const ptrdiff_t overread_in_bytes =
2007       kOverreadInBytesPass1 - sizeof(*src0) * width;
2008   uint16x8_t s[4];
2009   uint8x16_t mas[2];
2010   uint32x4_t sq[8], bs[6];
2011 
2012   s[0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
2013   s[1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
2014   Square(s[0], sq);
2015   BoxFilterPreProcess5LastRowLo(s, scale, sum5, square_sum5, sq, &mas[0], bs);
2016 
2017   int x = 0;
2018   do {
2019     uint16x8_t ma[2];
2020     uint32x4_t b[2][2];
2021     uint8x16_t ma5[3];
2022 
2023     s[2] = Load1QMsanU16(src0 + x + 16,
2024                          overread_in_bytes + sizeof(*src0) * (x + 16));
2025     s[3] = Load1QMsanU16(src0 + x + 24,
2026                          overread_in_bytes + sizeof(*src0) * (x + 24));
2027     BoxFilterPreProcess5LastRow(s, sum_width, x + 8, scale, sum5, square_sum5,
2028                                 sq, mas, bs);
2029     Prepare3_8<0>(mas, ma5);
2030     ma[1] = Sum565Lo(ma5);
2031     Sum565(bs, b[1]);
2032     ma[0] = vld1q_u16(ma565);
2033     LoadAligned32U32(b565, b[0]);
2034     const uint16x8_t sr_lo = vld1q_u16(src + x + 0);
2035     int16x8_t p = CalculateFilteredOutputPass1(sr_lo, ma, b);
2036     const int16x8_t d0 = SelfGuidedSingleMultiplier(sr_lo, p, w0);
2037 
2038     ma[1] = Sum565Hi(ma5);
2039     Sum565(bs + 2, b[1]);
2040     ma[0] = vld1q_u16(ma565 + 8);
2041     LoadAligned32U32(b565 + 8, b[0]);
2042     const uint16x8_t sr_hi = vld1q_u16(src + x + 8);
2043     p = CalculateFilteredOutputPass1(sr_hi, ma, b);
2044     const int16x8_t d1 = SelfGuidedSingleMultiplier(sr_hi, p, w0);
2045     ClipAndStore(dst + x + 0, d0);
2046     ClipAndStore(dst + x + 8, d1);
2047     s[1] = s[3];
2048     sq[2] = sq[6];
2049     sq[3] = sq[7];
2050     mas[0] = mas[1];
2051     bs[0] = bs[4];
2052     bs[1] = bs[5];
2053     ma565 += 16;
2054     b565 += 16;
2055     x += 16;
2056   } while (x < width);
2057 }
2058 
BoxFilterPass2(const uint16_t * const src,const uint16_t * const src0,const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16_t * const ma343[3],uint16_t * const ma444[2],uint32_t * const b343[3],uint32_t * const b444[2],uint16_t * const dst)2059 LIBGAV1_ALWAYS_INLINE void BoxFilterPass2(
2060     const uint16_t* const src, const uint16_t* const src0, const int width,
2061     const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0,
2062     uint16_t* const sum3[3], uint32_t* const square_sum3[3],
2063     uint16_t* const ma343[3], uint16_t* const ma444[2], uint32_t* const b343[3],
2064     uint32_t* const b444[2], uint16_t* const dst) {
2065   const ptrdiff_t overread_in_bytes =
2066       kOverreadInBytesPass2 - sizeof(*src0) * width;
2067   uint16x8_t s[4];
2068   uint8x16_t mas[2];
2069   uint32x4_t sq[8], bs[6];
2070 
2071   s[0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
2072   s[1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
2073   Square(s[0], sq);
2074   BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq, &mas[0], bs);
2075 
2076   int x = 0;
2077   do {
2078     s[2] = Load1QMsanU16(src0 + x + 16,
2079                          overread_in_bytes + sizeof(*src0) * (x + 16));
2080     s[3] = Load1QMsanU16(src0 + x + 24,
2081                          overread_in_bytes + sizeof(*src0) * (x + 24));
2082     BoxFilterPreProcess3(s, x + 8, sum_width, scale, sum3, square_sum3, sq, mas,
2083                          bs);
2084     uint16x8_t ma[3];
2085     uint32x4_t b[3][2];
2086     uint8x16_t ma3[3];
2087 
2088     Prepare3_8<0>(mas, ma3);
2089     Store343_444Lo(ma3, bs + 0, x, &ma[2], b[2], ma343[2], ma444[1], b343[2],
2090                    b444[1]);
2091     const uint16x8_t sr_lo = vld1q_u16(src + x + 0);
2092     ma[0] = vld1q_u16(ma343[0] + x);
2093     ma[1] = vld1q_u16(ma444[0] + x);
2094     LoadAligned32U32(b343[0] + x, b[0]);
2095     LoadAligned32U32(b444[0] + x, b[1]);
2096     const int16x8_t p0 = CalculateFilteredOutputPass2(sr_lo, ma, b);
2097 
2098     Store343_444Hi(ma3, bs + 2, x + 8, &ma[2], b[2], ma343[2], ma444[1],
2099                    b343[2], b444[1]);
2100     const uint16x8_t sr_hi = vld1q_u16(src + x + 8);
2101     ma[0] = vld1q_u16(ma343[0] + x + 8);
2102     ma[1] = vld1q_u16(ma444[0] + x + 8);
2103     LoadAligned32U32(b343[0] + x + 8, b[0]);
2104     LoadAligned32U32(b444[0] + x + 8, b[1]);
2105     const int16x8_t p1 = CalculateFilteredOutputPass2(sr_hi, ma, b);
2106     const int16x8_t d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0);
2107     const int16x8_t d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0);
2108     ClipAndStore(dst + x + 0, d0);
2109     ClipAndStore(dst + x + 8, d1);
2110     s[1] = s[3];
2111     sq[2] = sq[6];
2112     sq[3] = sq[7];
2113     mas[0] = mas[1];
2114     bs[0] = bs[4];
2115     bs[1] = bs[5];
2116     x += 16;
2117   } while (x < width);
2118 }
2119 
BoxFilter(const uint16_t * const src,const uint16_t * const src0,const uint16_t * const src1,const ptrdiff_t stride,const int width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * const ma343[4],uint16_t * const ma444[3],uint16_t * const ma565[2],uint32_t * const b343[4],uint32_t * const b444[3],uint32_t * const b565[2],uint16_t * const dst)2120 LIBGAV1_ALWAYS_INLINE void BoxFilter(
2121     const uint16_t* const src, const uint16_t* const src0,
2122     const uint16_t* const src1, const ptrdiff_t stride, const int width,
2123     const uint16_t scales[2], const int16_t w0, const int16_t w2,
2124     uint16_t* const sum3[4], uint16_t* const sum5[5],
2125     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2126     const ptrdiff_t sum_width, uint16_t* const ma343[4],
2127     uint16_t* const ma444[3], uint16_t* const ma565[2], uint32_t* const b343[4],
2128     uint32_t* const b444[3], uint32_t* const b565[2], uint16_t* const dst) {
2129   const ptrdiff_t overread_in_bytes =
2130       kOverreadInBytesPass1 - sizeof(*src0) * width;
2131   uint16x8_t s[2][4];
2132   uint8x16_t ma3[2][2], ma5[2];
2133   uint32x4_t sq[2][8], b3[2][6], b5[6];
2134 
2135   s[0][0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
2136   s[0][1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
2137   s[1][0] = Load1QMsanU16(src1 + 0, overread_in_bytes + 0);
2138   s[1][1] = Load1QMsanU16(src1 + 8, overread_in_bytes + 16);
2139   Square(s[0][0], sq[0]);
2140   Square(s[1][0], sq[1]);
2141   BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq,
2142                         ma3, b3, &ma5[0], b5);
2143 
2144   int x = 0;
2145   do {
2146     uint16x8_t ma[3][3];
2147     uint32x4_t b[3][3][2];
2148     uint8x16_t ma3x[2][3], ma5x[3];
2149     int16x8_t p[2][2];
2150 
2151     s[0][2] = Load1QMsanU16(src0 + x + 16,
2152                             overread_in_bytes + sizeof(*src0) * (x + 16));
2153     s[0][3] = Load1QMsanU16(src0 + x + 24,
2154                             overread_in_bytes + sizeof(*src0) * (x + 24));
2155     s[1][2] = Load1QMsanU16(src1 + x + 16,
2156                             overread_in_bytes + sizeof(*src1) * (x + 16));
2157     s[1][3] = Load1QMsanU16(src1 + x + 24,
2158                             overread_in_bytes + sizeof(*src1) * (x + 24));
2159 
2160     BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5,
2161                         sum_width, sq, ma3, b3, ma5, b5);
2162     Prepare3_8<0>(ma3[0], ma3x[0]);
2163     Prepare3_8<0>(ma3[1], ma3x[1]);
2164     Prepare3_8<0>(ma5, ma5x);
2165     Store343_444Lo(ma3x[0], b3[0], x, &ma[1][2], &ma[2][1], b[1][2], b[2][1],
2166                    ma343[2], ma444[1], b343[2], b444[1]);
2167     Store343_444Lo(ma3x[1], b3[1], x, &ma[2][2], b[2][2], ma343[3], ma444[2],
2168                    b343[3], b444[2]);
2169     ma[0][1] = Sum565Lo(ma5x);
2170     vst1q_u16(ma565[1] + x, ma[0][1]);
2171     Sum565(b5, b[0][1]);
2172     StoreAligned32U32(b565[1] + x, b[0][1]);
2173     const uint16x8_t sr0_lo = vld1q_u16(src + x);
2174     const uint16x8_t sr1_lo = vld1q_u16(src + stride + x);
2175     ma[0][0] = vld1q_u16(ma565[0] + x);
2176     LoadAligned32U32(b565[0] + x, b[0][0]);
2177     p[0][0] = CalculateFilteredOutputPass1(sr0_lo, ma[0], b[0]);
2178     p[1][0] = CalculateFilteredOutput<4>(sr1_lo, ma[0][1], b[0][1]);
2179     ma[1][0] = vld1q_u16(ma343[0] + x);
2180     ma[1][1] = vld1q_u16(ma444[0] + x);
2181     LoadAligned32U32(b343[0] + x, b[1][0]);
2182     LoadAligned32U32(b444[0] + x, b[1][1]);
2183     p[0][1] = CalculateFilteredOutputPass2(sr0_lo, ma[1], b[1]);
2184     const int16x8_t d00 = SelfGuidedDoubleMultiplier(sr0_lo, p[0], w0, w2);
2185     ma[2][0] = vld1q_u16(ma343[1] + x);
2186     LoadAligned32U32(b343[1] + x, b[2][0]);
2187     p[1][1] = CalculateFilteredOutputPass2(sr1_lo, ma[2], b[2]);
2188     const int16x8_t d10 = SelfGuidedDoubleMultiplier(sr1_lo, p[1], w0, w2);
2189 
2190     Store343_444Hi(ma3x[0], b3[0] + 2, x + 8, &ma[1][2], &ma[2][1], b[1][2],
2191                    b[2][1], ma343[2], ma444[1], b343[2], b444[1]);
2192     Store343_444Hi(ma3x[1], b3[1] + 2, x + 8, &ma[2][2], b[2][2], ma343[3],
2193                    ma444[2], b343[3], b444[2]);
2194     ma[0][1] = Sum565Hi(ma5x);
2195     vst1q_u16(ma565[1] + x + 8, ma[0][1]);
2196     Sum565(b5 + 2, b[0][1]);
2197     StoreAligned32U32(b565[1] + x + 8, b[0][1]);
2198     const uint16x8_t sr0_hi = Load1QMsanU16(
2199         src + x + 8, overread_in_bytes + 4 + sizeof(*src) * (x + 8));
2200     const uint16x8_t sr1_hi = Load1QMsanU16(
2201         src + stride + x + 8, overread_in_bytes + 4 + sizeof(*src) * (x + 8));
2202     ma[0][0] = vld1q_u16(ma565[0] + x + 8);
2203     LoadAligned32U32(b565[0] + x + 8, b[0][0]);
2204     p[0][0] = CalculateFilteredOutputPass1(sr0_hi, ma[0], b[0]);
2205     p[1][0] = CalculateFilteredOutput<4>(sr1_hi, ma[0][1], b[0][1]);
2206     ma[1][0] = vld1q_u16(ma343[0] + x + 8);
2207     ma[1][1] = vld1q_u16(ma444[0] + x + 8);
2208     LoadAligned32U32(b343[0] + x + 8, b[1][0]);
2209     LoadAligned32U32(b444[0] + x + 8, b[1][1]);
2210     p[0][1] = CalculateFilteredOutputPass2(sr0_hi, ma[1], b[1]);
2211     const int16x8_t d01 = SelfGuidedDoubleMultiplier(sr0_hi, p[0], w0, w2);
2212     ClipAndStore(dst + x + 0, d00);
2213     ClipAndStore(dst + x + 8, d01);
2214     ma[2][0] = vld1q_u16(ma343[1] + x + 8);
2215     LoadAligned32U32(b343[1] + x + 8, b[2][0]);
2216     p[1][1] = CalculateFilteredOutputPass2(sr1_hi, ma[2], b[2]);
2217     const int16x8_t d11 = SelfGuidedDoubleMultiplier(sr1_hi, p[1], w0, w2);
2218     ClipAndStore(dst + stride + x + 0, d10);
2219     ClipAndStore(dst + stride + x + 8, d11);
2220     s[0][0] = s[0][2];
2221     s[0][1] = s[0][3];
2222     s[1][0] = s[1][2];
2223     s[1][1] = s[1][3];
2224     sq[0][2] = sq[0][6];
2225     sq[0][3] = sq[0][7];
2226     sq[1][2] = sq[1][6];
2227     sq[1][3] = sq[1][7];
2228     ma3[0][0] = ma3[0][1];
2229     ma3[1][0] = ma3[1][1];
2230     ma5[0] = ma5[1];
2231     b3[0][0] = b3[0][4];
2232     b3[0][1] = b3[0][5];
2233     b3[1][0] = b3[1][4];
2234     b3[1][1] = b3[1][5];
2235     b5[0] = b5[4];
2236     b5[1] = b5[5];
2237     x += 16;
2238   } while (x < width);
2239 }
2240 
BoxFilterLastRow(const uint16_t * const src,const uint16_t * const src0,const int width,const ptrdiff_t sum_width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343,uint16_t * const ma444,uint16_t * const ma565,uint32_t * const b343,uint32_t * const b444,uint32_t * const b565,uint16_t * const dst)2241 inline void BoxFilterLastRow(
2242     const uint16_t* const src, const uint16_t* const src0, const int width,
2243     const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0,
2244     const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5],
2245     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2246     uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565,
2247     uint32_t* const b343, uint32_t* const b444, uint32_t* const b565,
2248     uint16_t* const dst) {
2249   const ptrdiff_t overread_in_bytes =
2250       kOverreadInBytesPass1 - sizeof(*src0) * width;
2251   uint16x8_t s[4];
2252   uint8x16_t ma3[2], ma5[2];
2253   uint32x4_t sq[8], b3[6], b5[6];
2254   uint16x8_t ma[3];
2255   uint32x4_t b[3][2];
2256 
2257   s[0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
2258   s[1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
2259   Square(s[0], sq);
2260   BoxFilterPreProcessLastRowLo(s, scales, sum3, sum5, square_sum3, square_sum5,
2261                                sq, &ma3[0], &ma5[0], b3, b5);
2262 
2263   int x = 0;
2264   do {
2265     uint8x16_t ma3x[3], ma5x[3];
2266     int16x8_t p[2];
2267 
2268     s[2] = Load1QMsanU16(src0 + x + 16,
2269                          overread_in_bytes + sizeof(*src0) * (x + 16));
2270     s[3] = Load1QMsanU16(src0 + x + 24,
2271                          overread_in_bytes + sizeof(*src0) * (x + 24));
2272     BoxFilterPreProcessLastRow(s, sum_width, x + 8, scales, sum3, sum5,
2273                                square_sum3, square_sum5, sq, ma3, ma5, b3, b5);
2274     Prepare3_8<0>(ma3, ma3x);
2275     Prepare3_8<0>(ma5, ma5x);
2276     ma[1] = Sum565Lo(ma5x);
2277     Sum565(b5, b[1]);
2278     ma[2] = Sum343Lo(ma3x);
2279     Sum343(b3, b[2]);
2280     const uint16x8_t sr_lo = vld1q_u16(src + x + 0);
2281     ma[0] = vld1q_u16(ma565 + x);
2282     LoadAligned32U32(b565 + x, b[0]);
2283     p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b);
2284     ma[0] = vld1q_u16(ma343 + x);
2285     ma[1] = vld1q_u16(ma444 + x);
2286     LoadAligned32U32(b343 + x, b[0]);
2287     LoadAligned32U32(b444 + x, b[1]);
2288     p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b);
2289     const int16x8_t d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2);
2290 
2291     ma[1] = Sum565Hi(ma5x);
2292     Sum565(b5 + 2, b[1]);
2293     ma[2] = Sum343Hi(ma3x);
2294     Sum343(b3 + 2, b[2]);
2295     const uint16x8_t sr_hi = Load1QMsanU16(
2296         src + x + 8, overread_in_bytes + 4 + sizeof(*src) * (x + 8));
2297     ma[0] = vld1q_u16(ma565 + x + 8);
2298     LoadAligned32U32(b565 + x + 8, b[0]);
2299     p[0] = CalculateFilteredOutputPass1(sr_hi, ma, b);
2300     ma[0] = vld1q_u16(ma343 + x + 8);
2301     ma[1] = vld1q_u16(ma444 + x + 8);
2302     LoadAligned32U32(b343 + x + 8, b[0]);
2303     LoadAligned32U32(b444 + x + 8, b[1]);
2304     p[1] = CalculateFilteredOutputPass2(sr_hi, ma, b);
2305     const int16x8_t d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2);
2306     ClipAndStore(dst + x + 0, d0);
2307     ClipAndStore(dst + x + 8, d1);
2308     s[1] = s[3];
2309     sq[2] = sq[6];
2310     sq[3] = sq[7];
2311     ma3[0] = ma3[1];
2312     ma5[0] = ma5[1];
2313     b3[0] = b3[4];
2314     b3[1] = b3[5];
2315     b5[0] = b5[4];
2316     b5[1] = b5[5];
2317     x += 16;
2318   } while (x < width);
2319 }
2320 
BoxFilterProcess(const RestorationUnitInfo & restoration_info,const uint16_t * src,const ptrdiff_t stride,const uint16_t * const top_border,const ptrdiff_t top_border_stride,const uint16_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint16_t * dst)2321 LIBGAV1_ALWAYS_INLINE void BoxFilterProcess(
2322     const RestorationUnitInfo& restoration_info, const uint16_t* src,
2323     const ptrdiff_t stride, const uint16_t* const top_border,
2324     const ptrdiff_t top_border_stride, const uint16_t* bottom_border,
2325     const ptrdiff_t bottom_border_stride, const int width, const int height,
2326     SgrBuffer* const sgr_buffer, uint16_t* dst) {
2327   const auto temp_stride = Align<ptrdiff_t>(width, 16);
2328   const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
2329   const auto sum_stride = temp_stride + 16;
2330   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2331   const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index];  // < 2^12.
2332   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
2333   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
2334   const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
2335   uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
2336   uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
2337   sum3[0] = sgr_buffer->sum3;
2338   square_sum3[0] = sgr_buffer->square_sum3;
2339   ma343[0] = sgr_buffer->ma343;
2340   b343[0] = sgr_buffer->b343;
2341   for (int i = 1; i <= 3; ++i) {
2342     sum3[i] = sum3[i - 1] + sum_stride;
2343     square_sum3[i] = square_sum3[i - 1] + sum_stride;
2344     ma343[i] = ma343[i - 1] + temp_stride;
2345     b343[i] = b343[i - 1] + temp_stride;
2346   }
2347   sum5[0] = sgr_buffer->sum5;
2348   square_sum5[0] = sgr_buffer->square_sum5;
2349   for (int i = 1; i <= 4; ++i) {
2350     sum5[i] = sum5[i - 1] + sum_stride;
2351     square_sum5[i] = square_sum5[i - 1] + sum_stride;
2352   }
2353   ma444[0] = sgr_buffer->ma444;
2354   b444[0] = sgr_buffer->b444;
2355   for (int i = 1; i <= 2; ++i) {
2356     ma444[i] = ma444[i - 1] + temp_stride;
2357     b444[i] = b444[i - 1] + temp_stride;
2358   }
2359   ma565[0] = sgr_buffer->ma565;
2360   ma565[1] = ma565[0] + temp_stride;
2361   b565[0] = sgr_buffer->b565;
2362   b565[1] = b565[0] + temp_stride;
2363   assert(scales[0] != 0);
2364   assert(scales[1] != 0);
2365   BoxSum(top_border, top_border_stride, width, sum_stride, sum_width, sum3[0],
2366          sum5[1], square_sum3[0], square_sum5[1]);
2367   sum5[0] = sum5[1];
2368   square_sum5[0] = square_sum5[1];
2369   const uint16_t* const s = (height > 1) ? src + stride : bottom_border;
2370   BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3,
2371                          square_sum5, sum_width, ma343, ma444[0], ma565[0],
2372                          b343, b444[0], b565[0]);
2373   sum5[0] = sgr_buffer->sum5;
2374   square_sum5[0] = sgr_buffer->square_sum5;
2375 
2376   for (int y = (height >> 1) - 1; y > 0; --y) {
2377     Circulate4PointersBy2<uint16_t>(sum3);
2378     Circulate4PointersBy2<uint32_t>(square_sum3);
2379     Circulate5PointersBy2<uint16_t>(sum5);
2380     Circulate5PointersBy2<uint32_t>(square_sum5);
2381     BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width,
2382               scales, w0, w2, sum3, sum5, square_sum3, square_sum5, sum_width,
2383               ma343, ma444, ma565, b343, b444, b565, dst);
2384     src += 2 * stride;
2385     dst += 2 * stride;
2386     Circulate4PointersBy2<uint16_t>(ma343);
2387     Circulate4PointersBy2<uint32_t>(b343);
2388     std::swap(ma444[0], ma444[2]);
2389     std::swap(b444[0], b444[2]);
2390     std::swap(ma565[0], ma565[1]);
2391     std::swap(b565[0], b565[1]);
2392   }
2393 
2394   Circulate4PointersBy2<uint16_t>(sum3);
2395   Circulate4PointersBy2<uint32_t>(square_sum3);
2396   Circulate5PointersBy2<uint16_t>(sum5);
2397   Circulate5PointersBy2<uint32_t>(square_sum5);
2398   if ((height & 1) == 0 || height > 1) {
2399     const uint16_t* sr[2];
2400     if ((height & 1) == 0) {
2401       sr[0] = bottom_border;
2402       sr[1] = bottom_border + bottom_border_stride;
2403     } else {
2404       sr[0] = src + 2 * stride;
2405       sr[1] = bottom_border;
2406     }
2407     BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5,
2408               square_sum3, square_sum5, sum_width, ma343, ma444, ma565, b343,
2409               b444, b565, dst);
2410   }
2411   if ((height & 1) != 0) {
2412     if (height > 1) {
2413       src += 2 * stride;
2414       dst += 2 * stride;
2415       Circulate4PointersBy2<uint16_t>(sum3);
2416       Circulate4PointersBy2<uint32_t>(square_sum3);
2417       Circulate5PointersBy2<uint16_t>(sum5);
2418       Circulate5PointersBy2<uint32_t>(square_sum5);
2419       Circulate4PointersBy2<uint16_t>(ma343);
2420       Circulate4PointersBy2<uint32_t>(b343);
2421       std::swap(ma444[0], ma444[2]);
2422       std::swap(b444[0], b444[2]);
2423       std::swap(ma565[0], ma565[1]);
2424       std::swap(b565[0], b565[1]);
2425     }
2426     BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width,
2427                      sum_width, scales, w0, w2, sum3, sum5, square_sum3,
2428                      square_sum5, ma343[0], ma444[0], ma565[0], b343[0],
2429                      b444[0], b565[0], dst);
2430   }
2431 }
2432 
BoxFilterProcessPass1(const RestorationUnitInfo & restoration_info,const uint16_t * src,const ptrdiff_t stride,const uint16_t * const top_border,const ptrdiff_t top_border_stride,const uint16_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint16_t * dst)2433 inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
2434                                   const uint16_t* src, const ptrdiff_t stride,
2435                                   const uint16_t* const top_border,
2436                                   const ptrdiff_t top_border_stride,
2437                                   const uint16_t* bottom_border,
2438                                   const ptrdiff_t bottom_border_stride,
2439                                   const int width, const int height,
2440                                   SgrBuffer* const sgr_buffer, uint16_t* dst) {
2441   const auto temp_stride = Align<ptrdiff_t>(width, 16);
2442   const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
2443   const auto sum_stride = temp_stride + 16;
2444   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2445   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0];  // < 2^12.
2446   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
2447   uint16_t *sum5[5], *ma565[2];
2448   uint32_t *square_sum5[5], *b565[2];
2449   sum5[0] = sgr_buffer->sum5;
2450   square_sum5[0] = sgr_buffer->square_sum5;
2451   for (int i = 1; i <= 4; ++i) {
2452     sum5[i] = sum5[i - 1] + sum_stride;
2453     square_sum5[i] = square_sum5[i - 1] + sum_stride;
2454   }
2455   ma565[0] = sgr_buffer->ma565;
2456   ma565[1] = ma565[0] + temp_stride;
2457   b565[0] = sgr_buffer->b565;
2458   b565[1] = b565[0] + temp_stride;
2459   assert(scale != 0);
2460 
2461   BoxSum<5>(top_border, top_border_stride, width, sum_stride, sum_width,
2462             sum5[1], square_sum5[1]);
2463   sum5[0] = sum5[1];
2464   square_sum5[0] = square_sum5[1];
2465   const uint16_t* const s = (height > 1) ? src + stride : bottom_border;
2466   BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, sum_width,
2467                           ma565[0], b565[0]);
2468   sum5[0] = sgr_buffer->sum5;
2469   square_sum5[0] = sgr_buffer->square_sum5;
2470 
2471   for (int y = (height >> 1) - 1; y > 0; --y) {
2472     Circulate5PointersBy2<uint16_t>(sum5);
2473     Circulate5PointersBy2<uint32_t>(square_sum5);
2474     BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5,
2475                    square_sum5, width, sum_width, scale, w0, ma565, b565, dst);
2476     src += 2 * stride;
2477     dst += 2 * stride;
2478     std::swap(ma565[0], ma565[1]);
2479     std::swap(b565[0], b565[1]);
2480   }
2481 
2482   Circulate5PointersBy2<uint16_t>(sum5);
2483   Circulate5PointersBy2<uint32_t>(square_sum5);
2484   if ((height & 1) == 0 || height > 1) {
2485     const uint16_t* sr[2];
2486     if ((height & 1) == 0) {
2487       sr[0] = bottom_border;
2488       sr[1] = bottom_border + bottom_border_stride;
2489     } else {
2490       sr[0] = src + 2 * stride;
2491       sr[1] = bottom_border;
2492     }
2493     BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width,
2494                    sum_width, scale, w0, ma565, b565, dst);
2495   }
2496   if ((height & 1) != 0) {
2497     src += 3;
2498     if (height > 1) {
2499       src += 2 * stride;
2500       dst += 2 * stride;
2501       std::swap(ma565[0], ma565[1]);
2502       std::swap(b565[0], b565[1]);
2503       Circulate5PointersBy2<uint16_t>(sum5);
2504       Circulate5PointersBy2<uint32_t>(square_sum5);
2505     }
2506     BoxFilterPass1LastRow(src, bottom_border + bottom_border_stride, width,
2507                           sum_width, scale, w0, sum5, square_sum5, ma565[0],
2508                           b565[0], dst);
2509   }
2510 }
2511 
BoxFilterProcessPass2(const RestorationUnitInfo & restoration_info,const uint16_t * src,const ptrdiff_t stride,const uint16_t * const top_border,const ptrdiff_t top_border_stride,const uint16_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint16_t * dst)2512 inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
2513                                   const uint16_t* src, const ptrdiff_t stride,
2514                                   const uint16_t* const top_border,
2515                                   const ptrdiff_t top_border_stride,
2516                                   const uint16_t* bottom_border,
2517                                   const ptrdiff_t bottom_border_stride,
2518                                   const int width, const int height,
2519                                   SgrBuffer* const sgr_buffer, uint16_t* dst) {
2520   assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
2521   const auto temp_stride = Align<ptrdiff_t>(width, 16);
2522   const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
2523   const auto sum_stride = temp_stride + 16;
2524   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
2525   const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
2526   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2527   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1];  // < 2^12.
2528   uint16_t *sum3[3], *ma343[3], *ma444[2];
2529   uint32_t *square_sum3[3], *b343[3], *b444[2];
2530   sum3[0] = sgr_buffer->sum3;
2531   square_sum3[0] = sgr_buffer->square_sum3;
2532   ma343[0] = sgr_buffer->ma343;
2533   b343[0] = sgr_buffer->b343;
2534   for (int i = 1; i <= 2; ++i) {
2535     sum3[i] = sum3[i - 1] + sum_stride;
2536     square_sum3[i] = square_sum3[i - 1] + sum_stride;
2537     ma343[i] = ma343[i - 1] + temp_stride;
2538     b343[i] = b343[i - 1] + temp_stride;
2539   }
2540   ma444[0] = sgr_buffer->ma444;
2541   ma444[1] = ma444[0] + temp_stride;
2542   b444[0] = sgr_buffer->b444;
2543   b444[1] = b444[0] + temp_stride;
2544   assert(scale != 0);
2545   BoxSum<3>(top_border, top_border_stride, width, sum_stride, sum_width,
2546             sum3[0], square_sum3[0]);
2547   BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3,
2548                                  sum_width, ma343[0], nullptr, b343[0],
2549                                  nullptr);
2550   Circulate3PointersBy1<uint16_t>(sum3);
2551   Circulate3PointersBy1<uint32_t>(square_sum3);
2552   const uint16_t* s;
2553   if (height > 1) {
2554     s = src + stride;
2555   } else {
2556     s = bottom_border;
2557     bottom_border += bottom_border_stride;
2558   }
2559   BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width,
2560                                 ma343[1], ma444[0], b343[1], b444[0]);
2561 
2562   for (int y = height - 2; y > 0; --y) {
2563     Circulate3PointersBy1<uint16_t>(sum3);
2564     Circulate3PointersBy1<uint32_t>(square_sum3);
2565     BoxFilterPass2(src + 2, src + 2 * stride, width, sum_width, scale, w0, sum3,
2566                    square_sum3, ma343, ma444, b343, b444, dst);
2567     src += stride;
2568     dst += stride;
2569     Circulate3PointersBy1<uint16_t>(ma343);
2570     Circulate3PointersBy1<uint32_t>(b343);
2571     std::swap(ma444[0], ma444[1]);
2572     std::swap(b444[0], b444[1]);
2573   }
2574 
2575   int y = std::min(height, 2);
2576   src += 2;
2577   do {
2578     Circulate3PointersBy1<uint16_t>(sum3);
2579     Circulate3PointersBy1<uint32_t>(square_sum3);
2580     BoxFilterPass2(src, bottom_border, width, sum_width, scale, w0, sum3,
2581                    square_sum3, ma343, ma444, b343, b444, dst);
2582     src += stride;
2583     dst += stride;
2584     bottom_border += bottom_border_stride;
2585     Circulate3PointersBy1<uint16_t>(ma343);
2586     Circulate3PointersBy1<uint32_t>(b343);
2587     std::swap(ma444[0], ma444[1]);
2588     std::swap(b444[0], b444[1]);
2589   } while (--y != 0);
2590 }
2591 
2592 // If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in
2593 // the end of each row. It is safe to overwrite the output as it will not be
2594 // part of the visible frame.
SelfGuidedFilter_NEON(const RestorationUnitInfo & LIBGAV1_RESTRICT restoration_info,const void * LIBGAV1_RESTRICT const source,const ptrdiff_t stride,const void * LIBGAV1_RESTRICT const top_border,const ptrdiff_t top_border_stride,const void * LIBGAV1_RESTRICT const bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,RestorationBuffer * LIBGAV1_RESTRICT const restoration_buffer,void * LIBGAV1_RESTRICT const dest)2595 void SelfGuidedFilter_NEON(
2596     const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
2597     const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
2598     const void* LIBGAV1_RESTRICT const top_border,
2599     const ptrdiff_t top_border_stride,
2600     const void* LIBGAV1_RESTRICT const bottom_border,
2601     const ptrdiff_t bottom_border_stride, const int width, const int height,
2602     RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
2603     void* LIBGAV1_RESTRICT const dest) {
2604   const int index = restoration_info.sgr_proj_info.index;
2605   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
2606   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
2607   const auto* const src = static_cast<const uint16_t*>(source);
2608   const auto* top = static_cast<const uint16_t*>(top_border);
2609   const auto* bottom = static_cast<const uint16_t*>(bottom_border);
2610   auto* const dst = static_cast<uint16_t*>(dest);
2611   SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
2612   if (radius_pass_1 == 0) {
2613     // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
2614     // following assertion.
2615     assert(radius_pass_0 != 0);
2616     BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3,
2617                           top_border_stride, bottom - 3, bottom_border_stride,
2618                           width, height, sgr_buffer, dst);
2619   } else if (radius_pass_0 == 0) {
2620     BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2,
2621                           top_border_stride, bottom - 2, bottom_border_stride,
2622                           width, height, sgr_buffer, dst);
2623   } else {
2624     BoxFilterProcess(restoration_info, src - 3, stride, top - 3,
2625                      top_border_stride, bottom - 3, bottom_border_stride, width,
2626                      height, sgr_buffer, dst);
2627   }
2628 }
2629 
Init10bpp()2630 void Init10bpp() {
2631   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
2632   assert(dsp != nullptr);
2633   dsp->loop_restorations[0] = WienerFilter_NEON;
2634   dsp->loop_restorations[1] = SelfGuidedFilter_NEON;
2635 }
2636 
2637 }  // namespace
2638 
LoopRestorationInit10bpp_NEON()2639 void LoopRestorationInit10bpp_NEON() { Init10bpp(); }
2640 
2641 }  // namespace dsp
2642 }  // namespace libgav1
2643 
2644 #else   // !(LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10)
2645 namespace libgav1 {
2646 namespace dsp {
2647 
LoopRestorationInit10bpp_NEON()2648 void LoopRestorationInit10bpp_NEON() {}
2649 
2650 }  // namespace dsp
2651 }  // namespace libgav1
2652 #endif  // LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
2653