1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp> 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <GraphUtils.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("ClFallback") 13*89c4ff92SAndroid Build Coastguard Worker { 14*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ClImportEnabledFallbackToNeon") 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 19*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 22*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 25*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 26*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 27*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 28*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 29*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 32*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 33*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 34*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 35*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 36*89c4ff92SAndroid Build Coastguard Worker 37*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 4, 2 }, DataType::Float32); 38*89c4ff92SAndroid Build Coastguard Worker info.SetConstant(true); 39*89c4ff92SAndroid Build Coastguard Worker 40*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 41*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 42*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 43*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 44*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { Compute::GpuAcc, Compute::CpuAcc }; 47*89c4ff92SAndroid Build Coastguard Worker // Use BackendSelectionHint to specify CpuAcc for Subtraction layer 48*89c4ff92SAndroid Build Coastguard Worker sub->BackendSelectionHint(backends[1]); 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker // optimize the network 51*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 52*89c4ff92SAndroid Build Coastguard Worker optOptions.SetImportEnabled(true); 53*89c4ff92SAndroid Build Coastguard Worker optOptions.SetExportEnabled(true); 54*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 59*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 60*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 61*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add"); 62*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]"); 63*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub"); 64*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output"); 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 67*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 68*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 69*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 70*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 71*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 72*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 73*89c4ff92SAndroid Build Coastguard Worker 74*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 75*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemCopy)); 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Worker // Correctly use backend hint 78*89c4ff92SAndroid Build Coastguard Worker CHECK((layer5->GetBackendId() == Compute::CpuAcc )); 79*89c4ff92SAndroid Build Coastguard Worker 80*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 81*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 82*89c4ff92SAndroid Build Coastguard Worker std::string ignoredErrorMessage; 83*89c4ff92SAndroid Build Coastguard Worker INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); 84*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); 85*89c4ff92SAndroid Build Coastguard Worker 86*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 87*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValue0 88*89c4ff92SAndroid Build Coastguard Worker { 89*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f, 1.0f, 1.0f, 2.0f, 2.0f 90*89c4ff92SAndroid Build Coastguard Worker }; 91*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValue1 92*89c4ff92SAndroid Build Coastguard Worker { 93*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f, 1.0f, 1.0f, 2.0f 94*89c4ff92SAndroid Build Coastguard Worker }; 95*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 96*89c4ff92SAndroid Build Coastguard Worker { 97*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 12.0f, 11.0f, 10.0f, 9.0f 98*89c4ff92SAndroid Build Coastguard Worker }; 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(16); 101*89c4ff92SAndroid Build Coastguard Worker 102*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput 103*89c4ff92SAndroid Build Coastguard Worker { 104*89c4ff92SAndroid Build Coastguard Worker 11.0f, 9.0f, 7.0f, 5.0f, 3.0f, 1.0f, -1.0f, -3.0f, -5.0f, -7.0f, -9.0f, -11.0f, 11.0f, 9.0f, 7.0f, 5.0f 105*89c4ff92SAndroid Build Coastguard Worker }; 106*89c4ff92SAndroid Build Coastguard Worker 107*89c4ff92SAndroid Build Coastguard Worker // Prepare aligned data 108*89c4ff92SAndroid Build Coastguard Worker unsigned int numElements = info.GetNumElements(); 109*89c4ff92SAndroid Build Coastguard Worker size_t totalBytes = numElements * sizeof(float); 110*89c4ff92SAndroid Build Coastguard Worker const size_t alignment = 64; 111*89c4ff92SAndroid Build Coastguard Worker size_t space = totalBytes + alignment + alignment; 112*89c4ff92SAndroid Build Coastguard Worker auto inputData0 = std::make_unique<uint8_t[]>(space); 113*89c4ff92SAndroid Build Coastguard Worker void* alignedInputPtr0 = inputData0.get(); 114*89c4ff92SAndroid Build Coastguard Worker CHECK(std::align(alignment, totalBytes, alignedInputPtr0, space)); 115*89c4ff92SAndroid Build Coastguard Worker 116*89c4ff92SAndroid Build Coastguard Worker auto* intputPtr0 = reinterpret_cast<float*>(alignedInputPtr0); 117*89c4ff92SAndroid Build Coastguard Worker std::copy(inputValue0.begin(), inputValue0.end(), intputPtr0); 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker auto inputData1 = std::make_unique<uint8_t[]>(space); 120*89c4ff92SAndroid Build Coastguard Worker void* alignedInputPtr1 = inputData1.get(); 121*89c4ff92SAndroid Build Coastguard Worker CHECK(std::align(alignment, totalBytes, alignedInputPtr1, space)); 122*89c4ff92SAndroid Build Coastguard Worker 123*89c4ff92SAndroid Build Coastguard Worker auto* intputPtr1 = reinterpret_cast<float*>(alignedInputPtr1); 124*89c4ff92SAndroid Build Coastguard Worker std::copy(inputValue1.begin(), inputValue1.end(), intputPtr1); 125*89c4ff92SAndroid Build Coastguard Worker 126*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 127*89c4ff92SAndroid Build Coastguard Worker { 128*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), alignedInputPtr0) }, 129*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 1), alignedInputPtr1) }, 130*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 2), inputData2.data()) } 131*89c4ff92SAndroid Build Coastguard Worker }; 132*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 133*89c4ff92SAndroid Build Coastguard Worker { 134*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 135*89c4ff92SAndroid Build Coastguard Worker }; 136*89c4ff92SAndroid Build Coastguard Worker 137*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 138*89c4ff92SAndroid Build Coastguard Worker 139*89c4ff92SAndroid Build Coastguard Worker // Do the inference 140*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 141*89c4ff92SAndroid Build Coastguard Worker 142*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 143*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 144*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 145*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 146*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 147*89c4ff92SAndroid Build Coastguard Worker 148*89c4ff92SAndroid Build Coastguard Worker // Executed Subtraction using CpuAcc 149*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("NeonSubtractionWorkload_Execute"); 150*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 151*89c4ff92SAndroid Build Coastguard Worker 152*89c4ff92SAndroid Build Coastguard Worker // Contain CopyMemGeneric 153*89c4ff92SAndroid Build Coastguard Worker found = dump.find("CopyMemGeneric"); 154*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 155*89c4ff92SAndroid Build Coastguard Worker 156*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 157*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 158*89c4ff92SAndroid Build Coastguard Worker 159*89c4ff92SAndroid Build Coastguard Worker runtime->UnloadNetwork(netId); 160*89c4ff92SAndroid Build Coastguard Worker } 161*89c4ff92SAndroid Build Coastguard Worker 162*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ClImportDisabledFallbackToNeon") 163*89c4ff92SAndroid Build Coastguard Worker { 164*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 165*89c4ff92SAndroid Build Coastguard Worker 166*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 167*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 168*89c4ff92SAndroid Build Coastguard Worker 169*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 170*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 171*89c4ff92SAndroid Build Coastguard Worker 172*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 173*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 174*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 175*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 176*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 177*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 178*89c4ff92SAndroid Build Coastguard Worker 179*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 180*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 181*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 182*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 183*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 184*89c4ff92SAndroid Build Coastguard Worker 185*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32); 186*89c4ff92SAndroid Build Coastguard Worker info.SetConstant(true); 187*89c4ff92SAndroid Build Coastguard Worker 188*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 189*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 190*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 191*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 192*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 193*89c4ff92SAndroid Build Coastguard Worker 194*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { Compute::GpuAcc, Compute::CpuAcc }; 195*89c4ff92SAndroid Build Coastguard Worker // Use BackendSelectionHint to specify CpuAcc for Subtraction layer 196*89c4ff92SAndroid Build Coastguard Worker sub->BackendSelectionHint(backends[1]); 197*89c4ff92SAndroid Build Coastguard Worker 198*89c4ff92SAndroid Build Coastguard Worker // optimize the network 199*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 200*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 201*89c4ff92SAndroid Build Coastguard Worker 202*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 203*89c4ff92SAndroid Build Coastguard Worker 204*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 205*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 206*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 207*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add"); 208*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]"); 209*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub"); 210*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output"); 211*89c4ff92SAndroid Build Coastguard Worker 212*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 213*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 214*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 215*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 216*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 217*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 218*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 219*89c4ff92SAndroid Build Coastguard Worker 220*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 221*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemCopy)); 222*89c4ff92SAndroid Build Coastguard Worker 223*89c4ff92SAndroid Build Coastguard Worker // Correctly use backend hint 224*89c4ff92SAndroid Build Coastguard Worker CHECK((layer5->GetBackendId() == Compute::CpuAcc )); 225*89c4ff92SAndroid Build Coastguard Worker 226*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 227*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 228*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet)); 229*89c4ff92SAndroid Build Coastguard Worker 230*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 231*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 232*89c4ff92SAndroid Build Coastguard Worker { 233*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f 234*89c4ff92SAndroid Build Coastguard Worker }; 235*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 236*89c4ff92SAndroid Build Coastguard Worker { 237*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f 238*89c4ff92SAndroid Build Coastguard Worker }; 239*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 240*89c4ff92SAndroid Build Coastguard Worker { 241*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f 242*89c4ff92SAndroid Build Coastguard Worker }; 243*89c4ff92SAndroid Build Coastguard Worker 244*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(12); 245*89c4ff92SAndroid Build Coastguard Worker 246*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput 247*89c4ff92SAndroid Build Coastguard Worker { 248*89c4ff92SAndroid Build Coastguard Worker 11.0f, 9.0f, 7.0f, 5.0f, 3.0f, 1.0f, -1.0f, -3.0f, -5.0f, -7.0f, -9.0f, -11.0f 249*89c4ff92SAndroid Build Coastguard Worker }; 250*89c4ff92SAndroid Build Coastguard Worker 251*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 252*89c4ff92SAndroid Build Coastguard Worker { 253*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData0.data()) }, 254*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 1), inputData1.data()) }, 255*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 2), inputData2.data()) } 256*89c4ff92SAndroid Build Coastguard Worker }; 257*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 258*89c4ff92SAndroid Build Coastguard Worker { 259*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 260*89c4ff92SAndroid Build Coastguard Worker }; 261*89c4ff92SAndroid Build Coastguard Worker 262*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 263*89c4ff92SAndroid Build Coastguard Worker 264*89c4ff92SAndroid Build Coastguard Worker // Do the inference 265*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 266*89c4ff92SAndroid Build Coastguard Worker 267*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 268*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 269*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 270*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 271*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 272*89c4ff92SAndroid Build Coastguard Worker 273*89c4ff92SAndroid Build Coastguard Worker // Executed Subtraction using CpuAcc 274*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("NeonSubtractionWorkload_Execute"); 275*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 276*89c4ff92SAndroid Build Coastguard Worker 277*89c4ff92SAndroid Build Coastguard Worker // Contain CopyMemGeneric 278*89c4ff92SAndroid Build Coastguard Worker found = dump.find("CopyMemGeneric"); 279*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 280*89c4ff92SAndroid Build Coastguard Worker 281*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 282*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 283*89c4ff92SAndroid Build Coastguard Worker } 284*89c4ff92SAndroid Build Coastguard Worker 285*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ClImportEnabledFallbackSubgraphToNeon") 286*89c4ff92SAndroid Build Coastguard Worker { 287*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 288*89c4ff92SAndroid Build Coastguard Worker 289*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 290*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 291*89c4ff92SAndroid Build Coastguard Worker 292*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 293*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 294*89c4ff92SAndroid Build Coastguard Worker 295*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor desc; 296*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolWidth = 2; 297*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolHeight = 2; 298*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideX = 2; 299*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideY = 2; 300*89c4ff92SAndroid Build Coastguard Worker 301*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 302*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 303*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 304*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 305*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 306*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling"); 307*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 308*89c4ff92SAndroid Build Coastguard Worker 309*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 310*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 311*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 312*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 313*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(pooling->GetInputSlot(0)); 314*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 315*89c4ff92SAndroid Build Coastguard Worker 316*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 4, 2 }, DataType::Float32); 317*89c4ff92SAndroid Build Coastguard Worker info.SetConstant(true); 318*89c4ff92SAndroid Build Coastguard Worker TensorInfo poolingInfo = TensorInfo({ 1, 2, 2, 1 }, DataType::Float32); 319*89c4ff92SAndroid Build Coastguard Worker 320*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 321*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 322*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 323*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 324*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 325*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo); 326*89c4ff92SAndroid Build Coastguard Worker 327*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { Compute::GpuAcc, Compute::CpuAcc }; 328*89c4ff92SAndroid Build Coastguard Worker // Use BackendSelectionHint to specify CpuAcc for Subtraction layer 329*89c4ff92SAndroid Build Coastguard Worker sub->BackendSelectionHint(backends[1]); 330*89c4ff92SAndroid Build Coastguard Worker 331*89c4ff92SAndroid Build Coastguard Worker // optimize the network 332*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 333*89c4ff92SAndroid Build Coastguard Worker optOptions.SetImportEnabled(true); 334*89c4ff92SAndroid Build Coastguard Worker optOptions.SetExportEnabled(true); 335*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 336*89c4ff92SAndroid Build Coastguard Worker 337*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 338*89c4ff92SAndroid Build Coastguard Worker 339*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 340*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 341*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 342*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add"); 343*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]"); 344*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub"); 345*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "[ sub (0) -> pooling (0) ]"); 346*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer7 = GetFirstLayerWithName(graph, "pooling"); 347*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer8 = GetFirstLayerWithName(graph, "output"); 348*89c4ff92SAndroid Build Coastguard Worker 349*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 350*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 351*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 352*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 353*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 354*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 355*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 356*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer6, layer7)); 357*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer7, layer8)); 358*89c4ff92SAndroid Build Coastguard Worker 359*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 360*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemCopy)); 361*89c4ff92SAndroid Build Coastguard Worker CHECK((layer6->GetType() == LayerType::MemCopy)); 362*89c4ff92SAndroid Build Coastguard Worker 363*89c4ff92SAndroid Build Coastguard Worker // Correctly use backend hint 364*89c4ff92SAndroid Build Coastguard Worker CHECK((layer5->GetBackendId() == Compute::CpuAcc )); 365*89c4ff92SAndroid Build Coastguard Worker 366*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 367*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 368*89c4ff92SAndroid Build Coastguard Worker std::string ignoredErrorMessage; 369*89c4ff92SAndroid Build Coastguard Worker INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); 370*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); 371*89c4ff92SAndroid Build Coastguard Worker 372*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 373*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValue0 374*89c4ff92SAndroid Build Coastguard Worker { 375*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f, 1.0f, 1.0f, 2.0f, 2.0f 376*89c4ff92SAndroid Build Coastguard Worker }; 377*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValue1 378*89c4ff92SAndroid Build Coastguard Worker { 379*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f, 1.0f, 1.0f, 2.0f 380*89c4ff92SAndroid Build Coastguard Worker }; 381*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 382*89c4ff92SAndroid Build Coastguard Worker { 383*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 12.0f, 11.0f, 10.0f, 9.0f 384*89c4ff92SAndroid Build Coastguard Worker }; 385*89c4ff92SAndroid Build Coastguard Worker 386*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(4); 387*89c4ff92SAndroid Build Coastguard Worker 388*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput{ 11.0f, 3.0f, -5.0f, 11.0f }; 389*89c4ff92SAndroid Build Coastguard Worker 390*89c4ff92SAndroid Build Coastguard Worker unsigned int numElements = info.GetNumElements(); 391*89c4ff92SAndroid Build Coastguard Worker size_t totalBytes = numElements * sizeof(float); 392*89c4ff92SAndroid Build Coastguard Worker const size_t alignment = 64; 393*89c4ff92SAndroid Build Coastguard Worker size_t space = totalBytes + alignment + alignment; 394*89c4ff92SAndroid Build Coastguard Worker auto inputData0 = std::make_unique<uint8_t[]>(space); 395*89c4ff92SAndroid Build Coastguard Worker void* alignedInputPtr0 = inputData0.get(); 396*89c4ff92SAndroid Build Coastguard Worker CHECK(std::align(alignment, totalBytes, alignedInputPtr0, space)); 397*89c4ff92SAndroid Build Coastguard Worker 398*89c4ff92SAndroid Build Coastguard Worker auto* intputPtr0 = reinterpret_cast<float*>(alignedInputPtr0); 399*89c4ff92SAndroid Build Coastguard Worker std::copy(inputValue0.begin(), inputValue0.end(), intputPtr0); 400*89c4ff92SAndroid Build Coastguard Worker 401*89c4ff92SAndroid Build Coastguard Worker auto inputData1 = std::make_unique<uint8_t[]>(space); 402*89c4ff92SAndroid Build Coastguard Worker void* alignedInputPtr1 = inputData1.get(); 403*89c4ff92SAndroid Build Coastguard Worker CHECK(std::align(alignment, totalBytes, alignedInputPtr1, space)); 404*89c4ff92SAndroid Build Coastguard Worker 405*89c4ff92SAndroid Build Coastguard Worker auto* intputPtr1 = reinterpret_cast<float*>(alignedInputPtr1); 406*89c4ff92SAndroid Build Coastguard Worker std::copy(inputValue1.begin(), inputValue1.end(), intputPtr1); 407*89c4ff92SAndroid Build Coastguard Worker 408*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 409*89c4ff92SAndroid Build Coastguard Worker { 410*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), alignedInputPtr0) }, 411*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 1), alignedInputPtr1) }, 412*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 2), inputData2.data()) } 413*89c4ff92SAndroid Build Coastguard Worker }; 414*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 415*89c4ff92SAndroid Build Coastguard Worker { 416*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 417*89c4ff92SAndroid Build Coastguard Worker }; 418*89c4ff92SAndroid Build Coastguard Worker 419*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 420*89c4ff92SAndroid Build Coastguard Worker 421*89c4ff92SAndroid Build Coastguard Worker // Do the inference 422*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 423*89c4ff92SAndroid Build Coastguard Worker 424*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 425*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 426*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 427*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 428*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 429*89c4ff92SAndroid Build Coastguard Worker 430*89c4ff92SAndroid Build Coastguard Worker // Executed Subtraction using CpuAcc 431*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("NeonSubtractionWorkload_Execute"); 432*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 433*89c4ff92SAndroid Build Coastguard Worker 434*89c4ff92SAndroid Build Coastguard Worker // Correctly switch back to GpuAcc 435*89c4ff92SAndroid Build Coastguard Worker found = dump.find("ClPooling2dWorkload_Execute"); 436*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 437*89c4ff92SAndroid Build Coastguard Worker 438*89c4ff92SAndroid Build Coastguard Worker // Contain CopyMemGeneric 439*89c4ff92SAndroid Build Coastguard Worker found = dump.find("CopyMemGeneric"); 440*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 441*89c4ff92SAndroid Build Coastguard Worker 442*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 443*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 444*89c4ff92SAndroid Build Coastguard Worker 445*89c4ff92SAndroid Build Coastguard Worker runtime->UnloadNetwork(netId); 446*89c4ff92SAndroid Build Coastguard Worker } 447*89c4ff92SAndroid Build Coastguard Worker 448*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ClImportDisableFallbackSubgraphToNeon") 449*89c4ff92SAndroid Build Coastguard Worker { 450*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 451*89c4ff92SAndroid Build Coastguard Worker 452*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 453*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 454*89c4ff92SAndroid Build Coastguard Worker 455*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 456*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 457*89c4ff92SAndroid Build Coastguard Worker 458*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor desc; 459*89c4ff92SAndroid Build Coastguard Worker 460*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 461*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 462*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 463*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 464*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 465*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling"); 466*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 467*89c4ff92SAndroid Build Coastguard Worker 468*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 469*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 470*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 471*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 472*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(pooling->GetInputSlot(0)); 473*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 474*89c4ff92SAndroid Build Coastguard Worker 475*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32); 476*89c4ff92SAndroid Build Coastguard Worker info.SetConstant(true); 477*89c4ff92SAndroid Build Coastguard Worker TensorInfo poolingInfo = TensorInfo({ 1, 2, 1, 1 }, DataType::Float32); 478*89c4ff92SAndroid Build Coastguard Worker 479*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 480*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 481*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 482*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 483*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 484*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo); 485*89c4ff92SAndroid Build Coastguard Worker 486*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { Compute::GpuAcc, Compute::CpuAcc }; 487*89c4ff92SAndroid Build Coastguard Worker // Use BackendSelectionHint to specify CpuAcc for Subtraction layer 488*89c4ff92SAndroid Build Coastguard Worker sub->BackendSelectionHint(backends[1]); 489*89c4ff92SAndroid Build Coastguard Worker 490*89c4ff92SAndroid Build Coastguard Worker // optimize the network 491*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 492*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 493*89c4ff92SAndroid Build Coastguard Worker 494*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 495*89c4ff92SAndroid Build Coastguard Worker 496*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 497*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 498*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 499*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add"); 500*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]"); 501*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub"); 502*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "[ sub (0) -> pooling (0) ]"); 503*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer7 = GetFirstLayerWithName(graph, "pooling"); 504*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer8 = GetFirstLayerWithName(graph, "output"); 505*89c4ff92SAndroid Build Coastguard Worker 506*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 507*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 508*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 509*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 510*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 511*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 512*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 513*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer6, layer7)); 514*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer7, layer8)); 515*89c4ff92SAndroid Build Coastguard Worker 516*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 517*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemCopy)); 518*89c4ff92SAndroid Build Coastguard Worker CHECK((layer6->GetType() == LayerType::MemCopy)); 519*89c4ff92SAndroid Build Coastguard Worker 520*89c4ff92SAndroid Build Coastguard Worker // Correctly use backend hint 521*89c4ff92SAndroid Build Coastguard Worker CHECK((layer5->GetBackendId() == Compute::CpuAcc )); 522*89c4ff92SAndroid Build Coastguard Worker 523*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 524*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 525*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet)); 526*89c4ff92SAndroid Build Coastguard Worker 527*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 528*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 529*89c4ff92SAndroid Build Coastguard Worker { 530*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f 531*89c4ff92SAndroid Build Coastguard Worker }; 532*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 533*89c4ff92SAndroid Build Coastguard Worker { 534*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f 535*89c4ff92SAndroid Build Coastguard Worker }; 536*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 537*89c4ff92SAndroid Build Coastguard Worker { 538*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f 539*89c4ff92SAndroid Build Coastguard Worker }; 540*89c4ff92SAndroid Build Coastguard Worker 541*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(2); 542*89c4ff92SAndroid Build Coastguard Worker 543*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput{ 11.0f, -1.0f }; 544*89c4ff92SAndroid Build Coastguard Worker 545*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 546*89c4ff92SAndroid Build Coastguard Worker { 547*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData0.data()) }, 548*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 1), inputData1.data()) }, 549*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 2), inputData2.data()) } 550*89c4ff92SAndroid Build Coastguard Worker }; 551*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 552*89c4ff92SAndroid Build Coastguard Worker { 553*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 554*89c4ff92SAndroid Build Coastguard Worker }; 555*89c4ff92SAndroid Build Coastguard Worker 556*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 557*89c4ff92SAndroid Build Coastguard Worker 558*89c4ff92SAndroid Build Coastguard Worker // Do the inference 559*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 560*89c4ff92SAndroid Build Coastguard Worker 561*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 562*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 563*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 564*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 565*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 566*89c4ff92SAndroid Build Coastguard Worker 567*89c4ff92SAndroid Build Coastguard Worker // Executed Subtraction using CpuAcc 568*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("NeonSubtractionWorkload_Execute"); 569*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 570*89c4ff92SAndroid Build Coastguard Worker 571*89c4ff92SAndroid Build Coastguard Worker // Correctly switch back to GpuAcc 572*89c4ff92SAndroid Build Coastguard Worker found = dump.find("ClPooling2dWorkload_Execute"); 573*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 574*89c4ff92SAndroid Build Coastguard Worker 575*89c4ff92SAndroid Build Coastguard Worker // Contain CopyMemGeneric 576*89c4ff92SAndroid Build Coastguard Worker found = dump.find("CopyMemGeneric"); 577*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 578*89c4ff92SAndroid Build Coastguard Worker 579*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 580*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 581*89c4ff92SAndroid Build Coastguard Worker } 582*89c4ff92SAndroid Build Coastguard Worker 583*89c4ff92SAndroid Build Coastguard Worker } 584