xref: /aosp_15_r20/external/armnn/src/backends/tosaReference/TosaRefBackend.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "TosaRefBackend.hpp"
7 #include "TosaRefBackendId.hpp"
8 #include "TosaRefWorkloadFactory.hpp"
9 #include "TosaRefLayerSupport.hpp"
10 #include "TosaRefTensorHandleFactory.hpp"
11 
12 #include <tosaCommon/TosaMappings.hpp>
13 #include <armnn/BackendRegistry.hpp>
14 #include <armnn/backends/IBackendContext.hpp>
15 #include <armnn/backends/IMemoryManager.hpp>
16 #include <armnn/utility/PolymorphicDowncast.hpp>
17 #include <backendsCommon/DefaultAllocator.hpp>
18 #include <backendsCommon/SubgraphUtils.hpp>
19 
20 #include <Optimizer.hpp>
21 
22 namespace armnn
23 {
24 
25 // Utility function to construct a valid Deleter for TosaSerializationHandler ptrs passed back to ArmNN
26 template <typename T>
DeleteAsType(const void * const blob)27 void DeleteAsType(const void* const blob)
28 {
29     delete static_cast<const T*>(blob);
30 }
31 
GetIdStatic()32 const BackendId& TosaRefBackend::GetIdStatic()
33 {
34     static const BackendId s_Id{TosaRefBackendId()};
35     return s_Id;
36 }
37 
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr & memoryManager) const38 IBackendInternal::IWorkloadFactoryPtr TosaRefBackend::CreateWorkloadFactory(
39     const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const
40 {
41     return std::make_unique<TosaRefWorkloadFactory>(PolymorphicPointerDowncast<TosaRefMemoryManager>(memoryManager));
42 }
43 
CreateWorkloadFactory(class TensorHandleFactoryRegistry & tensorHandleFactoryRegistry) const44 IBackendInternal::IWorkloadFactoryPtr TosaRefBackend::CreateWorkloadFactory(
45     class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry) const
46 {
47     auto memoryManager = std::make_shared<TosaRefMemoryManager>();
48 
49     tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager);
50 
51     auto factory = std::make_unique<TosaRefTensorHandleFactory>(memoryManager);
52     // Register copy and import factory pair
53     tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
54     // Register the factory
55     tensorHandleFactoryRegistry.RegisterFactory(std::move(factory));
56 
57     return std::make_unique<TosaRefWorkloadFactory>(PolymorphicPointerDowncast<TosaRefMemoryManager>(memoryManager));
58 }
59 
CreateBackendContext(const IRuntime::CreationOptions &) const60 IBackendInternal::IBackendContextPtr TosaRefBackend::CreateBackendContext(const IRuntime::CreationOptions&) const
61 {
62     return IBackendContextPtr{};
63 }
64 
CreateBackendProfilingContext(const IRuntime::CreationOptions &,IBackendProfilingPtr &)65 IBackendInternal::IBackendProfilingContextPtr TosaRefBackend::CreateBackendProfilingContext(
66     const IRuntime::CreationOptions&, IBackendProfilingPtr&)
67 {
68     return IBackendProfilingContextPtr{};
69 }
70 
CreateMemoryManager() const71 IBackendInternal::IMemoryManagerUniquePtr TosaRefBackend::CreateMemoryManager() const
72 {
73     return std::make_unique<TosaRefMemoryManager>();
74 }
75 
GetLayerSupport() const76 IBackendInternal::ILayerSupportSharedPtr TosaRefBackend::GetLayerSupport() const
77 {
78     static ILayerSupportSharedPtr layerSupport{new TosaRefLayerSupport};
79     return layerSupport;
80 }
81 
OptimizeSubgraphView(const SubgraphView & subgraph,const ModelOptions & modelOptions) const82 OptimizationViews TosaRefBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
83                                                        const ModelOptions& modelOptions) const
84 {
85     OptimizationViews optimizationViews(modelOptions);
86 
87     auto handler = std::make_unique<TosaSerializationHandler>();
88 
89     std::vector<std::string> graphInputs;
90     std::vector<std::string> graphOutputs;
91 
92     std::vector<TosaSerializationOperator*> operators;
93     std::vector<TosaSerializationTensor*> tensors;
94 
95     auto it = subgraph.endIConnectable();
96     while (it != subgraph.beginIConnectable())
97     {
98         --it;
99         Layer& base = *(PolymorphicDowncast<Layer*>(*it));
100 
101         if(base.GetType() == armnn::LayerType::Input ||
102            base.GetType() == armnn::LayerType::Output)
103         {
104             continue;
105         }
106 
107         tosa::TosaSerializationBasicBlock* mappings = GetTosaMappingFromLayer(&base);
108 
109         // Loop through inputs to see if there are any graph inputs, if so save them.
110         // If it's an input to the graph "input" can be found in the string.
111         for (uint32_t i = 0; i < mappings->GetInputs().size(); i++)
112         {
113             std::basic_string<char> blockInputName = mappings->GetInputs()[i];
114 
115             if (blockInputName.find("input") != std::string::npos)
116             {
117                 graphInputs.push_back(blockInputName);
118             }
119         }
120 
121         // Loop through outputs to see if there are any graph outputs, if so save them.
122         // If it's an output to the graph "output" can be found in the string.
123         for (uint32_t i = 0; i < mappings->GetOutputs().size(); i++)
124         {
125             std::basic_string<char> blockOutputName = mappings->GetOutputs()[i];
126 
127             if (blockOutputName.find("output") != std::string::npos)
128             {
129                 graphOutputs.push_back(blockOutputName);
130             }
131         }
132 
133         auto blockOperators = mappings->GetOperators();
134         operators.insert(operators.end(), blockOperators.begin(), blockOperators.end());
135 
136         auto blockTensors = mappings->GetTensors();
137         tensors.insert(tensors.end(), blockTensors.begin(), blockTensors.end());
138     }
139 
140     // Add all mappings to main block, the TOSA Reference Model requires the full graph to be in one block called main.
141     auto* block = new TosaSerializationBasicBlock("main", operators, tensors, graphInputs, graphOutputs);
142 
143     handler.get()->GetBlocks().push_back(block);
144 
145     auto compiledBlob =
146             std::make_unique<PreCompiledObjectPtr>(handler.release(), DeleteAsType<TosaSerializationHandler>);
147 
148     IConnectableLayer* preCompiledLayer = optimizationViews.GetINetwork()->AddPrecompiledLayer(
149             PreCompiledDescriptor(subgraph.GetNumInputSlots(), subgraph.GetNumOutputSlots()),
150             std::move(*compiledBlob),
151             armnn::Optional<BackendId>(GetId()),
152             "TOSA_Pre_Compiled_Layer");
153 
154     // Copy the output tensor infos from sub-graph
155     for (unsigned int i = 0; i < subgraph.GetNumOutputSlots(); i++)
156     {
157         preCompiledLayer->GetOutputSlot(i).SetTensorInfo(subgraph.GetIOutputSlot(i)->GetTensorInfo());
158     }
159 
160     optimizationViews.AddSubstitution({ std::move(subgraph), SubgraphView(preCompiledLayer) });
161     return optimizationViews;
162 }
163 
164 
GetHandleFactoryPreferences() const165 std::vector<ITensorHandleFactory::FactoryId> TosaRefBackend::GetHandleFactoryPreferences() const
166 {
167     return std::vector<ITensorHandleFactory::FactoryId> { TosaRefTensorHandleFactory::GetIdStatic() };
168 }
169 
RegisterTensorHandleFactories(class TensorHandleFactoryRegistry & registry)170 void TosaRefBackend::RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry)
171 {
172     auto memoryManager = std::make_shared<TosaRefMemoryManager>();
173 
174     registry.RegisterMemoryManager(memoryManager);
175 
176     auto factory = std::make_unique<TosaRefTensorHandleFactory>(memoryManager);
177 
178     // Register copy and import factory pair
179     registry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
180     // Register the factory
181     registry.RegisterFactory(std::move(factory));
182 }
183 
GetDefaultAllocator() const184 std::unique_ptr<ICustomAllocator> TosaRefBackend::GetDefaultAllocator() const
185 {
186     return std::make_unique<DefaultAllocator>();
187 }
188 
189 } // namespace armnn
190