xref: /aosp_15_r20/external/libaom/av1/common/arm/resize_neon_dotprod.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "aom_dsp/arm/transpose_neon.h"
17 #include "av1/common/arm/resize_neon.h"
18 #include "av1/common/resize.h"
19 #include "config/aom_scale_rtcd.h"
20 #include "config/av1_rtcd.h"
21 
22 // clang-format off
23 DECLARE_ALIGNED(16, static const uint8_t, kScale2DotProdPermuteTbl[32]) = {
24   0, 1, 2, 3, 2, 3, 4, 5, 4, 5,  6,  7,  6,  7,  8,  9,
25   4, 5, 6, 7, 6, 7, 8, 9, 8, 9, 10, 11, 10, 11, 12, 13
26 };
27 DECLARE_ALIGNED(16, static const uint8_t, kScale4DotProdPermuteTbl[16]) = {
28   0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10, 11
29 };
30 // clang-format on
31 
scale_2_to_1_filter8_8(const uint8x16_t s0,const uint8x16_t s1,const uint8x16x2_t permute_tbl,const int8x8_t filter)32 static inline uint8x8_t scale_2_to_1_filter8_8(const uint8x16_t s0,
33                                                const uint8x16_t s1,
34                                                const uint8x16x2_t permute_tbl,
35                                                const int8x8_t filter) {
36   // Transform sample range to [-128, 127] for 8-bit signed dot product.
37   int8x16_t s0_128 = vreinterpretq_s8_u8(vsubq_u8(s0, vdupq_n_u8(128)));
38   int8x16_t s1_128 = vreinterpretq_s8_u8(vsubq_u8(s1, vdupq_n_u8(128)));
39 
40   // Permute samples ready for dot product.
41   int8x16_t perm_samples[4] = { vqtbl1q_s8(s0_128, permute_tbl.val[0]),
42                                 vqtbl1q_s8(s0_128, permute_tbl.val[1]),
43                                 vqtbl1q_s8(s1_128, permute_tbl.val[0]),
44                                 vqtbl1q_s8(s1_128, permute_tbl.val[1]) };
45 
46   // Dot product constant:
47   // The shim of 128 << FILTER_BITS is needed because we are subtracting 128
48   // from every source value. The additional right shift by one is needed
49   // because we halve the filter values.
50   const int32x4_t acc = vdupq_n_s32((128 << FILTER_BITS) >> 1);
51 
52   // First 4 output values.
53   int32x4_t sum0123 = vdotq_lane_s32(acc, perm_samples[0], filter, 0);
54   sum0123 = vdotq_lane_s32(sum0123, perm_samples[1], filter, 1);
55   // Second 4 output values.
56   int32x4_t sum4567 = vdotq_lane_s32(acc, perm_samples[2], filter, 0);
57   sum4567 = vdotq_lane_s32(sum4567, perm_samples[3], filter, 1);
58 
59   int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
60 
61   // We halved the filter values so -1 from right shift.
62   return vqrshrun_n_s16(sum, FILTER_BITS - 1);
63 }
64 
scale_2_to_1_horiz_8tap(const uint8_t * src,const int src_stride,int w,int h,uint8_t * dst,const int dst_stride,const int16x8_t filters)65 static inline void scale_2_to_1_horiz_8tap(const uint8_t *src,
66                                            const int src_stride, int w, int h,
67                                            uint8_t *dst, const int dst_stride,
68                                            const int16x8_t filters) {
69   const int8x8_t filter = vmovn_s16(filters);
70   const uint8x16x2_t permute_tbl = vld1q_u8_x2(kScale2DotProdPermuteTbl);
71 
72   do {
73     const uint8_t *s = src;
74     uint8_t *d = dst;
75     int width = w;
76     do {
77       uint8x16_t s0[2], s1[2], s2[2], s3[2], s4[2], s5[2], s6[2], s7[2];
78       load_u8_16x8(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0], &s4[0],
79                    &s5[0], &s6[0], &s7[0]);
80       load_u8_16x8(s + 8, src_stride, &s0[1], &s1[1], &s2[1], &s3[1], &s4[1],
81                    &s5[1], &s6[1], &s7[1]);
82 
83       uint8x8_t d0 = scale_2_to_1_filter8_8(s0[0], s0[1], permute_tbl, filter);
84       uint8x8_t d1 = scale_2_to_1_filter8_8(s1[0], s1[1], permute_tbl, filter);
85       uint8x8_t d2 = scale_2_to_1_filter8_8(s2[0], s2[1], permute_tbl, filter);
86       uint8x8_t d3 = scale_2_to_1_filter8_8(s3[0], s3[1], permute_tbl, filter);
87 
88       uint8x8_t d4 = scale_2_to_1_filter8_8(s4[0], s4[1], permute_tbl, filter);
89       uint8x8_t d5 = scale_2_to_1_filter8_8(s5[0], s5[1], permute_tbl, filter);
90       uint8x8_t d6 = scale_2_to_1_filter8_8(s6[0], s6[1], permute_tbl, filter);
91       uint8x8_t d7 = scale_2_to_1_filter8_8(s7[0], s7[1], permute_tbl, filter);
92 
93       store_u8_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
94 
95       d += 8;
96       s += 16;
97       width -= 8;
98     } while (width > 0);
99 
100     dst += 8 * dst_stride;
101     src += 8 * src_stride;
102     h -= 8;
103   } while (h > 0);
104 }
105 
scale_plane_2_to_1_8tap(const uint8_t * src,const int src_stride,uint8_t * dst,const int dst_stride,const int w,const int h,const int16_t * const filter_ptr,uint8_t * const im_block)106 static inline void scale_plane_2_to_1_8tap(const uint8_t *src,
107                                            const int src_stride, uint8_t *dst,
108                                            const int dst_stride, const int w,
109                                            const int h,
110                                            const int16_t *const filter_ptr,
111                                            uint8_t *const im_block) {
112   assert(w > 0 && h > 0);
113 
114   const int im_h = 2 * h + SUBPEL_TAPS - 3;
115   const int im_stride = (w + 7) & ~7;
116   // All filter values are even, halve them to fit in int8_t when applying
117   // horizontal filter and stay in 16-bit elements when applying vertical
118   // filter.
119   const int16x8_t filters = vshrq_n_s16(vld1q_s16(filter_ptr), 1);
120 
121   const ptrdiff_t horiz_offset = SUBPEL_TAPS / 2 - 1;
122   const ptrdiff_t vert_offset = (SUBPEL_TAPS / 2 - 1) * src_stride;
123 
124   scale_2_to_1_horiz_8tap(src - horiz_offset - vert_offset, src_stride, w, im_h,
125                           im_block, im_stride, filters);
126 
127   // We can specialise the vertical filtering for 6-tap filters given that the
128   // EIGHTTAP_SMOOTH and EIGHTTAP_REGULAR filters are 0-padded.
129   scale_2_to_1_vert_6tap(im_block + im_stride, im_stride, w, h, dst, dst_stride,
130                          filters);
131 }
132 
scale_4_to_1_filter8_8(const uint8x16_t s0,const uint8x16_t s1,const uint8x16_t s2,const uint8x16_t s3,const uint8x16_t permute_tbl,const int8x8_t filter)133 static inline uint8x8_t scale_4_to_1_filter8_8(
134     const uint8x16_t s0, const uint8x16_t s1, const uint8x16_t s2,
135     const uint8x16_t s3, const uint8x16_t permute_tbl, const int8x8_t filter) {
136   int8x16_t filters = vcombine_s8(filter, filter);
137 
138   // Transform sample range to [-128, 127] for 8-bit signed dot product.
139   int8x16_t s0_128 = vreinterpretq_s8_u8(vsubq_u8(s0, vdupq_n_u8(128)));
140   int8x16_t s1_128 = vreinterpretq_s8_u8(vsubq_u8(s1, vdupq_n_u8(128)));
141   int8x16_t s2_128 = vreinterpretq_s8_u8(vsubq_u8(s2, vdupq_n_u8(128)));
142   int8x16_t s3_128 = vreinterpretq_s8_u8(vsubq_u8(s3, vdupq_n_u8(128)));
143 
144   int8x16_t perm_samples[4] = { vqtbl1q_s8(s0_128, permute_tbl),
145                                 vqtbl1q_s8(s1_128, permute_tbl),
146                                 vqtbl1q_s8(s2_128, permute_tbl),
147                                 vqtbl1q_s8(s3_128, permute_tbl) };
148 
149   // Dot product constant:
150   // The shim of 128 << FILTER_BITS is needed because we are subtracting 128
151   // from every source value. The additional right shift by one is needed
152   // because we halved the filter values and will use a pairwise add.
153   const int32x4_t acc = vdupq_n_s32((128 << FILTER_BITS) >> 2);
154 
155   int32x4_t sum0 = vdotq_s32(acc, perm_samples[0], filters);
156   int32x4_t sum1 = vdotq_s32(acc, perm_samples[1], filters);
157   int32x4_t sum2 = vdotq_s32(acc, perm_samples[2], filters);
158   int32x4_t sum3 = vdotq_s32(acc, perm_samples[3], filters);
159 
160   int32x4_t sum01 = vpaddq_s32(sum0, sum1);
161   int32x4_t sum23 = vpaddq_s32(sum2, sum3);
162 
163   int16x8_t sum = vcombine_s16(vmovn_s32(sum01), vmovn_s32(sum23));
164 
165   // We halved the filter values so -1 from right shift.
166   return vqrshrun_n_s16(sum, FILTER_BITS - 1);
167 }
168 
scale_4_to_1_horiz_8tap(const uint8_t * src,const int src_stride,int w,int h,uint8_t * dst,const int dst_stride,const int16x8_t filters)169 static inline void scale_4_to_1_horiz_8tap(const uint8_t *src,
170                                            const int src_stride, int w, int h,
171                                            uint8_t *dst, const int dst_stride,
172                                            const int16x8_t filters) {
173   const int8x8_t filter = vmovn_s16(filters);
174   const uint8x16_t permute_tbl = vld1q_u8(kScale4DotProdPermuteTbl);
175 
176   do {
177     const uint8_t *s = src;
178     uint8_t *d = dst;
179     int width = w;
180 
181     do {
182       uint8x16_t s0, s1, s2, s3, s4, s5, s6, s7;
183       load_u8_16x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
184 
185       uint8x8_t d0 =
186           scale_4_to_1_filter8_8(s0, s1, s2, s3, permute_tbl, filter);
187       uint8x8_t d1 =
188           scale_4_to_1_filter8_8(s4, s5, s6, s7, permute_tbl, filter);
189 
190       store_u8x2_strided_x4(d + 0 * dst_stride, dst_stride, d0);
191       store_u8x2_strided_x4(d + 4 * dst_stride, dst_stride, d1);
192 
193       d += 2;
194       s += 8;
195       width -= 2;
196     } while (width > 0);
197 
198     dst += 8 * dst_stride;
199     src += 8 * src_stride;
200     h -= 8;
201   } while (h > 0);
202 }
203 
scale_plane_4_to_1_8tap(const uint8_t * src,const int src_stride,uint8_t * dst,const int dst_stride,const int w,const int h,const int16_t * const filter_ptr,uint8_t * const im_block)204 static inline void scale_plane_4_to_1_8tap(const uint8_t *src,
205                                            const int src_stride, uint8_t *dst,
206                                            const int dst_stride, const int w,
207                                            const int h,
208                                            const int16_t *const filter_ptr,
209                                            uint8_t *const im_block) {
210   assert(w > 0 && h > 0);
211   const int im_h = 4 * h + SUBPEL_TAPS - 2;
212   const int im_stride = (w + 1) & ~1;
213   // All filter values are even, halve them to fit in int8_t when applying
214   // horizontal filter and stay in 16-bit elements when applying vertical
215   // filter.
216   const int16x8_t filters = vshrq_n_s16(vld1q_s16(filter_ptr), 1);
217 
218   const ptrdiff_t horiz_offset = SUBPEL_TAPS / 2 - 1;
219   const ptrdiff_t vert_offset = (SUBPEL_TAPS / 2 - 1) * src_stride;
220 
221   scale_4_to_1_horiz_8tap(src - horiz_offset - vert_offset, src_stride, w, im_h,
222                           im_block, im_stride, filters);
223 
224   // We can specialise the vertical filtering for 6-tap filters given that the
225   // EIGHTTAP_SMOOTH and EIGHTTAP_REGULAR filters are 0-padded.
226   scale_4_to_1_vert_6tap(im_block + im_stride, im_stride, w, h, dst, dst_stride,
227                          filters);
228 }
229 
has_normative_scaler_neon_dotprod(const int src_width,const int src_height,const int dst_width,const int dst_height)230 static inline bool has_normative_scaler_neon_dotprod(const int src_width,
231                                                      const int src_height,
232                                                      const int dst_width,
233                                                      const int dst_height) {
234   return (2 * dst_width == src_width && 2 * dst_height == src_height) ||
235          (4 * dst_width == src_width && 4 * dst_height == src_height);
236 }
237 
av1_resize_and_extend_frame_neon_dotprod(const YV12_BUFFER_CONFIG * src,YV12_BUFFER_CONFIG * dst,const InterpFilter filter,const int phase,const int num_planes)238 void av1_resize_and_extend_frame_neon_dotprod(const YV12_BUFFER_CONFIG *src,
239                                               YV12_BUFFER_CONFIG *dst,
240                                               const InterpFilter filter,
241                                               const int phase,
242                                               const int num_planes) {
243   assert(filter == BILINEAR || filter == EIGHTTAP_SMOOTH ||
244          filter == EIGHTTAP_REGULAR);
245 
246   bool has_normative_scaler =
247       has_normative_scaler_neon_dotprod(src->y_crop_width, src->y_crop_height,
248                                         dst->y_crop_width, dst->y_crop_height);
249 
250   if (num_planes > 1) {
251     has_normative_scaler =
252         has_normative_scaler && has_normative_scaler_neon_dotprod(
253                                     src->uv_crop_width, src->uv_crop_height,
254                                     dst->uv_crop_width, dst->uv_crop_height);
255   }
256 
257   if (!has_normative_scaler || filter == BILINEAR || phase == 0) {
258     av1_resize_and_extend_frame_neon(src, dst, filter, phase, num_planes);
259     return;
260   }
261 
262   // We use AOMMIN(num_planes, MAX_MB_PLANE) instead of num_planes to quiet
263   // the static analysis warnings.
264   int malloc_failed = 0;
265   for (int i = 0; i < AOMMIN(num_planes, MAX_MB_PLANE); ++i) {
266     const int is_uv = i > 0;
267     const int src_w = src->crop_widths[is_uv];
268     const int src_h = src->crop_heights[is_uv];
269     const int dst_w = dst->crop_widths[is_uv];
270     const int dst_h = dst->crop_heights[is_uv];
271     const int dst_y_w = (dst->crop_widths[0] + 1) & ~1;
272     const int dst_y_h = (dst->crop_heights[0] + 1) & ~1;
273 
274     if (2 * dst_w == src_w && 2 * dst_h == src_h) {
275       const int buffer_stride = (dst_y_w + 7) & ~7;
276       const int buffer_height = (2 * dst_y_h + SUBPEL_TAPS - 2 + 7) & ~7;
277       uint8_t *const temp_buffer =
278           (uint8_t *)malloc(buffer_stride * buffer_height);
279       if (!temp_buffer) {
280         malloc_failed = 1;
281         break;
282       }
283       const InterpKernel *interp_kernel =
284           (const InterpKernel *)av1_interp_filter_params_list[filter]
285               .filter_ptr;
286       scale_plane_2_to_1_8tap(src->buffers[i], src->strides[is_uv],
287                               dst->buffers[i], dst->strides[is_uv], dst_w,
288                               dst_h, interp_kernel[phase], temp_buffer);
289       free(temp_buffer);
290     } else if (4 * dst_w == src_w && 4 * dst_h == src_h) {
291       const int buffer_stride = (dst_y_w + 1) & ~1;
292       const int buffer_height = (4 * dst_y_h + SUBPEL_TAPS - 2 + 7) & ~7;
293       uint8_t *const temp_buffer =
294           (uint8_t *)malloc(buffer_stride * buffer_height);
295       if (!temp_buffer) {
296         malloc_failed = 1;
297         break;
298       }
299       const InterpKernel *interp_kernel =
300           (const InterpKernel *)av1_interp_filter_params_list[filter]
301               .filter_ptr;
302       scale_plane_4_to_1_8tap(src->buffers[i], src->strides[is_uv],
303                               dst->buffers[i], dst->strides[is_uv], dst_w,
304                               dst_h, interp_kernel[phase], temp_buffer);
305       free(temp_buffer);
306     }
307   }
308 
309   if (malloc_failed) {
310     av1_resize_and_extend_frame_c(src, dst, filter, phase, num_planes);
311   } else {
312     aom_extend_frame_borders(dst, num_planes);
313   }
314 }
315