xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""Keyword Spotting with PyArmNN demo for processing live microphone data or pre-recorded files."""
5
6import sys
7import os
8from argparse import ArgumentParser
9
10import numpy as np
11import sounddevice as sd
12
13script_dir = os.path.dirname(__file__)
14sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
15
16from network_executor import ArmnnNetworkExecutor
17from utils import prepare_input_data, dequantize_output
18from mfcc import AudioPreprocessor, MFCC, MFCCParams
19from audio_utils import decode, display_text
20from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio
21
22# Model Specific Labels
23labels = {0: 'silence',
24          1: 'unknown',
25          2: 'yes',
26          3: 'no',
27          4: 'up',
28          5: 'down',
29          6: 'left',
30          7: 'right',
31          8: 'on',
32          9: 'off',
33          10: 'stop',
34          11: 'go'}
35
36
37def parse_args():
38    parser = ArgumentParser(description="KWS with PyArmNN")
39    parser.add_argument(
40        "--audio_file_path",
41        required=False,
42        type=str,
43        help="Path to the audio file to perform KWS",
44    )
45    parser.add_argument(
46        "--duration",
47        type=int,
48        default=0,
49        help="""Duration for recording audio in seconds. Values <= 0 result in infinite
50           recording. Defaults to infinite.""",
51    )
52    parser.add_argument(
53        "--model_file_path",
54        required=True,
55        type=str,
56        help="Path to KWS model to use",
57    )
58    parser.add_argument(
59        "--preferred_backends",
60        type=str,
61        nargs="+",
62        default=["CpuAcc", "CpuRef"],
63        help="""List of backends in order of preference for optimizing
64        subgraphs, falling back to the next backend in the list on unsupported
65        layers. Defaults to [CpuAcc, CpuRef]""",
66    )
67    return parser.parse_args()
68
69
70def recognise_speech(audio_data, network, preprocessor, threshold):
71    # Prepare the input Tensors
72    input_data = prepare_input_data(audio_data, network.get_data_type(), network.get_input_quantization_scale(0),
73                                    network.get_input_quantization_offset(0), preprocessor)
74    # Run inference
75    output_result = network.run([input_data])
76
77    dequantized_result = []
78    for index, ofm in enumerate(output_result):
79        dequantized_result.append(dequantize_output(ofm, network.is_output_quantized(index),
80                                                    network.get_output_quantization_scale(index),
81                                                    network.get_output_quantization_offset(index)))
82
83    # Decode the text and display result if above threshold
84    decoded_result = decode(dequantized_result, labels)
85
86    if decoded_result[1] > threshold:
87        display_text(decoded_result)
88
89
90def main(args):
91    # Read command line args and invoke mic streaming if no file path supplied
92    audio_file = args.audio_file_path
93    if args.audio_file_path:
94        streaming_enabled = False
95    else:
96        streaming_enabled = True
97    # Create the ArmNN inference runner
98    network = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
99
100    # Specify model specific audio data requirements
101    # Overlap value specifies the number of samples to rewind between each data window
102    audio_capture_params = AudioCaptureParams(dtype=np.float32, overlap=2000, min_samples=16000, sampling_freq=16000,
103                                              mono=True)
104
105    # Create the preprocessor
106    mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000,
107                             num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024)
108    mfcc = MFCC(mfcc_params)
109    preprocessor = AudioPreprocessor(mfcc, model_input_size=49, stride=320)
110
111    # Set threshold for displaying classification and commence stream or file processing
112    threshold = .90
113    if streaming_enabled:
114        # Initialise audio stream
115        record_stream = CaptureAudioStream(audio_capture_params)
116        record_stream.set_stream_defaults()
117        record_stream.set_recording_duration(args.duration)
118        record_stream.countdown()
119
120        with sd.InputStream(callback=record_stream.callback):
121            print("Recording audio. Please speak.")
122            while record_stream.is_active:
123
124                audio_data = record_stream.capture_data()
125                recognise_speech(audio_data, network, preprocessor, threshold)
126                record_stream.is_first_window = False
127            print("\nFinished recording.")
128
129    # If file path has been supplied read-in and run inference
130    else:
131        print("Processing Audio Frames...")
132        buffer = capture_audio(audio_file, audio_capture_params)
133        for audio_data in buffer:
134            recognise_speech(audio_data, network, preprocessor, threshold)
135
136
137if __name__ == "__main__":
138    args = parse_args()
139    main(args)
140