1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "RefBackend.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "RefBackendId.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "RefWorkloadFactory.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "RefLayerSupport.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "RefTensorHandleFactory.hpp"
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendContext.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IMemoryManager.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/DefaultAllocator.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/SubgraphUtils.hpp>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp>
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker namespace armnn
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker
GetIdStatic()24*89c4ff92SAndroid Build Coastguard Worker const BackendId& RefBackend::GetIdStatic()
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker static const BackendId s_Id{RefBackendId()};
27*89c4ff92SAndroid Build Coastguard Worker return s_Id;
28*89c4ff92SAndroid Build Coastguard Worker }
29*89c4ff92SAndroid Build Coastguard Worker
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr & memoryManager) const30*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr RefBackend::CreateWorkloadFactory(
31*89c4ff92SAndroid Build Coastguard Worker const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<RefWorkloadFactory>(PolymorphicPointerDowncast<RefMemoryManager>(memoryManager));
34*89c4ff92SAndroid Build Coastguard Worker }
35*89c4ff92SAndroid Build Coastguard Worker
CreateWorkloadFactory(class TensorHandleFactoryRegistry & tensorHandleFactoryRegistry) const36*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr RefBackend::CreateWorkloadFactory(
37*89c4ff92SAndroid Build Coastguard Worker class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry) const
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker auto memoryManager = std::make_shared<RefMemoryManager>();
40*89c4ff92SAndroid Build Coastguard Worker
41*89c4ff92SAndroid Build Coastguard Worker tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager);
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<RefTensorHandleFactory> factory = std::make_unique<RefTensorHandleFactory>(memoryManager);
44*89c4ff92SAndroid Build Coastguard Worker // Register copy and import factory pair
45*89c4ff92SAndroid Build Coastguard Worker tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
46*89c4ff92SAndroid Build Coastguard Worker // Register the factory
47*89c4ff92SAndroid Build Coastguard Worker tensorHandleFactoryRegistry.RegisterFactory(std::move(factory));
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<RefWorkloadFactory>(PolymorphicPointerDowncast<RefMemoryManager>(memoryManager));
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
CreateBackendContext(const IRuntime::CreationOptions &) const52*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendContextPtr RefBackend::CreateBackendContext(const IRuntime::CreationOptions&) const
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker return IBackendContextPtr{};
55*89c4ff92SAndroid Build Coastguard Worker }
56*89c4ff92SAndroid Build Coastguard Worker
CreateBackendProfilingContext(const IRuntime::CreationOptions &,IBackendProfilingPtr &)57*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendProfilingContextPtr RefBackend::CreateBackendProfilingContext(
58*89c4ff92SAndroid Build Coastguard Worker const IRuntime::CreationOptions&, IBackendProfilingPtr&)
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker return IBackendProfilingContextPtr{};
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker
CreateMemoryManager() const63*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IMemoryManagerUniquePtr RefBackend::CreateMemoryManager() const
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<RefMemoryManager>();
66*89c4ff92SAndroid Build Coastguard Worker }
67*89c4ff92SAndroid Build Coastguard Worker
GetLayerSupport() const68*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::ILayerSupportSharedPtr RefBackend::GetLayerSupport() const
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker static ILayerSupportSharedPtr layerSupport{new RefLayerSupport};
71*89c4ff92SAndroid Build Coastguard Worker return layerSupport;
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker
OptimizeSubgraphView(const SubgraphView & subgraph,const ModelOptions & modelOptions) const74*89c4ff92SAndroid Build Coastguard Worker OptimizationViews RefBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
75*89c4ff92SAndroid Build Coastguard Worker const ModelOptions& modelOptions) const
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker OptimizationViews optimizationViews(modelOptions);
78*89c4ff92SAndroid Build Coastguard Worker
79*89c4ff92SAndroid Build Coastguard Worker auto it = subgraph.endIConnectable();
80*89c4ff92SAndroid Build Coastguard Worker std::map<LayerGuid, Layer*> untouched;
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker while (it != subgraph.beginIConnectable())
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker --it;
85*89c4ff92SAndroid Build Coastguard Worker Layer& base = *(PolymorphicDowncast<Layer*>(*it));
86*89c4ff92SAndroid Build Coastguard Worker untouched.insert({base.GetGuid(), &base});
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker it = subgraph.endIConnectable();
90*89c4ff92SAndroid Build Coastguard Worker while (it != subgraph.beginIConnectable())
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker --it;
93*89c4ff92SAndroid Build Coastguard Worker Layer& base = *(PolymorphicDowncast<Layer*>(*it));
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker // Special case to fuse padding into average pooling 2d for quantized datatype.
96*89c4ff92SAndroid Build Coastguard Worker // Required to be done as a backend specific optimization as Neon does not support this special case.
97*89c4ff92SAndroid Build Coastguard Worker if (base.GetType() == LayerType::Pooling2d)
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* baseLayer = PolymorphicDowncast<Pooling2dLayer*>(&base);
100*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor poolingDescriptor = baseLayer->GetParameters();
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker if (baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer().GetType() == LayerType::Pad)
103*89c4ff92SAndroid Build Coastguard Worker {
104*89c4ff92SAndroid Build Coastguard Worker PadLayer* padLayer = PolymorphicDowncast<PadLayer*>(
105*89c4ff92SAndroid Build Coastguard Worker &baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer());
106*89c4ff92SAndroid Build Coastguard Worker if (padLayer->GetOutputSlot(0).GetNumConnections() == 1 &&
107*89c4ff92SAndroid Build Coastguard Worker optimizations::pad_fold::TryFoldPadIntoLayer2d(padLayer->GetParameters(),
108*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor,
109*89c4ff92SAndroid Build Coastguard Worker padLayer->GetOutputSlot().GetTensorInfo(),
110*89c4ff92SAndroid Build Coastguard Worker true))
111*89c4ff92SAndroid Build Coastguard Worker {
112*89c4ff92SAndroid Build Coastguard Worker FoldPadIntoAveragePool2d<Pooling2dLayer>(optimizationViews, baseLayer,
113*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor, padLayer);
114*89c4ff92SAndroid Build Coastguard Worker untouched.erase(baseLayer->GetGuid());
115*89c4ff92SAndroid Build Coastguard Worker untouched.erase(padLayer->GetGuid());
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker }
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker if (optimizationViews.GetSubstitutions().empty())
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker optimizationViews.AddUntouchedSubgraph(SubgraphView(subgraph));
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker else
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker ReportUntouchedLayers(optimizationViews, untouched);
128*89c4ff92SAndroid Build Coastguard Worker }
129*89c4ff92SAndroid Build Coastguard Worker
130*89c4ff92SAndroid Build Coastguard Worker return optimizationViews;
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker
GetHandleFactoryPreferences() const133*89c4ff92SAndroid Build Coastguard Worker std::vector<ITensorHandleFactory::FactoryId> RefBackend::GetHandleFactoryPreferences() const
134*89c4ff92SAndroid Build Coastguard Worker {
135*89c4ff92SAndroid Build Coastguard Worker return std::vector<ITensorHandleFactory::FactoryId> { RefTensorHandleFactory::GetIdStatic() };
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker
RegisterTensorHandleFactories(class TensorHandleFactoryRegistry & registry)138*89c4ff92SAndroid Build Coastguard Worker void RefBackend::RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry)
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker auto memoryManager = std::make_shared<RefMemoryManager>();
141*89c4ff92SAndroid Build Coastguard Worker
142*89c4ff92SAndroid Build Coastguard Worker registry.RegisterMemoryManager(memoryManager);
143*89c4ff92SAndroid Build Coastguard Worker
144*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<RefTensorHandleFactory> factory = std::make_unique<RefTensorHandleFactory>(memoryManager);
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker // Register copy and import factory pair
147*89c4ff92SAndroid Build Coastguard Worker registry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
148*89c4ff92SAndroid Build Coastguard Worker // Register the factory
149*89c4ff92SAndroid Build Coastguard Worker registry.RegisterFactory(std::move(factory));
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker
GetDefaultAllocator() const152*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ICustomAllocator> RefBackend::GetDefaultAllocator() const
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<DefaultAllocator>();
155*89c4ff92SAndroid Build Coastguard Worker }
156*89c4ff92SAndroid Build Coastguard Worker
CreateExecutionData(WorkingMemDescriptor & workingMemDescriptor) const157*89c4ff92SAndroid Build Coastguard Worker ExecutionData RefBackend::CreateExecutionData(WorkingMemDescriptor& workingMemDescriptor) const
158*89c4ff92SAndroid Build Coastguard Worker {
159*89c4ff92SAndroid Build Coastguard Worker ExecutionData executionData;
160*89c4ff92SAndroid Build Coastguard Worker executionData.m_Data = &workingMemDescriptor;
161*89c4ff92SAndroid Build Coastguard Worker return executionData;
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker
UpdateExecutionData(ExecutionData & executionData,WorkingMemDescriptor & workingMemDescriptor) const164*89c4ff92SAndroid Build Coastguard Worker void RefBackend::UpdateExecutionData(ExecutionData& executionData, WorkingMemDescriptor& workingMemDescriptor) const
165*89c4ff92SAndroid Build Coastguard Worker {
166*89c4ff92SAndroid Build Coastguard Worker executionData.m_Data = &workingMemDescriptor;
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
170