xref: /aosp_15_r20/external/libaom/av1/encoder/x86/cnn_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <immintrin.h>
14 #include <math.h>
15 
16 #include "aom_dsp/aom_dsp_common.h"
17 #include "av1/common/av1_common_int.h"
18 #include "av1/encoder/cnn.h"
19 
20 // This mask rearranges source pixels in the order shown below.
21 // shuffle_src_layer0[0][8]: applied on source pixels 0 to 7.
22 // shuffle_src_layer0[1][8]: applied on source pixels 7 to 14.
23 // This shuffling is needed to process 3 5x5 blocks which need
24 // source pixels in the following order.
25 // 1st 5x5 block: source pixels needed are 0 to 4,
26 // 2nd 5x5 block: source pixels needed are 4 to 8,
27 // 3rd 5x5 block: source pixels needed are 8 to 12.
28 // Source pixels are loaded like mentioned below.
29 // load_src0 : 0, 1, 2, 3, 4, 5, 6, 7
30 // load_src1 : 7, 8, 9, 10, 11, 12, 13, 14
31 // After applying masks, source bytes will be in the order:
32 // load_src0 : 0, 1, 2, 3, 4, 4, 5, 6
33 //             consists 5 pixels needed for 1st 5x5 block and
34 //             first 3 pixels needed for 2nd 5x5 block.
35 // load_src1 : 7, 8, 8, 9, 10, 11, 12, x
36 //             consists last 2 pixels needed for 2nd 5x5 block and
37 //             5 pixels needed for 3rd 5x5 block.
38 DECLARE_ALIGNED(32, static const uint32_t,
39                 shuffle_src_layer0[2][8]) = { { 0, 1, 2, 3, 4, 4, 5, 6 },
40                                               { 0, 1, 1, 2, 3, 4, 5, 0 } };
41 
42 // This mask rearrange the weights to match shuffled source pixels order.
43 DECLARE_ALIGNED(32, static const uint32_t,
44                 shuffle_weight_layer0[2][8]) = { { 0, 1, 2, 3, 4, 0, 1, 2 },
45                                                  { 3, 4, 0, 1, 2, 3, 4, 0 } };
46 
47 // Shuffle mask used to rearrange weights corresponding to layer 1 and layer 2.
48 // For layer 1 and layer 2, convolution happens at 2x2 as filter_width and
49 // filter_height are equal to 2. So rearranging the weights in the
50 // order shown below to match source pixels. Basically this mask replicates
51 // the weights across the width of 2.
52 DECLARE_ALIGNED(32, static const uint32_t,
53                 shuffle_weight_layer_1_and_2[2][8]) = {
54   { 0, 1, 0, 1, 0, 1, 0, 1 }, { 2, 3, 2, 3, 2, 3, 2, 3 }
55 };
56 
57 // After the stages of multiplication and accumulation, the output values
58 // in the register will be jumbled. In order to store register into
59 // output buffer in a proper way, the following mask is applied on output
60 // register.
61 DECLARE_ALIGNED(32, static const uint32_t,
62                 shuffle_output_layer_1_and_2[8]) = { 0, 1, 4, 5, 2, 3, 6, 7 };
63 
64 // Load weights needed for layer 0 (for 5x5 block processing),
65 // and fill the registers appropriately to match source pixel mapping.
prepare_weights_for_5x5_convolve(const float * layer_config_weights,int off,float weight[5][8],const int cstep,__m256 * shuffle_weight,const __m256i weight_mask_0,const __m256i weight_mask_1)66 static inline void prepare_weights_for_5x5_convolve(
67     const float *layer_config_weights, int off, float weight[5][8],
68     const int cstep, __m256 *shuffle_weight, const __m256i weight_mask_0,
69     const __m256i weight_mask_1) {
70   for (int row = 0; row < 5; ++row) {
71     for (int col = 0; col < 5; ++col) {
72       weight[row][col] = layer_config_weights[off];
73       off += cstep;
74     }
75   }
76   shuffle_weight[0] = _mm256_loadu_ps(weight[0]);
77   shuffle_weight[1] = _mm256_loadu_ps(weight[1]);
78   shuffle_weight[2] = _mm256_loadu_ps(weight[2]);
79   shuffle_weight[3] = _mm256_loadu_ps(weight[3]);
80   shuffle_weight[4] = _mm256_loadu_ps(weight[4]);
81 
82   shuffle_weight[0] =
83       _mm256_permutevar8x32_ps(shuffle_weight[0], weight_mask_0);
84   shuffle_weight[1] =
85       _mm256_permutevar8x32_ps(shuffle_weight[1], weight_mask_0);
86   shuffle_weight[2] =
87       _mm256_permutevar8x32_ps(shuffle_weight[2], weight_mask_0);
88   shuffle_weight[3] =
89       _mm256_permutevar8x32_ps(shuffle_weight[3], weight_mask_0);
90   shuffle_weight[4] =
91       _mm256_permutevar8x32_ps(shuffle_weight[4], weight_mask_0);
92   shuffle_weight[5] =
93       _mm256_permutevar8x32_ps(shuffle_weight[0], weight_mask_1);
94   shuffle_weight[6] =
95       _mm256_permutevar8x32_ps(shuffle_weight[1], weight_mask_1);
96   shuffle_weight[7] =
97       _mm256_permutevar8x32_ps(shuffle_weight[2], weight_mask_1);
98   shuffle_weight[8] =
99       _mm256_permutevar8x32_ps(shuffle_weight[3], weight_mask_1);
100   shuffle_weight[9] =
101       _mm256_permutevar8x32_ps(shuffle_weight[4], weight_mask_1);
102 }
103 
104 // For each row, loads source pixels 0 to 7(load_src_0), 7 to 14(load_src_1) and
105 // arranges them appropriately to process 3 blocks.
106 #define PERFORM_CONVOLVE_FOR_3_5X5_BLOCKS()                            \
107   do {                                                                 \
108     for (int row = 0; row < 5; row++) {                                \
109       load_src_0 = _mm256_loadu_ps(input_ptr);                         \
110       load_src_1 = _mm256_loadu_ps(input_ptr + 7);                     \
111       load_src_0 = _mm256_permutevar8x32_ps(load_src_0, block0_1);     \
112       load_src_1 = _mm256_permutevar8x32_ps(load_src_1, block1_2);     \
113       load_src_0 = _mm256_mul_ps(load_src_0, shuffle_weight[0 + row]); \
114       load_src_1 = _mm256_mul_ps(load_src_1, shuffle_weight[5 + row]); \
115       accum_src_0 = _mm256_add_ps(load_src_0, accum_src_0);            \
116       accum_src_1 = _mm256_add_ps(load_src_1, accum_src_1);            \
117       input_ptr += in_stride;                                          \
118     }                                                                  \
119   } while (0)
120 
121 // Load masks needed for shuffling of output and weights.
load_shuffle_masks_for_2x2_convolve(__m256i * output_mask,__m256i * weight_mask)122 static inline void load_shuffle_masks_for_2x2_convolve(__m256i *output_mask,
123                                                        __m256i *weight_mask) {
124   // Load shuffle buffer needed to sort the output.
125   *output_mask =
126       _mm256_load_si256((const __m256i *)shuffle_output_layer_1_and_2);
127 
128   // Load shuffle buffers needed for weight.
129   weight_mask[0] =
130       _mm256_load_si256((const __m256i *)shuffle_weight_layer_1_and_2[0]);
131   weight_mask[1] =
132       _mm256_load_si256((const __m256i *)shuffle_weight_layer_1_and_2[1]);
133 }
134 
135 // Load weights needed for layer 1 and 2 (for 2x2 block processing),
136 // and fill the registers appropriately to match source pixel mapping.
prepare_weights_for_2x2_convolve(const float * layer_config_weights,int off,const int cstep,__m256 * shuffle_weight,__m256i * weight_mask)137 static inline void prepare_weights_for_2x2_convolve(
138     const float *layer_config_weights, int off, const int cstep,
139     __m256 *shuffle_weight, __m256i *weight_mask) {
140   // Weights needed for 2x2 block.
141   float weight[4] = { 0 };
142   for (int i = 0; i < 4; ++i) {
143     weight[i] = layer_config_weights[off];
144     off += cstep;
145   }
146 
147   const __m256 weight_vec = _mm256_castps128_ps256(_mm_loadu_ps(weight));
148   shuffle_weight[0] = _mm256_permutevar8x32_ps(weight_vec, weight_mask[0]);
149   shuffle_weight[1] = _mm256_permutevar8x32_ps(weight_vec, weight_mask[1]);
150 }
151 
152 // Do convolution of one 5x5 block.
153 #define PERFORM_CONVOLVE_FOR_1_5X5_BLOCK(w, accum0, in_stride)           \
154   do {                                                                   \
155     __m128 load_src[5];                                                  \
156     load_src[0] = _mm_loadu_ps(input_ptr);                               \
157     last_column_sum += input_ptr[4] * weight[0][4];                      \
158     input_ptr += in_stride;                                              \
159     load_src[1] = _mm_loadu_ps(input_ptr);                               \
160     last_column_sum += input_ptr[4] * weight[1][4];                      \
161     input_ptr += in_stride;                                              \
162     load_src[2] = _mm_loadu_ps(input_ptr);                               \
163     last_column_sum += input_ptr[4] * weight[2][4];                      \
164     input_ptr += in_stride;                                              \
165     load_src[3] = _mm_loadu_ps(input_ptr);                               \
166     last_column_sum += input_ptr[4] * weight[3][4];                      \
167     input_ptr += in_stride;                                              \
168     load_src[4] = _mm_loadu_ps(input_ptr);                               \
169     last_column_sum += input_ptr[4] * weight[4][4];                      \
170                                                                          \
171     load_src[0] = _mm_mul_ps(load_src[0], _mm256_castps256_ps128(w[0])); \
172     load_src[1] = _mm_mul_ps(load_src[1], _mm256_castps256_ps128(w[1])); \
173     load_src[2] = _mm_mul_ps(load_src[2], _mm256_castps256_ps128(w[2])); \
174     load_src[3] = _mm_mul_ps(load_src[3], _mm256_castps256_ps128(w[3])); \
175     load_src[4] = _mm_mul_ps(load_src[4], _mm256_castps256_ps128(w[4])); \
176                                                                          \
177     accum0 = _mm_add_ps(load_src[0], accum0);                            \
178     load_src[1] = _mm_add_ps(load_src[1], load_src[2]);                  \
179     load_src[3] = _mm_add_ps(load_src[3], load_src[4]);                  \
180     load_src[1] = _mm_add_ps(load_src[1], load_src[3]);                  \
181     accum0 = _mm_add_ps(accum0, load_src[1]);                            \
182   } while (0)
183 
184 // Do convolution on 8 horizontal 2x2 blocks.
perform_convolve_for_8h_2x2_blocks(const float * input_ptr,int in_stride,__m256 * weight,__m256 * out_accum,__m256i shuffle_output_mask)185 static inline void perform_convolve_for_8h_2x2_blocks(
186     const float *input_ptr, int in_stride, __m256 *weight, __m256 *out_accum,
187     __m256i shuffle_output_mask) {
188   __m256 load_src[4];
189   // Load input into source registers.
190   load_src[0] = _mm256_loadu_ps(input_ptr);
191   load_src[1] = _mm256_loadu_ps(input_ptr + 8);
192   load_src[2] = _mm256_loadu_ps(input_ptr + in_stride);
193   load_src[3] = _mm256_loadu_ps(input_ptr + in_stride + 8);
194 
195   // Multiply the loaded input with corresponding weights.
196   load_src[0] = _mm256_mul_ps(load_src[0], weight[0]);
197   load_src[1] = _mm256_mul_ps(load_src[1], weight[0]);
198   load_src[2] = _mm256_mul_ps(load_src[2], weight[1]);
199   load_src[3] = _mm256_mul_ps(load_src[3], weight[1]);
200 
201   // Accumulate across 2x2 blocks.
202   load_src[0] = _mm256_add_ps(load_src[0], load_src[2]);
203   load_src[1] = _mm256_add_ps(load_src[1], load_src[3]);
204   load_src[0] = _mm256_hadd_ps(load_src[0], load_src[1]);
205 
206   // Sort the output in order to store into output buffer.
207   load_src[0] = _mm256_permutevar8x32_ps(load_src[0], shuffle_output_mask);
208   *out_accum = _mm256_add_ps(*out_accum, load_src[0]);
209 }
210 
211 // Do convolution on 8 (4 horizontal x 2 vertical) 2x2 blocks.
perform_convolve_for_4hx2v_2x2_blocks(const float * input_ptr,int in_stride,__m256 * weight,__m256 * out_accum,__m256i shuffle_output_mask)212 static inline void perform_convolve_for_4hx2v_2x2_blocks(
213     const float *input_ptr, int in_stride, __m256 *weight, __m256 *out_accum,
214     __m256i shuffle_output_mask) {
215   __m256 load_src[4];
216   // Load input into source registers.
217   load_src[0] = _mm256_loadu_ps(input_ptr);
218   load_src[1] = _mm256_loadu_ps(input_ptr + in_stride);
219   load_src[2] = _mm256_loadu_ps(input_ptr + (in_stride * 2));
220   load_src[3] = _mm256_loadu_ps(input_ptr + (in_stride * 3));
221 
222   // Multiply the loaded input with corresponding weights.
223   load_src[0] = _mm256_mul_ps(load_src[0], weight[0]);
224   load_src[1] = _mm256_mul_ps(load_src[1], weight[1]);
225   load_src[2] = _mm256_mul_ps(load_src[2], weight[0]);
226   load_src[3] = _mm256_mul_ps(load_src[3], weight[1]);
227 
228   // Accumulate across 2x2 blocks.
229   load_src[0] = _mm256_add_ps(load_src[0], load_src[1]);
230   load_src[2] = _mm256_add_ps(load_src[2], load_src[3]);
231   load_src[0] = _mm256_hadd_ps(load_src[0], load_src[2]);
232 
233   // Sort the output in order to store into output buffer.
234   load_src[0] = _mm256_permutevar8x32_ps(load_src[0], shuffle_output_mask);
235   *out_accum = _mm256_add_ps(*out_accum, load_src[0]);
236 }
237 
238 // AVX2 variant of av1_cnn_convolve_no_maxpool_padding_valid_c(), when
239 // filter_width and filter_height are equal to 5.
240 // CNN convolve parsing is based on av1_intra_mode_cnn_partition_cnn_config.
241 // Based on the configuration set for each layer, the current encoder
242 // always chooses the case of no_maxpool_padding_valid.
243 // And also for layer 0 convolution happens at 5x5 level as the
244 // filter_width and filter_height are set as 5.
cnn_convolve_no_maxpool_padding_valid_5x5_avx2(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int channel_step)245 static void cnn_convolve_no_maxpool_padding_valid_5x5_avx2(
246     const float **input, int in_width, int in_height, int in_stride,
247     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
248     int start_idx, const int cstep, const int channel_step) {
249   const int kFilterWidth = 5;
250   const int kFilterHeight = 5;
251   const int kSkipWidth = 4;
252   const int kSkipHeight = 4;
253   assert(layer_config->filter_width == kFilterWidth &&
254          layer_config->filter_height == kFilterHeight);
255   assert(layer_config->skip_width == kSkipWidth &&
256          layer_config->skip_height == kSkipHeight);
257 
258   // Load shuffle buffers needed for source.
259   const __m256i block0_1 =
260       _mm256_load_si256((const __m256i *)shuffle_src_layer0[0]);
261   const __m256i block1_2 =
262       _mm256_load_si256((const __m256i *)shuffle_src_layer0[1]);
263 
264   // Load shuffle buffers needed for weight.
265   const __m256i weight_mask_0 =
266       _mm256_load_si256((const __m256i *)shuffle_weight_layer0[0]);
267   const __m256i weight_mask_1 =
268       _mm256_load_si256((const __m256i *)shuffle_weight_layer0[1]);
269 
270   // Width needs to be moved to go to next iteration of processing 3 5x5 blocks.
271   const int kSkipWidthForNextIter = kSkipWidth * 3;
272 
273   // Minimum width required to process 3 5x5 blocks at a time.
274   // min width (for processing 3 5x5 block) = 2*skip_width + filter_width
275   // Here, skip_width specifies how much width we should move while processing
276   // next block convolution and filter_width specifies for how many pixels
277   // filter needs to be applied.
278   const int kMinWidthFor3_5x5Blocks = (kSkipWidth * 2) + kFilterWidth;
279   for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
280     const float out_ch_bias = layer_config->bias[i];
281     for (int k = 0; k < layer_config->in_channels; ++k) {
282       __m256 shuffle_weight[10];
283 
284       // Weights needed are 5x5, for SIMD purpose made this array as 5x8.
285       float weight[5][8] = { { 0 } };
286       int off = k * layer_config->out_channels + i;
287 
288       // In layer 0, the convolution process happens at 5x5.
289       // The weights needed for 5x5 block are same across the in-channels,
290       // which is why the load of weights happens once for each in-channel.
291       prepare_weights_for_5x5_convolve(layer_config->weights, off, weight,
292                                        cstep, shuffle_weight, weight_mask_0,
293                                        weight_mask_1);
294 
295       for (int h = 0, u = 0; h < in_height - kFilterHeight + 1;
296            h += kSkipHeight, ++u) {
297         const int out_h = u * out_stride;
298         int v = 0;
299         int w = 0;
300         int rem_width = in_width;
301         // Processing 3 5x5 blocks at a time, if sufficient width is present.
302         while (rem_width >= kMinWidthFor3_5x5Blocks) {
303           __m256 load_src_0, load_src_1;
304           __m256 accum_src_0 = _mm256_setzero_ps();
305           __m256 accum_src_1 = _mm256_setzero_ps();
306           const float *input_ptr = &input[k][h * in_stride + w];
307           PERFORM_CONVOLVE_FOR_3_5X5_BLOCKS();
308 
309           // Accumulate across column.
310           __m256 accum = _mm256_hadd_ps(accum_src_0, accum_src_1);
311           __m128 tmp_reg_0 = _mm256_extractf128_ps(accum_src_0, 1);
312           __m128 tmp_reg_1 = _mm256_extractf128_ps(accum_src_1, 1);
313 
314           __m128 accum_l = _mm256_castps256_ps128(accum);
315           __m128 accum_h = _mm256_extractf128_ps(accum, 1);
316 
317           __m128 tmp_reg_2 = _mm_add_ps(accum_l, tmp_reg_0);
318           __m128 tmp_reg_3 = _mm_add_ps(tmp_reg_0, accum_h);
319           __m128 tmp_reg_4 = _mm_add_ps(tmp_reg_1, accum_h);
320 
321           // 1st 5x5 block output.
322           output[i][out_h + v] =
323               out_ch_bias + _mm_cvtss_f32(tmp_reg_2) +
324               _mm_cvtss_f32(_mm_shuffle_ps(accum_l, accum_l, 1));
325 
326           // 2nd 5x5 block output.
327           output[i][out_h + v + 1] =
328               out_ch_bias +
329               _mm_cvtss_f32(_mm_shuffle_ps(tmp_reg_3, tmp_reg_3, 1)) +
330               _mm_cvtss_f32(_mm_shuffle_ps(accum_l, accum_l, 2));
331 
332           // 3rd 5x5 block output.
333           output[i][out_h + v + 2] =
334               out_ch_bias +
335               _mm_cvtss_f32(_mm_shuffle_ps(tmp_reg_4, tmp_reg_4, 2)) +
336               _mm_cvtss_f32(_mm_shuffle_ps(accum_l, accum_l, 3));
337 
338           v += 3;
339           w += kSkipWidthForNextIter;
340           rem_width -= kSkipWidthForNextIter;
341         }
342 
343         // Process remaining blocks as single 5x5 block at a time.
344         while (rem_width >= kFilterWidth) {
345           float last_column_sum = 0;
346           __m128 accum = _mm_setzero_ps();
347           const float *input_ptr = &input[k][h * in_stride + w];
348           PERFORM_CONVOLVE_FOR_1_5X5_BLOCK(shuffle_weight, accum, in_stride);
349 
350           // Accumulate across column.
351           accum = _mm_hadd_ps(accum, accum);
352           output[i][out_h + v] = out_ch_bias + last_column_sum +
353                                  _mm_cvtss_f32(accum) +
354                                  _mm_cvtss_f32(_mm_shuffle_ps(accum, accum, 1));
355 
356           v += 1;
357           w += kSkipWidth;
358           rem_width -= kSkipWidth;
359         }
360       }
361     }
362   }
363 }
364 
365 // AVX2 implementation for layer 1.
cnn_convolve_no_maxpool_padding_valid_layer1_avx2(const float ** input,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int channel_step)366 static inline void cnn_convolve_no_maxpool_padding_valid_layer1_avx2(
367     const float **input, int in_stride,
368     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
369     int start_idx, const int cstep, const int channel_step) {
370   __m256i weight_mask[2];
371   __m256i shuffle_output_mask;
372   load_shuffle_masks_for_2x2_convolve(&shuffle_output_mask, weight_mask);
373 
374   const int kInHeight = 16;
375   const int kFilterHeight = 2;
376   const int kSkipHeight = 2;
377   for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
378     __m256 bias_reg = _mm256_set1_ps(layer_config->bias[i]);
379     // out_accum registers are used to store the 2x2 convolve outputs
380     // (calculated over input block size), which are accumulated across the
381     // in_channels. As per the design, each iteration of for loop processes 8
382     // (horizontal) 2x2 blocks and stores in corresponding out_accum register
383     // (as input size is 16x16, a total of 64 2x2 blocks are present and 8
384     // out_accum registers are enough to store the outputs).
385     // Hence for loops corresponding to 'j' and 'h', below, run over the number
386     // of out_accum registers.
387     __m256 out_accum[8];
388     for (int j = 0; j < 8; ++j) out_accum[j] = bias_reg;
389     for (int k = 0; k < layer_config->in_channels; ++k) {
390       __m256 shuffle_weight[2];
391       int off = k * layer_config->out_channels + i;
392       // In layer 1, the convolution process happens at 2x2.
393       // The weights needed for 2x2 block are same across the in-channels,
394       // which is why the load of weights happens once for each in-channel.
395       prepare_weights_for_2x2_convolve(layer_config->weights, off, cstep,
396                                        shuffle_weight, weight_mask);
397 
398       for (int h = 0, u = 0; h < kInHeight - kFilterHeight + 1;
399            h += kSkipHeight, ++u) {
400         const float *input_ptr = &input[k][h * in_stride];
401         perform_convolve_for_8h_2x2_blocks(input_ptr, in_stride, shuffle_weight,
402                                            &out_accum[u], shuffle_output_mask);
403       }
404     }
405     // Store output of layer 1.
406     for (int j = 0; j < 8; ++j) {
407       _mm256_storeu_ps(&output[i][j * out_stride], out_accum[j]);
408     }
409   }
410 }
411 
412 // AVX2 implementation for layer 2.
cnn_convolve_no_maxpool_padding_valid_layer2_avx2(const float ** input,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int channel_step)413 static inline void cnn_convolve_no_maxpool_padding_valid_layer2_avx2(
414     const float **input, int in_stride,
415     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
416     int start_idx, const int cstep, const int channel_step) {
417   __m256i weight_mask[2];
418   __m256i shuffle_output_mask;
419   load_shuffle_masks_for_2x2_convolve(&shuffle_output_mask, weight_mask);
420 
421   const int kInHeight = 8;
422   const int kFilterHeight = 2;
423   const int kSkipHeight = 2;
424   for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
425     __m256 bias_reg = _mm256_set1_ps(layer_config->bias[i]);
426     // out_accum registers are used to store the 2x2 convolve outputs
427     // (calculated over input block size), which are accumulated across the
428     // in_channels. As per the design, each iteration of for loop processes 8
429     // (4 horizontal x 2 vertical) 2x2 blocks and stores in corresponding
430     // out_accum register (as input size is 8x8, a total of 16 2x2 blocks are
431     // present and 2 out_accum registers are enough to store the outputs).
432     // Hence for loops corresponding to 'j' and 'h', below, run over the number
433     // of out_accum registers.
434     __m256 out_accum[2];
435 
436     // Height needs to be moved to go to next iteration of processing
437     // while processing 2 2x2 blocks vertically.
438     const int kSkipHeightForNextIter = kSkipHeight * 2;
439     for (int j = 0; j < 2; ++j) out_accum[j] = bias_reg;
440     for (int k = 0; k < layer_config->in_channels; ++k) {
441       __m256 shuffle_weight[2];
442       int off = k * layer_config->out_channels + i;
443       // In layer 2, the convolution process happens at 2x2.
444       // The weights needed for 2x2 block are same across the in-channels,
445       // which is why the load of weights happens once for each in-channel.
446       prepare_weights_for_2x2_convolve(layer_config->weights, off, cstep,
447                                        shuffle_weight, weight_mask);
448 
449       for (int h = 0, u = 0; h < kInHeight - kFilterHeight + 1;
450            h += kSkipHeightForNextIter, ++u) {
451         const float *input_ptr = &input[k][h * in_stride];
452         perform_convolve_for_4hx2v_2x2_blocks(input_ptr, in_stride,
453                                               shuffle_weight, &out_accum[u],
454                                               shuffle_output_mask);
455       }
456     }
457     // Store output of layer 2.
458     for (int j = 0; j < 2; ++j) {
459       _mm256_storeu_ps(&output[i][j * out_stride * 2], out_accum[j]);
460     }
461   }
462 }
463 
464 // AVX2 variant of av1_cnn_convolve_no_maxpool_padding_valid_c(), when
465 // filter_width and filter_height are equal to 2.
466 // As per the layer config set by av1_intra_mode_cnn_partition_cnn_config,
467 // the filter_width and filter_height are equal to 2 for layer >= 1. So
468 // convolution happens at 2x2 for layer >= 1.
cnn_convolve_no_maxpool_padding_valid_2x2_avx2(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int channel_step)469 static void cnn_convolve_no_maxpool_padding_valid_2x2_avx2(
470     const float **input, int in_width, int in_height, int in_stride,
471     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
472     int start_idx, const int cstep, const int channel_step) {
473   assert(layer_config->filter_width == 2 && layer_config->filter_height == 2);
474   assert(layer_config->skip_width == 2 && layer_config->skip_height == 2);
475 
476   if (in_width == 16 && in_height == 16) {
477     // This case of in_width and in_height equal to 16 corresponds to layer 1.
478     // The output size of this layer is 8x8.
479     cnn_convolve_no_maxpool_padding_valid_layer1_avx2(
480         input, in_stride, layer_config, output, out_stride, start_idx, cstep,
481         channel_step);
482   } else if (in_width == 8 && in_height == 8) {
483     // This case of in_width and in_height equal to 8 corresponds to layer 2.
484     // The output size of this layer is 4x4.
485     cnn_convolve_no_maxpool_padding_valid_layer2_avx2(
486         input, in_stride, layer_config, output, out_stride, start_idx, cstep,
487         channel_step);
488   } else {
489     // For layer equal to 3 and 4, the input is of size 4x4 and 2x2
490     // respectively. Implementing SIMD for these cases might not be optimal,
491     // which is why we call C path for layer >= 3.
492     av1_cnn_convolve_no_maxpool_padding_valid_c(
493         input, in_width, in_height, in_stride, layer_config, output, out_stride,
494         start_idx, cstep, channel_step);
495   }
496 }
497 
498 // AVX2 variant of av1_cnn_convolve_no_maxpool_padding_valid_c().
499 // As per the current encoder, av1_cnn_convolve function gets called for
500 // block size equal to 64x64. av1_cnn_convolve() uses layer config values
501 // set by av1_intra_mode_cnn_partition_cnn_config. The following are a few
502 // details related to each layer's config parameters.
503 // Layer_Number in_size out_size filter_wd filter_ht skip_wd skip_ht
504 //     0         64x64    16x16      5         5         4       4
505 //     1         16x16    8x8        2         2         2       2
506 //     2         8x8      4x4        2         2         2       2
507 //     3         4x4      2x2        2         2         2       2
508 //     4         2x2      1x1        2         2         2       2
509 // Here,
510 // filter_wd = filter_width and filter_ht = filter_height,
511 // skip_wd = skip_width and skip_ht = skip_height.
av1_cnn_convolve_no_maxpool_padding_valid_avx2(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,float ** output,int out_stride,int start_idx,int cstep,int channel_step)512 void av1_cnn_convolve_no_maxpool_padding_valid_avx2(
513     const float **input, int in_width, int in_height, int in_stride,
514     const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
515     int start_idx, int cstep, int channel_step) {
516   if (layer_config->filter_width == 5 && layer_config->filter_height == 5 &&
517       layer_config->skip_width == 4 && layer_config->skip_height == 4) {
518     cnn_convolve_no_maxpool_padding_valid_5x5_avx2(
519         input, in_width, in_height, in_stride, layer_config, output, out_stride,
520         start_idx, cstep, channel_step);
521   } else if (layer_config->filter_width == 2 &&
522              layer_config->filter_height == 2 &&
523              layer_config->skip_width == 2 && layer_config->skip_height == 2) {
524     cnn_convolve_no_maxpool_padding_valid_2x2_avx2(
525         input, in_width, in_height, in_stride, layer_config, output, out_stride,
526         start_idx, cstep, channel_step);
527   } else {
528     av1_cnn_convolve_no_maxpool_padding_valid_c(
529         input, in_width, in_height, in_stride, layer_config, output, out_stride,
530         start_idx, cstep, channel_step);
531   }
532 }
533