1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <backendsCommon/TensorHandleFactoryRegistry.hpp> 7 #include <armnn/backends/IMemoryManager.hpp> 8 9 namespace armnn 10 { 11 RegisterFactory(std::unique_ptr<ITensorHandleFactory> newFactory)12void TensorHandleFactoryRegistry::RegisterFactory(std::unique_ptr <ITensorHandleFactory> newFactory) 13 { 14 if (!newFactory) 15 { 16 return; 17 } 18 19 ITensorHandleFactory::FactoryId id = newFactory->GetId(); 20 21 // Don't register duplicates 22 for (auto& registeredFactory : m_Factories) 23 { 24 if (id == registeredFactory->GetId()) 25 { 26 return; 27 } 28 } 29 30 // Take ownership of the new allocator 31 m_Factories.push_back(std::move(newFactory)); 32 } 33 RegisterMemoryManager(std::shared_ptr<armnn::IMemoryManager> memoryManger)34void TensorHandleFactoryRegistry::RegisterMemoryManager(std::shared_ptr<armnn::IMemoryManager> memoryManger) 35 { 36 m_MemoryManagers.push_back(memoryManger); 37 } 38 GetFactory(ITensorHandleFactory::FactoryId id) const39ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFactory::FactoryId id) const 40 { 41 for (auto& factory : m_Factories) 42 { 43 if (factory->GetId() == id) 44 { 45 return factory.get(); 46 } 47 } 48 49 return nullptr; 50 } 51 GetFactory(ITensorHandleFactory::FactoryId id,MemorySource memSource) const52ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFactory::FactoryId id, 53 MemorySource memSource) const 54 { 55 for (auto& factory : m_Factories) 56 { 57 if (factory->GetId() == id && factory->GetImportFlags() == static_cast<MemorySourceFlags>(memSource)) 58 { 59 return factory.get(); 60 } 61 } 62 63 return nullptr; 64 } 65 RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId,ITensorHandleFactory::FactoryId importFactoryId)66void TensorHandleFactoryRegistry::RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId, 67 ITensorHandleFactory::FactoryId importFactoryId) 68 { 69 m_FactoryMappings[copyFactoryId] = importFactoryId; 70 } 71 GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId)72ITensorHandleFactory::FactoryId TensorHandleFactoryRegistry::GetMatchingImportFactoryId( 73 ITensorHandleFactory::FactoryId copyFactoryId) 74 { 75 return m_FactoryMappings[copyFactoryId]; 76 } 77 AquireMemory()78void TensorHandleFactoryRegistry::AquireMemory() 79 { 80 for (auto& mgr : m_MemoryManagers) 81 { 82 mgr->Acquire(); 83 } 84 } 85 ReleaseMemory()86void TensorHandleFactoryRegistry::ReleaseMemory() 87 { 88 for (auto& mgr : m_MemoryManagers) 89 { 90 mgr->Release(); 91 } 92 } 93 94 } // namespace armnn 95