xref: /aosp_15_r20/external/libaom/aom_dsp/arm/blk_sse_sum_sve.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "config/aom_dsp_rtcd.h"
16 #include "config/aom_config.h"
17 
18 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 
get_blk_sse_sum_4xh_sve(const int16_t * data,int stride,int bh,int * x_sum,int64_t * x2_sum)21 static inline void get_blk_sse_sum_4xh_sve(const int16_t *data, int stride,
22                                            int bh, int *x_sum,
23                                            int64_t *x2_sum) {
24   int32x4_t sum = vdupq_n_s32(0);
25   int64x2_t sse = vdupq_n_s64(0);
26 
27   do {
28     int16x8_t d = vcombine_s16(vld1_s16(data), vld1_s16(data + stride));
29 
30     sum = vpadalq_s16(sum, d);
31 
32     sse = aom_sdotq_s16(sse, d, d);
33 
34     data += 2 * stride;
35     bh -= 2;
36   } while (bh != 0);
37 
38   *x_sum = vaddvq_s32(sum);
39   *x2_sum = vaddvq_s64(sse);
40 }
41 
get_blk_sse_sum_8xh_sve(const int16_t * data,int stride,int bh,int * x_sum,int64_t * x2_sum)42 static inline void get_blk_sse_sum_8xh_sve(const int16_t *data, int stride,
43                                            int bh, int *x_sum,
44                                            int64_t *x2_sum) {
45   int32x4_t sum[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
46   int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
47 
48   do {
49     int16x8_t d0 = vld1q_s16(data);
50     int16x8_t d1 = vld1q_s16(data + stride);
51 
52     sum[0] = vpadalq_s16(sum[0], d0);
53     sum[1] = vpadalq_s16(sum[1], d1);
54 
55     sse[0] = aom_sdotq_s16(sse[0], d0, d0);
56     sse[1] = aom_sdotq_s16(sse[1], d1, d1);
57 
58     data += 2 * stride;
59     bh -= 2;
60   } while (bh != 0);
61 
62   *x_sum = vaddvq_s32(vaddq_s32(sum[0], sum[1]));
63   *x2_sum = vaddvq_s64(vaddq_s64(sse[0], sse[1]));
64 }
65 
get_blk_sse_sum_large_sve(const int16_t * data,int stride,int bw,int bh,int * x_sum,int64_t * x2_sum)66 static inline void get_blk_sse_sum_large_sve(const int16_t *data, int stride,
67                                              int bw, int bh, int *x_sum,
68                                              int64_t *x2_sum) {
69   int32x4_t sum[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
70   int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
71 
72   do {
73     int j = bw;
74     const int16_t *data_ptr = data;
75     do {
76       int16x8_t d0 = vld1q_s16(data_ptr);
77       int16x8_t d1 = vld1q_s16(data_ptr + 8);
78 
79       sum[0] = vpadalq_s16(sum[0], d0);
80       sum[1] = vpadalq_s16(sum[1], d1);
81 
82       sse[0] = aom_sdotq_s16(sse[0], d0, d0);
83       sse[1] = aom_sdotq_s16(sse[1], d1, d1);
84 
85       data_ptr += 16;
86       j -= 16;
87     } while (j != 0);
88 
89     data += stride;
90   } while (--bh != 0);
91 
92   *x_sum = vaddvq_s32(vaddq_s32(sum[0], sum[1]));
93   *x2_sum = vaddvq_s64(vaddq_s64(sse[0], sse[1]));
94 }
95 
aom_get_blk_sse_sum_sve(const int16_t * data,int stride,int bw,int bh,int * x_sum,int64_t * x2_sum)96 void aom_get_blk_sse_sum_sve(const int16_t *data, int stride, int bw, int bh,
97                              int *x_sum, int64_t *x2_sum) {
98   if (bw == 4) {
99     get_blk_sse_sum_4xh_sve(data, stride, bh, x_sum, x2_sum);
100   } else if (bw == 8) {
101     get_blk_sse_sum_8xh_sve(data, stride, bh, x_sum, x2_sum);
102   } else {
103     assert(bw % 16 == 0);
104     get_blk_sse_sum_large_sve(data, stride, bw, bh, x_sum, x2_sum);
105   }
106 }
107