1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp> 7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp> 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/memoryOptimizerStrategyLibrary/strategies/ConstantMemoryStrategy.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <reference/RefBackend.hpp> 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker namespace 16*89c4ff92SAndroid Build Coastguard Worker { 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker class SwapRegistryStorage : public armnn::BackendRegistry 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker public: SwapRegistryStorage()21*89c4ff92SAndroid Build Coastguard Worker SwapRegistryStorage() : armnn::BackendRegistry() 22*89c4ff92SAndroid Build Coastguard Worker { 23*89c4ff92SAndroid Build Coastguard Worker Swap(armnn::BackendRegistryInstance(), m_TempStorage); 24*89c4ff92SAndroid Build Coastguard Worker } 25*89c4ff92SAndroid Build Coastguard Worker ~SwapRegistryStorage()26*89c4ff92SAndroid Build Coastguard Worker ~SwapRegistryStorage() 27*89c4ff92SAndroid Build Coastguard Worker { 28*89c4ff92SAndroid Build Coastguard Worker Swap(armnn::BackendRegistryInstance(),m_TempStorage); 29*89c4ff92SAndroid Build Coastguard Worker } 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker private: 32*89c4ff92SAndroid Build Coastguard Worker FactoryStorage m_TempStorage; 33*89c4ff92SAndroid Build Coastguard Worker }; 34*89c4ff92SAndroid Build Coastguard Worker 35*89c4ff92SAndroid Build Coastguard Worker } 36*89c4ff92SAndroid Build Coastguard Worker 37*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("BackendRegistryTests") 38*89c4ff92SAndroid Build Coastguard Worker { 39*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SwapRegistry") 40*89c4ff92SAndroid Build Coastguard Worker { 41*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 42*89c4ff92SAndroid Build Coastguard Worker auto nFactories = BackendRegistryInstance().Size(); 43*89c4ff92SAndroid Build Coastguard Worker { 44*89c4ff92SAndroid Build Coastguard Worker SwapRegistryStorage helper; 45*89c4ff92SAndroid Build Coastguard Worker CHECK(BackendRegistryInstance().Size() == 0); 46*89c4ff92SAndroid Build Coastguard Worker } 47*89c4ff92SAndroid Build Coastguard Worker CHECK(BackendRegistryInstance().Size() == nFactories); 48*89c4ff92SAndroid Build Coastguard Worker } 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TestRegistryHelper") 51*89c4ff92SAndroid Build Coastguard Worker { 52*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 53*89c4ff92SAndroid Build Coastguard Worker SwapRegistryStorage helper; 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Worker bool called = false; 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker BackendRegistry::StaticRegistryInitializer factoryHelper( 58*89c4ff92SAndroid Build Coastguard Worker BackendRegistryInstance(), 59*89c4ff92SAndroid Build Coastguard Worker "HelloWorld", 60*89c4ff92SAndroid Build Coastguard Worker [&called]() __anon0e237d1a0202() 61*89c4ff92SAndroid Build Coastguard Worker { 62*89c4ff92SAndroid Build Coastguard Worker called = true; 63*89c4ff92SAndroid Build Coastguard Worker return armnn::IBackendInternalUniquePtr(nullptr); 64*89c4ff92SAndroid Build Coastguard Worker } 65*89c4ff92SAndroid Build Coastguard Worker ); 66*89c4ff92SAndroid Build Coastguard Worker 67*89c4ff92SAndroid Build Coastguard Worker // sanity check: the factory has not been called yet 68*89c4ff92SAndroid Build Coastguard Worker CHECK(called == false); 69*89c4ff92SAndroid Build Coastguard Worker 70*89c4ff92SAndroid Build Coastguard Worker auto factoryFunction = BackendRegistryInstance().GetFactory("HelloWorld"); 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker // sanity check: the factory still not called 73*89c4ff92SAndroid Build Coastguard Worker CHECK(called == false); 74*89c4ff92SAndroid Build Coastguard Worker 75*89c4ff92SAndroid Build Coastguard Worker factoryFunction(); 76*89c4ff92SAndroid Build Coastguard Worker CHECK(called == true); 77*89c4ff92SAndroid Build Coastguard Worker BackendRegistryInstance().Deregister("HelloWorld"); 78*89c4ff92SAndroid Build Coastguard Worker } 79*89c4ff92SAndroid Build Coastguard Worker 80*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TestDirectCallToRegistry") 81*89c4ff92SAndroid Build Coastguard Worker { 82*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 83*89c4ff92SAndroid Build Coastguard Worker SwapRegistryStorage helper; 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker bool called = false; 86*89c4ff92SAndroid Build Coastguard Worker BackendRegistryInstance().Register( 87*89c4ff92SAndroid Build Coastguard Worker "HelloWorld", 88*89c4ff92SAndroid Build Coastguard Worker [&called]() __anon0e237d1a0302() 89*89c4ff92SAndroid Build Coastguard Worker { 90*89c4ff92SAndroid Build Coastguard Worker called = true; 91*89c4ff92SAndroid Build Coastguard Worker return armnn::IBackendInternalUniquePtr(nullptr); 92*89c4ff92SAndroid Build Coastguard Worker } 93*89c4ff92SAndroid Build Coastguard Worker ); 94*89c4ff92SAndroid Build Coastguard Worker 95*89c4ff92SAndroid Build Coastguard Worker // sanity check: the factory has not been called yet 96*89c4ff92SAndroid Build Coastguard Worker CHECK(called == false); 97*89c4ff92SAndroid Build Coastguard Worker 98*89c4ff92SAndroid Build Coastguard Worker auto factoryFunction = BackendRegistryInstance().GetFactory("HelloWorld"); 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Worker // sanity check: the factory still not called 101*89c4ff92SAndroid Build Coastguard Worker CHECK(called == false); 102*89c4ff92SAndroid Build Coastguard Worker 103*89c4ff92SAndroid Build Coastguard Worker factoryFunction(); 104*89c4ff92SAndroid Build Coastguard Worker CHECK(called == true); 105*89c4ff92SAndroid Build Coastguard Worker BackendRegistryInstance().Deregister("HelloWorld"); 106*89c4ff92SAndroid Build Coastguard Worker } 107*89c4ff92SAndroid Build Coastguard Worker 108*89c4ff92SAndroid Build Coastguard Worker // Test that backends can throw exceptions during their factory function to prevent loading in an unsuitable 109*89c4ff92SAndroid Build Coastguard Worker // environment. For example Neon Backend loading on armhf device without neon support. 110*89c4ff92SAndroid Build Coastguard Worker // In reality the dynamic backend is loaded in during the LoadDynamicBackends(options.m_DynamicBackendsPath) 111*89c4ff92SAndroid Build Coastguard Worker // step of runtime constructor, then the factory function is called to check if supported, in case 112*89c4ff92SAndroid Build Coastguard Worker // of Neon not being detected the exception is raised and so the backend is not added to the supportedBackends 113*89c4ff92SAndroid Build Coastguard Worker // list 114*89c4ff92SAndroid Build Coastguard Worker 115*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ThrowBackendUnavailableException") 116*89c4ff92SAndroid Build Coastguard Worker { 117*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker const BackendId mockBackendId("MockDynamicBackend"); 120*89c4ff92SAndroid Build Coastguard Worker 121*89c4ff92SAndroid Build Coastguard Worker const std::string exceptionMessage("Mock error message to test unavailable backend"); 122*89c4ff92SAndroid Build Coastguard Worker 123*89c4ff92SAndroid Build Coastguard Worker // Register the mock backend with a factory function lambda that always throws 124*89c4ff92SAndroid Build Coastguard Worker BackendRegistryInstance().Register(mockBackendId, 125*89c4ff92SAndroid Build Coastguard Worker [exceptionMessage]() __anon0e237d1a0402() 126*89c4ff92SAndroid Build Coastguard Worker { 127*89c4ff92SAndroid Build Coastguard Worker throw armnn::BackendUnavailableException(exceptionMessage); 128*89c4ff92SAndroid Build Coastguard Worker return IBackendInternalUniquePtr(); // Satisfy return type 129*89c4ff92SAndroid Build Coastguard Worker }); 130*89c4ff92SAndroid Build Coastguard Worker 131*89c4ff92SAndroid Build Coastguard Worker // Get the factory function of the mock backend 132*89c4ff92SAndroid Build Coastguard Worker auto factoryFunc = BackendRegistryInstance().GetFactory(mockBackendId); 133*89c4ff92SAndroid Build Coastguard Worker 134*89c4ff92SAndroid Build Coastguard Worker try 135*89c4ff92SAndroid Build Coastguard Worker { 136*89c4ff92SAndroid Build Coastguard Worker // Call the factory function as done during runtime backend registering 137*89c4ff92SAndroid Build Coastguard Worker auto backend = factoryFunc(); 138*89c4ff92SAndroid Build Coastguard Worker FAIL("Expected exception to have been thrown"); 139*89c4ff92SAndroid Build Coastguard Worker } 140*89c4ff92SAndroid Build Coastguard Worker catch (const BackendUnavailableException& e) 141*89c4ff92SAndroid Build Coastguard Worker { 142*89c4ff92SAndroid Build Coastguard Worker // Caught 143*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(e.what(), exceptionMessage); 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker // Clean up the registry for the next test. 146*89c4ff92SAndroid Build Coastguard Worker BackendRegistryInstance().Deregister(mockBackendId); 147*89c4ff92SAndroid Build Coastguard Worker } 148*89c4ff92SAndroid Build Coastguard Worker 149*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNNREF_ENABLED) 150*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("RegisterMemoryOptimizerStrategy") 151*89c4ff92SAndroid Build Coastguard Worker { 152*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 153*89c4ff92SAndroid Build Coastguard Worker 154*89c4ff92SAndroid Build Coastguard Worker const BackendId cpuRefBackendId(armnn::Compute::CpuRef); 155*89c4ff92SAndroid Build Coastguard Worker CHECK(BackendRegistryInstance().GetMemoryOptimizerStrategies().empty()); 156*89c4ff92SAndroid Build Coastguard Worker 157*89c4ff92SAndroid Build Coastguard Worker // Register the memory optimizer 158*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<IMemoryOptimizerStrategy> memoryOptimizerStrategy = 159*89c4ff92SAndroid Build Coastguard Worker std::make_shared<ConstantMemoryStrategy>(); 160*89c4ff92SAndroid Build Coastguard Worker BackendRegistryInstance().RegisterMemoryOptimizerStrategy(cpuRefBackendId, memoryOptimizerStrategy); 161*89c4ff92SAndroid Build Coastguard Worker CHECK(!BackendRegistryInstance().GetMemoryOptimizerStrategies().empty()); 162*89c4ff92SAndroid Build Coastguard Worker CHECK(BackendRegistryInstance().GetMemoryOptimizerStrategies().size() == 1); 163*89c4ff92SAndroid Build Coastguard Worker // De-register the memory optimizer 164*89c4ff92SAndroid Build Coastguard Worker BackendRegistryInstance().DeregisterMemoryOptimizerStrategy(cpuRefBackendId); 165*89c4ff92SAndroid Build Coastguard Worker CHECK(BackendRegistryInstance().GetMemoryOptimizerStrategies().empty()); 166*89c4ff92SAndroid Build Coastguard Worker } 167*89c4ff92SAndroid Build Coastguard Worker #endif 168*89c4ff92SAndroid Build Coastguard Worker 169*89c4ff92SAndroid Build Coastguard Worker } 170