xref: /aosp_15_r20/external/libaom/av1/encoder/arm/ml_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <stdbool.h>
13 #include <assert.h>
14 #include <arm_neon.h>
15 
16 #include "config/aom_config.h"
17 #include "config/av1_rtcd.h"
18 #include "av1/encoder/ml.h"
19 
nn_activate8(float32x4_t * out_h,float32x4_t * out_l,const float32x4_t * zero)20 static void nn_activate8(float32x4_t *out_h, float32x4_t *out_l,
21                          const float32x4_t *zero) {
22   *out_h = vmaxq_f32(*out_h, *zero);
23   *out_l = vmaxq_f32(*out_l, *zero);
24 }
25 
nn_activate4(float32x4_t * x,const float32x4_t * zero)26 static void nn_activate4(float32x4_t *x, const float32x4_t *zero) {
27   *x = vmaxq_f32(*x, *zero);
28 }
29 
30 #define CLAMP_0(x) (x = x > 0 ? x : 0)
31 
nn_propagate_8to1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)32 static void nn_propagate_8to1(int num_inputs, const float *const inputs,
33                               const float *const weights,
34                               const float *layer_bias,
35                               float *const output_nodes, bool output_layer) {
36   const float32x4_t zero = vdupq_n_f32(0);
37   float32x4_t vadd = zero;
38   float total = *layer_bias;
39 
40   for (int in = 0; in < num_inputs; in += 8) {
41     const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
42     const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
43 
44     const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
45     const float32x4_t weights_l = vld1q_f32(&weights[in]);
46 
47     vadd = vmlaq_f32(vadd, inputs_h, weights_h);
48     vadd = vmlaq_f32(vadd, inputs_l, weights_l);
49   }
50 #if AOM_ARCH_AARCH64
51   total += vaddvq_f32(vadd);
52 #else
53   float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
54   vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
55   total += vget_lane_f32(vadd_lo, 0);
56 #endif
57 
58   if (!output_layer) CLAMP_0(total);
59   *output_nodes = total;
60 }
61 
nn_propagate_xto1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes)62 static void nn_propagate_xto1(int num_inputs, const float *const inputs,
63                               const float *const weights,
64                               const float *layer_bias,
65                               float *const output_nodes) {
66   float32x4_t vadd = vdupq_n_f32(0);
67 
68   float total = *layer_bias;
69   int j = num_inputs;
70   int in = 0;
71   while (j > 7) {
72     const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
73     const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
74 
75     const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
76     const float32x4_t weights_l = vld1q_f32(&weights[in]);
77 
78     vadd = vmlaq_f32(vadd, inputs_h, weights_h);
79     vadd = vmlaq_f32(vadd, inputs_l, weights_l);
80     in += 8;
81     j -= 8;
82   }
83 
84 #if AOM_ARCH_AARCH64
85   total += vaddvq_f32(vadd);
86 
87 #else
88   float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
89   vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
90   total += vget_lane_f32(vadd_lo, 0);
91 #endif
92   for (; in < num_inputs; in++) total += weights[in] * inputs[in];
93 
94   *output_nodes = CLAMP_0(total);
95 }
96 
nn_propagate_xsto1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes)97 static void nn_propagate_xsto1(int num_inputs, const float *const inputs,
98                                const float *const weights,
99                                const float *layer_bias,
100                                float *const output_nodes) {
101   float total = *layer_bias;
102 #if AOM_ARCH_AARCH64
103   const float32x4_t v_inputs = vld1q_f32(inputs);
104   const float32x4_t v_weights = vld1q_f32(weights);
105   const float32x4_t vadd = vmulq_f32(v_inputs, v_weights);
106   total += vaddvq_f32(vadd);
107   int in = 4;
108 #else
109   int in = 0;
110 #endif
111   for (; in < num_inputs; in++) total += weights[in] * inputs[in];
112 
113   *output_nodes = CLAMP_0(total);
114 }
115 
nn_propagate_4to1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)116 static void nn_propagate_4to1(int num_inputs, const float *const inputs,
117                               const float *const weights,
118                               const float *layer_bias,
119                               float *const output_nodes, bool output_layer) {
120   const float32x4_t zero = vdupq_n_f32(0);
121   float32x4_t vadd = zero;
122   float total = *layer_bias;
123 
124   for (int in = 0; in < num_inputs; in += 4) {
125     const float32x4_t v_inputs = vld1q_f32(&inputs[in]);
126     const float32x4_t v_weights = vld1q_f32(&weights[in]);
127     vadd = vmlaq_f32(vadd, v_inputs, v_weights);
128   }
129 
130 #if AOM_ARCH_AARCH64
131   total += vaddvq_f32(vadd);
132 #else
133   float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
134   vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
135   total += vget_lane_f32(vadd_lo, 0);
136 #endif
137 
138   if (!output_layer) CLAMP_0(total);
139   *output_nodes = total;
140 }
141 
nn_propagate_4to4(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)142 static void nn_propagate_4to4(int num_inputs, const float *const inputs,
143                               const float *const weights,
144                               const float *layer_bias,
145                               float *const output_nodes, bool output_layer) {
146   float32x4_t outputs = vld1q_f32(layer_bias);
147   const float32x4_t zero = vdupq_n_f32(0);
148 
149   float32x4_t mul0[2] = { zero, zero };
150   float32x4_t mul1[2] = { zero, zero };
151   for (int in = 0; in < num_inputs; in += 4) {
152     const float32x4_t v_input = vld1q_f32(&inputs[in]);
153 
154     for (int i = 0; i < 2; i++) {
155       const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
156       mul0[i] = vmlaq_f32(mul0[i], weight0, v_input);
157       const float32x4_t weight1 =
158           vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
159       mul1[i] = vmlaq_f32(mul1[i], weight1, v_input);
160     }
161   }
162   for (int i = 0; i < 2; i++)
163 #if AOM_ARCH_AARCH64
164     mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
165   const float32x4_t hh = vpaddq_f32(mul0[0], mul0[1]);
166 #else
167     mul0[i] =
168         vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
169                      vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
170   const float32x4_t hh =
171       vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
172                    vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
173 #endif
174 
175   outputs = vaddq_f32(outputs, hh);
176   if (!output_layer) nn_activate4(&outputs, &zero);
177   vst1q_f32(output_nodes, outputs);
178 }
179 
nn_propagate_4to8(const int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)180 static void nn_propagate_4to8(const int num_inputs, const float *const inputs,
181                               const float *const weights,
182                               const float *layer_bias,
183                               float *const output_nodes, bool output_layer) {
184   float32x4_t out_h = vld1q_f32(&layer_bias[4]);
185   float32x4_t out_l = vld1q_f32(layer_bias);
186   const float32x4_t zero = vdupq_n_f32(0);
187   float32x4_t mul0[4] = { zero, zero, zero, zero };
188   float32x4_t mul1[4] = { zero, zero, zero, zero };
189 
190   for (int in = 0; in < num_inputs; in += 4) {
191     const float32x4_t v_input = vld1q_f32(&inputs[in]);
192     for (int i = 0; i < 4; i++) {
193       const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
194       const float32x4_t weight1 =
195           vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
196       mul0[i] = vmlaq_f32(mul0[i], v_input, weight0);
197       mul1[i] = vmlaq_f32(mul1[i], v_input, weight1);
198     }
199   }
200   for (int i = 0; i < 4; i++)
201 #if AOM_ARCH_AARCH64
202     mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
203   const float32x4_t hh0 = vpaddq_f32(mul0[0], mul0[1]);
204   const float32x4_t hh1 = vpaddq_f32(mul0[2], mul0[3]);
205 #else
206     mul0[i] =
207         vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
208                      vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
209   const float32x4_t hh0 =
210       vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
211                    vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
212   const float32x4_t hh1 =
213       vcombine_f32(vpadd_f32(vget_low_f32(mul0[2]), vget_high_f32(mul0[2])),
214                    vpadd_f32(vget_low_f32(mul0[3]), vget_high_f32(mul0[3])));
215 #endif
216 
217   out_h = vaddq_f32(out_h, hh1);
218   out_l = vaddq_f32(out_l, hh0);
219 
220   if (!output_layer) nn_activate8(&out_h, &out_l, &zero);
221   vst1q_f32(&output_nodes[4], out_h);
222   vst1q_f32(output_nodes, out_l);
223 }
224 
nn_propagate_8to4(const int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)225 static void nn_propagate_8to4(const int num_inputs, const float *const inputs,
226                               const float *const weights,
227                               const float *layer_bias,
228                               float *const output_nodes, bool output_layer) {
229   float32x4_t outputs = vld1q_f32(layer_bias);
230   const float32x4_t zero = vdupq_n_f32(0);
231   float32x4_t add[4] = { zero, zero, zero, zero };
232   for (int in = 0; in < num_inputs; in += 8) {
233     const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
234     const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
235 
236     for (int i = 0; i < 4; i++) {
237       const float32x4_t weight_l = vld1q_f32(&weights[in + i * num_inputs]);
238       const float32x4_t weight_h = vld1q_f32(&weights[in + i * num_inputs + 4]);
239       add[i] = vmlaq_f32(add[i], inputs_l, weight_l);
240       add[i] = vmlaq_f32(add[i], inputs_h, weight_h);
241     }
242   }
243 #if AOM_ARCH_AARCH64
244   const float32x4_t hadd_h = vpaddq_f32(add[2], add[3]);
245   const float32x4_t hadd_l = vpaddq_f32(add[0], add[1]);
246   const float32x4_t haddhadd = vpaddq_f32(hadd_l, hadd_h);
247 #else
248   const float32x4_t hadd_h =
249       vcombine_f32(vpadd_f32(vget_low_f32(add[2]), vget_high_f32(add[2])),
250                    vpadd_f32(vget_low_f32(add[3]), vget_high_f32(add[3])));
251   const float32x4_t hadd_l =
252       vcombine_f32(vpadd_f32(vget_low_f32(add[0]), vget_high_f32(add[0])),
253                    vpadd_f32(vget_low_f32(add[1]), vget_high_f32(add[1])));
254   const float32x4_t haddhadd =
255       vcombine_f32(vpadd_f32(vget_low_f32(hadd_l), vget_high_f32(hadd_l)),
256                    vpadd_f32(vget_low_f32(hadd_h), vget_high_f32(hadd_h)));
257 #endif
258 
259   outputs = vaddq_f32(outputs, haddhadd);
260   if (!output_layer) nn_activate4(&outputs, &zero);
261   vst1q_f32(output_nodes, outputs);
262 }
263 
264 // Calculate prediction based on the given input features and neural net config.
265 // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
266 // layer.
av1_nn_predict_neon(const float * input_nodes,const NN_CONFIG * const nn_config,int reduce_prec,float * const output)267 void av1_nn_predict_neon(const float *input_nodes,
268                          const NN_CONFIG *const nn_config, int reduce_prec,
269                          float *const output) {
270   float buf[2][NN_MAX_NODES_PER_LAYER];
271   int buf_index = 0;
272   int num_inputs = nn_config->num_inputs;
273   // Hidden layers, except the final iteration is the output layer.
274   for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
275     const float *layer_weights = nn_config->weights[layer];
276     const float *layer_bias = nn_config->bias[layer];
277     bool output_layer = (layer == nn_config->num_hidden_layers);
278     float *const output_nodes = output_layer ? output : buf[buf_index];
279     const int num_outputs = output_layer ? nn_config->num_outputs
280                                          : nn_config->num_hidden_nodes[layer];
281 
282     if (num_inputs % 4 == 0 && num_outputs % 8 == 0) {
283       for (int out = 0; out < num_outputs; out += 8) {
284         nn_propagate_4to8(num_inputs, input_nodes,
285                           &layer_weights[out * num_inputs], &layer_bias[out],
286                           &output_nodes[out], output_layer);
287       }
288     } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) {
289       for (int out = 0; out < num_outputs; out += 4) {
290         nn_propagate_8to4(num_inputs, input_nodes,
291                           &layer_weights[out * num_inputs], &layer_bias[out],
292                           &output_nodes[out], output_layer);
293       }
294     } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) {
295       for (int out = 0; out < num_outputs; out += 4) {
296         nn_propagate_4to4(num_inputs, input_nodes,
297                           &layer_weights[out * num_inputs], &layer_bias[out],
298                           &output_nodes[out], output_layer);
299       }
300     } else if (num_inputs % 8 == 0) {
301       for (int out = 0; out < num_outputs; out++) {
302         nn_propagate_8to1(num_inputs, input_nodes,
303                           &layer_weights[out * num_inputs], &layer_bias[out],
304                           &output_nodes[out], output_layer);
305       }
306     } else if (num_inputs % 4 == 0) {
307       for (int out = 0; out < num_outputs; out++) {
308         nn_propagate_4to1(num_inputs, input_nodes,
309                           &layer_weights[out * num_inputs], &layer_bias[out],
310                           &output_nodes[out], output_layer);
311       }
312     } else if (num_inputs > 8) {
313       for (int out = 0; out < num_outputs; out++) {
314         nn_propagate_xto1(num_inputs, input_nodes,
315                           &layer_weights[out * num_inputs], &layer_bias[out],
316                           &output_nodes[out]);
317       }
318     } else if (num_inputs >= 4) {
319       for (int out = 0; out < num_outputs; out++) {
320         nn_propagate_xsto1(num_inputs, input_nodes,
321                            &layer_weights[out * num_inputs], &layer_bias[out],
322                            &output_nodes[out]);
323       }
324     } else {
325       for (int node = 0; node < num_outputs; ++node) {
326         float val = layer_bias[node];
327         for (int i = 0; i < num_inputs; ++i)
328           val += layer_weights[node * num_inputs + i] * input_nodes[i];
329         // ReLU as activation function.
330         val = val > 0.0f ? val : 0.0f;  // Could use AOMMAX().
331         output_nodes[node] = val;
332       }
333     }
334     input_nodes = output_nodes;
335     num_inputs = num_outputs;
336     buf_index = 1 - buf_index;
337   }
338   if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
339 }
340