xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/CompatibilityTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
7*89c4ff92SAndroid Build Coastguard Worker #include <cl/ClBackend.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #endif
9*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
10*89c4ff92SAndroid Build Coastguard Worker #include <neon/NeonBackend.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #endif
12*89c4ff92SAndroid Build Coastguard Worker #include <reference/RefBackend.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendHelper.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <Network.hpp>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <vector>
20*89c4ff92SAndroid Build Coastguard Worker #include <string>
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED) && defined(ARMCOMPUTECL_ENABLED)
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("BackendsCompatibility")
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker // Partially disabled Test Suite
29*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Neon_Cl_DirectCompatibility_Test")
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker     auto neonBackend = std::make_unique<NeonBackend>();
32*89c4ff92SAndroid Build Coastguard Worker     auto clBackend = std::make_unique<ClBackend>();
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     TensorHandleFactoryRegistry registry;
35*89c4ff92SAndroid Build Coastguard Worker     neonBackend->RegisterTensorHandleFactories(registry);
36*89c4ff92SAndroid Build Coastguard Worker     clBackend->RegisterTensorHandleFactories(registry);
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     const BackendId& neonBackendId = neonBackend->GetId();
39*89c4ff92SAndroid Build Coastguard Worker     const BackendId& clBackendId = clBackend->GetId();
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     BackendsMap backends;
42*89c4ff92SAndroid Build Coastguard Worker     backends[neonBackendId] = std::move(neonBackend);
43*89c4ff92SAndroid Build Coastguard Worker     backends[clBackendId] = std::move(clBackend);
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     armnn::Graph graph;
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     armnn::InputLayer* const inputLayer = graph.AddLayer<armnn::InputLayer>(0, "input");
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     inputLayer->SetBackendId(neonBackendId);
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     armnn::SoftmaxDescriptor smDesc;
52*89c4ff92SAndroid Build Coastguard Worker     armnn::SoftmaxLayer* const softmaxLayer1 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax1");
53*89c4ff92SAndroid Build Coastguard Worker     softmaxLayer1->SetBackendId(clBackendId);
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     armnn::SoftmaxLayer* const softmaxLayer2 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax2");
56*89c4ff92SAndroid Build Coastguard Worker     softmaxLayer2->SetBackendId(neonBackendId);
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     armnn::SoftmaxLayer* const softmaxLayer3 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax3");
59*89c4ff92SAndroid Build Coastguard Worker     softmaxLayer3->SetBackendId(clBackendId);
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     armnn::SoftmaxLayer* const softmaxLayer4 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax4");
62*89c4ff92SAndroid Build Coastguard Worker     softmaxLayer4->SetBackendId(neonBackendId);
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputLayer* const outputLayer = graph.AddLayer<armnn::OutputLayer>(0, "output");
65*89c4ff92SAndroid Build Coastguard Worker     outputLayer->SetBackendId(clBackendId);
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(softmaxLayer1->GetInputSlot(0));
68*89c4ff92SAndroid Build Coastguard Worker     softmaxLayer1->GetOutputSlot(0).Connect(softmaxLayer2->GetInputSlot(0));
69*89c4ff92SAndroid Build Coastguard Worker     softmaxLayer2->GetOutputSlot(0).Connect(softmaxLayer3->GetInputSlot(0));
70*89c4ff92SAndroid Build Coastguard Worker     softmaxLayer3->GetOutputSlot(0).Connect(softmaxLayer4->GetInputSlot(0));
71*89c4ff92SAndroid Build Coastguard Worker     softmaxLayer4->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker     graph.TopologicalSort();
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> errors;
76*89c4ff92SAndroid Build Coastguard Worker     auto result = SelectTensorHandleStrategy(graph, backends, registry, true, true, errors);
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker     CHECK(result.m_Error == false);
79*89c4ff92SAndroid Build Coastguard Worker     CHECK(result.m_Warning == false);
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker     // OutputSlot& inputLayerOut = inputLayer->GetOutputSlot(0);
82*89c4ff92SAndroid Build Coastguard Worker     // OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0);
83*89c4ff92SAndroid Build Coastguard Worker     // OutputSlot& softmaxLayer2Out = softmaxLayer2->GetOutputSlot(0);
84*89c4ff92SAndroid Build Coastguard Worker     // OutputSlot& softmaxLayer3Out = softmaxLayer3->GetOutputSlot(0);
85*89c4ff92SAndroid Build Coastguard Worker     // OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0);
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     // // Check that the correct factory was selected
88*89c4ff92SAndroid Build Coastguard Worker     // CHECK(inputLayerOut.GetTensorHandleFactoryId()    == "Arm/Cl/TensorHandleFactory");
89*89c4ff92SAndroid Build Coastguard Worker     // CHECK(softmaxLayer1Out.GetTensorHandleFactoryId() == "Arm/Cl/TensorHandleFactory");
90*89c4ff92SAndroid Build Coastguard Worker     // CHECK(softmaxLayer2Out.GetTensorHandleFactoryId() == "Arm/Cl/TensorHandleFactory");
91*89c4ff92SAndroid Build Coastguard Worker     // CHECK(softmaxLayer3Out.GetTensorHandleFactoryId() == "Arm/Cl/TensorHandleFactory");
92*89c4ff92SAndroid Build Coastguard Worker     // CHECK(softmaxLayer4Out.GetTensorHandleFactoryId() == "Arm/Cl/TensorHandleFactory");
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker     // // Check that the correct strategy was selected
95*89c4ff92SAndroid Build Coastguard Worker     // CHECK((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
96*89c4ff92SAndroid Build Coastguard Worker     // CHECK((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
97*89c4ff92SAndroid Build Coastguard Worker     // CHECK((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
98*89c4ff92SAndroid Build Coastguard Worker     // CHECK((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
99*89c4ff92SAndroid Build Coastguard Worker     // CHECK((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker     graph.AddCompatibilityLayers(backends, registry);
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     // Test for copy layers
104*89c4ff92SAndroid Build Coastguard Worker     int copyCount= 0;
105*89c4ff92SAndroid Build Coastguard Worker     graph.ForEachLayer([&copyCount](Layer* layer)
__anon7267a1730102(Layer* layer) 106*89c4ff92SAndroid Build Coastguard Worker     {
107*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetType() == LayerType::MemCopy)
108*89c4ff92SAndroid Build Coastguard Worker         {
109*89c4ff92SAndroid Build Coastguard Worker             copyCount++;
110*89c4ff92SAndroid Build Coastguard Worker         }
111*89c4ff92SAndroid Build Coastguard Worker     });
112*89c4ff92SAndroid Build Coastguard Worker     // CHECK(copyCount == 0);
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     // Test for import layers
115*89c4ff92SAndroid Build Coastguard Worker     int importCount= 0;
116*89c4ff92SAndroid Build Coastguard Worker     graph.ForEachLayer([&importCount](Layer *layer)
__anon7267a1730202(Layer *layer) 117*89c4ff92SAndroid Build Coastguard Worker     {
118*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetType() == LayerType::MemImport)
119*89c4ff92SAndroid Build Coastguard Worker         {
120*89c4ff92SAndroid Build Coastguard Worker             importCount++;
121*89c4ff92SAndroid Build Coastguard Worker         }
122*89c4ff92SAndroid Build Coastguard Worker     });
123*89c4ff92SAndroid Build Coastguard Worker     // CHECK(importCount == 0);
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker }
127*89c4ff92SAndroid Build Coastguard Worker #endif
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("BackendCapability")
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker namespace
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNNREF_ENABLED) || defined(ARMCOMPUTENEON_ENABLED) || defined(ARMCOMPUTECL_ENABLED)
CapabilityTestHelper(BackendCapabilities & capabilities,std::vector<std::pair<std::string,bool>> capabilityVector)135*89c4ff92SAndroid Build Coastguard Worker void CapabilityTestHelper(BackendCapabilities &capabilities,
136*89c4ff92SAndroid Build Coastguard Worker                           std::vector<std::pair<std::string, bool>> capabilityVector)
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker     for (auto pair : capabilityVector)
139*89c4ff92SAndroid Build Coastguard Worker     {
140*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(armnn::HasCapability(pair.first, capabilities),
141*89c4ff92SAndroid Build Coastguard Worker                         pair.first << " capability was not been found");
142*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(armnn::HasCapability(BackendOptions::BackendOption{pair.first, pair.second}, capabilities),
143*89c4ff92SAndroid Build Coastguard Worker                         pair.first << " capability set incorrectly");
144*89c4ff92SAndroid Build Coastguard Worker     }
145*89c4ff92SAndroid Build Coastguard Worker }
146*89c4ff92SAndroid Build Coastguard Worker #endif
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNNREF_ENABLED)
149*89c4ff92SAndroid Build Coastguard Worker 
150*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Ref_Backends_Unknown_Capability_Test")
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker     auto refBackend = std::make_unique<RefBackend>();
153*89c4ff92SAndroid Build Coastguard Worker     auto refCapabilities = refBackend->GetCapabilities();
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions::BackendOption AsyncExecutionFalse{"AsyncExecution", false};
156*89c4ff92SAndroid Build Coastguard Worker     CHECK(!armnn::HasCapability(AsyncExecutionFalse, refCapabilities));
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions::BackendOption AsyncExecutionInt{"AsyncExecution", 50};
159*89c4ff92SAndroid Build Coastguard Worker     CHECK(!armnn::HasCapability(AsyncExecutionFalse, refCapabilities));
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions::BackendOption AsyncExecutionFloat{"AsyncExecution", 0.0f};
162*89c4ff92SAndroid Build Coastguard Worker     CHECK(!armnn::HasCapability(AsyncExecutionFloat, refCapabilities));
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions::BackendOption AsyncExecutionString{"AsyncExecution", "true"};
165*89c4ff92SAndroid Build Coastguard Worker     CHECK(!armnn::HasCapability(AsyncExecutionString, refCapabilities));
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker     CHECK(!armnn::HasCapability("Telekinesis", refCapabilities));
168*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions::BackendOption unknownCapability{"Telekinesis", true};
169*89c4ff92SAndroid Build Coastguard Worker     CHECK(!armnn::HasCapability(unknownCapability, refCapabilities));
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Ref_Backends_Capability_Test")
173*89c4ff92SAndroid Build Coastguard Worker {
174*89c4ff92SAndroid Build Coastguard Worker     auto refBackend = std::make_unique<RefBackend>();
175*89c4ff92SAndroid Build Coastguard Worker     auto refCapabilities = refBackend->GetCapabilities();
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker     CapabilityTestHelper(refCapabilities,
178*89c4ff92SAndroid Build Coastguard Worker                          {{"NonConstWeights", true},
179*89c4ff92SAndroid Build Coastguard Worker                           {"AsyncExecution", true},
180*89c4ff92SAndroid Build Coastguard Worker                           {"ProtectedContentAllocation", false},
181*89c4ff92SAndroid Build Coastguard Worker                           {"ConstantTensorsAsInputs", true},
182*89c4ff92SAndroid Build Coastguard Worker                           {"PreImportIOTensors", true},
183*89c4ff92SAndroid Build Coastguard Worker                           {"ExternallyManagedMemory", true},
184*89c4ff92SAndroid Build Coastguard Worker                           {"MultiAxisPacking", false}});
185*89c4ff92SAndroid Build Coastguard Worker }
186*89c4ff92SAndroid Build Coastguard Worker 
187*89c4ff92SAndroid Build Coastguard Worker #endif
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Neon_Backends_Capability_Test")
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker     auto neonBackend = std::make_unique<NeonBackend>();
194*89c4ff92SAndroid Build Coastguard Worker     auto neonCapabilities = neonBackend->GetCapabilities();
195*89c4ff92SAndroid Build Coastguard Worker 
196*89c4ff92SAndroid Build Coastguard Worker     CapabilityTestHelper(neonCapabilities,
197*89c4ff92SAndroid Build Coastguard Worker                          {{"NonConstWeights", true},
198*89c4ff92SAndroid Build Coastguard Worker                           {"AsyncExecution", false},
199*89c4ff92SAndroid Build Coastguard Worker                           {"ProtectedContentAllocation", false},
200*89c4ff92SAndroid Build Coastguard Worker                           {"ConstantTensorsAsInputs", true},
201*89c4ff92SAndroid Build Coastguard Worker                           {"PreImportIOTensors", false},
202*89c4ff92SAndroid Build Coastguard Worker                           {"ExternallyManagedMemory", true},
203*89c4ff92SAndroid Build Coastguard Worker                           {"MultiAxisPacking", false}});
204*89c4ff92SAndroid Build Coastguard Worker }
205*89c4ff92SAndroid Build Coastguard Worker 
206*89c4ff92SAndroid Build Coastguard Worker #endif
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Cl_Backends_Capability_Test")
211*89c4ff92SAndroid Build Coastguard Worker {
212*89c4ff92SAndroid Build Coastguard Worker     auto clBackend = std::make_unique<ClBackend>();
213*89c4ff92SAndroid Build Coastguard Worker     auto clCapabilities = clBackend->GetCapabilities();
214*89c4ff92SAndroid Build Coastguard Worker 
215*89c4ff92SAndroid Build Coastguard Worker     CapabilityTestHelper(clCapabilities,
216*89c4ff92SAndroid Build Coastguard Worker                          {{"NonConstWeights", false},
217*89c4ff92SAndroid Build Coastguard Worker                           {"AsyncExecution", false},
218*89c4ff92SAndroid Build Coastguard Worker                           {"ProtectedContentAllocation", true},
219*89c4ff92SAndroid Build Coastguard Worker                           {"ConstantTensorsAsInputs", true},
220*89c4ff92SAndroid Build Coastguard Worker                           {"PreImportIOTensors", false},
221*89c4ff92SAndroid Build Coastguard Worker                           {"ExternallyManagedMemory", true},
222*89c4ff92SAndroid Build Coastguard Worker                           {"MultiAxisPacking", false}});
223*89c4ff92SAndroid Build Coastguard Worker }
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker #endif
226*89c4ff92SAndroid Build Coastguard Worker }
227*89c4ff92SAndroid Build Coastguard Worker }
228