xref: /aosp_15_r20/external/libaom/av1/encoder/arm/quantize_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include <assert.h>
15 #include <math.h>
16 
17 #include "config/aom_config.h"
18 
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 #include "aom_mem/aom_mem.h"
22 
23 #include "av1/common/quant_common.h"
24 #include "av1/common/seg_common.h"
25 
26 #include "av1/encoder/av1_quantize.h"
27 #include "av1/encoder/encoder.h"
28 #include "av1/encoder/rd.h"
29 
get_max_eob(int16x8_t v_eobmax)30 static inline uint16_t get_max_eob(int16x8_t v_eobmax) {
31 #if AOM_ARCH_AARCH64
32   return (uint16_t)vmaxvq_s16(v_eobmax);
33 #else
34   const int16x4_t v_eobmax_3210 =
35       vmax_s16(vget_low_s16(v_eobmax), vget_high_s16(v_eobmax));
36   const int64x1_t v_eobmax_xx32 =
37       vshr_n_s64(vreinterpret_s64_s16(v_eobmax_3210), 32);
38   const int16x4_t v_eobmax_tmp =
39       vmax_s16(v_eobmax_3210, vreinterpret_s16_s64(v_eobmax_xx32));
40   const int64x1_t v_eobmax_xxx3 =
41       vshr_n_s64(vreinterpret_s64_s16(v_eobmax_tmp), 16);
42   const int16x4_t v_eobmax_final =
43       vmax_s16(v_eobmax_tmp, vreinterpret_s16_s64(v_eobmax_xxx3));
44   return (uint16_t)vget_lane_s16(v_eobmax_final, 0);
45 #endif
46 }
47 
get_max_lane_eob(const int16_t * iscan,int16x8_t v_eobmax,uint16x8_t v_mask)48 static inline int16x8_t get_max_lane_eob(const int16_t *iscan,
49                                          int16x8_t v_eobmax,
50                                          uint16x8_t v_mask) {
51   const int16x8_t v_iscan = vld1q_s16(&iscan[0]);
52   const int16x8_t v_iscan_plus1 = vaddq_s16(v_iscan, vdupq_n_s16(1));
53   const int16x8_t v_nz_iscan = vbslq_s16(v_mask, v_iscan_plus1, vdupq_n_s16(0));
54   return vmaxq_s16(v_eobmax, v_nz_iscan);
55 }
56 
quantize_fp_8(const tran_low_t * coeff_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,int16x8_t v_quant,int16x8_t v_dequant,int16x8_t v_round,int16x8_t v_zero)57 static inline uint16x8_t quantize_fp_8(const tran_low_t *coeff_ptr,
58                                        tran_low_t *qcoeff_ptr,
59                                        tran_low_t *dqcoeff_ptr,
60                                        int16x8_t v_quant, int16x8_t v_dequant,
61                                        int16x8_t v_round, int16x8_t v_zero) {
62   const int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]);
63   const int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15);
64   const int16x8_t v_abs = vabsq_s16(v_coeff);
65   const int16x8_t v_tmp = vqaddq_s16(v_abs, v_round);
66   const int16x8_t v_tmp2 = vshrq_n_s16(vqdmulhq_s16(v_tmp, v_quant), 1);
67   const uint16x8_t v_nz_mask = vcgtq_s16(v_tmp2, v_zero);
68   const int16x8_t v_qcoeff_a = veorq_s16(v_tmp2, v_coeff_sign);
69   const int16x8_t v_qcoeff = vsubq_s16(v_qcoeff_a, v_coeff_sign);
70   const int16x8_t v_dqcoeff = vmulq_s16(v_qcoeff, v_dequant);
71   store_s16q_to_tran_low(&qcoeff_ptr[0], v_qcoeff);
72   store_s16q_to_tran_low(&dqcoeff_ptr[0], v_dqcoeff);
73   return v_nz_mask;
74 }
75 
av1_quantize_fp_neon(const tran_low_t * coeff_ptr,intptr_t count,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan)76 void av1_quantize_fp_neon(const tran_low_t *coeff_ptr, intptr_t count,
77                           const int16_t *zbin_ptr, const int16_t *round_ptr,
78                           const int16_t *quant_ptr,
79                           const int16_t *quant_shift_ptr,
80                           tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
81                           const int16_t *dequant_ptr, uint16_t *eob_ptr,
82                           const int16_t *scan, const int16_t *iscan) {
83   // TODO(jingning) Decide the need of these arguments after the
84   // quantization process is completed.
85   (void)zbin_ptr;
86   (void)quant_shift_ptr;
87   (void)scan;
88 
89   // Quantization pass: All coefficients with index >= zero_flag are
90   // skippable. Note: zero_flag can be zero.
91   const int16x8_t v_zero = vdupq_n_s16(0);
92   int16x8_t v_quant = vld1q_s16(quant_ptr);
93   int16x8_t v_dequant = vld1q_s16(dequant_ptr);
94   int16x8_t v_round = vld1q_s16(round_ptr);
95   int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1);
96   uint16x8_t v_nz_mask;
97   // process dc and the first seven ac coeffs
98   v_nz_mask = quantize_fp_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant,
99                             v_dequant, v_round, v_zero);
100   v_eobmax_76543210 = get_max_lane_eob(&iscan[0], v_eobmax_76543210, v_nz_mask);
101   // overwrite the dc constants with ac constants
102   v_quant = vdupq_lane_s16(vget_low_s16(v_quant), 1);
103   v_dequant = vdupq_lane_s16(vget_low_s16(v_dequant), 1);
104   v_round = vdupq_lane_s16(vget_low_s16(v_round), 1);
105 
106   count -= 8;
107   // now process the rest of the ac coeffs
108   do {
109     coeff_ptr += 8;
110     qcoeff_ptr += 8;
111     dqcoeff_ptr += 8;
112     iscan += 8;
113     v_nz_mask = quantize_fp_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant,
114                               v_dequant, v_round, v_zero);
115     v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask);
116     count -= 8;
117   } while (count > 0);
118   *eob_ptr = get_max_eob(v_eobmax_76543210);
119 }
120 
quantize_lp_8(const int16_t * coeff_ptr,int16_t * qcoeff_ptr,int16_t * dqcoeff_ptr,int16x8_t v_quant,int16x8_t v_dequant,int16x8_t v_round,int16x8_t v_zero)121 static inline uint16x8_t quantize_lp_8(const int16_t *coeff_ptr,
122                                        int16_t *qcoeff_ptr,
123                                        int16_t *dqcoeff_ptr, int16x8_t v_quant,
124                                        int16x8_t v_dequant, int16x8_t v_round,
125                                        int16x8_t v_zero) {
126   const int16x8_t v_coeff = vld1q_s16(&coeff_ptr[0]);
127   const int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15);
128   const int16x8_t v_abs = vabsq_s16(v_coeff);
129   const int16x8_t v_tmp = vqaddq_s16(v_abs, v_round);
130   const int16x8_t v_tmp2 = vshrq_n_s16(vqdmulhq_s16(v_tmp, v_quant), 1);
131   const uint16x8_t v_nz_mask = vcgtq_s16(v_tmp2, v_zero);
132   const int16x8_t v_qcoeff_a = veorq_s16(v_tmp2, v_coeff_sign);
133   const int16x8_t v_qcoeff = vsubq_s16(v_qcoeff_a, v_coeff_sign);
134   const int16x8_t v_dqcoeff = vmulq_s16(v_qcoeff, v_dequant);
135   vst1q_s16(qcoeff_ptr, v_qcoeff);
136   vst1q_s16(dqcoeff_ptr, v_dqcoeff);
137   return v_nz_mask;
138 }
139 
av1_quantize_lp_neon(const int16_t * coeff_ptr,intptr_t n_coeffs,const int16_t * round_ptr,const int16_t * quant_ptr,int16_t * qcoeff_ptr,int16_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan)140 void av1_quantize_lp_neon(const int16_t *coeff_ptr, intptr_t n_coeffs,
141                           const int16_t *round_ptr, const int16_t *quant_ptr,
142                           int16_t *qcoeff_ptr, int16_t *dqcoeff_ptr,
143                           const int16_t *dequant_ptr, uint16_t *eob_ptr,
144                           const int16_t *scan, const int16_t *iscan) {
145   (void)scan;
146   // Quantization pass: All coefficients with index >= zero_flag are
147   // skippable. Note: zero_flag can be zero.
148   const int16x8_t v_zero = vdupq_n_s16(0);
149   int16x8_t v_quant = vld1q_s16(quant_ptr);
150   int16x8_t v_dequant = vld1q_s16(dequant_ptr);
151   int16x8_t v_round = vld1q_s16(round_ptr);
152   int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1);
153   uint16x8_t v_nz_mask;
154   intptr_t count = n_coeffs;
155 
156   // process dc and the first seven ac coeffs
157   v_nz_mask = quantize_lp_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant,
158                             v_dequant, v_round, v_zero);
159   v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask);
160   // overwrite the dc constants with ac constants
161   v_quant = vdupq_lane_s16(vget_low_s16(v_quant), 1);
162   v_dequant = vdupq_lane_s16(vget_low_s16(v_dequant), 1);
163   v_round = vdupq_lane_s16(vget_low_s16(v_round), 1);
164 
165   count -= 8;
166   // now process the rest of the ac coeffs
167   do {
168     coeff_ptr += 8;
169     qcoeff_ptr += 8;
170     dqcoeff_ptr += 8;
171     iscan += 8;
172     v_nz_mask = quantize_lp_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant,
173                               v_dequant, v_round, v_zero);
174     v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask);
175     count -= 8;
176   } while (count != 0);
177   *eob_ptr = get_max_eob(v_eobmax_76543210);
178 }
179 
quantize_fp_logscale_8(const tran_low_t * coeff_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,int16x8_t v_quant,int16x8_t v_dequant,int16x8_t v_round,int16x8_t v_zero,int log_scale)180 static AOM_FORCE_INLINE uint16x8_t quantize_fp_logscale_8(
181     const tran_low_t *coeff_ptr, tran_low_t *qcoeff_ptr,
182     tran_low_t *dqcoeff_ptr, int16x8_t v_quant, int16x8_t v_dequant,
183     int16x8_t v_round, int16x8_t v_zero, int log_scale) {
184   const int16x8_t v_log_scale_minus_1 = vdupq_n_s16(log_scale - 1);
185   const int16x8_t v_neg_log_scale_plus_1 = vdupq_n_s16(-(1 + log_scale));
186   const int16x8_t v_coeff = load_tran_low_to_s16q(coeff_ptr);
187   const int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15);
188   const int16x8_t v_abs_coeff = vabsq_s16(v_coeff);
189   const uint16x8_t v_mask =
190       vcgeq_s16(v_abs_coeff, vshlq_s16(v_dequant, v_neg_log_scale_plus_1));
191   // const int64_t tmp = vmask ? (int64_t)abs_coeff + log_scaled_round : 0
192   const int16x8_t v_tmp = vandq_s16(vqaddq_s16(v_abs_coeff, v_round),
193                                     vreinterpretq_s16_u16(v_mask));
194   const int16x8_t v_tmp2 =
195       vqdmulhq_s16(vshlq_s16(v_tmp, v_log_scale_minus_1), v_quant);
196   const uint16x8_t v_nz_mask = vcgtq_s16(v_tmp2, v_zero);
197   const int16x8_t v_qcoeff =
198       vsubq_s16(veorq_s16(v_tmp2, v_coeff_sign), v_coeff_sign);
199   // Multiplying by dequant here will use all 16 bits. Cast to unsigned before
200   // shifting right. (vshlq_s16 will shift right if shift value is negative)
201   const uint16x8_t v_abs_dqcoeff =
202       vshlq_u16(vreinterpretq_u16_s16(vmulq_s16(v_tmp2, v_dequant)),
203                 vdupq_n_s16(-log_scale));
204   const int16x8_t v_dqcoeff =
205       vsubq_s16(veorq_s16(vreinterpretq_s16_u16(v_abs_dqcoeff), v_coeff_sign),
206                 v_coeff_sign);
207   store_s16q_to_tran_low(qcoeff_ptr, v_qcoeff);
208   store_s16q_to_tran_low(dqcoeff_ptr, v_dqcoeff);
209   return v_nz_mask;
210 }
211 
quantize_fp_logscale2_8(const tran_low_t * coeff_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,int16x8_t v_quant,int16x8_t v_dequant,int16x8_t v_round,int16x8_t v_zero)212 static AOM_FORCE_INLINE uint16x8_t quantize_fp_logscale2_8(
213     const tran_low_t *coeff_ptr, tran_low_t *qcoeff_ptr,
214     tran_low_t *dqcoeff_ptr, int16x8_t v_quant, int16x8_t v_dequant,
215     int16x8_t v_round, int16x8_t v_zero) {
216   const int16x8_t v_coeff = load_tran_low_to_s16q(coeff_ptr);
217   const int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15);
218   const int16x8_t v_abs_coeff = vabsq_s16(v_coeff);
219   const uint16x8_t v_mask =
220       vcgeq_u16(vshlq_n_u16(vreinterpretq_u16_s16(v_abs_coeff), 1),
221                 vshrq_n_u16(vreinterpretq_u16_s16(v_dequant), 2));
222   // abs_coeff = vmask ? (int64_t)abs_coeff + log_scaled_round : 0
223   const int16x8_t v_tmp = vandq_s16(vqaddq_s16(v_abs_coeff, v_round),
224                                     vreinterpretq_s16_u16(v_mask));
225   // tmp32 = (int)((abs_coeff * quant_ptr[rc != 0]) >> (16 - log_scale));
226   const int16x8_t v_tmp2 =
227       vorrq_s16(vshlq_n_s16(vqdmulhq_s16(v_tmp, v_quant), 1),
228                 vreinterpretq_s16_u16(vshrq_n_u16(
229                     vreinterpretq_u16_s16(vmulq_s16(v_tmp, v_quant)), 14)));
230   const uint16x8_t v_nz_mask = vcgtq_s16(v_tmp2, v_zero);
231   const int16x8_t v_qcoeff =
232       vsubq_s16(veorq_s16(v_tmp2, v_coeff_sign), v_coeff_sign);
233   // const tran_low_t abs_dqcoeff = (tmp32 * dequant_ptr[rc != 0]) >> log_scale;
234   const int16x8_t v_abs_dqcoeff =
235       vorrq_s16(vshlq_n_s16(vqdmulhq_s16(v_tmp2, v_dequant), 13),
236                 vreinterpretq_s16_u16(vshrq_n_u16(
237                     vreinterpretq_u16_s16(vmulq_s16(v_tmp2, v_dequant)), 2)));
238   const int16x8_t v_dqcoeff =
239       vsubq_s16(veorq_s16(v_abs_dqcoeff, v_coeff_sign), v_coeff_sign);
240   store_s16q_to_tran_low(qcoeff_ptr, v_qcoeff);
241   store_s16q_to_tran_low(dqcoeff_ptr, v_dqcoeff);
242   return v_nz_mask;
243 }
244 
quantize_fp_no_qmatrix_neon(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * round_ptr,const int16_t * quant_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * iscan,int log_scale)245 static AOM_FORCE_INLINE void quantize_fp_no_qmatrix_neon(
246     const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *round_ptr,
247     const int16_t *quant_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
248     const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *iscan,
249     int log_scale) {
250   const int16x8_t v_zero = vdupq_n_s16(0);
251   int16x8_t v_quant = vld1q_s16(quant_ptr);
252   int16x8_t v_dequant = vld1q_s16(dequant_ptr);
253   const int16x8_t v_round_no_scale = vld1q_s16(round_ptr);
254   int16x8_t v_round =
255       vqrdmulhq_n_s16(v_round_no_scale, (int16_t)(1 << (15 - log_scale)));
256   int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1);
257   intptr_t non_zero_count = n_coeffs;
258 
259   assert(n_coeffs > 16);
260   // Pre-scan pass
261   const int16x8_t v_dequant_scaled =
262       vshlq_s16(v_dequant, vdupq_n_s16(-(1 + log_scale)));
263   const int16x8_t v_zbin_s16 =
264       vdupq_lane_s16(vget_low_s16(v_dequant_scaled), 1);
265   intptr_t i = n_coeffs;
266   do {
267     const int16x8_t v_coeff_a = load_tran_low_to_s16q(coeff_ptr + i - 8);
268     const int16x8_t v_coeff_b = load_tran_low_to_s16q(coeff_ptr + i - 16);
269     const int16x8_t v_abs_coeff_a = vabsq_s16(v_coeff_a);
270     const int16x8_t v_abs_coeff_b = vabsq_s16(v_coeff_b);
271     const uint16x8_t v_mask_a = vcgeq_s16(v_abs_coeff_a, v_zbin_s16);
272     const uint16x8_t v_mask_b = vcgeq_s16(v_abs_coeff_b, v_zbin_s16);
273     // If the coefficient is in the base ZBIN range, then discard.
274     if (horizontal_long_add_u16x8(v_mask_a, v_mask_b) == 0) {
275       non_zero_count -= 16;
276     } else {
277       break;
278     }
279     i -= 16;
280   } while (i > 0);
281 
282   const intptr_t remaining_zcoeffs = n_coeffs - non_zero_count;
283   memset(qcoeff_ptr + non_zero_count, 0,
284          remaining_zcoeffs * sizeof(*qcoeff_ptr));
285   memset(dqcoeff_ptr + non_zero_count, 0,
286          remaining_zcoeffs * sizeof(*dqcoeff_ptr));
287 
288   // process dc and the first seven ac coeffs
289   uint16x8_t v_nz_mask;
290   if (log_scale == 2) {
291     v_nz_mask = quantize_fp_logscale2_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr,
292                                         v_quant, v_dequant, v_round, v_zero);
293   } else {
294     v_nz_mask =
295         quantize_fp_logscale_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant,
296                                v_dequant, v_round, v_zero, log_scale);
297   }
298   v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask);
299   // overwrite the dc constants with ac constants
300   v_quant = vdupq_lane_s16(vget_low_s16(v_quant), 1);
301   v_dequant = vdupq_lane_s16(vget_low_s16(v_dequant), 1);
302   v_round = vdupq_lane_s16(vget_low_s16(v_round), 1);
303 
304   for (intptr_t count = non_zero_count - 8; count > 0; count -= 8) {
305     coeff_ptr += 8;
306     qcoeff_ptr += 8;
307     dqcoeff_ptr += 8;
308     iscan += 8;
309     if (log_scale == 2) {
310       v_nz_mask = quantize_fp_logscale2_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr,
311                                           v_quant, v_dequant, v_round, v_zero);
312     } else {
313       v_nz_mask =
314           quantize_fp_logscale_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant,
315                                  v_dequant, v_round, v_zero, log_scale);
316     }
317     v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask);
318   }
319   *eob_ptr = get_max_eob(v_eobmax_76543210);
320 }
321 
av1_quantize_fp_32x32_neon(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan)322 void av1_quantize_fp_32x32_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs,
323                                 const int16_t *zbin_ptr,
324                                 const int16_t *round_ptr,
325                                 const int16_t *quant_ptr,
326                                 const int16_t *quant_shift_ptr,
327                                 tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
328                                 const int16_t *dequant_ptr, uint16_t *eob_ptr,
329                                 const int16_t *scan, const int16_t *iscan) {
330   (void)zbin_ptr;
331   (void)quant_shift_ptr;
332   (void)scan;
333   quantize_fp_no_qmatrix_neon(coeff_ptr, n_coeffs, round_ptr, quant_ptr,
334                               qcoeff_ptr, dqcoeff_ptr, dequant_ptr, eob_ptr,
335                               iscan, 1);
336 }
337 
av1_quantize_fp_64x64_neon(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan)338 void av1_quantize_fp_64x64_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs,
339                                 const int16_t *zbin_ptr,
340                                 const int16_t *round_ptr,
341                                 const int16_t *quant_ptr,
342                                 const int16_t *quant_shift_ptr,
343                                 tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
344                                 const int16_t *dequant_ptr, uint16_t *eob_ptr,
345                                 const int16_t *scan, const int16_t *iscan) {
346   (void)zbin_ptr;
347   (void)quant_shift_ptr;
348   (void)scan;
349   quantize_fp_no_qmatrix_neon(coeff_ptr, n_coeffs, round_ptr, quant_ptr,
350                               qcoeff_ptr, dqcoeff_ptr, dequant_ptr, eob_ptr,
351                               iscan, 2);
352 }
353 
aom_quantize_b_neon(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan)354 void aom_quantize_b_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs,
355                          const int16_t *zbin_ptr, const int16_t *round_ptr,
356                          const int16_t *quant_ptr,
357                          const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr,
358                          tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr,
359                          uint16_t *eob_ptr, const int16_t *scan,
360                          const int16_t *iscan) {
361   (void)quant_shift_ptr;
362   (void)scan;
363 
364   const int zbins[2] = { zbin_ptr[0], zbin_ptr[1] };
365 
366   memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr));
367   memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr));
368 
369   const int16x8_t zero = vdupq_n_s16(0);
370   int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero));
371 
372   int16x8_t vzbins = vdupq_n_s16(zbins[1]), vround = vdupq_n_s16(round_ptr[1]);
373   int16x8_t vdequant = vdupq_n_s16(dequant_ptr[1]);
374   int16x8_t vquant = vdupq_n_s16(quant_ptr[1]);
375   int16x8_t vquant_shift = vdupq_n_s16(quant_shift_ptr[1]);
376 
377   int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]);
378   int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15);
379   int16x8_t v_abs = vabsq_s16(v_coeff);
380 
381   vzbins = vsetq_lane_s16(zbins[0], vzbins, 0);
382 
383   uint16x8_t vcond = vcgeq_s16(v_abs, vzbins);
384   uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0);
385   if (nz_check) {
386     vround = vsetq_lane_s16(round_ptr[0], vround, 0);
387     vquant = vsetq_lane_s16(quant_ptr[0], vquant, 0);
388     vdequant = vsetq_lane_s16(dequant_ptr[0], vdequant, 0);
389     vquant_shift = vsetq_lane_s16(quant_shift_ptr[0], vquant_shift, 0);
390 
391     int16x8_t vtmp = vqaddq_s16(v_abs, vround);
392     int16x8_t vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1);
393     vtmp2 = vshrq_n_s16(vqdmulhq_s16(vtmp2, vquant_shift), 1);
394 
395     int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign);
396     int16x8_t coeff_nz_mask =
397         vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[0]));
398     store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask);
399     int16x8_t v_deq_abs = vmulq_s16(vtmp2, vdequant);
400 
401     vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign);
402     coeff_nz_mask =
403         vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[0]));
404     store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask);
405 
406     vround = vsetq_lane_s16(round_ptr[1], vround, 0);
407     vquant = vsetq_lane_s16(quant_ptr[1], vquant, 0);
408     vdequant = vsetq_lane_s16(dequant_ptr[1], vdequant, 0);
409     vquant_shift = vsetq_lane_s16(quant_shift_ptr[1], vquant_shift, 0);
410 
411     uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero);
412     const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond);
413     int16x8_t v_iscan = vld1q_s16(&iscan[0]);
414     vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
415     v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210);
416   }
417   vzbins = vsetq_lane_s16(zbins[1], vzbins, 0);
418 
419   for (int i = 8; i < n_coeffs; i += 8) {
420     v_coeff = load_tran_low_to_s16q(&coeff_ptr[i]);
421     v_coeff_sign = vshrq_n_s16(v_coeff, 15);
422     v_abs = vabsq_s16(v_coeff);
423     vcond = vcgeq_s16(v_abs, vzbins);
424 
425     nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0);
426     if (nz_check) {
427       int16x8_t vtmp = vqaddq_s16(v_abs, vround);
428       int16x8_t vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1);
429 
430       vtmp2 = vshrq_n_s16(vqdmulhq_s16(vtmp2, vquant_shift), 1);
431       int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign);
432       int16x8_t coeff_nz_mask =
433           vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[i]));
434       store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask);
435       int16x8_t v_deq_abs = vmulq_s16(vtmp2, vdequant);
436       vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign);
437       coeff_nz_mask =
438           vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[i]));
439       store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask);
440 
441       uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero);
442       const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond);
443       int16x8_t v_iscan = vld1q_s16(&iscan[i]);
444       vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
445       v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210);
446     }
447   }
448   *eob_ptr = get_max_eob(v_eobmax_76543210) + 1;
449 }
450 
451 #define QM_MULL_SHIFT(x0, x1)                                              \
452   vreinterpretq_s16_u16(vorrq_u16(                                         \
453       vreinterpretq_u16_s16(vshlq_n_s16(                                   \
454           vqdmulhq_s16(x0, vreinterpretq_s16_u16(x1)), 15 - AOM_QM_BITS)), \
455       vshrq_n_u16(vmulq_u16(vreinterpretq_u16_s16(x0), x1), AOM_QM_BITS)))
456 
aom_quantize_b_helper_16x16_neon(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan,const qm_val_t * qm_ptr,const qm_val_t * iqm_ptr)457 static void aom_quantize_b_helper_16x16_neon(
458     const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr,
459     const int16_t *round_ptr, const int16_t *quant_ptr,
460     const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr,
461     tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr,
462     const int16_t *scan, const int16_t *iscan, const qm_val_t *qm_ptr,
463     const qm_val_t *iqm_ptr) {
464   (void)scan;
465 
466   uint16x8_t vwt, viwt;
467   const int zbins[2] = { zbin_ptr[0], zbin_ptr[1] };
468 
469   memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr));
470   memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr));
471 
472   const int16x8_t zero = vdupq_n_s16(0);
473   int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero));
474 
475   int16x8_t vzbins = vdupq_n_s16(zbins[1]), vround = vdupq_n_s16(round_ptr[1]);
476   int16x8_t vdequant = vdupq_n_s16(dequant_ptr[1]);
477   int16x8_t vquant = vdupq_n_s16(quant_ptr[1]);
478   int16x8_t vquant_shift = vdupq_n_s16(quant_shift_ptr[1]);
479 
480   int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]);
481   int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15);
482   int16x8_t v_abs = vabsq_s16(v_coeff);
483   vzbins = vsetq_lane_s16(zbins[0], vzbins, 0);
484   uint16x8_t vcond;
485   if (qm_ptr == NULL) {
486     vcond = vcgeq_s16(v_abs, vzbins);
487   } else {
488     vwt = vmovl_u8(vld1_u8(&qm_ptr[0]));
489     vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins);
490   }
491   uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0);
492   if (nz_check) {
493     vround = vsetq_lane_s16(round_ptr[0], vround, 0);
494     vquant = vsetq_lane_s16(quant_ptr[0], vquant, 0);
495     vdequant = vsetq_lane_s16(dequant_ptr[0], vdequant, 0);
496     vquant_shift = vsetq_lane_s16(quant_shift_ptr[0], vquant_shift, 0);
497 
498     int16x8_t vtmp = vqaddq_s16(v_abs, vround);
499 
500     int16x8_t vtmp2;
501     if (qm_ptr == NULL) {
502       vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1);
503     } else {
504       vtmp2 = QM_MULL_SHIFT(vtmp, vwt);
505       vtmp2 = vaddq_s16(vtmp2, vtmp);
506     }
507 
508     vtmp2 = vshrq_n_s16(vqdmulhq_s16(vtmp2, vquant_shift), 1);
509     int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign);
510     int16x8_t coeff_nz_mask =
511         vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[0]));
512     store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask);
513 
514     if (iqm_ptr != NULL) {
515       viwt = vmovl_u8(vld1_u8(&iqm_ptr[0]));
516       vdequant = QM_MULL_SHIFT(vdequant, viwt);
517     }
518     int16x8_t v_deq_abs = vmulq_s16(vtmp2, vdequant);
519     vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign);
520     coeff_nz_mask =
521         vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[0]));
522     store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask);
523 
524     vround = vsetq_lane_s16(round_ptr[1], vround, 0);
525     vquant = vsetq_lane_s16(quant_ptr[1], vquant, 0);
526     vdequant = vsetq_lane_s16(dequant_ptr[1], vdequant, 0);
527     vquant_shift = vsetq_lane_s16(quant_shift_ptr[1], vquant_shift, 0);
528 
529     uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero);
530     const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond);
531     int16x8_t v_iscan = vld1q_s16(&iscan[0]);
532     vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
533     v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210);
534   }
535   vzbins = vsetq_lane_s16(zbins[1], vzbins, 0);
536 
537   for (int i = 8; i < n_coeffs; i += 8) {
538     v_coeff = load_tran_low_to_s16q(&coeff_ptr[i]);
539     v_coeff_sign = vshrq_n_s16(v_coeff, 15);
540     v_abs = vabsq_s16(v_coeff);
541 
542     if (qm_ptr == NULL) {
543       vcond = vcgeq_s16(v_abs, vzbins);
544     } else {
545       vwt = vmovl_u8(vld1_u8(&qm_ptr[i]));
546       vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins);
547     }
548     nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0);
549     if (nz_check) {
550       int16x8_t vtmp = vqaddq_s16(v_abs, vround);
551 
552       int16x8_t vtmp2;
553       if (qm_ptr == NULL) {
554         vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1);
555       } else {
556         vtmp2 = QM_MULL_SHIFT(vtmp, vwt);
557         vtmp2 = vaddq_s16(vtmp2, vtmp);
558       }
559 
560       vtmp2 = vshrq_n_s16(vqdmulhq_s16(vtmp2, vquant_shift), 1);
561       int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign);
562       int16x8_t coeff_nz_mask =
563           vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[i]));
564       store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask);
565 
566       if (iqm_ptr != NULL) {
567         viwt = vmovl_u8(vld1_u8(&iqm_ptr[i]));
568         vdequant = QM_MULL_SHIFT(vdequant, viwt);
569       }
570       int16x8_t v_deq_abs = vmulq_s16(vtmp2, vdequant);
571       vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign);
572       coeff_nz_mask =
573           vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[i]));
574       store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask);
575 
576       uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero);
577       const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond);
578       int16x8_t v_iscan = vld1q_s16(&iscan[i]);
579       vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
580       v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210);
581     }
582   }
583   *eob_ptr = get_max_eob(v_eobmax_76543210) + 1;
584 }
585 
aom_quantize_b_helper_32x32_neon(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan,const qm_val_t * qm_ptr,const qm_val_t * iqm_ptr)586 static void aom_quantize_b_helper_32x32_neon(
587     const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr,
588     const int16_t *round_ptr, const int16_t *quant_ptr,
589     const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr,
590     tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr,
591     const int16_t *scan, const int16_t *iscan, const qm_val_t *qm_ptr,
592     const qm_val_t *iqm_ptr) {
593   (void)scan;
594 
595   uint16x8_t vwt, viwt;
596   const int log_scale = 1;
597   const int zbins[2] = { ROUND_POWER_OF_TWO(zbin_ptr[0], log_scale),
598                          ROUND_POWER_OF_TWO(zbin_ptr[1], log_scale) };
599 
600   memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr));
601   memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr));
602 
603   const int16x8_t zero = vdupq_n_s16(0);
604   int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero));
605   const int16x8_t v_log_scale = v_eobmax_76543210;
606 
607   int16x8_t vzbins = vdupq_n_s16(zbins[1]),
608             vround = vdupq_n_s16(ROUND_POWER_OF_TWO(round_ptr[1], log_scale));
609   int16x8_t vdequant = vdupq_n_s16(dequant_ptr[1]);
610   int16x8_t vquant = vdupq_n_s16(quant_ptr[1]);
611   int16x8_t vquant_shift = vdupq_n_s16(quant_shift_ptr[1]);
612 
613   int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]);
614   int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15);
615   int16x8_t v_abs = vabsq_s16(v_coeff);
616   vzbins = vsetq_lane_s16(zbins[0], vzbins, 0);
617   uint16x8_t vcond;
618   if (qm_ptr == NULL) {
619     vcond = vcgeq_s16(v_abs, vzbins);
620   } else {
621     vwt = vmovl_u8(vld1_u8(&qm_ptr[0]));
622     vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins);
623   }
624   uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0);
625   if (nz_check) {
626     vround =
627         vsetq_lane_s16(ROUND_POWER_OF_TWO(round_ptr[0], log_scale), vround, 0);
628     vquant = vsetq_lane_s16(quant_ptr[0], vquant, 0);
629     vdequant = vsetq_lane_s16(dequant_ptr[0], vdequant, 0);
630     vquant_shift = vsetq_lane_s16(quant_shift_ptr[0], vquant_shift, 0);
631 
632     int16x8_t vtmp = vqaddq_s16(v_abs, vround);
633 
634     int16x8_t vtmp2;
635     if (qm_ptr == NULL) {
636       vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1);
637     } else {
638       vtmp2 = QM_MULL_SHIFT(vtmp, vwt);
639       vtmp2 = vaddq_s16(vtmp2, vtmp);
640     }
641 
642     vtmp2 = vqdmulhq_s16(vtmp2, vquant_shift);
643     int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign);
644     int16x8_t coeff_nz_mask =
645         vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[0]));
646     store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask);
647 
648     if (iqm_ptr != NULL) {
649       viwt = vmovl_u8(vld1_u8(&iqm_ptr[0]));
650       vdequant = QM_MULL_SHIFT(vdequant, viwt);
651     }
652     int16x8_t v_deq_abs = vreinterpretq_s16_u16(vshlq_u16(
653         vreinterpretq_u16_s16(vmulq_s16(vtmp2, vdequant)), v_log_scale));
654     vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign);
655     coeff_nz_mask =
656         vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[0]));
657     store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask);
658 
659     vzbins = vsetq_lane_s16(zbins[1], vzbins, 0);
660     vround =
661         vsetq_lane_s16(ROUND_POWER_OF_TWO(round_ptr[1], log_scale), vround, 0);
662     vquant = vsetq_lane_s16(quant_ptr[1], vquant, 0);
663     vdequant = vsetq_lane_s16(dequant_ptr[1], vdequant, 0);
664     vquant_shift = vsetq_lane_s16(quant_shift_ptr[1], vquant_shift, 0);
665 
666     uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero);
667     const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond);
668     int16x8_t v_iscan = vld1q_s16(&iscan[0]);
669     vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
670     v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210);
671   }
672   vzbins = vsetq_lane_s16(zbins[1], vzbins, 0);
673 
674   for (int i = 8; i < n_coeffs; i += 8) {
675     v_coeff = load_tran_low_to_s16q(&coeff_ptr[i]);
676     v_coeff_sign = vshrq_n_s16(v_coeff, 15);
677     v_abs = vabsq_s16(v_coeff);
678 
679     if (qm_ptr == NULL) {
680       vcond = vcgeq_s16(v_abs, vzbins);
681     } else {
682       vwt = vmovl_u8(vld1_u8(&qm_ptr[i]));
683       vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins);
684     }
685     nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0);
686     if (nz_check) {
687       int16x8_t vtmp = vqaddq_s16(v_abs, vround);
688 
689       int16x8_t vtmp2;
690       if (qm_ptr == NULL) {
691         vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1);
692       } else {
693         vtmp2 = QM_MULL_SHIFT(vtmp, vwt);
694         vtmp2 = vaddq_s16(vtmp2, vtmp);
695       }
696       vtmp2 = vqdmulhq_s16(vtmp2, vquant_shift);
697 
698       int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign);
699       int16x8_t coeff_nz_mask =
700           vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[i]));
701       store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask);
702 
703       if (iqm_ptr != NULL) {
704         viwt = vmovl_u8(vld1_u8(&iqm_ptr[i]));
705         vdequant = QM_MULL_SHIFT(vdequant, viwt);
706       }
707       int16x8_t v_deq_abs = vreinterpretq_s16_u16(vshlq_u16(
708           vreinterpretq_u16_s16(vmulq_s16(vtmp2, vdequant)), v_log_scale));
709       vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign);
710       coeff_nz_mask =
711           vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[i]));
712       store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask);
713 
714       uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero);
715       const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond);
716       int16x8_t v_iscan = vld1q_s16(&iscan[i]);
717       vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
718       v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210);
719     }
720   }
721   *eob_ptr = get_max_eob(v_eobmax_76543210) + 1;
722 }
723 
aom_quantize_b_helper_64x64_neon(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan,const qm_val_t * qm_ptr,const qm_val_t * iqm_ptr)724 static void aom_quantize_b_helper_64x64_neon(
725     const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr,
726     const int16_t *round_ptr, const int16_t *quant_ptr,
727     const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr,
728     tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr,
729     const int16_t *scan, const int16_t *iscan, const qm_val_t *qm_ptr,
730     const qm_val_t *iqm_ptr) {
731   (void)scan;
732 
733   uint16x8_t vwt, viwt;
734   const int log_scale = 2;
735   const int16x8_t v_log_scale =
736       vreinterpretq_s16_s64(vdupq_n_s64(0xFFFEFFFEFFFEFFFE));
737 
738   const int zbins[2] = { ROUND_POWER_OF_TWO(zbin_ptr[0], log_scale),
739                          ROUND_POWER_OF_TWO(zbin_ptr[1], log_scale) };
740 
741   memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr));
742   memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr));
743 
744   const int16x8_t zero = vdupq_n_s16(0);
745   int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero));
746   int16x8_t v_ones = vnegq_s16(v_eobmax_76543210);
747 
748   int16x8_t vzbins = vdupq_n_s16(zbins[1]),
749             vround = vdupq_n_s16(ROUND_POWER_OF_TWO(round_ptr[1], log_scale));
750   int16x8_t vdequant = vdupq_n_s16(dequant_ptr[1]);
751   int16x8_t vquant = vdupq_n_s16(quant_ptr[1]);
752   int16x8_t vquant_shift = vdupq_n_s16(quant_shift_ptr[1]);
753 
754   int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]);
755   int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15);
756   int16x8_t v_abs = vabsq_s16(v_coeff);
757   vzbins = vsetq_lane_s16(zbins[0], vzbins, 0);
758   uint16x8_t vcond;
759   if (qm_ptr == NULL) {
760     vcond = vcgeq_s16(v_abs, vzbins);
761   } else {
762     vwt = vmovl_u8(vld1_u8(&qm_ptr[0]));
763     vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins);
764   }
765   uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0);
766   if (nz_check) {
767     vround =
768         vsetq_lane_s16(ROUND_POWER_OF_TWO(round_ptr[0], log_scale), vround, 0);
769     vquant = vsetq_lane_s16(quant_ptr[0], vquant, 0);
770     vdequant = vsetq_lane_s16(dequant_ptr[0], vdequant, 0);
771     vquant_shift = vsetq_lane_s16(quant_shift_ptr[0], vquant_shift, 0);
772     int16x8_t vtmp = vqaddq_s16(v_abs, vround);
773 
774     int16x8_t vtmp2;
775     if (qm_ptr == NULL) {
776       vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1);
777     } else {
778       vtmp2 = QM_MULL_SHIFT(vtmp, vwt);
779       vtmp2 = vaddq_s16(vtmp2, vtmp);
780     }
781 
782     int16x8_t ones =
783         vandq_s16(vshrq_n_s16(vmulq_s16(vtmp2, vquant_shift), 14), v_ones);
784     vtmp2 =
785         vaddq_s16(vshlq_s16(vqdmulhq_s16(vtmp2, vquant_shift), v_ones), ones);
786     int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign);
787     int16x8_t coeff_nz_mask =
788         vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[0]));
789     store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask);
790 
791     if (iqm_ptr != NULL) {
792       viwt = vmovl_u8(vld1_u8(&iqm_ptr[0]));
793       vdequant = QM_MULL_SHIFT(vdequant, viwt);
794     }
795     int16x8_t v_deq_abs = vreinterpretq_s16_u16(vshlq_u16(
796         vreinterpretq_u16_s16(vmulq_s16(vtmp2, vdequant)), v_log_scale));
797     v_deq_abs =
798         vorrq_s16(vshlq_n_s16(vqdmulhq_s16(vtmp2, vdequant), 13), v_deq_abs);
799     vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign);
800     coeff_nz_mask =
801         vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[0]));
802     store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask);
803 
804     vround =
805         vsetq_lane_s16(ROUND_POWER_OF_TWO(round_ptr[1], log_scale), vround, 0);
806     vquant = vsetq_lane_s16(quant_ptr[1], vquant, 0);
807     vdequant = vsetq_lane_s16(dequant_ptr[1], vdequant, 0);
808     vquant_shift = vsetq_lane_s16(quant_shift_ptr[1], vquant_shift, 0);
809 
810     uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero);
811     const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond);
812     int16x8_t v_iscan = vld1q_s16(&iscan[0]);
813     vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
814     v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210);
815   }
816   vzbins = vsetq_lane_s16(zbins[1], vzbins, 0);
817 
818   for (int i = 8; i < n_coeffs; i += 8) {
819     v_coeff = load_tran_low_to_s16q(&coeff_ptr[i]);
820     v_coeff_sign = vshrq_n_s16(v_coeff, 15);
821     v_abs = vabsq_s16(v_coeff);
822 
823     if (qm_ptr == NULL) {
824       vcond = vcgeq_s16(v_abs, vzbins);
825     } else {
826       vwt = vmovl_u8(vld1_u8(&qm_ptr[i]));
827       vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins);
828     }
829     nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0);
830     if (nz_check) {
831       int16x8_t vtmp = vqaddq_s16(v_abs, vround);
832 
833       int16x8_t vtmp2;
834       if (qm_ptr == NULL) {
835         vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1);
836       } else {
837         vtmp2 = QM_MULL_SHIFT(vtmp, vwt);
838         vtmp2 = vaddq_s16(vtmp2, vtmp);
839       }
840 
841       int16x8_t ones =
842           vandq_s16(vshrq_n_s16(vmulq_s16(vtmp2, vquant_shift), 14), v_ones);
843       vtmp2 =
844           vaddq_s16(vshlq_s16(vqdmulhq_s16(vtmp2, vquant_shift), v_ones), ones);
845       int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign);
846       int16x8_t coeff_nz_mask =
847           vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[i]));
848       store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask);
849 
850       if (iqm_ptr != NULL) {
851         viwt = vmovl_u8(vld1_u8(&iqm_ptr[i]));
852         vdequant = QM_MULL_SHIFT(vdequant, viwt);
853       }
854       int16x8_t v_deq_abs = vreinterpretq_s16_u16(vshlq_u16(
855           vreinterpretq_u16_s16(vmulq_s16(vtmp2, vdequant)), v_log_scale));
856       v_deq_abs =
857           vorrq_s16(vshlq_n_s16(vqdmulhq_s16(vtmp2, vdequant), 13), v_deq_abs);
858       vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign);
859       coeff_nz_mask =
860           vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[i]));
861       store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask);
862 
863       uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero);
864       const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond);
865       int16x8_t v_iscan = vld1q_s16(&iscan[i]);
866       vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210));
867       v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210);
868     }
869   }
870   *eob_ptr = get_max_eob(v_eobmax_76543210) + 1;
871 }
872 
aom_quantize_b_helper_neon(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan,const qm_val_t * qm_ptr,const qm_val_t * iqm_ptr,const int log_scale)873 void aom_quantize_b_helper_neon(
874     const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr,
875     const int16_t *round_ptr, const int16_t *quant_ptr,
876     const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr,
877     tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr,
878     const int16_t *scan, const int16_t *iscan, const qm_val_t *qm_ptr,
879     const qm_val_t *iqm_ptr, const int log_scale) {
880   switch (log_scale) {  // log_scale for AV1 encoder can be only 0, 1, 2
881     case 0:
882       aom_quantize_b_helper_16x16_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr,
883                                        quant_ptr, quant_shift_ptr, qcoeff_ptr,
884                                        dqcoeff_ptr, dequant_ptr, eob_ptr, scan,
885                                        iscan, qm_ptr, iqm_ptr);
886       break;
887     case 1:
888       aom_quantize_b_helper_32x32_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr,
889                                        quant_ptr, quant_shift_ptr, qcoeff_ptr,
890                                        dqcoeff_ptr, dequant_ptr, eob_ptr, scan,
891                                        iscan, qm_ptr, iqm_ptr);
892       break;
893     case 2:
894       aom_quantize_b_helper_64x64_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr,
895                                        quant_ptr, quant_shift_ptr, qcoeff_ptr,
896                                        dqcoeff_ptr, dequant_ptr, eob_ptr, scan,
897                                        iscan, qm_ptr, iqm_ptr);
898       break;
899   }
900 }
901 
aom_quantize_b_32x32_neon(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan)902 void aom_quantize_b_32x32_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs,
903                                const int16_t *zbin_ptr,
904                                const int16_t *round_ptr,
905                                const int16_t *quant_ptr,
906                                const int16_t *quant_shift_ptr,
907                                tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
908                                const int16_t *dequant_ptr, uint16_t *eob_ptr,
909                                const int16_t *scan, const int16_t *iscan) {
910   aom_quantize_b_helper_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr,
911                              quant_ptr, quant_shift_ptr, qcoeff_ptr,
912                              dqcoeff_ptr, dequant_ptr, eob_ptr, scan, iscan,
913                              NULL, NULL, 1);
914 }
915 
aom_quantize_b_64x64_neon(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan)916 void aom_quantize_b_64x64_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs,
917                                const int16_t *zbin_ptr,
918                                const int16_t *round_ptr,
919                                const int16_t *quant_ptr,
920                                const int16_t *quant_shift_ptr,
921                                tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
922                                const int16_t *dequant_ptr, uint16_t *eob_ptr,
923                                const int16_t *scan, const int16_t *iscan) {
924   aom_quantize_b_helper_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr,
925                              quant_ptr, quant_shift_ptr, qcoeff_ptr,
926                              dqcoeff_ptr, dequant_ptr, eob_ptr, scan, iscan,
927                              NULL, NULL, 2);
928 }
929