1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2020 Arm Ltd. All rights reserved. 2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT 3*89c4ff92SAndroid Build Coastguard Workerimport os 4*89c4ff92SAndroid Build Coastguard Worker 5*89c4ff92SAndroid Build Coastguard Workerimport pytest 6*89c4ff92SAndroid Build Coastguard Workerimport warnings 7*89c4ff92SAndroid Build Coastguard Workerimport numpy as np 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Workerimport pyarmnn as ann 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker@pytest.fixture(scope="function") 13*89c4ff92SAndroid Build Coastguard Workerdef random_runtime(shared_data_folder): 14*89c4ff92SAndroid Build Coastguard Worker parser = ann.ITfLiteParser() 15*89c4ff92SAndroid Build Coastguard Worker network = parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'mock_model.tflite')) 16*89c4ff92SAndroid Build Coastguard Worker preferred_backends = [ann.BackendId('CpuRef')] 17*89c4ff92SAndroid Build Coastguard Worker options = ann.CreationOptions() 18*89c4ff92SAndroid Build Coastguard Worker 19*89c4ff92SAndroid Build Coastguard Worker runtime = ann.IRuntime(options) 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker graphs_count = parser.GetSubgraphCount() 22*89c4ff92SAndroid Build Coastguard Worker 23*89c4ff92SAndroid Build Coastguard Worker graph_id = graphs_count - 1 24*89c4ff92SAndroid Build Coastguard Worker input_names = parser.GetSubgraphInputTensorNames(graph_id) 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Worker input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) 27*89c4ff92SAndroid Build Coastguard Worker input_tensor_id = input_binding_info[0] 28*89c4ff92SAndroid Build Coastguard Worker 29*89c4ff92SAndroid Build Coastguard Worker input_tensor_info = input_binding_info[1] 30*89c4ff92SAndroid Build Coastguard Worker input_tensor_info.SetConstant() 31*89c4ff92SAndroid Build Coastguard Worker 32*89c4ff92SAndroid Build Coastguard Worker output_names = parser.GetSubgraphOutputTensorNames(graph_id) 33*89c4ff92SAndroid Build Coastguard Worker 34*89c4ff92SAndroid Build Coastguard Worker input_data = np.random.randint(255, size=input_tensor_info.GetNumElements(), dtype=np.uint8) 35*89c4ff92SAndroid Build Coastguard Worker 36*89c4ff92SAndroid Build Coastguard Worker const_tensor_pair = (input_tensor_id, ann.ConstTensor(input_tensor_info, input_data)) 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker input_tensors = [const_tensor_pair] 39*89c4ff92SAndroid Build Coastguard Worker 40*89c4ff92SAndroid Build Coastguard Worker output_tensors = [] 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker for index, output_name in enumerate(output_names): 43*89c4ff92SAndroid Build Coastguard Worker out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name) 44*89c4ff92SAndroid Build Coastguard Worker 45*89c4ff92SAndroid Build Coastguard Worker out_tensor_info = out_bind_info[1] 46*89c4ff92SAndroid Build Coastguard Worker out_tensor_id = out_bind_info[0] 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker output_tensors.append((out_tensor_id, 49*89c4ff92SAndroid Build Coastguard Worker ann.Tensor(out_tensor_info))) 50*89c4ff92SAndroid Build Coastguard Worker 51*89c4ff92SAndroid Build Coastguard Worker yield preferred_backends, network, runtime, input_tensors, output_tensors 52*89c4ff92SAndroid Build Coastguard Worker 53*89c4ff92SAndroid Build Coastguard Worker 54*89c4ff92SAndroid Build Coastguard Worker@pytest.fixture(scope='function') 55*89c4ff92SAndroid Build Coastguard Workerdef mock_model_runtime(shared_data_folder): 56*89c4ff92SAndroid Build Coastguard Worker parser = ann.ITfLiteParser() 57*89c4ff92SAndroid Build Coastguard Worker network = parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'mock_model.tflite')) 58*89c4ff92SAndroid Build Coastguard Worker graph_id = 0 59*89c4ff92SAndroid Build Coastguard Worker 60*89c4ff92SAndroid Build Coastguard Worker input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, "input_1") 61*89c4ff92SAndroid Build Coastguard Worker 62*89c4ff92SAndroid Build Coastguard Worker input_tensor_data = np.load(os.path.join(shared_data_folder, 'tflite_parser/input_lite.npy')) 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker preferred_backends = [ann.BackendId('CpuRef')] 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker options = ann.CreationOptions() 67*89c4ff92SAndroid Build Coastguard Worker runtime = ann.IRuntime(options) 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), ann.OptimizerOptions()) 70*89c4ff92SAndroid Build Coastguard Worker 71*89c4ff92SAndroid Build Coastguard Worker print(messages) 72*89c4ff92SAndroid Build Coastguard Worker 73*89c4ff92SAndroid Build Coastguard Worker net_id, messages = runtime.LoadNetwork(opt_network) 74*89c4ff92SAndroid Build Coastguard Worker 75*89c4ff92SAndroid Build Coastguard Worker print(messages) 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Worker input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor_data]) 78*89c4ff92SAndroid Build Coastguard Worker 79*89c4ff92SAndroid Build Coastguard Worker output_names = parser.GetSubgraphOutputTensorNames(graph_id) 80*89c4ff92SAndroid Build Coastguard Worker outputs_binding_info = [] 81*89c4ff92SAndroid Build Coastguard Worker 82*89c4ff92SAndroid Build Coastguard Worker for output_name in output_names: 83*89c4ff92SAndroid Build Coastguard Worker outputs_binding_info.append(parser.GetNetworkOutputBindingInfo(graph_id, output_name)) 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker output_tensors = ann.make_output_tensors(outputs_binding_info) 86*89c4ff92SAndroid Build Coastguard Worker 87*89c4ff92SAndroid Build Coastguard Worker yield runtime, net_id, input_tensors, output_tensors 88*89c4ff92SAndroid Build Coastguard Worker 89*89c4ff92SAndroid Build Coastguard Worker 90*89c4ff92SAndroid Build Coastguard Workerdef test_python_disowns_network(random_runtime): 91*89c4ff92SAndroid Build Coastguard Worker preferred_backends = random_runtime[0] 92*89c4ff92SAndroid Build Coastguard Worker network = random_runtime[1] 93*89c4ff92SAndroid Build Coastguard Worker runtime = random_runtime[2] 94*89c4ff92SAndroid Build Coastguard Worker opt_network, _ = ann.Optimize(network, preferred_backends, 95*89c4ff92SAndroid Build Coastguard Worker runtime.GetDeviceSpec(), ann.OptimizerOptions()) 96*89c4ff92SAndroid Build Coastguard Worker 97*89c4ff92SAndroid Build Coastguard Worker runtime.LoadNetwork(opt_network) 98*89c4ff92SAndroid Build Coastguard Worker 99*89c4ff92SAndroid Build Coastguard Worker assert not opt_network.thisown 100*89c4ff92SAndroid Build Coastguard Worker 101*89c4ff92SAndroid Build Coastguard Worker 102*89c4ff92SAndroid Build Coastguard Workerdef test_load_network(random_runtime): 103*89c4ff92SAndroid Build Coastguard Worker preferred_backends = random_runtime[0] 104*89c4ff92SAndroid Build Coastguard Worker network = random_runtime[1] 105*89c4ff92SAndroid Build Coastguard Worker runtime = random_runtime[2] 106*89c4ff92SAndroid Build Coastguard Worker 107*89c4ff92SAndroid Build Coastguard Worker opt_network, _ = ann.Optimize(network, preferred_backends, 108*89c4ff92SAndroid Build Coastguard Worker runtime.GetDeviceSpec(), ann.OptimizerOptions()) 109*89c4ff92SAndroid Build Coastguard Worker 110*89c4ff92SAndroid Build Coastguard Worker net_id, messages = runtime.LoadNetwork(opt_network) 111*89c4ff92SAndroid Build Coastguard Worker assert "" == messages 112*89c4ff92SAndroid Build Coastguard Worker assert net_id == 0 113*89c4ff92SAndroid Build Coastguard Worker 114*89c4ff92SAndroid Build Coastguard Worker 115*89c4ff92SAndroid Build Coastguard Workerdef test_create_runtime_with_external_profiling_enabled(): 116*89c4ff92SAndroid Build Coastguard Worker 117*89c4ff92SAndroid Build Coastguard Worker options = ann.CreationOptions() 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_FileOnly = True 120*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_EnableProfiling = True 121*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_OutgoingCaptureFile = "/tmp/outgoing.txt" 122*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_IncomingCaptureFile = "/tmp/incoming.txt" 123*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_TimelineEnabled = True 124*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_CapturePeriod = 1000 125*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_FileFormat = "JSON" 126*89c4ff92SAndroid Build Coastguard Worker 127*89c4ff92SAndroid Build Coastguard Worker runtime = ann.IRuntime(options) 128*89c4ff92SAndroid Build Coastguard Worker 129*89c4ff92SAndroid Build Coastguard Worker assert runtime is not None 130*89c4ff92SAndroid Build Coastguard Worker 131*89c4ff92SAndroid Build Coastguard Worker 132*89c4ff92SAndroid Build Coastguard Workerdef test_create_runtime_with_external_profiling_enabled_invalid_options(): 133*89c4ff92SAndroid Build Coastguard Worker 134*89c4ff92SAndroid Build Coastguard Worker options = ann.CreationOptions() 135*89c4ff92SAndroid Build Coastguard Worker 136*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_FileOnly = True 137*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_EnableProfiling = False 138*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_OutgoingCaptureFile = "/tmp/outgoing.txt" 139*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_IncomingCaptureFile = "/tmp/incoming.txt" 140*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_TimelineEnabled = True 141*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_CapturePeriod = 1000 142*89c4ff92SAndroid Build Coastguard Worker options.m_ProfilingOptions.m_FileFormat = "JSON" 143*89c4ff92SAndroid Build Coastguard Worker 144*89c4ff92SAndroid Build Coastguard Worker with pytest.raises(RuntimeError) as err: 145*89c4ff92SAndroid Build Coastguard Worker runtime = ann.IRuntime(options) 146*89c4ff92SAndroid Build Coastguard Worker 147*89c4ff92SAndroid Build Coastguard Worker expected_error_message = "It is not possible to enable timeline reporting without profiling being enabled" 148*89c4ff92SAndroid Build Coastguard Worker assert expected_error_message in str(err.value) 149*89c4ff92SAndroid Build Coastguard Worker 150*89c4ff92SAndroid Build Coastguard Worker 151*89c4ff92SAndroid Build Coastguard Workerdef test_load_network_properties_provided(random_runtime): 152*89c4ff92SAndroid Build Coastguard Worker preferred_backends = random_runtime[0] 153*89c4ff92SAndroid Build Coastguard Worker network = random_runtime[1] 154*89c4ff92SAndroid Build Coastguard Worker runtime = random_runtime[2] 155*89c4ff92SAndroid Build Coastguard Worker 156*89c4ff92SAndroid Build Coastguard Worker opt_network, _ = ann.Optimize(network, preferred_backends, 157*89c4ff92SAndroid Build Coastguard Worker runtime.GetDeviceSpec(), ann.OptimizerOptions()) 158*89c4ff92SAndroid Build Coastguard Worker 159*89c4ff92SAndroid Build Coastguard Worker inputSource = ann.MemorySource_Undefined 160*89c4ff92SAndroid Build Coastguard Worker outputSource = ann.MemorySource_Undefined 161*89c4ff92SAndroid Build Coastguard Worker properties = ann.INetworkProperties(False, inputSource, outputSource) 162*89c4ff92SAndroid Build Coastguard Worker net_id, messages = runtime.LoadNetwork(opt_network, properties) 163*89c4ff92SAndroid Build Coastguard Worker assert "" == messages 164*89c4ff92SAndroid Build Coastguard Worker assert net_id == 0 165*89c4ff92SAndroid Build Coastguard Worker 166*89c4ff92SAndroid Build Coastguard Worker 167*89c4ff92SAndroid Build Coastguard Workerdef test_network_properties_constructor(random_runtime): 168*89c4ff92SAndroid Build Coastguard Worker preferred_backends = random_runtime[0] 169*89c4ff92SAndroid Build Coastguard Worker network = random_runtime[1] 170*89c4ff92SAndroid Build Coastguard Worker runtime = random_runtime[2] 171*89c4ff92SAndroid Build Coastguard Worker 172*89c4ff92SAndroid Build Coastguard Worker opt_network, _ = ann.Optimize(network, preferred_backends, 173*89c4ff92SAndroid Build Coastguard Worker runtime.GetDeviceSpec(), ann.OptimizerOptions()) 174*89c4ff92SAndroid Build Coastguard Worker 175*89c4ff92SAndroid Build Coastguard Worker inputSource = ann.MemorySource_Undefined 176*89c4ff92SAndroid Build Coastguard Worker outputSource = ann.MemorySource_Undefined 177*89c4ff92SAndroid Build Coastguard Worker properties = ann.INetworkProperties(True, inputSource, outputSource) 178*89c4ff92SAndroid Build Coastguard Worker assert properties.m_AsyncEnabled == True 179*89c4ff92SAndroid Build Coastguard Worker assert properties.m_ProfilingEnabled == False 180*89c4ff92SAndroid Build Coastguard Worker assert properties.m_OutputNetworkDetailsMethod == ann.ProfilingDetailsMethod_Undefined 181*89c4ff92SAndroid Build Coastguard Worker assert properties.m_InputSource == ann.MemorySource_Undefined 182*89c4ff92SAndroid Build Coastguard Worker assert properties.m_OutputSource == ann.MemorySource_Undefined 183*89c4ff92SAndroid Build Coastguard Worker 184*89c4ff92SAndroid Build Coastguard Worker net_id, messages = runtime.LoadNetwork(opt_network, properties) 185*89c4ff92SAndroid Build Coastguard Worker assert "" == messages 186*89c4ff92SAndroid Build Coastguard Worker assert net_id == 0 187*89c4ff92SAndroid Build Coastguard Worker 188*89c4ff92SAndroid Build Coastguard Worker 189*89c4ff92SAndroid Build Coastguard Workerdef test_unload_network_fails_for_invalid_net_id(random_runtime): 190*89c4ff92SAndroid Build Coastguard Worker preferred_backends = random_runtime[0] 191*89c4ff92SAndroid Build Coastguard Worker network = random_runtime[1] 192*89c4ff92SAndroid Build Coastguard Worker runtime = random_runtime[2] 193*89c4ff92SAndroid Build Coastguard Worker 194*89c4ff92SAndroid Build Coastguard Worker ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), ann.OptimizerOptions()) 195*89c4ff92SAndroid Build Coastguard Worker 196*89c4ff92SAndroid Build Coastguard Worker with pytest.raises(RuntimeError) as err: 197*89c4ff92SAndroid Build Coastguard Worker runtime.UnloadNetwork(9) 198*89c4ff92SAndroid Build Coastguard Worker 199*89c4ff92SAndroid Build Coastguard Worker expected_error_message = "Failed to unload network." 200*89c4ff92SAndroid Build Coastguard Worker assert expected_error_message in str(err.value) 201*89c4ff92SAndroid Build Coastguard Worker 202*89c4ff92SAndroid Build Coastguard Worker 203*89c4ff92SAndroid Build Coastguard Workerdef test_enqueue_workload(random_runtime): 204*89c4ff92SAndroid Build Coastguard Worker preferred_backends = random_runtime[0] 205*89c4ff92SAndroid Build Coastguard Worker network = random_runtime[1] 206*89c4ff92SAndroid Build Coastguard Worker runtime = random_runtime[2] 207*89c4ff92SAndroid Build Coastguard Worker input_tensors = random_runtime[3] 208*89c4ff92SAndroid Build Coastguard Worker output_tensors = random_runtime[4] 209*89c4ff92SAndroid Build Coastguard Worker 210*89c4ff92SAndroid Build Coastguard Worker opt_network, _ = ann.Optimize(network, preferred_backends, 211*89c4ff92SAndroid Build Coastguard Worker runtime.GetDeviceSpec(), ann.OptimizerOptions()) 212*89c4ff92SAndroid Build Coastguard Worker 213*89c4ff92SAndroid Build Coastguard Worker net_id, _ = runtime.LoadNetwork(opt_network) 214*89c4ff92SAndroid Build Coastguard Worker runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) 215*89c4ff92SAndroid Build Coastguard Worker 216*89c4ff92SAndroid Build Coastguard Worker 217*89c4ff92SAndroid Build Coastguard Workerdef test_enqueue_workload_fails_with_empty_input_tensors(random_runtime): 218*89c4ff92SAndroid Build Coastguard Worker preferred_backends = random_runtime[0] 219*89c4ff92SAndroid Build Coastguard Worker network = random_runtime[1] 220*89c4ff92SAndroid Build Coastguard Worker runtime = random_runtime[2] 221*89c4ff92SAndroid Build Coastguard Worker input_tensors = [] 222*89c4ff92SAndroid Build Coastguard Worker output_tensors = random_runtime[4] 223*89c4ff92SAndroid Build Coastguard Worker 224*89c4ff92SAndroid Build Coastguard Worker opt_network, _ = ann.Optimize(network, preferred_backends, 225*89c4ff92SAndroid Build Coastguard Worker runtime.GetDeviceSpec(), ann.OptimizerOptions()) 226*89c4ff92SAndroid Build Coastguard Worker 227*89c4ff92SAndroid Build Coastguard Worker net_id, _ = runtime.LoadNetwork(opt_network) 228*89c4ff92SAndroid Build Coastguard Worker with pytest.raises(RuntimeError) as err: 229*89c4ff92SAndroid Build Coastguard Worker runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) 230*89c4ff92SAndroid Build Coastguard Worker 231*89c4ff92SAndroid Build Coastguard Worker expected_error_message = "Number of inputs provided does not match network." 232*89c4ff92SAndroid Build Coastguard Worker assert expected_error_message in str(err.value) 233*89c4ff92SAndroid Build Coastguard Worker 234*89c4ff92SAndroid Build Coastguard Worker 235*89c4ff92SAndroid Build Coastguard Worker@pytest.mark.x86_64 236*89c4ff92SAndroid Build Coastguard Worker@pytest.mark.parametrize('count', [5]) 237*89c4ff92SAndroid Build Coastguard Workerdef test_multiple_inference_runs_yield_same_result(count, mock_model_runtime): 238*89c4ff92SAndroid Build Coastguard Worker """ 239*89c4ff92SAndroid Build Coastguard Worker Test that results remain consistent among multiple runs of the same inference. 240*89c4ff92SAndroid Build Coastguard Worker """ 241*89c4ff92SAndroid Build Coastguard Worker runtime = mock_model_runtime[0] 242*89c4ff92SAndroid Build Coastguard Worker net_id = mock_model_runtime[1] 243*89c4ff92SAndroid Build Coastguard Worker input_tensors = mock_model_runtime[2] 244*89c4ff92SAndroid Build Coastguard Worker output_tensors = mock_model_runtime[3] 245*89c4ff92SAndroid Build Coastguard Worker 246*89c4ff92SAndroid Build Coastguard Worker expected_results = np.array([[4, 85, 108, 29, 8, 16, 0, 2, 5, 0]]) 247*89c4ff92SAndroid Build Coastguard Worker 248*89c4ff92SAndroid Build Coastguard Worker for _ in range(count): 249*89c4ff92SAndroid Build Coastguard Worker runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) 250*89c4ff92SAndroid Build Coastguard Worker 251*89c4ff92SAndroid Build Coastguard Worker output_vectors = ann.workload_tensors_to_ndarray(output_tensors) 252*89c4ff92SAndroid Build Coastguard Worker 253*89c4ff92SAndroid Build Coastguard Worker for i in range(len(expected_results)): 254*89c4ff92SAndroid Build Coastguard Worker assert output_vectors[i].all() == expected_results[i].all() 255*89c4ff92SAndroid Build Coastguard Worker 256*89c4ff92SAndroid Build Coastguard Worker 257*89c4ff92SAndroid Build Coastguard Worker@pytest.mark.aarch64 258*89c4ff92SAndroid Build Coastguard Workerdef test_aarch64_inference_results(mock_model_runtime): 259*89c4ff92SAndroid Build Coastguard Worker 260*89c4ff92SAndroid Build Coastguard Worker runtime = mock_model_runtime[0] 261*89c4ff92SAndroid Build Coastguard Worker net_id = mock_model_runtime[1] 262*89c4ff92SAndroid Build Coastguard Worker input_tensors = mock_model_runtime[2] 263*89c4ff92SAndroid Build Coastguard Worker output_tensors = mock_model_runtime[3] 264*89c4ff92SAndroid Build Coastguard Worker 265*89c4ff92SAndroid Build Coastguard Worker runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) 266*89c4ff92SAndroid Build Coastguard Worker 267*89c4ff92SAndroid Build Coastguard Worker output_vectors = ann.workload_tensors_to_ndarray(output_tensors) 268*89c4ff92SAndroid Build Coastguard Worker 269*89c4ff92SAndroid Build Coastguard Worker expected_outputs = expected_results = np.array([[4, 85, 108, 29, 8, 16, 0, 2, 5, 0]]) 270*89c4ff92SAndroid Build Coastguard Worker 271*89c4ff92SAndroid Build Coastguard Worker for i in range(len(expected_outputs)): 272*89c4ff92SAndroid Build Coastguard Worker assert output_vectors[i].all() == expected_results[i].all() 273*89c4ff92SAndroid Build Coastguard Worker 274*89c4ff92SAndroid Build Coastguard Worker 275*89c4ff92SAndroid Build Coastguard Workerdef test_enqueue_workload_with_profiler(random_runtime): 276*89c4ff92SAndroid Build Coastguard Worker """ 277*89c4ff92SAndroid Build Coastguard Worker Tests ArmNN's profiling extension 278*89c4ff92SAndroid Build Coastguard Worker """ 279*89c4ff92SAndroid Build Coastguard Worker preferred_backends = random_runtime[0] 280*89c4ff92SAndroid Build Coastguard Worker network = random_runtime[1] 281*89c4ff92SAndroid Build Coastguard Worker runtime = random_runtime[2] 282*89c4ff92SAndroid Build Coastguard Worker input_tensors = random_runtime[3] 283*89c4ff92SAndroid Build Coastguard Worker output_tensors = random_runtime[4] 284*89c4ff92SAndroid Build Coastguard Worker 285*89c4ff92SAndroid Build Coastguard Worker opt_network, _ = ann.Optimize(network, preferred_backends, 286*89c4ff92SAndroid Build Coastguard Worker runtime.GetDeviceSpec(), ann.OptimizerOptions()) 287*89c4ff92SAndroid Build Coastguard Worker net_id, _ = runtime.LoadNetwork(opt_network) 288*89c4ff92SAndroid Build Coastguard Worker 289*89c4ff92SAndroid Build Coastguard Worker profiler = runtime.GetProfiler(net_id) 290*89c4ff92SAndroid Build Coastguard Worker # By default profiling should be turned off: 291*89c4ff92SAndroid Build Coastguard Worker assert profiler.IsProfilingEnabled() is False 292*89c4ff92SAndroid Build Coastguard Worker 293*89c4ff92SAndroid Build Coastguard Worker # Enable profiling: 294*89c4ff92SAndroid Build Coastguard Worker profiler.EnableProfiling(True) 295*89c4ff92SAndroid Build Coastguard Worker assert profiler.IsProfilingEnabled() is True 296*89c4ff92SAndroid Build Coastguard Worker 297*89c4ff92SAndroid Build Coastguard Worker # Run the inference: 298*89c4ff92SAndroid Build Coastguard Worker runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) 299*89c4ff92SAndroid Build Coastguard Worker 300*89c4ff92SAndroid Build Coastguard Worker # Get profile output as a string: 301*89c4ff92SAndroid Build Coastguard Worker str_profile = profiler.as_json() 302*89c4ff92SAndroid Build Coastguard Worker 303*89c4ff92SAndroid Build Coastguard Worker # Verify that certain markers are present: 304*89c4ff92SAndroid Build Coastguard Worker assert len(str_profile) != 0 305*89c4ff92SAndroid Build Coastguard Worker assert str_profile.find('\"ArmNN\": {') > 0 306*89c4ff92SAndroid Build Coastguard Worker 307*89c4ff92SAndroid Build Coastguard Worker # Get events analysis output as a string: 308*89c4ff92SAndroid Build Coastguard Worker str_events_analysis = profiler.event_log() 309*89c4ff92SAndroid Build Coastguard Worker 310*89c4ff92SAndroid Build Coastguard Worker assert "Event Sequence - Name | Duration (ms) | Start (ms) | Stop (ms) | Device" in str_events_analysis 311*89c4ff92SAndroid Build Coastguard Worker 312*89c4ff92SAndroid Build Coastguard Worker assert profiler.thisown == 0 313*89c4ff92SAndroid Build Coastguard Worker 314*89c4ff92SAndroid Build Coastguard Worker 315*89c4ff92SAndroid Build Coastguard Workerdef test_check_runtime_swig_ownership(random_runtime): 316*89c4ff92SAndroid Build Coastguard Worker # Check to see that SWIG has ownership for runtime. This instructs SWIG to take 317*89c4ff92SAndroid Build Coastguard Worker # ownership of the return value. This allows the value to be automatically 318*89c4ff92SAndroid Build Coastguard Worker # garbage-collected when it is no longer in use 319*89c4ff92SAndroid Build Coastguard Worker runtime = random_runtime[2] 320*89c4ff92SAndroid Build Coastguard Worker assert runtime.thisown 321