1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "config/aom_dsp_rtcd.h"
16
17 #include "aom_dsp/arm/blend_neon.h"
18 #include "aom_dsp/arm/dist_wtd_avg_neon.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/blend.h"
21
aom_comp_avg_pred_neon(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride)22 void aom_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
23 int height, const uint8_t *ref, int ref_stride) {
24 if (width > 8) {
25 do {
26 const uint8_t *pred_ptr = pred;
27 const uint8_t *ref_ptr = ref;
28 uint8_t *comp_pred_ptr = comp_pred;
29 int w = width;
30
31 do {
32 const uint8x16_t p = vld1q_u8(pred_ptr);
33 const uint8x16_t r = vld1q_u8(ref_ptr);
34 const uint8x16_t avg = vrhaddq_u8(p, r);
35
36 vst1q_u8(comp_pred_ptr, avg);
37
38 ref_ptr += 16;
39 pred_ptr += 16;
40 comp_pred_ptr += 16;
41 w -= 16;
42 } while (w != 0);
43
44 ref += ref_stride;
45 pred += width;
46 comp_pred += width;
47 } while (--height != 0);
48 } else if (width == 8) {
49 int h = height / 2;
50
51 do {
52 const uint8x16_t p = vld1q_u8(pred);
53 const uint8x16_t r = load_u8_8x2(ref, ref_stride);
54 const uint8x16_t avg = vrhaddq_u8(p, r);
55
56 vst1q_u8(comp_pred, avg);
57
58 ref += 2 * ref_stride;
59 pred += 16;
60 comp_pred += 16;
61 } while (--h != 0);
62 } else {
63 int h = height / 4;
64 assert(width == 4);
65
66 do {
67 const uint8x16_t p = vld1q_u8(pred);
68 const uint8x16_t r = load_unaligned_u8q(ref, ref_stride);
69 const uint8x16_t avg = vrhaddq_u8(p, r);
70
71 vst1q_u8(comp_pred, avg);
72
73 ref += 4 * ref_stride;
74 pred += 16;
75 comp_pred += 16;
76 } while (--h != 0);
77 }
78 }
79
aom_dist_wtd_comp_avg_pred_neon(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const DIST_WTD_COMP_PARAMS * jcp_param)80 void aom_dist_wtd_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred,
81 int width, int height, const uint8_t *ref,
82 int ref_stride,
83 const DIST_WTD_COMP_PARAMS *jcp_param) {
84 const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
85 const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
86
87 if (width > 8) {
88 do {
89 const uint8_t *pred_ptr = pred;
90 const uint8_t *ref_ptr = ref;
91 uint8_t *comp_pred_ptr = comp_pred;
92 int w = width;
93
94 do {
95 const uint8x16_t p = vld1q_u8(pred_ptr);
96 const uint8x16_t r = vld1q_u8(ref_ptr);
97
98 const uint8x16_t wtd_avg =
99 dist_wtd_avg_u8x16(r, p, fwd_offset, bck_offset);
100
101 vst1q_u8(comp_pred_ptr, wtd_avg);
102
103 ref_ptr += 16;
104 pred_ptr += 16;
105 comp_pred_ptr += 16;
106 w -= 16;
107 } while (w != 0);
108
109 ref += ref_stride;
110 pred += width;
111 comp_pred += width;
112 } while (--height != 0);
113 } else if (width == 8) {
114 int h = height / 2;
115
116 do {
117 const uint8x16_t p = vld1q_u8(pred);
118 const uint8x16_t r = load_u8_8x2(ref, ref_stride);
119
120 const uint8x16_t wtd_avg =
121 dist_wtd_avg_u8x16(r, p, fwd_offset, bck_offset);
122
123 vst1q_u8(comp_pred, wtd_avg);
124
125 ref += 2 * ref_stride;
126 pred += 16;
127 comp_pred += 16;
128 } while (--h != 0);
129 } else {
130 int h = height / 2;
131 assert(width == 4);
132
133 do {
134 const uint8x8_t p = vld1_u8(pred);
135 const uint8x8_t r = load_unaligned_u8_4x2(ref, ref_stride);
136
137 const uint8x8_t wtd_avg = dist_wtd_avg_u8x8(r, p, vget_low_u8(fwd_offset),
138 vget_low_u8(bck_offset));
139
140 vst1_u8(comp_pred, wtd_avg);
141
142 ref += 2 * ref_stride;
143 pred += 8;
144 comp_pred += 8;
145 } while (--h != 0);
146 }
147 }
148
aom_comp_mask_pred_neon(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)149 void aom_comp_mask_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
150 int height, const uint8_t *ref, int ref_stride,
151 const uint8_t *mask, int mask_stride,
152 int invert_mask) {
153 const uint8_t *src0 = invert_mask ? pred : ref;
154 const uint8_t *src1 = invert_mask ? ref : pred;
155 const int src_stride0 = invert_mask ? width : ref_stride;
156 const int src_stride1 = invert_mask ? ref_stride : width;
157
158 if (width > 8) {
159 do {
160 const uint8_t *src0_ptr = src0;
161 const uint8_t *src1_ptr = src1;
162 const uint8_t *mask_ptr = mask;
163 uint8_t *comp_pred_ptr = comp_pred;
164 int w = width;
165
166 do {
167 const uint8x16_t s0 = vld1q_u8(src0_ptr);
168 const uint8x16_t s1 = vld1q_u8(src1_ptr);
169 const uint8x16_t m0 = vld1q_u8(mask_ptr);
170
171 uint8x16_t blend_u8 = alpha_blend_a64_u8x16(m0, s0, s1);
172
173 vst1q_u8(comp_pred_ptr, blend_u8);
174
175 src0_ptr += 16;
176 src1_ptr += 16;
177 mask_ptr += 16;
178 comp_pred_ptr += 16;
179 w -= 16;
180 } while (w != 0);
181
182 src0 += src_stride0;
183 src1 += src_stride1;
184 mask += mask_stride;
185 comp_pred += width;
186 } while (--height != 0);
187 } else if (width == 8) {
188 do {
189 const uint8x8_t s0 = vld1_u8(src0);
190 const uint8x8_t s1 = vld1_u8(src1);
191 const uint8x8_t m0 = vld1_u8(mask);
192
193 uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, s0, s1);
194
195 vst1_u8(comp_pred, blend_u8);
196
197 src0 += src_stride0;
198 src1 += src_stride1;
199 mask += mask_stride;
200 comp_pred += 8;
201 } while (--height != 0);
202 } else {
203 int h = height / 2;
204 assert(width == 4);
205
206 do {
207 const uint8x8_t s0 = load_unaligned_u8(src0, src_stride0);
208 const uint8x8_t s1 = load_unaligned_u8(src1, src_stride1);
209 const uint8x8_t m0 = load_unaligned_u8(mask, mask_stride);
210
211 uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, s0, s1);
212
213 vst1_u8(comp_pred, blend_u8);
214
215 src0 += 2 * src_stride0;
216 src1 += 2 * src_stride1;
217 mask += 2 * mask_stride;
218 comp_pred += 8;
219 } while (--h != 0);
220 }
221 }
222