1 /*
2 *
3 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
4 *
5 * This source code is subject to the terms of the BSD 2 Clause License and
6 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7 * was not distributed with this source code in the LICENSE file, you can
8 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9 * Media Patent License 1.0 was not distributed with this source code in the
10 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11 */
12
13 #include <arm_neon.h>
14 #include <assert.h>
15 #include <stdbool.h>
16
17 #include "aom_dsp/arm/mem_neon.h"
18 #include "aom_dsp/blend.h"
19 #include "aom_ports/mem.h"
20 #include "config/av1_rtcd.h"
21
diffwtd_mask_highbd_neon(uint8_t * mask,bool inverse,const uint16_t * src0,int src0_stride,const uint16_t * src1,int src1_stride,int h,int w,const unsigned int bd)22 static inline void diffwtd_mask_highbd_neon(uint8_t *mask, bool inverse,
23 const uint16_t *src0,
24 int src0_stride,
25 const uint16_t *src1,
26 int src1_stride, int h, int w,
27 const unsigned int bd) {
28 assert(DIFF_FACTOR > 0);
29 uint8x16_t max_alpha = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA);
30 uint8x16_t mask_base = vdupq_n_u8(38);
31 uint8x16_t mask_diff = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA - 38);
32
33 if (bd == 8) {
34 if (w >= 16) {
35 do {
36 uint8_t *mask_ptr = mask;
37 const uint16_t *src0_ptr = src0;
38 const uint16_t *src1_ptr = src1;
39 int width = w;
40 do {
41 uint16x8_t s0_lo = vld1q_u16(src0_ptr);
42 uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8);
43 uint16x8_t s1_lo = vld1q_u16(src1_ptr);
44 uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8);
45
46 uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo);
47 uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi);
48 uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, DIFF_FACTOR_LOG2);
49 uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, DIFF_FACTOR_LOG2);
50 uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
51
52 uint8x16_t m;
53 if (inverse) {
54 m = vqsubq_u8(mask_diff, diff);
55 } else {
56 m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha);
57 }
58
59 vst1q_u8(mask_ptr, m);
60
61 src0_ptr += 16;
62 src1_ptr += 16;
63 mask_ptr += 16;
64 width -= 16;
65 } while (width != 0);
66 mask += w;
67 src0 += src0_stride;
68 src1 += src1_stride;
69 } while (--h != 0);
70 } else if (w == 8) {
71 do {
72 uint8_t *mask_ptr = mask;
73 const uint16_t *src0_ptr = src0;
74 const uint16_t *src1_ptr = src1;
75 int width = w;
76 do {
77 uint16x8_t s0 = vld1q_u16(src0_ptr);
78 uint16x8_t s1 = vld1q_u16(src1_ptr);
79
80 uint16x8_t diff_u16 = vabdq_u16(s0, s1);
81 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
82 uint8x8_t m;
83 if (inverse) {
84 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
85 } else {
86 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
87 vget_low_u8(max_alpha));
88 }
89
90 vst1_u8(mask_ptr, m);
91
92 src0_ptr += 8;
93 src1_ptr += 8;
94 mask_ptr += 8;
95 width -= 8;
96 } while (width != 0);
97 mask += w;
98 src0 += src0_stride;
99 src1 += src1_stride;
100 } while (--h != 0);
101 } else if (w == 4) {
102 do {
103 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
104 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
105
106 uint16x8_t diff_u16 = vabdq_u16(s0, s1);
107 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
108 uint8x8_t m;
109 if (inverse) {
110 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
111 } else {
112 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
113 vget_low_u8(max_alpha));
114 }
115
116 store_u8x4_strided_x2(mask, w, m);
117
118 src0 += 2 * src0_stride;
119 src1 += 2 * src1_stride;
120 mask += 2 * w;
121 h -= 2;
122 } while (h != 0);
123 }
124 } else if (bd == 10) {
125 if (w >= 16) {
126 do {
127 uint8_t *mask_ptr = mask;
128 const uint16_t *src0_ptr = src0;
129 const uint16_t *src1_ptr = src1;
130 int width = w;
131 do {
132 uint16x8_t s0_lo = vld1q_u16(src0_ptr);
133 uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8);
134 uint16x8_t s1_lo = vld1q_u16(src1_ptr);
135 uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8);
136
137 uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo);
138 uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi);
139 uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, 2 + DIFF_FACTOR_LOG2);
140 uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, 2 + DIFF_FACTOR_LOG2);
141 uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
142
143 uint8x16_t m;
144 if (inverse) {
145 m = vqsubq_u8(mask_diff, diff);
146 } else {
147 m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha);
148 }
149
150 vst1q_u8(mask_ptr, m);
151
152 src0_ptr += 16;
153 src1_ptr += 16;
154 mask_ptr += 16;
155 width -= 16;
156 } while (width != 0);
157 mask += w;
158 src0 += src0_stride;
159 src1 += src1_stride;
160 } while (--h != 0);
161 } else if (w == 8) {
162 do {
163 uint8_t *mask_ptr = mask;
164 const uint16_t *src0_ptr = src0;
165 const uint16_t *src1_ptr = src1;
166 int width = w;
167 do {
168 uint16x8_t s0 = vld1q_u16(src0_ptr);
169 uint16x8_t s1 = vld1q_u16(src1_ptr);
170
171 uint16x8_t diff_u16 = vabdq_u16(s0, s1);
172 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 2 + DIFF_FACTOR_LOG2);
173 uint8x8_t m;
174 if (inverse) {
175 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
176 } else {
177 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
178 vget_low_u8(max_alpha));
179 }
180
181 vst1_u8(mask_ptr, m);
182
183 src0_ptr += 8;
184 src1_ptr += 8;
185 mask_ptr += 8;
186 width -= 8;
187 } while (width != 0);
188 mask += w;
189 src0 += src0_stride;
190 src1 += src1_stride;
191 } while (--h != 0);
192 } else if (w == 4) {
193 do {
194 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
195 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
196
197 uint16x8_t diff_u16 = vabdq_u16(s0, s1);
198 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 2 + DIFF_FACTOR_LOG2);
199 uint8x8_t m;
200 if (inverse) {
201 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
202 } else {
203 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
204 vget_low_u8(max_alpha));
205 }
206
207 store_u8x4_strided_x2(mask, w, m);
208
209 src0 += 2 * src0_stride;
210 src1 += 2 * src1_stride;
211 mask += 2 * w;
212 h -= 2;
213 } while (h != 0);
214 }
215 } else {
216 assert(bd == 12);
217 if (w >= 16) {
218 do {
219 uint8_t *mask_ptr = mask;
220 const uint16_t *src0_ptr = src0;
221 const uint16_t *src1_ptr = src1;
222 int width = w;
223 do {
224 uint16x8_t s0_lo = vld1q_u16(src0_ptr);
225 uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8);
226 uint16x8_t s1_lo = vld1q_u16(src1_ptr);
227 uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8);
228
229 uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo);
230 uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi);
231 uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, 4 + DIFF_FACTOR_LOG2);
232 uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, 4 + DIFF_FACTOR_LOG2);
233 uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
234
235 uint8x16_t m;
236 if (inverse) {
237 m = vqsubq_u8(mask_diff, diff);
238 } else {
239 m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha);
240 }
241
242 vst1q_u8(mask_ptr, m);
243
244 src0_ptr += 16;
245 src1_ptr += 16;
246 mask_ptr += 16;
247 width -= 16;
248 } while (width != 0);
249 mask += w;
250 src0 += src0_stride;
251 src1 += src1_stride;
252 } while (--h != 0);
253 } else if (w == 8) {
254 do {
255 uint8_t *mask_ptr = mask;
256 const uint16_t *src0_ptr = src0;
257 const uint16_t *src1_ptr = src1;
258 int width = w;
259 do {
260 uint16x8_t s0 = vld1q_u16(src0_ptr);
261 uint16x8_t s1 = vld1q_u16(src1_ptr);
262
263 uint16x8_t diff_u16 = vabdq_u16(s0, s1);
264 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 4 + DIFF_FACTOR_LOG2);
265 uint8x8_t m;
266 if (inverse) {
267 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
268 } else {
269 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
270 vget_low_u8(max_alpha));
271 }
272
273 vst1_u8(mask_ptr, m);
274
275 src0_ptr += 8;
276 src1_ptr += 8;
277 mask_ptr += 8;
278 width -= 8;
279 } while (width != 0);
280 mask += w;
281 src0 += src0_stride;
282 src1 += src1_stride;
283 } while (--h != 0);
284 } else if (w == 4) {
285 do {
286 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
287 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
288
289 uint16x8_t diff_u16 = vabdq_u16(s0, s1);
290 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 4 + DIFF_FACTOR_LOG2);
291 uint8x8_t m;
292 if (inverse) {
293 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
294 } else {
295 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
296 vget_low_u8(max_alpha));
297 }
298
299 store_u8x4_strided_x2(mask, w, m);
300
301 src0 += 2 * src0_stride;
302 src1 += 2 * src1_stride;
303 mask += 2 * w;
304 h -= 2;
305 } while (h != 0);
306 }
307 }
308 }
309
av1_build_compound_diffwtd_mask_highbd_neon(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w,int bd)310 void av1_build_compound_diffwtd_mask_highbd_neon(
311 uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0,
312 int src0_stride, const uint8_t *src1, int src1_stride, int h, int w,
313 int bd) {
314 assert(h % 4 == 0);
315 assert(w % 4 == 0);
316 assert(mask_type == DIFFWTD_38_INV || mask_type == DIFFWTD_38);
317
318 if (mask_type == DIFFWTD_38) {
319 diffwtd_mask_highbd_neon(mask, /*inverse=*/false, CONVERT_TO_SHORTPTR(src0),
320 src0_stride, CONVERT_TO_SHORTPTR(src1),
321 src1_stride, h, w, bd);
322 } else { // mask_type == DIFFWTD_38_INV
323 diffwtd_mask_highbd_neon(mask, /*inverse=*/true, CONVERT_TO_SHORTPTR(src0),
324 src0_stride, CONVERT_TO_SHORTPTR(src1),
325 src1_stride, h, w, bd);
326 }
327 }
328