xref: /aosp_15_r20/external/libgav1/src/dsp/x86/inverse_transform_sse4.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/inverse_transform.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_SSE4_1
19 
20 #include <smmintrin.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cstdint>
25 #include <cstring>
26 
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/dsp/x86/common_sse4.h"
30 #include "src/dsp/x86/transpose_sse4.h"
31 #include "src/utils/array_2d.h"
32 #include "src/utils/common.h"
33 #include "src/utils/compiler_attributes.h"
34 
35 namespace libgav1 {
36 namespace dsp {
37 namespace low_bitdepth {
38 namespace {
39 
40 // Include the constants and utility functions inside the anonymous namespace.
41 #include "src/dsp/inverse_transform.inc"
42 
43 template <int store_width, int store_count>
StoreDst(int16_t * LIBGAV1_RESTRICT dst,int32_t stride,int32_t idx,const __m128i * s)44 LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* LIBGAV1_RESTRICT dst,
45                                     int32_t stride, int32_t idx,
46                                     const __m128i* s) {
47   // NOTE: It is expected that the compiler will unroll these loops.
48   if (store_width == 16) {
49     for (int i = 0; i < store_count; i += 4) {
50       StoreUnaligned16(&dst[i * stride + idx], s[i]);
51       StoreUnaligned16(&dst[(i + 1) * stride + idx], s[i + 1]);
52       StoreUnaligned16(&dst[(i + 2) * stride + idx], s[i + 2]);
53       StoreUnaligned16(&dst[(i + 3) * stride + idx], s[i + 3]);
54     }
55   }
56   if (store_width == 8) {
57     for (int i = 0; i < store_count; i += 4) {
58       StoreLo8(&dst[i * stride + idx], s[i]);
59       StoreLo8(&dst[(i + 1) * stride + idx], s[i + 1]);
60       StoreLo8(&dst[(i + 2) * stride + idx], s[i + 2]);
61       StoreLo8(&dst[(i + 3) * stride + idx], s[i + 3]);
62     }
63   }
64 }
65 
66 template <int load_width, int load_count>
LoadSrc(const int16_t * LIBGAV1_RESTRICT src,int32_t stride,int32_t idx,__m128i * x)67 LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* LIBGAV1_RESTRICT src,
68                                    int32_t stride, int32_t idx, __m128i* x) {
69   // NOTE: It is expected that the compiler will unroll these loops.
70   if (load_width == 16) {
71     for (int i = 0; i < load_count; i += 4) {
72       x[i] = LoadUnaligned16(&src[i * stride + idx]);
73       x[i + 1] = LoadUnaligned16(&src[(i + 1) * stride + idx]);
74       x[i + 2] = LoadUnaligned16(&src[(i + 2) * stride + idx]);
75       x[i + 3] = LoadUnaligned16(&src[(i + 3) * stride + idx]);
76     }
77   }
78   if (load_width == 8) {
79     for (int i = 0; i < load_count; i += 4) {
80       x[i] = LoadLo8(&src[i * stride + idx]);
81       x[i + 1] = LoadLo8(&src[(i + 1) * stride + idx]);
82       x[i + 2] = LoadLo8(&src[(i + 2) * stride + idx]);
83       x[i + 3] = LoadLo8(&src[(i + 3) * stride + idx]);
84     }
85   }
86 }
87 
88 // Butterfly rotate 4 values.
ButterflyRotation_4(__m128i * a,__m128i * b,const int angle,const bool flip)89 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(__m128i* a, __m128i* b,
90                                                const int angle,
91                                                const bool flip) {
92   const int16_t cos128 = Cos128(angle);
93   const int16_t sin128 = Sin128(angle);
94   const __m128i psin_pcos = _mm_set1_epi32(
95       static_cast<uint16_t>(cos128) | (static_cast<uint32_t>(sin128) << 16));
96   const __m128i ba = _mm_unpacklo_epi16(*a, *b);
97   const __m128i ab = _mm_unpacklo_epi16(*b, *a);
98   const __m128i sign = _mm_set1_epi32(static_cast<int>(0x80000001));
99   // -sin cos, -sin cos, -sin cos, -sin cos
100   const __m128i msin_pcos = _mm_sign_epi16(psin_pcos, sign);
101   const __m128i x0 = _mm_madd_epi16(ba, msin_pcos);
102   const __m128i y0 = _mm_madd_epi16(ab, psin_pcos);
103   const __m128i x1 = RightShiftWithRounding_S32(x0, 12);
104   const __m128i y1 = RightShiftWithRounding_S32(y0, 12);
105   const __m128i x = _mm_packs_epi32(x1, x1);
106   const __m128i y = _mm_packs_epi32(y1, y1);
107   if (flip) {
108     *a = y;
109     *b = x;
110   } else {
111     *a = x;
112     *b = y;
113   }
114 }
115 
116 // Butterfly rotate 8 values.
ButterflyRotation_8(__m128i * a,__m128i * b,const int angle,const bool flip)117 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_8(__m128i* a, __m128i* b,
118                                                const int angle,
119                                                const bool flip) {
120   const int16_t cos128 = Cos128(angle);
121   const int16_t sin128 = Sin128(angle);
122   const __m128i psin_pcos = _mm_set1_epi32(
123       static_cast<uint16_t>(cos128) | (static_cast<uint32_t>(sin128) << 16));
124   const __m128i sign = _mm_set1_epi32(static_cast<int>(0x80000001));
125   // -sin cos, -sin cos, -sin cos, -sin cos
126   const __m128i msin_pcos = _mm_sign_epi16(psin_pcos, sign);
127   const __m128i ba = _mm_unpacklo_epi16(*a, *b);
128   const __m128i ab = _mm_unpacklo_epi16(*b, *a);
129   const __m128i ba_hi = _mm_unpackhi_epi16(*a, *b);
130   const __m128i ab_hi = _mm_unpackhi_epi16(*b, *a);
131   const __m128i x0 = _mm_madd_epi16(ba, msin_pcos);
132   const __m128i y0 = _mm_madd_epi16(ab, psin_pcos);
133   const __m128i x0_hi = _mm_madd_epi16(ba_hi, msin_pcos);
134   const __m128i y0_hi = _mm_madd_epi16(ab_hi, psin_pcos);
135   const __m128i x1 = RightShiftWithRounding_S32(x0, 12);
136   const __m128i y1 = RightShiftWithRounding_S32(y0, 12);
137   const __m128i x1_hi = RightShiftWithRounding_S32(x0_hi, 12);
138   const __m128i y1_hi = RightShiftWithRounding_S32(y0_hi, 12);
139   const __m128i x = _mm_packs_epi32(x1, x1_hi);
140   const __m128i y = _mm_packs_epi32(y1, y1_hi);
141   if (flip) {
142     *a = y;
143     *b = x;
144   } else {
145     *a = x;
146     *b = y;
147   }
148 }
149 
ButterflyRotation_FirstIsZero(__m128i * a,__m128i * b,const int angle,const bool flip)150 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_FirstIsZero(__m128i* a, __m128i* b,
151                                                          const int angle,
152                                                          const bool flip) {
153   const int16_t cos128 = Cos128(angle);
154   const int16_t sin128 = Sin128(angle);
155   const __m128i pcos = _mm_set1_epi16(cos128 << 3);
156   const __m128i psin = _mm_set1_epi16(-(sin128 << 3));
157   const __m128i x = _mm_mulhrs_epi16(*b, psin);
158   const __m128i y = _mm_mulhrs_epi16(*b, pcos);
159   if (flip) {
160     *a = y;
161     *b = x;
162   } else {
163     *a = x;
164     *b = y;
165   }
166 }
167 
ButterflyRotation_SecondIsZero(__m128i * a,__m128i * b,const int angle,const bool flip)168 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(__m128i* a,
169                                                           __m128i* b,
170                                                           const int angle,
171                                                           const bool flip) {
172   const int16_t cos128 = Cos128(angle);
173   const int16_t sin128 = Sin128(angle);
174   const __m128i pcos = _mm_set1_epi16(cos128 << 3);
175   const __m128i psin = _mm_set1_epi16(sin128 << 3);
176   const __m128i x = _mm_mulhrs_epi16(*a, pcos);
177   const __m128i y = _mm_mulhrs_epi16(*a, psin);
178   if (flip) {
179     *a = y;
180     *b = x;
181   } else {
182     *a = x;
183     *b = y;
184   }
185 }
186 
HadamardRotation(__m128i * a,__m128i * b,bool flip)187 LIBGAV1_ALWAYS_INLINE void HadamardRotation(__m128i* a, __m128i* b, bool flip) {
188   __m128i x, y;
189   if (flip) {
190     y = _mm_adds_epi16(*b, *a);
191     x = _mm_subs_epi16(*b, *a);
192   } else {
193     x = _mm_adds_epi16(*a, *b);
194     y = _mm_subs_epi16(*a, *b);
195   }
196   *a = x;
197   *b = y;
198 }
199 
200 using ButterflyRotationFunc = void (*)(__m128i* a, __m128i* b, int angle,
201                                        bool flip);
202 
ShiftResidual(const __m128i residual,const __m128i v_row_shift_add,const __m128i v_row_shift)203 LIBGAV1_ALWAYS_INLINE __m128i ShiftResidual(const __m128i residual,
204                                             const __m128i v_row_shift_add,
205                                             const __m128i v_row_shift) {
206   const __m128i k7ffd = _mm_set1_epi16(0x7ffd);
207   // The max row_shift is 2, so int16_t values greater than 0x7ffd may
208   // overflow.  Generate a mask for this case.
209   const __m128i mask = _mm_cmpgt_epi16(residual, k7ffd);
210   const __m128i x = _mm_add_epi16(residual, v_row_shift_add);
211   // Assume int16_t values.
212   const __m128i a = _mm_sra_epi16(x, v_row_shift);
213   // Assume uint16_t values.
214   const __m128i b = _mm_srl_epi16(x, v_row_shift);
215   // Select the correct shifted value.
216   return _mm_blendv_epi8(a, b, mask);
217 }
218 
219 //------------------------------------------------------------------------------
220 // Discrete Cosine Transforms (DCT).
221 
222 template <int width>
DctDcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)223 LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, int adjusted_tx_height,
224                                      bool should_round, int row_shift) {
225   if (adjusted_tx_height > 1) return false;
226 
227   auto* dst = static_cast<int16_t*>(dest);
228   const __m128i v_src_lo = _mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0);
229   const __m128i v_src =
230       (width == 4) ? v_src_lo : _mm_shuffle_epi32(v_src_lo, 0);
231   const __m128i v_mask =
232       _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0);
233   const __m128i v_kTransformRowMultiplier =
234       _mm_set1_epi16(kTransformRowMultiplier << 3);
235   const __m128i v_src_round =
236       _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier);
237   const __m128i s0 = _mm_blendv_epi8(v_src, v_src_round, v_mask);
238   const int16_t cos128 = Cos128(32);
239   const __m128i xy = _mm_mulhrs_epi16(s0, _mm_set1_epi16(cos128 << 3));
240 
241   // Expand to 32 bits to prevent int16_t overflows during the shift add.
242   const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
243   const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
244   const __m128i a = _mm_cvtepi16_epi32(xy);
245   const __m128i a1 = _mm_cvtepi16_epi32(_mm_srli_si128(xy, 8));
246   const __m128i b = _mm_add_epi32(a, v_row_shift_add);
247   const __m128i b1 = _mm_add_epi32(a1, v_row_shift_add);
248   const __m128i c = _mm_sra_epi32(b, v_row_shift);
249   const __m128i c1 = _mm_sra_epi32(b1, v_row_shift);
250   const __m128i xy_shifted = _mm_packs_epi32(c, c1);
251 
252   if (width == 4) {
253     StoreLo8(dst, xy_shifted);
254   } else {
255     for (int i = 0; i < width; i += 8) {
256       StoreUnaligned16(dst, xy_shifted);
257       dst += 8;
258     }
259   }
260   return true;
261 }
262 
263 template <int height>
DctDcOnlyColumn(void * dest,int adjusted_tx_height,int width)264 LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, int adjusted_tx_height,
265                                            int width) {
266   if (adjusted_tx_height > 1) return false;
267 
268   auto* dst = static_cast<int16_t*>(dest);
269   const int16_t cos128 = Cos128(32);
270 
271   // Calculate dc values for first row.
272   if (width == 4) {
273     const __m128i v_src = LoadLo8(dst);
274     const __m128i xy = _mm_mulhrs_epi16(v_src, _mm_set1_epi16(cos128 << 3));
275     StoreLo8(dst, xy);
276   } else {
277     int i = 0;
278     do {
279       const __m128i v_src = LoadUnaligned16(&dst[i]);
280       const __m128i xy = _mm_mulhrs_epi16(v_src, _mm_set1_epi16(cos128 << 3));
281       StoreUnaligned16(&dst[i], xy);
282       i += 8;
283     } while (i < width);
284   }
285 
286   // Copy first row to the rest of the block.
287   for (int y = 1; y < height; ++y) {
288     memcpy(&dst[y * width], dst, width * sizeof(dst[0]));
289   }
290   return true;
291 }
292 
293 template <ButterflyRotationFunc butterfly_rotation,
294           bool is_fast_butterfly = false>
Dct4Stages(__m128i * s)295 LIBGAV1_ALWAYS_INLINE void Dct4Stages(__m128i* s) {
296   // stage 12.
297   if (is_fast_butterfly) {
298     ButterflyRotation_SecondIsZero(&s[0], &s[1], 32, true);
299     ButterflyRotation_SecondIsZero(&s[2], &s[3], 48, false);
300   } else {
301     butterfly_rotation(&s[0], &s[1], 32, true);
302     butterfly_rotation(&s[2], &s[3], 48, false);
303   }
304 
305   // stage 17.
306   HadamardRotation(&s[0], &s[3], false);
307   HadamardRotation(&s[1], &s[2], false);
308 }
309 
310 // Process 4 dct4 rows or columns, depending on the transpose flag.
311 template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
Dct4_SSE4_1(void * dest,int32_t step,bool transpose)312 LIBGAV1_ALWAYS_INLINE void Dct4_SSE4_1(void* dest, int32_t step,
313                                        bool transpose) {
314   auto* const dst = static_cast<int16_t*>(dest);
315   __m128i s[4], x[4];
316 
317   if (stage_is_rectangular) {
318     if (transpose) {
319       __m128i input[8];
320       LoadSrc<8, 8>(dst, step, 0, input);
321       Transpose4x8To8x4_U16(input, x);
322     } else {
323       LoadSrc<16, 4>(dst, step, 0, x);
324     }
325   } else {
326     LoadSrc<8, 4>(dst, step, 0, x);
327     if (transpose) {
328       Transpose4x4_U16(x, x);
329     }
330   }
331   // stage 1.
332   // kBitReverseLookup 0, 2, 1, 3
333   s[0] = x[0];
334   s[1] = x[2];
335   s[2] = x[1];
336   s[3] = x[3];
337 
338   Dct4Stages<butterfly_rotation>(s);
339 
340   if (stage_is_rectangular) {
341     if (transpose) {
342       __m128i output[8];
343       Transpose8x4To4x8_U16(s, output);
344       StoreDst<8, 8>(dst, step, 0, output);
345     } else {
346       StoreDst<16, 4>(dst, step, 0, s);
347     }
348   } else {
349     if (transpose) {
350       Transpose4x4_U16(s, s);
351     }
352     StoreDst<8, 4>(dst, step, 0, s);
353   }
354 }
355 
356 template <ButterflyRotationFunc butterfly_rotation,
357           bool is_fast_butterfly = false>
Dct8Stages(__m128i * s)358 LIBGAV1_ALWAYS_INLINE void Dct8Stages(__m128i* s) {
359   // stage 8.
360   if (is_fast_butterfly) {
361     ButterflyRotation_SecondIsZero(&s[4], &s[7], 56, false);
362     ButterflyRotation_FirstIsZero(&s[5], &s[6], 24, false);
363   } else {
364     butterfly_rotation(&s[4], &s[7], 56, false);
365     butterfly_rotation(&s[5], &s[6], 24, false);
366   }
367 
368   // stage 13.
369   HadamardRotation(&s[4], &s[5], false);
370   HadamardRotation(&s[6], &s[7], true);
371 
372   // stage 18.
373   butterfly_rotation(&s[6], &s[5], 32, true);
374 
375   // stage 22.
376   HadamardRotation(&s[0], &s[7], false);
377   HadamardRotation(&s[1], &s[6], false);
378   HadamardRotation(&s[2], &s[5], false);
379   HadamardRotation(&s[3], &s[4], false);
380 }
381 
382 // Process dct8 rows or columns, depending on the transpose flag.
383 template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
Dct8_SSE4_1(void * dest,int32_t step,bool transpose)384 LIBGAV1_ALWAYS_INLINE void Dct8_SSE4_1(void* dest, int32_t step,
385                                        bool transpose) {
386   auto* const dst = static_cast<int16_t*>(dest);
387   __m128i s[8], x[8];
388 
389   if (stage_is_rectangular) {
390     if (transpose) {
391       __m128i input[4];
392       LoadSrc<16, 4>(dst, step, 0, input);
393       Transpose8x4To4x8_U16(input, x);
394     } else {
395       LoadSrc<8, 8>(dst, step, 0, x);
396     }
397   } else {
398     if (transpose) {
399       __m128i input[8];
400       LoadSrc<16, 8>(dst, step, 0, input);
401       Transpose8x8_U16(input, x);
402     } else {
403       LoadSrc<16, 8>(dst, step, 0, x);
404     }
405   }
406 
407   // stage 1.
408   // kBitReverseLookup 0, 4, 2, 6, 1, 5, 3, 7,
409   s[0] = x[0];
410   s[1] = x[4];
411   s[2] = x[2];
412   s[3] = x[6];
413   s[4] = x[1];
414   s[5] = x[5];
415   s[6] = x[3];
416   s[7] = x[7];
417 
418   Dct4Stages<butterfly_rotation>(s);
419   Dct8Stages<butterfly_rotation>(s);
420 
421   if (stage_is_rectangular) {
422     if (transpose) {
423       __m128i output[4];
424       Transpose4x8To8x4_U16(s, output);
425       StoreDst<16, 4>(dst, step, 0, output);
426     } else {
427       StoreDst<8, 8>(dst, step, 0, s);
428     }
429   } else {
430     if (transpose) {
431       __m128i output[8];
432       Transpose8x8_U16(s, output);
433       StoreDst<16, 8>(dst, step, 0, output);
434     } else {
435       StoreDst<16, 8>(dst, step, 0, s);
436     }
437   }
438 }
439 
440 template <ButterflyRotationFunc butterfly_rotation,
441           bool is_fast_butterfly = false>
Dct16Stages(__m128i * s)442 LIBGAV1_ALWAYS_INLINE void Dct16Stages(__m128i* s) {
443   // stage 5.
444   if (is_fast_butterfly) {
445     ButterflyRotation_SecondIsZero(&s[8], &s[15], 60, false);
446     ButterflyRotation_FirstIsZero(&s[9], &s[14], 28, false);
447     ButterflyRotation_SecondIsZero(&s[10], &s[13], 44, false);
448     ButterflyRotation_FirstIsZero(&s[11], &s[12], 12, false);
449   } else {
450     butterfly_rotation(&s[8], &s[15], 60, false);
451     butterfly_rotation(&s[9], &s[14], 28, false);
452     butterfly_rotation(&s[10], &s[13], 44, false);
453     butterfly_rotation(&s[11], &s[12], 12, false);
454   }
455 
456   // stage 9.
457   HadamardRotation(&s[8], &s[9], false);
458   HadamardRotation(&s[10], &s[11], true);
459   HadamardRotation(&s[12], &s[13], false);
460   HadamardRotation(&s[14], &s[15], true);
461 
462   // stage 14.
463   butterfly_rotation(&s[14], &s[9], 48, true);
464   butterfly_rotation(&s[13], &s[10], 112, true);
465 
466   // stage 19.
467   HadamardRotation(&s[8], &s[11], false);
468   HadamardRotation(&s[9], &s[10], false);
469   HadamardRotation(&s[12], &s[15], true);
470   HadamardRotation(&s[13], &s[14], true);
471 
472   // stage 23.
473   butterfly_rotation(&s[13], &s[10], 32, true);
474   butterfly_rotation(&s[12], &s[11], 32, true);
475 
476   // stage 26.
477   HadamardRotation(&s[0], &s[15], false);
478   HadamardRotation(&s[1], &s[14], false);
479   HadamardRotation(&s[2], &s[13], false);
480   HadamardRotation(&s[3], &s[12], false);
481   HadamardRotation(&s[4], &s[11], false);
482   HadamardRotation(&s[5], &s[10], false);
483   HadamardRotation(&s[6], &s[9], false);
484   HadamardRotation(&s[7], &s[8], false);
485 }
486 
487 // Process dct16 rows or columns, depending on the transpose flag.
488 template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
Dct16_SSE4_1(void * dest,int32_t step,bool transpose)489 LIBGAV1_ALWAYS_INLINE void Dct16_SSE4_1(void* dest, int32_t step,
490                                         bool transpose) {
491   auto* const dst = static_cast<int16_t*>(dest);
492   __m128i s[16], x[16];
493 
494   if (stage_is_rectangular) {
495     if (transpose) {
496       __m128i input[4];
497       LoadSrc<16, 4>(dst, step, 0, input);
498       Transpose8x4To4x8_U16(input, x);
499       LoadSrc<16, 4>(dst, step, 8, input);
500       Transpose8x4To4x8_U16(input, &x[8]);
501     } else {
502       LoadSrc<8, 16>(dst, step, 0, x);
503     }
504   } else {
505     if (transpose) {
506       for (int idx = 0; idx < 16; idx += 8) {
507         __m128i input[8];
508         LoadSrc<16, 8>(dst, step, idx, input);
509         Transpose8x8_U16(input, &x[idx]);
510       }
511     } else {
512       LoadSrc<16, 16>(dst, step, 0, x);
513     }
514   }
515 
516   // stage 1
517   // kBitReverseLookup 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15,
518   s[0] = x[0];
519   s[1] = x[8];
520   s[2] = x[4];
521   s[3] = x[12];
522   s[4] = x[2];
523   s[5] = x[10];
524   s[6] = x[6];
525   s[7] = x[14];
526   s[8] = x[1];
527   s[9] = x[9];
528   s[10] = x[5];
529   s[11] = x[13];
530   s[12] = x[3];
531   s[13] = x[11];
532   s[14] = x[7];
533   s[15] = x[15];
534 
535   Dct4Stages<butterfly_rotation>(s);
536   Dct8Stages<butterfly_rotation>(s);
537   Dct16Stages<butterfly_rotation>(s);
538 
539   if (stage_is_rectangular) {
540     if (transpose) {
541       __m128i output[4];
542       Transpose4x8To8x4_U16(s, output);
543       StoreDst<16, 4>(dst, step, 0, output);
544       Transpose4x8To8x4_U16(&s[8], output);
545       StoreDst<16, 4>(dst, step, 8, output);
546     } else {
547       StoreDst<8, 16>(dst, step, 0, s);
548     }
549   } else {
550     if (transpose) {
551       for (int idx = 0; idx < 16; idx += 8) {
552         __m128i output[8];
553         Transpose8x8_U16(&s[idx], output);
554         StoreDst<16, 8>(dst, step, idx, output);
555       }
556     } else {
557       StoreDst<16, 16>(dst, step, 0, s);
558     }
559   }
560 }
561 
562 template <ButterflyRotationFunc butterfly_rotation,
563           bool is_fast_butterfly = false>
Dct32Stages(__m128i * s)564 LIBGAV1_ALWAYS_INLINE void Dct32Stages(__m128i* s) {
565   // stage 3
566   if (is_fast_butterfly) {
567     ButterflyRotation_SecondIsZero(&s[16], &s[31], 62, false);
568     ButterflyRotation_FirstIsZero(&s[17], &s[30], 30, false);
569     ButterflyRotation_SecondIsZero(&s[18], &s[29], 46, false);
570     ButterflyRotation_FirstIsZero(&s[19], &s[28], 14, false);
571     ButterflyRotation_SecondIsZero(&s[20], &s[27], 54, false);
572     ButterflyRotation_FirstIsZero(&s[21], &s[26], 22, false);
573     ButterflyRotation_SecondIsZero(&s[22], &s[25], 38, false);
574     ButterflyRotation_FirstIsZero(&s[23], &s[24], 6, false);
575   } else {
576     butterfly_rotation(&s[16], &s[31], 62, false);
577     butterfly_rotation(&s[17], &s[30], 30, false);
578     butterfly_rotation(&s[18], &s[29], 46, false);
579     butterfly_rotation(&s[19], &s[28], 14, false);
580     butterfly_rotation(&s[20], &s[27], 54, false);
581     butterfly_rotation(&s[21], &s[26], 22, false);
582     butterfly_rotation(&s[22], &s[25], 38, false);
583     butterfly_rotation(&s[23], &s[24], 6, false);
584   }
585   // stage 6.
586   HadamardRotation(&s[16], &s[17], false);
587   HadamardRotation(&s[18], &s[19], true);
588   HadamardRotation(&s[20], &s[21], false);
589   HadamardRotation(&s[22], &s[23], true);
590   HadamardRotation(&s[24], &s[25], false);
591   HadamardRotation(&s[26], &s[27], true);
592   HadamardRotation(&s[28], &s[29], false);
593   HadamardRotation(&s[30], &s[31], true);
594 
595   // stage 10.
596   butterfly_rotation(&s[30], &s[17], 24 + 32, true);
597   butterfly_rotation(&s[29], &s[18], 24 + 64 + 32, true);
598   butterfly_rotation(&s[26], &s[21], 24, true);
599   butterfly_rotation(&s[25], &s[22], 24 + 64, true);
600 
601   // stage 15.
602   HadamardRotation(&s[16], &s[19], false);
603   HadamardRotation(&s[17], &s[18], false);
604   HadamardRotation(&s[20], &s[23], true);
605   HadamardRotation(&s[21], &s[22], true);
606   HadamardRotation(&s[24], &s[27], false);
607   HadamardRotation(&s[25], &s[26], false);
608   HadamardRotation(&s[28], &s[31], true);
609   HadamardRotation(&s[29], &s[30], true);
610 
611   // stage 20.
612   butterfly_rotation(&s[29], &s[18], 48, true);
613   butterfly_rotation(&s[28], &s[19], 48, true);
614   butterfly_rotation(&s[27], &s[20], 48 + 64, true);
615   butterfly_rotation(&s[26], &s[21], 48 + 64, true);
616 
617   // stage 24.
618   HadamardRotation(&s[16], &s[23], false);
619   HadamardRotation(&s[17], &s[22], false);
620   HadamardRotation(&s[18], &s[21], false);
621   HadamardRotation(&s[19], &s[20], false);
622   HadamardRotation(&s[24], &s[31], true);
623   HadamardRotation(&s[25], &s[30], true);
624   HadamardRotation(&s[26], &s[29], true);
625   HadamardRotation(&s[27], &s[28], true);
626 
627   // stage 27.
628   butterfly_rotation(&s[27], &s[20], 32, true);
629   butterfly_rotation(&s[26], &s[21], 32, true);
630   butterfly_rotation(&s[25], &s[22], 32, true);
631   butterfly_rotation(&s[24], &s[23], 32, true);
632 
633   // stage 29.
634   HadamardRotation(&s[0], &s[31], false);
635   HadamardRotation(&s[1], &s[30], false);
636   HadamardRotation(&s[2], &s[29], false);
637   HadamardRotation(&s[3], &s[28], false);
638   HadamardRotation(&s[4], &s[27], false);
639   HadamardRotation(&s[5], &s[26], false);
640   HadamardRotation(&s[6], &s[25], false);
641   HadamardRotation(&s[7], &s[24], false);
642   HadamardRotation(&s[8], &s[23], false);
643   HadamardRotation(&s[9], &s[22], false);
644   HadamardRotation(&s[10], &s[21], false);
645   HadamardRotation(&s[11], &s[20], false);
646   HadamardRotation(&s[12], &s[19], false);
647   HadamardRotation(&s[13], &s[18], false);
648   HadamardRotation(&s[14], &s[17], false);
649   HadamardRotation(&s[15], &s[16], false);
650 }
651 
652 // Process dct32 rows or columns, depending on the transpose flag.
Dct32_SSE4_1(void * dest,const int32_t step,const bool transpose)653 LIBGAV1_ALWAYS_INLINE void Dct32_SSE4_1(void* dest, const int32_t step,
654                                         const bool transpose) {
655   auto* const dst = static_cast<int16_t*>(dest);
656   __m128i s[32], x[32];
657 
658   if (transpose) {
659     for (int idx = 0; idx < 32; idx += 8) {
660       __m128i input[8];
661       LoadSrc<16, 8>(dst, step, idx, input);
662       Transpose8x8_U16(input, &x[idx]);
663     }
664   } else {
665     LoadSrc<16, 32>(dst, step, 0, x);
666   }
667 
668   // stage 1
669   // kBitReverseLookup
670   // 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30,
671   s[0] = x[0];
672   s[1] = x[16];
673   s[2] = x[8];
674   s[3] = x[24];
675   s[4] = x[4];
676   s[5] = x[20];
677   s[6] = x[12];
678   s[7] = x[28];
679   s[8] = x[2];
680   s[9] = x[18];
681   s[10] = x[10];
682   s[11] = x[26];
683   s[12] = x[6];
684   s[13] = x[22];
685   s[14] = x[14];
686   s[15] = x[30];
687 
688   // 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31,
689   s[16] = x[1];
690   s[17] = x[17];
691   s[18] = x[9];
692   s[19] = x[25];
693   s[20] = x[5];
694   s[21] = x[21];
695   s[22] = x[13];
696   s[23] = x[29];
697   s[24] = x[3];
698   s[25] = x[19];
699   s[26] = x[11];
700   s[27] = x[27];
701   s[28] = x[7];
702   s[29] = x[23];
703   s[30] = x[15];
704   s[31] = x[31];
705 
706   Dct4Stages<ButterflyRotation_8>(s);
707   Dct8Stages<ButterflyRotation_8>(s);
708   Dct16Stages<ButterflyRotation_8>(s);
709   Dct32Stages<ButterflyRotation_8>(s);
710 
711   if (transpose) {
712     for (int idx = 0; idx < 32; idx += 8) {
713       __m128i output[8];
714       Transpose8x8_U16(&s[idx], output);
715       StoreDst<16, 8>(dst, step, idx, output);
716     }
717   } else {
718     StoreDst<16, 32>(dst, step, 0, s);
719   }
720 }
721 
722 // Allow the compiler to call this function instead of force inlining. Tests
723 // show the performance is slightly faster.
Dct64_SSE4_1(void * dest,int32_t step,bool transpose)724 void Dct64_SSE4_1(void* dest, int32_t step, bool transpose) {
725   auto* const dst = static_cast<int16_t*>(dest);
726   __m128i s[64], x[32];
727 
728   if (transpose) {
729     // The last 32 values of every row are always zero if the |tx_width| is
730     // 64.
731     for (int idx = 0; idx < 32; idx += 8) {
732       __m128i input[8];
733       LoadSrc<16, 8>(dst, step, idx, input);
734       Transpose8x8_U16(input, &x[idx]);
735     }
736   } else {
737     // The last 32 values of every column are always zero if the |tx_height| is
738     // 64.
739     LoadSrc<16, 32>(dst, step, 0, x);
740   }
741 
742   // stage 1
743   // kBitReverseLookup
744   // 0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60,
745   s[0] = x[0];
746   s[2] = x[16];
747   s[4] = x[8];
748   s[6] = x[24];
749   s[8] = x[4];
750   s[10] = x[20];
751   s[12] = x[12];
752   s[14] = x[28];
753 
754   // 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62,
755   s[16] = x[2];
756   s[18] = x[18];
757   s[20] = x[10];
758   s[22] = x[26];
759   s[24] = x[6];
760   s[26] = x[22];
761   s[28] = x[14];
762   s[30] = x[30];
763 
764   // 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61,
765   s[32] = x[1];
766   s[34] = x[17];
767   s[36] = x[9];
768   s[38] = x[25];
769   s[40] = x[5];
770   s[42] = x[21];
771   s[44] = x[13];
772   s[46] = x[29];
773 
774   // 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63
775   s[48] = x[3];
776   s[50] = x[19];
777   s[52] = x[11];
778   s[54] = x[27];
779   s[56] = x[7];
780   s[58] = x[23];
781   s[60] = x[15];
782   s[62] = x[31];
783 
784   Dct4Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
785   Dct8Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
786   Dct16Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
787   Dct32Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
788 
789   //-- start dct 64 stages
790   // stage 2.
791   ButterflyRotation_SecondIsZero(&s[32], &s[63], 63 - 0, false);
792   ButterflyRotation_FirstIsZero(&s[33], &s[62], 63 - 32, false);
793   ButterflyRotation_SecondIsZero(&s[34], &s[61], 63 - 16, false);
794   ButterflyRotation_FirstIsZero(&s[35], &s[60], 63 - 48, false);
795   ButterflyRotation_SecondIsZero(&s[36], &s[59], 63 - 8, false);
796   ButterflyRotation_FirstIsZero(&s[37], &s[58], 63 - 40, false);
797   ButterflyRotation_SecondIsZero(&s[38], &s[57], 63 - 24, false);
798   ButterflyRotation_FirstIsZero(&s[39], &s[56], 63 - 56, false);
799   ButterflyRotation_SecondIsZero(&s[40], &s[55], 63 - 4, false);
800   ButterflyRotation_FirstIsZero(&s[41], &s[54], 63 - 36, false);
801   ButterflyRotation_SecondIsZero(&s[42], &s[53], 63 - 20, false);
802   ButterflyRotation_FirstIsZero(&s[43], &s[52], 63 - 52, false);
803   ButterflyRotation_SecondIsZero(&s[44], &s[51], 63 - 12, false);
804   ButterflyRotation_FirstIsZero(&s[45], &s[50], 63 - 44, false);
805   ButterflyRotation_SecondIsZero(&s[46], &s[49], 63 - 28, false);
806   ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false);
807 
808   // stage 4.
809   HadamardRotation(&s[32], &s[33], false);
810   HadamardRotation(&s[34], &s[35], true);
811   HadamardRotation(&s[36], &s[37], false);
812   HadamardRotation(&s[38], &s[39], true);
813   HadamardRotation(&s[40], &s[41], false);
814   HadamardRotation(&s[42], &s[43], true);
815   HadamardRotation(&s[44], &s[45], false);
816   HadamardRotation(&s[46], &s[47], true);
817   HadamardRotation(&s[48], &s[49], false);
818   HadamardRotation(&s[50], &s[51], true);
819   HadamardRotation(&s[52], &s[53], false);
820   HadamardRotation(&s[54], &s[55], true);
821   HadamardRotation(&s[56], &s[57], false);
822   HadamardRotation(&s[58], &s[59], true);
823   HadamardRotation(&s[60], &s[61], false);
824   HadamardRotation(&s[62], &s[63], true);
825 
826   // stage 7.
827   ButterflyRotation_8(&s[62], &s[33], 60 - 0, true);
828   ButterflyRotation_8(&s[61], &s[34], 60 - 0 + 64, true);
829   ButterflyRotation_8(&s[58], &s[37], 60 - 32, true);
830   ButterflyRotation_8(&s[57], &s[38], 60 - 32 + 64, true);
831   ButterflyRotation_8(&s[54], &s[41], 60 - 16, true);
832   ButterflyRotation_8(&s[53], &s[42], 60 - 16 + 64, true);
833   ButterflyRotation_8(&s[50], &s[45], 60 - 48, true);
834   ButterflyRotation_8(&s[49], &s[46], 60 - 48 + 64, true);
835 
836   // stage 11.
837   HadamardRotation(&s[32], &s[35], false);
838   HadamardRotation(&s[33], &s[34], false);
839   HadamardRotation(&s[36], &s[39], true);
840   HadamardRotation(&s[37], &s[38], true);
841   HadamardRotation(&s[40], &s[43], false);
842   HadamardRotation(&s[41], &s[42], false);
843   HadamardRotation(&s[44], &s[47], true);
844   HadamardRotation(&s[45], &s[46], true);
845   HadamardRotation(&s[48], &s[51], false);
846   HadamardRotation(&s[49], &s[50], false);
847   HadamardRotation(&s[52], &s[55], true);
848   HadamardRotation(&s[53], &s[54], true);
849   HadamardRotation(&s[56], &s[59], false);
850   HadamardRotation(&s[57], &s[58], false);
851   HadamardRotation(&s[60], &s[63], true);
852   HadamardRotation(&s[61], &s[62], true);
853 
854   // stage 16.
855   ButterflyRotation_8(&s[61], &s[34], 56, true);
856   ButterflyRotation_8(&s[60], &s[35], 56, true);
857   ButterflyRotation_8(&s[59], &s[36], 56 + 64, true);
858   ButterflyRotation_8(&s[58], &s[37], 56 + 64, true);
859   ButterflyRotation_8(&s[53], &s[42], 56 - 32, true);
860   ButterflyRotation_8(&s[52], &s[43], 56 - 32, true);
861   ButterflyRotation_8(&s[51], &s[44], 56 - 32 + 64, true);
862   ButterflyRotation_8(&s[50], &s[45], 56 - 32 + 64, true);
863 
864   // stage 21.
865   HadamardRotation(&s[32], &s[39], false);
866   HadamardRotation(&s[33], &s[38], false);
867   HadamardRotation(&s[34], &s[37], false);
868   HadamardRotation(&s[35], &s[36], false);
869   HadamardRotation(&s[40], &s[47], true);
870   HadamardRotation(&s[41], &s[46], true);
871   HadamardRotation(&s[42], &s[45], true);
872   HadamardRotation(&s[43], &s[44], true);
873   HadamardRotation(&s[48], &s[55], false);
874   HadamardRotation(&s[49], &s[54], false);
875   HadamardRotation(&s[50], &s[53], false);
876   HadamardRotation(&s[51], &s[52], false);
877   HadamardRotation(&s[56], &s[63], true);
878   HadamardRotation(&s[57], &s[62], true);
879   HadamardRotation(&s[58], &s[61], true);
880   HadamardRotation(&s[59], &s[60], true);
881 
882   // stage 25.
883   ButterflyRotation_8(&s[59], &s[36], 48, true);
884   ButterflyRotation_8(&s[58], &s[37], 48, true);
885   ButterflyRotation_8(&s[57], &s[38], 48, true);
886   ButterflyRotation_8(&s[56], &s[39], 48, true);
887   ButterflyRotation_8(&s[55], &s[40], 112, true);
888   ButterflyRotation_8(&s[54], &s[41], 112, true);
889   ButterflyRotation_8(&s[53], &s[42], 112, true);
890   ButterflyRotation_8(&s[52], &s[43], 112, true);
891 
892   // stage 28.
893   HadamardRotation(&s[32], &s[47], false);
894   HadamardRotation(&s[33], &s[46], false);
895   HadamardRotation(&s[34], &s[45], false);
896   HadamardRotation(&s[35], &s[44], false);
897   HadamardRotation(&s[36], &s[43], false);
898   HadamardRotation(&s[37], &s[42], false);
899   HadamardRotation(&s[38], &s[41], false);
900   HadamardRotation(&s[39], &s[40], false);
901   HadamardRotation(&s[48], &s[63], true);
902   HadamardRotation(&s[49], &s[62], true);
903   HadamardRotation(&s[50], &s[61], true);
904   HadamardRotation(&s[51], &s[60], true);
905   HadamardRotation(&s[52], &s[59], true);
906   HadamardRotation(&s[53], &s[58], true);
907   HadamardRotation(&s[54], &s[57], true);
908   HadamardRotation(&s[55], &s[56], true);
909 
910   // stage 30.
911   ButterflyRotation_8(&s[55], &s[40], 32, true);
912   ButterflyRotation_8(&s[54], &s[41], 32, true);
913   ButterflyRotation_8(&s[53], &s[42], 32, true);
914   ButterflyRotation_8(&s[52], &s[43], 32, true);
915   ButterflyRotation_8(&s[51], &s[44], 32, true);
916   ButterflyRotation_8(&s[50], &s[45], 32, true);
917   ButterflyRotation_8(&s[49], &s[46], 32, true);
918   ButterflyRotation_8(&s[48], &s[47], 32, true);
919 
920   // stage 31.
921   for (int i = 0; i < 32; i += 4) {
922     HadamardRotation(&s[i], &s[63 - i], false);
923     HadamardRotation(&s[i + 1], &s[63 - i - 1], false);
924     HadamardRotation(&s[i + 2], &s[63 - i - 2], false);
925     HadamardRotation(&s[i + 3], &s[63 - i - 3], false);
926   }
927   //-- end dct 64 stages
928 
929   if (transpose) {
930     for (int idx = 0; idx < 64; idx += 8) {
931       __m128i output[8];
932       Transpose8x8_U16(&s[idx], output);
933       StoreDst<16, 8>(dst, step, idx, output);
934     }
935   } else {
936     StoreDst<16, 64>(dst, step, 0, s);
937   }
938 }
939 
940 //------------------------------------------------------------------------------
941 // Asymmetric Discrete Sine Transforms (ADST).
942 
943 template <bool stage_is_rectangular>
Adst4_SSE4_1(void * dest,int32_t step,bool transpose)944 LIBGAV1_ALWAYS_INLINE void Adst4_SSE4_1(void* dest, int32_t step,
945                                         bool transpose) {
946   auto* const dst = static_cast<int16_t*>(dest);
947   __m128i s[8], x[4];
948 
949   if (stage_is_rectangular) {
950     if (transpose) {
951       __m128i input[8];
952       LoadSrc<8, 8>(dst, step, 0, input);
953       Transpose4x8To8x4_U16(input, x);
954     } else {
955       LoadSrc<16, 4>(dst, step, 0, x);
956     }
957   } else {
958     LoadSrc<8, 4>(dst, step, 0, x);
959     if (transpose) {
960       Transpose4x4_U16(x, x);
961     }
962   }
963 
964   const __m128i kAdst4Multiplier_1 = _mm_set1_epi16(kAdst4Multiplier[1]);
965   const __m128i kAdst4Multiplier_2 = _mm_set1_epi16(kAdst4Multiplier[2]);
966   const __m128i kAdst4Multiplier_3 = _mm_set1_epi16(kAdst4Multiplier[3]);
967   const __m128i kAdst4Multiplier_m0_1 =
968       _mm_set1_epi32(static_cast<uint16_t>(kAdst4Multiplier[1]) |
969                      (static_cast<uint32_t>(-kAdst4Multiplier[0]) << 16));
970   const __m128i kAdst4Multiplier_3_0 =
971       _mm_set1_epi32(static_cast<uint16_t>(kAdst4Multiplier[0]) |
972                      (static_cast<uint32_t>(kAdst4Multiplier[3]) << 16));
973 
974   // stage 1.
975   const __m128i x3_x0 = _mm_unpacklo_epi16(x[0], x[3]);
976   const __m128i x2_x0 = _mm_unpacklo_epi16(x[0], x[2]);
977   const __m128i zero_x1 = _mm_cvtepu16_epi32(x[1]);
978   const __m128i zero_x2 = _mm_cvtepu16_epi32(x[2]);
979   const __m128i zero_x3 = _mm_cvtepu16_epi32(x[3]);
980 
981   s[5] = _mm_madd_epi16(zero_x3, kAdst4Multiplier_1);
982   s[6] = _mm_madd_epi16(zero_x3, kAdst4Multiplier_3);
983 
984   // stage 2.
985   // ((src[0] - src[2]) + src[3]) * kAdst4Multiplier[2]
986   const __m128i k2_x3_x0 = _mm_madd_epi16(x3_x0, kAdst4Multiplier_2);
987   const __m128i k2_zero_x2 = _mm_madd_epi16(zero_x2, kAdst4Multiplier_2);
988   const __m128i b7 = _mm_sub_epi32(k2_x3_x0, k2_zero_x2);
989 
990   // stage 3.
991   s[0] = _mm_madd_epi16(x2_x0, kAdst4Multiplier_3_0);
992   s[1] = _mm_madd_epi16(x2_x0, kAdst4Multiplier_m0_1);
993   s[2] = b7;
994   s[3] = _mm_madd_epi16(zero_x1, kAdst4Multiplier_2);
995 
996   // stage 4.
997   s[0] = _mm_add_epi32(s[0], s[5]);
998   s[1] = _mm_sub_epi32(s[1], s[6]);
999 
1000   // stages 5 and 6.
1001   x[0] = _mm_add_epi32(s[0], s[3]);
1002   x[1] = _mm_add_epi32(s[1], s[3]);
1003   x[2] = _mm_add_epi32(s[0], s[1]);
1004   x[3] = _mm_sub_epi32(x[2], s[3]);
1005 
1006   x[0] = RightShiftWithRounding_S32(x[0], 12);
1007   x[1] = RightShiftWithRounding_S32(x[1], 12);
1008   x[2] = RightShiftWithRounding_S32(s[2], 12);
1009   x[3] = RightShiftWithRounding_S32(x[3], 12);
1010 
1011   x[0] = _mm_packs_epi32(x[0], x[1]);
1012   x[2] = _mm_packs_epi32(x[2], x[3]);
1013   x[1] = _mm_srli_si128(x[0], 8);
1014   x[3] = _mm_srli_si128(x[2], 8);
1015 
1016   if (stage_is_rectangular) {
1017     if (transpose) {
1018       __m128i output[8];
1019       Transpose8x4To4x8_U16(x, output);
1020       StoreDst<8, 8>(dst, step, 0, output);
1021     } else {
1022       StoreDst<16, 4>(dst, step, 0, x);
1023     }
1024   } else {
1025     if (transpose) {
1026       Transpose4x4_U16(x, x);
1027     }
1028     StoreDst<8, 4>(dst, step, 0, x);
1029   }
1030 }
1031 
1032 constexpr int16_t kAdst4DcOnlyMultiplier[8] = {1321, 0, 2482, 0,
1033                                                3344, 0, 2482, 1321};
1034 
Adst4DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1035 LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, int adjusted_tx_height,
1036                                        bool should_round, int row_shift) {
1037   if (adjusted_tx_height > 1) return false;
1038 
1039   auto* dst = static_cast<int16_t*>(dest);
1040   const __m128i v_src =
1041       _mm_shuffle_epi32(_mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0), 0);
1042   const __m128i v_mask =
1043       _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0);
1044   const __m128i v_kTransformRowMultiplier =
1045       _mm_set1_epi16(kTransformRowMultiplier << 3);
1046   const __m128i v_src_round =
1047       _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier);
1048   const __m128i s0 = _mm_blendv_epi8(v_src, v_src_round, v_mask);
1049   const __m128i v_kAdst4DcOnlyMultipliers =
1050       LoadUnaligned16(kAdst4DcOnlyMultiplier);
1051   // s0*k0 s0*k1 s0*k2 s0*k1
1052   // +
1053   // s0*0  s0*0  s0*0  s0*k0
1054   const __m128i x3 = _mm_madd_epi16(s0, v_kAdst4DcOnlyMultipliers);
1055   const __m128i dst_0 = RightShiftWithRounding_S32(x3, 12);
1056   const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
1057   const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
1058   const __m128i a = _mm_add_epi32(dst_0, v_row_shift_add);
1059   const __m128i b = _mm_sra_epi32(a, v_row_shift);
1060   const __m128i c = _mm_packs_epi32(b, b);
1061   StoreLo8(dst, c);
1062 
1063   return true;
1064 }
1065 
Adst4DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1066 LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, int adjusted_tx_height,
1067                                              int width) {
1068   if (adjusted_tx_height > 1) return false;
1069 
1070   auto* dst = static_cast<int16_t*>(dest);
1071   int i = 0;
1072   do {
1073     const __m128i v_src = _mm_cvtepi16_epi32(LoadLo8(&dst[i]));
1074     const __m128i kAdst4Multiplier_0 = _mm_set1_epi32(kAdst4Multiplier[0]);
1075     const __m128i kAdst4Multiplier_1 = _mm_set1_epi32(kAdst4Multiplier[1]);
1076     const __m128i kAdst4Multiplier_2 = _mm_set1_epi32(kAdst4Multiplier[2]);
1077     const __m128i s0 = _mm_mullo_epi32(kAdst4Multiplier_0, v_src);
1078     const __m128i s1 = _mm_mullo_epi32(kAdst4Multiplier_1, v_src);
1079     const __m128i s2 = _mm_mullo_epi32(kAdst4Multiplier_2, v_src);
1080     const __m128i x0 = s0;
1081     const __m128i x1 = s1;
1082     const __m128i x2 = s2;
1083     const __m128i x3 = _mm_add_epi32(s0, s1);
1084     const __m128i dst_0 = RightShiftWithRounding_S32(x0, 12);
1085     const __m128i dst_1 = RightShiftWithRounding_S32(x1, 12);
1086     const __m128i dst_2 = RightShiftWithRounding_S32(x2, 12);
1087     const __m128i dst_3 = RightShiftWithRounding_S32(x3, 12);
1088     const __m128i dst_0_1 = _mm_packs_epi32(dst_0, dst_1);
1089     const __m128i dst_2_3 = _mm_packs_epi32(dst_2, dst_3);
1090     StoreLo8(&dst[i], dst_0_1);
1091     StoreHi8(&dst[i + width * 1], dst_0_1);
1092     StoreLo8(&dst[i + width * 2], dst_2_3);
1093     StoreHi8(&dst[i + width * 3], dst_2_3);
1094     i += 4;
1095   } while (i < width);
1096 
1097   return true;
1098 }
1099 
1100 template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
Adst8_SSE4_1(void * dest,int32_t step,bool transpose)1101 LIBGAV1_ALWAYS_INLINE void Adst8_SSE4_1(void* dest, int32_t step,
1102                                         bool transpose) {
1103   auto* const dst = static_cast<int16_t*>(dest);
1104   __m128i s[8], x[8];
1105 
1106   if (stage_is_rectangular) {
1107     if (transpose) {
1108       __m128i input[4];
1109       LoadSrc<16, 4>(dst, step, 0, input);
1110       Transpose8x4To4x8_U16(input, x);
1111     } else {
1112       LoadSrc<8, 8>(dst, step, 0, x);
1113     }
1114   } else {
1115     if (transpose) {
1116       __m128i input[8];
1117       LoadSrc<16, 8>(dst, step, 0, input);
1118       Transpose8x8_U16(input, x);
1119     } else {
1120       LoadSrc<16, 8>(dst, step, 0, x);
1121     }
1122   }
1123 
1124   // stage 1.
1125   s[0] = x[7];
1126   s[1] = x[0];
1127   s[2] = x[5];
1128   s[3] = x[2];
1129   s[4] = x[3];
1130   s[5] = x[4];
1131   s[6] = x[1];
1132   s[7] = x[6];
1133 
1134   // stage 2.
1135   butterfly_rotation(&s[0], &s[1], 60 - 0, true);
1136   butterfly_rotation(&s[2], &s[3], 60 - 16, true);
1137   butterfly_rotation(&s[4], &s[5], 60 - 32, true);
1138   butterfly_rotation(&s[6], &s[7], 60 - 48, true);
1139 
1140   // stage 3.
1141   HadamardRotation(&s[0], &s[4], false);
1142   HadamardRotation(&s[1], &s[5], false);
1143   HadamardRotation(&s[2], &s[6], false);
1144   HadamardRotation(&s[3], &s[7], false);
1145 
1146   // stage 4.
1147   butterfly_rotation(&s[4], &s[5], 48 - 0, true);
1148   butterfly_rotation(&s[7], &s[6], 48 - 32, true);
1149 
1150   // stage 5.
1151   HadamardRotation(&s[0], &s[2], false);
1152   HadamardRotation(&s[4], &s[6], false);
1153   HadamardRotation(&s[1], &s[3], false);
1154   HadamardRotation(&s[5], &s[7], false);
1155 
1156   // stage 6.
1157   butterfly_rotation(&s[2], &s[3], 32, true);
1158   butterfly_rotation(&s[6], &s[7], 32, true);
1159 
1160   // stage 7.
1161   const __m128i v_zero = _mm_setzero_si128();
1162   x[0] = s[0];
1163   x[1] = _mm_subs_epi16(v_zero, s[4]);
1164   x[2] = s[6];
1165   x[3] = _mm_subs_epi16(v_zero, s[2]);
1166   x[4] = s[3];
1167   x[5] = _mm_subs_epi16(v_zero, s[7]);
1168   x[6] = s[5];
1169   x[7] = _mm_subs_epi16(v_zero, s[1]);
1170 
1171   if (stage_is_rectangular) {
1172     if (transpose) {
1173       __m128i output[4];
1174       Transpose4x8To8x4_U16(x, output);
1175       StoreDst<16, 4>(dst, step, 0, output);
1176     } else {
1177       StoreDst<8, 8>(dst, step, 0, x);
1178     }
1179   } else {
1180     if (transpose) {
1181       __m128i output[8];
1182       Transpose8x8_U16(x, output);
1183       StoreDst<16, 8>(dst, step, 0, output);
1184     } else {
1185       StoreDst<16, 8>(dst, step, 0, x);
1186     }
1187   }
1188 }
1189 
Adst8DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1190 LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, int adjusted_tx_height,
1191                                        bool should_round, int row_shift) {
1192   if (adjusted_tx_height > 1) return false;
1193 
1194   auto* dst = static_cast<int16_t*>(dest);
1195   __m128i s[8];
1196 
1197   const __m128i v_src = _mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0);
1198   const __m128i v_mask =
1199       _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0);
1200   const __m128i v_kTransformRowMultiplier =
1201       _mm_set1_epi16(kTransformRowMultiplier << 3);
1202   const __m128i v_src_round =
1203       _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier);
1204   // stage 1.
1205   s[1] = _mm_blendv_epi8(v_src, v_src_round, v_mask);
1206 
1207   // stage 2.
1208   ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
1209 
1210   // stage 3.
1211   s[4] = s[0];
1212   s[5] = s[1];
1213 
1214   // stage 4.
1215   ButterflyRotation_4(&s[4], &s[5], 48, true);
1216 
1217   // stage 5.
1218   s[2] = s[0];
1219   s[3] = s[1];
1220   s[6] = s[4];
1221   s[7] = s[5];
1222 
1223   // stage 6.
1224   ButterflyRotation_4(&s[2], &s[3], 32, true);
1225   ButterflyRotation_4(&s[6], &s[7], 32, true);
1226 
1227   // stage 7.
1228   __m128i x[8];
1229   const __m128i v_zero = _mm_setzero_si128();
1230   x[0] = s[0];
1231   x[1] = _mm_subs_epi16(v_zero, s[4]);
1232   x[2] = s[6];
1233   x[3] = _mm_subs_epi16(v_zero, s[2]);
1234   x[4] = s[3];
1235   x[5] = _mm_subs_epi16(v_zero, s[7]);
1236   x[6] = s[5];
1237   x[7] = _mm_subs_epi16(v_zero, s[1]);
1238 
1239   const __m128i x1_x0 = _mm_unpacklo_epi16(x[0], x[1]);
1240   const __m128i x3_x2 = _mm_unpacklo_epi16(x[2], x[3]);
1241   const __m128i x5_x4 = _mm_unpacklo_epi16(x[4], x[5]);
1242   const __m128i x7_x6 = _mm_unpacklo_epi16(x[6], x[7]);
1243   const __m128i x3_x0 = _mm_unpacklo_epi32(x1_x0, x3_x2);
1244   const __m128i x7_x4 = _mm_unpacklo_epi32(x5_x4, x7_x6);
1245 
1246   const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
1247   const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
1248   const __m128i a = _mm_add_epi32(_mm_cvtepi16_epi32(x3_x0), v_row_shift_add);
1249   const __m128i a1 = _mm_add_epi32(_mm_cvtepi16_epi32(x7_x4), v_row_shift_add);
1250   const __m128i b = _mm_sra_epi32(a, v_row_shift);
1251   const __m128i b1 = _mm_sra_epi32(a1, v_row_shift);
1252   StoreUnaligned16(dst, _mm_packs_epi32(b, b1));
1253 
1254   return true;
1255 }
1256 
Adst8DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1257 LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, int adjusted_tx_height,
1258                                              int width) {
1259   if (adjusted_tx_height > 1) return false;
1260 
1261   auto* dst = static_cast<int16_t*>(dest);
1262   __m128i s[8];
1263 
1264   int i = 0;
1265   do {
1266     const __m128i v_src = LoadLo8(dst);
1267     // stage 1.
1268     s[1] = v_src;
1269 
1270     // stage 2.
1271     ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
1272 
1273     // stage 3.
1274     s[4] = s[0];
1275     s[5] = s[1];
1276 
1277     // stage 4.
1278     ButterflyRotation_4(&s[4], &s[5], 48, true);
1279 
1280     // stage 5.
1281     s[2] = s[0];
1282     s[3] = s[1];
1283     s[6] = s[4];
1284     s[7] = s[5];
1285 
1286     // stage 6.
1287     ButterflyRotation_4(&s[2], &s[3], 32, true);
1288     ButterflyRotation_4(&s[6], &s[7], 32, true);
1289 
1290     // stage 7.
1291     __m128i x[8];
1292     const __m128i v_zero = _mm_setzero_si128();
1293     x[0] = s[0];
1294     x[1] = _mm_subs_epi16(v_zero, s[4]);
1295     x[2] = s[6];
1296     x[3] = _mm_subs_epi16(v_zero, s[2]);
1297     x[4] = s[3];
1298     x[5] = _mm_subs_epi16(v_zero, s[7]);
1299     x[6] = s[5];
1300     x[7] = _mm_subs_epi16(v_zero, s[1]);
1301 
1302     for (int j = 0; j < 8; ++j) {
1303       StoreLo8(&dst[j * width], x[j]);
1304     }
1305     i += 4;
1306     dst += 4;
1307   } while (i < width);
1308 
1309   return true;
1310 }
1311 
1312 template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
Adst16_SSE4_1(void * dest,int32_t step,bool transpose)1313 LIBGAV1_ALWAYS_INLINE void Adst16_SSE4_1(void* dest, int32_t step,
1314                                          bool transpose) {
1315   auto* const dst = static_cast<int16_t*>(dest);
1316   __m128i s[16], x[16];
1317 
1318   if (stage_is_rectangular) {
1319     if (transpose) {
1320       __m128i input[4];
1321       LoadSrc<16, 4>(dst, step, 0, input);
1322       Transpose8x4To4x8_U16(input, x);
1323       LoadSrc<16, 4>(dst, step, 8, input);
1324       Transpose8x4To4x8_U16(input, &x[8]);
1325     } else {
1326       LoadSrc<8, 16>(dst, step, 0, x);
1327     }
1328   } else {
1329     if (transpose) {
1330       for (int idx = 0; idx < 16; idx += 8) {
1331         __m128i input[8];
1332         LoadSrc<16, 8>(dst, step, idx, input);
1333         Transpose8x8_U16(input, &x[idx]);
1334       }
1335     } else {
1336       LoadSrc<16, 16>(dst, step, 0, x);
1337     }
1338   }
1339 
1340   // stage 1.
1341   s[0] = x[15];
1342   s[1] = x[0];
1343   s[2] = x[13];
1344   s[3] = x[2];
1345   s[4] = x[11];
1346   s[5] = x[4];
1347   s[6] = x[9];
1348   s[7] = x[6];
1349   s[8] = x[7];
1350   s[9] = x[8];
1351   s[10] = x[5];
1352   s[11] = x[10];
1353   s[12] = x[3];
1354   s[13] = x[12];
1355   s[14] = x[1];
1356   s[15] = x[14];
1357 
1358   // stage 2.
1359   butterfly_rotation(&s[0], &s[1], 62 - 0, true);
1360   butterfly_rotation(&s[2], &s[3], 62 - 8, true);
1361   butterfly_rotation(&s[4], &s[5], 62 - 16, true);
1362   butterfly_rotation(&s[6], &s[7], 62 - 24, true);
1363   butterfly_rotation(&s[8], &s[9], 62 - 32, true);
1364   butterfly_rotation(&s[10], &s[11], 62 - 40, true);
1365   butterfly_rotation(&s[12], &s[13], 62 - 48, true);
1366   butterfly_rotation(&s[14], &s[15], 62 - 56, true);
1367 
1368   // stage 3.
1369   HadamardRotation(&s[0], &s[8], false);
1370   HadamardRotation(&s[1], &s[9], false);
1371   HadamardRotation(&s[2], &s[10], false);
1372   HadamardRotation(&s[3], &s[11], false);
1373   HadamardRotation(&s[4], &s[12], false);
1374   HadamardRotation(&s[5], &s[13], false);
1375   HadamardRotation(&s[6], &s[14], false);
1376   HadamardRotation(&s[7], &s[15], false);
1377 
1378   // stage 4.
1379   butterfly_rotation(&s[8], &s[9], 56 - 0, true);
1380   butterfly_rotation(&s[13], &s[12], 8 + 0, true);
1381   butterfly_rotation(&s[10], &s[11], 56 - 32, true);
1382   butterfly_rotation(&s[15], &s[14], 8 + 32, true);
1383 
1384   // stage 5.
1385   HadamardRotation(&s[0], &s[4], false);
1386   HadamardRotation(&s[8], &s[12], false);
1387   HadamardRotation(&s[1], &s[5], false);
1388   HadamardRotation(&s[9], &s[13], false);
1389   HadamardRotation(&s[2], &s[6], false);
1390   HadamardRotation(&s[10], &s[14], false);
1391   HadamardRotation(&s[3], &s[7], false);
1392   HadamardRotation(&s[11], &s[15], false);
1393 
1394   // stage 6.
1395   butterfly_rotation(&s[4], &s[5], 48 - 0, true);
1396   butterfly_rotation(&s[12], &s[13], 48 - 0, true);
1397   butterfly_rotation(&s[7], &s[6], 48 - 32, true);
1398   butterfly_rotation(&s[15], &s[14], 48 - 32, true);
1399 
1400   // stage 7.
1401   HadamardRotation(&s[0], &s[2], false);
1402   HadamardRotation(&s[4], &s[6], false);
1403   HadamardRotation(&s[8], &s[10], false);
1404   HadamardRotation(&s[12], &s[14], false);
1405   HadamardRotation(&s[1], &s[3], false);
1406   HadamardRotation(&s[5], &s[7], false);
1407   HadamardRotation(&s[9], &s[11], false);
1408   HadamardRotation(&s[13], &s[15], false);
1409 
1410   // stage 8.
1411   butterfly_rotation(&s[2], &s[3], 32, true);
1412   butterfly_rotation(&s[6], &s[7], 32, true);
1413   butterfly_rotation(&s[10], &s[11], 32, true);
1414   butterfly_rotation(&s[14], &s[15], 32, true);
1415 
1416   // stage 9.
1417   const __m128i v_zero = _mm_setzero_si128();
1418   x[0] = s[0];
1419   x[1] = _mm_subs_epi16(v_zero, s[8]);
1420   x[2] = s[12];
1421   x[3] = _mm_subs_epi16(v_zero, s[4]);
1422   x[4] = s[6];
1423   x[5] = _mm_subs_epi16(v_zero, s[14]);
1424   x[6] = s[10];
1425   x[7] = _mm_subs_epi16(v_zero, s[2]);
1426   x[8] = s[3];
1427   x[9] = _mm_subs_epi16(v_zero, s[11]);
1428   x[10] = s[15];
1429   x[11] = _mm_subs_epi16(v_zero, s[7]);
1430   x[12] = s[5];
1431   x[13] = _mm_subs_epi16(v_zero, s[13]);
1432   x[14] = s[9];
1433   x[15] = _mm_subs_epi16(v_zero, s[1]);
1434 
1435   if (stage_is_rectangular) {
1436     if (transpose) {
1437       __m128i output[4];
1438       Transpose4x8To8x4_U16(x, output);
1439       StoreDst<16, 4>(dst, step, 0, output);
1440       Transpose4x8To8x4_U16(&x[8], output);
1441       StoreDst<16, 4>(dst, step, 8, output);
1442     } else {
1443       StoreDst<8, 16>(dst, step, 0, x);
1444     }
1445   } else {
1446     if (transpose) {
1447       for (int idx = 0; idx < 16; idx += 8) {
1448         __m128i output[8];
1449         Transpose8x8_U16(&x[idx], output);
1450         StoreDst<16, 8>(dst, step, idx, output);
1451       }
1452     } else {
1453       StoreDst<16, 16>(dst, step, 0, x);
1454     }
1455   }
1456 }
1457 
Adst16DcOnlyInternal(__m128i * s,__m128i * x)1458 LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(__m128i* s, __m128i* x) {
1459   // stage 2.
1460   ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true);
1461 
1462   // stage 3.
1463   s[8] = s[0];
1464   s[9] = s[1];
1465 
1466   // stage 4.
1467   ButterflyRotation_4(&s[8], &s[9], 56, true);
1468 
1469   // stage 5.
1470   s[4] = s[0];
1471   s[12] = s[8];
1472   s[5] = s[1];
1473   s[13] = s[9];
1474 
1475   // stage 6.
1476   ButterflyRotation_4(&s[4], &s[5], 48, true);
1477   ButterflyRotation_4(&s[12], &s[13], 48, true);
1478 
1479   // stage 7.
1480   s[2] = s[0];
1481   s[6] = s[4];
1482   s[10] = s[8];
1483   s[14] = s[12];
1484   s[3] = s[1];
1485   s[7] = s[5];
1486   s[11] = s[9];
1487   s[15] = s[13];
1488 
1489   // stage 8.
1490   ButterflyRotation_4(&s[2], &s[3], 32, true);
1491   ButterflyRotation_4(&s[6], &s[7], 32, true);
1492   ButterflyRotation_4(&s[10], &s[11], 32, true);
1493   ButterflyRotation_4(&s[14], &s[15], 32, true);
1494 
1495   // stage 9.
1496   const __m128i v_zero = _mm_setzero_si128();
1497   x[0] = s[0];
1498   x[1] = _mm_subs_epi16(v_zero, s[8]);
1499   x[2] = s[12];
1500   x[3] = _mm_subs_epi16(v_zero, s[4]);
1501   x[4] = s[6];
1502   x[5] = _mm_subs_epi16(v_zero, s[14]);
1503   x[6] = s[10];
1504   x[7] = _mm_subs_epi16(v_zero, s[2]);
1505   x[8] = s[3];
1506   x[9] = _mm_subs_epi16(v_zero, s[11]);
1507   x[10] = s[15];
1508   x[11] = _mm_subs_epi16(v_zero, s[7]);
1509   x[12] = s[5];
1510   x[13] = _mm_subs_epi16(v_zero, s[13]);
1511   x[14] = s[9];
1512   x[15] = _mm_subs_epi16(v_zero, s[1]);
1513 }
1514 
Adst16DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1515 LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, int adjusted_tx_height,
1516                                         bool should_round, int row_shift) {
1517   if (adjusted_tx_height > 1) return false;
1518 
1519   auto* dst = static_cast<int16_t*>(dest);
1520   __m128i s[16];
1521   __m128i x[16];
1522 
1523   const __m128i v_src = _mm_shufflelo_epi16(_mm_cvtsi32_si128(dst[0]), 0);
1524   const __m128i v_mask =
1525       _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0);
1526   const __m128i v_kTransformRowMultiplier =
1527       _mm_set1_epi16(kTransformRowMultiplier << 3);
1528   const __m128i v_src_round =
1529       _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier);
1530   // stage 1.
1531   s[1] = _mm_blendv_epi8(v_src, v_src_round, v_mask);
1532 
1533   Adst16DcOnlyInternal(s, x);
1534 
1535   for (int i = 0; i < 2; ++i) {
1536     const __m128i x1_x0 = _mm_unpacklo_epi16(x[0 + i * 8], x[1 + i * 8]);
1537     const __m128i x3_x2 = _mm_unpacklo_epi16(x[2 + i * 8], x[3 + i * 8]);
1538     const __m128i x5_x4 = _mm_unpacklo_epi16(x[4 + i * 8], x[5 + i * 8]);
1539     const __m128i x7_x6 = _mm_unpacklo_epi16(x[6 + i * 8], x[7 + i * 8]);
1540     const __m128i x3_x0 = _mm_unpacklo_epi32(x1_x0, x3_x2);
1541     const __m128i x7_x4 = _mm_unpacklo_epi32(x5_x4, x7_x6);
1542 
1543     const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
1544     const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
1545     const __m128i a = _mm_add_epi32(_mm_cvtepi16_epi32(x3_x0), v_row_shift_add);
1546     const __m128i a1 =
1547         _mm_add_epi32(_mm_cvtepi16_epi32(x7_x4), v_row_shift_add);
1548     const __m128i b = _mm_sra_epi32(a, v_row_shift);
1549     const __m128i b1 = _mm_sra_epi32(a1, v_row_shift);
1550     StoreUnaligned16(&dst[i * 8], _mm_packs_epi32(b, b1));
1551   }
1552   return true;
1553 }
1554 
Adst16DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1555 LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest,
1556                                               int adjusted_tx_height,
1557                                               int width) {
1558   if (adjusted_tx_height > 1) return false;
1559 
1560   auto* dst = static_cast<int16_t*>(dest);
1561   int i = 0;
1562   do {
1563     __m128i s[16];
1564     __m128i x[16];
1565     const __m128i v_src = LoadUnaligned16(dst);
1566     // stage 1.
1567     s[1] = v_src;
1568 
1569     Adst16DcOnlyInternal(s, x);
1570 
1571     for (int j = 0; j < 16; ++j) {
1572       StoreLo8(&dst[j * width], x[j]);
1573     }
1574     i += 4;
1575     dst += 4;
1576   } while (i < width);
1577 
1578   return true;
1579 }
1580 
1581 //------------------------------------------------------------------------------
1582 // Identity Transforms.
1583 
1584 template <bool is_row_shift>
Identity4_SSE4_1(void * dest,int32_t step)1585 LIBGAV1_ALWAYS_INLINE void Identity4_SSE4_1(void* dest, int32_t step) {
1586   auto* const dst = static_cast<int16_t*>(dest);
1587 
1588   if (is_row_shift) {
1589     const int shift = 1;
1590     const __m128i v_dual_round = _mm_set1_epi16((1 + (shift << 1)) << 11);
1591     const __m128i v_multiplier_one =
1592         _mm_set1_epi32((kIdentity4Multiplier << 16) | 0x0001);
1593     for (int i = 0; i < 4; i += 2) {
1594       const __m128i v_src = LoadUnaligned16(&dst[i * step]);
1595       const __m128i v_src_round = _mm_unpacklo_epi16(v_dual_round, v_src);
1596       const __m128i v_src_round_hi = _mm_unpackhi_epi16(v_dual_round, v_src);
1597       const __m128i a = _mm_madd_epi16(v_src_round, v_multiplier_one);
1598       const __m128i a_hi = _mm_madd_epi16(v_src_round_hi, v_multiplier_one);
1599       const __m128i b = _mm_srai_epi32(a, 12 + shift);
1600       const __m128i b_hi = _mm_srai_epi32(a_hi, 12 + shift);
1601       StoreUnaligned16(&dst[i * step], _mm_packs_epi32(b, b_hi));
1602     }
1603   } else {
1604     const __m128i v_multiplier =
1605         _mm_set1_epi16(kIdentity4MultiplierFraction << 3);
1606     for (int i = 0; i < 4; i += 2) {
1607       const __m128i v_src = LoadUnaligned16(&dst[i * step]);
1608       const __m128i a = _mm_mulhrs_epi16(v_src, v_multiplier);
1609       const __m128i b = _mm_adds_epi16(a, v_src);
1610       StoreUnaligned16(&dst[i * step], b);
1611     }
1612   }
1613 }
1614 
Identity4DcOnly(void * dest,int adjusted_tx_height,bool should_round,int tx_height)1615 LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, int adjusted_tx_height,
1616                                            bool should_round, int tx_height) {
1617   if (adjusted_tx_height > 1) return false;
1618 
1619   auto* dst = static_cast<int16_t*>(dest);
1620   const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]);
1621   const __m128i v_mask =
1622       _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0);
1623   const __m128i v_kTransformRowMultiplier =
1624       _mm_set1_epi16(kTransformRowMultiplier << 3);
1625   const __m128i v_src_round =
1626       _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier);
1627   const __m128i v_src = _mm_blendv_epi8(v_src0, v_src_round, v_mask);
1628 
1629   const int shift = (tx_height < 16) ? 0 : 1;
1630   const __m128i v_dual_round = _mm_set1_epi16((1 + (shift << 1)) << 11);
1631   const __m128i v_multiplier_one =
1632       _mm_set1_epi32((kIdentity4Multiplier << 16) | 0x0001);
1633   const __m128i v_src_round_lo = _mm_unpacklo_epi16(v_dual_round, v_src);
1634   const __m128i a = _mm_madd_epi16(v_src_round_lo, v_multiplier_one);
1635   const __m128i b = _mm_srai_epi32(a, 12 + shift);
1636   dst[0] = _mm_extract_epi16(_mm_packs_epi32(b, b), 0);
1637   return true;
1638 }
1639 
Identity4ColumnStoreToFrame(Array2DView<uint8_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int16_t * LIBGAV1_RESTRICT source)1640 LIBGAV1_ALWAYS_INLINE void Identity4ColumnStoreToFrame(
1641     Array2DView<uint8_t> frame, const int start_x, const int start_y,
1642     const int tx_width, const int tx_height,
1643     const int16_t* LIBGAV1_RESTRICT source) {
1644   const int stride = frame.columns();
1645   uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1646 
1647   const __m128i v_multiplier_fraction =
1648       _mm_set1_epi16(static_cast<int16_t>(kIdentity4MultiplierFraction << 3));
1649   const __m128i v_eight = _mm_set1_epi16(8);
1650 
1651   if (tx_width == 4) {
1652     int i = 0;
1653     do {
1654       const __m128i v_src = LoadLo8(&source[i * tx_width]);
1655       const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier_fraction);
1656       const __m128i frame_data = Load4(dst);
1657       const __m128i v_dst_i = _mm_adds_epi16(v_src_mult, v_src);
1658       const __m128i a = _mm_adds_epi16(v_dst_i, v_eight);
1659       const __m128i b = _mm_srai_epi16(a, 4);
1660       const __m128i c = _mm_cvtepu8_epi16(frame_data);
1661       const __m128i d = _mm_adds_epi16(c, b);
1662       Store4(dst, _mm_packus_epi16(d, d));
1663       dst += stride;
1664     } while (++i < tx_height);
1665   } else {
1666     int i = 0;
1667     do {
1668       const int row = i * tx_width;
1669       int j = 0;
1670       do {
1671         const __m128i v_src = LoadUnaligned16(&source[row + j]);
1672         const __m128i v_src_mult =
1673             _mm_mulhrs_epi16(v_src, v_multiplier_fraction);
1674         const __m128i frame_data = LoadLo8(dst + j);
1675         const __m128i v_dst_i = _mm_adds_epi16(v_src_mult, v_src);
1676         const __m128i a = _mm_adds_epi16(v_dst_i, v_eight);
1677         const __m128i b = _mm_srai_epi16(a, 4);
1678         const __m128i c = _mm_cvtepu8_epi16(frame_data);
1679         const __m128i d = _mm_adds_epi16(c, b);
1680         StoreLo8(dst + j, _mm_packus_epi16(d, d));
1681         j += 8;
1682       } while (j < tx_width);
1683       dst += stride;
1684     } while (++i < tx_height);
1685   }
1686 }
1687 
Identity4RowColumnStoreToFrame(Array2DView<uint8_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int16_t * LIBGAV1_RESTRICT source)1688 LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame(
1689     Array2DView<uint8_t> frame, const int start_x, const int start_y,
1690     const int tx_width, const int tx_height,
1691     const int16_t* LIBGAV1_RESTRICT source) {
1692   const int stride = frame.columns();
1693   uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1694 
1695   const __m128i v_multiplier_fraction =
1696       _mm_set1_epi16(static_cast<int16_t>(kIdentity4MultiplierFraction << 3));
1697   const __m128i v_eight = _mm_set1_epi16(8);
1698   const __m128i v_kTransformRowMultiplier =
1699       _mm_set1_epi16(kTransformRowMultiplier << 3);
1700 
1701   if (tx_width == 4) {
1702     int i = 0;
1703     do {
1704       const __m128i v_src = LoadLo8(&source[i * tx_width]);
1705       const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier_fraction);
1706       const __m128i frame_data = Load4(dst);
1707       const __m128i v_dst_row = _mm_adds_epi16(v_src_mult, v_src);
1708       const __m128i v_src_mult2 =
1709           _mm_mulhrs_epi16(v_dst_row, v_multiplier_fraction);
1710       const __m128i frame_data16 = _mm_cvtepu8_epi16(frame_data);
1711       const __m128i v_dst_col = _mm_adds_epi16(v_src_mult2, v_dst_row);
1712       const __m128i a = _mm_adds_epi16(v_dst_col, v_eight);
1713       const __m128i b = _mm_srai_epi16(a, 4);
1714       const __m128i c = _mm_adds_epi16(frame_data16, b);
1715       Store4(dst, _mm_packus_epi16(c, c));
1716       dst += stride;
1717     } while (++i < tx_height);
1718   } else {
1719     int i = 0;
1720     do {
1721       const int row = i * tx_width;
1722       int j = 0;
1723       do {
1724         const __m128i v_src = LoadUnaligned16(&source[row + j]);
1725         const __m128i v_src_round =
1726             _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier);
1727         const __m128i v_dst_row = _mm_adds_epi16(v_src_round, v_src_round);
1728         const __m128i v_src_mult2 =
1729             _mm_mulhrs_epi16(v_dst_row, v_multiplier_fraction);
1730         const __m128i frame_data = LoadLo8(dst + j);
1731         const __m128i frame_data16 = _mm_cvtepu8_epi16(frame_data);
1732         const __m128i v_dst_col = _mm_adds_epi16(v_src_mult2, v_dst_row);
1733         const __m128i a = _mm_adds_epi16(v_dst_col, v_eight);
1734         const __m128i b = _mm_srai_epi16(a, 4);
1735         const __m128i c = _mm_adds_epi16(frame_data16, b);
1736         StoreLo8(dst + j, _mm_packus_epi16(c, c));
1737         j += 8;
1738       } while (j < tx_width);
1739       dst += stride;
1740     } while (++i < tx_height);
1741   }
1742 }
1743 
Identity8Row32_SSE4_1(void * dest,int32_t step)1744 LIBGAV1_ALWAYS_INLINE void Identity8Row32_SSE4_1(void* dest, int32_t step) {
1745   auto* const dst = static_cast<int16_t*>(dest);
1746 
1747   // When combining the identity8 multiplier with the row shift, the
1748   // calculations for tx_height equal to 32 can be simplified from
1749   // ((A * 2) + 2) >> 2) to ((A + 1) >> 1).
1750   const __m128i v_row_multiplier = _mm_set1_epi16(1 << 14);
1751   for (int h = 0; h < 4; ++h) {
1752     const __m128i v_src = LoadUnaligned16(&dst[h * step]);
1753     const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_row_multiplier);
1754     StoreUnaligned16(&dst[h * step], v_src_mult);
1755   }
1756 }
1757 
Identity8Row4_SSE4_1(void * dest,int32_t step)1758 LIBGAV1_ALWAYS_INLINE void Identity8Row4_SSE4_1(void* dest, int32_t step) {
1759   auto* const dst = static_cast<int16_t*>(dest);
1760 
1761   for (int h = 0; h < 4; ++h) {
1762     const __m128i v_src = LoadUnaligned16(&dst[h * step]);
1763     // For bitdepth == 8, the identity row clamps to a signed 16bit value, so
1764     // saturating add here is ok.
1765     const __m128i a = _mm_adds_epi16(v_src, v_src);
1766     StoreUnaligned16(&dst[h * step], a);
1767   }
1768 }
1769 
Identity8DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1770 LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, int adjusted_tx_height,
1771                                            bool should_round, int row_shift) {
1772   if (adjusted_tx_height > 1) return false;
1773 
1774   auto* dst = static_cast<int16_t*>(dest);
1775   const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]);
1776   const __m128i v_mask =
1777       _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0);
1778   const __m128i v_kTransformRowMultiplier =
1779       _mm_set1_epi16(kTransformRowMultiplier << 3);
1780   const __m128i v_src_round =
1781       _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier);
1782   const __m128i v_src =
1783       _mm_cvtepi16_epi32(_mm_blendv_epi8(v_src0, v_src_round, v_mask));
1784   const __m128i v_srcx2 = _mm_add_epi32(v_src, v_src);
1785   const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
1786   const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
1787   const __m128i a = _mm_add_epi32(v_srcx2, v_row_shift_add);
1788   const __m128i b = _mm_sra_epi32(a, v_row_shift);
1789   dst[0] = _mm_extract_epi16(_mm_packs_epi32(b, b), 0);
1790   return true;
1791 }
1792 
Identity8ColumnStoreToFrame_SSE4_1(Array2DView<uint8_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int16_t * LIBGAV1_RESTRICT source)1793 LIBGAV1_ALWAYS_INLINE void Identity8ColumnStoreToFrame_SSE4_1(
1794     Array2DView<uint8_t> frame, const int start_x, const int start_y,
1795     const int tx_width, const int tx_height,
1796     const int16_t* LIBGAV1_RESTRICT source) {
1797   const int stride = frame.columns();
1798   uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1799   const __m128i v_eight = _mm_set1_epi16(8);
1800   if (tx_width == 4) {
1801     int i = 0;
1802     do {
1803       const int row = i * tx_width;
1804       const __m128i v_src = LoadLo8(&source[row]);
1805       const __m128i v_dst_i = _mm_adds_epi16(v_src, v_src);
1806       const __m128i frame_data = Load4(dst);
1807       const __m128i a = _mm_adds_epi16(v_dst_i, v_eight);
1808       const __m128i b = _mm_srai_epi16(a, 4);
1809       const __m128i c = _mm_cvtepu8_epi16(frame_data);
1810       const __m128i d = _mm_adds_epi16(c, b);
1811       Store4(dst, _mm_packus_epi16(d, d));
1812       dst += stride;
1813     } while (++i < tx_height);
1814   } else {
1815     int i = 0;
1816     do {
1817       const int row = i * tx_width;
1818       int j = 0;
1819       do {
1820         const __m128i v_src = LoadUnaligned16(&source[row + j]);
1821         const __m128i v_dst_i = _mm_adds_epi16(v_src, v_src);
1822         const __m128i frame_data = LoadLo8(dst + j);
1823         const __m128i a = _mm_adds_epi16(v_dst_i, v_eight);
1824         const __m128i b = _mm_srai_epi16(a, 4);
1825         const __m128i c = _mm_cvtepu8_epi16(frame_data);
1826         const __m128i d = _mm_adds_epi16(c, b);
1827         StoreLo8(dst + j, _mm_packus_epi16(d, d));
1828         j += 8;
1829       } while (j < tx_width);
1830       dst += stride;
1831     } while (++i < tx_height);
1832   }
1833 }
1834 
Identity16Row_SSE4_1(void * dest,int32_t step,int shift)1835 LIBGAV1_ALWAYS_INLINE void Identity16Row_SSE4_1(void* dest, int32_t step,
1836                                                 int shift) {
1837   auto* const dst = static_cast<int16_t*>(dest);
1838 
1839   const __m128i v_dual_round = _mm_set1_epi16((1 + (shift << 1)) << 11);
1840   const __m128i v_multiplier_one =
1841       _mm_set1_epi32((kIdentity16Multiplier << 16) | 0x0001);
1842   const __m128i v_shift = _mm_set_epi64x(0, 12 + shift);
1843 
1844   for (int h = 0; h < 4; ++h) {
1845     const __m128i v_src = LoadUnaligned16(&dst[h * step]);
1846     const __m128i v_src2 = LoadUnaligned16(&dst[h * step + 8]);
1847     const __m128i v_src_round0 = _mm_unpacklo_epi16(v_dual_round, v_src);
1848     const __m128i v_src_round1 = _mm_unpackhi_epi16(v_dual_round, v_src);
1849     const __m128i v_src2_round0 = _mm_unpacklo_epi16(v_dual_round, v_src2);
1850     const __m128i v_src2_round1 = _mm_unpackhi_epi16(v_dual_round, v_src2);
1851     const __m128i madd0 = _mm_madd_epi16(v_src_round0, v_multiplier_one);
1852     const __m128i madd1 = _mm_madd_epi16(v_src_round1, v_multiplier_one);
1853     const __m128i madd20 = _mm_madd_epi16(v_src2_round0, v_multiplier_one);
1854     const __m128i madd21 = _mm_madd_epi16(v_src2_round1, v_multiplier_one);
1855     const __m128i shift0 = _mm_sra_epi32(madd0, v_shift);
1856     const __m128i shift1 = _mm_sra_epi32(madd1, v_shift);
1857     const __m128i shift20 = _mm_sra_epi32(madd20, v_shift);
1858     const __m128i shift21 = _mm_sra_epi32(madd21, v_shift);
1859     StoreUnaligned16(&dst[h * step], _mm_packs_epi32(shift0, shift1));
1860     StoreUnaligned16(&dst[h * step + 8], _mm_packs_epi32(shift20, shift21));
1861   }
1862 }
1863 
Identity16DcOnly(void * dest,int adjusted_tx_height,bool should_round,int shift)1864 LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, int adjusted_tx_height,
1865                                             bool should_round, int shift) {
1866   if (adjusted_tx_height > 1) return false;
1867 
1868   auto* dst = static_cast<int16_t*>(dest);
1869   const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]);
1870   const __m128i v_mask =
1871       _mm_set1_epi16(should_round ? static_cast<int16_t>(0xffff) : 0);
1872   const __m128i v_kTransformRowMultiplier =
1873       _mm_set1_epi16(kTransformRowMultiplier << 3);
1874   const __m128i v_src_round0 =
1875       _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier);
1876   const __m128i v_src = _mm_blendv_epi8(v_src0, v_src_round0, v_mask);
1877   const __m128i v_dual_round = _mm_set1_epi16((1 + (shift << 1)) << 11);
1878   const __m128i v_multiplier_one =
1879       _mm_set1_epi32((kIdentity16Multiplier << 16) | 0x0001);
1880   const __m128i v_shift = _mm_set_epi64x(0, 12 + shift);
1881   const __m128i v_src_round = _mm_unpacklo_epi16(v_dual_round, v_src);
1882   const __m128i a = _mm_madd_epi16(v_src_round, v_multiplier_one);
1883   const __m128i b = _mm_sra_epi32(a, v_shift);
1884   dst[0] = _mm_extract_epi16(_mm_packs_epi32(b, b), 0);
1885   return true;
1886 }
1887 
Identity16ColumnStoreToFrame_SSE4_1(Array2DView<uint8_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int16_t * LIBGAV1_RESTRICT source)1888 LIBGAV1_ALWAYS_INLINE void Identity16ColumnStoreToFrame_SSE4_1(
1889     Array2DView<uint8_t> frame, const int start_x, const int start_y,
1890     const int tx_width, const int tx_height,
1891     const int16_t* LIBGAV1_RESTRICT source) {
1892   const int stride = frame.columns();
1893   uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1894   const __m128i v_eight = _mm_set1_epi16(8);
1895   const __m128i v_multiplier =
1896       _mm_set1_epi16(static_cast<int16_t>(kIdentity4MultiplierFraction << 4));
1897 
1898   if (tx_width == 4) {
1899     int i = 0;
1900     do {
1901       const __m128i v_src = LoadLo8(&source[i * tx_width]);
1902       const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier);
1903       const __m128i frame_data = Load4(dst);
1904       const __m128i v_srcx2 = _mm_adds_epi16(v_src, v_src);
1905       const __m128i v_dst_i = _mm_adds_epi16(v_src_mult, v_srcx2);
1906       const __m128i a = _mm_adds_epi16(v_dst_i, v_eight);
1907       const __m128i b = _mm_srai_epi16(a, 4);
1908       const __m128i c = _mm_cvtepu8_epi16(frame_data);
1909       const __m128i d = _mm_adds_epi16(c, b);
1910       Store4(dst, _mm_packus_epi16(d, d));
1911       dst += stride;
1912     } while (++i < tx_height);
1913   } else {
1914     int i = 0;
1915     do {
1916       const int row = i * tx_width;
1917       int j = 0;
1918       do {
1919         const __m128i v_src = LoadUnaligned16(&source[row + j]);
1920         const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier);
1921         const __m128i frame_data = LoadLo8(dst + j);
1922         const __m128i v_srcx2 = _mm_adds_epi16(v_src, v_src);
1923         const __m128i v_dst_i = _mm_adds_epi16(v_src_mult, v_srcx2);
1924         const __m128i a = _mm_adds_epi16(v_dst_i, v_eight);
1925         const __m128i b = _mm_srai_epi16(a, 4);
1926         const __m128i c = _mm_cvtepu8_epi16(frame_data);
1927         const __m128i d = _mm_adds_epi16(c, b);
1928         StoreLo8(dst + j, _mm_packus_epi16(d, d));
1929         j += 8;
1930       } while (j < tx_width);
1931       dst += stride;
1932     } while (++i < tx_height);
1933   }
1934 }
1935 
Identity32Row16_SSE4_1(void * dest,const int32_t step)1936 LIBGAV1_ALWAYS_INLINE void Identity32Row16_SSE4_1(void* dest,
1937                                                   const int32_t step) {
1938   auto* const dst = static_cast<int16_t*>(dest);
1939 
1940   // When combining the identity32 multiplier with the row shift, the
1941   // calculation for tx_height equal to 16 can be simplified from
1942   // ((A * 4) + 1) >> 1) to (A * 2).
1943   for (int h = 0; h < 4; ++h) {
1944     for (int i = 0; i < 32; i += 8) {
1945       const __m128i v_src = LoadUnaligned16(&dst[h * step + i]);
1946       // For bitdepth == 8, the identity row clamps to a signed 16bit value, so
1947       // saturating add here is ok.
1948       const __m128i v_dst_i = _mm_adds_epi16(v_src, v_src);
1949       StoreUnaligned16(&dst[h * step + i], v_dst_i);
1950     }
1951   }
1952 }
1953 
Identity32DcOnly(void * dest,int adjusted_tx_height)1954 LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest,
1955                                             int adjusted_tx_height) {
1956   if (adjusted_tx_height > 1) return false;
1957 
1958   auto* dst = static_cast<int16_t*>(dest);
1959   const __m128i v_src0 = _mm_cvtsi32_si128(dst[0]);
1960   const __m128i v_kTransformRowMultiplier =
1961       _mm_set1_epi16(kTransformRowMultiplier << 3);
1962   const __m128i v_src = _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier);
1963 
1964   // When combining the identity32 multiplier with the row shift, the
1965   // calculation for tx_height equal to 16 can be simplified from
1966   // ((A * 4) + 1) >> 1) to (A * 2).
1967   const __m128i v_dst_0 = _mm_adds_epi16(v_src, v_src);
1968   dst[0] = _mm_extract_epi16(v_dst_0, 0);
1969   return true;
1970 }
1971 
Identity32ColumnStoreToFrame(Array2DView<uint8_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int16_t * LIBGAV1_RESTRICT source)1972 LIBGAV1_ALWAYS_INLINE void Identity32ColumnStoreToFrame(
1973     Array2DView<uint8_t> frame, const int start_x, const int start_y,
1974     const int tx_width, const int tx_height,
1975     const int16_t* LIBGAV1_RESTRICT source) {
1976   const int stride = frame.columns();
1977   uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1978   const __m128i v_two = _mm_set1_epi16(2);
1979 
1980   int i = 0;
1981   do {
1982     const int row = i * tx_width;
1983     int j = 0;
1984     do {
1985       const __m128i v_dst_i = LoadUnaligned16(&source[row + j]);
1986       const __m128i frame_data = LoadLo8(dst + j);
1987       const __m128i a = _mm_adds_epi16(v_dst_i, v_two);
1988       const __m128i b = _mm_srai_epi16(a, 2);
1989       const __m128i c = _mm_cvtepu8_epi16(frame_data);
1990       const __m128i d = _mm_adds_epi16(c, b);
1991       StoreLo8(dst + j, _mm_packus_epi16(d, d));
1992       j += 8;
1993     } while (j < tx_width);
1994     dst += stride;
1995   } while (++i < tx_height);
1996 }
1997 
1998 //------------------------------------------------------------------------------
1999 // Walsh Hadamard Transform.
2000 
2001 // Process 4 wht4 rows and columns.
Wht4_SSE4_1(Array2DView<uint8_t> frame,const int start_x,const int start_y,const void * LIBGAV1_RESTRICT source,const int adjusted_tx_height)2002 LIBGAV1_ALWAYS_INLINE void Wht4_SSE4_1(Array2DView<uint8_t> frame,
2003                                        const int start_x, const int start_y,
2004                                        const void* LIBGAV1_RESTRICT source,
2005                                        const int adjusted_tx_height) {
2006   const auto* const src = static_cast<const int16_t*>(source);
2007   __m128i s[4], x[4];
2008 
2009   if (adjusted_tx_height == 1) {
2010     // Special case: only src[0] is nonzero.
2011     //   src[0]  0   0   0
2012     //       0   0   0   0
2013     //       0   0   0   0
2014     //       0   0   0   0
2015     //
2016     // After the row and column transforms are applied, we have:
2017     //       f   h   h   h
2018     //       g   i   i   i
2019     //       g   i   i   i
2020     //       g   i   i   i
2021     // where f, g, h, i are computed as follows.
2022     int16_t f = (src[0] >> 2) - (src[0] >> 3);
2023     const int16_t g = f >> 1;
2024     f = f - (f >> 1);
2025     const int16_t h = (src[0] >> 3) - (src[0] >> 4);
2026     const int16_t i = (src[0] >> 4);
2027     s[0] = _mm_set1_epi16(h);
2028     s[0] = _mm_insert_epi16(s[0], f, 0);
2029     s[1] = _mm_set1_epi16(i);
2030     s[1] = _mm_insert_epi16(s[1], g, 0);
2031     s[2] = s[3] = s[1];
2032   } else {
2033     x[0] = LoadLo8(&src[0 * 4]);
2034     x[2] = LoadLo8(&src[1 * 4]);
2035     x[3] = LoadLo8(&src[2 * 4]);
2036     x[1] = LoadLo8(&src[3 * 4]);
2037 
2038     // Row transforms.
2039     Transpose4x4_U16(x, x);
2040     s[0] = _mm_srai_epi16(x[0], 2);
2041     s[2] = _mm_srai_epi16(x[1], 2);
2042     s[3] = _mm_srai_epi16(x[2], 2);
2043     s[1] = _mm_srai_epi16(x[3], 2);
2044     s[0] = _mm_add_epi16(s[0], s[2]);
2045     s[3] = _mm_sub_epi16(s[3], s[1]);
2046     __m128i e = _mm_sub_epi16(s[0], s[3]);
2047     e = _mm_srai_epi16(e, 1);
2048     s[1] = _mm_sub_epi16(e, s[1]);
2049     s[2] = _mm_sub_epi16(e, s[2]);
2050     s[0] = _mm_sub_epi16(s[0], s[1]);
2051     s[3] = _mm_add_epi16(s[3], s[2]);
2052     Transpose4x4_U16(s, s);
2053 
2054     // Column transforms.
2055     s[0] = _mm_add_epi16(s[0], s[2]);
2056     s[3] = _mm_sub_epi16(s[3], s[1]);
2057     e = _mm_sub_epi16(s[0], s[3]);
2058     e = _mm_srai_epi16(e, 1);
2059     s[1] = _mm_sub_epi16(e, s[1]);
2060     s[2] = _mm_sub_epi16(e, s[2]);
2061     s[0] = _mm_sub_epi16(s[0], s[1]);
2062     s[3] = _mm_add_epi16(s[3], s[2]);
2063   }
2064 
2065   // Store to frame.
2066   const int stride = frame.columns();
2067   uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
2068   for (int row = 0; row < 4; ++row) {
2069     const __m128i frame_data = Load4(dst);
2070     const __m128i a = _mm_cvtepu8_epi16(frame_data);
2071     const __m128i b = _mm_add_epi16(a, s[row]);
2072     Store4(dst, _mm_packus_epi16(b, b));
2073     dst += stride;
2074   }
2075 }
2076 
2077 //------------------------------------------------------------------------------
2078 // row/column transform loops
2079 
2080 template <bool enable_flip_rows = false>
StoreToFrameWithRound(Array2DView<uint8_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int16_t * LIBGAV1_RESTRICT source,TransformType tx_type)2081 LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
2082     Array2DView<uint8_t> frame, const int start_x, const int start_y,
2083     const int tx_width, const int tx_height,
2084     const int16_t* LIBGAV1_RESTRICT source, TransformType tx_type) {
2085   const bool flip_rows =
2086       enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
2087   const __m128i v_eight = _mm_set1_epi16(8);
2088   const int stride = frame.columns();
2089   uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
2090   if (tx_width == 4) {
2091     for (int i = 0; i < tx_height; ++i) {
2092       const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4;
2093       const __m128i residual = LoadLo8(&source[row]);
2094       const __m128i frame_data = Load4(dst);
2095       // Saturate to prevent overflowing int16_t
2096       const __m128i a = _mm_adds_epi16(residual, v_eight);
2097       const __m128i b = _mm_srai_epi16(a, 4);
2098       const __m128i c = _mm_cvtepu8_epi16(frame_data);
2099       const __m128i d = _mm_adds_epi16(c, b);
2100       Store4(dst, _mm_packus_epi16(d, d));
2101       dst += stride;
2102     }
2103   } else if (tx_width == 8) {
2104     for (int i = 0; i < tx_height; ++i) {
2105       const int row = flip_rows ? (tx_height - i - 1) * 8 : i * 8;
2106       const __m128i residual = LoadUnaligned16(&source[row]);
2107       const __m128i frame_data = LoadLo8(dst);
2108       // Saturate to prevent overflowing int16_t
2109       const __m128i b = _mm_adds_epi16(residual, v_eight);
2110       const __m128i c = _mm_srai_epi16(b, 4);
2111       const __m128i d = _mm_cvtepu8_epi16(frame_data);
2112       const __m128i e = _mm_adds_epi16(d, c);
2113       StoreLo8(dst, _mm_packus_epi16(e, e));
2114       dst += stride;
2115     }
2116   } else {
2117     for (int i = 0; i < tx_height; ++i) {
2118       const int y = start_y + i;
2119       const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width;
2120       int j = 0;
2121       do {
2122         const int x = start_x + j;
2123         const __m128i residual = LoadUnaligned16(&source[row + j]);
2124         const __m128i residual_hi = LoadUnaligned16(&source[row + j + 8]);
2125         const __m128i frame_data = LoadUnaligned16(frame[y] + x);
2126         const __m128i b = _mm_adds_epi16(residual, v_eight);
2127         const __m128i b_hi = _mm_adds_epi16(residual_hi, v_eight);
2128         const __m128i c = _mm_srai_epi16(b, 4);
2129         const __m128i c_hi = _mm_srai_epi16(b_hi, 4);
2130         const __m128i d = _mm_cvtepu8_epi16(frame_data);
2131         const __m128i d_hi = _mm_cvtepu8_epi16(_mm_srli_si128(frame_data, 8));
2132         const __m128i e = _mm_adds_epi16(d, c);
2133         const __m128i e_hi = _mm_adds_epi16(d_hi, c_hi);
2134         StoreUnaligned16(frame[y] + x, _mm_packus_epi16(e, e_hi));
2135         j += 16;
2136       } while (j < tx_width);
2137     }
2138   }
2139 }
2140 
2141 template <int tx_height>
FlipColumns(int16_t * source,int tx_width)2142 LIBGAV1_ALWAYS_INLINE void FlipColumns(int16_t* source, int tx_width) {
2143   const __m128i word_reverse_8 =
2144       _mm_set_epi32(0x01000302, 0x05040706, 0x09080b0a, 0x0d0c0f0e);
2145   if (tx_width >= 16) {
2146     int i = 0;
2147     do {
2148       // read 16 shorts
2149       const __m128i v3210 = LoadUnaligned16(&source[i]);
2150       const __m128i v7654 = LoadUnaligned16(&source[i + 8]);
2151       const __m128i v0123 = _mm_shuffle_epi8(v3210, word_reverse_8);
2152       const __m128i v4567 = _mm_shuffle_epi8(v7654, word_reverse_8);
2153       StoreUnaligned16(&source[i], v4567);
2154       StoreUnaligned16(&source[i + 8], v0123);
2155       i += 16;
2156     } while (i < tx_width * tx_height);
2157   } else if (tx_width == 8) {
2158     for (int i = 0; i < 8 * tx_height; i += 8) {
2159       const __m128i a = LoadUnaligned16(&source[i]);
2160       const __m128i b = _mm_shuffle_epi8(a, word_reverse_8);
2161       StoreUnaligned16(&source[i], b);
2162     }
2163   } else {
2164     const __m128i dual_word_reverse_4 =
2165         _mm_set_epi32(0x09080b0a, 0x0d0c0f0e, 0x01000302, 0x05040706);
2166     // Process two rows per iteration.
2167     for (int i = 0; i < 4 * tx_height; i += 8) {
2168       const __m128i a = LoadUnaligned16(&source[i]);
2169       const __m128i b = _mm_shuffle_epi8(a, dual_word_reverse_4);
2170       StoreUnaligned16(&source[i], b);
2171     }
2172   }
2173 }
2174 
2175 template <int tx_width>
ApplyRounding(int16_t * source,int num_rows)2176 LIBGAV1_ALWAYS_INLINE void ApplyRounding(int16_t* source, int num_rows) {
2177   const __m128i v_kTransformRowMultiplier =
2178       _mm_set1_epi16(kTransformRowMultiplier << 3);
2179   if (tx_width == 4) {
2180     // Process two rows per iteration.
2181     int i = 0;
2182     do {
2183       const __m128i a = LoadUnaligned16(&source[i]);
2184       const __m128i b = _mm_mulhrs_epi16(a, v_kTransformRowMultiplier);
2185       StoreUnaligned16(&source[i], b);
2186       i += 8;
2187     } while (i < tx_width * num_rows);
2188   } else {
2189     int i = 0;
2190     do {
2191       // The last 32 values of every row are always zero if the |tx_width| is
2192       // 64.
2193       const int non_zero_width = (tx_width < 64) ? tx_width : 32;
2194       int j = 0;
2195       do {
2196         const __m128i a = LoadUnaligned16(&source[i * tx_width + j]);
2197         const __m128i b = _mm_mulhrs_epi16(a, v_kTransformRowMultiplier);
2198         StoreUnaligned16(&source[i * tx_width + j], b);
2199         j += 8;
2200       } while (j < non_zero_width);
2201     } while (++i < num_rows);
2202   }
2203 }
2204 
2205 template <int tx_width>
RowShift(int16_t * source,int num_rows,int row_shift)2206 LIBGAV1_ALWAYS_INLINE void RowShift(int16_t* source, int num_rows,
2207                                     int row_shift) {
2208   const __m128i v_row_shift_add = _mm_set1_epi16(row_shift);
2209   const __m128i v_row_shift = _mm_cvtepu16_epi64(v_row_shift_add);
2210   if (tx_width == 4) {
2211     // Process two rows per iteration.
2212     int i = 0;
2213     do {
2214       const __m128i residual = LoadUnaligned16(&source[i]);
2215       const __m128i shifted_residual =
2216           ShiftResidual(residual, v_row_shift_add, v_row_shift);
2217       StoreUnaligned16(&source[i], shifted_residual);
2218       i += 8;
2219     } while (i < tx_width * num_rows);
2220   } else {
2221     int i = 0;
2222     do {
2223       for (int j = 0; j < tx_width; j += 8) {
2224         const __m128i residual = LoadUnaligned16(&source[i * tx_width + j]);
2225         const __m128i shifted_residual =
2226             ShiftResidual(residual, v_row_shift_add, v_row_shift);
2227         StoreUnaligned16(&source[i * tx_width + j], shifted_residual);
2228       }
2229     } while (++i < num_rows);
2230   }
2231 }
2232 
Dct4TransformLoopRow_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2233 void Dct4TransformLoopRow_SSE4_1(TransformType /*tx_type*/,
2234                                  TransformSize tx_size, int adjusted_tx_height,
2235                                  void* src_buffer, int /*start_x*/,
2236                                  int /*start_y*/, void* /*dst_frame*/) {
2237   auto* src = static_cast<int16_t*>(src_buffer);
2238   const int tx_height = kTransformHeight[tx_size];
2239   const bool should_round = (tx_height == 8);
2240   const int row_shift = static_cast<int>(tx_height == 16);
2241 
2242   if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) {
2243     return;
2244   }
2245 
2246   if (should_round) {
2247     ApplyRounding<4>(src, adjusted_tx_height);
2248   }
2249 
2250   if (adjusted_tx_height <= 4) {
2251     // Process 4 1d dct4 rows in parallel.
2252     Dct4_SSE4_1<ButterflyRotation_4, false>(src, /*step=*/4,
2253                                             /*transpose=*/true);
2254   } else {
2255     // Process 8 1d dct4 rows in parallel per iteration.
2256     int i = 0;
2257     do {
2258       Dct4_SSE4_1<ButterflyRotation_8, true>(&src[i * 4], /*step=*/4,
2259                                              /*transpose=*/true);
2260       i += 8;
2261     } while (i < adjusted_tx_height);
2262   }
2263   if (tx_height == 16) {
2264     RowShift<4>(src, adjusted_tx_height, 1);
2265   }
2266 }
2267 
Dct4TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2268 void Dct4TransformLoopColumn_SSE4_1(TransformType tx_type,
2269                                     TransformSize tx_size,
2270                                     int adjusted_tx_height,
2271                                     void* LIBGAV1_RESTRICT src_buffer,
2272                                     int start_x, int start_y,
2273                                     void* LIBGAV1_RESTRICT dst_frame) {
2274   auto* src = static_cast<int16_t*>(src_buffer);
2275   const int tx_width = kTransformWidth[tx_size];
2276 
2277   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2278     FlipColumns<4>(src, tx_width);
2279   }
2280 
2281   if (!DctDcOnlyColumn<4>(src, adjusted_tx_height, tx_width)) {
2282     if (tx_width == 4) {
2283       // Process 4 1d dct4 columns in parallel.
2284       Dct4_SSE4_1<ButterflyRotation_4, false>(src, tx_width,
2285                                               /*transpose=*/false);
2286     } else {
2287       // Process 8 1d dct4 columns in parallel per iteration.
2288       int i = 0;
2289       do {
2290         Dct4_SSE4_1<ButterflyRotation_8, true>(&src[i], tx_width,
2291                                                /*transpose=*/false);
2292         i += 8;
2293       } while (i < tx_width);
2294     }
2295   }
2296   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2297   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 4, src, tx_type);
2298 }
2299 
Dct8TransformLoopRow_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2300 void Dct8TransformLoopRow_SSE4_1(TransformType /*tx_type*/,
2301                                  TransformSize tx_size, int adjusted_tx_height,
2302                                  void* src_buffer, int /*start_x*/,
2303                                  int /*start_y*/, void* /*dst_frame*/) {
2304   auto* src = static_cast<int16_t*>(src_buffer);
2305   const bool should_round = kShouldRound[tx_size];
2306   const uint8_t row_shift = kTransformRowShift[tx_size];
2307 
2308   if (DctDcOnly<8>(src, adjusted_tx_height, should_round, row_shift)) {
2309     return;
2310   }
2311 
2312   if (should_round) {
2313     ApplyRounding<8>(src, adjusted_tx_height);
2314   }
2315 
2316   if (adjusted_tx_height <= 4) {
2317     // Process 4 1d dct8 rows in parallel.
2318     Dct8_SSE4_1<ButterflyRotation_4, true>(src, /*step=*/8, /*transpose=*/true);
2319   } else {
2320     // Process 8 1d dct8 rows in parallel per iteration.
2321     int i = 0;
2322     do {
2323       Dct8_SSE4_1<ButterflyRotation_8, false>(&src[i * 8], /*step=*/8,
2324                                               /*transpose=*/true);
2325       i += 8;
2326     } while (i < adjusted_tx_height);
2327   }
2328   if (row_shift > 0) {
2329     RowShift<8>(src, adjusted_tx_height, row_shift);
2330   }
2331 }
2332 
Dct8TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2333 void Dct8TransformLoopColumn_SSE4_1(TransformType tx_type,
2334                                     TransformSize tx_size,
2335                                     int adjusted_tx_height,
2336                                     void* LIBGAV1_RESTRICT src_buffer,
2337                                     int start_x, int start_y,
2338                                     void* LIBGAV1_RESTRICT dst_frame) {
2339   auto* src = static_cast<int16_t*>(src_buffer);
2340   const int tx_width = kTransformWidth[tx_size];
2341 
2342   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2343     FlipColumns<8>(src, tx_width);
2344   }
2345 
2346   if (!DctDcOnlyColumn<8>(src, adjusted_tx_height, tx_width)) {
2347     if (tx_width == 4) {
2348       // Process 4 1d dct8 columns in parallel.
2349       Dct8_SSE4_1<ButterflyRotation_4, true>(src, 4, /*transpose=*/false);
2350     } else {
2351       // Process 8 1d dct8 columns in parallel per iteration.
2352       int i = 0;
2353       do {
2354         Dct8_SSE4_1<ButterflyRotation_8, false>(&src[i], tx_width,
2355                                                 /*transpose=*/false);
2356         i += 8;
2357       } while (i < tx_width);
2358     }
2359   }
2360   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2361   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 8, src, tx_type);
2362 }
2363 
Dct16TransformLoopRow_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2364 void Dct16TransformLoopRow_SSE4_1(TransformType /*tx_type*/,
2365                                   TransformSize tx_size, int adjusted_tx_height,
2366                                   void* src_buffer, int /*start_x*/,
2367                                   int /*start_y*/, void* /*dst_frame*/) {
2368   auto* src = static_cast<int16_t*>(src_buffer);
2369   const bool should_round = kShouldRound[tx_size];
2370   const uint8_t row_shift = kTransformRowShift[tx_size];
2371 
2372   if (DctDcOnly<16>(src, adjusted_tx_height, should_round, row_shift)) {
2373     return;
2374   }
2375 
2376   if (should_round) {
2377     ApplyRounding<16>(src, adjusted_tx_height);
2378   }
2379 
2380   if (adjusted_tx_height <= 4) {
2381     // Process 4 1d dct16 rows in parallel.
2382     Dct16_SSE4_1<ButterflyRotation_4, true>(src, 16, /*transpose=*/true);
2383   } else {
2384     int i = 0;
2385     do {
2386       // Process 8 1d dct16 rows in parallel per iteration.
2387       Dct16_SSE4_1<ButterflyRotation_8, false>(&src[i * 16], 16,
2388                                                /*transpose=*/true);
2389       i += 8;
2390     } while (i < adjusted_tx_height);
2391   }
2392   // row_shift is always non zero here.
2393   RowShift<16>(src, adjusted_tx_height, row_shift);
2394 }
2395 
Dct16TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2396 void Dct16TransformLoopColumn_SSE4_1(TransformType tx_type,
2397                                      TransformSize tx_size,
2398                                      int adjusted_tx_height,
2399                                      void* LIBGAV1_RESTRICT src_buffer,
2400                                      int start_x, int start_y,
2401                                      void* LIBGAV1_RESTRICT dst_frame) {
2402   auto* src = static_cast<int16_t*>(src_buffer);
2403   const int tx_width = kTransformWidth[tx_size];
2404 
2405   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2406     FlipColumns<16>(src, tx_width);
2407   }
2408 
2409   if (!DctDcOnlyColumn<16>(src, adjusted_tx_height, tx_width)) {
2410     if (tx_width == 4) {
2411       // Process 4 1d dct16 columns in parallel.
2412       Dct16_SSE4_1<ButterflyRotation_4, true>(src, 4, /*transpose=*/false);
2413     } else {
2414       int i = 0;
2415       do {
2416         // Process 8 1d dct16 columns in parallel per iteration.
2417         Dct16_SSE4_1<ButterflyRotation_8, false>(&src[i], tx_width,
2418                                                  /*transpose=*/false);
2419         i += 8;
2420       } while (i < tx_width);
2421     }
2422   }
2423   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2424   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 16, src, tx_type);
2425 }
2426 
Dct32TransformLoopRow_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2427 void Dct32TransformLoopRow_SSE4_1(TransformType /*tx_type*/,
2428                                   TransformSize tx_size, int adjusted_tx_height,
2429                                   void* src_buffer, int /*start_x*/,
2430                                   int /*start_y*/, void* /*dst_frame*/) {
2431   auto* src = static_cast<int16_t*>(src_buffer);
2432   const bool should_round = kShouldRound[tx_size];
2433   const uint8_t row_shift = kTransformRowShift[tx_size];
2434 
2435   if (DctDcOnly<32>(src, adjusted_tx_height, should_round, row_shift)) {
2436     return;
2437   }
2438 
2439   if (should_round) {
2440     ApplyRounding<32>(src, adjusted_tx_height);
2441   }
2442   // Process 8 1d dct32 rows in parallel per iteration.
2443   int i = 0;
2444   do {
2445     Dct32_SSE4_1(&src[i * 32], 32, /*transpose=*/true);
2446     i += 8;
2447   } while (i < adjusted_tx_height);
2448   // row_shift is always non zero here.
2449   RowShift<32>(src, adjusted_tx_height, row_shift);
2450 }
2451 
Dct32TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2452 void Dct32TransformLoopColumn_SSE4_1(TransformType tx_type,
2453                                      TransformSize tx_size,
2454                                      int adjusted_tx_height,
2455                                      void* LIBGAV1_RESTRICT src_buffer,
2456                                      int start_x, int start_y,
2457                                      void* LIBGAV1_RESTRICT dst_frame) {
2458   auto* src = static_cast<int16_t*>(src_buffer);
2459   const int tx_width = kTransformWidth[tx_size];
2460 
2461   if (!DctDcOnlyColumn<32>(src, adjusted_tx_height, tx_width)) {
2462     // Process 8 1d dct32 columns in parallel per iteration.
2463     int i = 0;
2464     do {
2465       Dct32_SSE4_1(&src[i], tx_width, /*transpose=*/false);
2466       i += 8;
2467     } while (i < tx_width);
2468   }
2469   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2470   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 32, src, tx_type);
2471 }
2472 
Dct64TransformLoopRow_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2473 void Dct64TransformLoopRow_SSE4_1(TransformType /*tx_type*/,
2474                                   TransformSize tx_size, int adjusted_tx_height,
2475                                   void* src_buffer, int /*start_x*/,
2476                                   int /*start_y*/, void* /*dst_frame*/) {
2477   auto* src = static_cast<int16_t*>(src_buffer);
2478   const bool should_round = kShouldRound[tx_size];
2479   const uint8_t row_shift = kTransformRowShift[tx_size];
2480 
2481   if (DctDcOnly<64>(src, adjusted_tx_height, should_round, row_shift)) {
2482     return;
2483   }
2484 
2485   if (should_round) {
2486     ApplyRounding<64>(src, adjusted_tx_height);
2487   }
2488   // Process 8 1d dct64 rows in parallel per iteration.
2489   int i = 0;
2490   do {
2491     Dct64_SSE4_1(&src[i * 64], 64, /*transpose=*/true);
2492     i += 8;
2493   } while (i < adjusted_tx_height);
2494   // row_shift is always non zero here.
2495   RowShift<64>(src, adjusted_tx_height, row_shift);
2496 }
2497 
Dct64TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2498 void Dct64TransformLoopColumn_SSE4_1(TransformType tx_type,
2499                                      TransformSize tx_size,
2500                                      int adjusted_tx_height,
2501                                      void* LIBGAV1_RESTRICT src_buffer,
2502                                      int start_x, int start_y,
2503                                      void* LIBGAV1_RESTRICT dst_frame) {
2504   auto* src = static_cast<int16_t*>(src_buffer);
2505   const int tx_width = kTransformWidth[tx_size];
2506 
2507   if (!DctDcOnlyColumn<64>(src, adjusted_tx_height, tx_width)) {
2508     // Process 8 1d dct64 columns in parallel per iteration.
2509     int i = 0;
2510     do {
2511       Dct64_SSE4_1(&src[i], tx_width, /*transpose=*/false);
2512       i += 8;
2513     } while (i < tx_width);
2514   }
2515   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2516   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 64, src, tx_type);
2517 }
2518 
Adst4TransformLoopRow_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2519 void Adst4TransformLoopRow_SSE4_1(TransformType /*tx_type*/,
2520                                   TransformSize tx_size, int adjusted_tx_height,
2521                                   void* src_buffer, int /*start_x*/,
2522                                   int /*start_y*/, void* /*dst_frame*/) {
2523   auto* src = static_cast<int16_t*>(src_buffer);
2524   const int tx_height = kTransformHeight[tx_size];
2525   const int row_shift = static_cast<int>(tx_height == 16);
2526   const bool should_round = (tx_height == 8);
2527 
2528   if (Adst4DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2529     return;
2530   }
2531 
2532   if (should_round) {
2533     ApplyRounding<4>(src, adjusted_tx_height);
2534   }
2535 
2536   // Process 4 1d adst4 rows in parallel per iteration.
2537   int i = 0;
2538   do {
2539     Adst4_SSE4_1<false>(&src[i * 4], /*step=*/4, /*transpose=*/true);
2540     i += 4;
2541   } while (i < adjusted_tx_height);
2542 
2543   if (row_shift != 0) {
2544     RowShift<4>(src, adjusted_tx_height, 1);
2545   }
2546 }
2547 
Adst4TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2548 void Adst4TransformLoopColumn_SSE4_1(TransformType tx_type,
2549                                      TransformSize tx_size,
2550                                      int adjusted_tx_height,
2551                                      void* LIBGAV1_RESTRICT src_buffer,
2552                                      int start_x, int start_y,
2553                                      void* LIBGAV1_RESTRICT dst_frame) {
2554   auto* src = static_cast<int16_t*>(src_buffer);
2555   const int tx_width = kTransformWidth[tx_size];
2556 
2557   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2558     FlipColumns<4>(src, tx_width);
2559   }
2560 
2561   if (!Adst4DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2562     // Process 4 1d adst4 columns in parallel per iteration.
2563     int i = 0;
2564     do {
2565       Adst4_SSE4_1<false>(&src[i], tx_width, /*transpose=*/false);
2566       i += 4;
2567     } while (i < tx_width);
2568   }
2569   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2570   StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
2571                                                    tx_width, 4, src, tx_type);
2572 }
2573 
Adst8TransformLoopRow_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2574 void Adst8TransformLoopRow_SSE4_1(TransformType /*tx_type*/,
2575                                   TransformSize tx_size, int adjusted_tx_height,
2576                                   void* src_buffer, int /*start_x*/,
2577                                   int /*start_y*/, void* /*dst_frame*/) {
2578   auto* src = static_cast<int16_t*>(src_buffer);
2579   const bool should_round = kShouldRound[tx_size];
2580   const uint8_t row_shift = kTransformRowShift[tx_size];
2581 
2582   if (Adst8DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2583     return;
2584   }
2585 
2586   if (should_round) {
2587     ApplyRounding<8>(src, adjusted_tx_height);
2588   }
2589 
2590   if (adjusted_tx_height <= 4) {
2591     // Process 4 1d adst8 rows in parallel.
2592     Adst8_SSE4_1<ButterflyRotation_4, true>(src, /*step=*/8,
2593                                             /*transpose=*/true);
2594   } else {
2595     // Process 8 1d adst8 rows in parallel per iteration.
2596     int i = 0;
2597     do {
2598       Adst8_SSE4_1<ButterflyRotation_8, false>(&src[i * 8], /*step=*/8,
2599                                                /*transpose=*/true);
2600       i += 8;
2601     } while (i < adjusted_tx_height);
2602   }
2603   if (row_shift > 0) {
2604     RowShift<8>(src, adjusted_tx_height, row_shift);
2605   }
2606 }
2607 
Adst8TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2608 void Adst8TransformLoopColumn_SSE4_1(TransformType tx_type,
2609                                      TransformSize tx_size,
2610                                      int adjusted_tx_height,
2611                                      void* LIBGAV1_RESTRICT src_buffer,
2612                                      int start_x, int start_y,
2613                                      void* LIBGAV1_RESTRICT dst_frame) {
2614   auto* src = static_cast<int16_t*>(src_buffer);
2615   const int tx_width = kTransformWidth[tx_size];
2616 
2617   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2618     FlipColumns<8>(src, tx_width);
2619   }
2620 
2621   if (!Adst8DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2622     if (tx_width == 4) {
2623       // Process 4 1d adst8 columns in parallel.
2624       Adst8_SSE4_1<ButterflyRotation_4, true>(src, 4, /*transpose=*/false);
2625     } else {
2626       // Process 8 1d adst8 columns in parallel per iteration.
2627       int i = 0;
2628       do {
2629         Adst8_SSE4_1<ButterflyRotation_8, false>(&src[i], tx_width,
2630                                                  /*transpose=*/false);
2631         i += 8;
2632       } while (i < tx_width);
2633     }
2634   }
2635   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2636   StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
2637                                                    tx_width, 8, src, tx_type);
2638 }
2639 
Adst16TransformLoopRow_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2640 void Adst16TransformLoopRow_SSE4_1(TransformType /*tx_type*/,
2641                                    TransformSize tx_size,
2642                                    int adjusted_tx_height, void* src_buffer,
2643                                    int /*start_x*/, int /*start_y*/,
2644                                    void* /*dst_frame*/) {
2645   auto* src = static_cast<int16_t*>(src_buffer);
2646   const bool should_round = kShouldRound[tx_size];
2647   const uint8_t row_shift = kTransformRowShift[tx_size];
2648 
2649   if (Adst16DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2650     return;
2651   }
2652 
2653   if (should_round) {
2654     ApplyRounding<16>(src, adjusted_tx_height);
2655   }
2656 
2657   if (adjusted_tx_height <= 4) {
2658     // Process 4 1d adst16 rows in parallel.
2659     Adst16_SSE4_1<ButterflyRotation_4, true>(src, 16, /*transpose=*/true);
2660   } else {
2661     int i = 0;
2662     do {
2663       // Process 8 1d adst16 rows in parallel per iteration.
2664       Adst16_SSE4_1<ButterflyRotation_8, false>(&src[i * 16], 16,
2665                                                 /*transpose=*/true);
2666       i += 8;
2667     } while (i < adjusted_tx_height);
2668   }
2669   // row_shift is always non zero here.
2670   RowShift<16>(src, adjusted_tx_height, row_shift);
2671 }
2672 
Adst16TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2673 void Adst16TransformLoopColumn_SSE4_1(TransformType tx_type,
2674                                       TransformSize tx_size,
2675                                       int adjusted_tx_height,
2676                                       void* LIBGAV1_RESTRICT src_buffer,
2677                                       int start_x, int start_y,
2678                                       void* LIBGAV1_RESTRICT dst_frame) {
2679   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2680   auto* src = static_cast<int16_t*>(src_buffer);
2681   const int tx_width = kTransformWidth[tx_size];
2682 
2683   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2684     FlipColumns<16>(src, tx_width);
2685   }
2686 
2687   if (!Adst16DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2688     if (tx_width == 4) {
2689       // Process 4 1d adst16 columns in parallel.
2690       Adst16_SSE4_1<ButterflyRotation_4, true>(src, 4, /*transpose=*/false);
2691     } else {
2692       int i = 0;
2693       do {
2694         // Process 8 1d adst16 columns in parallel per iteration.
2695         Adst16_SSE4_1<ButterflyRotation_8, false>(&src[i], tx_width,
2696                                                   /*transpose=*/false);
2697         i += 8;
2698       } while (i < tx_width);
2699     }
2700   }
2701   StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
2702                                                    tx_width, 16, src, tx_type);
2703 }
2704 
Identity4TransformLoopRow_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2705 void Identity4TransformLoopRow_SSE4_1(TransformType tx_type,
2706                                       TransformSize tx_size,
2707                                       int adjusted_tx_height, void* src_buffer,
2708                                       int /*start_x*/, int /*start_y*/,
2709                                       void* /*dst_frame*/) {
2710   // Special case: Process row calculations during column transform call.
2711   // Improves performance.
2712   if (tx_type == kTransformTypeIdentityIdentity &&
2713       tx_size == kTransformSize4x4) {
2714     return;
2715   }
2716 
2717   auto* src = static_cast<int16_t*>(src_buffer);
2718   const int tx_height = kTransformHeight[tx_size];
2719   const bool should_round = (tx_height == 8);
2720   if (Identity4DcOnly(src, adjusted_tx_height, should_round, tx_height)) {
2721     return;
2722   }
2723 
2724   if (should_round) {
2725     ApplyRounding<4>(src, adjusted_tx_height);
2726   }
2727   if (tx_height < 16) {
2728     int i = 0;
2729     do {
2730       Identity4_SSE4_1<false>(&src[i * 4], /*step=*/4);
2731       i += 4;
2732     } while (i < adjusted_tx_height);
2733   } else {
2734     int i = 0;
2735     do {
2736       Identity4_SSE4_1<true>(&src[i * 4], /*step=*/4);
2737       i += 4;
2738     } while (i < adjusted_tx_height);
2739   }
2740 }
2741 
Identity4TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2742 void Identity4TransformLoopColumn_SSE4_1(TransformType tx_type,
2743                                          TransformSize tx_size,
2744                                          int adjusted_tx_height,
2745                                          void* LIBGAV1_RESTRICT src_buffer,
2746                                          int start_x, int start_y,
2747                                          void* LIBGAV1_RESTRICT dst_frame) {
2748   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2749   auto* src = static_cast<int16_t*>(src_buffer);
2750   const int tx_width = kTransformWidth[tx_size];
2751 
2752   // Special case: Process row calculations during column transform call.
2753   if (tx_type == kTransformTypeIdentityIdentity &&
2754       (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) {
2755     Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width,
2756                                    adjusted_tx_height, src);
2757     return;
2758   }
2759 
2760   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2761     FlipColumns<4>(src, tx_width);
2762   }
2763 
2764   Identity4ColumnStoreToFrame(frame, start_x, start_y, tx_width,
2765                               adjusted_tx_height, src);
2766 }
2767 
Identity8TransformLoopRow_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2768 void Identity8TransformLoopRow_SSE4_1(TransformType tx_type,
2769                                       TransformSize tx_size,
2770                                       int adjusted_tx_height, void* src_buffer,
2771                                       int /*start_x*/, int /*start_y*/,
2772                                       void* /*dst_frame*/) {
2773   // Special case: Process row calculations during column transform call.
2774   // Improves performance.
2775   if (tx_type == kTransformTypeIdentityIdentity &&
2776       tx_size == kTransformSize8x4) {
2777     return;
2778   }
2779 
2780   auto* src = static_cast<int16_t*>(src_buffer);
2781   const int tx_height = kTransformHeight[tx_size];
2782   const bool should_round = kShouldRound[tx_size];
2783   const uint8_t row_shift = kTransformRowShift[tx_size];
2784   if (Identity8DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2785     return;
2786   }
2787 
2788   if (should_round) {
2789     ApplyRounding<8>(src, adjusted_tx_height);
2790   }
2791 
2792   // When combining the identity8 multiplier with the row shift, the
2793   // calculations for tx_height == 8 and tx_height == 16 can be simplified
2794   // from ((A * 2) + 1) >> 1) to A.
2795   if ((tx_height & 0x18) != 0) {
2796     return;
2797   }
2798   if (tx_height == 32) {
2799     int i = 0;
2800     do {
2801       Identity8Row32_SSE4_1(&src[i * 8], /*step=*/8);
2802       i += 4;
2803     } while (i < adjusted_tx_height);
2804     return;
2805   }
2806 
2807   assert(tx_size == kTransformSize8x4);
2808   int i = 0;
2809   do {
2810     Identity8Row4_SSE4_1(&src[i * 8], /*step=*/8);
2811     i += 4;
2812   } while (i < adjusted_tx_height);
2813 }
2814 
Identity8TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2815 void Identity8TransformLoopColumn_SSE4_1(TransformType tx_type,
2816                                          TransformSize tx_size,
2817                                          int adjusted_tx_height,
2818                                          void* LIBGAV1_RESTRICT src_buffer,
2819                                          int start_x, int start_y,
2820                                          void* LIBGAV1_RESTRICT dst_frame) {
2821   auto* src = static_cast<int16_t*>(src_buffer);
2822   const int tx_width = kTransformWidth[tx_size];
2823 
2824   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2825     FlipColumns<8>(src, tx_width);
2826   }
2827 
2828   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2829   Identity8ColumnStoreToFrame_SSE4_1(frame, start_x, start_y, tx_width,
2830                                      adjusted_tx_height, src);
2831 }
2832 
Identity16TransformLoopRow_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2833 void Identity16TransformLoopRow_SSE4_1(TransformType /*tx_type*/,
2834                                        TransformSize tx_size,
2835                                        int adjusted_tx_height, void* src_buffer,
2836                                        int /*start_x*/, int /*start_y*/,
2837                                        void* /*dst_frame*/) {
2838   auto* src = static_cast<int16_t*>(src_buffer);
2839   const bool should_round = kShouldRound[tx_size];
2840   const uint8_t row_shift = kTransformRowShift[tx_size];
2841   if (Identity16DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2842     return;
2843   }
2844 
2845   if (should_round) {
2846     ApplyRounding<16>(src, adjusted_tx_height);
2847   }
2848   int i = 0;
2849   do {
2850     Identity16Row_SSE4_1(&src[i * 16], /*step=*/16,
2851                          kTransformRowShift[tx_size]);
2852     i += 4;
2853   } while (i < adjusted_tx_height);
2854 }
2855 
Identity16TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2856 void Identity16TransformLoopColumn_SSE4_1(TransformType tx_type,
2857                                           TransformSize tx_size,
2858                                           int adjusted_tx_height,
2859                                           void* LIBGAV1_RESTRICT src_buffer,
2860                                           int start_x, int start_y,
2861                                           void* LIBGAV1_RESTRICT dst_frame) {
2862   auto* src = static_cast<int16_t*>(src_buffer);
2863   const int tx_width = kTransformWidth[tx_size];
2864 
2865   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2866     FlipColumns<16>(src, tx_width);
2867   }
2868   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2869   Identity16ColumnStoreToFrame_SSE4_1(frame, start_x, start_y, tx_width,
2870                                       adjusted_tx_height, src);
2871 }
2872 
Identity32TransformLoopRow_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2873 void Identity32TransformLoopRow_SSE4_1(TransformType /*tx_type*/,
2874                                        TransformSize tx_size,
2875                                        int adjusted_tx_height, void* src_buffer,
2876                                        int /*start_x*/, int /*start_y*/,
2877                                        void* /*dst_frame*/) {
2878   const int tx_height = kTransformHeight[tx_size];
2879   // When combining the identity32 multiplier with the row shift, the
2880   // calculations for tx_height == 8 and tx_height == 32 can be simplified
2881   // from ((A * 4) + 2) >> 2) to A.
2882   if ((tx_height & 0x28) != 0) {
2883     return;
2884   }
2885 
2886   // Process kTransformSize32x16. The src is always rounded before the
2887   // identity transform and shifted by 1 afterwards.
2888   auto* src = static_cast<int16_t*>(src_buffer);
2889   if (Identity32DcOnly(src, adjusted_tx_height)) {
2890     return;
2891   }
2892 
2893   assert(tx_size == kTransformSize32x16);
2894   ApplyRounding<32>(src, adjusted_tx_height);
2895   int i = 0;
2896   do {
2897     Identity32Row16_SSE4_1(&src[i * 32], /*step=*/32);
2898     i += 4;
2899   } while (i < adjusted_tx_height);
2900 }
2901 
Identity32TransformLoopColumn_SSE4_1(TransformType,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2902 void Identity32TransformLoopColumn_SSE4_1(TransformType /*tx_type*/,
2903                                           TransformSize tx_size,
2904                                           int adjusted_tx_height,
2905                                           void* LIBGAV1_RESTRICT src_buffer,
2906                                           int start_x, int start_y,
2907                                           void* LIBGAV1_RESTRICT dst_frame) {
2908   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2909   auto* src = static_cast<int16_t*>(src_buffer);
2910   const int tx_width = kTransformWidth[tx_size];
2911 
2912   Identity32ColumnStoreToFrame(frame, start_x, start_y, tx_width,
2913                                adjusted_tx_height, src);
2914 }
2915 
Wht4TransformLoopRow_SSE4_1(TransformType tx_type,TransformSize tx_size,int,void *,int,int,void *)2916 void Wht4TransformLoopRow_SSE4_1(TransformType tx_type, TransformSize tx_size,
2917                                  int /*adjusted_tx_height*/,
2918                                  void* /*src_buffer*/, int /*start_x*/,
2919                                  int /*start_y*/, void* /*dst_frame*/) {
2920   assert(tx_type == kTransformTypeDctDct);
2921   assert(tx_size == kTransformSize4x4);
2922   static_cast<void>(tx_type);
2923   static_cast<void>(tx_size);
2924   // Do both row and column transforms in the column-transform pass.
2925 }
2926 
Wht4TransformLoopColumn_SSE4_1(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2927 void Wht4TransformLoopColumn_SSE4_1(TransformType tx_type,
2928                                     TransformSize tx_size,
2929                                     int adjusted_tx_height,
2930                                     void* LIBGAV1_RESTRICT src_buffer,
2931                                     int start_x, int start_y,
2932                                     void* LIBGAV1_RESTRICT dst_frame) {
2933   assert(tx_type == kTransformTypeDctDct);
2934   assert(tx_size == kTransformSize4x4);
2935   static_cast<void>(tx_type);
2936   static_cast<void>(tx_size);
2937 
2938   // Do both row and column transforms in the column-transform pass.
2939   // Process 4 1d wht4 rows and columns in parallel.
2940   const auto* src = static_cast<int16_t*>(src_buffer);
2941   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2942   Wht4_SSE4_1(frame, start_x, start_y, src, adjusted_tx_height);
2943 }
2944 
2945 //------------------------------------------------------------------------------
2946 
Init8bpp()2947 void Init8bpp() {
2948   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
2949   assert(dsp != nullptr);
2950 
2951   // Maximum transform size for Dct is 64.
2952 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize4_Transform1dDct)
2953   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
2954       Dct4TransformLoopRow_SSE4_1;
2955   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
2956       Dct4TransformLoopColumn_SSE4_1;
2957 #endif
2958 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize8_Transform1dDct)
2959   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
2960       Dct8TransformLoopRow_SSE4_1;
2961   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
2962       Dct8TransformLoopColumn_SSE4_1;
2963 #endif
2964 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize16_Transform1dDct)
2965   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
2966       Dct16TransformLoopRow_SSE4_1;
2967   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
2968       Dct16TransformLoopColumn_SSE4_1;
2969 #endif
2970 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize32_Transform1dDct)
2971   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
2972       Dct32TransformLoopRow_SSE4_1;
2973   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
2974       Dct32TransformLoopColumn_SSE4_1;
2975 #endif
2976 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize64_Transform1dDct)
2977   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
2978       Dct64TransformLoopRow_SSE4_1;
2979   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
2980       Dct64TransformLoopColumn_SSE4_1;
2981 #endif
2982 
2983   // Maximum transform size for Adst is 16.
2984 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize4_Transform1dAdst)
2985   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
2986       Adst4TransformLoopRow_SSE4_1;
2987   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
2988       Adst4TransformLoopColumn_SSE4_1;
2989 #endif
2990 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize8_Transform1dAdst)
2991   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
2992       Adst8TransformLoopRow_SSE4_1;
2993   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
2994       Adst8TransformLoopColumn_SSE4_1;
2995 #endif
2996 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize16_Transform1dAdst)
2997   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
2998       Adst16TransformLoopRow_SSE4_1;
2999   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
3000       Adst16TransformLoopColumn_SSE4_1;
3001 #endif
3002 
3003   // Maximum transform size for Identity transform is 32.
3004 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize4_Transform1dIdentity)
3005   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
3006       Identity4TransformLoopRow_SSE4_1;
3007   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
3008       Identity4TransformLoopColumn_SSE4_1;
3009 #endif
3010 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize8_Transform1dIdentity)
3011   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
3012       Identity8TransformLoopRow_SSE4_1;
3013   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
3014       Identity8TransformLoopColumn_SSE4_1;
3015 #endif
3016 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize16_Transform1dIdentity)
3017   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
3018       Identity16TransformLoopRow_SSE4_1;
3019   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
3020       Identity16TransformLoopColumn_SSE4_1;
3021 #endif
3022 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize32_Transform1dIdentity)
3023   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
3024       Identity32TransformLoopRow_SSE4_1;
3025   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
3026       Identity32TransformLoopColumn_SSE4_1;
3027 #endif
3028 
3029   // Maximum transform size for Wht is 4.
3030 #if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize4_Transform1dWht)
3031   dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
3032       Wht4TransformLoopRow_SSE4_1;
3033   dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
3034       Wht4TransformLoopColumn_SSE4_1;
3035 #endif
3036 }
3037 
3038 }  // namespace
3039 }  // namespace low_bitdepth
3040 
InverseTransformInit_SSE4_1()3041 void InverseTransformInit_SSE4_1() { low_bitdepth::Init8bpp(); }
3042 
3043 }  // namespace dsp
3044 }  // namespace libgav1
3045 #else   // !LIBGAV1_TARGETING_SSE4_1
3046 namespace libgav1 {
3047 namespace dsp {
3048 
InverseTransformInit_SSE4_1()3049 void InverseTransformInit_SSE4_1() {}
3050 
3051 }  // namespace dsp
3052 }  // namespace libgav1
3053 #endif  // LIBGAV1_TARGETING_SSE4_1
3054