xref: /aosp_15_r20/external/libaom/aom_dsp/arm/sad_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/arm/dist_wtd_avg_neon.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 
sad128xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)22 static inline unsigned int sad128xh_neon(const uint8_t *src_ptr, int src_stride,
23                                          const uint8_t *ref_ptr, int ref_stride,
24                                          int h) {
25   // We use 8 accumulators to prevent overflow for large values of 'h', as well
26   // as enabling optimal UADALP instruction throughput on CPUs that have either
27   // 2 or 4 Neon pipes.
28   uint16x8_t sum[8] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
29                         vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
30                         vdupq_n_u16(0), vdupq_n_u16(0) };
31 
32   int i = h;
33   do {
34     uint8x16_t s0, s1, s2, s3, s4, s5, s6, s7;
35     uint8x16_t r0, r1, r2, r3, r4, r5, r6, r7;
36     uint8x16_t diff0, diff1, diff2, diff3, diff4, diff5, diff6, diff7;
37 
38     s0 = vld1q_u8(src_ptr);
39     r0 = vld1q_u8(ref_ptr);
40     diff0 = vabdq_u8(s0, r0);
41     sum[0] = vpadalq_u8(sum[0], diff0);
42 
43     s1 = vld1q_u8(src_ptr + 16);
44     r1 = vld1q_u8(ref_ptr + 16);
45     diff1 = vabdq_u8(s1, r1);
46     sum[1] = vpadalq_u8(sum[1], diff1);
47 
48     s2 = vld1q_u8(src_ptr + 32);
49     r2 = vld1q_u8(ref_ptr + 32);
50     diff2 = vabdq_u8(s2, r2);
51     sum[2] = vpadalq_u8(sum[2], diff2);
52 
53     s3 = vld1q_u8(src_ptr + 48);
54     r3 = vld1q_u8(ref_ptr + 48);
55     diff3 = vabdq_u8(s3, r3);
56     sum[3] = vpadalq_u8(sum[3], diff3);
57 
58     s4 = vld1q_u8(src_ptr + 64);
59     r4 = vld1q_u8(ref_ptr + 64);
60     diff4 = vabdq_u8(s4, r4);
61     sum[4] = vpadalq_u8(sum[4], diff4);
62 
63     s5 = vld1q_u8(src_ptr + 80);
64     r5 = vld1q_u8(ref_ptr + 80);
65     diff5 = vabdq_u8(s5, r5);
66     sum[5] = vpadalq_u8(sum[5], diff5);
67 
68     s6 = vld1q_u8(src_ptr + 96);
69     r6 = vld1q_u8(ref_ptr + 96);
70     diff6 = vabdq_u8(s6, r6);
71     sum[6] = vpadalq_u8(sum[6], diff6);
72 
73     s7 = vld1q_u8(src_ptr + 112);
74     r7 = vld1q_u8(ref_ptr + 112);
75     diff7 = vabdq_u8(s7, r7);
76     sum[7] = vpadalq_u8(sum[7], diff7);
77 
78     src_ptr += src_stride;
79     ref_ptr += ref_stride;
80   } while (--i != 0);
81 
82   uint32x4_t sum_u32 = vpaddlq_u16(sum[0]);
83   sum_u32 = vpadalq_u16(sum_u32, sum[1]);
84   sum_u32 = vpadalq_u16(sum_u32, sum[2]);
85   sum_u32 = vpadalq_u16(sum_u32, sum[3]);
86   sum_u32 = vpadalq_u16(sum_u32, sum[4]);
87   sum_u32 = vpadalq_u16(sum_u32, sum[5]);
88   sum_u32 = vpadalq_u16(sum_u32, sum[6]);
89   sum_u32 = vpadalq_u16(sum_u32, sum[7]);
90 
91   return horizontal_add_u32x4(sum_u32);
92 }
93 
sad64xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)94 static inline unsigned int sad64xh_neon(const uint8_t *src_ptr, int src_stride,
95                                         const uint8_t *ref_ptr, int ref_stride,
96                                         int h) {
97   uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
98                         vdupq_n_u16(0) };
99 
100   int i = h;
101   do {
102     uint8x16_t s0, s1, s2, s3, r0, r1, r2, r3;
103     uint8x16_t diff0, diff1, diff2, diff3;
104 
105     s0 = vld1q_u8(src_ptr);
106     r0 = vld1q_u8(ref_ptr);
107     diff0 = vabdq_u8(s0, r0);
108     sum[0] = vpadalq_u8(sum[0], diff0);
109 
110     s1 = vld1q_u8(src_ptr + 16);
111     r1 = vld1q_u8(ref_ptr + 16);
112     diff1 = vabdq_u8(s1, r1);
113     sum[1] = vpadalq_u8(sum[1], diff1);
114 
115     s2 = vld1q_u8(src_ptr + 32);
116     r2 = vld1q_u8(ref_ptr + 32);
117     diff2 = vabdq_u8(s2, r2);
118     sum[2] = vpadalq_u8(sum[2], diff2);
119 
120     s3 = vld1q_u8(src_ptr + 48);
121     r3 = vld1q_u8(ref_ptr + 48);
122     diff3 = vabdq_u8(s3, r3);
123     sum[3] = vpadalq_u8(sum[3], diff3);
124 
125     src_ptr += src_stride;
126     ref_ptr += ref_stride;
127   } while (--i != 0);
128 
129   uint32x4_t sum_u32 = vpaddlq_u16(sum[0]);
130   sum_u32 = vpadalq_u16(sum_u32, sum[1]);
131   sum_u32 = vpadalq_u16(sum_u32, sum[2]);
132   sum_u32 = vpadalq_u16(sum_u32, sum[3]);
133 
134   return horizontal_add_u32x4(sum_u32);
135 }
136 
sad32xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)137 static inline unsigned int sad32xh_neon(const uint8_t *src_ptr, int src_stride,
138                                         const uint8_t *ref_ptr, int ref_stride,
139                                         int h) {
140   uint16x8_t sum[2] = { vdupq_n_u16(0), vdupq_n_u16(0) };
141 
142   int i = h;
143   do {
144     uint8x16_t s0 = vld1q_u8(src_ptr);
145     uint8x16_t r0 = vld1q_u8(ref_ptr);
146     uint8x16_t diff0 = vabdq_u8(s0, r0);
147     sum[0] = vpadalq_u8(sum[0], diff0);
148 
149     uint8x16_t s1 = vld1q_u8(src_ptr + 16);
150     uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
151     uint8x16_t diff1 = vabdq_u8(s1, r1);
152     sum[1] = vpadalq_u8(sum[1], diff1);
153 
154     src_ptr += src_stride;
155     ref_ptr += ref_stride;
156   } while (--i != 0);
157 
158   return horizontal_add_u16x8(vaddq_u16(sum[0], sum[1]));
159 }
160 
sad16xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)161 static inline unsigned int sad16xh_neon(const uint8_t *src_ptr, int src_stride,
162                                         const uint8_t *ref_ptr, int ref_stride,
163                                         int h) {
164   uint16x8_t sum = vdupq_n_u16(0);
165 
166   int i = h;
167   do {
168     uint8x16_t s = vld1q_u8(src_ptr);
169     uint8x16_t r = vld1q_u8(ref_ptr);
170 
171     uint8x16_t diff = vabdq_u8(s, r);
172     sum = vpadalq_u8(sum, diff);
173 
174     src_ptr += src_stride;
175     ref_ptr += ref_stride;
176   } while (--i != 0);
177 
178   return horizontal_add_u16x8(sum);
179 }
180 
sad8xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)181 static inline unsigned int sad8xh_neon(const uint8_t *src_ptr, int src_stride,
182                                        const uint8_t *ref_ptr, int ref_stride,
183                                        int h) {
184   uint16x8_t sum = vdupq_n_u16(0);
185 
186   int i = h;
187   do {
188     uint8x8_t s = vld1_u8(src_ptr);
189     uint8x8_t r = vld1_u8(ref_ptr);
190 
191     sum = vabal_u8(sum, s, r);
192 
193     src_ptr += src_stride;
194     ref_ptr += ref_stride;
195   } while (--i != 0);
196 
197   return horizontal_add_u16x8(sum);
198 }
199 
sad4xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)200 static inline unsigned int sad4xh_neon(const uint8_t *src_ptr, int src_stride,
201                                        const uint8_t *ref_ptr, int ref_stride,
202                                        int h) {
203   uint16x8_t sum = vdupq_n_u16(0);
204 
205   int i = h / 2;
206   do {
207     uint8x8_t s = load_unaligned_u8(src_ptr, src_stride);
208     uint8x8_t r = load_unaligned_u8(ref_ptr, ref_stride);
209 
210     sum = vabal_u8(sum, s, r);
211 
212     src_ptr += 2 * src_stride;
213     ref_ptr += 2 * ref_stride;
214   } while (--i != 0);
215 
216   return horizontal_add_u16x8(sum);
217 }
218 
219 #define SAD_WXH_NEON(w, h)                                                   \
220   unsigned int aom_sad##w##x##h##_neon(const uint8_t *src, int src_stride,   \
221                                        const uint8_t *ref, int ref_stride) { \
222     return sad##w##xh_neon(src, src_stride, ref, ref_stride, (h));           \
223   }
224 
225 SAD_WXH_NEON(4, 4)
226 SAD_WXH_NEON(4, 8)
227 
228 SAD_WXH_NEON(8, 4)
229 SAD_WXH_NEON(8, 8)
230 SAD_WXH_NEON(8, 16)
231 
232 SAD_WXH_NEON(16, 8)
233 SAD_WXH_NEON(16, 16)
234 SAD_WXH_NEON(16, 32)
235 
236 SAD_WXH_NEON(32, 16)
237 SAD_WXH_NEON(32, 32)
238 SAD_WXH_NEON(32, 64)
239 
240 SAD_WXH_NEON(64, 32)
241 SAD_WXH_NEON(64, 64)
242 SAD_WXH_NEON(64, 128)
243 
244 SAD_WXH_NEON(128, 64)
245 SAD_WXH_NEON(128, 128)
246 
247 #if !CONFIG_REALTIME_ONLY
248 SAD_WXH_NEON(4, 16)
249 SAD_WXH_NEON(8, 32)
250 SAD_WXH_NEON(16, 4)
251 SAD_WXH_NEON(16, 64)
252 SAD_WXH_NEON(32, 8)
253 SAD_WXH_NEON(64, 16)
254 #endif  // !CONFIG_REALTIME_ONLY
255 
256 #undef SAD_WXH_NEON
257 
258 #define SAD_SKIP_WXH_NEON(w, h)                                                \
259   unsigned int aom_sad_skip_##w##x##h##_neon(                                  \
260       const uint8_t *src, int src_stride, const uint8_t *ref,                  \
261       int ref_stride) {                                                        \
262     return 2 *                                                                 \
263            sad##w##xh_neon(src, 2 * src_stride, ref, 2 * ref_stride, (h) / 2); \
264   }
265 
266 SAD_SKIP_WXH_NEON(4, 4)
267 SAD_SKIP_WXH_NEON(4, 8)
268 
269 SAD_SKIP_WXH_NEON(8, 4)
270 SAD_SKIP_WXH_NEON(8, 8)
271 SAD_SKIP_WXH_NEON(8, 16)
272 
273 SAD_SKIP_WXH_NEON(16, 8)
274 SAD_SKIP_WXH_NEON(16, 16)
275 SAD_SKIP_WXH_NEON(16, 32)
276 
277 SAD_SKIP_WXH_NEON(32, 16)
278 SAD_SKIP_WXH_NEON(32, 32)
279 SAD_SKIP_WXH_NEON(32, 64)
280 
281 SAD_SKIP_WXH_NEON(64, 32)
282 SAD_SKIP_WXH_NEON(64, 64)
283 SAD_SKIP_WXH_NEON(64, 128)
284 
285 SAD_SKIP_WXH_NEON(128, 64)
286 SAD_SKIP_WXH_NEON(128, 128)
287 
288 #if !CONFIG_REALTIME_ONLY
289 SAD_SKIP_WXH_NEON(4, 16)
290 SAD_SKIP_WXH_NEON(8, 32)
291 SAD_SKIP_WXH_NEON(16, 4)
292 SAD_SKIP_WXH_NEON(16, 64)
293 SAD_SKIP_WXH_NEON(32, 8)
294 SAD_SKIP_WXH_NEON(64, 16)
295 #endif  // !CONFIG_REALTIME_ONLY
296 
297 #undef SAD_SKIP_WXH_NEON
298 
sad128xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)299 static inline unsigned int sad128xh_avg_neon(const uint8_t *src_ptr,
300                                              int src_stride,
301                                              const uint8_t *ref_ptr,
302                                              int ref_stride, int h,
303                                              const uint8_t *second_pred) {
304   // We use 8 accumulators to prevent overflow for large values of 'h', as well
305   // as enabling optimal UADALP instruction throughput on CPUs that have either
306   // 2 or 4 Neon pipes.
307   uint16x8_t sum[8] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
308                         vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
309                         vdupq_n_u16(0), vdupq_n_u16(0) };
310 
311   int i = h;
312   do {
313     uint8x16_t s0, s1, s2, s3, s4, s5, s6, s7;
314     uint8x16_t r0, r1, r2, r3, r4, r5, r6, r7;
315     uint8x16_t p0, p1, p2, p3, p4, p5, p6, p7;
316     uint8x16_t avg0, avg1, avg2, avg3, avg4, avg5, avg6, avg7;
317     uint8x16_t diff0, diff1, diff2, diff3, diff4, diff5, diff6, diff7;
318 
319     s0 = vld1q_u8(src_ptr);
320     r0 = vld1q_u8(ref_ptr);
321     p0 = vld1q_u8(second_pred);
322     avg0 = vrhaddq_u8(r0, p0);
323     diff0 = vabdq_u8(s0, avg0);
324     sum[0] = vpadalq_u8(sum[0], diff0);
325 
326     s1 = vld1q_u8(src_ptr + 16);
327     r1 = vld1q_u8(ref_ptr + 16);
328     p1 = vld1q_u8(second_pred + 16);
329     avg1 = vrhaddq_u8(r1, p1);
330     diff1 = vabdq_u8(s1, avg1);
331     sum[1] = vpadalq_u8(sum[1], diff1);
332 
333     s2 = vld1q_u8(src_ptr + 32);
334     r2 = vld1q_u8(ref_ptr + 32);
335     p2 = vld1q_u8(second_pred + 32);
336     avg2 = vrhaddq_u8(r2, p2);
337     diff2 = vabdq_u8(s2, avg2);
338     sum[2] = vpadalq_u8(sum[2], diff2);
339 
340     s3 = vld1q_u8(src_ptr + 48);
341     r3 = vld1q_u8(ref_ptr + 48);
342     p3 = vld1q_u8(second_pred + 48);
343     avg3 = vrhaddq_u8(r3, p3);
344     diff3 = vabdq_u8(s3, avg3);
345     sum[3] = vpadalq_u8(sum[3], diff3);
346 
347     s4 = vld1q_u8(src_ptr + 64);
348     r4 = vld1q_u8(ref_ptr + 64);
349     p4 = vld1q_u8(second_pred + 64);
350     avg4 = vrhaddq_u8(r4, p4);
351     diff4 = vabdq_u8(s4, avg4);
352     sum[4] = vpadalq_u8(sum[4], diff4);
353 
354     s5 = vld1q_u8(src_ptr + 80);
355     r5 = vld1q_u8(ref_ptr + 80);
356     p5 = vld1q_u8(second_pred + 80);
357     avg5 = vrhaddq_u8(r5, p5);
358     diff5 = vabdq_u8(s5, avg5);
359     sum[5] = vpadalq_u8(sum[5], diff5);
360 
361     s6 = vld1q_u8(src_ptr + 96);
362     r6 = vld1q_u8(ref_ptr + 96);
363     p6 = vld1q_u8(second_pred + 96);
364     avg6 = vrhaddq_u8(r6, p6);
365     diff6 = vabdq_u8(s6, avg6);
366     sum[6] = vpadalq_u8(sum[6], diff6);
367 
368     s7 = vld1q_u8(src_ptr + 112);
369     r7 = vld1q_u8(ref_ptr + 112);
370     p7 = vld1q_u8(second_pred + 112);
371     avg7 = vrhaddq_u8(r7, p7);
372     diff7 = vabdq_u8(s7, avg7);
373     sum[7] = vpadalq_u8(sum[7], diff7);
374 
375     src_ptr += src_stride;
376     ref_ptr += ref_stride;
377     second_pred += 128;
378   } while (--i != 0);
379 
380   uint32x4_t sum_u32 = vpaddlq_u16(sum[0]);
381   sum_u32 = vpadalq_u16(sum_u32, sum[1]);
382   sum_u32 = vpadalq_u16(sum_u32, sum[2]);
383   sum_u32 = vpadalq_u16(sum_u32, sum[3]);
384   sum_u32 = vpadalq_u16(sum_u32, sum[4]);
385   sum_u32 = vpadalq_u16(sum_u32, sum[5]);
386   sum_u32 = vpadalq_u16(sum_u32, sum[6]);
387   sum_u32 = vpadalq_u16(sum_u32, sum[7]);
388 
389   return horizontal_add_u32x4(sum_u32);
390 }
391 
sad64xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)392 static inline unsigned int sad64xh_avg_neon(const uint8_t *src_ptr,
393                                             int src_stride,
394                                             const uint8_t *ref_ptr,
395                                             int ref_stride, int h,
396                                             const uint8_t *second_pred) {
397   uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
398                         vdupq_n_u16(0) };
399 
400   int i = h;
401   do {
402     uint8x16_t s0, s1, s2, s3, r0, r1, r2, r3, p0, p1, p2, p3;
403     uint8x16_t avg0, avg1, avg2, avg3, diff0, diff1, diff2, diff3;
404 
405     s0 = vld1q_u8(src_ptr);
406     r0 = vld1q_u8(ref_ptr);
407     p0 = vld1q_u8(second_pred);
408     avg0 = vrhaddq_u8(r0, p0);
409     diff0 = vabdq_u8(s0, avg0);
410     sum[0] = vpadalq_u8(sum[0], diff0);
411 
412     s1 = vld1q_u8(src_ptr + 16);
413     r1 = vld1q_u8(ref_ptr + 16);
414     p1 = vld1q_u8(second_pred + 16);
415     avg1 = vrhaddq_u8(r1, p1);
416     diff1 = vabdq_u8(s1, avg1);
417     sum[1] = vpadalq_u8(sum[1], diff1);
418 
419     s2 = vld1q_u8(src_ptr + 32);
420     r2 = vld1q_u8(ref_ptr + 32);
421     p2 = vld1q_u8(second_pred + 32);
422     avg2 = vrhaddq_u8(r2, p2);
423     diff2 = vabdq_u8(s2, avg2);
424     sum[2] = vpadalq_u8(sum[2], diff2);
425 
426     s3 = vld1q_u8(src_ptr + 48);
427     r3 = vld1q_u8(ref_ptr + 48);
428     p3 = vld1q_u8(second_pred + 48);
429     avg3 = vrhaddq_u8(r3, p3);
430     diff3 = vabdq_u8(s3, avg3);
431     sum[3] = vpadalq_u8(sum[3], diff3);
432 
433     src_ptr += src_stride;
434     ref_ptr += ref_stride;
435     second_pred += 64;
436   } while (--i != 0);
437 
438   uint32x4_t sum_u32 = vpaddlq_u16(sum[0]);
439   sum_u32 = vpadalq_u16(sum_u32, sum[1]);
440   sum_u32 = vpadalq_u16(sum_u32, sum[2]);
441   sum_u32 = vpadalq_u16(sum_u32, sum[3]);
442 
443   return horizontal_add_u32x4(sum_u32);
444 }
445 
sad32xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)446 static inline unsigned int sad32xh_avg_neon(const uint8_t *src_ptr,
447                                             int src_stride,
448                                             const uint8_t *ref_ptr,
449                                             int ref_stride, int h,
450                                             const uint8_t *second_pred) {
451   uint16x8_t sum[2] = { vdupq_n_u16(0), vdupq_n_u16(0) };
452 
453   int i = h;
454   do {
455     uint8x16_t s0 = vld1q_u8(src_ptr);
456     uint8x16_t r0 = vld1q_u8(ref_ptr);
457     uint8x16_t p0 = vld1q_u8(second_pred);
458     uint8x16_t avg0 = vrhaddq_u8(r0, p0);
459     uint8x16_t diff0 = vabdq_u8(s0, avg0);
460     sum[0] = vpadalq_u8(sum[0], diff0);
461 
462     uint8x16_t s1 = vld1q_u8(src_ptr + 16);
463     uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
464     uint8x16_t p1 = vld1q_u8(second_pred + 16);
465     uint8x16_t avg1 = vrhaddq_u8(r1, p1);
466     uint8x16_t diff1 = vabdq_u8(s1, avg1);
467     sum[1] = vpadalq_u8(sum[1], diff1);
468 
469     src_ptr += src_stride;
470     ref_ptr += ref_stride;
471     second_pred += 32;
472   } while (--i != 0);
473 
474   return horizontal_add_u16x8(vaddq_u16(sum[0], sum[1]));
475 }
476 
sad16xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)477 static inline unsigned int sad16xh_avg_neon(const uint8_t *src_ptr,
478                                             int src_stride,
479                                             const uint8_t *ref_ptr,
480                                             int ref_stride, int h,
481                                             const uint8_t *second_pred) {
482   uint16x8_t sum = vdupq_n_u16(0);
483 
484   int i = h;
485   do {
486     uint8x16_t s = vld1q_u8(src_ptr);
487     uint8x16_t r = vld1q_u8(ref_ptr);
488     uint8x16_t p = vld1q_u8(second_pred);
489 
490     uint8x16_t avg = vrhaddq_u8(r, p);
491     uint8x16_t diff = vabdq_u8(s, avg);
492     sum = vpadalq_u8(sum, diff);
493 
494     src_ptr += src_stride;
495     ref_ptr += ref_stride;
496     second_pred += 16;
497   } while (--i != 0);
498 
499   return horizontal_add_u16x8(sum);
500 }
501 
sad8xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)502 static inline unsigned int sad8xh_avg_neon(const uint8_t *src_ptr,
503                                            int src_stride,
504                                            const uint8_t *ref_ptr,
505                                            int ref_stride, int h,
506                                            const uint8_t *second_pred) {
507   uint16x8_t sum = vdupq_n_u16(0);
508 
509   int i = h;
510   do {
511     uint8x8_t s = vld1_u8(src_ptr);
512     uint8x8_t r = vld1_u8(ref_ptr);
513     uint8x8_t p = vld1_u8(second_pred);
514 
515     uint8x8_t avg = vrhadd_u8(r, p);
516     sum = vabal_u8(sum, s, avg);
517 
518     src_ptr += src_stride;
519     ref_ptr += ref_stride;
520     second_pred += 8;
521   } while (--i != 0);
522 
523   return horizontal_add_u16x8(sum);
524 }
525 
sad4xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)526 static inline unsigned int sad4xh_avg_neon(const uint8_t *src_ptr,
527                                            int src_stride,
528                                            const uint8_t *ref_ptr,
529                                            int ref_stride, int h,
530                                            const uint8_t *second_pred) {
531   uint16x8_t sum = vdupq_n_u16(0);
532 
533   int i = h / 2;
534   do {
535     uint8x8_t s = load_unaligned_u8(src_ptr, src_stride);
536     uint8x8_t r = load_unaligned_u8(ref_ptr, ref_stride);
537     uint8x8_t p = vld1_u8(second_pred);
538 
539     uint8x8_t avg = vrhadd_u8(r, p);
540     sum = vabal_u8(sum, s, avg);
541 
542     src_ptr += 2 * src_stride;
543     ref_ptr += 2 * ref_stride;
544     second_pred += 8;
545   } while (--i != 0);
546 
547   return horizontal_add_u16x8(sum);
548 }
549 
550 #define SAD_WXH_AVG_NEON(w, h)                                                 \
551   unsigned int aom_sad##w##x##h##_avg_neon(const uint8_t *src, int src_stride, \
552                                            const uint8_t *ref, int ref_stride, \
553                                            const uint8_t *second_pred) {       \
554     return sad##w##xh_avg_neon(src, src_stride, ref, ref_stride, (h),          \
555                                second_pred);                                   \
556   }
557 
558 SAD_WXH_AVG_NEON(4, 4)
559 SAD_WXH_AVG_NEON(4, 8)
560 
561 SAD_WXH_AVG_NEON(8, 4)
562 SAD_WXH_AVG_NEON(8, 8)
563 SAD_WXH_AVG_NEON(8, 16)
564 
565 SAD_WXH_AVG_NEON(16, 8)
566 SAD_WXH_AVG_NEON(16, 16)
567 SAD_WXH_AVG_NEON(16, 32)
568 
569 SAD_WXH_AVG_NEON(32, 16)
570 SAD_WXH_AVG_NEON(32, 32)
571 SAD_WXH_AVG_NEON(32, 64)
572 
573 SAD_WXH_AVG_NEON(64, 32)
574 SAD_WXH_AVG_NEON(64, 64)
575 SAD_WXH_AVG_NEON(64, 128)
576 
577 SAD_WXH_AVG_NEON(128, 64)
578 SAD_WXH_AVG_NEON(128, 128)
579 
580 #if !CONFIG_REALTIME_ONLY
581 SAD_WXH_AVG_NEON(4, 16)
582 SAD_WXH_AVG_NEON(8, 32)
583 SAD_WXH_AVG_NEON(16, 4)
584 SAD_WXH_AVG_NEON(16, 64)
585 SAD_WXH_AVG_NEON(32, 8)
586 SAD_WXH_AVG_NEON(64, 16)
587 #endif  // !CONFIG_REALTIME_ONLY
588 
589 #undef SAD_WXH_AVG_NEON
590 
dist_wtd_sad128xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)591 static inline unsigned int dist_wtd_sad128xh_avg_neon(
592     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
593     int ref_stride, int h, const uint8_t *second_pred,
594     const DIST_WTD_COMP_PARAMS *jcp_param) {
595   const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
596   const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
597   // We use 8 accumulators to prevent overflow for large values of 'h', as well
598   // as enabling optimal UADALP instruction throughput on CPUs that have either
599   // 2 or 4 Neon pipes.
600   uint16x8_t sum[8] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
601                         vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
602                         vdupq_n_u16(0), vdupq_n_u16(0) };
603 
604   do {
605     uint8x16_t s0 = vld1q_u8(src_ptr);
606     uint8x16_t r0 = vld1q_u8(ref_ptr);
607     uint8x16_t p0 = vld1q_u8(second_pred);
608     uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
609     uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
610     sum[0] = vpadalq_u8(sum[0], diff0);
611 
612     uint8x16_t s1 = vld1q_u8(src_ptr + 16);
613     uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
614     uint8x16_t p1 = vld1q_u8(second_pred + 16);
615     uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
616     uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
617     sum[1] = vpadalq_u8(sum[1], diff1);
618 
619     uint8x16_t s2 = vld1q_u8(src_ptr + 32);
620     uint8x16_t r2 = vld1q_u8(ref_ptr + 32);
621     uint8x16_t p2 = vld1q_u8(second_pred + 32);
622     uint8x16_t wtd_avg2 = dist_wtd_avg_u8x16(p2, r2, bck_offset, fwd_offset);
623     uint8x16_t diff2 = vabdq_u8(s2, wtd_avg2);
624     sum[2] = vpadalq_u8(sum[2], diff2);
625 
626     uint8x16_t s3 = vld1q_u8(src_ptr + 48);
627     uint8x16_t r3 = vld1q_u8(ref_ptr + 48);
628     uint8x16_t p3 = vld1q_u8(second_pred + 48);
629     uint8x16_t wtd_avg3 = dist_wtd_avg_u8x16(p3, r3, bck_offset, fwd_offset);
630     uint8x16_t diff3 = vabdq_u8(s3, wtd_avg3);
631     sum[3] = vpadalq_u8(sum[3], diff3);
632 
633     uint8x16_t s4 = vld1q_u8(src_ptr + 64);
634     uint8x16_t r4 = vld1q_u8(ref_ptr + 64);
635     uint8x16_t p4 = vld1q_u8(second_pred + 64);
636     uint8x16_t wtd_avg4 = dist_wtd_avg_u8x16(p4, r4, bck_offset, fwd_offset);
637     uint8x16_t diff4 = vabdq_u8(s4, wtd_avg4);
638     sum[4] = vpadalq_u8(sum[4], diff4);
639 
640     uint8x16_t s5 = vld1q_u8(src_ptr + 80);
641     uint8x16_t r5 = vld1q_u8(ref_ptr + 80);
642     uint8x16_t p5 = vld1q_u8(second_pred + 80);
643     uint8x16_t wtd_avg5 = dist_wtd_avg_u8x16(p5, r5, bck_offset, fwd_offset);
644     uint8x16_t diff5 = vabdq_u8(s5, wtd_avg5);
645     sum[5] = vpadalq_u8(sum[5], diff5);
646 
647     uint8x16_t s6 = vld1q_u8(src_ptr + 96);
648     uint8x16_t r6 = vld1q_u8(ref_ptr + 96);
649     uint8x16_t p6 = vld1q_u8(second_pred + 96);
650     uint8x16_t wtd_avg6 = dist_wtd_avg_u8x16(p6, r6, bck_offset, fwd_offset);
651     uint8x16_t diff6 = vabdq_u8(s6, wtd_avg6);
652     sum[6] = vpadalq_u8(sum[6], diff6);
653 
654     uint8x16_t s7 = vld1q_u8(src_ptr + 112);
655     uint8x16_t r7 = vld1q_u8(ref_ptr + 112);
656     uint8x16_t p7 = vld1q_u8(second_pred + 112);
657     uint8x16_t wtd_avg7 = dist_wtd_avg_u8x16(p7, r7, bck_offset, fwd_offset);
658     uint8x16_t diff7 = vabdq_u8(s7, wtd_avg7);
659     sum[7] = vpadalq_u8(sum[7], diff7);
660 
661     src_ptr += src_stride;
662     ref_ptr += ref_stride;
663     second_pred += 128;
664   } while (--h != 0);
665 
666   uint32x4_t sum_u32 = vpaddlq_u16(sum[0]);
667   sum_u32 = vpadalq_u16(sum_u32, sum[1]);
668   sum_u32 = vpadalq_u16(sum_u32, sum[2]);
669   sum_u32 = vpadalq_u16(sum_u32, sum[3]);
670   sum_u32 = vpadalq_u16(sum_u32, sum[4]);
671   sum_u32 = vpadalq_u16(sum_u32, sum[5]);
672   sum_u32 = vpadalq_u16(sum_u32, sum[6]);
673   sum_u32 = vpadalq_u16(sum_u32, sum[7]);
674 
675   return horizontal_add_u32x4(sum_u32);
676 }
677 
dist_wtd_sad64xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)678 static inline unsigned int dist_wtd_sad64xh_avg_neon(
679     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
680     int ref_stride, int h, const uint8_t *second_pred,
681     const DIST_WTD_COMP_PARAMS *jcp_param) {
682   const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
683   const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
684   uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
685                         vdupq_n_u16(0) };
686 
687   do {
688     uint8x16_t s0 = vld1q_u8(src_ptr);
689     uint8x16_t r0 = vld1q_u8(ref_ptr);
690     uint8x16_t p0 = vld1q_u8(second_pred);
691     uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
692     uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
693     sum[0] = vpadalq_u8(sum[0], diff0);
694 
695     uint8x16_t s1 = vld1q_u8(src_ptr + 16);
696     uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
697     uint8x16_t p1 = vld1q_u8(second_pred + 16);
698     uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
699     uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
700     sum[1] = vpadalq_u8(sum[1], diff1);
701 
702     uint8x16_t s2 = vld1q_u8(src_ptr + 32);
703     uint8x16_t r2 = vld1q_u8(ref_ptr + 32);
704     uint8x16_t p2 = vld1q_u8(second_pred + 32);
705     uint8x16_t wtd_avg2 = dist_wtd_avg_u8x16(p2, r2, bck_offset, fwd_offset);
706     uint8x16_t diff2 = vabdq_u8(s2, wtd_avg2);
707     sum[2] = vpadalq_u8(sum[2], diff2);
708 
709     uint8x16_t s3 = vld1q_u8(src_ptr + 48);
710     uint8x16_t r3 = vld1q_u8(ref_ptr + 48);
711     uint8x16_t p3 = vld1q_u8(second_pred + 48);
712     uint8x16_t wtd_avg3 = dist_wtd_avg_u8x16(p3, r3, bck_offset, fwd_offset);
713     uint8x16_t diff3 = vabdq_u8(s3, wtd_avg3);
714     sum[3] = vpadalq_u8(sum[3], diff3);
715 
716     src_ptr += src_stride;
717     ref_ptr += ref_stride;
718     second_pred += 64;
719   } while (--h != 0);
720 
721   uint32x4_t sum_u32 = vpaddlq_u16(sum[0]);
722   sum_u32 = vpadalq_u16(sum_u32, sum[1]);
723   sum_u32 = vpadalq_u16(sum_u32, sum[2]);
724   sum_u32 = vpadalq_u16(sum_u32, sum[3]);
725 
726   return horizontal_add_u32x4(sum_u32);
727 }
728 
dist_wtd_sad32xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)729 static inline unsigned int dist_wtd_sad32xh_avg_neon(
730     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
731     int ref_stride, int h, const uint8_t *second_pred,
732     const DIST_WTD_COMP_PARAMS *jcp_param) {
733   const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
734   const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
735   uint16x8_t sum[2] = { vdupq_n_u16(0), vdupq_n_u16(0) };
736 
737   do {
738     uint8x16_t s0 = vld1q_u8(src_ptr);
739     uint8x16_t r0 = vld1q_u8(ref_ptr);
740     uint8x16_t p0 = vld1q_u8(second_pred);
741     uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
742     uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
743     sum[0] = vpadalq_u8(sum[0], diff0);
744 
745     uint8x16_t s1 = vld1q_u8(src_ptr + 16);
746     uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
747     uint8x16_t p1 = vld1q_u8(second_pred + 16);
748     uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
749     uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
750     sum[1] = vpadalq_u8(sum[1], diff1);
751 
752     src_ptr += src_stride;
753     ref_ptr += ref_stride;
754     second_pred += 32;
755   } while (--h != 0);
756 
757   return horizontal_add_u16x8(vaddq_u16(sum[0], sum[1]));
758 }
759 
dist_wtd_sad16xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)760 static inline unsigned int dist_wtd_sad16xh_avg_neon(
761     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
762     int ref_stride, int h, const uint8_t *second_pred,
763     const DIST_WTD_COMP_PARAMS *jcp_param) {
764   const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
765   const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
766   uint16x8_t sum = vdupq_n_u16(0);
767 
768   do {
769     uint8x16_t s = vld1q_u8(src_ptr);
770     uint8x16_t r = vld1q_u8(ref_ptr);
771     uint8x16_t p = vld1q_u8(second_pred);
772 
773     uint8x16_t wtd_avg = dist_wtd_avg_u8x16(p, r, bck_offset, fwd_offset);
774     uint8x16_t diff = vabdq_u8(s, wtd_avg);
775     sum = vpadalq_u8(sum, diff);
776 
777     src_ptr += src_stride;
778     ref_ptr += ref_stride;
779     second_pred += 16;
780   } while (--h != 0);
781 
782   return horizontal_add_u16x8(sum);
783 }
784 
dist_wtd_sad8xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)785 static inline unsigned int dist_wtd_sad8xh_avg_neon(
786     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
787     int ref_stride, int h, const uint8_t *second_pred,
788     const DIST_WTD_COMP_PARAMS *jcp_param) {
789   const uint8x8_t fwd_offset = vdup_n_u8(jcp_param->fwd_offset);
790   const uint8x8_t bck_offset = vdup_n_u8(jcp_param->bck_offset);
791   uint16x8_t sum = vdupq_n_u16(0);
792 
793   do {
794     uint8x8_t s = vld1_u8(src_ptr);
795     uint8x8_t r = vld1_u8(ref_ptr);
796     uint8x8_t p = vld1_u8(second_pred);
797 
798     uint8x8_t wtd_avg = dist_wtd_avg_u8x8(p, r, bck_offset, fwd_offset);
799     sum = vabal_u8(sum, s, wtd_avg);
800 
801     src_ptr += src_stride;
802     ref_ptr += ref_stride;
803     second_pred += 8;
804   } while (--h != 0);
805 
806   return horizontal_add_u16x8(sum);
807 }
808 
dist_wtd_sad4xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)809 static inline unsigned int dist_wtd_sad4xh_avg_neon(
810     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
811     int ref_stride, int h, const uint8_t *second_pred,
812     const DIST_WTD_COMP_PARAMS *jcp_param) {
813   const uint8x8_t fwd_offset = vdup_n_u8(jcp_param->fwd_offset);
814   const uint8x8_t bck_offset = vdup_n_u8(jcp_param->bck_offset);
815   uint16x8_t sum = vdupq_n_u16(0);
816 
817   int i = h / 2;
818   do {
819     uint8x8_t s = load_unaligned_u8(src_ptr, src_stride);
820     uint8x8_t r = load_unaligned_u8(ref_ptr, ref_stride);
821     uint8x8_t p = vld1_u8(second_pred);
822 
823     uint8x8_t wtd_avg = dist_wtd_avg_u8x8(p, r, bck_offset, fwd_offset);
824     sum = vabal_u8(sum, s, wtd_avg);
825 
826     src_ptr += 2 * src_stride;
827     ref_ptr += 2 * ref_stride;
828     second_pred += 8;
829   } while (--i != 0);
830 
831   return horizontal_add_u16x8(sum);
832 }
833 
834 #define DIST_WTD_SAD_WXH_AVG_NEON(w, h)                                        \
835   unsigned int aom_dist_wtd_sad##w##x##h##_avg_neon(                           \
836       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,  \
837       const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) {     \
838     return dist_wtd_sad##w##xh_avg_neon(src, src_stride, ref, ref_stride, (h), \
839                                         second_pred, jcp_param);               \
840   }
841 
842 DIST_WTD_SAD_WXH_AVG_NEON(4, 4)
843 DIST_WTD_SAD_WXH_AVG_NEON(4, 8)
844 
845 DIST_WTD_SAD_WXH_AVG_NEON(8, 4)
846 DIST_WTD_SAD_WXH_AVG_NEON(8, 8)
847 DIST_WTD_SAD_WXH_AVG_NEON(8, 16)
848 
849 DIST_WTD_SAD_WXH_AVG_NEON(16, 8)
850 DIST_WTD_SAD_WXH_AVG_NEON(16, 16)
851 DIST_WTD_SAD_WXH_AVG_NEON(16, 32)
852 
853 DIST_WTD_SAD_WXH_AVG_NEON(32, 16)
854 DIST_WTD_SAD_WXH_AVG_NEON(32, 32)
855 DIST_WTD_SAD_WXH_AVG_NEON(32, 64)
856 
857 DIST_WTD_SAD_WXH_AVG_NEON(64, 32)
858 DIST_WTD_SAD_WXH_AVG_NEON(64, 64)
859 DIST_WTD_SAD_WXH_AVG_NEON(64, 128)
860 
861 DIST_WTD_SAD_WXH_AVG_NEON(128, 64)
862 DIST_WTD_SAD_WXH_AVG_NEON(128, 128)
863 
864 #if !CONFIG_REALTIME_ONLY
865 DIST_WTD_SAD_WXH_AVG_NEON(4, 16)
866 DIST_WTD_SAD_WXH_AVG_NEON(8, 32)
867 DIST_WTD_SAD_WXH_AVG_NEON(16, 4)
868 DIST_WTD_SAD_WXH_AVG_NEON(16, 64)
869 DIST_WTD_SAD_WXH_AVG_NEON(32, 8)
870 DIST_WTD_SAD_WXH_AVG_NEON(64, 16)
871 #endif  // !CONFIG_REALTIME_ONLY
872 
873 #undef DIST_WTD_SAD_WXH_AVG_NEON
874