xref: /aosp_15_r20/external/libaom/test/cnn_test.cc (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <math.h>
14 #include <stdio.h>
15 
16 #include "gtest/gtest.h"
17 
18 #include "config/av1_rtcd.h"
19 
20 #include "aom_ports/aom_timer.h"
21 #include "av1/encoder/cnn.h"
22 #include "av1/encoder/partition_cnn_weights.h"
23 #include "test/acm_random.h"
24 #include "test/function_equivalence_test.h"
25 #include "test/util.h"
26 
27 #define SQR(x) ((x) * (x))
28 
29 // Best possible pixelwise guaranteed precision given each float has at most
30 // 3 specified decimals.
31 #define PIXELWISE_FLOAT_TOL 1E-2
32 
33 #define MSE_FLOAT_TOL 1E-6
34 #define MSE_INT_TOL 0
35 
36 // CNN convolve pixelwise error threshold for functional equivalence.
37 #define CNN_CONVOLVE_PIXELWISE_FLOAT_TOL 1E-3f
38 
39 namespace {
40 
41 class CNNTest : public ::testing::Test {
42  protected:
RunCNNTest(int image_width,int image_height,const float * input,const float * expected,const CNN_CONFIG * cnn_config,int in_stride,CNN_THREAD_DATA * thread_data,double tolerance)43   static void RunCNNTest(int image_width, int image_height, const float *input,
44                          const float *expected, const CNN_CONFIG *cnn_config,
45                          int in_stride, CNN_THREAD_DATA *thread_data,
46                          double tolerance) {
47     int out_width, out_height, out_channels;
48     av1_find_cnn_output_size(image_width, image_height, cnn_config, &out_width,
49                              &out_height, &out_channels);
50 
51     const int out_size = out_width * out_height;
52     const int out_stride = out_width;
53 
54     float *output_ =
55         (float *)aom_malloc(sizeof(*output_) * out_size * out_channels);
56     ASSERT_NE(output_, nullptr);
57     float *output[CNN_MAX_CHANNELS] = { nullptr };
58     for (int channel = 0; channel < out_channels; ++channel) {
59       output[channel] = output_ + (channel * out_size);
60     }
61     const int num_outputs = 1;
62     const int output_chs[1] = { out_channels };
63     const int output_strides[1] = { out_stride };
64     CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_strides,
65                                     output };
66 
67     RunMultiOutCNNTest(&input, image_width, image_height, in_stride, cnn_config,
68                        thread_data, &output_struct, &expected, tolerance);
69 
70     aom_free(output_);
71   }
72 
RunMultiOutCNNTest(const float ** input,int image_width,int image_height,int in_stride,const CNN_CONFIG * cnn_config,CNN_THREAD_DATA * thread_data,CNN_MULTI_OUT * output,const float ** expected,double tolerance)73   static void RunMultiOutCNNTest(const float **input, int image_width,
74                                  int image_height, int in_stride,
75                                  const CNN_CONFIG *cnn_config,
76                                  CNN_THREAD_DATA *thread_data,
77                                  CNN_MULTI_OUT *output, const float **expected,
78                                  double tolerance) {
79     const int num_outputs = output->num_outputs;
80     const int *output_chs = output->output_channels;
81 
82     int *out_widths = (int *)aom_calloc(num_outputs, sizeof(*out_widths));
83     int *out_heights = (int *)aom_calloc(num_outputs, sizeof(*out_heights));
84     int *not_used = (int *)aom_calloc(num_outputs, sizeof(*not_used));
85     ASSERT_NE(out_widths, nullptr);
86     ASSERT_NE(out_heights, nullptr);
87     ASSERT_NE(not_used, nullptr);
88 
89     av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
90                              out_heights, not_used);
91     ASSERT_TRUE(av1_cnn_predict(input, image_width, image_height, in_stride,
92                                 cnn_config, thread_data, output));
93 
94     int channel_offset = 0;
95     for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
96       const float *expected_out = expected[output_idx];
97       const int curr_output_chs = output_chs[output_idx];
98       const int out_size = out_widths[output_idx] * out_heights[output_idx];
99 
100       double mse = 0;
101       int expected_ite = 0;
102       for (int channel = 0; channel < curr_output_chs; ++channel) {
103         const float *buf_out = output->output_buffer[channel_offset];
104 
105         for (int i = 0; i < out_size; ++i) {
106           EXPECT_NEAR(expected_out[expected_ite], buf_out[i],
107                       PIXELWISE_FLOAT_TOL)
108               << " output " << output_idx << " channel " << channel << " pixel "
109               << expected_ite % out_size << ": " << expected_out[expected_ite]
110               << "/" << buf_out[i] << std::endl;
111           mse += SQR(expected_out[expected_ite] - buf_out[i]);
112           expected_ite++;
113         }
114 
115         channel_offset++;
116       }
117       mse /= (out_size * curr_output_chs);
118       EXPECT_LE(mse, tolerance) << " output " << output_idx << std::endl;
119     }
120 
121     aom_free(out_widths);
122     aom_free(out_heights);
123     aom_free(not_used);
124   }
125 
AssignLayerWeightsBiases(CNN_CONFIG * cnn_config,float * weights,float * bias)126   static void AssignLayerWeightsBiases(CNN_CONFIG *cnn_config, float *weights,
127                                        float *bias) {
128     size_t weight_offset = 0;
129     size_t bias_offset = 0;
130     for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
131       CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
132       layer_config->weights = weights + weight_offset;
133       layer_config->bias = bias + bias_offset;
134       weight_offset += layer_config->filter_width *
135                        layer_config->filter_height * layer_config->in_channels *
136                        layer_config->out_channels;
137       bias_offset += layer_config->out_channels;
138 
139       ASSERT_NE(layer_config->weights, nullptr);
140       ASSERT_NE(layer_config->bias, nullptr);
141     }
142   }
143 };
144 
145 }  // namespace
146 
TEST_F(CNNTest,TestMultilayerConvolution)147 TEST_F(CNNTest, TestMultilayerConvolution) {
148   int image_height = 16;
149   int image_width = 16;
150   int filter_height = 5;
151   int filter_width = 4;
152 
153   float input[] = {
154     -3, 1,  -3, 2,  -2, -2, 2,  -2, 1,  -2, -3, 1,  2,  2,  2,  -2, 0,  1,  -1,
155     -3, -1, -1, 1,  0,  -3, 1,  0,  -1, 1,  0,  0,  -3, -3, -3, 0,  2,  1,  -1,
156     2,  0,  1,  -3, -1, 2,  2,  1,  -2, 0,  -1, 0,  -2, -2, -1, 1,  0,  0,  0,
157     -2, -2, -2, 1,  1,  -2, 1,  1,  -2, -2, 1,  -2, -1, -2, -3, 2,  -3, -1, 1,
158     0,  -2, -2, -2, 1,  -2, -2, -1, -1, 2,  2,  2,  -1, 1,  -3, -3, 0,  2,  0,
159     2,  1,  -3, -3, 1,  2,  2,  1,  -2, -3, 0,  -3, 0,  -3, -2, 0,  1,  1,  0,
160     -3, 2,  -1, 2,  1,  0,  1,  -2, 1,  -1, -1, 2,  0,  -2, -3, 1,  1,  -2, -1,
161     -3, -3, -1, 0,  -3, -2, 0,  0,  1,  0,  -3, -2, -1, 1,  0,  2,  1,  0,  -3,
162     -2, -3, -3, -1, 0,  -2, 2,  -1, -3, 0,  -1, -1, 2,  0,  -3, -2, -1, 0,  0,
163     1,  -2, 1,  2,  1,  2,  2,  -3, 2,  -1, 0,  0,  -1, 0,  2,  2,  -1, 2,  -2,
164     1,  1,  -3, -3, 1,  -1, -1, -2, 2,  -2, -2, 2,  -1, -3, 2,  -3, 1,  -1, -1,
165     -3, 1,  -1, 1,  0,  -3, -3, 1,  -3, -3, 0,  2,  2,  -2, -1, 2,  0,  2,  1,
166     -1, -3, 0,  0,  -1, -1, 1,  0,  2,  0,  -3, 2,  1,  0,  1,  -3, 2,  -3, -3,
167     -1, -3, -3, 2,  0,  2,  -2, 1,  -1,
168   };
169 
170   float weights[] = {
171     -2, 2,  -2, 2,  -1, -3, 2,  2,  0,  0,  -3, -1, -2, -3, 1,  -1, 0,  0,  0,
172     2,  -2, 2,  -2, -3, 1,  1,  1,  -3, -1, 0,  1,  2,  -2, 0,  -1, -3, -1, -2,
173     2,  -3, -3, 1,  -2, -3, 0,  2,  1,  -3, -3, -1, -3, -2, -1, -3, -1, -3, -2,
174     -1, -3, -1, -2, -2, -3, 2,  0,  -3, 0,  -3, -3, 1,  -3, -1, 0,  -1, 1,  1,
175     -1, 1,  -2, 0,  2,  0,  -3, 1,  -1, -1, 2,  0,  1,  -3, -3, 1,  2,  -3, -3,
176     1,  -3, 2,  0,  -3, 1,  2,  2,  -2, -1, -2, 1,  1,  0,  -2, -2, 1,  2,  -1,
177     -3, 1,  -2, 2,  -3, -2, -3, 2,  1,  0,  -2, 0,  1,  -3, 2,  -2, -2, 0,  2,
178     -3, 2,  0,  0,  1,  -2, 1,  1,  -2, -1, -2, 1,  -2, 0,  -2, -2, 0,  -1, -1,
179     -3, -3, -3, 1,  -3, -2, 2,  -1, 2,  0,  2,  -2, 2,  -2, 1,  -3, -3, -1, 0,
180     2,  2,  1,  -1, -3, -1, -3, 2,  1,  -2, 0,  -3, -1, -3, -1, 2,  1,  0,  2,
181     -1, 1,  0,  1,  2,  -1, -2, 2,  1,  -3, -1, -3, 0,  1,  -2, 0,  -2, -3, 0,
182     -2, 2,  2,  0,  0,  2,  -3, 2,  -3, -2, 1,  2,  -3, -3, -1, -3, 0,  -3, -3,
183     -2, -2, -2, 0,  0,  1,  0,  0,  -1, 0,  0,  -3, 0,  -3, -1, -2, 1,  -2, -1,
184     2,  -2, 0,  0,  1,  0,  -2, -1, 0,  -3, 1,  0,  -1, -3, 1,  -1, 1,  -1, -3,
185     1,  0,  1,  1,  -1, 2,  2,  0,  0,  1,  -3, 2,  -2, -2, -3, -2, -1, -2, 2,
186     0,  2,  -2, -3, -1, -3, 2,  2,  -1, 2,  2,  -1, 0,  -3, 1,
187   };
188 
189   float bias[] = {
190     1, -1, 0, 1, 1, 1, -2,
191   };
192 
193   float expected_same[] = {
194     -1125, 2926,  6406,  631,   -1244, 97,    -1454, 2526,  1065,  3292,  3464,
195     2553,  -330,  532,   1038,  1182,  -402,  3758,  3392,  9854,  4365,  1408,
196     4736,  3134,  3838,  2409,  3221,  4350,  6750,  4045,  815,   1188,  2959,
197     9802,  9590,  4572,  5740,  4253,  1701,  7974,  7012,  6854,  7093,  3907,
198     4539,  3886,  4267,  3505,  465,   7824,  9219,  10026, 7968,  957,   2295,
199     5594,  10811, 9641,  5950,  10043, 8783,  3132,  1421,  1110,  4108,  13929,
200     10660, -84,   -61,   3932,  -180,  6811,  13393, 15147, 15640, 9337,  6961,
201     3808,  1604,  1398,  1047,  6739,  10144, 6517,  4698,  2678,  7389,  2595,
202     5248,  12075, 11272, 13951, 8820,  1090,  2199,  2206,  2788,  12116, 6683,
203     2612,  -291,  3183,  9414,  12316, 14524, 12333, 13208, 7832,  4664,  4657,
204     3534,  1298,  -666,  4250,  7707,  9103,  5760,  688,   9571,  15782, 14203,
205     14878, 17339, 14684, 8690,  5671,  875,   1429,  1531,  6173,  2984,  5558,
206     2996,  7928,  6733,  16117, 15262, 12757, 7980,  3923,  4795,  5973,  2051,
207     455,   -1922, 1816,  5906,  3321,  10908, 10910, 7377,  12204, 12809, 11195,
208     7451,  6666,  74,    -1645, -35,   -391,  3813,  7324,  892,   1656,  6095,
209     12193, 14648, 12156, 14663, 10251, 10325, 7821,  3925,  323,   697,   442,
210     1324,  4669,  7002,  5485,  5171,  5086,  10582, 11053, 9709,  11353, 8543,
211     5256,  2873,  235,   -628,  1496,  1878,  -867,  3420,  6865,  5937,  10182,
212     13277, 10069, 10789, 5998,  624,   -2082, 4417,  1258,  -1080, -819,  -1430,
213     1033,  5220,  6335,  8471,  8980,  11908, 14430, 12584, 8404,  1576,  -803,
214     985,   1481,  1367,  -193,  873,   3684,  2288,  6676,  9477,  11155, 9602,
215     9707,  10507, 4739,  3174,  -575,  -178,  3002,  1710,  423,   -477,  554,
216     3088,  2029,  5113,  5000,  3771,  6090,  5365,  1185,  2855,  399,   -312,
217     -1577, 176,   955,
218   };
219 
220   float expected_replicate[] = {
221     13768, 13528, 12999, 6906,  4618,  4043,  2611,  9955,  6685,  4776,  2753,
222     1036,  3063,  4544,  5183,  7349,  12451, 12501, 9131,  12753, 8908,  4058,
223     6299,  7542,  7115,  3307,  3360,  3543,  9754,  7808,  5991,  9019,  14320,
224     14919, 12492, 6871,  7373,  3336,  2085,  10604, 9377,  6882,  5009,  3103,
225     6220,  6278,  7588,  10196, 11045, 11563, 11842, 11911, 8279,  2030,  1858,
226     6368,  12123, 9909,  6347,  10345, 9365,  4038,  1673,  3051,  16492, 16649,
227     12276, 408,   -301,  4122,  -654,  7864,  14038, 15279, 15315, 9744,  8243,
228     5298,  746,   380,   9824,  9124,  10895, 6640,  4712,  2669,  6980,  2759,
229     5385,  12345, 11336, 13129, 8600,  2370,  3682,  5219,  12407, 13123, 6784,
230     2612,  -291,  3183,  9414,  12316, 14524, 12333, 13397, 7543,  3916,  4153,
231     4477,  4314,  7983,  8418,  9163,  9103,  5760,  688,   9571,  15782, 14203,
232     14878, 17718, 14570, 7940,  6642,  5094,  7133,  9964,  10219, 3224,  5558,
233     2996,  7928,  6733,  16117, 15262, 12757, 7958,  4401,  5187,  5476,  5529,
234     6055,  2206,  3909,  6015,  3321,  10908, 10910, 7377,  12204, 12809, 11195,
235     6967,  6840,  481,   -1600, 274,   1,     10373, 8514,  1123,  2117,  6758,
236     12736, 16223, 13585, 15988, 11771, 10600, 7918,  4156,  2840,  3111,  3287,
237     6359,  7652,  8813,  6530,  6967,  7789,  13671, 13990, 13247, 13241, 9836,
238     5251,  3024,  2313,  1834,  4187,  2637,  -1312, 2139,  7378,  7665,  11933,
239     15591, 15314, 15678, 9531,  2820,  -1516, 3400,  1314,  22,    363,   -2896,
240     -898,  5906,  7308,  10650, 12975, 16978, 20370, 18817, 12381, 4118,  -861,
241     -137,  236,   1802,  1632,  -350,  2334,  3400,  8680,  14064, 18216, 18675,
242     21765, 22871, 11491, 4937,  -1555, -11,   1669,  2392,  3265,  -5254, -217,
243     5001,  8063,  13444, 18884, 19706, 22794, 21064, 9545,  6689,  -7,    289,
244     -2021, 504,   2347,
245   };
246 
247   float expected_valid[] = {
248     2612,  -291,  3183,  9414,  12316, 14524, 12333, 9103,  5760,  688,
249     9571,  15782, 14203, 14878, 5558,  2996,  7928,  6733,  16117, 15262,
250     12757, 3321,  10908, 10910, 7377,  12204, 12809, 11195,
251   };
252 
253   CNN_CONFIG cnn_config = { 3,
254                             0,
255                             0,
256                             0,
257                             0,
258                             {
259                                 {
260                                     1,
261                                     filter_width,
262                                     filter_height,
263                                     3,
264                                     1,
265                                     1,
266                                     0,
267                                     nullptr,
268                                     nullptr,
269                                     PADDING_SAME_ZERO,
270                                     NONE,
271                                     0,
272                                     0,
273                                     BRANCH_NO_COPY,
274                                     BRANCH_NOC,
275                                     {},
276                                     {},
277                                     -1,
278                                 },
279                                 {
280                                     3,
281                                     filter_width,
282                                     filter_height,
283                                     3,
284                                     1,
285                                     1,
286                                     0,
287                                     nullptr,
288                                     nullptr,
289                                     PADDING_SAME_ZERO,
290                                     NONE,
291                                     0,
292                                     0,
293                                     BRANCH_NO_COPY,
294                                     BRANCH_NOC,
295                                     {},
296                                     {},
297                                     -1,
298                                 },
299                                 {
300                                     3,
301                                     filter_width,
302                                     filter_height,
303                                     1,
304                                     1,
305                                     1,
306                                     0,
307                                     nullptr,
308                                     nullptr,
309                                     PADDING_SAME_ZERO,
310                                     NONE,
311                                     0,
312                                     0,
313                                     BRANCH_NO_COPY,
314                                     BRANCH_NOC,
315                                     {},
316                                     {},
317                                     0,
318                                 },
319                             } };
320 
321   // Weights and biases need to be specified separately because
322   // of the offset.
323   AssignLayerWeightsBiases(&cnn_config, weights, bias);
324 
325   CNN_THREAD_DATA thread_data = { 1, nullptr };
326 
327   RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
328              image_width, &thread_data, MSE_INT_TOL);
329 
330   for (int i = 0; i < cnn_config.num_layers; ++i) {
331     cnn_config.layer_config[i].pad = PADDING_SAME_REPLICATE;
332   }
333 
334   RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
335              image_width, &thread_data, MSE_INT_TOL);
336 
337   for (int i = 0; i < cnn_config.num_layers; ++i) {
338     cnn_config.layer_config[i].pad = PADDING_VALID;
339   }
340 
341   RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
342              image_width, &thread_data, MSE_INT_TOL);
343 }
344 
TEST_F(CNNTest,TestRELUSingleLayer)345 TEST_F(CNNTest, TestRELUSingleLayer) {
346   int image_width = 8;
347   int image_height = 8;
348   int filter_height = 5;
349   int filter_width = 4;
350   float input[] = {
351     0, -2, -3, 1,  -1, 2,  -2, 1,  -3, -1, 0,  1,  -2, -3, -2, -2,
352     1, -3, 2,  -3, -1, -1, 2,  0,  -2, -3, 0,  -2, -3, 1,  -1, -1,
353     2, -2, 0,  -2, -3, -3, 1,  1,  -1, 1,  0,  1,  -3, 0,  2,  2,
354     0, -3, 1,  -3, 2,  -2, 1,  -1, -1, -2, -3, -2, -1, -3, -2, -1,
355   };
356   float expected_same[] = {
357     9,  0,  1,  1,  0,  3,  0,  19, 0,  12, 10, 0,  0,  0,  5, 0,
358     0,  18, 21, 7,  19, 4,  3,  0,  0,  9,  16, 0,  11, 16, 0, 11,
359     12, 2,  0,  11, 0,  16, 6,  0,  8,  22, 13, 10, 12, 0,  0, 0,
360     0,  1,  2,  12, 29, 6,  10, 0,  13, 0,  0,  5,  8,  10, 0, 0,
361   };
362   float expected_replicate[] = {
363     18, 17, 12, 2,  0,  0,  5,  11, 0,  17, 22, 6,  0,  0,  17, 0,
364     0,  18, 21, 7,  19, 4,  3,  5,  3,  9,  16, 0,  11, 16, 0,  3,
365     3,  2,  0,  11, 0,  16, 6,  0,  17, 22, 13, 10, 12, 0,  0,  0,
366     0,  4,  1,  10, 30, 7,  10, 0,  23, 8,  0,  13, 15, 19, 8,  10,
367   };
368   float expected_valid[] = {
369     18, 21, 7, 19, 4, 9, 16, 0, 11, 16, 2, 0, 11, 0, 16, 22, 13, 10, 12, 0,
370   };
371   float weights[] = {
372     -2, -3, 1, 2, 2, -2, -3, 0, -3, 2, 2, -3, -3, -2, 0, 1, 2, 0, -1, -1,
373   };
374   float bias[] = { -3 };
375 
376   CNN_CONFIG cnn_config = { 1,
377                             0,
378                             0,
379                             0,
380                             0,
381                             { {
382                                 1,
383                                 filter_width,
384                                 filter_height,
385                                 1,
386                                 1,
387                                 1,
388                                 0,
389                                 weights,
390                                 bias,
391                                 PADDING_SAME_ZERO,
392                                 RELU,
393                                 0,
394                                 0,
395                                 BRANCH_NO_COPY,
396                                 BRANCH_NOC,
397                                 {},
398                                 {},
399                                 0,
400                             } } };
401 
402   CNN_THREAD_DATA thread_data = { 1, nullptr };
403 
404   RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
405              image_width, &thread_data, MSE_INT_TOL);
406 
407   cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
408 
409   RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
410              image_width, &thread_data, MSE_INT_TOL);
411 
412   cnn_config.layer_config[0].pad = PADDING_VALID;
413 
414   RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
415              image_width, &thread_data, MSE_INT_TOL);
416 }
417 
TEST_F(CNNTest,TestVaryingStridesVaryingDimImages)418 TEST_F(CNNTest, TestVaryingStridesVaryingDimImages) {
419   float weights[] = {
420     1,  -5, -3, -4, -1, 1,  2,  -3, 2,  2,  -1, 1,  -5, 1,  1,
421     -3, -5, 3,  1,  4,  -2, -5, -2, -3, -5, 0,  -1, -5, 2,  -2,
422     -2, 1,  -2, -4, 1,  3,  -2, 2,  0,  -3, 2,  -3, -2, -3,
423   };
424   float bias[] = { 2 };
425 
426   CNN_CONFIG cnn_config = { 1,
427                             0,
428                             0,
429                             0,
430                             0,
431                             {
432                                 {
433                                     1,
434                                     4,
435                                     11,
436                                     1,
437                                     7,
438                                     6,
439                                     0,
440                                     weights,
441                                     bias,
442                                     PADDING_SAME_ZERO,
443                                     NONE,
444                                     0,
445                                     0,
446                                     BRANCH_NO_COPY,
447                                     BRANCH_NOC,
448                                     {},
449                                     {},
450                                     0,
451                                 },
452                             } };
453 
454   int image_height = 24;
455   int image_width = 17;
456   float input[] = {
457     -1, -3, 4,  4,  -5, 4,  3,  -5, -1, -3, 4,  -4, 2,  -3, 3,  -5, 2,  -1, -5,
458     1,  -1, 3,  1,  -3, -3, 4,  0,  2,  -3, -5, -5, -4, 0,  -5, -2, -3, -1, -2,
459     2,  -5, 4,  4,  0,  -4, -3, 1,  -3, -5, -4, -4, 1,  -2, -3, 3,  -3, -3, -1,
460     -5, -5, -2, 3,  1,  -1, -5, -5, 1,  -4, -2, -1, -2, -4, -4, 2,  -2, 2,  1,
461     -2, -4, -1, 1,  -2, -5, 3,  -2, -1, -1, -5, -3, 1,  -2, -2, -3, -1, -2, -4,
462     -2, 1,  -4, -1, 4,  3,  -4, 0,  4,  2,  2,  4,  -3, -5, 2,  2,  1,  -1, -4,
463     -2, 1,  3,  2,  0,  4,  -1, -3, 2,  1,  -4, 2,  2,  -4, -2, 0,  -2, -1, 4,
464     4,  2,  3,  -4, 2,  -4, -5, 4,  -1, -3, -1, 0,  -4, 1,  3,  -1, -3, -5, 3,
465     -2, -4, 1,  2,  -2, -3, -3, -5, 1,  -3, -1, 0,  -1, 3,  -4, -1, -5, -5, 1,
466     0,  0,  -2, -2, 2,  -2, 0,  0,  2,  0,  -3, 0,  -1, -4, -4, -1, 3,  -4, -4,
467     -1, 0,  -5, -3, -2, 4,  -3, -4, -4, 0,  -5, 1,  -2, -3, -3, -4, 4,  3,  4,
468     3,  3,  -1, 3,  1,  -3, -2, 3,  3,  0,  2,  -4, -3, 2,  2,  0,  -2, 4,  -2,
469     2,  -2, -1, -4, -2, 2,  -4, 3,  -1, 4,  1,  1,  4,  -1, -4, -4, 1,  1,  -2,
470     4,  -1, 3,  2,  -3, 4,  3,  1,  4,  0,  -4, 2,  0,  2,  4,  -2, -2, 4,  2,
471     -1, -2, 1,  -3, 2,  3,  -5, -3, 4,  4,  2,  -5, -4, -5, -2, -4, 2,  0,  2,
472     -5, 4,  -4, -2, -5, 2,  1,  0,  4,  1,  -2, -3, -4, -3, -4, 3,  3,  2,  0,
473     -3, 1,  -5, 4,  0,  4,  -1, 3,  -5, -5, -2, -1, -1, 4,  3,  3,  4,  3,  -4,
474     4,  -3, -3, -1, -4, -1, -4, -1, -2, 4,  -2, -4, 4,  4,  -3, -4, -1, 1,  2,
475     -1, -2, -2, 3,  2,  2,  -3, 0,  -1, 0,  3,  2,  -5, 0,  -4, 0,  0,  2,  -4,
476     -1, -1, 0,  -2, 0,  1,  0,  0,  4,  -5, -1, -5, 2,  -1, 0,  2,  -1, 1,  3,
477     -3, -5, -2, -3, 4,  -2, -2, -1, -3, -4, -1, -2, -4, 1,  4,  -3, -2, -1, 3,
478     -3, -2, 3,  2,  1,  -4, -3, -5, 1,
479   };
480   float expected_1[] = {
481     41, -26, 5, 76, 13, 83, -21, 53, -54, -14, 21, 121,
482   };
483 
484   CNN_THREAD_DATA thread_data = { 1, nullptr };
485 
486   RunCNNTest(image_width, image_height, input, expected_1, &cnn_config,
487              image_width, &thread_data, MSE_INT_TOL);
488 
489   cnn_config.layer_config[0].skip_width = 6;
490   cnn_config.layer_config[0].skip_height = 7;
491 
492   float expected_2[] = {
493     21, -50, 41, 20, 72, 127, -21, 103, 62, -37, 83, -3,
494   };
495   RunCNNTest(image_width, image_height, input, expected_2, &cnn_config,
496              image_width, &thread_data, MSE_INT_TOL);
497 
498   cnn_config.layer_config[0].skip_width = 3;
499   cnn_config.layer_config[0].skip_height = 10;
500 
501   float expected_3[] = {
502     -26, -21, -35, 69, 49,  4,  -51, -43, -56,
503     -41, 15,  -44, 40, -62, 63, 38,  27,  47,
504   };
505   RunCNNTest(image_width, image_height, input, expected_3, &cnn_config,
506              image_width, &thread_data, MSE_INT_TOL);
507 
508   cnn_config.layer_config[0].skip_width = 10;
509   cnn_config.layer_config[0].skip_height = 3;
510 
511   float expected_4[] = {
512     21, 49, 28, 87, 50, 40, 102, 81, 58, 85, 51, 66, 36, 19, -37, -45,
513   };
514 
515   RunCNNTest(image_width, image_height, input, expected_4, &cnn_config,
516              image_width, &thread_data, MSE_INT_TOL);
517 }
518 
TEST_F(CNNTest,TestMaxPool)519 TEST_F(CNNTest, TestMaxPool) {
520   int image_width = 8;
521   int image_height = 8;
522   int stride = 3;
523   float input[] = {
524     1,  -4, -4, 8, 0, 7, -5, -2, 8, 2, 2, 8,  5,  -1, -1, 9,
525     -3, 0,  -2, 0, 6, 3, -4, 8,  7, 8, 7, -1, 4,  -1, 0,  2,
526     -5, -2, 8,  5, 5, 4, 2,  7,  4, 6, 2, 8,  8,  -4, -3, -4,
527     -3, -1, 2,  3, 3, 6, -5, 8,  9, 5, 0, -2, -1, 6,  5,  7,
528   };
529 
530   float expected[] = {
531     49, 58, 70, 68, 68, 70, 48, 57, 88,
532   };
533 
534   float weights[] = {
535     3, 1, 3, 4, -1, 5, -2, 1, -4,
536   };
537 
538   float bias[] = {
539     -3,
540   };
541 
542   CNN_CONFIG cnn_config = { 1,
543                             0,
544                             0,
545                             0,
546                             0,
547                             { {
548                                 1,
549                                 3,
550                                 3,
551                                 1,
552                                 stride,
553                                 stride,
554                                 1,
555                                 weights,
556                                 bias,
557                                 PADDING_SAME_ZERO,
558                                 NONE,
559                                 0,
560                                 0,
561                                 BRANCH_NO_COPY,
562                                 BRANCH_NOC,
563                                 {},
564                                 {},
565                                 0,
566                             } } };
567 
568   CNN_THREAD_DATA thread_data = { 1, nullptr };
569 
570   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
571              image_width, &thread_data, MSE_INT_TOL);
572 }
573 
TEST_F(CNNTest,TestDeconvolveNonActivationSingleLayerSingleKernel)574 TEST_F(CNNTest, TestDeconvolveNonActivationSingleLayerSingleKernel) {
575   int image_width = 4;
576   int image_height = 7;
577   float input[] = {
578     9,  6,   181, 9,  218, 30, 80,  108, 68,  216, 70, 128, 179, 228,
579     33, 212, 34,  14, 48,  27, 230, 23,  202, 113, 80, 56,  122, 112,
580   };
581 
582   float expected_1_same[] = {
583     15,   -30,  36,   -525,  377, -193, 558, 531,  6,   -24,  -15,  124,
584     166,  -561, -356, -754,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
585     433,  -311, 711,  381,   247, -317, 453, 129,  215, -627, -409, -885,
586     17,   -255, -55,  -647,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
587     133,  -719, 633,  -225,  785, 191,  463, 79,   65,  9,    77,   -853,
588     -365, -949, -15,  -667,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
589     355,  -866, 990,  207,   747, 12,   520, -116, 176, -312, -133, -1370,
590     -426, -802, 143,  -771,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
591     65,   -79,  127,  -59,   135, -90,  195, 114,  31,  -91,  -57,  -133,
592     17,   -176, -72,  -276,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
593     457,  -302, 733,  58,    470, -475, 829, 490,  227, -670, -440, -790,
594     153,  -588, -294, -1150, -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
595     157,  -251, 349,  -185,  409, -293, 587, 251,  77,  -187, -107, -369,
596     7,    -481, -135, -827,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
597   };
598   float expected_1_valid[] = {
599     -30,  15,   -30,  36,   -525,  377,  -193,  558,  531,  24,   24,   6,
600     6,    -24,  -15,  124,  166,   -561, -356,  -754, -21,  -39,  -3,   -3,
601     -3,   -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -657, 433,  -311,
602     711,  381,  247,  -317, 453,   129,  321,   321,  215,  215,  -627, -409,
603     -885, 17,   -255, -55,  -647,  -219, -435,  -3,   -3,   -3,   -3,   -3,
604     -3,   -3,   -3,   -3,   -3,    -3,   -207,  133,  -719, 633,  -225, 785,
605     191,  463,  79,   381,  381,   65,   65,    9,    77,   -853, -365, -949,
606     -15,  -667, -259, -515, -3,    -3,   -3,    -3,   -3,   -3,   -3,   -3,
607     -3,   -3,   -3,   -540, 355,   -866, 990,   207,  747,  12,   520,  -116,
608     633,  633,  176,  176,  -312,  -133, -1370, -426, -802, 143,  -771, -427,
609     -851, -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -3,   -3,   -3,
610     -105, 65,   -79,  127,  -59,   135,  -90,   195,  114,  78,   78,   31,
611     31,   -91,  -57,  -133, 17,    -176, -72,   -276, -57,  -111, -3,   -3,
612     -3,   -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -693, 457,  -302,
613     733,  58,   470,  -475, 829,   490,  336,   336,  227,  227,  -670, -440,
614     -790, 153,  -588, -294, -1150, -229, -455,  -3,   -3,   -3,   -3,   -3,
615     -3,   -3,   -3,   -3,   -3,    -3,   -243,  157,  -251, 349,  -185, 409,
616     -293, 587,  251,  333,  333,   77,   77,    -187, -107, -369, 7,    -481,
617     -135, -827, -227, -451,
618   };
619   float weights_1[] = { -3, 2, -1, 3, 3, 1, 1, -3, -2, -4 };
620   float bias_1[] = { -3 };
621 
622   CNN_CONFIG cnn_config = { 1,
623                             0,
624                             0,
625                             0,
626                             0,
627                             { {
628                                 1,
629                                 5,
630                                 2,
631                                 1,
632                                 2,
633                                 3,
634                                 0,
635                                 weights_1,
636                                 bias_1,
637                                 PADDING_SAME_ZERO,
638                                 NONE,
639                                 1,
640                                 0,
641                                 BRANCH_NO_COPY,
642                                 BRANCH_NOC,
643                                 {},
644                                 {},
645                                 0,
646                             } } };
647 
648   CNN_THREAD_DATA thread_data = { 1, nullptr };
649 
650   RunCNNTest(image_width, image_height, input, expected_1_same, &cnn_config,
651              image_width, &thread_data, MSE_INT_TOL);
652 
653   // Change padding to valid
654   cnn_config.layer_config[0].pad = PADDING_VALID;
655 
656   RunCNNTest(image_width, image_height, input, expected_1_valid, &cnn_config,
657              image_width, &thread_data, MSE_INT_TOL);
658 
659   float expected_12_same[] = {
660     15,  -12,  6,    36,   -9,   -528, 377,  -184, 513,  558,  -12,  24,
661     6,   -30,  -15,  -33,  -21,  166,  154,  -546, -356, -718, -30,  -21,
662     433, -221, 561,  711,  -33,  -153, 247,  -83,  -87,  453,  -111, 321,
663     215, -657, -409, -845, -93,  17,   -43,  -243, -55,  -215, -327, -219,
664     133, -71,  -447, 633,  -219, 435,  785,  -73,  -177, 463,  -131, 381,
665     65,  -207, 77,   -59,  -651, -365, -797, -213, -15,  -155, -387, -259,
666     355, -182, -150, 990,  -231, 582,  747,  -36,  -540, 520,  -215, 633,
667     176, -540, -133, -491, -687, -426, -882, -102, 143,  77,   -639, -427,
668     65,  -37,  57,   127,  -17,  -105, 135,  -51,  60,   195,  -30,  78,
669     31,  -105, -57,  -125, -45,  17,   -11,  -147, -72,  -168, -84,  -57,
670     457, -233, 618,  733,  -26,  -540, 470,  -205, 264,  829,  -116, 336,
671     227, -693, -440, -900, -72,  153,  107,  -609, -294, -698, -342, -229,
672     157, -83,  69,   349,  -59,  -201, 409,  -125, 27,   587,  -115, 333,
673     77,  -243, -107, -267, -171, 7,    -105, -369, -135, -379, -339, -227,
674   };
675   float expected_12_valid[] = {
676     -30,  15,   -12,  6,    36,   -9,   -528, 377,  -184, 513,  558,  -12,
677     24,   24,   6,    6,    -30,  -15,  -33,  -21,  166,  154,  -546, -356,
678     -718, -30,  -21,  -39,  -657, 433,  -221, 561,  711,  -33,  -153, 247,
679     -83,  -87,  453,  -111, 321,  321,  215,  215,  -657, -409, -845, -93,
680     17,   -43,  -243, -55,  -215, -327, -219, -435, -207, 133,  -71,  -447,
681     633,  -219, 435,  785,  -73,  -177, 463,  -131, 381,  381,  65,   65,
682     -207, 77,   -59,  -651, -365, -797, -213, -15,  -155, -387, -259, -515,
683     -540, 355,  -182, -150, 990,  -231, 582,  747,  -36,  -540, 520,  -215,
684     633,  633,  176,  176,  -540, -133, -491, -687, -426, -882, -102, 143,
685     77,   -639, -427, -851, -105, 65,   -37,  57,   127,  -17,  -105, 135,
686     -51,  60,   195,  -30,  78,   78,   31,   31,   -105, -57,  -125, -45,
687     17,   -11,  -147, -72,  -168, -84,  -57,  -111, -693, 457,  -233, 618,
688     733,  -26,  -540, 470,  -205, 264,  829,  -116, 336,  336,  227,  227,
689     -693, -440, -900, -72,  153,  107,  -609, -294, -698, -342, -229, -455,
690     -243, 157,  -83,  69,   349,  -59,  -201, 409,  -125, 27,   587,  -115,
691     333,  333,  77,   77,   -243, -107, -267, -171, 7,    -105, -369, -135,
692     -379, -339, -227, -451,
693   };
694 
695   // Change skip_width, skip_height to {2, 3}
696   cnn_config.layer_config[0].skip_width = 3;
697   cnn_config.layer_config[0].skip_height = 2;
698   // Set padding to same
699   cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
700 
701   RunCNNTest(image_width, image_height, input, expected_12_same, &cnn_config,
702              image_width, &thread_data, MSE_INT_TOL);
703 
704   // Change padding to valid
705   cnn_config.layer_config[0].pad = PADDING_VALID;
706   RunCNNTest(image_width, image_height, input, expected_12_valid, &cnn_config,
707              image_width, &thread_data, MSE_INT_TOL);
708 
709   cnn_config.layer_config[0].filter_width = 4;
710   cnn_config.layer_config[0].filter_height = 3;
711   float weights_2[] = { -1, -3, -1, -3, 0, 2, -2, 4, 3, 0, 1, 4 };
712   float bias_2[] = { -4 };
713   cnn_config.layer_config[0].weights = weights_2;
714   cnn_config.layer_config[0].bias = bias_2;
715 
716   cnn_config.layer_config[0].skip_width = 5;
717   cnn_config.layer_config[0].skip_height = 2;
718   float expected_2_same[] = {
719     -13,  -31,  -13,  -31,  -4,   -10,  -22,  -10,  -22,  -4,   -185, -547,
720     -185, -547, -4,   -13,  -31,  -13,  -31,  -4,   -4,   14,   -22,  32,
721     -4,   -4,   8,    -16,  20,   -4,   -4,   358,  -366, 720,  -4,   -4,
722     14,   -22,  32,   -4,   -195, -658, -213, -622, -4,   -16,  -94,  -28,
723     -70,  -4,   459,  -244, 97,   480,  -4,   -85,  -328, -103, -292, -4,
724     -4,   432,  -440, 868,  -4,   -4,   56,   -64,  116,  -4,   -4,   156,
725     -164, 316,  -4,   -4,   212,  -220, 428,  -4,   582,  -208, 146,  664,
726     -4,   -130, -652, -190, -532, -4,   166,  -214, 6,    106,  -4,   192,
727     -388, -24,  44,   -4,   -4,   132,  -140, 268,  -4,   -4,   428,  -436,
728     860,  -4,   -4,   136,  -144, 276,  -4,   -4,   252,  -260, 508,  -4,
729     21,   -541, -115, -269, -4,   416,  -688, -16,  176,  -4,   173,  -103,
730     33,   177,  -4,   168,  -640, -88,  -128, -4,   -4,   354,  -362, 712,
731     -4,   -4,   452,  -460, 908,  -4,   -4,   62,   -70,  128,  -4,   -4,
732     420,  -428, 844,  -4,   499,  -106, 141,  610,  -4,   666,  -46,  210,
733     866,  -4,   47,   -148, -19,  -16,  -4,   605,  -85,  181,  763,  -4,
734     -4,   64,   -72,  132,  -4,   -4,   24,   -32,  52,   -4,   -4,   92,
735     -100, 188,  -4,   -4,   50,   -58,  104,  -4,   -132, -694, -200, -558,
736     -4,   15,   -73,  -13,  -17,  -4,   -62,  -610, -158, -418, -4,   -36,
737     -343, -90,  -235, -4,   -4,   456,  -464, 916,  -4,   -4,   42,   -50,
738     88,   -4,   -4,   400,  -408, 804,  -4,   -4,   222,  -230, 448,  -4,
739     606,  -244, 146,  676,  -4,   9,    -172, -37,  -80,  -4,   480,  -370,
740     76,   438,  -4,   223,  -340, -3,   112,  -4,   -4,   156,  -164, 316,
741     -4,   -4,   108,  -116, 220,  -4,   -4,   240,  -248, 484,  -4,   -4,
742     220,  -228, 444,  -4,
743   };
744   float expected_2_valid[] = {
745     -13,  -31,  -13,  -31,  -4,   -10,  -22,  -10,  -22,  -4,   -185, -547,
746     -185, -547, -4,   -13,  -31,  -13,  -31,  -4,   14,   -22,  32,   -4,
747     -4,   8,    -16,  20,   -4,   -4,   358,  -366, 720,  -4,   -4,   14,
748     -22,  32,   -195, -658, -213, -622, -4,   -16,  -94,  -28,  -70,  -4,
749     459,  -244, 97,   480,  -4,   -85,  -328, -103, -292, -4,   432,  -440,
750     868,  -4,   -4,   56,   -64,  116,  -4,   -4,   156,  -164, 316,  -4,
751     -4,   212,  -220, 428,  582,  -208, 146,  664,  -4,   -130, -652, -190,
752     -532, -4,   166,  -214, 6,    106,  -4,   192,  -388, -24,  44,   -4,
753     132,  -140, 268,  -4,   -4,   428,  -436, 860,  -4,   -4,   136,  -144,
754     276,  -4,   -4,   252,  -260, 508,  21,   -541, -115, -269, -4,   416,
755     -688, -16,  176,  -4,   173,  -103, 33,   177,  -4,   168,  -640, -88,
756     -128, -4,   354,  -362, 712,  -4,   -4,   452,  -460, 908,  -4,   -4,
757     62,   -70,  128,  -4,   -4,   420,  -428, 844,  499,  -106, 141,  610,
758     -4,   666,  -46,  210,  866,  -4,   47,   -148, -19,  -16,  -4,   605,
759     -85,  181,  763,  -4,   64,   -72,  132,  -4,   -4,   24,   -32,  52,
760     -4,   -4,   92,   -100, 188,  -4,   -4,   50,   -58,  104,  -132, -694,
761     -200, -558, -4,   15,   -73,  -13,  -17,  -4,   -62,  -610, -158, -418,
762     -4,   -36,  -343, -90,  -235, -4,   456,  -464, 916,  -4,   -4,   42,
763     -50,  88,   -4,   -4,   400,  -408, 804,  -4,   -4,   222,  -230, 448,
764     606,  -244, 146,  676,  -4,   9,    -172, -37,  -80,  -4,   480,  -370,
765     76,   438,  -4,   223,  -340, -3,   112,  -4,   156,  -164, 316,  -4,
766     -4,   108,  -116, 220,  -4,   -4,   240,  -248, 484,  -4,   -4,   220,
767     -228, 444,  236,  -4,   76,   316,  -4,   164,  -4,   52,   220,  -4,
768     362,  -4,   118,  484,  -4,   332,  -4,   108,  444,
769   };
770   // Set padding to same
771   cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
772 
773   RunCNNTest(image_width, image_height, input, expected_2_same, &cnn_config,
774              image_width, &thread_data, MSE_INT_TOL);
775 
776   cnn_config.layer_config[0].pad = PADDING_VALID;
777 
778   RunCNNTest(image_width, image_height, input, expected_2_valid, &cnn_config,
779              image_width, &thread_data, MSE_INT_TOL);
780 
781   cnn_config.layer_config[0].skip_width = 2;
782   cnn_config.layer_config[0].skip_height = 5;
783   float expected_21_same[] = {
784     -31,  -19,  -49,   -191, -565, -194, -574, -13,  14,   -22,  44,   -16,
785     382,  -366, 738,   -22,  -4,   23,   32,   545,  20,   204,  720,  5,
786     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
787     -4,   -4,   -4,    -4,   -658, -252, -748, -114, -334, -192, -568, -112,
788     432,  -440, 928,   -64,  276,  -164, 532,  -220, -4,   304,  868,  266,
789     116,  400,  316,   104,  -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
790     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -208, -288, -856, -290,
791     -862, -202, -598,  -132, 132,  -140, 700,  -436, 1000, -144, 532,  -260,
792     -4,   712,  268,   422,  860,  450,  276,  124,  -4,   -4,   -4,   -4,
793     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
794     -541, -411, -1225, -265, -787, -249, -739, -216, 354,  -362, 1168, -460,
795     974,  -70,  552,   -428, -4,   859,  712,  323,  908,  665,  128,  208,
796     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
797     -4,   -4,   -4,    -4,   -106, -52,  -148, -66,  -190, -79,  -229, -31,
798     64,   -72,  160,   -32,  148,  -100, 242,  -58,  -4,   72,   132,  154,
799     52,   125,  188,   23,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
800     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -694, -257, -763, -229,
801     -679, -319, -949,  -117, 456,  -464, 962,  -50,  492,  -408, 1030, -230,
802     -4,   295,  916,   625,  88,   537,  804,  109,  -4,   -4,   -4,   -4,
803     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
804     -244, -140, -412,  -182, -538, -238, -706, -116, 156,  -164, 428,  -116,
805     464,  -248, 708,   -228, -4,   244,  316,  418,  220,  454,  484,  108,
806     -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
807     -4,   -4,   -4,    -4,
808   };
809   float expected_21_valid[] = {
810     -13,  -31,  -19,  -49,  -191, -565, -194, -574, -13,  -31,   -4,   14,
811     -22,  44,   -16,  382,  -366, 738,  -22,  32,   23,   -4,    23,   32,
812     545,  20,   204,  720,  5,    32,   -4,   -4,   -4,   -4,    -4,   -4,
813     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
814     -4,   -4,   -222, -658, -252, -748, -114, -334, -192, -568,  -112, -328,
815     -4,   432,  -440, 928,  -64,  276,  -164, 532,  -220, 428,   650,  -4,
816     304,  868,  266,  116,  400,  316,  104,  428,  -4,   -4,    -4,   -4,
817     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
818     -4,   -4,   -4,   -4,   -72,  -208, -288, -856, -290, -862,  -202, -598,
819     -132, -388, -4,   132,  -140, 700,  -436, 1000, -144, 532,   -260, 508,
820     200,  -4,   712,  268,  422,  860,  450,  276,  124,  508,   -4,   -4,
821     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
822     -4,   -4,   -4,   -4,   -4,   -4,   -183, -541, -411, -1225, -265, -787,
823     -249, -739, -216, -640, -4,   354,  -362, 1168, -460, 974,   -70,  552,
824     -428, 844,  533,  -4,   859,  712,  323,  908,  665,  128,   208,  844,
825     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
826     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -38,  -106,  -52,  -148,
827     -66,  -190, -79,  -229, -31,  -85,  -4,   64,   -72,  160,   -32,  148,
828     -100, 242,  -58,  104,  98,   -4,   72,   132,  154,  52,    125,  188,
829     23,   104,  -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
830     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -234, -694,
831     -257, -763, -229, -679, -319, -949, -117, -343, -4,   456,   -464, 962,
832     -50,  492,  -408, 1030, -230, 448,  686,  -4,   295,  916,   625,  88,
833     537,  804,  109,  448,  -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
834     -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
835     -84,  -244, -140, -412, -182, -538, -238, -706, -116, -340,  -4,   156,
836     -164, 428,  -116, 464,  -248, 708,  -228, 444,  236,  -4,    244,  316,
837     418,  220,  454,  484,  108,  444,
838   };
839 
840   cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
841 
842   RunCNNTest(image_width, image_height, input, expected_21_same, &cnn_config,
843              image_width, &thread_data, MSE_INT_TOL);
844 
845   cnn_config.layer_config[0].pad = PADDING_VALID;
846 
847   RunCNNTest(image_width, image_height, input, expected_21_valid, &cnn_config,
848              image_width, &thread_data, MSE_INT_TOL);
849 }
850 
TEST_F(CNNTest,TestLargeKernelsAndStrides)851 TEST_F(CNNTest, TestLargeKernelsAndStrides) {
852   float input_10x11[] = {
853     4,  4,  2,  4,  2,  -5, -2, 3, -1, 0,  0,  1,  2,  0,  -5, -2, -5, 1,  -3,
854     -1, 4,  -3, 2,  -2, 1,  0,  1, -3, -3, -4, -2, -2, 1,  -4, -1, 4,  1,  -4,
855     -4, -4, 3,  2,  -5, 3,  -5, 1, 2,  -4, 1,  -1, 3,  4,  -2, 3,  -3, 3,  0,
856     2,  -4, -5, -5, -2, -1, -2, 1, 1,  1,  -2, 4,  -5, 4,  -1, -1, 2,  3,  -4,
857     2,  2,  3,  0,  0,  1,  0,  3, 2,  3,  1,  -2, 3,  -4, 3,  2,  4,  -2, 0,
858     4,  -4, 1,  -3, -3, -3, -5, 1, -3, -5, 0,  4,  -1, -3, 2,
859   };
860 
861   float weights_10x11[] = {
862     -3, 4,  -4, -3, -5, 1,  -2, 3,  1,  -4, -4, 0,  -1, 0,  3,  1,  -3, -2, 0,
863     -1, 1,  3,  -4, -4, -3, -3, -2, 4,  3,  -5, 4,  2,  -3, 4,  -2, -1, 2,  -1,
864     -5, 0,  -3, 0,  3,  -5, -5, 3,  -4, -1, -5, 3,  4,  0,  4,  -5, 2,  -1, 2,
865     -1, -1, -1, -5, 0,  -4, 3,  -1, 1,  1,  -1, 3,  2,  -5, -4, 0,  -4, 4,  -5,
866     -3, 4,  -5, 2,  -5, -4, -4, -1, 3,  3,  0,  2,  -4, 1,  -2, 1,  1,  0,  3,
867     -2, 0,  1,  2,  4,  -3, -1, -5, -5, 2,  -4, 1,  1,  2,  -4, -2, -2, 2,  1,
868     3,  4,  -5, 1,  -1, -3, -3, -1, -2, -5, 1,  -1, 0,  1,  4,  4,  0,  0,  4,
869     -3, -1, -5, -3, 0,  1,  1,  1,  -5, 3,  4,  3,  -5, 3,  -2, -2, 0,  -4, 0,
870     0,  -2, 1,  -4, -1, 0,  -5, -2, -2, -5, -3, -3, 1,  1,  -3, 2,  4,  2,  4,
871     -4, -3, 3,  1,  1,  3,  -4, 4,  -2, -3, -3, -3, -3, -4, -2, 3,  -5, 2,  4,
872     -1, -4, -4, 4,  -2, -1, 3,  -3, -4, -4, -2, 4,  1,  0,  2,  -1, 4,  -3, 1,
873     4,  -3, 4,  4,  0,  -4, 3,  -2, -3, 2,  3,  -1, -3, 2,  1,  4,  -2, -3, 1,
874     4,  -2, 2,  -2, -5, -2, 1,  4,  -1, -4, 4,  -5, 2,  -5, -4, -1, -2, 3,  1,
875     2,  1,  -5, 1,  -5, -4, -1, -2, 2,  -2, -4, -3, -2, -2, 4,  -1, 2,  2,  -4,
876     2,  -2, 4,  -4, -2, -2, 1,  -1, 1,  1,  1,  -4, -5, -2, 3,  -4, -1, 3,  -2,
877     3,  2,  -5, -4, 0,  3,  -2, -4, -5, 3,  -2, -4, 2,  -2, 1,  -4, 0,  2,  -5,
878     1,  -4, -1, -1, 4,  -5, -4, 0,  -5, -4, -3, -5, -4, 0,  2,  0,  -4, 2,  -2,
879     1,  1,  -3, 2,  0,  -4, 0,  -4, 1,  0,  -5, -1, -1, -1, -5, 4,  2,  2,  -4,
880     3,  -2, -2, 2,  -3, -2, -1, 2,  -4, -5, 2,  -2, -4, -5, -5, -1, 2,  -1, 0,
881     -5, -2, -2, -5, 0,  1,  -1, -5, 0,  3,  2,  3,  0,  -3, -2, 0,  -5, -1, -2,
882     2,  -4, -1, 2,  2,  -5, 2,  -4, 0,  3,  -3, 1,  0,  0,  1,  -5, -3, 1,  -1,
883     0,  -4, -3, 2,  -4, -4, 4,  -1, 0,  1,  2,  -4, -5, 4,  -2, 1,  -4, -4, -3,
884     -1, -1, 1,  -1, -4, -1, -4, -3, 2,  -1, -2, -4, 1,  1,  0,  -2, 0,  -4, 3,
885     -3, 0,  -4, -1, -4, 2,  -1, -2, -5, -1, -2, -3, 3,  -1, 0,  -3, 0,  1,  -5,
886     1,  -5, 0,  1,
887   };
888 
889   float bias_10x11[] = { 3 };
890 
891   float expected_10x11[] = {
892     118,
893   };
894 
895   CNN_CONFIG cnn_config = { 1,
896                             0,
897                             0,
898                             0,
899                             0,
900                             { {
901                                 1,
902                                 23,
903                                 20,
904                                 1,
905                                 15,
906                                 20,
907                                 0,
908                                 weights_10x11,
909                                 bias_10x11,
910                                 PADDING_SAME_ZERO,
911                                 NONE,
912                                 0,
913                                 0,
914                                 BRANCH_NO_COPY,
915                                 BRANCH_NOC,
916                                 {},
917                                 {},
918                                 0,
919                             } } };
920 
921   int image_height = 10;
922   int image_width = 11;
923 
924   CNN_THREAD_DATA thread_data = { 1, nullptr };
925 
926   RunCNNTest(image_width, image_height, input_10x11, expected_10x11,
927              &cnn_config, image_width, &thread_data, MSE_INT_TOL);
928 
929   float input_11x10[] = {
930     -2, -2, 3,  -5, -1, -3, 1,  3,  2,  1,  1,  -5, 4,  1,  3,  -5, 3,  -3, -5,
931     0,  -1, -3, -3, 1,  1,  -5, -1, -5, -5, -3, 0,  1,  -3, -1, -3, -3, 0,  3,
932     4,  -4, -1, 3,  -3, -1, -3, 1,  -3, -2, -1, -4, -3, 2,  -4, 1,  -4, -1, -3,
933     -5, -1, 2,  3,  0,  2,  2,  -5, 4,  1,  2,  -1, -4, 4,  -4, -4, 0,  -1, 1,
934     -1, 1,  -3, -3, -2, 1,  2,  4,  4,  4,  -3, -3, 0,  1,  0,  1,  4,  1,  3,
935     4,  -3, -2, -4, 4,  2,  0,  3,  4,  -1, 2,  -2, 1,  -3, -2,
936   };
937 
938   float weights_11x10[] = {
939     4,  -1, 1,  -1, 2,  4,  3,  3,  -4, 3,  -5, 1,  -1, -1, -2, -2, 0,  2,  -3,
940     -2, 3,  -5, -1, 0,  -1, -2, -2, -1, 2,  4,  3,  1,  0,  0,  -3, 3,  -4, -1,
941     -5, 4,  -2, -2, 1,  2,  -1, -3, 1,  2,  -5, 1,  -3, 3,  3,  0,  -4, -4, -5,
942     -3, -4, -4, 4,  -2, 4,  4,  -2, 2,  -5, -1, -2, -5, -1, 4,  -3, 3,  -2, 0,
943     -4, -3, 0,  -1, -2, 4,  2,  0,  -2, -5, -4, 1,  4,  -4, -2, 2,  -2, 1,  1,
944     -4, 1,  -4, -4, -2, 4,  2,  -1, -5, -5, 1,  -3, -3, 3,  -3, -5, -3, 4,  -1,
945     -1, -3, 0,  -4, 3,  -1, 0,  -2, 0,  -5, -2, -5, 2,  0,  -5, 2,  3,  -2, 2,
946     4,  -1, 1,  -3, 2,  3,  2,  0,  -5, -4, -5, 2,  1,  1,  -1, -2, 3,  4,  2,
947     -2, 4,  -2, 3,  1,  -4, -3, -1, 4,  4,  -3, -5, -2, 2,  0,  3,  -2, 3,  -1,
948     -4, 0,  -2, 0,  3,  4,  -2, -3, -2, 0,  3,  4,  2,  -4, 0,  1,  2,  2,  -1,
949     -1, 4,  1,  4,  -2, -1, -1, -5, 1,  -3, 3,  3,  -1, -4, 3,  -5, 0,  0,  -1,
950     -4, -1, -2, 4,  -2, 3,  3,  -3, 1,  -1, 2,  -1, 4,  4,  -2, -2, 4,  -2, 0,
951     3,  -3, -5, -1, -2, 4,  -4, 2,  -4, 0,  -2, 3,  -3, 2,  2,  -2, -5, -1, 4,
952     3,  -2, -1, 3,  3,  -1, 3,  0,  -3, 0,  4,  2,  0,  -1, 4,  1,  1,  2,  1,
953     3,  1,  1,  1,  -3, -5, -4, 4,  -4, 2,  0,  0,  -4, 1,  4,  -5, 4,  4,  0,
954     1,  0,  -2, -4, -4, -3, 0,  1,  -5, 4,  0,  -3, -2, -4, 2,  4,  1,  -5, 1,
955     -4, 1,  0,  -3, -3, 0,  2,  -5, 4,  3,  -2, -5, 3,  1,  -1, 0,  3,  -2, -2,
956     3,  -2, -5, 4,  1,  -2, 2,  -1, 0,  4,  0,  -5, 3,  -2, 1,  2,  1,  -5, -3,
957     -2, -5, 4,  -4, 0,  3,  2,  -1, -4, -1, 2,  1,  -2, 3,  -1, -4, 2,  0,  -3,
958     1,  -1, 2,  -5, -4, -1, -5, 1,  4,  3,  4,  2,  -3, 1,  -5, -1, 3,  0,  -1,
959     -4, 3,  4,  -5, 4,  4,  -3, 2,  -3, -1, -3, -5, -3, 2,  -3, -2, 1,  1,  0,
960     -5, 3,  2,  1,  -5, 1,  1,  1,  3,  4,  -4, -1, -2, 0,  -5, -3, -5, -2, -4,
961     3,  3,  3,  4,  0,  -4, -1, -5, 0,  -3, 1,  4,  4,  -4, 4,  -5, -5, -1, -2,
962     -5, 3,  -4, 4,  3,  0,  -3, 2,  -2, 0,  0,  4,  4,  0,  -2, 1,  -1, -3, 2,
963     -1, 1,  -3, -5,
964   };
965 
966   float bias_11x10[] = {
967     -5,
968   };
969 
970   float expected_11x10[] = {
971     36,  -84,  95,   45,  18,   46,   77,  -54, -99,  -149, 66,  49,  161, 11,
972     39,  61,   -66,  61,  4,    -3,   34,  -44, -23,  31,   64,  29,  47,  72,
973     -27, -27,  121,  -3,  100,  1,    30,  -78, -12,  -89,  -59, 8,   -16, 112,
974     91,  -102, -26,  -4,  30,   54,   4,   -84, -24,  -58,  27,  -53, -33, 5,
975     53,  -26,  63,   50,  -103, -130, -23, 6,   -104, -207, 73,  23,  77,  132,
976     38,  32,   -130, -44, -60,  7,    27,  176, 45,   -32,  -2,  99,  -97, 63,
977     69,  126,  47,   63,  136,  -57,  5,   16,  -40,  -157, 8,   38,  -44, -10,
978     91,  7,    122,  140, 30,   -105, 4,   -1,  113,  64,   180, 141,
979   };
980 
981   cnn_config.layer_config[0].weights = weights_11x10;
982   cnn_config.layer_config[0].bias = bias_11x10;
983   cnn_config.layer_config[0].filter_width = 20;
984   cnn_config.layer_config[0].filter_height = 23;
985   cnn_config.layer_config[0].skip_width = 1;
986   cnn_config.layer_config[0].skip_height = 1;
987   image_height = 11;
988   image_width = 10;
989 
990   RunCNNTest(image_width, image_height, input_11x10, expected_11x10,
991              &cnn_config, image_width, &thread_data, MSE_INT_TOL);
992 }
993 
TEST_F(CNNTest,TestSoftsignSingleLayer)994 TEST_F(CNNTest, TestSoftsignSingleLayer) {
995   int image_width = 8;
996   int image_height = 8;
997   int filter_height = 5;
998   int filter_width = 4;
999   float input[] = {
1000     -0.5220f, 0.8410f,  -0.8990f, -0.0090f, 0.6710f,  -0.9470f, -0.8240f,
1001     -0.0870f, 0.5380f,  0.4750f,  0.570f,   -0.3760f, -0.6960f, -0.5940f,
1002     -0.3830f, 0.080f,   -0.0980f, -0.4940f, -0.4030f, 0.9460f,  -0.6020f,
1003     0.4220f,  0.6190f,  0.6640f,  -0.9210f, -0.1470f, -0.2480f, -0.1120f,
1004     -0.580f,  -0.0650f, 0.3330f,  0.9860f,  -0.7430f, 0.7610f,  0.4840f,
1005     0.1030f,  0.9570f,  0.6120f,  -0.5240f, -0.1220f, -0.5850f, -0.270f,
1006     0.7840f,  -0.9790f, 0.7290f,  -0.30f,   -0.6460f, 0.0780f,  0.4750f,
1007     -0.0510f, 0.4550f,  0.3850f,  -0.7230f, 0.4460f,  -0.6260f, -0.810f,
1008     0.8720f,  -0.2120f, -0.580f,  -0.9510f, -0.8430f, -0.1340f, -0.0850f,
1009     0.9190f,
1010   };
1011   float expected_same[] = {
1012     0.430f,   0.660f,  0.5510f,  -0.610f,  0.450f,  -0.1610f, 0.0520f,  0.3240f,
1013     0.6820f,  0.3820f, 0.6360f,  0.7480f,  0.3080f, 0.090f,   0.3910f,  0.1730f,
1014     0.340f,   0.6660f, -0.4990f, 0.4280f,  0.1540f, 0.120f,   0.4670f,  0.6150f,
1015     -0.3880f, 0.7590f, 0.4190f,  0.7350f,  0.5310f, -0.5160f, -0.1760f, 0.6790f,
1016     -0.6780f, 0.5470f, 0.5750f,  -0.6420f, 0.7210f, -0.4620f, 0.5430f,  0.770f,
1017     -0.1990f, 0.3950f, 0.7860f,  -0.4380f, 0.7540f, 0.2640f,  -0.6430f, 0.4510f,
1018     -0.1260f, 0.1590f, -0.2110f, -0.0560f, 0.6570f, 0.680f,   0.5870f,  0.4720f,
1019     0.4040f,  0.3630f, 0.670f,   0.2360f,  0.410f,  0.6980f,  -0.5350f, 0.3940f,
1020   };
1021   float expected_replicate[] = {
1022     0.540f,   0.7230f,  -0.3530f, -0.2130f, 0.7440f,  -0.4470f, -0.6260f,
1023     -0.2050f, 0.7230f,  0.4630f,  0.5920f,  0.7440f,  0.6080f,  0.3130f,
1024     -0.5670f, -0.4720f, 0.5480f,  0.6660f,  -0.4990f, 0.4280f,  0.1540f,
1025     0.120f,   0.3390f,  0.6090f,  0.4160f,  0.7590f,  0.4190f,  0.7350f,
1026     0.5310f,  -0.5160f, -0.490f,  0.4450f,  -0.610f,  0.5470f,  0.5750f,
1027     -0.6420f, 0.7210f,  -0.4620f, 0.3150f,  0.7370f,  -0.5820f, 0.3950f,
1028     0.7860f,  -0.4380f, 0.7540f,  0.2640f,  -0.7430f, -0.5340f, -0.6270f,
1029     0.4430f,  0.4730f,  0.4570f,  0.7450f,  0.630f,   0.2620f,  0.3140f,
1030     -0.1840f, 0.1810f,  0.7210f,  0.2760f,  0.6430f,  0.6720f,  -0.4390f,
1031     0.2040f,
1032   };
1033   float expected_valid[] = {
1034     0.6660f,  -0.4990f, 0.4280f,  0.1540f,  0.120f,  0.7590f,  0.4190f,
1035     0.7350f,  0.5310f,  -0.5160f, 0.5470f,  0.5750f, -0.6420f, 0.7210f,
1036     -0.4620f, 0.3950f,  0.7860f,  -0.4380f, 0.7540f, 0.2640f,
1037   };
1038   float weights[] = {
1039     0.6210f,  0.3710f,  -0.2770f, -0.7230f, -0.2450f, 0.6770f,  0.3080f,
1040     -0.9880f, -0.080f,  0.7190f,  -0.6760f, -0.0170f, -0.8970f, 0.8260f,
1041     0.7390f,  -0.4550f, -0.4260f, -0.6330f, 0.0880f,  -0.9390f,
1042   };
1043   float bias[] = {
1044     0.750f,
1045   };
1046 
1047   CNN_CONFIG cnn_config = { 1,
1048                             0,
1049                             0,
1050                             0,
1051                             0,
1052                             { {
1053                                 1,
1054                                 filter_width,
1055                                 filter_height,
1056                                 1,
1057                                 1,
1058                                 1,
1059                                 0,
1060                                 weights,
1061                                 bias,
1062                                 PADDING_SAME_ZERO,
1063                                 SOFTSIGN,
1064                                 0,
1065                                 0,
1066                                 BRANCH_NO_COPY,
1067                                 BRANCH_NOC,
1068                                 {},
1069                                 {},
1070                                 0,
1071                             } } };
1072 
1073   CNN_THREAD_DATA thread_data = { 1, nullptr };
1074 
1075   RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
1076              image_width, &thread_data, MSE_FLOAT_TOL);
1077 
1078   cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
1079 
1080   RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
1081              image_width, &thread_data, MSE_FLOAT_TOL);
1082 
1083   cnn_config.layer_config[0].pad = PADDING_VALID;
1084 
1085   RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
1086              image_width, &thread_data, MSE_FLOAT_TOL);
1087 }
1088 
TEST_F(CNNTest,TestBranchTensorAdd)1089 TEST_F(CNNTest, TestBranchTensorAdd) {
1090   int filter_width = 2;
1091   int filter_height = 3;
1092 
1093   int image_width = 4;
1094   int image_height = 4;
1095 
1096   float input[] = {
1097     -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1098   };
1099 
1100   float weights[] = {
1101     -3, -1, 4,  -1, -3, 3,  3,  0,  2,  0,  3,  2,  4,  4, 4,  -5, 1, -4,
1102     2,  -4, 1,  -3, 0,  4,  -5, 4,  0,  -4, -3, -1, 0,  0, -2, 0,  0, 2,
1103     -5, -1, 1,  -3, 3,  4,  3,  0,  1,  -1, 1,  1,  2,  4, -2, -5, 2, -2,
1104     3,  -2, 4,  -1, 0,  2,  3,  2,  -2, -1, -3, 1,  3,  4, -1, -3, 0, -4,
1105     4,  2,  -3, -3, -1, 0,  1,  0,  3,  3,  -3, 0,  3,  2, -5, -3, 4, -5,
1106     3,  -1, -1, -3, 0,  1,  -1, -4, 2,  4,  -1, 4,  -1, 1, 3,  4,  4, 4,
1107     0,  -1, -3, -3, -3, -3, 2,  -3, -2, 2,  3,  -3,
1108   };
1109 
1110   float bias[] = {
1111     3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
1112   };
1113 
1114   float expected[] = {
1115     -11502, -4101, -3424, 668,   -17950, -5470, -5504, 626,
1116     4835,   446,   1779,  -3483, 3679,   -4214, 4578,  -105,
1117   };
1118 
1119   int channels = 2;
1120 
1121   CNN_CONFIG cnn_config = { 6,
1122                             0,
1123                             0,
1124                             0,
1125                             0,
1126                             { {
1127                                   1,
1128                                   filter_width,
1129                                   filter_height,
1130                                   channels,
1131                                   1,
1132                                   1,
1133                                   0,
1134                                   weights,
1135                                   bias,
1136                                   PADDING_SAME_ZERO,
1137                                   NONE,
1138                                   0,
1139                                   0,
1140                                   BRANCH_NO_COPY,
1141                                   BRANCH_NOC,
1142                                   {},
1143                                   {},
1144                                   -1,
1145                               },
1146                               {
1147                                   channels,
1148                                   filter_width,
1149                                   filter_height,
1150                                   channels,
1151                                   1,
1152                                   1,
1153                                   0,
1154                                   nullptr,
1155                                   nullptr,
1156                                   PADDING_SAME_ZERO,
1157                                   NONE,
1158                                   0,
1159                                   0,
1160                                   BRANCH_INPUT,
1161                                   BRANCH_NOC,
1162                                   {
1163                                       0x02,
1164                                       0,
1165                                       0x00,
1166                                   },
1167                                   {},
1168                                   -1,
1169                               },
1170                               {
1171                                   channels,
1172                                   filter_width,
1173                                   filter_height,
1174                                   channels,
1175                                   1,
1176                                   1,
1177                                   0,
1178                                   nullptr,
1179                                   nullptr,
1180                                   PADDING_SAME_ZERO,
1181                                   NONE,
1182                                   0,
1183                                   1,
1184                                   BRANCH_NO_COPY,
1185                                   BRANCH_NOC,
1186                                   {},
1187                                   {},
1188                                   -1,
1189                               },
1190                               {
1191                                   channels,
1192                                   filter_width,
1193                                   filter_height,
1194                                   channels,
1195                                   1,
1196                                   1,
1197                                   0,
1198                                   nullptr,
1199                                   nullptr,
1200                                   PADDING_SAME_ZERO,
1201                                   NONE,
1202                                   0,
1203                                   1,
1204                                   BRANCH_NO_COPY,
1205                                   BRANCH_NOC,
1206                                   {},
1207                                   {},
1208                                   -1,
1209                               },
1210                               {
1211                                   channels,
1212                                   filter_width,
1213                                   filter_height,
1214                                   channels,
1215                                   1,
1216                                   1,
1217                                   0,
1218                                   nullptr,
1219                                   nullptr,
1220                                   PADDING_SAME_ZERO,
1221                                   NONE,
1222                                   0,
1223                                   0,
1224                                   BRANCH_NO_COPY,
1225                                   BRANCH_ADD,
1226                                   {
1227                                       0x00,
1228                                       0,
1229                                       0x02,
1230                                   },
1231                                   {},
1232                                   -1,
1233                               },
1234                               {
1235                                   channels,
1236                                   filter_width,
1237                                   filter_height,
1238                                   1,
1239                                   1,
1240                                   1,
1241                                   0,
1242                                   nullptr,
1243                                   nullptr,
1244                                   PADDING_SAME_ZERO,
1245                                   NONE,
1246                                   0,
1247                                   0,
1248                                   BRANCH_NO_COPY,
1249                                   BRANCH_NOC,
1250                                   {},
1251                                   {},
1252                                   0,
1253                               } } };
1254 
1255   // Weights and biases need to be specified separately because
1256   // of the offset.
1257   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1258 
1259   CNN_THREAD_DATA thread_data = { 1, nullptr };
1260 
1261   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1262              image_width, &thread_data, MSE_INT_TOL);
1263 }
1264 
TEST_F(CNNTest,TestBranchTensorConcatenation)1265 TEST_F(CNNTest, TestBranchTensorConcatenation) {
1266   int filter_width = 2;
1267   int filter_height = 3;
1268 
1269   int image_width = 4;
1270   int image_height = 4;
1271 
1272   float input[] = {
1273     -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1274   };
1275 
1276   float weights[] = {
1277     3,  0,  2,  0,  2,  3,  1,  -3, 1,  -5, -3, 0,  -4, 4,  0,  -5, 0,  -5, -1,
1278     -2, -5, 0,  -3, 2,  -4, 2,  0,  2,  -1, 0,  -4, 3,  0,  0,  -1, -5, 2,  -1,
1279     4,  -4, -2, -3, -3, 3,  4,  -2, -1, -4, -1, 4,  4,  -1, 4,  3,  -4, 2,  -2,
1280     -4, -3, -2, 3,  -3, -5, -1, 3,  -2, 4,  1,  -4, -3, -5, -5, -3, 4,  -2, -2,
1281     -1, -5, -5, 0,  -1, -2, -3, 3,  -4, -5, 2,  -3, 1,  0,  -5, 2,  2,  -2, 0,
1282     2,  2,  -2, 4,  2,  2,  0,  1,  -5, -3, 0,  2,  -2, 1,  2,  -5, 2,  3,  3,
1283     -1, 3,  0,  -3, 3,  -4, -4, 3,  3,  -4, -2, 2,  -2, 2,  -2, -1, 3,  0,
1284   };
1285 
1286   float bias[] = {
1287     -3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
1288   };
1289 
1290   float expected[] = {
1291     -33533, -32087, -6741,  -2124, 39979, 41453, 14034, 689,
1292     -22611, -42203, -14882, -239,  15781, 15963, 9524,  837,
1293   };
1294 
1295   int channels = 2;
1296 
1297   CNN_CONFIG cnn_config = { 6,
1298                             0,
1299                             0,
1300                             0,
1301                             0,
1302                             { {
1303                                   1,
1304                                   filter_width,
1305                                   filter_height,
1306                                   channels,
1307                                   1,
1308                                   1,
1309                                   0,
1310                                   weights,
1311                                   bias,
1312                                   PADDING_SAME_ZERO,
1313                                   NONE,
1314                                   0,
1315                                   0,
1316                                   BRANCH_NO_COPY,
1317                                   BRANCH_NOC,
1318                                   {},
1319                                   {},
1320                                   -1,
1321                               },
1322                               {
1323                                   channels,
1324                                   filter_width,
1325                                   filter_height,
1326                                   channels,
1327                                   1,
1328                                   1,
1329                                   0,
1330                                   nullptr,
1331                                   nullptr,
1332                                   PADDING_SAME_ZERO,
1333                                   NONE,
1334                                   0,
1335                                   0,
1336                                   BRANCH_INPUT,
1337                                   BRANCH_NOC,
1338                                   {
1339                                       0x02,
1340                                       0,
1341                                       0x00,
1342                                   },
1343                                   {},
1344                                   -1,
1345                               },
1346                               {
1347                                   channels,
1348                                   filter_width,
1349                                   filter_height,
1350                                   channels,
1351                                   1,
1352                                   1,
1353                                   0,
1354                                   nullptr,
1355                                   nullptr,
1356                                   PADDING_SAME_ZERO,
1357                                   NONE,
1358                                   0,
1359                                   1,
1360                                   BRANCH_NO_COPY,
1361                                   BRANCH_NOC,
1362                                   {},
1363                                   {},
1364                                   -1,
1365                               },
1366                               {
1367                                   channels,
1368                                   filter_width,
1369                                   filter_height,
1370                                   channels,
1371                                   1,
1372                                   1,
1373                                   0,
1374                                   nullptr,
1375                                   nullptr,
1376                                   PADDING_SAME_ZERO,
1377                                   NONE,
1378                                   0,
1379                                   1,
1380                                   BRANCH_NO_COPY,
1381                                   BRANCH_NOC,
1382                                   {},
1383                                   {},
1384                                   -1,
1385                               },
1386                               {
1387                                   channels,
1388                                   filter_width,
1389                                   filter_height,
1390                                   channels,
1391                                   1,
1392                                   1,
1393                                   0,
1394                                   nullptr,
1395                                   nullptr,
1396                                   PADDING_SAME_ZERO,
1397                                   NONE,
1398                                   0,
1399                                   0,
1400                                   BRANCH_NO_COPY,
1401                                   BRANCH_CAT,
1402                                   {
1403                                       0x00,
1404                                       0,
1405                                       0x02,
1406                                   },
1407                                   {},
1408                                   -1,
1409                               },
1410                               {
1411                                   channels + channels,
1412                                   filter_width,
1413                                   filter_height,
1414                                   1,
1415                                   1,
1416                                   1,
1417                                   0,
1418                                   nullptr,
1419                                   nullptr,
1420                                   PADDING_SAME_ZERO,
1421                                   NONE,
1422                                   0,
1423                                   0,
1424                                   BRANCH_NO_COPY,
1425                                   BRANCH_NOC,
1426                                   {},
1427                                   {},
1428                                   0,
1429                               } } };
1430 
1431   // Weights and biases need to be specified separately because
1432   // of the offset.
1433   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1434 
1435   CNN_THREAD_DATA thread_data = { 1, nullptr };
1436 
1437   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1438              image_width, &thread_data, MSE_INT_TOL);
1439 }
1440 
1441 // TODO(logangw): Add test to test all combinations of branch_copy_type.
1442 
TEST_F(CNNTest,TestBranchCombinations)1443 TEST_F(CNNTest, TestBranchCombinations) {
1444   int filter_width = 2;
1445   int filter_height = 3;
1446 
1447   int image_width = 4;
1448   int image_height = 4;
1449 
1450   float input[] = {
1451     3, 2, -5, -4, 4, -2, -4, -3, 4, 2, -3, 2, -3, 1, -5, -1,
1452   };
1453 
1454   float weights[] = {
1455     2,  3,  0,  4,  4,  3,  1,  0,  1,  -5, 4,  -3, 3,  0,  4,  -1, -1, -5,
1456     2,  1,  -3, -5, 3,  -1, -3, -2, 0,  -2, 3,  0,  -2, -4, -2, -2, 2,  -5,
1457     4,  -5, 0,  1,  -5, -4, -3, -4, 2,  -2, 1,  0,  3,  -2, -4, 3,  4,  -4,
1458     -1, -1, -3, -2, -2, -1, 2,  0,  2,  -1, 2,  -4, -4, -1, 2,  0,  3,  -2,
1459     -2, 3,  -3, 4,  -2, 4,  3,  4,  1,  0,  -2, -3, -5, 1,  -3, 2,  0,  -2,
1460     -2, -1, -1, -5, -2, -3, -1, 3,  3,  4,  4,  0,  2,  1,  3,  -3, 2,  -5,
1461     -5, 1,  -5, -1, 3,  3,  2,  -4, -1, 3,  -4, -2, -5, -2, 1,  3,  2,  2,
1462     -5, -2, -3, -1, -2, -4, -1, -2, 2,  1,  -4, -4, 2,  0,  2,  0,  2,  -3,
1463     -2, -4, 4,  0,  1,  -3, -5, 4,  -1, 2,  3,  -5, -1, 0,  4,  -1, -1, 3,
1464     -1, -3, 3,  1,  4,  3,  4,  3,  -4, -5, -1, 3,  3,  -4, 3,  1,  3,  -5,
1465     3,  4,  -5, 4,  2,  -1, -5, 2,  1,  0,  4,  0,  -3, 2,  0,  2,  -2, 1,
1466     -1, -2, -1, -5, 4,  3,  3,  -2, 2,  4,  -5, -5, -3, -2, 4,  0,  -4, 1,
1467   };
1468 
1469   float bias[] = {
1470     -1, 4, 0, 2, 2, -2, 0, -4, -5, -1, 1, -2, 3, 0, 4, -2, 1, 0, 0,
1471   };
1472 
1473   float expected[] = {
1474     149496, 15553,  -24193, -20956, 134094, 86432,  -68283, -6366,
1475     -53031, 133739, 67407,  -13539, -53205, -58635, -20033, 1979,
1476   };
1477 
1478   int channels = 2;
1479 
1480   CNN_CONFIG cnn_config = { 10,
1481                             0,
1482                             0,
1483                             0,
1484                             0,
1485                             {
1486                                 {
1487                                     1,
1488                                     filter_width,
1489                                     filter_height,
1490                                     channels,
1491                                     1,
1492                                     1,
1493                                     0,
1494                                     weights,
1495                                     bias,
1496                                     PADDING_SAME_ZERO,
1497                                     NONE,
1498                                     0,
1499                                     0,
1500                                     BRANCH_NO_COPY,
1501                                     BRANCH_NOC,
1502                                     {},
1503                                     {},
1504                                     -1,
1505                                 },
1506                                 {
1507                                     channels,
1508                                     filter_width,
1509                                     filter_height,
1510                                     channels,
1511                                     1,
1512                                     1,
1513                                     0,
1514                                     nullptr,
1515                                     nullptr,
1516                                     PADDING_SAME_ZERO,
1517                                     NONE,
1518                                     0,
1519                                     0,
1520                                     BRANCH_INPUT,
1521                                     BRANCH_NOC,
1522                                     {
1523                                         0x06,
1524                                         0,
1525                                         0x00,
1526                                     },
1527                                     {},
1528                                     -1,
1529                                 },
1530                                 {
1531                                     channels,
1532                                     filter_width,
1533                                     filter_height,
1534                                     channels,
1535                                     1,
1536                                     1,
1537                                     0,
1538                                     nullptr,
1539                                     nullptr,
1540                                     PADDING_SAME_ZERO,
1541                                     NONE,
1542                                     0,
1543                                     2,
1544                                     BRANCH_OUTPUT,
1545                                     BRANCH_NOC,
1546                                     {
1547                                         0x08,
1548                                         0,
1549                                         0x00,
1550                                     },
1551                                     {},
1552                                     -1,
1553                                 },
1554                                 {
1555                                     channels,
1556                                     filter_width,
1557                                     filter_height,
1558                                     channels,
1559                                     1,
1560                                     1,
1561                                     0,
1562                                     nullptr,
1563                                     nullptr,
1564                                     PADDING_SAME_ZERO,
1565                                     NONE,
1566                                     0,
1567                                     3,
1568                                     BRANCH_NO_COPY,
1569                                     BRANCH_NOC,
1570                                     {},
1571                                     {},
1572                                     -1,
1573                                 },
1574                                 {
1575                                     channels,
1576                                     filter_width,
1577                                     filter_height,
1578                                     channels,
1579                                     1,
1580                                     1,
1581                                     0,
1582                                     nullptr,
1583                                     nullptr,
1584                                     PADDING_SAME_ZERO,
1585                                     NONE,
1586                                     0,
1587                                     2,
1588                                     BRANCH_NO_COPY,
1589                                     BRANCH_ADD,
1590                                     {
1591                                         0x00,
1592                                         0,
1593                                         0x08,
1594                                     },
1595                                     {},
1596                                     -1,
1597                                 },
1598                                 {
1599                                     channels,
1600                                     filter_width,
1601                                     filter_height,
1602                                     channels,
1603                                     1,
1604                                     1,
1605                                     0,
1606                                     nullptr,
1607                                     nullptr,
1608                                     PADDING_SAME_ZERO,
1609                                     NONE,
1610                                     0,
1611                                     2,
1612                                     BRANCH_NO_COPY,
1613                                     BRANCH_NOC,
1614                                     {},
1615                                     {},
1616                                     -1,
1617                                 },
1618                                 {
1619                                     channels,
1620                                     filter_width,
1621                                     filter_height,
1622                                     channels,
1623                                     1,
1624                                     1,
1625                                     0,
1626                                     nullptr,
1627                                     nullptr,
1628                                     PADDING_SAME_ZERO,
1629                                     NONE,
1630                                     0,
1631                                     1,
1632                                     BRANCH_NO_COPY,
1633                                     BRANCH_NOC,
1634                                     {},
1635                                     {},
1636                                     -1,
1637                                 },
1638                                 {
1639                                     channels,
1640                                     filter_width,
1641                                     filter_height,
1642                                     channels,
1643                                     1,
1644                                     1,
1645                                     0,
1646                                     nullptr,
1647                                     nullptr,
1648                                     PADDING_SAME_ZERO,
1649                                     NONE,
1650                                     0,
1651                                     1,
1652                                     BRANCH_NO_COPY,
1653                                     BRANCH_ADD,
1654                                     {
1655                                         0x00,
1656                                         0,
1657                                         0x0C,
1658                                     },
1659                                     {},
1660                                     -1,
1661                                 },
1662                                 {
1663                                     channels,
1664                                     filter_width,
1665                                     filter_height,
1666                                     channels,
1667                                     1,
1668                                     1,
1669                                     0,
1670                                     nullptr,
1671                                     nullptr,
1672                                     PADDING_SAME_ZERO,
1673                                     NONE,
1674                                     0,
1675                                     0,
1676                                     BRANCH_NO_COPY,
1677                                     BRANCH_ADD,
1678                                     {
1679                                         0x00,
1680                                         0,
1681                                         0x02,
1682                                     },
1683                                     {},
1684                                     -1,
1685                                 },
1686                                 {
1687                                     channels,
1688                                     filter_width,
1689                                     filter_height,
1690                                     1,
1691                                     1,
1692                                     1,
1693                                     0,
1694                                     nullptr,
1695                                     nullptr,
1696                                     PADDING_SAME_ZERO,
1697                                     NONE,
1698                                     0,
1699                                     0,
1700                                     BRANCH_NO_COPY,
1701                                     BRANCH_NOC,
1702                                     {},
1703                                     {},
1704                                     0,
1705                                 },
1706                             } };
1707 
1708   // Weights and biases need to be specified separately because
1709   // of the offset.
1710   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1711 
1712   CNN_THREAD_DATA thread_data = { 1, nullptr };
1713 
1714   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1715              image_width, &thread_data, MSE_INT_TOL);
1716 }
1717 
TEST_F(CNNTest,TestSplittingTensors)1718 TEST_F(CNNTest, TestSplittingTensors) {
1719   int filter_width = 2;
1720   int filter_height = 3;
1721 
1722   int image_width = 4;
1723   int image_height = 4;
1724 
1725   float input[] = {
1726     -1, -1, 2, 1, 3, 2, 4, -3, -4, -2, 2, -3, 1, -3, 4, -2,
1727   };
1728 
1729   float weights[] = {
1730     -4, 1,  0,  2,  3,  4,  4,  -4, -5, -3, 2,  2,  -4, -3, 3,  2,
1731     4,  -4, -3, -4, -4, 1,  -3, -5, -3, 4,  2,  -2, 2,  -1, -4, -1,
1732     -2, -3, 1,  1,  0,  -5, -1, 3,  3,  -5, -3, 0,  -3, 1,  -3, -1,
1733     1,  -3, -2, -2, 4,  -2, 0,  1,  2,  2,  -4, 2,  4,  0,  -5, -2,
1734     4,  4,  -5, 1,  0,  2,  -2, -5, -5, -3, -5, -5, 4,  -3, 0,  0,
1735     -4, -4, 0,  -5, -4, 0,  0,  -3, -5, -3, -1, 2,  -1, 4,  -1, 2,
1736   };
1737 
1738   float bias[] = {
1739     -4, -2, -3, -3, 3, 1, -2,
1740   };
1741 
1742   float expected[] = {
1743     530,  -762,  1469, 777,  849,   -771, -1698, 600,
1744     -658, -1821, 98,   -668, -1798, 30,   887,   -971,
1745   };
1746 
1747   CNN_CONFIG cnn_config = { 3,
1748                             0,
1749                             0,
1750                             0,
1751                             0,
1752                             {
1753                                 {
1754                                     1,
1755                                     filter_width,
1756                                     filter_height,
1757                                     4,
1758                                     1,
1759                                     1,
1760                                     0,
1761                                     nullptr,
1762                                     nullptr,
1763                                     PADDING_SAME_ZERO,
1764                                     NONE,
1765                                     0,
1766                                     0,
1767                                     BRANCH_OUTPUT,
1768                                     BRANCH_NOC,
1769                                     {
1770                                         0x02,
1771                                         2,
1772                                         0x00,
1773                                     },
1774                                     {},
1775                                     -1,
1776                                 },
1777                                 {
1778                                     4,
1779                                     filter_width,
1780                                     filter_height,
1781                                     2,
1782                                     1,
1783                                     1,
1784                                     0,
1785                                     nullptr,
1786                                     nullptr,
1787                                     PADDING_SAME_ZERO,
1788                                     NONE,
1789                                     0,
1790                                     0,
1791                                     BRANCH_NO_COPY,
1792                                     BRANCH_CAT,
1793                                     {
1794                                         0x00,
1795                                         0,
1796                                         0x02,
1797                                     },
1798                                     {},
1799                                     -1,
1800                                 },
1801                                 {
1802                                     4,
1803                                     filter_width,
1804                                     filter_height,
1805                                     1,
1806                                     1,
1807                                     1,
1808                                     0,
1809                                     nullptr,
1810                                     nullptr,
1811                                     PADDING_SAME_ZERO,
1812                                     NONE,
1813                                     0,
1814                                     0,
1815                                     BRANCH_NO_COPY,
1816                                     BRANCH_NOC,
1817                                     {},
1818                                     {},
1819                                     0,
1820                                 },
1821                             } };
1822 
1823   // Weights and biases need to be specified separately because
1824   // of the offset.
1825   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1826 
1827   CNN_THREAD_DATA thread_data = { 1, nullptr };
1828 
1829   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1830              image_width, &thread_data, MSE_INT_TOL);
1831 }
1832 
TEST_F(CNNTest,TestOutputChannelsCount)1833 TEST_F(CNNTest, TestOutputChannelsCount) {
1834   int filter_width = 1;
1835   int filter_height = 1;
1836 
1837   int image_width = 2;
1838   int image_height = 2;
1839 
1840   float input[] = { 0, 0, 0, 0 };
1841 
1842   float weights[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
1843 
1844   float bias[] = { 0, 0, 0, 0, 0, 0 };
1845 
1846   float expected[] = {
1847     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1848     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1849   };
1850 
1851   CNN_CONFIG cnn_config = { 3,
1852                             0,
1853                             0,
1854                             0,
1855                             0,
1856                             {
1857                                 {
1858                                     1,
1859                                     filter_width,
1860                                     filter_height,
1861                                     2,
1862                                     1,
1863                                     1,
1864                                     0,
1865                                     weights,
1866                                     bias,
1867                                     PADDING_SAME_ZERO,
1868                                     NONE,
1869                                     0,
1870                                     0,
1871                                     BRANCH_INPUT,
1872                                     BRANCH_NOC,
1873                                     {
1874                                         0x06,
1875                                         0,
1876                                         0x00,
1877                                     },
1878                                     {},
1879                                     -1,
1880                                 },
1881                                 {
1882                                     1,
1883                                     filter_width,
1884                                     filter_height,
1885                                     2,
1886                                     1,
1887                                     1,
1888                                     0,
1889                                     weights,
1890                                     bias,
1891                                     PADDING_SAME_ZERO,
1892                                     NONE,
1893                                     0,
1894                                     2,
1895                                     BRANCH_NO_COPY,
1896                                     BRANCH_CAT,
1897                                     {
1898                                         0x00,
1899                                         0,
1900                                         0x03,
1901                                     },
1902                                     {},
1903                                     -1,
1904                                 },
1905                                 {
1906                                     2,
1907                                     filter_width,
1908                                     filter_height,
1909                                     2,
1910                                     1,
1911                                     1,
1912                                     0,
1913                                     weights,
1914                                     bias,
1915                                     PADDING_SAME_ZERO,
1916                                     NONE,
1917                                     0,
1918                                     0,
1919                                     BRANCH_NO_COPY,
1920                                     BRANCH_CAT,
1921                                     {
1922                                         0x00,
1923                                         0,
1924                                         0x04,
1925                                     },
1926                                     {},
1927                                     0,
1928                                 },
1929                             } };
1930 
1931   // Weights and biases need to be specified separately because
1932   // of the offset.
1933   AssignLayerWeightsBiases(&cnn_config, weights, bias);
1934 
1935   CNN_THREAD_DATA thread_data = { 1, nullptr };
1936 
1937   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1938              image_width, &thread_data, MSE_FLOAT_TOL);
1939 }
1940 
TEST_F(CNNTest,TestBatchNorm)1941 TEST_F(CNNTest, TestBatchNorm) {
1942   int image_width = 28;
1943   int image_height = 28;
1944   int filter_height = 7;
1945   int filter_width = 7;
1946   float input[] = {
1947     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1948     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1949     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1950     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1951     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1952     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1953     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1954     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1955     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1956     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1957     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1958     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1959     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1960     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1961     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1962     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1963     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1964     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1965     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1966     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1967     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1968     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1969     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1970     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1971     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1972     0.0f,       0.0f,       0.0117647f,  0.0705882f,  0.0705882f,  0.0705882f,
1973     0.494118f,  0.533333f,  0.686275f,   0.101961f,   0.65098f,    1.0f,
1974     0.968627f,  0.498039f,  0.0f,        0.0f,        0.0f,        0.0f,
1975     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1976     0.0f,       0.0f,       0.117647f,   0.141176f,   0.368627f,   0.603922f,
1977     0.666667f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.992157f,
1978     0.882353f,  0.67451f,   0.992157f,   0.94902f,    0.764706f,   0.25098f,
1979     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1980     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.192157f,
1981     0.933333f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.992157f,
1982     0.992157f,  0.992157f,  0.992157f,   0.984314f,   0.364706f,   0.321569f,
1983     0.321569f,  0.219608f,  0.152941f,   0.0f,        0.0f,        0.0f,
1984     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1985     0.0f,       0.0f,       0.0f,        0.0705882f,  0.858824f,   0.992157f,
1986     0.992157f,  0.992157f,  0.992157f,   0.992157f,   0.776471f,   0.713725f,
1987     0.968627f,  0.945098f,  0.0f,        0.0f,        0.0f,        0.0f,
1988     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1989     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1990     0.0f,       0.0f,       0.313725f,   0.611765f,   0.419608f,   0.992157f,
1991     0.992157f,  0.803922f,  0.0431373f,  0.0f,        0.168627f,   0.603922f,
1992     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1993     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1994     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1995     0.0f,       0.054902f,  0.00392157f, 0.603922f,   0.992157f,   0.352941f,
1996     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1997     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1998     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
1999     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2000     0.0f,       0.545098f,  0.992157f,   0.745098f,   0.00784314f, 0.0f,
2001     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2002     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2003     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2004     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0431373f,
2005     0.745098f,  0.992157f,  0.27451f,    0.0f,        0.0f,        0.0f,
2006     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2007     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2008     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2009     0.0f,       0.0f,       0.0f,        0.0f,        0.137255f,   0.945098f,
2010     0.882353f,  0.627451f,  0.423529f,   0.00392157f, 0.0f,        0.0f,
2011     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2012     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2013     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2014     0.0f,       0.0f,       0.0f,        0.317647f,   0.941176f,   0.992157f,
2015     0.992157f,  0.466667f,  0.0980392f,  0.0f,        0.0f,        0.0f,
2016     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2017     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2018     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2019     0.0f,       0.0f,       0.176471f,   0.729412f,   0.992157f,   0.992157f,
2020     0.588235f,  0.105882f,  0.0f,        0.0f,        0.0f,        0.0f,
2021     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2022     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2023     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2024     0.0f,       0.0627451f, 0.364706f,   0.988235f,   0.992157f,   0.733333f,
2025     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2026     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2027     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2028     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2029     0.0f,       0.976471f,  0.992157f,   0.976471f,   0.25098f,    0.0f,
2030     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2031     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2032     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2033     0.0f,       0.0f,       0.180392f,   0.509804f,   0.717647f,   0.992157f,
2034     0.992157f,  0.811765f,  0.00784314f, 0.0f,        0.0f,        0.0f,
2035     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2036     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2037     0.0f,       0.0f,       0.0f,        0.0f,        0.152941f,   0.580392f,
2038     0.898039f,  0.992157f,  0.992157f,   0.992157f,   0.980392f,   0.713725f,
2039     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2040     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2041     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2042     0.0941176f, 0.447059f,  0.866667f,   0.992157f,   0.992157f,   0.992157f,
2043     0.992157f,  0.788235f,  0.305882f,   0.0f,        0.0f,        0.0f,
2044     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2045     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2046     0.0f,       0.0f,       0.0901961f,  0.258824f,   0.835294f,   0.992157f,
2047     0.992157f,  0.992157f,  0.992157f,   0.776471f,   0.317647f,   0.00784314f,
2048     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2049     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2050     0.0f,       0.0f,       0.0f,        0.0f,        0.0705882f,  0.670588f,
2051     0.858824f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.764706f,
2052     0.313725f,  0.0352941f, 0.0f,        0.0f,        0.0f,        0.0f,
2053     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2054     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2055     0.215686f,  0.67451f,   0.886275f,   0.992157f,   0.992157f,   0.992157f,
2056     0.992157f,  0.956863f,  0.521569f,   0.0431373f,  0.0f,        0.0f,
2057     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2058     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2059     0.0f,       0.0f,       0.0f,        0.0f,        0.533333f,   0.992157f,
2060     0.992157f,  0.992157f,  0.831373f,   0.529412f,   0.517647f,   0.0627451f,
2061     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2062     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2063     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2064     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2065     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2066     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2067     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2068     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2069     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2070     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2071     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2072     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2073     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2074     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2075     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2076     0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
2077     0.0f,       0.0f,       0.0f,        0.0f
2078   };
2079   float expected[] = {
2080     -0.836424f, -0.857365f, -1.62739f,  -1.62739f,  -0.836424f, 5.40742f,
2081     0.920853f,  -0.692567f, -0.836424f, -0.534405f, -1.62739f,  -0.836424f,
2082     1.32602f,   1.36312f,   0.112766f,  -0.836424f, -0.192962f, 1.56975f,
2083     2.45777f,   0.944414f,  -0.192962f, -1.5519f,   -1.5519f,   -0.554006f,
2084     -0.192962f, 1.4231f,    -1.5519f,   -0.192962f, 1.3661f,    -1.5519f,
2085     -1.5519f,   -0.192962f, -0.843708f, -0.359025f, -0.843708f, -0.843708f,
2086     -0.843708f, 4.53065f,   0.0429584f, -0.796804f, -0.843708f, 0.3473f,
2087     -0.843708f, -0.843708f, -0.114439f, 3.14817f,   0.0811934f, -0.843708f
2088   };
2089   float kernel[] = {
2090     0.119643f,    -0.237864f,   0.0462892f,   0.0502297f,   -0.0134528f,
2091     0.146347f,    0.153133f,    0.0513307f,   0.0752369f,   0.0135557f,
2092     -0.111434f,   0.0941854f,   0.0788362f,   0.0299412f,   0.111762f,
2093     0.144066f,    0.00431504f,  -0.0177954f,  0.0738092f,   -0.0344215f,
2094     0.0832582f,   0.053989f,    -0.112691f,   0.0962145f,   0.0186525f,
2095     -0.00660205f, -0.111962f,   -0.126801f,   -0.231625f,   0.17309f,
2096     0.0748875f,   -0.179569f,   -0.00513812f, -0.156579f,   -0.147322f,
2097     0.184168f,    0.189308f,    -0.200359f,   -0.0156733f,  0.140649f,
2098     0.0858496f,   -0.0263217f,  -0.0740749f,  -0.112563f,   0.107528f,
2099     0.0609729f,   -0.221625f,   0.0769944f,   -0.00900815f, -0.00136441f,
2100     -0.0236521f,  -0.0418025f,  -0.00286299f, 0.12241f,     0.0964093f,
2101     -0.0150897f,  0.0532171f,   0.0625916f,   0.116939f,    0.118024f,
2102     0.161918f,    -0.00909767f, 0.100897f,    -0.054563f,   -0.175179f,
2103     -0.0687892f,  0.00734235f,  0.109833f,    -0.113776f,   0.0595405f,
2104     -0.170255f,   0.0124815f,   -0.0363301f,  -0.0127038f,  0.0445554f,
2105     -0.0729894f,  0.107428f,    -0.0341417f,  0.132619f,    0.00984557f,
2106     -0.00443654f, 0.202929f,    0.0945134f,   0.0148725f,   0.00998574f,
2107     -0.0226449f,  0.0478197f,   -0.0793442f,  0.0707599f,   -0.084225f,
2108     0.0865795f,   0.071104f,    -0.047894f,   0.0838322f,   0.0635493f,
2109     -0.00370265f, -0.157247f,   -0.0289622f,  -0.0590963f,  0.13207f,
2110     0.00468011f,  -0.0345372f,  0.217939f,    0.18861f,     -0.0290393f,
2111     -0.0440664f,  0.0126197f,   -0.129132f,   -0.124943f,   0.0968156f,
2112     -0.0853643f,  -0.182305f,   0.00461618f,  -0.147095f,   -0.230282f,
2113     0.00856019f,  0.0278893f,   -0.0300229f,  0.0417871f,   0.0804717f,
2114     -0.0768571f,  -0.0397085f,  -0.0601096f,  0.100901f,    -0.0184926f,
2115     0.0350673f,   0.0971094f,   -0.0171837f,  -0.289644f,   -0.0899041f,
2116     0.08998f,     -0.160319f,   -0.0195103f,  0.0392167f,   -0.137864f,
2117     -0.0136294f,  0.0330886f,   -0.0409244f,  -0.092533f,   -0.0427934f,
2118     -0.191144f,   -0.0969461f,  0.112035f,    0.138611f,    0.128717f,
2119     0.191184f,    0.197462f
2120   };
2121   float bias[] = { 0.186703f, 0.204358f, -0.0230452f };
2122 
2123   float bn_gamma[] = { 1.32173f, 1.26171f, 1.21966f };
2124   float bn_beta[] = { -0.232595f, -0.222652f, -0.232209f };
2125   float bn_mean[] = { 0.329233f, 0.199894f, 0.12389f };
2126   float bn_std[] = { 0.311986f, 0.189737f, 0.247104f };
2127 
2128   CNN_BATCHNORM_PARAMS bn_params = {
2129     bn_gamma,
2130     bn_beta,
2131     bn_mean,
2132     bn_std,
2133   };
2134 
2135   CNN_CONFIG cnn_config = {
2136     1,
2137     0,
2138     0,
2139     0,
2140     0,
2141     {
2142         {
2143             1,
2144             filter_width,
2145             filter_height,
2146             3,
2147             7,
2148             7,
2149             0,
2150             kernel,
2151             bias,
2152             PADDING_VALID,
2153             RELU,
2154             0,
2155             0,
2156             BRANCH_NO_COPY,
2157             BRANCH_NOC,
2158             {},
2159             bn_params,
2160             0,
2161         },
2162     },
2163   };
2164 
2165   CNN_THREAD_DATA thread_data = { 1, nullptr };
2166 
2167   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2168              image_width, &thread_data, MSE_FLOAT_TOL);
2169 }
2170 
TEST_F(CNNTest,TestMultithreading)2171 TEST_F(CNNTest, TestMultithreading) {
2172   int image_height = 2;
2173   int image_width = 2;
2174   int filter_height = 3;
2175   int filter_width = 3;
2176 
2177   float input[] = {
2178     -2,
2179     4,
2180     1,
2181     0,
2182   };
2183 
2184   float weights[] = {
2185     -4, 2, -2, 0,  -4, 4, -3, -3, -3, -1, 1,  0,  -5, -3, 0, -5, 0, 0,
2186     -1, 0, 2,  -5, 0,  1, 4,  2,  1,  0,  -2, -1, -5, -3, 2, -2, 1, -5,
2187   };
2188 
2189   float bias[] = {
2190     -4,
2191     -3,
2192     -2,
2193     3,
2194   };
2195 
2196   float expected[] = {
2197     2, 10, -8, -17, -24, 5, -15, 6, -5, -5, 7, -10, 4, 13, 9, -14,
2198   };
2199 
2200   CNN_CONFIG cnn_config = {
2201     1,
2202     0,
2203     0,
2204     0,
2205     0,
2206     {
2207         {
2208             1,
2209             filter_width,
2210             filter_height,
2211             4,
2212             1,
2213             1,
2214             0,
2215             weights,
2216             bias,
2217             PADDING_SAME_ZERO,
2218             NONE,
2219             0,
2220             0,
2221             BRANCH_NO_COPY,
2222             BRANCH_NOC,
2223             {},
2224             {},
2225             0,
2226         },
2227     },
2228   };
2229 
2230   CNN_THREAD_DATA thread_data = { 1, nullptr };
2231 
2232   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2233              image_width, &thread_data, MSE_FLOAT_TOL);
2234 
2235   const AVxWorkerInterface *const winterface = aom_get_worker_interface();
2236   AVxWorker workers[4];
2237 
2238   for (int i = 0; i < 4; ++i) {
2239     winterface->init(&workers[i]);
2240   }
2241 
2242   thread_data = { 4, workers };
2243 
2244   RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2245              image_width, &thread_data, MSE_FLOAT_TOL);
2246 
2247   for (int i = 0; i < 4; ++i) {
2248     winterface->end(&workers[i]);
2249   }
2250 }
2251 
TEST_F(CNNTest,TestMultiOutput)2252 TEST_F(CNNTest, TestMultiOutput) {
2253   const int image_dim = 8;
2254   const int image_ch = 3;
2255   const int filter_dim = 2;
2256   const int stride = 2;
2257   const int num_filters = 2;
2258 
2259   const float input_[] = {
2260     1.7537929121f,     0.134331551012f,    0.123580039877f,   0.957731845246f,
2261     0.391006834217f,   1.00699352042f,     -0.778177955829f,  -0.814166433059f,
2262     -0.656374394915f,  0.321967305228f,    -2.19455719176f,   0.708035038966f,
2263     0.409148822266f,   -0.318254408902f,   0.152450211189f,   -0.250210793369f,
2264     0.826811563186f,   1.6804156584f,      0.273626975978f,   0.437936241887f,
2265     -0.329935520167f,  -0.288761611645f,   0.156937008304f,   0.271054157295f,
2266     -0.0224828854332f, 1.70110336895f,     -0.989066699309f,  1.30863131729f,
2267     -0.165813705702f,  0.00380178619265f,  -0.0837342367587f, 0.760954783156f,
2268     -0.413610373524f,  1.17968204175f,     0.720295719536f,   0.308718974472f,
2269     -1.10091337671f,   0.693160033687f,    -0.0202862320697f, 1.0221927503f,
2270     -1.24521801881f,   -0.478501952308f,   -1.71648619442f,   -0.182571723636f,
2271     0.339292649504f,   2.0806519131f,      0.967974033444f,   0.175248672328f,
2272     0.0658124561472f,  0.795504169496f,    0.750592557361f,   -1.46631013249f,
2273     -1.79052846838f,   -1.03672179515f,    -0.841985521653f,  1.20995011489f,
2274     0.140859718215f,   -0.651552622661f,   0.451065110806f,   1.1189443693f,
2275     0.100213260593f,   -0.834076868118f,   -1.28734321611f,   1.22064420095f,
2276     -0.364143084361f,  0.750961509335f,    -0.888689074553f,  -0.8253547106f,
2277     -1.21800999027f,   -0.966670603566f,   1.37384014741f,    0.47281264834f,
2278     -0.420416235531f,  0.520163906493f,    0.501296589423f,   1.53418976951f,
2279     0.715234751485f,   0.644551588907f,    0.0763504863375f,  -0.0018541943723f,
2280     0.322853189656f,   -0.795099723224f,   -0.125177096675f,  1.4476577471f,
2281     -0.585888410088f,  -1.44391754955f,    -0.610543221933f,  -0.221859179799f,
2282     0.252060200774f,   -0.86287169623f,    -0.0350246229157f, 1.0932311997f,
2283     0.899464648842f,   -0.468806951704f,   -0.300861137168f,  1.15776414206f,
2284     1.03268544738f,    -0.171579585622f,   -0.179136557119f,  -0.354091003368f,
2285     -0.612298249394f,  -1.20237379258f,    1.54604109659f,    0.130664370287f,
2286     0.885225111868f,   1.0362799581f,      0.980561720868f,   -0.619379186999f,
2287     -1.33818929924f,   -0.237233737961f,   -1.89335425073f,   0.567821011321f,
2288     0.862420368465f,   -1.37380916821f,    0.352190056666f,   0.611261516274f,
2289     0.393237747152f,   0.894686247967f,    0.190405182149f,   0.264872662911f,
2290     -0.0657009133797f, 0.0580512653493f,   -0.401825294366f,  0.4106081318f,
2291     0.49484512188f,    -0.0751103149442f,  -1.43243736382f,   1.79855656009f,
2292     -1.1075351975f,    0.000354882733011f, -0.950716438608f,  1.27129831688f,
2293     1.00495189838f,    0.110358656713f,    1.08315032822f,    -0.972676676218f,
2294     -0.0757668962831f, 1.88932045165f,     -0.0672638136275f, 0.425913010161f,
2295     -0.781540372017f,  0.976000248609f,    0.687218504122f,   1.31374513445f,
2296     -0.932658930672f,  -1.25339468479f,    0.422071294078f,   -0.24189927912f,
2297     0.216906604642f,   -1.88720997548f,    1.99252872889f,    0.353943735777f,
2298     0.737434784132f,   -1.17848645017f,    1.70424254896f,    0.775297112968f,
2299     -0.516392797501f,  0.398130609129f,    0.737248101457f,   0.166282500886f,
2300     1.24699015468f,    0.47116183125f,     1.19091180182f,    -0.372695424578f,
2301     0.219773209389f,   -0.829467838962f,   -0.52533122724f,   1.98707754595f,
2302     0.553692606972f,   -0.933228902369f,   1.55427751643f,    -1.08813399144f,
2303     -0.325686682094f,  0.205091443796f,    -1.70381666435f,   0.466465327942f,
2304     1.73126863447f,    -0.939133672634f,   1.48318077459f,    -0.599414038168f,
2305     -1.1583078687f,    0.518116190201f,    0.133571482458f,   0.84958342672f,
2306     1.02205000597f,    -0.0772082009087f,  -1.69567503859f,   1.4697939436f,
2307     1.67813743122f,    -0.627911582938f,   0.131380509137f,   -1.35717850726f,
2308   };
2309   const float *input[3] = { input_, &input_[image_dim * image_dim],
2310                             &input_[2 * image_dim * image_dim] };
2311 
2312   const float bias[] = { 0.0f, 0.0f };
2313 
2314   const float weights_1[] = {
2315     -0.489547413618f, 0.141916424749f,  -0.279286485585f,  -0.115322211094f,
2316     0.299572786936f,  0.205289980785f,  -0.536254480088f,  -0.253626313744f,
2317     -0.422883815849f, -0.169702966298f, -0.540104704793f,  0.495319646763f,
2318     0.298799079422f,  -0.10054550901f,  -0.306085047056f,  0.171061886165f,
2319     -0.108058703878f, -0.410734629888f, -0.0640674673049f, -0.386524840979f,
2320     -0.157203423678f, -0.362138920529f, -0.216206085209f,  0.147502517971f,
2321   };
2322 
2323   const float weights_2[] = {
2324     0.207580604357f,  0.480821146263f,  -0.29111909562f,   0.47422567493f,
2325     0.206892553253f,  -0.235067084092f, 0.354516800602f,   -0.212399370252f,
2326     -0.419071343731f, -0.050350731631f, -0.0516457320279f, -0.0359310500731f,
2327     0.567044864811f,  -0.060341127522f, 0.0501464839637f,  -0.437785677916f,
2328   };
2329 
2330   const float weights_3[] = {
2331     -0.0690452401448f, -0.356657338763f,   -0.219464031809f, 0.551288365843f,
2332     0.181372090853f,   -0.00245268542109f, 0.409000696276f,  -0.593209108763f,
2333     0.587352566749f,   -0.243720660227f,   0.266232713887f,  -0.00439285245097f,
2334     0.252883228305f,   0.152646192631f,    0.0918944932026f, 0.398853715057f,
2335   };
2336 
2337   const float weights_4[] = {
2338     0.207560791573f,   0.194201350401f,   0.227802322443f,  0.206533663345f,
2339     0.0557331066805f,  0.0224159800424f,  -0.143939197467f, -0.27703361602f,
2340     0.130643888389f,   -0.269456557461f,  0.186242862864f,  -0.162879944774f,
2341     -0.145503996718f,  -0.0768822987581f, -0.203127976359f, -0.238119922873f,
2342     -0.258806479994f,  0.0357957680385f,  -0.1027606976f,   -0.287920082345f,
2343     0.189047820993f,   0.250711538481f,   -0.272815714175f, -0.0431449742024f,
2344     0.207261230996f,   -0.0396472677451f, 0.131236557412f,  0.174291832499f,
2345     -0.251515885765f,  -0.107164007499f,  0.185824534748f,  -0.00561585838161f,
2346     0.273393799578f,   -0.139563699075f,  -0.263922456031f, -0.118859844081f,
2347     0.109230982597f,   -0.170170294794f,  0.0123025648515f, -0.0839368964355f,
2348     -0.0774058234297f, 0.255847138286f,   -0.208430879637f, 0.279170114319f,
2349     -0.272890330712f,  -0.217725903006f,  -0.295923275459f, -0.17008723953f,
2350     -0.284281803405f,  0.281406323629f,   0.266910044663f,  -0.209963914338f,
2351     0.271980962964f,   0.142013581699f,   -0.143896509026f, -0.290509242975f,
2352     -0.305768180935f,  0.196902832117f,   -0.090424189662f, -0.147460802346f,
2353     0.217722016651f,   0.12353848977f,    -0.169177363577f, -0.0454230918512f,
2354   };
2355 
2356   const float expected_0[] = {
2357     -2.04858441055f,  -2.12883075791f,    -0.045177363807f, 0.763949675768f,
2358     -0.544361512821f, -1.58123168032f,    1.89319847039f,   0.16859080901f,
2359     -1.16023321135f,  -0.396988107751f,   1.76637090744f,   -1.40434786514f,
2360     0.908227575669f,  0.817064817605f,    0.215631134908f,  -0.848605613428f,
2361     -0.106756747018f, 0.0193027166685f,   0.801345615113f,  -0.395407237598f,
2362     -1.79983795658f,  -1.73054496242f,    0.0584392594454f, -0.388786095569f,
2363     -0.237269619354f, 0.000843578271263f, -1.24043512104f,  0.487839445893f,
2364     -0.394259726605f, 0.559632843424f,    -0.527224052291f, -1.53792340282f,
2365   };
2366 
2367   const float expected_1[] = {
2368     0.0f, 0.0f,           0.0f, 0.0f, 0.4057888292f, 0.325309571755f,
2369     0.0f, 1.22013465602f,
2370   };
2371 
2372   const float expected_2[] = {
2373     0.156119444687f,
2374     0.517385299817f,
2375   };
2376 
2377   const float expected_3[] = {
2378     0.224177852984f,
2379     0.503384419034f,
2380     0.156119444687f,
2381     0.517385299817f,
2382   };
2383 
2384   const float *expected[] = { expected_0, expected_1, expected_2, expected_3 };
2385 
2386   CNN_CONFIG cnn_config = {
2387     4,  // num_layers
2388     0,  // is_residue
2389     0,  // ext_width
2390     0,  // ext_height
2391     0,  // strict_bounds
2392     {
2393         // layer_config
2394         {
2395             image_ch,           // in_channels
2396             filter_dim,         // filter_width
2397             filter_dim,         // filter_height
2398             num_filters,        // out_channels
2399             stride,             // skip_width
2400             stride,             // skip_height
2401             0,                  // max_pool
2402             weights_1,          // weights
2403             bias,               // bias
2404             PADDING_SAME_ZERO,  // pad
2405             NONE,               // activation
2406             0,                  // deconvolve
2407             0,                  // branch
2408             BRANCH_OUTPUT,      // branch_copy_type
2409             BRANCH_NOC,         // branch_combine_type
2410             { 2, 0, 0 },        // branch_config
2411             {},                 // bn_params
2412             0,                  // output_num
2413         },
2414         {
2415             num_filters,        // in_channels
2416             filter_dim,         // filter_width
2417             filter_dim,         // filter_height
2418             num_filters,        // out_channels
2419             stride,             // skip_width
2420             stride,             // skip_height
2421             0,                  // max_pool
2422             weights_2,          // weights
2423             bias,               // bias
2424             PADDING_SAME_ZERO,  // pad
2425             RELU,               // activation
2426             0,                  // deconvolve
2427             0,                  // branch
2428             BRANCH_NO_COPY,     // branch_copy_type
2429             BRANCH_NOC,         // branch_combine_type
2430             {},                 // branch_config
2431             {},                 // bn_params
2432             1,                  // output_num
2433         },
2434         {
2435             num_filters,        // in_channels
2436             filter_dim,         // filter_width
2437             filter_dim,         // filter_height
2438             num_filters,        // out_channels
2439             stride,             // skip_width
2440             stride,             // skip_height
2441             0,                  // max_pool
2442             weights_3,          // weights
2443             bias,               // bias
2444             PADDING_SAME_ZERO,  // pad
2445             RELU,               // activation
2446             0,                  // deconvolve
2447             0,                  // branch
2448             BRANCH_NO_COPY,     // branch_copy_type
2449             BRANCH_NOC,         // branch_combine_type
2450             {},                 // branch_config
2451             {},                 // bn_params
2452             2,                  // output_num
2453         },
2454         {
2455             num_filters,     // in_channels
2456             2 * filter_dim,  // filter_width
2457             2 * filter_dim,  // filter_height
2458             num_filters,     // out_channels
2459             2 * stride,      // skip_width
2460             2 * stride,      // skip_height
2461             0,               // max_pool
2462             weights_4,       // weights
2463             bias,            // bias
2464             PADDING_VALID,   // pad
2465             RELU,            // activation
2466             0,               // deconvolve
2467             1,               // branch
2468             BRANCH_NO_COPY,  // branch_copy_type
2469             BRANCH_CAT,      // branch_combine_type
2470             { 0, 0, 1 },     // branch_config
2471             {},              // bn_params
2472             3,               // output_num
2473         },
2474     },
2475   };
2476 
2477   CNN_THREAD_DATA thread_data = { 1, nullptr };
2478 
2479   const int num_outputs = 4;
2480   const int output_chs[4] = { filter_dim, filter_dim, filter_dim,
2481                               2 * filter_dim };
2482   const int output_dims[4] = { 4, 2, 1, 1 };
2483   const int output_sizes[4] = {
2484     output_chs[0] * output_dims[0] * output_dims[0],
2485     output_chs[1] * output_dims[1] * output_dims[1],
2486     output_chs[2] * output_dims[2] * output_dims[2],
2487     output_chs[3] * output_dims[3] * output_dims[3],
2488   };
2489   float *const output_ = (float *)aom_malloc(
2490       sizeof(*output_) *
2491       (output_sizes[0] + output_sizes[1] + output_sizes[2] + output_sizes[3]));
2492   ASSERT_NE(output_, nullptr);
2493   float *output[CNN_MAX_CHANNELS] = { nullptr };
2494   int ch_ite = 0;
2495   float *output_ite = output_;
2496   for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
2497     for (int channel = 0; channel < output_chs[output_idx]; ++channel) {
2498       output[ch_ite++] = output_ite;
2499       output_ite += output_dims[output_idx] * output_dims[output_idx];
2500     }
2501   }
2502   CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_dims,
2503                                   output };
2504 
2505   RunMultiOutCNNTest(input, image_dim, image_dim, image_dim, &cnn_config,
2506                      &thread_data, &output_struct, expected, MSE_FLOAT_TOL);
2507 
2508   aom_free(output_);
2509 }
2510 
2511 namespace {
2512 
2513 typedef void (*CNNConvolveNoMaxpoolPaddingValidFunc)(
2514     const float **input, int in_width, int in_height, int in_stride,
2515     const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
2516     int start_idx, int cstep, int channel_step);
2517 
2518 typedef libaom_test::FuncParam<CNNConvolveNoMaxpoolPaddingValidFunc>
2519     CNNConvolveTestFuncs;
2520 
2521 class CNNConvolveTest : public ::testing::TestWithParam<CNNConvolveTestFuncs> {
2522  protected:
SetUp()2523   void SetUp() override { params_ = GetParam(); }
2524 
RunCNNConvolveSetup(int run_times)2525   void RunCNNConvolveSetup(int run_times) {
2526     int in_width = 65;
2527     int in_height = 65;
2528 
2529     const CNN_CONFIG *cnn_config = &av1_intra_mode_cnn_partition_cnn_config;
2530 
2531     for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
2532       int out_width = 0, out_height = 0;
2533       int in_size = in_width * in_height;
2534       // Get current layer output width and height.
2535       av1_find_cnn_layer_output_size(in_height, in_width,
2536                                      &cnn_config->layer_config[layer],
2537                                      &out_width, &out_height);
2538 
2539       int out_size = out_width * out_height;
2540       float *input[20], *output_ref[20], *output_mod[20];
2541 
2542       float *input_data =
2543           (float *)aom_malloc(sizeof(*input_data) * in_size *
2544                               cnn_config->layer_config[layer].in_channels);
2545       float *temp_ptr = input_data;
2546       ASSERT_NE(temp_ptr, nullptr);
2547       for (int i = 0; i < cnn_config->layer_config[layer].in_channels; ++i) {
2548         input[i] = temp_ptr;
2549         for (int j = 0; j < in_size; j++) {
2550           *(temp_ptr++) = ((float)rng_.Rand31() - (1 << 30)) / (1u << 31);
2551         }
2552       }
2553 
2554       float *out_data_ref = (float *)aom_calloc(
2555           sizeof(*out_data_ref),
2556           out_size * cnn_config->layer_config[layer].out_channels);
2557       ASSERT_NE(out_data_ref, nullptr);
2558       float *out_data_mod = (float *)aom_calloc(
2559           sizeof(*out_data_mod),
2560           out_size * cnn_config->layer_config[layer].out_channels);
2561       ASSERT_NE(out_data_mod, nullptr);
2562       float *temp_ptr1 = out_data_ref;
2563       float *temp_ptr2 = out_data_mod;
2564       for (int i = 0; i < cnn_config->layer_config[layer].out_channels; ++i) {
2565         output_ref[i] = temp_ptr1;
2566         output_mod[i] = temp_ptr2;
2567         temp_ptr1 += out_size;
2568         temp_ptr2 += out_size;
2569       }
2570 
2571       RunCNNConvolveTest(input, in_width, in_height, out_size,
2572                          &cnn_config->layer_config[layer], 0, 1, run_times,
2573                          layer, output_ref, output_mod, out_width);
2574 
2575       // Set current layer output width and height as next layer input width and
2576       // height.
2577       in_width = out_width;
2578       in_height = out_height;
2579 
2580       aom_free(input_data);
2581       aom_free(out_data_ref);
2582       aom_free(out_data_mod);
2583     }
2584   }
2585 
RunCNNConvolveTest(float ** input,int in_width,int in_height,int out_size,const CNN_LAYER_CONFIG * layer_config,int start_idx,int step,int run_times,int layer,float ** output_ref,float ** output_mod,int out_stride)2586   void RunCNNConvolveTest(float **input, int in_width, int in_height,
2587                           int out_size, const CNN_LAYER_CONFIG *layer_config,
2588                           int start_idx, int step, int run_times, int layer,
2589                           float **output_ref, float **output_mod,
2590                           int out_stride) {
2591     const int cstep = layer_config->in_channels * layer_config->out_channels;
2592     const int channel_step = AOMMAX(step, 1);
2593     aom_usec_timer timer;
2594     aom_usec_timer_start(&timer);
2595     for (int i = 0; i < run_times; ++i) {
2596       params_.ref_func((const float **)input, in_width, in_height, in_width,
2597                        layer_config, output_ref, out_stride, start_idx, cstep,
2598                        channel_step);
2599     }
2600     aom_usec_timer_mark(&timer);
2601     const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer));
2602 
2603     aom_usec_timer_start(&timer);
2604     for (int i = 0; i < run_times; ++i) {
2605       params_.tst_func((const float **)input, in_width, in_height, in_width,
2606                        layer_config, output_mod, out_stride, start_idx, cstep,
2607                        channel_step);
2608     }
2609     aom_usec_timer_mark(&timer);
2610     const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer));
2611 
2612     if (run_times > 1) {
2613       printf("layer : %d \n", layer);
2614       printf("%7.2f/%7.2fns (%3.2f)\n", time1, time2, time1 / time2);
2615     } else {
2616       for (int channel = 0; channel < layer_config->out_channels; ++channel) {
2617         const float *buf_ref = output_ref[channel];
2618         const float *buf_mod = output_mod[channel];
2619 
2620         for (int i = 0; i < out_size; ++i) {
2621           if (buf_ref[i] < CNN_CONVOLVE_PIXELWISE_FLOAT_TOL) {
2622             ASSERT_LE(buf_ref[i], CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
2623                 << "Reference output was near-zero, test output was not ("
2624                 << buf_mod[i] << ")";
2625           } else {
2626             const float error = buf_ref[i] - buf_mod[i];
2627             const float relative_error = fabsf(error / buf_ref[i]);
2628             ASSERT_LE(relative_error, CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
2629                 << " channel " << channel << " pixel " << i << ": "
2630                 << buf_ref[i] << "/" << buf_mod[i] << std::endl;
2631           }
2632         }
2633       }
2634     }
2635   }
2636 
2637  private:
2638   CNNConvolveTestFuncs params_;
2639   libaom_test::ACMRandom rng_;
2640 };
2641 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CNNConvolveTest);
2642 
TEST_P(CNNConvolveTest,CheckOutput)2643 TEST_P(CNNConvolveTest, CheckOutput) { RunCNNConvolveSetup(1); }
2644 
TEST_P(CNNConvolveTest,DISABLED_Speed)2645 TEST_P(CNNConvolveTest, DISABLED_Speed) { RunCNNConvolveSetup(100000); }
2646 
2647 #if HAVE_AVX2 && !CONFIG_EXCLUDE_SIMD_MISMATCH
2648 INSTANTIATE_TEST_SUITE_P(AVX2, CNNConvolveTest,
2649                          ::testing::Values(CNNConvolveTestFuncs(
2650                              &av1_cnn_convolve_no_maxpool_padding_valid_c,
2651                              &av1_cnn_convolve_no_maxpool_padding_valid_avx2)));
2652 #endif
2653 
2654 #if HAVE_NEON
2655 INSTANTIATE_TEST_SUITE_P(NEON, CNNConvolveTest,
2656                          ::testing::Values(CNNConvolveTestFuncs(
2657                              &av1_cnn_convolve_no_maxpool_padding_valid_c,
2658                              &av1_cnn_convolve_no_maxpool_padding_valid_neon)));
2659 #endif
2660 
2661 }  // namespace
2662