xref: /aosp_15_r20/external/armnn/src/backends/neon/NeonBackend.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "NeonBackend.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "NeonBackendId.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "NeonBackendModelContext.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "NeonWorkloadFactory.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "NeonLayerSupport.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "NeonTensorHandleFactory.hpp"
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeSubgraphUtils.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeUtils.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/BaseMemoryManager.hpp>
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendContext.hpp>
21*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IMemoryManager.hpp>
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker #include <neon/workloads/NeonAdditionWorkload.hpp>
26*89c4ff92SAndroid Build Coastguard Worker #include <neon/workloads/NeonBatchNormalizationWorkload.hpp>
27*89c4ff92SAndroid Build Coastguard Worker #include <neon/workloads/NeonConvolution2dWorkload.hpp>
28*89c4ff92SAndroid Build Coastguard Worker #include <neon/workloads/NeonDepthwiseConvolutionWorkload.hpp>
29*89c4ff92SAndroid Build Coastguard Worker #include <neon/workloads/NeonDivisionWorkload.hpp>
30*89c4ff92SAndroid Build Coastguard Worker #include <neon/workloads/NeonFullyConnectedWorkload.hpp>
31*89c4ff92SAndroid Build Coastguard Worker #include <neon/workloads/NeonMultiplicationWorkload.hpp>
32*89c4ff92SAndroid Build Coastguard Worker #include <neon/workloads/NeonReduceWorkload.hpp>
33*89c4ff92SAndroid Build Coastguard Worker #include <neon/workloads/NeonSubtractionWorkload.hpp>
34*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/DefaultAllocator.hpp>
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp>
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/Types.h>
39*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/Allocator.h>
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker namespace armnn
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker 
GetIdStatic()44*89c4ff92SAndroid Build Coastguard Worker const BackendId& NeonBackend::GetIdStatic()
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker     static const BackendId s_Id{NeonBackendId()};
47*89c4ff92SAndroid Build Coastguard Worker     return s_Id;
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker 
CreateMemoryManager() const50*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IMemoryManagerUniquePtr NeonBackend::CreateMemoryManager() const
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<NeonMemoryManager>(std::make_unique<arm_compute::Allocator>(),
53*89c4ff92SAndroid Build Coastguard Worker                                                BaseMemoryManager::MemoryAffinity::Offset);
54*89c4ff92SAndroid Build Coastguard Worker }
55*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr & memoryManager) const56*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory(
57*89c4ff92SAndroid Build Coastguard Worker     const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<NeonWorkloadFactory>(
60*89c4ff92SAndroid Build Coastguard Worker         PolymorphicPointerDowncast<NeonMemoryManager>(memoryManager));
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const ModelOptions & modelOptions) const63*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory(
64*89c4ff92SAndroid Build Coastguard Worker     const IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const ModelOptions& modelOptions) const
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<NeonWorkloadFactory>(
67*89c4ff92SAndroid Build Coastguard Worker         PolymorphicPointerDowncast<NeonMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(class TensorHandleFactoryRegistry & tensorHandleFactoryRegistry) const70*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory(
71*89c4ff92SAndroid Build Coastguard Worker     class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry) const
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker     auto memoryManager = std::make_shared<NeonMemoryManager>(std::make_unique<arm_compute::Allocator>(),
74*89c4ff92SAndroid Build Coastguard Worker                                                              BaseMemoryManager::MemoryAffinity::Offset);
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager);
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker     auto factory = std::make_unique<NeonTensorHandleFactory>(memoryManager);
79*89c4ff92SAndroid Build Coastguard Worker     // Register copy and import factory pair
80*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
81*89c4ff92SAndroid Build Coastguard Worker     // Register the factory
82*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterFactory(std::move(factory));
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<NeonWorkloadFactory>(
86*89c4ff92SAndroid Build Coastguard Worker         PolymorphicPointerDowncast<NeonMemoryManager>(memoryManager));
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(TensorHandleFactoryRegistry & tensorHandleFactoryRegistry,const ModelOptions & modelOptions) const89*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory(
90*89c4ff92SAndroid Build Coastguard Worker     TensorHandleFactoryRegistry& tensorHandleFactoryRegistry, const ModelOptions& modelOptions) const
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker     auto memoryManager = std::make_shared<NeonMemoryManager>(std::make_unique<arm_compute::Allocator>(),
93*89c4ff92SAndroid Build Coastguard Worker                                                              BaseMemoryManager::MemoryAffinity::Offset);
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager);
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker     auto factory = std::make_unique<NeonTensorHandleFactory>(memoryManager);
98*89c4ff92SAndroid Build Coastguard Worker     // Register copy and import factory pair
99*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
100*89c4ff92SAndroid Build Coastguard Worker     // Register the factory
101*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterFactory(std::move(factory));
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<NeonWorkloadFactory>(
104*89c4ff92SAndroid Build Coastguard Worker         PolymorphicPointerDowncast<NeonMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendContext(const IRuntime::CreationOptions &) const107*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendContextPtr NeonBackend::CreateBackendContext(const IRuntime::CreationOptions&) const
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker     return IBackendContextPtr{};
110*89c4ff92SAndroid Build Coastguard Worker }
111*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendProfilingContext(const IRuntime::CreationOptions &,IBackendProfilingPtr &)112*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendProfilingContextPtr NeonBackend::CreateBackendProfilingContext(
113*89c4ff92SAndroid Build Coastguard Worker     const IRuntime::CreationOptions&, IBackendProfilingPtr&)
114*89c4ff92SAndroid Build Coastguard Worker {
115*89c4ff92SAndroid Build Coastguard Worker     return IBackendProfilingContextPtr{};
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendSpecificModelContext(const ModelOptions & modelOptions) const118*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendSpecificModelContextPtr NeonBackend::CreateBackendSpecificModelContext(
119*89c4ff92SAndroid Build Coastguard Worker     const ModelOptions& modelOptions) const
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker     return IBackendSpecificModelContextPtr{new NeonBackendModelContext{modelOptions}};
122*89c4ff92SAndroid Build Coastguard Worker }
123*89c4ff92SAndroid Build Coastguard Worker 
GetLayerSupport() const124*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::ILayerSupportSharedPtr NeonBackend::GetLayerSupport() const
125*89c4ff92SAndroid Build Coastguard Worker {
126*89c4ff92SAndroid Build Coastguard Worker     static ILayerSupportSharedPtr layerSupport
127*89c4ff92SAndroid Build Coastguard Worker         {
128*89c4ff92SAndroid Build Coastguard Worker             new NeonLayerSupport(IBackendInternal::IBackendSpecificModelContextPtr{})
129*89c4ff92SAndroid Build Coastguard Worker         };
130*89c4ff92SAndroid Build Coastguard Worker     return layerSupport;
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker 
GetLayerSupport(const ModelOptions & modelOptions) const133*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::ILayerSupportSharedPtr NeonBackend::GetLayerSupport(const ModelOptions& modelOptions) const
134*89c4ff92SAndroid Build Coastguard Worker {
135*89c4ff92SAndroid Build Coastguard Worker     static ILayerSupportSharedPtr layerSupport
136*89c4ff92SAndroid Build Coastguard Worker         {
137*89c4ff92SAndroid Build Coastguard Worker             new NeonLayerSupport(CreateBackendSpecificModelContext(modelOptions))
138*89c4ff92SAndroid Build Coastguard Worker         };
139*89c4ff92SAndroid Build Coastguard Worker     return layerSupport;
140*89c4ff92SAndroid Build Coastguard Worker }
141*89c4ff92SAndroid Build Coastguard Worker 
OptimizeSubgraphView(const SubgraphView & subgraph,const ModelOptions & modelOptions) const142*89c4ff92SAndroid Build Coastguard Worker OptimizationViews NeonBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
143*89c4ff92SAndroid Build Coastguard Worker                                                     const ModelOptions& modelOptions) const
144*89c4ff92SAndroid Build Coastguard Worker {
145*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews(modelOptions);
146*89c4ff92SAndroid Build Coastguard Worker 
147*89c4ff92SAndroid Build Coastguard Worker     auto it = subgraph.endIConnectable();
148*89c4ff92SAndroid Build Coastguard Worker     std::map<LayerGuid, Layer*> untouched;
149*89c4ff92SAndroid Build Coastguard Worker 
150*89c4ff92SAndroid Build Coastguard Worker     while (it != subgraph.beginIConnectable())
151*89c4ff92SAndroid Build Coastguard Worker     {
152*89c4ff92SAndroid Build Coastguard Worker         --it;
153*89c4ff92SAndroid Build Coastguard Worker         Layer& base = *(PolymorphicDowncast<Layer*>(*it));
154*89c4ff92SAndroid Build Coastguard Worker         untouched.insert({base.GetGuid(), &base});
155*89c4ff92SAndroid Build Coastguard Worker     }
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker     it = subgraph.endIConnectable();
158*89c4ff92SAndroid Build Coastguard Worker     while (it != subgraph.beginIConnectable())
159*89c4ff92SAndroid Build Coastguard Worker     {
160*89c4ff92SAndroid Build Coastguard Worker         --it;
161*89c4ff92SAndroid Build Coastguard Worker         Layer& base = *(PolymorphicDowncast<Layer*>(*it));
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker         // Fuse activation into previous layer if supported by backend
164*89c4ff92SAndroid Build Coastguard Worker         if ((base.GetType() == LayerType::DepthwiseConvolution2d || base.GetType() == LayerType::Convolution2d
165*89c4ff92SAndroid Build Coastguard Worker              || base.GetType() == LayerType::BatchNormalization || base.GetType() == LayerType::FullyConnected
166*89c4ff92SAndroid Build Coastguard Worker              || base.GetType() == LayerType::Addition || base.GetType() == LayerType::Multiplication
167*89c4ff92SAndroid Build Coastguard Worker              || base.GetType() == LayerType::Subtraction || base.GetType() == LayerType::Division)
168*89c4ff92SAndroid Build Coastguard Worker             && (base.GetAdditionalInformation<ActivationDescriptor>() == nullptr))
169*89c4ff92SAndroid Build Coastguard Worker         {
170*89c4ff92SAndroid Build Coastguard Worker             for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
171*89c4ff92SAndroid Build Coastguard Worker             {
172*89c4ff92SAndroid Build Coastguard Worker                 if (output->GetNumConnections() == 1)
173*89c4ff92SAndroid Build Coastguard Worker                 {
174*89c4ff92SAndroid Build Coastguard Worker                     for (auto&& childInput : output->GetConnections())
175*89c4ff92SAndroid Build Coastguard Worker                     {
176*89c4ff92SAndroid Build Coastguard Worker                         if ((childInput->GetOwningLayer().GetType() == LayerType::Activation) &&
177*89c4ff92SAndroid Build Coastguard Worker                             (checkDataTypeInputandOutput(childInput->GetOwningLayer())))
178*89c4ff92SAndroid Build Coastguard Worker                         {
179*89c4ff92SAndroid Build Coastguard Worker                             Layer& child = childInput->GetOwningLayer();
180*89c4ff92SAndroid Build Coastguard Worker 
181*89c4ff92SAndroid Build Coastguard Worker                             auto* activationLayer = PolymorphicDowncast<ActivationLayer*>(&child);
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker                             const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") +
184*89c4ff92SAndroid Build Coastguard Worker                                                      base.GetName();
185*89c4ff92SAndroid Build Coastguard Worker 
186*89c4ff92SAndroid Build Coastguard Worker                             // Get params from activation layer
187*89c4ff92SAndroid Build Coastguard Worker                             ActivationDescriptor activationDesc = activationLayer->GetParameters();
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker                             if (base.GetType() == LayerType::Convolution2d)
190*89c4ff92SAndroid Build Coastguard Worker                             {
191*89c4ff92SAndroid Build Coastguard Worker                                 Convolution2dLayer* baseLayer = PolymorphicDowncast<Convolution2dLayer*>(&base);
192*89c4ff92SAndroid Build Coastguard Worker 
193*89c4ff92SAndroid Build Coastguard Worker                                 Optional<TensorInfo> biases;
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker                                 if (baseLayer->GetParameters().m_BiasEnabled)
196*89c4ff92SAndroid Build Coastguard Worker                                 {
197*89c4ff92SAndroid Build Coastguard Worker                                     biases = baseLayer->GetInputSlot(2).GetConnectedOutputSlot()->GetTensorInfo();
198*89c4ff92SAndroid Build Coastguard Worker                                 }
199*89c4ff92SAndroid Build Coastguard Worker 
200*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = NeonConvolution2dWorkloadValidate(
201*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
202*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
203*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetParameters(),
204*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
205*89c4ff92SAndroid Build Coastguard Worker                                         biases,
206*89c4ff92SAndroid Build Coastguard Worker                                         false,
207*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
210*89c4ff92SAndroid Build Coastguard Worker                                 {
211*89c4ff92SAndroid Build Coastguard Worker                                     FuseConvolution2dLayer<Convolution2dLayer>(optimizationViews,
212*89c4ff92SAndroid Build Coastguard Worker                                                                                baseLayer,
213*89c4ff92SAndroid Build Coastguard Worker                                                                                activationLayer,
214*89c4ff92SAndroid Build Coastguard Worker                                                                                activationDesc,
215*89c4ff92SAndroid Build Coastguard Worker                                                                                name);
216*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
217*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
218*89c4ff92SAndroid Build Coastguard Worker                                 }
219*89c4ff92SAndroid Build Coastguard Worker                             }
220*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::DepthwiseConvolution2d)
221*89c4ff92SAndroid Build Coastguard Worker                             {
222*89c4ff92SAndroid Build Coastguard Worker                                 DepthwiseConvolution2dLayer* baseLayer =
223*89c4ff92SAndroid Build Coastguard Worker                                         PolymorphicDowncast<DepthwiseConvolution2dLayer*>(&base);
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker                                 Optional<TensorInfo> biases;
226*89c4ff92SAndroid Build Coastguard Worker 
227*89c4ff92SAndroid Build Coastguard Worker                                 if (baseLayer->GetParameters().m_BiasEnabled)
228*89c4ff92SAndroid Build Coastguard Worker                                 {
229*89c4ff92SAndroid Build Coastguard Worker                                     biases = baseLayer->GetInputSlot(2).GetConnectedOutputSlot()->GetTensorInfo();
230*89c4ff92SAndroid Build Coastguard Worker                                 }
231*89c4ff92SAndroid Build Coastguard Worker 
232*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = NeonDepthwiseConvolutionWorkloadValidate(
233*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
234*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
235*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetParameters(),
236*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
237*89c4ff92SAndroid Build Coastguard Worker                                         biases,
238*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
239*89c4ff92SAndroid Build Coastguard Worker 
240*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
241*89c4ff92SAndroid Build Coastguard Worker                                 {
242*89c4ff92SAndroid Build Coastguard Worker                                     FuseDepthwiseConvolution2dLayer<DepthwiseConvolution2dLayer>(optimizationViews,
243*89c4ff92SAndroid Build Coastguard Worker                                                                                                  baseLayer,
244*89c4ff92SAndroid Build Coastguard Worker                                                                                                  activationLayer,
245*89c4ff92SAndroid Build Coastguard Worker                                                                                                  activationDesc,
246*89c4ff92SAndroid Build Coastguard Worker                                                                                                  name);
247*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
248*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
249*89c4ff92SAndroid Build Coastguard Worker                                 }
250*89c4ff92SAndroid Build Coastguard Worker                             }
251*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::FullyConnected)
252*89c4ff92SAndroid Build Coastguard Worker                             {
253*89c4ff92SAndroid Build Coastguard Worker                                 FullyConnectedLayer* baseLayer = PolymorphicDowncast<FullyConnectedLayer*>(&base);
254*89c4ff92SAndroid Build Coastguard Worker                                 FullyConnectedDescriptor descriptor = baseLayer->GetParameters();
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker                                 // As bias is optional only try to get TensorInfo from input if bias is enabled.
257*89c4ff92SAndroid Build Coastguard Worker                                 Optional<TensorInfo> biases;
258*89c4ff92SAndroid Build Coastguard Worker                                 if (descriptor.m_BiasEnabled)
259*89c4ff92SAndroid Build Coastguard Worker                                 {
260*89c4ff92SAndroid Build Coastguard Worker                                     biases = baseLayer->GetInputSlot(2).GetConnectedOutputSlot()->GetTensorInfo();
261*89c4ff92SAndroid Build Coastguard Worker                                 }
262*89c4ff92SAndroid Build Coastguard Worker 
263*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = NeonFullyConnectedWorkloadValidate(
264*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
265*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
266*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
267*89c4ff92SAndroid Build Coastguard Worker                                         biases,
268*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetParameters(),
269*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
270*89c4ff92SAndroid Build Coastguard Worker 
271*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
272*89c4ff92SAndroid Build Coastguard Worker                                 {
273*89c4ff92SAndroid Build Coastguard Worker                                     FuseFullyConnectedLayer<FullyConnectedLayer>(optimizationViews,
274*89c4ff92SAndroid Build Coastguard Worker                                                                                  baseLayer,
275*89c4ff92SAndroid Build Coastguard Worker                                                                                  activationLayer,
276*89c4ff92SAndroid Build Coastguard Worker                                                                                  activationDesc,
277*89c4ff92SAndroid Build Coastguard Worker                                                                                  name);
278*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
279*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
280*89c4ff92SAndroid Build Coastguard Worker                                 }
281*89c4ff92SAndroid Build Coastguard Worker                             }
282*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::BatchNormalization)
283*89c4ff92SAndroid Build Coastguard Worker                             {
284*89c4ff92SAndroid Build Coastguard Worker                                 BatchNormalizationLayer* baseLayer =
285*89c4ff92SAndroid Build Coastguard Worker                                         PolymorphicDowncast<BatchNormalizationLayer*>(&base);
286*89c4ff92SAndroid Build Coastguard Worker 
287*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = NeonBatchNormalizationValidate(
288*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
289*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
290*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->m_Mean->GetTensorInfo(),
291*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->m_Variance->GetTensorInfo(),
292*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->m_Beta->GetTensorInfo(),
293*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->m_Gamma->GetTensorInfo(),
294*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetParameters(),
295*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
296*89c4ff92SAndroid Build Coastguard Worker 
297*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
298*89c4ff92SAndroid Build Coastguard Worker                                 {
299*89c4ff92SAndroid Build Coastguard Worker                                     BatchNormalizationLayer* replacementLayer =
300*89c4ff92SAndroid Build Coastguard Worker                                         FuseBatchNormalizationLayer<BatchNormalizationLayer>(optimizationViews,
301*89c4ff92SAndroid Build Coastguard Worker                                                                                              baseLayer,
302*89c4ff92SAndroid Build Coastguard Worker                                                                                              activationLayer,
303*89c4ff92SAndroid Build Coastguard Worker                                                                                              activationDesc,
304*89c4ff92SAndroid Build Coastguard Worker                                                                                              name);
305*89c4ff92SAndroid Build Coastguard Worker 
306*89c4ff92SAndroid Build Coastguard Worker                                     replacementLayer->m_Beta     = std::move(baseLayer->m_Beta);
307*89c4ff92SAndroid Build Coastguard Worker                                     replacementLayer->m_Gamma    = std::move(baseLayer->m_Gamma);
308*89c4ff92SAndroid Build Coastguard Worker                                     replacementLayer->m_Mean     = std::move(baseLayer->m_Mean);
309*89c4ff92SAndroid Build Coastguard Worker                                     replacementLayer->m_Variance = std::move(baseLayer->m_Variance);
310*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
311*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
312*89c4ff92SAndroid Build Coastguard Worker                                 }
313*89c4ff92SAndroid Build Coastguard Worker                             }
314*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::Addition)
315*89c4ff92SAndroid Build Coastguard Worker                             {
316*89c4ff92SAndroid Build Coastguard Worker                                 AdditionLayer* baseLayer = PolymorphicDowncast<AdditionLayer*>(&base);
317*89c4ff92SAndroid Build Coastguard Worker 
318*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = NeonAdditionWorkloadValidate(
319*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
320*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
321*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
322*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
323*89c4ff92SAndroid Build Coastguard Worker 
324*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
325*89c4ff92SAndroid Build Coastguard Worker                                 {
326*89c4ff92SAndroid Build Coastguard Worker                                     FuseAdditionLayer<AdditionLayer>(optimizationViews,
327*89c4ff92SAndroid Build Coastguard Worker                                                                      baseLayer,
328*89c4ff92SAndroid Build Coastguard Worker                                                                      activationLayer,
329*89c4ff92SAndroid Build Coastguard Worker                                                                      activationDesc,
330*89c4ff92SAndroid Build Coastguard Worker                                                                      name);
331*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
332*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
333*89c4ff92SAndroid Build Coastguard Worker                                 }
334*89c4ff92SAndroid Build Coastguard Worker                             }
335*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::Division)
336*89c4ff92SAndroid Build Coastguard Worker                             {
337*89c4ff92SAndroid Build Coastguard Worker                                 DivisionLayer* baseLayer = PolymorphicDowncast<DivisionLayer*>(&base);
338*89c4ff92SAndroid Build Coastguard Worker 
339*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = NeonDivisionWorkloadValidate(
340*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
341*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
342*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
343*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
344*89c4ff92SAndroid Build Coastguard Worker 
345*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
346*89c4ff92SAndroid Build Coastguard Worker                                 {
347*89c4ff92SAndroid Build Coastguard Worker                                     FuseDivisionLayer<DivisionLayer>(optimizationViews,
348*89c4ff92SAndroid Build Coastguard Worker                                                                      baseLayer,
349*89c4ff92SAndroid Build Coastguard Worker                                                                      activationLayer,
350*89c4ff92SAndroid Build Coastguard Worker                                                                      activationDesc,
351*89c4ff92SAndroid Build Coastguard Worker                                                                      name);
352*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
353*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
354*89c4ff92SAndroid Build Coastguard Worker                                 }
355*89c4ff92SAndroid Build Coastguard Worker                             }
356*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::Multiplication)
357*89c4ff92SAndroid Build Coastguard Worker                             {
358*89c4ff92SAndroid Build Coastguard Worker                                 MultiplicationLayer* baseLayer = PolymorphicDowncast<MultiplicationLayer*>(&base);
359*89c4ff92SAndroid Build Coastguard Worker 
360*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = NeonMultiplicationWorkloadValidate(
361*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
362*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
363*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
364*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
365*89c4ff92SAndroid Build Coastguard Worker 
366*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
367*89c4ff92SAndroid Build Coastguard Worker                                 {
368*89c4ff92SAndroid Build Coastguard Worker                                     FuseMultiplicationLayer<MultiplicationLayer>(optimizationViews,
369*89c4ff92SAndroid Build Coastguard Worker                                                                                  baseLayer,
370*89c4ff92SAndroid Build Coastguard Worker                                                                                  activationLayer,
371*89c4ff92SAndroid Build Coastguard Worker                                                                                  activationDesc,
372*89c4ff92SAndroid Build Coastguard Worker                                                                                  name);
373*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
374*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
375*89c4ff92SAndroid Build Coastguard Worker                                 }
376*89c4ff92SAndroid Build Coastguard Worker                             }
377*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::Subtraction)
378*89c4ff92SAndroid Build Coastguard Worker                             {
379*89c4ff92SAndroid Build Coastguard Worker                                 SubtractionLayer* baseLayer = PolymorphicDowncast<SubtractionLayer*>(&base);
380*89c4ff92SAndroid Build Coastguard Worker 
381*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = NeonSubtractionWorkloadValidate(
382*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
383*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
384*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
385*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
386*89c4ff92SAndroid Build Coastguard Worker 
387*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
388*89c4ff92SAndroid Build Coastguard Worker                                 {
389*89c4ff92SAndroid Build Coastguard Worker                                     FuseSubtractionLayer<SubtractionLayer>(optimizationViews,
390*89c4ff92SAndroid Build Coastguard Worker                                                                            baseLayer,
391*89c4ff92SAndroid Build Coastguard Worker                                                                            activationLayer,
392*89c4ff92SAndroid Build Coastguard Worker                                                                            activationDesc,
393*89c4ff92SAndroid Build Coastguard Worker                                                                            name);
394*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
395*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
396*89c4ff92SAndroid Build Coastguard Worker                                 }
397*89c4ff92SAndroid Build Coastguard Worker                             }
398*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::ElementwiseBinary)
399*89c4ff92SAndroid Build Coastguard Worker                             {
400*89c4ff92SAndroid Build Coastguard Worker                                 ElementwiseBinaryLayer* baseLayer = PolymorphicDowncast<ElementwiseBinaryLayer*>(&base);
401*89c4ff92SAndroid Build Coastguard Worker 
402*89c4ff92SAndroid Build Coastguard Worker                                 if (baseLayer->GetParameters().m_Operation == BinaryOperation::Add)
403*89c4ff92SAndroid Build Coastguard Worker                                 {
404*89c4ff92SAndroid Build Coastguard Worker                                     arm_compute::Status status = NeonAdditionWorkloadValidate(
405*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
406*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
407*89c4ff92SAndroid Build Coastguard Worker                                             activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
408*89c4ff92SAndroid Build Coastguard Worker                                             &activationDesc);
409*89c4ff92SAndroid Build Coastguard Worker 
410*89c4ff92SAndroid Build Coastguard Worker                                     if (status)
411*89c4ff92SAndroid Build Coastguard Worker                                     {
412*89c4ff92SAndroid Build Coastguard Worker                                         FuseElementwiseBinaryLayer<ElementwiseBinaryLayer>(optimizationViews,
413*89c4ff92SAndroid Build Coastguard Worker                                                                                            baseLayer,
414*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationLayer,
415*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationDesc,
416*89c4ff92SAndroid Build Coastguard Worker                                                                                            BinaryOperation::Add,
417*89c4ff92SAndroid Build Coastguard Worker                                                                                            name);
418*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(baseLayer->GetGuid());
419*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(activationLayer->GetGuid());
420*89c4ff92SAndroid Build Coastguard Worker                                     }
421*89c4ff92SAndroid Build Coastguard Worker                                 }
422*89c4ff92SAndroid Build Coastguard Worker                                 else if (baseLayer->GetParameters().m_Operation == BinaryOperation::Div)
423*89c4ff92SAndroid Build Coastguard Worker                                 {
424*89c4ff92SAndroid Build Coastguard Worker                                     arm_compute::Status status = NeonDivisionWorkloadValidate(
425*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
426*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
427*89c4ff92SAndroid Build Coastguard Worker                                             activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
428*89c4ff92SAndroid Build Coastguard Worker                                             &activationDesc);
429*89c4ff92SAndroid Build Coastguard Worker 
430*89c4ff92SAndroid Build Coastguard Worker                                     if (status)
431*89c4ff92SAndroid Build Coastguard Worker                                     {
432*89c4ff92SAndroid Build Coastguard Worker                                         FuseElementwiseBinaryLayer<ElementwiseBinaryLayer>(optimizationViews,
433*89c4ff92SAndroid Build Coastguard Worker                                                                                            baseLayer,
434*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationLayer,
435*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationDesc,
436*89c4ff92SAndroid Build Coastguard Worker                                                                                            BinaryOperation::Div,
437*89c4ff92SAndroid Build Coastguard Worker                                                                                            name);
438*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(baseLayer->GetGuid());
439*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(activationLayer->GetGuid());
440*89c4ff92SAndroid Build Coastguard Worker                                     }
441*89c4ff92SAndroid Build Coastguard Worker                                 }
442*89c4ff92SAndroid Build Coastguard Worker                                 else if (baseLayer->GetParameters().m_Operation == BinaryOperation::Mul)
443*89c4ff92SAndroid Build Coastguard Worker                                 {
444*89c4ff92SAndroid Build Coastguard Worker                                     arm_compute::Status status = NeonMultiplicationWorkloadValidate(
445*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
446*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
447*89c4ff92SAndroid Build Coastguard Worker                                             activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
448*89c4ff92SAndroid Build Coastguard Worker                                             &activationDesc);
449*89c4ff92SAndroid Build Coastguard Worker 
450*89c4ff92SAndroid Build Coastguard Worker                                     if (status)
451*89c4ff92SAndroid Build Coastguard Worker                                     {
452*89c4ff92SAndroid Build Coastguard Worker                                         FuseElementwiseBinaryLayer<ElementwiseBinaryLayer>(optimizationViews,
453*89c4ff92SAndroid Build Coastguard Worker                                                                                            baseLayer,
454*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationLayer,
455*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationDesc,
456*89c4ff92SAndroid Build Coastguard Worker                                                                                            BinaryOperation::Mul,
457*89c4ff92SAndroid Build Coastguard Worker                                                                                            name);
458*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(baseLayer->GetGuid());
459*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(activationLayer->GetGuid());
460*89c4ff92SAndroid Build Coastguard Worker                                     }
461*89c4ff92SAndroid Build Coastguard Worker                                 }
462*89c4ff92SAndroid Build Coastguard Worker                                 else if (baseLayer->GetParameters().m_Operation == BinaryOperation::Sub)
463*89c4ff92SAndroid Build Coastguard Worker                                 {
464*89c4ff92SAndroid Build Coastguard Worker                                     arm_compute::Status status = NeonSubtractionWorkloadValidate(
465*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
466*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
467*89c4ff92SAndroid Build Coastguard Worker                                             activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
468*89c4ff92SAndroid Build Coastguard Worker                                             &activationDesc);
469*89c4ff92SAndroid Build Coastguard Worker 
470*89c4ff92SAndroid Build Coastguard Worker                                     if (status)
471*89c4ff92SAndroid Build Coastguard Worker                                     {
472*89c4ff92SAndroid Build Coastguard Worker                                         FuseElementwiseBinaryLayer<ElementwiseBinaryLayer>(optimizationViews,
473*89c4ff92SAndroid Build Coastguard Worker                                                                                            baseLayer,
474*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationLayer,
475*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationDesc,
476*89c4ff92SAndroid Build Coastguard Worker                                                                                            BinaryOperation::Sub,
477*89c4ff92SAndroid Build Coastguard Worker                                                                                            name);
478*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(baseLayer->GetGuid());
479*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(activationLayer->GetGuid());
480*89c4ff92SAndroid Build Coastguard Worker                                     }
481*89c4ff92SAndroid Build Coastguard Worker                                 }
482*89c4ff92SAndroid Build Coastguard Worker                                 // No fusion available for other BinaryOperations
483*89c4ff92SAndroid Build Coastguard Worker                             }
484*89c4ff92SAndroid Build Coastguard Worker                         }
485*89c4ff92SAndroid Build Coastguard Worker                     }
486*89c4ff92SAndroid Build Coastguard Worker                 }
487*89c4ff92SAndroid Build Coastguard Worker             }
488*89c4ff92SAndroid Build Coastguard Worker         }
489*89c4ff92SAndroid Build Coastguard Worker 
490*89c4ff92SAndroid Build Coastguard Worker         // Separate reduce layer with multiple axes into multiple reduce layers with 1 axis.
491*89c4ff92SAndroid Build Coastguard Worker         if (base.GetType() == LayerType::Reduce)
492*89c4ff92SAndroid Build Coastguard Worker         {
493*89c4ff92SAndroid Build Coastguard Worker             ReduceLayer* baseLayer            = PolymorphicDowncast<ReduceLayer*>(&base);
494*89c4ff92SAndroid Build Coastguard Worker             ReduceDescriptor reduceDescriptor = baseLayer->GetParameters();
495*89c4ff92SAndroid Build Coastguard Worker 
496*89c4ff92SAndroid Build Coastguard Worker             if (!reduceDescriptor.m_vAxis.empty() && reduceDescriptor.m_vAxis.size() > 1)
497*89c4ff92SAndroid Build Coastguard Worker             {
498*89c4ff92SAndroid Build Coastguard Worker                 // Add new layers to the graph and connect them.
499*89c4ff92SAndroid Build Coastguard Worker                 std::vector<IConnectableLayer*> layers = ChainReduceLayers<ReduceLayer>(optimizationViews,
500*89c4ff92SAndroid Build Coastguard Worker                                                                                         baseLayer,
501*89c4ff92SAndroid Build Coastguard Worker                                                                                         reduceDescriptor);
502*89c4ff92SAndroid Build Coastguard Worker 
503*89c4ff92SAndroid Build Coastguard Worker                 // Replace existing baselayer with new subgraph.
504*89c4ff92SAndroid Build Coastguard Worker                 ReplaceLayers<ReduceLayer>(optimizationViews, baseLayer, layers);
505*89c4ff92SAndroid Build Coastguard Worker                 untouched.erase(baseLayer->GetGuid());
506*89c4ff92SAndroid Build Coastguard Worker             }
507*89c4ff92SAndroid Build Coastguard Worker         }
508*89c4ff92SAndroid Build Coastguard Worker     }
509*89c4ff92SAndroid Build Coastguard Worker 
510*89c4ff92SAndroid Build Coastguard Worker     if (optimizationViews.GetSubstitutions().empty())
511*89c4ff92SAndroid Build Coastguard Worker     {
512*89c4ff92SAndroid Build Coastguard Worker         optimizationViews.AddUntouchedSubgraph(SubgraphView(subgraph));
513*89c4ff92SAndroid Build Coastguard Worker     }
514*89c4ff92SAndroid Build Coastguard Worker     else
515*89c4ff92SAndroid Build Coastguard Worker     {
516*89c4ff92SAndroid Build Coastguard Worker         ReportUntouchedLayers(optimizationViews, untouched);
517*89c4ff92SAndroid Build Coastguard Worker     }
518*89c4ff92SAndroid Build Coastguard Worker 
519*89c4ff92SAndroid Build Coastguard Worker     return optimizationViews;
520*89c4ff92SAndroid Build Coastguard Worker }
521*89c4ff92SAndroid Build Coastguard Worker 
GetHandleFactoryPreferences() const522*89c4ff92SAndroid Build Coastguard Worker std::vector<ITensorHandleFactory::FactoryId> NeonBackend::GetHandleFactoryPreferences() const
523*89c4ff92SAndroid Build Coastguard Worker {
524*89c4ff92SAndroid Build Coastguard Worker     return std::vector<ITensorHandleFactory::FactoryId>() = { NeonTensorHandleFactory::GetIdStatic() };
525*89c4ff92SAndroid Build Coastguard Worker }
526*89c4ff92SAndroid Build Coastguard Worker 
RegisterTensorHandleFactories(class TensorHandleFactoryRegistry & registry)527*89c4ff92SAndroid Build Coastguard Worker void NeonBackend::RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry)
528*89c4ff92SAndroid Build Coastguard Worker {
529*89c4ff92SAndroid Build Coastguard Worker     auto memoryManager = std::make_shared<NeonMemoryManager>(std::make_unique<arm_compute::Allocator>(),
530*89c4ff92SAndroid Build Coastguard Worker                                                              BaseMemoryManager::MemoryAffinity::Offset);
531*89c4ff92SAndroid Build Coastguard Worker 
532*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterMemoryManager(memoryManager);
533*89c4ff92SAndroid Build Coastguard Worker 
534*89c4ff92SAndroid Build Coastguard Worker     auto factory = std::make_unique<NeonTensorHandleFactory>(memoryManager);
535*89c4ff92SAndroid Build Coastguard Worker     // Register copy and import factory pair
536*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
537*89c4ff92SAndroid Build Coastguard Worker     // Register the factory
538*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(factory));
539*89c4ff92SAndroid Build Coastguard Worker }
540*89c4ff92SAndroid Build Coastguard Worker 
GetDefaultAllocator() const541*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ICustomAllocator> NeonBackend::GetDefaultAllocator() const
542*89c4ff92SAndroid Build Coastguard Worker {
543*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<DefaultAllocator>();
544*89c4ff92SAndroid Build Coastguard Worker }
545*89c4ff92SAndroid Build Coastguard Worker 
546*89c4ff92SAndroid Build Coastguard Worker 
547*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
548