1# Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 2# SPDX-License-Identifier: MIT 3 4import os 5from typing import List, Tuple 6 7import numpy as np 8from tflite_runtime import interpreter as tflite 9 10class TFLiteNetworkExecutor: 11 12 def __init__(self, model_file: str, backends: list, delegate_path: str): 13 """ 14 Creates an inference executor for a given network and a list of backends. 15 16 Args: 17 model_file: User-specified model file. 18 backends: List of backends to optimize network. 19 delegate_path: tflite delegate file path (.so). 20 """ 21 self.model_file = model_file 22 self.backends = backends 23 self.delegate_path = delegate_path 24 self.interpreter, self.input_details, self.output_details = self.create_network() 25 26 def run(self, input_data_list: list) -> List[np.ndarray]: 27 """ 28 Executes inference for the loaded network. 29 30 Args: 31 input_data_list: List of input frames. 32 33 Returns: 34 list: Inference results as a list of ndarrays. 35 """ 36 output = [] 37 for index, input_data in enumerate(input_data_list): 38 self.interpreter.set_tensor(self.input_details[index]['index'], input_data) 39 self.interpreter.invoke() 40 for curr_output in self.output_details: 41 output.append(self.interpreter.get_tensor(curr_output['index'])) 42 43 return output 44 45 def create_network(self): 46 """ 47 Creates a network based on the model file and a list of backends. 48 49 Returns: 50 interpreter: A TensorFlow Lite object for executing inference. 51 input_details: Contains essential information about the model input. 52 output_details: Used to map output tensor and its memory. 53 """ 54 55 # Controls whether optimizations are used or not. 56 # Please note that optimizations can improve performance in some cases, but it can also 57 # degrade the performance in other cases. Accuracy might also be affected. 58 59 optimization_enable = "true" 60 61 if not os.path.exists(self.model_file): 62 raise FileNotFoundError(f'Model file not found for: {self.model_file}') 63 64 _, ext = os.path.splitext(self.model_file) 65 if ext == '.tflite': 66 armnn_delegate = tflite.load_delegate(library=self.delegate_path, 67 options={"backends": ','.join(self.backends), "logging-severity": "info", 68 "enable-fast-math": optimization_enable, 69 "reduce-fp32-to-fp16": optimization_enable}) 70 interpreter = tflite.Interpreter(model_path=self.model_file, 71 experimental_delegates=[armnn_delegate]) 72 interpreter.allocate_tensors() 73 else: 74 raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") 75 76 # Get input and output binding information 77 input_details = interpreter.get_input_details() 78 output_details = interpreter.get_output_details() 79 80 return interpreter, input_details, output_details 81 82 def get_data_type(self): 83 """ 84 Get the input data type of the initiated network. 85 86 Returns: 87 numpy data type or None if doesn't exist in the if condition. 88 """ 89 return self.input_details[0]['dtype'] 90 91 def get_shape(self): 92 """ 93 Get the input shape of the initiated network. 94 95 Returns: 96 tuple: The Shape of the network input. 97 """ 98 return tuple(self.input_details[0]['shape']) 99