1 // 2 // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Types.hpp> 9 #include <armnn/backends/ITensorHandleFactory.hpp> 10 #include <map> 11 #include <memory> 12 #include <vector> 13 14 namespace armnn 15 { 16 17 //Forward 18 class IMemoryManager; 19 20 using CopyAndImportFactoryPairs = std::map<ITensorHandleFactory::FactoryId, ITensorHandleFactory::FactoryId>; 21 22 /// 23 class TensorHandleFactoryRegistry 24 { 25 public: 26 TensorHandleFactoryRegistry() = default; 27 28 TensorHandleFactoryRegistry(const TensorHandleFactoryRegistry& other) = delete; 29 TensorHandleFactoryRegistry(TensorHandleFactoryRegistry&& other) = delete; 30 31 /// Register a TensorHandleFactory and transfer ownership 32 void RegisterFactory(std::unique_ptr<ITensorHandleFactory> allocator); 33 34 /// Register a memory manager with shared ownership 35 void RegisterMemoryManager(std::shared_ptr<IMemoryManager> memoryManger); 36 37 /// Find a TensorHandleFactory by Id 38 /// Returns nullptr if not found 39 ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id) const; 40 41 /// Overload of above allowing specification of Memory Source 42 ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id, 43 MemorySource memSource) const; 44 45 /// Register a pair of TensorHandleFactory Id for Memory Copy and TensorHandleFactory Id for Memory Import 46 void RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId, 47 ITensorHandleFactory::FactoryId importFactoryId); 48 49 /// Get a matching TensorHandleFatory Id for Memory Import given TensorHandleFactory Id for Memory Copy 50 ITensorHandleFactory::FactoryId GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId); 51 52 /// Aquire memory required for inference 53 void AquireMemory(); 54 55 /// Release memory required for inference 56 void ReleaseMemory(); 57 GetMemoryManagers()58 std::vector<std::shared_ptr<IMemoryManager>>& GetMemoryManagers() 59 { 60 return m_MemoryManagers; 61 } 62 63 private: 64 std::vector<std::unique_ptr<ITensorHandleFactory>> m_Factories; 65 std::vector<std::shared_ptr<IMemoryManager>> m_MemoryManagers; 66 CopyAndImportFactoryPairs m_FactoryMappings; 67 }; 68 69 } // namespace armnn 70