xref: /aosp_15_r20/external/libvpx/vpx_dsp/arm/highbd_variance_sve.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  * Copyright (c) 2024 The WebM project authors. All Rights Reserved.
3  *
4  * Use of this source code is governed by a BSD-style license
5  * that can be found in the LICENSE file in the root of the source
6  * tree. An additional intellectual property rights grant can be found
7  * in the file PATENTS.  All contributing project authors may
8  * be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <arm_neon.h>
12 
13 #include "./vpx_dsp_rtcd.h"
14 #include "./vpx_config.h"
15 
16 #include "vpx_dsp/arm/mem_neon.h"
17 #include "vpx_dsp/arm/sum_neon.h"
18 #include "vpx_dsp/arm/vpx_neon_sve_bridge.h"
19 #include "vpx_ports/mem.h"
20 
highbd_mse_wxh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int w,int h)21 static INLINE uint32_t highbd_mse_wxh_sve(const uint16_t *src_ptr,
22                                           int src_stride,
23                                           const uint16_t *ref_ptr,
24                                           int ref_stride, int w, int h) {
25   uint64x2_t sse = vdupq_n_u64(0);
26 
27   do {
28     int j = 0;
29     do {
30       uint16x8_t s = vld1q_u16(src_ptr + j);
31       uint16x8_t r = vld1q_u16(ref_ptr + j);
32 
33       uint16x8_t diff = vabdq_u16(s, r);
34 
35       sse = vpx_dotq_u16(sse, diff, diff);
36 
37       j += 8;
38     } while (j < w);
39 
40     src_ptr += src_stride;
41     ref_ptr += ref_stride;
42   } while (--h != 0);
43 
44   return (uint32_t)horizontal_add_uint64x2(sse);
45 }
46 
47 #define HIGHBD_MSE_WXH_SVE(w, h)                                      \
48   uint32_t vpx_highbd_10_mse##w##x##h##_sve(                          \
49       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
50       int ref_stride, uint32_t *sse) {                                \
51     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
52     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
53     uint32_t sse_tmp =                                                \
54         highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h);   \
55     sse_tmp = ROUND_POWER_OF_TWO(sse_tmp, 4);                         \
56     *sse = sse_tmp;                                                   \
57     return sse_tmp;                                                   \
58   }                                                                   \
59                                                                       \
60   uint32_t vpx_highbd_12_mse##w##x##h##_sve(                          \
61       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
62       int ref_stride, uint32_t *sse) {                                \
63     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
64     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
65     uint32_t sse_tmp =                                                \
66         highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h);   \
67     sse_tmp = ROUND_POWER_OF_TWO(sse_tmp, 8);                         \
68     *sse = sse_tmp;                                                   \
69     return sse_tmp;                                                   \
70   }
71 
72 HIGHBD_MSE_WXH_SVE(16, 16)
73 HIGHBD_MSE_WXH_SVE(16, 8)
74 HIGHBD_MSE_WXH_SVE(8, 16)
75 HIGHBD_MSE_WXH_SVE(8, 8)
76 
77 #undef HIGHBD_MSE_WXH_SVE
78 
79 // Process a block of width 4 two rows at a time.
highbd_variance_4xh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,uint64_t * sse,int64_t * sum)80 static INLINE void highbd_variance_4xh_sve(const uint16_t *src_ptr,
81                                            int src_stride,
82                                            const uint16_t *ref_ptr,
83                                            int ref_stride, int h, uint64_t *sse,
84                                            int64_t *sum) {
85   int16x8_t sum_s16 = vdupq_n_s16(0);
86   int64x2_t sse_s64 = vdupq_n_s64(0);
87 
88   do {
89     const uint16x8_t s = load_unaligned_u16q(src_ptr, src_stride);
90     const uint16x8_t r = load_unaligned_u16q(ref_ptr, ref_stride);
91 
92     int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r));
93     sum_s16 = vaddq_s16(sum_s16, diff);
94     sse_s64 = vpx_dotq_s16(sse_s64, diff, diff);
95 
96     src_ptr += 2 * src_stride;
97     ref_ptr += 2 * ref_stride;
98     h -= 2;
99   } while (h != 0);
100 
101   *sum = horizontal_add_int16x8(sum_s16);
102   *sse = horizontal_add_int64x2(sse_s64);
103 }
104 
highbd_variance_8xh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,uint64_t * sse,int64_t * sum)105 static INLINE void highbd_variance_8xh_sve(const uint16_t *src_ptr,
106                                            int src_stride,
107                                            const uint16_t *ref_ptr,
108                                            int ref_stride, int h, uint64_t *sse,
109                                            int64_t *sum) {
110   int32x4_t sum_s32 = vdupq_n_s32(0);
111   int64x2_t sse_s64 = vdupq_n_s64(0);
112 
113   do {
114     const uint16x8_t s = vld1q_u16(src_ptr);
115     const uint16x8_t r = vld1q_u16(ref_ptr);
116 
117     const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r));
118     sum_s32 = vpadalq_s16(sum_s32, diff);
119     sse_s64 = vpx_dotq_s16(sse_s64, diff, diff);
120 
121     src_ptr += src_stride;
122     ref_ptr += ref_stride;
123   } while (--h != 0);
124 
125   *sum = horizontal_add_int32x4(sum_s32);
126   *sse = horizontal_add_int64x2(sse_s64);
127 }
128 
highbd_variance_16xh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,uint64_t * sse,int64_t * sum)129 static INLINE void highbd_variance_16xh_sve(const uint16_t *src_ptr,
130                                             int src_stride,
131                                             const uint16_t *ref_ptr,
132                                             int ref_stride, int h,
133                                             uint64_t *sse, int64_t *sum) {
134   int32x4_t sum_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
135   int64x2_t sse_s64[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
136 
137   do {
138     const uint16x8_t s0 = vld1q_u16(src_ptr);
139     const uint16x8_t s1 = vld1q_u16(src_ptr + 8);
140 
141     const uint16x8_t r0 = vld1q_u16(ref_ptr);
142     const uint16x8_t r1 = vld1q_u16(ref_ptr + 8);
143 
144     const int16x8_t diff0 = vreinterpretq_s16_u16(vsubq_u16(s0, r0));
145     const int16x8_t diff1 = vreinterpretq_s16_u16(vsubq_u16(s1, r1));
146 
147     sum_s32[0] = vpadalq_s16(sum_s32[0], diff0);
148     sum_s32[1] = vpadalq_s16(sum_s32[1], diff1);
149 
150     sse_s64[0] = vpx_dotq_s16(sse_s64[0], diff0, diff0);
151     sse_s64[1] = vpx_dotq_s16(sse_s64[1], diff1, diff1);
152 
153     src_ptr += src_stride;
154     ref_ptr += ref_stride;
155   } while (--h != 0);
156 
157   sum_s32[0] = vaddq_s32(sum_s32[0], sum_s32[1]);
158   sse_s64[0] = vaddq_s64(sse_s64[0], sse_s64[1]);
159 
160   *sum = horizontal_add_int32x4(sum_s32[0]);
161   *sse = horizontal_add_int64x2(sse_s64[0]);
162 }
163 
highbd_variance_wxh_sve(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int w,int h,uint64_t * sse,int64_t * sum)164 static INLINE void highbd_variance_wxh_sve(const uint16_t *src_ptr,
165                                            int src_stride,
166                                            const uint16_t *ref_ptr,
167                                            int ref_stride, int w, int h,
168                                            uint64_t *sse, int64_t *sum) {
169   int32x4_t sum_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
170                            vdupq_n_s32(0) };
171   int64x2_t sse_s64[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0),
172                            vdupq_n_s64(0) };
173 
174   do {
175     int i = 0;
176     do {
177       const uint16x8_t s0 = vld1q_u16(src_ptr + i);
178       const uint16x8_t s1 = vld1q_u16(src_ptr + i + 8);
179       const uint16x8_t s2 = vld1q_u16(src_ptr + i + 16);
180       const uint16x8_t s3 = vld1q_u16(src_ptr + i + 24);
181 
182       const uint16x8_t r0 = vld1q_u16(ref_ptr + i);
183       const uint16x8_t r1 = vld1q_u16(ref_ptr + i + 8);
184       const uint16x8_t r2 = vld1q_u16(ref_ptr + i + 16);
185       const uint16x8_t r3 = vld1q_u16(ref_ptr + i + 24);
186 
187       const int16x8_t diff0 = vreinterpretq_s16_u16(vsubq_u16(s0, r0));
188       const int16x8_t diff1 = vreinterpretq_s16_u16(vsubq_u16(s1, r1));
189       const int16x8_t diff2 = vreinterpretq_s16_u16(vsubq_u16(s2, r2));
190       const int16x8_t diff3 = vreinterpretq_s16_u16(vsubq_u16(s3, r3));
191 
192       sum_s32[0] = vpadalq_s16(sum_s32[0], diff0);
193       sum_s32[1] = vpadalq_s16(sum_s32[1], diff1);
194       sum_s32[2] = vpadalq_s16(sum_s32[2], diff2);
195       sum_s32[3] = vpadalq_s16(sum_s32[3], diff3);
196 
197       sse_s64[0] = vpx_dotq_s16(sse_s64[0], diff0, diff0);
198       sse_s64[1] = vpx_dotq_s16(sse_s64[1], diff1, diff1);
199       sse_s64[2] = vpx_dotq_s16(sse_s64[2], diff2, diff2);
200       sse_s64[3] = vpx_dotq_s16(sse_s64[3], diff3, diff3);
201 
202       i += 32;
203     } while (i < w);
204 
205     src_ptr += src_stride;
206     ref_ptr += ref_stride;
207   } while (--h != 0);
208 
209   sum_s32[0] = vaddq_s32(sum_s32[0], sum_s32[1]);
210   sum_s32[2] = vaddq_s32(sum_s32[2], sum_s32[3]);
211   sum_s32[0] = vaddq_s32(sum_s32[0], sum_s32[2]);
212 
213   sse_s64[0] = vaddq_s64(sse_s64[0], sse_s64[1]);
214   sse_s64[2] = vaddq_s64(sse_s64[2], sse_s64[3]);
215   sse_s64[0] = vaddq_s64(sse_s64[0], sse_s64[2]);
216 
217   *sum = horizontal_add_int32x4(sum_s32[0]);
218   *sse = horizontal_add_int64x2(sse_s64[0]);
219 }
220 
highbd_variance_32xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)221 static INLINE void highbd_variance_32xh_sve(const uint16_t *src, int src_stride,
222                                             const uint16_t *ref, int ref_stride,
223                                             int h, uint64_t *sse,
224                                             int64_t *sum) {
225   highbd_variance_wxh_sve(src, src_stride, ref, ref_stride, 32, h, sse, sum);
226 }
227 
highbd_variance_64xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)228 static INLINE void highbd_variance_64xh_sve(const uint16_t *src, int src_stride,
229                                             const uint16_t *ref, int ref_stride,
230                                             int h, uint64_t *sse,
231                                             int64_t *sum) {
232   highbd_variance_wxh_sve(src, src_stride, ref, ref_stride, 64, h, sse, sum);
233 }
234 
235 #define HBD_VARIANCE_WXH_SVE(w, h)                                    \
236   uint32_t vpx_highbd_8_variance##w##x##h##_sve(                      \
237       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
238       int ref_stride, uint32_t *sse) {                                \
239     int sum;                                                          \
240     uint64_t sse_long = 0;                                            \
241     int64_t sum_long = 0;                                             \
242     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
243     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
244     highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h,  \
245                                 &sse_long, &sum_long);                \
246     *sse = (uint32_t)sse_long;                                        \
247     sum = (int)sum_long;                                              \
248     return *sse - (uint32_t)(((int64_t)sum * sum) / (w * h));         \
249   }                                                                   \
250                                                                       \
251   uint32_t vpx_highbd_10_variance##w##x##h##_sve(                     \
252       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
253       int ref_stride, uint32_t *sse) {                                \
254     int sum;                                                          \
255     int64_t var;                                                      \
256     uint64_t sse_long = 0;                                            \
257     int64_t sum_long = 0;                                             \
258     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
259     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
260     highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h,  \
261                                 &sse_long, &sum_long);                \
262     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4);                 \
263     sum = (int)ROUND_POWER_OF_TWO(sum_long, 2);                       \
264     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h));         \
265     return (var >= 0) ? (uint32_t)var : 0;                            \
266   }                                                                   \
267                                                                       \
268   uint32_t vpx_highbd_12_variance##w##x##h##_sve(                     \
269       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
270       int ref_stride, uint32_t *sse) {                                \
271     int sum;                                                          \
272     int64_t var;                                                      \
273     uint64_t sse_long = 0;                                            \
274     int64_t sum_long = 0;                                             \
275     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
276     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
277     highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h,  \
278                                 &sse_long, &sum_long);                \
279     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);                 \
280     sum = (int)ROUND_POWER_OF_TWO(sum_long, 4);                       \
281     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h));         \
282     return (var >= 0) ? (uint32_t)var : 0;                            \
283   }
284 
285 HBD_VARIANCE_WXH_SVE(4, 4)
286 HBD_VARIANCE_WXH_SVE(4, 8)
287 
288 HBD_VARIANCE_WXH_SVE(8, 4)
289 HBD_VARIANCE_WXH_SVE(8, 8)
290 HBD_VARIANCE_WXH_SVE(8, 16)
291 
292 HBD_VARIANCE_WXH_SVE(16, 8)
293 HBD_VARIANCE_WXH_SVE(16, 16)
294 HBD_VARIANCE_WXH_SVE(16, 32)
295 
296 HBD_VARIANCE_WXH_SVE(32, 16)
297 HBD_VARIANCE_WXH_SVE(32, 32)
298 HBD_VARIANCE_WXH_SVE(32, 64)
299 
300 HBD_VARIANCE_WXH_SVE(64, 32)
301 HBD_VARIANCE_WXH_SVE(64, 64)
302 
303 #define HIGHBD_GET_VAR_SVE(s)                                         \
304   void vpx_highbd_8_get##s##x##s##var_sve(                            \
305       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
306       int ref_stride, uint32_t *sse, int *sum) {                      \
307     uint64_t sse_long = 0;                                            \
308     int64_t sum_long = 0;                                             \
309     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
310     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
311     highbd_variance_##s##xh_sve(src, src_stride, ref, ref_stride, s,  \
312                                 &sse_long, &sum_long);                \
313     *sse = (uint32_t)sse_long;                                        \
314     *sum = (int)sum_long;                                             \
315   }                                                                   \
316                                                                       \
317   void vpx_highbd_10_get##s##x##s##var_sve(                           \
318       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
319       int ref_stride, uint32_t *sse, int *sum) {                      \
320     uint64_t sse_long = 0;                                            \
321     int64_t sum_long = 0;                                             \
322     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
323     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
324     highbd_variance_##s##xh_sve(src, src_stride, ref, ref_stride, s,  \
325                                 &sse_long, &sum_long);                \
326     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4);                 \
327     *sum = (int)ROUND_POWER_OF_TWO(sum_long, 2);                      \
328   }                                                                   \
329                                                                       \
330   void vpx_highbd_12_get##s##x##s##var_sve(                           \
331       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
332       int ref_stride, uint32_t *sse, int *sum) {                      \
333     uint64_t sse_long = 0;                                            \
334     int64_t sum_long = 0;                                             \
335     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
336     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
337     highbd_variance_##s##xh_sve(src, src_stride, ref, ref_stride, s,  \
338                                 &sse_long, &sum_long);                \
339     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);                 \
340     *sum = (int)ROUND_POWER_OF_TWO(sum_long, 4);                      \
341   }
342 
343 HIGHBD_GET_VAR_SVE(8)
344 HIGHBD_GET_VAR_SVE(16)
345