xref: /aosp_15_r20/external/libaom/aom_dsp/arm/aom_convolve8_neon.h (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #ifndef AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_
13 #define AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_
14 
15 #include <arm_neon.h>
16 
17 #include "aom_dsp/aom_filter.h"
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "config/aom_config.h"
20 
convolve8_4(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t s4,const int16x4_t s5,const int16x4_t s6,const int16x4_t s7,const int16x8_t filter)21 static inline int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1,
22                                     const int16x4_t s2, const int16x4_t s3,
23                                     const int16x4_t s4, const int16x4_t s5,
24                                     const int16x4_t s6, const int16x4_t s7,
25                                     const int16x8_t filter) {
26   const int16x4_t filter_lo = vget_low_s16(filter);
27   const int16x4_t filter_hi = vget_high_s16(filter);
28 
29   int16x4_t sum = vmul_lane_s16(s0, filter_lo, 0);
30   sum = vmla_lane_s16(sum, s1, filter_lo, 1);
31   sum = vmla_lane_s16(sum, s2, filter_lo, 2);
32   sum = vmla_lane_s16(sum, s3, filter_lo, 3);
33   sum = vmla_lane_s16(sum, s4, filter_hi, 0);
34   sum = vmla_lane_s16(sum, s5, filter_hi, 1);
35   sum = vmla_lane_s16(sum, s6, filter_hi, 2);
36   sum = vmla_lane_s16(sum, s7, filter_hi, 3);
37 
38   return sum;
39 }
40 
convolve8_8(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x8_t s5,const int16x8_t s6,const int16x8_t s7,const int16x8_t filter)41 static inline uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1,
42                                     const int16x8_t s2, const int16x8_t s3,
43                                     const int16x8_t s4, const int16x8_t s5,
44                                     const int16x8_t s6, const int16x8_t s7,
45                                     const int16x8_t filter) {
46   const int16x4_t filter_lo = vget_low_s16(filter);
47   const int16x4_t filter_hi = vget_high_s16(filter);
48 
49   int16x8_t sum = vmulq_lane_s16(s0, filter_lo, 0);
50   sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
51   sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
52   sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
53   sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
54   sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
55   sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
56   sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
57 
58   // We halved the filter values so -1 from right shift.
59   return vqrshrun_n_s16(sum, FILTER_BITS - 1);
60 }
61 
convolve8_horiz_2tap_neon(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * filter_x,int w,int h)62 static inline void convolve8_horiz_2tap_neon(const uint8_t *src,
63                                              ptrdiff_t src_stride, uint8_t *dst,
64                                              ptrdiff_t dst_stride,
65                                              const int16_t *filter_x, int w,
66                                              int h) {
67   // Bilinear filter values are all positive.
68   const uint8x8_t f0 = vdup_n_u8((uint8_t)filter_x[3]);
69   const uint8x8_t f1 = vdup_n_u8((uint8_t)filter_x[4]);
70 
71   if (w == 4) {
72     do {
73       uint8x8_t s0 =
74           load_unaligned_u8(src + 0 * src_stride + 0, (int)src_stride);
75       uint8x8_t s1 =
76           load_unaligned_u8(src + 0 * src_stride + 1, (int)src_stride);
77       uint8x8_t s2 =
78           load_unaligned_u8(src + 2 * src_stride + 0, (int)src_stride);
79       uint8x8_t s3 =
80           load_unaligned_u8(src + 2 * src_stride + 1, (int)src_stride);
81 
82       uint16x8_t sum0 = vmull_u8(s0, f0);
83       sum0 = vmlal_u8(sum0, s1, f1);
84       uint16x8_t sum1 = vmull_u8(s2, f0);
85       sum1 = vmlal_u8(sum1, s3, f1);
86 
87       uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
88       uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
89 
90       store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d0);
91       store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d1);
92 
93       src += 4 * src_stride;
94       dst += 4 * dst_stride;
95       h -= 4;
96     } while (h > 0);
97   } else if (w == 8) {
98     do {
99       uint8x8_t s0 = vld1_u8(src + 0 * src_stride + 0);
100       uint8x8_t s1 = vld1_u8(src + 0 * src_stride + 1);
101       uint8x8_t s2 = vld1_u8(src + 1 * src_stride + 0);
102       uint8x8_t s3 = vld1_u8(src + 1 * src_stride + 1);
103 
104       uint16x8_t sum0 = vmull_u8(s0, f0);
105       sum0 = vmlal_u8(sum0, s1, f1);
106       uint16x8_t sum1 = vmull_u8(s2, f0);
107       sum1 = vmlal_u8(sum1, s3, f1);
108 
109       uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
110       uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
111 
112       vst1_u8(dst + 0 * dst_stride, d0);
113       vst1_u8(dst + 1 * dst_stride, d1);
114 
115       src += 2 * src_stride;
116       dst += 2 * dst_stride;
117       h -= 2;
118     } while (h > 0);
119   } else {
120     do {
121       int width = w;
122       const uint8_t *s = src;
123       uint8_t *d = dst;
124 
125       do {
126         uint8x16_t s0 = vld1q_u8(s + 0);
127         uint8x16_t s1 = vld1q_u8(s + 1);
128 
129         uint16x8_t sum0 = vmull_u8(vget_low_u8(s0), f0);
130         sum0 = vmlal_u8(sum0, vget_low_u8(s1), f1);
131         uint16x8_t sum1 = vmull_u8(vget_high_u8(s0), f0);
132         sum1 = vmlal_u8(sum1, vget_high_u8(s1), f1);
133 
134         uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
135         uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
136 
137         vst1q_u8(d, vcombine_u8(d0, d1));
138 
139         s += 16;
140         d += 16;
141         width -= 16;
142       } while (width != 0);
143       src += src_stride;
144       dst += dst_stride;
145     } while (--h > 0);
146   }
147 }
148 
convolve4_8(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x4_t filter)149 static inline uint8x8_t convolve4_8(const int16x8_t s0, const int16x8_t s1,
150                                     const int16x8_t s2, const int16x8_t s3,
151                                     const int16x4_t filter) {
152   int16x8_t sum = vmulq_lane_s16(s0, filter, 0);
153   sum = vmlaq_lane_s16(sum, s1, filter, 1);
154   sum = vmlaq_lane_s16(sum, s2, filter, 2);
155   sum = vmlaq_lane_s16(sum, s3, filter, 3);
156 
157   // We halved the filter values so -1 from right shift.
158   return vqrshrun_n_s16(sum, FILTER_BITS - 1);
159 }
160 
convolve8_vert_4tap_neon(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * filter_y,int w,int h)161 static inline void convolve8_vert_4tap_neon(const uint8_t *src,
162                                             ptrdiff_t src_stride, uint8_t *dst,
163                                             ptrdiff_t dst_stride,
164                                             const int16_t *filter_y, int w,
165                                             int h) {
166   // All filter values are even, halve to reduce intermediate precision
167   // requirements.
168   const int16x4_t filter = vshr_n_s16(vld1_s16(filter_y + 2), 1);
169 
170   if (w == 4) {
171     uint8x8_t t01 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride);
172     uint8x8_t t12 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride);
173 
174     int16x8_t s01 = vreinterpretq_s16_u16(vmovl_u8(t01));
175     int16x8_t s12 = vreinterpretq_s16_u16(vmovl_u8(t12));
176 
177     src += 2 * src_stride;
178 
179     do {
180       uint8x8_t t23 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride);
181       uint8x8_t t34 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride);
182       uint8x8_t t45 = load_unaligned_u8(src + 2 * src_stride, (int)src_stride);
183       uint8x8_t t56 = load_unaligned_u8(src + 3 * src_stride, (int)src_stride);
184 
185       int16x8_t s23 = vreinterpretq_s16_u16(vmovl_u8(t23));
186       int16x8_t s34 = vreinterpretq_s16_u16(vmovl_u8(t34));
187       int16x8_t s45 = vreinterpretq_s16_u16(vmovl_u8(t45));
188       int16x8_t s56 = vreinterpretq_s16_u16(vmovl_u8(t56));
189 
190       uint8x8_t d01 = convolve4_8(s01, s12, s23, s34, filter);
191       uint8x8_t d23 = convolve4_8(s23, s34, s45, s56, filter);
192 
193       store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
194       store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
195 
196       s01 = s45;
197       s12 = s56;
198 
199       src += 4 * src_stride;
200       dst += 4 * dst_stride;
201       h -= 4;
202     } while (h != 0);
203   } else {
204     do {
205       uint8x8_t t0, t1, t2;
206       load_u8_8x3(src, src_stride, &t0, &t1, &t2);
207 
208       int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
209       int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
210       int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
211 
212       int height = h;
213       const uint8_t *s = src + 3 * src_stride;
214       uint8_t *d = dst;
215 
216       do {
217         uint8x8_t t3;
218         load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3);
219 
220         int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t0));
221         int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t1));
222         int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t2));
223         int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t3));
224 
225         uint8x8_t d0 = convolve4_8(s0, s1, s2, s3, filter);
226         uint8x8_t d1 = convolve4_8(s1, s2, s3, s4, filter);
227         uint8x8_t d2 = convolve4_8(s2, s3, s4, s5, filter);
228         uint8x8_t d3 = convolve4_8(s3, s4, s5, s6, filter);
229 
230         store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
231 
232         s0 = s4;
233         s1 = s5;
234         s2 = s6;
235 
236         s += 4 * src_stride;
237         d += 4 * dst_stride;
238         height -= 4;
239       } while (height != 0);
240       src += 8;
241       dst += 8;
242       w -= 8;
243     } while (w != 0);
244   }
245 }
246 
convolve8_vert_2tap_neon(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * filter_y,int w,int h)247 static inline void convolve8_vert_2tap_neon(const uint8_t *src,
248                                             ptrdiff_t src_stride, uint8_t *dst,
249                                             ptrdiff_t dst_stride,
250                                             const int16_t *filter_y, int w,
251                                             int h) {
252   // Bilinear filter values are all positive.
253   uint8x8_t f0 = vdup_n_u8((uint8_t)filter_y[3]);
254   uint8x8_t f1 = vdup_n_u8((uint8_t)filter_y[4]);
255 
256   if (w == 4) {
257     do {
258       uint8x8_t s0 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride);
259       uint8x8_t s1 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride);
260       uint8x8_t s2 = load_unaligned_u8(src + 2 * src_stride, (int)src_stride);
261       uint8x8_t s3 = load_unaligned_u8(src + 3 * src_stride, (int)src_stride);
262 
263       uint16x8_t sum0 = vmull_u8(s0, f0);
264       sum0 = vmlal_u8(sum0, s1, f1);
265       uint16x8_t sum1 = vmull_u8(s2, f0);
266       sum1 = vmlal_u8(sum1, s3, f1);
267 
268       uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
269       uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
270 
271       store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d0);
272       store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d1);
273 
274       src += 4 * src_stride;
275       dst += 4 * dst_stride;
276       h -= 4;
277     } while (h > 0);
278   } else if (w == 8) {
279     do {
280       uint8x8_t s0, s1, s2;
281       load_u8_8x3(src, src_stride, &s0, &s1, &s2);
282 
283       uint16x8_t sum0 = vmull_u8(s0, f0);
284       sum0 = vmlal_u8(sum0, s1, f1);
285       uint16x8_t sum1 = vmull_u8(s1, f0);
286       sum1 = vmlal_u8(sum1, s2, f1);
287 
288       uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
289       uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
290 
291       vst1_u8(dst + 0 * dst_stride, d0);
292       vst1_u8(dst + 1 * dst_stride, d1);
293 
294       src += 2 * src_stride;
295       dst += 2 * dst_stride;
296       h -= 2;
297     } while (h > 0);
298   } else {
299     do {
300       int width = w;
301       const uint8_t *s = src;
302       uint8_t *d = dst;
303 
304       do {
305         uint8x16_t s0 = vld1q_u8(s + 0 * src_stride);
306         uint8x16_t s1 = vld1q_u8(s + 1 * src_stride);
307 
308         uint16x8_t sum0 = vmull_u8(vget_low_u8(s0), f0);
309         sum0 = vmlal_u8(sum0, vget_low_u8(s1), f1);
310         uint16x8_t sum1 = vmull_u8(vget_high_u8(s0), f0);
311         sum1 = vmlal_u8(sum1, vget_high_u8(s1), f1);
312 
313         uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
314         uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
315 
316         vst1q_u8(d, vcombine_u8(d0, d1));
317 
318         s += 16;
319         d += 16;
320         width -= 16;
321       } while (width != 0);
322       src += src_stride;
323       dst += dst_stride;
324     } while (--h > 0);
325   }
326 }
327 
328 #endif  // AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_
329