xref: /aosp_15_r20/external/armnn/src/backends/reference/RefBackend.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "RefBackend.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "RefBackendId.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "RefWorkloadFactory.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "RefLayerSupport.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "RefTensorHandleFactory.hpp"
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendContext.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IMemoryManager.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/DefaultAllocator.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/SubgraphUtils.hpp>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker namespace armnn
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker 
GetIdStatic()24*89c4ff92SAndroid Build Coastguard Worker const BackendId& RefBackend::GetIdStatic()
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker     static const BackendId s_Id{RefBackendId()};
27*89c4ff92SAndroid Build Coastguard Worker     return s_Id;
28*89c4ff92SAndroid Build Coastguard Worker }
29*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr & memoryManager) const30*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr RefBackend::CreateWorkloadFactory(
31*89c4ff92SAndroid Build Coastguard Worker     const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<RefWorkloadFactory>(PolymorphicPointerDowncast<RefMemoryManager>(memoryManager));
34*89c4ff92SAndroid Build Coastguard Worker }
35*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(class TensorHandleFactoryRegistry & tensorHandleFactoryRegistry) const36*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr RefBackend::CreateWorkloadFactory(
37*89c4ff92SAndroid Build Coastguard Worker     class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry) const
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker     auto memoryManager = std::make_shared<RefMemoryManager>();
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager);
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<RefTensorHandleFactory> factory = std::make_unique<RefTensorHandleFactory>(memoryManager);
44*89c4ff92SAndroid Build Coastguard Worker     // Register copy and import factory pair
45*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
46*89c4ff92SAndroid Build Coastguard Worker     // Register the factory
47*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterFactory(std::move(factory));
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<RefWorkloadFactory>(PolymorphicPointerDowncast<RefMemoryManager>(memoryManager));
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendContext(const IRuntime::CreationOptions &) const52*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendContextPtr RefBackend::CreateBackendContext(const IRuntime::CreationOptions&) const
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker     return IBackendContextPtr{};
55*89c4ff92SAndroid Build Coastguard Worker }
56*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendProfilingContext(const IRuntime::CreationOptions &,IBackendProfilingPtr &)57*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendProfilingContextPtr RefBackend::CreateBackendProfilingContext(
58*89c4ff92SAndroid Build Coastguard Worker     const IRuntime::CreationOptions&, IBackendProfilingPtr&)
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker     return IBackendProfilingContextPtr{};
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker 
CreateMemoryManager() const63*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IMemoryManagerUniquePtr RefBackend::CreateMemoryManager() const
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<RefMemoryManager>();
66*89c4ff92SAndroid Build Coastguard Worker }
67*89c4ff92SAndroid Build Coastguard Worker 
GetLayerSupport() const68*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::ILayerSupportSharedPtr RefBackend::GetLayerSupport() const
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker     static ILayerSupportSharedPtr layerSupport{new RefLayerSupport};
71*89c4ff92SAndroid Build Coastguard Worker     return layerSupport;
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker 
OptimizeSubgraphView(const SubgraphView & subgraph,const ModelOptions & modelOptions) const74*89c4ff92SAndroid Build Coastguard Worker OptimizationViews RefBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
75*89c4ff92SAndroid Build Coastguard Worker                                                    const ModelOptions& modelOptions) const
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews(modelOptions);
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker     auto it = subgraph.endIConnectable();
80*89c4ff92SAndroid Build Coastguard Worker     std::map<LayerGuid, Layer*> untouched;
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker     while (it != subgraph.beginIConnectable())
83*89c4ff92SAndroid Build Coastguard Worker     {
84*89c4ff92SAndroid Build Coastguard Worker         --it;
85*89c4ff92SAndroid Build Coastguard Worker         Layer& base = *(PolymorphicDowncast<Layer*>(*it));
86*89c4ff92SAndroid Build Coastguard Worker         untouched.insert({base.GetGuid(), &base});
87*89c4ff92SAndroid Build Coastguard Worker     }
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     it = subgraph.endIConnectable();
90*89c4ff92SAndroid Build Coastguard Worker     while (it != subgraph.beginIConnectable())
91*89c4ff92SAndroid Build Coastguard Worker     {
92*89c4ff92SAndroid Build Coastguard Worker         --it;
93*89c4ff92SAndroid Build Coastguard Worker         Layer& base = *(PolymorphicDowncast<Layer*>(*it));
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker         // Special case to fuse padding into average pooling 2d for quantized datatype.
96*89c4ff92SAndroid Build Coastguard Worker         // Required to be done as a backend specific optimization as Neon does not support this special case.
97*89c4ff92SAndroid Build Coastguard Worker         if (base.GetType() == LayerType::Pooling2d)
98*89c4ff92SAndroid Build Coastguard Worker         {
99*89c4ff92SAndroid Build Coastguard Worker             Pooling2dLayer* baseLayer = PolymorphicDowncast<Pooling2dLayer*>(&base);
100*89c4ff92SAndroid Build Coastguard Worker             Pooling2dDescriptor poolingDescriptor = baseLayer->GetParameters();
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker             if (baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer().GetType() == LayerType::Pad)
103*89c4ff92SAndroid Build Coastguard Worker             {
104*89c4ff92SAndroid Build Coastguard Worker                 PadLayer* padLayer = PolymorphicDowncast<PadLayer*>(
105*89c4ff92SAndroid Build Coastguard Worker                     &baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer());
106*89c4ff92SAndroid Build Coastguard Worker                 if (padLayer->GetOutputSlot(0).GetNumConnections() == 1 &&
107*89c4ff92SAndroid Build Coastguard Worker                     optimizations::pad_fold::TryFoldPadIntoLayer2d(padLayer->GetParameters(),
108*89c4ff92SAndroid Build Coastguard Worker                                                                    poolingDescriptor,
109*89c4ff92SAndroid Build Coastguard Worker                                                                    padLayer->GetOutputSlot().GetTensorInfo(),
110*89c4ff92SAndroid Build Coastguard Worker                                                                    true))
111*89c4ff92SAndroid Build Coastguard Worker                 {
112*89c4ff92SAndroid Build Coastguard Worker                     FoldPadIntoAveragePool2d<Pooling2dLayer>(optimizationViews, baseLayer,
113*89c4ff92SAndroid Build Coastguard Worker                                                              poolingDescriptor, padLayer);
114*89c4ff92SAndroid Build Coastguard Worker                     untouched.erase(baseLayer->GetGuid());
115*89c4ff92SAndroid Build Coastguard Worker                     untouched.erase(padLayer->GetGuid());
116*89c4ff92SAndroid Build Coastguard Worker                 }
117*89c4ff92SAndroid Build Coastguard Worker             }
118*89c4ff92SAndroid Build Coastguard Worker         }
119*89c4ff92SAndroid Build Coastguard Worker     }
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     if (optimizationViews.GetSubstitutions().empty())
122*89c4ff92SAndroid Build Coastguard Worker     {
123*89c4ff92SAndroid Build Coastguard Worker         optimizationViews.AddUntouchedSubgraph(SubgraphView(subgraph));
124*89c4ff92SAndroid Build Coastguard Worker     }
125*89c4ff92SAndroid Build Coastguard Worker     else
126*89c4ff92SAndroid Build Coastguard Worker     {
127*89c4ff92SAndroid Build Coastguard Worker         ReportUntouchedLayers(optimizationViews, untouched);
128*89c4ff92SAndroid Build Coastguard Worker     }
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker     return optimizationViews;
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker 
GetHandleFactoryPreferences() const133*89c4ff92SAndroid Build Coastguard Worker std::vector<ITensorHandleFactory::FactoryId> RefBackend::GetHandleFactoryPreferences() const
134*89c4ff92SAndroid Build Coastguard Worker {
135*89c4ff92SAndroid Build Coastguard Worker     return std::vector<ITensorHandleFactory::FactoryId> { RefTensorHandleFactory::GetIdStatic() };
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker 
RegisterTensorHandleFactories(class TensorHandleFactoryRegistry & registry)138*89c4ff92SAndroid Build Coastguard Worker void RefBackend::RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry)
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker     auto memoryManager = std::make_shared<RefMemoryManager>();
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterMemoryManager(memoryManager);
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<RefTensorHandleFactory> factory = std::make_unique<RefTensorHandleFactory>(memoryManager);
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker     // Register copy and import factory pair
147*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
148*89c4ff92SAndroid Build Coastguard Worker     // Register the factory
149*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(factory));
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker 
GetDefaultAllocator() const152*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ICustomAllocator> RefBackend::GetDefaultAllocator() const
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<DefaultAllocator>();
155*89c4ff92SAndroid Build Coastguard Worker }
156*89c4ff92SAndroid Build Coastguard Worker 
CreateExecutionData(WorkingMemDescriptor & workingMemDescriptor) const157*89c4ff92SAndroid Build Coastguard Worker ExecutionData RefBackend::CreateExecutionData(WorkingMemDescriptor& workingMemDescriptor) const
158*89c4ff92SAndroid Build Coastguard Worker {
159*89c4ff92SAndroid Build Coastguard Worker     ExecutionData executionData;
160*89c4ff92SAndroid Build Coastguard Worker     executionData.m_Data = &workingMemDescriptor;
161*89c4ff92SAndroid Build Coastguard Worker     return executionData;
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker 
UpdateExecutionData(ExecutionData & executionData,WorkingMemDescriptor & workingMemDescriptor) const164*89c4ff92SAndroid Build Coastguard Worker void RefBackend::UpdateExecutionData(ExecutionData& executionData, WorkingMemDescriptor& workingMemDescriptor) const
165*89c4ff92SAndroid Build Coastguard Worker {
166*89c4ff92SAndroid Build Coastguard Worker     executionData.m_Data = &workingMemDescriptor;
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
170