1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "config/aom_dsp_rtcd.h"
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "aom_dsp/arm/sum_neon.h"
17
sse_16x1_neon_dotprod(const uint8_t * src,const uint8_t * ref,uint32x4_t * sse)18 static inline void sse_16x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
19 uint32x4_t *sse) {
20 uint8x16_t s = vld1q_u8(src);
21 uint8x16_t r = vld1q_u8(ref);
22
23 uint8x16_t abs_diff = vabdq_u8(s, r);
24
25 *sse = vdotq_u32(*sse, abs_diff, abs_diff);
26 }
27
sse_8x1_neon_dotprod(const uint8_t * src,const uint8_t * ref,uint32x2_t * sse)28 static inline void sse_8x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
29 uint32x2_t *sse) {
30 uint8x8_t s = vld1_u8(src);
31 uint8x8_t r = vld1_u8(ref);
32
33 uint8x8_t abs_diff = vabd_u8(s, r);
34
35 *sse = vdot_u32(*sse, abs_diff, abs_diff);
36 }
37
sse_4x2_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,uint32x2_t * sse)38 static inline void sse_4x2_neon_dotprod(const uint8_t *src, int src_stride,
39 const uint8_t *ref, int ref_stride,
40 uint32x2_t *sse) {
41 uint8x8_t s = load_unaligned_u8(src, src_stride);
42 uint8x8_t r = load_unaligned_u8(ref, ref_stride);
43
44 uint8x8_t abs_diff = vabd_u8(s, r);
45
46 *sse = vdot_u32(*sse, abs_diff, abs_diff);
47 }
48
sse_wxh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)49 static inline uint32_t sse_wxh_neon_dotprod(const uint8_t *src, int src_stride,
50 const uint8_t *ref, int ref_stride,
51 int width, int height) {
52 uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
53
54 if ((width & 0x07) && ((width & 0x07) < 5)) {
55 int i = height;
56 do {
57 int j = 0;
58 do {
59 sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
60 sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
61 &sse[1]);
62 j += 8;
63 } while (j + 4 < width);
64
65 sse_4x2_neon_dotprod(src + j, src_stride, ref + j, ref_stride, &sse[0]);
66 src += 2 * src_stride;
67 ref += 2 * ref_stride;
68 i -= 2;
69 } while (i != 0);
70 } else {
71 int i = height;
72 do {
73 int j = 0;
74 do {
75 sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
76 sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
77 &sse[1]);
78 j += 8;
79 } while (j < width);
80
81 src += 2 * src_stride;
82 ref += 2 * ref_stride;
83 i -= 2;
84 } while (i != 0);
85 }
86 return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
87 }
88
sse_128xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)89 static inline uint32_t sse_128xh_neon_dotprod(const uint8_t *src,
90 int src_stride,
91 const uint8_t *ref,
92 int ref_stride, int height) {
93 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
94
95 int i = height;
96 do {
97 sse_16x1_neon_dotprod(src, ref, &sse[0]);
98 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
99 sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
100 sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
101 sse_16x1_neon_dotprod(src + 64, ref + 64, &sse[0]);
102 sse_16x1_neon_dotprod(src + 80, ref + 80, &sse[1]);
103 sse_16x1_neon_dotprod(src + 96, ref + 96, &sse[0]);
104 sse_16x1_neon_dotprod(src + 112, ref + 112, &sse[1]);
105
106 src += src_stride;
107 ref += ref_stride;
108 } while (--i != 0);
109
110 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
111 }
112
sse_64xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)113 static inline uint32_t sse_64xh_neon_dotprod(const uint8_t *src, int src_stride,
114 const uint8_t *ref, int ref_stride,
115 int height) {
116 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
117
118 int i = height;
119 do {
120 sse_16x1_neon_dotprod(src, ref, &sse[0]);
121 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
122 sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
123 sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
124
125 src += src_stride;
126 ref += ref_stride;
127 } while (--i != 0);
128
129 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
130 }
131
sse_32xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)132 static inline uint32_t sse_32xh_neon_dotprod(const uint8_t *src, int src_stride,
133 const uint8_t *ref, int ref_stride,
134 int height) {
135 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
136
137 int i = height;
138 do {
139 sse_16x1_neon_dotprod(src, ref, &sse[0]);
140 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
141
142 src += src_stride;
143 ref += ref_stride;
144 } while (--i != 0);
145
146 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
147 }
148
sse_16xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)149 static inline uint32_t sse_16xh_neon_dotprod(const uint8_t *src, int src_stride,
150 const uint8_t *ref, int ref_stride,
151 int height) {
152 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
153
154 int i = height;
155 do {
156 sse_16x1_neon_dotprod(src, ref, &sse[0]);
157 src += src_stride;
158 ref += ref_stride;
159 sse_16x1_neon_dotprod(src, ref, &sse[1]);
160 src += src_stride;
161 ref += ref_stride;
162 i -= 2;
163 } while (i != 0);
164
165 return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
166 }
167
sse_8xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)168 static inline uint32_t sse_8xh_neon_dotprod(const uint8_t *src, int src_stride,
169 const uint8_t *ref, int ref_stride,
170 int height) {
171 uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
172
173 int i = height;
174 do {
175 sse_8x1_neon_dotprod(src, ref, &sse[0]);
176 src += src_stride;
177 ref += ref_stride;
178 sse_8x1_neon_dotprod(src, ref, &sse[1]);
179 src += src_stride;
180 ref += ref_stride;
181 i -= 2;
182 } while (i != 0);
183
184 return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
185 }
186
sse_4xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)187 static inline uint32_t sse_4xh_neon_dotprod(const uint8_t *src, int src_stride,
188 const uint8_t *ref, int ref_stride,
189 int height) {
190 uint32x2_t sse = vdup_n_u32(0);
191
192 int i = height;
193 do {
194 sse_4x2_neon_dotprod(src, src_stride, ref, ref_stride, &sse);
195
196 src += 2 * src_stride;
197 ref += 2 * ref_stride;
198 i -= 2;
199 } while (i != 0);
200
201 return horizontal_add_u32x2(sse);
202 }
203
aom_sse_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)204 int64_t aom_sse_neon_dotprod(const uint8_t *src, int src_stride,
205 const uint8_t *ref, int ref_stride, int width,
206 int height) {
207 switch (width) {
208 case 4:
209 return sse_4xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
210 case 8:
211 return sse_8xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
212 case 16:
213 return sse_16xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
214 case 32:
215 return sse_32xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
216 case 64:
217 return sse_64xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
218 case 128:
219 return sse_128xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
220 default:
221 return sse_wxh_neon_dotprod(src, src_stride, ref, ref_stride, width,
222 height);
223 }
224 }
225