xref: /aosp_15_r20/external/libaom/aom_dsp/arm/highbd_variance_sve.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17 
18 #include "aom_dsp/aom_filter.h"
19 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
20 #include "aom_dsp/arm/mem_neon.h"
21 #include "aom_dsp/variance.h"
22 
23 // Process a block of width 4 two rows at a time.
highbd_variance_4xh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,uint64_t * sse,int64_t * sum)24 static inline void highbd_variance_4xh_sve(const uint16_t *src_ptr,
25                                            int src_stride,
26                                            const uint16_t *ref_ptr,
27                                            int ref_stride, int h, uint64_t *sse,
28                                            int64_t *sum) {
29   int16x8_t sum_s16 = vdupq_n_s16(0);
30   int64x2_t sse_s64 = vdupq_n_s64(0);
31 
32   do {
33     const uint16x8_t s = load_unaligned_u16_4x2(src_ptr, src_stride);
34     const uint16x8_t r = load_unaligned_u16_4x2(ref_ptr, ref_stride);
35 
36     int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r));
37     sum_s16 = vaddq_s16(sum_s16, diff);
38 
39     sse_s64 = aom_sdotq_s16(sse_s64, diff, diff);
40 
41     src_ptr += 2 * src_stride;
42     ref_ptr += 2 * ref_stride;
43     h -= 2;
44   } while (h != 0);
45 
46   *sum = vaddlvq_s16(sum_s16);
47   *sse = vaddvq_s64(sse_s64);
48 }
49 
variance_8x1_sve(const uint16_t * src,const uint16_t * ref,int32x4_t * sum,int64x2_t * sse)50 static inline void variance_8x1_sve(const uint16_t *src, const uint16_t *ref,
51                                     int32x4_t *sum, int64x2_t *sse) {
52   const uint16x8_t s = vld1q_u16(src);
53   const uint16x8_t r = vld1q_u16(ref);
54 
55   const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r));
56   *sum = vpadalq_s16(*sum, diff);
57 
58   *sse = aom_sdotq_s16(*sse, diff, diff);
59 }
60 
highbd_variance_8xh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,uint64_t * sse,int64_t * sum)61 static inline void highbd_variance_8xh_sve(const uint16_t *src_ptr,
62                                            int src_stride,
63                                            const uint16_t *ref_ptr,
64                                            int ref_stride, int h, uint64_t *sse,
65                                            int64_t *sum) {
66   int32x4_t sum_s32 = vdupq_n_s32(0);
67   int64x2_t sse_s64 = vdupq_n_s64(0);
68 
69   do {
70     variance_8x1_sve(src_ptr, ref_ptr, &sum_s32, &sse_s64);
71 
72     src_ptr += src_stride;
73     ref_ptr += ref_stride;
74   } while (--h != 0);
75 
76   *sum = vaddlvq_s32(sum_s32);
77   *sse = vaddvq_s64(sse_s64);
78 }
79 
highbd_variance_16xh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,uint64_t * sse,int64_t * sum)80 static inline void highbd_variance_16xh_sve(const uint16_t *src_ptr,
81                                             int src_stride,
82                                             const uint16_t *ref_ptr,
83                                             int ref_stride, int h,
84                                             uint64_t *sse, int64_t *sum) {
85   int32x4_t sum_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
86   int64x2_t sse_s64[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
87 
88   do {
89     variance_8x1_sve(src_ptr, ref_ptr, &sum_s32[0], &sse_s64[0]);
90     variance_8x1_sve(src_ptr + 8, ref_ptr + 8, &sum_s32[1], &sse_s64[1]);
91 
92     src_ptr += src_stride;
93     ref_ptr += ref_stride;
94   } while (--h != 0);
95 
96   *sum = vaddlvq_s32(vaddq_s32(sum_s32[0], sum_s32[1]));
97   *sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[1]));
98 }
99 
highbd_variance_large_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int w,int h,uint64_t * sse,int64_t * sum)100 static inline void highbd_variance_large_sve(const uint16_t *src_ptr,
101                                              int src_stride,
102                                              const uint16_t *ref_ptr,
103                                              int ref_stride, int w, int h,
104                                              uint64_t *sse, int64_t *sum) {
105   int32x4_t sum_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
106                            vdupq_n_s32(0) };
107   int64x2_t sse_s64[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0),
108                            vdupq_n_s64(0) };
109 
110   do {
111     int j = 0;
112     do {
113       variance_8x1_sve(src_ptr + j, ref_ptr + j, &sum_s32[0], &sse_s64[0]);
114       variance_8x1_sve(src_ptr + j + 8, ref_ptr + j + 8, &sum_s32[1],
115                        &sse_s64[1]);
116       variance_8x1_sve(src_ptr + j + 16, ref_ptr + j + 16, &sum_s32[2],
117                        &sse_s64[2]);
118       variance_8x1_sve(src_ptr + j + 24, ref_ptr + j + 24, &sum_s32[3],
119                        &sse_s64[3]);
120 
121       j += 32;
122     } while (j < w);
123 
124     src_ptr += src_stride;
125     ref_ptr += ref_stride;
126   } while (--h != 0);
127 
128   sum_s32[0] = vaddq_s32(sum_s32[0], sum_s32[1]);
129   sum_s32[2] = vaddq_s32(sum_s32[2], sum_s32[3]);
130   *sum = vaddlvq_s32(vaddq_s32(sum_s32[0], sum_s32[2]));
131   sse_s64[0] = vaddq_s64(sse_s64[0], sse_s64[1]);
132   sse_s64[2] = vaddq_s64(sse_s64[2], sse_s64[3]);
133   *sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[2]));
134 }
135 
highbd_variance_32xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)136 static inline void highbd_variance_32xh_sve(const uint16_t *src, int src_stride,
137                                             const uint16_t *ref, int ref_stride,
138                                             int h, uint64_t *sse,
139                                             int64_t *sum) {
140   highbd_variance_large_sve(src, src_stride, ref, ref_stride, 32, h, sse, sum);
141 }
142 
highbd_variance_64xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)143 static inline void highbd_variance_64xh_sve(const uint16_t *src, int src_stride,
144                                             const uint16_t *ref, int ref_stride,
145                                             int h, uint64_t *sse,
146                                             int64_t *sum) {
147   highbd_variance_large_sve(src, src_stride, ref, ref_stride, 64, h, sse, sum);
148 }
149 
highbd_variance_128xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)150 static inline void highbd_variance_128xh_sve(const uint16_t *src,
151                                              int src_stride,
152                                              const uint16_t *ref,
153                                              int ref_stride, int h,
154                                              uint64_t *sse, int64_t *sum) {
155   highbd_variance_large_sve(src, src_stride, ref, ref_stride, 128, h, sse, sum);
156 }
157 
158 #define HBD_VARIANCE_WXH_8_SVE(w, h)                                  \
159   uint32_t aom_highbd_8_variance##w##x##h##_sve(                      \
160       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
161       int ref_stride, uint32_t *sse) {                                \
162     int sum;                                                          \
163     uint64_t sse_long = 0;                                            \
164     int64_t sum_long = 0;                                             \
165     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
166     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
167     highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h,  \
168                                 &sse_long, &sum_long);                \
169     *sse = (uint32_t)sse_long;                                        \
170     sum = (int)sum_long;                                              \
171     return *sse - (uint32_t)(((int64_t)sum * sum) / (w * h));         \
172   }
173 
174 #define HBD_VARIANCE_WXH_10_SVE(w, h)                                 \
175   uint32_t aom_highbd_10_variance##w##x##h##_sve(                     \
176       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
177       int ref_stride, uint32_t *sse) {                                \
178     int sum;                                                          \
179     int64_t var;                                                      \
180     uint64_t sse_long = 0;                                            \
181     int64_t sum_long = 0;                                             \
182     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
183     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
184     highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h,  \
185                                 &sse_long, &sum_long);                \
186     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4);                 \
187     sum = (int)ROUND_POWER_OF_TWO(sum_long, 2);                       \
188     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h));         \
189     return (var >= 0) ? (uint32_t)var : 0;                            \
190   }
191 
192 #define HBD_VARIANCE_WXH_12_SVE(w, h)                                 \
193   uint32_t aom_highbd_12_variance##w##x##h##_sve(                     \
194       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
195       int ref_stride, uint32_t *sse) {                                \
196     int sum;                                                          \
197     int64_t var;                                                      \
198     uint64_t sse_long = 0;                                            \
199     int64_t sum_long = 0;                                             \
200     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
201     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
202     highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h,  \
203                                 &sse_long, &sum_long);                \
204     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);                 \
205     sum = (int)ROUND_POWER_OF_TWO(sum_long, 4);                       \
206     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h));         \
207     return (var >= 0) ? (uint32_t)var : 0;                            \
208   }
209 
210 // 8-bit
211 HBD_VARIANCE_WXH_8_SVE(4, 4)
212 HBD_VARIANCE_WXH_8_SVE(4, 8)
213 
214 HBD_VARIANCE_WXH_8_SVE(8, 4)
215 HBD_VARIANCE_WXH_8_SVE(8, 8)
216 HBD_VARIANCE_WXH_8_SVE(8, 16)
217 
218 HBD_VARIANCE_WXH_8_SVE(16, 8)
219 HBD_VARIANCE_WXH_8_SVE(16, 16)
220 HBD_VARIANCE_WXH_8_SVE(16, 32)
221 
222 HBD_VARIANCE_WXH_8_SVE(32, 16)
223 HBD_VARIANCE_WXH_8_SVE(32, 32)
224 HBD_VARIANCE_WXH_8_SVE(32, 64)
225 
226 HBD_VARIANCE_WXH_8_SVE(64, 32)
227 HBD_VARIANCE_WXH_8_SVE(64, 64)
228 HBD_VARIANCE_WXH_8_SVE(64, 128)
229 
230 HBD_VARIANCE_WXH_8_SVE(128, 64)
231 HBD_VARIANCE_WXH_8_SVE(128, 128)
232 
233 // 10-bit
234 HBD_VARIANCE_WXH_10_SVE(4, 4)
235 HBD_VARIANCE_WXH_10_SVE(4, 8)
236 
237 HBD_VARIANCE_WXH_10_SVE(8, 4)
238 HBD_VARIANCE_WXH_10_SVE(8, 8)
239 HBD_VARIANCE_WXH_10_SVE(8, 16)
240 
241 HBD_VARIANCE_WXH_10_SVE(16, 8)
242 HBD_VARIANCE_WXH_10_SVE(16, 16)
243 HBD_VARIANCE_WXH_10_SVE(16, 32)
244 
245 HBD_VARIANCE_WXH_10_SVE(32, 16)
246 HBD_VARIANCE_WXH_10_SVE(32, 32)
247 HBD_VARIANCE_WXH_10_SVE(32, 64)
248 
249 HBD_VARIANCE_WXH_10_SVE(64, 32)
250 HBD_VARIANCE_WXH_10_SVE(64, 64)
251 HBD_VARIANCE_WXH_10_SVE(64, 128)
252 
253 HBD_VARIANCE_WXH_10_SVE(128, 64)
254 HBD_VARIANCE_WXH_10_SVE(128, 128)
255 
256 // 12-bit
257 HBD_VARIANCE_WXH_12_SVE(4, 4)
258 HBD_VARIANCE_WXH_12_SVE(4, 8)
259 
260 HBD_VARIANCE_WXH_12_SVE(8, 4)
261 HBD_VARIANCE_WXH_12_SVE(8, 8)
262 HBD_VARIANCE_WXH_12_SVE(8, 16)
263 
264 HBD_VARIANCE_WXH_12_SVE(16, 8)
265 HBD_VARIANCE_WXH_12_SVE(16, 16)
266 HBD_VARIANCE_WXH_12_SVE(16, 32)
267 
268 HBD_VARIANCE_WXH_12_SVE(32, 16)
269 HBD_VARIANCE_WXH_12_SVE(32, 32)
270 HBD_VARIANCE_WXH_12_SVE(32, 64)
271 
272 HBD_VARIANCE_WXH_12_SVE(64, 32)
273 HBD_VARIANCE_WXH_12_SVE(64, 64)
274 HBD_VARIANCE_WXH_12_SVE(64, 128)
275 
276 HBD_VARIANCE_WXH_12_SVE(128, 64)
277 HBD_VARIANCE_WXH_12_SVE(128, 128)
278 
279 #if !CONFIG_REALTIME_ONLY
280 // 8-bit
281 HBD_VARIANCE_WXH_8_SVE(4, 16)
282 
283 HBD_VARIANCE_WXH_8_SVE(8, 32)
284 
285 HBD_VARIANCE_WXH_8_SVE(16, 4)
286 HBD_VARIANCE_WXH_8_SVE(16, 64)
287 
288 HBD_VARIANCE_WXH_8_SVE(32, 8)
289 
290 HBD_VARIANCE_WXH_8_SVE(64, 16)
291 
292 // 10-bit
293 HBD_VARIANCE_WXH_10_SVE(4, 16)
294 
295 HBD_VARIANCE_WXH_10_SVE(8, 32)
296 
297 HBD_VARIANCE_WXH_10_SVE(16, 4)
298 HBD_VARIANCE_WXH_10_SVE(16, 64)
299 
300 HBD_VARIANCE_WXH_10_SVE(32, 8)
301 
302 HBD_VARIANCE_WXH_10_SVE(64, 16)
303 
304 // 12-bit
305 HBD_VARIANCE_WXH_12_SVE(4, 16)
306 
307 HBD_VARIANCE_WXH_12_SVE(8, 32)
308 
309 HBD_VARIANCE_WXH_12_SVE(16, 4)
310 HBD_VARIANCE_WXH_12_SVE(16, 64)
311 
312 HBD_VARIANCE_WXH_12_SVE(32, 8)
313 
314 HBD_VARIANCE_WXH_12_SVE(64, 16)
315 
316 #endif  // !CONFIG_REALTIME_ONLY
317 
318 #undef HBD_VARIANCE_WXH_8_SVE
319 #undef HBD_VARIANCE_WXH_10_SVE
320 #undef HBD_VARIANCE_WXH_12_SVE
321 
highbd_mse_wxh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int w,int h,unsigned int * sse)322 static inline uint32_t highbd_mse_wxh_sve(const uint16_t *src_ptr,
323                                           int src_stride,
324                                           const uint16_t *ref_ptr,
325                                           int ref_stride, int w, int h,
326                                           unsigned int *sse) {
327   uint64x2_t sse_u64 = vdupq_n_u64(0);
328 
329   do {
330     int j = 0;
331     do {
332       uint16x8_t s = vld1q_u16(src_ptr + j);
333       uint16x8_t r = vld1q_u16(ref_ptr + j);
334 
335       uint16x8_t diff = vabdq_u16(s, r);
336 
337       sse_u64 = aom_udotq_u16(sse_u64, diff, diff);
338 
339       j += 8;
340     } while (j < w);
341 
342     src_ptr += src_stride;
343     ref_ptr += ref_stride;
344   } while (--h != 0);
345 
346   *sse = (uint32_t)vaddvq_u64(sse_u64);
347   return *sse;
348 }
349 
350 #define HIGHBD_MSE_WXH_SVE(w, h)                                      \
351   uint32_t aom_highbd_10_mse##w##x##h##_sve(                          \
352       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
353       int ref_stride, uint32_t *sse) {                                \
354     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
355     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
356     highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse);  \
357     *sse = ROUND_POWER_OF_TWO(*sse, 4);                               \
358     return *sse;                                                      \
359   }                                                                   \
360                                                                       \
361   uint32_t aom_highbd_12_mse##w##x##h##_sve(                          \
362       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
363       int ref_stride, uint32_t *sse) {                                \
364     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
365     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
366     highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse);  \
367     *sse = ROUND_POWER_OF_TWO(*sse, 8);                               \
368     return *sse;                                                      \
369   }
370 
371 HIGHBD_MSE_WXH_SVE(16, 16)
372 HIGHBD_MSE_WXH_SVE(16, 8)
373 HIGHBD_MSE_WXH_SVE(8, 16)
374 HIGHBD_MSE_WXH_SVE(8, 8)
375 
376 #undef HIGHBD_MSE_WXH_SVE
377 
aom_mse_wxh_16bit_highbd_sve(uint16_t * dst,int dstride,uint16_t * src,int sstride,int w,int h)378 uint64_t aom_mse_wxh_16bit_highbd_sve(uint16_t *dst, int dstride, uint16_t *src,
379                                       int sstride, int w, int h) {
380   assert((w == 8 || w == 4) && (h == 8 || h == 4));
381 
382   uint64x2_t sum = vdupq_n_u64(0);
383 
384   if (w == 8) {
385     do {
386       uint16x8_t d0 = vld1q_u16(dst + 0 * dstride);
387       uint16x8_t d1 = vld1q_u16(dst + 1 * dstride);
388       uint16x8_t s0 = vld1q_u16(src + 0 * sstride);
389       uint16x8_t s1 = vld1q_u16(src + 1 * sstride);
390 
391       uint16x8_t abs_diff0 = vabdq_u16(s0, d0);
392       uint16x8_t abs_diff1 = vabdq_u16(s1, d1);
393 
394       sum = aom_udotq_u16(sum, abs_diff0, abs_diff0);
395       sum = aom_udotq_u16(sum, abs_diff1, abs_diff1);
396 
397       dst += 2 * dstride;
398       src += 2 * sstride;
399       h -= 2;
400     } while (h != 0);
401   } else {  // w == 4
402     do {
403       uint16x8_t d0 = load_unaligned_u16_4x2(dst + 0 * dstride, dstride);
404       uint16x8_t d1 = load_unaligned_u16_4x2(dst + 2 * dstride, dstride);
405       uint16x8_t s0 = load_unaligned_u16_4x2(src + 0 * sstride, sstride);
406       uint16x8_t s1 = load_unaligned_u16_4x2(src + 2 * sstride, sstride);
407 
408       uint16x8_t abs_diff0 = vabdq_u16(s0, d0);
409       uint16x8_t abs_diff1 = vabdq_u16(s1, d1);
410 
411       sum = aom_udotq_u16(sum, abs_diff0, abs_diff0);
412       sum = aom_udotq_u16(sum, abs_diff1, abs_diff1);
413 
414       dst += 4 * dstride;
415       src += 4 * sstride;
416       h -= 4;
417     } while (h != 0);
418   }
419 
420   return vaddvq_u64(sum);
421 }
422