1 // 2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ArmnnNetworkExecutor.hpp" 9 #include "Decoder.hpp" 10 #include "MFCC.hpp" 11 #include "DsCNNPreprocessor.hpp" 12 13 namespace kws 14 { 15 /** 16 * Generic Keyword Spotting pipeline with 3 steps: data pre-processing, inference execution and inference 17 * result post-processing. 18 * 19 */ 20 class KWSPipeline 21 { 22 public: 23 24 /** 25 * Creates speech recognition pipeline with given network executor and decoder. 26 * @param executor - unique pointer to inference runner 27 * @param decoder - unique pointer to inference results decoder 28 */ 29 KWSPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor, 30 std::unique_ptr<Decoder> decoder, 31 std::unique_ptr<DsCNNPreprocessor> preProcessor); 32 33 /** 34 * @brief Standard audio pre-processing implementation. 35 * 36 * Preprocesses and prepares the data for inference by 37 * extracting the MFCC features. 38 39 * @param[in] audio - the raw audio data 40 */ 41 42 std::vector<int8_t> PreProcessing(std::vector<float>& audio); 43 44 /** 45 * @brief Executes inference 46 * 47 * Calls inference runner provided during instance construction. 48 * 49 * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor. 50 * @param[out] result - raw inference results. 51 */ 52 void Inference(const std::vector<int8_t>& preprocessedData, common::InferenceResults<int8_t>& result); 53 54 /** 55 * @brief Standard inference results post-processing implementation. 56 * 57 * Decodes inference results using decoder provided during construction. 58 * 59 * @param[in] inferenceResult - inference results to be decoded. 60 * @param[in] labels - the words we use for the model 61 */ 62 void PostProcessing(common::InferenceResults<int8_t>& inferenceResults, 63 std::map<int, std::string>& labels, 64 const std::function<void (int, std::string&, float)>& callback); 65 66 /** 67 * @brief Get the number of samples for the pipeline input 68 69 * @return - number of samples for the pipeline 70 */ 71 int getInputSamplesSize(); 72 73 protected: 74 std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor; 75 std::unique_ptr<Decoder> m_decoder; 76 std::unique_ptr<DsCNNPreprocessor> m_preProcessor; 77 }; 78 79 using IPipelinePtr = std::unique_ptr<kws::KWSPipeline>; 80 81 /** 82 * Constructs speech recognition pipeline based on configuration provided. 83 * 84 * @param[in] config - speech recognition pipeline configuration. 85 * @param[in] labels - asr labels 86 * 87 * @return unique pointer to asr pipeline. 88 */ 89 IPipelinePtr CreatePipeline(common::PipelineOptions& config); 90 91 };// namespace kws