xref: /aosp_15_r20/external/libaom/av1/common/arm/highbd_reconinter_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  *
3  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
4  *
5  * This source code is subject to the terms of the BSD 2 Clause License and
6  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7  * was not distributed with this source code in the LICENSE file, you can
8  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9  * Media Patent License 1.0 was not distributed with this source code in the
10  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11  */
12 
13 #include <arm_neon.h>
14 #include <assert.h>
15 #include <stdbool.h>
16 
17 #include "aom_dsp/arm/mem_neon.h"
18 #include "aom_dsp/blend.h"
19 #include "aom_ports/mem.h"
20 #include "config/av1_rtcd.h"
21 
diffwtd_mask_highbd_neon(uint8_t * mask,bool inverse,const uint16_t * src0,int src0_stride,const uint16_t * src1,int src1_stride,int h,int w,const unsigned int bd)22 static inline void diffwtd_mask_highbd_neon(uint8_t *mask, bool inverse,
23                                             const uint16_t *src0,
24                                             int src0_stride,
25                                             const uint16_t *src1,
26                                             int src1_stride, int h, int w,
27                                             const unsigned int bd) {
28   assert(DIFF_FACTOR > 0);
29   uint8x16_t max_alpha = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA);
30   uint8x16_t mask_base = vdupq_n_u8(38);
31   uint8x16_t mask_diff = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA - 38);
32 
33   if (bd == 8) {
34     if (w >= 16) {
35       do {
36         uint8_t *mask_ptr = mask;
37         const uint16_t *src0_ptr = src0;
38         const uint16_t *src1_ptr = src1;
39         int width = w;
40         do {
41           uint16x8_t s0_lo = vld1q_u16(src0_ptr);
42           uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8);
43           uint16x8_t s1_lo = vld1q_u16(src1_ptr);
44           uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8);
45 
46           uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo);
47           uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi);
48           uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, DIFF_FACTOR_LOG2);
49           uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, DIFF_FACTOR_LOG2);
50           uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
51 
52           uint8x16_t m;
53           if (inverse) {
54             m = vqsubq_u8(mask_diff, diff);
55           } else {
56             m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha);
57           }
58 
59           vst1q_u8(mask_ptr, m);
60 
61           src0_ptr += 16;
62           src1_ptr += 16;
63           mask_ptr += 16;
64           width -= 16;
65         } while (width != 0);
66         mask += w;
67         src0 += src0_stride;
68         src1 += src1_stride;
69       } while (--h != 0);
70     } else if (w == 8) {
71       do {
72         uint8_t *mask_ptr = mask;
73         const uint16_t *src0_ptr = src0;
74         const uint16_t *src1_ptr = src1;
75         int width = w;
76         do {
77           uint16x8_t s0 = vld1q_u16(src0_ptr);
78           uint16x8_t s1 = vld1q_u16(src1_ptr);
79 
80           uint16x8_t diff_u16 = vabdq_u16(s0, s1);
81           uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
82           uint8x8_t m;
83           if (inverse) {
84             m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
85           } else {
86             m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
87                         vget_low_u8(max_alpha));
88           }
89 
90           vst1_u8(mask_ptr, m);
91 
92           src0_ptr += 8;
93           src1_ptr += 8;
94           mask_ptr += 8;
95           width -= 8;
96         } while (width != 0);
97         mask += w;
98         src0 += src0_stride;
99         src1 += src1_stride;
100       } while (--h != 0);
101     } else if (w == 4) {
102       do {
103         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
104         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
105 
106         uint16x8_t diff_u16 = vabdq_u16(s0, s1);
107         uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
108         uint8x8_t m;
109         if (inverse) {
110           m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
111         } else {
112           m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
113                       vget_low_u8(max_alpha));
114         }
115 
116         store_u8x4_strided_x2(mask, w, m);
117 
118         src0 += 2 * src0_stride;
119         src1 += 2 * src1_stride;
120         mask += 2 * w;
121         h -= 2;
122       } while (h != 0);
123     }
124   } else if (bd == 10) {
125     if (w >= 16) {
126       do {
127         uint8_t *mask_ptr = mask;
128         const uint16_t *src0_ptr = src0;
129         const uint16_t *src1_ptr = src1;
130         int width = w;
131         do {
132           uint16x8_t s0_lo = vld1q_u16(src0_ptr);
133           uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8);
134           uint16x8_t s1_lo = vld1q_u16(src1_ptr);
135           uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8);
136 
137           uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo);
138           uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi);
139           uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, 2 + DIFF_FACTOR_LOG2);
140           uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, 2 + DIFF_FACTOR_LOG2);
141           uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
142 
143           uint8x16_t m;
144           if (inverse) {
145             m = vqsubq_u8(mask_diff, diff);
146           } else {
147             m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha);
148           }
149 
150           vst1q_u8(mask_ptr, m);
151 
152           src0_ptr += 16;
153           src1_ptr += 16;
154           mask_ptr += 16;
155           width -= 16;
156         } while (width != 0);
157         mask += w;
158         src0 += src0_stride;
159         src1 += src1_stride;
160       } while (--h != 0);
161     } else if (w == 8) {
162       do {
163         uint8_t *mask_ptr = mask;
164         const uint16_t *src0_ptr = src0;
165         const uint16_t *src1_ptr = src1;
166         int width = w;
167         do {
168           uint16x8_t s0 = vld1q_u16(src0_ptr);
169           uint16x8_t s1 = vld1q_u16(src1_ptr);
170 
171           uint16x8_t diff_u16 = vabdq_u16(s0, s1);
172           uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 2 + DIFF_FACTOR_LOG2);
173           uint8x8_t m;
174           if (inverse) {
175             m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
176           } else {
177             m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
178                         vget_low_u8(max_alpha));
179           }
180 
181           vst1_u8(mask_ptr, m);
182 
183           src0_ptr += 8;
184           src1_ptr += 8;
185           mask_ptr += 8;
186           width -= 8;
187         } while (width != 0);
188         mask += w;
189         src0 += src0_stride;
190         src1 += src1_stride;
191       } while (--h != 0);
192     } else if (w == 4) {
193       do {
194         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
195         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
196 
197         uint16x8_t diff_u16 = vabdq_u16(s0, s1);
198         uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 2 + DIFF_FACTOR_LOG2);
199         uint8x8_t m;
200         if (inverse) {
201           m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
202         } else {
203           m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
204                       vget_low_u8(max_alpha));
205         }
206 
207         store_u8x4_strided_x2(mask, w, m);
208 
209         src0 += 2 * src0_stride;
210         src1 += 2 * src1_stride;
211         mask += 2 * w;
212         h -= 2;
213       } while (h != 0);
214     }
215   } else {
216     assert(bd == 12);
217     if (w >= 16) {
218       do {
219         uint8_t *mask_ptr = mask;
220         const uint16_t *src0_ptr = src0;
221         const uint16_t *src1_ptr = src1;
222         int width = w;
223         do {
224           uint16x8_t s0_lo = vld1q_u16(src0_ptr);
225           uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8);
226           uint16x8_t s1_lo = vld1q_u16(src1_ptr);
227           uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8);
228 
229           uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo);
230           uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi);
231           uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, 4 + DIFF_FACTOR_LOG2);
232           uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, 4 + DIFF_FACTOR_LOG2);
233           uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
234 
235           uint8x16_t m;
236           if (inverse) {
237             m = vqsubq_u8(mask_diff, diff);
238           } else {
239             m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha);
240           }
241 
242           vst1q_u8(mask_ptr, m);
243 
244           src0_ptr += 16;
245           src1_ptr += 16;
246           mask_ptr += 16;
247           width -= 16;
248         } while (width != 0);
249         mask += w;
250         src0 += src0_stride;
251         src1 += src1_stride;
252       } while (--h != 0);
253     } else if (w == 8) {
254       do {
255         uint8_t *mask_ptr = mask;
256         const uint16_t *src0_ptr = src0;
257         const uint16_t *src1_ptr = src1;
258         int width = w;
259         do {
260           uint16x8_t s0 = vld1q_u16(src0_ptr);
261           uint16x8_t s1 = vld1q_u16(src1_ptr);
262 
263           uint16x8_t diff_u16 = vabdq_u16(s0, s1);
264           uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 4 + DIFF_FACTOR_LOG2);
265           uint8x8_t m;
266           if (inverse) {
267             m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
268           } else {
269             m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
270                         vget_low_u8(max_alpha));
271           }
272 
273           vst1_u8(mask_ptr, m);
274 
275           src0_ptr += 8;
276           src1_ptr += 8;
277           mask_ptr += 8;
278           width -= 8;
279         } while (width != 0);
280         mask += w;
281         src0 += src0_stride;
282         src1 += src1_stride;
283       } while (--h != 0);
284     } else if (w == 4) {
285       do {
286         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
287         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
288 
289         uint16x8_t diff_u16 = vabdq_u16(s0, s1);
290         uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 4 + DIFF_FACTOR_LOG2);
291         uint8x8_t m;
292         if (inverse) {
293           m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
294         } else {
295           m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
296                       vget_low_u8(max_alpha));
297         }
298 
299         store_u8x4_strided_x2(mask, w, m);
300 
301         src0 += 2 * src0_stride;
302         src1 += 2 * src1_stride;
303         mask += 2 * w;
304         h -= 2;
305       } while (h != 0);
306     }
307   }
308 }
309 
av1_build_compound_diffwtd_mask_highbd_neon(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w,int bd)310 void av1_build_compound_diffwtd_mask_highbd_neon(
311     uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0,
312     int src0_stride, const uint8_t *src1, int src1_stride, int h, int w,
313     int bd) {
314   assert(h % 4 == 0);
315   assert(w % 4 == 0);
316   assert(mask_type == DIFFWTD_38_INV || mask_type == DIFFWTD_38);
317 
318   if (mask_type == DIFFWTD_38) {
319     diffwtd_mask_highbd_neon(mask, /*inverse=*/false, CONVERT_TO_SHORTPTR(src0),
320                              src0_stride, CONVERT_TO_SHORTPTR(src1),
321                              src1_stride, h, w, bd);
322   } else {  // mask_type == DIFFWTD_38_INV
323     diffwtd_mask_highbd_neon(mask, /*inverse=*/true, CONVERT_TO_SHORTPTR(src0),
324                              src0_stride, CONVERT_TO_SHORTPTR(src1),
325                              src1_stride, h, w, bd);
326   }
327 }
328