1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "config/aom_dsp_rtcd.h"
15 #include "aom_dsp/arm/sum_neon.h"
16
highbd_sse_8x1_init_neon(const uint16_t * src,const uint16_t * ref,uint32x4_t * sse_acc0,uint32x4_t * sse_acc1)17 static inline void highbd_sse_8x1_init_neon(const uint16_t *src,
18 const uint16_t *ref,
19 uint32x4_t *sse_acc0,
20 uint32x4_t *sse_acc1) {
21 uint16x8_t s = vld1q_u16(src);
22 uint16x8_t r = vld1q_u16(ref);
23
24 uint16x8_t abs_diff = vabdq_u16(s, r);
25 uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
26 uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
27
28 *sse_acc0 = vmull_u16(abs_diff_lo, abs_diff_lo);
29 *sse_acc1 = vmull_u16(abs_diff_hi, abs_diff_hi);
30 }
31
highbd_sse_8x1_neon(const uint16_t * src,const uint16_t * ref,uint32x4_t * sse_acc0,uint32x4_t * sse_acc1)32 static inline void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref,
33 uint32x4_t *sse_acc0,
34 uint32x4_t *sse_acc1) {
35 uint16x8_t s = vld1q_u16(src);
36 uint16x8_t r = vld1q_u16(ref);
37
38 uint16x8_t abs_diff = vabdq_u16(s, r);
39 uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
40 uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
41
42 *sse_acc0 = vmlal_u16(*sse_acc0, abs_diff_lo, abs_diff_lo);
43 *sse_acc1 = vmlal_u16(*sse_acc1, abs_diff_hi, abs_diff_hi);
44 }
45
highbd_sse_128xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)46 static inline int64_t highbd_sse_128xh_neon(const uint16_t *src, int src_stride,
47 const uint16_t *ref, int ref_stride,
48 int height) {
49 uint32x4_t sse[16];
50 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
51 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
52 highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
53 highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
54 highbd_sse_8x1_init_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]);
55 highbd_sse_8x1_init_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]);
56 highbd_sse_8x1_init_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]);
57 highbd_sse_8x1_init_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]);
58 highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]);
59 highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]);
60 highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]);
61 highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]);
62 highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]);
63 highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]);
64 highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]);
65 highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]);
66
67 src += src_stride;
68 ref += ref_stride;
69
70 while (--height != 0) {
71 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
72 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
73 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
74 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
75 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]);
76 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]);
77 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]);
78 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]);
79 highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]);
80 highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]);
81 highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]);
82 highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]);
83 highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]);
84 highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]);
85 highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]);
86 highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]);
87
88 src += src_stride;
89 ref += ref_stride;
90 }
91
92 return horizontal_long_add_u32x4_x16(sse);
93 }
94
highbd_sse_64xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)95 static inline int64_t highbd_sse_64xh_neon(const uint16_t *src, int src_stride,
96 const uint16_t *ref, int ref_stride,
97 int height) {
98 uint32x4_t sse[8];
99 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
100 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
101 highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
102 highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
103 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]);
104 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]);
105 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]);
106 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]);
107
108 src += src_stride;
109 ref += ref_stride;
110
111 while (--height != 0) {
112 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
113 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
114 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
115 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
116 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]);
117 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]);
118 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]);
119 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]);
120
121 src += src_stride;
122 ref += ref_stride;
123 }
124
125 return horizontal_long_add_u32x4_x8(sse);
126 }
127
highbd_sse_32xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)128 static inline int64_t highbd_sse_32xh_neon(const uint16_t *src, int src_stride,
129 const uint16_t *ref, int ref_stride,
130 int height) {
131 uint32x4_t sse[8];
132 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
133 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
134 highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
135 highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
136
137 src += src_stride;
138 ref += ref_stride;
139
140 while (--height != 0) {
141 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
142 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
143 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
144 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
145
146 src += src_stride;
147 ref += ref_stride;
148 }
149
150 return horizontal_long_add_u32x4_x8(sse);
151 }
152
highbd_sse_16xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)153 static inline int64_t highbd_sse_16xh_neon(const uint16_t *src, int src_stride,
154 const uint16_t *ref, int ref_stride,
155 int height) {
156 uint32x4_t sse[4];
157 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
158 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
159
160 src += src_stride;
161 ref += ref_stride;
162
163 while (--height != 0) {
164 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
165 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
166
167 src += src_stride;
168 ref += ref_stride;
169 }
170
171 return horizontal_long_add_u32x4_x4(sse);
172 }
173
highbd_sse_8xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)174 static inline int64_t highbd_sse_8xh_neon(const uint16_t *src, int src_stride,
175 const uint16_t *ref, int ref_stride,
176 int height) {
177 uint32x4_t sse[2];
178 highbd_sse_8x1_init_neon(src, ref, &sse[0], &sse[1]);
179
180 src += src_stride;
181 ref += ref_stride;
182
183 while (--height != 0) {
184 highbd_sse_8x1_neon(src, ref, &sse[0], &sse[1]);
185
186 src += src_stride;
187 ref += ref_stride;
188 }
189
190 return horizontal_long_add_u32x4_x2(sse);
191 }
192
highbd_sse_4xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)193 static inline int64_t highbd_sse_4xh_neon(const uint16_t *src, int src_stride,
194 const uint16_t *ref, int ref_stride,
195 int height) {
196 // Peel the first loop iteration.
197 uint16x4_t s = vld1_u16(src);
198 uint16x4_t r = vld1_u16(ref);
199
200 uint16x4_t abs_diff = vabd_u16(s, r);
201 uint32x4_t sse = vmull_u16(abs_diff, abs_diff);
202
203 src += src_stride;
204 ref += ref_stride;
205
206 while (--height != 0) {
207 s = vld1_u16(src);
208 r = vld1_u16(ref);
209
210 abs_diff = vabd_u16(s, r);
211 sse = vmlal_u16(sse, abs_diff, abs_diff);
212
213 src += src_stride;
214 ref += ref_stride;
215 }
216
217 return horizontal_long_add_u32x4(sse);
218 }
219
highbd_sse_wxh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int width,int height)220 static inline int64_t highbd_sse_wxh_neon(const uint16_t *src, int src_stride,
221 const uint16_t *ref, int ref_stride,
222 int width, int height) {
223 // { 0, 1, 2, 3, 4, 5, 6, 7 }
224 uint16x8_t k01234567 = vmovl_u8(vcreate_u8(0x0706050403020100));
225 uint16x8_t remainder_mask = vcltq_u16(k01234567, vdupq_n_u16(width & 7));
226 uint64_t sse = 0;
227
228 do {
229 int w = width;
230 int offset = 0;
231
232 do {
233 uint16x8_t s = vld1q_u16(src + offset);
234 uint16x8_t r = vld1q_u16(ref + offset);
235
236 if (w < 8) {
237 // Mask out-of-range elements.
238 s = vandq_u16(s, remainder_mask);
239 r = vandq_u16(r, remainder_mask);
240 }
241
242 uint16x8_t abs_diff = vabdq_u16(s, r);
243 uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
244 uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
245
246 uint32x4_t sse_u32 = vmull_u16(abs_diff_lo, abs_diff_lo);
247 sse_u32 = vmlal_u16(sse_u32, abs_diff_hi, abs_diff_hi);
248
249 sse += horizontal_long_add_u32x4(sse_u32);
250
251 offset += 8;
252 w -= 8;
253 } while (w > 0);
254
255 src += src_stride;
256 ref += ref_stride;
257 } while (--height != 0);
258
259 return sse;
260 }
261
aom_highbd_sse_neon(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,int width,int height)262 int64_t aom_highbd_sse_neon(const uint8_t *src8, int src_stride,
263 const uint8_t *ref8, int ref_stride, int width,
264 int height) {
265 uint16_t *src = CONVERT_TO_SHORTPTR(src8);
266 uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
267
268 switch (width) {
269 case 4:
270 return highbd_sse_4xh_neon(src, src_stride, ref, ref_stride, height);
271 case 8:
272 return highbd_sse_8xh_neon(src, src_stride, ref, ref_stride, height);
273 case 16:
274 return highbd_sse_16xh_neon(src, src_stride, ref, ref_stride, height);
275 case 32:
276 return highbd_sse_32xh_neon(src, src_stride, ref, ref_stride, height);
277 case 64:
278 return highbd_sse_64xh_neon(src, src_stride, ref, ref_stride, height);
279 case 128:
280 return highbd_sse_128xh_neon(src, src_stride, ref, ref_stride, height);
281 default:
282 return highbd_sse_wxh_neon(src, src_stride, ref, ref_stride, width,
283 height);
284 }
285 }
286