xref: /aosp_15_r20/external/armnn/samples/SpeechRecognition/src/Main.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <iostream>
6 #include <map>
7 #include <vector>
8 #include <algorithm>
9 #include <cmath>
10 
11 #include "CmdArgsParser.hpp"
12 #include "ArmnnNetworkExecutor.hpp"
13 #include "AudioCapture.hpp"
14 #include "SpeechRecognitionPipeline.hpp"
15 #include "Wav2LetterMFCC.hpp"
16 
17 using InferenceResult = std::vector<int8_t>;
18 using InferenceResults = std::vector<InferenceResult>;
19 
20 const std::string AUDIO_FILE_PATH = "--audio-file-path";
21 const std::string MODEL_FILE_PATH = "--model-file-path";
22 const std::string LABEL_PATH = "--label-path";
23 const std::string PREFERRED_BACKENDS = "--preferred-backends";
24 const std::string HELP = "--help";
25 
26 std::map<int, std::string> labels =
27 {
28         {0,  "a"},
29         {1,  "b"},
30         {2,  "c"},
31         {3,  "d"},
32         {4,  "e"},
33         {5,  "f"},
34         {6,  "g"},
35         {7,  "h"},
36         {8,  "i"},
37         {9,  "j"},
38         {10, "k"},
39         {11, "l"},
40         {12, "m"},
41         {13, "n"},
42         {14, "o"},
43         {15, "p"},
44         {16, "q"},
45         {17, "r"},
46         {18, "s"},
47         {19, "t"},
48         {20, "u"},
49         {21, "v"},
50         {22, "w"},
51         {23, "x"},
52         {24, "y"},
53         {25, "z"},
54         {26, "\'"},
55         {27, " "},
56         {28, "$"}
57 };
58 
59 /*
60  * The accepted options for this Speech Recognition executable
61  */
62 static std::map<std::string, std::string> CMD_OPTIONS =
63 {
64     {AUDIO_FILE_PATH,    "[REQUIRED] Path to the Audio file to run speech recognition on"},
65     {MODEL_FILE_PATH,    "[REQUIRED] Path to the Speech Recognition model to use"},
66     {PREFERRED_BACKENDS, "[OPTIONAL] Takes the preferred backends in preference order, separated by comma."
67                          " For example: CpuAcc,GpuAcc,CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]."
68                          " Defaults to CpuAcc,CpuRef"}
69 };
70 
71 /*
72  * Reads the user supplied backend preference, splits it by comma, and returns an ordered vector
73  */
GetPreferredBackendList(const std::string & preferredBackends)74 std::vector<armnn::BackendId> GetPreferredBackendList(const std::string& preferredBackends)
75 {
76     std::vector<armnn::BackendId> backends;
77     std::stringstream ss(preferredBackends);
78 
79     while (ss.good())
80     {
81         std::string backend;
82         std::getline(ss, backend, ',');
83         backends.emplace_back(backend);
84     }
85     return backends;
86 }
87 
main(int argc,char * argv[])88 int main(int argc, char* argv[])
89 {
90     bool isFirstWindow = true;
91     std::string currentRContext = "";
92 
93     std::map<std::string, std::string> options;
94 
95     int result = ParseOptions(options, CMD_OPTIONS, argv, argc);
96     if (result != 0)
97     {
98         return result;
99     }
100 
101     // Create the network options
102     common::PipelineOptions pipelineOptions;
103     pipelineOptions.m_ModelFilePath = GetSpecifiedOption(options, MODEL_FILE_PATH);
104     pipelineOptions.m_ModelName = "Wav2Letter";
105     if (CheckOptionSpecified(options, PREFERRED_BACKENDS))
106     {
107         pipelineOptions.m_backends = GetPreferredBackendList((GetSpecifiedOption(options, PREFERRED_BACKENDS)));
108     }
109     else
110     {
111         pipelineOptions.m_backends = {"CpuAcc", "CpuRef"};
112     }
113 
114     asr::IPipelinePtr asrPipeline = asr::CreatePipeline(pipelineOptions, labels);
115 
116     audio::AudioCapture capture;
117     std::vector<float> audioData = audio::AudioCapture::LoadAudioFile(GetSpecifiedOption(options, AUDIO_FILE_PATH));
118     capture.InitSlidingWindow(audioData.data(), audioData.size(), asrPipeline->getInputSamplesSize(),
119                               asrPipeline->getSlidingWindowOffset());
120 
121     while (capture.HasNext())
122     {
123         std::vector<float> audioBlock = capture.Next();
124         InferenceResults results;
125 
126         std::vector<int8_t> preprocessedData = asrPipeline->PreProcessing(audioBlock);
127         asrPipeline->Inference<int8_t>(preprocessedData, results);
128         asrPipeline->PostProcessing<int8_t>(results, isFirstWindow, !capture.HasNext(), currentRContext);
129     }
130 
131     return 0;
132 }