1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "av1/common/arm/compound_convolve_neon.h"
17 #include "config/aom_config.h"
18 #include "config/av1_rtcd.h"
19
20 DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
21 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
22 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
23 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
24 };
25
26 DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
27 // clang-format off
28 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9,
29 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13
30 // clang-format on
31 };
32
convolve6_4_2d_h(uint8x16_t samples,const int8x16_t x_filter,const uint8x16_t permute_tbl,const int32x4_t horiz_const)33 static inline int16x4_t convolve6_4_2d_h(uint8x16_t samples,
34 const int8x16_t x_filter,
35 const uint8x16_t permute_tbl,
36 const int32x4_t horiz_const) {
37 // Permute samples ready for matrix multiply.
38 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 }
39 uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
40
41 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
42 // (filter), destructively accumulating into the destination register.
43 int32x4_t sum = vusmmlaq_s32(horiz_const, permuted_samples, x_filter);
44
45 // We halved the convolution filter values so -1 from the right shift.
46 return vshrn_n_s32(sum, ROUND0_BITS - 1);
47 }
48
convolve6_8_2d_h(uint8x16_t samples,const int8x16_t x_filter,const uint8x16x2_t permute_tbl,const int32x4_t horiz_const)49 static inline int16x8_t convolve6_8_2d_h(uint8x16_t samples,
50 const int8x16_t x_filter,
51 const uint8x16x2_t permute_tbl,
52 const int32x4_t horiz_const) {
53 // Permute samples ready for matrix multiply.
54 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 }
55 // { 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 }
56 uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
57 vqtbl1q_u8(samples, permute_tbl.val[1]) };
58
59 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
60 // (filter), destructively accumulating into the destination register.
61 int32x4_t sum0123 = vusmmlaq_s32(horiz_const, permuted_samples[0], x_filter);
62 int32x4_t sum4567 = vusmmlaq_s32(horiz_const, permuted_samples[1], x_filter);
63
64 // Narrow and re-pack.
65 // We halved the convolution filter values so -1 from the right shift.
66 return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
67 vshrn_n_s32(sum4567, ROUND0_BITS - 1));
68 }
69
dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(const uint8_t * src,int src_stride,int16_t * im_block,const int im_stride,const int16_t * x_filter_ptr,const int im_h,int w)70 static inline void dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(
71 const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
72 const int16_t *x_filter_ptr, const int im_h, int w) {
73 const int bd = 8;
74 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
75 // shifts - which are generally faster than rounding shifts on modern CPUs.
76 // (The extra -1 is needed because we halved the filter values.)
77 const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
78 (1 << ((ROUND0_BITS - 1) - 1)));
79
80 // Filter values are even, so halve to reduce intermediate precision reqs.
81 const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
82 // Stagger the filter for use with the matrix multiply instructions.
83 // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 }
84 const int8x16_t x_filter =
85 vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
86
87 const uint8_t *src_ptr = src;
88 int16_t *dst_ptr = im_block;
89 int dst_stride = im_stride;
90 int height = im_h;
91
92 if (w == 4) {
93 const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
94 do {
95 uint8x16_t s0, s1, s2, s3;
96 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
97
98 int16x4_t d0 = convolve6_4_2d_h(s0, x_filter, permute_tbl, horiz_const);
99 int16x4_t d1 = convolve6_4_2d_h(s1, x_filter, permute_tbl, horiz_const);
100 int16x4_t d2 = convolve6_4_2d_h(s2, x_filter, permute_tbl, horiz_const);
101 int16x4_t d3 = convolve6_4_2d_h(s3, x_filter, permute_tbl, horiz_const);
102
103 store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
104
105 src_ptr += 4 * src_stride;
106 dst_ptr += 4 * dst_stride;
107 height -= 4;
108 } while (height > 4);
109
110 do {
111 uint8x16_t s0 = vld1q_u8(src_ptr);
112
113 int16x4_t d0 = convolve6_4_2d_h(s0, x_filter, permute_tbl, horiz_const);
114
115 vst1_s16(dst_ptr, d0);
116
117 src_ptr += src_stride;
118 dst_ptr += dst_stride;
119 } while (--height != 0);
120 } else {
121 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
122 do {
123 const uint8_t *s = src_ptr;
124 int16_t *d = dst_ptr;
125 int width = w;
126
127 do {
128 uint8x16_t s0, s1, s2, s3;
129 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
130
131 int16x8_t d0 = convolve6_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
132 int16x8_t d1 = convolve6_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
133 int16x8_t d2 = convolve6_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
134 int16x8_t d3 = convolve6_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
135
136 store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
137
138 s += 8;
139 d += 8;
140 width -= 8;
141 } while (width > 0);
142 src_ptr += 4 * src_stride;
143 dst_ptr += 4 * dst_stride;
144 height -= 4;
145 } while (height > 4);
146
147 do {
148 const uint8_t *s = src_ptr;
149 int16_t *d = dst_ptr;
150 int width = w;
151
152 do {
153 uint8x16_t s0 = vld1q_u8(s);
154
155 int16x8_t d0 = convolve6_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
156
157 vst1q_s16(d, d0);
158
159 s += 8;
160 d += 8;
161 width -= 8;
162 } while (width > 0);
163 src_ptr += src_stride;
164 dst_ptr += dst_stride;
165 } while (--height != 0);
166 }
167 }
168
convolve8_8_2d_h(uint8x16_t samples,const int8x8_t x_filter,const uint8x16x3_t permute_tbl,const int32x4_t horiz_const)169 static inline int16x8_t convolve8_8_2d_h(uint8x16_t samples,
170 const int8x8_t x_filter,
171 const uint8x16x3_t permute_tbl,
172 const int32x4_t horiz_const) {
173 uint8x16_t permuted_samples[3];
174 int32x4_t sum[2];
175
176 // Permute samples ready for dot product.
177 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
178 permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
179 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
180 permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
181 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
182 permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
183
184 // First 4 output values.
185 sum[0] = vusdotq_lane_s32(horiz_const, permuted_samples[0], x_filter, 0);
186 sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
187 // Second 4 output values.
188 sum[1] = vusdotq_lane_s32(horiz_const, permuted_samples[1], x_filter, 0);
189 sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
190
191 // Narrow and re-pack.
192 // We halved the convolution filter values so -1 from the right shift.
193 return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
194 vshrn_n_s32(sum[1], ROUND0_BITS - 1));
195 }
196
dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(const uint8_t * src,int src_stride,int16_t * im_block,const int im_stride,const int16_t * x_filter_ptr,const int im_h,int w)197 static inline void dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(
198 const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
199 const int16_t *x_filter_ptr, const int im_h, int w) {
200 const int bd = 8;
201 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
202 // shifts - which are generally faster than rounding shifts on modern CPUs.
203 // (The extra -1 is needed because we halved the filter values.)
204 const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
205 (1 << ((ROUND0_BITS - 1) - 1)));
206
207 const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
208 // Filter values are even, so halve to reduce intermediate precision reqs.
209 const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
210
211 const uint8_t *src_ptr = src;
212 int16_t *dst_ptr = im_block;
213 int dst_stride = im_stride;
214 int height = im_h;
215
216 do {
217 const uint8_t *s = src_ptr;
218 int16_t *d = dst_ptr;
219 int width = w;
220
221 do {
222 uint8x16_t s0, s1, s2, s3;
223 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
224
225 int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
226 int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
227 int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
228 int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
229
230 store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
231
232 s += 8;
233 d += 8;
234 width -= 8;
235 } while (width > 0);
236 src_ptr += 4 * src_stride;
237 dst_ptr += 4 * dst_stride;
238 height -= 4;
239 } while (height > 4);
240
241 do {
242 const uint8_t *s = src_ptr;
243 int16_t *d = dst_ptr;
244 int width = w;
245
246 do {
247 uint8x16_t s0 = vld1q_u8(s);
248
249 int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
250
251 vst1q_s16(d, d0);
252
253 s += 8;
254 d += 8;
255 width -= 8;
256 } while (width > 0);
257 src_ptr += src_stride;
258 dst_ptr += dst_stride;
259 } while (--height != 0);
260 }
261
av1_dist_wtd_convolve_2d_neon_i8mm(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params)262 void av1_dist_wtd_convolve_2d_neon_i8mm(
263 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
264 int h, const InterpFilterParams *filter_params_x,
265 const InterpFilterParams *filter_params_y, const int subpel_x_qn,
266 const int subpel_y_qn, ConvolveParams *conv_params) {
267 assert(w % 4 == 0);
268 assert(h % 4 == 0);
269
270 DECLARE_ALIGNED(16, int16_t,
271 im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
272
273 const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
274 const int clamped_x_taps = x_filter_taps < 6 ? 6 : x_filter_taps;
275 const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
276 const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
277
278 const int im_h = h + clamped_y_taps - 1;
279 const int im_stride = MAX_SB_SIZE;
280 const int vert_offset = clamped_y_taps / 2 - 1;
281 const int horiz_offset = clamped_x_taps / 2 - 1;
282 const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
283 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
284 filter_params_x, subpel_x_qn & SUBPEL_MASK);
285 const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
286 filter_params_y, subpel_y_qn & SUBPEL_MASK);
287
288 const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
289
290 if (clamped_x_taps == 6) {
291 dist_wtd_convolve_2d_horiz_6tap_neon_i8mm(src_ptr, src_stride, im_block,
292 im_stride, x_filter_ptr, im_h, w);
293 } else {
294 dist_wtd_convolve_2d_horiz_8tap_neon_i8mm(src_ptr, src_stride, im_block,
295 im_stride, x_filter_ptr, im_h, w);
296 }
297
298 if (clamped_y_taps == 6) {
299 if (conv_params->do_average) {
300 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
301 dist_wtd_convolve_2d_vert_6tap_dist_wtd_avg_neon(
302 im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
303 w);
304 } else {
305 dist_wtd_convolve_2d_vert_6tap_avg_neon(im_block, im_stride, dst8,
306 dst8_stride, conv_params,
307 y_filter, h, w);
308 }
309 } else {
310 dist_wtd_convolve_2d_vert_6tap_neon(im_block, im_stride, conv_params,
311 y_filter, h, w);
312 }
313 } else {
314 if (conv_params->do_average) {
315 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
316 dist_wtd_convolve_2d_vert_8tap_dist_wtd_avg_neon(
317 im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
318 w);
319 } else {
320 dist_wtd_convolve_2d_vert_8tap_avg_neon(im_block, im_stride, dst8,
321 dst8_stride, conv_params,
322 y_filter, h, w);
323 }
324 } else {
325 dist_wtd_convolve_2d_vert_8tap_neon(im_block, im_stride, conv_params,
326 y_filter, h, w);
327 }
328 }
329 }
330
convolve6_4_x(uint8x16_t samples,const int8x16_t x_filter,const uint8x16_t permute_tbl,const int32x4_t round_offset)331 static inline uint16x4_t convolve6_4_x(uint8x16_t samples,
332 const int8x16_t x_filter,
333 const uint8x16_t permute_tbl,
334 const int32x4_t round_offset) {
335 // Permute samples ready for matrix multiply.
336 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 }
337 uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
338
339 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
340 // (filter), destructively accumulating into the destination register.
341 int32x4_t sum = vusmmlaq_s32(round_offset, permuted_samples, x_filter);
342
343 // We halved the convolution filter values so -1 from the right shift.
344 return vreinterpret_u16_s16(vshrn_n_s32(sum, ROUND0_BITS - 1));
345 }
346
convolve6_8_x(uint8x16_t samples,const int8x16_t x_filter,const uint8x16x2_t permute_tbl,const int32x4_t round_offset)347 static inline uint16x8_t convolve6_8_x(uint8x16_t samples,
348 const int8x16_t x_filter,
349 const uint8x16x2_t permute_tbl,
350 const int32x4_t round_offset) {
351 // Permute samples ready for matrix multiply.
352 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 }
353 // { 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 }
354 uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
355 vqtbl1q_u8(samples, permute_tbl.val[1]) };
356
357 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
358 // (filter), destructively accumulating into the destination register.
359 int32x4_t sum0123 = vusmmlaq_s32(round_offset, permuted_samples[0], x_filter);
360 int32x4_t sum4567 = vusmmlaq_s32(round_offset, permuted_samples[1], x_filter);
361
362 // Narrow and re-pack.
363 // We halved the convolution filter values so -1 from the right shift.
364 int16x8_t res = vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
365 vshrn_n_s32(sum4567, ROUND0_BITS - 1));
366 return vreinterpretq_u16_s16(res);
367 }
368
convolve8_8_x(uint8x16_t samples,const int8x8_t x_filter,const uint8x16x3_t permute_tbl,const int32x4_t round_offset)369 static inline uint16x8_t convolve8_8_x(uint8x16_t samples,
370 const int8x8_t x_filter,
371 const uint8x16x3_t permute_tbl,
372 const int32x4_t round_offset) {
373 uint8x16_t permuted_samples[3];
374 int32x4_t sum[2];
375
376 // Permute samples ready for dot product.
377 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
378 permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
379 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
380 permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
381 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
382 permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
383
384 // First 4 output values.
385 sum[0] = vusdotq_lane_s32(round_offset, permuted_samples[0], x_filter, 0);
386 sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
387 // Second 4 output values.
388 sum[1] = vusdotq_lane_s32(round_offset, permuted_samples[1], x_filter, 0);
389 sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
390
391 // Narrow and re-pack.
392 // We halved the convolution filter values so -1 from the right shift.
393 int16x8_t res = vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
394 vshrn_n_s32(sum[1], ROUND0_BITS - 1));
395 return vreinterpretq_u16_s16(res);
396 }
397
dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,uint8_t * dst8,int dst8_stride,int w,int h,const int16_t * x_filter_ptr,const uint16_t fwd_offset,const uint16_t bck_offset)398 static inline void dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(
399 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
400 uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr,
401 const uint16_t fwd_offset, const uint16_t bck_offset) {
402 assert(w % 4 == 0);
403 assert(h % 4 == 0);
404
405 const int bd = 8;
406 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
407 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
408 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
409 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
410 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
411 // shifts - which are generally faster than rounding shifts on modern CPUs.
412 // (The extra -1 is needed because we halved the filter values.)
413 const int32x4_t round_offset_shim = vdupq_n_s32(
414 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
415
416 // Filter values are even, so halve to reduce intermediate precision reqs.
417 const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
418 // Stagger the filter for use with the matrix multiply instructions.
419 // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 }
420 const int8x16_t x_filter =
421 vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
422
423 if (w == 4) {
424 const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
425 do {
426 uint8x16_t s0, s1, s2, s3;
427 load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
428
429 uint16x4_t d0 =
430 convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim);
431 uint16x4_t d1 =
432 convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim);
433 uint16x4_t d2 =
434 convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim);
435 uint16x4_t d3 =
436 convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim);
437
438 uint16x4_t dd0, dd1, dd2, dd3;
439 load_u16_4x4(dst, dst_stride, &dd0, &dd1, &dd2, &dd3);
440
441 uint8x8_t d01_u8, d23_u8;
442 compute_dist_wtd_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
443 bck_offset, round_offset_vec, &d01_u8, &d23_u8);
444
445 store_u8x4_strided_x2(dst8 + 0 * dst8_stride, dst8_stride, d01_u8);
446 store_u8x4_strided_x2(dst8 + 2 * dst8_stride, dst8_stride, d23_u8);
447
448 src += 4 * src_stride;
449 dst += 4 * dst_stride;
450 dst8 += 4 * dst8_stride;
451 h -= 4;
452 } while (h != 0);
453 } else {
454 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
455 do {
456 const uint8_t *s = src;
457 uint16_t *d = dst;
458 uint8_t *d_u8 = dst8;
459 int width = w;
460
461 do {
462 uint8x16_t s0, s1, s2, s3;
463 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
464
465 uint16x8_t d0 =
466 convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim);
467 uint16x8_t d1 =
468 convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim);
469 uint16x8_t d2 =
470 convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim);
471 uint16x8_t d3 =
472 convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim);
473
474 uint16x8_t dd0, dd1, dd2, dd3;
475 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
476
477 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
478 compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
479 bck_offset, round_offset_vec, &d0_u8, &d1_u8,
480 &d2_u8, &d3_u8);
481
482 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
483
484 s += 8;
485 d += 8;
486 d_u8 += 8;
487 width -= 8;
488 } while (width != 0);
489 src += 4 * src_stride;
490 dst += 4 * dst_stride;
491 dst8 += 4 * dst8_stride;
492 h -= 4;
493 } while (h != 0);
494 }
495 }
496
dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,uint8_t * dst8,int dst8_stride,int w,int h,const int16_t * x_filter_ptr,const uint16_t fwd_offset,const uint16_t bck_offset)497 static inline void dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm(
498 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
499 uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr,
500 const uint16_t fwd_offset, const uint16_t bck_offset) {
501 assert(w % 4 == 0);
502 assert(h % 4 == 0);
503
504 const int bd = 8;
505 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
506 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
507 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
508 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
509 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
510 // shifts - which are generally faster than rounding shifts on modern CPUs.
511 // (The extra -1 is needed because we halved the filter values.)
512 const int32x4_t round_offset_shim = vdupq_n_s32(
513 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
514
515 const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
516 // Filter values are even, so halve to reduce intermediate precision reqs.
517 const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
518
519 do {
520 const uint8_t *s = src;
521 uint16_t *d = dst;
522 uint8_t *d_u8 = dst8;
523 int width = w;
524
525 do {
526 uint8x16_t s0, s1, s2, s3;
527 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
528
529 uint16x8_t d0 =
530 convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
531 uint16x8_t d1 =
532 convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
533 uint16x8_t d2 =
534 convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
535 uint16x8_t d3 =
536 convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
537
538 uint16x8_t dd0, dd1, dd2, dd3;
539 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
540
541 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
542 compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
543 bck_offset, round_offset_vec, &d0_u8, &d1_u8,
544 &d2_u8, &d3_u8);
545
546 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
547
548 s += 8;
549 d += 8;
550 d_u8 += 8;
551 width -= 8;
552 } while (width != 0);
553 src += 4 * src_stride;
554 dst += 4 * dst_stride;
555 dst8 += 4 * dst8_stride;
556 h -= 4;
557 } while (h != 0);
558 }
559
dist_wtd_convolve_x_avg_6tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,uint8_t * dst8,int dst8_stride,int w,int h,const int16_t * x_filter_ptr)560 static inline void dist_wtd_convolve_x_avg_6tap_neon_i8mm(
561 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
562 uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr) {
563 assert(w % 4 == 0);
564 assert(h % 4 == 0);
565
566 const int bd = 8;
567 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
568 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
569 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
570 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
571 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
572 // shifts - which are generally faster than rounding shifts on modern CPUs.
573 // (The extra -1 is needed because we halved the filter values.)
574 const int32x4_t round_offset_shim = vdupq_n_s32(
575 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
576
577 // Filter values are even, so halve to reduce intermediate precision reqs.
578 const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
579 // Stagger the filter for use with the matrix multiply instructions.
580 // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 }
581 const int8x16_t x_filter =
582 vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
583
584 if (w == 4) {
585 const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
586 do {
587 uint8x16_t s0, s1, s2, s3;
588 load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
589
590 uint16x4_t d0 =
591 convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim);
592 uint16x4_t d1 =
593 convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim);
594 uint16x4_t d2 =
595 convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim);
596 uint16x4_t d3 =
597 convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim);
598
599 uint16x4_t dd0, dd1, dd2, dd3;
600 load_u16_4x4(dst, dst_stride, &dd0, &dd1, &dd2, &dd3);
601
602 uint8x8_t d01_u8, d23_u8;
603 compute_basic_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
604 round_offset_vec, &d01_u8, &d23_u8);
605
606 store_u8x4_strided_x2(dst8 + 0 * dst8_stride, dst8_stride, d01_u8);
607 store_u8x4_strided_x2(dst8 + 2 * dst8_stride, dst8_stride, d23_u8);
608
609 src += 4 * src_stride;
610 dst += 4 * dst_stride;
611 dst8 += 4 * dst8_stride;
612 h -= 4;
613 } while (h != 0);
614 } else {
615 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
616 do {
617 const uint8_t *s = src;
618 uint16_t *d = dst;
619 uint8_t *d_u8 = dst8;
620 int width = w;
621
622 do {
623 uint8x16_t s0, s1, s2, s3;
624 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
625
626 uint16x8_t d0 =
627 convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim);
628 uint16x8_t d1 =
629 convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim);
630 uint16x8_t d2 =
631 convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim);
632 uint16x8_t d3 =
633 convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim);
634
635 uint16x8_t dd0, dd1, dd2, dd3;
636 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
637
638 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
639 compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
640 round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
641
642 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
643
644 s += 8;
645 d += 8;
646 d_u8 += 8;
647 width -= 8;
648 } while (width != 0);
649 src += 4 * src_stride;
650 dst += 4 * dst_stride;
651 dst8 += 4 * dst8_stride;
652 h -= 4;
653 } while (h != 0);
654 }
655 }
656
dist_wtd_convolve_x_avg_8tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,uint8_t * dst8,int dst8_stride,int w,int h,const int16_t * x_filter_ptr)657 static inline void dist_wtd_convolve_x_avg_8tap_neon_i8mm(
658 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride,
659 uint8_t *dst8, int dst8_stride, int w, int h, const int16_t *x_filter_ptr) {
660 assert(w % 4 == 0);
661 assert(h % 4 == 0);
662
663 const int bd = 8;
664 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
665 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
666 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
667 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
668 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
669 // shifts - which are generally faster than rounding shifts on modern CPUs.
670 // (The extra -1 is needed because we halved the filter values.)
671 const int32x4_t round_offset_shim = vdupq_n_s32(
672 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
673
674 const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
675 // Filter values are even, so halve to reduce intermediate precision reqs.
676 const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
677
678 do {
679 const uint8_t *s = src;
680 uint16_t *d = dst;
681 uint8_t *d_u8 = dst8;
682 int width = w;
683
684 do {
685 uint8x16_t s0, s1, s2, s3;
686 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
687
688 uint16x8_t d0 =
689 convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
690 uint16x8_t d1 =
691 convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
692 uint16x8_t d2 =
693 convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
694 uint16x8_t d3 =
695 convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
696
697 uint16x8_t dd0, dd1, dd2, dd3;
698 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
699
700 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
701 compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
702 round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
703
704 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
705
706 s += 8;
707 d += 8;
708 d_u8 += 8;
709 width -= 8;
710 } while (width != 0);
711 src += 4 * src_stride;
712 dst += 4 * dst_stride;
713 dst8 += 4 * dst8_stride;
714 h -= 4;
715 } while (h != 0);
716 }
717
dist_wtd_convolve_x_6tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,const int16_t * x_filter_ptr)718 static inline void dist_wtd_convolve_x_6tap_neon_i8mm(
719 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
720 int h, const int16_t *x_filter_ptr) {
721 assert(w % 4 == 0);
722 assert(h % 4 == 0);
723
724 const int bd = 8;
725 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
726 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
727 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
728 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
729 // shifts - which are generally faster than rounding shifts on modern CPUs.
730 // (The extra -1 is needed because we halved the filter values.)
731 const int32x4_t round_offset_shim = vdupq_n_s32(
732 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
733
734 // Filter values are even, so halve to reduce intermediate precision reqs.
735 const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
736 // Stagger the filter for use with the matrix multiply instructions.
737 // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 }
738 const int8x16_t x_filter =
739 vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
740
741 if (w == 4) {
742 const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
743 do {
744 uint8x16_t s0, s1, s2, s3;
745 load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
746
747 uint16x4_t d0 =
748 convolve6_4_x(s0, x_filter, permute_tbl, round_offset_shim);
749 uint16x4_t d1 =
750 convolve6_4_x(s1, x_filter, permute_tbl, round_offset_shim);
751 uint16x4_t d2 =
752 convolve6_4_x(s2, x_filter, permute_tbl, round_offset_shim);
753 uint16x4_t d3 =
754 convolve6_4_x(s3, x_filter, permute_tbl, round_offset_shim);
755
756 store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
757
758 src += 4 * src_stride;
759 dst += 4 * dst_stride;
760 h -= 4;
761 } while (h != 0);
762 } else {
763 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
764 do {
765 const uint8_t *s = src;
766 uint16_t *d = dst;
767 int width = w;
768
769 do {
770 uint8x16_t s0, s1, s2, s3;
771 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
772
773 uint16x8_t d0 =
774 convolve6_8_x(s0, x_filter, permute_tbl, round_offset_shim);
775 uint16x8_t d1 =
776 convolve6_8_x(s1, x_filter, permute_tbl, round_offset_shim);
777 uint16x8_t d2 =
778 convolve6_8_x(s2, x_filter, permute_tbl, round_offset_shim);
779 uint16x8_t d3 =
780 convolve6_8_x(s3, x_filter, permute_tbl, round_offset_shim);
781
782 store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
783
784 s += 8;
785 d += 8;
786 width -= 8;
787 } while (width != 0);
788 src += 4 * src_stride;
789 dst += 4 * dst_stride;
790 h -= 4;
791 } while (h != 0);
792 }
793 }
794
dist_wtd_convolve_x_8tap_neon_i8mm(const uint8_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,const int16_t * x_filter_ptr)795 static inline void dist_wtd_convolve_x_8tap_neon_i8mm(
796 const uint8_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
797 int h, const int16_t *x_filter_ptr) {
798 assert(w % 4 == 0);
799 assert(h % 4 == 0);
800
801 const int bd = 8;
802 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
803 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
804 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
805 // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
806 // shifts - which are generally faster than rounding shifts on modern CPUs.
807 // (The extra -1 is needed because we halved the filter values.)
808 const int32x4_t round_offset_shim = vdupq_n_s32(
809 (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
810
811 const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
812 // Filter values are even, so halve to reduce intermediate precision reqs.
813 const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
814
815 do {
816 const uint8_t *s = src;
817 uint16_t *d = dst;
818 int width = w;
819
820 do {
821 uint8x16_t s0, s1, s2, s3;
822 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
823
824 uint16x8_t d0 =
825 convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
826 uint16x8_t d1 =
827 convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
828 uint16x8_t d2 =
829 convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
830 uint16x8_t d3 =
831 convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
832
833 store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
834
835 s += 8;
836 d += 8;
837 width -= 8;
838 } while (width != 0);
839 src += 4 * src_stride;
840 dst += 4 * dst_stride;
841 h -= 4;
842 } while (h != 0);
843 }
844
av1_dist_wtd_convolve_x_neon_i8mm(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)845 void av1_dist_wtd_convolve_x_neon_i8mm(
846 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
847 int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
848 ConvolveParams *conv_params) {
849 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
850 filter_params_x, subpel_x_qn & SUBPEL_MASK);
851 const int filter_taps =
852 get_filter_tap(filter_params_x, subpel_x_qn & SUBPEL_MASK);
853
854 src -= (SUBPEL_TAPS / 2 - 1);
855
856 if (conv_params->do_average) {
857 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
858 if (filter_taps < 8) {
859 dist_wtd_convolve_x_dist_wtd_avg_6tap_neon_i8mm(
860 src + 1, src_stride, conv_params->dst, conv_params->dst_stride,
861 dst8, dst8_stride, w, h, x_filter_ptr, conv_params->fwd_offset,
862 conv_params->bck_offset);
863 return;
864 }
865
866 dist_wtd_convolve_x_dist_wtd_avg_8tap_neon_i8mm(
867 src, src_stride, conv_params->dst, conv_params->dst_stride, dst8,
868 dst8_stride, w, h, x_filter_ptr, conv_params->fwd_offset,
869 conv_params->bck_offset);
870 } else {
871 if (filter_taps < 8) {
872 dist_wtd_convolve_x_avg_6tap_neon_i8mm(
873 src + 1, src_stride, conv_params->dst, conv_params->dst_stride,
874 dst8, dst8_stride, w, h, x_filter_ptr);
875 return;
876 }
877
878 dist_wtd_convolve_x_avg_8tap_neon_i8mm(src, src_stride, conv_params->dst,
879 conv_params->dst_stride, dst8,
880 dst8_stride, w, h, x_filter_ptr);
881 }
882 } else {
883 if (filter_taps < 8) {
884 dist_wtd_convolve_x_6tap_neon_i8mm(src + 1, src_stride, conv_params->dst,
885 conv_params->dst_stride, w, h,
886 x_filter_ptr);
887 return;
888 }
889
890 dist_wtd_convolve_x_8tap_neon_i8mm(src, src_stride, conv_params->dst,
891 conv_params->dst_stride, w, h,
892 x_filter_ptr);
893 }
894 }
895