1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <stdbool.h>
13 #include <assert.h>
14 #include <immintrin.h>
15
16 #include "config/av1_rtcd.h"
17 #include "av1/encoder/ml.h"
18 #include "av1/encoder/x86/ml_sse3.h"
19
20 #define CALC_OUTPUT_FOR_2ROWS \
21 const int index = weight_idx + (2 * i * tot_num_inputs); \
22 const __m256 weight0 = _mm256_loadu_ps(&weights[index]); \
23 const __m256 weight1 = _mm256_loadu_ps(&weights[index + tot_num_inputs]); \
24 const __m256 mul0 = _mm256_mul_ps(inputs256, weight0); \
25 const __m256 mul1 = _mm256_mul_ps(inputs256, weight1); \
26 hadd[i] = _mm256_hadd_ps(mul0, mul1);
27
nn_propagate_8to1(const float * const inputs,const float * const weights,const float * const bias,int num_inputs_to_process,int tot_num_inputs,int num_outputs,float * const output_nodes,int is_clip_required)28 static inline void nn_propagate_8to1(
29 const float *const inputs, const float *const weights,
30 const float *const bias, int num_inputs_to_process, int tot_num_inputs,
31 int num_outputs, float *const output_nodes, int is_clip_required) {
32 // Process one output row at a time.
33 for (int out = 0; out < num_outputs; out++) {
34 __m256 in_result = _mm256_setzero_ps();
35 float bias_val = bias[out];
36 for (int in = 0; in < num_inputs_to_process; in += 8) {
37 const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]);
38 const int weight_idx = in + (out * tot_num_inputs);
39 const __m256 weight0 = _mm256_loadu_ps(&weights[weight_idx]);
40 const __m256 mul0 = _mm256_mul_ps(inputs256, weight0);
41 in_result = _mm256_add_ps(in_result, mul0);
42 }
43 const __m128 low_128 = _mm256_castps256_ps128(in_result);
44 const __m128 high_128 = _mm256_extractf128_ps(in_result, 1);
45 const __m128 sum_par_0 = _mm_add_ps(low_128, high_128);
46 const __m128 sum_par_1 = _mm_hadd_ps(sum_par_0, sum_par_0);
47 const __m128 sum_tot =
48 _mm_add_ps(_mm_shuffle_ps(sum_par_1, sum_par_1, 0x99), sum_par_1);
49
50 bias_val += _mm_cvtss_f32(sum_tot);
51 if (is_clip_required) bias_val = AOMMAX(bias_val, 0);
52 output_nodes[out] = bias_val;
53 }
54 }
55
nn_propagate_8to4(const float * const inputs,const float * const weights,const float * const bias,int num_inputs_to_process,int tot_num_inputs,int num_outputs,float * const output_nodes,int is_clip_required)56 static inline void nn_propagate_8to4(
57 const float *const inputs, const float *const weights,
58 const float *const bias, int num_inputs_to_process, int tot_num_inputs,
59 int num_outputs, float *const output_nodes, int is_clip_required) {
60 __m256 hadd[2];
61 for (int out = 0; out < num_outputs; out += 4) {
62 __m128 bias_reg = _mm_loadu_ps(&bias[out]);
63 __m128 in_result = _mm_setzero_ps();
64 for (int in = 0; in < num_inputs_to_process; in += 8) {
65 const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]);
66 const int weight_idx = in + (out * tot_num_inputs);
67 // Process two output row at a time.
68 for (int i = 0; i < 2; i++) {
69 CALC_OUTPUT_FOR_2ROWS
70 }
71
72 const __m256 sum_par = _mm256_hadd_ps(hadd[0], hadd[1]);
73 const __m128 low_128 = _mm256_castps256_ps128(sum_par);
74 const __m128 high_128 = _mm256_extractf128_ps(sum_par, 1);
75 const __m128 result = _mm_add_ps(low_128, high_128);
76
77 in_result = _mm_add_ps(in_result, result);
78 }
79
80 in_result = _mm_add_ps(in_result, bias_reg);
81 if (is_clip_required) in_result = _mm_max_ps(in_result, _mm_setzero_ps());
82 _mm_storeu_ps(&output_nodes[out], in_result);
83 }
84 }
85
nn_propagate_8to8(const float * const inputs,const float * const weights,const float * const bias,int num_inputs_to_process,int tot_num_inputs,int num_outputs,float * const output_nodes,int is_clip_required)86 static inline void nn_propagate_8to8(
87 const float *const inputs, const float *const weights,
88 const float *const bias, int num_inputs_to_process, int tot_num_inputs,
89 int num_outputs, float *const output_nodes, int is_clip_required) {
90 __m256 hadd[4];
91 for (int out = 0; out < num_outputs; out += 8) {
92 __m256 bias_reg = _mm256_loadu_ps(&bias[out]);
93 __m256 in_result = _mm256_setzero_ps();
94 for (int in = 0; in < num_inputs_to_process; in += 8) {
95 const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]);
96 const int weight_idx = in + (out * tot_num_inputs);
97 // Process two output rows at a time.
98 for (int i = 0; i < 4; i++) {
99 CALC_OUTPUT_FOR_2ROWS
100 }
101 const __m256 hh0 = _mm256_hadd_ps(hadd[0], hadd[1]);
102 const __m256 hh1 = _mm256_hadd_ps(hadd[2], hadd[3]);
103
104 __m256 ht_0 = _mm256_permute2f128_ps(hh0, hh1, 0x20);
105 __m256 ht_1 = _mm256_permute2f128_ps(hh0, hh1, 0x31);
106
107 __m256 result = _mm256_add_ps(ht_0, ht_1);
108 in_result = _mm256_add_ps(in_result, result);
109 }
110 in_result = _mm256_add_ps(in_result, bias_reg);
111 if (is_clip_required)
112 in_result = _mm256_max_ps(in_result, _mm256_setzero_ps());
113 _mm256_storeu_ps(&output_nodes[out], in_result);
114 }
115 }
116
nn_propagate_input_multiple_of_8(const float * const inputs,const float * const weights,const float * const bias,int num_inputs_to_process,int tot_num_inputs,bool is_output_layer,int num_outputs,float * const output_nodes)117 static inline void nn_propagate_input_multiple_of_8(
118 const float *const inputs, const float *const weights,
119 const float *const bias, int num_inputs_to_process, int tot_num_inputs,
120 bool is_output_layer, int num_outputs, float *const output_nodes) {
121 // The saturation of output is considered for hidden layer which is not equal
122 // to final hidden layer.
123 const int is_clip_required =
124 !is_output_layer && num_inputs_to_process == tot_num_inputs;
125 if (num_outputs % 8 == 0) {
126 nn_propagate_8to8(inputs, weights, bias, num_inputs_to_process,
127 tot_num_inputs, num_outputs, output_nodes,
128 is_clip_required);
129 } else if (num_outputs % 4 == 0) {
130 nn_propagate_8to4(inputs, weights, bias, num_inputs_to_process,
131 tot_num_inputs, num_outputs, output_nodes,
132 is_clip_required);
133 } else {
134 nn_propagate_8to1(inputs, weights, bias, num_inputs_to_process,
135 tot_num_inputs, num_outputs, output_nodes,
136 is_clip_required);
137 }
138 }
139
av1_nn_predict_avx2(const float * input_nodes,const NN_CONFIG * const nn_config,int reduce_prec,float * const output)140 void av1_nn_predict_avx2(const float *input_nodes,
141 const NN_CONFIG *const nn_config, int reduce_prec,
142 float *const output) {
143 float buf[2][NN_MAX_NODES_PER_LAYER];
144 int buf_index = 0;
145 int num_inputs = nn_config->num_inputs;
146 assert(num_inputs > 0 && num_inputs <= NN_MAX_NODES_PER_LAYER);
147
148 for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
149 const float *layer_weights = nn_config->weights[layer];
150 const float *layer_bias = nn_config->bias[layer];
151 bool is_output_layer = layer == nn_config->num_hidden_layers;
152 float *const output_nodes = is_output_layer ? output : &buf[buf_index][0];
153 const int num_outputs = is_output_layer
154 ? nn_config->num_outputs
155 : nn_config->num_hidden_nodes[layer];
156 assert(num_outputs > 0 && num_outputs <= NN_MAX_NODES_PER_LAYER);
157
158 // Process input multiple of 8 using AVX2 intrinsic.
159 if (num_inputs % 8 == 0) {
160 nn_propagate_input_multiple_of_8(input_nodes, layer_weights, layer_bias,
161 num_inputs, num_inputs, is_output_layer,
162 num_outputs, output_nodes);
163 } else {
164 // When number of inputs is not multiple of 8, use hybrid approach of AVX2
165 // and SSE3 based on the need.
166 const int in_mul_8 = num_inputs / 8;
167 const int num_inputs_to_process = in_mul_8 * 8;
168 int bias_is_considered = 0;
169 if (in_mul_8) {
170 nn_propagate_input_multiple_of_8(
171 input_nodes, layer_weights, layer_bias, num_inputs_to_process,
172 num_inputs, is_output_layer, num_outputs, output_nodes);
173 bias_is_considered = 1;
174 }
175
176 const float *out_temp = bias_is_considered ? output_nodes : layer_bias;
177 const int input_remaining = num_inputs % 8;
178 if (input_remaining % 4 == 0 && num_outputs % 8 == 0) {
179 for (int out = 0; out < num_outputs; out += 8) {
180 __m128 out_h = _mm_loadu_ps(&out_temp[out + 4]);
181 __m128 out_l = _mm_loadu_ps(&out_temp[out]);
182 for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
183 av1_nn_propagate_4to8_sse3(&input_nodes[in],
184 &layer_weights[out * num_inputs + in],
185 &out_h, &out_l, num_inputs);
186 }
187 if (!is_output_layer) {
188 const __m128 zero = _mm_setzero_ps();
189 out_h = _mm_max_ps(out_h, zero);
190 out_l = _mm_max_ps(out_l, zero);
191 }
192 _mm_storeu_ps(&output_nodes[out + 4], out_h);
193 _mm_storeu_ps(&output_nodes[out], out_l);
194 }
195 } else if (input_remaining % 4 == 0 && num_outputs % 4 == 0) {
196 for (int out = 0; out < num_outputs; out += 4) {
197 __m128 outputs = _mm_loadu_ps(&out_temp[out]);
198 for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
199 av1_nn_propagate_4to4_sse3(&input_nodes[in],
200 &layer_weights[out * num_inputs + in],
201 &outputs, num_inputs);
202 }
203 if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
204 _mm_storeu_ps(&output_nodes[out], outputs);
205 }
206 } else if (input_remaining % 4 == 0) {
207 for (int out = 0; out < num_outputs; out++) {
208 __m128 outputs = _mm_load1_ps(&out_temp[out]);
209 for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
210 av1_nn_propagate_4to1_sse3(&input_nodes[in],
211 &layer_weights[out * num_inputs + in],
212 &outputs);
213 }
214 if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
215 output_nodes[out] = _mm_cvtss_f32(outputs);
216 }
217 } else {
218 // Use SSE instructions for scalar operations to avoid the latency
219 // of swapping between SIMD and FPU modes.
220 for (int out = 0; out < num_outputs; out++) {
221 __m128 outputs = _mm_load1_ps(&out_temp[out]);
222 for (int in_node = in_mul_8 * 8; in_node < num_inputs; in_node++) {
223 __m128 input = _mm_load1_ps(&input_nodes[in_node]);
224 __m128 weight =
225 _mm_load1_ps(&layer_weights[num_inputs * out + in_node]);
226 outputs = _mm_add_ps(outputs, _mm_mul_ps(input, weight));
227 }
228 if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
229 output_nodes[out] = _mm_cvtss_f32(outputs);
230 }
231 }
232 }
233 // Before processing the next layer, treat the output of current layer as
234 // input to next layer.
235 input_nodes = output_nodes;
236 num_inputs = num_outputs;
237 buf_index = 1 - buf_index;
238 }
239 if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
240 }
241