1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT 3*89c4ff92SAndroid Build Coastguard Worker 4*89c4ff92SAndroid Build Coastguard Workerimport tflite_runtime.interpreter as tflite 5*89c4ff92SAndroid Build Coastguard Workerimport numpy as np 6*89c4ff92SAndroid Build Coastguard Workerimport os 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Workerdef run_mock_model(delegate, test_data_folder): 10*89c4ff92SAndroid Build Coastguard Worker model_path = os.path.join(test_data_folder, 'mock_model.tflite') 11*89c4ff92SAndroid Build Coastguard Worker interpreter = tflite.Interpreter(model_path=model_path, 12*89c4ff92SAndroid Build Coastguard Worker experimental_delegates=[delegate]) 13*89c4ff92SAndroid Build Coastguard Worker interpreter.allocate_tensors() 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker # Get input and output tensors. 16*89c4ff92SAndroid Build Coastguard Worker input_details = interpreter.get_input_details() 17*89c4ff92SAndroid Build Coastguard Worker output_details = interpreter.get_output_details() 18*89c4ff92SAndroid Build Coastguard Worker 19*89c4ff92SAndroid Build Coastguard Worker # Test model on random input data. 20*89c4ff92SAndroid Build Coastguard Worker input_shape = input_details[0]['shape'] 21*89c4ff92SAndroid Build Coastguard Worker input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8) 22*89c4ff92SAndroid Build Coastguard Worker interpreter.set_tensor(input_details[0]['index'], input_data) 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker interpreter.invoke() 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Workerdef run_inference(test_data_folder, model_filename, inputs, delegates=None): 27*89c4ff92SAndroid Build Coastguard Worker model_path = os.path.join(test_data_folder, model_filename) 28*89c4ff92SAndroid Build Coastguard Worker interpreter = tflite.Interpreter(model_path=model_path, 29*89c4ff92SAndroid Build Coastguard Worker experimental_delegates=delegates) 30*89c4ff92SAndroid Build Coastguard Worker interpreter.allocate_tensors() 31*89c4ff92SAndroid Build Coastguard Worker 32*89c4ff92SAndroid Build Coastguard Worker # Get input and output tensors. 33*89c4ff92SAndroid Build Coastguard Worker input_details = interpreter.get_input_details() 34*89c4ff92SAndroid Build Coastguard Worker output_details = interpreter.get_output_details() 35*89c4ff92SAndroid Build Coastguard Worker 36*89c4ff92SAndroid Build Coastguard Worker # Set inputs to tensors. 37*89c4ff92SAndroid Build Coastguard Worker for i in range(len(inputs)): 38*89c4ff92SAndroid Build Coastguard Worker interpreter.set_tensor(input_details[i]['index'], inputs[i]) 39*89c4ff92SAndroid Build Coastguard Worker 40*89c4ff92SAndroid Build Coastguard Worker interpreter.invoke() 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker results = [] 43*89c4ff92SAndroid Build Coastguard Worker for output in output_details: 44*89c4ff92SAndroid Build Coastguard Worker results.append(interpreter.get_tensor(output['index'])) 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker return results 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Workerdef compare_outputs(outputs, expected_outputs): 49*89c4ff92SAndroid Build Coastguard Worker assert len(outputs) == len(expected_outputs), 'Incorrect number of outputs' 50*89c4ff92SAndroid Build Coastguard Worker for i in range(len(expected_outputs)): 51*89c4ff92SAndroid Build Coastguard Worker assert outputs[i].shape == expected_outputs[i].shape, 'Incorrect output shape on output#{}'.format(i) 52*89c4ff92SAndroid Build Coastguard Worker assert outputs[i].dtype == expected_outputs[i].dtype, 'Incorrect output data type on output#{}'.format(i) 53*89c4ff92SAndroid Build Coastguard Worker assert outputs[i].all() == expected_outputs[i].all(), 'Incorrect output value on output#{}'.format(i)