xref: /aosp_15_r20/external/armnn/samples/KeywordSpotting/include/KeywordSpottingPipeline.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ArmnnNetworkExecutor.hpp"
9 #include "Decoder.hpp"
10 #include "MFCC.hpp"
11 #include "DsCNNPreprocessor.hpp"
12 
13 namespace kws
14 {
15 /**
16  * Generic Keyword Spotting pipeline with 3 steps: data pre-processing, inference execution and inference
17  * result post-processing.
18  *
19  */
20 class KWSPipeline
21 {
22 public:
23 
24     /**
25      * Creates speech recognition pipeline with given network executor and decoder.
26      * @param executor - unique pointer to inference runner
27      * @param decoder - unique pointer to inference results decoder
28      */
29     KWSPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,
30                 std::unique_ptr<Decoder> decoder,
31                 std::unique_ptr<DsCNNPreprocessor> preProcessor);
32 
33     /**
34      * @brief Standard audio pre-processing implementation.
35      *
36      * Preprocesses and prepares the data for inference by
37      * extracting the MFCC features.
38 
39      * @param[in] audio - the raw audio data
40      */
41 
42     std::vector<int8_t> PreProcessing(std::vector<float>& audio);
43 
44     /**
45      * @brief Executes inference
46      *
47      * Calls inference runner provided during instance construction.
48      *
49      * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor.
50      * @param[out] result - raw inference results.
51      */
52     void Inference(const std::vector<int8_t>& preprocessedData, common::InferenceResults<int8_t>& result);
53 
54     /**
55      * @brief Standard inference results post-processing implementation.
56      *
57      * Decodes inference results using decoder provided during construction.
58      *
59      * @param[in] inferenceResult - inference results to be decoded.
60      * @param[in] labels - the words we use for the model
61      */
62     void PostProcessing(common::InferenceResults<int8_t>& inferenceResults,
63                         std::map<int, std::string>& labels,
64                         const std::function<void (int, std::string&, float)>& callback);
65 
66     /**
67      * @brief Get the number of samples for the pipeline input
68 
69      * @return - number of samples for the pipeline
70      */
71     int getInputSamplesSize();
72 
73 protected:
74     std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor;
75     std::unique_ptr<Decoder> m_decoder;
76     std::unique_ptr<DsCNNPreprocessor> m_preProcessor;
77 };
78 
79 using IPipelinePtr = std::unique_ptr<kws::KWSPipeline>;
80 
81 /**
82  * Constructs speech recognition pipeline based on configuration provided.
83  *
84  * @param[in] config - speech recognition pipeline configuration.
85  * @param[in] labels - asr labels
86  *
87  * @return unique pointer to asr pipeline.
88  */
89 IPipelinePtr CreatePipeline(common::PipelineOptions& config);
90 
91 };// namespace kws