xref: /aosp_15_r20/external/libaom/aom_dsp/arm/sse_neon_dotprod.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "aom_dsp/arm/sum_neon.h"
17 
sse_16x1_neon_dotprod(const uint8_t * src,const uint8_t * ref,uint32x4_t * sse)18 static inline void sse_16x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
19                                          uint32x4_t *sse) {
20   uint8x16_t s = vld1q_u8(src);
21   uint8x16_t r = vld1q_u8(ref);
22 
23   uint8x16_t abs_diff = vabdq_u8(s, r);
24 
25   *sse = vdotq_u32(*sse, abs_diff, abs_diff);
26 }
27 
sse_8x1_neon_dotprod(const uint8_t * src,const uint8_t * ref,uint32x2_t * sse)28 static inline void sse_8x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
29                                         uint32x2_t *sse) {
30   uint8x8_t s = vld1_u8(src);
31   uint8x8_t r = vld1_u8(ref);
32 
33   uint8x8_t abs_diff = vabd_u8(s, r);
34 
35   *sse = vdot_u32(*sse, abs_diff, abs_diff);
36 }
37 
sse_4x2_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,uint32x2_t * sse)38 static inline void sse_4x2_neon_dotprod(const uint8_t *src, int src_stride,
39                                         const uint8_t *ref, int ref_stride,
40                                         uint32x2_t *sse) {
41   uint8x8_t s = load_unaligned_u8(src, src_stride);
42   uint8x8_t r = load_unaligned_u8(ref, ref_stride);
43 
44   uint8x8_t abs_diff = vabd_u8(s, r);
45 
46   *sse = vdot_u32(*sse, abs_diff, abs_diff);
47 }
48 
sse_wxh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)49 static inline uint32_t sse_wxh_neon_dotprod(const uint8_t *src, int src_stride,
50                                             const uint8_t *ref, int ref_stride,
51                                             int width, int height) {
52   uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
53 
54   if ((width & 0x07) && ((width & 0x07) < 5)) {
55     int i = height;
56     do {
57       int j = 0;
58       do {
59         sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
60         sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
61                              &sse[1]);
62         j += 8;
63       } while (j + 4 < width);
64 
65       sse_4x2_neon_dotprod(src + j, src_stride, ref + j, ref_stride, &sse[0]);
66       src += 2 * src_stride;
67       ref += 2 * ref_stride;
68       i -= 2;
69     } while (i != 0);
70   } else {
71     int i = height;
72     do {
73       int j = 0;
74       do {
75         sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
76         sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
77                              &sse[1]);
78         j += 8;
79       } while (j < width);
80 
81       src += 2 * src_stride;
82       ref += 2 * ref_stride;
83       i -= 2;
84     } while (i != 0);
85   }
86   return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
87 }
88 
sse_128xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)89 static inline uint32_t sse_128xh_neon_dotprod(const uint8_t *src,
90                                               int src_stride,
91                                               const uint8_t *ref,
92                                               int ref_stride, int height) {
93   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
94 
95   int i = height;
96   do {
97     sse_16x1_neon_dotprod(src, ref, &sse[0]);
98     sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
99     sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
100     sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
101     sse_16x1_neon_dotprod(src + 64, ref + 64, &sse[0]);
102     sse_16x1_neon_dotprod(src + 80, ref + 80, &sse[1]);
103     sse_16x1_neon_dotprod(src + 96, ref + 96, &sse[0]);
104     sse_16x1_neon_dotprod(src + 112, ref + 112, &sse[1]);
105 
106     src += src_stride;
107     ref += ref_stride;
108   } while (--i != 0);
109 
110   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
111 }
112 
sse_64xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)113 static inline uint32_t sse_64xh_neon_dotprod(const uint8_t *src, int src_stride,
114                                              const uint8_t *ref, int ref_stride,
115                                              int height) {
116   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
117 
118   int i = height;
119   do {
120     sse_16x1_neon_dotprod(src, ref, &sse[0]);
121     sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
122     sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
123     sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
124 
125     src += src_stride;
126     ref += ref_stride;
127   } while (--i != 0);
128 
129   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
130 }
131 
sse_32xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)132 static inline uint32_t sse_32xh_neon_dotprod(const uint8_t *src, int src_stride,
133                                              const uint8_t *ref, int ref_stride,
134                                              int height) {
135   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
136 
137   int i = height;
138   do {
139     sse_16x1_neon_dotprod(src, ref, &sse[0]);
140     sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
141 
142     src += src_stride;
143     ref += ref_stride;
144   } while (--i != 0);
145 
146   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
147 }
148 
sse_16xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)149 static inline uint32_t sse_16xh_neon_dotprod(const uint8_t *src, int src_stride,
150                                              const uint8_t *ref, int ref_stride,
151                                              int height) {
152   uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
153 
154   int i = height;
155   do {
156     sse_16x1_neon_dotprod(src, ref, &sse[0]);
157     src += src_stride;
158     ref += ref_stride;
159     sse_16x1_neon_dotprod(src, ref, &sse[1]);
160     src += src_stride;
161     ref += ref_stride;
162     i -= 2;
163   } while (i != 0);
164 
165   return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
166 }
167 
sse_8xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)168 static inline uint32_t sse_8xh_neon_dotprod(const uint8_t *src, int src_stride,
169                                             const uint8_t *ref, int ref_stride,
170                                             int height) {
171   uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
172 
173   int i = height;
174   do {
175     sse_8x1_neon_dotprod(src, ref, &sse[0]);
176     src += src_stride;
177     ref += ref_stride;
178     sse_8x1_neon_dotprod(src, ref, &sse[1]);
179     src += src_stride;
180     ref += ref_stride;
181     i -= 2;
182   } while (i != 0);
183 
184   return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
185 }
186 
sse_4xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)187 static inline uint32_t sse_4xh_neon_dotprod(const uint8_t *src, int src_stride,
188                                             const uint8_t *ref, int ref_stride,
189                                             int height) {
190   uint32x2_t sse = vdup_n_u32(0);
191 
192   int i = height;
193   do {
194     sse_4x2_neon_dotprod(src, src_stride, ref, ref_stride, &sse);
195 
196     src += 2 * src_stride;
197     ref += 2 * ref_stride;
198     i -= 2;
199   } while (i != 0);
200 
201   return horizontal_add_u32x2(sse);
202 }
203 
aom_sse_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)204 int64_t aom_sse_neon_dotprod(const uint8_t *src, int src_stride,
205                              const uint8_t *ref, int ref_stride, int width,
206                              int height) {
207   switch (width) {
208     case 4:
209       return sse_4xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
210     case 8:
211       return sse_8xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
212     case 16:
213       return sse_16xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
214     case 32:
215       return sse_32xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
216     case 64:
217       return sse_64xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
218     case 128:
219       return sse_128xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
220     default:
221       return sse_wxh_neon_dotprod(src, src_stride, ref, ref_stride, width,
222                                   height);
223   }
224 }
225