1*89c4ff92SAndroid Build Coastguard Worker# 2*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker# 5*89c4ff92SAndroid Build Coastguard Workerimport argparse 6*89c4ff92SAndroid Build Coastguard Workerfrom pathlib import Path 7*89c4ff92SAndroid Build Coastguard Workerfrom typing import Union 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Workerimport tflite_runtime.interpreter as tflite 10*89c4ff92SAndroid Build Coastguard Workerfrom PIL import Image 11*89c4ff92SAndroid Build Coastguard Workerimport numpy as np 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Workerdef check_args(args: argparse.Namespace): 15*89c4ff92SAndroid Build Coastguard Worker """Check the values used in the command-line have acceptable values 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker args: 18*89c4ff92SAndroid Build Coastguard Worker - args: argparse.Namespace 19*89c4ff92SAndroid Build Coastguard Worker 20*89c4ff92SAndroid Build Coastguard Worker returns: 21*89c4ff92SAndroid Build Coastguard Worker - None 22*89c4ff92SAndroid Build Coastguard Worker 23*89c4ff92SAndroid Build Coastguard Worker raises: 24*89c4ff92SAndroid Build Coastguard Worker - FileNotFoundError: if passed files do not exist. 25*89c4ff92SAndroid Build Coastguard Worker - IOError: if files are of incorrect format. 26*89c4ff92SAndroid Build Coastguard Worker """ 27*89c4ff92SAndroid Build Coastguard Worker input_image_p = args.input_image 28*89c4ff92SAndroid Build Coastguard Worker if not input_image_p.suffix in (".png", ".jpg", ".jpeg"): 29*89c4ff92SAndroid Build Coastguard Worker raise IOError( 30*89c4ff92SAndroid Build Coastguard Worker "--input_image option should point to an image file of the " 31*89c4ff92SAndroid Build Coastguard Worker "format .jpg, .jpeg, .png" 32*89c4ff92SAndroid Build Coastguard Worker ) 33*89c4ff92SAndroid Build Coastguard Worker if not input_image_p.exists(): 34*89c4ff92SAndroid Build Coastguard Worker raise FileNotFoundError("Cannot find ", input_image_p.name) 35*89c4ff92SAndroid Build Coastguard Worker model_p = args.model_file 36*89c4ff92SAndroid Build Coastguard Worker if not model_p.suffix == ".tflite": 37*89c4ff92SAndroid Build Coastguard Worker raise IOError("--model_file should point to a tflite file.") 38*89c4ff92SAndroid Build Coastguard Worker if not model_p.exists(): 39*89c4ff92SAndroid Build Coastguard Worker raise FileNotFoundError("Cannot find ", model_p.name) 40*89c4ff92SAndroid Build Coastguard Worker label_mapping_p = args.label_file 41*89c4ff92SAndroid Build Coastguard Worker if not label_mapping_p.suffix == ".txt": 42*89c4ff92SAndroid Build Coastguard Worker raise IOError("--label_file expects a .txt file.") 43*89c4ff92SAndroid Build Coastguard Worker if not label_mapping_p.exists(): 44*89c4ff92SAndroid Build Coastguard Worker raise FileNotFoundError("Cannot find ", label_mapping_p.name) 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker # check all args given in preferred backends make sense 47*89c4ff92SAndroid Build Coastguard Worker supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"] 48*89c4ff92SAndroid Build Coastguard Worker if not all([backend in supported_backends for backend in args.preferred_backends]): 49*89c4ff92SAndroid Build Coastguard Worker raise ValueError("Incorrect backends given. Please choose from "\ 50*89c4ff92SAndroid Build Coastguard Worker "'GpuAcc', 'CpuAcc', 'CpuRef'.") 51*89c4ff92SAndroid Build Coastguard Worker 52*89c4ff92SAndroid Build Coastguard Worker return None 53*89c4ff92SAndroid Build Coastguard Worker 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Workerdef load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool): 56*89c4ff92SAndroid Build Coastguard Worker """load an image and put into correct format for the tensorflow lite model 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker args: 59*89c4ff92SAndroid Build Coastguard Worker - image_path: pathlib.Path 60*89c4ff92SAndroid Build Coastguard Worker - model_input_dims: tuple (or array-like). (height,width) 61*89c4ff92SAndroid Build Coastguard Worker 62*89c4ff92SAndroid Build Coastguard Worker returns: 63*89c4ff92SAndroid Build Coastguard Worker - image: np.array 64*89c4ff92SAndroid Build Coastguard Worker """ 65*89c4ff92SAndroid Build Coastguard Worker height, width = model_input_dims 66*89c4ff92SAndroid Build Coastguard Worker # load and resize image 67*89c4ff92SAndroid Build Coastguard Worker image = Image.open(image_path).resize((width, height)) 68*89c4ff92SAndroid Build Coastguard Worker # convert to greyscale if expected 69*89c4ff92SAndroid Build Coastguard Worker if grayscale: 70*89c4ff92SAndroid Build Coastguard Worker image = image.convert("LA") 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker image = np.expand_dims(image, axis=0) 73*89c4ff92SAndroid Build Coastguard Worker 74*89c4ff92SAndroid Build Coastguard Worker return image 75*89c4ff92SAndroid Build Coastguard Worker 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Workerdef load_delegate(delegate_path: Path, backends: list): 78*89c4ff92SAndroid Build Coastguard Worker """load the armnn delegate. 79*89c4ff92SAndroid Build Coastguard Worker 80*89c4ff92SAndroid Build Coastguard Worker args: 81*89c4ff92SAndroid Build Coastguard Worker - delegate_path: pathlib.Path -> location of you libarmnnDelegate.so 82*89c4ff92SAndroid Build Coastguard Worker - backends: list -> list of backends you want to use in string format 83*89c4ff92SAndroid Build Coastguard Worker 84*89c4ff92SAndroid Build Coastguard Worker returns: 85*89c4ff92SAndroid Build Coastguard Worker - armnn_delegate: tflite.delegate 86*89c4ff92SAndroid Build Coastguard Worker """ 87*89c4ff92SAndroid Build Coastguard Worker # create a command separated string 88*89c4ff92SAndroid Build Coastguard Worker backend_string = ",".join(backends) 89*89c4ff92SAndroid Build Coastguard Worker # load delegate 90*89c4ff92SAndroid Build Coastguard Worker armnn_delegate = tflite.load_delegate( 91*89c4ff92SAndroid Build Coastguard Worker library=delegate_path, 92*89c4ff92SAndroid Build Coastguard Worker options={"backends": backend_string, "logging-severity": "info"}, 93*89c4ff92SAndroid Build Coastguard Worker ) 94*89c4ff92SAndroid Build Coastguard Worker 95*89c4ff92SAndroid Build Coastguard Worker return armnn_delegate 96*89c4ff92SAndroid Build Coastguard Worker 97*89c4ff92SAndroid Build Coastguard Worker 98*89c4ff92SAndroid Build Coastguard Workerdef load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate): 99*89c4ff92SAndroid Build Coastguard Worker """load a tflite model for use with the armnn delegate. 100*89c4ff92SAndroid Build Coastguard Worker 101*89c4ff92SAndroid Build Coastguard Worker args: 102*89c4ff92SAndroid Build Coastguard Worker - model_path: pathlib.Path 103*89c4ff92SAndroid Build Coastguard Worker - armnn_delegate: tflite.TfLiteDelegate 104*89c4ff92SAndroid Build Coastguard Worker 105*89c4ff92SAndroid Build Coastguard Worker returns: 106*89c4ff92SAndroid Build Coastguard Worker - interpreter: tflite.Interpreter 107*89c4ff92SAndroid Build Coastguard Worker """ 108*89c4ff92SAndroid Build Coastguard Worker interpreter = tflite.Interpreter( 109*89c4ff92SAndroid Build Coastguard Worker model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate] 110*89c4ff92SAndroid Build Coastguard Worker ) 111*89c4ff92SAndroid Build Coastguard Worker interpreter.allocate_tensors() 112*89c4ff92SAndroid Build Coastguard Worker 113*89c4ff92SAndroid Build Coastguard Worker return interpreter 114*89c4ff92SAndroid Build Coastguard Worker 115*89c4ff92SAndroid Build Coastguard Worker 116*89c4ff92SAndroid Build Coastguard Workerdef run_inference(interpreter, input_image): 117*89c4ff92SAndroid Build Coastguard Worker """Run inference on a processed input image and return the output from 118*89c4ff92SAndroid Build Coastguard Worker inference. 119*89c4ff92SAndroid Build Coastguard Worker 120*89c4ff92SAndroid Build Coastguard Worker args: 121*89c4ff92SAndroid Build Coastguard Worker - interpreter: tflite_runtime.interpreter.Interpreter 122*89c4ff92SAndroid Build Coastguard Worker - input_image: np.array 123*89c4ff92SAndroid Build Coastguard Worker 124*89c4ff92SAndroid Build Coastguard Worker returns: 125*89c4ff92SAndroid Build Coastguard Worker - output_data: np.array 126*89c4ff92SAndroid Build Coastguard Worker """ 127*89c4ff92SAndroid Build Coastguard Worker # Get input and output tensors. 128*89c4ff92SAndroid Build Coastguard Worker input_details = interpreter.get_input_details() 129*89c4ff92SAndroid Build Coastguard Worker output_details = interpreter.get_output_details() 130*89c4ff92SAndroid Build Coastguard Worker # Test model on random input data. 131*89c4ff92SAndroid Build Coastguard Worker interpreter.set_tensor(input_details[0]["index"], input_image) 132*89c4ff92SAndroid Build Coastguard Worker interpreter.invoke() 133*89c4ff92SAndroid Build Coastguard Worker output_data = interpreter.get_tensor(output_details[0]["index"]) 134*89c4ff92SAndroid Build Coastguard Worker 135*89c4ff92SAndroid Build Coastguard Worker return output_data 136*89c4ff92SAndroid Build Coastguard Worker 137*89c4ff92SAndroid Build Coastguard Worker 138*89c4ff92SAndroid Build Coastguard Workerdef create_mapping(label_mapping_p): 139*89c4ff92SAndroid Build Coastguard Worker """Creates a Python dictionary mapping an index to a label. 140*89c4ff92SAndroid Build Coastguard Worker 141*89c4ff92SAndroid Build Coastguard Worker label_mapping[idx] = label 142*89c4ff92SAndroid Build Coastguard Worker 143*89c4ff92SAndroid Build Coastguard Worker args: 144*89c4ff92SAndroid Build Coastguard Worker - label_mapping_p: pathlib.Path 145*89c4ff92SAndroid Build Coastguard Worker 146*89c4ff92SAndroid Build Coastguard Worker returns: 147*89c4ff92SAndroid Build Coastguard Worker - label_mapping: dict 148*89c4ff92SAndroid Build Coastguard Worker """ 149*89c4ff92SAndroid Build Coastguard Worker idx = 0 150*89c4ff92SAndroid Build Coastguard Worker label_mapping = {} 151*89c4ff92SAndroid Build Coastguard Worker with open(label_mapping_p) as label_mapping_raw: 152*89c4ff92SAndroid Build Coastguard Worker for line in label_mapping_raw: 153*89c4ff92SAndroid Build Coastguard Worker label_mapping[idx] = line 154*89c4ff92SAndroid Build Coastguard Worker idx += 1 155*89c4ff92SAndroid Build Coastguard Worker 156*89c4ff92SAndroid Build Coastguard Worker return label_mapping 157*89c4ff92SAndroid Build Coastguard Worker 158*89c4ff92SAndroid Build Coastguard Worker 159*89c4ff92SAndroid Build Coastguard Workerdef process_output(output_data, label_mapping): 160*89c4ff92SAndroid Build Coastguard Worker """Process the output tensor into a label from the labelmapping file. Takes 161*89c4ff92SAndroid Build Coastguard Worker the index of the maximum valur from the output array. 162*89c4ff92SAndroid Build Coastguard Worker 163*89c4ff92SAndroid Build Coastguard Worker args: 164*89c4ff92SAndroid Build Coastguard Worker - output_data: np.array 165*89c4ff92SAndroid Build Coastguard Worker - label_mapping: dict 166*89c4ff92SAndroid Build Coastguard Worker 167*89c4ff92SAndroid Build Coastguard Worker returns: 168*89c4ff92SAndroid Build Coastguard Worker - str: labelmapping for max index. 169*89c4ff92SAndroid Build Coastguard Worker """ 170*89c4ff92SAndroid Build Coastguard Worker idx = np.argmax(output_data[0]) 171*89c4ff92SAndroid Build Coastguard Worker 172*89c4ff92SAndroid Build Coastguard Worker return label_mapping[idx] 173*89c4ff92SAndroid Build Coastguard Worker 174*89c4ff92SAndroid Build Coastguard Worker 175*89c4ff92SAndroid Build Coastguard Workerdef main(args): 176*89c4ff92SAndroid Build Coastguard Worker """Run the inference for options passed in the command line. 177*89c4ff92SAndroid Build Coastguard Worker 178*89c4ff92SAndroid Build Coastguard Worker args: 179*89c4ff92SAndroid Build Coastguard Worker - args: argparse.Namespace 180*89c4ff92SAndroid Build Coastguard Worker 181*89c4ff92SAndroid Build Coastguard Worker returns: 182*89c4ff92SAndroid Build Coastguard Worker - None 183*89c4ff92SAndroid Build Coastguard Worker """ 184*89c4ff92SAndroid Build Coastguard Worker # sanity check on args 185*89c4ff92SAndroid Build Coastguard Worker check_args(args) 186*89c4ff92SAndroid Build Coastguard Worker # load in the armnn delegate 187*89c4ff92SAndroid Build Coastguard Worker armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends) 188*89c4ff92SAndroid Build Coastguard Worker # load tflite model 189*89c4ff92SAndroid Build Coastguard Worker interpreter = load_tf_model(args.model_file, armnn_delegate) 190*89c4ff92SAndroid Build Coastguard Worker # get input shape for image resizing 191*89c4ff92SAndroid Build Coastguard Worker input_shape = interpreter.get_input_details()[0]["shape"] 192*89c4ff92SAndroid Build Coastguard Worker height, width = input_shape[1], input_shape[2] 193*89c4ff92SAndroid Build Coastguard Worker input_shape = (height, width) 194*89c4ff92SAndroid Build Coastguard Worker # load input image 195*89c4ff92SAndroid Build Coastguard Worker input_image = load_image(args.input_image, input_shape, False) 196*89c4ff92SAndroid Build Coastguard Worker # get label mapping 197*89c4ff92SAndroid Build Coastguard Worker labelmapping = create_mapping(args.label_file) 198*89c4ff92SAndroid Build Coastguard Worker output_tensor = run_inference(interpreter, input_image) 199*89c4ff92SAndroid Build Coastguard Worker output_prediction = process_output(output_tensor, labelmapping) 200*89c4ff92SAndroid Build Coastguard Worker 201*89c4ff92SAndroid Build Coastguard Worker print("Prediction: ", output_prediction) 202*89c4ff92SAndroid Build Coastguard Worker 203*89c4ff92SAndroid Build Coastguard Worker return None 204*89c4ff92SAndroid Build Coastguard Worker 205*89c4ff92SAndroid Build Coastguard Worker 206*89c4ff92SAndroid Build Coastguard Workerif __name__ == "__main__": 207*89c4ff92SAndroid Build Coastguard Worker parser = argparse.ArgumentParser( 208*89c4ff92SAndroid Build Coastguard Worker formatter_class=argparse.ArgumentDefaultsHelpFormatter 209*89c4ff92SAndroid Build Coastguard Worker ) 210*89c4ff92SAndroid Build Coastguard Worker parser.add_argument( 211*89c4ff92SAndroid Build Coastguard Worker "--input_image", help="File path of image file", type=Path, required=True 212*89c4ff92SAndroid Build Coastguard Worker ) 213*89c4ff92SAndroid Build Coastguard Worker parser.add_argument( 214*89c4ff92SAndroid Build Coastguard Worker "--model_file", 215*89c4ff92SAndroid Build Coastguard Worker help="File path of the model tflite file", 216*89c4ff92SAndroid Build Coastguard Worker type=Path, 217*89c4ff92SAndroid Build Coastguard Worker required=True, 218*89c4ff92SAndroid Build Coastguard Worker ) 219*89c4ff92SAndroid Build Coastguard Worker parser.add_argument( 220*89c4ff92SAndroid Build Coastguard Worker "--label_file", 221*89c4ff92SAndroid Build Coastguard Worker help="File path of model labelmapping file", 222*89c4ff92SAndroid Build Coastguard Worker type=Path, 223*89c4ff92SAndroid Build Coastguard Worker required=True, 224*89c4ff92SAndroid Build Coastguard Worker ) 225*89c4ff92SAndroid Build Coastguard Worker parser.add_argument( 226*89c4ff92SAndroid Build Coastguard Worker "--delegate_path", 227*89c4ff92SAndroid Build Coastguard Worker help="File path of ArmNN delegate file", 228*89c4ff92SAndroid Build Coastguard Worker type=Path, 229*89c4ff92SAndroid Build Coastguard Worker required=True, 230*89c4ff92SAndroid Build Coastguard Worker ) 231*89c4ff92SAndroid Build Coastguard Worker parser.add_argument( 232*89c4ff92SAndroid Build Coastguard Worker "--preferred_backends", 233*89c4ff92SAndroid Build Coastguard Worker help="list of backends in order of preference", 234*89c4ff92SAndroid Build Coastguard Worker type=str, 235*89c4ff92SAndroid Build Coastguard Worker nargs="+", 236*89c4ff92SAndroid Build Coastguard Worker required=False, 237*89c4ff92SAndroid Build Coastguard Worker default=["CpuAcc", "CpuRef"], 238*89c4ff92SAndroid Build Coastguard Worker ) 239*89c4ff92SAndroid Build Coastguard Worker args = parser.parse_args() 240*89c4ff92SAndroid Build Coastguard Worker 241*89c4ff92SAndroid Build Coastguard Worker main(args) 242