xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/image_classification/example_utils.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020 NXP and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4from urllib.parse import urlparse
5from PIL import Image
6from zipfile import ZipFile
7import os
8import pyarmnn as ann
9import numpy as np
10import requests
11import argparse
12import warnings
13
14DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
15
16
17def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info):
18    """Runs inference on a set of images.
19
20    Args:
21        runtime: Arm NN runtime
22        net_id: Network ID
23        images: Loaded images to run inference on
24        labels: Loaded labels per class
25        input_binding_info: Network input information
26        output_binding_info: Network output information
27
28    Returns:
29        None
30    """
31    output_tensors = ann.make_output_tensors([output_binding_info])
32    for idx, im in enumerate(images):
33        # Create input tensors
34        input_tensors = ann.make_input_tensors([input_binding_info], [im])
35
36        # Run inference
37        print("Running inference({0}) ...".format(idx))
38        runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
39
40        # Process output
41        # output tensor has a shape (1, 1001)
42        out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
43        results = np.argsort(out_tensor)[::-1]
44        print_top_n(5, results, labels, out_tensor)
45
46
47def unzip_file(filename: str):
48    """Unzips a file.
49
50    Args:
51        filename(str): Name of the file
52
53    Returns:
54        None
55    """
56    with ZipFile(filename, 'r') as zip_obj:
57        zip_obj.extractall()
58
59
60def parse_command_line(desc: str = ""):
61    """Adds arguments to the script.
62
63    Args:
64        desc (str): Script description
65
66    Returns:
67        Namespace: Arguments to the script command
68    """
69    parser = argparse.ArgumentParser(description=desc)
70    parser.add_argument("-v", "--verbose", help="Increase output verbosity",
71                        action="store_true")
72    parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.",
73                        action="store", default="")
74    parser.add_argument("-m", "--model-dir",
75                        help="Model directory which contains the model file (TFLite, ONNX).", action="store",
76                        default="")
77    return parser.parse_args()
78
79
80def __create_network(model_file: str, backends: list, parser=None):
81    """Creates a network based on a file and parser type.
82
83    Args:
84        model_file (str): Path of the model file
85        backends (list): List of backends to use when running inference.
86        parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
87
88    Returns:
89        int: Network ID
90        IParser: TF Lite parser instance
91        IRuntime: Runtime object instance
92    """
93    args = parse_command_line()
94    options = ann.CreationOptions()
95    runtime = ann.IRuntime(options)
96
97    if parser is None:
98        # try to determine what parser to create based on model extension
99        _, ext = os.path.splitext(model_file)
100        if ext == ".onnx":
101            parser = ann.IOnnxParser()
102        elif ext == ".tflite":
103            parser = ann.ITfLiteParser()
104    assert (parser is not None)
105
106    network = parser.CreateNetworkFromBinaryFile(model_file)
107
108    preferred_backends = []
109    for b in backends:
110        preferred_backends.append(ann.BackendId(b))
111
112    opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
113                                         ann.OptimizerOptions())
114    if args.verbose:
115        for m in messages:
116            warnings.warn(m)
117
118    net_id, w = runtime.LoadNetwork(opt_network)
119    if args.verbose and w:
120        warnings.warn(w)
121
122    return net_id, parser, runtime
123
124
125def create_tflite_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
126    """Creates a network from a tflite model file.
127
128    Args:
129        model_file (str): Path of the model file.
130        backends (list): List of backends to use when running inference.
131
132    Returns:
133        int: Network ID.
134        int: Graph ID.
135        ITFliteParser: TF Lite parser instance.
136        IRuntime: Runtime object instance.
137    """
138    net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser())
139    graph_id = parser.GetSubgraphCount() - 1
140
141    return net_id, graph_id, parser, runtime
142
143
144def create_onnx_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
145    """Creates a network from an onnx model file.
146
147    Args:
148        model_file (str): Path of the model file.
149        backends (list): List of backends to use when running inference.
150
151    Returns:
152        int: Network ID.
153        IOnnxParser: ONNX parser instance.
154        IRuntime: Runtime object instance.
155    """
156    return __create_network(model_file, backends, ann.IOnnxParser())
157
158
159def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list,
160                       stddev: list):
161    """Default preprocessing image function.
162
163    Args:
164        img (PIL.Image): PIL.Image object instance.
165        width (int): Width to resize to.
166        height (int): Height to resize to.
167        data_type: Data Type to cast the image to.
168        scale (float): Scaling value.
169        mean (list): RGB mean offset.
170        stddev (list): RGB standard deviation.
171
172    Returns:
173        np.array: Resized and preprocessed image.
174    """
175    img = img.resize((width, height), Image.BILINEAR)
176    img = img.convert('RGB')
177    img = np.array(img)
178    img = np.reshape(img, (-1, 3))  # reshape to [RGB][RGB]...
179    img = ((img / scale) - mean) / stddev
180    img = img.flatten().astype(data_type)
181    return img
182
183
184def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
185                scale: float = 1., mean: list = (0., 0., 0.), stddev: list = (1., 1., 1.),
186                preprocess_fn=preprocess_default):
187    """Loads images, resizes and performs any additional preprocessing to run inference.
188
189    Args:
190        img (list): List of PIL.Image object instances.
191        input_width (int): Width to resize to.
192        input_height (int): Height to resize to.
193        data_type: Data Type to cast the image to.
194        scale (float): Scaling value.
195        mean (list): RGB mean offset.
196        stddev (list): RGB standard deviation.
197        preprocess_fn: Preprocessing function.
198
199    Returns:
200        np.array: Resized and preprocessed images.
201    """
202    images = []
203    for i in image_files:
204        img = Image.open(i)
205        img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev)
206        images.append(img)
207    return images
208
209
210def load_labels(label_file: str):
211    """Loads a labels file containing a label per line.
212
213    Args:
214        label_file (str): Labels file path.
215
216    Returns:
217        list: List of labels read from a file.
218    """
219    with open(label_file, 'r') as f:
220        labels = [l.rstrip() for l in f]
221        return labels
222
223
224def print_top_n(N: int, results: list, labels: list, prob: list):
225    """Prints TOP-N results
226
227    Args:
228        N (int): Result count to print.
229        results (list): Top prediction indices.
230        labels (list): A list of labels for every class.
231        prob (list): A list of probabilities for every class.
232
233    Returns:
234        None
235    """
236    assert (len(results) >= 1 and len(results) == len(labels) == len(prob))
237    for i in range(min(len(results), N)):
238        print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
239
240
241def download_file(url: str, force: bool = False, filename: str = None):
242    """Downloads a file.
243
244    Args:
245        url (str): File url.
246        force (bool): Forces to download the file even if it exists.
247        filename (str): Renames the file when set.
248
249    Raises:
250        RuntimeError: If for some reason download fails.
251
252    Returns:
253        str: Path to the downloaded file.
254    """
255    try:
256        if filename is None:  # extract filename from url when None
257            filename = urlparse(url)
258            filename = os.path.basename(filename.path)
259
260        print("Downloading '{0}' from '{1}' ...".format(filename, url))
261        if not os.path.exists(filename) or force is True:
262            r = requests.get(url)
263            with open(filename, 'wb') as f:
264                f.write(r.content)
265            print("Finished.")
266        else:
267            print("File already exists.")
268    except:
269        raise RuntimeError("Unable to download file.")
270
271    return filename
272
273
274def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None):
275    """Gets model and labels.
276
277    Args:
278        model_dir(str): Folder in which model and label files can be found
279        model (str): Name of the model file
280        labels (str): Name of the labels file
281        archive (str): Name of the archive file (optional - need to provide only labels and model)
282        download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it)
283
284    Returns:
285        tuple (str, str): Output label and model filenames
286    """
287    labels = os.path.join(model_dir, labels)
288    model = os.path.join(model_dir, model)
289
290    if os.path.exists(labels) and os.path.exists(model):
291        print("Found model ({0}) and labels ({1}).".format(model, labels))
292    elif archive is not None and os.path.exists(os.path.join(model_dir, archive)):
293        print("Found archive ({0}). Unzipping ...".format(archive))
294        unzip_file(archive)
295    elif download_url is not None:
296        print("Model, labels or archive not found. Downloading ...".format(archive))
297        try:
298            if isinstance(download_url, str):
299                download_url = [download_url]
300            for dl in download_url:
301                archive = download_file(dl)
302                if dl.lower().endswith(".zip"):
303                    unzip_file(archive)
304        except RuntimeError:
305            print("Unable to download file ({}).".format(download_url))
306
307    if not os.path.exists(labels) or not os.path.exists(model):
308        raise RuntimeError("Unable to provide model and labels.")
309
310    return model, labels
311
312
313def list_images(folder: str = None, formats: list = ('.jpg', '.jpeg')):
314    """Lists files of a certain format in a folder.
315
316    Args:
317        folder (str): Path to the folder to search
318        formats (list): List of supported files
319
320    Returns:
321        list: A list of found files
322    """
323    files = []
324    if folder and not os.path.exists(folder):
325        print("Folder '{}' does not exist.".format(folder))
326        return files
327
328    for file in os.listdir(folder if folder else os.getcwd()):
329        for frmt in formats:
330            if file.lower().endswith(frmt):
331                files.append(os.path.join(folder, file) if folder else file)
332                break  # only the format loop
333
334    return files
335
336
337def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
338    """Gets image.
339
340    Args:
341        image_dir (str): Image filename
342        image_url (str): Image url
343
344    Returns:
345        str: Output image filename
346    """
347    images = list_images(image_dir)
348    if not images and image_url is not None:
349        print("No images found. Downloading ...")
350        try:
351            images = [download_file(image_url)]
352        except RuntimeError:
353            print("Unable to download file ({0}).".format(image_url))
354
355    if not images:
356        raise RuntimeError("Unable to provide images.")
357
358    return images
359