1 /*
2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <stdbool.h>
13 #include <assert.h>
14 #include <arm_neon.h>
15
16 #include "config/aom_config.h"
17 #include "config/av1_rtcd.h"
18 #include "av1/encoder/ml.h"
19
nn_activate8(float32x4_t * out_h,float32x4_t * out_l,const float32x4_t * zero)20 static void nn_activate8(float32x4_t *out_h, float32x4_t *out_l,
21 const float32x4_t *zero) {
22 *out_h = vmaxq_f32(*out_h, *zero);
23 *out_l = vmaxq_f32(*out_l, *zero);
24 }
25
nn_activate4(float32x4_t * x,const float32x4_t * zero)26 static void nn_activate4(float32x4_t *x, const float32x4_t *zero) {
27 *x = vmaxq_f32(*x, *zero);
28 }
29
30 #define CLAMP_0(x) (x = x > 0 ? x : 0)
31
nn_propagate_8to1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)32 static void nn_propagate_8to1(int num_inputs, const float *const inputs,
33 const float *const weights,
34 const float *layer_bias,
35 float *const output_nodes, bool output_layer) {
36 const float32x4_t zero = vdupq_n_f32(0);
37 float32x4_t vadd = zero;
38 float total = *layer_bias;
39
40 for (int in = 0; in < num_inputs; in += 8) {
41 const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
42 const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
43
44 const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
45 const float32x4_t weights_l = vld1q_f32(&weights[in]);
46
47 vadd = vmlaq_f32(vadd, inputs_h, weights_h);
48 vadd = vmlaq_f32(vadd, inputs_l, weights_l);
49 }
50 #if AOM_ARCH_AARCH64
51 total += vaddvq_f32(vadd);
52 #else
53 float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
54 vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
55 total += vget_lane_f32(vadd_lo, 0);
56 #endif
57
58 if (!output_layer) CLAMP_0(total);
59 *output_nodes = total;
60 }
61
nn_propagate_xto1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes)62 static void nn_propagate_xto1(int num_inputs, const float *const inputs,
63 const float *const weights,
64 const float *layer_bias,
65 float *const output_nodes) {
66 float32x4_t vadd = vdupq_n_f32(0);
67
68 float total = *layer_bias;
69 int j = num_inputs;
70 int in = 0;
71 while (j > 7) {
72 const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
73 const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
74
75 const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
76 const float32x4_t weights_l = vld1q_f32(&weights[in]);
77
78 vadd = vmlaq_f32(vadd, inputs_h, weights_h);
79 vadd = vmlaq_f32(vadd, inputs_l, weights_l);
80 in += 8;
81 j -= 8;
82 }
83
84 #if AOM_ARCH_AARCH64
85 total += vaddvq_f32(vadd);
86
87 #else
88 float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
89 vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
90 total += vget_lane_f32(vadd_lo, 0);
91 #endif
92 for (; in < num_inputs; in++) total += weights[in] * inputs[in];
93
94 *output_nodes = CLAMP_0(total);
95 }
96
nn_propagate_xsto1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes)97 static void nn_propagate_xsto1(int num_inputs, const float *const inputs,
98 const float *const weights,
99 const float *layer_bias,
100 float *const output_nodes) {
101 float total = *layer_bias;
102 #if AOM_ARCH_AARCH64
103 const float32x4_t v_inputs = vld1q_f32(inputs);
104 const float32x4_t v_weights = vld1q_f32(weights);
105 const float32x4_t vadd = vmulq_f32(v_inputs, v_weights);
106 total += vaddvq_f32(vadd);
107 int in = 4;
108 #else
109 int in = 0;
110 #endif
111 for (; in < num_inputs; in++) total += weights[in] * inputs[in];
112
113 *output_nodes = CLAMP_0(total);
114 }
115
nn_propagate_4to1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)116 static void nn_propagate_4to1(int num_inputs, const float *const inputs,
117 const float *const weights,
118 const float *layer_bias,
119 float *const output_nodes, bool output_layer) {
120 const float32x4_t zero = vdupq_n_f32(0);
121 float32x4_t vadd = zero;
122 float total = *layer_bias;
123
124 for (int in = 0; in < num_inputs; in += 4) {
125 const float32x4_t v_inputs = vld1q_f32(&inputs[in]);
126 const float32x4_t v_weights = vld1q_f32(&weights[in]);
127 vadd = vmlaq_f32(vadd, v_inputs, v_weights);
128 }
129
130 #if AOM_ARCH_AARCH64
131 total += vaddvq_f32(vadd);
132 #else
133 float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
134 vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
135 total += vget_lane_f32(vadd_lo, 0);
136 #endif
137
138 if (!output_layer) CLAMP_0(total);
139 *output_nodes = total;
140 }
141
nn_propagate_4to4(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)142 static void nn_propagate_4to4(int num_inputs, const float *const inputs,
143 const float *const weights,
144 const float *layer_bias,
145 float *const output_nodes, bool output_layer) {
146 float32x4_t outputs = vld1q_f32(layer_bias);
147 const float32x4_t zero = vdupq_n_f32(0);
148
149 float32x4_t mul0[2] = { zero, zero };
150 float32x4_t mul1[2] = { zero, zero };
151 for (int in = 0; in < num_inputs; in += 4) {
152 const float32x4_t v_input = vld1q_f32(&inputs[in]);
153
154 for (int i = 0; i < 2; i++) {
155 const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
156 mul0[i] = vmlaq_f32(mul0[i], weight0, v_input);
157 const float32x4_t weight1 =
158 vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
159 mul1[i] = vmlaq_f32(mul1[i], weight1, v_input);
160 }
161 }
162 for (int i = 0; i < 2; i++)
163 #if AOM_ARCH_AARCH64
164 mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
165 const float32x4_t hh = vpaddq_f32(mul0[0], mul0[1]);
166 #else
167 mul0[i] =
168 vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
169 vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
170 const float32x4_t hh =
171 vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
172 vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
173 #endif
174
175 outputs = vaddq_f32(outputs, hh);
176 if (!output_layer) nn_activate4(&outputs, &zero);
177 vst1q_f32(output_nodes, outputs);
178 }
179
nn_propagate_4to8(const int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)180 static void nn_propagate_4to8(const int num_inputs, const float *const inputs,
181 const float *const weights,
182 const float *layer_bias,
183 float *const output_nodes, bool output_layer) {
184 float32x4_t out_h = vld1q_f32(&layer_bias[4]);
185 float32x4_t out_l = vld1q_f32(layer_bias);
186 const float32x4_t zero = vdupq_n_f32(0);
187 float32x4_t mul0[4] = { zero, zero, zero, zero };
188 float32x4_t mul1[4] = { zero, zero, zero, zero };
189
190 for (int in = 0; in < num_inputs; in += 4) {
191 const float32x4_t v_input = vld1q_f32(&inputs[in]);
192 for (int i = 0; i < 4; i++) {
193 const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
194 const float32x4_t weight1 =
195 vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
196 mul0[i] = vmlaq_f32(mul0[i], v_input, weight0);
197 mul1[i] = vmlaq_f32(mul1[i], v_input, weight1);
198 }
199 }
200 for (int i = 0; i < 4; i++)
201 #if AOM_ARCH_AARCH64
202 mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
203 const float32x4_t hh0 = vpaddq_f32(mul0[0], mul0[1]);
204 const float32x4_t hh1 = vpaddq_f32(mul0[2], mul0[3]);
205 #else
206 mul0[i] =
207 vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
208 vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
209 const float32x4_t hh0 =
210 vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
211 vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
212 const float32x4_t hh1 =
213 vcombine_f32(vpadd_f32(vget_low_f32(mul0[2]), vget_high_f32(mul0[2])),
214 vpadd_f32(vget_low_f32(mul0[3]), vget_high_f32(mul0[3])));
215 #endif
216
217 out_h = vaddq_f32(out_h, hh1);
218 out_l = vaddq_f32(out_l, hh0);
219
220 if (!output_layer) nn_activate8(&out_h, &out_l, &zero);
221 vst1q_f32(&output_nodes[4], out_h);
222 vst1q_f32(output_nodes, out_l);
223 }
224
nn_propagate_8to4(const int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)225 static void nn_propagate_8to4(const int num_inputs, const float *const inputs,
226 const float *const weights,
227 const float *layer_bias,
228 float *const output_nodes, bool output_layer) {
229 float32x4_t outputs = vld1q_f32(layer_bias);
230 const float32x4_t zero = vdupq_n_f32(0);
231 float32x4_t add[4] = { zero, zero, zero, zero };
232 for (int in = 0; in < num_inputs; in += 8) {
233 const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
234 const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
235
236 for (int i = 0; i < 4; i++) {
237 const float32x4_t weight_l = vld1q_f32(&weights[in + i * num_inputs]);
238 const float32x4_t weight_h = vld1q_f32(&weights[in + i * num_inputs + 4]);
239 add[i] = vmlaq_f32(add[i], inputs_l, weight_l);
240 add[i] = vmlaq_f32(add[i], inputs_h, weight_h);
241 }
242 }
243 #if AOM_ARCH_AARCH64
244 const float32x4_t hadd_h = vpaddq_f32(add[2], add[3]);
245 const float32x4_t hadd_l = vpaddq_f32(add[0], add[1]);
246 const float32x4_t haddhadd = vpaddq_f32(hadd_l, hadd_h);
247 #else
248 const float32x4_t hadd_h =
249 vcombine_f32(vpadd_f32(vget_low_f32(add[2]), vget_high_f32(add[2])),
250 vpadd_f32(vget_low_f32(add[3]), vget_high_f32(add[3])));
251 const float32x4_t hadd_l =
252 vcombine_f32(vpadd_f32(vget_low_f32(add[0]), vget_high_f32(add[0])),
253 vpadd_f32(vget_low_f32(add[1]), vget_high_f32(add[1])));
254 const float32x4_t haddhadd =
255 vcombine_f32(vpadd_f32(vget_low_f32(hadd_l), vget_high_f32(hadd_l)),
256 vpadd_f32(vget_low_f32(hadd_h), vget_high_f32(hadd_h)));
257 #endif
258
259 outputs = vaddq_f32(outputs, haddhadd);
260 if (!output_layer) nn_activate4(&outputs, &zero);
261 vst1q_f32(output_nodes, outputs);
262 }
263
264 // Calculate prediction based on the given input features and neural net config.
265 // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
266 // layer.
av1_nn_predict_neon(const float * input_nodes,const NN_CONFIG * const nn_config,int reduce_prec,float * const output)267 void av1_nn_predict_neon(const float *input_nodes,
268 const NN_CONFIG *const nn_config, int reduce_prec,
269 float *const output) {
270 float buf[2][NN_MAX_NODES_PER_LAYER];
271 int buf_index = 0;
272 int num_inputs = nn_config->num_inputs;
273 // Hidden layers, except the final iteration is the output layer.
274 for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
275 const float *layer_weights = nn_config->weights[layer];
276 const float *layer_bias = nn_config->bias[layer];
277 bool output_layer = (layer == nn_config->num_hidden_layers);
278 float *const output_nodes = output_layer ? output : buf[buf_index];
279 const int num_outputs = output_layer ? nn_config->num_outputs
280 : nn_config->num_hidden_nodes[layer];
281
282 if (num_inputs % 4 == 0 && num_outputs % 8 == 0) {
283 for (int out = 0; out < num_outputs; out += 8) {
284 nn_propagate_4to8(num_inputs, input_nodes,
285 &layer_weights[out * num_inputs], &layer_bias[out],
286 &output_nodes[out], output_layer);
287 }
288 } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) {
289 for (int out = 0; out < num_outputs; out += 4) {
290 nn_propagate_8to4(num_inputs, input_nodes,
291 &layer_weights[out * num_inputs], &layer_bias[out],
292 &output_nodes[out], output_layer);
293 }
294 } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) {
295 for (int out = 0; out < num_outputs; out += 4) {
296 nn_propagate_4to4(num_inputs, input_nodes,
297 &layer_weights[out * num_inputs], &layer_bias[out],
298 &output_nodes[out], output_layer);
299 }
300 } else if (num_inputs % 8 == 0) {
301 for (int out = 0; out < num_outputs; out++) {
302 nn_propagate_8to1(num_inputs, input_nodes,
303 &layer_weights[out * num_inputs], &layer_bias[out],
304 &output_nodes[out], output_layer);
305 }
306 } else if (num_inputs % 4 == 0) {
307 for (int out = 0; out < num_outputs; out++) {
308 nn_propagate_4to1(num_inputs, input_nodes,
309 &layer_weights[out * num_inputs], &layer_bias[out],
310 &output_nodes[out], output_layer);
311 }
312 } else if (num_inputs > 8) {
313 for (int out = 0; out < num_outputs; out++) {
314 nn_propagate_xto1(num_inputs, input_nodes,
315 &layer_weights[out * num_inputs], &layer_bias[out],
316 &output_nodes[out]);
317 }
318 } else if (num_inputs >= 4) {
319 for (int out = 0; out < num_outputs; out++) {
320 nn_propagate_xsto1(num_inputs, input_nodes,
321 &layer_weights[out * num_inputs], &layer_bias[out],
322 &output_nodes[out]);
323 }
324 } else {
325 for (int node = 0; node < num_outputs; ++node) {
326 float val = layer_bias[node];
327 for (int i = 0; i < num_inputs; ++i)
328 val += layer_weights[node * num_inputs + i] * input_nodes[i];
329 // ReLU as activation function.
330 val = val > 0.0f ? val : 0.0f; // Could use AOMMAX().
331 output_nodes[node] = val;
332 }
333 }
334 input_nodes = output_nodes;
335 num_inputs = num_outputs;
336 buf_index = 1 - buf_index;
337 }
338 if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
339 }
340