xref: /aosp_15_r20/external/armnn/src/backends/cl/test/ClFallbackTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <GraphUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("ClFallback")
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ClImportEnabledFallbackToNeon")
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
19*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
22*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
25*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
26*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
27*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
28*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
29*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
32*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
33*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
34*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
35*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(output->GetInputSlot(0));
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 4, 2 }, DataType::Float32);
38*89c4ff92SAndroid Build Coastguard Worker     info.SetConstant(true);
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
41*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
42*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
43*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
44*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { Compute::GpuAcc, Compute::CpuAcc };
47*89c4ff92SAndroid Build Coastguard Worker     // Use BackendSelectionHint to specify CpuAcc for Subtraction layer
48*89c4ff92SAndroid Build Coastguard Worker     sub->BackendSelectionHint(backends[1]);
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
51*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
52*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetImportEnabled(true);
53*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetExportEnabled(true);
54*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
59*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
60*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
61*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add");
62*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]");
63*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub");
64*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output");
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
67*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
68*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
69*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
70*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
71*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
72*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
75*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemCopy));
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     // Correctly use backend hint
78*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer5->GetBackendId() == Compute::CpuAcc ));
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
81*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
82*89c4ff92SAndroid Build Coastguard Worker     std::string ignoredErrorMessage;
83*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
84*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
85*89c4ff92SAndroid Build Coastguard Worker 
86*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
87*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValue0
88*89c4ff92SAndroid Build Coastguard Worker     {
89*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f, 1.0f, 1.0f, 2.0f, 2.0f
90*89c4ff92SAndroid Build Coastguard Worker     };
91*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValue1
92*89c4ff92SAndroid Build Coastguard Worker     {
93*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f, 1.0f, 1.0f, 2.0f
94*89c4ff92SAndroid Build Coastguard Worker     };
95*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
96*89c4ff92SAndroid Build Coastguard Worker     {
97*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 12.0f, 11.0f, 10.0f, 9.0f
98*89c4ff92SAndroid Build Coastguard Worker     };
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(16);
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
103*89c4ff92SAndroid Build Coastguard Worker     {
104*89c4ff92SAndroid Build Coastguard Worker         11.0f, 9.0f, 7.0f, 5.0f, 3.0f, 1.0f, -1.0f, -3.0f, -5.0f, -7.0f, -9.0f, -11.0f, 11.0f, 9.0f, 7.0f, 5.0f
105*89c4ff92SAndroid Build Coastguard Worker     };
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     // Prepare aligned data
108*89c4ff92SAndroid Build Coastguard Worker     unsigned int numElements = info.GetNumElements();
109*89c4ff92SAndroid Build Coastguard Worker     size_t totalBytes = numElements * sizeof(float);
110*89c4ff92SAndroid Build Coastguard Worker     const size_t alignment = 64;
111*89c4ff92SAndroid Build Coastguard Worker     size_t space = totalBytes + alignment + alignment;
112*89c4ff92SAndroid Build Coastguard Worker     auto inputData0 = std::make_unique<uint8_t[]>(space);
113*89c4ff92SAndroid Build Coastguard Worker     void* alignedInputPtr0 = inputData0.get();
114*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::align(alignment, totalBytes, alignedInputPtr0, space));
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     auto* intputPtr0 = reinterpret_cast<float*>(alignedInputPtr0);
117*89c4ff92SAndroid Build Coastguard Worker     std::copy(inputValue0.begin(), inputValue0.end(), intputPtr0);
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker     auto inputData1 = std::make_unique<uint8_t[]>(space);
120*89c4ff92SAndroid Build Coastguard Worker     void* alignedInputPtr1 = inputData1.get();
121*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::align(alignment, totalBytes, alignedInputPtr1, space));
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     auto* intputPtr1 = reinterpret_cast<float*>(alignedInputPtr1);
124*89c4ff92SAndroid Build Coastguard Worker     std::copy(inputValue1.begin(), inputValue1.end(), intputPtr1);
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
127*89c4ff92SAndroid Build Coastguard Worker     {
128*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), alignedInputPtr0) },
129*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 1), alignedInputPtr1) },
130*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 2), inputData2.data()) }
131*89c4ff92SAndroid Build Coastguard Worker     };
132*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
133*89c4ff92SAndroid Build Coastguard Worker     {
134*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
135*89c4ff92SAndroid Build Coastguard Worker     };
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
140*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
143*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
144*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
145*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
146*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker     // Executed Subtraction using CpuAcc
149*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("NeonSubtractionWorkload_Execute");
150*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     // Contain CopyMemGeneric
153*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
154*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
155*89c4ff92SAndroid Build Coastguard Worker 
156*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
157*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
158*89c4ff92SAndroid Build Coastguard Worker 
159*89c4ff92SAndroid Build Coastguard Worker     runtime->UnloadNetwork(netId);
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ClImportDisabledFallbackToNeon")
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
167*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
170*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
173*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
174*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
175*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
176*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
177*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
180*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
181*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
182*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
183*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(output->GetInputSlot(0));
184*89c4ff92SAndroid Build Coastguard Worker 
185*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32);
186*89c4ff92SAndroid Build Coastguard Worker     info.SetConstant(true);
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
189*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
190*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
191*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
192*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
193*89c4ff92SAndroid Build Coastguard Worker 
194*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { Compute::GpuAcc, Compute::CpuAcc };
195*89c4ff92SAndroid Build Coastguard Worker     // Use BackendSelectionHint to specify CpuAcc for Subtraction layer
196*89c4ff92SAndroid Build Coastguard Worker     sub->BackendSelectionHint(backends[1]);
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
199*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
200*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
203*89c4ff92SAndroid Build Coastguard Worker 
204*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
205*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
206*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
207*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add");
208*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]");
209*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub");
210*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output");
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
213*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
214*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
215*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
216*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
217*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
218*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
219*89c4ff92SAndroid Build Coastguard Worker 
220*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
221*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemCopy));
222*89c4ff92SAndroid Build Coastguard Worker 
223*89c4ff92SAndroid Build Coastguard Worker     // Correctly use backend hint
224*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer5->GetBackendId() == Compute::CpuAcc ));
225*89c4ff92SAndroid Build Coastguard Worker 
226*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
227*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
228*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet));
229*89c4ff92SAndroid Build Coastguard Worker 
230*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
231*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
232*89c4ff92SAndroid Build Coastguard Worker     {
233*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f
234*89c4ff92SAndroid Build Coastguard Worker     };
235*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
236*89c4ff92SAndroid Build Coastguard Worker     {
237*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f
238*89c4ff92SAndroid Build Coastguard Worker     };
239*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
240*89c4ff92SAndroid Build Coastguard Worker     {
241*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f
242*89c4ff92SAndroid Build Coastguard Worker     };
243*89c4ff92SAndroid Build Coastguard Worker 
244*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(12);
245*89c4ff92SAndroid Build Coastguard Worker 
246*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput
247*89c4ff92SAndroid Build Coastguard Worker     {
248*89c4ff92SAndroid Build Coastguard Worker         11.0f, 9.0f, 7.0f, 5.0f, 3.0f, 1.0f, -1.0f, -3.0f, -5.0f, -7.0f, -9.0f, -11.0f
249*89c4ff92SAndroid Build Coastguard Worker     };
250*89c4ff92SAndroid Build Coastguard Worker 
251*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
252*89c4ff92SAndroid Build Coastguard Worker     {
253*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData0.data()) },
254*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 1), inputData1.data()) },
255*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 2), inputData2.data()) }
256*89c4ff92SAndroid Build Coastguard Worker     };
257*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
258*89c4ff92SAndroid Build Coastguard Worker     {
259*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
260*89c4ff92SAndroid Build Coastguard Worker     };
261*89c4ff92SAndroid Build Coastguard Worker 
262*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
263*89c4ff92SAndroid Build Coastguard Worker 
264*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
265*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
266*89c4ff92SAndroid Build Coastguard Worker 
267*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
268*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
269*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
270*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
271*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
272*89c4ff92SAndroid Build Coastguard Worker 
273*89c4ff92SAndroid Build Coastguard Worker     // Executed Subtraction using CpuAcc
274*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("NeonSubtractionWorkload_Execute");
275*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
276*89c4ff92SAndroid Build Coastguard Worker 
277*89c4ff92SAndroid Build Coastguard Worker     // Contain CopyMemGeneric
278*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
279*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
280*89c4ff92SAndroid Build Coastguard Worker 
281*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
282*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
283*89c4ff92SAndroid Build Coastguard Worker }
284*89c4ff92SAndroid Build Coastguard Worker 
285*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ClImportEnabledFallbackSubgraphToNeon")
286*89c4ff92SAndroid Build Coastguard Worker {
287*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
288*89c4ff92SAndroid Build Coastguard Worker 
289*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
290*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
291*89c4ff92SAndroid Build Coastguard Worker 
292*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
293*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
294*89c4ff92SAndroid Build Coastguard Worker 
295*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor desc;
296*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolWidth = 2;
297*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolHeight = 2;
298*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX = 2;
299*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY = 2;
300*89c4ff92SAndroid Build Coastguard Worker 
301*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
302*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
303*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
304*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
305*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
306*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling");
307*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
308*89c4ff92SAndroid Build Coastguard Worker 
309*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
310*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
311*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
312*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
313*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
314*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
315*89c4ff92SAndroid Build Coastguard Worker 
316*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 4, 2 }, DataType::Float32);
317*89c4ff92SAndroid Build Coastguard Worker     info.SetConstant(true);
318*89c4ff92SAndroid Build Coastguard Worker     TensorInfo poolingInfo = TensorInfo({ 1, 2, 2, 1 }, DataType::Float32);
319*89c4ff92SAndroid Build Coastguard Worker 
320*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
321*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
322*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
323*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
324*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
325*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo);
326*89c4ff92SAndroid Build Coastguard Worker 
327*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { Compute::GpuAcc, Compute::CpuAcc };
328*89c4ff92SAndroid Build Coastguard Worker     // Use BackendSelectionHint to specify CpuAcc for Subtraction layer
329*89c4ff92SAndroid Build Coastguard Worker     sub->BackendSelectionHint(backends[1]);
330*89c4ff92SAndroid Build Coastguard Worker 
331*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
332*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
333*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetImportEnabled(true);
334*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetExportEnabled(true);
335*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
336*89c4ff92SAndroid Build Coastguard Worker 
337*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
338*89c4ff92SAndroid Build Coastguard Worker 
339*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
340*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
341*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
342*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add");
343*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]");
344*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub");
345*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "[ sub (0) -> pooling (0) ]");
346*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer7 = GetFirstLayerWithName(graph, "pooling");
347*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer8 = GetFirstLayerWithName(graph, "output");
348*89c4ff92SAndroid Build Coastguard Worker 
349*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
350*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
351*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
352*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
353*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
354*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
355*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
356*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer6, layer7));
357*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer7, layer8));
358*89c4ff92SAndroid Build Coastguard Worker 
359*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
360*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemCopy));
361*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer6->GetType() == LayerType::MemCopy));
362*89c4ff92SAndroid Build Coastguard Worker 
363*89c4ff92SAndroid Build Coastguard Worker     // Correctly use backend hint
364*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer5->GetBackendId() == Compute::CpuAcc ));
365*89c4ff92SAndroid Build Coastguard Worker 
366*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
367*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
368*89c4ff92SAndroid Build Coastguard Worker     std::string ignoredErrorMessage;
369*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
370*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties);
371*89c4ff92SAndroid Build Coastguard Worker 
372*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
373*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValue0
374*89c4ff92SAndroid Build Coastguard Worker     {
375*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f, 1.0f, 1.0f, 2.0f, 2.0f
376*89c4ff92SAndroid Build Coastguard Worker     };
377*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValue1
378*89c4ff92SAndroid Build Coastguard Worker     {
379*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f, 1.0f, 1.0f, 2.0f
380*89c4ff92SAndroid Build Coastguard Worker     };
381*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
382*89c4ff92SAndroid Build Coastguard Worker     {
383*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 12.0f, 11.0f, 10.0f, 9.0f
384*89c4ff92SAndroid Build Coastguard Worker     };
385*89c4ff92SAndroid Build Coastguard Worker 
386*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
387*89c4ff92SAndroid Build Coastguard Worker 
388*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput{ 11.0f, 3.0f, -5.0f, 11.0f };
389*89c4ff92SAndroid Build Coastguard Worker 
390*89c4ff92SAndroid Build Coastguard Worker     unsigned int numElements = info.GetNumElements();
391*89c4ff92SAndroid Build Coastguard Worker     size_t totalBytes = numElements * sizeof(float);
392*89c4ff92SAndroid Build Coastguard Worker     const size_t alignment = 64;
393*89c4ff92SAndroid Build Coastguard Worker     size_t space = totalBytes + alignment + alignment;
394*89c4ff92SAndroid Build Coastguard Worker     auto inputData0 = std::make_unique<uint8_t[]>(space);
395*89c4ff92SAndroid Build Coastguard Worker     void* alignedInputPtr0 = inputData0.get();
396*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::align(alignment, totalBytes, alignedInputPtr0, space));
397*89c4ff92SAndroid Build Coastguard Worker 
398*89c4ff92SAndroid Build Coastguard Worker     auto* intputPtr0 = reinterpret_cast<float*>(alignedInputPtr0);
399*89c4ff92SAndroid Build Coastguard Worker     std::copy(inputValue0.begin(), inputValue0.end(), intputPtr0);
400*89c4ff92SAndroid Build Coastguard Worker 
401*89c4ff92SAndroid Build Coastguard Worker     auto inputData1 = std::make_unique<uint8_t[]>(space);
402*89c4ff92SAndroid Build Coastguard Worker     void* alignedInputPtr1 = inputData1.get();
403*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::align(alignment, totalBytes, alignedInputPtr1, space));
404*89c4ff92SAndroid Build Coastguard Worker 
405*89c4ff92SAndroid Build Coastguard Worker     auto* intputPtr1 = reinterpret_cast<float*>(alignedInputPtr1);
406*89c4ff92SAndroid Build Coastguard Worker     std::copy(inputValue1.begin(), inputValue1.end(), intputPtr1);
407*89c4ff92SAndroid Build Coastguard Worker 
408*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
409*89c4ff92SAndroid Build Coastguard Worker     {
410*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), alignedInputPtr0) },
411*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 1), alignedInputPtr1) },
412*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 2), inputData2.data()) }
413*89c4ff92SAndroid Build Coastguard Worker     };
414*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
415*89c4ff92SAndroid Build Coastguard Worker     {
416*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
417*89c4ff92SAndroid Build Coastguard Worker     };
418*89c4ff92SAndroid Build Coastguard Worker 
419*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
420*89c4ff92SAndroid Build Coastguard Worker 
421*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
422*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
423*89c4ff92SAndroid Build Coastguard Worker 
424*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
425*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
426*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
427*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
428*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
429*89c4ff92SAndroid Build Coastguard Worker 
430*89c4ff92SAndroid Build Coastguard Worker     // Executed Subtraction using CpuAcc
431*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("NeonSubtractionWorkload_Execute");
432*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
433*89c4ff92SAndroid Build Coastguard Worker 
434*89c4ff92SAndroid Build Coastguard Worker     // Correctly switch back to GpuAcc
435*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("ClPooling2dWorkload_Execute");
436*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
437*89c4ff92SAndroid Build Coastguard Worker 
438*89c4ff92SAndroid Build Coastguard Worker     // Contain CopyMemGeneric
439*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
440*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
441*89c4ff92SAndroid Build Coastguard Worker 
442*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
443*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
444*89c4ff92SAndroid Build Coastguard Worker 
445*89c4ff92SAndroid Build Coastguard Worker     runtime->UnloadNetwork(netId);
446*89c4ff92SAndroid Build Coastguard Worker }
447*89c4ff92SAndroid Build Coastguard Worker 
448*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ClImportDisableFallbackSubgraphToNeon")
449*89c4ff92SAndroid Build Coastguard Worker {
450*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
451*89c4ff92SAndroid Build Coastguard Worker 
452*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
453*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime(IRuntime::Create(options));
454*89c4ff92SAndroid Build Coastguard Worker 
455*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
456*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
457*89c4ff92SAndroid Build Coastguard Worker 
458*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor desc;
459*89c4ff92SAndroid Build Coastguard Worker 
460*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input0 = net->AddInputLayer(0, "input0");
461*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input1 = net->AddInputLayer(1, "input1");
462*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input2 = net->AddInputLayer(2, "input2");
463*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add");
464*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub");
465*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling");
466*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
467*89c4ff92SAndroid Build Coastguard Worker 
468*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).Connect(add->GetInputSlot(0));
469*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(add->GetInputSlot(1));
470*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0));
471*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(sub->GetInputSlot(1));
472*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).Connect(pooling->GetInputSlot(0));
473*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0));
474*89c4ff92SAndroid Build Coastguard Worker 
475*89c4ff92SAndroid Build Coastguard Worker     TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32);
476*89c4ff92SAndroid Build Coastguard Worker     info.SetConstant(true);
477*89c4ff92SAndroid Build Coastguard Worker     TensorInfo poolingInfo = TensorInfo({ 1, 2, 1, 1 }, DataType::Float32);
478*89c4ff92SAndroid Build Coastguard Worker 
479*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot(0).SetTensorInfo(info);
480*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(info);
481*89c4ff92SAndroid Build Coastguard Worker     input2->GetOutputSlot(0).SetTensorInfo(info);
482*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
483*89c4ff92SAndroid Build Coastguard Worker     sub->GetOutputSlot(0).SetTensorInfo(info);
484*89c4ff92SAndroid Build Coastguard Worker     pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo);
485*89c4ff92SAndroid Build Coastguard Worker 
486*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> backends = { Compute::GpuAcc, Compute::CpuAcc };
487*89c4ff92SAndroid Build Coastguard Worker     // Use BackendSelectionHint to specify CpuAcc for Subtraction layer
488*89c4ff92SAndroid Build Coastguard Worker     sub->BackendSelectionHint(backends[1]);
489*89c4ff92SAndroid Build Coastguard Worker 
490*89c4ff92SAndroid Build Coastguard Worker     // optimize the network
491*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
492*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions);
493*89c4ff92SAndroid Build Coastguard Worker 
494*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = GetGraphForTesting(optNet.get());
495*89c4ff92SAndroid Build Coastguard Worker 
496*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0");
497*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1");
498*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2");
499*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add");
500*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]");
501*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub");
502*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "[ sub (0) -> pooling (0) ]");
503*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer7 = GetFirstLayerWithName(graph, "pooling");
504*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* const layer8 = GetFirstLayerWithName(graph, "output");
505*89c4ff92SAndroid Build Coastguard Worker 
506*89c4ff92SAndroid Build Coastguard Worker     // Checks order is valid.
507*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer0, layer1));
508*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer1, layer2));
509*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer2, layer3));
510*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer3, layer4));
511*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer4, layer5));
512*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer5, layer6));
513*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer6, layer7));
514*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckOrder(graph, layer7, layer8));
515*89c4ff92SAndroid Build Coastguard Worker 
516*89c4ff92SAndroid Build Coastguard Worker     // Use memory import between backends
517*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer4->GetType() == LayerType::MemCopy));
518*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer6->GetType() == LayerType::MemCopy));
519*89c4ff92SAndroid Build Coastguard Worker 
520*89c4ff92SAndroid Build Coastguard Worker     // Correctly use backend hint
521*89c4ff92SAndroid Build Coastguard Worker     CHECK((layer5->GetBackendId() == Compute::CpuAcc ));
522*89c4ff92SAndroid Build Coastguard Worker 
523*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime. It should pass.
524*89c4ff92SAndroid Build Coastguard Worker     NetworkId netId;
525*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(netId, std::move(optNet));
526*89c4ff92SAndroid Build Coastguard Worker 
527*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
528*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData0
529*89c4ff92SAndroid Build Coastguard Worker     {
530*89c4ff92SAndroid Build Coastguard Worker         1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f
531*89c4ff92SAndroid Build Coastguard Worker     };
532*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData1
533*89c4ff92SAndroid Build Coastguard Worker     {
534*89c4ff92SAndroid Build Coastguard Worker         0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f
535*89c4ff92SAndroid Build Coastguard Worker     };
536*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData2
537*89c4ff92SAndroid Build Coastguard Worker     {
538*89c4ff92SAndroid Build Coastguard Worker         12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f
539*89c4ff92SAndroid Build Coastguard Worker     };
540*89c4ff92SAndroid Build Coastguard Worker 
541*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(2);
542*89c4ff92SAndroid Build Coastguard Worker 
543*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput{ 11.0f, -1.0f };
544*89c4ff92SAndroid Build Coastguard Worker 
545*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
546*89c4ff92SAndroid Build Coastguard Worker     {
547*89c4ff92SAndroid Build Coastguard Worker         { 0, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData0.data()) },
548*89c4ff92SAndroid Build Coastguard Worker         { 1, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 1), inputData1.data()) },
549*89c4ff92SAndroid Build Coastguard Worker         { 2, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 2), inputData2.data()) }
550*89c4ff92SAndroid Build Coastguard Worker     };
551*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
552*89c4ff92SAndroid Build Coastguard Worker     {
553*89c4ff92SAndroid Build Coastguard Worker         { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) }
554*89c4ff92SAndroid Build Coastguard Worker     };
555*89c4ff92SAndroid Build Coastguard Worker 
556*89c4ff92SAndroid Build Coastguard Worker     runtime->GetProfiler(netId)->EnableProfiling(true);
557*89c4ff92SAndroid Build Coastguard Worker 
558*89c4ff92SAndroid Build Coastguard Worker     // Do the inference
559*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
560*89c4ff92SAndroid Build Coastguard Worker 
561*89c4ff92SAndroid Build Coastguard Worker     // Retrieve the Profiler.Print() output to get the workload execution
562*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance();
563*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
564*89c4ff92SAndroid Build Coastguard Worker     profilerManager.GetProfiler()->Print(ss);;
565*89c4ff92SAndroid Build Coastguard Worker     std::string dump = ss.str();
566*89c4ff92SAndroid Build Coastguard Worker 
567*89c4ff92SAndroid Build Coastguard Worker     // Executed Subtraction using CpuAcc
568*89c4ff92SAndroid Build Coastguard Worker     std::size_t found = dump.find("NeonSubtractionWorkload_Execute");
569*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
570*89c4ff92SAndroid Build Coastguard Worker 
571*89c4ff92SAndroid Build Coastguard Worker     // Correctly switch back to GpuAcc
572*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("ClPooling2dWorkload_Execute");
573*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
574*89c4ff92SAndroid Build Coastguard Worker 
575*89c4ff92SAndroid Build Coastguard Worker     // Contain CopyMemGeneric
576*89c4ff92SAndroid Build Coastguard Worker     found = dump.find("CopyMemGeneric");
577*89c4ff92SAndroid Build Coastguard Worker     CHECK(found != std::string::npos);
578*89c4ff92SAndroid Build Coastguard Worker 
579*89c4ff92SAndroid Build Coastguard Worker     // Check output is as expected
580*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputData == expectedOutput);
581*89c4ff92SAndroid Build Coastguard Worker }
582*89c4ff92SAndroid Build Coastguard Worker 
583*89c4ff92SAndroid Build Coastguard Worker }
584