xref: /aosp_15_r20/external/libaom/aom_dsp/arm/sse_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "aom_dsp/arm/sum_neon.h"
17 
sse_16x1_neon(const uint8_t * src,const uint8_t * ref,uint32x4_t * sse)18 static inline void sse_16x1_neon(const uint8_t *src, const uint8_t *ref,
19                                  uint32x4_t *sse) {
20   uint8x16_t s = vld1q_u8(src);
21   uint8x16_t r = vld1q_u8(ref);
22 
23   uint8x16_t abs_diff = vabdq_u8(s, r);
24   uint8x8_t abs_diff_lo = vget_low_u8(abs_diff);
25   uint8x8_t abs_diff_hi = vget_high_u8(abs_diff);
26 
27   *sse = vpadalq_u16(*sse, vmull_u8(abs_diff_lo, abs_diff_lo));
28   *sse = vpadalq_u16(*sse, vmull_u8(abs_diff_hi, abs_diff_hi));
29 }
30 
sse_8x1_neon(const uint8_t * src,const uint8_t * ref,uint32x4_t * sse)31 static inline void sse_8x1_neon(const uint8_t *src, const uint8_t *ref,
32                                 uint32x4_t *sse) {
33   uint8x8_t s = vld1_u8(src);
34   uint8x8_t r = vld1_u8(ref);
35 
36   uint8x8_t abs_diff = vabd_u8(s, r);
37 
38   *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff));
39 }
40 
sse_4x2_neon(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,uint32x4_t * sse)41 static inline void sse_4x2_neon(const uint8_t *src, int src_stride,
42                                 const uint8_t *ref, int ref_stride,
43                                 uint32x4_t *sse) {
44   uint8x8_t s = load_unaligned_u8(src, src_stride);
45   uint8x8_t r = load_unaligned_u8(ref, ref_stride);
46 
47   uint8x8_t abs_diff = vabd_u8(s, r);
48 
49   *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff));
50 }
51 
sse_wxh_neon(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)52 static inline uint32_t sse_wxh_neon(const uint8_t *src, int src_stride,
53                                     const uint8_t *ref, int ref_stride,
54                                     int width, int height) {
55   uint32x4_t sse = vdupq_n_u32(0);
56 
57   if ((width & 0x07) && ((width & 0x07) < 5)) {
58     int i = height;
59     do {
60       int j = 0;
61       do {
62         sse_8x1_neon(src + j, ref + j, &sse);
63         sse_8x1_neon(src + j + src_stride, ref + j + ref_stride, &sse);
64         j += 8;
65       } while (j + 4 < width);
66 
67       sse_4x2_neon(src + j, src_stride, ref + j, ref_stride, &sse);
68       src += 2 * src_stride;
69       ref += 2 * ref_stride;
70       i -= 2;
71     } while (i != 0);
72   } else {
73     int i = height;
74     do {
75       int j = 0;
76       do {
77         sse_8x1_neon(src + j, ref + j, &sse);
78         j += 8;
79       } while (j < width);
80 
81       src += src_stride;
82       ref += ref_stride;
83     } while (--i != 0);
84   }
85   return horizontal_add_u32x4(sse);
86 }
87 
sse_128xh_neon(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)88 static inline uint32_t sse_128xh_neon(const uint8_t *src, int src_stride,
89                                       const uint8_t *ref, int ref_stride,
90                                       int height) {
91   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
92 
93   int i = height;
94   do {
95     sse_16x1_neon(src, ref, &sse[0]);
96     sse_16x1_neon(src + 16, ref + 16, &sse[1]);
97     sse_16x1_neon(src + 32, ref + 32, &sse[0]);
98     sse_16x1_neon(src + 48, ref + 48, &sse[1]);
99     sse_16x1_neon(src + 64, ref + 64, &sse[0]);
100     sse_16x1_neon(src + 80, ref + 80, &sse[1]);
101     sse_16x1_neon(src + 96, ref + 96, &sse[0]);
102     sse_16x1_neon(src + 112, ref + 112, &sse[1]);
103 
104     src += src_stride;
105     ref += ref_stride;
106   } while (--i != 0);
107 
108   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
109 }
110 
sse_64xh_neon(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)111 static inline uint32_t sse_64xh_neon(const uint8_t *src, int src_stride,
112                                      const uint8_t *ref, int ref_stride,
113                                      int height) {
114   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
115 
116   int i = height;
117   do {
118     sse_16x1_neon(src, ref, &sse[0]);
119     sse_16x1_neon(src + 16, ref + 16, &sse[1]);
120     sse_16x1_neon(src + 32, ref + 32, &sse[0]);
121     sse_16x1_neon(src + 48, ref + 48, &sse[1]);
122 
123     src += src_stride;
124     ref += ref_stride;
125   } while (--i != 0);
126 
127   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
128 }
129 
sse_32xh_neon(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)130 static inline uint32_t sse_32xh_neon(const uint8_t *src, int src_stride,
131                                      const uint8_t *ref, int ref_stride,
132                                      int height) {
133   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
134 
135   int i = height;
136   do {
137     sse_16x1_neon(src, ref, &sse[0]);
138     sse_16x1_neon(src + 16, ref + 16, &sse[1]);
139 
140     src += src_stride;
141     ref += ref_stride;
142   } while (--i != 0);
143 
144   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
145 }
146 
sse_16xh_neon(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)147 static inline uint32_t sse_16xh_neon(const uint8_t *src, int src_stride,
148                                      const uint8_t *ref, int ref_stride,
149                                      int height) {
150   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
151 
152   int i = height;
153   do {
154     sse_16x1_neon(src, ref, &sse[0]);
155     src += src_stride;
156     ref += ref_stride;
157     sse_16x1_neon(src, ref, &sse[1]);
158     src += src_stride;
159     ref += ref_stride;
160     i -= 2;
161   } while (i != 0);
162 
163   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
164 }
165 
sse_8xh_neon(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)166 static inline uint32_t sse_8xh_neon(const uint8_t *src, int src_stride,
167                                     const uint8_t *ref, int ref_stride,
168                                     int height) {
169   uint32x4_t sse = vdupq_n_u32(0);
170 
171   int i = height;
172   do {
173     sse_8x1_neon(src, ref, &sse);
174 
175     src += src_stride;
176     ref += ref_stride;
177   } while (--i != 0);
178 
179   return horizontal_add_u32x4(sse);
180 }
181 
sse_4xh_neon(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)182 static inline uint32_t sse_4xh_neon(const uint8_t *src, int src_stride,
183                                     const uint8_t *ref, int ref_stride,
184                                     int height) {
185   uint32x4_t sse = vdupq_n_u32(0);
186 
187   int i = height;
188   do {
189     sse_4x2_neon(src, src_stride, ref, ref_stride, &sse);
190 
191     src += 2 * src_stride;
192     ref += 2 * ref_stride;
193     i -= 2;
194   } while (i != 0);
195 
196   return horizontal_add_u32x4(sse);
197 }
198 
aom_sse_neon(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)199 int64_t aom_sse_neon(const uint8_t *src, int src_stride, const uint8_t *ref,
200                      int ref_stride, int width, int height) {
201   switch (width) {
202     case 4: return sse_4xh_neon(src, src_stride, ref, ref_stride, height);
203     case 8: return sse_8xh_neon(src, src_stride, ref, ref_stride, height);
204     case 16: return sse_16xh_neon(src, src_stride, ref, ref_stride, height);
205     case 32: return sse_32xh_neon(src, src_stride, ref, ref_stride, height);
206     case 64: return sse_64xh_neon(src, src_stride, ref, ref_stride, height);
207     case 128: return sse_128xh_neon(src, src_stride, ref, ref_stride, height);
208     default:
209       return sse_wxh_neon(src, src_stride, ref, ref_stride, width, height);
210   }
211 }
212