1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
7*89c4ff92SAndroid Build Coastguard Worker #include <cl/ClBackend.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #endif
9*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
10*89c4ff92SAndroid Build Coastguard Worker #include <neon/NeonBackend.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #endif
12*89c4ff92SAndroid Build Coastguard Worker #include <reference/RefBackend.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendHelper.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <Network.hpp>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker #include <vector>
20*89c4ff92SAndroid Build Coastguard Worker #include <string>
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED) && defined(ARMCOMPUTECL_ENABLED)
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("BackendsCompatibility")
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker // Partially disabled Test Suite
29*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Neon_Cl_DirectCompatibility_Test")
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker auto neonBackend = std::make_unique<NeonBackend>();
32*89c4ff92SAndroid Build Coastguard Worker auto clBackend = std::make_unique<ClBackend>();
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker TensorHandleFactoryRegistry registry;
35*89c4ff92SAndroid Build Coastguard Worker neonBackend->RegisterTensorHandleFactories(registry);
36*89c4ff92SAndroid Build Coastguard Worker clBackend->RegisterTensorHandleFactories(registry);
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker const BackendId& neonBackendId = neonBackend->GetId();
39*89c4ff92SAndroid Build Coastguard Worker const BackendId& clBackendId = clBackend->GetId();
40*89c4ff92SAndroid Build Coastguard Worker
41*89c4ff92SAndroid Build Coastguard Worker BackendsMap backends;
42*89c4ff92SAndroid Build Coastguard Worker backends[neonBackendId] = std::move(neonBackend);
43*89c4ff92SAndroid Build Coastguard Worker backends[clBackendId] = std::move(clBackend);
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Worker armnn::Graph graph;
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker armnn::InputLayer* const inputLayer = graph.AddLayer<armnn::InputLayer>(0, "input");
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker inputLayer->SetBackendId(neonBackendId);
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker armnn::SoftmaxDescriptor smDesc;
52*89c4ff92SAndroid Build Coastguard Worker armnn::SoftmaxLayer* const softmaxLayer1 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax1");
53*89c4ff92SAndroid Build Coastguard Worker softmaxLayer1->SetBackendId(clBackendId);
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker armnn::SoftmaxLayer* const softmaxLayer2 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax2");
56*89c4ff92SAndroid Build Coastguard Worker softmaxLayer2->SetBackendId(neonBackendId);
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker armnn::SoftmaxLayer* const softmaxLayer3 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax3");
59*89c4ff92SAndroid Build Coastguard Worker softmaxLayer3->SetBackendId(clBackendId);
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker armnn::SoftmaxLayer* const softmaxLayer4 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax4");
62*89c4ff92SAndroid Build Coastguard Worker softmaxLayer4->SetBackendId(neonBackendId);
63*89c4ff92SAndroid Build Coastguard Worker
64*89c4ff92SAndroid Build Coastguard Worker armnn::OutputLayer* const outputLayer = graph.AddLayer<armnn::OutputLayer>(0, "output");
65*89c4ff92SAndroid Build Coastguard Worker outputLayer->SetBackendId(clBackendId);
66*89c4ff92SAndroid Build Coastguard Worker
67*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(softmaxLayer1->GetInputSlot(0));
68*89c4ff92SAndroid Build Coastguard Worker softmaxLayer1->GetOutputSlot(0).Connect(softmaxLayer2->GetInputSlot(0));
69*89c4ff92SAndroid Build Coastguard Worker softmaxLayer2->GetOutputSlot(0).Connect(softmaxLayer3->GetInputSlot(0));
70*89c4ff92SAndroid Build Coastguard Worker softmaxLayer3->GetOutputSlot(0).Connect(softmaxLayer4->GetInputSlot(0));
71*89c4ff92SAndroid Build Coastguard Worker softmaxLayer4->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker graph.TopologicalSort();
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> errors;
76*89c4ff92SAndroid Build Coastguard Worker auto result = SelectTensorHandleStrategy(graph, backends, registry, true, true, errors);
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker CHECK(result.m_Error == false);
79*89c4ff92SAndroid Build Coastguard Worker CHECK(result.m_Warning == false);
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker // OutputSlot& inputLayerOut = inputLayer->GetOutputSlot(0);
82*89c4ff92SAndroid Build Coastguard Worker // OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0);
83*89c4ff92SAndroid Build Coastguard Worker // OutputSlot& softmaxLayer2Out = softmaxLayer2->GetOutputSlot(0);
84*89c4ff92SAndroid Build Coastguard Worker // OutputSlot& softmaxLayer3Out = softmaxLayer3->GetOutputSlot(0);
85*89c4ff92SAndroid Build Coastguard Worker // OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0);
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker // // Check that the correct factory was selected
88*89c4ff92SAndroid Build Coastguard Worker // CHECK(inputLayerOut.GetTensorHandleFactoryId() == "Arm/Cl/TensorHandleFactory");
89*89c4ff92SAndroid Build Coastguard Worker // CHECK(softmaxLayer1Out.GetTensorHandleFactoryId() == "Arm/Cl/TensorHandleFactory");
90*89c4ff92SAndroid Build Coastguard Worker // CHECK(softmaxLayer2Out.GetTensorHandleFactoryId() == "Arm/Cl/TensorHandleFactory");
91*89c4ff92SAndroid Build Coastguard Worker // CHECK(softmaxLayer3Out.GetTensorHandleFactoryId() == "Arm/Cl/TensorHandleFactory");
92*89c4ff92SAndroid Build Coastguard Worker // CHECK(softmaxLayer4Out.GetTensorHandleFactoryId() == "Arm/Cl/TensorHandleFactory");
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker // // Check that the correct strategy was selected
95*89c4ff92SAndroid Build Coastguard Worker // CHECK((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
96*89c4ff92SAndroid Build Coastguard Worker // CHECK((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
97*89c4ff92SAndroid Build Coastguard Worker // CHECK((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
98*89c4ff92SAndroid Build Coastguard Worker // CHECK((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
99*89c4ff92SAndroid Build Coastguard Worker // CHECK((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker graph.AddCompatibilityLayers(backends, registry);
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker // Test for copy layers
104*89c4ff92SAndroid Build Coastguard Worker int copyCount= 0;
105*89c4ff92SAndroid Build Coastguard Worker graph.ForEachLayer([©Count](Layer* layer)
__anon7267a1730102(Layer* layer) 106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType() == LayerType::MemCopy)
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker copyCount++;
110*89c4ff92SAndroid Build Coastguard Worker }
111*89c4ff92SAndroid Build Coastguard Worker });
112*89c4ff92SAndroid Build Coastguard Worker // CHECK(copyCount == 0);
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker // Test for import layers
115*89c4ff92SAndroid Build Coastguard Worker int importCount= 0;
116*89c4ff92SAndroid Build Coastguard Worker graph.ForEachLayer([&importCount](Layer *layer)
__anon7267a1730202(Layer *layer) 117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType() == LayerType::MemImport)
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker importCount++;
121*89c4ff92SAndroid Build Coastguard Worker }
122*89c4ff92SAndroid Build Coastguard Worker });
123*89c4ff92SAndroid Build Coastguard Worker // CHECK(importCount == 0);
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker }
127*89c4ff92SAndroid Build Coastguard Worker #endif
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("BackendCapability")
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker namespace
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNNREF_ENABLED) || defined(ARMCOMPUTENEON_ENABLED) || defined(ARMCOMPUTECL_ENABLED)
CapabilityTestHelper(BackendCapabilities & capabilities,std::vector<std::pair<std::string,bool>> capabilityVector)135*89c4ff92SAndroid Build Coastguard Worker void CapabilityTestHelper(BackendCapabilities &capabilities,
136*89c4ff92SAndroid Build Coastguard Worker std::vector<std::pair<std::string, bool>> capabilityVector)
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker for (auto pair : capabilityVector)
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(armnn::HasCapability(pair.first, capabilities),
141*89c4ff92SAndroid Build Coastguard Worker pair.first << " capability was not been found");
142*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(armnn::HasCapability(BackendOptions::BackendOption{pair.first, pair.second}, capabilities),
143*89c4ff92SAndroid Build Coastguard Worker pair.first << " capability set incorrectly");
144*89c4ff92SAndroid Build Coastguard Worker }
145*89c4ff92SAndroid Build Coastguard Worker }
146*89c4ff92SAndroid Build Coastguard Worker #endif
147*89c4ff92SAndroid Build Coastguard Worker
148*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNNREF_ENABLED)
149*89c4ff92SAndroid Build Coastguard Worker
150*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Ref_Backends_Unknown_Capability_Test")
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker auto refBackend = std::make_unique<RefBackend>();
153*89c4ff92SAndroid Build Coastguard Worker auto refCapabilities = refBackend->GetCapabilities();
154*89c4ff92SAndroid Build Coastguard Worker
155*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions::BackendOption AsyncExecutionFalse{"AsyncExecution", false};
156*89c4ff92SAndroid Build Coastguard Worker CHECK(!armnn::HasCapability(AsyncExecutionFalse, refCapabilities));
157*89c4ff92SAndroid Build Coastguard Worker
158*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions::BackendOption AsyncExecutionInt{"AsyncExecution", 50};
159*89c4ff92SAndroid Build Coastguard Worker CHECK(!armnn::HasCapability(AsyncExecutionFalse, refCapabilities));
160*89c4ff92SAndroid Build Coastguard Worker
161*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions::BackendOption AsyncExecutionFloat{"AsyncExecution", 0.0f};
162*89c4ff92SAndroid Build Coastguard Worker CHECK(!armnn::HasCapability(AsyncExecutionFloat, refCapabilities));
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions::BackendOption AsyncExecutionString{"AsyncExecution", "true"};
165*89c4ff92SAndroid Build Coastguard Worker CHECK(!armnn::HasCapability(AsyncExecutionString, refCapabilities));
166*89c4ff92SAndroid Build Coastguard Worker
167*89c4ff92SAndroid Build Coastguard Worker CHECK(!armnn::HasCapability("Telekinesis", refCapabilities));
168*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions::BackendOption unknownCapability{"Telekinesis", true};
169*89c4ff92SAndroid Build Coastguard Worker CHECK(!armnn::HasCapability(unknownCapability, refCapabilities));
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker
172*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Ref_Backends_Capability_Test")
173*89c4ff92SAndroid Build Coastguard Worker {
174*89c4ff92SAndroid Build Coastguard Worker auto refBackend = std::make_unique<RefBackend>();
175*89c4ff92SAndroid Build Coastguard Worker auto refCapabilities = refBackend->GetCapabilities();
176*89c4ff92SAndroid Build Coastguard Worker
177*89c4ff92SAndroid Build Coastguard Worker CapabilityTestHelper(refCapabilities,
178*89c4ff92SAndroid Build Coastguard Worker {{"NonConstWeights", true},
179*89c4ff92SAndroid Build Coastguard Worker {"AsyncExecution", true},
180*89c4ff92SAndroid Build Coastguard Worker {"ProtectedContentAllocation", false},
181*89c4ff92SAndroid Build Coastguard Worker {"ConstantTensorsAsInputs", true},
182*89c4ff92SAndroid Build Coastguard Worker {"PreImportIOTensors", true},
183*89c4ff92SAndroid Build Coastguard Worker {"ExternallyManagedMemory", true},
184*89c4ff92SAndroid Build Coastguard Worker {"MultiAxisPacking", false}});
185*89c4ff92SAndroid Build Coastguard Worker }
186*89c4ff92SAndroid Build Coastguard Worker
187*89c4ff92SAndroid Build Coastguard Worker #endif
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
190*89c4ff92SAndroid Build Coastguard Worker
191*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Neon_Backends_Capability_Test")
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker auto neonBackend = std::make_unique<NeonBackend>();
194*89c4ff92SAndroid Build Coastguard Worker auto neonCapabilities = neonBackend->GetCapabilities();
195*89c4ff92SAndroid Build Coastguard Worker
196*89c4ff92SAndroid Build Coastguard Worker CapabilityTestHelper(neonCapabilities,
197*89c4ff92SAndroid Build Coastguard Worker {{"NonConstWeights", true},
198*89c4ff92SAndroid Build Coastguard Worker {"AsyncExecution", false},
199*89c4ff92SAndroid Build Coastguard Worker {"ProtectedContentAllocation", false},
200*89c4ff92SAndroid Build Coastguard Worker {"ConstantTensorsAsInputs", true},
201*89c4ff92SAndroid Build Coastguard Worker {"PreImportIOTensors", false},
202*89c4ff92SAndroid Build Coastguard Worker {"ExternallyManagedMemory", true},
203*89c4ff92SAndroid Build Coastguard Worker {"MultiAxisPacking", false}});
204*89c4ff92SAndroid Build Coastguard Worker }
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Worker #endif
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
209*89c4ff92SAndroid Build Coastguard Worker
210*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Cl_Backends_Capability_Test")
211*89c4ff92SAndroid Build Coastguard Worker {
212*89c4ff92SAndroid Build Coastguard Worker auto clBackend = std::make_unique<ClBackend>();
213*89c4ff92SAndroid Build Coastguard Worker auto clCapabilities = clBackend->GetCapabilities();
214*89c4ff92SAndroid Build Coastguard Worker
215*89c4ff92SAndroid Build Coastguard Worker CapabilityTestHelper(clCapabilities,
216*89c4ff92SAndroid Build Coastguard Worker {{"NonConstWeights", false},
217*89c4ff92SAndroid Build Coastguard Worker {"AsyncExecution", false},
218*89c4ff92SAndroid Build Coastguard Worker {"ProtectedContentAllocation", true},
219*89c4ff92SAndroid Build Coastguard Worker {"ConstantTensorsAsInputs", true},
220*89c4ff92SAndroid Build Coastguard Worker {"PreImportIOTensors", false},
221*89c4ff92SAndroid Build Coastguard Worker {"ExternallyManagedMemory", true},
222*89c4ff92SAndroid Build Coastguard Worker {"MultiAxisPacking", false}});
223*89c4ff92SAndroid Build Coastguard Worker }
224*89c4ff92SAndroid Build Coastguard Worker
225*89c4ff92SAndroid Build Coastguard Worker #endif
226*89c4ff92SAndroid Build Coastguard Worker }
227*89c4ff92SAndroid Build Coastguard Worker }
228