xref: /aosp_15_r20/external/libaom/aom_dsp/arm/avg_pred_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom_dsp/arm/blend_neon.h"
18 #include "aom_dsp/arm/dist_wtd_avg_neon.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/blend.h"
21 
aom_comp_avg_pred_neon(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride)22 void aom_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
23                             int height, const uint8_t *ref, int ref_stride) {
24   if (width > 8) {
25     do {
26       const uint8_t *pred_ptr = pred;
27       const uint8_t *ref_ptr = ref;
28       uint8_t *comp_pred_ptr = comp_pred;
29       int w = width;
30 
31       do {
32         const uint8x16_t p = vld1q_u8(pred_ptr);
33         const uint8x16_t r = vld1q_u8(ref_ptr);
34         const uint8x16_t avg = vrhaddq_u8(p, r);
35 
36         vst1q_u8(comp_pred_ptr, avg);
37 
38         ref_ptr += 16;
39         pred_ptr += 16;
40         comp_pred_ptr += 16;
41         w -= 16;
42       } while (w != 0);
43 
44       ref += ref_stride;
45       pred += width;
46       comp_pred += width;
47     } while (--height != 0);
48   } else if (width == 8) {
49     int h = height / 2;
50 
51     do {
52       const uint8x16_t p = vld1q_u8(pred);
53       const uint8x16_t r = load_u8_8x2(ref, ref_stride);
54       const uint8x16_t avg = vrhaddq_u8(p, r);
55 
56       vst1q_u8(comp_pred, avg);
57 
58       ref += 2 * ref_stride;
59       pred += 16;
60       comp_pred += 16;
61     } while (--h != 0);
62   } else {
63     int h = height / 4;
64     assert(width == 4);
65 
66     do {
67       const uint8x16_t p = vld1q_u8(pred);
68       const uint8x16_t r = load_unaligned_u8q(ref, ref_stride);
69       const uint8x16_t avg = vrhaddq_u8(p, r);
70 
71       vst1q_u8(comp_pred, avg);
72 
73       ref += 4 * ref_stride;
74       pred += 16;
75       comp_pred += 16;
76     } while (--h != 0);
77   }
78 }
79 
aom_dist_wtd_comp_avg_pred_neon(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const DIST_WTD_COMP_PARAMS * jcp_param)80 void aom_dist_wtd_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred,
81                                      int width, int height, const uint8_t *ref,
82                                      int ref_stride,
83                                      const DIST_WTD_COMP_PARAMS *jcp_param) {
84   const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
85   const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
86 
87   if (width > 8) {
88     do {
89       const uint8_t *pred_ptr = pred;
90       const uint8_t *ref_ptr = ref;
91       uint8_t *comp_pred_ptr = comp_pred;
92       int w = width;
93 
94       do {
95         const uint8x16_t p = vld1q_u8(pred_ptr);
96         const uint8x16_t r = vld1q_u8(ref_ptr);
97 
98         const uint8x16_t wtd_avg =
99             dist_wtd_avg_u8x16(r, p, fwd_offset, bck_offset);
100 
101         vst1q_u8(comp_pred_ptr, wtd_avg);
102 
103         ref_ptr += 16;
104         pred_ptr += 16;
105         comp_pred_ptr += 16;
106         w -= 16;
107       } while (w != 0);
108 
109       ref += ref_stride;
110       pred += width;
111       comp_pred += width;
112     } while (--height != 0);
113   } else if (width == 8) {
114     int h = height / 2;
115 
116     do {
117       const uint8x16_t p = vld1q_u8(pred);
118       const uint8x16_t r = load_u8_8x2(ref, ref_stride);
119 
120       const uint8x16_t wtd_avg =
121           dist_wtd_avg_u8x16(r, p, fwd_offset, bck_offset);
122 
123       vst1q_u8(comp_pred, wtd_avg);
124 
125       ref += 2 * ref_stride;
126       pred += 16;
127       comp_pred += 16;
128     } while (--h != 0);
129   } else {
130     int h = height / 2;
131     assert(width == 4);
132 
133     do {
134       const uint8x8_t p = vld1_u8(pred);
135       const uint8x8_t r = load_unaligned_u8_4x2(ref, ref_stride);
136 
137       const uint8x8_t wtd_avg = dist_wtd_avg_u8x8(r, p, vget_low_u8(fwd_offset),
138                                                   vget_low_u8(bck_offset));
139 
140       vst1_u8(comp_pred, wtd_avg);
141 
142       ref += 2 * ref_stride;
143       pred += 8;
144       comp_pred += 8;
145     } while (--h != 0);
146   }
147 }
148 
aom_comp_mask_pred_neon(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)149 void aom_comp_mask_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
150                              int height, const uint8_t *ref, int ref_stride,
151                              const uint8_t *mask, int mask_stride,
152                              int invert_mask) {
153   const uint8_t *src0 = invert_mask ? pred : ref;
154   const uint8_t *src1 = invert_mask ? ref : pred;
155   const int src_stride0 = invert_mask ? width : ref_stride;
156   const int src_stride1 = invert_mask ? ref_stride : width;
157 
158   if (width > 8) {
159     do {
160       const uint8_t *src0_ptr = src0;
161       const uint8_t *src1_ptr = src1;
162       const uint8_t *mask_ptr = mask;
163       uint8_t *comp_pred_ptr = comp_pred;
164       int w = width;
165 
166       do {
167         const uint8x16_t s0 = vld1q_u8(src0_ptr);
168         const uint8x16_t s1 = vld1q_u8(src1_ptr);
169         const uint8x16_t m0 = vld1q_u8(mask_ptr);
170 
171         uint8x16_t blend_u8 = alpha_blend_a64_u8x16(m0, s0, s1);
172 
173         vst1q_u8(comp_pred_ptr, blend_u8);
174 
175         src0_ptr += 16;
176         src1_ptr += 16;
177         mask_ptr += 16;
178         comp_pred_ptr += 16;
179         w -= 16;
180       } while (w != 0);
181 
182       src0 += src_stride0;
183       src1 += src_stride1;
184       mask += mask_stride;
185       comp_pred += width;
186     } while (--height != 0);
187   } else if (width == 8) {
188     do {
189       const uint8x8_t s0 = vld1_u8(src0);
190       const uint8x8_t s1 = vld1_u8(src1);
191       const uint8x8_t m0 = vld1_u8(mask);
192 
193       uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, s0, s1);
194 
195       vst1_u8(comp_pred, blend_u8);
196 
197       src0 += src_stride0;
198       src1 += src_stride1;
199       mask += mask_stride;
200       comp_pred += 8;
201     } while (--height != 0);
202   } else {
203     int h = height / 2;
204     assert(width == 4);
205 
206     do {
207       const uint8x8_t s0 = load_unaligned_u8(src0, src_stride0);
208       const uint8x8_t s1 = load_unaligned_u8(src1, src_stride1);
209       const uint8x8_t m0 = load_unaligned_u8(mask, mask_stride);
210 
211       uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, s0, s1);
212 
213       vst1_u8(comp_pred, blend_u8);
214 
215       src0 += 2 * src_stride0;
216       src1 += 2 * src_stride1;
217       mask += 2 * mask_stride;
218       comp_pred += 8;
219     } while (--h != 0);
220   }
221 }
222