xref: /aosp_15_r20/external/libgav1/src/dsp/arm/inverse_transform_10bit_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2021 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/inverse_transform.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
19 
20 #include <arm_neon.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cstdint>
25 
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/array_2d.h"
30 #include "src/utils/common.h"
31 #include "src/utils/compiler_attributes.h"
32 #include "src/utils/constants.h"
33 
34 namespace libgav1 {
35 namespace dsp {
36 namespace {
37 
38 // Include the constants and utility functions inside the anonymous namespace.
39 #include "src/dsp/inverse_transform.inc"
40 
41 //------------------------------------------------------------------------------
42 
Transpose4x4(const int32x4_t in[4],int32x4_t out[4])43 LIBGAV1_ALWAYS_INLINE void Transpose4x4(const int32x4_t in[4],
44                                         int32x4_t out[4]) {
45   // in:
46   // 00 01 02 03
47   // 10 11 12 13
48   // 20 21 22 23
49   // 30 31 32 33
50 
51   // 00 10 02 12   a.val[0]
52   // 01 11 03 13   a.val[1]
53   // 20 30 22 32   b.val[0]
54   // 21 31 23 33   b.val[1]
55   const int32x4x2_t a = vtrnq_s32(in[0], in[1]);
56   const int32x4x2_t b = vtrnq_s32(in[2], in[3]);
57   out[0] = vextq_s32(vextq_s32(a.val[0], a.val[0], 2), b.val[0], 2);
58   out[1] = vextq_s32(vextq_s32(a.val[1], a.val[1], 2), b.val[1], 2);
59   out[2] = vextq_s32(a.val[0], vextq_s32(b.val[0], b.val[0], 2), 2);
60   out[3] = vextq_s32(a.val[1], vextq_s32(b.val[1], b.val[1], 2), 2);
61   // out:
62   // 00 10 20 30
63   // 01 11 21 31
64   // 02 12 22 32
65   // 03 13 23 33
66 }
67 
68 //------------------------------------------------------------------------------
69 template <int store_count>
StoreDst(int32_t * LIBGAV1_RESTRICT dst,int32_t stride,int32_t idx,const int32x4_t * const s)70 LIBGAV1_ALWAYS_INLINE void StoreDst(int32_t* LIBGAV1_RESTRICT dst,
71                                     int32_t stride, int32_t idx,
72                                     const int32x4_t* const s) {
73   assert(store_count % 4 == 0);
74   for (int i = 0; i < store_count; i += 4) {
75     vst1q_s32(&dst[i * stride + idx], s[i]);
76     vst1q_s32(&dst[(i + 1) * stride + idx], s[i + 1]);
77     vst1q_s32(&dst[(i + 2) * stride + idx], s[i + 2]);
78     vst1q_s32(&dst[(i + 3) * stride + idx], s[i + 3]);
79   }
80 }
81 
82 template <int load_count>
LoadSrc(const int32_t * LIBGAV1_RESTRICT src,int32_t stride,int32_t idx,int32x4_t * x)83 LIBGAV1_ALWAYS_INLINE void LoadSrc(const int32_t* LIBGAV1_RESTRICT src,
84                                    int32_t stride, int32_t idx, int32x4_t* x) {
85   assert(load_count % 4 == 0);
86   for (int i = 0; i < load_count; i += 4) {
87     x[i] = vld1q_s32(&src[i * stride + idx]);
88     x[i + 1] = vld1q_s32(&src[(i + 1) * stride + idx]);
89     x[i + 2] = vld1q_s32(&src[(i + 2) * stride + idx]);
90     x[i + 3] = vld1q_s32(&src[(i + 3) * stride + idx]);
91   }
92 }
93 
94 // Butterfly rotate 4 values.
ButterflyRotation_4(int32x4_t * a,int32x4_t * b,const int angle,const bool flip)95 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(int32x4_t* a, int32x4_t* b,
96                                                const int angle,
97                                                const bool flip) {
98   const int32_t cos128 = Cos128(angle);
99   const int32_t sin128 = Sin128(angle);
100   const int32x4_t acc_x = vmulq_n_s32(*a, cos128);
101   const int32x4_t acc_y = vmulq_n_s32(*a, sin128);
102   // The max range for the input is 18 bits. The cos128/sin128 is 13 bits,
103   // which leaves 1 bit for the add/subtract. For 10bpp, x/y will fit in a 32
104   // bit lane.
105   const int32x4_t x0 = vmlsq_n_s32(acc_x, *b, sin128);
106   const int32x4_t y0 = vmlaq_n_s32(acc_y, *b, cos128);
107   const int32x4_t x = vrshrq_n_s32(x0, 12);
108   const int32x4_t y = vrshrq_n_s32(y0, 12);
109   if (flip) {
110     *a = y;
111     *b = x;
112   } else {
113     *a = x;
114     *b = y;
115   }
116 }
117 
ButterflyRotation_FirstIsZero(int32x4_t * a,int32x4_t * b,const int angle,const bool flip)118 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_FirstIsZero(int32x4_t* a,
119                                                          int32x4_t* b,
120                                                          const int angle,
121                                                          const bool flip) {
122   const int32_t cos128 = Cos128(angle);
123   const int32_t sin128 = Sin128(angle);
124   assert(sin128 <= 0xfff);
125   const int32x4_t x0 = vmulq_n_s32(*b, -sin128);
126   const int32x4_t y0 = vmulq_n_s32(*b, cos128);
127   const int32x4_t x = vrshrq_n_s32(x0, 12);
128   const int32x4_t y = vrshrq_n_s32(y0, 12);
129   if (flip) {
130     *a = y;
131     *b = x;
132   } else {
133     *a = x;
134     *b = y;
135   }
136 }
137 
ButterflyRotation_SecondIsZero(int32x4_t * a,int32x4_t * b,const int angle,const bool flip)138 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(int32x4_t* a,
139                                                           int32x4_t* b,
140                                                           const int angle,
141                                                           const bool flip) {
142   const int32_t cos128 = Cos128(angle);
143   const int32_t sin128 = Sin128(angle);
144   const int32x4_t x0 = vmulq_n_s32(*a, cos128);
145   const int32x4_t y0 = vmulq_n_s32(*a, sin128);
146   const int32x4_t x = vrshrq_n_s32(x0, 12);
147   const int32x4_t y = vrshrq_n_s32(y0, 12);
148   if (flip) {
149     *a = y;
150     *b = x;
151   } else {
152     *a = x;
153     *b = y;
154   }
155 }
156 
HadamardRotation(int32x4_t * a,int32x4_t * b,bool flip)157 LIBGAV1_ALWAYS_INLINE void HadamardRotation(int32x4_t* a, int32x4_t* b,
158                                             bool flip) {
159   int32x4_t x, y;
160   if (flip) {
161     y = vqaddq_s32(*b, *a);
162     x = vqsubq_s32(*b, *a);
163   } else {
164     x = vqaddq_s32(*a, *b);
165     y = vqsubq_s32(*a, *b);
166   }
167   *a = x;
168   *b = y;
169 }
170 
HadamardRotation(int32x4_t * a,int32x4_t * b,bool flip,const int32x4_t min,const int32x4_t max)171 LIBGAV1_ALWAYS_INLINE void HadamardRotation(int32x4_t* a, int32x4_t* b,
172                                             bool flip, const int32x4_t min,
173                                             const int32x4_t max) {
174   int32x4_t x, y;
175   if (flip) {
176     y = vqaddq_s32(*b, *a);
177     x = vqsubq_s32(*b, *a);
178   } else {
179     x = vqaddq_s32(*a, *b);
180     y = vqsubq_s32(*a, *b);
181   }
182   *a = vmaxq_s32(vminq_s32(x, max), min);
183   *b = vmaxq_s32(vminq_s32(y, max), min);
184 }
185 
186 using ButterflyRotationFunc = void (*)(int32x4_t* a, int32x4_t* b, int angle,
187                                        bool flip);
188 
189 //------------------------------------------------------------------------------
190 // Discrete Cosine Transforms (DCT).
191 
192 template <int width>
DctDcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)193 LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, int adjusted_tx_height,
194                                      bool should_round, int row_shift) {
195   if (adjusted_tx_height > 1) return false;
196 
197   auto* dst = static_cast<int32_t*>(dest);
198   const int32x4_t v_src = vdupq_n_s32(dst[0]);
199   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
200   const int32x4_t v_src_round =
201       vqrdmulhq_n_s32(v_src, kTransformRowMultiplier << (31 - 12));
202   const int32x4_t s0 = vbslq_s32(v_mask, v_src_round, v_src);
203   const int32_t cos128 = Cos128(32);
204   const int32x4_t xy = vqrdmulhq_n_s32(s0, cos128 << (31 - 12));
205   // vqrshlq_s32 will shift right if shift value is negative.
206   const int32x4_t xy_shifted = vqrshlq_s32(xy, vdupq_n_s32(-row_shift));
207   // Clamp result to signed 16 bits.
208   const int32x4_t result = vmovl_s16(vqmovn_s32(xy_shifted));
209   if (width == 4) {
210     vst1q_s32(dst, result);
211   } else {
212     for (int i = 0; i < width; i += 4) {
213       vst1q_s32(dst, result);
214       dst += 4;
215     }
216   }
217   return true;
218 }
219 
220 template <int height>
DctDcOnlyColumn(void * dest,int adjusted_tx_height,int width)221 LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, int adjusted_tx_height,
222                                            int width) {
223   if (adjusted_tx_height > 1) return false;
224 
225   auto* dst = static_cast<int32_t*>(dest);
226   const int32_t cos128 = Cos128(32);
227 
228   // Calculate dc values for first row.
229   if (width == 4) {
230     const int32x4_t v_src = vld1q_s32(dst);
231     const int32x4_t xy = vqrdmulhq_n_s32(v_src, cos128 << (31 - 12));
232     vst1q_s32(dst, xy);
233   } else {
234     int i = 0;
235     do {
236       const int32x4_t v_src = vld1q_s32(&dst[i]);
237       const int32x4_t xy = vqrdmulhq_n_s32(v_src, cos128 << (31 - 12));
238       vst1q_s32(&dst[i], xy);
239       i += 4;
240     } while (i < width);
241   }
242 
243   // Copy first row to the rest of the block.
244   for (int y = 1; y < height; ++y) {
245     memcpy(&dst[y * width], dst, width * sizeof(dst[0]));
246   }
247   return true;
248 }
249 
250 template <ButterflyRotationFunc butterfly_rotation,
251           bool is_fast_butterfly = false>
Dct4Stages(int32x4_t * s,const int32x4_t min,const int32x4_t max,const bool is_last_stage)252 LIBGAV1_ALWAYS_INLINE void Dct4Stages(int32x4_t* s, const int32x4_t min,
253                                       const int32x4_t max,
254                                       const bool is_last_stage) {
255   // stage 12.
256   if (is_fast_butterfly) {
257     ButterflyRotation_SecondIsZero(&s[0], &s[1], 32, true);
258     ButterflyRotation_SecondIsZero(&s[2], &s[3], 48, false);
259   } else {
260     butterfly_rotation(&s[0], &s[1], 32, true);
261     butterfly_rotation(&s[2], &s[3], 48, false);
262   }
263 
264   // stage 17.
265   if (is_last_stage) {
266     HadamardRotation(&s[0], &s[3], false);
267     HadamardRotation(&s[1], &s[2], false);
268   } else {
269     HadamardRotation(&s[0], &s[3], false, min, max);
270     HadamardRotation(&s[1], &s[2], false, min, max);
271   }
272 }
273 
274 template <ButterflyRotationFunc butterfly_rotation>
Dct4_NEON(void * dest,int32_t step,bool is_row,int row_shift)275 LIBGAV1_ALWAYS_INLINE void Dct4_NEON(void* dest, int32_t step, bool is_row,
276                                      int row_shift) {
277   auto* const dst = static_cast<int32_t*>(dest);
278   // When |is_row| is true, set range to the row range, otherwise, set to the
279   // column range.
280   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
281   const int32x4_t min = vdupq_n_s32(-(1 << range));
282   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
283   int32x4_t s[4], x[4];
284 
285   if (is_row) {
286     assert(step == 4);
287     int32x4x4_t y = vld4q_s32(dst);
288     for (int i = 0; i < 4; ++i) x[i] = y.val[i];
289   } else {
290     LoadSrc<4>(dst, step, 0, x);
291   }
292 
293   // stage 1.
294   // kBitReverseLookup 0, 2, 1, 3
295   s[0] = x[0];
296   s[1] = x[2];
297   s[2] = x[1];
298   s[3] = x[3];
299 
300   Dct4Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/true);
301 
302   if (is_row) {
303     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
304     for (auto& i : s) {
305       i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
306     }
307     int32x4x4_t y;
308     for (int i = 0; i < 4; ++i) y.val[i] = s[i];
309     vst4q_s32(dst, y);
310   } else {
311     StoreDst<4>(dst, step, 0, s);
312   }
313 }
314 
315 template <ButterflyRotationFunc butterfly_rotation,
316           bool is_fast_butterfly = false>
Dct8Stages(int32x4_t * s,const int32x4_t min,const int32x4_t max,const bool is_last_stage)317 LIBGAV1_ALWAYS_INLINE void Dct8Stages(int32x4_t* s, const int32x4_t min,
318                                       const int32x4_t max,
319                                       const bool is_last_stage) {
320   // stage 8.
321   if (is_fast_butterfly) {
322     ButterflyRotation_SecondIsZero(&s[4], &s[7], 56, false);
323     ButterflyRotation_FirstIsZero(&s[5], &s[6], 24, false);
324   } else {
325     butterfly_rotation(&s[4], &s[7], 56, false);
326     butterfly_rotation(&s[5], &s[6], 24, false);
327   }
328 
329   // stage 13.
330   HadamardRotation(&s[4], &s[5], false, min, max);
331   HadamardRotation(&s[6], &s[7], true, min, max);
332 
333   // stage 18.
334   butterfly_rotation(&s[6], &s[5], 32, true);
335 
336   // stage 22.
337   if (is_last_stage) {
338     HadamardRotation(&s[0], &s[7], false);
339     HadamardRotation(&s[1], &s[6], false);
340     HadamardRotation(&s[2], &s[5], false);
341     HadamardRotation(&s[3], &s[4], false);
342   } else {
343     HadamardRotation(&s[0], &s[7], false, min, max);
344     HadamardRotation(&s[1], &s[6], false, min, max);
345     HadamardRotation(&s[2], &s[5], false, min, max);
346     HadamardRotation(&s[3], &s[4], false, min, max);
347   }
348 }
349 
350 // Process dct8 rows or columns, depending on the |is_row| flag.
351 template <ButterflyRotationFunc butterfly_rotation>
Dct8_NEON(void * dest,int32_t step,bool is_row,int row_shift)352 LIBGAV1_ALWAYS_INLINE void Dct8_NEON(void* dest, int32_t step, bool is_row,
353                                      int row_shift) {
354   auto* const dst = static_cast<int32_t*>(dest);
355   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
356   const int32x4_t min = vdupq_n_s32(-(1 << range));
357   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
358   int32x4_t s[8], x[8];
359 
360   if (is_row) {
361     LoadSrc<4>(dst, step, 0, &x[0]);
362     LoadSrc<4>(dst, step, 4, &x[4]);
363     Transpose4x4(&x[0], &x[0]);
364     Transpose4x4(&x[4], &x[4]);
365   } else {
366     LoadSrc<8>(dst, step, 0, &x[0]);
367   }
368 
369   // stage 1.
370   // kBitReverseLookup 0, 4, 2, 6, 1, 5, 3, 7,
371   s[0] = x[0];
372   s[1] = x[4];
373   s[2] = x[2];
374   s[3] = x[6];
375   s[4] = x[1];
376   s[5] = x[5];
377   s[6] = x[3];
378   s[7] = x[7];
379 
380   Dct4Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/false);
381   Dct8Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/true);
382 
383   if (is_row) {
384     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
385     for (auto& i : s) {
386       i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
387     }
388     Transpose4x4(&s[0], &s[0]);
389     Transpose4x4(&s[4], &s[4]);
390     StoreDst<4>(dst, step, 0, &s[0]);
391     StoreDst<4>(dst, step, 4, &s[4]);
392   } else {
393     StoreDst<8>(dst, step, 0, &s[0]);
394   }
395 }
396 
397 template <ButterflyRotationFunc butterfly_rotation,
398           bool is_fast_butterfly = false>
Dct16Stages(int32x4_t * s,const int32x4_t min,const int32x4_t max,const bool is_last_stage)399 LIBGAV1_ALWAYS_INLINE void Dct16Stages(int32x4_t* s, const int32x4_t min,
400                                        const int32x4_t max,
401                                        const bool is_last_stage) {
402   // stage 5.
403   if (is_fast_butterfly) {
404     ButterflyRotation_SecondIsZero(&s[8], &s[15], 60, false);
405     ButterflyRotation_FirstIsZero(&s[9], &s[14], 28, false);
406     ButterflyRotation_SecondIsZero(&s[10], &s[13], 44, false);
407     ButterflyRotation_FirstIsZero(&s[11], &s[12], 12, false);
408   } else {
409     butterfly_rotation(&s[8], &s[15], 60, false);
410     butterfly_rotation(&s[9], &s[14], 28, false);
411     butterfly_rotation(&s[10], &s[13], 44, false);
412     butterfly_rotation(&s[11], &s[12], 12, false);
413   }
414 
415   // stage 9.
416   HadamardRotation(&s[8], &s[9], false, min, max);
417   HadamardRotation(&s[10], &s[11], true, min, max);
418   HadamardRotation(&s[12], &s[13], false, min, max);
419   HadamardRotation(&s[14], &s[15], true, min, max);
420 
421   // stage 14.
422   butterfly_rotation(&s[14], &s[9], 48, true);
423   butterfly_rotation(&s[13], &s[10], 112, true);
424 
425   // stage 19.
426   HadamardRotation(&s[8], &s[11], false, min, max);
427   HadamardRotation(&s[9], &s[10], false, min, max);
428   HadamardRotation(&s[12], &s[15], true, min, max);
429   HadamardRotation(&s[13], &s[14], true, min, max);
430 
431   // stage 23.
432   butterfly_rotation(&s[13], &s[10], 32, true);
433   butterfly_rotation(&s[12], &s[11], 32, true);
434 
435   // stage 26.
436   if (is_last_stage) {
437     HadamardRotation(&s[0], &s[15], false);
438     HadamardRotation(&s[1], &s[14], false);
439     HadamardRotation(&s[2], &s[13], false);
440     HadamardRotation(&s[3], &s[12], false);
441     HadamardRotation(&s[4], &s[11], false);
442     HadamardRotation(&s[5], &s[10], false);
443     HadamardRotation(&s[6], &s[9], false);
444     HadamardRotation(&s[7], &s[8], false);
445   } else {
446     HadamardRotation(&s[0], &s[15], false, min, max);
447     HadamardRotation(&s[1], &s[14], false, min, max);
448     HadamardRotation(&s[2], &s[13], false, min, max);
449     HadamardRotation(&s[3], &s[12], false, min, max);
450     HadamardRotation(&s[4], &s[11], false, min, max);
451     HadamardRotation(&s[5], &s[10], false, min, max);
452     HadamardRotation(&s[6], &s[9], false, min, max);
453     HadamardRotation(&s[7], &s[8], false, min, max);
454   }
455 }
456 
457 // Process dct16 rows or columns, depending on the |is_row| flag.
458 template <ButterflyRotationFunc butterfly_rotation>
Dct16_NEON(void * dest,int32_t step,bool is_row,int row_shift)459 LIBGAV1_ALWAYS_INLINE void Dct16_NEON(void* dest, int32_t step, bool is_row,
460                                       int row_shift) {
461   auto* const dst = static_cast<int32_t*>(dest);
462   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
463   const int32x4_t min = vdupq_n_s32(-(1 << range));
464   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
465   int32x4_t s[16], x[16];
466 
467   if (is_row) {
468     for (int idx = 0; idx < 16; idx += 8) {
469       LoadSrc<4>(dst, step, idx, &x[idx]);
470       LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]);
471       Transpose4x4(&x[idx], &x[idx]);
472       Transpose4x4(&x[idx + 4], &x[idx + 4]);
473     }
474   } else {
475     LoadSrc<16>(dst, step, 0, &x[0]);
476   }
477 
478   // stage 1
479   // kBitReverseLookup 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15,
480   s[0] = x[0];
481   s[1] = x[8];
482   s[2] = x[4];
483   s[3] = x[12];
484   s[4] = x[2];
485   s[5] = x[10];
486   s[6] = x[6];
487   s[7] = x[14];
488   s[8] = x[1];
489   s[9] = x[9];
490   s[10] = x[5];
491   s[11] = x[13];
492   s[12] = x[3];
493   s[13] = x[11];
494   s[14] = x[7];
495   s[15] = x[15];
496 
497   Dct4Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/false);
498   Dct8Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/false);
499   Dct16Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/true);
500 
501   if (is_row) {
502     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
503     for (auto& i : s) {
504       i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
505     }
506     for (int idx = 0; idx < 16; idx += 8) {
507       Transpose4x4(&s[idx], &s[idx]);
508       Transpose4x4(&s[idx + 4], &s[idx + 4]);
509       StoreDst<4>(dst, step, idx, &s[idx]);
510       StoreDst<4>(dst, step, idx + 4, &s[idx + 4]);
511     }
512   } else {
513     StoreDst<16>(dst, step, 0, &s[0]);
514   }
515 }
516 
517 template <ButterflyRotationFunc butterfly_rotation,
518           bool is_fast_butterfly = false>
Dct32Stages(int32x4_t * s,const int32x4_t min,const int32x4_t max,const bool is_last_stage)519 LIBGAV1_ALWAYS_INLINE void Dct32Stages(int32x4_t* s, const int32x4_t min,
520                                        const int32x4_t max,
521                                        const bool is_last_stage) {
522   // stage 3
523   if (is_fast_butterfly) {
524     ButterflyRotation_SecondIsZero(&s[16], &s[31], 62, false);
525     ButterflyRotation_FirstIsZero(&s[17], &s[30], 30, false);
526     ButterflyRotation_SecondIsZero(&s[18], &s[29], 46, false);
527     ButterflyRotation_FirstIsZero(&s[19], &s[28], 14, false);
528     ButterflyRotation_SecondIsZero(&s[20], &s[27], 54, false);
529     ButterflyRotation_FirstIsZero(&s[21], &s[26], 22, false);
530     ButterflyRotation_SecondIsZero(&s[22], &s[25], 38, false);
531     ButterflyRotation_FirstIsZero(&s[23], &s[24], 6, false);
532   } else {
533     butterfly_rotation(&s[16], &s[31], 62, false);
534     butterfly_rotation(&s[17], &s[30], 30, false);
535     butterfly_rotation(&s[18], &s[29], 46, false);
536     butterfly_rotation(&s[19], &s[28], 14, false);
537     butterfly_rotation(&s[20], &s[27], 54, false);
538     butterfly_rotation(&s[21], &s[26], 22, false);
539     butterfly_rotation(&s[22], &s[25], 38, false);
540     butterfly_rotation(&s[23], &s[24], 6, false);
541   }
542 
543   // stage 6.
544   HadamardRotation(&s[16], &s[17], false, min, max);
545   HadamardRotation(&s[18], &s[19], true, min, max);
546   HadamardRotation(&s[20], &s[21], false, min, max);
547   HadamardRotation(&s[22], &s[23], true, min, max);
548   HadamardRotation(&s[24], &s[25], false, min, max);
549   HadamardRotation(&s[26], &s[27], true, min, max);
550   HadamardRotation(&s[28], &s[29], false, min, max);
551   HadamardRotation(&s[30], &s[31], true, min, max);
552 
553   // stage 10.
554   butterfly_rotation(&s[30], &s[17], 24 + 32, true);
555   butterfly_rotation(&s[29], &s[18], 24 + 64 + 32, true);
556   butterfly_rotation(&s[26], &s[21], 24, true);
557   butterfly_rotation(&s[25], &s[22], 24 + 64, true);
558 
559   // stage 15.
560   HadamardRotation(&s[16], &s[19], false, min, max);
561   HadamardRotation(&s[17], &s[18], false, min, max);
562   HadamardRotation(&s[20], &s[23], true, min, max);
563   HadamardRotation(&s[21], &s[22], true, min, max);
564   HadamardRotation(&s[24], &s[27], false, min, max);
565   HadamardRotation(&s[25], &s[26], false, min, max);
566   HadamardRotation(&s[28], &s[31], true, min, max);
567   HadamardRotation(&s[29], &s[30], true, min, max);
568 
569   // stage 20.
570   butterfly_rotation(&s[29], &s[18], 48, true);
571   butterfly_rotation(&s[28], &s[19], 48, true);
572   butterfly_rotation(&s[27], &s[20], 48 + 64, true);
573   butterfly_rotation(&s[26], &s[21], 48 + 64, true);
574 
575   // stage 24.
576   HadamardRotation(&s[16], &s[23], false, min, max);
577   HadamardRotation(&s[17], &s[22], false, min, max);
578   HadamardRotation(&s[18], &s[21], false, min, max);
579   HadamardRotation(&s[19], &s[20], false, min, max);
580   HadamardRotation(&s[24], &s[31], true, min, max);
581   HadamardRotation(&s[25], &s[30], true, min, max);
582   HadamardRotation(&s[26], &s[29], true, min, max);
583   HadamardRotation(&s[27], &s[28], true, min, max);
584 
585   // stage 27.
586   butterfly_rotation(&s[27], &s[20], 32, true);
587   butterfly_rotation(&s[26], &s[21], 32, true);
588   butterfly_rotation(&s[25], &s[22], 32, true);
589   butterfly_rotation(&s[24], &s[23], 32, true);
590 
591   // stage 29.
592   if (is_last_stage) {
593     HadamardRotation(&s[0], &s[31], false);
594     HadamardRotation(&s[1], &s[30], false);
595     HadamardRotation(&s[2], &s[29], false);
596     HadamardRotation(&s[3], &s[28], false);
597     HadamardRotation(&s[4], &s[27], false);
598     HadamardRotation(&s[5], &s[26], false);
599     HadamardRotation(&s[6], &s[25], false);
600     HadamardRotation(&s[7], &s[24], false);
601     HadamardRotation(&s[8], &s[23], false);
602     HadamardRotation(&s[9], &s[22], false);
603     HadamardRotation(&s[10], &s[21], false);
604     HadamardRotation(&s[11], &s[20], false);
605     HadamardRotation(&s[12], &s[19], false);
606     HadamardRotation(&s[13], &s[18], false);
607     HadamardRotation(&s[14], &s[17], false);
608     HadamardRotation(&s[15], &s[16], false);
609   } else {
610     HadamardRotation(&s[0], &s[31], false, min, max);
611     HadamardRotation(&s[1], &s[30], false, min, max);
612     HadamardRotation(&s[2], &s[29], false, min, max);
613     HadamardRotation(&s[3], &s[28], false, min, max);
614     HadamardRotation(&s[4], &s[27], false, min, max);
615     HadamardRotation(&s[5], &s[26], false, min, max);
616     HadamardRotation(&s[6], &s[25], false, min, max);
617     HadamardRotation(&s[7], &s[24], false, min, max);
618     HadamardRotation(&s[8], &s[23], false, min, max);
619     HadamardRotation(&s[9], &s[22], false, min, max);
620     HadamardRotation(&s[10], &s[21], false, min, max);
621     HadamardRotation(&s[11], &s[20], false, min, max);
622     HadamardRotation(&s[12], &s[19], false, min, max);
623     HadamardRotation(&s[13], &s[18], false, min, max);
624     HadamardRotation(&s[14], &s[17], false, min, max);
625     HadamardRotation(&s[15], &s[16], false, min, max);
626   }
627 }
628 
629 // Process dct32 rows or columns, depending on the |is_row| flag.
Dct32_NEON(void * dest,const int32_t step,const bool is_row,int row_shift)630 LIBGAV1_ALWAYS_INLINE void Dct32_NEON(void* dest, const int32_t step,
631                                       const bool is_row, int row_shift) {
632   auto* const dst = static_cast<int32_t*>(dest);
633   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
634   const int32x4_t min = vdupq_n_s32(-(1 << range));
635   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
636   int32x4_t s[32], x[32];
637 
638   if (is_row) {
639     for (int idx = 0; idx < 32; idx += 8) {
640       LoadSrc<4>(dst, step, idx, &x[idx]);
641       LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]);
642       Transpose4x4(&x[idx], &x[idx]);
643       Transpose4x4(&x[idx + 4], &x[idx + 4]);
644     }
645   } else {
646     LoadSrc<32>(dst, step, 0, &x[0]);
647   }
648 
649   // stage 1
650   // kBitReverseLookup
651   // 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30,
652   s[0] = x[0];
653   s[1] = x[16];
654   s[2] = x[8];
655   s[3] = x[24];
656   s[4] = x[4];
657   s[5] = x[20];
658   s[6] = x[12];
659   s[7] = x[28];
660   s[8] = x[2];
661   s[9] = x[18];
662   s[10] = x[10];
663   s[11] = x[26];
664   s[12] = x[6];
665   s[13] = x[22];
666   s[14] = x[14];
667   s[15] = x[30];
668 
669   // 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31,
670   s[16] = x[1];
671   s[17] = x[17];
672   s[18] = x[9];
673   s[19] = x[25];
674   s[20] = x[5];
675   s[21] = x[21];
676   s[22] = x[13];
677   s[23] = x[29];
678   s[24] = x[3];
679   s[25] = x[19];
680   s[26] = x[11];
681   s[27] = x[27];
682   s[28] = x[7];
683   s[29] = x[23];
684   s[30] = x[15];
685   s[31] = x[31];
686 
687   Dct4Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/false);
688   Dct8Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/false);
689   Dct16Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/false);
690   Dct32Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/true);
691 
692   if (is_row) {
693     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
694     for (int idx = 0; idx < 32; idx += 8) {
695       int32x4_t output[8];
696       Transpose4x4(&s[idx], &output[0]);
697       Transpose4x4(&s[idx + 4], &output[4]);
698       for (auto& o : output) {
699         o = vmovl_s16(vqmovn_s32(vqrshlq_s32(o, v_row_shift)));
700       }
701       StoreDst<4>(dst, step, idx, &output[0]);
702       StoreDst<4>(dst, step, idx + 4, &output[4]);
703     }
704   } else {
705     StoreDst<32>(dst, step, 0, &s[0]);
706   }
707 }
708 
Dct64_NEON(void * dest,int32_t step,bool is_row,int row_shift)709 void Dct64_NEON(void* dest, int32_t step, bool is_row, int row_shift) {
710   auto* const dst = static_cast<int32_t*>(dest);
711   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
712   const int32x4_t min = vdupq_n_s32(-(1 << range));
713   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
714   int32x4_t s[64], x[32];
715 
716   if (is_row) {
717     // The last 32 values of every row are always zero if the |tx_width| is
718     // 64.
719     for (int idx = 0; idx < 32; idx += 8) {
720       LoadSrc<4>(dst, step, idx, &x[idx]);
721       LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]);
722       Transpose4x4(&x[idx], &x[idx]);
723       Transpose4x4(&x[idx + 4], &x[idx + 4]);
724     }
725   } else {
726     // The last 32 values of every column are always zero if the |tx_height| is
727     // 64.
728     LoadSrc<32>(dst, step, 0, &x[0]);
729   }
730 
731   // stage 1
732   // kBitReverseLookup
733   // 0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60,
734   s[0] = x[0];
735   s[2] = x[16];
736   s[4] = x[8];
737   s[6] = x[24];
738   s[8] = x[4];
739   s[10] = x[20];
740   s[12] = x[12];
741   s[14] = x[28];
742 
743   // 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62,
744   s[16] = x[2];
745   s[18] = x[18];
746   s[20] = x[10];
747   s[22] = x[26];
748   s[24] = x[6];
749   s[26] = x[22];
750   s[28] = x[14];
751   s[30] = x[30];
752 
753   // 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61,
754   s[32] = x[1];
755   s[34] = x[17];
756   s[36] = x[9];
757   s[38] = x[25];
758   s[40] = x[5];
759   s[42] = x[21];
760   s[44] = x[13];
761   s[46] = x[29];
762 
763   // 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63
764   s[48] = x[3];
765   s[50] = x[19];
766   s[52] = x[11];
767   s[54] = x[27];
768   s[56] = x[7];
769   s[58] = x[23];
770   s[60] = x[15];
771   s[62] = x[31];
772 
773   Dct4Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
774       s, min, max, /*is_last_stage=*/false);
775   Dct8Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
776       s, min, max, /*is_last_stage=*/false);
777   Dct16Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
778       s, min, max, /*is_last_stage=*/false);
779   Dct32Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
780       s, min, max, /*is_last_stage=*/false);
781 
782   //-- start dct 64 stages
783   // stage 2.
784   ButterflyRotation_SecondIsZero(&s[32], &s[63], 63 - 0, false);
785   ButterflyRotation_FirstIsZero(&s[33], &s[62], 63 - 32, false);
786   ButterflyRotation_SecondIsZero(&s[34], &s[61], 63 - 16, false);
787   ButterflyRotation_FirstIsZero(&s[35], &s[60], 63 - 48, false);
788   ButterflyRotation_SecondIsZero(&s[36], &s[59], 63 - 8, false);
789   ButterflyRotation_FirstIsZero(&s[37], &s[58], 63 - 40, false);
790   ButterflyRotation_SecondIsZero(&s[38], &s[57], 63 - 24, false);
791   ButterflyRotation_FirstIsZero(&s[39], &s[56], 63 - 56, false);
792   ButterflyRotation_SecondIsZero(&s[40], &s[55], 63 - 4, false);
793   ButterflyRotation_FirstIsZero(&s[41], &s[54], 63 - 36, false);
794   ButterflyRotation_SecondIsZero(&s[42], &s[53], 63 - 20, false);
795   ButterflyRotation_FirstIsZero(&s[43], &s[52], 63 - 52, false);
796   ButterflyRotation_SecondIsZero(&s[44], &s[51], 63 - 12, false);
797   ButterflyRotation_FirstIsZero(&s[45], &s[50], 63 - 44, false);
798   ButterflyRotation_SecondIsZero(&s[46], &s[49], 63 - 28, false);
799   ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false);
800 
801   // stage 4.
802   HadamardRotation(&s[32], &s[33], false, min, max);
803   HadamardRotation(&s[34], &s[35], true, min, max);
804   HadamardRotation(&s[36], &s[37], false, min, max);
805   HadamardRotation(&s[38], &s[39], true, min, max);
806   HadamardRotation(&s[40], &s[41], false, min, max);
807   HadamardRotation(&s[42], &s[43], true, min, max);
808   HadamardRotation(&s[44], &s[45], false, min, max);
809   HadamardRotation(&s[46], &s[47], true, min, max);
810   HadamardRotation(&s[48], &s[49], false, min, max);
811   HadamardRotation(&s[50], &s[51], true, min, max);
812   HadamardRotation(&s[52], &s[53], false, min, max);
813   HadamardRotation(&s[54], &s[55], true, min, max);
814   HadamardRotation(&s[56], &s[57], false, min, max);
815   HadamardRotation(&s[58], &s[59], true, min, max);
816   HadamardRotation(&s[60], &s[61], false, min, max);
817   HadamardRotation(&s[62], &s[63], true, min, max);
818 
819   // stage 7.
820   ButterflyRotation_4(&s[62], &s[33], 60 - 0, true);
821   ButterflyRotation_4(&s[61], &s[34], 60 - 0 + 64, true);
822   ButterflyRotation_4(&s[58], &s[37], 60 - 32, true);
823   ButterflyRotation_4(&s[57], &s[38], 60 - 32 + 64, true);
824   ButterflyRotation_4(&s[54], &s[41], 60 - 16, true);
825   ButterflyRotation_4(&s[53], &s[42], 60 - 16 + 64, true);
826   ButterflyRotation_4(&s[50], &s[45], 60 - 48, true);
827   ButterflyRotation_4(&s[49], &s[46], 60 - 48 + 64, true);
828 
829   // stage 11.
830   HadamardRotation(&s[32], &s[35], false, min, max);
831   HadamardRotation(&s[33], &s[34], false, min, max);
832   HadamardRotation(&s[36], &s[39], true, min, max);
833   HadamardRotation(&s[37], &s[38], true, min, max);
834   HadamardRotation(&s[40], &s[43], false, min, max);
835   HadamardRotation(&s[41], &s[42], false, min, max);
836   HadamardRotation(&s[44], &s[47], true, min, max);
837   HadamardRotation(&s[45], &s[46], true, min, max);
838   HadamardRotation(&s[48], &s[51], false, min, max);
839   HadamardRotation(&s[49], &s[50], false, min, max);
840   HadamardRotation(&s[52], &s[55], true, min, max);
841   HadamardRotation(&s[53], &s[54], true, min, max);
842   HadamardRotation(&s[56], &s[59], false, min, max);
843   HadamardRotation(&s[57], &s[58], false, min, max);
844   HadamardRotation(&s[60], &s[63], true, min, max);
845   HadamardRotation(&s[61], &s[62], true, min, max);
846 
847   // stage 16.
848   ButterflyRotation_4(&s[61], &s[34], 56, true);
849   ButterflyRotation_4(&s[60], &s[35], 56, true);
850   ButterflyRotation_4(&s[59], &s[36], 56 + 64, true);
851   ButterflyRotation_4(&s[58], &s[37], 56 + 64, true);
852   ButterflyRotation_4(&s[53], &s[42], 56 - 32, true);
853   ButterflyRotation_4(&s[52], &s[43], 56 - 32, true);
854   ButterflyRotation_4(&s[51], &s[44], 56 - 32 + 64, true);
855   ButterflyRotation_4(&s[50], &s[45], 56 - 32 + 64, true);
856 
857   // stage 21.
858   HadamardRotation(&s[32], &s[39], false, min, max);
859   HadamardRotation(&s[33], &s[38], false, min, max);
860   HadamardRotation(&s[34], &s[37], false, min, max);
861   HadamardRotation(&s[35], &s[36], false, min, max);
862   HadamardRotation(&s[40], &s[47], true, min, max);
863   HadamardRotation(&s[41], &s[46], true, min, max);
864   HadamardRotation(&s[42], &s[45], true, min, max);
865   HadamardRotation(&s[43], &s[44], true, min, max);
866   HadamardRotation(&s[48], &s[55], false, min, max);
867   HadamardRotation(&s[49], &s[54], false, min, max);
868   HadamardRotation(&s[50], &s[53], false, min, max);
869   HadamardRotation(&s[51], &s[52], false, min, max);
870   HadamardRotation(&s[56], &s[63], true, min, max);
871   HadamardRotation(&s[57], &s[62], true, min, max);
872   HadamardRotation(&s[58], &s[61], true, min, max);
873   HadamardRotation(&s[59], &s[60], true, min, max);
874 
875   // stage 25.
876   ButterflyRotation_4(&s[59], &s[36], 48, true);
877   ButterflyRotation_4(&s[58], &s[37], 48, true);
878   ButterflyRotation_4(&s[57], &s[38], 48, true);
879   ButterflyRotation_4(&s[56], &s[39], 48, true);
880   ButterflyRotation_4(&s[55], &s[40], 112, true);
881   ButterflyRotation_4(&s[54], &s[41], 112, true);
882   ButterflyRotation_4(&s[53], &s[42], 112, true);
883   ButterflyRotation_4(&s[52], &s[43], 112, true);
884 
885   // stage 28.
886   HadamardRotation(&s[32], &s[47], false, min, max);
887   HadamardRotation(&s[33], &s[46], false, min, max);
888   HadamardRotation(&s[34], &s[45], false, min, max);
889   HadamardRotation(&s[35], &s[44], false, min, max);
890   HadamardRotation(&s[36], &s[43], false, min, max);
891   HadamardRotation(&s[37], &s[42], false, min, max);
892   HadamardRotation(&s[38], &s[41], false, min, max);
893   HadamardRotation(&s[39], &s[40], false, min, max);
894   HadamardRotation(&s[48], &s[63], true, min, max);
895   HadamardRotation(&s[49], &s[62], true, min, max);
896   HadamardRotation(&s[50], &s[61], true, min, max);
897   HadamardRotation(&s[51], &s[60], true, min, max);
898   HadamardRotation(&s[52], &s[59], true, min, max);
899   HadamardRotation(&s[53], &s[58], true, min, max);
900   HadamardRotation(&s[54], &s[57], true, min, max);
901   HadamardRotation(&s[55], &s[56], true, min, max);
902 
903   // stage 30.
904   ButterflyRotation_4(&s[55], &s[40], 32, true);
905   ButterflyRotation_4(&s[54], &s[41], 32, true);
906   ButterflyRotation_4(&s[53], &s[42], 32, true);
907   ButterflyRotation_4(&s[52], &s[43], 32, true);
908   ButterflyRotation_4(&s[51], &s[44], 32, true);
909   ButterflyRotation_4(&s[50], &s[45], 32, true);
910   ButterflyRotation_4(&s[49], &s[46], 32, true);
911   ButterflyRotation_4(&s[48], &s[47], 32, true);
912 
913   // stage 31.
914   for (int i = 0; i < 32; i += 4) {
915     HadamardRotation(&s[i], &s[63 - i], false, min, max);
916     HadamardRotation(&s[i + 1], &s[63 - i - 1], false, min, max);
917     HadamardRotation(&s[i + 2], &s[63 - i - 2], false, min, max);
918     HadamardRotation(&s[i + 3], &s[63 - i - 3], false, min, max);
919   }
920   //-- end dct 64 stages
921   if (is_row) {
922     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
923     for (int idx = 0; idx < 64; idx += 8) {
924       int32x4_t output[8];
925       Transpose4x4(&s[idx], &output[0]);
926       Transpose4x4(&s[idx + 4], &output[4]);
927       for (auto& o : output) {
928         o = vmovl_s16(vqmovn_s32(vqrshlq_s32(o, v_row_shift)));
929       }
930       StoreDst<4>(dst, step, idx, &output[0]);
931       StoreDst<4>(dst, step, idx + 4, &output[4]);
932     }
933   } else {
934     StoreDst<64>(dst, step, 0, &s[0]);
935   }
936 }
937 
938 //------------------------------------------------------------------------------
939 // Asymmetric Discrete Sine Transforms (ADST).
Adst4_NEON(void * dest,int32_t step,bool is_row,int row_shift)940 LIBGAV1_ALWAYS_INLINE void Adst4_NEON(void* dest, int32_t step, bool is_row,
941                                       int row_shift) {
942   auto* const dst = static_cast<int32_t*>(dest);
943   int32x4_t s[8];
944   int32x4_t x[4];
945 
946   if (is_row) {
947     assert(step == 4);
948     int32x4x4_t y = vld4q_s32(dst);
949     for (int i = 0; i < 4; ++i) x[i] = y.val[i];
950   } else {
951     LoadSrc<4>(dst, step, 0, x);
952   }
953 
954   // stage 1.
955   s[5] = vmulq_n_s32(x[3], kAdst4Multiplier[1]);
956   s[6] = vmulq_n_s32(x[3], kAdst4Multiplier[3]);
957 
958   // stage 2.
959   const int32x4_t a7 = vsubq_s32(x[0], x[2]);
960   const int32x4_t b7 = vaddq_s32(a7, x[3]);
961 
962   // stage 3.
963   s[0] = vmulq_n_s32(x[0], kAdst4Multiplier[0]);
964   s[1] = vmulq_n_s32(x[0], kAdst4Multiplier[1]);
965   // s[0] = s[0] + s[3]
966   s[0] = vmlaq_n_s32(s[0], x[2], kAdst4Multiplier[3]);
967   // s[1] = s[1] - s[4]
968   s[1] = vmlsq_n_s32(s[1], x[2], kAdst4Multiplier[0]);
969 
970   s[3] = vmulq_n_s32(x[1], kAdst4Multiplier[2]);
971   s[2] = vmulq_n_s32(b7, kAdst4Multiplier[2]);
972 
973   // stage 4.
974   s[0] = vaddq_s32(s[0], s[5]);
975   s[1] = vsubq_s32(s[1], s[6]);
976 
977   // stages 5 and 6.
978   const int32x4_t x0 = vaddq_s32(s[0], s[3]);
979   const int32x4_t x1 = vaddq_s32(s[1], s[3]);
980   const int32x4_t x3_a = vaddq_s32(s[0], s[1]);
981   const int32x4_t x3 = vsubq_s32(x3_a, s[3]);
982   x[0] = vrshrq_n_s32(x0, 12);
983   x[1] = vrshrq_n_s32(x1, 12);
984   x[2] = vrshrq_n_s32(s[2], 12);
985   x[3] = vrshrq_n_s32(x3, 12);
986 
987   if (is_row) {
988     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
989     x[0] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[0], v_row_shift)));
990     x[1] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[1], v_row_shift)));
991     x[2] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[2], v_row_shift)));
992     x[3] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[3], v_row_shift)));
993     int32x4x4_t y;
994     for (int i = 0; i < 4; ++i) y.val[i] = x[i];
995     vst4q_s32(dst, y);
996   } else {
997     StoreDst<4>(dst, step, 0, x);
998   }
999 }
1000 
1001 alignas(16) constexpr int32_t kAdst4DcOnlyMultiplier[4] = {1321, 2482, 3344,
1002                                                            2482};
1003 
Adst4DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1004 LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, int adjusted_tx_height,
1005                                        bool should_round, int row_shift) {
1006   if (adjusted_tx_height > 1) return false;
1007 
1008   auto* dst = static_cast<int32_t*>(dest);
1009   int32x4_t s[2];
1010 
1011   const int32x4_t v_src0 = vdupq_n_s32(dst[0]);
1012   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1013   const int32x4_t v_src0_round =
1014       vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
1015 
1016   const int32x4_t v_src = vbslq_s32(v_mask, v_src0_round, v_src0);
1017   const int32x4_t kAdst4DcOnlyMultipliers = vld1q_s32(kAdst4DcOnlyMultiplier);
1018   s[1] = vdupq_n_s32(0);
1019 
1020   // s0*k0 s0*k1 s0*k2 s0*k1
1021   s[0] = vmulq_s32(kAdst4DcOnlyMultipliers, v_src);
1022   // 0     0     0     s0*k0
1023   s[1] = vextq_s32(s[1], s[0], 1);
1024 
1025   const int32x4_t x3 = vaddq_s32(s[0], s[1]);
1026   const int32x4_t dst_0 = vrshrq_n_s32(x3, 12);
1027 
1028   // vqrshlq_s32 will shift right if shift value is negative.
1029   vst1q_s32(dst,
1030             vmovl_s16(vqmovn_s32(vqrshlq_s32(dst_0, vdupq_n_s32(-row_shift)))));
1031 
1032   return true;
1033 }
1034 
Adst4DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1035 LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, int adjusted_tx_height,
1036                                              int width) {
1037   if (adjusted_tx_height > 1) return false;
1038 
1039   auto* dst = static_cast<int32_t*>(dest);
1040   int32x4_t s[4];
1041 
1042   int i = 0;
1043   do {
1044     const int32x4_t v_src = vld1q_s32(&dst[i]);
1045 
1046     s[0] = vmulq_n_s32(v_src, kAdst4Multiplier[0]);
1047     s[1] = vmulq_n_s32(v_src, kAdst4Multiplier[1]);
1048     s[2] = vmulq_n_s32(v_src, kAdst4Multiplier[2]);
1049 
1050     const int32x4_t x0 = s[0];
1051     const int32x4_t x1 = s[1];
1052     const int32x4_t x2 = s[2];
1053     const int32x4_t x3 = vaddq_s32(s[0], s[1]);
1054     const int32x4_t dst_0 = vrshrq_n_s32(x0, 12);
1055     const int32x4_t dst_1 = vrshrq_n_s32(x1, 12);
1056     const int32x4_t dst_2 = vrshrq_n_s32(x2, 12);
1057     const int32x4_t dst_3 = vrshrq_n_s32(x3, 12);
1058 
1059     vst1q_s32(&dst[i], dst_0);
1060     vst1q_s32(&dst[i + width * 1], dst_1);
1061     vst1q_s32(&dst[i + width * 2], dst_2);
1062     vst1q_s32(&dst[i + width * 3], dst_3);
1063 
1064     i += 4;
1065   } while (i < width);
1066 
1067   return true;
1068 }
1069 
1070 template <ButterflyRotationFunc butterfly_rotation>
Adst8_NEON(void * dest,int32_t step,bool is_row,int row_shift)1071 LIBGAV1_ALWAYS_INLINE void Adst8_NEON(void* dest, int32_t step, bool is_row,
1072                                       int row_shift) {
1073   auto* const dst = static_cast<int32_t*>(dest);
1074   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
1075   const int32x4_t min = vdupq_n_s32(-(1 << range));
1076   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
1077   int32x4_t s[8], x[8];
1078 
1079   if (is_row) {
1080     LoadSrc<4>(dst, step, 0, &x[0]);
1081     LoadSrc<4>(dst, step, 4, &x[4]);
1082     Transpose4x4(&x[0], &x[0]);
1083     Transpose4x4(&x[4], &x[4]);
1084   } else {
1085     LoadSrc<8>(dst, step, 0, &x[0]);
1086   }
1087 
1088   // stage 1.
1089   s[0] = x[7];
1090   s[1] = x[0];
1091   s[2] = x[5];
1092   s[3] = x[2];
1093   s[4] = x[3];
1094   s[5] = x[4];
1095   s[6] = x[1];
1096   s[7] = x[6];
1097 
1098   // stage 2.
1099   butterfly_rotation(&s[0], &s[1], 60 - 0, true);
1100   butterfly_rotation(&s[2], &s[3], 60 - 16, true);
1101   butterfly_rotation(&s[4], &s[5], 60 - 32, true);
1102   butterfly_rotation(&s[6], &s[7], 60 - 48, true);
1103 
1104   // stage 3.
1105   HadamardRotation(&s[0], &s[4], false, min, max);
1106   HadamardRotation(&s[1], &s[5], false, min, max);
1107   HadamardRotation(&s[2], &s[6], false, min, max);
1108   HadamardRotation(&s[3], &s[7], false, min, max);
1109 
1110   // stage 4.
1111   butterfly_rotation(&s[4], &s[5], 48 - 0, true);
1112   butterfly_rotation(&s[7], &s[6], 48 - 32, true);
1113 
1114   // stage 5.
1115   HadamardRotation(&s[0], &s[2], false, min, max);
1116   HadamardRotation(&s[4], &s[6], false, min, max);
1117   HadamardRotation(&s[1], &s[3], false, min, max);
1118   HadamardRotation(&s[5], &s[7], false, min, max);
1119 
1120   // stage 6.
1121   butterfly_rotation(&s[2], &s[3], 32, true);
1122   butterfly_rotation(&s[6], &s[7], 32, true);
1123 
1124   // stage 7.
1125   x[0] = s[0];
1126   x[1] = vqnegq_s32(s[4]);
1127   x[2] = s[6];
1128   x[3] = vqnegq_s32(s[2]);
1129   x[4] = s[3];
1130   x[5] = vqnegq_s32(s[7]);
1131   x[6] = s[5];
1132   x[7] = vqnegq_s32(s[1]);
1133 
1134   if (is_row) {
1135     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
1136     for (auto& i : x) {
1137       i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
1138     }
1139     Transpose4x4(&x[0], &x[0]);
1140     Transpose4x4(&x[4], &x[4]);
1141     StoreDst<4>(dst, step, 0, &x[0]);
1142     StoreDst<4>(dst, step, 4, &x[4]);
1143   } else {
1144     StoreDst<8>(dst, step, 0, &x[0]);
1145   }
1146 }
1147 
Adst8DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1148 LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, int adjusted_tx_height,
1149                                        bool should_round, int row_shift) {
1150   if (adjusted_tx_height > 1) return false;
1151 
1152   auto* dst = static_cast<int32_t*>(dest);
1153   int32x4_t s[8];
1154 
1155   const int32x4_t v_src = vdupq_n_s32(dst[0]);
1156   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1157   const int32x4_t v_src_round =
1158       vqrdmulhq_n_s32(v_src, kTransformRowMultiplier << (31 - 12));
1159   // stage 1.
1160   s[1] = vbslq_s32(v_mask, v_src_round, v_src);
1161 
1162   // stage 2.
1163   ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
1164 
1165   // stage 3.
1166   s[4] = s[0];
1167   s[5] = s[1];
1168 
1169   // stage 4.
1170   ButterflyRotation_4(&s[4], &s[5], 48, true);
1171 
1172   // stage 5.
1173   s[2] = s[0];
1174   s[3] = s[1];
1175   s[6] = s[4];
1176   s[7] = s[5];
1177 
1178   // stage 6.
1179   ButterflyRotation_4(&s[2], &s[3], 32, true);
1180   ButterflyRotation_4(&s[6], &s[7], 32, true);
1181 
1182   // stage 7.
1183   int32x4_t x[8];
1184   x[0] = s[0];
1185   x[1] = vqnegq_s32(s[4]);
1186   x[2] = s[6];
1187   x[3] = vqnegq_s32(s[2]);
1188   x[4] = s[3];
1189   x[5] = vqnegq_s32(s[7]);
1190   x[6] = s[5];
1191   x[7] = vqnegq_s32(s[1]);
1192 
1193   for (int i = 0; i < 8; ++i) {
1194     // vqrshlq_s32 will shift right if shift value is negative.
1195     x[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[i], vdupq_n_s32(-row_shift))));
1196     vst1q_lane_s32(&dst[i], x[i], 0);
1197   }
1198 
1199   return true;
1200 }
1201 
Adst8DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1202 LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, int adjusted_tx_height,
1203                                              int width) {
1204   if (adjusted_tx_height > 1) return false;
1205 
1206   auto* dst = static_cast<int32_t*>(dest);
1207   int32x4_t s[8];
1208 
1209   int i = 0;
1210   do {
1211     const int32x4_t v_src = vld1q_s32(dst);
1212     // stage 1.
1213     s[1] = v_src;
1214 
1215     // stage 2.
1216     ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
1217 
1218     // stage 3.
1219     s[4] = s[0];
1220     s[5] = s[1];
1221 
1222     // stage 4.
1223     ButterflyRotation_4(&s[4], &s[5], 48, true);
1224 
1225     // stage 5.
1226     s[2] = s[0];
1227     s[3] = s[1];
1228     s[6] = s[4];
1229     s[7] = s[5];
1230 
1231     // stage 6.
1232     ButterflyRotation_4(&s[2], &s[3], 32, true);
1233     ButterflyRotation_4(&s[6], &s[7], 32, true);
1234 
1235     // stage 7.
1236     int32x4_t x[8];
1237     x[0] = s[0];
1238     x[1] = vqnegq_s32(s[4]);
1239     x[2] = s[6];
1240     x[3] = vqnegq_s32(s[2]);
1241     x[4] = s[3];
1242     x[5] = vqnegq_s32(s[7]);
1243     x[6] = s[5];
1244     x[7] = vqnegq_s32(s[1]);
1245 
1246     for (int j = 0; j < 8; ++j) {
1247       vst1q_s32(&dst[j * width], x[j]);
1248     }
1249     i += 4;
1250     dst += 4;
1251   } while (i < width);
1252 
1253   return true;
1254 }
1255 
1256 template <ButterflyRotationFunc butterfly_rotation>
Adst16_NEON(void * dest,int32_t step,bool is_row,int row_shift)1257 LIBGAV1_ALWAYS_INLINE void Adst16_NEON(void* dest, int32_t step, bool is_row,
1258                                        int row_shift) {
1259   auto* const dst = static_cast<int32_t*>(dest);
1260   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
1261   const int32x4_t min = vdupq_n_s32(-(1 << range));
1262   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
1263   int32x4_t s[16], x[16];
1264 
1265   if (is_row) {
1266     for (int idx = 0; idx < 16; idx += 8) {
1267       LoadSrc<4>(dst, step, idx, &x[idx]);
1268       LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]);
1269       Transpose4x4(&x[idx], &x[idx]);
1270       Transpose4x4(&x[idx + 4], &x[idx + 4]);
1271     }
1272   } else {
1273     LoadSrc<16>(dst, step, 0, &x[0]);
1274   }
1275 
1276   // stage 1.
1277   s[0] = x[15];
1278   s[1] = x[0];
1279   s[2] = x[13];
1280   s[3] = x[2];
1281   s[4] = x[11];
1282   s[5] = x[4];
1283   s[6] = x[9];
1284   s[7] = x[6];
1285   s[8] = x[7];
1286   s[9] = x[8];
1287   s[10] = x[5];
1288   s[11] = x[10];
1289   s[12] = x[3];
1290   s[13] = x[12];
1291   s[14] = x[1];
1292   s[15] = x[14];
1293 
1294   // stage 2.
1295   butterfly_rotation(&s[0], &s[1], 62 - 0, true);
1296   butterfly_rotation(&s[2], &s[3], 62 - 8, true);
1297   butterfly_rotation(&s[4], &s[5], 62 - 16, true);
1298   butterfly_rotation(&s[6], &s[7], 62 - 24, true);
1299   butterfly_rotation(&s[8], &s[9], 62 - 32, true);
1300   butterfly_rotation(&s[10], &s[11], 62 - 40, true);
1301   butterfly_rotation(&s[12], &s[13], 62 - 48, true);
1302   butterfly_rotation(&s[14], &s[15], 62 - 56, true);
1303 
1304   // stage 3.
1305   HadamardRotation(&s[0], &s[8], false, min, max);
1306   HadamardRotation(&s[1], &s[9], false, min, max);
1307   HadamardRotation(&s[2], &s[10], false, min, max);
1308   HadamardRotation(&s[3], &s[11], false, min, max);
1309   HadamardRotation(&s[4], &s[12], false, min, max);
1310   HadamardRotation(&s[5], &s[13], false, min, max);
1311   HadamardRotation(&s[6], &s[14], false, min, max);
1312   HadamardRotation(&s[7], &s[15], false, min, max);
1313 
1314   // stage 4.
1315   butterfly_rotation(&s[8], &s[9], 56 - 0, true);
1316   butterfly_rotation(&s[13], &s[12], 8 + 0, true);
1317   butterfly_rotation(&s[10], &s[11], 56 - 32, true);
1318   butterfly_rotation(&s[15], &s[14], 8 + 32, true);
1319 
1320   // stage 5.
1321   HadamardRotation(&s[0], &s[4], false, min, max);
1322   HadamardRotation(&s[8], &s[12], false, min, max);
1323   HadamardRotation(&s[1], &s[5], false, min, max);
1324   HadamardRotation(&s[9], &s[13], false, min, max);
1325   HadamardRotation(&s[2], &s[6], false, min, max);
1326   HadamardRotation(&s[10], &s[14], false, min, max);
1327   HadamardRotation(&s[3], &s[7], false, min, max);
1328   HadamardRotation(&s[11], &s[15], false, min, max);
1329 
1330   // stage 6.
1331   butterfly_rotation(&s[4], &s[5], 48 - 0, true);
1332   butterfly_rotation(&s[12], &s[13], 48 - 0, true);
1333   butterfly_rotation(&s[7], &s[6], 48 - 32, true);
1334   butterfly_rotation(&s[15], &s[14], 48 - 32, true);
1335 
1336   // stage 7.
1337   HadamardRotation(&s[0], &s[2], false, min, max);
1338   HadamardRotation(&s[4], &s[6], false, min, max);
1339   HadamardRotation(&s[8], &s[10], false, min, max);
1340   HadamardRotation(&s[12], &s[14], false, min, max);
1341   HadamardRotation(&s[1], &s[3], false, min, max);
1342   HadamardRotation(&s[5], &s[7], false, min, max);
1343   HadamardRotation(&s[9], &s[11], false, min, max);
1344   HadamardRotation(&s[13], &s[15], false, min, max);
1345 
1346   // stage 8.
1347   butterfly_rotation(&s[2], &s[3], 32, true);
1348   butterfly_rotation(&s[6], &s[7], 32, true);
1349   butterfly_rotation(&s[10], &s[11], 32, true);
1350   butterfly_rotation(&s[14], &s[15], 32, true);
1351 
1352   // stage 9.
1353   x[0] = s[0];
1354   x[1] = vqnegq_s32(s[8]);
1355   x[2] = s[12];
1356   x[3] = vqnegq_s32(s[4]);
1357   x[4] = s[6];
1358   x[5] = vqnegq_s32(s[14]);
1359   x[6] = s[10];
1360   x[7] = vqnegq_s32(s[2]);
1361   x[8] = s[3];
1362   x[9] = vqnegq_s32(s[11]);
1363   x[10] = s[15];
1364   x[11] = vqnegq_s32(s[7]);
1365   x[12] = s[5];
1366   x[13] = vqnegq_s32(s[13]);
1367   x[14] = s[9];
1368   x[15] = vqnegq_s32(s[1]);
1369 
1370   if (is_row) {
1371     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
1372     for (auto& i : x) {
1373       i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
1374     }
1375     for (int idx = 0; idx < 16; idx += 8) {
1376       Transpose4x4(&x[idx], &x[idx]);
1377       Transpose4x4(&x[idx + 4], &x[idx + 4]);
1378       StoreDst<4>(dst, step, idx, &x[idx]);
1379       StoreDst<4>(dst, step, idx + 4, &x[idx + 4]);
1380     }
1381   } else {
1382     StoreDst<16>(dst, step, 0, &x[0]);
1383   }
1384 }
1385 
Adst16DcOnlyInternal(int32x4_t * s,int32x4_t * x)1386 LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(int32x4_t* s, int32x4_t* x) {
1387   // stage 2.
1388   ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true);
1389 
1390   // stage 3.
1391   s[8] = s[0];
1392   s[9] = s[1];
1393 
1394   // stage 4.
1395   ButterflyRotation_4(&s[8], &s[9], 56, true);
1396 
1397   // stage 5.
1398   s[4] = s[0];
1399   s[12] = s[8];
1400   s[5] = s[1];
1401   s[13] = s[9];
1402 
1403   // stage 6.
1404   ButterflyRotation_4(&s[4], &s[5], 48, true);
1405   ButterflyRotation_4(&s[12], &s[13], 48, true);
1406 
1407   // stage 7.
1408   s[2] = s[0];
1409   s[6] = s[4];
1410   s[10] = s[8];
1411   s[14] = s[12];
1412   s[3] = s[1];
1413   s[7] = s[5];
1414   s[11] = s[9];
1415   s[15] = s[13];
1416 
1417   // stage 8.
1418   ButterflyRotation_4(&s[2], &s[3], 32, true);
1419   ButterflyRotation_4(&s[6], &s[7], 32, true);
1420   ButterflyRotation_4(&s[10], &s[11], 32, true);
1421   ButterflyRotation_4(&s[14], &s[15], 32, true);
1422 
1423   // stage 9.
1424   x[0] = s[0];
1425   x[1] = vqnegq_s32(s[8]);
1426   x[2] = s[12];
1427   x[3] = vqnegq_s32(s[4]);
1428   x[4] = s[6];
1429   x[5] = vqnegq_s32(s[14]);
1430   x[6] = s[10];
1431   x[7] = vqnegq_s32(s[2]);
1432   x[8] = s[3];
1433   x[9] = vqnegq_s32(s[11]);
1434   x[10] = s[15];
1435   x[11] = vqnegq_s32(s[7]);
1436   x[12] = s[5];
1437   x[13] = vqnegq_s32(s[13]);
1438   x[14] = s[9];
1439   x[15] = vqnegq_s32(s[1]);
1440 }
1441 
Adst16DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1442 LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, int adjusted_tx_height,
1443                                         bool should_round, int row_shift) {
1444   if (adjusted_tx_height > 1) return false;
1445 
1446   auto* dst = static_cast<int32_t*>(dest);
1447   int32x4_t s[16];
1448   int32x4_t x[16];
1449   const int32x4_t v_src = vdupq_n_s32(dst[0]);
1450   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1451   const int32x4_t v_src_round =
1452       vqrdmulhq_n_s32(v_src, kTransformRowMultiplier << (31 - 12));
1453   // stage 1.
1454   s[1] = vbslq_s32(v_mask, v_src_round, v_src);
1455 
1456   Adst16DcOnlyInternal(s, x);
1457 
1458   for (int i = 0; i < 16; ++i) {
1459     // vqrshlq_s32 will shift right if shift value is negative.
1460     x[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[i], vdupq_n_s32(-row_shift))));
1461     vst1q_lane_s32(&dst[i], x[i], 0);
1462   }
1463 
1464   return true;
1465 }
1466 
Adst16DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1467 LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest,
1468                                               int adjusted_tx_height,
1469                                               int width) {
1470   if (adjusted_tx_height > 1) return false;
1471 
1472   auto* dst = static_cast<int32_t*>(dest);
1473   int i = 0;
1474   do {
1475     int32x4_t s[16];
1476     int32x4_t x[16];
1477     const int32x4_t v_src = vld1q_s32(dst);
1478     // stage 1.
1479     s[1] = v_src;
1480 
1481     Adst16DcOnlyInternal(s, x);
1482 
1483     for (int j = 0; j < 16; ++j) {
1484       vst1q_s32(&dst[j * width], x[j]);
1485     }
1486     i += 4;
1487     dst += 4;
1488   } while (i < width);
1489 
1490   return true;
1491 }
1492 
1493 //------------------------------------------------------------------------------
1494 // Identity Transforms.
1495 
Identity4_NEON(void * dest,int32_t step,int shift)1496 LIBGAV1_ALWAYS_INLINE void Identity4_NEON(void* dest, int32_t step, int shift) {
1497   auto* const dst = static_cast<int32_t*>(dest);
1498   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
1499   const int32x4_t v_multiplier = vdupq_n_s32(kIdentity4Multiplier);
1500   const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
1501   for (int i = 0; i < 4; ++i) {
1502     const int32x4_t v_src = vld1q_s32(&dst[i * step]);
1503     const int32x4_t v_src_mult_lo =
1504         vmlaq_s32(v_dual_round, v_src, v_multiplier);
1505     const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
1506     vst1q_s32(&dst[i * step], vmovl_s16(vqmovn_s32(shift_lo)));
1507   }
1508 }
1509 
Identity4DcOnly(void * dest,int adjusted_tx_height,bool should_round,int tx_height)1510 LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, int adjusted_tx_height,
1511                                            bool should_round, int tx_height) {
1512   if (adjusted_tx_height > 1) return false;
1513 
1514   auto* dst = static_cast<int32_t*>(dest);
1515   const int32x4_t v_src0 = vdupq_n_s32(dst[0]);
1516   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1517   const int32x4_t v_src_round =
1518       vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
1519   const int32x4_t v_src = vbslq_s32(v_mask, v_src_round, v_src0);
1520   const int shift = tx_height < 16 ? 0 : 1;
1521   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
1522   const int32x4_t v_multiplier = vdupq_n_s32(kIdentity4Multiplier);
1523   const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
1524   const int32x4_t v_src_mult_lo = vmlaq_s32(v_dual_round, v_src, v_multiplier);
1525   const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift);
1526   vst1q_lane_s32(dst, vmovl_s16(vqmovn_s32(dst_0)), 0);
1527   return true;
1528 }
1529 
1530 template <int identity_size>
IdentityColumnStoreToFrame(Array2DView<uint16_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int32_t * LIBGAV1_RESTRICT source)1531 LIBGAV1_ALWAYS_INLINE void IdentityColumnStoreToFrame(
1532     Array2DView<uint16_t> frame, const int start_x, const int start_y,
1533     const int tx_width, const int tx_height,
1534     const int32_t* LIBGAV1_RESTRICT source) {
1535   static_assert(identity_size == 4 || identity_size == 8 ||
1536                     identity_size == 16 || identity_size == 32,
1537                 "Invalid identity_size.");
1538   const int stride = frame.columns();
1539   uint16_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1540   const int32x4_t v_dual_round = vdupq_n_s32((1 + (1 << 4)) << 11);
1541   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
1542 
1543   if (identity_size < 32) {
1544     if (tx_width == 4) {
1545       int i = 0;
1546       do {
1547         int32x4x2_t v_src, v_dst_i, a, b;
1548         v_src.val[0] = vld1q_s32(&source[i * 4]);
1549         v_src.val[1] = vld1q_s32(&source[(i * 4) + 4]);
1550         if (identity_size == 4) {
1551           v_dst_i.val[0] =
1552               vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity4Multiplier);
1553           v_dst_i.val[1] =
1554               vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity4Multiplier);
1555           a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
1556           a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
1557         } else if (identity_size == 8) {
1558           v_dst_i.val[0] = vaddq_s32(v_src.val[0], v_src.val[0]);
1559           v_dst_i.val[1] = vaddq_s32(v_src.val[1], v_src.val[1]);
1560           a.val[0] = vrshrq_n_s32(v_dst_i.val[0], 4);
1561           a.val[1] = vrshrq_n_s32(v_dst_i.val[1], 4);
1562         } else {  // identity_size == 16
1563           v_dst_i.val[0] =
1564               vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier);
1565           v_dst_i.val[1] =
1566               vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier);
1567           a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
1568           a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
1569         }
1570         uint16x4x2_t frame_data;
1571         frame_data.val[0] = vld1_u16(dst);
1572         frame_data.val[1] = vld1_u16(dst + stride);
1573         b.val[0] = vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0]));
1574         b.val[1] = vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1]));
1575         vst1_u16(dst, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth));
1576         vst1_u16(dst + stride, vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth));
1577         dst += stride << 1;
1578         i += 2;
1579       } while (i < tx_height);
1580     } else {
1581       int i = 0;
1582       do {
1583         const int row = i * tx_width;
1584         int j = 0;
1585         do {
1586           int32x4x2_t v_src, v_dst_i, a, b;
1587           v_src.val[0] = vld1q_s32(&source[row + j]);
1588           v_src.val[1] = vld1q_s32(&source[row + j + 4]);
1589           if (identity_size == 4) {
1590             v_dst_i.val[0] =
1591                 vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity4Multiplier);
1592             v_dst_i.val[1] =
1593                 vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity4Multiplier);
1594             a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
1595             a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
1596           } else if (identity_size == 8) {
1597             v_dst_i.val[0] = vaddq_s32(v_src.val[0], v_src.val[0]);
1598             v_dst_i.val[1] = vaddq_s32(v_src.val[1], v_src.val[1]);
1599             a.val[0] = vrshrq_n_s32(v_dst_i.val[0], 4);
1600             a.val[1] = vrshrq_n_s32(v_dst_i.val[1], 4);
1601           } else {  // identity_size == 16
1602             v_dst_i.val[0] =
1603                 vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier);
1604             v_dst_i.val[1] =
1605                 vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier);
1606             a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
1607             a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
1608           }
1609           uint16x4x2_t frame_data;
1610           frame_data.val[0] = vld1_u16(dst + j);
1611           frame_data.val[1] = vld1_u16(dst + j + 4);
1612           b.val[0] =
1613               vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0]));
1614           b.val[1] =
1615               vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1]));
1616           vst1_u16(dst + j, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth));
1617           vst1_u16(dst + j + 4,
1618                    vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth));
1619           j += 8;
1620         } while (j < tx_width);
1621         dst += stride;
1622       } while (++i < tx_height);
1623     }
1624   } else {
1625     int i = 0;
1626     do {
1627       const int row = i * tx_width;
1628       int j = 0;
1629       do {
1630         const int32x4_t v_dst_i = vld1q_s32(&source[row + j]);
1631         const uint16x4_t frame_data = vld1_u16(dst + j);
1632         const int32x4_t a = vrshrq_n_s32(v_dst_i, 2);
1633         const int32x4_t b = vaddw_s16(a, vreinterpret_s16_u16(frame_data));
1634         const uint16x4_t d = vmin_u16(vqmovun_s32(b), v_max_bitdepth);
1635         vst1_u16(dst + j, d);
1636         j += 4;
1637       } while (j < tx_width);
1638       dst += stride;
1639     } while (++i < tx_height);
1640   }
1641 }
1642 
Identity4RowColumnStoreToFrame(Array2DView<uint16_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int32_t * LIBGAV1_RESTRICT source)1643 LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame(
1644     Array2DView<uint16_t> frame, const int start_x, const int start_y,
1645     const int tx_width, const int tx_height,
1646     const int32_t* LIBGAV1_RESTRICT source) {
1647   const int stride = frame.columns();
1648   uint16_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1649   const int32x4_t v_round = vdupq_n_s32((1 + (0)) << 11);
1650   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
1651 
1652   if (tx_width == 4) {
1653     int i = 0;
1654     do {
1655       const int32x4_t v_src = vld1q_s32(&source[i * 4]);
1656       const int32x4_t v_dst_row =
1657           vshrq_n_s32(vmlaq_n_s32(v_round, v_src, kIdentity4Multiplier), 12);
1658       const int32x4_t v_dst_col =
1659           vmlaq_n_s32(v_round, v_dst_row, kIdentity4Multiplier);
1660       const uint16x4_t frame_data = vld1_u16(dst);
1661       const int32x4_t a = vrshrq_n_s32(v_dst_col, 4 + 12);
1662       const int32x4_t b = vaddw_s16(a, vreinterpret_s16_u16(frame_data));
1663       vst1_u16(dst, vmin_u16(vqmovun_s32(b), v_max_bitdepth));
1664       dst += stride;
1665     } while (++i < tx_height);
1666   } else {
1667     int i = 0;
1668     do {
1669       const int row = i * tx_width;
1670       int j = 0;
1671       do {
1672         int32x4x2_t v_src, v_src_round, v_dst_row, v_dst_col, a, b;
1673         v_src.val[0] = vld1q_s32(&source[row + j]);
1674         v_src.val[1] = vld1q_s32(&source[row + j + 4]);
1675         v_src_round.val[0] = vshrq_n_s32(
1676             vmlaq_n_s32(v_round, v_src.val[0], kTransformRowMultiplier), 12);
1677         v_src_round.val[1] = vshrq_n_s32(
1678             vmlaq_n_s32(v_round, v_src.val[1], kTransformRowMultiplier), 12);
1679         v_dst_row.val[0] = vqaddq_s32(v_src_round.val[0], v_src_round.val[0]);
1680         v_dst_row.val[1] = vqaddq_s32(v_src_round.val[1], v_src_round.val[1]);
1681         v_dst_col.val[0] =
1682             vmlaq_n_s32(v_round, v_dst_row.val[0], kIdentity4Multiplier);
1683         v_dst_col.val[1] =
1684             vmlaq_n_s32(v_round, v_dst_row.val[1], kIdentity4Multiplier);
1685         uint16x4x2_t frame_data;
1686         frame_data.val[0] = vld1_u16(dst + j);
1687         frame_data.val[1] = vld1_u16(dst + j + 4);
1688         a.val[0] = vrshrq_n_s32(v_dst_col.val[0], 4 + 12);
1689         a.val[1] = vrshrq_n_s32(v_dst_col.val[1], 4 + 12);
1690         b.val[0] = vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0]));
1691         b.val[1] = vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1]));
1692         vst1_u16(dst + j, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth));
1693         vst1_u16(dst + j + 4, vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth));
1694         j += 8;
1695       } while (j < tx_width);
1696       dst += stride;
1697     } while (++i < tx_height);
1698   }
1699 }
1700 
Identity8Row32_NEON(void * dest,int32_t step)1701 LIBGAV1_ALWAYS_INLINE void Identity8Row32_NEON(void* dest, int32_t step) {
1702   auto* const dst = static_cast<int32_t*>(dest);
1703 
1704   // When combining the identity8 multiplier with the row shift, the
1705   // calculations for tx_height equal to 32 can be simplified from
1706   // ((A * 2) + 2) >> 2) to ((A + 1) >> 1).
1707   for (int i = 0; i < 4; ++i) {
1708     const int32x4_t v_src_lo = vld1q_s32(&dst[i * step]);
1709     const int32x4_t v_src_hi = vld1q_s32(&dst[(i * step) + 4]);
1710     const int32x4_t a_lo = vrshrq_n_s32(v_src_lo, 1);
1711     const int32x4_t a_hi = vrshrq_n_s32(v_src_hi, 1);
1712     vst1q_s32(&dst[i * step], vmovl_s16(vqmovn_s32(a_lo)));
1713     vst1q_s32(&dst[(i * step) + 4], vmovl_s16(vqmovn_s32(a_hi)));
1714   }
1715 }
1716 
Identity8Row4_NEON(void * dest,int32_t step)1717 LIBGAV1_ALWAYS_INLINE void Identity8Row4_NEON(void* dest, int32_t step) {
1718   auto* const dst = static_cast<int32_t*>(dest);
1719 
1720   for (int i = 0; i < 4; ++i) {
1721     const int32x4_t v_src_lo = vld1q_s32(&dst[i * step]);
1722     const int32x4_t v_src_hi = vld1q_s32(&dst[(i * step) + 4]);
1723     const int32x4_t v_srcx2_lo = vqaddq_s32(v_src_lo, v_src_lo);
1724     const int32x4_t v_srcx2_hi = vqaddq_s32(v_src_hi, v_src_hi);
1725     vst1q_s32(&dst[i * step], vmovl_s16(vqmovn_s32(v_srcx2_lo)));
1726     vst1q_s32(&dst[(i * step) + 4], vmovl_s16(vqmovn_s32(v_srcx2_hi)));
1727   }
1728 }
1729 
Identity8DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1730 LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, int adjusted_tx_height,
1731                                            bool should_round, int row_shift) {
1732   if (adjusted_tx_height > 1) return false;
1733 
1734   auto* dst = static_cast<int32_t*>(dest);
1735   const int32x4_t v_src0 = vdupq_n_s32(dst[0]);
1736   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1737   const int32x4_t v_src_round =
1738       vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
1739   const int32x4_t v_src = vbslq_s32(v_mask, v_src_round, v_src0);
1740   const int32x4_t v_srcx2 = vaddq_s32(v_src, v_src);
1741   const int32x4_t dst_0 = vqrshlq_s32(v_srcx2, vdupq_n_s32(-row_shift));
1742   vst1q_lane_s32(dst, vmovl_s16(vqmovn_s32(dst_0)), 0);
1743   return true;
1744 }
1745 
Identity16Row_NEON(void * dest,int32_t step,int shift)1746 LIBGAV1_ALWAYS_INLINE void Identity16Row_NEON(void* dest, int32_t step,
1747                                               int shift) {
1748   auto* const dst = static_cast<int32_t*>(dest);
1749   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
1750   const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
1751 
1752   for (int i = 0; i < 4; ++i) {
1753     for (int j = 0; j < 2; ++j) {
1754       int32x4x2_t v_src;
1755       v_src.val[0] = vld1q_s32(&dst[i * step + j * 8]);
1756       v_src.val[1] = vld1q_s32(&dst[i * step + j * 8 + 4]);
1757       const int32x4_t v_src_mult_lo =
1758           vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier);
1759       const int32x4_t v_src_mult_hi =
1760           vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier);
1761       const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
1762       const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift);
1763       vst1q_s32(&dst[i * step + j * 8], vmovl_s16(vqmovn_s32(shift_lo)));
1764       vst1q_s32(&dst[i * step + j * 8 + 4], vmovl_s16(vqmovn_s32(shift_hi)));
1765     }
1766   }
1767 }
1768 
Identity16DcOnly(void * dest,int adjusted_tx_height,bool should_round,int shift)1769 LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, int adjusted_tx_height,
1770                                             bool should_round, int shift) {
1771   if (adjusted_tx_height > 1) return false;
1772 
1773   auto* dst = static_cast<int32_t*>(dest);
1774   const int32x4_t v_src0 = vdupq_n_s32(dst[0]);
1775   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1776   const int32x4_t v_src_round =
1777       vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
1778   const int32x4_t v_src = vbslq_s32(v_mask, v_src_round, v_src0);
1779   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
1780   const int32x4_t v_src_mult_lo =
1781       vmlaq_n_s32(v_dual_round, v_src, kIdentity16Multiplier);
1782   const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, vdupq_n_s32(-(12 + shift)));
1783   vst1q_lane_s32(dst, vmovl_s16(vqmovn_s32(dst_0)), 0);
1784   return true;
1785 }
1786 
Identity32Row16_NEON(void * dest,const int32_t step)1787 LIBGAV1_ALWAYS_INLINE void Identity32Row16_NEON(void* dest,
1788                                                 const int32_t step) {
1789   auto* const dst = static_cast<int32_t*>(dest);
1790 
1791   // When combining the identity32 multiplier with the row shift, the
1792   // calculation for tx_height equal to 16 can be simplified from
1793   // ((A * 4) + 1) >> 1) to (A * 2).
1794   for (int i = 0; i < 4; ++i) {
1795     for (int j = 0; j < 32; j += 4) {
1796       const int32x4_t v_src = vld1q_s32(&dst[i * step + j]);
1797       const int32x4_t v_dst_i = vqaddq_s32(v_src, v_src);
1798       vst1q_s32(&dst[i * step + j], v_dst_i);
1799     }
1800   }
1801 }
1802 
Identity32DcOnly(void * dest,int adjusted_tx_height)1803 LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest,
1804                                             int adjusted_tx_height) {
1805   if (adjusted_tx_height > 1) return false;
1806 
1807   auto* dst = static_cast<int32_t*>(dest);
1808   const int32x2_t v_src0 = vdup_n_s32(dst[0]);
1809   const int32x2_t v_src =
1810       vqrdmulh_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
1811   // When combining the identity32 multiplier with the row shift, the
1812   // calculation for tx_height equal to 16 can be simplified from
1813   // ((A * 4) + 1) >> 1) to (A * 2).
1814   const int32x2_t v_dst_0 = vqadd_s32(v_src, v_src);
1815   vst1_lane_s32(dst, v_dst_0, 0);
1816   return true;
1817 }
1818 
1819 //------------------------------------------------------------------------------
1820 // Walsh Hadamard Transform.
1821 
1822 // Process 4 wht4 rows and columns.
Wht4_NEON(uint16_t * LIBGAV1_RESTRICT dst,const int dst_stride,const void * LIBGAV1_RESTRICT source,const int adjusted_tx_height)1823 LIBGAV1_ALWAYS_INLINE void Wht4_NEON(uint16_t* LIBGAV1_RESTRICT dst,
1824                                      const int dst_stride,
1825                                      const void* LIBGAV1_RESTRICT source,
1826                                      const int adjusted_tx_height) {
1827   const auto* const src = static_cast<const int32_t*>(source);
1828   int32x4_t s[4];
1829 
1830   if (adjusted_tx_height == 1) {
1831     // Special case: only src[0] is nonzero.
1832     //   src[0]  0   0   0
1833     //       0   0   0   0
1834     //       0   0   0   0
1835     //       0   0   0   0
1836     //
1837     // After the row and column transforms are applied, we have:
1838     //       f   h   h   h
1839     //       g   i   i   i
1840     //       g   i   i   i
1841     //       g   i   i   i
1842     // where f, g, h, i are computed as follows.
1843     int32_t f = (src[0] >> 2) - (src[0] >> 3);
1844     const int32_t g = f >> 1;
1845     f = f - (f >> 1);
1846     const int32_t h = (src[0] >> 3) - (src[0] >> 4);
1847     const int32_t i = (src[0] >> 4);
1848     s[0] = vdupq_n_s32(h);
1849     s[0] = vsetq_lane_s32(f, s[0], 0);
1850     s[1] = vdupq_n_s32(i);
1851     s[1] = vsetq_lane_s32(g, s[1], 0);
1852     s[2] = s[3] = s[1];
1853   } else {
1854     // Load the 4x4 source in transposed form.
1855     int32x4x4_t columns = vld4q_s32(src);
1856 
1857     // Shift right and permute the columns for the WHT.
1858     s[0] = vshrq_n_s32(columns.val[0], 2);
1859     s[2] = vshrq_n_s32(columns.val[1], 2);
1860     s[3] = vshrq_n_s32(columns.val[2], 2);
1861     s[1] = vshrq_n_s32(columns.val[3], 2);
1862 
1863     // Row transforms.
1864     s[0] = vaddq_s32(s[0], s[2]);
1865     s[3] = vsubq_s32(s[3], s[1]);
1866     int32x4_t e = vhsubq_s32(s[0], s[3]);  // e = (s[0] - s[3]) >> 1
1867     s[1] = vsubq_s32(e, s[1]);
1868     s[2] = vsubq_s32(e, s[2]);
1869     s[0] = vsubq_s32(s[0], s[1]);
1870     s[3] = vaddq_s32(s[3], s[2]);
1871 
1872     int32x4_t x[4];
1873     Transpose4x4(s, x);
1874 
1875     s[0] = x[0];
1876     s[2] = x[1];
1877     s[3] = x[2];
1878     s[1] = x[3];
1879 
1880     // Column transforms.
1881     s[0] = vaddq_s32(s[0], s[2]);
1882     s[3] = vsubq_s32(s[3], s[1]);
1883     e = vhsubq_s32(s[0], s[3]);  // e = (s[0] - s[3]) >> 1
1884     s[1] = vsubq_s32(e, s[1]);
1885     s[2] = vsubq_s32(e, s[2]);
1886     s[0] = vsubq_s32(s[0], s[1]);
1887     s[3] = vaddq_s32(s[3], s[2]);
1888   }
1889 
1890   // Store to frame.
1891   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
1892   for (int row = 0; row < 4; row += 1) {
1893     const uint16x4_t frame_data = vld1_u16(dst);
1894     const int32x4_t b = vaddw_s16(s[row], vreinterpret_s16_u16(frame_data));
1895     vst1_u16(dst, vmin_u16(vqmovun_s32(b), v_max_bitdepth));
1896     dst += dst_stride;
1897   }
1898 }
1899 
1900 //------------------------------------------------------------------------------
1901 // row/column transform loops
1902 
1903 template <int tx_height>
FlipColumns(int32_t * source,int tx_width)1904 LIBGAV1_ALWAYS_INLINE void FlipColumns(int32_t* source, int tx_width) {
1905   if (tx_width >= 16) {
1906     int i = 0;
1907     do {
1908       // 00 01 02 03
1909       const int32x4_t a = vld1q_s32(&source[i]);
1910       const int32x4_t b = vld1q_s32(&source[i + 4]);
1911       const int32x4_t c = vld1q_s32(&source[i + 8]);
1912       const int32x4_t d = vld1q_s32(&source[i + 12]);
1913       // 01 00 03 02
1914       const int32x4_t a_rev = vrev64q_s32(a);
1915       const int32x4_t b_rev = vrev64q_s32(b);
1916       const int32x4_t c_rev = vrev64q_s32(c);
1917       const int32x4_t d_rev = vrev64q_s32(d);
1918       // 03 02 01 00
1919       vst1q_s32(&source[i], vextq_s32(d_rev, d_rev, 2));
1920       vst1q_s32(&source[i + 4], vextq_s32(c_rev, c_rev, 2));
1921       vst1q_s32(&source[i + 8], vextq_s32(b_rev, b_rev, 2));
1922       vst1q_s32(&source[i + 12], vextq_s32(a_rev, a_rev, 2));
1923       i += 16;
1924     } while (i < tx_width * tx_height);
1925   } else if (tx_width == 8) {
1926     for (int i = 0; i < 8 * tx_height; i += 8) {
1927       // 00 01 02 03
1928       const int32x4_t a = vld1q_s32(&source[i]);
1929       const int32x4_t b = vld1q_s32(&source[i + 4]);
1930       // 01 00 03 02
1931       const int32x4_t a_rev = vrev64q_s32(a);
1932       const int32x4_t b_rev = vrev64q_s32(b);
1933       // 03 02 01 00
1934       vst1q_s32(&source[i], vextq_s32(b_rev, b_rev, 2));
1935       vst1q_s32(&source[i + 4], vextq_s32(a_rev, a_rev, 2));
1936     }
1937   } else {
1938     // Process two rows per iteration.
1939     for (int i = 0; i < 4 * tx_height; i += 8) {
1940       // 00 01 02 03
1941       const int32x4_t a = vld1q_s32(&source[i]);
1942       const int32x4_t b = vld1q_s32(&source[i + 4]);
1943       // 01 00 03 02
1944       const int32x4_t a_rev = vrev64q_s32(a);
1945       const int32x4_t b_rev = vrev64q_s32(b);
1946       // 03 02 01 00
1947       vst1q_s32(&source[i], vextq_s32(a_rev, a_rev, 2));
1948       vst1q_s32(&source[i + 4], vextq_s32(b_rev, b_rev, 2));
1949     }
1950   }
1951 }
1952 
1953 template <int tx_width>
ApplyRounding(int32_t * source,int num_rows)1954 LIBGAV1_ALWAYS_INLINE void ApplyRounding(int32_t* source, int num_rows) {
1955   // Process two rows per iteration.
1956   int i = 0;
1957   do {
1958     const int32x4_t a_lo = vld1q_s32(&source[i]);
1959     const int32x4_t a_hi = vld1q_s32(&source[i + 4]);
1960     const int32x4_t b_lo =
1961         vqrdmulhq_n_s32(a_lo, kTransformRowMultiplier << (31 - 12));
1962     const int32x4_t b_hi =
1963         vqrdmulhq_n_s32(a_hi, kTransformRowMultiplier << (31 - 12));
1964     vst1q_s32(&source[i], b_lo);
1965     vst1q_s32(&source[i + 4], b_hi);
1966     i += 8;
1967   } while (i < tx_width * num_rows);
1968 }
1969 
1970 template <int tx_width>
RowShift(int32_t * source,int num_rows,int row_shift)1971 LIBGAV1_ALWAYS_INLINE void RowShift(int32_t* source, int num_rows,
1972                                     int row_shift) {
1973   // vqrshlq_s32 will shift right if shift value is negative.
1974   row_shift = -row_shift;
1975 
1976   // Process two rows per iteration.
1977   int i = 0;
1978   do {
1979     const int32x4_t residual0 = vld1q_s32(&source[i]);
1980     const int32x4_t residual1 = vld1q_s32(&source[i + 4]);
1981     vst1q_s32(&source[i], vqrshlq_s32(residual0, vdupq_n_s32(row_shift)));
1982     vst1q_s32(&source[i + 4], vqrshlq_s32(residual1, vdupq_n_s32(row_shift)));
1983     i += 8;
1984   } while (i < tx_width * num_rows);
1985 }
1986 
1987 template <int tx_height, bool enable_flip_rows = false>
StoreToFrameWithRound(Array2DView<uint16_t> frame,const int start_x,const int start_y,const int tx_width,const int32_t * LIBGAV1_RESTRICT source,TransformType tx_type)1988 LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
1989     Array2DView<uint16_t> frame, const int start_x, const int start_y,
1990     const int tx_width, const int32_t* LIBGAV1_RESTRICT source,
1991     TransformType tx_type) {
1992   const bool flip_rows =
1993       enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
1994   const int stride = frame.columns();
1995   uint16_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1996 
1997   if (tx_width == 4) {
1998     for (int i = 0; i < tx_height; ++i) {
1999       const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4;
2000       const int32x4_t residual = vld1q_s32(&source[row]);
2001       const uint16x4_t frame_data = vld1_u16(dst);
2002       const int32x4_t a = vrshrq_n_s32(residual, 4);
2003       const uint32x4_t b = vaddw_u16(vreinterpretq_u32_s32(a), frame_data);
2004       const uint16x4_t d = vqmovun_s32(vreinterpretq_s32_u32(b));
2005       vst1_u16(dst, vmin_u16(d, vdup_n_u16((1 << kBitdepth10) - 1)));
2006       dst += stride;
2007     }
2008   } else {
2009     for (int i = 0; i < tx_height; ++i) {
2010       const int y = start_y + i;
2011       const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width;
2012       int j = 0;
2013       do {
2014         const int x = start_x + j;
2015         const int32x4_t residual = vld1q_s32(&source[row + j]);
2016         const int32x4_t residual_hi = vld1q_s32(&source[row + j + 4]);
2017         const uint16x8_t frame_data = vld1q_u16(frame[y] + x);
2018         const int32x4_t a = vrshrq_n_s32(residual, 4);
2019         const int32x4_t a_hi = vrshrq_n_s32(residual_hi, 4);
2020         const uint32x4_t b =
2021             vaddw_u16(vreinterpretq_u32_s32(a), vget_low_u16(frame_data));
2022         const uint32x4_t b_hi =
2023             vaddw_u16(vreinterpretq_u32_s32(a_hi), vget_high_u16(frame_data));
2024         const uint16x4_t d = vqmovun_s32(vreinterpretq_s32_u32(b));
2025         const uint16x4_t d_hi = vqmovun_s32(vreinterpretq_s32_u32(b_hi));
2026         vst1q_u16(frame[y] + x, vminq_u16(vcombine_u16(d, d_hi),
2027                                           vdupq_n_u16((1 << kBitdepth10) - 1)));
2028         j += 8;
2029       } while (j < tx_width);
2030     }
2031   }
2032 }
2033 
Dct4TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2034 void Dct4TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size,
2035                                int adjusted_tx_height, void* src_buffer,
2036                                int /*start_x*/, int /*start_y*/,
2037                                void* /*dst_frame*/) {
2038   auto* src = static_cast<int32_t*>(src_buffer);
2039   const int tx_height = kTransformHeight[tx_size];
2040   const bool should_round = (tx_height == 8);
2041   const int row_shift = static_cast<int>(tx_height == 16);
2042 
2043   if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) {
2044     return;
2045   }
2046 
2047   if (should_round) {
2048     ApplyRounding<4>(src, adjusted_tx_height);
2049   }
2050 
2051   // Process 4 1d dct4 rows in parallel per iteration.
2052   int i = adjusted_tx_height;
2053   auto* data = src;
2054   do {
2055     Dct4_NEON<ButterflyRotation_4>(data, /*step=*/4, /*is_row=*/true,
2056                                    row_shift);
2057     data += 16;
2058     i -= 4;
2059   } while (i != 0);
2060 }
2061 
Dct4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2062 void Dct4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2063                                   int adjusted_tx_height,
2064                                   void* LIBGAV1_RESTRICT src_buffer,
2065                                   int start_x, int start_y,
2066                                   void* LIBGAV1_RESTRICT dst_frame) {
2067   auto* src = static_cast<int32_t*>(src_buffer);
2068   const int tx_width = kTransformWidth[tx_size];
2069 
2070   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2071     FlipColumns<4>(src, tx_width);
2072   }
2073 
2074   if (!DctDcOnlyColumn<4>(src, adjusted_tx_height, tx_width)) {
2075     // Process 4 1d dct4 columns in parallel per iteration.
2076     int i = tx_width;
2077     auto* data = src;
2078     do {
2079       Dct4_NEON<ButterflyRotation_4>(data, tx_width, /*transpose=*/false,
2080                                      /*row_shift=*/0);
2081       data += 4;
2082       i -= 4;
2083     } while (i != 0);
2084   }
2085 
2086   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2087   StoreToFrameWithRound<4>(frame, start_x, start_y, tx_width, src, tx_type);
2088 }
2089 
Dct8TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2090 void Dct8TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size,
2091                                int adjusted_tx_height, void* src_buffer,
2092                                int /*start_x*/, int /*start_y*/,
2093                                void* /*dst_frame*/) {
2094   auto* src = static_cast<int32_t*>(src_buffer);
2095   const bool should_round = kShouldRound[tx_size];
2096   const uint8_t row_shift = kTransformRowShift[tx_size];
2097 
2098   if (DctDcOnly<8>(src, adjusted_tx_height, should_round, row_shift)) {
2099     return;
2100   }
2101 
2102   if (should_round) {
2103     ApplyRounding<8>(src, adjusted_tx_height);
2104   }
2105 
2106   // Process 4 1d dct8 rows in parallel per iteration.
2107   int i = adjusted_tx_height;
2108   auto* data = src;
2109   do {
2110     Dct8_NEON<ButterflyRotation_4>(data, /*step=*/8, /*is_row=*/true,
2111                                    row_shift);
2112     data += 32;
2113     i -= 4;
2114   } while (i != 0);
2115 }
2116 
Dct8TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2117 void Dct8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2118                                   int adjusted_tx_height,
2119                                   void* LIBGAV1_RESTRICT src_buffer,
2120                                   int start_x, int start_y,
2121                                   void* LIBGAV1_RESTRICT dst_frame) {
2122   auto* src = static_cast<int32_t*>(src_buffer);
2123   const int tx_width = kTransformWidth[tx_size];
2124 
2125   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2126     FlipColumns<8>(src, tx_width);
2127   }
2128 
2129   if (!DctDcOnlyColumn<8>(src, adjusted_tx_height, tx_width)) {
2130     // Process 4 1d dct8 columns in parallel per iteration.
2131     int i = tx_width;
2132     auto* data = src;
2133     do {
2134       Dct8_NEON<ButterflyRotation_4>(data, tx_width, /*is_row=*/false,
2135                                      /*row_shift=*/0);
2136       data += 4;
2137       i -= 4;
2138     } while (i != 0);
2139   }
2140   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2141   StoreToFrameWithRound<8>(frame, start_x, start_y, tx_width, src, tx_type);
2142 }
2143 
Dct16TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2144 void Dct16TransformLoopRow_NEON(TransformType /*tx_type*/,
2145                                 TransformSize tx_size, int adjusted_tx_height,
2146                                 void* src_buffer, int /*start_x*/,
2147                                 int /*start_y*/, void* /*dst_frame*/) {
2148   auto* src = static_cast<int32_t*>(src_buffer);
2149   const bool should_round = kShouldRound[tx_size];
2150   const uint8_t row_shift = kTransformRowShift[tx_size];
2151 
2152   if (DctDcOnly<16>(src, adjusted_tx_height, should_round, row_shift)) {
2153     return;
2154   }
2155 
2156   if (should_round) {
2157     ApplyRounding<16>(src, adjusted_tx_height);
2158   }
2159 
2160   assert(adjusted_tx_height % 4 == 0);
2161   int i = adjusted_tx_height;
2162   auto* data = src;
2163   do {
2164     // Process 4 1d dct16 rows in parallel per iteration.
2165     Dct16_NEON<ButterflyRotation_4>(data, 16, /*is_row=*/true, row_shift);
2166     data += 64;
2167     i -= 4;
2168   } while (i != 0);
2169 }
2170 
Dct16TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2171 void Dct16TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2172                                    int adjusted_tx_height,
2173                                    void* LIBGAV1_RESTRICT src_buffer,
2174                                    int start_x, int start_y,
2175                                    void* LIBGAV1_RESTRICT dst_frame) {
2176   auto* src = static_cast<int32_t*>(src_buffer);
2177   const int tx_width = kTransformWidth[tx_size];
2178 
2179   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2180     FlipColumns<16>(src, tx_width);
2181   }
2182 
2183   if (!DctDcOnlyColumn<16>(src, adjusted_tx_height, tx_width)) {
2184     // Process 4 1d dct16 columns in parallel per iteration.
2185     int i = tx_width;
2186     auto* data = src;
2187     do {
2188       Dct16_NEON<ButterflyRotation_4>(data, tx_width, /*is_row=*/false,
2189                                       /*row_shift=*/0);
2190       data += 4;
2191       i -= 4;
2192     } while (i != 0);
2193   }
2194   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2195   StoreToFrameWithRound<16>(frame, start_x, start_y, tx_width, src, tx_type);
2196 }
2197 
Dct32TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2198 void Dct32TransformLoopRow_NEON(TransformType /*tx_type*/,
2199                                 TransformSize tx_size, int adjusted_tx_height,
2200                                 void* src_buffer, int /*start_x*/,
2201                                 int /*start_y*/, void* /*dst_frame*/) {
2202   auto* src = static_cast<int32_t*>(src_buffer);
2203   const bool should_round = kShouldRound[tx_size];
2204   const uint8_t row_shift = kTransformRowShift[tx_size];
2205 
2206   if (DctDcOnly<32>(src, adjusted_tx_height, should_round, row_shift)) {
2207     return;
2208   }
2209 
2210   if (should_round) {
2211     ApplyRounding<32>(src, adjusted_tx_height);
2212   }
2213 
2214   assert(adjusted_tx_height % 4 == 0);
2215   int i = adjusted_tx_height;
2216   auto* data = src;
2217   do {
2218     // Process 4 1d dct32 rows in parallel per iteration.
2219     Dct32_NEON(data, 32, /*is_row=*/true, row_shift);
2220     data += 128;
2221     i -= 4;
2222   } while (i != 0);
2223 }
2224 
Dct32TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2225 void Dct32TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2226                                    int adjusted_tx_height,
2227                                    void* LIBGAV1_RESTRICT src_buffer,
2228                                    int start_x, int start_y,
2229                                    void* LIBGAV1_RESTRICT dst_frame) {
2230   auto* src = static_cast<int32_t*>(src_buffer);
2231   const int tx_width = kTransformWidth[tx_size];
2232 
2233   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2234     FlipColumns<32>(src, tx_width);
2235   }
2236 
2237   if (!DctDcOnlyColumn<32>(src, adjusted_tx_height, tx_width)) {
2238     // Process 4 1d dct32 columns in parallel per iteration.
2239     int i = tx_width;
2240     auto* data = src;
2241     do {
2242       Dct32_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0);
2243       data += 4;
2244       i -= 4;
2245     } while (i != 0);
2246   }
2247   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2248   StoreToFrameWithRound<32>(frame, start_x, start_y, tx_width, src, tx_type);
2249 }
2250 
Dct64TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2251 void Dct64TransformLoopRow_NEON(TransformType /*tx_type*/,
2252                                 TransformSize tx_size, int adjusted_tx_height,
2253                                 void* src_buffer, int /*start_x*/,
2254                                 int /*start_y*/, void* /*dst_frame*/) {
2255   auto* src = static_cast<int32_t*>(src_buffer);
2256   const bool should_round = kShouldRound[tx_size];
2257   const uint8_t row_shift = kTransformRowShift[tx_size];
2258 
2259   if (DctDcOnly<64>(src, adjusted_tx_height, should_round, row_shift)) {
2260     return;
2261   }
2262 
2263   if (should_round) {
2264     ApplyRounding<64>(src, adjusted_tx_height);
2265   }
2266 
2267   assert(adjusted_tx_height % 4 == 0);
2268   int i = adjusted_tx_height;
2269   auto* data = src;
2270   do {
2271     // Process 4 1d dct64 rows in parallel per iteration.
2272     Dct64_NEON(data, 64, /*is_row=*/true, row_shift);
2273     data += 128 * 2;
2274     i -= 4;
2275   } while (i != 0);
2276 }
2277 
Dct64TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2278 void Dct64TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2279                                    int adjusted_tx_height,
2280                                    void* LIBGAV1_RESTRICT src_buffer,
2281                                    int start_x, int start_y,
2282                                    void* LIBGAV1_RESTRICT dst_frame) {
2283   auto* src = static_cast<int32_t*>(src_buffer);
2284   const int tx_width = kTransformWidth[tx_size];
2285 
2286   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2287     FlipColumns<64>(src, tx_width);
2288   }
2289 
2290   if (!DctDcOnlyColumn<64>(src, adjusted_tx_height, tx_width)) {
2291     // Process 4 1d dct64 columns in parallel per iteration.
2292     int i = tx_width;
2293     auto* data = src;
2294     do {
2295       Dct64_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0);
2296       data += 4;
2297       i -= 4;
2298     } while (i != 0);
2299   }
2300   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2301   StoreToFrameWithRound<64>(frame, start_x, start_y, tx_width, src, tx_type);
2302 }
2303 
Adst4TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2304 void Adst4TransformLoopRow_NEON(TransformType /*tx_type*/,
2305                                 TransformSize tx_size, int adjusted_tx_height,
2306                                 void* src_buffer, int /*start_x*/,
2307                                 int /*start_y*/, void* /*dst_frame*/) {
2308   auto* src = static_cast<int32_t*>(src_buffer);
2309   const int tx_height = kTransformHeight[tx_size];
2310   const int row_shift = static_cast<int>(tx_height == 16);
2311   const bool should_round = (tx_height == 8);
2312 
2313   if (Adst4DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2314     return;
2315   }
2316 
2317   if (should_round) {
2318     ApplyRounding<4>(src, adjusted_tx_height);
2319   }
2320 
2321   // Process 4 1d adst4 rows in parallel per iteration.
2322   int i = adjusted_tx_height;
2323   auto* data = src;
2324   do {
2325     Adst4_NEON(data, /*step=*/4, /*is_row=*/true, row_shift);
2326     data += 16;
2327     i -= 4;
2328   } while (i != 0);
2329 }
2330 
Adst4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2331 void Adst4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2332                                    int adjusted_tx_height,
2333                                    void* LIBGAV1_RESTRICT src_buffer,
2334                                    int start_x, int start_y,
2335                                    void* LIBGAV1_RESTRICT dst_frame) {
2336   auto* src = static_cast<int32_t*>(src_buffer);
2337   const int tx_width = kTransformWidth[tx_size];
2338 
2339   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2340     FlipColumns<4>(src, tx_width);
2341   }
2342 
2343   if (!Adst4DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2344     // Process 4 1d adst4 columns in parallel per iteration.
2345     int i = tx_width;
2346     auto* data = src;
2347     do {
2348       Adst4_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0);
2349       data += 4;
2350       i -= 4;
2351     } while (i != 0);
2352   }
2353 
2354   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2355   StoreToFrameWithRound<4, /*enable_flip_rows=*/true>(frame, start_x, start_y,
2356                                                       tx_width, src, tx_type);
2357 }
2358 
Adst8TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2359 void Adst8TransformLoopRow_NEON(TransformType /*tx_type*/,
2360                                 TransformSize tx_size, int adjusted_tx_height,
2361                                 void* src_buffer, int /*start_x*/,
2362                                 int /*start_y*/, void* /*dst_frame*/) {
2363   auto* src = static_cast<int32_t*>(src_buffer);
2364   const bool should_round = kShouldRound[tx_size];
2365   const uint8_t row_shift = kTransformRowShift[tx_size];
2366 
2367   if (Adst8DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2368     return;
2369   }
2370 
2371   if (should_round) {
2372     ApplyRounding<8>(src, adjusted_tx_height);
2373   }
2374 
2375   // Process 4 1d adst8 rows in parallel per iteration.
2376   assert(adjusted_tx_height % 4 == 0);
2377   int i = adjusted_tx_height;
2378   auto* data = src;
2379   do {
2380     Adst8_NEON<ButterflyRotation_4>(data, /*step=*/8,
2381                                     /*transpose=*/true, row_shift);
2382     data += 32;
2383     i -= 4;
2384   } while (i != 0);
2385 }
2386 
Adst8TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2387 void Adst8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2388                                    int adjusted_tx_height,
2389                                    void* LIBGAV1_RESTRICT src_buffer,
2390                                    int start_x, int start_y,
2391                                    void* LIBGAV1_RESTRICT dst_frame) {
2392   auto* src = static_cast<int32_t*>(src_buffer);
2393   const int tx_width = kTransformWidth[tx_size];
2394 
2395   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2396     FlipColumns<8>(src, tx_width);
2397   }
2398 
2399   if (!Adst8DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2400     // Process 4 1d adst8 columns in parallel per iteration.
2401     int i = tx_width;
2402     auto* data = src;
2403     do {
2404       Adst8_NEON<ButterflyRotation_4>(data, tx_width, /*transpose=*/false,
2405                                       /*row_shift=*/0);
2406       data += 4;
2407       i -= 4;
2408     } while (i != 0);
2409   }
2410   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2411   StoreToFrameWithRound<8, /*enable_flip_rows=*/true>(frame, start_x, start_y,
2412                                                       tx_width, src, tx_type);
2413 }
2414 
Adst16TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2415 void Adst16TransformLoopRow_NEON(TransformType /*tx_type*/,
2416                                  TransformSize tx_size, int adjusted_tx_height,
2417                                  void* src_buffer, int /*start_x*/,
2418                                  int /*start_y*/, void* /*dst_frame*/) {
2419   auto* src = static_cast<int32_t*>(src_buffer);
2420   const bool should_round = kShouldRound[tx_size];
2421   const uint8_t row_shift = kTransformRowShift[tx_size];
2422 
2423   if (Adst16DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2424     return;
2425   }
2426 
2427   if (should_round) {
2428     ApplyRounding<16>(src, adjusted_tx_height);
2429   }
2430 
2431   assert(adjusted_tx_height % 4 == 0);
2432   int i = adjusted_tx_height;
2433   do {
2434     // Process 4 1d adst16 rows in parallel per iteration.
2435     Adst16_NEON<ButterflyRotation_4>(src, 16, /*is_row=*/true, row_shift);
2436     src += 64;
2437     i -= 4;
2438   } while (i != 0);
2439 }
2440 
Adst16TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2441 void Adst16TransformLoopColumn_NEON(TransformType tx_type,
2442                                     TransformSize tx_size,
2443                                     int adjusted_tx_height,
2444                                     void* LIBGAV1_RESTRICT src_buffer,
2445                                     int start_x, int start_y,
2446                                     void* LIBGAV1_RESTRICT dst_frame) {
2447   auto* src = static_cast<int32_t*>(src_buffer);
2448   const int tx_width = kTransformWidth[tx_size];
2449 
2450   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2451     FlipColumns<16>(src, tx_width);
2452   }
2453 
2454   if (!Adst16DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2455     int i = tx_width;
2456     auto* data = src;
2457     do {
2458       // Process 4 1d adst16 columns in parallel per iteration.
2459       Adst16_NEON<ButterflyRotation_4>(data, tx_width, /*is_row=*/false,
2460                                        /*row_shift=*/0);
2461       data += 4;
2462       i -= 4;
2463     } while (i != 0);
2464   }
2465   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2466   StoreToFrameWithRound<16, /*enable_flip_rows=*/true>(frame, start_x, start_y,
2467                                                        tx_width, src, tx_type);
2468 }
2469 
Identity4TransformLoopRow_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2470 void Identity4TransformLoopRow_NEON(TransformType tx_type,
2471                                     TransformSize tx_size,
2472                                     int adjusted_tx_height, void* src_buffer,
2473                                     int /*start_x*/, int /*start_y*/,
2474                                     void* /*dst_frame*/) {
2475   // Special case: Process row calculations during column transform call.
2476   // Improves performance.
2477   if (tx_type == kTransformTypeIdentityIdentity &&
2478       tx_size == kTransformSize4x4) {
2479     return;
2480   }
2481 
2482   auto* src = static_cast<int32_t*>(src_buffer);
2483   const int tx_height = kTransformHeight[tx_size];
2484   const bool should_round = (tx_height == 8);
2485 
2486   if (Identity4DcOnly(src, adjusted_tx_height, should_round, tx_height)) {
2487     return;
2488   }
2489 
2490   if (should_round) {
2491     ApplyRounding<4>(src, adjusted_tx_height);
2492   }
2493 
2494   const int shift = tx_height > 8 ? 1 : 0;
2495   int i = adjusted_tx_height;
2496   do {
2497     Identity4_NEON(src, /*step=*/4, shift);
2498     src += 16;
2499     i -= 4;
2500   } while (i != 0);
2501 }
2502 
Identity4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2503 void Identity4TransformLoopColumn_NEON(TransformType tx_type,
2504                                        TransformSize tx_size,
2505                                        int adjusted_tx_height,
2506                                        void* LIBGAV1_RESTRICT src_buffer,
2507                                        int start_x, int start_y,
2508                                        void* LIBGAV1_RESTRICT dst_frame) {
2509   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2510   auto* src = static_cast<int32_t*>(src_buffer);
2511   const int tx_width = kTransformWidth[tx_size];
2512 
2513   // Special case: Process row calculations during column transform call.
2514   if (tx_type == kTransformTypeIdentityIdentity &&
2515       (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) {
2516     Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width,
2517                                    adjusted_tx_height, src);
2518     return;
2519   }
2520 
2521   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2522     FlipColumns<4>(src, tx_width);
2523   }
2524 
2525   IdentityColumnStoreToFrame<4>(frame, start_x, start_y, tx_width,
2526                                 adjusted_tx_height, src);
2527 }
2528 
Identity8TransformLoopRow_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2529 void Identity8TransformLoopRow_NEON(TransformType tx_type,
2530                                     TransformSize tx_size,
2531                                     int adjusted_tx_height, void* src_buffer,
2532                                     int /*start_x*/, int /*start_y*/,
2533                                     void* /*dst_frame*/) {
2534   // Special case: Process row calculations during column transform call.
2535   // Improves performance.
2536   if (tx_type == kTransformTypeIdentityIdentity &&
2537       tx_size == kTransformSize8x4) {
2538     return;
2539   }
2540 
2541   auto* src = static_cast<int32_t*>(src_buffer);
2542   const int tx_height = kTransformHeight[tx_size];
2543   const bool should_round = kShouldRound[tx_size];
2544   const uint8_t row_shift = kTransformRowShift[tx_size];
2545 
2546   if (Identity8DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2547     return;
2548   }
2549   if (should_round) {
2550     ApplyRounding<8>(src, adjusted_tx_height);
2551   }
2552 
2553   // When combining the identity8 multiplier with the row shift, the
2554   // calculations for tx_height == 8 and tx_height == 16 can be simplified
2555   // from ((A * 2) + 1) >> 1) to A. For 10bpp, A must be clamped to a signed 16
2556   // bit value.
2557   if ((tx_height & 0x18) != 0) {
2558     for (int i = 0; i < tx_height; ++i) {
2559       const int32x4_t v_src_lo = vld1q_s32(&src[i * 8]);
2560       const int32x4_t v_src_hi = vld1q_s32(&src[(i * 8) + 4]);
2561       vst1q_s32(&src[i * 8], vmovl_s16(vqmovn_s32(v_src_lo)));
2562       vst1q_s32(&src[(i * 8) + 4], vmovl_s16(vqmovn_s32(v_src_hi)));
2563     }
2564     return;
2565   }
2566   if (tx_height == 32) {
2567     int i = adjusted_tx_height;
2568     do {
2569       Identity8Row32_NEON(src, /*step=*/8);
2570       src += 32;
2571       i -= 4;
2572     } while (i != 0);
2573     return;
2574   }
2575 
2576   assert(tx_size == kTransformSize8x4);
2577   int i = adjusted_tx_height;
2578   do {
2579     Identity8Row4_NEON(src, /*step=*/8);
2580     src += 32;
2581     i -= 4;
2582   } while (i != 0);
2583 }
2584 
Identity8TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2585 void Identity8TransformLoopColumn_NEON(TransformType tx_type,
2586                                        TransformSize tx_size,
2587                                        int adjusted_tx_height,
2588                                        void* LIBGAV1_RESTRICT src_buffer,
2589                                        int start_x, int start_y,
2590                                        void* LIBGAV1_RESTRICT dst_frame) {
2591   auto* src = static_cast<int32_t*>(src_buffer);
2592   const int tx_width = kTransformWidth[tx_size];
2593 
2594   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2595     FlipColumns<8>(src, tx_width);
2596   }
2597   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2598   IdentityColumnStoreToFrame<8>(frame, start_x, start_y, tx_width,
2599                                 adjusted_tx_height, src);
2600 }
2601 
Identity16TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2602 void Identity16TransformLoopRow_NEON(TransformType /*tx_type*/,
2603                                      TransformSize tx_size,
2604                                      int adjusted_tx_height, void* src_buffer,
2605                                      int /*start_x*/, int /*start_y*/,
2606                                      void* /*dst_frame*/) {
2607   auto* src = static_cast<int32_t*>(src_buffer);
2608   const bool should_round = kShouldRound[tx_size];
2609   const uint8_t row_shift = kTransformRowShift[tx_size];
2610 
2611   if (Identity16DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2612     return;
2613   }
2614 
2615   if (should_round) {
2616     ApplyRounding<16>(src, adjusted_tx_height);
2617   }
2618   int i = adjusted_tx_height;
2619   do {
2620     Identity16Row_NEON(src, /*step=*/16, row_shift);
2621     src += 64;
2622     i -= 4;
2623   } while (i != 0);
2624 }
2625 
Identity16TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2626 void Identity16TransformLoopColumn_NEON(TransformType tx_type,
2627                                         TransformSize tx_size,
2628                                         int adjusted_tx_height,
2629                                         void* LIBGAV1_RESTRICT src_buffer,
2630                                         int start_x, int start_y,
2631                                         void* LIBGAV1_RESTRICT dst_frame) {
2632   auto* src = static_cast<int32_t*>(src_buffer);
2633   const int tx_width = kTransformWidth[tx_size];
2634 
2635   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2636     FlipColumns<16>(src, tx_width);
2637   }
2638   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2639   IdentityColumnStoreToFrame<16>(frame, start_x, start_y, tx_width,
2640                                  adjusted_tx_height, src);
2641 }
2642 
Identity32TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2643 void Identity32TransformLoopRow_NEON(TransformType /*tx_type*/,
2644                                      TransformSize tx_size,
2645                                      int adjusted_tx_height, void* src_buffer,
2646                                      int /*start_x*/, int /*start_y*/,
2647                                      void* /*dst_frame*/) {
2648   const int tx_height = kTransformHeight[tx_size];
2649 
2650   // When combining the identity32 multiplier with the row shift, the
2651   // calculations for tx_height == 8 and tx_height == 32 can be simplified
2652   // from ((A * 4) + 2) >> 2) to A.
2653   if ((tx_height & 0x28) != 0) {
2654     return;
2655   }
2656 
2657   // Process kTransformSize32x16. The src is always rounded before the identity
2658   // transform and shifted by 1 afterwards.
2659   auto* src = static_cast<int32_t*>(src_buffer);
2660   if (Identity32DcOnly(src, adjusted_tx_height)) {
2661     return;
2662   }
2663 
2664   assert(tx_size == kTransformSize32x16);
2665   ApplyRounding<32>(src, adjusted_tx_height);
2666   int i = adjusted_tx_height;
2667   do {
2668     Identity32Row16_NEON(src, /*step=*/32);
2669     src += 128;
2670     i -= 4;
2671   } while (i != 0);
2672 }
2673 
Identity32TransformLoopColumn_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2674 void Identity32TransformLoopColumn_NEON(TransformType /*tx_type*/,
2675                                         TransformSize tx_size,
2676                                         int adjusted_tx_height,
2677                                         void* LIBGAV1_RESTRICT src_buffer,
2678                                         int start_x, int start_y,
2679                                         void* LIBGAV1_RESTRICT dst_frame) {
2680   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2681   auto* src = static_cast<int32_t*>(src_buffer);
2682   const int tx_width = kTransformWidth[tx_size];
2683 
2684   IdentityColumnStoreToFrame<32>(frame, start_x, start_y, tx_width,
2685                                  adjusted_tx_height, src);
2686 }
2687 
Wht4TransformLoopRow_NEON(TransformType tx_type,TransformSize tx_size,int,void *,int,int,void *)2688 void Wht4TransformLoopRow_NEON(TransformType tx_type, TransformSize tx_size,
2689                                int /*adjusted_tx_height*/, void* /*src_buffer*/,
2690                                int /*start_x*/, int /*start_y*/,
2691                                void* /*dst_frame*/) {
2692   assert(tx_type == kTransformTypeDctDct);
2693   assert(tx_size == kTransformSize4x4);
2694   static_cast<void>(tx_type);
2695   static_cast<void>(tx_size);
2696   // Do both row and column transforms in the column-transform pass.
2697 }
2698 
Wht4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2699 void Wht4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2700                                   int adjusted_tx_height,
2701                                   void* LIBGAV1_RESTRICT src_buffer,
2702                                   int start_x, int start_y,
2703                                   void* LIBGAV1_RESTRICT dst_frame) {
2704   assert(tx_type == kTransformTypeDctDct);
2705   assert(tx_size == kTransformSize4x4);
2706   static_cast<void>(tx_type);
2707   static_cast<void>(tx_size);
2708 
2709   // Process 4 1d wht4 rows and columns in parallel.
2710   const auto* src = static_cast<int32_t*>(src_buffer);
2711   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2712   uint16_t* dst = frame[start_y] + start_x;
2713   const int dst_stride = frame.columns();
2714   Wht4_NEON(dst, dst_stride, src, adjusted_tx_height);
2715 }
2716 
2717 //------------------------------------------------------------------------------
2718 
Init10bpp()2719 void Init10bpp() {
2720   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
2721   assert(dsp != nullptr);
2722   // Maximum transform size for Dct is 64.
2723   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
2724       Dct4TransformLoopRow_NEON;
2725   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
2726       Dct4TransformLoopColumn_NEON;
2727   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
2728       Dct8TransformLoopRow_NEON;
2729   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
2730       Dct8TransformLoopColumn_NEON;
2731   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
2732       Dct16TransformLoopRow_NEON;
2733   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
2734       Dct16TransformLoopColumn_NEON;
2735   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
2736       Dct32TransformLoopRow_NEON;
2737   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
2738       Dct32TransformLoopColumn_NEON;
2739   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
2740       Dct64TransformLoopRow_NEON;
2741   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
2742       Dct64TransformLoopColumn_NEON;
2743 
2744   // Maximum transform size for Adst is 16.
2745   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
2746       Adst4TransformLoopRow_NEON;
2747   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
2748       Adst4TransformLoopColumn_NEON;
2749   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
2750       Adst8TransformLoopRow_NEON;
2751   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
2752       Adst8TransformLoopColumn_NEON;
2753   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
2754       Adst16TransformLoopRow_NEON;
2755   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
2756       Adst16TransformLoopColumn_NEON;
2757 
2758   // Maximum transform size for Identity transform is 32.
2759   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
2760       Identity4TransformLoopRow_NEON;
2761   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
2762       Identity4TransformLoopColumn_NEON;
2763   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
2764       Identity8TransformLoopRow_NEON;
2765   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
2766       Identity8TransformLoopColumn_NEON;
2767   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
2768       Identity16TransformLoopRow_NEON;
2769   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
2770       Identity16TransformLoopColumn_NEON;
2771   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
2772       Identity32TransformLoopRow_NEON;
2773   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
2774       Identity32TransformLoopColumn_NEON;
2775 
2776   // Maximum transform size for Wht is 4.
2777   dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
2778       Wht4TransformLoopRow_NEON;
2779   dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
2780       Wht4TransformLoopColumn_NEON;
2781 }
2782 
2783 }  // namespace
2784 
InverseTransformInit10bpp_NEON()2785 void InverseTransformInit10bpp_NEON() { Init10bpp(); }
2786 
2787 }  // namespace dsp
2788 }  // namespace libgav1
2789 #else   // !LIBGAV1_ENABLE_NEON || LIBGAV1_MAX_BITDEPTH < 10
2790 namespace libgav1 {
2791 namespace dsp {
2792 
InverseTransformInit10bpp_NEON()2793 void InverseTransformInit10bpp_NEON() {}
2794 
2795 }  // namespace dsp
2796 }  // namespace libgav1
2797 #endif  // LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
2798