xref: /aosp_15_r20/external/armnn/python/pyarmnn/test/test_runtime.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2020 Arm Ltd. All rights reserved.
2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT
3*89c4ff92SAndroid Build Coastguard Workerimport os
4*89c4ff92SAndroid Build Coastguard Worker
5*89c4ff92SAndroid Build Coastguard Workerimport pytest
6*89c4ff92SAndroid Build Coastguard Workerimport warnings
7*89c4ff92SAndroid Build Coastguard Workerimport numpy as np
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Workerimport pyarmnn as ann
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker@pytest.fixture(scope="function")
13*89c4ff92SAndroid Build Coastguard Workerdef random_runtime(shared_data_folder):
14*89c4ff92SAndroid Build Coastguard Worker    parser = ann.ITfLiteParser()
15*89c4ff92SAndroid Build Coastguard Worker    network = parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'mock_model.tflite'))
16*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = [ann.BackendId('CpuRef')]
17*89c4ff92SAndroid Build Coastguard Worker    options = ann.CreationOptions()
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker    runtime = ann.IRuntime(options)
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker    graphs_count = parser.GetSubgraphCount()
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker    graph_id = graphs_count - 1
24*89c4ff92SAndroid Build Coastguard Worker    input_names = parser.GetSubgraphInputTensorNames(graph_id)
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker    input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
27*89c4ff92SAndroid Build Coastguard Worker    input_tensor_id = input_binding_info[0]
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker    input_tensor_info = input_binding_info[1]
30*89c4ff92SAndroid Build Coastguard Worker    input_tensor_info.SetConstant()
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker    output_names = parser.GetSubgraphOutputTensorNames(graph_id)
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker    input_data = np.random.randint(255, size=input_tensor_info.GetNumElements(), dtype=np.uint8)
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker    const_tensor_pair = (input_tensor_id, ann.ConstTensor(input_tensor_info, input_data))
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker    input_tensors = [const_tensor_pair]
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker    output_tensors = []
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker    for index, output_name in enumerate(output_names):
43*89c4ff92SAndroid Build Coastguard Worker        out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Worker        out_tensor_info = out_bind_info[1]
46*89c4ff92SAndroid Build Coastguard Worker        out_tensor_id = out_bind_info[0]
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker        output_tensors.append((out_tensor_id,
49*89c4ff92SAndroid Build Coastguard Worker                               ann.Tensor(out_tensor_info)))
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker    yield preferred_backends, network, runtime, input_tensors, output_tensors
52*89c4ff92SAndroid Build Coastguard Worker
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker@pytest.fixture(scope='function')
55*89c4ff92SAndroid Build Coastguard Workerdef mock_model_runtime(shared_data_folder):
56*89c4ff92SAndroid Build Coastguard Worker    parser = ann.ITfLiteParser()
57*89c4ff92SAndroid Build Coastguard Worker    network = parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'mock_model.tflite'))
58*89c4ff92SAndroid Build Coastguard Worker    graph_id = 0
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker    input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, "input_1")
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker    input_tensor_data = np.load(os.path.join(shared_data_folder, 'tflite_parser/input_lite.npy'))
63*89c4ff92SAndroid Build Coastguard Worker
64*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = [ann.BackendId('CpuRef')]
65*89c4ff92SAndroid Build Coastguard Worker
66*89c4ff92SAndroid Build Coastguard Worker    options = ann.CreationOptions()
67*89c4ff92SAndroid Build Coastguard Worker    runtime = ann.IRuntime(options)
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker    opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), ann.OptimizerOptions())
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker    print(messages)
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker    net_id, messages = runtime.LoadNetwork(opt_network)
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker    print(messages)
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker    input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor_data])
78*89c4ff92SAndroid Build Coastguard Worker
79*89c4ff92SAndroid Build Coastguard Worker    output_names = parser.GetSubgraphOutputTensorNames(graph_id)
80*89c4ff92SAndroid Build Coastguard Worker    outputs_binding_info = []
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker    for output_name in output_names:
83*89c4ff92SAndroid Build Coastguard Worker        outputs_binding_info.append(parser.GetNetworkOutputBindingInfo(graph_id, output_name))
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker    output_tensors = ann.make_output_tensors(outputs_binding_info)
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker    yield runtime, net_id, input_tensors, output_tensors
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Workerdef test_python_disowns_network(random_runtime):
91*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = random_runtime[0]
92*89c4ff92SAndroid Build Coastguard Worker    network = random_runtime[1]
93*89c4ff92SAndroid Build Coastguard Worker    runtime = random_runtime[2]
94*89c4ff92SAndroid Build Coastguard Worker    opt_network, _ = ann.Optimize(network, preferred_backends,
95*89c4ff92SAndroid Build Coastguard Worker                                  runtime.GetDeviceSpec(), ann.OptimizerOptions())
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker    runtime.LoadNetwork(opt_network)
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker    assert not opt_network.thisown
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Workerdef test_load_network(random_runtime):
103*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = random_runtime[0]
104*89c4ff92SAndroid Build Coastguard Worker    network = random_runtime[1]
105*89c4ff92SAndroid Build Coastguard Worker    runtime = random_runtime[2]
106*89c4ff92SAndroid Build Coastguard Worker
107*89c4ff92SAndroid Build Coastguard Worker    opt_network, _ = ann.Optimize(network, preferred_backends,
108*89c4ff92SAndroid Build Coastguard Worker                                  runtime.GetDeviceSpec(), ann.OptimizerOptions())
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker    net_id, messages = runtime.LoadNetwork(opt_network)
111*89c4ff92SAndroid Build Coastguard Worker    assert "" == messages
112*89c4ff92SAndroid Build Coastguard Worker    assert net_id == 0
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker
115*89c4ff92SAndroid Build Coastguard Workerdef test_create_runtime_with_external_profiling_enabled():
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker    options = ann.CreationOptions()
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_FileOnly = True
120*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_EnableProfiling = True
121*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_OutgoingCaptureFile = "/tmp/outgoing.txt"
122*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_IncomingCaptureFile = "/tmp/incoming.txt"
123*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_TimelineEnabled = True
124*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_CapturePeriod = 1000
125*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_FileFormat = "JSON"
126*89c4ff92SAndroid Build Coastguard Worker
127*89c4ff92SAndroid Build Coastguard Worker    runtime = ann.IRuntime(options)
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker    assert runtime is not None
130*89c4ff92SAndroid Build Coastguard Worker
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Workerdef test_create_runtime_with_external_profiling_enabled_invalid_options():
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker    options = ann.CreationOptions()
135*89c4ff92SAndroid Build Coastguard Worker
136*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_FileOnly = True
137*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_EnableProfiling = False
138*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_OutgoingCaptureFile = "/tmp/outgoing.txt"
139*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_IncomingCaptureFile = "/tmp/incoming.txt"
140*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_TimelineEnabled = True
141*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_CapturePeriod = 1000
142*89c4ff92SAndroid Build Coastguard Worker    options.m_ProfilingOptions.m_FileFormat = "JSON"
143*89c4ff92SAndroid Build Coastguard Worker
144*89c4ff92SAndroid Build Coastguard Worker    with pytest.raises(RuntimeError) as err:
145*89c4ff92SAndroid Build Coastguard Worker        runtime = ann.IRuntime(options)
146*89c4ff92SAndroid Build Coastguard Worker
147*89c4ff92SAndroid Build Coastguard Worker    expected_error_message = "It is not possible to enable timeline reporting without profiling being enabled"
148*89c4ff92SAndroid Build Coastguard Worker    assert expected_error_message in str(err.value)
149*89c4ff92SAndroid Build Coastguard Worker
150*89c4ff92SAndroid Build Coastguard Worker
151*89c4ff92SAndroid Build Coastguard Workerdef test_load_network_properties_provided(random_runtime):
152*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = random_runtime[0]
153*89c4ff92SAndroid Build Coastguard Worker    network = random_runtime[1]
154*89c4ff92SAndroid Build Coastguard Worker    runtime = random_runtime[2]
155*89c4ff92SAndroid Build Coastguard Worker
156*89c4ff92SAndroid Build Coastguard Worker    opt_network, _ = ann.Optimize(network, preferred_backends,
157*89c4ff92SAndroid Build Coastguard Worker                                  runtime.GetDeviceSpec(), ann.OptimizerOptions())
158*89c4ff92SAndroid Build Coastguard Worker
159*89c4ff92SAndroid Build Coastguard Worker    inputSource = ann.MemorySource_Undefined
160*89c4ff92SAndroid Build Coastguard Worker    outputSource = ann.MemorySource_Undefined
161*89c4ff92SAndroid Build Coastguard Worker    properties = ann.INetworkProperties(False, inputSource, outputSource)
162*89c4ff92SAndroid Build Coastguard Worker    net_id, messages = runtime.LoadNetwork(opt_network, properties)
163*89c4ff92SAndroid Build Coastguard Worker    assert "" == messages
164*89c4ff92SAndroid Build Coastguard Worker    assert net_id == 0
165*89c4ff92SAndroid Build Coastguard Worker
166*89c4ff92SAndroid Build Coastguard Worker
167*89c4ff92SAndroid Build Coastguard Workerdef test_network_properties_constructor(random_runtime):
168*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = random_runtime[0]
169*89c4ff92SAndroid Build Coastguard Worker    network = random_runtime[1]
170*89c4ff92SAndroid Build Coastguard Worker    runtime = random_runtime[2]
171*89c4ff92SAndroid Build Coastguard Worker
172*89c4ff92SAndroid Build Coastguard Worker    opt_network, _ = ann.Optimize(network, preferred_backends,
173*89c4ff92SAndroid Build Coastguard Worker                                  runtime.GetDeviceSpec(), ann.OptimizerOptions())
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker    inputSource = ann.MemorySource_Undefined
176*89c4ff92SAndroid Build Coastguard Worker    outputSource = ann.MemorySource_Undefined
177*89c4ff92SAndroid Build Coastguard Worker    properties = ann.INetworkProperties(True, inputSource, outputSource)
178*89c4ff92SAndroid Build Coastguard Worker    assert properties.m_AsyncEnabled == True
179*89c4ff92SAndroid Build Coastguard Worker    assert properties.m_ProfilingEnabled == False
180*89c4ff92SAndroid Build Coastguard Worker    assert properties.m_OutputNetworkDetailsMethod == ann.ProfilingDetailsMethod_Undefined
181*89c4ff92SAndroid Build Coastguard Worker    assert properties.m_InputSource == ann.MemorySource_Undefined
182*89c4ff92SAndroid Build Coastguard Worker    assert properties.m_OutputSource == ann.MemorySource_Undefined
183*89c4ff92SAndroid Build Coastguard Worker
184*89c4ff92SAndroid Build Coastguard Worker    net_id, messages = runtime.LoadNetwork(opt_network, properties)
185*89c4ff92SAndroid Build Coastguard Worker    assert "" == messages
186*89c4ff92SAndroid Build Coastguard Worker    assert net_id == 0
187*89c4ff92SAndroid Build Coastguard Worker
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Workerdef test_unload_network_fails_for_invalid_net_id(random_runtime):
190*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = random_runtime[0]
191*89c4ff92SAndroid Build Coastguard Worker    network = random_runtime[1]
192*89c4ff92SAndroid Build Coastguard Worker    runtime = random_runtime[2]
193*89c4ff92SAndroid Build Coastguard Worker
194*89c4ff92SAndroid Build Coastguard Worker    ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), ann.OptimizerOptions())
195*89c4ff92SAndroid Build Coastguard Worker
196*89c4ff92SAndroid Build Coastguard Worker    with pytest.raises(RuntimeError) as err:
197*89c4ff92SAndroid Build Coastguard Worker        runtime.UnloadNetwork(9)
198*89c4ff92SAndroid Build Coastguard Worker
199*89c4ff92SAndroid Build Coastguard Worker    expected_error_message = "Failed to unload network."
200*89c4ff92SAndroid Build Coastguard Worker    assert expected_error_message in str(err.value)
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker
203*89c4ff92SAndroid Build Coastguard Workerdef test_enqueue_workload(random_runtime):
204*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = random_runtime[0]
205*89c4ff92SAndroid Build Coastguard Worker    network = random_runtime[1]
206*89c4ff92SAndroid Build Coastguard Worker    runtime = random_runtime[2]
207*89c4ff92SAndroid Build Coastguard Worker    input_tensors = random_runtime[3]
208*89c4ff92SAndroid Build Coastguard Worker    output_tensors = random_runtime[4]
209*89c4ff92SAndroid Build Coastguard Worker
210*89c4ff92SAndroid Build Coastguard Worker    opt_network, _ = ann.Optimize(network, preferred_backends,
211*89c4ff92SAndroid Build Coastguard Worker                                  runtime.GetDeviceSpec(), ann.OptimizerOptions())
212*89c4ff92SAndroid Build Coastguard Worker
213*89c4ff92SAndroid Build Coastguard Worker    net_id, _ = runtime.LoadNetwork(opt_network)
214*89c4ff92SAndroid Build Coastguard Worker    runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
215*89c4ff92SAndroid Build Coastguard Worker
216*89c4ff92SAndroid Build Coastguard Worker
217*89c4ff92SAndroid Build Coastguard Workerdef test_enqueue_workload_fails_with_empty_input_tensors(random_runtime):
218*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = random_runtime[0]
219*89c4ff92SAndroid Build Coastguard Worker    network = random_runtime[1]
220*89c4ff92SAndroid Build Coastguard Worker    runtime = random_runtime[2]
221*89c4ff92SAndroid Build Coastguard Worker    input_tensors = []
222*89c4ff92SAndroid Build Coastguard Worker    output_tensors = random_runtime[4]
223*89c4ff92SAndroid Build Coastguard Worker
224*89c4ff92SAndroid Build Coastguard Worker    opt_network, _ = ann.Optimize(network, preferred_backends,
225*89c4ff92SAndroid Build Coastguard Worker                                  runtime.GetDeviceSpec(), ann.OptimizerOptions())
226*89c4ff92SAndroid Build Coastguard Worker
227*89c4ff92SAndroid Build Coastguard Worker    net_id, _ = runtime.LoadNetwork(opt_network)
228*89c4ff92SAndroid Build Coastguard Worker    with pytest.raises(RuntimeError) as err:
229*89c4ff92SAndroid Build Coastguard Worker        runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
230*89c4ff92SAndroid Build Coastguard Worker
231*89c4ff92SAndroid Build Coastguard Worker    expected_error_message = "Number of inputs provided does not match network."
232*89c4ff92SAndroid Build Coastguard Worker    assert expected_error_message in str(err.value)
233*89c4ff92SAndroid Build Coastguard Worker
234*89c4ff92SAndroid Build Coastguard Worker
235*89c4ff92SAndroid Build Coastguard Worker@pytest.mark.x86_64
236*89c4ff92SAndroid Build Coastguard Worker@pytest.mark.parametrize('count', [5])
237*89c4ff92SAndroid Build Coastguard Workerdef test_multiple_inference_runs_yield_same_result(count, mock_model_runtime):
238*89c4ff92SAndroid Build Coastguard Worker    """
239*89c4ff92SAndroid Build Coastguard Worker    Test that results remain consistent among multiple runs of the same inference.
240*89c4ff92SAndroid Build Coastguard Worker    """
241*89c4ff92SAndroid Build Coastguard Worker    runtime = mock_model_runtime[0]
242*89c4ff92SAndroid Build Coastguard Worker    net_id = mock_model_runtime[1]
243*89c4ff92SAndroid Build Coastguard Worker    input_tensors = mock_model_runtime[2]
244*89c4ff92SAndroid Build Coastguard Worker    output_tensors = mock_model_runtime[3]
245*89c4ff92SAndroid Build Coastguard Worker
246*89c4ff92SAndroid Build Coastguard Worker    expected_results = np.array([[4,  85, 108,  29,   8,  16,   0,   2,   5,   0]])
247*89c4ff92SAndroid Build Coastguard Worker
248*89c4ff92SAndroid Build Coastguard Worker    for _ in range(count):
249*89c4ff92SAndroid Build Coastguard Worker        runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
250*89c4ff92SAndroid Build Coastguard Worker
251*89c4ff92SAndroid Build Coastguard Worker        output_vectors = ann.workload_tensors_to_ndarray(output_tensors)
252*89c4ff92SAndroid Build Coastguard Worker
253*89c4ff92SAndroid Build Coastguard Worker        for i in range(len(expected_results)):
254*89c4ff92SAndroid Build Coastguard Worker            assert output_vectors[i].all() == expected_results[i].all()
255*89c4ff92SAndroid Build Coastguard Worker
256*89c4ff92SAndroid Build Coastguard Worker
257*89c4ff92SAndroid Build Coastguard Worker@pytest.mark.aarch64
258*89c4ff92SAndroid Build Coastguard Workerdef test_aarch64_inference_results(mock_model_runtime):
259*89c4ff92SAndroid Build Coastguard Worker
260*89c4ff92SAndroid Build Coastguard Worker    runtime = mock_model_runtime[0]
261*89c4ff92SAndroid Build Coastguard Worker    net_id = mock_model_runtime[1]
262*89c4ff92SAndroid Build Coastguard Worker    input_tensors = mock_model_runtime[2]
263*89c4ff92SAndroid Build Coastguard Worker    output_tensors = mock_model_runtime[3]
264*89c4ff92SAndroid Build Coastguard Worker
265*89c4ff92SAndroid Build Coastguard Worker    runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
266*89c4ff92SAndroid Build Coastguard Worker
267*89c4ff92SAndroid Build Coastguard Worker    output_vectors = ann.workload_tensors_to_ndarray(output_tensors)
268*89c4ff92SAndroid Build Coastguard Worker
269*89c4ff92SAndroid Build Coastguard Worker    expected_outputs = expected_results = np.array([[4,  85, 108,  29,   8,  16,   0,   2,   5,   0]])
270*89c4ff92SAndroid Build Coastguard Worker
271*89c4ff92SAndroid Build Coastguard Worker    for i in range(len(expected_outputs)):
272*89c4ff92SAndroid Build Coastguard Worker        assert output_vectors[i].all() == expected_results[i].all()
273*89c4ff92SAndroid Build Coastguard Worker
274*89c4ff92SAndroid Build Coastguard Worker
275*89c4ff92SAndroid Build Coastguard Workerdef test_enqueue_workload_with_profiler(random_runtime):
276*89c4ff92SAndroid Build Coastguard Worker    """
277*89c4ff92SAndroid Build Coastguard Worker    Tests ArmNN's profiling extension
278*89c4ff92SAndroid Build Coastguard Worker    """
279*89c4ff92SAndroid Build Coastguard Worker    preferred_backends = random_runtime[0]
280*89c4ff92SAndroid Build Coastguard Worker    network = random_runtime[1]
281*89c4ff92SAndroid Build Coastguard Worker    runtime = random_runtime[2]
282*89c4ff92SAndroid Build Coastguard Worker    input_tensors = random_runtime[3]
283*89c4ff92SAndroid Build Coastguard Worker    output_tensors = random_runtime[4]
284*89c4ff92SAndroid Build Coastguard Worker
285*89c4ff92SAndroid Build Coastguard Worker    opt_network, _ = ann.Optimize(network, preferred_backends,
286*89c4ff92SAndroid Build Coastguard Worker                                  runtime.GetDeviceSpec(), ann.OptimizerOptions())
287*89c4ff92SAndroid Build Coastguard Worker    net_id, _ = runtime.LoadNetwork(opt_network)
288*89c4ff92SAndroid Build Coastguard Worker
289*89c4ff92SAndroid Build Coastguard Worker    profiler = runtime.GetProfiler(net_id)
290*89c4ff92SAndroid Build Coastguard Worker    # By default profiling should be turned off:
291*89c4ff92SAndroid Build Coastguard Worker    assert profiler.IsProfilingEnabled() is False
292*89c4ff92SAndroid Build Coastguard Worker
293*89c4ff92SAndroid Build Coastguard Worker    # Enable profiling:
294*89c4ff92SAndroid Build Coastguard Worker    profiler.EnableProfiling(True)
295*89c4ff92SAndroid Build Coastguard Worker    assert profiler.IsProfilingEnabled() is True
296*89c4ff92SAndroid Build Coastguard Worker
297*89c4ff92SAndroid Build Coastguard Worker    # Run the inference:
298*89c4ff92SAndroid Build Coastguard Worker    runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
299*89c4ff92SAndroid Build Coastguard Worker
300*89c4ff92SAndroid Build Coastguard Worker    # Get profile output as a string:
301*89c4ff92SAndroid Build Coastguard Worker    str_profile = profiler.as_json()
302*89c4ff92SAndroid Build Coastguard Worker
303*89c4ff92SAndroid Build Coastguard Worker    # Verify that certain markers are present:
304*89c4ff92SAndroid Build Coastguard Worker    assert len(str_profile) != 0
305*89c4ff92SAndroid Build Coastguard Worker    assert str_profile.find('\"ArmNN\": {') > 0
306*89c4ff92SAndroid Build Coastguard Worker
307*89c4ff92SAndroid Build Coastguard Worker    # Get events analysis output as a string:
308*89c4ff92SAndroid Build Coastguard Worker    str_events_analysis = profiler.event_log()
309*89c4ff92SAndroid Build Coastguard Worker
310*89c4ff92SAndroid Build Coastguard Worker    assert "Event Sequence - Name | Duration (ms) | Start (ms) | Stop (ms) | Device" in str_events_analysis
311*89c4ff92SAndroid Build Coastguard Worker
312*89c4ff92SAndroid Build Coastguard Worker    assert profiler.thisown == 0
313*89c4ff92SAndroid Build Coastguard Worker
314*89c4ff92SAndroid Build Coastguard Worker
315*89c4ff92SAndroid Build Coastguard Workerdef test_check_runtime_swig_ownership(random_runtime):
316*89c4ff92SAndroid Build Coastguard Worker    # Check to see that SWIG has ownership for runtime. This instructs SWIG to take
317*89c4ff92SAndroid Build Coastguard Worker    # ownership of the return value. This allows the value to be automatically
318*89c4ff92SAndroid Build Coastguard Worker    # garbage-collected when it is no longer in use
319*89c4ff92SAndroid Build Coastguard Worker    runtime = random_runtime[2]
320*89c4ff92SAndroid Build Coastguard Worker    assert runtime.thisown
321