xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/common/utils.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""Contains helper functions that can be used across the example apps."""
5
6import os
7import errno
8from pathlib import Path
9
10import numpy as np
11import datetime
12
13
14def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
15    """Creates a dictionary of labels from the input labels file.
16
17    Args:
18        labels_file: Path to file containing labels to map model outputs.
19        include_rgb: Adds randomly generated RGB values to the values of the
20            dictionary. Used for plotting bounding boxes of different colours.
21
22    Returns:
23        Dictionary with classification indices for keys and labels for values.
24
25    Raises:
26        FileNotFoundError:
27            Provided `labels_file_path` does not exist.
28    """
29    labels_file = Path(labels_file_path)
30    if not labels_file.is_file():
31        raise FileNotFoundError(
32            errno.ENOENT, os.strerror(errno.ENOENT), labels_file_path
33        )
34
35    labels = {}
36    with open(labels_file, "r") as f:
37        for idx, line in enumerate(f, 0):
38            if include_rgb:
39                labels[idx] = line.strip("\n"), tuple(np.random.random(size=3) * 255)
40            else:
41                labels[idx] = line.strip("\n")
42        return labels
43
44
45def prepare_input_data(audio_data, input_data_type, input_quant_scale, input_quant_offset, mfcc_preprocessor):
46    """
47    Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
48    input tensors.
49
50    Args:
51        audio_data: The audio data to process
52        mfcc_instance: The mfcc class instance
53        input_data_type: The model's input data type
54        input_quant_scale: The model's quantization scale
55        input_quant_offset: The model's quantization offset
56        mfcc_preprocessor: The mfcc preprocessor instance
57    Returns:
58        input_data: The prepared input data
59    """
60
61    input_data = mfcc_preprocessor.extract_features(audio_data)
62    if input_data_type != np.float32:
63        input_data = quantize_input(input_data, input_data_type, input_quant_scale, input_quant_offset)
64    return input_data
65
66
67def quantize_input(data, input_data_type, input_quant_scale, input_quant_offset):
68    """Quantize the float input to (u)int8 ready for inputting to model."""
69    if data.ndim != 2:
70        raise RuntimeError("Audio data must have 2 dimensions for quantization")
71
72    if (input_data_type != np.int8) and (input_data_type != np.uint8):
73        raise ValueError("Could not quantize data to required data type")
74
75    d_min = np.iinfo(input_data_type).min
76    d_max = np.iinfo(input_data_type).max
77
78    for row in range(data.shape[0]):
79        for col in range(data.shape[1]):
80            data[row, col] = (data[row, col] / input_quant_scale) + input_quant_offset
81            data[row, col] = np.clip(data[row, col], d_min, d_max)
82    data = data.astype(input_data_type)
83    return data
84
85
86def dequantize_output(data, is_output_quantized, output_quant_scale, output_quant_offset):
87    """Dequantize the (u)int8 output to float"""
88
89    if is_output_quantized:
90        if data.ndim != 2:
91            raise RuntimeError("Data must have 2 dimensions for quantization")
92
93        data = data.astype(float)
94        for row in range(data.shape[0]):
95            for col in range(data.shape[1]):
96                data[row, col] = (data[row, col] - output_quant_offset)*output_quant_scale
97    return data
98
99
100class Profiling:
101    def __init__(self, enabled: bool):
102        self.m_start = 0
103        self.m_end = 0
104        self.m_enabled = enabled
105
106    def profiling_start(self):
107        if self.m_enabled:
108            self.m_start = datetime.datetime.now()
109
110    def profiling_stop_and_print_us(self, msg):
111        if self.m_enabled:
112            self.m_end = datetime.datetime.now()
113            period = self.m_end - self.m_start
114            period_us = period.seconds * 1_000_000 + period.microseconds
115            print(f'Profiling: {msg} : {period_us:,} microSeconds')
116            return period_us
117        return 0
118