1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker *
4*77c1e3ccSAndroid Build Coastguard Worker * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker */
11*77c1e3ccSAndroid Build Coastguard Worker
12*77c1e3ccSAndroid Build Coastguard Worker #include <stdbool.h>
13*77c1e3ccSAndroid Build Coastguard Worker #include <assert.h>
14*77c1e3ccSAndroid Build Coastguard Worker
15*77c1e3ccSAndroid Build Coastguard Worker #include "config/av1_rtcd.h"
16*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/ml.h"
17*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/x86/ml_sse3.h"
18*77c1e3ccSAndroid Build Coastguard Worker
19*77c1e3ccSAndroid Build Coastguard Worker // In order to avoid the high-latency of swapping between FPU and SIMD
20*77c1e3ccSAndroid Build Coastguard Worker // operations, we keep the result in a 128-bit register even though we only
21*77c1e3ccSAndroid Build Coastguard Worker // care about a single value.
nn_propagate_8to1(const float * const inputs,const float * const weights,__m128 * const output)22*77c1e3ccSAndroid Build Coastguard Worker static void nn_propagate_8to1(const float *const inputs,
23*77c1e3ccSAndroid Build Coastguard Worker const float *const weights,
24*77c1e3ccSAndroid Build Coastguard Worker __m128 *const output) {
25*77c1e3ccSAndroid Build Coastguard Worker const __m128 inputs_h = _mm_loadu_ps(&inputs[4]);
26*77c1e3ccSAndroid Build Coastguard Worker const __m128 inputs_l = _mm_loadu_ps(inputs);
27*77c1e3ccSAndroid Build Coastguard Worker
28*77c1e3ccSAndroid Build Coastguard Worker const __m128 weights_h = _mm_loadu_ps(&weights[4]);
29*77c1e3ccSAndroid Build Coastguard Worker const __m128 weights_l = _mm_loadu_ps(weights);
30*77c1e3ccSAndroid Build Coastguard Worker
31*77c1e3ccSAndroid Build Coastguard Worker const __m128 mul_h = _mm_mul_ps(inputs_h, weights_h);
32*77c1e3ccSAndroid Build Coastguard Worker const __m128 mul_l = _mm_mul_ps(inputs_l, weights_l);
33*77c1e3ccSAndroid Build Coastguard Worker // [7 6 5 4] [3 2 1 0] (weight and input indices)
34*77c1e3ccSAndroid Build Coastguard Worker
35*77c1e3ccSAndroid Build Coastguard Worker const __m128 vadd = _mm_add_ps(mul_l, mul_h);
36*77c1e3ccSAndroid Build Coastguard Worker // [7+3 6+2 5+1 4+0]
37*77c1e3ccSAndroid Build Coastguard Worker const __m128 hadd1 = _mm_hadd_ps(vadd, vadd);
38*77c1e3ccSAndroid Build Coastguard Worker // [7+6+3+2 5+4+1+0 7+6+3+2 5+4+1+0]
39*77c1e3ccSAndroid Build Coastguard Worker const __m128 hadd2 = _mm_hadd_ps(hadd1, hadd1);
40*77c1e3ccSAndroid Build Coastguard Worker // [7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0]
41*77c1e3ccSAndroid Build Coastguard Worker *output = _mm_add_ps(*output, hadd2);
42*77c1e3ccSAndroid Build Coastguard Worker }
43*77c1e3ccSAndroid Build Coastguard Worker
av1_nn_propagate_4to1_sse3(const float * const inputs,const float * const weights,__m128 * const output)44*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_propagate_4to1_sse3(const float *const inputs,
45*77c1e3ccSAndroid Build Coastguard Worker const float *const weights,
46*77c1e3ccSAndroid Build Coastguard Worker __m128 *const output) {
47*77c1e3ccSAndroid Build Coastguard Worker const __m128 inputs128 = _mm_loadu_ps(inputs);
48*77c1e3ccSAndroid Build Coastguard Worker
49*77c1e3ccSAndroid Build Coastguard Worker const __m128 weights128 = _mm_loadu_ps(weights);
50*77c1e3ccSAndroid Build Coastguard Worker
51*77c1e3ccSAndroid Build Coastguard Worker const __m128 mul = _mm_mul_ps(inputs128, weights128);
52*77c1e3ccSAndroid Build Coastguard Worker // [3 2 1 0] (weight and input indices)
53*77c1e3ccSAndroid Build Coastguard Worker
54*77c1e3ccSAndroid Build Coastguard Worker const __m128 hadd1 = _mm_hadd_ps(mul, mul);
55*77c1e3ccSAndroid Build Coastguard Worker // [3+2 1+0 3+2 1+0]
56*77c1e3ccSAndroid Build Coastguard Worker const __m128 hadd2 = _mm_hadd_ps(hadd1, hadd1);
57*77c1e3ccSAndroid Build Coastguard Worker // [3+2+1+0 3+2+1+0 3+2+1+0 3+2+1+0]
58*77c1e3ccSAndroid Build Coastguard Worker *output = _mm_add_ps(*output, hadd2);
59*77c1e3ccSAndroid Build Coastguard Worker }
60*77c1e3ccSAndroid Build Coastguard Worker
av1_nn_propagate_4to4_sse3(const float * const inputs,const float * const weights,__m128 * const outputs,const int num_inputs)61*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_propagate_4to4_sse3(const float *const inputs,
62*77c1e3ccSAndroid Build Coastguard Worker const float *const weights,
63*77c1e3ccSAndroid Build Coastguard Worker __m128 *const outputs, const int num_inputs) {
64*77c1e3ccSAndroid Build Coastguard Worker const __m128 inputs128 = _mm_loadu_ps(inputs);
65*77c1e3ccSAndroid Build Coastguard Worker
66*77c1e3ccSAndroid Build Coastguard Worker __m128 hadd[2];
67*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < 2; i++) { // For each pair of outputs
68*77c1e3ccSAndroid Build Coastguard Worker const __m128 weight0 = _mm_loadu_ps(&weights[2 * i * num_inputs]);
69*77c1e3ccSAndroid Build Coastguard Worker const __m128 mul0 = _mm_mul_ps(weight0, inputs128);
70*77c1e3ccSAndroid Build Coastguard Worker const __m128 weight1 = _mm_loadu_ps(&weights[(2 * i + 1) * num_inputs]);
71*77c1e3ccSAndroid Build Coastguard Worker const __m128 mul1 = _mm_mul_ps(weight1, inputs128);
72*77c1e3ccSAndroid Build Coastguard Worker hadd[i] = _mm_hadd_ps(mul0, mul1);
73*77c1e3ccSAndroid Build Coastguard Worker }
74*77c1e3ccSAndroid Build Coastguard Worker // hadd[0] = [7+6 5+4 3+2 1+0] (weight indices)
75*77c1e3ccSAndroid Build Coastguard Worker // hadd[1] = [15+14 13+12 11+10 9+8]
76*77c1e3ccSAndroid Build Coastguard Worker
77*77c1e3ccSAndroid Build Coastguard Worker const __m128 hh = _mm_hadd_ps(hadd[0], hadd[1]);
78*77c1e3ccSAndroid Build Coastguard Worker // [15+14+13+12 11+10+9+8 7+6+5+4 3+2+1+0]
79*77c1e3ccSAndroid Build Coastguard Worker
80*77c1e3ccSAndroid Build Coastguard Worker *outputs = _mm_add_ps(*outputs, hh);
81*77c1e3ccSAndroid Build Coastguard Worker }
82*77c1e3ccSAndroid Build Coastguard Worker
av1_nn_propagate_4to8_sse3(const float * const inputs,const float * const weights,__m128 * const out_h,__m128 * const out_l,const int num_inputs)83*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_propagate_4to8_sse3(const float *const inputs,
84*77c1e3ccSAndroid Build Coastguard Worker const float *const weights, __m128 *const out_h,
85*77c1e3ccSAndroid Build Coastguard Worker __m128 *const out_l, const int num_inputs) {
86*77c1e3ccSAndroid Build Coastguard Worker const __m128 inputs128 = _mm_loadu_ps(inputs);
87*77c1e3ccSAndroid Build Coastguard Worker
88*77c1e3ccSAndroid Build Coastguard Worker __m128 hadd[4];
89*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < 4; i++) { // For each pair of outputs
90*77c1e3ccSAndroid Build Coastguard Worker const __m128 weight0 = _mm_loadu_ps(&weights[2 * i * num_inputs]);
91*77c1e3ccSAndroid Build Coastguard Worker const __m128 weight1 = _mm_loadu_ps(&weights[(2 * i + 1) * num_inputs]);
92*77c1e3ccSAndroid Build Coastguard Worker const __m128 mul0 = _mm_mul_ps(inputs128, weight0);
93*77c1e3ccSAndroid Build Coastguard Worker const __m128 mul1 = _mm_mul_ps(inputs128, weight1);
94*77c1e3ccSAndroid Build Coastguard Worker hadd[i] = _mm_hadd_ps(mul0, mul1);
95*77c1e3ccSAndroid Build Coastguard Worker }
96*77c1e3ccSAndroid Build Coastguard Worker // hadd[0] = [7+6 5+4 3+2 1+0] (weight indices)
97*77c1e3ccSAndroid Build Coastguard Worker // hadd[1] = [15+14 13+12 11+10 9+8]
98*77c1e3ccSAndroid Build Coastguard Worker // hadd[2] = [23+22 21+20 19+18 17+16]
99*77c1e3ccSAndroid Build Coastguard Worker // hadd[3] = [31+30 29+28 27+26 25+24]
100*77c1e3ccSAndroid Build Coastguard Worker
101*77c1e3ccSAndroid Build Coastguard Worker const __m128 hh0 = _mm_hadd_ps(hadd[0], hadd[1]);
102*77c1e3ccSAndroid Build Coastguard Worker // [15+14+13+12 11+10+9+8 7+6+5+4 3+2+1+0]
103*77c1e3ccSAndroid Build Coastguard Worker const __m128 hh1 = _mm_hadd_ps(hadd[2], hadd[3]);
104*77c1e3ccSAndroid Build Coastguard Worker // [31+30+29+28 27+26+25+24 23+22+21+20 19+18+17+16]
105*77c1e3ccSAndroid Build Coastguard Worker
106*77c1e3ccSAndroid Build Coastguard Worker *out_h = _mm_add_ps(*out_h, hh1);
107*77c1e3ccSAndroid Build Coastguard Worker *out_l = _mm_add_ps(*out_l, hh0);
108*77c1e3ccSAndroid Build Coastguard Worker }
109*77c1e3ccSAndroid Build Coastguard Worker
nn_propagate_8to4(const float * const inputs,const float * const weights,__m128 * const outputs,const int num_inputs)110*77c1e3ccSAndroid Build Coastguard Worker static void nn_propagate_8to4(const float *const inputs,
111*77c1e3ccSAndroid Build Coastguard Worker const float *const weights, __m128 *const outputs,
112*77c1e3ccSAndroid Build Coastguard Worker const int num_inputs) {
113*77c1e3ccSAndroid Build Coastguard Worker const __m128 inputs_h = _mm_loadu_ps(inputs + 4);
114*77c1e3ccSAndroid Build Coastguard Worker const __m128 inputs_l = _mm_loadu_ps(inputs);
115*77c1e3ccSAndroid Build Coastguard Worker // [7 6 5 4] [3 2 1 0] (input indices)
116*77c1e3ccSAndroid Build Coastguard Worker
117*77c1e3ccSAndroid Build Coastguard Worker __m128 add[4];
118*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < 4; i++) { // For each output:
119*77c1e3ccSAndroid Build Coastguard Worker const __m128 weight_h = _mm_loadu_ps(&weights[i * num_inputs + 4]);
120*77c1e3ccSAndroid Build Coastguard Worker const __m128 weight_l = _mm_loadu_ps(&weights[i * num_inputs]);
121*77c1e3ccSAndroid Build Coastguard Worker const __m128 mul_h = _mm_mul_ps(inputs_h, weight_h);
122*77c1e3ccSAndroid Build Coastguard Worker const __m128 mul_l = _mm_mul_ps(inputs_l, weight_l);
123*77c1e3ccSAndroid Build Coastguard Worker add[i] = _mm_add_ps(mul_l, mul_h);
124*77c1e3ccSAndroid Build Coastguard Worker }
125*77c1e3ccSAndroid Build Coastguard Worker // add[0] = [7+3 6+2 5+1 4+0]
126*77c1e3ccSAndroid Build Coastguard Worker // add[1] = [15+11 14+10 13+9 12+8]
127*77c1e3ccSAndroid Build Coastguard Worker // add[2] = [23+19 22+18 21+17 20+16]
128*77c1e3ccSAndroid Build Coastguard Worker // add[3] = [31+27 30+26 29+25 28+24]
129*77c1e3ccSAndroid Build Coastguard Worker
130*77c1e3ccSAndroid Build Coastguard Worker const __m128 hadd_h = _mm_hadd_ps(add[2], add[3]);
131*77c1e3ccSAndroid Build Coastguard Worker // [31+30+27+26 29+28+25+24 23+22+19+18 21+20+17+16]
132*77c1e3ccSAndroid Build Coastguard Worker const __m128 hadd_l = _mm_hadd_ps(add[0], add[1]);
133*77c1e3ccSAndroid Build Coastguard Worker // [15+14+11+10 13+12+9+8 7+6+3+2 5+4+1+0]
134*77c1e3ccSAndroid Build Coastguard Worker
135*77c1e3ccSAndroid Build Coastguard Worker const __m128 haddhadd = _mm_hadd_ps(hadd_l, hadd_h);
136*77c1e3ccSAndroid Build Coastguard Worker // [31+30+29+28+27+26+25+24 23+22+21+20+19+18+17+16
137*77c1e3ccSAndroid Build Coastguard Worker // 15+14+13+12+11+10+9+8 7+6+5+4+3+2+1+0]
138*77c1e3ccSAndroid Build Coastguard Worker
139*77c1e3ccSAndroid Build Coastguard Worker *outputs = _mm_add_ps(*outputs, haddhadd);
140*77c1e3ccSAndroid Build Coastguard Worker }
141*77c1e3ccSAndroid Build Coastguard Worker
nn_activate8(__m128 * out_h,__m128 * out_l)142*77c1e3ccSAndroid Build Coastguard Worker static void nn_activate8(__m128 *out_h, __m128 *out_l) {
143*77c1e3ccSAndroid Build Coastguard Worker const __m128 zero = _mm_setzero_ps();
144*77c1e3ccSAndroid Build Coastguard Worker *out_h = _mm_max_ps(*out_h, zero);
145*77c1e3ccSAndroid Build Coastguard Worker *out_l = _mm_max_ps(*out_l, zero);
146*77c1e3ccSAndroid Build Coastguard Worker }
147*77c1e3ccSAndroid Build Coastguard Worker
nn_activate4(__m128 * x)148*77c1e3ccSAndroid Build Coastguard Worker static void nn_activate4(__m128 *x) { *x = _mm_max_ps(*x, _mm_setzero_ps()); }
149*77c1e3ccSAndroid Build Coastguard Worker
150*77c1e3ccSAndroid Build Coastguard Worker // Calculate prediction based on the given input features and neural net config.
151*77c1e3ccSAndroid Build Coastguard Worker // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
152*77c1e3ccSAndroid Build Coastguard Worker // layer.
av1_nn_predict_sse3(const float * input_nodes,const NN_CONFIG * const nn_config,int reduce_prec,float * const output)153*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_predict_sse3(const float *input_nodes,
154*77c1e3ccSAndroid Build Coastguard Worker const NN_CONFIG *const nn_config, int reduce_prec,
155*77c1e3ccSAndroid Build Coastguard Worker float *const output) {
156*77c1e3ccSAndroid Build Coastguard Worker float buf[2][NN_MAX_NODES_PER_LAYER];
157*77c1e3ccSAndroid Build Coastguard Worker int buf_index = 0;
158*77c1e3ccSAndroid Build Coastguard Worker int num_inputs = nn_config->num_inputs;
159*77c1e3ccSAndroid Build Coastguard Worker
160*77c1e3ccSAndroid Build Coastguard Worker // Hidden layers, except the final iteration is the output layer.
161*77c1e3ccSAndroid Build Coastguard Worker for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
162*77c1e3ccSAndroid Build Coastguard Worker const float *layer_weights = nn_config->weights[layer];
163*77c1e3ccSAndroid Build Coastguard Worker const float *layer_bias = nn_config->bias[layer];
164*77c1e3ccSAndroid Build Coastguard Worker bool output_layer = (layer == nn_config->num_hidden_layers);
165*77c1e3ccSAndroid Build Coastguard Worker float *const output_nodes = output_layer ? output : &buf[buf_index][0];
166*77c1e3ccSAndroid Build Coastguard Worker const int num_outputs = output_layer ? nn_config->num_outputs
167*77c1e3ccSAndroid Build Coastguard Worker : nn_config->num_hidden_nodes[layer];
168*77c1e3ccSAndroid Build Coastguard Worker
169*77c1e3ccSAndroid Build Coastguard Worker if (num_inputs % 4 == 0 && num_outputs % 8 == 0) {
170*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out += 8) {
171*77c1e3ccSAndroid Build Coastguard Worker __m128 out_h = _mm_loadu_ps(&layer_bias[out + 4]);
172*77c1e3ccSAndroid Build Coastguard Worker __m128 out_l = _mm_loadu_ps(&layer_bias[out]);
173*77c1e3ccSAndroid Build Coastguard Worker for (int in = 0; in < num_inputs; in += 4) {
174*77c1e3ccSAndroid Build Coastguard Worker av1_nn_propagate_4to8_sse3(&input_nodes[in],
175*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs + in],
176*77c1e3ccSAndroid Build Coastguard Worker &out_h, &out_l, num_inputs);
177*77c1e3ccSAndroid Build Coastguard Worker }
178*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) nn_activate8(&out_h, &out_l);
179*77c1e3ccSAndroid Build Coastguard Worker _mm_storeu_ps(&output_nodes[out + 4], out_h);
180*77c1e3ccSAndroid Build Coastguard Worker _mm_storeu_ps(&output_nodes[out], out_l);
181*77c1e3ccSAndroid Build Coastguard Worker }
182*77c1e3ccSAndroid Build Coastguard Worker } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) {
183*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out += 4) {
184*77c1e3ccSAndroid Build Coastguard Worker __m128 outputs = _mm_loadu_ps(&layer_bias[out]);
185*77c1e3ccSAndroid Build Coastguard Worker for (int in = 0; in < num_inputs; in += 8) {
186*77c1e3ccSAndroid Build Coastguard Worker nn_propagate_8to4(&input_nodes[in],
187*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs + in], &outputs,
188*77c1e3ccSAndroid Build Coastguard Worker num_inputs);
189*77c1e3ccSAndroid Build Coastguard Worker }
190*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) nn_activate4(&outputs);
191*77c1e3ccSAndroid Build Coastguard Worker _mm_storeu_ps(&output_nodes[out], outputs);
192*77c1e3ccSAndroid Build Coastguard Worker }
193*77c1e3ccSAndroid Build Coastguard Worker } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) {
194*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out += 4) {
195*77c1e3ccSAndroid Build Coastguard Worker __m128 outputs = _mm_loadu_ps(&layer_bias[out]);
196*77c1e3ccSAndroid Build Coastguard Worker for (int in = 0; in < num_inputs; in += 4) {
197*77c1e3ccSAndroid Build Coastguard Worker av1_nn_propagate_4to4_sse3(&input_nodes[in],
198*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs + in],
199*77c1e3ccSAndroid Build Coastguard Worker &outputs, num_inputs);
200*77c1e3ccSAndroid Build Coastguard Worker }
201*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) nn_activate4(&outputs);
202*77c1e3ccSAndroid Build Coastguard Worker _mm_storeu_ps(&output_nodes[out], outputs);
203*77c1e3ccSAndroid Build Coastguard Worker }
204*77c1e3ccSAndroid Build Coastguard Worker } else if (num_inputs % 8 == 0) {
205*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out++) {
206*77c1e3ccSAndroid Build Coastguard Worker __m128 total = _mm_load1_ps(&layer_bias[out]);
207*77c1e3ccSAndroid Build Coastguard Worker for (int in = 0; in < num_inputs; in += 8) {
208*77c1e3ccSAndroid Build Coastguard Worker nn_propagate_8to1(&input_nodes[in],
209*77c1e3ccSAndroid Build Coastguard Worker &layer_weights[out * num_inputs + in], &total);
210*77c1e3ccSAndroid Build Coastguard Worker }
211*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) nn_activate4(&total);
212*77c1e3ccSAndroid Build Coastguard Worker output_nodes[out] = _mm_cvtss_f32(total);
213*77c1e3ccSAndroid Build Coastguard Worker }
214*77c1e3ccSAndroid Build Coastguard Worker } else if (num_inputs % 4 == 0) {
215*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out++) {
216*77c1e3ccSAndroid Build Coastguard Worker __m128 total = _mm_load1_ps(&layer_bias[out]);
217*77c1e3ccSAndroid Build Coastguard Worker for (int in = 0; in < num_inputs; in += 4) {
218*77c1e3ccSAndroid Build Coastguard Worker av1_nn_propagate_4to1_sse3(
219*77c1e3ccSAndroid Build Coastguard Worker &input_nodes[in], &layer_weights[out * num_inputs + in], &total);
220*77c1e3ccSAndroid Build Coastguard Worker }
221*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) nn_activate4(&total);
222*77c1e3ccSAndroid Build Coastguard Worker output_nodes[out] = _mm_cvtss_f32(total);
223*77c1e3ccSAndroid Build Coastguard Worker }
224*77c1e3ccSAndroid Build Coastguard Worker } else {
225*77c1e3ccSAndroid Build Coastguard Worker // Use SSE instructions for scalar operations to avoid the latency of
226*77c1e3ccSAndroid Build Coastguard Worker // swapping between SIMD and FPU modes.
227*77c1e3ccSAndroid Build Coastguard Worker for (int out = 0; out < num_outputs; out++) {
228*77c1e3ccSAndroid Build Coastguard Worker __m128 total = _mm_load1_ps(&layer_bias[out]);
229*77c1e3ccSAndroid Build Coastguard Worker for (int in_node = 0; in_node < num_inputs; in_node++) {
230*77c1e3ccSAndroid Build Coastguard Worker __m128 input = _mm_load1_ps(&input_nodes[in_node]);
231*77c1e3ccSAndroid Build Coastguard Worker __m128 weight =
232*77c1e3ccSAndroid Build Coastguard Worker _mm_load1_ps(&layer_weights[num_inputs * out + in_node]);
233*77c1e3ccSAndroid Build Coastguard Worker total = _mm_add_ps(total, _mm_mul_ps(input, weight));
234*77c1e3ccSAndroid Build Coastguard Worker }
235*77c1e3ccSAndroid Build Coastguard Worker if (!output_layer) nn_activate4(&total);
236*77c1e3ccSAndroid Build Coastguard Worker output_nodes[out] = _mm_cvtss_f32(total);
237*77c1e3ccSAndroid Build Coastguard Worker }
238*77c1e3ccSAndroid Build Coastguard Worker }
239*77c1e3ccSAndroid Build Coastguard Worker input_nodes = output_nodes;
240*77c1e3ccSAndroid Build Coastguard Worker num_inputs = num_outputs;
241*77c1e3ccSAndroid Build Coastguard Worker buf_index = 1 - buf_index;
242*77c1e3ccSAndroid Build Coastguard Worker }
243*77c1e3ccSAndroid Build Coastguard Worker if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
244*77c1e3ccSAndroid Build Coastguard Worker }
245*77c1e3ccSAndroid Build Coastguard Worker
246*77c1e3ccSAndroid Build Coastguard Worker // Based on N. N. Schraudolph. A Fast, Compact Approximation of the Exponential
247*77c1e3ccSAndroid Build Coastguard Worker // Function. Neural Computation, 11(4):853–862, 1999.
approx_exp(__m128 y)248*77c1e3ccSAndroid Build Coastguard Worker static inline __m128 approx_exp(__m128 y) {
249*77c1e3ccSAndroid Build Coastguard Worker #define A ((1 << 23) / 0.69314718056f) // (1 << 23) / ln(2)
250*77c1e3ccSAndroid Build Coastguard Worker #define B \
251*77c1e3ccSAndroid Build Coastguard Worker 127 // Offset for the exponent according to IEEE floating point standard.
252*77c1e3ccSAndroid Build Coastguard Worker #define C 60801 // Magic number controls the accuracy of approximation
253*77c1e3ccSAndroid Build Coastguard Worker const __m128 multiplier = _mm_set1_ps(A);
254*77c1e3ccSAndroid Build Coastguard Worker const __m128i offset = _mm_set1_epi32(B * (1 << 23) - C);
255*77c1e3ccSAndroid Build Coastguard Worker
256*77c1e3ccSAndroid Build Coastguard Worker y = _mm_mul_ps(y, multiplier);
257*77c1e3ccSAndroid Build Coastguard Worker y = _mm_castsi128_ps(_mm_add_epi32(_mm_cvtps_epi32(y), offset));
258*77c1e3ccSAndroid Build Coastguard Worker return y;
259*77c1e3ccSAndroid Build Coastguard Worker #undef A
260*77c1e3ccSAndroid Build Coastguard Worker #undef B
261*77c1e3ccSAndroid Build Coastguard Worker #undef C
262*77c1e3ccSAndroid Build Coastguard Worker }
263*77c1e3ccSAndroid Build Coastguard Worker
reduce_max(__m128 reg)264*77c1e3ccSAndroid Build Coastguard Worker static inline __m128 reduce_max(__m128 reg) {
265*77c1e3ccSAndroid Build Coastguard Worker __m128 tmp_reg;
266*77c1e3ccSAndroid Build Coastguard Worker
267*77c1e3ccSAndroid Build Coastguard Worker tmp_reg = _mm_shuffle_ps(reg, reg, 0x4e); // 01 00 11 10
268*77c1e3ccSAndroid Build Coastguard Worker reg = _mm_max_ps(reg, tmp_reg);
269*77c1e3ccSAndroid Build Coastguard Worker
270*77c1e3ccSAndroid Build Coastguard Worker tmp_reg = _mm_shuffle_ps(reg, reg, 0xb1); // 10 11 00 01
271*77c1e3ccSAndroid Build Coastguard Worker reg = _mm_max_ps(reg, tmp_reg);
272*77c1e3ccSAndroid Build Coastguard Worker
273*77c1e3ccSAndroid Build Coastguard Worker return reg;
274*77c1e3ccSAndroid Build Coastguard Worker }
275*77c1e3ccSAndroid Build Coastguard Worker
reduce_sum(__m128 reg)276*77c1e3ccSAndroid Build Coastguard Worker static inline __m128 reduce_sum(__m128 reg) {
277*77c1e3ccSAndroid Build Coastguard Worker __m128 tmp_reg;
278*77c1e3ccSAndroid Build Coastguard Worker
279*77c1e3ccSAndroid Build Coastguard Worker tmp_reg = _mm_shuffle_ps(reg, reg, 0x4e); // 01 00 11 10
280*77c1e3ccSAndroid Build Coastguard Worker reg = _mm_add_ps(reg, tmp_reg);
281*77c1e3ccSAndroid Build Coastguard Worker
282*77c1e3ccSAndroid Build Coastguard Worker tmp_reg = _mm_shuffle_ps(reg, reg, 0xb1); // 10 11 00 01
283*77c1e3ccSAndroid Build Coastguard Worker reg = _mm_add_ps(reg, tmp_reg);
284*77c1e3ccSAndroid Build Coastguard Worker
285*77c1e3ccSAndroid Build Coastguard Worker return reg;
286*77c1e3ccSAndroid Build Coastguard Worker }
287*77c1e3ccSAndroid Build Coastguard Worker
av1_nn_fast_softmax_16_sse3(const float * input,float * output)288*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_fast_softmax_16_sse3(const float *input, float *output) {
289*77c1e3ccSAndroid Build Coastguard Worker // Clips at -10 to avoid underflowing
290*77c1e3ccSAndroid Build Coastguard Worker const __m128 clipper = _mm_set1_ps(-10.0f);
291*77c1e3ccSAndroid Build Coastguard Worker
292*77c1e3ccSAndroid Build Coastguard Worker // Load in 16 values
293*77c1e3ccSAndroid Build Coastguard Worker __m128 in_0 = _mm_loadu_ps(&input[0]);
294*77c1e3ccSAndroid Build Coastguard Worker __m128 in_1 = _mm_loadu_ps(&input[4]);
295*77c1e3ccSAndroid Build Coastguard Worker __m128 in_2 = _mm_loadu_ps(&input[8]);
296*77c1e3ccSAndroid Build Coastguard Worker __m128 in_3 = _mm_loadu_ps(&input[12]);
297*77c1e3ccSAndroid Build Coastguard Worker
298*77c1e3ccSAndroid Build Coastguard Worker // Get the max
299*77c1e3ccSAndroid Build Coastguard Worker __m128 max_0 = _mm_max_ps(in_0, in_1);
300*77c1e3ccSAndroid Build Coastguard Worker __m128 max_1 = _mm_max_ps(in_2, in_3);
301*77c1e3ccSAndroid Build Coastguard Worker
302*77c1e3ccSAndroid Build Coastguard Worker max_0 = _mm_max_ps(max_0, max_1);
303*77c1e3ccSAndroid Build Coastguard Worker max_0 = reduce_max(max_0);
304*77c1e3ccSAndroid Build Coastguard Worker
305*77c1e3ccSAndroid Build Coastguard Worker // Subtract the max off and clip
306*77c1e3ccSAndroid Build Coastguard Worker in_0 = _mm_sub_ps(in_0, max_0);
307*77c1e3ccSAndroid Build Coastguard Worker in_1 = _mm_sub_ps(in_1, max_0);
308*77c1e3ccSAndroid Build Coastguard Worker in_2 = _mm_sub_ps(in_2, max_0);
309*77c1e3ccSAndroid Build Coastguard Worker in_3 = _mm_sub_ps(in_3, max_0);
310*77c1e3ccSAndroid Build Coastguard Worker
311*77c1e3ccSAndroid Build Coastguard Worker in_0 = _mm_max_ps(in_0, clipper);
312*77c1e3ccSAndroid Build Coastguard Worker in_1 = _mm_max_ps(in_1, clipper);
313*77c1e3ccSAndroid Build Coastguard Worker in_2 = _mm_max_ps(in_2, clipper);
314*77c1e3ccSAndroid Build Coastguard Worker in_3 = _mm_max_ps(in_3, clipper);
315*77c1e3ccSAndroid Build Coastguard Worker
316*77c1e3ccSAndroid Build Coastguard Worker // Exponentiate and compute the denominator
317*77c1e3ccSAndroid Build Coastguard Worker __m128 sum = in_0 = approx_exp(in_0);
318*77c1e3ccSAndroid Build Coastguard Worker in_1 = approx_exp(in_1);
319*77c1e3ccSAndroid Build Coastguard Worker sum = _mm_add_ps(sum, in_1);
320*77c1e3ccSAndroid Build Coastguard Worker in_2 = approx_exp(in_2);
321*77c1e3ccSAndroid Build Coastguard Worker sum = _mm_add_ps(sum, in_2);
322*77c1e3ccSAndroid Build Coastguard Worker in_3 = approx_exp(in_3);
323*77c1e3ccSAndroid Build Coastguard Worker sum = _mm_add_ps(sum, in_3);
324*77c1e3ccSAndroid Build Coastguard Worker sum = reduce_sum(sum);
325*77c1e3ccSAndroid Build Coastguard Worker
326*77c1e3ccSAndroid Build Coastguard Worker // Divide to get the probability
327*77c1e3ccSAndroid Build Coastguard Worker in_0 = _mm_div_ps(in_0, sum);
328*77c1e3ccSAndroid Build Coastguard Worker in_1 = _mm_div_ps(in_1, sum);
329*77c1e3ccSAndroid Build Coastguard Worker in_2 = _mm_div_ps(in_2, sum);
330*77c1e3ccSAndroid Build Coastguard Worker in_3 = _mm_div_ps(in_3, sum);
331*77c1e3ccSAndroid Build Coastguard Worker
332*77c1e3ccSAndroid Build Coastguard Worker _mm_storeu_ps(&output[0], in_0);
333*77c1e3ccSAndroid Build Coastguard Worker _mm_storeu_ps(&output[4], in_1);
334*77c1e3ccSAndroid Build Coastguard Worker _mm_storeu_ps(&output[8], in_2);
335*77c1e3ccSAndroid Build Coastguard Worker _mm_storeu_ps(&output[12], in_3);
336*77c1e3ccSAndroid Build Coastguard Worker }
337