xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/TensorHandleFactoryRegistry.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <backendsCommon/TensorHandleFactoryRegistry.hpp>
7 #include <armnn/backends/IMemoryManager.hpp>
8 
9 namespace armnn
10 {
11 
RegisterFactory(std::unique_ptr<ITensorHandleFactory> newFactory)12 void TensorHandleFactoryRegistry::RegisterFactory(std::unique_ptr <ITensorHandleFactory> newFactory)
13 {
14     if (!newFactory)
15     {
16         return;
17     }
18 
19     ITensorHandleFactory::FactoryId id = newFactory->GetId();
20 
21     // Don't register duplicates
22     for (auto& registeredFactory : m_Factories)
23     {
24         if (id == registeredFactory->GetId())
25         {
26             return;
27         }
28     }
29 
30     // Take ownership of the new allocator
31     m_Factories.push_back(std::move(newFactory));
32 }
33 
RegisterMemoryManager(std::shared_ptr<armnn::IMemoryManager> memoryManger)34 void TensorHandleFactoryRegistry::RegisterMemoryManager(std::shared_ptr<armnn::IMemoryManager> memoryManger)
35 {
36     m_MemoryManagers.push_back(memoryManger);
37 }
38 
GetFactory(ITensorHandleFactory::FactoryId id) const39 ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFactory::FactoryId id) const
40 {
41     for (auto& factory : m_Factories)
42     {
43         if (factory->GetId() == id)
44         {
45             return factory.get();
46         }
47     }
48 
49     return nullptr;
50 }
51 
GetFactory(ITensorHandleFactory::FactoryId id,MemorySource memSource) const52 ITensorHandleFactory* TensorHandleFactoryRegistry::GetFactory(ITensorHandleFactory::FactoryId id,
53                                                               MemorySource memSource) const
54 {
55     for (auto& factory : m_Factories)
56     {
57         if (factory->GetId() == id && factory->GetImportFlags() == static_cast<MemorySourceFlags>(memSource))
58         {
59             return factory.get();
60         }
61     }
62 
63     return nullptr;
64 }
65 
RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId,ITensorHandleFactory::FactoryId importFactoryId)66 void TensorHandleFactoryRegistry::RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId,
67                                                                    ITensorHandleFactory::FactoryId importFactoryId)
68 {
69     m_FactoryMappings[copyFactoryId] = importFactoryId;
70 }
71 
GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId)72 ITensorHandleFactory::FactoryId TensorHandleFactoryRegistry::GetMatchingImportFactoryId(
73     ITensorHandleFactory::FactoryId copyFactoryId)
74 {
75     return m_FactoryMappings[copyFactoryId];
76 }
77 
AquireMemory()78 void TensorHandleFactoryRegistry::AquireMemory()
79 {
80     for (auto& mgr : m_MemoryManagers)
81     {
82         mgr->Acquire();
83     }
84 }
85 
ReleaseMemory()86 void TensorHandleFactoryRegistry::ReleaseMemory()
87 {
88     for (auto& mgr : m_MemoryManagers)
89     {
90         mgr->Release();
91     }
92 }
93 
94 } // namespace armnn
95