xref: /aosp_15_r20/external/libaom/av1/common/arm/resize_neon_i8mm.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "aom_dsp/arm/transpose_neon.h"
17 #include "av1/common/arm/resize_neon.h"
18 #include "av1/common/resize.h"
19 #include "config/aom_scale_rtcd.h"
20 #include "config/av1_rtcd.h"
21 
22 // clang-format off
23 DECLARE_ALIGNED(16, static const uint8_t, kScalePermuteTbl[16]) = {
24   0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10, 11
25 };
26 // clang-format on
27 
scale_2_to_1_filter8_8(const uint8x16_t s0,const uint8x16_t s1,const uint8x16_t permute_tbl,const int8x16_t filters)28 static inline uint8x8_t scale_2_to_1_filter8_8(const uint8x16_t s0,
29                                                const uint8x16_t s1,
30                                                const uint8x16_t permute_tbl,
31                                                const int8x16_t filters) {
32   // Permute samples ready for matrix multiply.
33   // { 0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10, 11 }
34   uint8x16_t perm_samples[2] = { vqtbl1q_u8(s0, permute_tbl),
35                                  vqtbl1q_u8(s1, permute_tbl) };
36 
37   // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
38   // (filter), destructively accumulating into the destination register.
39   int32x4_t sum0123 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[0], filters);
40   int32x4_t sum4567 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[1], filters);
41 
42   int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
43 
44   // We halved the filter values so -1 from right shift.
45   return vqrshrun_n_s16(sum, FILTER_BITS - 1);
46 }
47 
scale_2_to_1_horiz_6tap(const uint8_t * src,const int src_stride,int w,int h,uint8_t * dst,const int dst_stride,const int16x8_t filter)48 static inline void scale_2_to_1_horiz_6tap(const uint8_t *src,
49                                            const int src_stride, int w, int h,
50                                            uint8_t *dst, const int dst_stride,
51                                            const int16x8_t filter) {
52   const int8x8_t filter_s8 = vmovn_s16(filter);
53   // Stagger the filter for use with the matrix multiply instructions.
54   // { f1, f2, f3, f4, f5, f6, 0, 0, 0, 0, f1, f2, f3, f4, f5, f6 }
55   const int8x16_t filters = vcombine_s8(vext_s8(filter_s8, filter_s8, 1),
56                                         vext_s8(filter_s8, filter_s8, 7));
57   const uint8x16_t permute_tbl = vld1q_u8(kScalePermuteTbl);
58 
59   do {
60     const uint8_t *s = src;
61     uint8_t *d = dst;
62     int width = w;
63 
64     do {
65       uint8x16_t s0[2], s1[2], s2[2], s3[2], s4[2], s5[2], s6[2], s7[2];
66       load_u8_16x8(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0], &s4[0],
67                    &s5[0], &s6[0], &s7[0]);
68       load_u8_16x8(s + 8, src_stride, &s0[1], &s1[1], &s2[1], &s3[1], &s4[1],
69                    &s5[1], &s6[1], &s7[1]);
70 
71       uint8x8_t d0 = scale_2_to_1_filter8_8(s0[0], s0[1], permute_tbl, filters);
72       uint8x8_t d1 = scale_2_to_1_filter8_8(s1[0], s1[1], permute_tbl, filters);
73       uint8x8_t d2 = scale_2_to_1_filter8_8(s2[0], s2[1], permute_tbl, filters);
74       uint8x8_t d3 = scale_2_to_1_filter8_8(s3[0], s3[1], permute_tbl, filters);
75 
76       uint8x8_t d4 = scale_2_to_1_filter8_8(s4[0], s4[1], permute_tbl, filters);
77       uint8x8_t d5 = scale_2_to_1_filter8_8(s5[0], s5[1], permute_tbl, filters);
78       uint8x8_t d6 = scale_2_to_1_filter8_8(s6[0], s6[1], permute_tbl, filters);
79       uint8x8_t d7 = scale_2_to_1_filter8_8(s7[0], s7[1], permute_tbl, filters);
80 
81       store_u8_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
82 
83       d += 8;
84       s += 16;
85       width -= 8;
86     } while (width > 0);
87 
88     dst += 8 * dst_stride;
89     src += 8 * src_stride;
90     h -= 8;
91   } while (h > 0);
92 }
93 
scale_plane_2_to_1_6tap(const uint8_t * src,const int src_stride,uint8_t * dst,const int dst_stride,const int w,const int h,const int16_t * const filter_ptr,uint8_t * const im_block)94 static inline void scale_plane_2_to_1_6tap(const uint8_t *src,
95                                            const int src_stride, uint8_t *dst,
96                                            const int dst_stride, const int w,
97                                            const int h,
98                                            const int16_t *const filter_ptr,
99                                            uint8_t *const im_block) {
100   assert(w > 0 && h > 0);
101 
102   const int im_h = 2 * h + SUBPEL_TAPS - 3;
103   const int im_stride = (w + 7) & ~7;
104   // All filter values are even, halve them to fit in int8_t when applying
105   // horizontal filter and stay in 16-bit elements when applying vertical
106   // filter.
107   const int16x8_t filters = vshrq_n_s16(vld1q_s16(filter_ptr), 1);
108 
109   const ptrdiff_t horiz_offset = SUBPEL_TAPS / 2 - 2;
110   const ptrdiff_t vert_offset = (SUBPEL_TAPS / 2 - 2) * src_stride;
111 
112   scale_2_to_1_horiz_6tap(src - horiz_offset - vert_offset, src_stride, w, im_h,
113                           im_block, im_stride, filters);
114 
115   scale_2_to_1_vert_6tap(im_block, im_stride, w, h, dst, dst_stride, filters);
116 }
117 
scale_4_to_1_filter8_8(const uint8x16_t s0,const uint8x16_t s1,const uint8x16_t s2,const uint8x16_t s3,const uint8x16_t permute_tbl,const int8x8_t filter)118 static inline uint8x8_t scale_4_to_1_filter8_8(
119     const uint8x16_t s0, const uint8x16_t s1, const uint8x16_t s2,
120     const uint8x16_t s3, const uint8x16_t permute_tbl, const int8x8_t filter) {
121   int8x16_t filters = vcombine_s8(filter, filter);
122 
123   uint8x16_t perm_samples[4] = { vqtbl1q_u8(s0, permute_tbl),
124                                  vqtbl1q_u8(s1, permute_tbl),
125                                  vqtbl1q_u8(s2, permute_tbl),
126                                  vqtbl1q_u8(s3, permute_tbl) };
127 
128   int32x4_t sum0 = vusdotq_s32(vdupq_n_s32(0), perm_samples[0], filters);
129   int32x4_t sum1 = vusdotq_s32(vdupq_n_s32(0), perm_samples[1], filters);
130   int32x4_t sum2 = vusdotq_s32(vdupq_n_s32(0), perm_samples[2], filters);
131   int32x4_t sum3 = vusdotq_s32(vdupq_n_s32(0), perm_samples[3], filters);
132 
133   int32x4_t sum01 = vpaddq_s32(sum0, sum1);
134   int32x4_t sum23 = vpaddq_s32(sum2, sum3);
135 
136   int16x8_t sum = vcombine_s16(vmovn_s32(sum01), vmovn_s32(sum23));
137 
138   // We halved the filter values so -1 from right shift.
139   return vqrshrun_n_s16(sum, FILTER_BITS - 1);
140 }
141 
scale_4_to_1_horiz_8tap(const uint8_t * src,const int src_stride,int w,int h,uint8_t * dst,const int dst_stride,const int16x8_t filters)142 static inline void scale_4_to_1_horiz_8tap(const uint8_t *src,
143                                            const int src_stride, int w, int h,
144                                            uint8_t *dst, const int dst_stride,
145                                            const int16x8_t filters) {
146   const int8x8_t filter = vmovn_s16(filters);
147   const uint8x16_t permute_tbl = vld1q_u8(kScalePermuteTbl);
148 
149   do {
150     const uint8_t *s = src;
151     uint8_t *d = dst;
152     int width = w;
153 
154     do {
155       uint8x16_t s0, s1, s2, s3, s4, s5, s6, s7;
156       load_u8_16x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
157 
158       uint8x8_t d0 =
159           scale_4_to_1_filter8_8(s0, s1, s2, s3, permute_tbl, filter);
160       uint8x8_t d1 =
161           scale_4_to_1_filter8_8(s4, s5, s6, s7, permute_tbl, filter);
162 
163       store_u8x2_strided_x4(d + 0 * dst_stride, dst_stride, d0);
164       store_u8x2_strided_x4(d + 4 * dst_stride, dst_stride, d1);
165 
166       d += 2;
167       s += 8;
168       width -= 2;
169     } while (width > 0);
170 
171     dst += 8 * dst_stride;
172     src += 8 * src_stride;
173     h -= 8;
174   } while (h > 0);
175 }
176 
scale_plane_4_to_1_8tap(const uint8_t * src,const int src_stride,uint8_t * dst,const int dst_stride,const int w,const int h,const int16_t * const filter_ptr,uint8_t * const im_block)177 static inline void scale_plane_4_to_1_8tap(const uint8_t *src,
178                                            const int src_stride, uint8_t *dst,
179                                            const int dst_stride, const int w,
180                                            const int h,
181                                            const int16_t *const filter_ptr,
182                                            uint8_t *const im_block) {
183   assert(w > 0 && h > 0);
184   const int im_h = 4 * h + SUBPEL_TAPS - 3;
185   const int im_stride = (w + 1) & ~1;
186   // All filter values are even, halve them to fit in int8_t when applying
187   // horizontal filter and stay in 16-bit elements when applying vertical
188   // filter.
189   const int16x8_t filters = vshrq_n_s16(vld1q_s16(filter_ptr), 1);
190 
191   const ptrdiff_t horiz_offset = SUBPEL_TAPS / 2 - 1;
192   const ptrdiff_t vert_offset = (SUBPEL_TAPS / 2 - 2) * src_stride;
193 
194   scale_4_to_1_horiz_8tap(src - horiz_offset - vert_offset, src_stride, w, im_h,
195                           im_block, im_stride, filters);
196 
197   // We can specialise the vertical filtering for 6-tap filters given that the
198   // EIGHTTAP_SMOOTH and EIGHTTAP_REGULAR filters are 0-padded.
199   scale_4_to_1_vert_6tap(im_block, im_stride, w, h, dst, dst_stride, filters);
200 }
201 
has_normative_scaler_neon_i8mm(const int src_width,const int src_height,const int dst_width,const int dst_height)202 static inline bool has_normative_scaler_neon_i8mm(const int src_width,
203                                                   const int src_height,
204                                                   const int dst_width,
205                                                   const int dst_height) {
206   return (2 * dst_width == src_width && 2 * dst_height == src_height) ||
207          (4 * dst_width == src_width && 4 * dst_height == src_height);
208 }
209 
av1_resize_and_extend_frame_neon_i8mm(const YV12_BUFFER_CONFIG * src,YV12_BUFFER_CONFIG * dst,const InterpFilter filter,const int phase,const int num_planes)210 void av1_resize_and_extend_frame_neon_i8mm(const YV12_BUFFER_CONFIG *src,
211                                            YV12_BUFFER_CONFIG *dst,
212                                            const InterpFilter filter,
213                                            const int phase,
214                                            const int num_planes) {
215   assert(filter == BILINEAR || filter == EIGHTTAP_SMOOTH ||
216          filter == EIGHTTAP_REGULAR);
217 
218   bool has_normative_scaler =
219       has_normative_scaler_neon_i8mm(src->y_crop_width, src->y_crop_height,
220                                      dst->y_crop_width, dst->y_crop_height);
221 
222   if (num_planes > 1) {
223     has_normative_scaler =
224         has_normative_scaler &&
225         has_normative_scaler_neon_i8mm(src->uv_crop_width, src->uv_crop_height,
226                                        dst->uv_crop_width, dst->uv_crop_height);
227   }
228 
229   if (!has_normative_scaler || filter == BILINEAR || phase == 0) {
230     av1_resize_and_extend_frame_neon(src, dst, filter, phase, num_planes);
231     return;
232   }
233 
234   // We use AOMMIN(num_planes, MAX_MB_PLANE) instead of num_planes to quiet
235   // the static analysis warnings.
236   int malloc_failed = 0;
237   for (int i = 0; i < AOMMIN(num_planes, MAX_MB_PLANE); ++i) {
238     const int is_uv = i > 0;
239     const int src_w = src->crop_widths[is_uv];
240     const int src_h = src->crop_heights[is_uv];
241     const int dst_w = dst->crop_widths[is_uv];
242     const int dst_h = dst->crop_heights[is_uv];
243     const int dst_y_w = (dst->crop_widths[0] + 1) & ~1;
244     const int dst_y_h = (dst->crop_heights[0] + 1) & ~1;
245 
246     if (2 * dst_w == src_w && 2 * dst_h == src_h) {
247       const int buffer_stride = (dst_y_w + 7) & ~7;
248       const int buffer_height = (2 * dst_y_h + SUBPEL_TAPS - 2 + 7) & ~7;
249       uint8_t *const temp_buffer =
250           (uint8_t *)malloc(buffer_stride * buffer_height);
251       if (!temp_buffer) {
252         malloc_failed = 1;
253         break;
254       }
255       const InterpKernel *interp_kernel =
256           (const InterpKernel *)av1_interp_filter_params_list[filter]
257               .filter_ptr;
258       scale_plane_2_to_1_6tap(src->buffers[i], src->strides[is_uv],
259                               dst->buffers[i], dst->strides[is_uv], dst_w,
260                               dst_h, interp_kernel[phase], temp_buffer);
261       free(temp_buffer);
262     } else if (4 * dst_w == src_w && 4 * dst_h == src_h) {
263       const int buffer_stride = (dst_y_w + 1) & ~1;
264       const int buffer_height = (4 * dst_y_h + SUBPEL_TAPS - 2 + 7) & ~7;
265       uint8_t *const temp_buffer =
266           (uint8_t *)malloc(buffer_stride * buffer_height);
267       if (!temp_buffer) {
268         malloc_failed = 1;
269         break;
270       }
271       const InterpKernel *interp_kernel =
272           (const InterpKernel *)av1_interp_filter_params_list[filter]
273               .filter_ptr;
274       scale_plane_4_to_1_8tap(src->buffers[i], src->strides[is_uv],
275                               dst->buffers[i], dst->strides[is_uv], dst_w,
276                               dst_h, interp_kernel[phase], temp_buffer);
277       free(temp_buffer);
278     }
279   }
280 
281   if (malloc_failed) {
282     av1_resize_and_extend_frame_c(src, dst, filter, phase, num_planes);
283   } else {
284     aom_extend_frame_borders(dst, num_planes);
285   }
286 }
287