xref: /aosp_15_r20/external/libaom/av1/common/arm/highbd_compound_convolve_neon.h (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <arm_neon.h>
14 
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17 
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_ports/mem.h"
21 
22 #define ROUND_SHIFT 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS
23 
highbd_12_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params)24 static inline void highbd_12_comp_avg_neon(const uint16_t *src_ptr,
25                                            int src_stride, uint16_t *dst_ptr,
26                                            int dst_stride, int w, int h,
27                                            ConvolveParams *conv_params) {
28   const int offset_bits = 12 + 2 * FILTER_BITS - ROUND0_BITS - 2;
29   const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
30                      (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
31 
32   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
33   const int ref_stride = conv_params->dst_stride;
34   const uint16x4_t offset_vec = vdup_n_u16((uint16_t)offset);
35   const uint16x8_t max = vdupq_n_u16((1 << 12) - 1);
36 
37   if (w == 4) {
38     do {
39       const uint16x4_t src = vld1_u16(src_ptr);
40       const uint16x4_t ref = vld1_u16(ref_ptr);
41 
42       uint16x4_t avg = vhadd_u16(src, ref);
43       int32x4_t d0 = vreinterpretq_s32_u32(vsubl_u16(avg, offset_vec));
44 
45       uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT - 2);
46       d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
47 
48       vst1_u16(dst_ptr, d0_u16);
49 
50       src_ptr += src_stride;
51       ref_ptr += ref_stride;
52       dst_ptr += dst_stride;
53     } while (--h != 0);
54   } else {
55     do {
56       int width = w;
57       const uint16_t *src = src_ptr;
58       const uint16_t *ref = ref_ptr;
59       uint16_t *dst = dst_ptr;
60       do {
61         const uint16x8_t s = vld1q_u16(src);
62         const uint16x8_t r = vld1q_u16(ref);
63 
64         uint16x8_t avg = vhaddq_u16(s, r);
65         int32x4_t d0_lo =
66             vreinterpretq_s32_u32(vsubl_u16(vget_low_u16(avg), offset_vec));
67         int32x4_t d0_hi =
68             vreinterpretq_s32_u32(vsubl_u16(vget_high_u16(avg), offset_vec));
69 
70         uint16x8_t d0 = vcombine_u16(vqrshrun_n_s32(d0_lo, ROUND_SHIFT - 2),
71                                      vqrshrun_n_s32(d0_hi, ROUND_SHIFT - 2));
72         d0 = vminq_u16(d0, max);
73         vst1q_u16(dst, d0);
74 
75         src += 8;
76         ref += 8;
77         dst += 8;
78         width -= 8;
79       } while (width != 0);
80 
81       src_ptr += src_stride;
82       ref_ptr += ref_stride;
83       dst_ptr += dst_stride;
84     } while (--h != 0);
85   }
86 }
87 
highbd_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params,const int bd)88 static inline void highbd_comp_avg_neon(const uint16_t *src_ptr, int src_stride,
89                                         uint16_t *dst_ptr, int dst_stride,
90                                         int w, int h,
91                                         ConvolveParams *conv_params,
92                                         const int bd) {
93   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
94   const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
95                      (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
96 
97   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
98   const int ref_stride = conv_params->dst_stride;
99   const uint16x4_t offset_vec = vdup_n_u16((uint16_t)offset);
100   const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
101 
102   if (w == 4) {
103     do {
104       const uint16x4_t src = vld1_u16(src_ptr);
105       const uint16x4_t ref = vld1_u16(ref_ptr);
106 
107       uint16x4_t avg = vhadd_u16(src, ref);
108       int32x4_t d0 = vreinterpretq_s32_u32(vsubl_u16(avg, offset_vec));
109 
110       uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT);
111       d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
112 
113       vst1_u16(dst_ptr, d0_u16);
114 
115       src_ptr += src_stride;
116       ref_ptr += ref_stride;
117       dst_ptr += dst_stride;
118     } while (--h != 0);
119   } else {
120     do {
121       int width = w;
122       const uint16_t *src = src_ptr;
123       const uint16_t *ref = ref_ptr;
124       uint16_t *dst = dst_ptr;
125       do {
126         const uint16x8_t s = vld1q_u16(src);
127         const uint16x8_t r = vld1q_u16(ref);
128 
129         uint16x8_t avg = vhaddq_u16(s, r);
130         int32x4_t d0_lo =
131             vreinterpretq_s32_u32(vsubl_u16(vget_low_u16(avg), offset_vec));
132         int32x4_t d0_hi =
133             vreinterpretq_s32_u32(vsubl_u16(vget_high_u16(avg), offset_vec));
134 
135         uint16x8_t d0 = vcombine_u16(vqrshrun_n_s32(d0_lo, ROUND_SHIFT),
136                                      vqrshrun_n_s32(d0_hi, ROUND_SHIFT));
137         d0 = vminq_u16(d0, max);
138         vst1q_u16(dst, d0);
139 
140         src += 8;
141         ref += 8;
142         dst += 8;
143         width -= 8;
144       } while (width != 0);
145 
146       src_ptr += src_stride;
147       ref_ptr += ref_stride;
148       dst_ptr += dst_stride;
149     } while (--h != 0);
150   }
151 }
152 
highbd_12_dist_wtd_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params)153 static inline void highbd_12_dist_wtd_comp_avg_neon(
154     const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
155     int w, int h, ConvolveParams *conv_params) {
156   const int offset_bits = 12 + 2 * FILTER_BITS - ROUND0_BITS - 2;
157   const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
158                      (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
159 
160   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
161   const int ref_stride = conv_params->dst_stride;
162   const uint32x4_t offset_vec = vdupq_n_u32(offset);
163   const uint16x8_t max = vdupq_n_u16((1 << 12) - 1);
164   uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset);
165   uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset);
166 
167   // Weighted averaging
168   if (w == 4) {
169     do {
170       const uint16x4_t src = vld1_u16(src_ptr);
171       const uint16x4_t ref = vld1_u16(ref_ptr);
172 
173       uint32x4_t wtd_avg = vmull_u16(ref, fwd_offset);
174       wtd_avg = vmlal_u16(wtd_avg, src, bck_offset);
175       wtd_avg = vshrq_n_u32(wtd_avg, DIST_PRECISION_BITS);
176       int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg, offset_vec));
177 
178       uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT - 2);
179       d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
180 
181       vst1_u16(dst_ptr, d0_u16);
182 
183       src_ptr += src_stride;
184       dst_ptr += dst_stride;
185       ref_ptr += ref_stride;
186     } while (--h != 0);
187   } else {
188     do {
189       int width = w;
190       const uint16_t *src = src_ptr;
191       const uint16_t *ref = ref_ptr;
192       uint16_t *dst = dst_ptr;
193       do {
194         const uint16x8_t s = vld1q_u16(src);
195         const uint16x8_t r = vld1q_u16(ref);
196 
197         uint32x4_t wtd_avg0 = vmull_u16(vget_low_u16(r), fwd_offset);
198         wtd_avg0 = vmlal_u16(wtd_avg0, vget_low_u16(s), bck_offset);
199         wtd_avg0 = vshrq_n_u32(wtd_avg0, DIST_PRECISION_BITS);
200         int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg0, offset_vec));
201 
202         uint32x4_t wtd_avg1 = vmull_u16(vget_high_u16(r), fwd_offset);
203         wtd_avg1 = vmlal_u16(wtd_avg1, vget_high_u16(s), bck_offset);
204         wtd_avg1 = vshrq_n_u32(wtd_avg1, DIST_PRECISION_BITS);
205         int32x4_t d1 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg1, offset_vec));
206 
207         uint16x8_t d01 = vcombine_u16(vqrshrun_n_s32(d0, ROUND_SHIFT - 2),
208                                       vqrshrun_n_s32(d1, ROUND_SHIFT - 2));
209         d01 = vminq_u16(d01, max);
210         vst1q_u16(dst, d01);
211 
212         src += 8;
213         ref += 8;
214         dst += 8;
215         width -= 8;
216       } while (width != 0);
217       src_ptr += src_stride;
218       dst_ptr += dst_stride;
219       ref_ptr += ref_stride;
220     } while (--h != 0);
221   }
222 }
223 
highbd_dist_wtd_comp_avg_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,ConvolveParams * conv_params,const int bd)224 static inline void highbd_dist_wtd_comp_avg_neon(
225     const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
226     int w, int h, ConvolveParams *conv_params, const int bd) {
227   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
228   const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
229                      (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
230 
231   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
232   const int ref_stride = conv_params->dst_stride;
233   const uint32x4_t offset_vec = vdupq_n_u32(offset);
234   const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
235   uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset);
236   uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset);
237 
238   // Weighted averaging
239   if (w == 4) {
240     do {
241       const uint16x4_t src = vld1_u16(src_ptr);
242       const uint16x4_t ref = vld1_u16(ref_ptr);
243 
244       uint32x4_t wtd_avg = vmull_u16(ref, fwd_offset);
245       wtd_avg = vmlal_u16(wtd_avg, src, bck_offset);
246       wtd_avg = vshrq_n_u32(wtd_avg, DIST_PRECISION_BITS);
247       int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg, offset_vec));
248 
249       uint16x4_t d0_u16 = vqrshrun_n_s32(d0, ROUND_SHIFT);
250       d0_u16 = vmin_u16(d0_u16, vget_low_u16(max));
251 
252       vst1_u16(dst_ptr, d0_u16);
253 
254       src_ptr += src_stride;
255       dst_ptr += dst_stride;
256       ref_ptr += ref_stride;
257     } while (--h != 0);
258   } else {
259     do {
260       int width = w;
261       const uint16_t *src = src_ptr;
262       const uint16_t *ref = ref_ptr;
263       uint16_t *dst = dst_ptr;
264       do {
265         const uint16x8_t s = vld1q_u16(src);
266         const uint16x8_t r = vld1q_u16(ref);
267 
268         uint32x4_t wtd_avg0 = vmull_u16(vget_low_u16(r), fwd_offset);
269         wtd_avg0 = vmlal_u16(wtd_avg0, vget_low_u16(s), bck_offset);
270         wtd_avg0 = vshrq_n_u32(wtd_avg0, DIST_PRECISION_BITS);
271         int32x4_t d0 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg0, offset_vec));
272 
273         uint32x4_t wtd_avg1 = vmull_u16(vget_high_u16(r), fwd_offset);
274         wtd_avg1 = vmlal_u16(wtd_avg1, vget_high_u16(s), bck_offset);
275         wtd_avg1 = vshrq_n_u32(wtd_avg1, DIST_PRECISION_BITS);
276         int32x4_t d1 = vreinterpretq_s32_u32(vsubq_u32(wtd_avg1, offset_vec));
277 
278         uint16x8_t d01 = vcombine_u16(vqrshrun_n_s32(d0, ROUND_SHIFT),
279                                       vqrshrun_n_s32(d1, ROUND_SHIFT));
280         d01 = vminq_u16(d01, max);
281         vst1q_u16(dst, d01);
282 
283         src += 8;
284         ref += 8;
285         dst += 8;
286         width -= 8;
287       } while (width != 0);
288       src_ptr += src_stride;
289       dst_ptr += dst_stride;
290       ref_ptr += ref_stride;
291     } while (--h != 0);
292   }
293 }
294