1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "config/aom_dsp_rtcd.h"
16 #include "config/aom_config.h"
17
18 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
19 #include "aom_dsp/arm/mem_neon.h"
20
get_blk_sse_sum_4xh_sve(const int16_t * data,int stride,int bh,int * x_sum,int64_t * x2_sum)21 static inline void get_blk_sse_sum_4xh_sve(const int16_t *data, int stride,
22 int bh, int *x_sum,
23 int64_t *x2_sum) {
24 int32x4_t sum = vdupq_n_s32(0);
25 int64x2_t sse = vdupq_n_s64(0);
26
27 do {
28 int16x8_t d = vcombine_s16(vld1_s16(data), vld1_s16(data + stride));
29
30 sum = vpadalq_s16(sum, d);
31
32 sse = aom_sdotq_s16(sse, d, d);
33
34 data += 2 * stride;
35 bh -= 2;
36 } while (bh != 0);
37
38 *x_sum = vaddvq_s32(sum);
39 *x2_sum = vaddvq_s64(sse);
40 }
41
get_blk_sse_sum_8xh_sve(const int16_t * data,int stride,int bh,int * x_sum,int64_t * x2_sum)42 static inline void get_blk_sse_sum_8xh_sve(const int16_t *data, int stride,
43 int bh, int *x_sum,
44 int64_t *x2_sum) {
45 int32x4_t sum[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
46 int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
47
48 do {
49 int16x8_t d0 = vld1q_s16(data);
50 int16x8_t d1 = vld1q_s16(data + stride);
51
52 sum[0] = vpadalq_s16(sum[0], d0);
53 sum[1] = vpadalq_s16(sum[1], d1);
54
55 sse[0] = aom_sdotq_s16(sse[0], d0, d0);
56 sse[1] = aom_sdotq_s16(sse[1], d1, d1);
57
58 data += 2 * stride;
59 bh -= 2;
60 } while (bh != 0);
61
62 *x_sum = vaddvq_s32(vaddq_s32(sum[0], sum[1]));
63 *x2_sum = vaddvq_s64(vaddq_s64(sse[0], sse[1]));
64 }
65
get_blk_sse_sum_large_sve(const int16_t * data,int stride,int bw,int bh,int * x_sum,int64_t * x2_sum)66 static inline void get_blk_sse_sum_large_sve(const int16_t *data, int stride,
67 int bw, int bh, int *x_sum,
68 int64_t *x2_sum) {
69 int32x4_t sum[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
70 int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
71
72 do {
73 int j = bw;
74 const int16_t *data_ptr = data;
75 do {
76 int16x8_t d0 = vld1q_s16(data_ptr);
77 int16x8_t d1 = vld1q_s16(data_ptr + 8);
78
79 sum[0] = vpadalq_s16(sum[0], d0);
80 sum[1] = vpadalq_s16(sum[1], d1);
81
82 sse[0] = aom_sdotq_s16(sse[0], d0, d0);
83 sse[1] = aom_sdotq_s16(sse[1], d1, d1);
84
85 data_ptr += 16;
86 j -= 16;
87 } while (j != 0);
88
89 data += stride;
90 } while (--bh != 0);
91
92 *x_sum = vaddvq_s32(vaddq_s32(sum[0], sum[1]));
93 *x2_sum = vaddvq_s64(vaddq_s64(sse[0], sse[1]));
94 }
95
aom_get_blk_sse_sum_sve(const int16_t * data,int stride,int bw,int bh,int * x_sum,int64_t * x2_sum)96 void aom_get_blk_sse_sum_sve(const int16_t *data, int stride, int bw, int bh,
97 int *x_sum, int64_t *x2_sum) {
98 if (bw == 4) {
99 get_blk_sse_sum_4xh_sve(data, stride, bh, x_sum, x2_sum);
100 } else if (bw == 8) {
101 get_blk_sse_sum_8xh_sve(data, stride, bh, x_sum, x2_sum);
102 } else {
103 assert(bw % 16 == 0);
104 get_blk_sse_sum_large_sve(data, stride, bw, bh, x_sum, x2_sum);
105 }
106 }
107