1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker *
4*77c1e3ccSAndroid Build Coastguard Worker * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker */
11*77c1e3ccSAndroid Build Coastguard Worker
12*77c1e3ccSAndroid Build Coastguard Worker #include <stdbool.h>
13*77c1e3ccSAndroid Build Coastguard Worker #include <assert.h>
14*77c1e3ccSAndroid Build Coastguard Worker #include <arm_neon.h>
15*77c1e3ccSAndroid Build Coastguard Worker
16*77c1e3ccSAndroid Build Coastguard Worker #include "config/aom_config.h"
17*77c1e3ccSAndroid Build Coastguard Worker #include "config/av1_rtcd.h"
18*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/ml.h"
19*77c1e3ccSAndroid Build Coastguard Worker
nn_activate8(float32x4_t * out_h,float32x4_t * out_l,const float32x4_t * zero)20*77c1e3ccSAndroid Build Coastguard Worker static void nn_activate8(float32x4_t *out_h, float32x4_t *out_l,
21*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t *zero) {
22*77c1e3ccSAndroid Build Coastguard Worker *out_h = vmaxq_f32(*out_h, *zero);
23*77c1e3ccSAndroid Build Coastguard Worker *out_l = vmaxq_f32(*out_l, *zero);
24*77c1e3ccSAndroid Build Coastguard Worker }
25*77c1e3ccSAndroid Build Coastguard Worker
nn_activate4(float32x4_t * x,const float32x4_t * zero)26*77c1e3ccSAndroid Build Coastguard Worker static void nn_activate4(float32x4_t *x, const float32x4_t *zero) {
27*77c1e3ccSAndroid Build Coastguard Worker *x = vmaxq_f32(*x, *zero);
28*77c1e3ccSAndroid Build Coastguard Worker }
29*77c1e3ccSAndroid Build Coastguard Worker
30*77c1e3ccSAndroid Build Coastguard Worker #define CLAMP_0(x) (x = x > 0 ? x : 0)
31*77c1e3ccSAndroid Build Coastguard Worker
nn_propagate_8to1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)32*77c1e3ccSAndroid Build Coastguard Worker static void nn_propagate_8to1(int num_inputs, const float *const inputs,
33*77c1e3ccSAndroid Build Coastguard Worker const float *const weights,
34*77c1e3ccSAndroid Build Coastguard Worker const float *layer_bias,
35*77c1e3ccSAndroid Build Coastguard Worker float *const output_nodes, bool output_layer) {
36*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t zero = vdupq_n_f32(0);
37*77c1e3ccSAndroid Build Coastguard Worker float32x4_t vadd = zero;
38*77c1e3ccSAndroid Build Coastguard Worker float total = *layer_bias;
39*77c1e3ccSAndroid Build Coastguard Worker
40*77c1e3ccSAndroid Build Coastguard Worker for (int in = 0; in < num_inputs; in += 8) {
41*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
42*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
43*77c1e3ccSAndroid Build Coastguard Worker
44*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
45*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t weights_l = vld1q_f32(&weights[in]);
46*77c1e3ccSAndroid Build Coastguard Worker
47*77c1e3ccSAndroid Build Coastguard Worker vadd = vmlaq_f32(vadd, inputs_h, weights_h);
48*77c1e3ccSAndroid Build Coastguard Worker vadd = vmlaq_f32(vadd, inputs_l, weights_l);
49*77c1e3ccSAndroid Build Coastguard Worker }
50*77c1e3ccSAndroid Build Coastguard Worker #if AOM_ARCH_AARCH64
51*77c1e3ccSAndroid Build Coastguard Worker total += vaddvq_f32(vadd);
52*77c1e3ccSAndroid Build Coastguard Worker #else
53*77c1e3ccSAndroid Build Coastguard Worker float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
54*77c1e3ccSAndroid Build Coastguard Worker vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
55*77c1e3ccSAndroid Build Coastguard Worker total += vget_lane_f32(vadd_lo, 0);
56*77c1e3ccSAndroid Build Coastguard Worker #endif
57*77c1e3ccSAndroid Build Coastguard Worker
58*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) CLAMP_0(total);
59*77c1e3ccSAndroid Build Coastguard Worker *output_nodes = total;
60*77c1e3ccSAndroid Build Coastguard Worker }
61*77c1e3ccSAndroid Build Coastguard Worker
nn_propagate_xto1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes)62*77c1e3ccSAndroid Build Coastguard Worker static void nn_propagate_xto1(int num_inputs, const float *const inputs,
63*77c1e3ccSAndroid Build Coastguard Worker const float *const weights,
64*77c1e3ccSAndroid Build Coastguard Worker const float *layer_bias,
65*77c1e3ccSAndroid Build Coastguard Worker float *const output_nodes) {
66*77c1e3ccSAndroid Build Coastguard Worker float32x4_t vadd = vdupq_n_f32(0);
67*77c1e3ccSAndroid Build Coastguard Worker
68*77c1e3ccSAndroid Build Coastguard Worker float total = *layer_bias;
69*77c1e3ccSAndroid Build Coastguard Worker int j = num_inputs;
70*77c1e3ccSAndroid Build Coastguard Worker int in = 0;
71*77c1e3ccSAndroid Build Coastguard Worker while (j > 7) {
72*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
73*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
74*77c1e3ccSAndroid Build Coastguard Worker
75*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
76*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t weights_l = vld1q_f32(&weights[in]);
77*77c1e3ccSAndroid Build Coastguard Worker
78*77c1e3ccSAndroid Build Coastguard Worker vadd = vmlaq_f32(vadd, inputs_h, weights_h);
79*77c1e3ccSAndroid Build Coastguard Worker vadd = vmlaq_f32(vadd, inputs_l, weights_l);
80*77c1e3ccSAndroid Build Coastguard Worker in += 8;
81*77c1e3ccSAndroid Build Coastguard Worker j -= 8;
82*77c1e3ccSAndroid Build Coastguard Worker }
83*77c1e3ccSAndroid Build Coastguard Worker
84*77c1e3ccSAndroid Build Coastguard Worker #if AOM_ARCH_AARCH64
85*77c1e3ccSAndroid Build Coastguard Worker total += vaddvq_f32(vadd);
86*77c1e3ccSAndroid Build Coastguard Worker
87*77c1e3ccSAndroid Build Coastguard Worker #else
88*77c1e3ccSAndroid Build Coastguard Worker float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
89*77c1e3ccSAndroid Build Coastguard Worker vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
90*77c1e3ccSAndroid Build Coastguard Worker total += vget_lane_f32(vadd_lo, 0);
91*77c1e3ccSAndroid Build Coastguard Worker #endif
92*77c1e3ccSAndroid Build Coastguard Worker for (; in < num_inputs; in++) total += weights[in] * inputs[in];
93*77c1e3ccSAndroid Build Coastguard Worker
94*77c1e3ccSAndroid Build Coastguard Worker *output_nodes = CLAMP_0(total);
95*77c1e3ccSAndroid Build Coastguard Worker }
96*77c1e3ccSAndroid Build Coastguard Worker
nn_propagate_xsto1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes)97*77c1e3ccSAndroid Build Coastguard Worker static void nn_propagate_xsto1(int num_inputs, const float *const inputs,
98*77c1e3ccSAndroid Build Coastguard Worker const float *const weights,
99*77c1e3ccSAndroid Build Coastguard Worker const float *layer_bias,
100*77c1e3ccSAndroid Build Coastguard Worker float *const output_nodes) {
101*77c1e3ccSAndroid Build Coastguard Worker float total = *layer_bias;
102*77c1e3ccSAndroid Build Coastguard Worker #if AOM_ARCH_AARCH64
103*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t v_inputs = vld1q_f32(inputs);
104*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t v_weights = vld1q_f32(weights);
105*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t vadd = vmulq_f32(v_inputs, v_weights);
106*77c1e3ccSAndroid Build Coastguard Worker total += vaddvq_f32(vadd);
107*77c1e3ccSAndroid Build Coastguard Worker int in = 4;
108*77c1e3ccSAndroid Build Coastguard Worker #else
109*77c1e3ccSAndroid Build Coastguard Worker int in = 0;
110*77c1e3ccSAndroid Build Coastguard Worker #endif
111*77c1e3ccSAndroid Build Coastguard Worker for (; in < num_inputs; in++) total += weights[in] * inputs[in];
112*77c1e3ccSAndroid Build Coastguard Worker
113*77c1e3ccSAndroid Build Coastguard Worker *output_nodes = CLAMP_0(total);
114*77c1e3ccSAndroid Build Coastguard Worker }
115*77c1e3ccSAndroid Build Coastguard Worker
nn_propagate_4to1(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)116*77c1e3ccSAndroid Build Coastguard Worker static void nn_propagate_4to1(int num_inputs, const float *const inputs,
117*77c1e3ccSAndroid Build Coastguard Worker const float *const weights,
118*77c1e3ccSAndroid Build Coastguard Worker const float *layer_bias,
119*77c1e3ccSAndroid Build Coastguard Worker float *const output_nodes, bool output_layer) {
120*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t zero = vdupq_n_f32(0);
121*77c1e3ccSAndroid Build Coastguard Worker float32x4_t vadd = zero;
122*77c1e3ccSAndroid Build Coastguard Worker float total = *layer_bias;
123*77c1e3ccSAndroid Build Coastguard Worker
124*77c1e3ccSAndroid Build Coastguard Worker for (int in = 0; in < num_inputs; in += 4) {
125*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t v_inputs = vld1q_f32(&inputs[in]);
126*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t v_weights = vld1q_f32(&weights[in]);
127*77c1e3ccSAndroid Build Coastguard Worker vadd = vmlaq_f32(vadd, v_inputs, v_weights);
128*77c1e3ccSAndroid Build Coastguard Worker }
129*77c1e3ccSAndroid Build Coastguard Worker
130*77c1e3ccSAndroid Build Coastguard Worker #if AOM_ARCH_AARCH64
131*77c1e3ccSAndroid Build Coastguard Worker total += vaddvq_f32(vadd);
132*77c1e3ccSAndroid Build Coastguard Worker #else
133*77c1e3ccSAndroid Build Coastguard Worker float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
134*77c1e3ccSAndroid Build Coastguard Worker vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
135*77c1e3ccSAndroid Build Coastguard Worker total += vget_lane_f32(vadd_lo, 0);
136*77c1e3ccSAndroid Build Coastguard Worker #endif
137*77c1e3ccSAndroid Build Coastguard Worker
138*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) CLAMP_0(total);
139*77c1e3ccSAndroid Build Coastguard Worker *output_nodes = total;
140*77c1e3ccSAndroid Build Coastguard Worker }
141*77c1e3ccSAndroid Build Coastguard Worker
nn_propagate_4to4(int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)142*77c1e3ccSAndroid Build Coastguard Worker static void nn_propagate_4to4(int num_inputs, const float *const inputs,
143*77c1e3ccSAndroid Build Coastguard Worker const float *const weights,
144*77c1e3ccSAndroid Build Coastguard Worker const float *layer_bias,
145*77c1e3ccSAndroid Build Coastguard Worker float *const output_nodes, bool output_layer) {
146*77c1e3ccSAndroid Build Coastguard Worker float32x4_t outputs = vld1q_f32(layer_bias);
147*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t zero = vdupq_n_f32(0);
148*77c1e3ccSAndroid Build Coastguard Worker
149*77c1e3ccSAndroid Build Coastguard Worker float32x4_t mul0[2] = { zero, zero };
150*77c1e3ccSAndroid Build Coastguard Worker float32x4_t mul1[2] = { zero, zero };
151*77c1e3ccSAndroid Build Coastguard Worker for (int in = 0; in < num_inputs; in += 4) {
152*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t v_input = vld1q_f32(&inputs[in]);
153*77c1e3ccSAndroid Build Coastguard Worker
154*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < 2; i++) {
155*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
156*77c1e3ccSAndroid Build Coastguard Worker mul0[i] = vmlaq_f32(mul0[i], weight0, v_input);
157*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t weight1 =
158*77c1e3ccSAndroid Build Coastguard Worker vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
159*77c1e3ccSAndroid Build Coastguard Worker mul1[i] = vmlaq_f32(mul1[i], weight1, v_input);
160*77c1e3ccSAndroid Build Coastguard Worker }
161*77c1e3ccSAndroid Build Coastguard Worker }
162*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < 2; i++)
163*77c1e3ccSAndroid Build Coastguard Worker #if AOM_ARCH_AARCH64
164*77c1e3ccSAndroid Build Coastguard Worker mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
165*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t hh = vpaddq_f32(mul0[0], mul0[1]);
166*77c1e3ccSAndroid Build Coastguard Worker #else
167*77c1e3ccSAndroid Build Coastguard Worker mul0[i] =
168*77c1e3ccSAndroid Build Coastguard Worker vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
169*77c1e3ccSAndroid Build Coastguard Worker vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
170*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t hh =
171*77c1e3ccSAndroid Build Coastguard Worker vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
172*77c1e3ccSAndroid Build Coastguard Worker vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
173*77c1e3ccSAndroid Build Coastguard Worker #endif
174*77c1e3ccSAndroid Build Coastguard Worker
175*77c1e3ccSAndroid Build Coastguard Worker outputs = vaddq_f32(outputs, hh);
176*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) nn_activate4(&outputs, &zero);
177*77c1e3ccSAndroid Build Coastguard Worker vst1q_f32(output_nodes, outputs);
178*77c1e3ccSAndroid Build Coastguard Worker }
179*77c1e3ccSAndroid Build Coastguard Worker
nn_propagate_4to8(const int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)180*77c1e3ccSAndroid Build Coastguard Worker static void nn_propagate_4to8(const int num_inputs, const float *const inputs,
181*77c1e3ccSAndroid Build Coastguard Worker const float *const weights,
182*77c1e3ccSAndroid Build Coastguard Worker const float *layer_bias,
183*77c1e3ccSAndroid Build Coastguard Worker float *const output_nodes, bool output_layer) {
184*77c1e3ccSAndroid Build Coastguard Worker float32x4_t out_h = vld1q_f32(&layer_bias[4]);
185*77c1e3ccSAndroid Build Coastguard Worker float32x4_t out_l = vld1q_f32(layer_bias);
186*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t zero = vdupq_n_f32(0);
187*77c1e3ccSAndroid Build Coastguard Worker float32x4_t mul0[4] = { zero, zero, zero, zero };
188*77c1e3ccSAndroid Build Coastguard Worker float32x4_t mul1[4] = { zero, zero, zero, zero };
189*77c1e3ccSAndroid Build Coastguard Worker
190*77c1e3ccSAndroid Build Coastguard Worker for (int in = 0; in < num_inputs; in += 4) {
191*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t v_input = vld1q_f32(&inputs[in]);
192*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < 4; i++) {
193*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
194*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t weight1 =
195*77c1e3ccSAndroid Build Coastguard Worker vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
196*77c1e3ccSAndroid Build Coastguard Worker mul0[i] = vmlaq_f32(mul0[i], v_input, weight0);
197*77c1e3ccSAndroid Build Coastguard Worker mul1[i] = vmlaq_f32(mul1[i], v_input, weight1);
198*77c1e3ccSAndroid Build Coastguard Worker }
199*77c1e3ccSAndroid Build Coastguard Worker }
200*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < 4; i++)
201*77c1e3ccSAndroid Build Coastguard Worker #if AOM_ARCH_AARCH64
202*77c1e3ccSAndroid Build Coastguard Worker mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
203*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t hh0 = vpaddq_f32(mul0[0], mul0[1]);
204*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t hh1 = vpaddq_f32(mul0[2], mul0[3]);
205*77c1e3ccSAndroid Build Coastguard Worker #else
206*77c1e3ccSAndroid Build Coastguard Worker mul0[i] =
207*77c1e3ccSAndroid Build Coastguard Worker vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
208*77c1e3ccSAndroid Build Coastguard Worker vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
209*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t hh0 =
210*77c1e3ccSAndroid Build Coastguard Worker vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
211*77c1e3ccSAndroid Build Coastguard Worker vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
212*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t hh1 =
213*77c1e3ccSAndroid Build Coastguard Worker vcombine_f32(vpadd_f32(vget_low_f32(mul0[2]), vget_high_f32(mul0[2])),
214*77c1e3ccSAndroid Build Coastguard Worker vpadd_f32(vget_low_f32(mul0[3]), vget_high_f32(mul0[3])));
215*77c1e3ccSAndroid Build Coastguard Worker #endif
216*77c1e3ccSAndroid Build Coastguard Worker
217*77c1e3ccSAndroid Build Coastguard Worker out_h = vaddq_f32(out_h, hh1);
218*77c1e3ccSAndroid Build Coastguard Worker out_l = vaddq_f32(out_l, hh0);
219*77c1e3ccSAndroid Build Coastguard Worker
220*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) nn_activate8(&out_h, &out_l, &zero);
221*77c1e3ccSAndroid Build Coastguard Worker vst1q_f32(&output_nodes[4], out_h);
222*77c1e3ccSAndroid Build Coastguard Worker vst1q_f32(output_nodes, out_l);
223*77c1e3ccSAndroid Build Coastguard Worker }
224*77c1e3ccSAndroid Build Coastguard Worker
nn_propagate_8to4(const int num_inputs,const float * const inputs,const float * const weights,const float * layer_bias,float * const output_nodes,bool output_layer)225*77c1e3ccSAndroid Build Coastguard Worker static void nn_propagate_8to4(const int num_inputs, const float *const inputs,
226*77c1e3ccSAndroid Build Coastguard Worker const float *const weights,
227*77c1e3ccSAndroid Build Coastguard Worker const float *layer_bias,
228*77c1e3ccSAndroid Build Coastguard Worker float *const output_nodes, bool output_layer) {
229*77c1e3ccSAndroid Build Coastguard Worker float32x4_t outputs = vld1q_f32(layer_bias);
230*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t zero = vdupq_n_f32(0);
231*77c1e3ccSAndroid Build Coastguard Worker float32x4_t add[4] = { zero, zero, zero, zero };
232*77c1e3ccSAndroid Build Coastguard Worker for (int in = 0; in < num_inputs; in += 8) {
233*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
234*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
235*77c1e3ccSAndroid Build Coastguard Worker
236*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < 4; i++) {
237*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t weight_l = vld1q_f32(&weights[in + i * num_inputs]);
238*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t weight_h = vld1q_f32(&weights[in + i * num_inputs + 4]);
239*77c1e3ccSAndroid Build Coastguard Worker add[i] = vmlaq_f32(add[i], inputs_l, weight_l);
240*77c1e3ccSAndroid Build Coastguard Worker add[i] = vmlaq_f32(add[i], inputs_h, weight_h);
241*77c1e3ccSAndroid Build Coastguard Worker }
242*77c1e3ccSAndroid Build Coastguard Worker }
243*77c1e3ccSAndroid Build Coastguard Worker #if AOM_ARCH_AARCH64
244*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t hadd_h = vpaddq_f32(add[2], add[3]);
245*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t hadd_l = vpaddq_f32(add[0], add[1]);
246*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t haddhadd = vpaddq_f32(hadd_l, hadd_h);
247*77c1e3ccSAndroid Build Coastguard Worker #else
248*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t hadd_h =
249*77c1e3ccSAndroid Build Coastguard Worker vcombine_f32(vpadd_f32(vget_low_f32(add[2]), vget_high_f32(add[2])),
250*77c1e3ccSAndroid Build Coastguard Worker vpadd_f32(vget_low_f32(add[3]), vget_high_f32(add[3])));
251*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t hadd_l =
252*77c1e3ccSAndroid Build Coastguard Worker vcombine_f32(vpadd_f32(vget_low_f32(add[0]), vget_high_f32(add[0])),
253*77c1e3ccSAndroid Build Coastguard Worker vpadd_f32(vget_low_f32(add[1]), vget_high_f32(add[1])));
254*77c1e3ccSAndroid Build Coastguard Worker const float32x4_t haddhadd =
255*77c1e3ccSAndroid Build Coastguard Worker vcombine_f32(vpadd_f32(vget_low_f32(hadd_l), vget_high_f32(hadd_l)),
256*77c1e3ccSAndroid Build Coastguard Worker vpadd_f32(vget_low_f32(hadd_h), vget_high_f32(hadd_h)));
257*77c1e3ccSAndroid Build Coastguard Worker #endif
258*77c1e3ccSAndroid Build Coastguard Worker
259*77c1e3ccSAndroid Build Coastguard Worker outputs = vaddq_f32(outputs, haddhadd);
260*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) nn_activate4(&outputs, &zero);
261*77c1e3ccSAndroid Build Coastguard Worker vst1q_f32(output_nodes, outputs);
262*77c1e3ccSAndroid Build Coastguard Worker }
263*77c1e3ccSAndroid Build Coastguard Worker
264*77c1e3ccSAndroid Build Coastguard Worker // Calculate prediction based on the given input features and neural net config.
265*77c1e3ccSAndroid Build Coastguard Worker // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
266*77c1e3ccSAndroid Build Coastguard Worker // layer.
av1_nn_predict_neon(const float * input_nodes,const NN_CONFIG * const nn_config,int reduce_prec,float * const output)267*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_predict_neon(const float *input_nodes,
268*77c1e3ccSAndroid Build Coastguard Worker const NN_CONFIG *const nn_config, int reduce_prec,
269*77c1e3ccSAndroid Build Coastguard Worker float *const output) {
270*77c1e3ccSAndroid Build Coastguard Worker float buf[2][NN_MAX_NODES_PER_LAYER];
271*77c1e3ccSAndroid Build Coastguard Worker int buf_index = 0;
272*77c1e3ccSAndroid Build Coastguard Worker int num_inputs = nn_config->num_inputs;
273*77c1e3ccSAndroid Build Coastguard Worker // Hidden layers, except the final iteration is the output layer.
274*77c1e3ccSAndroid Build Coastguard Worker for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
275*77c1e3ccSAndroid Build Coastguard Worker const float *layer_weights = nn_config->weights[layer];
276*77c1e3ccSAndroid Build Coastguard Worker const float *layer_bias = nn_config->bias[layer];
277*77c1e3ccSAndroid Build Coastguard Worker bool output_layer = (layer == nn_config->num_hidden_layers);
278*77c1e3ccSAndroid Build Coastguard Worker float *const output_nodes = output_layer ? output : buf[buf_index];
279*77c1e3ccSAndroid Build Coastguard Worker const int num_outputs = output_layer ? nn_config->num_outputs
280*77c1e3ccSAndroid Build Coastguard Worker : nn_config->num_hidden_nodes[layer];
281*77c1e3ccSAndroid Build Coastguard Worker
282*77c1e3ccSAndroid Build Coastguard Worker if (num_inputs % 4 == 0 && num_outputs % 8 == 0) {
283*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out += 8) {
284*77c1e3ccSAndroid Build Coastguard Worker nn_propagate_4to8(num_inputs, input_nodes,
285*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs], &layer_bias[out],
286*77c1e3ccSAndroid Build Coastguard Worker &output_nodes[out], output_layer);
287*77c1e3ccSAndroid Build Coastguard Worker }
288*77c1e3ccSAndroid Build Coastguard Worker } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) {
289*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out += 4) {
290*77c1e3ccSAndroid Build Coastguard Worker nn_propagate_8to4(num_inputs, input_nodes,
291*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs], &layer_bias[out],
292*77c1e3ccSAndroid Build Coastguard Worker &output_nodes[out], output_layer);
293*77c1e3ccSAndroid Build Coastguard Worker }
294*77c1e3ccSAndroid Build Coastguard Worker } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) {
295*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out += 4) {
296*77c1e3ccSAndroid Build Coastguard Worker nn_propagate_4to4(num_inputs, input_nodes,
297*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs], &layer_bias[out],
298*77c1e3ccSAndroid Build Coastguard Worker &output_nodes[out], output_layer);
299*77c1e3ccSAndroid Build Coastguard Worker }
300*77c1e3ccSAndroid Build Coastguard Worker } else if (num_inputs % 8 == 0) {
301*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out++) {
302*77c1e3ccSAndroid Build Coastguard Worker nn_propagate_8to1(num_inputs, input_nodes,
303*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs], &layer_bias[out],
304*77c1e3ccSAndroid Build Coastguard Worker &output_nodes[out], output_layer);
305*77c1e3ccSAndroid Build Coastguard Worker }
306*77c1e3ccSAndroid Build Coastguard Worker } else if (num_inputs % 4 == 0) {
307*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out++) {
308*77c1e3ccSAndroid Build Coastguard Worker nn_propagate_4to1(num_inputs, input_nodes,
309*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs], &layer_bias[out],
310*77c1e3ccSAndroid Build Coastguard Worker &output_nodes[out], output_layer);
311*77c1e3ccSAndroid Build Coastguard Worker }
312*77c1e3ccSAndroid Build Coastguard Worker } else if (num_inputs > 8) {
313*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out++) {
314*77c1e3ccSAndroid Build Coastguard Worker nn_propagate_xto1(num_inputs, input_nodes,
315*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs], &layer_bias[out],
316*77c1e3ccSAndroid Build Coastguard Worker &output_nodes[out]);
317*77c1e3ccSAndroid Build Coastguard Worker }
318*77c1e3ccSAndroid Build Coastguard Worker } else if (num_inputs >= 4) {
319*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out++) {
320*77c1e3ccSAndroid Build Coastguard Worker nn_propagate_xsto1(num_inputs, input_nodes,
321*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs], &layer_bias[out],
322*77c1e3ccSAndroid Build Coastguard Worker &output_nodes[out]);
323*77c1e3ccSAndroid Build Coastguard Worker }
324*77c1e3ccSAndroid Build Coastguard Worker } else {
325*77c1e3ccSAndroid Build Coastguard Worker for (int node = 0; node < num_outputs; ++node) {
326*77c1e3ccSAndroid Build Coastguard Worker float val = layer_bias[node];
327*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < num_inputs; ++i)
328*77c1e3ccSAndroid Build Coastguard Worker val += layer_weights[node * num_inputs + i] * input_nodes[i];
329*77c1e3ccSAndroid Build Coastguard Worker // ReLU as activation function.
330*77c1e3ccSAndroid Build Coastguard Worker val = val > 0.0f ? val : 0.0f; // Could use AOMMAX().
331*77c1e3ccSAndroid Build Coastguard Worker output_nodes[node] = val;
332*77c1e3ccSAndroid Build Coastguard Worker }
333*77c1e3ccSAndroid Build Coastguard Worker }
334*77c1e3ccSAndroid Build Coastguard Worker input_nodes = output_nodes;
335*77c1e3ccSAndroid Build Coastguard Worker num_inputs = num_outputs;
336*77c1e3ccSAndroid Build Coastguard Worker buf_index = 1 - buf_index;
337*77c1e3ccSAndroid Build Coastguard Worker }
338*77c1e3ccSAndroid Build Coastguard Worker if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
339*77c1e3ccSAndroid Build Coastguard Worker }
340