xref: /aosp_15_r20/external/armnn/python/pyarmnn/test/test_deserializer.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT
3*89c4ff92SAndroid Build Coastguard Workerimport os
4*89c4ff92SAndroid Build Coastguard Worker
5*89c4ff92SAndroid Build Coastguard Workerimport pytest
6*89c4ff92SAndroid Build Coastguard Workerimport pyarmnn as ann
7*89c4ff92SAndroid Build Coastguard Workerimport numpy as np
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker@pytest.fixture()
11*89c4ff92SAndroid Build Coastguard Workerdef parser(shared_data_folder):
12*89c4ff92SAndroid Build Coastguard Worker    """
13*89c4ff92SAndroid Build Coastguard Worker    Parse and setup the test network to be used for the tests below
14*89c4ff92SAndroid Build Coastguard Worker    """
15*89c4ff92SAndroid Build Coastguard Worker    parser = ann.IDeserializer()
16*89c4ff92SAndroid Build Coastguard Worker    parser.CreateNetworkFromBinary(os.path.join(shared_data_folder, 'mock_model.armnn'))
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker    yield parser
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Workerdef test_deserializer_swig_destroy():
22*89c4ff92SAndroid Build Coastguard Worker    assert ann.IDeserializer.__swig_destroy__, "There is a swig python destructor defined"
23*89c4ff92SAndroid Build Coastguard Worker    assert ann.IDeserializer.__swig_destroy__.__name__ == "delete_IDeserializer"
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Workerdef test_check_deserializer_swig_ownership(parser):
27*89c4ff92SAndroid Build Coastguard Worker    # Check to see that SWIG has ownership for parser. This instructs SWIG to take
28*89c4ff92SAndroid Build Coastguard Worker    # ownership of the return value. This allows the value to be automatically
29*89c4ff92SAndroid Build Coastguard Worker    # garbage-collected when it is no longer in use
30*89c4ff92SAndroid Build Coastguard Worker    assert parser.thisown
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Workerdef test_deserializer_get_network_input_binding_info(parser):
34*89c4ff92SAndroid Build Coastguard Worker    # use 0 as a dummy value for layer_id, which is unused in the actual implementation
35*89c4ff92SAndroid Build Coastguard Worker    layer_id = 0
36*89c4ff92SAndroid Build Coastguard Worker    input_name = 'input_1'
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker    input_binding_info = parser.GetNetworkInputBindingInfo(layer_id, input_name)
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker    tensor = input_binding_info[1]
41*89c4ff92SAndroid Build Coastguard Worker    assert tensor.GetDataType() == 2
42*89c4ff92SAndroid Build Coastguard Worker    assert tensor.GetNumDimensions() == 4
43*89c4ff92SAndroid Build Coastguard Worker    assert tensor.GetNumElements() == 784
44*89c4ff92SAndroid Build Coastguard Worker    assert tensor.GetQuantizationOffset() == 128
45*89c4ff92SAndroid Build Coastguard Worker    assert tensor.GetQuantizationScale() == 0.007843137718737125
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Workerdef test_deserializer_get_network_output_binding_info(parser):
49*89c4ff92SAndroid Build Coastguard Worker    # use 0 as a dummy value for layer_id, which is unused in the actual implementation
50*89c4ff92SAndroid Build Coastguard Worker    layer_id = 0
51*89c4ff92SAndroid Build Coastguard Worker    output_name = "dense/Softmax"
52*89c4ff92SAndroid Build Coastguard Worker
53*89c4ff92SAndroid Build Coastguard Worker    output_binding_info1 = parser.GetNetworkOutputBindingInfo(layer_id, output_name)
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker    # Check the tensor info retrieved from GetNetworkOutputBindingInfo
56*89c4ff92SAndroid Build Coastguard Worker    tensor1 = output_binding_info1[1]
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker    assert tensor1.GetDataType() == 2
59*89c4ff92SAndroid Build Coastguard Worker    assert tensor1.GetNumDimensions() == 2
60*89c4ff92SAndroid Build Coastguard Worker    assert tensor1.GetNumElements() == 10
61*89c4ff92SAndroid Build Coastguard Worker    assert tensor1.GetQuantizationOffset() == 0
62*89c4ff92SAndroid Build Coastguard Worker    assert tensor1.GetQuantizationScale() == 0.00390625
63*89c4ff92SAndroid Build Coastguard Worker
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Workerdef test_deserializer_filenotfound_exception(shared_data_folder):
66*89c4ff92SAndroid Build Coastguard Worker    parser = ann.IDeserializer()
67*89c4ff92SAndroid Build Coastguard Worker
68*89c4ff92SAndroid Build Coastguard Worker    with pytest.raises(RuntimeError) as err:
69*89c4ff92SAndroid Build Coastguard Worker        parser.CreateNetworkFromBinary(os.path.join(shared_data_folder, 'some_unknown_network.armnn'))
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker    # Only check for part of the exception since the exception returns
72*89c4ff92SAndroid Build Coastguard Worker    # absolute path which will change on different machines.
73*89c4ff92SAndroid Build Coastguard Worker    assert 'Cannot read the file' in str(err.value)
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Workerdef test_deserializer_end_to_end(shared_data_folder):
77*89c4ff92SAndroid Build Coastguard Worker    parser = ann.IDeserializer()
78*89c4ff92SAndroid Build Coastguard Worker
79*89c4ff92SAndroid Build Coastguard Worker    network = parser.CreateNetworkFromBinary(os.path.join(shared_data_folder, "mock_model.armnn"))
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker    # use 0 as a dummy value for layer_id, which is unused in the actual implementation
82*89c4ff92SAndroid Build Coastguard Worker    layer_id = 0
83*89c4ff92SAndroid Build Coastguard Worker    input_name = 'input_1'
84*89c4ff92SAndroid Build Coastguard Worker    output_name = 'dense/Softmax'
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker    input_binding_info = parser.GetNetworkInputBindingInfo(layer_id, input_name)
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = [ann.BackendId('CpuAcc'), ann.BackendId('CpuRef')]
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker    options = ann.CreationOptions()
91*89c4ff92SAndroid Build Coastguard Worker    runtime = ann.IRuntime(options)
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker    opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), ann.OptimizerOptions())
94*89c4ff92SAndroid Build Coastguard Worker    assert 0 == len(messages)
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker    net_id, messages = runtime.LoadNetwork(opt_network)
97*89c4ff92SAndroid Build Coastguard Worker    assert "" == messages
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker    # Load test image data stored in input_lite.npy
100*89c4ff92SAndroid Build Coastguard Worker    input_tensor_data = np.load(os.path.join(shared_data_folder, 'deserializer/input_lite.npy'))
101*89c4ff92SAndroid Build Coastguard Worker    input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor_data])
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker    output_tensors = []
104*89c4ff92SAndroid Build Coastguard Worker    out_bind_info = parser.GetNetworkOutputBindingInfo(layer_id, output_name)
105*89c4ff92SAndroid Build Coastguard Worker    out_tensor_info = out_bind_info[1]
106*89c4ff92SAndroid Build Coastguard Worker    out_tensor_id = out_bind_info[0]
107*89c4ff92SAndroid Build Coastguard Worker    output_tensors.append((out_tensor_id,
108*89c4ff92SAndroid Build Coastguard Worker                           ann.Tensor(out_tensor_info)))
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker    runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
111*89c4ff92SAndroid Build Coastguard Worker
112*89c4ff92SAndroid Build Coastguard Worker    output_vectors = []
113*89c4ff92SAndroid Build Coastguard Worker    for index, out_tensor in enumerate(output_tensors):
114*89c4ff92SAndroid Build Coastguard Worker        output_vectors.append(out_tensor[1].get_memory_area())
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker    # Load golden output file for result comparison.
117*89c4ff92SAndroid Build Coastguard Worker    expected_outputs = np.load(os.path.join(shared_data_folder, 'deserializer/golden_output_lite.npy'))
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker    # Check that output matches golden output
120*89c4ff92SAndroid Build Coastguard Worker    assert (expected_outputs == output_vectors[0]).all()
121