1# Copyright © 2020,2023 Arm Ltd. All rights reserved. 2# SPDX-License-Identifier: MIT 3import os 4 5import pytest 6import pyarmnn as ann 7import numpy as np 8 9 10@pytest.fixture() 11def parser(shared_data_folder): 12 """ 13 Parse and setup the test network to be used for the tests below 14 """ 15 16 # create onnx parser 17 parser = ann.IOnnxParser() 18 19 # path to model 20 path_to_model = os.path.join(shared_data_folder, 'mock_model.onnx') 21 22 # parse onnx binary & create network 23 parser.CreateNetworkFromBinaryFile(path_to_model) 24 25 yield parser 26 27 28def test_onnx_parser_swig_destroy(): 29 assert ann.IOnnxParser.__swig_destroy__, "There is a swig python destructor defined" 30 assert ann.IOnnxParser.__swig_destroy__.__name__ == "delete_IOnnxParser" 31 32 33def test_check_onnx_parser_swig_ownership(parser): 34 # Check to see that SWIG has ownership for parser. This instructs SWIG to take 35 # ownership of the return value. This allows the value to be automatically 36 # garbage-collected when it is no longer in use 37 assert parser.thisown 38 39 40def test_onnx_parser_get_network_input_binding_info(parser): 41 input_binding_info = parser.GetNetworkInputBindingInfo("input") 42 43 tensor = input_binding_info[1] 44 assert tensor.GetDataType() == 1 45 assert tensor.GetNumDimensions() == 4 46 assert tensor.GetNumElements() == 784 47 assert tensor.GetQuantizationOffset() == 0 48 assert tensor.GetQuantizationScale() == 1 49 50 51def test_onnx_parser_get_network_output_binding_info(parser): 52 output_binding_info = parser.GetNetworkOutputBindingInfo("output") 53 54 tensor = output_binding_info[1] 55 assert tensor.GetDataType() == 1 56 assert tensor.GetNumDimensions() == 4 57 assert tensor.GetNumElements() == 10 58 assert tensor.GetQuantizationOffset() == 0 59 assert tensor.GetQuantizationScale() == 1 60 61 62def test_onnx_filenotfound_exception(shared_data_folder): 63 parser = ann.IOnnxParser() 64 65 # path to model 66 path_to_model = os.path.join(shared_data_folder, 'some_unknown_model.onnx') 67 68 # parse onnx binary & create network 69 70 with pytest.raises(RuntimeError) as err: 71 parser.CreateNetworkFromBinaryFile(path_to_model) 72 73 # Only check for part of the exception since the exception returns 74 # absolute path which will change on different machines. 75 assert 'Invalid (null) filename' in str(err.value) 76 77 78def test_onnx_parser_end_to_end(shared_data_folder): 79 parser = ann.IOnnxParser = ann.IOnnxParser() 80 81 network = parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'mock_model.onnx')) 82 83 # load test image data stored in input_onnx.npy 84 input_binding_info = parser.GetNetworkInputBindingInfo("input") 85 input_tensor_data = np.load(os.path.join(shared_data_folder, 'onnx_parser/input_onnx.npy')).astype(np.float32) 86 87 options = ann.CreationOptions() 88 runtime = ann.IRuntime(options) 89 90 preferred_backends = [ann.BackendId('CpuAcc'), ann.BackendId('CpuRef')] 91 opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), ann.OptimizerOptions()) 92 93 assert 0 == len(messages) 94 95 net_id, messages = runtime.LoadNetwork(opt_network) 96 97 assert "" == messages 98 99 input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor_data]) 100 output_tensors = ann.make_output_tensors([parser.GetNetworkOutputBindingInfo("output")]) 101 102 runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) 103 104 output = ann.workload_tensors_to_ndarray(output_tensors) 105 106 # Load golden output file for result comparison. 107 golden_output = np.load(os.path.join(shared_data_folder, 'onnx_parser/golden_output_onnx.npy')) 108 109 # Check that output matches golden output to 4 decimal places (there are slight rounding differences after this) 110 np.testing.assert_almost_equal(output[0], golden_output, decimal=4) 111