1# Copyright © 2020 NXP and Contributors. All rights reserved. 2# SPDX-License-Identifier: MIT 3 4from urllib.parse import urlparse 5from PIL import Image 6from zipfile import ZipFile 7import os 8import pyarmnn as ann 9import numpy as np 10import requests 11import argparse 12import warnings 13 14DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg' 15 16 17def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info): 18 """Runs inference on a set of images. 19 20 Args: 21 runtime: Arm NN runtime 22 net_id: Network ID 23 images: Loaded images to run inference on 24 labels: Loaded labels per class 25 input_binding_info: Network input information 26 output_binding_info: Network output information 27 28 Returns: 29 None 30 """ 31 output_tensors = ann.make_output_tensors([output_binding_info]) 32 for idx, im in enumerate(images): 33 # Create input tensors 34 input_tensors = ann.make_input_tensors([input_binding_info], [im]) 35 36 # Run inference 37 print("Running inference({0}) ...".format(idx)) 38 runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) 39 40 # Process output 41 # output tensor has a shape (1, 1001) 42 out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0] 43 results = np.argsort(out_tensor)[::-1] 44 print_top_n(5, results, labels, out_tensor) 45 46 47def unzip_file(filename: str): 48 """Unzips a file. 49 50 Args: 51 filename(str): Name of the file 52 53 Returns: 54 None 55 """ 56 with ZipFile(filename, 'r') as zip_obj: 57 zip_obj.extractall() 58 59 60def parse_command_line(desc: str = ""): 61 """Adds arguments to the script. 62 63 Args: 64 desc (str): Script description 65 66 Returns: 67 Namespace: Arguments to the script command 68 """ 69 parser = argparse.ArgumentParser(description=desc) 70 parser.add_argument("-v", "--verbose", help="Increase output verbosity", 71 action="store_true") 72 parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.", 73 action="store", default="") 74 parser.add_argument("-m", "--model-dir", 75 help="Model directory which contains the model file (TFLite, ONNX).", action="store", 76 default="") 77 return parser.parse_args() 78 79 80def __create_network(model_file: str, backends: list, parser=None): 81 """Creates a network based on a file and parser type. 82 83 Args: 84 model_file (str): Path of the model file 85 backends (list): List of backends to use when running inference. 86 parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...) 87 88 Returns: 89 int: Network ID 90 IParser: TF Lite parser instance 91 IRuntime: Runtime object instance 92 """ 93 args = parse_command_line() 94 options = ann.CreationOptions() 95 runtime = ann.IRuntime(options) 96 97 if parser is None: 98 # try to determine what parser to create based on model extension 99 _, ext = os.path.splitext(model_file) 100 if ext == ".onnx": 101 parser = ann.IOnnxParser() 102 elif ext == ".tflite": 103 parser = ann.ITfLiteParser() 104 assert (parser is not None) 105 106 network = parser.CreateNetworkFromBinaryFile(model_file) 107 108 preferred_backends = [] 109 for b in backends: 110 preferred_backends.append(ann.BackendId(b)) 111 112 opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), 113 ann.OptimizerOptions()) 114 if args.verbose: 115 for m in messages: 116 warnings.warn(m) 117 118 net_id, w = runtime.LoadNetwork(opt_network) 119 if args.verbose and w: 120 warnings.warn(w) 121 122 return net_id, parser, runtime 123 124 125def create_tflite_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')): 126 """Creates a network from a tflite model file. 127 128 Args: 129 model_file (str): Path of the model file. 130 backends (list): List of backends to use when running inference. 131 132 Returns: 133 int: Network ID. 134 int: Graph ID. 135 ITFliteParser: TF Lite parser instance. 136 IRuntime: Runtime object instance. 137 """ 138 net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser()) 139 graph_id = parser.GetSubgraphCount() - 1 140 141 return net_id, graph_id, parser, runtime 142 143 144def create_onnx_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')): 145 """Creates a network from an onnx model file. 146 147 Args: 148 model_file (str): Path of the model file. 149 backends (list): List of backends to use when running inference. 150 151 Returns: 152 int: Network ID. 153 IOnnxParser: ONNX parser instance. 154 IRuntime: Runtime object instance. 155 """ 156 return __create_network(model_file, backends, ann.IOnnxParser()) 157 158 159def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list, 160 stddev: list): 161 """Default preprocessing image function. 162 163 Args: 164 img (PIL.Image): PIL.Image object instance. 165 width (int): Width to resize to. 166 height (int): Height to resize to. 167 data_type: Data Type to cast the image to. 168 scale (float): Scaling value. 169 mean (list): RGB mean offset. 170 stddev (list): RGB standard deviation. 171 172 Returns: 173 np.array: Resized and preprocessed image. 174 """ 175 img = img.resize((width, height), Image.BILINEAR) 176 img = img.convert('RGB') 177 img = np.array(img) 178 img = np.reshape(img, (-1, 3)) # reshape to [RGB][RGB]... 179 img = ((img / scale) - mean) / stddev 180 img = img.flatten().astype(data_type) 181 return img 182 183 184def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8, 185 scale: float = 1., mean: list = (0., 0., 0.), stddev: list = (1., 1., 1.), 186 preprocess_fn=preprocess_default): 187 """Loads images, resizes and performs any additional preprocessing to run inference. 188 189 Args: 190 img (list): List of PIL.Image object instances. 191 input_width (int): Width to resize to. 192 input_height (int): Height to resize to. 193 data_type: Data Type to cast the image to. 194 scale (float): Scaling value. 195 mean (list): RGB mean offset. 196 stddev (list): RGB standard deviation. 197 preprocess_fn: Preprocessing function. 198 199 Returns: 200 np.array: Resized and preprocessed images. 201 """ 202 images = [] 203 for i in image_files: 204 img = Image.open(i) 205 img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev) 206 images.append(img) 207 return images 208 209 210def load_labels(label_file: str): 211 """Loads a labels file containing a label per line. 212 213 Args: 214 label_file (str): Labels file path. 215 216 Returns: 217 list: List of labels read from a file. 218 """ 219 with open(label_file, 'r') as f: 220 labels = [l.rstrip() for l in f] 221 return labels 222 223 224def print_top_n(N: int, results: list, labels: list, prob: list): 225 """Prints TOP-N results 226 227 Args: 228 N (int): Result count to print. 229 results (list): Top prediction indices. 230 labels (list): A list of labels for every class. 231 prob (list): A list of probabilities for every class. 232 233 Returns: 234 None 235 """ 236 assert (len(results) >= 1 and len(results) == len(labels) == len(prob)) 237 for i in range(min(len(results), N)): 238 print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]])) 239 240 241def download_file(url: str, force: bool = False, filename: str = None): 242 """Downloads a file. 243 244 Args: 245 url (str): File url. 246 force (bool): Forces to download the file even if it exists. 247 filename (str): Renames the file when set. 248 249 Raises: 250 RuntimeError: If for some reason download fails. 251 252 Returns: 253 str: Path to the downloaded file. 254 """ 255 try: 256 if filename is None: # extract filename from url when None 257 filename = urlparse(url) 258 filename = os.path.basename(filename.path) 259 260 print("Downloading '{0}' from '{1}' ...".format(filename, url)) 261 if not os.path.exists(filename) or force is True: 262 r = requests.get(url) 263 with open(filename, 'wb') as f: 264 f.write(r.content) 265 print("Finished.") 266 else: 267 print("File already exists.") 268 except: 269 raise RuntimeError("Unable to download file.") 270 271 return filename 272 273 274def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None): 275 """Gets model and labels. 276 277 Args: 278 model_dir(str): Folder in which model and label files can be found 279 model (str): Name of the model file 280 labels (str): Name of the labels file 281 archive (str): Name of the archive file (optional - need to provide only labels and model) 282 download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it) 283 284 Returns: 285 tuple (str, str): Output label and model filenames 286 """ 287 labels = os.path.join(model_dir, labels) 288 model = os.path.join(model_dir, model) 289 290 if os.path.exists(labels) and os.path.exists(model): 291 print("Found model ({0}) and labels ({1}).".format(model, labels)) 292 elif archive is not None and os.path.exists(os.path.join(model_dir, archive)): 293 print("Found archive ({0}). Unzipping ...".format(archive)) 294 unzip_file(archive) 295 elif download_url is not None: 296 print("Model, labels or archive not found. Downloading ...".format(archive)) 297 try: 298 if isinstance(download_url, str): 299 download_url = [download_url] 300 for dl in download_url: 301 archive = download_file(dl) 302 if dl.lower().endswith(".zip"): 303 unzip_file(archive) 304 except RuntimeError: 305 print("Unable to download file ({}).".format(download_url)) 306 307 if not os.path.exists(labels) or not os.path.exists(model): 308 raise RuntimeError("Unable to provide model and labels.") 309 310 return model, labels 311 312 313def list_images(folder: str = None, formats: list = ('.jpg', '.jpeg')): 314 """Lists files of a certain format in a folder. 315 316 Args: 317 folder (str): Path to the folder to search 318 formats (list): List of supported files 319 320 Returns: 321 list: A list of found files 322 """ 323 files = [] 324 if folder and not os.path.exists(folder): 325 print("Folder '{}' does not exist.".format(folder)) 326 return files 327 328 for file in os.listdir(folder if folder else os.getcwd()): 329 for frmt in formats: 330 if file.lower().endswith(frmt): 331 files.append(os.path.join(folder, file) if folder else file) 332 break # only the format loop 333 334 return files 335 336 337def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL): 338 """Gets image. 339 340 Args: 341 image_dir (str): Image filename 342 image_url (str): Image url 343 344 Returns: 345 str: Output image filename 346 """ 347 images = list_images(image_dir) 348 if not images and image_url is not None: 349 print("No images found. Downloading ...") 350 try: 351 images = [download_file(image_url)] 352 except RuntimeError: 353 print("Unable to download file ({0}).".format(image_url)) 354 355 if not images: 356 raise RuntimeError("Unable to provide images.") 357 358 return images 359