xref: /aosp_15_r20/external/libgav1/src/dsp/x86/convolve_avx2.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/convolve.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_AVX2
19 #include <immintrin.h>
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cstdint>
24 #include <cstring>
25 
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_avx2.h"
29 #include "src/utils/common.h"
30 #include "src/utils/compiler_attributes.h"
31 #include "src/utils/constants.h"
32 
33 namespace libgav1 {
34 namespace dsp {
35 namespace low_bitdepth {
36 namespace {
37 
38 #include "src/dsp/x86/convolve_sse4.inc"
39 
40 // Multiply every entry in |src[]| by the corresponding entry in |taps[]| and
41 // sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final
42 // sum from outranging int16_t.
43 template <int num_taps>
SumOnePassTaps(const __m256i * const src,const __m256i * const taps)44 __m256i SumOnePassTaps(const __m256i* const src, const __m256i* const taps) {
45   __m256i sum;
46   if (num_taps == 6) {
47     // 6 taps.
48     const __m256i v_madd_21 = _mm256_maddubs_epi16(src[0], taps[0]);  // k2k1
49     const __m256i v_madd_43 = _mm256_maddubs_epi16(src[1], taps[1]);  // k4k3
50     const __m256i v_madd_65 = _mm256_maddubs_epi16(src[2], taps[2]);  // k6k5
51     sum = _mm256_add_epi16(v_madd_21, v_madd_43);
52     sum = _mm256_add_epi16(sum, v_madd_65);
53   } else if (num_taps == 8) {
54     // 8 taps.
55     const __m256i v_madd_10 = _mm256_maddubs_epi16(src[0], taps[0]);  // k1k0
56     const __m256i v_madd_32 = _mm256_maddubs_epi16(src[1], taps[1]);  // k3k2
57     const __m256i v_madd_54 = _mm256_maddubs_epi16(src[2], taps[2]);  // k5k4
58     const __m256i v_madd_76 = _mm256_maddubs_epi16(src[3], taps[3]);  // k7k6
59     const __m256i v_sum_3210 = _mm256_add_epi16(v_madd_10, v_madd_32);
60     const __m256i v_sum_7654 = _mm256_add_epi16(v_madd_54, v_madd_76);
61     sum = _mm256_add_epi16(v_sum_7654, v_sum_3210);
62   } else if (num_taps == 2) {
63     // 2 taps.
64     sum = _mm256_maddubs_epi16(src[0], taps[0]);  // k4k3
65   } else {
66     // 4 taps.
67     const __m256i v_madd_32 = _mm256_maddubs_epi16(src[0], taps[0]);  // k3k2
68     const __m256i v_madd_54 = _mm256_maddubs_epi16(src[1], taps[1]);  // k5k4
69     sum = _mm256_add_epi16(v_madd_32, v_madd_54);
70   }
71   return sum;
72 }
73 
74 template <int num_taps>
SumHorizontalTaps(const __m256i * const src,const __m256i * const v_tap)75 __m256i SumHorizontalTaps(const __m256i* const src,
76                           const __m256i* const v_tap) {
77   __m256i v_src[4];
78   const __m256i src_long = *src;
79   const __m256i src_long_dup_lo = _mm256_unpacklo_epi8(src_long, src_long);
80   const __m256i src_long_dup_hi = _mm256_unpackhi_epi8(src_long, src_long);
81 
82   if (num_taps == 6) {
83     // 6 taps.
84     v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 3);   // _21
85     v_src[1] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7);   // _43
86     v_src[2] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 11);  // _65
87   } else if (num_taps == 8) {
88     // 8 taps.
89     v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 1);   // _10
90     v_src[1] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5);   // _32
91     v_src[2] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9);   // _54
92     v_src[3] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 13);  // _76
93   } else if (num_taps == 2) {
94     // 2 taps.
95     v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7);  // _43
96   } else {
97     // 4 taps.
98     v_src[0] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5);  // _32
99     v_src[1] = _mm256_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9);  // _54
100   }
101   return SumOnePassTaps<num_taps>(v_src, v_tap);
102 }
103 
104 template <int num_taps>
SimpleHorizontalTaps(const __m256i * const src,const __m256i * const v_tap)105 __m256i SimpleHorizontalTaps(const __m256i* const src,
106                              const __m256i* const v_tap) {
107   __m256i sum = SumHorizontalTaps<num_taps>(src, v_tap);
108 
109   // Normally the Horizontal pass does the downshift in two passes:
110   // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
111   // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them
112   // requires adding the rounding offset from the skipped shift.
113   constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2);
114 
115   sum = _mm256_add_epi16(sum, _mm256_set1_epi16(first_shift_rounding_bit));
116   sum = RightShiftWithRounding_S16(sum, kFilterBits - 1);
117   return _mm256_packus_epi16(sum, sum);
118 }
119 
120 template <int num_taps>
HorizontalTaps8To16(const __m256i * const src,const __m256i * const v_tap)121 __m256i HorizontalTaps8To16(const __m256i* const src,
122                             const __m256i* const v_tap) {
123   const __m256i sum = SumHorizontalTaps<num_taps>(src, v_tap);
124 
125   return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
126 }
127 
128 // Filter 2xh sizes.
129 template <int num_taps, bool is_2d = false, bool is_compound = false>
FilterHorizontal(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int,const int height,const __m128i * const v_tap)130 void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
131                       const ptrdiff_t src_stride,
132                       void* LIBGAV1_RESTRICT const dest,
133                       const ptrdiff_t pred_stride, const int /*width*/,
134                       const int height, const __m128i* const v_tap) {
135   auto* dest8 = static_cast<uint8_t*>(dest);
136   auto* dest16 = static_cast<uint16_t*>(dest);
137 
138   // Horizontal passes only need to account for |num_taps| 2 and 4 when
139   // |width| <= 4.
140   assert(num_taps <= 4);
141   if (num_taps <= 4) {
142     if (!is_compound) {
143       int y = height;
144       if (is_2d) y -= 1;
145       do {
146         if (is_2d) {
147           const __m128i sum =
148               HorizontalTaps8To16_2x2<num_taps>(src, src_stride, v_tap);
149           Store4(&dest16[0], sum);
150           dest16 += pred_stride;
151           Store4(&dest16[0], _mm_srli_si128(sum, 8));
152           dest16 += pred_stride;
153         } else {
154           const __m128i sum =
155               SimpleHorizontalTaps2x2<num_taps>(src, src_stride, v_tap);
156           Store2(dest8, sum);
157           dest8 += pred_stride;
158           Store2(dest8, _mm_srli_si128(sum, 4));
159           dest8 += pred_stride;
160         }
161 
162         src += src_stride << 1;
163         y -= 2;
164       } while (y != 0);
165 
166       // The 2d filters have an odd |height| because the horizontal pass
167       // generates context for the vertical pass.
168       if (is_2d) {
169         assert(height % 2 == 1);
170         __m128i sum;
171         const __m128i input = LoadLo8(&src[2]);
172         if (num_taps == 2) {
173           // 03 04 04 05 05 06 06 07 ....
174           const __m128i v_src_43 =
175               _mm_srli_si128(_mm_unpacklo_epi8(input, input), 3);
176           sum = _mm_maddubs_epi16(v_src_43, v_tap[0]);  // k4k3
177         } else {
178           // 02 03 03 04 04 05 05 06 06 07 ....
179           const __m128i v_src_32 =
180               _mm_srli_si128(_mm_unpacklo_epi8(input, input), 1);
181           // 04 05 05 06 06 07 07 08 ...
182           const __m128i v_src_54 = _mm_srli_si128(v_src_32, 4);
183           const __m128i v_madd_32 =
184               _mm_maddubs_epi16(v_src_32, v_tap[0]);  // k3k2
185           const __m128i v_madd_54 =
186               _mm_maddubs_epi16(v_src_54, v_tap[1]);  // k5k4
187           sum = _mm_add_epi16(v_madd_54, v_madd_32);
188         }
189         sum = RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
190         Store4(dest16, sum);
191       }
192     }
193   }
194 }
195 
196 // Filter widths >= 4.
197 template <int num_taps, bool is_2d = false, bool is_compound = false>
FilterHorizontal(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int width,const int height,const __m256i * const v_tap)198 void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
199                       const ptrdiff_t src_stride,
200                       void* LIBGAV1_RESTRICT const dest,
201                       const ptrdiff_t pred_stride, const int width,
202                       const int height, const __m256i* const v_tap) {
203   auto* dest8 = static_cast<uint8_t*>(dest);
204   auto* dest16 = static_cast<uint16_t*>(dest);
205 
206   if (width >= 32) {
207     int y = height;
208     do {
209       int x = 0;
210       do {
211         if (is_2d || is_compound) {
212           // Load into 2 128 bit lanes.
213           const __m256i src_long =
214               SetrM128i(LoadUnaligned16(&src[x]), LoadUnaligned16(&src[x + 8]));
215           const __m256i result =
216               HorizontalTaps8To16<num_taps>(&src_long, v_tap);
217           const __m256i src_long2 = SetrM128i(LoadUnaligned16(&src[x + 16]),
218                                               LoadUnaligned16(&src[x + 24]));
219           const __m256i result2 =
220               HorizontalTaps8To16<num_taps>(&src_long2, v_tap);
221           if (is_2d) {
222             StoreAligned32(&dest16[x], result);
223             StoreAligned32(&dest16[x + 16], result2);
224           } else {
225             StoreUnaligned32(&dest16[x], result);
226             StoreUnaligned32(&dest16[x + 16], result2);
227           }
228         } else {
229           // Load src used to calculate dest8[7:0] and dest8[23:16].
230           const __m256i src_long = LoadUnaligned32(&src[x]);
231           const __m256i result =
232               SimpleHorizontalTaps<num_taps>(&src_long, v_tap);
233           // Load src used to calculate dest8[15:8] and dest8[31:24].
234           const __m256i src_long2 = LoadUnaligned32(&src[x + 8]);
235           const __m256i result2 =
236               SimpleHorizontalTaps<num_taps>(&src_long2, v_tap);
237           // Combine results and store.
238           StoreUnaligned32(&dest8[x], _mm256_unpacklo_epi64(result, result2));
239         }
240         x += 32;
241       } while (x < width);
242       src += src_stride;
243       dest8 += pred_stride;
244       dest16 += pred_stride;
245     } while (--y != 0);
246   } else if (width == 16) {
247     int y = height;
248     if (is_2d) y -= 1;
249     do {
250       if (is_2d || is_compound) {
251         // Load into 2 128 bit lanes.
252         const __m256i src_long =
253             SetrM128i(LoadUnaligned16(&src[0]), LoadUnaligned16(&src[8]));
254         const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap);
255         const __m256i src_long2 =
256             SetrM128i(LoadUnaligned16(&src[src_stride]),
257                       LoadUnaligned16(&src[8 + src_stride]));
258         const __m256i result2 =
259             HorizontalTaps8To16<num_taps>(&src_long2, v_tap);
260         if (is_2d) {
261           StoreAligned32(&dest16[0], result);
262           StoreAligned32(&dest16[pred_stride], result2);
263         } else {
264           StoreUnaligned32(&dest16[0], result);
265           StoreUnaligned32(&dest16[pred_stride], result2);
266         }
267       } else {
268         // Load into 2 128 bit lanes.
269         const __m256i src_long = SetrM128i(LoadUnaligned16(&src[0]),
270                                            LoadUnaligned16(&src[src_stride]));
271         const __m256i result = SimpleHorizontalTaps<num_taps>(&src_long, v_tap);
272         const __m256i src_long2 = SetrM128i(
273             LoadUnaligned16(&src[8]), LoadUnaligned16(&src[8 + src_stride]));
274         const __m256i result2 =
275             SimpleHorizontalTaps<num_taps>(&src_long2, v_tap);
276         const __m256i packed_result = _mm256_unpacklo_epi64(result, result2);
277         StoreUnaligned16(&dest8[0], _mm256_castsi256_si128(packed_result));
278         StoreUnaligned16(&dest8[pred_stride],
279                          _mm256_extracti128_si256(packed_result, 1));
280       }
281       src += src_stride * 2;
282       dest8 += pred_stride * 2;
283       dest16 += pred_stride * 2;
284       y -= 2;
285     } while (y != 0);
286 
287     // The 2d filters have an odd |height| during the horizontal pass, so
288     // filter the remaining row.
289     if (is_2d) {
290       const __m256i src_long =
291           SetrM128i(LoadUnaligned16(&src[0]), LoadUnaligned16(&src[8]));
292       const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap);
293       StoreAligned32(&dest16[0], result);
294     }
295 
296   } else if (width == 8) {
297     int y = height;
298     if (is_2d) y -= 1;
299     do {
300       // Load into 2 128 bit lanes.
301       const __m128i this_row = LoadUnaligned16(&src[0]);
302       const __m128i next_row = LoadUnaligned16(&src[src_stride]);
303       const __m256i src_long = SetrM128i(this_row, next_row);
304       if (is_2d || is_compound) {
305         const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap);
306         if (is_2d) {
307           StoreAligned16(&dest16[0], _mm256_castsi256_si128(result));
308           StoreAligned16(&dest16[pred_stride],
309                          _mm256_extracti128_si256(result, 1));
310         } else {
311           StoreUnaligned16(&dest16[0], _mm256_castsi256_si128(result));
312           StoreUnaligned16(&dest16[pred_stride],
313                            _mm256_extracti128_si256(result, 1));
314         }
315       } else {
316         const __m128i this_row = LoadUnaligned16(&src[0]);
317         const __m128i next_row = LoadUnaligned16(&src[src_stride]);
318         // Load into 2 128 bit lanes.
319         const __m256i src_long = SetrM128i(this_row, next_row);
320         const __m256i result = SimpleHorizontalTaps<num_taps>(&src_long, v_tap);
321         StoreLo8(&dest8[0], _mm256_castsi256_si128(result));
322         StoreLo8(&dest8[pred_stride], _mm256_extracti128_si256(result, 1));
323       }
324       src += src_stride * 2;
325       dest8 += pred_stride * 2;
326       dest16 += pred_stride * 2;
327       y -= 2;
328     } while (y != 0);
329 
330     // The 2d filters have an odd |height| during the horizontal pass, so
331     // filter the remaining row.
332     if (is_2d) {
333       const __m256i src_long = _mm256_castsi128_si256(LoadUnaligned16(&src[0]));
334       const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap);
335       StoreAligned16(&dest16[0], _mm256_castsi256_si128(result));
336     }
337 
338   } else {  // width == 4
339     int y = height;
340     if (is_2d) y -= 1;
341     do {
342       // Load into 2 128 bit lanes.
343       const __m128i this_row = LoadUnaligned16(&src[0]);
344       const __m128i next_row = LoadUnaligned16(&src[src_stride]);
345       const __m256i src_long = SetrM128i(this_row, next_row);
346       if (is_2d || is_compound) {
347         const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap);
348         StoreLo8(&dest16[0], _mm256_castsi256_si128(result));
349         StoreLo8(&dest16[pred_stride], _mm256_extracti128_si256(result, 1));
350       } else {
351         const __m128i this_row = LoadUnaligned16(&src[0]);
352         const __m128i next_row = LoadUnaligned16(&src[src_stride]);
353         // Load into 2 128 bit lanes.
354         const __m256i src_long = SetrM128i(this_row, next_row);
355         const __m256i result = SimpleHorizontalTaps<num_taps>(&src_long, v_tap);
356         Store4(&dest8[0], _mm256_castsi256_si128(result));
357         Store4(&dest8[pred_stride], _mm256_extracti128_si256(result, 1));
358       }
359       src += src_stride * 2;
360       dest8 += pred_stride * 2;
361       dest16 += pred_stride * 2;
362       y -= 2;
363     } while (y != 0);
364 
365     // The 2d filters have an odd |height| during the horizontal pass, so
366     // filter the remaining row.
367     if (is_2d) {
368       const __m256i src_long = _mm256_castsi128_si256(LoadUnaligned16(&src[0]));
369       const __m256i result = HorizontalTaps8To16<num_taps>(&src_long, v_tap);
370       StoreLo8(&dest16[0], _mm256_castsi256_si128(result));
371     }
372   }
373 }
374 
375 template <int num_taps, bool is_2d_vertical = false>
SetupTaps(const __m128i * const filter,__m256i * v_tap)376 LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter,
377                                      __m256i* v_tap) {
378   if (num_taps == 8) {
379     if (is_2d_vertical) {
380       v_tap[0] = _mm256_broadcastd_epi32(*filter);                      // k1k0
381       v_tap[1] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 4));   // k3k2
382       v_tap[2] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 8));   // k5k4
383       v_tap[3] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 12));  // k7k6
384     } else {
385       v_tap[0] = _mm256_broadcastw_epi16(*filter);                     // k1k0
386       v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2));  // k3k2
387       v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4));  // k5k4
388       v_tap[3] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 6));  // k7k6
389     }
390   } else if (num_taps == 6) {
391     if (is_2d_vertical) {
392       v_tap[0] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 2));   // k2k1
393       v_tap[1] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 6));   // k4k3
394       v_tap[2] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 10));  // k6k5
395     } else {
396       v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 1));  // k2k1
397       v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3));  // k4k3
398       v_tap[2] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 5));  // k6k5
399     }
400   } else if (num_taps == 4) {
401     if (is_2d_vertical) {
402       v_tap[0] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 4));  // k3k2
403       v_tap[1] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 8));  // k5k4
404     } else {
405       v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 2));  // k3k2
406       v_tap[1] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 4));  // k5k4
407     }
408   } else {  // num_taps == 2
409     if (is_2d_vertical) {
410       v_tap[0] = _mm256_broadcastd_epi32(_mm_srli_si128(*filter, 6));  // k4k3
411     } else {
412       v_tap[0] = _mm256_broadcastw_epi16(_mm_srli_si128(*filter, 3));  // k4k3
413     }
414   }
415 }
416 
417 template <int num_taps, bool is_compound>
SimpleSum2DVerticalTaps(const __m256i * const src,const __m256i * const taps)418 __m256i SimpleSum2DVerticalTaps(const __m256i* const src,
419                                 const __m256i* const taps) {
420   __m256i sum_lo =
421       _mm256_madd_epi16(_mm256_unpacklo_epi16(src[0], src[1]), taps[0]);
422   __m256i sum_hi =
423       _mm256_madd_epi16(_mm256_unpackhi_epi16(src[0], src[1]), taps[0]);
424   if (num_taps >= 4) {
425     __m256i madd_lo =
426         _mm256_madd_epi16(_mm256_unpacklo_epi16(src[2], src[3]), taps[1]);
427     __m256i madd_hi =
428         _mm256_madd_epi16(_mm256_unpackhi_epi16(src[2], src[3]), taps[1]);
429     sum_lo = _mm256_add_epi32(sum_lo, madd_lo);
430     sum_hi = _mm256_add_epi32(sum_hi, madd_hi);
431     if (num_taps >= 6) {
432       madd_lo =
433           _mm256_madd_epi16(_mm256_unpacklo_epi16(src[4], src[5]), taps[2]);
434       madd_hi =
435           _mm256_madd_epi16(_mm256_unpackhi_epi16(src[4], src[5]), taps[2]);
436       sum_lo = _mm256_add_epi32(sum_lo, madd_lo);
437       sum_hi = _mm256_add_epi32(sum_hi, madd_hi);
438       if (num_taps == 8) {
439         madd_lo =
440             _mm256_madd_epi16(_mm256_unpacklo_epi16(src[6], src[7]), taps[3]);
441         madd_hi =
442             _mm256_madd_epi16(_mm256_unpackhi_epi16(src[6], src[7]), taps[3]);
443         sum_lo = _mm256_add_epi32(sum_lo, madd_lo);
444         sum_hi = _mm256_add_epi32(sum_hi, madd_hi);
445       }
446     }
447   }
448 
449   if (is_compound) {
450     return _mm256_packs_epi32(
451         RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1),
452         RightShiftWithRounding_S32(sum_hi,
453                                    kInterRoundBitsCompoundVertical - 1));
454   }
455 
456   return _mm256_packs_epi32(
457       RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1),
458       RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1));
459 }
460 
461 template <int num_taps, bool is_compound = false>
Filter2DVertical16xH(const uint16_t * LIBGAV1_RESTRICT src,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const __m256i * const taps)462 void Filter2DVertical16xH(const uint16_t* LIBGAV1_RESTRICT src,
463                           void* LIBGAV1_RESTRICT const dst,
464                           const ptrdiff_t dst_stride, const int width,
465                           const int height, const __m256i* const taps) {
466   assert(width >= 8);
467   constexpr int next_row = num_taps - 1;
468   // The Horizontal pass uses |width| as |stride| for the intermediate buffer.
469   const ptrdiff_t src_stride = width;
470 
471   auto* dst8 = static_cast<uint8_t*>(dst);
472   auto* dst16 = static_cast<uint16_t*>(dst);
473 
474   int x = 0;
475   do {
476     __m256i srcs[8];
477     const uint16_t* src_x = src + x;
478     srcs[0] = LoadAligned32(src_x);
479     src_x += src_stride;
480     if (num_taps >= 4) {
481       srcs[1] = LoadAligned32(src_x);
482       src_x += src_stride;
483       srcs[2] = LoadAligned32(src_x);
484       src_x += src_stride;
485       if (num_taps >= 6) {
486         srcs[3] = LoadAligned32(src_x);
487         src_x += src_stride;
488         srcs[4] = LoadAligned32(src_x);
489         src_x += src_stride;
490         if (num_taps == 8) {
491           srcs[5] = LoadAligned32(src_x);
492           src_x += src_stride;
493           srcs[6] = LoadAligned32(src_x);
494           src_x += src_stride;
495         }
496       }
497     }
498 
499     auto* dst8_x = dst8 + x;
500     auto* dst16_x = dst16 + x;
501     int y = height;
502     do {
503       srcs[next_row] = LoadAligned32(src_x);
504       src_x += src_stride;
505 
506       const __m256i sum =
507           SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
508       if (is_compound) {
509         StoreUnaligned32(dst16_x, sum);
510         dst16_x += dst_stride;
511       } else {
512         const __m128i packed_sum = _mm_packus_epi16(
513             _mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
514         StoreUnaligned16(dst8_x, packed_sum);
515         dst8_x += dst_stride;
516       }
517 
518       srcs[0] = srcs[1];
519       if (num_taps >= 4) {
520         srcs[1] = srcs[2];
521         srcs[2] = srcs[3];
522         if (num_taps >= 6) {
523           srcs[3] = srcs[4];
524           srcs[4] = srcs[5];
525           if (num_taps == 8) {
526             srcs[5] = srcs[6];
527             srcs[6] = srcs[7];
528           }
529         }
530       }
531     } while (--y != 0);
532     x += 16;
533   } while (x < width);
534 }
535 
536 template <bool is_2d = false, bool is_compound = false>
DoHorizontalPass2xH(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const int filter_id,const int filter_index)537 LIBGAV1_ALWAYS_INLINE void DoHorizontalPass2xH(
538     const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
539     void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride,
540     const int width, const int height, const int filter_id,
541     const int filter_index) {
542   assert(filter_id != 0);
543   __m128i v_tap[4];
544   const __m128i v_horizontal_filter =
545       LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]);
546 
547   if ((filter_index & 0x4) != 0) {  // 4 tap.
548     // ((filter_index == 4) | (filter_index == 5))
549     SetupTaps<4>(&v_horizontal_filter, v_tap);
550     FilterHorizontal<4, is_2d, is_compound>(src, src_stride, dst, dst_stride,
551                                             width, height, v_tap);
552   } else {  // 2 tap.
553     SetupTaps<2>(&v_horizontal_filter, v_tap);
554     FilterHorizontal<2, is_2d, is_compound>(src, src_stride, dst, dst_stride,
555                                             width, height, v_tap);
556   }
557 }
558 
559 template <bool is_2d = false, bool is_compound = false>
DoHorizontalPass(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const int filter_id,const int filter_index)560 LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
561     const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
562     void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride,
563     const int width, const int height, const int filter_id,
564     const int filter_index) {
565   assert(filter_id != 0);
566   __m256i v_tap[4];
567   const __m128i v_horizontal_filter =
568       LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]);
569 
570   if (filter_index == 2) {  // 8 tap.
571     SetupTaps<8>(&v_horizontal_filter, v_tap);
572     FilterHorizontal<8, is_2d, is_compound>(src, src_stride, dst, dst_stride,
573                                             width, height, v_tap);
574   } else if (filter_index == 1) {  // 6 tap.
575     SetupTaps<6>(&v_horizontal_filter, v_tap);
576     FilterHorizontal<6, is_2d, is_compound>(src, src_stride, dst, dst_stride,
577                                             width, height, v_tap);
578   } else if (filter_index == 0) {  // 6 tap.
579     SetupTaps<6>(&v_horizontal_filter, v_tap);
580     FilterHorizontal<6, is_2d, is_compound>(src, src_stride, dst, dst_stride,
581                                             width, height, v_tap);
582   } else if ((filter_index & 0x4) != 0) {  // 4 tap.
583     // ((filter_index == 4) | (filter_index == 5))
584     SetupTaps<4>(&v_horizontal_filter, v_tap);
585     FilterHorizontal<4, is_2d, is_compound>(src, src_stride, dst, dst_stride,
586                                             width, height, v_tap);
587   } else {  // 2 tap.
588     SetupTaps<2>(&v_horizontal_filter, v_tap);
589     FilterHorizontal<2, is_2d, is_compound>(src, src_stride, dst, dst_stride,
590                                             width, height, v_tap);
591   }
592 }
593 
Convolve2D_AVX2(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int horizontal_filter_id,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t pred_stride)594 void Convolve2D_AVX2(const void* LIBGAV1_RESTRICT const reference,
595                      const ptrdiff_t reference_stride,
596                      const int horizontal_filter_index,
597                      const int vertical_filter_index,
598                      const int horizontal_filter_id,
599                      const int vertical_filter_id, const int width,
600                      const int height, void* LIBGAV1_RESTRICT prediction,
601                      const ptrdiff_t pred_stride) {
602   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
603   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
604   const int vertical_taps =
605       GetNumTapsInFilter(vert_filter_index, vertical_filter_id);
606 
607   // The output of the horizontal filter is guaranteed to fit in 16 bits.
608   alignas(32) uint16_t
609       intermediate_result[kMaxSuperBlockSizeInPixels *
610                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
611 #if LIBGAV1_MSAN
612   // Quiet msan warnings. Set with random non-zero value to aid in debugging.
613   memset(intermediate_result, 0x33, sizeof(intermediate_result));
614 #endif
615   const int intermediate_height = height + vertical_taps - 1;
616 
617   const ptrdiff_t src_stride = reference_stride;
618   const auto* src = static_cast<const uint8_t*>(reference) -
619                     (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
620   if (width > 2) {
621     DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result,
622                                      width, width, intermediate_height,
623                                      horizontal_filter_id, horiz_filter_index);
624   } else {
625     // Use non avx2 version for smaller widths.
626     DoHorizontalPass2xH</*is_2d=*/true>(
627         src, src_stride, intermediate_result, width, width, intermediate_height,
628         horizontal_filter_id, horiz_filter_index);
629   }
630 
631   // Vertical filter.
632   auto* dest = static_cast<uint8_t*>(prediction);
633   const ptrdiff_t dest_stride = pred_stride;
634   assert(vertical_filter_id != 0);
635 
636   const __m128i v_filter =
637       LoadLo8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]);
638 
639   // Use 256 bits for width > 8.
640   if (width > 8) {
641     __m256i taps_256[4];
642     const __m128i v_filter_ext = _mm_cvtepi8_epi16(v_filter);
643 
644     if (vertical_taps == 8) {
645       SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256);
646       Filter2DVertical16xH<8>(intermediate_result, dest, dest_stride, width,
647                               height, taps_256);
648     } else if (vertical_taps == 6) {
649       SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256);
650       Filter2DVertical16xH<6>(intermediate_result, dest, dest_stride, width,
651                               height, taps_256);
652     } else if (vertical_taps == 4) {
653       SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256);
654       Filter2DVertical16xH<4>(intermediate_result, dest, dest_stride, width,
655                               height, taps_256);
656     } else {  // |vertical_taps| == 2
657       SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256);
658       Filter2DVertical16xH<2>(intermediate_result, dest, dest_stride, width,
659                               height, taps_256);
660     }
661   } else {  // width <= 8
662     __m128i taps[4];
663     // Use 128 bit code.
664     if (vertical_taps == 8) {
665       SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps);
666       if (width == 2) {
667         Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height,
668                                taps);
669       } else if (width == 4) {
670         Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height,
671                                taps);
672       } else {
673         Filter2DVertical<8>(intermediate_result, dest, dest_stride, width,
674                             height, taps);
675       }
676     } else if (vertical_taps == 6) {
677       SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps);
678       if (width == 2) {
679         Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height,
680                                taps);
681       } else if (width == 4) {
682         Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height,
683                                taps);
684       } else {
685         Filter2DVertical<6>(intermediate_result, dest, dest_stride, width,
686                             height, taps);
687       }
688     } else if (vertical_taps == 4) {
689       SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps);
690       if (width == 2) {
691         Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height,
692                                taps);
693       } else if (width == 4) {
694         Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height,
695                                taps);
696       } else {
697         Filter2DVertical<4>(intermediate_result, dest, dest_stride, width,
698                             height, taps);
699       }
700     } else {  // |vertical_taps| == 2
701       SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps);
702       if (width == 2) {
703         Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height,
704                                taps);
705       } else if (width == 4) {
706         Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height,
707                                taps);
708       } else {
709         Filter2DVertical<2>(intermediate_result, dest, dest_stride, width,
710                             height, taps);
711       }
712     }
713   }
714 }
715 
716 // The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D
717 // Vertical calculations.
Compound1DShift(const __m256i sum)718 __m256i Compound1DShift(const __m256i sum) {
719   return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
720 }
721 
722 template <int num_taps, bool unpack_high = false>
SumVerticalTaps(const __m256i * const srcs,const __m256i * const v_tap)723 __m256i SumVerticalTaps(const __m256i* const srcs, const __m256i* const v_tap) {
724   __m256i v_src[4];
725 
726   if (!unpack_high) {
727     if (num_taps == 6) {
728       // 6 taps.
729       v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]);
730       v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]);
731       v_src[2] = _mm256_unpacklo_epi8(srcs[4], srcs[5]);
732     } else if (num_taps == 8) {
733       // 8 taps.
734       v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]);
735       v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]);
736       v_src[2] = _mm256_unpacklo_epi8(srcs[4], srcs[5]);
737       v_src[3] = _mm256_unpacklo_epi8(srcs[6], srcs[7]);
738     } else if (num_taps == 2) {
739       // 2 taps.
740       v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]);
741     } else {
742       // 4 taps.
743       v_src[0] = _mm256_unpacklo_epi8(srcs[0], srcs[1]);
744       v_src[1] = _mm256_unpacklo_epi8(srcs[2], srcs[3]);
745     }
746   } else {
747     if (num_taps == 6) {
748       // 6 taps.
749       v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]);
750       v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]);
751       v_src[2] = _mm256_unpackhi_epi8(srcs[4], srcs[5]);
752     } else if (num_taps == 8) {
753       // 8 taps.
754       v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]);
755       v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]);
756       v_src[2] = _mm256_unpackhi_epi8(srcs[4], srcs[5]);
757       v_src[3] = _mm256_unpackhi_epi8(srcs[6], srcs[7]);
758     } else if (num_taps == 2) {
759       // 2 taps.
760       v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]);
761     } else {
762       // 4 taps.
763       v_src[0] = _mm256_unpackhi_epi8(srcs[0], srcs[1]);
764       v_src[1] = _mm256_unpackhi_epi8(srcs[2], srcs[3]);
765     }
766   }
767   return SumOnePassTaps<num_taps>(v_src, v_tap);
768 }
769 
770 template <int num_taps, bool is_compound = false>
FilterVertical32xH(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const __m256i * const v_tap)771 void FilterVertical32xH(const uint8_t* LIBGAV1_RESTRICT src,
772                         const ptrdiff_t src_stride,
773                         void* LIBGAV1_RESTRICT const dst,
774                         const ptrdiff_t dst_stride, const int width,
775                         const int height, const __m256i* const v_tap) {
776   const int next_row = num_taps - 1;
777   auto* dst8 = static_cast<uint8_t*>(dst);
778   auto* dst16 = static_cast<uint16_t*>(dst);
779   assert(width >= 32);
780   int x = 0;
781   do {
782     const uint8_t* src_x = src + x;
783     __m256i srcs[8];
784     srcs[0] = LoadUnaligned32(src_x);
785     src_x += src_stride;
786     if (num_taps >= 4) {
787       srcs[1] = LoadUnaligned32(src_x);
788       src_x += src_stride;
789       srcs[2] = LoadUnaligned32(src_x);
790       src_x += src_stride;
791       if (num_taps >= 6) {
792         srcs[3] = LoadUnaligned32(src_x);
793         src_x += src_stride;
794         srcs[4] = LoadUnaligned32(src_x);
795         src_x += src_stride;
796         if (num_taps == 8) {
797           srcs[5] = LoadUnaligned32(src_x);
798           src_x += src_stride;
799           srcs[6] = LoadUnaligned32(src_x);
800           src_x += src_stride;
801         }
802       }
803     }
804 
805     auto* dst8_x = dst8 + x;
806     auto* dst16_x = dst16 + x;
807     int y = height;
808     do {
809       srcs[next_row] = LoadUnaligned32(src_x);
810       src_x += src_stride;
811 
812       const __m256i sums = SumVerticalTaps<num_taps>(srcs, v_tap);
813       const __m256i sums_hi =
814           SumVerticalTaps<num_taps, /*unpack_high=*/true>(srcs, v_tap);
815       if (is_compound) {
816         const __m256i results =
817             Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x20));
818         const __m256i results_hi =
819             Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x31));
820         StoreUnaligned32(dst16_x, results);
821         StoreUnaligned32(dst16_x + 16, results_hi);
822         dst16_x += dst_stride;
823       } else {
824         const __m256i results =
825             RightShiftWithRounding_S16(sums, kFilterBits - 1);
826         const __m256i results_hi =
827             RightShiftWithRounding_S16(sums_hi, kFilterBits - 1);
828         const __m256i packed_results = _mm256_packus_epi16(results, results_hi);
829 
830         StoreUnaligned32(dst8_x, packed_results);
831         dst8_x += dst_stride;
832       }
833 
834       srcs[0] = srcs[1];
835       if (num_taps >= 4) {
836         srcs[1] = srcs[2];
837         srcs[2] = srcs[3];
838         if (num_taps >= 6) {
839           srcs[3] = srcs[4];
840           srcs[4] = srcs[5];
841           if (num_taps == 8) {
842             srcs[5] = srcs[6];
843             srcs[6] = srcs[7];
844           }
845         }
846       }
847     } while (--y != 0);
848     x += 32;
849   } while (x < width);
850 }
851 
852 template <int num_taps, bool is_compound = false>
FilterVertical16xH(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int,const int height,const __m256i * const v_tap)853 void FilterVertical16xH(const uint8_t* LIBGAV1_RESTRICT src,
854                         const ptrdiff_t src_stride,
855                         void* LIBGAV1_RESTRICT const dst,
856                         const ptrdiff_t dst_stride, const int /*width*/,
857                         const int height, const __m256i* const v_tap) {
858   const int next_row = num_taps;
859   auto* dst8 = static_cast<uint8_t*>(dst);
860   auto* dst16 = static_cast<uint16_t*>(dst);
861 
862   const uint8_t* src_x = src;
863   __m256i srcs[8 + 1];
864   // The upper 128 bits hold the filter data for the next row.
865   srcs[0] = _mm256_castsi128_si256(LoadUnaligned16(src_x));
866   src_x += src_stride;
867   if (num_taps >= 4) {
868     srcs[1] = _mm256_castsi128_si256(LoadUnaligned16(src_x));
869     src_x += src_stride;
870     srcs[0] =
871         _mm256_inserti128_si256(srcs[0], _mm256_castsi256_si128(srcs[1]), 1);
872     srcs[2] = _mm256_castsi128_si256(LoadUnaligned16(src_x));
873     src_x += src_stride;
874     srcs[1] =
875         _mm256_inserti128_si256(srcs[1], _mm256_castsi256_si128(srcs[2]), 1);
876     if (num_taps >= 6) {
877       srcs[3] = _mm256_castsi128_si256(LoadUnaligned16(src_x));
878       src_x += src_stride;
879       srcs[2] =
880           _mm256_inserti128_si256(srcs[2], _mm256_castsi256_si128(srcs[3]), 1);
881       srcs[4] = _mm256_castsi128_si256(LoadUnaligned16(src_x));
882       src_x += src_stride;
883       srcs[3] =
884           _mm256_inserti128_si256(srcs[3], _mm256_castsi256_si128(srcs[4]), 1);
885       if (num_taps == 8) {
886         srcs[5] = _mm256_castsi128_si256(LoadUnaligned16(src_x));
887         src_x += src_stride;
888         srcs[4] = _mm256_inserti128_si256(srcs[4],
889                                           _mm256_castsi256_si128(srcs[5]), 1);
890         srcs[6] = _mm256_castsi128_si256(LoadUnaligned16(src_x));
891         src_x += src_stride;
892         srcs[5] = _mm256_inserti128_si256(srcs[5],
893                                           _mm256_castsi256_si128(srcs[6]), 1);
894       }
895     }
896   }
897 
898   int y = height;
899   do {
900     srcs[next_row - 1] = _mm256_castsi128_si256(LoadUnaligned16(src_x));
901     src_x += src_stride;
902 
903     srcs[next_row - 2] = _mm256_inserti128_si256(
904         srcs[next_row - 2], _mm256_castsi256_si128(srcs[next_row - 1]), 1);
905 
906     srcs[next_row] = _mm256_castsi128_si256(LoadUnaligned16(src_x));
907     src_x += src_stride;
908 
909     srcs[next_row - 1] = _mm256_inserti128_si256(
910         srcs[next_row - 1], _mm256_castsi256_si128(srcs[next_row]), 1);
911 
912     const __m256i sums = SumVerticalTaps<num_taps>(srcs, v_tap);
913     const __m256i sums_hi =
914         SumVerticalTaps<num_taps, /*unpack_high=*/true>(srcs, v_tap);
915     if (is_compound) {
916       const __m256i results =
917           Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x20));
918       const __m256i results_hi =
919           Compound1DShift(_mm256_permute2x128_si256(sums, sums_hi, 0x31));
920 
921       StoreUnaligned32(dst16, results);
922       StoreUnaligned32(dst16 + dst_stride, results_hi);
923       dst16 += dst_stride << 1;
924     } else {
925       const __m256i results = RightShiftWithRounding_S16(sums, kFilterBits - 1);
926       const __m256i results_hi =
927           RightShiftWithRounding_S16(sums_hi, kFilterBits - 1);
928       const __m256i packed_results = _mm256_packus_epi16(results, results_hi);
929       const __m128i this_dst = _mm256_castsi256_si128(packed_results);
930       const auto next_dst = _mm256_extracti128_si256(packed_results, 1);
931 
932       StoreUnaligned16(dst8, this_dst);
933       StoreUnaligned16(dst8 + dst_stride, next_dst);
934       dst8 += dst_stride << 1;
935     }
936 
937     srcs[0] = srcs[2];
938     if (num_taps >= 4) {
939       srcs[1] = srcs[3];
940       srcs[2] = srcs[4];
941       if (num_taps >= 6) {
942         srcs[3] = srcs[5];
943         srcs[4] = srcs[6];
944         if (num_taps == 8) {
945           srcs[5] = srcs[7];
946           srcs[6] = srcs[8];
947         }
948       }
949     }
950     y -= 2;
951   } while (y != 0);
952 }
953 
954 template <int num_taps, bool is_compound = false>
FilterVertical8xH(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int,const int height,const __m256i * const v_tap)955 void FilterVertical8xH(const uint8_t* LIBGAV1_RESTRICT src,
956                        const ptrdiff_t src_stride,
957                        void* LIBGAV1_RESTRICT const dst,
958                        const ptrdiff_t dst_stride, const int /*width*/,
959                        const int height, const __m256i* const v_tap) {
960   const int next_row = num_taps;
961   auto* dst8 = static_cast<uint8_t*>(dst);
962   auto* dst16 = static_cast<uint16_t*>(dst);
963 
964   const uint8_t* src_x = src;
965   __m256i srcs[8 + 1];
966   // The upper 128 bits hold the filter data for the next row.
967   srcs[0] = _mm256_castsi128_si256(LoadLo8(src_x));
968   src_x += src_stride;
969   if (num_taps >= 4) {
970     srcs[1] = _mm256_castsi128_si256(LoadLo8(src_x));
971     src_x += src_stride;
972     srcs[0] =
973         _mm256_inserti128_si256(srcs[0], _mm256_castsi256_si128(srcs[1]), 1);
974     srcs[2] = _mm256_castsi128_si256(LoadLo8(src_x));
975     src_x += src_stride;
976     srcs[1] =
977         _mm256_inserti128_si256(srcs[1], _mm256_castsi256_si128(srcs[2]), 1);
978     if (num_taps >= 6) {
979       srcs[3] = _mm256_castsi128_si256(LoadLo8(src_x));
980       src_x += src_stride;
981       srcs[2] =
982           _mm256_inserti128_si256(srcs[2], _mm256_castsi256_si128(srcs[3]), 1);
983       srcs[4] = _mm256_castsi128_si256(LoadLo8(src_x));
984       src_x += src_stride;
985       srcs[3] =
986           _mm256_inserti128_si256(srcs[3], _mm256_castsi256_si128(srcs[4]), 1);
987       if (num_taps == 8) {
988         srcs[5] = _mm256_castsi128_si256(LoadLo8(src_x));
989         src_x += src_stride;
990         srcs[4] = _mm256_inserti128_si256(srcs[4],
991                                           _mm256_castsi256_si128(srcs[5]), 1);
992         srcs[6] = _mm256_castsi128_si256(LoadLo8(src_x));
993         src_x += src_stride;
994         srcs[5] = _mm256_inserti128_si256(srcs[5],
995                                           _mm256_castsi256_si128(srcs[6]), 1);
996       }
997     }
998   }
999 
1000   int y = height;
1001   do {
1002     srcs[next_row - 1] = _mm256_castsi128_si256(LoadLo8(src_x));
1003     src_x += src_stride;
1004 
1005     srcs[next_row - 2] = _mm256_inserti128_si256(
1006         srcs[next_row - 2], _mm256_castsi256_si128(srcs[next_row - 1]), 1);
1007 
1008     srcs[next_row] = _mm256_castsi128_si256(LoadLo8(src_x));
1009     src_x += src_stride;
1010 
1011     srcs[next_row - 1] = _mm256_inserti128_si256(
1012         srcs[next_row - 1], _mm256_castsi256_si128(srcs[next_row]), 1);
1013 
1014     const __m256i sums = SumVerticalTaps<num_taps>(srcs, v_tap);
1015     if (is_compound) {
1016       const __m256i results = Compound1DShift(sums);
1017       const __m128i this_dst = _mm256_castsi256_si128(results);
1018       const auto next_dst = _mm256_extracti128_si256(results, 1);
1019 
1020       StoreUnaligned16(dst16, this_dst);
1021       StoreUnaligned16(dst16 + dst_stride, next_dst);
1022       dst16 += dst_stride << 1;
1023     } else {
1024       const __m256i results = RightShiftWithRounding_S16(sums, kFilterBits - 1);
1025       const __m256i packed_results = _mm256_packus_epi16(results, results);
1026       const __m128i this_dst = _mm256_castsi256_si128(packed_results);
1027       const auto next_dst = _mm256_extracti128_si256(packed_results, 1);
1028 
1029       StoreLo8(dst8, this_dst);
1030       StoreLo8(dst8 + dst_stride, next_dst);
1031       dst8 += dst_stride << 1;
1032     }
1033 
1034     srcs[0] = srcs[2];
1035     if (num_taps >= 4) {
1036       srcs[1] = srcs[3];
1037       srcs[2] = srcs[4];
1038       if (num_taps >= 6) {
1039         srcs[3] = srcs[5];
1040         srcs[4] = srcs[6];
1041         if (num_taps == 8) {
1042           srcs[5] = srcs[7];
1043           srcs[6] = srcs[8];
1044         }
1045       }
1046     }
1047     y -= 2;
1048   } while (y != 0);
1049 }
1050 
1051 template <int num_taps, bool is_compound = false>
FilterVertical8xH(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int,const int height,const __m128i * const v_tap)1052 void FilterVertical8xH(const uint8_t* LIBGAV1_RESTRICT src,
1053                        const ptrdiff_t src_stride,
1054                        void* LIBGAV1_RESTRICT const dst,
1055                        const ptrdiff_t dst_stride, const int /*width*/,
1056                        const int height, const __m128i* const v_tap) {
1057   const int next_row = num_taps - 1;
1058   auto* dst8 = static_cast<uint8_t*>(dst);
1059   auto* dst16 = static_cast<uint16_t*>(dst);
1060 
1061   const uint8_t* src_x = src;
1062   __m128i srcs[8];
1063   srcs[0] = LoadLo8(src_x);
1064   src_x += src_stride;
1065   if (num_taps >= 4) {
1066     srcs[1] = LoadLo8(src_x);
1067     src_x += src_stride;
1068     srcs[2] = LoadLo8(src_x);
1069     src_x += src_stride;
1070     if (num_taps >= 6) {
1071       srcs[3] = LoadLo8(src_x);
1072       src_x += src_stride;
1073       srcs[4] = LoadLo8(src_x);
1074       src_x += src_stride;
1075       if (num_taps == 8) {
1076         srcs[5] = LoadLo8(src_x);
1077         src_x += src_stride;
1078         srcs[6] = LoadLo8(src_x);
1079         src_x += src_stride;
1080       }
1081     }
1082   }
1083 
1084   int y = height;
1085   do {
1086     srcs[next_row] = LoadLo8(src_x);
1087     src_x += src_stride;
1088 
1089     const __m128i sums = SumVerticalTaps<num_taps>(srcs, v_tap);
1090     if (is_compound) {
1091       const __m128i results = Compound1DShift(sums);
1092       StoreUnaligned16(dst16, results);
1093       dst16 += dst_stride;
1094     } else {
1095       const __m128i results = RightShiftWithRounding_S16(sums, kFilterBits - 1);
1096       StoreLo8(dst8, _mm_packus_epi16(results, results));
1097       dst8 += dst_stride;
1098     }
1099 
1100     srcs[0] = srcs[1];
1101     if (num_taps >= 4) {
1102       srcs[1] = srcs[2];
1103       srcs[2] = srcs[3];
1104       if (num_taps >= 6) {
1105         srcs[3] = srcs[4];
1106         srcs[4] = srcs[5];
1107         if (num_taps == 8) {
1108           srcs[5] = srcs[6];
1109           srcs[6] = srcs[7];
1110         }
1111       }
1112     }
1113   } while (--y != 0);
1114 }
1115 
ConvolveVertical_AVX2(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int vertical_filter_index,const int,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t pred_stride)1116 void ConvolveVertical_AVX2(const void* LIBGAV1_RESTRICT const reference,
1117                            const ptrdiff_t reference_stride,
1118                            const int /*horizontal_filter_index*/,
1119                            const int vertical_filter_index,
1120                            const int /*horizontal_filter_id*/,
1121                            const int vertical_filter_id, const int width,
1122                            const int height, void* LIBGAV1_RESTRICT prediction,
1123                            const ptrdiff_t pred_stride) {
1124   const int filter_index = GetFilterIndex(vertical_filter_index, height);
1125   const int vertical_taps =
1126       GetNumTapsInFilter(filter_index, vertical_filter_id);
1127   const ptrdiff_t src_stride = reference_stride;
1128   const auto* src = static_cast<const uint8_t*>(reference) -
1129                     (vertical_taps / 2 - 1) * src_stride;
1130   auto* dest = static_cast<uint8_t*>(prediction);
1131   const ptrdiff_t dest_stride = pred_stride;
1132   assert(vertical_filter_id != 0);
1133 
1134   const __m128i v_filter =
1135       LoadLo8(kHalfSubPixelFilters[filter_index][vertical_filter_id]);
1136 
1137   // Use 256 bits for width > 4.
1138   if (width > 4) {
1139     __m256i taps_256[4];
1140     if (vertical_taps == 6) {  // 6 tap.
1141       SetupTaps<6>(&v_filter, taps_256);
1142       if (width == 8) {
1143         FilterVertical8xH<6>(src, src_stride, dest, dest_stride, width, height,
1144                              taps_256);
1145       } else if (width == 16) {
1146         FilterVertical16xH<6>(src, src_stride, dest, dest_stride, width, height,
1147                               taps_256);
1148       } else {
1149         FilterVertical32xH<6>(src, src_stride, dest, dest_stride, width, height,
1150                               taps_256);
1151       }
1152     } else if (vertical_taps == 8) {  // 8 tap.
1153       SetupTaps<8>(&v_filter, taps_256);
1154       if (width == 8) {
1155         FilterVertical8xH<8>(src, src_stride, dest, dest_stride, width, height,
1156                              taps_256);
1157       } else if (width == 16) {
1158         FilterVertical16xH<8>(src, src_stride, dest, dest_stride, width, height,
1159                               taps_256);
1160       } else {
1161         FilterVertical32xH<8>(src, src_stride, dest, dest_stride, width, height,
1162                               taps_256);
1163       }
1164     } else if (vertical_taps == 2) {  // 2 tap.
1165       SetupTaps<2>(&v_filter, taps_256);
1166       if (width == 8) {
1167         FilterVertical8xH<2>(src, src_stride, dest, dest_stride, width, height,
1168                              taps_256);
1169       } else if (width == 16) {
1170         FilterVertical16xH<2>(src, src_stride, dest, dest_stride, width, height,
1171                               taps_256);
1172       } else {
1173         FilterVertical32xH<2>(src, src_stride, dest, dest_stride, width, height,
1174                               taps_256);
1175       }
1176     } else {  // 4 tap.
1177       SetupTaps<4>(&v_filter, taps_256);
1178       if (width == 8) {
1179         FilterVertical8xH<4>(src, src_stride, dest, dest_stride, width, height,
1180                              taps_256);
1181       } else if (width == 16) {
1182         FilterVertical16xH<4>(src, src_stride, dest, dest_stride, width, height,
1183                               taps_256);
1184       } else {
1185         FilterVertical32xH<4>(src, src_stride, dest, dest_stride, width, height,
1186                               taps_256);
1187       }
1188     }
1189   } else {  // width <= 8
1190     // Use 128 bit code.
1191     __m128i taps[4];
1192 
1193     if (vertical_taps == 6) {  // 6 tap.
1194       SetupTaps<6>(&v_filter, taps);
1195       if (width == 2) {
1196         FilterVertical2xH<6>(src, src_stride, dest, dest_stride, height, taps);
1197       } else {
1198         FilterVertical4xH<6>(src, src_stride, dest, dest_stride, height, taps);
1199       }
1200     } else if (vertical_taps == 8) {  // 8 tap.
1201       SetupTaps<8>(&v_filter, taps);
1202       if (width == 2) {
1203         FilterVertical2xH<8>(src, src_stride, dest, dest_stride, height, taps);
1204       } else {
1205         FilterVertical4xH<8>(src, src_stride, dest, dest_stride, height, taps);
1206       }
1207     } else if (vertical_taps == 2) {  // 2 tap.
1208       SetupTaps<2>(&v_filter, taps);
1209       if (width == 2) {
1210         FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps);
1211       } else {
1212         FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps);
1213       }
1214     } else {  // 4 tap.
1215       SetupTaps<4>(&v_filter, taps);
1216       if (width == 2) {
1217         FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height, taps);
1218       } else {
1219         FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height, taps);
1220       }
1221     }
1222   }
1223 }
1224 
ConvolveCompoundVertical_AVX2(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int vertical_filter_index,const int,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t)1225 void ConvolveCompoundVertical_AVX2(
1226     const void* LIBGAV1_RESTRICT const reference,
1227     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
1228     const int vertical_filter_index, const int /*horizontal_filter_id*/,
1229     const int vertical_filter_id, const int width, const int height,
1230     void* LIBGAV1_RESTRICT prediction, const ptrdiff_t /*pred_stride*/) {
1231   const int filter_index = GetFilterIndex(vertical_filter_index, height);
1232   const int vertical_taps =
1233       GetNumTapsInFilter(filter_index, vertical_filter_id);
1234   const ptrdiff_t src_stride = reference_stride;
1235   const auto* src = static_cast<const uint8_t*>(reference) -
1236                     (vertical_taps / 2 - 1) * src_stride;
1237   auto* dest = static_cast<uint8_t*>(prediction);
1238   const ptrdiff_t dest_stride = width;
1239   assert(vertical_filter_id != 0);
1240 
1241   const __m128i v_filter =
1242       LoadLo8(kHalfSubPixelFilters[filter_index][vertical_filter_id]);
1243 
1244   // Use 256 bits for width > 4.
1245   if (width > 4) {
1246     __m256i taps_256[4];
1247     if (vertical_taps == 6) {  // 6 tap.
1248       SetupTaps<6>(&v_filter, taps_256);
1249       if (width == 8) {
1250         FilterVertical8xH<6, /*is_compound=*/true>(
1251             src, src_stride, dest, dest_stride, width, height, taps_256);
1252       } else if (width == 16) {
1253         FilterVertical16xH<6, /*is_compound=*/true>(
1254             src, src_stride, dest, dest_stride, width, height, taps_256);
1255       } else {
1256         FilterVertical32xH<6, /*is_compound=*/true>(
1257             src, src_stride, dest, dest_stride, width, height, taps_256);
1258       }
1259     } else if (vertical_taps == 8) {  // 8 tap.
1260       SetupTaps<8>(&v_filter, taps_256);
1261       if (width == 8) {
1262         FilterVertical8xH<8, /*is_compound=*/true>(
1263             src, src_stride, dest, dest_stride, width, height, taps_256);
1264       } else if (width == 16) {
1265         FilterVertical16xH<8, /*is_compound=*/true>(
1266             src, src_stride, dest, dest_stride, width, height, taps_256);
1267       } else {
1268         FilterVertical32xH<8, /*is_compound=*/true>(
1269             src, src_stride, dest, dest_stride, width, height, taps_256);
1270       }
1271     } else if (vertical_taps == 2) {  // 2 tap.
1272       SetupTaps<2>(&v_filter, taps_256);
1273       if (width == 8) {
1274         FilterVertical8xH<2, /*is_compound=*/true>(
1275             src, src_stride, dest, dest_stride, width, height, taps_256);
1276       } else if (width == 16) {
1277         FilterVertical16xH<2, /*is_compound=*/true>(
1278             src, src_stride, dest, dest_stride, width, height, taps_256);
1279       } else {
1280         FilterVertical32xH<2, /*is_compound=*/true>(
1281             src, src_stride, dest, dest_stride, width, height, taps_256);
1282       }
1283     } else {  // 4 tap.
1284       SetupTaps<4>(&v_filter, taps_256);
1285       if (width == 8) {
1286         FilterVertical8xH<4, /*is_compound=*/true>(
1287             src, src_stride, dest, dest_stride, width, height, taps_256);
1288       } else if (width == 16) {
1289         FilterVertical16xH<4, /*is_compound=*/true>(
1290             src, src_stride, dest, dest_stride, width, height, taps_256);
1291       } else {
1292         FilterVertical32xH<4, /*is_compound=*/true>(
1293             src, src_stride, dest, dest_stride, width, height, taps_256);
1294       }
1295     }
1296   } else {  // width <= 4
1297     // Use 128 bit code.
1298     __m128i taps[4];
1299 
1300     if (vertical_taps == 6) {  // 6 tap.
1301       SetupTaps<6>(&v_filter, taps);
1302       FilterVertical4xH<6, /*is_compound=*/true>(src, src_stride, dest,
1303                                                  dest_stride, height, taps);
1304     } else if (vertical_taps == 8) {  // 8 tap.
1305       SetupTaps<8>(&v_filter, taps);
1306       FilterVertical4xH<8, /*is_compound=*/true>(src, src_stride, dest,
1307                                                  dest_stride, height, taps);
1308     } else if (vertical_taps == 2) {  // 2 tap.
1309       SetupTaps<2>(&v_filter, taps);
1310       FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest,
1311                                                  dest_stride, height, taps);
1312     } else {  // 4 tap.
1313       SetupTaps<4>(&v_filter, taps);
1314       FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest,
1315                                                  dest_stride, height, taps);
1316     }
1317   }
1318 }
1319 
ConvolveHorizontal_AVX2(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int,const int horizontal_filter_id,const int,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t pred_stride)1320 void ConvolveHorizontal_AVX2(
1321     const void* LIBGAV1_RESTRICT const reference,
1322     const ptrdiff_t reference_stride, const int horizontal_filter_index,
1323     const int /*vertical_filter_index*/, const int horizontal_filter_id,
1324     const int /*vertical_filter_id*/, const int width, const int height,
1325     void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
1326   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
1327   // Set |src| to the outermost tap.
1328   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
1329   auto* dest = static_cast<uint8_t*>(prediction);
1330 
1331   if (width > 2) {
1332     DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height,
1333                      horizontal_filter_id, filter_index);
1334   } else {
1335     // Use non avx2 version for smaller widths.
1336     DoHorizontalPass2xH(src, reference_stride, dest, pred_stride, width, height,
1337                         horizontal_filter_id, filter_index);
1338   }
1339 }
1340 
ConvolveCompoundHorizontal_AVX2(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int,const int horizontal_filter_id,const int,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t pred_stride)1341 void ConvolveCompoundHorizontal_AVX2(
1342     const void* LIBGAV1_RESTRICT const reference,
1343     const ptrdiff_t reference_stride, const int horizontal_filter_index,
1344     const int /*vertical_filter_index*/, const int horizontal_filter_id,
1345     const int /*vertical_filter_id*/, const int width, const int height,
1346     void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
1347   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
1348   // Set |src| to the outermost tap.
1349   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
1350   auto* dest = static_cast<uint8_t*>(prediction);
1351   // All compound functions output to the predictor buffer with |pred_stride|
1352   // equal to |width|.
1353   assert(pred_stride == width);
1354   // Compound functions start at 4x4.
1355   assert(width >= 4 && height >= 4);
1356 
1357 #ifdef NDEBUG
1358   // Quiet compiler error.
1359   (void)pred_stride;
1360 #endif
1361 
1362   DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>(
1363       src, reference_stride, dest, width, width, height, horizontal_filter_id,
1364       filter_index);
1365 }
1366 
ConvolveCompound2D_AVX2(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int horizontal_filter_id,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t pred_stride)1367 void ConvolveCompound2D_AVX2(
1368     const void* LIBGAV1_RESTRICT const reference,
1369     const ptrdiff_t reference_stride, const int horizontal_filter_index,
1370     const int vertical_filter_index, const int horizontal_filter_id,
1371     const int vertical_filter_id, const int width, const int height,
1372     void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
1373   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
1374   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
1375   const int vertical_taps =
1376       GetNumTapsInFilter(vert_filter_index, vertical_filter_id);
1377 
1378   // The output of the horizontal filter is guaranteed to fit in 16 bits.
1379   alignas(32) uint16_t
1380       intermediate_result[kMaxSuperBlockSizeInPixels *
1381                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
1382 #if LIBGAV1_MSAN
1383   // Quiet msan warnings. Set with random non-zero value to aid in debugging.
1384   memset(intermediate_result, 0x33, sizeof(intermediate_result));
1385 #endif
1386   const int intermediate_height = height + vertical_taps - 1;
1387 
1388   const ptrdiff_t src_stride = reference_stride;
1389   const auto* src = static_cast<const uint8_t*>(reference) -
1390                     (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
1391   DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>(
1392       src, src_stride, intermediate_result, width, width, intermediate_height,
1393       horizontal_filter_id, horiz_filter_index);
1394 
1395   // Vertical filter.
1396   auto* dest = static_cast<uint8_t*>(prediction);
1397   const ptrdiff_t dest_stride = pred_stride;
1398   assert(vertical_filter_id != 0);
1399 
1400   const __m128i v_filter =
1401       LoadLo8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]);
1402 
1403   // Use 256 bits for width > 8.
1404   if (width > 8) {
1405     __m256i taps_256[4];
1406     const __m128i v_filter_ext = _mm_cvtepi8_epi16(v_filter);
1407 
1408     if (vertical_taps == 8) {
1409       SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256);
1410       Filter2DVertical16xH<8, /*is_compound=*/true>(
1411           intermediate_result, dest, dest_stride, width, height, taps_256);
1412     } else if (vertical_taps == 6) {
1413       SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256);
1414       Filter2DVertical16xH<6, /*is_compound=*/true>(
1415           intermediate_result, dest, dest_stride, width, height, taps_256);
1416     } else if (vertical_taps == 4) {
1417       SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256);
1418       Filter2DVertical16xH<4, /*is_compound=*/true>(
1419           intermediate_result, dest, dest_stride, width, height, taps_256);
1420     } else {  // |vertical_taps| == 2
1421       SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter_ext, taps_256);
1422       Filter2DVertical16xH<2, /*is_compound=*/true>(
1423           intermediate_result, dest, dest_stride, width, height, taps_256);
1424     }
1425   } else {  // width <= 8
1426     __m128i taps[4];
1427     // Use 128 bit code.
1428     if (vertical_taps == 8) {
1429       SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps);
1430       if (width == 4) {
1431         Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest,
1432                                                      dest_stride, height, taps);
1433       } else {
1434         Filter2DVertical<8, /*is_compound=*/true>(
1435             intermediate_result, dest, dest_stride, width, height, taps);
1436       }
1437     } else if (vertical_taps == 6) {
1438       SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps);
1439       if (width == 4) {
1440         Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest,
1441                                                      dest_stride, height, taps);
1442       } else {
1443         Filter2DVertical<6, /*is_compound=*/true>(
1444             intermediate_result, dest, dest_stride, width, height, taps);
1445       }
1446     } else if (vertical_taps == 4) {
1447       SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps);
1448       if (width == 4) {
1449         Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest,
1450                                                      dest_stride, height, taps);
1451       } else {
1452         Filter2DVertical<4, /*is_compound=*/true>(
1453             intermediate_result, dest, dest_stride, width, height, taps);
1454       }
1455     } else {  // |vertical_taps| == 2
1456       SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps);
1457       if (width == 4) {
1458         Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest,
1459                                                      dest_stride, height, taps);
1460       } else {
1461         Filter2DVertical<2, /*is_compound=*/true>(
1462             intermediate_result, dest, dest_stride, width, height, taps);
1463       }
1464     }
1465   }
1466 }
1467 
Init8bpp()1468 void Init8bpp() {
1469   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
1470   assert(dsp != nullptr);
1471   dsp->convolve[0][0][0][1] = ConvolveHorizontal_AVX2;
1472   dsp->convolve[0][0][1][0] = ConvolveVertical_AVX2;
1473   dsp->convolve[0][0][1][1] = Convolve2D_AVX2;
1474 
1475   dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_AVX2;
1476   dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_AVX2;
1477   dsp->convolve[0][1][1][1] = ConvolveCompound2D_AVX2;
1478 }
1479 
1480 }  // namespace
1481 }  // namespace low_bitdepth
1482 
ConvolveInit_AVX2()1483 void ConvolveInit_AVX2() { low_bitdepth::Init8bpp(); }
1484 
1485 }  // namespace dsp
1486 }  // namespace libgav1
1487 
1488 #else   // !LIBGAV1_TARGETING_AVX2
1489 namespace libgav1 {
1490 namespace dsp {
1491 
ConvolveInit_AVX2()1492 void ConvolveInit_AVX2() {}
1493 
1494 }  // namespace dsp
1495 }  // namespace libgav1
1496 #endif  // LIBGAV1_TARGETING_AVX2
1497