xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/TensorHandleFactoryRegistry.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Types.hpp>
9 #include <armnn/backends/ITensorHandleFactory.hpp>
10 #include <map>
11 #include <memory>
12 #include <vector>
13 
14 namespace armnn
15 {
16 
17 //Forward
18 class IMemoryManager;
19 
20 using CopyAndImportFactoryPairs = std::map<ITensorHandleFactory::FactoryId, ITensorHandleFactory::FactoryId>;
21 
22 ///
23 class TensorHandleFactoryRegistry
24 {
25 public:
26     TensorHandleFactoryRegistry() = default;
27 
28     TensorHandleFactoryRegistry(const TensorHandleFactoryRegistry& other) = delete;
29     TensorHandleFactoryRegistry(TensorHandleFactoryRegistry&& other) = delete;
30 
31     /// Register a TensorHandleFactory and transfer ownership
32     void RegisterFactory(std::unique_ptr<ITensorHandleFactory> allocator);
33 
34     /// Register a memory manager with shared ownership
35     void RegisterMemoryManager(std::shared_ptr<IMemoryManager> memoryManger);
36 
37     /// Find a TensorHandleFactory by Id
38     /// Returns nullptr if not found
39     ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id) const;
40 
41     /// Overload of above allowing specification of Memory Source
42     ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id,
43                                      MemorySource memSource) const;
44 
45     /// Register a pair of TensorHandleFactory Id for Memory Copy and TensorHandleFactory Id for Memory Import
46     void RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId,
47                                           ITensorHandleFactory::FactoryId importFactoryId);
48 
49     /// Get a matching TensorHandleFatory Id for Memory Import given TensorHandleFactory Id for Memory Copy
50     ITensorHandleFactory::FactoryId GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId);
51 
52     /// Aquire memory required for inference
53     void AquireMemory();
54 
55     /// Release memory required for inference
56     void ReleaseMemory();
57 
GetMemoryManagers()58     std::vector<std::shared_ptr<IMemoryManager>>& GetMemoryManagers()
59     {
60         return m_MemoryManagers;
61     }
62 
63 private:
64     std::vector<std::unique_ptr<ITensorHandleFactory>> m_Factories;
65     std::vector<std::shared_ptr<IMemoryManager>> m_MemoryManagers;
66     CopyAndImportFactoryPairs m_FactoryMappings;
67 };
68 
69 } // namespace armnn
70