xref: /aosp_15_r20/external/libvpx/vpx_dsp/arm/fdct_neon.h (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2022 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef VPX_VPX_DSP_ARM_FDCT_NEON_H_
12 #define VPX_VPX_DSP_ARM_FDCT_NEON_H_
13 
14 #include <arm_neon.h>
15 
16 // fdct_round_shift((a +/- b) * c)
17 // Variant that performs fast vqrdmulh_s16 operation on half vector
18 // can be slightly less accurate, adequate for pass1
butterfly_one_coeff_s16_fast_half(const int16x4_t a,const int16x4_t b,const tran_coef_t constant,int16x4_t * add,int16x4_t * sub)19 static INLINE void butterfly_one_coeff_s16_fast_half(const int16x4_t a,
20                                                      const int16x4_t b,
21                                                      const tran_coef_t constant,
22                                                      int16x4_t *add,
23                                                      int16x4_t *sub) {
24   int16x4_t c = vdup_n_s16(2 * constant);
25   *add = vqrdmulh_s16(vadd_s16(a, b), c);
26   *sub = vqrdmulh_s16(vsub_s16(a, b), c);
27 }
28 
29 // fdct_round_shift((a +/- b) * c)
30 // Variant that performs fast vqrdmulh_s16 operation on full vector
31 // can be slightly less accurate, adequate for pass1
butterfly_one_coeff_s16_fast(const int16x8_t a,const int16x8_t b,const tran_coef_t constant,int16x8_t * add,int16x8_t * sub)32 static INLINE void butterfly_one_coeff_s16_fast(const int16x8_t a,
33                                                 const int16x8_t b,
34                                                 const tran_coef_t constant,
35                                                 int16x8_t *add,
36                                                 int16x8_t *sub) {
37   int16x8_t c = vdupq_n_s16(2 * constant);
38   *add = vqrdmulhq_s16(vaddq_s16(a, b), c);
39   *sub = vqrdmulhq_s16(vsubq_s16(a, b), c);
40 }
41 
42 // fdct_round_shift((a +/- b) * c)
43 // Variant that performs fast vqrdmulhq_s32 operation on full vector
44 // more accurate does 32-bit processing, takes 16-bit input values,
45 // returns full 32-bit values, high/low
butterfly_one_coeff_s16_s32_fast(const int16x8_t a,const int16x8_t b,const tran_coef_t constant,int32x4_t * add_lo,int32x4_t * add_hi,int32x4_t * sub_lo,int32x4_t * sub_hi)46 static INLINE void butterfly_one_coeff_s16_s32_fast(
47     const int16x8_t a, const int16x8_t b, const tran_coef_t constant,
48     int32x4_t *add_lo, int32x4_t *add_hi, int32x4_t *sub_lo,
49     int32x4_t *sub_hi) {
50   int32x4_t c = vdupq_n_s32(constant << 17);
51   const int16x4_t a_lo = vget_low_s16(a);
52   const int16x4_t a_hi = vget_high_s16(a);
53   const int16x4_t b_lo = vget_low_s16(b);
54   const int16x4_t b_hi = vget_high_s16(b);
55   *add_lo = vqrdmulhq_s32(vaddl_s16(a_lo, b_lo), c);
56   *add_hi = vqrdmulhq_s32(vaddl_s16(a_hi, b_hi), c);
57   *sub_lo = vqrdmulhq_s32(vsubl_s16(a_lo, b_lo), c);
58   *sub_hi = vqrdmulhq_s32(vsubl_s16(a_hi, b_hi), c);
59 }
60 
61 // fdct_round_shift((a +/- b) * c)
62 // Variant that performs fast vqrdmulhq_s32 operation on full vector
63 // more accurate does 32-bit processing, takes 16-bit input values,
64 // returns full 32-bit values, high/low
butterfly_one_coeff_s16_s32_fast_narrow(const int16x8_t a,const int16x8_t b,const tran_coef_t constant,int16x8_t * add,int16x8_t * sub)65 static INLINE void butterfly_one_coeff_s16_s32_fast_narrow(
66     const int16x8_t a, const int16x8_t b, const tran_coef_t constant,
67     int16x8_t *add, int16x8_t *sub) {
68   int32x4_t add_lo, add_hi, sub_lo, sub_hi;
69   butterfly_one_coeff_s16_s32_fast(a, b, constant, &add_lo, &add_hi, &sub_lo,
70                                    &sub_hi);
71   *add = vcombine_s16(vmovn_s32(add_lo), vmovn_s32(add_hi));
72   *sub = vcombine_s16(vmovn_s32(sub_lo), vmovn_s32(sub_hi));
73 }
74 
75 // fdct_round_shift((a +/- b) * c)
76 // Variant that performs fast vqrdmulhq_s32 operation on full vector
77 // more accurate does 32-bit processing, takes 16-bit input values,
78 // returns full 32-bit values, high/low
butterfly_one_coeff_s16_s32_fast_half(const int16x4_t a,const int16x4_t b,const tran_coef_t constant,int32x4_t * add,int32x4_t * sub)79 static INLINE void butterfly_one_coeff_s16_s32_fast_half(
80     const int16x4_t a, const int16x4_t b, const tran_coef_t constant,
81     int32x4_t *add, int32x4_t *sub) {
82   int32x4_t c = vdupq_n_s32(constant << 17);
83   *add = vqrdmulhq_s32(vaddl_s16(a, b), c);
84   *sub = vqrdmulhq_s32(vsubl_s16(a, b), c);
85 }
86 
87 // fdct_round_shift((a +/- b) * c)
88 // Variant that performs fast vqrdmulhq_s32 operation on half vector
89 // more accurate does 32-bit processing, takes 16-bit input values,
90 // returns narrowed down 16-bit values
butterfly_one_coeff_s16_s32_fast_narrow_half(const int16x4_t a,const int16x4_t b,const tran_coef_t constant,int16x4_t * add,int16x4_t * sub)91 static INLINE void butterfly_one_coeff_s16_s32_fast_narrow_half(
92     const int16x4_t a, const int16x4_t b, const tran_coef_t constant,
93     int16x4_t *add, int16x4_t *sub) {
94   int32x4_t add32, sub32;
95   butterfly_one_coeff_s16_s32_fast_half(a, b, constant, &add32, &sub32);
96   *add = vmovn_s32(add32);
97   *sub = vmovn_s32(sub32);
98 }
99 
100 // fdct_round_shift((a +/- b) * c)
101 // Original Variant that performs normal implementation on full vector
102 // fully accurate does 32-bit processing, takes 16-bit values
butterfly_one_coeff_s16_s32(const int16x8_t a,const int16x8_t b,const tran_coef_t constant,int32x4_t * add_lo,int32x4_t * add_hi,int32x4_t * sub_lo,int32x4_t * sub_hi)103 static INLINE void butterfly_one_coeff_s16_s32(
104     const int16x8_t a, const int16x8_t b, const tran_coef_t constant,
105     int32x4_t *add_lo, int32x4_t *add_hi, int32x4_t *sub_lo,
106     int32x4_t *sub_hi) {
107   const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant);
108   const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant);
109   const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), constant);
110   const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), constant);
111   const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant);
112   const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant);
113   *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
114   *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
115   *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
116   *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
117 }
118 
119 // fdct_round_shift((a +/- b) * c)
120 // Original Variant that performs normal implementation on full vector
121 // fully accurate does 32-bit processing, takes 16-bit values
122 // returns narrowed down 16-bit values
butterfly_one_coeff_s16_s32_narrow(const int16x8_t a,const int16x8_t b,const tran_coef_t constant,int16x8_t * add,int16x8_t * sub)123 static INLINE void butterfly_one_coeff_s16_s32_narrow(
124     const int16x8_t a, const int16x8_t b, const tran_coef_t constant,
125     int16x8_t *add, int16x8_t *sub) {
126   int32x4_t add32_lo, add32_hi, sub32_lo, sub32_hi;
127   butterfly_one_coeff_s16_s32(a, b, constant, &add32_lo, &add32_hi, &sub32_lo,
128                               &sub32_hi);
129   *add = vcombine_s16(vmovn_s32(add32_lo), vmovn_s32(add32_hi));
130   *sub = vcombine_s16(vmovn_s32(sub32_lo), vmovn_s32(sub32_hi));
131 }
132 
133 // fdct_round_shift((a +/- b) * c)
134 // Variant that performs fast vqrdmulhq_s32 operation on full vector
135 // more accurate does 32-bit processing, takes and returns 32-bit values,
136 // high/low
butterfly_one_coeff_s32_noround(const int32x4_t a_lo,const int32x4_t a_hi,const int32x4_t b_lo,const int32x4_t b_hi,const tran_coef_t constant,int32x4_t * add_lo,int32x4_t * add_hi,int32x4_t * sub_lo,int32x4_t * sub_hi)137 static INLINE void butterfly_one_coeff_s32_noround(
138     const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
139     const int32x4_t b_hi, const tran_coef_t constant, int32x4_t *add_lo,
140     int32x4_t *add_hi, int32x4_t *sub_lo, int32x4_t *sub_hi) {
141   const int32x4_t a1 = vmulq_n_s32(a_lo, constant);
142   const int32x4_t a2 = vmulq_n_s32(a_hi, constant);
143   const int32x4_t a3 = vmulq_n_s32(a_lo, constant);
144   const int32x4_t a4 = vmulq_n_s32(a_hi, constant);
145   *add_lo = vmlaq_n_s32(a1, b_lo, constant);
146   *add_hi = vmlaq_n_s32(a2, b_hi, constant);
147   *sub_lo = vmlsq_n_s32(a3, b_lo, constant);
148   *sub_hi = vmlsq_n_s32(a4, b_hi, constant);
149 }
150 
151 // fdct_round_shift((a +/- b) * c)
152 // Variant that performs fast vqrdmulhq_s32 operation on full vector
153 // more accurate does 32-bit processing, takes and returns 32-bit values,
154 // high/low
butterfly_one_coeff_s32_fast_half(const int32x4_t a,const int32x4_t b,const tran_coef_t constant,int32x4_t * add,int32x4_t * sub)155 static INLINE void butterfly_one_coeff_s32_fast_half(const int32x4_t a,
156                                                      const int32x4_t b,
157                                                      const tran_coef_t constant,
158                                                      int32x4_t *add,
159                                                      int32x4_t *sub) {
160   const int32x4_t c = vdupq_n_s32(constant << 17);
161   *add = vqrdmulhq_s32(vaddq_s32(a, b), c);
162   *sub = vqrdmulhq_s32(vsubq_s32(a, b), c);
163 }
164 
165 // fdct_round_shift((a +/- b) * c)
166 // Variant that performs fast vqrdmulhq_s32 operation on full vector
167 // more accurate does 32-bit processing, takes and returns 32-bit values,
168 // high/low
butterfly_one_coeff_s32_fast(const int32x4_t a_lo,const int32x4_t a_hi,const int32x4_t b_lo,const int32x4_t b_hi,const tran_coef_t constant,int32x4_t * add_lo,int32x4_t * add_hi,int32x4_t * sub_lo,int32x4_t * sub_hi)169 static INLINE void butterfly_one_coeff_s32_fast(
170     const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
171     const int32x4_t b_hi, const tran_coef_t constant, int32x4_t *add_lo,
172     int32x4_t *add_hi, int32x4_t *sub_lo, int32x4_t *sub_hi) {
173   const int32x4_t c = vdupq_n_s32(constant << 17);
174   *add_lo = vqrdmulhq_s32(vaddq_s32(a_lo, b_lo), c);
175   *add_hi = vqrdmulhq_s32(vaddq_s32(a_hi, b_hi), c);
176   *sub_lo = vqrdmulhq_s32(vsubq_s32(a_lo, b_lo), c);
177   *sub_hi = vqrdmulhq_s32(vsubq_s32(a_hi, b_hi), c);
178 }
179 
180 // fdct_round_shift((a +/- b) * c)
181 // Variant that performs normal implementation on full vector
182 // more accurate does 64-bit processing, takes and returns 32-bit values
183 // returns narrowed results
butterfly_one_coeff_s32_s64_narrow(const int32x4_t a_lo,const int32x4_t a_hi,const int32x4_t b_lo,const int32x4_t b_hi,const tran_coef_t constant,int32x4_t * add_lo,int32x4_t * add_hi,int32x4_t * sub_lo,int32x4_t * sub_hi)184 static INLINE void butterfly_one_coeff_s32_s64_narrow(
185     const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
186     const int32x4_t b_hi, const tran_coef_t constant, int32x4_t *add_lo,
187     int32x4_t *add_hi, int32x4_t *sub_lo, int32x4_t *sub_hi) {
188   // ac holds the following values:
189   // ac: vget_low_s32(a_lo) * c, vget_high_s32(a_lo) * c,
190   //     vget_low_s32(a_hi) * c, vget_high_s32(a_hi) * c
191   int64x2_t ac[4];
192   int64x2_t sum[4];
193   int64x2_t diff[4];
194 
195   ac[0] = vmull_n_s32(vget_low_s32(a_lo), constant);
196   ac[1] = vmull_n_s32(vget_high_s32(a_lo), constant);
197   ac[2] = vmull_n_s32(vget_low_s32(a_hi), constant);
198   ac[3] = vmull_n_s32(vget_high_s32(a_hi), constant);
199 
200   sum[0] = vmlal_n_s32(ac[0], vget_low_s32(b_lo), constant);
201   sum[1] = vmlal_n_s32(ac[1], vget_high_s32(b_lo), constant);
202   sum[2] = vmlal_n_s32(ac[2], vget_low_s32(b_hi), constant);
203   sum[3] = vmlal_n_s32(ac[3], vget_high_s32(b_hi), constant);
204   *add_lo = vcombine_s32(vrshrn_n_s64(sum[0], DCT_CONST_BITS),
205                          vrshrn_n_s64(sum[1], DCT_CONST_BITS));
206   *add_hi = vcombine_s32(vrshrn_n_s64(sum[2], DCT_CONST_BITS),
207                          vrshrn_n_s64(sum[3], DCT_CONST_BITS));
208 
209   diff[0] = vmlsl_n_s32(ac[0], vget_low_s32(b_lo), constant);
210   diff[1] = vmlsl_n_s32(ac[1], vget_high_s32(b_lo), constant);
211   diff[2] = vmlsl_n_s32(ac[2], vget_low_s32(b_hi), constant);
212   diff[3] = vmlsl_n_s32(ac[3], vget_high_s32(b_hi), constant);
213   *sub_lo = vcombine_s32(vrshrn_n_s64(diff[0], DCT_CONST_BITS),
214                          vrshrn_n_s64(diff[1], DCT_CONST_BITS));
215   *sub_hi = vcombine_s32(vrshrn_n_s64(diff[2], DCT_CONST_BITS),
216                          vrshrn_n_s64(diff[3], DCT_CONST_BITS));
217 }
218 
219 // fdct_round_shift(a * c1 +/- b * c2)
220 // Variant that performs normal implementation on half vector
221 // more accurate does 64-bit processing, takes and returns 32-bit values
222 // returns narrowed results
butterfly_two_coeff_s32_s64_narrow_half(const int32x4_t a,const int32x4_t b,const tran_coef_t constant1,const tran_coef_t constant2,int32x4_t * add,int32x4_t * sub)223 static INLINE void butterfly_two_coeff_s32_s64_narrow_half(
224     const int32x4_t a, const int32x4_t b, const tran_coef_t constant1,
225     const tran_coef_t constant2, int32x4_t *add, int32x4_t *sub) {
226   const int32x2_t a_lo = vget_low_s32(a);
227   const int32x2_t a_hi = vget_high_s32(a);
228   const int32x2_t b_lo = vget_low_s32(b);
229   const int32x2_t b_hi = vget_high_s32(b);
230 
231   const int64x2_t axc0_64_lo = vmull_n_s32(a_lo, constant1);
232   const int64x2_t axc0_64_hi = vmull_n_s32(a_hi, constant1);
233   const int64x2_t axc1_64_lo = vmull_n_s32(a_lo, constant2);
234   const int64x2_t axc1_64_hi = vmull_n_s32(a_hi, constant2);
235 
236   const int64x2_t sum_lo = vmlal_n_s32(axc0_64_lo, b_lo, constant2);
237   const int64x2_t sum_hi = vmlal_n_s32(axc0_64_hi, b_hi, constant2);
238   const int64x2_t diff_lo = vmlsl_n_s32(axc1_64_lo, b_lo, constant1);
239   const int64x2_t diff_hi = vmlsl_n_s32(axc1_64_hi, b_hi, constant1);
240 
241   *add = vcombine_s32(vrshrn_n_s64(sum_lo, DCT_CONST_BITS),
242                       vrshrn_n_s64(sum_hi, DCT_CONST_BITS));
243   *sub = vcombine_s32(vrshrn_n_s64(diff_lo, DCT_CONST_BITS),
244                       vrshrn_n_s64(diff_hi, DCT_CONST_BITS));
245 }
246 
247 // fdct_round_shift(a * c1 +/- b * c2)
248 // Variant that performs normal implementation on full vector
249 // more accurate does 64-bit processing, takes and returns 64-bit values
250 // returns results without rounding
butterfly_two_coeff_s32_s64_noround(const int32x4_t a_lo,const int32x4_t a_hi,const int32x4_t b_lo,const int32x4_t b_hi,const tran_coef_t constant1,const tran_coef_t constant2,int64x2_t * add_lo,int64x2_t * add_hi,int64x2_t * sub_lo,int64x2_t * sub_hi)251 static INLINE void butterfly_two_coeff_s32_s64_noround(
252     const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
253     const int32x4_t b_hi, const tran_coef_t constant1,
254     const tran_coef_t constant2, int64x2_t *add_lo /*[2]*/,
255     int64x2_t *add_hi /*[2]*/, int64x2_t *sub_lo /*[2]*/,
256     int64x2_t *sub_hi /*[2]*/) {
257   // ac1/ac2 hold the following values:
258   // ac1: vget_low_s32(a_lo) * c1, vget_high_s32(a_lo) * c1,
259   //      vget_low_s32(a_hi) * c1, vget_high_s32(a_hi) * c1
260   // ac2: vget_low_s32(a_lo) * c2, vget_high_s32(a_lo) * c2,
261   //      vget_low_s32(a_hi) * c2, vget_high_s32(a_hi) * c2
262   int64x2_t ac1[4];
263   int64x2_t ac2[4];
264 
265   ac1[0] = vmull_n_s32(vget_low_s32(a_lo), constant1);
266   ac1[1] = vmull_n_s32(vget_high_s32(a_lo), constant1);
267   ac1[2] = vmull_n_s32(vget_low_s32(a_hi), constant1);
268   ac1[3] = vmull_n_s32(vget_high_s32(a_hi), constant1);
269   ac2[0] = vmull_n_s32(vget_low_s32(a_lo), constant2);
270   ac2[1] = vmull_n_s32(vget_high_s32(a_lo), constant2);
271   ac2[2] = vmull_n_s32(vget_low_s32(a_hi), constant2);
272   ac2[3] = vmull_n_s32(vget_high_s32(a_hi), constant2);
273 
274   add_lo[0] = vmlal_n_s32(ac1[0], vget_low_s32(b_lo), constant2);
275   add_lo[1] = vmlal_n_s32(ac1[1], vget_high_s32(b_lo), constant2);
276   add_hi[0] = vmlal_n_s32(ac1[2], vget_low_s32(b_hi), constant2);
277   add_hi[1] = vmlal_n_s32(ac1[3], vget_high_s32(b_hi), constant2);
278 
279   sub_lo[0] = vmlsl_n_s32(ac2[0], vget_low_s32(b_lo), constant1);
280   sub_lo[1] = vmlsl_n_s32(ac2[1], vget_high_s32(b_lo), constant1);
281   sub_hi[0] = vmlsl_n_s32(ac2[2], vget_low_s32(b_hi), constant1);
282   sub_hi[1] = vmlsl_n_s32(ac2[3], vget_high_s32(b_hi), constant1);
283 }
284 
285 // fdct_round_shift(a * c1 +/- b * c2)
286 // Variant that performs normal implementation on full vector
287 // more accurate does 64-bit processing, takes and returns 32-bit values
288 // returns narrowed results
butterfly_two_coeff_s32_s64_narrow(const int32x4_t a_lo,const int32x4_t a_hi,const int32x4_t b_lo,const int32x4_t b_hi,const tran_coef_t constant1,const tran_coef_t constant2,int32x4_t * add_lo,int32x4_t * add_hi,int32x4_t * sub_lo,int32x4_t * sub_hi)289 static INLINE void butterfly_two_coeff_s32_s64_narrow(
290     const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
291     const int32x4_t b_hi, const tran_coef_t constant1,
292     const tran_coef_t constant2, int32x4_t *add_lo, int32x4_t *add_hi,
293     int32x4_t *sub_lo, int32x4_t *sub_hi) {
294   // ac1/ac2 hold the following values:
295   // ac1: vget_low_s32(a_lo) * c1, vget_high_s32(a_lo) * c1,
296   //      vget_low_s32(a_hi) * c1, vget_high_s32(a_hi) * c1
297   // ac2: vget_low_s32(a_lo) * c2, vget_high_s32(a_lo) * c2,
298   //      vget_low_s32(a_hi) * c2, vget_high_s32(a_hi) * c2
299   int64x2_t ac1[4];
300   int64x2_t ac2[4];
301   int64x2_t sum[4];
302   int64x2_t diff[4];
303 
304   ac1[0] = vmull_n_s32(vget_low_s32(a_lo), constant1);
305   ac1[1] = vmull_n_s32(vget_high_s32(a_lo), constant1);
306   ac1[2] = vmull_n_s32(vget_low_s32(a_hi), constant1);
307   ac1[3] = vmull_n_s32(vget_high_s32(a_hi), constant1);
308   ac2[0] = vmull_n_s32(vget_low_s32(a_lo), constant2);
309   ac2[1] = vmull_n_s32(vget_high_s32(a_lo), constant2);
310   ac2[2] = vmull_n_s32(vget_low_s32(a_hi), constant2);
311   ac2[3] = vmull_n_s32(vget_high_s32(a_hi), constant2);
312 
313   sum[0] = vmlal_n_s32(ac1[0], vget_low_s32(b_lo), constant2);
314   sum[1] = vmlal_n_s32(ac1[1], vget_high_s32(b_lo), constant2);
315   sum[2] = vmlal_n_s32(ac1[2], vget_low_s32(b_hi), constant2);
316   sum[3] = vmlal_n_s32(ac1[3], vget_high_s32(b_hi), constant2);
317   *add_lo = vcombine_s32(vrshrn_n_s64(sum[0], DCT_CONST_BITS),
318                          vrshrn_n_s64(sum[1], DCT_CONST_BITS));
319   *add_hi = vcombine_s32(vrshrn_n_s64(sum[2], DCT_CONST_BITS),
320                          vrshrn_n_s64(sum[3], DCT_CONST_BITS));
321 
322   diff[0] = vmlsl_n_s32(ac2[0], vget_low_s32(b_lo), constant1);
323   diff[1] = vmlsl_n_s32(ac2[1], vget_high_s32(b_lo), constant1);
324   diff[2] = vmlsl_n_s32(ac2[2], vget_low_s32(b_hi), constant1);
325   diff[3] = vmlsl_n_s32(ac2[3], vget_high_s32(b_hi), constant1);
326   *sub_lo = vcombine_s32(vrshrn_n_s64(diff[0], DCT_CONST_BITS),
327                          vrshrn_n_s64(diff[1], DCT_CONST_BITS));
328   *sub_hi = vcombine_s32(vrshrn_n_s64(diff[2], DCT_CONST_BITS),
329                          vrshrn_n_s64(diff[3], DCT_CONST_BITS));
330 }
331 
332 // fdct_round_shift(a * c1 +/- b * c2)
333 // Original Variant that performs normal implementation on full vector
334 // more accurate does 32-bit processing, takes and returns 32-bit values
335 // returns narrowed results
butterfly_two_coeff_s16_s32_noround(const int16x4_t a_lo,const int16x4_t a_hi,const int16x4_t b_lo,const int16x4_t b_hi,const tran_coef_t constant1,const tran_coef_t constant2,int32x4_t * add_lo,int32x4_t * add_hi,int32x4_t * sub_lo,int32x4_t * sub_hi)336 static INLINE void butterfly_two_coeff_s16_s32_noround(
337     const int16x4_t a_lo, const int16x4_t a_hi, const int16x4_t b_lo,
338     const int16x4_t b_hi, const tran_coef_t constant1,
339     const tran_coef_t constant2, int32x4_t *add_lo, int32x4_t *add_hi,
340     int32x4_t *sub_lo, int32x4_t *sub_hi) {
341   const int32x4_t a1 = vmull_n_s16(a_lo, constant1);
342   const int32x4_t a2 = vmull_n_s16(a_hi, constant1);
343   const int32x4_t a3 = vmull_n_s16(a_lo, constant2);
344   const int32x4_t a4 = vmull_n_s16(a_hi, constant2);
345   *add_lo = vmlal_n_s16(a1, b_lo, constant2);
346   *add_hi = vmlal_n_s16(a2, b_hi, constant2);
347   *sub_lo = vmlsl_n_s16(a3, b_lo, constant1);
348   *sub_hi = vmlsl_n_s16(a4, b_hi, constant1);
349 }
350 
351 // fdct_round_shift(a * c1 +/- b * c2)
352 // Original Variant that performs normal implementation on full vector
353 // more accurate does 32-bit processing, takes and returns 32-bit values
354 // returns narrowed results
butterfly_two_coeff_s32_noround(const int32x4_t a_lo,const int32x4_t a_hi,const int32x4_t b_lo,const int32x4_t b_hi,const tran_coef_t constant1,const tran_coef_t constant2,int32x4_t * add_lo,int32x4_t * add_hi,int32x4_t * sub_lo,int32x4_t * sub_hi)355 static INLINE void butterfly_two_coeff_s32_noround(
356     const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
357     const int32x4_t b_hi, const tran_coef_t constant1,
358     const tran_coef_t constant2, int32x4_t *add_lo, int32x4_t *add_hi,
359     int32x4_t *sub_lo, int32x4_t *sub_hi) {
360   const int32x4_t a1 = vmulq_n_s32(a_lo, constant1);
361   const int32x4_t a2 = vmulq_n_s32(a_hi, constant1);
362   const int32x4_t a3 = vmulq_n_s32(a_lo, constant2);
363   const int32x4_t a4 = vmulq_n_s32(a_hi, constant2);
364   *add_lo = vmlaq_n_s32(a1, b_lo, constant2);
365   *add_hi = vmlaq_n_s32(a2, b_hi, constant2);
366   *sub_lo = vmlsq_n_s32(a3, b_lo, constant1);
367   *sub_hi = vmlsq_n_s32(a4, b_hi, constant1);
368 }
369 
370 // fdct_round_shift(a * c1 +/- b * c2)
371 // Variant that performs normal implementation on half vector
372 // more accurate does 32-bit processing, takes and returns 16-bit values
373 // returns narrowed results
butterfly_two_coeff_half(const int16x4_t a,const int16x4_t b,const tran_coef_t constant1,const tran_coef_t constant2,int16x4_t * add,int16x4_t * sub)374 static INLINE void butterfly_two_coeff_half(const int16x4_t a,
375                                             const int16x4_t b,
376                                             const tran_coef_t constant1,
377                                             const tran_coef_t constant2,
378                                             int16x4_t *add, int16x4_t *sub) {
379   const int32x4_t a1 = vmull_n_s16(a, constant1);
380   const int32x4_t a2 = vmull_n_s16(a, constant2);
381   const int32x4_t sum = vmlal_n_s16(a1, b, constant2);
382   const int32x4_t diff = vmlsl_n_s16(a2, b, constant1);
383   *add = vqrshrn_n_s32(sum, DCT_CONST_BITS);
384   *sub = vqrshrn_n_s32(diff, DCT_CONST_BITS);
385 }
386 
387 // fdct_round_shift(a * c1 +/- b * c2)
388 // Original Variant that performs normal implementation on full vector
389 // more accurate does 32-bit processing, takes and returns 16-bit values
390 // returns narrowed results
butterfly_two_coeff(const int16x8_t a,const int16x8_t b,const tran_coef_t constant1,const tran_coef_t constant2,int16x8_t * add,int16x8_t * sub)391 static INLINE void butterfly_two_coeff(const int16x8_t a, const int16x8_t b,
392                                        const tran_coef_t constant1,
393                                        const tran_coef_t constant2,
394                                        int16x8_t *add, int16x8_t *sub) {
395   const int32x4_t a1 = vmull_n_s16(vget_low_s16(a), constant1);
396   const int32x4_t a2 = vmull_n_s16(vget_high_s16(a), constant1);
397   const int32x4_t a3 = vmull_n_s16(vget_low_s16(a), constant2);
398   const int32x4_t a4 = vmull_n_s16(vget_high_s16(a), constant2);
399   const int32x4_t sum0 = vmlal_n_s16(a1, vget_low_s16(b), constant2);
400   const int32x4_t sum1 = vmlal_n_s16(a2, vget_high_s16(b), constant2);
401   const int32x4_t diff0 = vmlsl_n_s16(a3, vget_low_s16(b), constant1);
402   const int32x4_t diff1 = vmlsl_n_s16(a4, vget_high_s16(b), constant1);
403   const int16x4_t rounded0 = vqrshrn_n_s32(sum0, DCT_CONST_BITS);
404   const int16x4_t rounded1 = vqrshrn_n_s32(sum1, DCT_CONST_BITS);
405   const int16x4_t rounded2 = vqrshrn_n_s32(diff0, DCT_CONST_BITS);
406   const int16x4_t rounded3 = vqrshrn_n_s32(diff1, DCT_CONST_BITS);
407   *add = vcombine_s16(rounded0, rounded1);
408   *sub = vcombine_s16(rounded2, rounded3);
409 }
410 
411 // fdct_round_shift(a * c1 +/- b * c2)
412 // Original Variant that performs normal implementation on full vector
413 // more accurate does 32-bit processing, takes and returns 32-bit values
414 // returns narrowed results
butterfly_two_coeff_s32(const int32x4_t a_lo,const int32x4_t a_hi,const int32x4_t b_lo,const int32x4_t b_hi,const tran_coef_t constant1,const tran_coef_t constant2,int32x4_t * add_lo,int32x4_t * add_hi,int32x4_t * sub_lo,int32x4_t * sub_hi)415 static INLINE void butterfly_two_coeff_s32(
416     const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
417     const int32x4_t b_hi, const tran_coef_t constant1,
418     const tran_coef_t constant2, int32x4_t *add_lo, int32x4_t *add_hi,
419     int32x4_t *sub_lo, int32x4_t *sub_hi) {
420   const int32x4_t a1 = vmulq_n_s32(a_lo, constant1);
421   const int32x4_t a2 = vmulq_n_s32(a_hi, constant1);
422   const int32x4_t a3 = vmulq_n_s32(a_lo, constant2);
423   const int32x4_t a4 = vmulq_n_s32(a_hi, constant2);
424   const int32x4_t sum0 = vmlaq_n_s32(a1, b_lo, constant2);
425   const int32x4_t sum1 = vmlaq_n_s32(a2, b_hi, constant2);
426   const int32x4_t diff0 = vmlsq_n_s32(a3, b_lo, constant1);
427   const int32x4_t diff1 = vmlsq_n_s32(a4, b_hi, constant1);
428   *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
429   *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
430   *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
431   *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
432 }
433 
434 // Add 1 if positive, 2 if negative, and shift by 2.
435 // In practice, add 1, then add the sign bit, then shift without rounding.
add_round_shift_s16(const int16x8_t a)436 static INLINE int16x8_t add_round_shift_s16(const int16x8_t a) {
437   const int16x8_t one = vdupq_n_s16(1);
438   const uint16x8_t a_u16 = vreinterpretq_u16_s16(a);
439   const uint16x8_t a_sign_u16 = vshrq_n_u16(a_u16, 15);
440   const int16x8_t a_sign_s16 = vreinterpretq_s16_u16(a_sign_u16);
441   return vshrq_n_s16(vaddq_s16(vaddq_s16(a, a_sign_s16), one), 2);
442 }
443 
444 // Add 1 if positive, 2 if negative, and shift by 2.
445 // In practice, add 1, then add the sign bit, then shift and round,
446 // return narrowed results
add_round_shift_s32_narrow(const int32x4_t a_lo,const int32x4_t a_hi)447 static INLINE int16x8_t add_round_shift_s32_narrow(const int32x4_t a_lo,
448                                                    const int32x4_t a_hi) {
449   const int32x4_t one = vdupq_n_s32(1);
450   const uint32x4_t a_lo_u32 = vreinterpretq_u32_s32(a_lo);
451   const uint32x4_t a_lo_sign_u32 = vshrq_n_u32(a_lo_u32, 31);
452   const int32x4_t a_lo_sign_s32 = vreinterpretq_s32_u32(a_lo_sign_u32);
453   const int16x4_t b_lo =
454       vshrn_n_s32(vqaddq_s32(vqaddq_s32(a_lo, a_lo_sign_s32), one), 2);
455   const uint32x4_t a_hi_u32 = vreinterpretq_u32_s32(a_hi);
456   const uint32x4_t a_hi_sign_u32 = vshrq_n_u32(a_hi_u32, 31);
457   const int32x4_t a_hi_sign_s32 = vreinterpretq_s32_u32(a_hi_sign_u32);
458   const int16x4_t b_hi =
459       vshrn_n_s32(vqaddq_s32(vqaddq_s32(a_hi, a_hi_sign_s32), one), 2);
460   return vcombine_s16(b_lo, b_hi);
461 }
462 
463 // Add 1 if negative, and shift by 1.
464 // In practice, add the sign bit, then shift and round
add_round_shift_half_s32(const int32x4_t a)465 static INLINE int32x4_t add_round_shift_half_s32(const int32x4_t a) {
466   const uint32x4_t a_u32 = vreinterpretq_u32_s32(a);
467   const uint32x4_t a_sign_u32 = vshrq_n_u32(a_u32, 31);
468   const int32x4_t a_sign_s32 = vreinterpretq_s32_u32(a_sign_u32);
469   return vshrq_n_s32(vaddq_s32(a, a_sign_s32), 1);
470 }
471 
472 // Add 1 if positive, 2 if negative, and shift by 2.
473 // In practice, add 1, then add the sign bit, then shift without rounding.
add_round_shift_s32(const int32x4_t a)474 static INLINE int32x4_t add_round_shift_s32(const int32x4_t a) {
475   const int32x4_t one = vdupq_n_s32(1);
476   const uint32x4_t a_u32 = vreinterpretq_u32_s32(a);
477   const uint32x4_t a_sign_u32 = vshrq_n_u32(a_u32, 31);
478   const int32x4_t a_sign_s32 = vreinterpretq_s32_u32(a_sign_u32);
479   return vshrq_n_s32(vaddq_s32(vaddq_s32(a, a_sign_s32), one), 2);
480 }
481 
482 // Add 2 if positive, 1 if negative, and shift by 2.
483 // In practice, subtract the sign bit, then shift with rounding.
sub_round_shift_s16(const int16x8_t a)484 static INLINE int16x8_t sub_round_shift_s16(const int16x8_t a) {
485   const uint16x8_t a_u16 = vreinterpretq_u16_s16(a);
486   const uint16x8_t a_sign_u16 = vshrq_n_u16(a_u16, 15);
487   const int16x8_t a_sign_s16 = vreinterpretq_s16_u16(a_sign_u16);
488   return vrshrq_n_s16(vsubq_s16(a, a_sign_s16), 2);
489 }
490 
491 // Add 2 if positive, 1 if negative, and shift by 2.
492 // In practice, subtract the sign bit, then shift with rounding.
sub_round_shift_s32(const int32x4_t a)493 static INLINE int32x4_t sub_round_shift_s32(const int32x4_t a) {
494   const uint32x4_t a_u32 = vreinterpretq_u32_s32(a);
495   const uint32x4_t a_sign_u32 = vshrq_n_u32(a_u32, 31);
496   const int32x4_t a_sign_s32 = vreinterpretq_s32_u32(a_sign_u32);
497   return vrshrq_n_s32(vsubq_s32(a, a_sign_s32), 2);
498 }
499 
add_s64_round_narrow(const int64x2_t * a,const int64x2_t * b)500 static INLINE int32x4_t add_s64_round_narrow(const int64x2_t *a /*[2]*/,
501                                              const int64x2_t *b /*[2]*/) {
502   int64x2_t result[2];
503   result[0] = vaddq_s64(a[0], b[0]);
504   result[1] = vaddq_s64(a[1], b[1]);
505   return vcombine_s32(vrshrn_n_s64(result[0], DCT_CONST_BITS),
506                       vrshrn_n_s64(result[1], DCT_CONST_BITS));
507 }
508 
sub_s64_round_narrow(const int64x2_t * a,const int64x2_t * b)509 static INLINE int32x4_t sub_s64_round_narrow(const int64x2_t *a /*[2]*/,
510                                              const int64x2_t *b /*[2]*/) {
511   int64x2_t result[2];
512   result[0] = vsubq_s64(a[0], b[0]);
513   result[1] = vsubq_s64(a[1], b[1]);
514   return vcombine_s32(vrshrn_n_s64(result[0], DCT_CONST_BITS),
515                       vrshrn_n_s64(result[1], DCT_CONST_BITS));
516 }
517 
add_s32_s64_narrow(const int32x4_t a,const int32x4_t b)518 static INLINE int32x4_t add_s32_s64_narrow(const int32x4_t a,
519                                            const int32x4_t b) {
520   int64x2_t a64[2], b64[2], result[2];
521   a64[0] = vmovl_s32(vget_low_s32(a));
522   a64[1] = vmovl_s32(vget_high_s32(a));
523   b64[0] = vmovl_s32(vget_low_s32(b));
524   b64[1] = vmovl_s32(vget_high_s32(b));
525   result[0] = vaddq_s64(a64[0], b64[0]);
526   result[1] = vaddq_s64(a64[1], b64[1]);
527   return vcombine_s32(vmovn_s64(result[0]), vmovn_s64(result[1]));
528 }
529 
sub_s32_s64_narrow(const int32x4_t a,const int32x4_t b)530 static INLINE int32x4_t sub_s32_s64_narrow(const int32x4_t a,
531                                            const int32x4_t b) {
532   int64x2_t a64[2], b64[2], result[2];
533   a64[0] = vmovl_s32(vget_low_s32(a));
534   a64[1] = vmovl_s32(vget_high_s32(a));
535   b64[0] = vmovl_s32(vget_low_s32(b));
536   b64[1] = vmovl_s32(vget_high_s32(b));
537   result[0] = vsubq_s64(a64[0], b64[0]);
538   result[1] = vsubq_s64(a64[1], b64[1]);
539   return vcombine_s32(vmovn_s64(result[0]), vmovn_s64(result[1]));
540 }
541 
542 #endif  // VPX_VPX_DSP_ARM_FDCT_NEON_H_
543