xref: /aosp_15_r20/external/libvpx/vpx_dsp/arm/highbd_sad_neon.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2022 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <arm_neon.h>
12 
13 #include "./vpx_config.h"
14 #include "./vpx_dsp_rtcd.h"
15 
16 #include "vpx/vpx_integer.h"
17 #include "vpx_dsp/arm/mem_neon.h"
18 #include "vpx_dsp/arm/sum_neon.h"
19 
highbd_sad4xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)20 static INLINE uint32_t highbd_sad4xh_neon(const uint8_t *src_ptr,
21                                           int src_stride,
22                                           const uint8_t *ref_ptr,
23                                           int ref_stride, int h) {
24   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
25   const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
26   uint32x4_t sum = vdupq_n_u32(0);
27 
28   int i = h;
29   do {
30     uint16x4_t s = vld1_u16(src16_ptr);
31     uint16x4_t r = vld1_u16(ref16_ptr);
32     sum = vabal_u16(sum, s, r);
33 
34     src16_ptr += src_stride;
35     ref16_ptr += ref_stride;
36   } while (--i != 0);
37 
38   return horizontal_add_uint32x4(sum);
39 }
40 
highbd_sad8xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)41 static INLINE uint32_t highbd_sad8xh_neon(const uint8_t *src_ptr,
42                                           int src_stride,
43                                           const uint8_t *ref_ptr,
44                                           int ref_stride, int h) {
45   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
46   const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
47   uint16x8_t sum = vdupq_n_u16(0);
48 
49   int i = h;
50   do {
51     uint16x8_t s = vld1q_u16(src16_ptr);
52     uint16x8_t r = vld1q_u16(ref16_ptr);
53     sum = vabaq_u16(sum, s, r);
54 
55     src16_ptr += src_stride;
56     ref16_ptr += ref_stride;
57   } while (--i != 0);
58 
59   return horizontal_add_uint16x8(sum);
60 }
61 
highbd_sad16xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)62 static INLINE uint32_t highbd_sad16xh_neon(const uint8_t *src_ptr,
63                                            int src_stride,
64                                            const uint8_t *ref_ptr,
65                                            int ref_stride, int h) {
66   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
67   const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
68   uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
69 
70   int i = h;
71   do {
72     uint16x8_t s0, s1, r0, r1;
73     uint16x8_t diff0, diff1;
74 
75     s0 = vld1q_u16(src16_ptr);
76     r0 = vld1q_u16(ref16_ptr);
77     diff0 = vabdq_u16(s0, r0);
78     sum[0] = vpadalq_u16(sum[0], diff0);
79 
80     s1 = vld1q_u16(src16_ptr + 8);
81     r1 = vld1q_u16(ref16_ptr + 8);
82     diff1 = vabdq_u16(s1, r1);
83     sum[1] = vpadalq_u16(sum[1], diff1);
84 
85     src16_ptr += src_stride;
86     ref16_ptr += ref_stride;
87   } while (--i != 0);
88 
89   sum[0] = vaddq_u32(sum[0], sum[1]);
90   return horizontal_add_uint32x4(sum[0]);
91 }
92 
highbd_sadwxh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int w,int h)93 static INLINE uint32_t highbd_sadwxh_neon(const uint8_t *src_ptr,
94                                           int src_stride,
95                                           const uint8_t *ref_ptr,
96                                           int ref_stride, int w, int h) {
97   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
98   const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
99   uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
100                         vdupq_n_u32(0) };
101 
102   int i = h;
103   do {
104     int j = 0;
105     do {
106       uint16x8_t s0, s1, s2, s3, r0, r1, r2, r3;
107       uint16x8_t diff0, diff1, diff2, diff3;
108 
109       s0 = vld1q_u16(src16_ptr + j);
110       r0 = vld1q_u16(ref16_ptr + j);
111       diff0 = vabdq_u16(s0, r0);
112       sum[0] = vpadalq_u16(sum[0], diff0);
113 
114       s1 = vld1q_u16(src16_ptr + j + 8);
115       r1 = vld1q_u16(ref16_ptr + j + 8);
116       diff1 = vabdq_u16(s1, r1);
117       sum[1] = vpadalq_u16(sum[1], diff1);
118 
119       s2 = vld1q_u16(src16_ptr + j + 16);
120       r2 = vld1q_u16(ref16_ptr + j + 16);
121       diff2 = vabdq_u16(s2, r2);
122       sum[2] = vpadalq_u16(sum[2], diff2);
123 
124       s3 = vld1q_u16(src16_ptr + j + 24);
125       r3 = vld1q_u16(ref16_ptr + j + 24);
126       diff3 = vabdq_u16(s3, r3);
127       sum[3] = vpadalq_u16(sum[3], diff3);
128 
129       j += 32;
130     } while (j < w);
131 
132     src16_ptr += src_stride;
133     ref16_ptr += ref_stride;
134   } while (--i != 0);
135 
136   sum[0] = vaddq_u32(sum[0], sum[1]);
137   sum[2] = vaddq_u32(sum[2], sum[3]);
138   sum[0] = vaddq_u32(sum[0], sum[2]);
139 
140   return horizontal_add_uint32x4(sum[0]);
141 }
142 
highbd_sad64xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)143 static INLINE unsigned int highbd_sad64xh_neon(const uint8_t *src_ptr,
144                                                int src_stride,
145                                                const uint8_t *ref_ptr,
146                                                int ref_stride, int h) {
147   return highbd_sadwxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 64, h);
148 }
149 
highbd_sad32xh_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)150 static INLINE unsigned int highbd_sad32xh_neon(const uint8_t *src_ptr,
151                                                int src_stride,
152                                                const uint8_t *ref_ptr,
153                                                int ref_stride, int h) {
154   return highbd_sadwxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 32, h);
155 }
156 
157 #define HBD_SAD_WXH_NEON(w, h)                                            \
158   unsigned int vpx_highbd_sad##w##x##h##_neon(                            \
159       const uint8_t *src, int src_stride, const uint8_t *ref,             \
160       int ref_stride) {                                                   \
161     return highbd_sad##w##xh_neon(src, src_stride, ref, ref_stride, (h)); \
162   }
163 
164 HBD_SAD_WXH_NEON(4, 4)
165 HBD_SAD_WXH_NEON(4, 8)
166 
167 HBD_SAD_WXH_NEON(8, 4)
168 HBD_SAD_WXH_NEON(8, 8)
169 HBD_SAD_WXH_NEON(8, 16)
170 
171 HBD_SAD_WXH_NEON(16, 8)
172 HBD_SAD_WXH_NEON(16, 16)
173 HBD_SAD_WXH_NEON(16, 32)
174 
175 HBD_SAD_WXH_NEON(32, 16)
176 HBD_SAD_WXH_NEON(32, 32)
177 HBD_SAD_WXH_NEON(32, 64)
178 
179 HBD_SAD_WXH_NEON(64, 32)
180 HBD_SAD_WXH_NEON(64, 64)
181 
182 #undef HBD_SAD_WXH_NEON
183 
184 #define HBD_SAD_SKIP_WXH_NEON(w, h)                             \
185   unsigned int vpx_highbd_sad_skip_##w##x##h##_neon(            \
186       const uint8_t *src, int src_stride, const uint8_t *ref,   \
187       int ref_stride) {                                         \
188     return 2 * highbd_sad##w##xh_neon(src, 2 * src_stride, ref, \
189                                       2 * ref_stride, (h) / 2); \
190   }
191 
192 HBD_SAD_SKIP_WXH_NEON(4, 4)
193 HBD_SAD_SKIP_WXH_NEON(4, 8)
194 
195 HBD_SAD_SKIP_WXH_NEON(8, 4)
196 HBD_SAD_SKIP_WXH_NEON(8, 8)
197 HBD_SAD_SKIP_WXH_NEON(8, 16)
198 
199 HBD_SAD_SKIP_WXH_NEON(16, 8)
200 HBD_SAD_SKIP_WXH_NEON(16, 16)
201 HBD_SAD_SKIP_WXH_NEON(16, 32)
202 
203 HBD_SAD_SKIP_WXH_NEON(32, 16)
204 HBD_SAD_SKIP_WXH_NEON(32, 32)
205 HBD_SAD_SKIP_WXH_NEON(32, 64)
206 
207 HBD_SAD_SKIP_WXH_NEON(64, 32)
208 HBD_SAD_SKIP_WXH_NEON(64, 64)
209 
210 #undef HBD_SAD_SKIP_WXH_NEON
211 
highbd_sad4xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)212 static INLINE uint32_t highbd_sad4xh_avg_neon(const uint8_t *src_ptr,
213                                               int src_stride,
214                                               const uint8_t *ref_ptr,
215                                               int ref_stride, int h,
216                                               const uint8_t *second_pred) {
217   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
218   const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
219   const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred);
220   uint32x4_t sum = vdupq_n_u32(0);
221 
222   int i = h;
223   do {
224     uint16x4_t s = vld1_u16(src16_ptr);
225     uint16x4_t r = vld1_u16(ref16_ptr);
226     uint16x4_t p = vld1_u16(pred16_ptr);
227 
228     uint16x4_t avg = vrhadd_u16(r, p);
229     sum = vabal_u16(sum, s, avg);
230 
231     src16_ptr += src_stride;
232     ref16_ptr += ref_stride;
233     pred16_ptr += 4;
234   } while (--i != 0);
235 
236   return horizontal_add_uint32x4(sum);
237 }
238 
highbd_sad8xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)239 static INLINE uint32_t highbd_sad8xh_avg_neon(const uint8_t *src_ptr,
240                                               int src_stride,
241                                               const uint8_t *ref_ptr,
242                                               int ref_stride, int h,
243                                               const uint8_t *second_pred) {
244   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
245   const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
246   const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred);
247   uint32x4_t sum = vdupq_n_u32(0);
248 
249   int i = h;
250   do {
251     uint16x8_t s = vld1q_u16(src16_ptr);
252     uint16x8_t r = vld1q_u16(ref16_ptr);
253     uint16x8_t p = vld1q_u16(pred16_ptr);
254 
255     uint16x8_t avg = vrhaddq_u16(r, p);
256     uint16x8_t diff = vabdq_u16(s, avg);
257     sum = vpadalq_u16(sum, diff);
258 
259     src16_ptr += src_stride;
260     ref16_ptr += ref_stride;
261     pred16_ptr += 8;
262   } while (--i != 0);
263 
264   return horizontal_add_uint32x4(sum);
265 }
266 
highbd_sad16xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)267 static INLINE uint32_t highbd_sad16xh_avg_neon(const uint8_t *src_ptr,
268                                                int src_stride,
269                                                const uint8_t *ref_ptr,
270                                                int ref_stride, int h,
271                                                const uint8_t *second_pred) {
272   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
273   const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
274   const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred);
275   uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
276 
277   int i = h;
278   do {
279     uint16x8_t s0, s1, r0, r1, p0, p1;
280     uint16x8_t avg0, avg1, diff0, diff1;
281 
282     s0 = vld1q_u16(src16_ptr);
283     r0 = vld1q_u16(ref16_ptr);
284     p0 = vld1q_u16(pred16_ptr);
285     avg0 = vrhaddq_u16(r0, p0);
286     diff0 = vabdq_u16(s0, avg0);
287     sum[0] = vpadalq_u16(sum[0], diff0);
288 
289     s1 = vld1q_u16(src16_ptr + 8);
290     r1 = vld1q_u16(ref16_ptr + 8);
291     p1 = vld1q_u16(pred16_ptr + 8);
292     avg1 = vrhaddq_u16(r1, p1);
293     diff1 = vabdq_u16(s1, avg1);
294     sum[1] = vpadalq_u16(sum[1], diff1);
295 
296     src16_ptr += src_stride;
297     ref16_ptr += ref_stride;
298     pred16_ptr += 16;
299   } while (--i != 0);
300 
301   sum[0] = vaddq_u32(sum[0], sum[1]);
302   return horizontal_add_uint32x4(sum[0]);
303 }
304 
highbd_sadwxh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int w,int h,const uint8_t * second_pred)305 static INLINE uint32_t highbd_sadwxh_avg_neon(const uint8_t *src_ptr,
306                                               int src_stride,
307                                               const uint8_t *ref_ptr,
308                                               int ref_stride, int w, int h,
309                                               const uint8_t *second_pred) {
310   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
311   const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr);
312   const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred);
313   uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
314                         vdupq_n_u32(0) };
315 
316   int i = h;
317   do {
318     int j = 0;
319     do {
320       uint16x8_t s0, s1, s2, s3, r0, r1, r2, r3, p0, p1, p2, p3;
321       uint16x8_t avg0, avg1, avg2, avg3, diff0, diff1, diff2, diff3;
322 
323       s0 = vld1q_u16(src16_ptr + j);
324       r0 = vld1q_u16(ref16_ptr + j);
325       p0 = vld1q_u16(pred16_ptr + j);
326       avg0 = vrhaddq_u16(r0, p0);
327       diff0 = vabdq_u16(s0, avg0);
328       sum[0] = vpadalq_u16(sum[0], diff0);
329 
330       s1 = vld1q_u16(src16_ptr + j + 8);
331       r1 = vld1q_u16(ref16_ptr + j + 8);
332       p1 = vld1q_u16(pred16_ptr + j + 8);
333       avg1 = vrhaddq_u16(r1, p1);
334       diff1 = vabdq_u16(s1, avg1);
335       sum[1] = vpadalq_u16(sum[1], diff1);
336 
337       s2 = vld1q_u16(src16_ptr + j + 16);
338       r2 = vld1q_u16(ref16_ptr + j + 16);
339       p2 = vld1q_u16(pred16_ptr + j + 16);
340       avg2 = vrhaddq_u16(r2, p2);
341       diff2 = vabdq_u16(s2, avg2);
342       sum[2] = vpadalq_u16(sum[2], diff2);
343 
344       s3 = vld1q_u16(src16_ptr + j + 24);
345       r3 = vld1q_u16(ref16_ptr + j + 24);
346       p3 = vld1q_u16(pred16_ptr + j + 24);
347       avg3 = vrhaddq_u16(r3, p3);
348       diff3 = vabdq_u16(s3, avg3);
349       sum[3] = vpadalq_u16(sum[3], diff3);
350 
351       j += 32;
352     } while (j < w);
353 
354     src16_ptr += src_stride;
355     ref16_ptr += ref_stride;
356     pred16_ptr += w;
357   } while (--i != 0);
358 
359   sum[0] = vaddq_u32(sum[0], sum[1]);
360   sum[2] = vaddq_u32(sum[2], sum[3]);
361   sum[0] = vaddq_u32(sum[0], sum[2]);
362 
363   return horizontal_add_uint32x4(sum[0]);
364 }
365 
highbd_sad64xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)366 static INLINE unsigned int highbd_sad64xh_avg_neon(const uint8_t *src_ptr,
367                                                    int src_stride,
368                                                    const uint8_t *ref_ptr,
369                                                    int ref_stride, int h,
370                                                    const uint8_t *second_pred) {
371   return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride, 64, h,
372                                 second_pred);
373 }
374 
highbd_sad32xh_avg_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)375 static INLINE unsigned int highbd_sad32xh_avg_neon(const uint8_t *src_ptr,
376                                                    int src_stride,
377                                                    const uint8_t *ref_ptr,
378                                                    int ref_stride, int h,
379                                                    const uint8_t *second_pred) {
380   return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride, 32, h,
381                                 second_pred);
382 }
383 
384 #define HBD_SAD_WXH_AVG_NEON(w, h)                                            \
385   uint32_t vpx_highbd_sad##w##x##h##_avg_neon(                                \
386       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
387       const uint8_t *second_pred) {                                           \
388     return highbd_sad##w##xh_avg_neon(src, src_stride, ref, ref_stride, (h),  \
389                                       second_pred);                           \
390   }
391 
392 HBD_SAD_WXH_AVG_NEON(4, 4)
393 HBD_SAD_WXH_AVG_NEON(4, 8)
394 
395 HBD_SAD_WXH_AVG_NEON(8, 4)
396 HBD_SAD_WXH_AVG_NEON(8, 8)
397 HBD_SAD_WXH_AVG_NEON(8, 16)
398 
399 HBD_SAD_WXH_AVG_NEON(16, 8)
400 HBD_SAD_WXH_AVG_NEON(16, 16)
401 HBD_SAD_WXH_AVG_NEON(16, 32)
402 
403 HBD_SAD_WXH_AVG_NEON(32, 16)
404 HBD_SAD_WXH_AVG_NEON(32, 32)
405 HBD_SAD_WXH_AVG_NEON(32, 64)
406 
407 HBD_SAD_WXH_AVG_NEON(64, 32)
408 HBD_SAD_WXH_AVG_NEON(64, 64)
409