xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/BackendRegistryTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/memoryOptimizerStrategyLibrary/strategies/ConstantMemoryStrategy.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <reference/RefBackend.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker class SwapRegistryStorage : public armnn::BackendRegistry
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker public:
SwapRegistryStorage()21*89c4ff92SAndroid Build Coastguard Worker     SwapRegistryStorage() : armnn::BackendRegistry()
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker         Swap(armnn::BackendRegistryInstance(),  m_TempStorage);
24*89c4ff92SAndroid Build Coastguard Worker     }
25*89c4ff92SAndroid Build Coastguard Worker 
~SwapRegistryStorage()26*89c4ff92SAndroid Build Coastguard Worker     ~SwapRegistryStorage()
27*89c4ff92SAndroid Build Coastguard Worker     {
28*89c4ff92SAndroid Build Coastguard Worker         Swap(armnn::BackendRegistryInstance(),m_TempStorage);
29*89c4ff92SAndroid Build Coastguard Worker     }
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker private:
32*89c4ff92SAndroid Build Coastguard Worker     FactoryStorage m_TempStorage;
33*89c4ff92SAndroid Build Coastguard Worker };
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("BackendRegistryTests")
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SwapRegistry")
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
42*89c4ff92SAndroid Build Coastguard Worker     auto nFactories = BackendRegistryInstance().Size();
43*89c4ff92SAndroid Build Coastguard Worker     {
44*89c4ff92SAndroid Build Coastguard Worker         SwapRegistryStorage helper;
45*89c4ff92SAndroid Build Coastguard Worker         CHECK(BackendRegistryInstance().Size() == 0);
46*89c4ff92SAndroid Build Coastguard Worker     }
47*89c4ff92SAndroid Build Coastguard Worker     CHECK(BackendRegistryInstance().Size() == nFactories);
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TestRegistryHelper")
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
53*89c4ff92SAndroid Build Coastguard Worker     SwapRegistryStorage helper;
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     bool called = false;
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     BackendRegistry::StaticRegistryInitializer factoryHelper(
58*89c4ff92SAndroid Build Coastguard Worker         BackendRegistryInstance(),
59*89c4ff92SAndroid Build Coastguard Worker         "HelloWorld",
60*89c4ff92SAndroid Build Coastguard Worker         [&called]()
__anon0e237d1a0202() 61*89c4ff92SAndroid Build Coastguard Worker         {
62*89c4ff92SAndroid Build Coastguard Worker             called = true;
63*89c4ff92SAndroid Build Coastguard Worker             return armnn::IBackendInternalUniquePtr(nullptr);
64*89c4ff92SAndroid Build Coastguard Worker         }
65*89c4ff92SAndroid Build Coastguard Worker     );
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     // sanity check: the factory has not been called yet
68*89c4ff92SAndroid Build Coastguard Worker     CHECK(called == false);
69*89c4ff92SAndroid Build Coastguard Worker 
70*89c4ff92SAndroid Build Coastguard Worker     auto factoryFunction = BackendRegistryInstance().GetFactory("HelloWorld");
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     // sanity check: the factory still not called
73*89c4ff92SAndroid Build Coastguard Worker     CHECK(called == false);
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker     factoryFunction();
76*89c4ff92SAndroid Build Coastguard Worker     CHECK(called == true);
77*89c4ff92SAndroid Build Coastguard Worker     BackendRegistryInstance().Deregister("HelloWorld");
78*89c4ff92SAndroid Build Coastguard Worker }
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TestDirectCallToRegistry")
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
83*89c4ff92SAndroid Build Coastguard Worker     SwapRegistryStorage helper;
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     bool called = false;
86*89c4ff92SAndroid Build Coastguard Worker     BackendRegistryInstance().Register(
87*89c4ff92SAndroid Build Coastguard Worker         "HelloWorld",
88*89c4ff92SAndroid Build Coastguard Worker         [&called]()
__anon0e237d1a0302() 89*89c4ff92SAndroid Build Coastguard Worker         {
90*89c4ff92SAndroid Build Coastguard Worker             called = true;
91*89c4ff92SAndroid Build Coastguard Worker             return armnn::IBackendInternalUniquePtr(nullptr);
92*89c4ff92SAndroid Build Coastguard Worker         }
93*89c4ff92SAndroid Build Coastguard Worker     );
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker     // sanity check: the factory has not been called yet
96*89c4ff92SAndroid Build Coastguard Worker     CHECK(called == false);
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker     auto factoryFunction = BackendRegistryInstance().GetFactory("HelloWorld");
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     // sanity check: the factory still not called
101*89c4ff92SAndroid Build Coastguard Worker     CHECK(called == false);
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     factoryFunction();
104*89c4ff92SAndroid Build Coastguard Worker     CHECK(called == true);
105*89c4ff92SAndroid Build Coastguard Worker     BackendRegistryInstance().Deregister("HelloWorld");
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker // Test that backends can throw exceptions during their factory function to prevent loading in an unsuitable
109*89c4ff92SAndroid Build Coastguard Worker // environment. For example Neon Backend loading on armhf device without neon support.
110*89c4ff92SAndroid Build Coastguard Worker // In reality the dynamic backend is loaded in during the LoadDynamicBackends(options.m_DynamicBackendsPath)
111*89c4ff92SAndroid Build Coastguard Worker // step of runtime constructor, then the factory function is called to check if supported, in case
112*89c4ff92SAndroid Build Coastguard Worker // of Neon not being detected the exception is raised and so the backend is not added to the supportedBackends
113*89c4ff92SAndroid Build Coastguard Worker // list
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ThrowBackendUnavailableException")
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker     const BackendId mockBackendId("MockDynamicBackend");
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     const std::string exceptionMessage("Mock error message to test unavailable backend");
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     // Register the mock backend with a factory function lambda that always throws
124*89c4ff92SAndroid Build Coastguard Worker     BackendRegistryInstance().Register(mockBackendId,
125*89c4ff92SAndroid Build Coastguard Worker             [exceptionMessage]()
__anon0e237d1a0402() 126*89c4ff92SAndroid Build Coastguard Worker             {
127*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::BackendUnavailableException(exceptionMessage);
128*89c4ff92SAndroid Build Coastguard Worker                 return IBackendInternalUniquePtr(); // Satisfy return type
129*89c4ff92SAndroid Build Coastguard Worker             });
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker     // Get the factory function of the mock backend
132*89c4ff92SAndroid Build Coastguard Worker     auto factoryFunc = BackendRegistryInstance().GetFactory(mockBackendId);
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker     try
135*89c4ff92SAndroid Build Coastguard Worker     {
136*89c4ff92SAndroid Build Coastguard Worker         // Call the factory function as done during runtime backend registering
137*89c4ff92SAndroid Build Coastguard Worker         auto backend = factoryFunc();
138*89c4ff92SAndroid Build Coastguard Worker         FAIL("Expected exception to have been thrown");
139*89c4ff92SAndroid Build Coastguard Worker     }
140*89c4ff92SAndroid Build Coastguard Worker     catch (const BackendUnavailableException& e)
141*89c4ff92SAndroid Build Coastguard Worker     {
142*89c4ff92SAndroid Build Coastguard Worker         // Caught
143*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(e.what(), exceptionMessage);
144*89c4ff92SAndroid Build Coastguard Worker     }
145*89c4ff92SAndroid Build Coastguard Worker     // Clean up the registry for the next test.
146*89c4ff92SAndroid Build Coastguard Worker     BackendRegistryInstance().Deregister(mockBackendId);
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNNREF_ENABLED)
150*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RegisterMemoryOptimizerStrategy")
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker     const BackendId cpuRefBackendId(armnn::Compute::CpuRef);
155*89c4ff92SAndroid Build Coastguard Worker     CHECK(BackendRegistryInstance().GetMemoryOptimizerStrategies().empty());
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker     // Register the memory optimizer
158*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<IMemoryOptimizerStrategy> memoryOptimizerStrategy =
159*89c4ff92SAndroid Build Coastguard Worker         std::make_shared<ConstantMemoryStrategy>();
160*89c4ff92SAndroid Build Coastguard Worker     BackendRegistryInstance().RegisterMemoryOptimizerStrategy(cpuRefBackendId, memoryOptimizerStrategy);
161*89c4ff92SAndroid Build Coastguard Worker     CHECK(!BackendRegistryInstance().GetMemoryOptimizerStrategies().empty());
162*89c4ff92SAndroid Build Coastguard Worker     CHECK(BackendRegistryInstance().GetMemoryOptimizerStrategies().size() == 1);
163*89c4ff92SAndroid Build Coastguard Worker     // De-register the memory optimizer
164*89c4ff92SAndroid Build Coastguard Worker     BackendRegistryInstance().DeregisterMemoryOptimizerStrategy(cpuRefBackendId);
165*89c4ff92SAndroid Build Coastguard Worker     CHECK(BackendRegistryInstance().GetMemoryOptimizerStrategies().empty());
166*89c4ff92SAndroid Build Coastguard Worker }
167*89c4ff92SAndroid Build Coastguard Worker #endif
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker }
170