1# Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 2# SPDX-License-Identifier: MIT 3 4"""Keyword Spotting with PyArmNN demo for processing live microphone data or pre-recorded files.""" 5 6import sys 7import os 8from argparse import ArgumentParser 9 10import numpy as np 11import sounddevice as sd 12 13script_dir = os.path.dirname(__file__) 14sys.path.insert(1, os.path.join(script_dir, '..', 'common')) 15 16from network_executor import ArmnnNetworkExecutor 17from utils import prepare_input_data, dequantize_output 18from mfcc import AudioPreprocessor, MFCC, MFCCParams 19from audio_utils import decode, display_text 20from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio 21 22# Model Specific Labels 23labels = {0: 'silence', 24 1: 'unknown', 25 2: 'yes', 26 3: 'no', 27 4: 'up', 28 5: 'down', 29 6: 'left', 30 7: 'right', 31 8: 'on', 32 9: 'off', 33 10: 'stop', 34 11: 'go'} 35 36 37def parse_args(): 38 parser = ArgumentParser(description="KWS with PyArmNN") 39 parser.add_argument( 40 "--audio_file_path", 41 required=False, 42 type=str, 43 help="Path to the audio file to perform KWS", 44 ) 45 parser.add_argument( 46 "--duration", 47 type=int, 48 default=0, 49 help="""Duration for recording audio in seconds. Values <= 0 result in infinite 50 recording. Defaults to infinite.""", 51 ) 52 parser.add_argument( 53 "--model_file_path", 54 required=True, 55 type=str, 56 help="Path to KWS model to use", 57 ) 58 parser.add_argument( 59 "--preferred_backends", 60 type=str, 61 nargs="+", 62 default=["CpuAcc", "CpuRef"], 63 help="""List of backends in order of preference for optimizing 64 subgraphs, falling back to the next backend in the list on unsupported 65 layers. Defaults to [CpuAcc, CpuRef]""", 66 ) 67 return parser.parse_args() 68 69 70def recognise_speech(audio_data, network, preprocessor, threshold): 71 # Prepare the input Tensors 72 input_data = prepare_input_data(audio_data, network.get_data_type(), network.get_input_quantization_scale(0), 73 network.get_input_quantization_offset(0), preprocessor) 74 # Run inference 75 output_result = network.run([input_data]) 76 77 dequantized_result = [] 78 for index, ofm in enumerate(output_result): 79 dequantized_result.append(dequantize_output(ofm, network.is_output_quantized(index), 80 network.get_output_quantization_scale(index), 81 network.get_output_quantization_offset(index))) 82 83 # Decode the text and display result if above threshold 84 decoded_result = decode(dequantized_result, labels) 85 86 if decoded_result[1] > threshold: 87 display_text(decoded_result) 88 89 90def main(args): 91 # Read command line args and invoke mic streaming if no file path supplied 92 audio_file = args.audio_file_path 93 if args.audio_file_path: 94 streaming_enabled = False 95 else: 96 streaming_enabled = True 97 # Create the ArmNN inference runner 98 network = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends) 99 100 # Specify model specific audio data requirements 101 # Overlap value specifies the number of samples to rewind between each data window 102 audio_capture_params = AudioCaptureParams(dtype=np.float32, overlap=2000, min_samples=16000, sampling_freq=16000, 103 mono=True) 104 105 # Create the preprocessor 106 mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000, 107 num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024) 108 mfcc = MFCC(mfcc_params) 109 preprocessor = AudioPreprocessor(mfcc, model_input_size=49, stride=320) 110 111 # Set threshold for displaying classification and commence stream or file processing 112 threshold = .90 113 if streaming_enabled: 114 # Initialise audio stream 115 record_stream = CaptureAudioStream(audio_capture_params) 116 record_stream.set_stream_defaults() 117 record_stream.set_recording_duration(args.duration) 118 record_stream.countdown() 119 120 with sd.InputStream(callback=record_stream.callback): 121 print("Recording audio. Please speak.") 122 while record_stream.is_active: 123 124 audio_data = record_stream.capture_data() 125 recognise_speech(audio_data, network, preprocessor, threshold) 126 record_stream.is_first_window = False 127 print("\nFinished recording.") 128 129 # If file path has been supplied read-in and run inference 130 else: 131 print("Processing Audio Frames...") 132 buffer = capture_audio(audio_file, audio_capture_params) 133 for audio_data in buffer: 134 recognise_speech(audio_data, network, preprocessor, threshold) 135 136 137if __name__ == "__main__": 138 args = parse_args() 139 main(args) 140