xref: /aosp_15_r20/external/libaom/av1/encoder/arm/highbd_pickrst_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 #include <stdint.h>
15 
16 #include "aom_dsp/arm/mem_neon.h"
17 #include "aom_dsp/arm/sum_neon.h"
18 #include "aom_dsp/arm/transpose_neon.h"
19 #include "av1/encoder/arm/pickrst_neon.h"
20 #include "av1/encoder/pickrst.h"
21 
highbd_calc_proj_params_r0_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])22 static inline void highbd_calc_proj_params_r0_r1_neon(
23     const uint8_t *src8, int width, int height, int src_stride,
24     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
25     int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) {
26   assert(width % 8 == 0);
27   const int size = width * height;
28   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
29   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
30 
31   int64x2_t h00_lo = vdupq_n_s64(0);
32   int64x2_t h00_hi = vdupq_n_s64(0);
33   int64x2_t h11_lo = vdupq_n_s64(0);
34   int64x2_t h11_hi = vdupq_n_s64(0);
35   int64x2_t h01_lo = vdupq_n_s64(0);
36   int64x2_t h01_hi = vdupq_n_s64(0);
37   int64x2_t c0_lo = vdupq_n_s64(0);
38   int64x2_t c0_hi = vdupq_n_s64(0);
39   int64x2_t c1_lo = vdupq_n_s64(0);
40   int64x2_t c1_hi = vdupq_n_s64(0);
41 
42   do {
43     const uint16_t *src_ptr = src;
44     const uint16_t *dat_ptr = dat;
45     int32_t *flt0_ptr = flt0;
46     int32_t *flt1_ptr = flt1;
47     int w = width;
48 
49     do {
50       uint16x8_t s = vld1q_u16(src_ptr);
51       uint16x8_t d = vld1q_u16(dat_ptr);
52       int32x4_t f0_lo = vld1q_s32(flt0_ptr);
53       int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
54       int32x4_t f1_lo = vld1q_s32(flt1_ptr);
55       int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
56 
57       int32x4_t u_lo =
58           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS));
59       int32x4_t u_hi = vreinterpretq_s32_u32(
60           vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS));
61       int32x4_t s_lo =
62           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS));
63       int32x4_t s_hi = vreinterpretq_s32_u32(
64           vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS));
65       s_lo = vsubq_s32(s_lo, u_lo);
66       s_hi = vsubq_s32(s_hi, u_hi);
67 
68       f0_lo = vsubq_s32(f0_lo, u_lo);
69       f0_hi = vsubq_s32(f0_hi, u_hi);
70       f1_lo = vsubq_s32(f1_lo, u_lo);
71       f1_hi = vsubq_s32(f1_hi, u_hi);
72 
73       h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
74       h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
75       h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
76       h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
77 
78       h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
79       h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
80       h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
81       h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
82 
83       h01_lo = vmlal_s32(h01_lo, vget_low_s32(f0_lo), vget_low_s32(f1_lo));
84       h01_lo = vmlal_s32(h01_lo, vget_high_s32(f0_lo), vget_high_s32(f1_lo));
85       h01_hi = vmlal_s32(h01_hi, vget_low_s32(f0_hi), vget_low_s32(f1_hi));
86       h01_hi = vmlal_s32(h01_hi, vget_high_s32(f0_hi), vget_high_s32(f1_hi));
87 
88       c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
89       c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
90       c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
91       c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
92 
93       c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
94       c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
95       c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
96       c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
97 
98       src_ptr += 8;
99       dat_ptr += 8;
100       flt0_ptr += 8;
101       flt1_ptr += 8;
102       w -= 8;
103     } while (w != 0);
104 
105     src += src_stride;
106     dat += dat_stride;
107     flt0 += flt0_stride;
108     flt1 += flt1_stride;
109   } while (--height != 0);
110 
111   H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
112   H[0][1] = horizontal_add_s64x2(vaddq_s64(h01_lo, h01_hi)) / size;
113   H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
114   H[1][0] = H[0][1];
115   C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
116   C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
117 }
118 
highbd_calc_proj_params_r0_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int64_t H[2][2],int64_t C[2])119 static inline void highbd_calc_proj_params_r0_neon(
120     const uint8_t *src8, int width, int height, int src_stride,
121     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
122     int64_t H[2][2], int64_t C[2]) {
123   assert(width % 8 == 0);
124   const int size = width * height;
125   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
126   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
127 
128   int64x2_t h00_lo = vdupq_n_s64(0);
129   int64x2_t h00_hi = vdupq_n_s64(0);
130   int64x2_t c0_lo = vdupq_n_s64(0);
131   int64x2_t c0_hi = vdupq_n_s64(0);
132 
133   do {
134     const uint16_t *src_ptr = src;
135     const uint16_t *dat_ptr = dat;
136     int32_t *flt0_ptr = flt0;
137     int w = width;
138 
139     do {
140       uint16x8_t s = vld1q_u16(src_ptr);
141       uint16x8_t d = vld1q_u16(dat_ptr);
142       int32x4_t f0_lo = vld1q_s32(flt0_ptr);
143       int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
144 
145       int32x4_t u_lo =
146           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS));
147       int32x4_t u_hi = vreinterpretq_s32_u32(
148           vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS));
149       int32x4_t s_lo =
150           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS));
151       int32x4_t s_hi = vreinterpretq_s32_u32(
152           vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS));
153       s_lo = vsubq_s32(s_lo, u_lo);
154       s_hi = vsubq_s32(s_hi, u_hi);
155 
156       f0_lo = vsubq_s32(f0_lo, u_lo);
157       f0_hi = vsubq_s32(f0_hi, u_hi);
158 
159       h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
160       h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
161       h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
162       h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
163 
164       c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
165       c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
166       c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
167       c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
168 
169       src_ptr += 8;
170       dat_ptr += 8;
171       flt0_ptr += 8;
172       w -= 8;
173     } while (w != 0);
174 
175     src += src_stride;
176     dat += dat_stride;
177     flt0 += flt0_stride;
178   } while (--height != 0);
179 
180   H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
181   C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
182 }
183 
highbd_calc_proj_params_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])184 static inline void highbd_calc_proj_params_r1_neon(
185     const uint8_t *src8, int width, int height, int src_stride,
186     const uint8_t *dat8, int dat_stride, int32_t *flt1, int flt1_stride,
187     int64_t H[2][2], int64_t C[2]) {
188   assert(width % 8 == 0);
189   const int size = width * height;
190   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
191   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
192 
193   int64x2_t h11_lo = vdupq_n_s64(0);
194   int64x2_t h11_hi = vdupq_n_s64(0);
195   int64x2_t c1_lo = vdupq_n_s64(0);
196   int64x2_t c1_hi = vdupq_n_s64(0);
197 
198   do {
199     const uint16_t *src_ptr = src;
200     const uint16_t *dat_ptr = dat;
201     int32_t *flt1_ptr = flt1;
202     int w = width;
203 
204     do {
205       uint16x8_t s = vld1q_u16(src_ptr);
206       uint16x8_t d = vld1q_u16(dat_ptr);
207       int32x4_t f1_lo = vld1q_s32(flt1_ptr);
208       int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
209 
210       int32x4_t u_lo =
211           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS));
212       int32x4_t u_hi = vreinterpretq_s32_u32(
213           vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS));
214       int32x4_t s_lo =
215           vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS));
216       int32x4_t s_hi = vreinterpretq_s32_u32(
217           vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS));
218       s_lo = vsubq_s32(s_lo, u_lo);
219       s_hi = vsubq_s32(s_hi, u_hi);
220 
221       f1_lo = vsubq_s32(f1_lo, u_lo);
222       f1_hi = vsubq_s32(f1_hi, u_hi);
223 
224       h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
225       h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
226       h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
227       h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
228 
229       c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
230       c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
231       c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
232       c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
233 
234       src_ptr += 8;
235       dat_ptr += 8;
236       flt1_ptr += 8;
237       w -= 8;
238     } while (w != 0);
239 
240     src += src_stride;
241     dat += dat_stride;
242     flt1 += flt1_stride;
243   } while (--height != 0);
244 
245   H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
246   C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
247 }
248 
249 // The function calls 3 subfunctions for the following cases :
250 // 1) When params->r[0] > 0 and params->r[1] > 0. In this case all elements
251 //    of C and H need to be computed.
252 // 2) When only params->r[0] > 0. In this case only H[0][0] and C[0] are
253 //    non-zero and need to be computed.
254 // 3) When only params->r[1] > 0. In this case only H[1][1] and C[1] are
255 //    non-zero and need to be computed.
av1_calc_proj_params_high_bd_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2],const sgr_params_type * params)256 void av1_calc_proj_params_high_bd_neon(const uint8_t *src8, int width,
257                                        int height, int src_stride,
258                                        const uint8_t *dat8, int dat_stride,
259                                        int32_t *flt0, int flt0_stride,
260                                        int32_t *flt1, int flt1_stride,
261                                        int64_t H[2][2], int64_t C[2],
262                                        const sgr_params_type *params) {
263   if ((params->r[0] > 0) && (params->r[1] > 0)) {
264     highbd_calc_proj_params_r0_r1_neon(src8, width, height, src_stride, dat8,
265                                        dat_stride, flt0, flt0_stride, flt1,
266                                        flt1_stride, H, C);
267   } else if (params->r[0] > 0) {
268     highbd_calc_proj_params_r0_neon(src8, width, height, src_stride, dat8,
269                                     dat_stride, flt0, flt0_stride, H, C);
270   } else if (params->r[1] > 0) {
271     highbd_calc_proj_params_r1_neon(src8, width, height, src_stride, dat8,
272                                     dat_stride, flt1, flt1_stride, H, C);
273   }
274 }
275 
hadd_update_4_stats_neon(const int64_t * const src,const int32x4_t * deltas,int64_t * const dst)276 static inline void hadd_update_4_stats_neon(const int64_t *const src,
277                                             const int32x4_t *deltas,
278                                             int64_t *const dst) {
279   int64x2_t delta0_s64 = vpaddlq_s32(deltas[0]);
280   int64x2_t delta1_s64 = vpaddlq_s32(deltas[1]);
281   int64x2_t delta2_s64 = vpaddlq_s32(deltas[2]);
282   int64x2_t delta3_s64 = vpaddlq_s32(deltas[3]);
283 
284 #if AOM_ARCH_AARCH64
285   int64x2_t delta01 = vpaddq_s64(delta0_s64, delta1_s64);
286   int64x2_t delta23 = vpaddq_s64(delta2_s64, delta3_s64);
287 
288   int64x2_t src0 = vld1q_s64(src);
289   int64x2_t src1 = vld1q_s64(src + 2);
290   vst1q_s64(dst, vaddq_s64(src0, delta01));
291   vst1q_s64(dst + 2, vaddq_s64(src1, delta23));
292 #else
293   dst[0] = src[0] + horizontal_add_s64x2(delta0_s64);
294   dst[1] = src[1] + horizontal_add_s64x2(delta1_s64);
295   dst[2] = src[2] + horizontal_add_s64x2(delta2_s64);
296   dst[3] = src[3] + horizontal_add_s64x2(delta3_s64);
297 #endif
298 }
299 
compute_stats_win5_highbd_neon(const int16_t * const d,const int32_t d_stride,const int16_t * const s,const int32_t s_stride,const int32_t width,const int32_t height,int64_t * const M,int64_t * const H,aom_bit_depth_t bit_depth)300 static inline void compute_stats_win5_highbd_neon(
301     const int16_t *const d, const int32_t d_stride, const int16_t *const s,
302     const int32_t s_stride, const int32_t width, const int32_t height,
303     int64_t *const M, int64_t *const H, aom_bit_depth_t bit_depth) {
304   const int32_t wiener_win = WIENER_WIN_CHROMA;
305   const int32_t wiener_win2 = wiener_win * wiener_win;
306   const int32_t w16 = width & ~15;
307   const int32_t h8 = height & ~7;
308   int16x8_t mask[2];
309   mask[0] = vld1q_s16(&(mask_16bit[16]) - width % 16);
310   mask[1] = vld1q_s16(&(mask_16bit[16]) - width % 16 + 8);
311   int32_t i, j, x, y;
312 
313   const int32_t num_bit_left =
314       32 - 1 /* sign */ - 2 * bit_depth /* energy */ + 2 /* SIMD */;
315   const int32_t h_allowed =
316       (1 << num_bit_left) / (w16 + ((w16 != width) ? 16 : 0));
317 
318   // Step 1: Calculate the top edge of the whole matrix, i.e., the top
319   // edge of each triangle and square on the top row.
320   j = 0;
321   do {
322     const int16_t *s_t = s;
323     const int16_t *d_t = d;
324     int32_t height_t = 0;
325     int64x2_t sum_m[WIENER_WIN_CHROMA] = { vdupq_n_s64(0) };
326     int64x2_t sum_h[WIENER_WIN_CHROMA] = { vdupq_n_s64(0) };
327     int16x8_t src[2], dgd[2];
328 
329     do {
330       const int32_t h_t =
331           ((height - height_t) < h_allowed) ? (height - height_t) : h_allowed;
332       int32x4_t row_m[WIENER_WIN_CHROMA] = { vdupq_n_s32(0) };
333       int32x4_t row_h[WIENER_WIN_CHROMA] = { vdupq_n_s32(0) };
334 
335       y = h_t;
336       do {
337         x = 0;
338         while (x < w16) {
339           src[0] = vld1q_s16(s_t + x + 0);
340           src[1] = vld1q_s16(s_t + x + 8);
341           dgd[0] = vld1q_s16(d_t + x + 0);
342           dgd[1] = vld1q_s16(d_t + x + 8);
343           stats_top_win5_neon(src, dgd, d_t + j + x, d_stride, row_m, row_h);
344           x += 16;
345         }
346 
347         if (w16 != width) {
348           src[0] = vld1q_s16(s_t + w16 + 0);
349           src[1] = vld1q_s16(s_t + w16 + 8);
350           dgd[0] = vld1q_s16(d_t + w16 + 0);
351           dgd[1] = vld1q_s16(d_t + w16 + 8);
352           src[0] = vandq_s16(src[0], mask[0]);
353           src[1] = vandq_s16(src[1], mask[1]);
354           dgd[0] = vandq_s16(dgd[0], mask[0]);
355           dgd[1] = vandq_s16(dgd[1], mask[1]);
356           stats_top_win5_neon(src, dgd, d_t + j + w16, d_stride, row_m, row_h);
357         }
358 
359         s_t += s_stride;
360         d_t += d_stride;
361       } while (--y);
362 
363       sum_m[0] = vpadalq_s32(sum_m[0], row_m[0]);
364       sum_m[1] = vpadalq_s32(sum_m[1], row_m[1]);
365       sum_m[2] = vpadalq_s32(sum_m[2], row_m[2]);
366       sum_m[3] = vpadalq_s32(sum_m[3], row_m[3]);
367       sum_m[4] = vpadalq_s32(sum_m[4], row_m[4]);
368       sum_h[0] = vpadalq_s32(sum_h[0], row_h[0]);
369       sum_h[1] = vpadalq_s32(sum_h[1], row_h[1]);
370       sum_h[2] = vpadalq_s32(sum_h[2], row_h[2]);
371       sum_h[3] = vpadalq_s32(sum_h[3], row_h[3]);
372       sum_h[4] = vpadalq_s32(sum_h[4], row_h[4]);
373 
374       height_t += h_t;
375     } while (height_t < height);
376 
377 #if AOM_ARCH_AARCH64
378     int64x2_t sum_m0 = vpaddq_s64(sum_m[0], sum_m[1]);
379     int64x2_t sum_m2 = vpaddq_s64(sum_m[2], sum_m[3]);
380     vst1q_s64(&M[wiener_win * j + 0], sum_m0);
381     vst1q_s64(&M[wiener_win * j + 2], sum_m2);
382     M[wiener_win * j + 4] = vaddvq_s64(sum_m[4]);
383 
384     int64x2_t sum_h0 = vpaddq_s64(sum_h[0], sum_h[1]);
385     int64x2_t sum_h2 = vpaddq_s64(sum_h[2], sum_h[3]);
386     vst1q_s64(&H[wiener_win * j + 0], sum_h0);
387     vst1q_s64(&H[wiener_win * j + 2], sum_h2);
388     H[wiener_win * j + 4] = vaddvq_s64(sum_h[4]);
389 #else
390     M[wiener_win * j + 0] = horizontal_add_s64x2(sum_m[0]);
391     M[wiener_win * j + 1] = horizontal_add_s64x2(sum_m[1]);
392     M[wiener_win * j + 2] = horizontal_add_s64x2(sum_m[2]);
393     M[wiener_win * j + 3] = horizontal_add_s64x2(sum_m[3]);
394     M[wiener_win * j + 4] = horizontal_add_s64x2(sum_m[4]);
395 
396     H[wiener_win * j + 0] = horizontal_add_s64x2(sum_h[0]);
397     H[wiener_win * j + 1] = horizontal_add_s64x2(sum_h[1]);
398     H[wiener_win * j + 2] = horizontal_add_s64x2(sum_h[2]);
399     H[wiener_win * j + 3] = horizontal_add_s64x2(sum_h[3]);
400     H[wiener_win * j + 4] = horizontal_add_s64x2(sum_h[4]);
401 #endif  // AOM_ARCH_AARCH64
402   } while (++j < wiener_win);
403 
404   // Step 2: Calculate the left edge of each square on the top row.
405   j = 1;
406   do {
407     const int16_t *d_t = d;
408     int32_t height_t = 0;
409     int64x2_t sum_h[WIENER_WIN_CHROMA - 1] = { vdupq_n_s64(0) };
410     int16x8_t dgd[2];
411 
412     do {
413       const int32_t h_t =
414           ((height - height_t) < h_allowed) ? (height - height_t) : h_allowed;
415       int32x4_t row_h[WIENER_WIN_CHROMA - 1] = { vdupq_n_s32(0) };
416 
417       y = h_t;
418       do {
419         x = 0;
420         while (x < w16) {
421           dgd[0] = vld1q_s16(d_t + j + x + 0);
422           dgd[1] = vld1q_s16(d_t + j + x + 8);
423           stats_left_win5_neon(dgd, d_t + x, d_stride, row_h);
424           x += 16;
425         }
426 
427         if (w16 != width) {
428           dgd[0] = vld1q_s16(d_t + j + x + 0);
429           dgd[1] = vld1q_s16(d_t + j + x + 8);
430           dgd[0] = vandq_s16(dgd[0], mask[0]);
431           dgd[1] = vandq_s16(dgd[1], mask[1]);
432           stats_left_win5_neon(dgd, d_t + x, d_stride, row_h);
433         }
434 
435         d_t += d_stride;
436       } while (--y);
437 
438       sum_h[0] = vpadalq_s32(sum_h[0], row_h[0]);
439       sum_h[1] = vpadalq_s32(sum_h[1], row_h[1]);
440       sum_h[2] = vpadalq_s32(sum_h[2], row_h[2]);
441       sum_h[3] = vpadalq_s32(sum_h[3], row_h[3]);
442 
443       height_t += h_t;
444     } while (height_t < height);
445 
446 #if AOM_ARCH_AARCH64
447     int64x2_t sum_h0 = vpaddq_s64(sum_h[0], sum_h[1]);
448     int64x2_t sum_h1 = vpaddq_s64(sum_h[2], sum_h[3]);
449     vst1_s64(&H[1 * wiener_win2 + j * wiener_win], vget_low_s64(sum_h0));
450     vst1_s64(&H[2 * wiener_win2 + j * wiener_win], vget_high_s64(sum_h0));
451     vst1_s64(&H[3 * wiener_win2 + j * wiener_win], vget_low_s64(sum_h1));
452     vst1_s64(&H[4 * wiener_win2 + j * wiener_win], vget_high_s64(sum_h1));
453 #else
454     H[1 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[0]);
455     H[2 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[1]);
456     H[3 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[2]);
457     H[4 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[3]);
458 #endif  // AOM_ARCH_AARCH64
459   } while (++j < wiener_win);
460 
461   // Step 3: Derive the top edge of each triangle along the diagonal. No
462   // triangle in top row.
463   {
464     const int16_t *d_t = d;
465 
466     if (height % 2) {
467       int32x4_t deltas[(WIENER_WIN + 1) * 2] = { vdupq_n_s32(0) };
468       int32x4_t deltas_tr[(WIENER_WIN + 1) * 2] = { vdupq_n_s32(0) };
469       int16x8_t ds[WIENER_WIN * 2];
470 
471       load_s16_8x4(d_t, d_stride, &ds[0], &ds[2], &ds[4], &ds[6]);
472       load_s16_8x4(d_t + width, d_stride, &ds[1], &ds[3], &ds[5], &ds[7]);
473       d_t += 4 * d_stride;
474 
475       step3_win5_oneline_neon(&d_t, d_stride, width, height, ds, deltas);
476       transpose_arrays_s32_8x8(deltas, deltas_tr);
477 
478       update_5_stats_neon(H + 0 * wiener_win * wiener_win2 + 0 * wiener_win,
479                           deltas_tr[0], vgetq_lane_s32(deltas_tr[4], 0),
480                           H + 1 * wiener_win * wiener_win2 + 1 * wiener_win);
481 
482       update_5_stats_neon(H + 1 * wiener_win * wiener_win2 + 1 * wiener_win,
483                           deltas_tr[1], vgetq_lane_s32(deltas_tr[5], 0),
484                           H + 2 * wiener_win * wiener_win2 + 2 * wiener_win);
485 
486       update_5_stats_neon(H + 2 * wiener_win * wiener_win2 + 2 * wiener_win,
487                           deltas_tr[2], vgetq_lane_s32(deltas_tr[6], 0),
488                           H + 3 * wiener_win * wiener_win2 + 3 * wiener_win);
489 
490       update_5_stats_neon(H + 3 * wiener_win * wiener_win2 + 3 * wiener_win,
491                           deltas_tr[3], vgetq_lane_s32(deltas_tr[7], 0),
492                           H + 4 * wiener_win * wiener_win2 + 4 * wiener_win);
493 
494     } else {
495       int32x4_t deltas[WIENER_WIN_CHROMA * 2] = { vdupq_n_s32(0) };
496       int16x8_t ds[WIENER_WIN_CHROMA * 2];
497 
498       ds[0] = load_unaligned_s16_4x2(d_t + 0 * d_stride, width);
499       ds[1] = load_unaligned_s16_4x2(d_t + 1 * d_stride, width);
500       ds[2] = load_unaligned_s16_4x2(d_t + 2 * d_stride, width);
501       ds[3] = load_unaligned_s16_4x2(d_t + 3 * d_stride, width);
502 
503       step3_win5_neon(d_t + 4 * d_stride, d_stride, width, height, ds, deltas);
504 
505       transpose_elems_inplace_s32_4x4(&deltas[0], &deltas[1], &deltas[2],
506                                       &deltas[3]);
507 
508       update_5_stats_neon(H + 0 * wiener_win * wiener_win2 + 0 * wiener_win,
509                           deltas[0], vgetq_lane_s32(deltas[4], 0),
510                           H + 1 * wiener_win * wiener_win2 + 1 * wiener_win);
511 
512       update_5_stats_neon(H + 1 * wiener_win * wiener_win2 + 1 * wiener_win,
513                           deltas[1], vgetq_lane_s32(deltas[4], 1),
514                           H + 2 * wiener_win * wiener_win2 + 2 * wiener_win);
515 
516       update_5_stats_neon(H + 2 * wiener_win * wiener_win2 + 2 * wiener_win,
517                           deltas[2], vgetq_lane_s32(deltas[4], 2),
518                           H + 3 * wiener_win * wiener_win2 + 3 * wiener_win);
519 
520       update_5_stats_neon(H + 3 * wiener_win * wiener_win2 + 3 * wiener_win,
521                           deltas[3], vgetq_lane_s32(deltas[4], 3),
522                           H + 4 * wiener_win * wiener_win2 + 4 * wiener_win);
523     }
524   }
525 
526   // Step 4: Derive the top and left edge of each square. No square in top and
527   // bottom row.
528 
529   {
530     y = h8;
531 
532     int16x4_t d_s[12];
533     int16x4_t d_e[12];
534     const int16_t *d_t = d;
535     int16x4_t zeros = vdup_n_s16(0);
536     load_s16_4x4(d_t, d_stride, &d_s[0], &d_s[1], &d_s[2], &d_s[3]);
537     load_s16_4x4(d_t + width, d_stride, &d_e[0], &d_e[1], &d_e[2], &d_e[3]);
538     int32x4_t deltas[6][18] = { { vdupq_n_s32(0) }, { vdupq_n_s32(0) } };
539 
540     while (y >= 8) {
541       load_s16_4x8(d_t + 4 * d_stride, d_stride, &d_s[4], &d_s[5], &d_s[6],
542                    &d_s[7], &d_s[8], &d_s[9], &d_s[10], &d_s[11]);
543       load_s16_4x8(d_t + width + 4 * d_stride, d_stride, &d_e[4], &d_e[5],
544                    &d_e[6], &d_e[7], &d_e[8], &d_e[9], &d_e[10], &d_e[11]);
545 
546       int16x8_t s_tr[8], e_tr[8];
547       transpose_elems_s16_4x8(d_s[0], d_s[1], d_s[2], d_s[3], d_s[4], d_s[5],
548                               d_s[6], d_s[7], &s_tr[0], &s_tr[1], &s_tr[2],
549                               &s_tr[3]);
550       transpose_elems_s16_4x8(d_s[8], d_s[9], d_s[10], d_s[11], zeros, zeros,
551                               zeros, zeros, &s_tr[4], &s_tr[5], &s_tr[6],
552                               &s_tr[7]);
553 
554       transpose_elems_s16_4x8(d_e[0], d_e[1], d_e[2], d_e[3], d_e[4], d_e[5],
555                               d_e[6], d_e[7], &e_tr[0], &e_tr[1], &e_tr[2],
556                               &e_tr[3]);
557       transpose_elems_s16_4x8(d_e[8], d_e[9], d_e[10], d_e[11], zeros, zeros,
558                               zeros, zeros, &e_tr[4], &e_tr[5], &e_tr[6],
559                               &e_tr[7]);
560 
561       int16x8_t start_col0[5], start_col1[5], start_col2[5], start_col3[5];
562       start_col0[0] = s_tr[0];
563       start_col0[1] = vextq_s16(s_tr[0], s_tr[4], 1);
564       start_col0[2] = vextq_s16(s_tr[0], s_tr[4], 2);
565       start_col0[3] = vextq_s16(s_tr[0], s_tr[4], 3);
566       start_col0[4] = vextq_s16(s_tr[0], s_tr[4], 4);
567 
568       start_col1[0] = s_tr[1];
569       start_col1[1] = vextq_s16(s_tr[1], s_tr[5], 1);
570       start_col1[2] = vextq_s16(s_tr[1], s_tr[5], 2);
571       start_col1[3] = vextq_s16(s_tr[1], s_tr[5], 3);
572       start_col1[4] = vextq_s16(s_tr[1], s_tr[5], 4);
573 
574       start_col2[0] = s_tr[2];
575       start_col2[1] = vextq_s16(s_tr[2], s_tr[6], 1);
576       start_col2[2] = vextq_s16(s_tr[2], s_tr[6], 2);
577       start_col2[3] = vextq_s16(s_tr[2], s_tr[6], 3);
578       start_col2[4] = vextq_s16(s_tr[2], s_tr[6], 4);
579 
580       start_col3[0] = s_tr[3];
581       start_col3[1] = vextq_s16(s_tr[3], s_tr[7], 1);
582       start_col3[2] = vextq_s16(s_tr[3], s_tr[7], 2);
583       start_col3[3] = vextq_s16(s_tr[3], s_tr[7], 3);
584       start_col3[4] = vextq_s16(s_tr[3], s_tr[7], 4);
585 
586       // i = 1, j = 2;
587       sub_deltas_step4(start_col0, start_col1, deltas[0]);
588 
589       // i = 1, j = 3;
590       sub_deltas_step4(start_col0, start_col2, deltas[1]);
591 
592       // i = 1, j = 4
593       sub_deltas_step4(start_col0, start_col3, deltas[2]);
594 
595       // i = 2, j =3
596       sub_deltas_step4(start_col1, start_col2, deltas[3]);
597 
598       // i = 2, j = 4
599       sub_deltas_step4(start_col1, start_col3, deltas[4]);
600 
601       // i = 3, j = 4
602       sub_deltas_step4(start_col2, start_col3, deltas[5]);
603 
604       int16x8_t end_col0[5], end_col1[5], end_col2[5], end_col3[5];
605       end_col0[0] = e_tr[0];
606       end_col0[1] = vextq_s16(e_tr[0], e_tr[4], 1);
607       end_col0[2] = vextq_s16(e_tr[0], e_tr[4], 2);
608       end_col0[3] = vextq_s16(e_tr[0], e_tr[4], 3);
609       end_col0[4] = vextq_s16(e_tr[0], e_tr[4], 4);
610 
611       end_col1[0] = e_tr[1];
612       end_col1[1] = vextq_s16(e_tr[1], e_tr[5], 1);
613       end_col1[2] = vextq_s16(e_tr[1], e_tr[5], 2);
614       end_col1[3] = vextq_s16(e_tr[1], e_tr[5], 3);
615       end_col1[4] = vextq_s16(e_tr[1], e_tr[5], 4);
616 
617       end_col2[0] = e_tr[2];
618       end_col2[1] = vextq_s16(e_tr[2], e_tr[6], 1);
619       end_col2[2] = vextq_s16(e_tr[2], e_tr[6], 2);
620       end_col2[3] = vextq_s16(e_tr[2], e_tr[6], 3);
621       end_col2[4] = vextq_s16(e_tr[2], e_tr[6], 4);
622 
623       end_col3[0] = e_tr[3];
624       end_col3[1] = vextq_s16(e_tr[3], e_tr[7], 1);
625       end_col3[2] = vextq_s16(e_tr[3], e_tr[7], 2);
626       end_col3[3] = vextq_s16(e_tr[3], e_tr[7], 3);
627       end_col3[4] = vextq_s16(e_tr[3], e_tr[7], 4);
628 
629       // i = 1, j = 2;
630       add_deltas_step4(end_col0, end_col1, deltas[0]);
631 
632       // i = 1, j = 3;
633       add_deltas_step4(end_col0, end_col2, deltas[1]);
634 
635       // i = 1, j = 4
636       add_deltas_step4(end_col0, end_col3, deltas[2]);
637 
638       // i = 2, j =3
639       add_deltas_step4(end_col1, end_col2, deltas[3]);
640 
641       // i = 2, j = 4
642       add_deltas_step4(end_col1, end_col3, deltas[4]);
643 
644       // i = 3, j = 4
645       add_deltas_step4(end_col2, end_col3, deltas[5]);
646 
647       d_s[0] = d_s[8];
648       d_s[1] = d_s[9];
649       d_s[2] = d_s[10];
650       d_s[3] = d_s[11];
651       d_e[0] = d_e[8];
652       d_e[1] = d_e[9];
653       d_e[2] = d_e[10];
654       d_e[3] = d_e[11];
655 
656       d_t += 8 * d_stride;
657       y -= 8;
658     }
659 
660     if (h8 != height) {
661       const int16x8_t mask_h = vld1q_s16(&mask_16bit[16] - (height % 8));
662 
663       load_s16_4x8(d_t + 4 * d_stride, d_stride, &d_s[4], &d_s[5], &d_s[6],
664                    &d_s[7], &d_s[8], &d_s[9], &d_s[10], &d_s[11]);
665       load_s16_4x8(d_t + width + 4 * d_stride, d_stride, &d_e[4], &d_e[5],
666                    &d_e[6], &d_e[7], &d_e[8], &d_e[9], &d_e[10], &d_e[11]);
667       int16x8_t s_tr[8], e_tr[8];
668       transpose_elems_s16_4x8(d_s[0], d_s[1], d_s[2], d_s[3], d_s[4], d_s[5],
669                               d_s[6], d_s[7], &s_tr[0], &s_tr[1], &s_tr[2],
670                               &s_tr[3]);
671       transpose_elems_s16_4x8(d_s[8], d_s[9], d_s[10], d_s[11], zeros, zeros,
672                               zeros, zeros, &s_tr[4], &s_tr[5], &s_tr[6],
673                               &s_tr[7]);
674       transpose_elems_s16_4x8(d_e[0], d_e[1], d_e[2], d_e[3], d_e[4], d_e[5],
675                               d_e[6], d_e[7], &e_tr[0], &e_tr[1], &e_tr[2],
676                               &e_tr[3]);
677       transpose_elems_s16_4x8(d_e[8], d_e[9], d_e[10], d_e[11], zeros, zeros,
678                               zeros, zeros, &e_tr[4], &e_tr[5], &e_tr[6],
679                               &e_tr[7]);
680 
681       int16x8_t start_col0[5], start_col1[5], start_col2[5], start_col3[5];
682       start_col0[0] = vandq_s16(s_tr[0], mask_h);
683       start_col0[1] = vandq_s16(vextq_s16(s_tr[0], s_tr[4], 1), mask_h);
684       start_col0[2] = vandq_s16(vextq_s16(s_tr[0], s_tr[4], 2), mask_h);
685       start_col0[3] = vandq_s16(vextq_s16(s_tr[0], s_tr[4], 3), mask_h);
686       start_col0[4] = vandq_s16(vextq_s16(s_tr[0], s_tr[4], 4), mask_h);
687 
688       start_col1[0] = vandq_s16(s_tr[1], mask_h);
689       start_col1[1] = vandq_s16(vextq_s16(s_tr[1], s_tr[5], 1), mask_h);
690       start_col1[2] = vandq_s16(vextq_s16(s_tr[1], s_tr[5], 2), mask_h);
691       start_col1[3] = vandq_s16(vextq_s16(s_tr[1], s_tr[5], 3), mask_h);
692       start_col1[4] = vandq_s16(vextq_s16(s_tr[1], s_tr[5], 4), mask_h);
693 
694       start_col2[0] = vandq_s16(s_tr[2], mask_h);
695       start_col2[1] = vandq_s16(vextq_s16(s_tr[2], s_tr[6], 1), mask_h);
696       start_col2[2] = vandq_s16(vextq_s16(s_tr[2], s_tr[6], 2), mask_h);
697       start_col2[3] = vandq_s16(vextq_s16(s_tr[2], s_tr[6], 3), mask_h);
698       start_col2[4] = vandq_s16(vextq_s16(s_tr[2], s_tr[6], 4), mask_h);
699 
700       start_col3[0] = vandq_s16(s_tr[3], mask_h);
701       start_col3[1] = vandq_s16(vextq_s16(s_tr[3], s_tr[7], 1), mask_h);
702       start_col3[2] = vandq_s16(vextq_s16(s_tr[3], s_tr[7], 2), mask_h);
703       start_col3[3] = vandq_s16(vextq_s16(s_tr[3], s_tr[7], 3), mask_h);
704       start_col3[4] = vandq_s16(vextq_s16(s_tr[3], s_tr[7], 4), mask_h);
705 
706       // i = 1, j = 2;
707       sub_deltas_step4(start_col0, start_col1, deltas[0]);
708 
709       // i = 1, j = 3;
710       sub_deltas_step4(start_col0, start_col2, deltas[1]);
711 
712       // i = 1, j = 4
713       sub_deltas_step4(start_col0, start_col3, deltas[2]);
714 
715       // i = 2, j = 3
716       sub_deltas_step4(start_col1, start_col2, deltas[3]);
717 
718       // i = 2, j = 4
719       sub_deltas_step4(start_col1, start_col3, deltas[4]);
720 
721       // i = 3, j = 4
722       sub_deltas_step4(start_col2, start_col3, deltas[5]);
723 
724       int16x8_t end_col0[5], end_col1[5], end_col2[5], end_col3[5];
725       end_col0[0] = vandq_s16(e_tr[0], mask_h);
726       end_col0[1] = vandq_s16(vextq_s16(e_tr[0], e_tr[4], 1), mask_h);
727       end_col0[2] = vandq_s16(vextq_s16(e_tr[0], e_tr[4], 2), mask_h);
728       end_col0[3] = vandq_s16(vextq_s16(e_tr[0], e_tr[4], 3), mask_h);
729       end_col0[4] = vandq_s16(vextq_s16(e_tr[0], e_tr[4], 4), mask_h);
730 
731       end_col1[0] = vandq_s16(e_tr[1], mask_h);
732       end_col1[1] = vandq_s16(vextq_s16(e_tr[1], e_tr[5], 1), mask_h);
733       end_col1[2] = vandq_s16(vextq_s16(e_tr[1], e_tr[5], 2), mask_h);
734       end_col1[3] = vandq_s16(vextq_s16(e_tr[1], e_tr[5], 3), mask_h);
735       end_col1[4] = vandq_s16(vextq_s16(e_tr[1], e_tr[5], 4), mask_h);
736 
737       end_col2[0] = vandq_s16(e_tr[2], mask_h);
738       end_col2[1] = vandq_s16(vextq_s16(e_tr[2], e_tr[6], 1), mask_h);
739       end_col2[2] = vandq_s16(vextq_s16(e_tr[2], e_tr[6], 2), mask_h);
740       end_col2[3] = vandq_s16(vextq_s16(e_tr[2], e_tr[6], 3), mask_h);
741       end_col2[4] = vandq_s16(vextq_s16(e_tr[2], e_tr[6], 4), mask_h);
742 
743       end_col3[0] = vandq_s16(e_tr[3], mask_h);
744       end_col3[1] = vandq_s16(vextq_s16(e_tr[3], e_tr[7], 1), mask_h);
745       end_col3[2] = vandq_s16(vextq_s16(e_tr[3], e_tr[7], 2), mask_h);
746       end_col3[3] = vandq_s16(vextq_s16(e_tr[3], e_tr[7], 3), mask_h);
747       end_col3[4] = vandq_s16(vextq_s16(e_tr[3], e_tr[7], 4), mask_h);
748 
749       // i = 1, j = 2;
750       add_deltas_step4(end_col0, end_col1, deltas[0]);
751 
752       // i = 1, j = 3;
753       add_deltas_step4(end_col0, end_col2, deltas[1]);
754 
755       // i = 1, j = 4
756       add_deltas_step4(end_col0, end_col3, deltas[2]);
757 
758       // i = 2, j =3
759       add_deltas_step4(end_col1, end_col2, deltas[3]);
760 
761       // i = 2, j = 4
762       add_deltas_step4(end_col1, end_col3, deltas[4]);
763 
764       // i = 3, j = 4
765       add_deltas_step4(end_col2, end_col3, deltas[5]);
766     }
767 
768     int32x4_t delta[6][2];
769     int32_t single_delta[6];
770 
771     delta[0][0] = horizontal_add_4d_s32x4(&deltas[0][0]);
772     delta[1][0] = horizontal_add_4d_s32x4(&deltas[1][0]);
773     delta[2][0] = horizontal_add_4d_s32x4(&deltas[2][0]);
774     delta[3][0] = horizontal_add_4d_s32x4(&deltas[3][0]);
775     delta[4][0] = horizontal_add_4d_s32x4(&deltas[4][0]);
776     delta[5][0] = horizontal_add_4d_s32x4(&deltas[5][0]);
777 
778     delta[0][1] = horizontal_add_4d_s32x4(&deltas[0][5]);
779     delta[1][1] = horizontal_add_4d_s32x4(&deltas[1][5]);
780     delta[2][1] = horizontal_add_4d_s32x4(&deltas[2][5]);
781     delta[3][1] = horizontal_add_4d_s32x4(&deltas[3][5]);
782     delta[4][1] = horizontal_add_4d_s32x4(&deltas[4][5]);
783     delta[5][1] = horizontal_add_4d_s32x4(&deltas[5][5]);
784 
785     single_delta[0] = horizontal_add_s32x4(deltas[0][4]);
786     single_delta[1] = horizontal_add_s32x4(deltas[1][4]);
787     single_delta[2] = horizontal_add_s32x4(deltas[2][4]);
788     single_delta[3] = horizontal_add_s32x4(deltas[3][4]);
789     single_delta[4] = horizontal_add_s32x4(deltas[4][4]);
790     single_delta[5] = horizontal_add_s32x4(deltas[5][4]);
791 
792     int idx = 0;
793     for (i = 1; i < wiener_win - 1; i++) {
794       for (j = i + 1; j < wiener_win; j++) {
795         update_4_stats_neon(
796             H + (i - 1) * wiener_win * wiener_win2 + (j - 1) * wiener_win,
797             delta[idx][0], H + i * wiener_win * wiener_win2 + j * wiener_win);
798         H[i * wiener_win * wiener_win2 + j * wiener_win + 4] =
799             H[(i - 1) * wiener_win * wiener_win2 + (j - 1) * wiener_win + 4] +
800             single_delta[idx];
801 
802         H[(i * wiener_win + 1) * wiener_win2 + j * wiener_win] =
803             H[((i - 1) * wiener_win + 1) * wiener_win2 + (j - 1) * wiener_win] +
804             vgetq_lane_s32(delta[idx][1], 0);
805         H[(i * wiener_win + 2) * wiener_win2 + j * wiener_win] =
806             H[((i - 1) * wiener_win + 2) * wiener_win2 + (j - 1) * wiener_win] +
807             vgetq_lane_s32(delta[idx][1], 1);
808         H[(i * wiener_win + 3) * wiener_win2 + j * wiener_win] =
809             H[((i - 1) * wiener_win + 3) * wiener_win2 + (j - 1) * wiener_win] +
810             vgetq_lane_s32(delta[idx][1], 2);
811         H[(i * wiener_win + 4) * wiener_win2 + j * wiener_win] =
812             H[((i - 1) * wiener_win + 4) * wiener_win2 + (j - 1) * wiener_win] +
813             vgetq_lane_s32(delta[idx][1], 3);
814 
815         idx++;
816       }
817     }
818   }
819 
820   // Step 5: Derive other points of each square. No square in bottom row.
821   i = 0;
822   do {
823     const int16_t *const di = d + i;
824 
825     j = i + 1;
826     do {
827       const int16_t *const dj = d + j;
828       int32x4_t deltas[WIENER_WIN_CHROMA - 1][WIENER_WIN_CHROMA - 1] = {
829         { vdupq_n_s32(0) }, { vdupq_n_s32(0) }
830       };
831       int16x8_t d_is[WIN_CHROMA], d_ie[WIN_CHROMA];
832       int16x8_t d_js[WIN_CHROMA], d_je[WIN_CHROMA];
833 
834       x = 0;
835       while (x < w16) {
836         load_square_win5_neon(di + x, dj + x, d_stride, height, d_is, d_ie,
837                               d_js, d_je);
838         derive_square_win5_neon(d_is, d_ie, d_js, d_je, deltas);
839         x += 16;
840       }
841 
842       if (w16 != width) {
843         load_square_win5_neon(di + x, dj + x, d_stride, height, d_is, d_ie,
844                               d_js, d_je);
845         d_is[0] = vandq_s16(d_is[0], mask[0]);
846         d_is[1] = vandq_s16(d_is[1], mask[1]);
847         d_is[2] = vandq_s16(d_is[2], mask[0]);
848         d_is[3] = vandq_s16(d_is[3], mask[1]);
849         d_is[4] = vandq_s16(d_is[4], mask[0]);
850         d_is[5] = vandq_s16(d_is[5], mask[1]);
851         d_is[6] = vandq_s16(d_is[6], mask[0]);
852         d_is[7] = vandq_s16(d_is[7], mask[1]);
853         d_ie[0] = vandq_s16(d_ie[0], mask[0]);
854         d_ie[1] = vandq_s16(d_ie[1], mask[1]);
855         d_ie[2] = vandq_s16(d_ie[2], mask[0]);
856         d_ie[3] = vandq_s16(d_ie[3], mask[1]);
857         d_ie[4] = vandq_s16(d_ie[4], mask[0]);
858         d_ie[5] = vandq_s16(d_ie[5], mask[1]);
859         d_ie[6] = vandq_s16(d_ie[6], mask[0]);
860         d_ie[7] = vandq_s16(d_ie[7], mask[1]);
861         derive_square_win5_neon(d_is, d_ie, d_js, d_je, deltas);
862       }
863 
864       hadd_update_4_stats_neon(
865           H + (i * wiener_win + 0) * wiener_win2 + j * wiener_win, deltas[0],
866           H + (i * wiener_win + 1) * wiener_win2 + j * wiener_win + 1);
867       hadd_update_4_stats_neon(
868           H + (i * wiener_win + 1) * wiener_win2 + j * wiener_win, deltas[1],
869           H + (i * wiener_win + 2) * wiener_win2 + j * wiener_win + 1);
870       hadd_update_4_stats_neon(
871           H + (i * wiener_win + 2) * wiener_win2 + j * wiener_win, deltas[2],
872           H + (i * wiener_win + 3) * wiener_win2 + j * wiener_win + 1);
873       hadd_update_4_stats_neon(
874           H + (i * wiener_win + 3) * wiener_win2 + j * wiener_win, deltas[3],
875           H + (i * wiener_win + 4) * wiener_win2 + j * wiener_win + 1);
876     } while (++j < wiener_win);
877   } while (++i < wiener_win - 1);
878 
879   // Step 6: Derive other points of each upper triangle along the diagonal.
880   i = 0;
881   do {
882     const int16_t *const di = d + i;
883     int32x4_t deltas[WIENER_WIN_CHROMA * 2 + 1] = { vdupq_n_s32(0) };
884     int16x8_t d_is[WIN_CHROMA], d_ie[WIN_CHROMA];
885 
886     x = 0;
887     while (x < w16) {
888       load_triangle_win5_neon(di + x, d_stride, height, d_is, d_ie);
889       derive_triangle_win5_neon(d_is, d_ie, deltas);
890       x += 16;
891     }
892 
893     if (w16 != width) {
894       load_triangle_win5_neon(di + x, d_stride, height, d_is, d_ie);
895       d_is[0] = vandq_s16(d_is[0], mask[0]);
896       d_is[1] = vandq_s16(d_is[1], mask[1]);
897       d_is[2] = vandq_s16(d_is[2], mask[0]);
898       d_is[3] = vandq_s16(d_is[3], mask[1]);
899       d_is[4] = vandq_s16(d_is[4], mask[0]);
900       d_is[5] = vandq_s16(d_is[5], mask[1]);
901       d_is[6] = vandq_s16(d_is[6], mask[0]);
902       d_is[7] = vandq_s16(d_is[7], mask[1]);
903       d_ie[0] = vandq_s16(d_ie[0], mask[0]);
904       d_ie[1] = vandq_s16(d_ie[1], mask[1]);
905       d_ie[2] = vandq_s16(d_ie[2], mask[0]);
906       d_ie[3] = vandq_s16(d_ie[3], mask[1]);
907       d_ie[4] = vandq_s16(d_ie[4], mask[0]);
908       d_ie[5] = vandq_s16(d_ie[5], mask[1]);
909       d_ie[6] = vandq_s16(d_ie[6], mask[0]);
910       d_ie[7] = vandq_s16(d_ie[7], mask[1]);
911       derive_triangle_win5_neon(d_is, d_ie, deltas);
912     }
913 
914     // Row 1: 4 points
915     hadd_update_4_stats_neon(
916         H + (i * wiener_win + 0) * wiener_win2 + i * wiener_win, deltas,
917         H + (i * wiener_win + 1) * wiener_win2 + i * wiener_win + 1);
918 
919     // Row 2: 3 points
920     int64x2_t delta4_s64 = vpaddlq_s32(deltas[4]);
921     int64x2_t delta5_s64 = vpaddlq_s32(deltas[5]);
922 
923 #if AOM_ARCH_AARCH64
924     int64x2_t deltas45 = vpaddq_s64(delta4_s64, delta5_s64);
925     int64x2_t src =
926         vld1q_s64(H + (i * wiener_win + 1) * wiener_win2 + i * wiener_win + 1);
927     int64x2_t dst = vaddq_s64(src, deltas45);
928     vst1q_s64(H + (i * wiener_win + 2) * wiener_win2 + i * wiener_win + 2, dst);
929 #else
930     H[(i * wiener_win + 2) * wiener_win2 + i * wiener_win + 2 + 0] =
931         H[(i * wiener_win + 1) * wiener_win2 + i * wiener_win + 1 + 0] +
932         horizontal_add_s64x2(delta4_s64);
933     H[(i * wiener_win + 2) * wiener_win2 + i * wiener_win + 2 + 1] =
934         H[(i * wiener_win + 1) * wiener_win2 + i * wiener_win + 1 + 1] +
935         horizontal_add_s64x2(delta5_s64);
936 #endif  // AOM_ARCH_AARCH64
937 
938     H[(i * wiener_win + 2) * wiener_win2 + i * wiener_win + 4] =
939         H[(i * wiener_win + 1) * wiener_win2 + i * wiener_win + 3] +
940         horizontal_long_add_s32x4(deltas[6]);
941 
942     // Row 3: 2 points
943     int64x2_t delta7_s64 = vpaddlq_s32(deltas[7]);
944     int64x2_t delta8_s64 = vpaddlq_s32(deltas[8]);
945 
946 #if AOM_ARCH_AARCH64
947     int64x2_t deltas78 = vpaddq_s64(delta7_s64, delta8_s64);
948     vst1q_s64(H + (i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3,
949               vaddq_s64(dst, deltas78));
950 #else
951     H[(i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3 + 0] =
952         H[(i * wiener_win + 2) * wiener_win2 + i * wiener_win + 2 + 0] +
953         horizontal_add_s64x2(delta7_s64);
954     H[(i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3 + 1] =
955         H[(i * wiener_win + 2) * wiener_win2 + i * wiener_win + 2 + 1] +
956         horizontal_add_s64x2(delta8_s64);
957 #endif  // AOM_ARCH_AARCH64
958 
959     // Row 4: 1 point
960     H[(i * wiener_win + 4) * wiener_win2 + i * wiener_win + 4] =
961         H[(i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3] +
962         horizontal_long_add_s32x4(deltas[9]);
963   } while (++i < wiener_win);
964 }
965 
hadd_update_6_stats_neon(const int64_t * const src,const int32x4_t * deltas,int64_t * const dst)966 static inline void hadd_update_6_stats_neon(const int64_t *const src,
967                                             const int32x4_t *deltas,
968                                             int64_t *const dst) {
969   int64x2_t delta0_s64 = vpaddlq_s32(deltas[0]);
970   int64x2_t delta1_s64 = vpaddlq_s32(deltas[1]);
971   int64x2_t delta2_s64 = vpaddlq_s32(deltas[2]);
972   int64x2_t delta3_s64 = vpaddlq_s32(deltas[3]);
973   int64x2_t delta4_s64 = vpaddlq_s32(deltas[4]);
974   int64x2_t delta5_s64 = vpaddlq_s32(deltas[5]);
975 
976 #if AOM_ARCH_AARCH64
977   int64x2_t delta01 = vpaddq_s64(delta0_s64, delta1_s64);
978   int64x2_t delta23 = vpaddq_s64(delta2_s64, delta3_s64);
979   int64x2_t delta45 = vpaddq_s64(delta4_s64, delta5_s64);
980 
981   int64x2_t src0 = vld1q_s64(src);
982   int64x2_t src1 = vld1q_s64(src + 2);
983   int64x2_t src2 = vld1q_s64(src + 4);
984 
985   vst1q_s64(dst, vaddq_s64(src0, delta01));
986   vst1q_s64(dst + 2, vaddq_s64(src1, delta23));
987   vst1q_s64(dst + 4, vaddq_s64(src2, delta45));
988 #else
989   dst[0] = src[0] + horizontal_add_s64x2(delta0_s64);
990   dst[1] = src[1] + horizontal_add_s64x2(delta1_s64);
991   dst[2] = src[2] + horizontal_add_s64x2(delta2_s64);
992   dst[3] = src[3] + horizontal_add_s64x2(delta3_s64);
993   dst[4] = src[4] + horizontal_add_s64x2(delta4_s64);
994   dst[5] = src[5] + horizontal_add_s64x2(delta5_s64);
995 #endif
996 }
997 
compute_stats_win7_highbd_neon(const int16_t * const d,const int32_t d_stride,const int16_t * const s,const int32_t s_stride,const int32_t width,const int32_t height,int64_t * const M,int64_t * const H,aom_bit_depth_t bit_depth)998 static inline void compute_stats_win7_highbd_neon(
999     const int16_t *const d, const int32_t d_stride, const int16_t *const s,
1000     const int32_t s_stride, const int32_t width, const int32_t height,
1001     int64_t *const M, int64_t *const H, aom_bit_depth_t bit_depth) {
1002   const int32_t wiener_win = WIENER_WIN;
1003   const int32_t wiener_win2 = wiener_win * wiener_win;
1004   const int32_t w16 = width & ~15;
1005   const int32_t h8 = height & ~7;
1006   int16x8_t mask[2];
1007   mask[0] = vld1q_s16(&(mask_16bit[16]) - width % 16);
1008   mask[1] = vld1q_s16(&(mask_16bit[16]) - width % 16 + 8);
1009   int32_t i, j, x, y;
1010 
1011   const int32_t num_bit_left =
1012       32 - 1 /* sign */ - 2 * bit_depth /* energy */ + 2 /* SIMD */;
1013   const int32_t h_allowed =
1014       (1 << num_bit_left) / (w16 + ((w16 != width) ? 16 : 0));
1015 
1016   // Step 1: Calculate the top edge of the whole matrix, i.e., the top
1017   // edge of each triangle and square on the top row.
1018   j = 0;
1019   do {
1020     const int16_t *s_t = s;
1021     const int16_t *d_t = d;
1022     int32_t height_t = 0;
1023     int64x2_t sum_m[WIENER_WIN] = { vdupq_n_s64(0) };
1024     int64x2_t sum_h[WIENER_WIN] = { vdupq_n_s64(0) };
1025     int16x8_t src[2], dgd[2];
1026 
1027     do {
1028       const int32_t h_t =
1029           ((height - height_t) < h_allowed) ? (height - height_t) : h_allowed;
1030       int32x4_t row_m[WIENER_WIN * 2] = { vdupq_n_s32(0) };
1031       int32x4_t row_h[WIENER_WIN * 2] = { vdupq_n_s32(0) };
1032 
1033       y = h_t;
1034       do {
1035         x = 0;
1036         while (x < w16) {
1037           src[0] = vld1q_s16(s_t + x);
1038           src[1] = vld1q_s16(s_t + x + 8);
1039           dgd[0] = vld1q_s16(d_t + x);
1040           dgd[1] = vld1q_s16(d_t + x + 8);
1041           stats_top_win7_neon(src, dgd, d_t + j + x, d_stride, row_m, row_h);
1042           x += 16;
1043         }
1044 
1045         if (w16 != width) {
1046           src[0] = vld1q_s16(s_t + w16);
1047           src[1] = vld1q_s16(s_t + w16 + 8);
1048           dgd[0] = vld1q_s16(d_t + w16);
1049           dgd[1] = vld1q_s16(d_t + w16 + 8);
1050           src[0] = vandq_s16(src[0], mask[0]);
1051           src[1] = vandq_s16(src[1], mask[1]);
1052           dgd[0] = vandq_s16(dgd[0], mask[0]);
1053           dgd[1] = vandq_s16(dgd[1], mask[1]);
1054           stats_top_win7_neon(src, dgd, d_t + j + w16, d_stride, row_m, row_h);
1055         }
1056 
1057         s_t += s_stride;
1058         d_t += d_stride;
1059       } while (--y);
1060 
1061       sum_m[0] = vpadalq_s32(sum_m[0], row_m[0]);
1062       sum_m[1] = vpadalq_s32(sum_m[1], row_m[1]);
1063       sum_m[2] = vpadalq_s32(sum_m[2], row_m[2]);
1064       sum_m[3] = vpadalq_s32(sum_m[3], row_m[3]);
1065       sum_m[4] = vpadalq_s32(sum_m[4], row_m[4]);
1066       sum_m[5] = vpadalq_s32(sum_m[5], row_m[5]);
1067       sum_m[6] = vpadalq_s32(sum_m[6], row_m[6]);
1068 
1069       sum_h[0] = vpadalq_s32(sum_h[0], row_h[0]);
1070       sum_h[1] = vpadalq_s32(sum_h[1], row_h[1]);
1071       sum_h[2] = vpadalq_s32(sum_h[2], row_h[2]);
1072       sum_h[3] = vpadalq_s32(sum_h[3], row_h[3]);
1073       sum_h[4] = vpadalq_s32(sum_h[4], row_h[4]);
1074       sum_h[5] = vpadalq_s32(sum_h[5], row_h[5]);
1075       sum_h[6] = vpadalq_s32(sum_h[6], row_h[6]);
1076 
1077       height_t += h_t;
1078     } while (height_t < height);
1079 
1080 #if AOM_ARCH_AARCH64
1081     vst1q_s64(M + wiener_win * j + 0, vpaddq_s64(sum_m[0], sum_m[1]));
1082     vst1q_s64(M + wiener_win * j + 2, vpaddq_s64(sum_m[2], sum_m[3]));
1083     vst1q_s64(M + wiener_win * j + 4, vpaddq_s64(sum_m[4], sum_m[5]));
1084     M[wiener_win * j + 6] = vaddvq_s64(sum_m[6]);
1085 
1086     vst1q_s64(H + wiener_win * j + 0, vpaddq_s64(sum_h[0], sum_h[1]));
1087     vst1q_s64(H + wiener_win * j + 2, vpaddq_s64(sum_h[2], sum_h[3]));
1088     vst1q_s64(H + wiener_win * j + 4, vpaddq_s64(sum_h[4], sum_h[5]));
1089     H[wiener_win * j + 6] = vaddvq_s64(sum_h[6]);
1090 #else
1091     M[wiener_win * j + 0] = horizontal_add_s64x2(sum_m[0]);
1092     M[wiener_win * j + 1] = horizontal_add_s64x2(sum_m[1]);
1093     M[wiener_win * j + 2] = horizontal_add_s64x2(sum_m[2]);
1094     M[wiener_win * j + 3] = horizontal_add_s64x2(sum_m[3]);
1095     M[wiener_win * j + 4] = horizontal_add_s64x2(sum_m[4]);
1096     M[wiener_win * j + 5] = horizontal_add_s64x2(sum_m[5]);
1097     M[wiener_win * j + 6] = horizontal_add_s64x2(sum_m[6]);
1098 
1099     H[wiener_win * j + 0] = horizontal_add_s64x2(sum_h[0]);
1100     H[wiener_win * j + 1] = horizontal_add_s64x2(sum_h[1]);
1101     H[wiener_win * j + 2] = horizontal_add_s64x2(sum_h[2]);
1102     H[wiener_win * j + 3] = horizontal_add_s64x2(sum_h[3]);
1103     H[wiener_win * j + 4] = horizontal_add_s64x2(sum_h[4]);
1104     H[wiener_win * j + 5] = horizontal_add_s64x2(sum_h[5]);
1105     H[wiener_win * j + 6] = horizontal_add_s64x2(sum_h[6]);
1106 #endif  // AOM_ARCH_AARCH64
1107   } while (++j < wiener_win);
1108 
1109   // Step 2: Calculate the left edge of each square on the top row.
1110   j = 1;
1111   do {
1112     const int16_t *d_t = d;
1113     int32_t height_t = 0;
1114     int64x2_t sum_h[WIENER_WIN - 1] = { vdupq_n_s64(0) };
1115     int16x8_t dgd[2];
1116 
1117     do {
1118       const int32_t h_t =
1119           ((height - height_t) < h_allowed) ? (height - height_t) : h_allowed;
1120       int32x4_t row_h[WIENER_WIN - 1] = { vdupq_n_s32(0) };
1121 
1122       y = h_t;
1123       do {
1124         x = 0;
1125         while (x < w16) {
1126           dgd[0] = vld1q_s16(d_t + j + x + 0);
1127           dgd[1] = vld1q_s16(d_t + j + x + 8);
1128           stats_left_win7_neon(dgd, d_t + x, d_stride, row_h);
1129           x += 16;
1130         }
1131 
1132         if (w16 != width) {
1133           dgd[0] = vld1q_s16(d_t + j + x + 0);
1134           dgd[1] = vld1q_s16(d_t + j + x + 8);
1135           dgd[0] = vandq_s16(dgd[0], mask[0]);
1136           dgd[1] = vandq_s16(dgd[1], mask[1]);
1137           stats_left_win7_neon(dgd, d_t + x, d_stride, row_h);
1138         }
1139 
1140         d_t += d_stride;
1141       } while (--y);
1142 
1143       sum_h[0] = vpadalq_s32(sum_h[0], row_h[0]);
1144       sum_h[1] = vpadalq_s32(sum_h[1], row_h[1]);
1145       sum_h[2] = vpadalq_s32(sum_h[2], row_h[2]);
1146       sum_h[3] = vpadalq_s32(sum_h[3], row_h[3]);
1147       sum_h[4] = vpadalq_s32(sum_h[4], row_h[4]);
1148       sum_h[5] = vpadalq_s32(sum_h[5], row_h[5]);
1149 
1150       height_t += h_t;
1151     } while (height_t < height);
1152 
1153 #if AOM_ARCH_AARCH64
1154     int64x2_t sum_h0 = vpaddq_s64(sum_h[0], sum_h[1]);
1155     int64x2_t sum_h2 = vpaddq_s64(sum_h[2], sum_h[3]);
1156     int64x2_t sum_h4 = vpaddq_s64(sum_h[4], sum_h[5]);
1157     vst1_s64(&H[1 * wiener_win2 + j * wiener_win], vget_low_s64(sum_h0));
1158     vst1_s64(&H[2 * wiener_win2 + j * wiener_win], vget_high_s64(sum_h0));
1159     vst1_s64(&H[3 * wiener_win2 + j * wiener_win], vget_low_s64(sum_h2));
1160     vst1_s64(&H[4 * wiener_win2 + j * wiener_win], vget_high_s64(sum_h2));
1161     vst1_s64(&H[5 * wiener_win2 + j * wiener_win], vget_low_s64(sum_h4));
1162     vst1_s64(&H[6 * wiener_win2 + j * wiener_win], vget_high_s64(sum_h4));
1163 #else
1164     H[1 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[0]);
1165     H[2 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[1]);
1166     H[3 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[2]);
1167     H[4 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[3]);
1168     H[5 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[4]);
1169     H[6 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[5]);
1170 #endif  // AOM_ARCH_AARCH64
1171 
1172   } while (++j < wiener_win);
1173 
1174   // Step 3: Derive the top edge of each triangle along the diagonal. No
1175   // triangle in top row.
1176   {
1177     const int16_t *d_t = d;
1178     // Pad to call transpose function.
1179     int32x4_t deltas[(WIENER_WIN + 1) * 2] = { vdupq_n_s32(0) };
1180     int32x4_t deltas_tr[(WIENER_WIN + 1) * 2] = { vdupq_n_s32(0) };
1181     int16x8_t ds[WIENER_WIN * 2];
1182 
1183     load_s16_8x6(d_t, d_stride, &ds[0], &ds[2], &ds[4], &ds[6], &ds[8],
1184                  &ds[10]);
1185     load_s16_8x6(d_t + width, d_stride, &ds[1], &ds[3], &ds[5], &ds[7], &ds[9],
1186                  &ds[11]);
1187 
1188     d_t += 6 * d_stride;
1189 
1190     step3_win7_neon(d_t, d_stride, width, height, ds, deltas);
1191     transpose_arrays_s32_8x8(deltas, deltas_tr);
1192 
1193     update_8_stats_neon(H + 0 * wiener_win * wiener_win2 + 0 * wiener_win,
1194                         deltas_tr[0], deltas_tr[4],
1195                         H + 1 * wiener_win * wiener_win2 + 1 * wiener_win);
1196     update_8_stats_neon(H + 1 * wiener_win * wiener_win2 + 1 * wiener_win,
1197                         deltas_tr[1], deltas_tr[5],
1198                         H + 2 * wiener_win * wiener_win2 + 2 * wiener_win);
1199     update_8_stats_neon(H + 2 * wiener_win * wiener_win2 + 2 * wiener_win,
1200                         deltas_tr[2], deltas_tr[6],
1201                         H + 3 * wiener_win * wiener_win2 + 3 * wiener_win);
1202     update_8_stats_neon(H + 3 * wiener_win * wiener_win2 + 3 * wiener_win,
1203                         deltas_tr[3], deltas_tr[7],
1204                         H + 4 * wiener_win * wiener_win2 + 4 * wiener_win);
1205     update_8_stats_neon(H + 4 * wiener_win * wiener_win2 + 4 * wiener_win,
1206                         deltas_tr[8], deltas_tr[12],
1207                         H + 5 * wiener_win * wiener_win2 + 5 * wiener_win);
1208     update_8_stats_neon(H + 5 * wiener_win * wiener_win2 + 5 * wiener_win,
1209                         deltas_tr[9], deltas_tr[13],
1210                         H + 6 * wiener_win * wiener_win2 + 6 * wiener_win);
1211   }
1212 
1213   // Step 4: Derive the top and left edge of each square. No square in top and
1214   // bottom row.
1215 
1216   i = 1;
1217   do {
1218     j = i + 1;
1219     do {
1220       const int16_t *di = d + i - 1;
1221       const int16_t *dj = d + j - 1;
1222       int32x4_t deltas[(2 * WIENER_WIN - 1) * 2] = { vdupq_n_s32(0) };
1223       int16x8_t dd[WIENER_WIN * 2], ds[WIENER_WIN * 2];
1224 
1225       dd[5] = vdupq_n_s16(0);  // Initialize to avoid warning.
1226       const int16_t dd0_values[] = { di[0 * d_stride],
1227                                      di[1 * d_stride],
1228                                      di[2 * d_stride],
1229                                      di[3 * d_stride],
1230                                      di[4 * d_stride],
1231                                      di[5 * d_stride],
1232                                      0,
1233                                      0 };
1234       dd[0] = vld1q_s16(dd0_values);
1235       const int16_t dd1_values[] = { di[0 * d_stride + width],
1236                                      di[1 * d_stride + width],
1237                                      di[2 * d_stride + width],
1238                                      di[3 * d_stride + width],
1239                                      di[4 * d_stride + width],
1240                                      di[5 * d_stride + width],
1241                                      0,
1242                                      0 };
1243       dd[1] = vld1q_s16(dd1_values);
1244       const int16_t ds0_values[] = { dj[0 * d_stride],
1245                                      dj[1 * d_stride],
1246                                      dj[2 * d_stride],
1247                                      dj[3 * d_stride],
1248                                      dj[4 * d_stride],
1249                                      dj[5 * d_stride],
1250                                      0,
1251                                      0 };
1252       ds[0] = vld1q_s16(ds0_values);
1253       int16_t ds1_values[] = { dj[0 * d_stride + width],
1254                                dj[1 * d_stride + width],
1255                                dj[2 * d_stride + width],
1256                                dj[3 * d_stride + width],
1257                                dj[4 * d_stride + width],
1258                                dj[5 * d_stride + width],
1259                                0,
1260                                0 };
1261       ds[1] = vld1q_s16(ds1_values);
1262 
1263       y = 0;
1264       while (y < h8) {
1265         // 00s 10s 20s 30s 40s 50s 60s 70s  00e 10e 20e 30e 40e 50e 60e 70e
1266         dd[0] = vsetq_lane_s16(di[6 * d_stride], dd[0], 6);
1267         dd[0] = vsetq_lane_s16(di[7 * d_stride], dd[0], 7);
1268         dd[1] = vsetq_lane_s16(di[6 * d_stride + width], dd[1], 6);
1269         dd[1] = vsetq_lane_s16(di[7 * d_stride + width], dd[1], 7);
1270 
1271         // 00s 10s 20s 30s 40s 50s 60s 70s  00e 10e 20e 30e 40e 50e 60e 70e
1272         // 01s 11s 21s 31s 41s 51s 61s 71s  01e 11e 21e 31e 41e 51e 61e 71e
1273         ds[0] = vsetq_lane_s16(dj[6 * d_stride], ds[0], 6);
1274         ds[0] = vsetq_lane_s16(dj[7 * d_stride], ds[0], 7);
1275         ds[1] = vsetq_lane_s16(dj[6 * d_stride + width], ds[1], 6);
1276         ds[1] = vsetq_lane_s16(dj[7 * d_stride + width], ds[1], 7);
1277 
1278         load_more_16_neon(di + 8 * d_stride, width, &dd[0], &dd[2]);
1279         load_more_16_neon(dj + 8 * d_stride, width, &ds[0], &ds[2]);
1280         load_more_16_neon(di + 9 * d_stride, width, &dd[2], &dd[4]);
1281         load_more_16_neon(dj + 9 * d_stride, width, &ds[2], &ds[4]);
1282         load_more_16_neon(di + 10 * d_stride, width, &dd[4], &dd[6]);
1283         load_more_16_neon(dj + 10 * d_stride, width, &ds[4], &ds[6]);
1284         load_more_16_neon(di + 11 * d_stride, width, &dd[6], &dd[8]);
1285         load_more_16_neon(dj + 11 * d_stride, width, &ds[6], &ds[8]);
1286         load_more_16_neon(di + 12 * d_stride, width, &dd[8], &dd[10]);
1287         load_more_16_neon(dj + 12 * d_stride, width, &ds[8], &ds[10]);
1288         load_more_16_neon(di + 13 * d_stride, width, &dd[10], &dd[12]);
1289         load_more_16_neon(dj + 13 * d_stride, width, &ds[10], &ds[12]);
1290 
1291         madd_neon(&deltas[0], dd[0], ds[0]);
1292         madd_neon(&deltas[1], dd[1], ds[1]);
1293         madd_neon(&deltas[2], dd[0], ds[2]);
1294         madd_neon(&deltas[3], dd[1], ds[3]);
1295         madd_neon(&deltas[4], dd[0], ds[4]);
1296         madd_neon(&deltas[5], dd[1], ds[5]);
1297         madd_neon(&deltas[6], dd[0], ds[6]);
1298         madd_neon(&deltas[7], dd[1], ds[7]);
1299         madd_neon(&deltas[8], dd[0], ds[8]);
1300         madd_neon(&deltas[9], dd[1], ds[9]);
1301         madd_neon(&deltas[10], dd[0], ds[10]);
1302         madd_neon(&deltas[11], dd[1], ds[11]);
1303         madd_neon(&deltas[12], dd[0], ds[12]);
1304         madd_neon(&deltas[13], dd[1], ds[13]);
1305         madd_neon(&deltas[14], dd[2], ds[0]);
1306         madd_neon(&deltas[15], dd[3], ds[1]);
1307         madd_neon(&deltas[16], dd[4], ds[0]);
1308         madd_neon(&deltas[17], dd[5], ds[1]);
1309         madd_neon(&deltas[18], dd[6], ds[0]);
1310         madd_neon(&deltas[19], dd[7], ds[1]);
1311         madd_neon(&deltas[20], dd[8], ds[0]);
1312         madd_neon(&deltas[21], dd[9], ds[1]);
1313         madd_neon(&deltas[22], dd[10], ds[0]);
1314         madd_neon(&deltas[23], dd[11], ds[1]);
1315         madd_neon(&deltas[24], dd[12], ds[0]);
1316         madd_neon(&deltas[25], dd[13], ds[1]);
1317 
1318         dd[0] = vextq_s16(dd[12], vdupq_n_s16(0), 2);
1319         dd[1] = vextq_s16(dd[13], vdupq_n_s16(0), 2);
1320         ds[0] = vextq_s16(ds[12], vdupq_n_s16(0), 2);
1321         ds[1] = vextq_s16(ds[13], vdupq_n_s16(0), 2);
1322 
1323         di += 8 * d_stride;
1324         dj += 8 * d_stride;
1325         y += 8;
1326       }
1327 
1328       deltas[0] = hadd_four_32_neon(deltas[0], deltas[2], deltas[4], deltas[6]);
1329       deltas[1] = hadd_four_32_neon(deltas[1], deltas[3], deltas[5], deltas[7]);
1330       deltas[2] =
1331           hadd_four_32_neon(deltas[8], deltas[10], deltas[12], deltas[12]);
1332       deltas[3] =
1333           hadd_four_32_neon(deltas[9], deltas[11], deltas[13], deltas[13]);
1334       deltas[4] =
1335           hadd_four_32_neon(deltas[14], deltas[16], deltas[18], deltas[20]);
1336       deltas[5] =
1337           hadd_four_32_neon(deltas[15], deltas[17], deltas[19], deltas[21]);
1338       deltas[6] =
1339           hadd_four_32_neon(deltas[22], deltas[24], deltas[22], deltas[24]);
1340       deltas[7] =
1341           hadd_four_32_neon(deltas[23], deltas[25], deltas[23], deltas[25]);
1342       deltas[0] = vsubq_s32(deltas[1], deltas[0]);
1343       deltas[1] = vsubq_s32(deltas[3], deltas[2]);
1344       deltas[2] = vsubq_s32(deltas[5], deltas[4]);
1345       deltas[3] = vsubq_s32(deltas[7], deltas[6]);
1346 
1347       if (h8 != height) {
1348         const int16_t ds0_vals[] = {
1349           dj[0 * d_stride], dj[0 * d_stride + width],
1350           dj[1 * d_stride], dj[1 * d_stride + width],
1351           dj[2 * d_stride], dj[2 * d_stride + width],
1352           dj[3 * d_stride], dj[3 * d_stride + width]
1353         };
1354         ds[0] = vld1q_s16(ds0_vals);
1355 
1356         ds[1] = vsetq_lane_s16(dj[4 * d_stride], ds[1], 0);
1357         ds[1] = vsetq_lane_s16(dj[4 * d_stride + width], ds[1], 1);
1358         ds[1] = vsetq_lane_s16(dj[5 * d_stride], ds[1], 2);
1359         ds[1] = vsetq_lane_s16(dj[5 * d_stride + width], ds[1], 3);
1360         const int16_t dd4_vals[] = {
1361           -di[1 * d_stride], di[1 * d_stride + width],
1362           -di[2 * d_stride], di[2 * d_stride + width],
1363           -di[3 * d_stride], di[3 * d_stride + width],
1364           -di[4 * d_stride], di[4 * d_stride + width]
1365         };
1366         dd[4] = vld1q_s16(dd4_vals);
1367 
1368         dd[5] = vsetq_lane_s16(-di[5 * d_stride], dd[5], 0);
1369         dd[5] = vsetq_lane_s16(di[5 * d_stride + width], dd[5], 1);
1370         do {
1371           dd[0] = vdupq_n_s16(-di[0 * d_stride]);
1372           dd[2] = dd[3] = vdupq_n_s16(di[0 * d_stride + width]);
1373           dd[0] = dd[1] = vzipq_s16(dd[0], dd[2]).val[0];
1374 
1375           ds[4] = vdupq_n_s16(dj[0 * d_stride]);
1376           ds[6] = ds[7] = vdupq_n_s16(dj[0 * d_stride + width]);
1377           ds[4] = ds[5] = vzipq_s16(ds[4], ds[6]).val[0];
1378 
1379           dd[5] = vsetq_lane_s16(-di[6 * d_stride], dd[5], 2);
1380           dd[5] = vsetq_lane_s16(di[6 * d_stride + width], dd[5], 3);
1381           ds[1] = vsetq_lane_s16(dj[6 * d_stride], ds[1], 4);
1382           ds[1] = vsetq_lane_s16(dj[6 * d_stride + width], ds[1], 5);
1383 
1384           madd_neon_pairwise(&deltas[0], dd[0], ds[0]);
1385           madd_neon_pairwise(&deltas[1], dd[1], ds[1]);
1386           madd_neon_pairwise(&deltas[2], dd[4], ds[4]);
1387           madd_neon_pairwise(&deltas[3], dd[5], ds[5]);
1388 
1389           int32_t tmp0 = vgetq_lane_s32(vreinterpretq_s32_s16(ds[0]), 0);
1390           ds[0] = vextq_s16(ds[0], ds[1], 2);
1391           ds[1] = vextq_s16(ds[1], ds[0], 2);
1392           ds[1] = vreinterpretq_s16_s32(
1393               vsetq_lane_s32(tmp0, vreinterpretq_s32_s16(ds[1]), 3));
1394           int32_t tmp1 = vgetq_lane_s32(vreinterpretq_s32_s16(dd[4]), 0);
1395           dd[4] = vextq_s16(dd[4], dd[5], 2);
1396           dd[5] = vextq_s16(dd[5], dd[4], 2);
1397           dd[5] = vreinterpretq_s16_s32(
1398               vsetq_lane_s32(tmp1, vreinterpretq_s32_s16(dd[5]), 3));
1399           di += d_stride;
1400           dj += d_stride;
1401         } while (++y < height);
1402       }
1403 
1404       // Writing one more element on the top edge of a square falls to
1405       // the next square in the same row or the first element in the next
1406       // row, which will just be overwritten later.
1407       update_8_stats_neon(
1408           H + (i - 1) * wiener_win * wiener_win2 + (j - 1) * wiener_win,
1409           deltas[0], deltas[1],
1410           H + i * wiener_win * wiener_win2 + j * wiener_win);
1411 
1412       H[(i * wiener_win + 1) * wiener_win2 + j * wiener_win] =
1413           H[((i - 1) * wiener_win + 1) * wiener_win2 + (j - 1) * wiener_win] +
1414           vgetq_lane_s32(deltas[2], 0);
1415       H[(i * wiener_win + 2) * wiener_win2 + j * wiener_win] =
1416           H[((i - 1) * wiener_win + 2) * wiener_win2 + (j - 1) * wiener_win] +
1417           vgetq_lane_s32(deltas[2], 1);
1418       H[(i * wiener_win + 3) * wiener_win2 + j * wiener_win] =
1419           H[((i - 1) * wiener_win + 3) * wiener_win2 + (j - 1) * wiener_win] +
1420           vgetq_lane_s32(deltas[2], 2);
1421       H[(i * wiener_win + 4) * wiener_win2 + j * wiener_win] =
1422           H[((i - 1) * wiener_win + 4) * wiener_win2 + (j - 1) * wiener_win] +
1423           vgetq_lane_s32(deltas[2], 3);
1424       H[(i * wiener_win + 5) * wiener_win2 + j * wiener_win] =
1425           H[((i - 1) * wiener_win + 5) * wiener_win2 + (j - 1) * wiener_win] +
1426           vgetq_lane_s32(deltas[3], 0);
1427       H[(i * wiener_win + 6) * wiener_win2 + j * wiener_win] =
1428           H[((i - 1) * wiener_win + 6) * wiener_win2 + (j - 1) * wiener_win] +
1429           vgetq_lane_s32(deltas[3], 1);
1430     } while (++j < wiener_win);
1431   } while (++i < wiener_win - 1);
1432 
1433   // Step 5: Derive other points of each square. No square in bottom row.
1434   i = 0;
1435   do {
1436     const int16_t *const di = d + i;
1437 
1438     j = i + 1;
1439     do {
1440       const int16_t *const dj = d + j;
1441       int32x4_t deltas[WIENER_WIN - 1][WIN_7] = { { vdupq_n_s32(0) },
1442                                                   { vdupq_n_s32(0) } };
1443       int16x8_t d_is[WIN_7];
1444       int16x8_t d_ie[WIN_7];
1445       int16x8_t d_js[WIN_7];
1446       int16x8_t d_je[WIN_7];
1447 
1448       x = 0;
1449       while (x < w16) {
1450         load_square_win7_neon(di + x, dj + x, d_stride, height, d_is, d_ie,
1451                               d_js, d_je);
1452         derive_square_win7_neon(d_is, d_ie, d_js, d_je, deltas);
1453         x += 16;
1454       }
1455 
1456       if (w16 != width) {
1457         load_square_win7_neon(di + x, dj + x, d_stride, height, d_is, d_ie,
1458                               d_js, d_je);
1459         d_is[0] = vandq_s16(d_is[0], mask[0]);
1460         d_is[1] = vandq_s16(d_is[1], mask[1]);
1461         d_is[2] = vandq_s16(d_is[2], mask[0]);
1462         d_is[3] = vandq_s16(d_is[3], mask[1]);
1463         d_is[4] = vandq_s16(d_is[4], mask[0]);
1464         d_is[5] = vandq_s16(d_is[5], mask[1]);
1465         d_is[6] = vandq_s16(d_is[6], mask[0]);
1466         d_is[7] = vandq_s16(d_is[7], mask[1]);
1467         d_is[8] = vandq_s16(d_is[8], mask[0]);
1468         d_is[9] = vandq_s16(d_is[9], mask[1]);
1469         d_is[10] = vandq_s16(d_is[10], mask[0]);
1470         d_is[11] = vandq_s16(d_is[11], mask[1]);
1471         d_ie[0] = vandq_s16(d_ie[0], mask[0]);
1472         d_ie[1] = vandq_s16(d_ie[1], mask[1]);
1473         d_ie[2] = vandq_s16(d_ie[2], mask[0]);
1474         d_ie[3] = vandq_s16(d_ie[3], mask[1]);
1475         d_ie[4] = vandq_s16(d_ie[4], mask[0]);
1476         d_ie[5] = vandq_s16(d_ie[5], mask[1]);
1477         d_ie[6] = vandq_s16(d_ie[6], mask[0]);
1478         d_ie[7] = vandq_s16(d_ie[7], mask[1]);
1479         d_ie[8] = vandq_s16(d_ie[8], mask[0]);
1480         d_ie[9] = vandq_s16(d_ie[9], mask[1]);
1481         d_ie[10] = vandq_s16(d_ie[10], mask[0]);
1482         d_ie[11] = vandq_s16(d_ie[11], mask[1]);
1483         derive_square_win7_neon(d_is, d_ie, d_js, d_je, deltas);
1484       }
1485 
1486       hadd_update_6_stats_neon(
1487           H + (i * wiener_win + 0) * wiener_win2 + j * wiener_win, deltas[0],
1488           H + (i * wiener_win + 1) * wiener_win2 + j * wiener_win + 1);
1489       hadd_update_6_stats_neon(
1490           H + (i * wiener_win + 1) * wiener_win2 + j * wiener_win, deltas[1],
1491           H + (i * wiener_win + 2) * wiener_win2 + j * wiener_win + 1);
1492       hadd_update_6_stats_neon(
1493           H + (i * wiener_win + 2) * wiener_win2 + j * wiener_win, deltas[2],
1494           H + (i * wiener_win + 3) * wiener_win2 + j * wiener_win + 1);
1495       hadd_update_6_stats_neon(
1496           H + (i * wiener_win + 3) * wiener_win2 + j * wiener_win, deltas[3],
1497           H + (i * wiener_win + 4) * wiener_win2 + j * wiener_win + 1);
1498       hadd_update_6_stats_neon(
1499           H + (i * wiener_win + 4) * wiener_win2 + j * wiener_win, deltas[4],
1500           H + (i * wiener_win + 5) * wiener_win2 + j * wiener_win + 1);
1501       hadd_update_6_stats_neon(
1502           H + (i * wiener_win + 5) * wiener_win2 + j * wiener_win, deltas[5],
1503           H + (i * wiener_win + 6) * wiener_win2 + j * wiener_win + 1);
1504     } while (++j < wiener_win);
1505   } while (++i < wiener_win - 1);
1506 
1507   // Step 6: Derive other points of each upper triangle along the diagonal.
1508   i = 0;
1509   do {
1510     const int16_t *const di = d + i;
1511     int32x4_t deltas[WIENER_WIN * (WIENER_WIN - 1)] = { vdupq_n_s32(0) };
1512     int16x8_t d_is[WIN_7], d_ie[WIN_7];
1513 
1514     x = 0;
1515     while (x < w16) {
1516       load_triangle_win7_neon(di + x, d_stride, height, d_is, d_ie);
1517       derive_triangle_win7_neon(d_is, d_ie, deltas);
1518       x += 16;
1519     }
1520 
1521     if (w16 != width) {
1522       load_triangle_win7_neon(di + x, d_stride, height, d_is, d_ie);
1523       d_is[0] = vandq_s16(d_is[0], mask[0]);
1524       d_is[1] = vandq_s16(d_is[1], mask[1]);
1525       d_is[2] = vandq_s16(d_is[2], mask[0]);
1526       d_is[3] = vandq_s16(d_is[3], mask[1]);
1527       d_is[4] = vandq_s16(d_is[4], mask[0]);
1528       d_is[5] = vandq_s16(d_is[5], mask[1]);
1529       d_is[6] = vandq_s16(d_is[6], mask[0]);
1530       d_is[7] = vandq_s16(d_is[7], mask[1]);
1531       d_is[8] = vandq_s16(d_is[8], mask[0]);
1532       d_is[9] = vandq_s16(d_is[9], mask[1]);
1533       d_is[10] = vandq_s16(d_is[10], mask[0]);
1534       d_is[11] = vandq_s16(d_is[11], mask[1]);
1535       d_ie[0] = vandq_s16(d_ie[0], mask[0]);
1536       d_ie[1] = vandq_s16(d_ie[1], mask[1]);
1537       d_ie[2] = vandq_s16(d_ie[2], mask[0]);
1538       d_ie[3] = vandq_s16(d_ie[3], mask[1]);
1539       d_ie[4] = vandq_s16(d_ie[4], mask[0]);
1540       d_ie[5] = vandq_s16(d_ie[5], mask[1]);
1541       d_ie[6] = vandq_s16(d_ie[6], mask[0]);
1542       d_ie[7] = vandq_s16(d_ie[7], mask[1]);
1543       d_ie[8] = vandq_s16(d_ie[8], mask[0]);
1544       d_ie[9] = vandq_s16(d_ie[9], mask[1]);
1545       d_ie[10] = vandq_s16(d_ie[10], mask[0]);
1546       d_ie[11] = vandq_s16(d_ie[11], mask[1]);
1547       derive_triangle_win7_neon(d_is, d_ie, deltas);
1548     }
1549 
1550     // Row 1: 6 points
1551     hadd_update_6_stats_neon(
1552         H + (i * wiener_win + 0) * wiener_win2 + i * wiener_win, deltas,
1553         H + (i * wiener_win + 1) * wiener_win2 + i * wiener_win + 1);
1554 
1555     // Row 2: 5 points
1556     hadd_update_4_stats_neon(
1557         H + (i * wiener_win + 1) * wiener_win2 + i * wiener_win + 1, deltas + 6,
1558         H + (i * wiener_win + 2) * wiener_win2 + i * wiener_win + 2);
1559     H[(i * wiener_win + 2) * wiener_win2 + i * wiener_win + 6] =
1560         H[(i * wiener_win + 1) * wiener_win2 + i * wiener_win + 5] +
1561         horizontal_long_add_s32x4(deltas[10]);
1562 
1563     // Row 3: 4 points
1564     hadd_update_4_stats_neon(
1565         H + (i * wiener_win + 2) * wiener_win2 + i * wiener_win + 2,
1566         deltas + 11,
1567         H + (i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3);
1568 
1569     // Row 4: 3 points
1570 #if AOM_ARCH_AARCH64
1571     int64x2_t delta15_s64 = vpaddlq_s32(deltas[15]);
1572     int64x2_t delta16_s64 = vpaddlq_s32(deltas[16]);
1573     int64x2_t delta1516 = vpaddq_s64(delta15_s64, delta16_s64);
1574 
1575     int64x2_t h0 =
1576         vld1q_s64(H + (i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3);
1577     vst1q_s64(H + (i * wiener_win + 4) * wiener_win2 + i * wiener_win + 4,
1578               vaddq_s64(h0, delta1516));
1579 #else
1580     H[(i * wiener_win + 4) * wiener_win2 + i * wiener_win + 4 + 0] =
1581         H[(i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3 + 0] +
1582         horizontal_long_add_s32x4(deltas[15]);
1583     H[(i * wiener_win + 4) * wiener_win2 + i * wiener_win + 4 + 1] =
1584         H[(i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3 + 1] +
1585         horizontal_long_add_s32x4(deltas[16]);
1586 #endif  // AOM_ARCH_AARCH64
1587 
1588     H[(i * wiener_win + 4) * wiener_win2 + i * wiener_win + 6] =
1589         H[(i * wiener_win + 3) * wiener_win2 + i * wiener_win + 5] +
1590         horizontal_long_add_s32x4(deltas[17]);
1591 
1592     // Row 5: 2 points
1593     int64x2_t delta18_s64 = vpaddlq_s32(deltas[18]);
1594     int64x2_t delta19_s64 = vpaddlq_s32(deltas[19]);
1595 
1596 #if AOM_ARCH_AARCH64
1597     int64x2_t delta1819 = vpaddq_s64(delta18_s64, delta19_s64);
1598 
1599     int64x2_t h1 =
1600         vld1q_s64(H + (i * wiener_win + 4) * wiener_win2 + i * wiener_win + 4);
1601     vst1q_s64(H + (i * wiener_win + 5) * wiener_win2 + i * wiener_win + 5,
1602               vaddq_s64(h1, delta1819));
1603 #else
1604     H[(i * wiener_win + 5) * wiener_win2 + i * wiener_win + 5] =
1605         H[(i * wiener_win + 4) * wiener_win2 + i * wiener_win + 4] +
1606         horizontal_add_s64x2(delta18_s64);
1607     H[(i * wiener_win + 5) * wiener_win2 + i * wiener_win + 5 + 1] =
1608         H[(i * wiener_win + 4) * wiener_win2 + i * wiener_win + 4 + 1] +
1609         horizontal_add_s64x2(delta19_s64);
1610 #endif  // AOM_ARCH_AARCH64
1611 
1612     // Row 6: 1 points
1613     H[(i * wiener_win + 6) * wiener_win2 + i * wiener_win + 6] =
1614         H[(i * wiener_win + 5) * wiener_win2 + i * wiener_win + 5] +
1615         horizontal_long_add_s32x4(deltas[20]);
1616   } while (++i < wiener_win);
1617 }
1618 
sub_avg_block_highbd_neon(const uint16_t * src,const int32_t src_stride,const uint16_t avg,const int32_t width,const int32_t height,int16_t * dst,const int32_t dst_stride)1619 static inline void sub_avg_block_highbd_neon(const uint16_t *src,
1620                                              const int32_t src_stride,
1621                                              const uint16_t avg,
1622                                              const int32_t width,
1623                                              const int32_t height, int16_t *dst,
1624                                              const int32_t dst_stride) {
1625   const uint16x8_t a = vdupq_n_u16(avg);
1626 
1627   int32_t i = height + 1;
1628   do {
1629     int32_t j = 0;
1630     while (j < width) {
1631       const uint16x8_t s = vld1q_u16(src + j);
1632       const uint16x8_t d = vsubq_u16(s, a);
1633       vst1q_s16(dst + j, vreinterpretq_s16_u16(d));
1634       j += 8;
1635     }
1636 
1637     src += src_stride;
1638     dst += dst_stride;
1639   } while (--i);
1640 }
1641 
highbd_find_average_neon(const uint16_t * src,int src_stride,int width,int height)1642 static inline uint16_t highbd_find_average_neon(const uint16_t *src,
1643                                                 int src_stride, int width,
1644                                                 int height) {
1645   assert(width > 0);
1646   assert(height > 0);
1647 
1648   uint64x2_t sum_u64 = vdupq_n_u64(0);
1649   uint64_t sum = 0;
1650   const uint16x8_t mask =
1651       vreinterpretq_u16_s16(vld1q_s16(&mask_16bit[16] - (width % 8)));
1652 
1653   int h = height;
1654   do {
1655     uint32x4_t sum_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
1656 
1657     int w = width;
1658     const uint16_t *row = src;
1659     while (w >= 32) {
1660       uint16x8_t s0 = vld1q_u16(row + 0);
1661       uint16x8_t s1 = vld1q_u16(row + 8);
1662       uint16x8_t s2 = vld1q_u16(row + 16);
1663       uint16x8_t s3 = vld1q_u16(row + 24);
1664 
1665       s0 = vaddq_u16(s0, s1);
1666       s2 = vaddq_u16(s2, s3);
1667       sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
1668       sum_u32[1] = vpadalq_u16(sum_u32[1], s2);
1669 
1670       row += 32;
1671       w -= 32;
1672     }
1673 
1674     if (w >= 16) {
1675       uint16x8_t s0 = vld1q_u16(row + 0);
1676       uint16x8_t s1 = vld1q_u16(row + 8);
1677 
1678       s0 = vaddq_u16(s0, s1);
1679       sum_u32[0] = vpadalq_u16(sum_u32[0], s0);
1680 
1681       row += 16;
1682       w -= 16;
1683     }
1684 
1685     if (w >= 8) {
1686       uint16x8_t s0 = vld1q_u16(row);
1687       sum_u32[1] = vpadalq_u16(sum_u32[1], s0);
1688 
1689       row += 8;
1690       w -= 8;
1691     }
1692 
1693     if (w) {
1694       uint16x8_t s0 = vandq_u16(vld1q_u16(row), mask);
1695       sum_u32[1] = vpadalq_u16(sum_u32[1], s0);
1696 
1697       row += 8;
1698       w -= 8;
1699     }
1700 
1701     sum_u64 = vpadalq_u32(sum_u64, vaddq_u32(sum_u32[0], sum_u32[1]));
1702 
1703     src += src_stride;
1704   } while (--h != 0);
1705 
1706   return (uint16_t)((horizontal_add_u64x2(sum_u64) + sum) / (height * width));
1707 }
1708 
av1_compute_stats_highbd_neon(int32_t wiener_win,const uint8_t * dgd8,const uint8_t * src8,int16_t * dgd_avg,int16_t * src_avg,int32_t h_start,int32_t h_end,int32_t v_start,int32_t v_end,int32_t dgd_stride,int32_t src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)1709 void av1_compute_stats_highbd_neon(int32_t wiener_win, const uint8_t *dgd8,
1710                                    const uint8_t *src8, int16_t *dgd_avg,
1711                                    int16_t *src_avg, int32_t h_start,
1712                                    int32_t h_end, int32_t v_start,
1713                                    int32_t v_end, int32_t dgd_stride,
1714                                    int32_t src_stride, int64_t *M, int64_t *H,
1715                                    aom_bit_depth_t bit_depth) {
1716   const int32_t wiener_win2 = wiener_win * wiener_win;
1717   const int32_t wiener_halfwin = (wiener_win >> 1);
1718   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
1719   const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
1720   const int32_t width = h_end - h_start;
1721   const int32_t height = v_end - v_start;
1722   const uint16_t *dgd_start = dgd + h_start + v_start * dgd_stride;
1723   const uint16_t avg =
1724       highbd_find_average_neon(dgd_start, dgd_stride, width, height);
1725   const int32_t d_stride = (width + 2 * wiener_halfwin + 15) & ~15;
1726   const int32_t s_stride = (width + 15) & ~15;
1727 
1728   sub_avg_block_highbd_neon(src + v_start * src_stride + h_start, src_stride,
1729                             avg, width, height, src_avg, s_stride);
1730   sub_avg_block_highbd_neon(
1731       dgd + (v_start - wiener_halfwin) * dgd_stride + h_start - wiener_halfwin,
1732       dgd_stride, avg, width + 2 * wiener_halfwin, height + 2 * wiener_halfwin,
1733       dgd_avg, d_stride);
1734 
1735   if (wiener_win == WIENER_WIN) {
1736     compute_stats_win7_highbd_neon(dgd_avg, d_stride, src_avg, s_stride, width,
1737                                    height, M, H, bit_depth);
1738   } else if (wiener_win == WIENER_WIN_CHROMA) {
1739     compute_stats_win5_highbd_neon(dgd_avg, d_stride, src_avg, s_stride, width,
1740                                    height, M, H, bit_depth);
1741   }
1742 
1743   // H is a symmetric matrix, so we only need to fill out the upper triangle.
1744   // We can copy it down to the lower triangle outside the (i, j) loops.
1745   if (bit_depth == AOM_BITS_8) {
1746     diagonal_copy_stats_neon(wiener_win2, H);
1747   } else if (bit_depth == AOM_BITS_10) {  // bit_depth == AOM_BITS_10
1748     const int32_t k4 = wiener_win2 & ~3;
1749 
1750     int32_t k = 0;
1751     do {
1752       int64x2_t dst = div4_neon(vld1q_s64(M + k));
1753       vst1q_s64(M + k, dst);
1754       dst = div4_neon(vld1q_s64(M + k + 2));
1755       vst1q_s64(M + k + 2, dst);
1756       H[k * wiener_win2 + k] /= 4;
1757       k += 4;
1758     } while (k < k4);
1759 
1760     H[k * wiener_win2 + k] /= 4;
1761 
1762     for (; k < wiener_win2; ++k) {
1763       M[k] /= 4;
1764     }
1765 
1766     div4_diagonal_copy_stats_neon(wiener_win2, H);
1767   } else {  // bit_depth == AOM_BITS_12
1768     const int32_t k4 = wiener_win2 & ~3;
1769 
1770     int32_t k = 0;
1771     do {
1772       int64x2_t dst = div16_neon(vld1q_s64(M + k));
1773       vst1q_s64(M + k, dst);
1774       dst = div16_neon(vld1q_s64(M + k + 2));
1775       vst1q_s64(M + k + 2, dst);
1776       H[k * wiener_win2 + k] /= 16;
1777       k += 4;
1778     } while (k < k4);
1779 
1780     H[k * wiener_win2 + k] /= 16;
1781 
1782     for (; k < wiener_win2; ++k) {
1783       M[k] /= 16;
1784     }
1785 
1786     div16_diagonal_copy_stats_neon(wiener_win2, H);
1787   }
1788 }
av1_highbd_pixel_proj_error_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int xq[2],const sgr_params_type * params)1789 int64_t av1_highbd_pixel_proj_error_neon(
1790     const uint8_t *src8, int width, int height, int src_stride,
1791     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
1792     int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
1793   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
1794   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
1795   int64_t sse = 0;
1796   int64x2_t sse_s64 = vdupq_n_s64(0);
1797 
1798   if (params->r[0] > 0 && params->r[1] > 0) {
1799     int32x2_t xq_v = vld1_s32(xq);
1800     int32x2_t xq_sum_v = vshl_n_s32(vpadd_s32(xq_v, xq_v), 4);
1801 
1802     do {
1803       int j = 0;
1804       int32x4_t sse_s32 = vdupq_n_s32(0);
1805 
1806       do {
1807         const uint16x8_t d = vld1q_u16(&dat[j]);
1808         const uint16x8_t s = vld1q_u16(&src[j]);
1809         int32x4_t flt0_0 = vld1q_s32(&flt0[j]);
1810         int32x4_t flt0_1 = vld1q_s32(&flt0[j + 4]);
1811         int32x4_t flt1_0 = vld1q_s32(&flt1[j]);
1812         int32x4_t flt1_1 = vld1q_s32(&flt1[j + 4]);
1813 
1814         int32x4_t d_s32_lo = vreinterpretq_s32_u32(
1815             vmull_lane_u16(vget_low_u16(d), vreinterpret_u16_s32(xq_sum_v), 0));
1816         int32x4_t d_s32_hi = vreinterpretq_s32_u32(vmull_lane_u16(
1817             vget_high_u16(d), vreinterpret_u16_s32(xq_sum_v), 0));
1818 
1819         int32x4_t v0 = vsubq_s32(
1820             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)),
1821             d_s32_lo);
1822         int32x4_t v1 = vsubq_s32(
1823             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)),
1824             d_s32_hi);
1825 
1826         v0 = vmlaq_lane_s32(v0, flt0_0, xq_v, 0);
1827         v1 = vmlaq_lane_s32(v1, flt0_1, xq_v, 0);
1828         v0 = vmlaq_lane_s32(v0, flt1_0, xq_v, 1);
1829         v1 = vmlaq_lane_s32(v1, flt1_1, xq_v, 1);
1830 
1831         int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1832         int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1833 
1834         int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1),
1835                                 vreinterpretq_s16_u16(vsubq_u16(d, s)));
1836         int16x4_t e_lo = vget_low_s16(e);
1837         int16x4_t e_hi = vget_high_s16(e);
1838 
1839         sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
1840         sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
1841 
1842         j += 8;
1843       } while (j <= width - 8);
1844 
1845       for (int k = j; k < width; ++k) {
1846         int32_t v = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
1847         v += xq[0] * (flt0[k]) + xq[1] * (flt1[k]);
1848         v -= (xq[1] + xq[0]) * (int32_t)(dat[k] << 4);
1849         int32_t e =
1850             (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k];
1851         sse += ((int64_t)e * e);
1852       }
1853 
1854       sse_s64 = vpadalq_s32(sse_s64, sse_s32);
1855 
1856       dat += dat_stride;
1857       src += src_stride;
1858       flt0 += flt0_stride;
1859       flt1 += flt1_stride;
1860     } while (--height != 0);
1861   } else if (params->r[0] > 0 || params->r[1] > 0) {
1862     int xq_active = (params->r[0] > 0) ? xq[0] : xq[1];
1863     int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
1864     int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
1865     int32x4_t xq_v = vdupq_n_s32(xq_active);
1866 
1867     do {
1868       int j = 0;
1869       int32x4_t sse_s32 = vdupq_n_s32(0);
1870       do {
1871         const uint16x8_t d0 = vld1q_u16(&dat[j]);
1872         const uint16x8_t s0 = vld1q_u16(&src[j]);
1873         int32x4_t flt0_0 = vld1q_s32(&flt[j]);
1874         int32x4_t flt0_1 = vld1q_s32(&flt[j + 4]);
1875 
1876         uint16x8_t d_u16 = vshlq_n_u16(d0, 4);
1877         int32x4_t sub0 = vreinterpretq_s32_u32(
1878             vsubw_u16(vreinterpretq_u32_s32(flt0_0), vget_low_u16(d_u16)));
1879         int32x4_t sub1 = vreinterpretq_s32_u32(
1880             vsubw_u16(vreinterpretq_u32_s32(flt0_1), vget_high_u16(d_u16)));
1881 
1882         int32x4_t v0 = vmlaq_s32(
1883             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)), sub0,
1884             xq_v);
1885         int32x4_t v1 = vmlaq_s32(
1886             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)), sub1,
1887             xq_v);
1888 
1889         int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1890         int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
1891 
1892         int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1),
1893                                 vreinterpretq_s16_u16(vsubq_u16(d0, s0)));
1894         int16x4_t e_lo = vget_low_s16(e);
1895         int16x4_t e_hi = vget_high_s16(e);
1896 
1897         sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
1898         sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
1899 
1900         j += 8;
1901       } while (j <= width - 8);
1902 
1903       for (int k = j; k < width; ++k) {
1904         int32_t v = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
1905         v += xq_active * (int32_t)((uint32_t)flt[j] - (uint16_t)(dat[k] << 4));
1906         const int32_t e =
1907             (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k];
1908         sse += ((int64_t)e * e);
1909       }
1910 
1911       sse_s64 = vpadalq_s32(sse_s64, sse_s32);
1912 
1913       dat += dat_stride;
1914       flt += flt_stride;
1915       src += src_stride;
1916     } while (--height != 0);
1917   } else {
1918     do {
1919       int j = 0;
1920 
1921       do {
1922         const uint16x8_t d = vld1q_u16(&dat[j]);
1923         const uint16x8_t s = vld1q_u16(&src[j]);
1924 
1925         uint16x8_t diff = vabdq_u16(d, s);
1926         uint16x4_t diff_lo = vget_low_u16(diff);
1927         uint16x4_t diff_hi = vget_high_u16(diff);
1928 
1929         uint32x4_t sqr_lo = vmull_u16(diff_lo, diff_lo);
1930         uint32x4_t sqr_hi = vmull_u16(diff_hi, diff_hi);
1931 
1932         sse_s64 = vpadalq_s32(sse_s64, vreinterpretq_s32_u32(sqr_lo));
1933         sse_s64 = vpadalq_s32(sse_s64, vreinterpretq_s32_u32(sqr_hi));
1934 
1935         j += 8;
1936       } while (j <= width - 8);
1937 
1938       for (int k = j; k < width; ++k) {
1939         int32_t e = dat[k] - src[k];
1940         sse += e * e;
1941       }
1942 
1943       dat += dat_stride;
1944       src += src_stride;
1945     } while (--height != 0);
1946   }
1947 
1948   sse += horizontal_add_s64x2(sse_s64);
1949   return sse;
1950 }
1951