xref: /aosp_15_r20/external/libaom/aom_dsp/arm/highbd_sse_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 #include "aom_dsp/arm/sum_neon.h"
16 
highbd_sse_8x1_init_neon(const uint16_t * src,const uint16_t * ref,uint32x4_t * sse_acc0,uint32x4_t * sse_acc1)17 static inline void highbd_sse_8x1_init_neon(const uint16_t *src,
18                                             const uint16_t *ref,
19                                             uint32x4_t *sse_acc0,
20                                             uint32x4_t *sse_acc1) {
21   uint16x8_t s = vld1q_u16(src);
22   uint16x8_t r = vld1q_u16(ref);
23 
24   uint16x8_t abs_diff = vabdq_u16(s, r);
25   uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
26   uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
27 
28   *sse_acc0 = vmull_u16(abs_diff_lo, abs_diff_lo);
29   *sse_acc1 = vmull_u16(abs_diff_hi, abs_diff_hi);
30 }
31 
highbd_sse_8x1_neon(const uint16_t * src,const uint16_t * ref,uint32x4_t * sse_acc0,uint32x4_t * sse_acc1)32 static inline void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref,
33                                        uint32x4_t *sse_acc0,
34                                        uint32x4_t *sse_acc1) {
35   uint16x8_t s = vld1q_u16(src);
36   uint16x8_t r = vld1q_u16(ref);
37 
38   uint16x8_t abs_diff = vabdq_u16(s, r);
39   uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
40   uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
41 
42   *sse_acc0 = vmlal_u16(*sse_acc0, abs_diff_lo, abs_diff_lo);
43   *sse_acc1 = vmlal_u16(*sse_acc1, abs_diff_hi, abs_diff_hi);
44 }
45 
highbd_sse_128xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)46 static inline int64_t highbd_sse_128xh_neon(const uint16_t *src, int src_stride,
47                                             const uint16_t *ref, int ref_stride,
48                                             int height) {
49   uint32x4_t sse[16];
50   highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
51   highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
52   highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
53   highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
54   highbd_sse_8x1_init_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]);
55   highbd_sse_8x1_init_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]);
56   highbd_sse_8x1_init_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]);
57   highbd_sse_8x1_init_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]);
58   highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]);
59   highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]);
60   highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]);
61   highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]);
62   highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]);
63   highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]);
64   highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]);
65   highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]);
66 
67   src += src_stride;
68   ref += ref_stride;
69 
70   while (--height != 0) {
71     highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
72     highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
73     highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
74     highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
75     highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]);
76     highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]);
77     highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]);
78     highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]);
79     highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]);
80     highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]);
81     highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]);
82     highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]);
83     highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]);
84     highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]);
85     highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]);
86     highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]);
87 
88     src += src_stride;
89     ref += ref_stride;
90   }
91 
92   return horizontal_long_add_u32x4_x16(sse);
93 }
94 
highbd_sse_64xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)95 static inline int64_t highbd_sse_64xh_neon(const uint16_t *src, int src_stride,
96                                            const uint16_t *ref, int ref_stride,
97                                            int height) {
98   uint32x4_t sse[8];
99   highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
100   highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
101   highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
102   highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
103   highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]);
104   highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]);
105   highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]);
106   highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]);
107 
108   src += src_stride;
109   ref += ref_stride;
110 
111   while (--height != 0) {
112     highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
113     highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
114     highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
115     highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
116     highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]);
117     highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]);
118     highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]);
119     highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]);
120 
121     src += src_stride;
122     ref += ref_stride;
123   }
124 
125   return horizontal_long_add_u32x4_x8(sse);
126 }
127 
highbd_sse_32xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)128 static inline int64_t highbd_sse_32xh_neon(const uint16_t *src, int src_stride,
129                                            const uint16_t *ref, int ref_stride,
130                                            int height) {
131   uint32x4_t sse[8];
132   highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
133   highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
134   highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
135   highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
136 
137   src += src_stride;
138   ref += ref_stride;
139 
140   while (--height != 0) {
141     highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
142     highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
143     highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
144     highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
145 
146     src += src_stride;
147     ref += ref_stride;
148   }
149 
150   return horizontal_long_add_u32x4_x8(sse);
151 }
152 
highbd_sse_16xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)153 static inline int64_t highbd_sse_16xh_neon(const uint16_t *src, int src_stride,
154                                            const uint16_t *ref, int ref_stride,
155                                            int height) {
156   uint32x4_t sse[4];
157   highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
158   highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
159 
160   src += src_stride;
161   ref += ref_stride;
162 
163   while (--height != 0) {
164     highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
165     highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
166 
167     src += src_stride;
168     ref += ref_stride;
169   }
170 
171   return horizontal_long_add_u32x4_x4(sse);
172 }
173 
highbd_sse_8xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)174 static inline int64_t highbd_sse_8xh_neon(const uint16_t *src, int src_stride,
175                                           const uint16_t *ref, int ref_stride,
176                                           int height) {
177   uint32x4_t sse[2];
178   highbd_sse_8x1_init_neon(src, ref, &sse[0], &sse[1]);
179 
180   src += src_stride;
181   ref += ref_stride;
182 
183   while (--height != 0) {
184     highbd_sse_8x1_neon(src, ref, &sse[0], &sse[1]);
185 
186     src += src_stride;
187     ref += ref_stride;
188   }
189 
190   return horizontal_long_add_u32x4_x2(sse);
191 }
192 
highbd_sse_4xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)193 static inline int64_t highbd_sse_4xh_neon(const uint16_t *src, int src_stride,
194                                           const uint16_t *ref, int ref_stride,
195                                           int height) {
196   // Peel the first loop iteration.
197   uint16x4_t s = vld1_u16(src);
198   uint16x4_t r = vld1_u16(ref);
199 
200   uint16x4_t abs_diff = vabd_u16(s, r);
201   uint32x4_t sse = vmull_u16(abs_diff, abs_diff);
202 
203   src += src_stride;
204   ref += ref_stride;
205 
206   while (--height != 0) {
207     s = vld1_u16(src);
208     r = vld1_u16(ref);
209 
210     abs_diff = vabd_u16(s, r);
211     sse = vmlal_u16(sse, abs_diff, abs_diff);
212 
213     src += src_stride;
214     ref += ref_stride;
215   }
216 
217   return horizontal_long_add_u32x4(sse);
218 }
219 
highbd_sse_wxh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int width,int height)220 static inline int64_t highbd_sse_wxh_neon(const uint16_t *src, int src_stride,
221                                           const uint16_t *ref, int ref_stride,
222                                           int width, int height) {
223   // { 0, 1, 2, 3, 4, 5, 6, 7 }
224   uint16x8_t k01234567 = vmovl_u8(vcreate_u8(0x0706050403020100));
225   uint16x8_t remainder_mask = vcltq_u16(k01234567, vdupq_n_u16(width & 7));
226   uint64_t sse = 0;
227 
228   do {
229     int w = width;
230     int offset = 0;
231 
232     do {
233       uint16x8_t s = vld1q_u16(src + offset);
234       uint16x8_t r = vld1q_u16(ref + offset);
235 
236       if (w < 8) {
237         // Mask out-of-range elements.
238         s = vandq_u16(s, remainder_mask);
239         r = vandq_u16(r, remainder_mask);
240       }
241 
242       uint16x8_t abs_diff = vabdq_u16(s, r);
243       uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
244       uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
245 
246       uint32x4_t sse_u32 = vmull_u16(abs_diff_lo, abs_diff_lo);
247       sse_u32 = vmlal_u16(sse_u32, abs_diff_hi, abs_diff_hi);
248 
249       sse += horizontal_long_add_u32x4(sse_u32);
250 
251       offset += 8;
252       w -= 8;
253     } while (w > 0);
254 
255     src += src_stride;
256     ref += ref_stride;
257   } while (--height != 0);
258 
259   return sse;
260 }
261 
aom_highbd_sse_neon(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,int width,int height)262 int64_t aom_highbd_sse_neon(const uint8_t *src8, int src_stride,
263                             const uint8_t *ref8, int ref_stride, int width,
264                             int height) {
265   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
266   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
267 
268   switch (width) {
269     case 4:
270       return highbd_sse_4xh_neon(src, src_stride, ref, ref_stride, height);
271     case 8:
272       return highbd_sse_8xh_neon(src, src_stride, ref, ref_stride, height);
273     case 16:
274       return highbd_sse_16xh_neon(src, src_stride, ref, ref_stride, height);
275     case 32:
276       return highbd_sse_32xh_neon(src, src_stride, ref, ref_stride, height);
277     case 64:
278       return highbd_sse_64xh_neon(src, src_stride, ref, ref_stride, height);
279     case 128:
280       return highbd_sse_128xh_neon(src, src_stride, ref, ref_stride, height);
281     default:
282       return highbd_sse_wxh_neon(src, src_stride, ref, ref_stride, width,
283                                  height);
284   }
285 }
286