xref: /btstack/port/stm32-f4discovery-usb/Drivers/CMSIS/NN/NN_Lib_Tests/nn_test/arm_nnexamples_nn_test.h (revision a8f7f3fcbcd51f8d2e92aca076b6a9f812db358c)
1 #ifndef _MAIN_H_
2 #define _MAIN_H_
3 
4 #include <stdio.h>
5 #include <stdlib.h>
6 #include <math.h>
7 
8 #include "arm_math.h"
9 
10 #include "arm_nnfunctions.h"
11 #include "ref_functions.h"
12 
13 extern int test_index;
14 extern q7_t test_flags[50];
15 
initialize_results_q7(q7_t * ref,q7_t * opt,int length)16 void initialize_results_q7(q7_t * ref, q7_t * opt, int length)
17 {
18     arm_fill_q7(0, ref, length);
19     arm_fill_q7(37, opt, length);
20 }
21 
initialize_results_q15(q15_t * ref,q15_t * opt,int length)22 void initialize_results_q15(q15_t * ref, q15_t * opt, int length)
23 {
24     arm_fill_q15(0, ref, length);
25     arm_fill_q15(0x5F5, opt, length);
26 }
27 
verify_results_q7(q7_t * ref,q7_t * opt,int length)28 void verify_results_q7(q7_t * ref, q7_t * opt, int length)
29 {
30 
31     bool      if_match = true;
32 
33     for (int i = 0; i < length; i++)
34     {
35         if (ref[i] != opt[i])
36         {
37             printf("Output mismatch at %d, expected %d, actual %d\r\n", i, ref[i], opt[i]);
38 
39             if_match = false;
40         }
41     }
42 
43     if (if_match == true)
44     {
45         printf("Outputs match.\r\n\r\n");
46         test_flags[test_index++] = 0;
47     } else {
48         test_flags[test_index++] = 1;
49     }
50 
51 }
52 
verify_results_q15(q15_t * ref,q15_t * opt,int length)53 void verify_results_q15(q15_t * ref, q15_t * opt, int length)
54 {
55 
56     bool      if_match = true;
57 
58     for (int i = 0; i < length; i++)
59     {
60         if (ref[i] != opt[i])
61         {
62             printf("Output mismatch at %d, expected %d, actual %d\r\n", i, ref[i], opt[i]);
63 
64             if_match = false;
65         }
66     }
67 
68     if (if_match == true)
69     {
70         printf("Outputs match.\r\n\r\n");
71         test_flags[test_index++] = 0;
72     } else {
73         test_flags[test_index++] = 1;
74     }
75 
76 }
77 
78 #endif
79