1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker *
4*77c1e3ccSAndroid Build Coastguard Worker * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker */
11*77c1e3ccSAndroid Build Coastguard Worker
12*77c1e3ccSAndroid Build Coastguard Worker #include <tuple>
13*77c1e3ccSAndroid Build Coastguard Worker
14*77c1e3ccSAndroid Build Coastguard Worker #include "gtest/gtest.h"
15*77c1e3ccSAndroid Build Coastguard Worker
16*77c1e3ccSAndroid Build Coastguard Worker #include "aom/aom_integer.h"
17*77c1e3ccSAndroid Build Coastguard Worker #include "aom_ports/aom_timer.h"
18*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/ml.h"
19*77c1e3ccSAndroid Build Coastguard Worker #include "config/aom_config.h"
20*77c1e3ccSAndroid Build Coastguard Worker #include "config/aom_dsp_rtcd.h"
21*77c1e3ccSAndroid Build Coastguard Worker #include "config/av1_rtcd.h"
22*77c1e3ccSAndroid Build Coastguard Worker #include "test/util.h"
23*77c1e3ccSAndroid Build Coastguard Worker #include "test/register_state_check.h"
24*77c1e3ccSAndroid Build Coastguard Worker #include "test/acm_random.h"
25*77c1e3ccSAndroid Build Coastguard Worker
26*77c1e3ccSAndroid Build Coastguard Worker namespace {
27*77c1e3ccSAndroid Build Coastguard Worker typedef void (*NnPredict_Func)(const float *const input_nodes,
28*77c1e3ccSAndroid Build Coastguard Worker const NN_CONFIG *const nn_config,
29*77c1e3ccSAndroid Build Coastguard Worker int reduce_prec, float *const output);
30*77c1e3ccSAndroid Build Coastguard Worker
31*77c1e3ccSAndroid Build Coastguard Worker typedef std::tuple<const NnPredict_Func> NnPredictTestParam;
32*77c1e3ccSAndroid Build Coastguard Worker
33*77c1e3ccSAndroid Build Coastguard Worker const float epsilon = 1e-3f; // Error threshold for functional equivalence
34*77c1e3ccSAndroid Build Coastguard Worker
35*77c1e3ccSAndroid Build Coastguard Worker class NnPredictTest : public ::testing::TestWithParam<NnPredictTestParam> {
36*77c1e3ccSAndroid Build Coastguard Worker public:
SetUp()37*77c1e3ccSAndroid Build Coastguard Worker void SetUp() override {
38*77c1e3ccSAndroid Build Coastguard Worker const int MAX_NODES2 = NN_MAX_NODES_PER_LAYER * NN_MAX_NODES_PER_LAYER;
39*77c1e3ccSAndroid Build Coastguard Worker // Allocate two massive buffers on the heap for edge weights and node bias
40*77c1e3ccSAndroid Build Coastguard Worker // Then set-up the double-dimension arrays pointing into the big buffers
41*77c1e3ccSAndroid Build Coastguard Worker weights_buf = (float *)aom_malloc(MAX_NODES2 * (NN_MAX_HIDDEN_LAYERS + 1) *
42*77c1e3ccSAndroid Build Coastguard Worker sizeof(*weights_buf));
43*77c1e3ccSAndroid Build Coastguard Worker bias_buf =
44*77c1e3ccSAndroid Build Coastguard Worker (float *)aom_malloc(NN_MAX_NODES_PER_LAYER *
45*77c1e3ccSAndroid Build Coastguard Worker (NN_MAX_HIDDEN_LAYERS + 1) * sizeof(*bias_buf));
46*77c1e3ccSAndroid Build Coastguard Worker ASSERT_NE(weights_buf, nullptr);
47*77c1e3ccSAndroid Build Coastguard Worker ASSERT_NE(bias_buf, nullptr);
48*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < NN_MAX_HIDDEN_LAYERS + 1; i++) {
49*77c1e3ccSAndroid Build Coastguard Worker weights[i] = &weights_buf[i * MAX_NODES2];
50*77c1e3ccSAndroid Build Coastguard Worker bias[i] = &bias_buf[i * NN_MAX_NODES_PER_LAYER];
51*77c1e3ccSAndroid Build Coastguard Worker }
52*77c1e3ccSAndroid Build Coastguard Worker target_func_ = GET_PARAM(0);
53*77c1e3ccSAndroid Build Coastguard Worker }
TearDown()54*77c1e3ccSAndroid Build Coastguard Worker void TearDown() override {
55*77c1e3ccSAndroid Build Coastguard Worker aom_free(weights_buf);
56*77c1e3ccSAndroid Build Coastguard Worker aom_free(bias_buf);
57*77c1e3ccSAndroid Build Coastguard Worker }
58*77c1e3ccSAndroid Build Coastguard Worker void RunNnPredictTest(const NN_CONFIG *const shape);
59*77c1e3ccSAndroid Build Coastguard Worker void RunNnPredictSpeedTest(const NN_CONFIG *const shape, const int run_times);
60*77c1e3ccSAndroid Build Coastguard Worker void RunNnPredictTest_all(const NN_CONFIG *const shapes,
61*77c1e3ccSAndroid Build Coastguard Worker const int num_shapes);
62*77c1e3ccSAndroid Build Coastguard Worker void RunNnPredictSpeedTest_all(const NN_CONFIG *const shapes,
63*77c1e3ccSAndroid Build Coastguard Worker const int num_shapes, const int run_times);
64*77c1e3ccSAndroid Build Coastguard Worker
65*77c1e3ccSAndroid Build Coastguard Worker private:
66*77c1e3ccSAndroid Build Coastguard Worker NnPredict_Func target_func_;
67*77c1e3ccSAndroid Build Coastguard Worker libaom_test::ACMRandom rng_;
68*77c1e3ccSAndroid Build Coastguard Worker float *weights[NN_MAX_HIDDEN_LAYERS + 1] = {};
69*77c1e3ccSAndroid Build Coastguard Worker float *bias[NN_MAX_HIDDEN_LAYERS + 1] = {};
70*77c1e3ccSAndroid Build Coastguard Worker float *weights_buf = nullptr, *bias_buf = nullptr;
71*77c1e3ccSAndroid Build Coastguard Worker };
72*77c1e3ccSAndroid Build Coastguard Worker GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(NnPredictTest);
73*77c1e3ccSAndroid Build Coastguard Worker
RunNnPredictTest(const NN_CONFIG * const shape)74*77c1e3ccSAndroid Build Coastguard Worker void NnPredictTest::RunNnPredictTest(const NN_CONFIG *const shape) {
75*77c1e3ccSAndroid Build Coastguard Worker float inputs[NN_MAX_NODES_PER_LAYER] = { 0 };
76*77c1e3ccSAndroid Build Coastguard Worker float outputs_test[NN_MAX_NODES_PER_LAYER] = { 0 };
77*77c1e3ccSAndroid Build Coastguard Worker float outputs_ref[NN_MAX_NODES_PER_LAYER] = { 0 };
78*77c1e3ccSAndroid Build Coastguard Worker
79*77c1e3ccSAndroid Build Coastguard Worker NN_CONFIG nn_config;
80*77c1e3ccSAndroid Build Coastguard Worker memcpy(&nn_config, shape, sizeof(nn_config));
81*77c1e3ccSAndroid Build Coastguard Worker
82*77c1e3ccSAndroid Build Coastguard Worker char shape_str[32] = { 0 };
83*77c1e3ccSAndroid Build Coastguard Worker snprintf(shape_str, sizeof(shape_str), "%d", shape->num_inputs);
84*77c1e3ccSAndroid Build Coastguard Worker for (int layer = 0; layer < shape->num_hidden_layers; layer++)
85*77c1e3ccSAndroid Build Coastguard Worker snprintf(&shape_str[strlen(shape_str)],
86*77c1e3ccSAndroid Build Coastguard Worker sizeof(shape_str) - strlen(shape_str), "x%d",
87*77c1e3ccSAndroid Build Coastguard Worker shape->num_hidden_nodes[layer]);
88*77c1e3ccSAndroid Build Coastguard Worker snprintf(&shape_str[strlen(shape_str)], sizeof(shape_str) - strlen(shape_str),
89*77c1e3ccSAndroid Build Coastguard Worker "x%d", shape->num_outputs);
90*77c1e3ccSAndroid Build Coastguard Worker
91*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < NN_MAX_HIDDEN_LAYERS + 1; i++) {
92*77c1e3ccSAndroid Build Coastguard Worker nn_config.weights[i] = weights[i];
93*77c1e3ccSAndroid Build Coastguard Worker nn_config.bias[i] = bias[i];
94*77c1e3ccSAndroid Build Coastguard Worker }
95*77c1e3ccSAndroid Build Coastguard Worker
96*77c1e3ccSAndroid Build Coastguard Worker for (int iter = 0; iter < 10000 && !HasFatalFailure(); ++iter) {
97*77c1e3ccSAndroid Build Coastguard Worker for (int node = 0; node < shape->num_inputs; node++) {
98*77c1e3ccSAndroid Build Coastguard Worker inputs[node] = ((float)rng_.Rand31() - (1 << 30)) / (1u << 31);
99*77c1e3ccSAndroid Build Coastguard Worker }
100*77c1e3ccSAndroid Build Coastguard Worker for (int layer = 0; layer < shape->num_hidden_layers; layer++) {
101*77c1e3ccSAndroid Build Coastguard Worker for (int node = 0; node < NN_MAX_NODES_PER_LAYER; node++) {
102*77c1e3ccSAndroid Build Coastguard Worker bias[layer][node] = ((float)rng_.Rand31() - (1 << 30)) / (1u << 31);
103*77c1e3ccSAndroid Build Coastguard Worker }
104*77c1e3ccSAndroid Build Coastguard Worker for (int node = 0; node < NN_MAX_NODES_PER_LAYER * NN_MAX_NODES_PER_LAYER;
105*77c1e3ccSAndroid Build Coastguard Worker node++) {
106*77c1e3ccSAndroid Build Coastguard Worker weights[layer][node] = ((float)rng_.Rand31() - (1 << 30)) / (1u << 31);
107*77c1e3ccSAndroid Build Coastguard Worker }
108*77c1e3ccSAndroid Build Coastguard Worker }
109*77c1e3ccSAndroid Build Coastguard Worker // Now the outputs:
110*77c1e3ccSAndroid Build Coastguard Worker int layer = shape->num_hidden_layers;
111*77c1e3ccSAndroid Build Coastguard Worker for (int node = 0; node < NN_MAX_NODES_PER_LAYER; node++) {
112*77c1e3ccSAndroid Build Coastguard Worker bias[layer][node] = ((float)rng_.Rand31() - (1 << 30)) / (1u << 31);
113*77c1e3ccSAndroid Build Coastguard Worker }
114*77c1e3ccSAndroid Build Coastguard Worker for (int node = 0; node < NN_MAX_NODES_PER_LAYER * NN_MAX_NODES_PER_LAYER;
115*77c1e3ccSAndroid Build Coastguard Worker node++) {
116*77c1e3ccSAndroid Build Coastguard Worker weights[layer][node] = ((float)rng_.Rand31() - (1 << 30)) / (1u << 31);
117*77c1e3ccSAndroid Build Coastguard Worker }
118*77c1e3ccSAndroid Build Coastguard Worker
119*77c1e3ccSAndroid Build Coastguard Worker av1_nn_predict_c(inputs, &nn_config, 0, outputs_ref);
120*77c1e3ccSAndroid Build Coastguard Worker target_func_(inputs, &nn_config, 0, outputs_test);
121*77c1e3ccSAndroid Build Coastguard Worker
122*77c1e3ccSAndroid Build Coastguard Worker for (int node = 0; node < shape->num_outputs; node++) {
123*77c1e3ccSAndroid Build Coastguard Worker if (outputs_ref[node] < epsilon) {
124*77c1e3ccSAndroid Build Coastguard Worker ASSERT_LE(outputs_test[node], epsilon)
125*77c1e3ccSAndroid Build Coastguard Worker << "Reference output was near-zero, test output was not ("
126*77c1e3ccSAndroid Build Coastguard Worker << shape_str << ")";
127*77c1e3ccSAndroid Build Coastguard Worker } else {
128*77c1e3ccSAndroid Build Coastguard Worker const float error = outputs_ref[node] - outputs_test[node];
129*77c1e3ccSAndroid Build Coastguard Worker const float relative_error = fabsf(error / outputs_ref[node]);
130*77c1e3ccSAndroid Build Coastguard Worker ASSERT_LE(relative_error, epsilon)
131*77c1e3ccSAndroid Build Coastguard Worker << "Excessive relative error between reference and test ("
132*77c1e3ccSAndroid Build Coastguard Worker << shape_str << ")";
133*77c1e3ccSAndroid Build Coastguard Worker }
134*77c1e3ccSAndroid Build Coastguard Worker }
135*77c1e3ccSAndroid Build Coastguard Worker }
136*77c1e3ccSAndroid Build Coastguard Worker }
137*77c1e3ccSAndroid Build Coastguard Worker
RunNnPredictSpeedTest(const NN_CONFIG * const shape,const int run_times)138*77c1e3ccSAndroid Build Coastguard Worker void NnPredictTest::RunNnPredictSpeedTest(const NN_CONFIG *const shape,
139*77c1e3ccSAndroid Build Coastguard Worker const int run_times) {
140*77c1e3ccSAndroid Build Coastguard Worker float inputs[NN_MAX_NODES_PER_LAYER] = { 0 };
141*77c1e3ccSAndroid Build Coastguard Worker float outputs_test[NN_MAX_NODES_PER_LAYER] = { 0 };
142*77c1e3ccSAndroid Build Coastguard Worker float outputs_ref[NN_MAX_NODES_PER_LAYER] = { 0 };
143*77c1e3ccSAndroid Build Coastguard Worker
144*77c1e3ccSAndroid Build Coastguard Worker NN_CONFIG nn_config;
145*77c1e3ccSAndroid Build Coastguard Worker memcpy(&nn_config, shape, sizeof(nn_config));
146*77c1e3ccSAndroid Build Coastguard Worker
147*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < NN_MAX_HIDDEN_LAYERS; i++) {
148*77c1e3ccSAndroid Build Coastguard Worker nn_config.weights[i] = weights[i];
149*77c1e3ccSAndroid Build Coastguard Worker nn_config.bias[i] = bias[i];
150*77c1e3ccSAndroid Build Coastguard Worker }
151*77c1e3ccSAndroid Build Coastguard Worker // Don't bother actually changing the values for inputs/weights/bias: it
152*77c1e3ccSAndroid Build Coastguard Worker // shouldn't make any difference for a speed test.
153*77c1e3ccSAndroid Build Coastguard Worker
154*77c1e3ccSAndroid Build Coastguard Worker aom_usec_timer timer;
155*77c1e3ccSAndroid Build Coastguard Worker aom_usec_timer_start(&timer);
156*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < run_times; ++i) {
157*77c1e3ccSAndroid Build Coastguard Worker av1_nn_predict_c(inputs, &nn_config, 0, outputs_ref);
158*77c1e3ccSAndroid Build Coastguard Worker }
159*77c1e3ccSAndroid Build Coastguard Worker aom_usec_timer_mark(&timer);
160*77c1e3ccSAndroid Build Coastguard Worker const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer));
161*77c1e3ccSAndroid Build Coastguard Worker aom_usec_timer_start(&timer);
162*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < run_times; ++i) {
163*77c1e3ccSAndroid Build Coastguard Worker target_func_(inputs, &nn_config, 0, outputs_test);
164*77c1e3ccSAndroid Build Coastguard Worker }
165*77c1e3ccSAndroid Build Coastguard Worker aom_usec_timer_mark(&timer);
166*77c1e3ccSAndroid Build Coastguard Worker const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer));
167*77c1e3ccSAndroid Build Coastguard Worker
168*77c1e3ccSAndroid Build Coastguard Worker printf("%d", shape->num_inputs);
169*77c1e3ccSAndroid Build Coastguard Worker for (int layer = 0; layer < shape->num_hidden_layers; layer++)
170*77c1e3ccSAndroid Build Coastguard Worker printf("x%d", shape->num_hidden_nodes[layer]);
171*77c1e3ccSAndroid Build Coastguard Worker printf("x%d: ", shape->num_outputs);
172*77c1e3ccSAndroid Build Coastguard Worker printf("%7.2f/%7.2fns (%3.2f)\n", time1, time2, time1 / time2);
173*77c1e3ccSAndroid Build Coastguard Worker }
174*77c1e3ccSAndroid Build Coastguard Worker
175*77c1e3ccSAndroid Build Coastguard Worker // This is all the neural network shapes observed executed in a few different
176*77c1e3ccSAndroid Build Coastguard Worker // runs of the encoder. It also conveniently covers all the kernels
177*77c1e3ccSAndroid Build Coastguard Worker // implemented.
178*77c1e3ccSAndroid Build Coastguard Worker static const NN_CONFIG kShapes[] = {
179*77c1e3ccSAndroid Build Coastguard Worker { 37, 1, 2, { 16, 24 }, {}, {} }, { 24, 24, 1, { 12 }, {}, {} },
180*77c1e3ccSAndroid Build Coastguard Worker { 10, 16, 1, { 64 }, {}, {} }, { 12, 1, 1, { 12 }, {}, {} },
181*77c1e3ccSAndroid Build Coastguard Worker { 12, 1, 1, { 24 }, {}, {} }, { 12, 1, 1, { 32 }, {}, {} },
182*77c1e3ccSAndroid Build Coastguard Worker { 18, 4, 1, { 24 }, {}, {} }, { 18, 4, 1, { 32 }, {}, {} },
183*77c1e3ccSAndroid Build Coastguard Worker { 4, 1, 1, { 16 }, {}, {} }, { 8, 1, 0, { 0 }, {}, {} },
184*77c1e3ccSAndroid Build Coastguard Worker { 8, 4, 1, { 16 }, {}, {} }, { 8, 1, 1, { 32 }, {}, {} },
185*77c1e3ccSAndroid Build Coastguard Worker { 9, 3, 1, { 32 }, {}, {} }, { 8, 4, 0, { 0 }, {}, {} },
186*77c1e3ccSAndroid Build Coastguard Worker { 8, 8, 0, { 0 }, {}, {} }, { 4, 4, 1, { 8 }, {}, {} },
187*77c1e3ccSAndroid Build Coastguard Worker { 4, 3, 0, { 64 }, {}, {} },
188*77c1e3ccSAndroid Build Coastguard Worker };
189*77c1e3ccSAndroid Build Coastguard Worker
RunNnPredictTest_all(const NN_CONFIG * const shapes,const int num_shapes)190*77c1e3ccSAndroid Build Coastguard Worker void NnPredictTest::RunNnPredictTest_all(const NN_CONFIG *const shapes,
191*77c1e3ccSAndroid Build Coastguard Worker const int num_shapes) {
192*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < num_shapes; i++) RunNnPredictTest(&shapes[i]);
193*77c1e3ccSAndroid Build Coastguard Worker }
194*77c1e3ccSAndroid Build Coastguard Worker
RunNnPredictSpeedTest_all(const NN_CONFIG * const shapes,const int num_shapes,const int run_times)195*77c1e3ccSAndroid Build Coastguard Worker void NnPredictTest::RunNnPredictSpeedTest_all(const NN_CONFIG *const shapes,
196*77c1e3ccSAndroid Build Coastguard Worker const int num_shapes,
197*77c1e3ccSAndroid Build Coastguard Worker const int run_times) {
198*77c1e3ccSAndroid Build Coastguard Worker for (int i = 0; i < num_shapes; i++)
199*77c1e3ccSAndroid Build Coastguard Worker NnPredictTest::RunNnPredictSpeedTest(&shapes[i], run_times);
200*77c1e3ccSAndroid Build Coastguard Worker }
201*77c1e3ccSAndroid Build Coastguard Worker
TEST_P(NnPredictTest,RandomValues)202*77c1e3ccSAndroid Build Coastguard Worker TEST_P(NnPredictTest, RandomValues) {
203*77c1e3ccSAndroid Build Coastguard Worker RunNnPredictTest_all(kShapes, sizeof(kShapes) / sizeof(kShapes[0]));
204*77c1e3ccSAndroid Build Coastguard Worker }
205*77c1e3ccSAndroid Build Coastguard Worker
TEST_P(NnPredictTest,DISABLED_Speed)206*77c1e3ccSAndroid Build Coastguard Worker TEST_P(NnPredictTest, DISABLED_Speed) {
207*77c1e3ccSAndroid Build Coastguard Worker RunNnPredictSpeedTest_all(kShapes, sizeof(kShapes) / sizeof(kShapes[0]),
208*77c1e3ccSAndroid Build Coastguard Worker 10000000);
209*77c1e3ccSAndroid Build Coastguard Worker }
210*77c1e3ccSAndroid Build Coastguard Worker
211*77c1e3ccSAndroid Build Coastguard Worker #if !CONFIG_EXCLUDE_SIMD_MISMATCH
212*77c1e3ccSAndroid Build Coastguard Worker #if HAVE_SSE3
213*77c1e3ccSAndroid Build Coastguard Worker INSTANTIATE_TEST_SUITE_P(SSE3, NnPredictTest,
214*77c1e3ccSAndroid Build Coastguard Worker ::testing::Values(av1_nn_predict_sse3));
215*77c1e3ccSAndroid Build Coastguard Worker #endif
216*77c1e3ccSAndroid Build Coastguard Worker
217*77c1e3ccSAndroid Build Coastguard Worker #if HAVE_AVX2
218*77c1e3ccSAndroid Build Coastguard Worker INSTANTIATE_TEST_SUITE_P(AVX2, NnPredictTest,
219*77c1e3ccSAndroid Build Coastguard Worker ::testing::Values(av1_nn_predict_avx2));
220*77c1e3ccSAndroid Build Coastguard Worker #endif
221*77c1e3ccSAndroid Build Coastguard Worker
222*77c1e3ccSAndroid Build Coastguard Worker #if HAVE_NEON
223*77c1e3ccSAndroid Build Coastguard Worker INSTANTIATE_TEST_SUITE_P(NEON, NnPredictTest,
224*77c1e3ccSAndroid Build Coastguard Worker ::testing::Values(av1_nn_predict_neon));
225*77c1e3ccSAndroid Build Coastguard Worker #endif
226*77c1e3ccSAndroid Build Coastguard Worker #endif // !CONFIG_EXCLUDE_SIMD_MISMATCH
227*77c1e3ccSAndroid Build Coastguard Worker
228*77c1e3ccSAndroid Build Coastguard Worker } // namespace
229