1 /*
2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17
18 #include "aom/aom_integer.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21
22 #define MAX_UPSAMPLE_SZ 16
23
24 // TODO(aomedia:349436249): enable for armv7 after SIGBUS is fixed.
25 #if AOM_ARCH_AARCH64
26
27 // These kernels are a transposed version of those defined in reconintra.c,
28 // with the absolute value of the negatives taken in the top row.
29 DECLARE_ALIGNED(16, const uint8_t,
30 av1_filter_intra_taps_neon[FILTER_INTRA_MODES][7][8]) = {
31 // clang-format off
32 {
33 { 6, 5, 3, 3, 4, 3, 3, 3 },
34 { 10, 2, 1, 1, 6, 2, 2, 1 },
35 { 0, 10, 1, 1, 0, 6, 2, 2 },
36 { 0, 0, 10, 2, 0, 0, 6, 2 },
37 { 0, 0, 0, 10, 0, 0, 0, 6 },
38 { 12, 9, 7, 5, 2, 2, 2, 3 },
39 { 0, 0, 0, 0, 12, 9, 7, 5 }
40 },
41 {
42 { 10, 6, 4, 2, 10, 6, 4, 2 },
43 { 16, 0, 0, 0, 16, 0, 0, 0 },
44 { 0, 16, 0, 0, 0, 16, 0, 0 },
45 { 0, 0, 16, 0, 0, 0, 16, 0 },
46 { 0, 0, 0, 16, 0, 0, 0, 16 },
47 { 10, 6, 4, 2, 0, 0, 0, 0 },
48 { 0, 0, 0, 0, 10, 6, 4, 2 }
49 },
50 {
51 { 8, 8, 8, 8, 4, 4, 4, 4 },
52 { 8, 0, 0, 0, 4, 0, 0, 0 },
53 { 0, 8, 0, 0, 0, 4, 0, 0 },
54 { 0, 0, 8, 0, 0, 0, 4, 0 },
55 { 0, 0, 0, 8, 0, 0, 0, 4 },
56 { 16, 16, 16, 16, 0, 0, 0, 0 },
57 { 0, 0, 0, 0, 16, 16, 16, 16 }
58 },
59 {
60 { 2, 1, 1, 0, 1, 1, 1, 1 },
61 { 8, 3, 2, 1, 4, 3, 2, 2 },
62 { 0, 8, 3, 2, 0, 4, 3, 2 },
63 { 0, 0, 8, 3, 0, 0, 4, 3 },
64 { 0, 0, 0, 8, 0, 0, 0, 4 },
65 { 10, 6, 4, 2, 3, 4, 4, 3 },
66 { 0, 0, 0, 0, 10, 6, 4, 3 }
67 },
68 {
69 { 12, 10, 9, 8, 10, 9, 8, 7 },
70 { 14, 0, 0, 0, 12, 1, 0, 0 },
71 { 0, 14, 0, 0, 0, 12, 0, 0 },
72 { 0, 0, 14, 0, 0, 0, 12, 1 },
73 { 0, 0, 0, 14, 0, 0, 0, 12 },
74 { 14, 12, 11, 10, 0, 0, 1, 1 },
75 { 0, 0, 0, 0, 14, 12, 11, 9 }
76 }
77 // clang-format on
78 };
79
80 #define FILTER_INTRA_SCALE_BITS 4
81
av1_filter_intra_predictor_neon(uint8_t * dst,ptrdiff_t stride,TX_SIZE tx_size,const uint8_t * above,const uint8_t * left,int mode)82 void av1_filter_intra_predictor_neon(uint8_t *dst, ptrdiff_t stride,
83 TX_SIZE tx_size, const uint8_t *above,
84 const uint8_t *left, int mode) {
85 const int width = tx_size_wide[tx_size];
86 const int height = tx_size_high[tx_size];
87 assert(width <= 32 && height <= 32);
88
89 const uint8x8_t f0 = vld1_u8(av1_filter_intra_taps_neon[mode][0]);
90 const uint8x8_t f1 = vld1_u8(av1_filter_intra_taps_neon[mode][1]);
91 const uint8x8_t f2 = vld1_u8(av1_filter_intra_taps_neon[mode][2]);
92 const uint8x8_t f3 = vld1_u8(av1_filter_intra_taps_neon[mode][3]);
93 const uint8x8_t f4 = vld1_u8(av1_filter_intra_taps_neon[mode][4]);
94 const uint8x8_t f5 = vld1_u8(av1_filter_intra_taps_neon[mode][5]);
95 const uint8x8_t f6 = vld1_u8(av1_filter_intra_taps_neon[mode][6]);
96
97 uint8_t buffer[33][33];
98 // Populate the top row in the scratch buffer with data from above.
99 memcpy(buffer[0], &above[-1], (width + 1) * sizeof(uint8_t));
100 // Populate the first column in the scratch buffer with data from the left.
101 int r = 0;
102 do {
103 buffer[r + 1][0] = left[r];
104 } while (++r < height);
105
106 // Computing 4 cols per iteration (instead of 8) for 8x<h> blocks is faster.
107 if (width <= 8) {
108 r = 1;
109 do {
110 int c = 1;
111 uint8x8_t s0 = vld1_dup_u8(&buffer[r - 1][c - 1]);
112 uint8x8_t s5 = vld1_dup_u8(&buffer[r + 0][c - 1]);
113 uint8x8_t s6 = vld1_dup_u8(&buffer[r + 1][c - 1]);
114
115 do {
116 uint8x8_t s1234 = load_u8_4x1(&buffer[r - 1][c - 1] + 1);
117 uint8x8_t s1 = vdup_lane_u8(s1234, 0);
118 uint8x8_t s2 = vdup_lane_u8(s1234, 1);
119 uint8x8_t s3 = vdup_lane_u8(s1234, 2);
120 uint8x8_t s4 = vdup_lane_u8(s1234, 3);
121
122 uint16x8_t sum = vmull_u8(s1, f1);
123 // First row of each filter has all negative values so subtract.
124 sum = vmlsl_u8(sum, s0, f0);
125 sum = vmlal_u8(sum, s2, f2);
126 sum = vmlal_u8(sum, s3, f3);
127 sum = vmlal_u8(sum, s4, f4);
128 sum = vmlal_u8(sum, s5, f5);
129 sum = vmlal_u8(sum, s6, f6);
130
131 uint8x8_t res =
132 vqrshrun_n_s16(vreinterpretq_s16_u16(sum), FILTER_INTRA_SCALE_BITS);
133
134 // Store buffer[r + 0][c] and buffer[r + 1][c].
135 store_u8x4_strided_x2(&buffer[r][c], 33, res);
136
137 store_u8x4_strided_x2(dst + (r - 1) * stride + c - 1, stride, res);
138
139 s0 = s4;
140 s5 = vdup_lane_u8(res, 3);
141 s6 = vdup_lane_u8(res, 7);
142 c += 4;
143 } while (c < width + 1);
144
145 r += 2;
146 } while (r < height + 1);
147 } else {
148 r = 1;
149 do {
150 int c = 1;
151 uint8x8_t s0_lo = vld1_dup_u8(&buffer[r - 1][c - 1]);
152 uint8x8_t s5_lo = vld1_dup_u8(&buffer[r + 0][c - 1]);
153 uint8x8_t s6_lo = vld1_dup_u8(&buffer[r + 1][c - 1]);
154
155 do {
156 uint8x8_t s1234 = vld1_u8(&buffer[r - 1][c - 1] + 1);
157 uint8x8_t s1_lo = vdup_lane_u8(s1234, 0);
158 uint8x8_t s2_lo = vdup_lane_u8(s1234, 1);
159 uint8x8_t s3_lo = vdup_lane_u8(s1234, 2);
160 uint8x8_t s4_lo = vdup_lane_u8(s1234, 3);
161
162 uint16x8_t sum_lo = vmull_u8(s1_lo, f1);
163 // First row of each filter has all negative values so subtract.
164 sum_lo = vmlsl_u8(sum_lo, s0_lo, f0);
165 sum_lo = vmlal_u8(sum_lo, s2_lo, f2);
166 sum_lo = vmlal_u8(sum_lo, s3_lo, f3);
167 sum_lo = vmlal_u8(sum_lo, s4_lo, f4);
168 sum_lo = vmlal_u8(sum_lo, s5_lo, f5);
169 sum_lo = vmlal_u8(sum_lo, s6_lo, f6);
170
171 uint8x8_t res_lo = vqrshrun_n_s16(vreinterpretq_s16_u16(sum_lo),
172 FILTER_INTRA_SCALE_BITS);
173
174 uint8x8_t s0_hi = s4_lo;
175 uint8x8_t s1_hi = vdup_lane_u8(s1234, 4);
176 uint8x8_t s2_hi = vdup_lane_u8(s1234, 5);
177 uint8x8_t s3_hi = vdup_lane_u8(s1234, 6);
178 uint8x8_t s4_hi = vdup_lane_u8(s1234, 7);
179 uint8x8_t s5_hi = vdup_lane_u8(res_lo, 3);
180 uint8x8_t s6_hi = vdup_lane_u8(res_lo, 7);
181
182 uint16x8_t sum_hi = vmull_u8(s1_hi, f1);
183 // First row of each filter has all negative values so subtract.
184 sum_hi = vmlsl_u8(sum_hi, s0_hi, f0);
185 sum_hi = vmlal_u8(sum_hi, s2_hi, f2);
186 sum_hi = vmlal_u8(sum_hi, s3_hi, f3);
187 sum_hi = vmlal_u8(sum_hi, s4_hi, f4);
188 sum_hi = vmlal_u8(sum_hi, s5_hi, f5);
189 sum_hi = vmlal_u8(sum_hi, s6_hi, f6);
190
191 uint8x8_t res_hi = vqrshrun_n_s16(vreinterpretq_s16_u16(sum_hi),
192 FILTER_INTRA_SCALE_BITS);
193
194 uint32x2x2_t res =
195 vzip_u32(vreinterpret_u32_u8(res_lo), vreinterpret_u32_u8(res_hi));
196
197 vst1_u8(&buffer[r + 0][c], vreinterpret_u8_u32(res.val[0]));
198 vst1_u8(&buffer[r + 1][c], vreinterpret_u8_u32(res.val[1]));
199
200 vst1_u8(dst + (r - 1) * stride + c - 1,
201 vreinterpret_u8_u32(res.val[0]));
202 vst1_u8(dst + (r + 0) * stride + c - 1,
203 vreinterpret_u8_u32(res.val[1]));
204
205 s0_lo = s4_hi;
206 s5_lo = vdup_lane_u8(res_hi, 3);
207 s6_lo = vdup_lane_u8(res_hi, 7);
208 c += 8;
209 } while (c < width + 1);
210
211 r += 2;
212 } while (r < height + 1);
213 }
214 }
215 #endif // AOM_ARCH_AARCH64
216
av1_filter_intra_edge_neon(uint8_t * p,int sz,int strength)217 void av1_filter_intra_edge_neon(uint8_t *p, int sz, int strength) {
218 if (!strength) return;
219 assert(sz >= 0 && sz <= 129);
220
221 uint8_t edge[160]; // Max value of sz + enough padding for vector accesses.
222 memcpy(edge + 1, p, sz * sizeof(*p));
223
224 // Populate extra space appropriately.
225 edge[0] = edge[1];
226 edge[sz + 1] = edge[sz];
227 edge[sz + 2] = edge[sz];
228
229 // Don't overwrite first pixel.
230 uint8_t *dst = p + 1;
231 sz--;
232
233 if (strength == 1) { // Filter: {4, 8, 4}.
234 const uint8_t *src = edge + 1;
235
236 while (sz >= 8) {
237 uint8x8_t s0 = vld1_u8(src);
238 uint8x8_t s1 = vld1_u8(src + 1);
239 uint8x8_t s2 = vld1_u8(src + 2);
240
241 // Make use of the identity:
242 // (4*a + 8*b + 4*c) >> 4 == (a + (b << 1) + c) >> 2
243 uint16x8_t t0 = vaddl_u8(s0, s2);
244 uint16x8_t t1 = vaddl_u8(s1, s1);
245 uint16x8_t sum = vaddq_u16(t0, t1);
246 uint8x8_t res = vrshrn_n_u16(sum, 2);
247
248 vst1_u8(dst, res);
249
250 src += 8;
251 dst += 8;
252 sz -= 8;
253 }
254
255 if (sz > 0) { // Handle sz < 8 to avoid modifying out-of-bounds values.
256 uint8x8_t s0 = vld1_u8(src);
257 uint8x8_t s1 = vld1_u8(src + 1);
258 uint8x8_t s2 = vld1_u8(src + 2);
259
260 uint16x8_t t0 = vaddl_u8(s0, s2);
261 uint16x8_t t1 = vaddl_u8(s1, s1);
262 uint16x8_t sum = vaddq_u16(t0, t1);
263 uint8x8_t res = vrshrn_n_u16(sum, 2);
264
265 // Mask off out-of-bounds indices.
266 uint8x8_t current_dst = vld1_u8(dst);
267 uint8x8_t mask = vcgt_u8(vdup_n_u8(sz), vcreate_u8(0x0706050403020100));
268 res = vbsl_u8(mask, res, current_dst);
269
270 vst1_u8(dst, res);
271 }
272 } else if (strength == 2) { // Filter: {5, 6, 5}.
273 const uint8_t *src = edge + 1;
274
275 const uint8x8x3_t filter = { { vdup_n_u8(5), vdup_n_u8(6), vdup_n_u8(5) } };
276
277 while (sz >= 8) {
278 uint8x8_t s0 = vld1_u8(src);
279 uint8x8_t s1 = vld1_u8(src + 1);
280 uint8x8_t s2 = vld1_u8(src + 2);
281
282 uint16x8_t accum = vmull_u8(s0, filter.val[0]);
283 accum = vmlal_u8(accum, s1, filter.val[1]);
284 accum = vmlal_u8(accum, s2, filter.val[2]);
285 uint8x8_t res = vrshrn_n_u16(accum, 4);
286
287 vst1_u8(dst, res);
288
289 src += 8;
290 dst += 8;
291 sz -= 8;
292 }
293
294 if (sz > 0) { // Handle sz < 8 to avoid modifying out-of-bounds values.
295 uint8x8_t s0 = vld1_u8(src);
296 uint8x8_t s1 = vld1_u8(src + 1);
297 uint8x8_t s2 = vld1_u8(src + 2);
298
299 uint16x8_t accum = vmull_u8(s0, filter.val[0]);
300 accum = vmlal_u8(accum, s1, filter.val[1]);
301 accum = vmlal_u8(accum, s2, filter.val[2]);
302 uint8x8_t res = vrshrn_n_u16(accum, 4);
303
304 // Mask off out-of-bounds indices.
305 uint8x8_t current_dst = vld1_u8(dst);
306 uint8x8_t mask = vcgt_u8(vdup_n_u8(sz), vcreate_u8(0x0706050403020100));
307 res = vbsl_u8(mask, res, current_dst);
308
309 vst1_u8(dst, res);
310 }
311 } else { // Filter {2, 4, 4, 4, 2}.
312 const uint8_t *src = edge;
313
314 while (sz >= 8) {
315 uint8x8_t s0 = vld1_u8(src);
316 uint8x8_t s1 = vld1_u8(src + 1);
317 uint8x8_t s2 = vld1_u8(src + 2);
318 uint8x8_t s3 = vld1_u8(src + 3);
319 uint8x8_t s4 = vld1_u8(src + 4);
320
321 // Make use of the identity:
322 // (2*a + 4*b + 4*c + 4*d + 2*e) >> 4 == (a + ((b + c + d) << 1) + e) >> 3
323 uint16x8_t t0 = vaddl_u8(s0, s4);
324 uint16x8_t t1 = vaddl_u8(s1, s2);
325 t1 = vaddw_u8(t1, s3);
326 t1 = vaddq_u16(t1, t1);
327 uint16x8_t sum = vaddq_u16(t0, t1);
328 uint8x8_t res = vrshrn_n_u16(sum, 3);
329
330 vst1_u8(dst, res);
331
332 src += 8;
333 dst += 8;
334 sz -= 8;
335 }
336
337 if (sz > 0) { // Handle sz < 8 to avoid modifying out-of-bounds values.
338 uint8x8_t s0 = vld1_u8(src);
339 uint8x8_t s1 = vld1_u8(src + 1);
340 uint8x8_t s2 = vld1_u8(src + 2);
341 uint8x8_t s3 = vld1_u8(src + 3);
342 uint8x8_t s4 = vld1_u8(src + 4);
343
344 uint16x8_t t0 = vaddl_u8(s0, s4);
345 uint16x8_t t1 = vaddl_u8(s1, s2);
346 t1 = vaddw_u8(t1, s3);
347 t1 = vaddq_u16(t1, t1);
348 uint16x8_t sum = vaddq_u16(t0, t1);
349 uint8x8_t res = vrshrn_n_u16(sum, 3);
350
351 // Mask off out-of-bounds indices.
352 uint8x8_t current_dst = vld1_u8(dst);
353 uint8x8_t mask = vcgt_u8(vdup_n_u8(sz), vcreate_u8(0x0706050403020100));
354 res = vbsl_u8(mask, res, current_dst);
355
356 vst1_u8(dst, res);
357 }
358 }
359 }
360
av1_upsample_intra_edge_neon(uint8_t * p,int sz)361 void av1_upsample_intra_edge_neon(uint8_t *p, int sz) {
362 if (!sz) return;
363
364 assert(sz <= MAX_UPSAMPLE_SZ);
365
366 uint8_t edge[MAX_UPSAMPLE_SZ + 3];
367 const uint8_t *src = edge;
368
369 // Copy p[-1..(sz-1)] and pad out both ends.
370 edge[0] = p[-1];
371 edge[1] = p[-1];
372 memcpy(edge + 2, p, sz);
373 edge[sz + 2] = p[sz - 1];
374 p[-2] = p[-1];
375
376 uint8_t *dst = p - 1;
377
378 do {
379 uint8x8_t s0 = vld1_u8(src);
380 uint8x8_t s1 = vld1_u8(src + 1);
381 uint8x8_t s2 = vld1_u8(src + 2);
382 uint8x8_t s3 = vld1_u8(src + 3);
383
384 int16x8_t t0 = vreinterpretq_s16_u16(vaddl_u8(s0, s3));
385 int16x8_t t1 = vreinterpretq_s16_u16(vaddl_u8(s1, s2));
386 t1 = vmulq_n_s16(t1, 9);
387 t1 = vsubq_s16(t1, t0);
388
389 uint8x8x2_t res = { { vqrshrun_n_s16(t1, 4), s2 } };
390
391 vst2_u8(dst, res);
392
393 src += 8;
394 dst += 16;
395 sz -= 8;
396 } while (sz > 0);
397 }
398