xref: /aosp_15_r20/external/armnn/samples/KeywordSpotting/test/KeywordSpottingPipelineTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <catch.hpp>
7 #include <map>
8 #include <cinttypes>
9 #include "KeywordSpottingPipeline.hpp"
10 #include "DsCNNPreprocessor.hpp"
11 
GetResourceFilePath(const std::string & filename)12 static std::string GetResourceFilePath(const std::string& filename)
13 {
14     std::string testResources = TEST_RESOURCE_DIR;
15     if (testResources.empty())
16     {
17         throw std::invalid_argument("Invalid test resources directory provided");
18     }
19     else
20     {
21         if(testResources.back() != '/')
22         {
23             return testResources + "/" + filename;
24         }
25         else
26         {
27             return testResources + filename;
28         }
29     }
30 }
31 
32 TEST_CASE("Test Keyword spotting pipeline")
33 {
34     const int8_t ifm0_kws [] =
35     {
36     -0x1b, 0x4f, 0x7a, -0x55, 0x6, -0x11, 0x6e, -0x6, 0x67, -0x7e, -0xd, 0x6, 0x49, 0x79, -0x1e, 0xe,
37      0x1d, 0x6e, 0x6f, 0x6f, -0x2e, -0x4b, 0x2, -0x3e, 0x40, -0x4b, -0x7, 0x31, -0x38, -0x64, -0x28,
38      0xc, -0x1d, 0xf, 0x1c, 0x5a, -0x4b, 0x56, 0x7e, 0x9, -0x29, 0x13, -0x65, -0xa, 0x34, -0x59, 0x41,
39     -0x6f, 0x75, 0x67, -0x5f, 0x17, 0x4a, -0x76, -0x7a, 0x49, -0x19, -0x41, 0x78, 0x40, 0x44, 0xe,
40     -0x51, -0x5c, 0x3d, 0x24, 0x76, -0x66, -0x11, 0x5e, 0x7b, -0x4, 0x7a, 0x9, 0x13, 0x8, -0x21, -0x11,
41      0x13, 0x7a, 0x25, 0x6, -0x68, 0x6a, -0x30, -0x16, -0x43, -0x27, 0x4c, 0x6b, -0x14, -0x12, -0x5f,
42      0x49, -0x2a, 0x44, 0x57, -0x78, -0x72, 0x62, -0x8, -0x38, -0x73, -0x2, -0x80, 0x79, -0x3f, 0x57,
43      0x9, -0x7e, -0x34, -0x59, 0x19, -0x66, 0x58, -0x3b, -0x69, -0x1a, 0x13, -0x2f, -0x2f, 0x13, 0x35,
44     -0x30, 0x1e, 0x3b, -0x71, 0x67, 0x7d, -0x5d, 0x1a, 0x69, -0x53, -0x38, -0xf, 0x76, 0x2, 0x7e, 0x45,
45     -0xa, 0x59, -0x6b, -0x28, -0x5d, -0x63, -0x7d, -0x3, 0x48, 0x74, -0x75, -0x7a, 0x1f, -0x53, 0x5b,
46      0x4d, -0x18, -0x4a, 0x39, -0x52, 0x5a, -0x6b, -0x41, -0x3e, -0x61, -0x80, -0x52, 0x67, 0x71, -0x47,
47      0x79, -0x41, 0x3a, -0x8, -0x1f, 0x4d, -0x7, 0x5b, 0x6b, -0x1b, -0x8, -0x20, -0x21, 0x7c, -0x74,
48      0x25, -0x68, -0xe, -0x7e, -0x45, -0x28, 0x45, -0x1a, -0x39, 0x78, 0x11, 0x48, -0x6b, -0x7b, -0x43,
49     -0x21, 0x38, 0x46, 0x7c, -0x5d, 0x59, 0x53, -0x3f, -0x15, 0x59, -0x17, 0x75, 0x2f, 0x7c, 0x68, 0x6a,
50      0x0, -0x10, 0x5b, 0x61, 0x36, -0x41, 0x33, 0x23, -0x80, -0x1d, -0xb, -0x56, 0x2d, 0x68, -0x68,
51      0x2f, 0x48, -0x5d, -0x44, 0x64, -0x27, 0x68, -0x13, 0x39, -0x3f, 0x18, 0x31, 0x15, -0x78, -0x2,
52      0x72, 0x60, 0x59, -0x30, -0x22, 0x73, 0x61, 0x76, -0x4, -0x62, -0x64, -0x80, -0x32, -0x16, 0x51,
53     -0x2, -0x70, 0x71, 0x3f, -0x5f, -0x35, -0x3c, 0x79, 0x48, 0x61, 0x5b, -0x20, -0x1e, -0x68, -0x1c,
54      0x6c, 0x3a, 0x28, -0x36, -0x3e, 0x5f, -0x75, -0x73, 0x1e, 0x75, -0x66, -0x22, 0x20, -0x64, 0x67,
55      0x36, 0x14, 0x37, -0xa, -0xe, 0x8, -0x37, -0x43, 0x21, -0x8, 0x54, 0x1, 0x34, -0x2c, -0x73, -0x11,
56     -0x48, -0x1c, -0x40, 0x14, 0x4e, -0x53, 0x25, 0x5e, 0x14, 0x4f, 0x7c, 0x6d, -0x61, -0x38, 0x35,
57     -0x5a, -0x44, 0x12, 0x52, -0x60, 0x22, -0x1c, -0x8, -0x4, -0x6b, -0x71, 0x43, 0xb, 0x7b, -0x7,
58     -0x3c, -0x3b, -0x40, -0xd, 0x44, 0x6, 0x30, 0x38, 0x57, 0x1f, -0x7, 0x2, 0x4f, 0x64, 0x7c, -0x3,
59     -0x13, -0x71, -0x45, -0x53, -0x52, 0x2b, -0x11, -0x1d, -0x2, -0x29, -0x37, 0x3d, 0x19, 0x76, 0x18,
60      0x1d, 0x12, -0x29, -0x5e, -0x54, -0x48, 0x5d, -0x41, -0x3f, 0x7e, -0x2a, 0x41, 0x57, -0x65, -0x15,
61      0x12, 0x1f, -0x57, 0x79, -0x64, 0x3a, -0x2f, 0x7f, -0x6c, 0xa, 0x52, -0x1f, -0x41, 0x6e, -0x4b,
62      0x3d, -0x1b, -0x42, 0x22, -0x3c, -0x35, -0xf, 0xc, 0x32, -0x15, -0x68, -0x21, 0x0, -0x16, 0x14,
63     -0x10, -0x5b, 0x2f, 0x21, 0x41, -0x8, -0x12, -0xa, 0x10, 0xf, 0x7e, -0x76, -0x1d, 0x2b, -0x49,
64      0x42, -0x25, -0x78, -0x69, -0x2c, 0x3f, 0xc, 0x52, 0x6d, 0x2e, -0x13, 0x76, 0x37, -0x36, -0x51,
65     -0x5, -0x63, -0x4f, 0x1c, 0x6b, -0x4b, 0x71, -0x12, 0x72, -0x3f,-0x4a, 0xf, 0x3a, -0xd, 0x38, 0x3b,
66     -0x5d, 0x75, -0x43, -0x10, -0xa, -0x7a, 0x1a, -0x44, 0x1c, 0x6a, 0x43, -0x1b, -0x35, 0x7d, -0x2c,
67     -0x10, 0x5b, -0x42, -0x4f, 0x69, 0x1f, 0x1b, -0x64, -0x21, 0x19, -0x5d, 0x2e, -0x2a, -0x65, -0x13,
68     -0x70, -0x6e
69     };
70 
71     const int8_t ofm0_kws [] =
72     {
73     -0x80, 0x7f, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80
74     };
75 
76     // First 640 samples from yes.wav.
77     std::vector<int16_t> testWav = std::vector<int16_t>
78     {
79     139, 143, 164, 163, 157, 156, 151, 148, 172, 171,
80     165, 169, 149, 142, 145, 147, 166, 146, 112, 132,
81     132, 136, 165, 176, 176, 152, 138, 158, 179, 185,
82     183, 148, 121, 130, 167, 204, 163, 132, 165, 184,
83     193, 205, 210, 204, 195, 178, 168, 197, 207, 201,
84     197, 177, 185, 196, 191, 198, 196, 183, 193, 181,
85     157, 170, 167, 159, 164, 152, 146, 167, 180, 171,
86     194, 232, 204, 173, 171, 172, 184, 169, 175, 199,
87     200, 195, 185, 214, 214, 193, 196, 191, 204, 191,
88     172, 187, 183, 192, 203, 172, 182, 228, 232, 205,
89     177, 174, 191, 210, 210, 211, 197, 177, 198, 217,
90     233, 236, 203, 191, 169, 145, 149, 161, 198, 206,
91     176, 137, 142, 181, 200, 215, 201, 188, 166, 162,
92     184, 155, 135, 132, 126, 142, 169, 184, 172, 156,
93     132, 119, 150, 147, 154, 160, 125, 130, 137, 154,
94     161, 168, 195, 182, 160, 134, 138, 146, 130, 120,
95     101, 122, 137, 118, 117, 131, 145, 140, 146, 148,
96     148, 168, 159, 134, 114, 114, 130, 147, 147, 134,
97     125, 98, 107, 127, 99, 79, 84, 107, 117, 114,
98     93, 92, 127, 112, 109, 110, 96, 118, 97, 87,
99     110, 95, 128, 153, 147, 165, 146, 106, 101, 137,
100     139, 96, 73, 90, 91, 51, 69, 102, 100, 103,
101     96, 101, 123, 107, 82, 89, 118, 127, 99, 100,
102     111, 97, 111, 123, 106, 121, 133, 103, 100, 88,
103     85, 111, 114, 125, 102, 91, 97, 84, 139, 157,
104     109, 66, 72, 129, 111, 90, 127, 126, 101, 109,
105     142, 138, 129, 159, 140, 80, 74, 78, 76, 98,
106     68, 42, 106, 143, 112, 102, 115, 114, 82, 75,
107     92, 80, 110, 114, 66, 86, 119, 101, 101, 103,
108     118, 145, 85, 40, 62, 88, 95, 87, 73, 64,
109     86, 71, 71, 105, 80, 73, 96, 92, 85, 90,
110     81, 86, 105, 100, 89, 78, 102, 114, 95, 98,
111     69, 70, 108, 112, 111, 90, 104, 137, 143, 160,
112     145, 121, 98, 86, 91, 87, 115, 123, 109, 99,
113     85, 120, 131, 116, 125, 144, 153, 111, 98, 110,
114     93, 89, 101, 137, 155, 142, 108, 94, 136, 145,
115     129, 129, 122, 109, 90, 76, 81, 110, 119, 96,
116     95, 102, 105, 111, 90, 89, 111, 115, 86, 51,
117     107, 140, 105, 105, 110, 142, 125, 76, 75, 69,
118     65, 52, 61, 69, 55, 42, 47, 58, 37, 35,
119     24, 20, 44, 22, 16, 26, 6, 3, 4, 23,
120     60, 51, 30, 12, 24, 31, -9, -16, -13, 13,
121     19, 9, 37, 55, 70, 36, 23, 57, 45, 33,
122     50, 59, 18, 11, 62, 74, 52, 8, -3, 26,
123     51, 48, -5, -9, 12, -7, -12, -5, 28, 41,
124     -2, -30, -13, 31, 33, -12, -22, -8, -15, -17,
125     2, -6, -25, -27, -24, -8, 4, -9, -52, -47,
126     -9, -32, -45, -5, 41, 15, -32, -14, 2, -1,
127     -10, -30, -32, -25, -21, -17, -14, 8, -4, -13,
128     34, 18, -36, -38, -18, -19, -28, -17, -14, -16,
129     -2, -20, -27, 12, 11, -17, -33, -12, -22, -64,
130     -42, -26, -23, -22, -37, -51, -53, -30, -18, -48,
131     -69, -38, -54, -96, -72, -49, -50, -57, -41, -22,
132     -43, -64, -54, -23, -49, -69, -41, -44, -42, -49,
133     -40, -26, -54, -50, -38, -49, -70, -94, -89, -69,
134     -56, -65, -71, -47, -39, -49, -79, -91, -56, -46,
135     -62, -86, -64, -32, -47, -50, -71, -77, -65, -68,
136     -52, -51, -61, -67, -61, -81, -93, -52, -59, -62,
137     -51, -75, -76, -50, -32, -54, -68, -70, -43, 1,
138     -42, -92, -80, -41, -38, -79, -69, -49, -82, -122,
139     -93, -21, -24, -61, -70, -73, -62, -74, -69, -43,
140     -25, -15, -43, -23, -26, -69, -44, -12, 1, -51,
141     -78, -13, 3, -53, -105, -72, -24, -62, -66, -31,
142     -40, -65, -86, -64, -44, -55, -63, -61, -37, -41,
143     };
144 
145     // Golden audio ops mfcc output for the above wav.
146     const std::vector<float> testWavMfcc
147     {
148     -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072,
149     };
150 
151     std::vector<float> testWavFloat(640);
152     constexpr float normaliser = 1.0/(1u<<15u);
153     std::transform(testWav.begin(), testWav.end(), testWavFloat.begin(),
154                    std::bind1st(std::multiplies<float>(), normaliser));
155 
156     const float DsCNNInputQuantizationScale = 1.107164;
157     const int DsCNNInputQuantizationOffset = 95;
158 
159     std::map<int,std::string> labels =
160     {
161         {0,"silence"},
162         {1, "unknown"},
163         { 2, "yes"},
164         { 3,"no"},
165         { 4, "up"},
166         { 5, "down"},
167         { 6, "left"},
168         { 7, "right"},
169         { 8, "on"},
170         { 9, "off"},
171         { 10, "stop"},
172         {11, "go"}
173     };
174     common::PipelineOptions options;
175     options.m_ModelFilePath = GetResourceFilePath("ds_cnn_clustered_int8.tflite");
176     options.m_ModelName = "DS_CNN_CLUSTERED_INT8";
177     options.m_backends = {"CpuAcc", "CpuRef"};
178     kws::IPipelinePtr kwsPipeline = kws::CreatePipeline(options);
179 
180     CHECK(kwsPipeline->getInputSamplesSize() == 16000);
181     std::vector<int8_t> expectedWavMfcc;
182     for(auto& i : testWavMfcc)
183     {
184         expectedWavMfcc.push_back(
185             (i + DsCNNInputQuantizationScale * DsCNNInputQuantizationOffset) / DsCNNInputQuantizationScale);
186     }
187 
188     SECTION("Pre-processing")
189     {
190         testWavFloat.resize(16000);
191         expectedWavMfcc.resize(49 * 10);
192         std::vector<int8_t> preprocessedData = kwsPipeline->PreProcessing(testWavFloat);
193         CHECK(preprocessedData.size() == expectedWavMfcc.size());
194         for(int i = 0; i < 10; ++i)
195         {
196             CHECK(expectedWavMfcc[i] == Approx(preprocessedData[i]).margin(1));
197         }
198     }
199 
200     SECTION("Execute inference")
201     {
202         common::InferenceResults<int8_t> result;
203         std::vector<int8_t> IFM(std::begin(ifm0_kws), std::end(ifm0_kws));
204         kwsPipeline->Inference(IFM, result);
205         std::vector<int8_t> OFM(std::begin(ofm0_kws), std::end(ofm0_kws));
206 
207         CHECK(1 == result.size());
208         CHECK(OFM.size() == result[0].size());
209 
210         int count = 0;
211         for (auto& i : result)
212         {
213             for (signed char& j : i)
214             {
215                 CHECK(j == OFM[count++]);
216 
217             }
218         }
219     }
220 
221     SECTION("Convert inference result to keyword")
222     {
223         std::vector< std::vector< int8_t >> modelOutput = {{1, 4, 2, 3, 1, 1, 3, 1, 43, 1, 6, 1}};
224         kwsPipeline->PostProcessing(modelOutput, labels,
__anonb24260750102(int index, std::string& label, float prob) 225                                     [](int index, std::string& label, float prob) -> void {
226                                         CHECK(index == 8);
227                                         CHECK(label == "on");
228                                     });
229     }
230 }
231