xref: /aosp_15_r20/external/armnn/delegate/python/test/utils.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT
3*89c4ff92SAndroid Build Coastguard Worker
4*89c4ff92SAndroid Build Coastguard Workerimport tflite_runtime.interpreter as tflite
5*89c4ff92SAndroid Build Coastguard Workerimport numpy as np
6*89c4ff92SAndroid Build Coastguard Workerimport os
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Workerdef run_mock_model(delegate, test_data_folder):
10*89c4ff92SAndroid Build Coastguard Worker    model_path = os.path.join(test_data_folder, 'mock_model.tflite')
11*89c4ff92SAndroid Build Coastguard Worker    interpreter = tflite.Interpreter(model_path=model_path,
12*89c4ff92SAndroid Build Coastguard Worker                                     experimental_delegates=[delegate])
13*89c4ff92SAndroid Build Coastguard Worker    interpreter.allocate_tensors()
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker    # Get input and output tensors.
16*89c4ff92SAndroid Build Coastguard Worker    input_details = interpreter.get_input_details()
17*89c4ff92SAndroid Build Coastguard Worker    output_details = interpreter.get_output_details()
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker    # Test model on random input data.
20*89c4ff92SAndroid Build Coastguard Worker    input_shape = input_details[0]['shape']
21*89c4ff92SAndroid Build Coastguard Worker    input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8)
22*89c4ff92SAndroid Build Coastguard Worker    interpreter.set_tensor(input_details[0]['index'], input_data)
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Worker    interpreter.invoke()
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Workerdef run_inference(test_data_folder, model_filename, inputs, delegates=None):
27*89c4ff92SAndroid Build Coastguard Worker    model_path = os.path.join(test_data_folder, model_filename)
28*89c4ff92SAndroid Build Coastguard Worker    interpreter = tflite.Interpreter(model_path=model_path,
29*89c4ff92SAndroid Build Coastguard Worker                                     experimental_delegates=delegates)
30*89c4ff92SAndroid Build Coastguard Worker    interpreter.allocate_tensors()
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker    # Get input and output tensors.
33*89c4ff92SAndroid Build Coastguard Worker    input_details = interpreter.get_input_details()
34*89c4ff92SAndroid Build Coastguard Worker    output_details = interpreter.get_output_details()
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker    # Set inputs to tensors.
37*89c4ff92SAndroid Build Coastguard Worker    for i in range(len(inputs)):
38*89c4ff92SAndroid Build Coastguard Worker        interpreter.set_tensor(input_details[i]['index'], inputs[i])
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker    interpreter.invoke()
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker    results = []
43*89c4ff92SAndroid Build Coastguard Worker    for output in output_details:
44*89c4ff92SAndroid Build Coastguard Worker        results.append(interpreter.get_tensor(output['index']))
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker    return results
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Workerdef compare_outputs(outputs, expected_outputs):
49*89c4ff92SAndroid Build Coastguard Worker    assert len(outputs) == len(expected_outputs), 'Incorrect number of outputs'
50*89c4ff92SAndroid Build Coastguard Worker    for i in range(len(expected_outputs)):
51*89c4ff92SAndroid Build Coastguard Worker        assert outputs[i].shape == expected_outputs[i].shape, 'Incorrect output shape on output#{}'.format(i)
52*89c4ff92SAndroid Build Coastguard Worker        assert outputs[i].dtype == expected_outputs[i].dtype, 'Incorrect output data type on output#{}'.format(i)
53*89c4ff92SAndroid Build Coastguard Worker        assert outputs[i].all() == expected_outputs[i].all(), 'Incorrect output value on output#{}'.format(i)