xref: /aosp_15_r20/external/libaom/aom_dsp/x86/aom_quantize_avx.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 #include "aom/aom_integer.h"
16 #include "aom_dsp/x86/bitdepth_conversion_sse2.h"
17 #include "aom_dsp/x86/quantize_x86.h"
18 
calculate_dqcoeff_and_store(__m128i qcoeff,__m128i dequant,tran_low_t * dqcoeff)19 static inline void calculate_dqcoeff_and_store(__m128i qcoeff, __m128i dequant,
20                                                tran_low_t *dqcoeff) {
21   const __m128i low = _mm_mullo_epi16(qcoeff, dequant);
22   const __m128i high = _mm_mulhi_epi16(qcoeff, dequant);
23 
24   const __m128i dqcoeff32_0 = _mm_unpacklo_epi16(low, high);
25   const __m128i dqcoeff32_1 = _mm_unpackhi_epi16(low, high);
26 
27   _mm_store_si128((__m128i *)(dqcoeff), dqcoeff32_0);
28   _mm_store_si128((__m128i *)(dqcoeff + 4), dqcoeff32_1);
29 }
30 
aom_quantize_b_avx(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan)31 void aom_quantize_b_avx(const tran_low_t *coeff_ptr, intptr_t n_coeffs,
32                         const int16_t *zbin_ptr, const int16_t *round_ptr,
33                         const int16_t *quant_ptr,
34                         const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr,
35                         tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr,
36                         uint16_t *eob_ptr, const int16_t *scan,
37                         const int16_t *iscan) {
38   const __m128i zero = _mm_setzero_si128();
39   const __m256i big_zero = _mm256_setzero_si256();
40   int index;
41 
42   __m128i zbin, round, quant, dequant, shift;
43   __m128i coeff0, coeff1;
44   __m128i qcoeff0, qcoeff1;
45   __m128i cmp_mask0, cmp_mask1;
46   __m128i all_zero;
47   __m128i eob = zero, eob0;
48 
49   (void)scan;
50 
51   *eob_ptr = 0;
52 
53   load_b_values(zbin_ptr, &zbin, round_ptr, &round, quant_ptr, &quant,
54                 dequant_ptr, &dequant, quant_shift_ptr, &shift);
55 
56   // Do DC and first 15 AC.
57   coeff0 = load_tran_low(coeff_ptr);
58   coeff1 = load_tran_low(coeff_ptr + 8);
59 
60   qcoeff0 = _mm_abs_epi16(coeff0);
61   qcoeff1 = _mm_abs_epi16(coeff1);
62 
63   cmp_mask0 = _mm_cmpgt_epi16(qcoeff0, zbin);
64   zbin = _mm_unpackhi_epi64(zbin, zbin);  // Switch DC to AC
65   cmp_mask1 = _mm_cmpgt_epi16(qcoeff1, zbin);
66 
67   all_zero = _mm_or_si128(cmp_mask0, cmp_mask1);
68   if (_mm_test_all_zeros(all_zero, all_zero)) {
69     _mm256_store_si256((__m256i *)(qcoeff_ptr), big_zero);
70     _mm256_store_si256((__m256i *)(dqcoeff_ptr), big_zero);
71     _mm256_store_si256((__m256i *)(qcoeff_ptr + 8), big_zero);
72     _mm256_store_si256((__m256i *)(dqcoeff_ptr + 8), big_zero);
73 
74     if (n_coeffs == 16) return;
75 
76     round = _mm_unpackhi_epi64(round, round);
77     quant = _mm_unpackhi_epi64(quant, quant);
78     shift = _mm_unpackhi_epi64(shift, shift);
79     dequant = _mm_unpackhi_epi64(dequant, dequant);
80   } else {
81     calculate_qcoeff(&qcoeff0, round, quant, shift);
82     round = _mm_unpackhi_epi64(round, round);
83     quant = _mm_unpackhi_epi64(quant, quant);
84     shift = _mm_unpackhi_epi64(shift, shift);
85     calculate_qcoeff(&qcoeff1, round, quant, shift);
86 
87     // Reinsert signs
88     qcoeff0 = _mm_sign_epi16(qcoeff0, coeff0);
89     qcoeff1 = _mm_sign_epi16(qcoeff1, coeff1);
90 
91     // Mask out zbin threshold coeffs
92     qcoeff0 = _mm_and_si128(qcoeff0, cmp_mask0);
93     qcoeff1 = _mm_and_si128(qcoeff1, cmp_mask1);
94 
95     store_tran_low(qcoeff0, qcoeff_ptr);
96     store_tran_low(qcoeff1, qcoeff_ptr + 8);
97 
98     calculate_dqcoeff_and_store(qcoeff0, dequant, dqcoeff_ptr);
99     dequant = _mm_unpackhi_epi64(dequant, dequant);
100     calculate_dqcoeff_and_store(qcoeff1, dequant, dqcoeff_ptr + 8);
101 
102     eob =
103         scan_for_eob(&qcoeff0, &qcoeff1, cmp_mask0, cmp_mask1, iscan, 0, zero);
104   }
105 
106   // AC only loop.
107   for (index = 16; index < n_coeffs; index += 16) {
108     coeff0 = load_tran_low(coeff_ptr + index);
109     coeff1 = load_tran_low(coeff_ptr + index + 8);
110 
111     qcoeff0 = _mm_abs_epi16(coeff0);
112     qcoeff1 = _mm_abs_epi16(coeff1);
113 
114     cmp_mask0 = _mm_cmpgt_epi16(qcoeff0, zbin);
115     cmp_mask1 = _mm_cmpgt_epi16(qcoeff1, zbin);
116 
117     all_zero = _mm_or_si128(cmp_mask0, cmp_mask1);
118     if (_mm_test_all_zeros(all_zero, all_zero)) {
119       _mm256_store_si256((__m256i *)(qcoeff_ptr + index), big_zero);
120       _mm256_store_si256((__m256i *)(dqcoeff_ptr + index), big_zero);
121       _mm256_store_si256((__m256i *)(qcoeff_ptr + index + 8), big_zero);
122       _mm256_store_si256((__m256i *)(dqcoeff_ptr + index + 8), big_zero);
123       continue;
124     }
125 
126     calculate_qcoeff(&qcoeff0, round, quant, shift);
127     calculate_qcoeff(&qcoeff1, round, quant, shift);
128 
129     qcoeff0 = _mm_sign_epi16(qcoeff0, coeff0);
130     qcoeff1 = _mm_sign_epi16(qcoeff1, coeff1);
131 
132     qcoeff0 = _mm_and_si128(qcoeff0, cmp_mask0);
133     qcoeff1 = _mm_and_si128(qcoeff1, cmp_mask1);
134 
135     store_tran_low(qcoeff0, qcoeff_ptr + index);
136     store_tran_low(qcoeff1, qcoeff_ptr + index + 8);
137 
138     calculate_dqcoeff_and_store(qcoeff0, dequant, dqcoeff_ptr + index);
139     calculate_dqcoeff_and_store(qcoeff1, dequant, dqcoeff_ptr + index + 8);
140 
141     eob0 = scan_for_eob(&qcoeff0, &qcoeff1, cmp_mask0, cmp_mask1, iscan, index,
142                         zero);
143     eob = _mm_max_epi16(eob, eob0);
144   }
145 
146   *eob_ptr = accumulate_eob(eob);
147 }
148 
aom_quantize_b_32x32_avx(const tran_low_t * coeff_ptr,intptr_t n_coeffs,const int16_t * zbin_ptr,const int16_t * round_ptr,const int16_t * quant_ptr,const int16_t * quant_shift_ptr,tran_low_t * qcoeff_ptr,tran_low_t * dqcoeff_ptr,const int16_t * dequant_ptr,uint16_t * eob_ptr,const int16_t * scan,const int16_t * iscan)149 void aom_quantize_b_32x32_avx(const tran_low_t *coeff_ptr, intptr_t n_coeffs,
150                               const int16_t *zbin_ptr, const int16_t *round_ptr,
151                               const int16_t *quant_ptr,
152                               const int16_t *quant_shift_ptr,
153                               tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr,
154                               const int16_t *dequant_ptr, uint16_t *eob_ptr,
155                               const int16_t *scan, const int16_t *iscan) {
156   const __m128i zero = _mm_setzero_si128();
157   const __m128i one = _mm_set1_epi16(1);
158   const __m256i big_zero = _mm256_setzero_si256();
159   int index;
160   const int log_scale = 1;
161 
162   __m128i zbin, round, quant, dequant, shift;
163   __m128i coeff0, coeff1;
164   __m128i qcoeff0, qcoeff1;
165   __m128i cmp_mask0, cmp_mask1;
166   __m128i all_zero;
167   __m128i eob = zero, eob0;
168 
169   (void)scan;
170 
171   // Setup global values.
172   // The 32x32 halves zbin and round.
173   zbin = _mm_load_si128((const __m128i *)zbin_ptr);
174   // Shift with rounding.
175   zbin = _mm_add_epi16(zbin, one);
176   zbin = _mm_srli_epi16(zbin, 1);
177   // x86 has no "greater *or equal*" comparison. Subtract 1 from zbin so
178   // it is a strict "greater" comparison.
179   zbin = _mm_sub_epi16(zbin, one);
180 
181   round = _mm_load_si128((const __m128i *)round_ptr);
182   round = _mm_add_epi16(round, one);
183   round = _mm_srli_epi16(round, 1);
184 
185   quant = _mm_load_si128((const __m128i *)quant_ptr);
186   dequant = _mm_load_si128((const __m128i *)dequant_ptr);
187   shift = _mm_load_si128((const __m128i *)quant_shift_ptr);
188 
189   // Do DC and first 15 AC.
190   coeff0 = load_tran_low(coeff_ptr);
191   coeff1 = load_tran_low(coeff_ptr + 8);
192 
193   qcoeff0 = _mm_abs_epi16(coeff0);
194   qcoeff1 = _mm_abs_epi16(coeff1);
195 
196   cmp_mask0 = _mm_cmpgt_epi16(qcoeff0, zbin);
197   zbin = _mm_unpackhi_epi64(zbin, zbin);  // Switch DC to AC.
198   cmp_mask1 = _mm_cmpgt_epi16(qcoeff1, zbin);
199 
200   all_zero = _mm_or_si128(cmp_mask0, cmp_mask1);
201   if (_mm_test_all_zeros(all_zero, all_zero)) {
202     _mm256_store_si256((__m256i *)(qcoeff_ptr), big_zero);
203     _mm256_store_si256((__m256i *)(dqcoeff_ptr), big_zero);
204     _mm256_store_si256((__m256i *)(qcoeff_ptr + 8), big_zero);
205     _mm256_store_si256((__m256i *)(dqcoeff_ptr + 8), big_zero);
206 
207     round = _mm_unpackhi_epi64(round, round);
208     quant = _mm_unpackhi_epi64(quant, quant);
209     shift = _mm_unpackhi_epi64(shift, shift);
210     dequant = _mm_unpackhi_epi64(dequant, dequant);
211   } else {
212     calculate_qcoeff_log_scale(&qcoeff0, round, quant, &shift, &log_scale);
213     round = _mm_unpackhi_epi64(round, round);
214     quant = _mm_unpackhi_epi64(quant, quant);
215     shift = _mm_unpackhi_epi64(shift, shift);
216     calculate_qcoeff_log_scale(&qcoeff1, round, quant, &shift, &log_scale);
217 
218     // Reinsert signs.
219     qcoeff0 = _mm_sign_epi16(qcoeff0, coeff0);
220     qcoeff1 = _mm_sign_epi16(qcoeff1, coeff1);
221 
222     // Mask out zbin threshold coeffs.
223     qcoeff0 = _mm_and_si128(qcoeff0, cmp_mask0);
224     qcoeff1 = _mm_and_si128(qcoeff1, cmp_mask1);
225 
226     store_tran_low(qcoeff0, qcoeff_ptr);
227     store_tran_low(qcoeff1, qcoeff_ptr + 8);
228 
229     calculate_dqcoeff_and_store_log_scale(qcoeff0, dequant, zero, dqcoeff_ptr,
230                                           &log_scale);
231     dequant = _mm_unpackhi_epi64(dequant, dequant);
232     calculate_dqcoeff_and_store_log_scale(qcoeff1, dequant, zero,
233                                           dqcoeff_ptr + 8, &log_scale);
234 
235     eob =
236         scan_for_eob(&qcoeff0, &qcoeff1, cmp_mask0, cmp_mask1, iscan, 0, zero);
237   }
238 
239   // AC only loop.
240   for (index = 16; index < n_coeffs; index += 16) {
241     coeff0 = load_tran_low(coeff_ptr + index);
242     coeff1 = load_tran_low(coeff_ptr + index + 8);
243 
244     qcoeff0 = _mm_abs_epi16(coeff0);
245     qcoeff1 = _mm_abs_epi16(coeff1);
246 
247     cmp_mask0 = _mm_cmpgt_epi16(qcoeff0, zbin);
248     cmp_mask1 = _mm_cmpgt_epi16(qcoeff1, zbin);
249 
250     all_zero = _mm_or_si128(cmp_mask0, cmp_mask1);
251     if (_mm_test_all_zeros(all_zero, all_zero)) {
252       _mm256_store_si256((__m256i *)(qcoeff_ptr + index), big_zero);
253       _mm256_store_si256((__m256i *)(dqcoeff_ptr + index), big_zero);
254       _mm256_store_si256((__m256i *)(qcoeff_ptr + index + 8), big_zero);
255       _mm256_store_si256((__m256i *)(dqcoeff_ptr + index + 8), big_zero);
256       continue;
257     }
258 
259     calculate_qcoeff_log_scale(&qcoeff0, round, quant, &shift, &log_scale);
260     calculate_qcoeff_log_scale(&qcoeff1, round, quant, &shift, &log_scale);
261 
262     qcoeff0 = _mm_sign_epi16(qcoeff0, coeff0);
263     qcoeff1 = _mm_sign_epi16(qcoeff1, coeff1);
264 
265     qcoeff0 = _mm_and_si128(qcoeff0, cmp_mask0);
266     qcoeff1 = _mm_and_si128(qcoeff1, cmp_mask1);
267 
268     store_tran_low(qcoeff0, qcoeff_ptr + index);
269     store_tran_low(qcoeff1, qcoeff_ptr + index + 8);
270 
271     calculate_dqcoeff_and_store_log_scale(qcoeff0, dequant, zero,
272                                           dqcoeff_ptr + index, &log_scale);
273     calculate_dqcoeff_and_store_log_scale(qcoeff1, dequant, zero,
274                                           dqcoeff_ptr + index + 8, &log_scale);
275 
276     eob0 = scan_for_eob(&qcoeff0, &qcoeff1, cmp_mask0, cmp_mask1, iscan, index,
277                         zero);
278     eob = _mm_max_epi16(eob, eob0);
279   }
280 
281   *eob_ptr = accumulate_eob(eob);
282 }
283