1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17
18 #include "aom_dsp/aom_filter.h"
19 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
20 #include "aom_dsp/arm/mem_neon.h"
21 #include "aom_dsp/variance.h"
22
23 // Process a block of width 4 two rows at a time.
highbd_variance_4xh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,uint64_t * sse,int64_t * sum)24 static inline void highbd_variance_4xh_sve(const uint16_t *src_ptr,
25 int src_stride,
26 const uint16_t *ref_ptr,
27 int ref_stride, int h, uint64_t *sse,
28 int64_t *sum) {
29 int16x8_t sum_s16 = vdupq_n_s16(0);
30 int64x2_t sse_s64 = vdupq_n_s64(0);
31
32 do {
33 const uint16x8_t s = load_unaligned_u16_4x2(src_ptr, src_stride);
34 const uint16x8_t r = load_unaligned_u16_4x2(ref_ptr, ref_stride);
35
36 int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r));
37 sum_s16 = vaddq_s16(sum_s16, diff);
38
39 sse_s64 = aom_sdotq_s16(sse_s64, diff, diff);
40
41 src_ptr += 2 * src_stride;
42 ref_ptr += 2 * ref_stride;
43 h -= 2;
44 } while (h != 0);
45
46 *sum = vaddlvq_s16(sum_s16);
47 *sse = vaddvq_s64(sse_s64);
48 }
49
variance_8x1_sve(const uint16_t * src,const uint16_t * ref,int32x4_t * sum,int64x2_t * sse)50 static inline void variance_8x1_sve(const uint16_t *src, const uint16_t *ref,
51 int32x4_t *sum, int64x2_t *sse) {
52 const uint16x8_t s = vld1q_u16(src);
53 const uint16x8_t r = vld1q_u16(ref);
54
55 const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r));
56 *sum = vpadalq_s16(*sum, diff);
57
58 *sse = aom_sdotq_s16(*sse, diff, diff);
59 }
60
highbd_variance_8xh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,uint64_t * sse,int64_t * sum)61 static inline void highbd_variance_8xh_sve(const uint16_t *src_ptr,
62 int src_stride,
63 const uint16_t *ref_ptr,
64 int ref_stride, int h, uint64_t *sse,
65 int64_t *sum) {
66 int32x4_t sum_s32 = vdupq_n_s32(0);
67 int64x2_t sse_s64 = vdupq_n_s64(0);
68
69 do {
70 variance_8x1_sve(src_ptr, ref_ptr, &sum_s32, &sse_s64);
71
72 src_ptr += src_stride;
73 ref_ptr += ref_stride;
74 } while (--h != 0);
75
76 *sum = vaddlvq_s32(sum_s32);
77 *sse = vaddvq_s64(sse_s64);
78 }
79
highbd_variance_16xh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,uint64_t * sse,int64_t * sum)80 static inline void highbd_variance_16xh_sve(const uint16_t *src_ptr,
81 int src_stride,
82 const uint16_t *ref_ptr,
83 int ref_stride, int h,
84 uint64_t *sse, int64_t *sum) {
85 int32x4_t sum_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
86 int64x2_t sse_s64[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
87
88 do {
89 variance_8x1_sve(src_ptr, ref_ptr, &sum_s32[0], &sse_s64[0]);
90 variance_8x1_sve(src_ptr + 8, ref_ptr + 8, &sum_s32[1], &sse_s64[1]);
91
92 src_ptr += src_stride;
93 ref_ptr += ref_stride;
94 } while (--h != 0);
95
96 *sum = vaddlvq_s32(vaddq_s32(sum_s32[0], sum_s32[1]));
97 *sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[1]));
98 }
99
highbd_variance_large_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int w,int h,uint64_t * sse,int64_t * sum)100 static inline void highbd_variance_large_sve(const uint16_t *src_ptr,
101 int src_stride,
102 const uint16_t *ref_ptr,
103 int ref_stride, int w, int h,
104 uint64_t *sse, int64_t *sum) {
105 int32x4_t sum_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
106 vdupq_n_s32(0) };
107 int64x2_t sse_s64[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0),
108 vdupq_n_s64(0) };
109
110 do {
111 int j = 0;
112 do {
113 variance_8x1_sve(src_ptr + j, ref_ptr + j, &sum_s32[0], &sse_s64[0]);
114 variance_8x1_sve(src_ptr + j + 8, ref_ptr + j + 8, &sum_s32[1],
115 &sse_s64[1]);
116 variance_8x1_sve(src_ptr + j + 16, ref_ptr + j + 16, &sum_s32[2],
117 &sse_s64[2]);
118 variance_8x1_sve(src_ptr + j + 24, ref_ptr + j + 24, &sum_s32[3],
119 &sse_s64[3]);
120
121 j += 32;
122 } while (j < w);
123
124 src_ptr += src_stride;
125 ref_ptr += ref_stride;
126 } while (--h != 0);
127
128 sum_s32[0] = vaddq_s32(sum_s32[0], sum_s32[1]);
129 sum_s32[2] = vaddq_s32(sum_s32[2], sum_s32[3]);
130 *sum = vaddlvq_s32(vaddq_s32(sum_s32[0], sum_s32[2]));
131 sse_s64[0] = vaddq_s64(sse_s64[0], sse_s64[1]);
132 sse_s64[2] = vaddq_s64(sse_s64[2], sse_s64[3]);
133 *sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[2]));
134 }
135
highbd_variance_32xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)136 static inline void highbd_variance_32xh_sve(const uint16_t *src, int src_stride,
137 const uint16_t *ref, int ref_stride,
138 int h, uint64_t *sse,
139 int64_t *sum) {
140 highbd_variance_large_sve(src, src_stride, ref, ref_stride, 32, h, sse, sum);
141 }
142
highbd_variance_64xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)143 static inline void highbd_variance_64xh_sve(const uint16_t *src, int src_stride,
144 const uint16_t *ref, int ref_stride,
145 int h, uint64_t *sse,
146 int64_t *sum) {
147 highbd_variance_large_sve(src, src_stride, ref, ref_stride, 64, h, sse, sum);
148 }
149
highbd_variance_128xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)150 static inline void highbd_variance_128xh_sve(const uint16_t *src,
151 int src_stride,
152 const uint16_t *ref,
153 int ref_stride, int h,
154 uint64_t *sse, int64_t *sum) {
155 highbd_variance_large_sve(src, src_stride, ref, ref_stride, 128, h, sse, sum);
156 }
157
158 #define HBD_VARIANCE_WXH_8_SVE(w, h) \
159 uint32_t aom_highbd_8_variance##w##x##h##_sve( \
160 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
161 int ref_stride, uint32_t *sse) { \
162 int sum; \
163 uint64_t sse_long = 0; \
164 int64_t sum_long = 0; \
165 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
166 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
167 highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h, \
168 &sse_long, &sum_long); \
169 *sse = (uint32_t)sse_long; \
170 sum = (int)sum_long; \
171 return *sse - (uint32_t)(((int64_t)sum * sum) / (w * h)); \
172 }
173
174 #define HBD_VARIANCE_WXH_10_SVE(w, h) \
175 uint32_t aom_highbd_10_variance##w##x##h##_sve( \
176 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
177 int ref_stride, uint32_t *sse) { \
178 int sum; \
179 int64_t var; \
180 uint64_t sse_long = 0; \
181 int64_t sum_long = 0; \
182 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
183 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
184 highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h, \
185 &sse_long, &sum_long); \
186 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4); \
187 sum = (int)ROUND_POWER_OF_TWO(sum_long, 2); \
188 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \
189 return (var >= 0) ? (uint32_t)var : 0; \
190 }
191
192 #define HBD_VARIANCE_WXH_12_SVE(w, h) \
193 uint32_t aom_highbd_12_variance##w##x##h##_sve( \
194 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
195 int ref_stride, uint32_t *sse) { \
196 int sum; \
197 int64_t var; \
198 uint64_t sse_long = 0; \
199 int64_t sum_long = 0; \
200 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
201 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
202 highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h, \
203 &sse_long, &sum_long); \
204 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8); \
205 sum = (int)ROUND_POWER_OF_TWO(sum_long, 4); \
206 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \
207 return (var >= 0) ? (uint32_t)var : 0; \
208 }
209
210 // 8-bit
211 HBD_VARIANCE_WXH_8_SVE(4, 4)
212 HBD_VARIANCE_WXH_8_SVE(4, 8)
213
214 HBD_VARIANCE_WXH_8_SVE(8, 4)
215 HBD_VARIANCE_WXH_8_SVE(8, 8)
216 HBD_VARIANCE_WXH_8_SVE(8, 16)
217
218 HBD_VARIANCE_WXH_8_SVE(16, 8)
219 HBD_VARIANCE_WXH_8_SVE(16, 16)
220 HBD_VARIANCE_WXH_8_SVE(16, 32)
221
222 HBD_VARIANCE_WXH_8_SVE(32, 16)
223 HBD_VARIANCE_WXH_8_SVE(32, 32)
224 HBD_VARIANCE_WXH_8_SVE(32, 64)
225
226 HBD_VARIANCE_WXH_8_SVE(64, 32)
227 HBD_VARIANCE_WXH_8_SVE(64, 64)
228 HBD_VARIANCE_WXH_8_SVE(64, 128)
229
230 HBD_VARIANCE_WXH_8_SVE(128, 64)
231 HBD_VARIANCE_WXH_8_SVE(128, 128)
232
233 // 10-bit
234 HBD_VARIANCE_WXH_10_SVE(4, 4)
235 HBD_VARIANCE_WXH_10_SVE(4, 8)
236
237 HBD_VARIANCE_WXH_10_SVE(8, 4)
238 HBD_VARIANCE_WXH_10_SVE(8, 8)
239 HBD_VARIANCE_WXH_10_SVE(8, 16)
240
241 HBD_VARIANCE_WXH_10_SVE(16, 8)
242 HBD_VARIANCE_WXH_10_SVE(16, 16)
243 HBD_VARIANCE_WXH_10_SVE(16, 32)
244
245 HBD_VARIANCE_WXH_10_SVE(32, 16)
246 HBD_VARIANCE_WXH_10_SVE(32, 32)
247 HBD_VARIANCE_WXH_10_SVE(32, 64)
248
249 HBD_VARIANCE_WXH_10_SVE(64, 32)
250 HBD_VARIANCE_WXH_10_SVE(64, 64)
251 HBD_VARIANCE_WXH_10_SVE(64, 128)
252
253 HBD_VARIANCE_WXH_10_SVE(128, 64)
254 HBD_VARIANCE_WXH_10_SVE(128, 128)
255
256 // 12-bit
257 HBD_VARIANCE_WXH_12_SVE(4, 4)
258 HBD_VARIANCE_WXH_12_SVE(4, 8)
259
260 HBD_VARIANCE_WXH_12_SVE(8, 4)
261 HBD_VARIANCE_WXH_12_SVE(8, 8)
262 HBD_VARIANCE_WXH_12_SVE(8, 16)
263
264 HBD_VARIANCE_WXH_12_SVE(16, 8)
265 HBD_VARIANCE_WXH_12_SVE(16, 16)
266 HBD_VARIANCE_WXH_12_SVE(16, 32)
267
268 HBD_VARIANCE_WXH_12_SVE(32, 16)
269 HBD_VARIANCE_WXH_12_SVE(32, 32)
270 HBD_VARIANCE_WXH_12_SVE(32, 64)
271
272 HBD_VARIANCE_WXH_12_SVE(64, 32)
273 HBD_VARIANCE_WXH_12_SVE(64, 64)
274 HBD_VARIANCE_WXH_12_SVE(64, 128)
275
276 HBD_VARIANCE_WXH_12_SVE(128, 64)
277 HBD_VARIANCE_WXH_12_SVE(128, 128)
278
279 #if !CONFIG_REALTIME_ONLY
280 // 8-bit
281 HBD_VARIANCE_WXH_8_SVE(4, 16)
282
283 HBD_VARIANCE_WXH_8_SVE(8, 32)
284
285 HBD_VARIANCE_WXH_8_SVE(16, 4)
286 HBD_VARIANCE_WXH_8_SVE(16, 64)
287
288 HBD_VARIANCE_WXH_8_SVE(32, 8)
289
290 HBD_VARIANCE_WXH_8_SVE(64, 16)
291
292 // 10-bit
293 HBD_VARIANCE_WXH_10_SVE(4, 16)
294
295 HBD_VARIANCE_WXH_10_SVE(8, 32)
296
297 HBD_VARIANCE_WXH_10_SVE(16, 4)
298 HBD_VARIANCE_WXH_10_SVE(16, 64)
299
300 HBD_VARIANCE_WXH_10_SVE(32, 8)
301
302 HBD_VARIANCE_WXH_10_SVE(64, 16)
303
304 // 12-bit
305 HBD_VARIANCE_WXH_12_SVE(4, 16)
306
307 HBD_VARIANCE_WXH_12_SVE(8, 32)
308
309 HBD_VARIANCE_WXH_12_SVE(16, 4)
310 HBD_VARIANCE_WXH_12_SVE(16, 64)
311
312 HBD_VARIANCE_WXH_12_SVE(32, 8)
313
314 HBD_VARIANCE_WXH_12_SVE(64, 16)
315
316 #endif // !CONFIG_REALTIME_ONLY
317
318 #undef HBD_VARIANCE_WXH_8_SVE
319 #undef HBD_VARIANCE_WXH_10_SVE
320 #undef HBD_VARIANCE_WXH_12_SVE
321
highbd_mse_wxh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int w,int h,unsigned int * sse)322 static inline uint32_t highbd_mse_wxh_sve(const uint16_t *src_ptr,
323 int src_stride,
324 const uint16_t *ref_ptr,
325 int ref_stride, int w, int h,
326 unsigned int *sse) {
327 uint64x2_t sse_u64 = vdupq_n_u64(0);
328
329 do {
330 int j = 0;
331 do {
332 uint16x8_t s = vld1q_u16(src_ptr + j);
333 uint16x8_t r = vld1q_u16(ref_ptr + j);
334
335 uint16x8_t diff = vabdq_u16(s, r);
336
337 sse_u64 = aom_udotq_u16(sse_u64, diff, diff);
338
339 j += 8;
340 } while (j < w);
341
342 src_ptr += src_stride;
343 ref_ptr += ref_stride;
344 } while (--h != 0);
345
346 *sse = (uint32_t)vaddvq_u64(sse_u64);
347 return *sse;
348 }
349
350 #define HIGHBD_MSE_WXH_SVE(w, h) \
351 uint32_t aom_highbd_10_mse##w##x##h##_sve( \
352 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
353 int ref_stride, uint32_t *sse) { \
354 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
355 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
356 highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse); \
357 *sse = ROUND_POWER_OF_TWO(*sse, 4); \
358 return *sse; \
359 } \
360 \
361 uint32_t aom_highbd_12_mse##w##x##h##_sve( \
362 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
363 int ref_stride, uint32_t *sse) { \
364 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
365 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
366 highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse); \
367 *sse = ROUND_POWER_OF_TWO(*sse, 8); \
368 return *sse; \
369 }
370
371 HIGHBD_MSE_WXH_SVE(16, 16)
372 HIGHBD_MSE_WXH_SVE(16, 8)
373 HIGHBD_MSE_WXH_SVE(8, 16)
374 HIGHBD_MSE_WXH_SVE(8, 8)
375
376 #undef HIGHBD_MSE_WXH_SVE
377
aom_mse_wxh_16bit_highbd_sve(uint16_t * dst,int dstride,uint16_t * src,int sstride,int w,int h)378 uint64_t aom_mse_wxh_16bit_highbd_sve(uint16_t *dst, int dstride, uint16_t *src,
379 int sstride, int w, int h) {
380 assert((w == 8 || w == 4) && (h == 8 || h == 4));
381
382 uint64x2_t sum = vdupq_n_u64(0);
383
384 if (w == 8) {
385 do {
386 uint16x8_t d0 = vld1q_u16(dst + 0 * dstride);
387 uint16x8_t d1 = vld1q_u16(dst + 1 * dstride);
388 uint16x8_t s0 = vld1q_u16(src + 0 * sstride);
389 uint16x8_t s1 = vld1q_u16(src + 1 * sstride);
390
391 uint16x8_t abs_diff0 = vabdq_u16(s0, d0);
392 uint16x8_t abs_diff1 = vabdq_u16(s1, d1);
393
394 sum = aom_udotq_u16(sum, abs_diff0, abs_diff0);
395 sum = aom_udotq_u16(sum, abs_diff1, abs_diff1);
396
397 dst += 2 * dstride;
398 src += 2 * sstride;
399 h -= 2;
400 } while (h != 0);
401 } else { // w == 4
402 do {
403 uint16x8_t d0 = load_unaligned_u16_4x2(dst + 0 * dstride, dstride);
404 uint16x8_t d1 = load_unaligned_u16_4x2(dst + 2 * dstride, dstride);
405 uint16x8_t s0 = load_unaligned_u16_4x2(src + 0 * sstride, sstride);
406 uint16x8_t s1 = load_unaligned_u16_4x2(src + 2 * sstride, sstride);
407
408 uint16x8_t abs_diff0 = vabdq_u16(s0, d0);
409 uint16x8_t abs_diff1 = vabdq_u16(s1, d1);
410
411 sum = aom_udotq_u16(sum, abs_diff0, abs_diff0);
412 sum = aom_udotq_u16(sum, abs_diff1, abs_diff1);
413
414 dst += 4 * dstride;
415 src += 4 * sstride;
416 h -= 4;
417 } while (h != 0);
418 }
419
420 return vaddvq_u64(sum);
421 }
422