xref: /aosp_15_r20/external/armnn/src/armnn/layers/SplitterLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "SplitterLayer.hpp"
6 
7 #include "LayerCloneBase.hpp"
8 
9 #include <armnn/TypesUtils.hpp>
10 #include <armnn/backends/WorkloadData.hpp>
11 #include <armnn/backends/WorkloadFactory.hpp>
12 
13 namespace armnn
14 {
15 
SplitterLayer(const ViewsDescriptor & param,const char * name)16 SplitterLayer::SplitterLayer(const ViewsDescriptor& param, const char* name)
17     : LayerWithParameters(1, param.GetNumViews(), LayerType::Splitter, param, name)
18 {
19 }
20 
CreateWorkload(const IWorkloadFactory & factory) const21 std::unique_ptr<IWorkload> SplitterLayer::CreateWorkload(const IWorkloadFactory& factory) const
22 {
23     SplitterQueueDescriptor descriptor;
24 
25     // Copies the window origins to the descriptor.
26     for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
27     {
28         descriptor.m_ViewOrigins.emplace_back(
29             std::vector<unsigned int>(m_Param.GetViewOrigin(i), m_Param.GetViewOrigin(i) + m_Param.GetNumDimensions()));
30     }
31 
32     SetAdditionalInfo(descriptor);
33 
34     return factory.CreateWorkload(LayerType::Splitter, descriptor, PrepInfoAndDesc(descriptor));
35 }
36 
37 template<typename FactoryType>
CreateTensors(const TensorHandleFactoryRegistry & registry,const FactoryType & factory,bool isMemoryManaged)38 void SplitterLayer::CreateTensors(const TensorHandleFactoryRegistry& registry,
39                                   const FactoryType& factory,
40                                   bool isMemoryManaged)
41 {
42     //If sub tensors are supported than all the "splitter" need to do is to
43     //set the outputs to be appropriate sub tensors of the input.
44     bool useSubTensors = factory.SupportsSubTensors();
45 
46     if (useSubTensors)
47     {
48         // Get outputHandler of previous layer
49         const OutputHandler& outputHandler = GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();
50         const OutputSlot* slot = GetInputSlots()[0].GetConnectedOutputSlot();
51 
52         const TensorInfo& parentInfo = outputHandler.GetTensorInfo();
53 
54         ITensorHandle* inputData = outputHandler.GetData();
55 
56         std::vector<std::unique_ptr<ITensorHandle>> subTensors;
57 
58         // check if split is along the x or y (2 innermost dimensions)
59         auto numberOfDimensions = m_Param.GetNumDimensions();
60 
61         // Compute split axis within class as aclCommon function causes header issues when included
62         auto ComputeSplitAxis = [&](const armnn::SplitterDescriptor& desc, const TensorShape& input)
63         {
64             unsigned int numSplit = desc.GetNumViews();
65             unsigned int numDimensions = desc.GetNumDimensions();
66             std::set<unsigned int> splitAxis;
67 
68             for (unsigned int i = 0; i < numSplit; ++i)
69             {
70                 for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
71                 {
72                     if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
73                     {
74                         splitAxis.insert(dimIdx);
75                     }
76                 }
77             }
78             return splitAxis;
79         };
80 
81         std::set<unsigned int> axis = ComputeSplitAxis(m_Param, parentInfo.GetShape());
82         std::set<unsigned int>::iterator axisIt = axis.begin();
83 
84         bool isOnXorY = m_Param.GetNumDimensions() >= 3 &&
85                             ((*axisIt == numberOfDimensions - 1) ||
86                                 (*axisIt == numberOfDimensions - 2));
87 
88         //Creates the outputs as subtensors of the input.
89         for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
90         {
91             const TensorInfo& info = m_OutputHandlers[i].GetTensorInfo();
92 
93             OutputSlot& outSlot = GetOutputSlot(i);
94             ITensorHandleFactory::FactoryId factoryId = outSlot.GetTensorHandleFactoryId();
95 
96             const unsigned int numOutputSlots = GetNumOutputSlots();
97 
98             // if split along x or y (2 innermost dimensions) and the next layers do not require padding
99             bool canUseSubTensorOnXorY = true;
100             bool isTensorHandleFactory = std::is_same<armnn::ITensorHandleFactory, FactoryType>::value;
101             if (isTensorHandleFactory)
102             {
103                 for (unsigned int it = 0; it < numOutputSlots; ++it)
104                 {
105                     InputSlot* inputSlot = GetOutputSlot(it).GetConnection(0);
106                     ITensorHandleFactory* handleFactory  = registry.GetFactory(factoryId);
107                     std::vector<Capability> capabilities =
108                         handleFactory->GetCapabilities(&(inputSlot->GetOwningLayer()),
109                                                        this,
110                                                        CapabilityClass::PaddingRequired);
111                     if (isOnXorY)
112                     {
113                         canUseSubTensorOnXorY = false;
114                         if (capabilities.empty())
115                         {
116                             canUseSubTensorOnXorY = true;
117                         }
118                     }
119 
120                     if (!canUseSubTensorOnXorY)
121                     {
122                         break;
123                     }
124                 }
125             }
126 
127             auto CreateSubTensor = [&]()
128             {
129                 // Make sure:
130                 // 1) quantization parameters are in the same space
131                 // 2) the same TensorHandleFactory is used for input and split layer output
132                 // 3) the output does not go to a Constant layer or input layer
133                 // 4) if split along x or y (2 innermost dimensions) and the next layers do not require padding
134                 if (parentInfo.IsTypeSpaceMatch(info) && //(1)
135                     factoryId == slot->GetTensorHandleFactoryId() && //(2)
136                     GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Constant && //(3)
137                     GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Input && //(3)
138                     canUseSubTensorOnXorY) //(4)
139                 {
140                     ARMNN_NO_DEPRECATE_WARN_BEGIN
141                     return factory.CreateSubTensorHandle(*inputData,
142                                                          info.GetShape(),
143                                                          this->m_Param.GetViewOrigin(i));
144                     ARMNN_NO_DEPRECATE_WARN_END
145                 }
146                 return std::unique_ptr<ITensorHandle>();
147             };
148 
149             auto subTensor = CreateSubTensor();
150             if (!subTensor)
151             {
152                 useSubTensors = false;
153                 break; //Failed to create a valid sub-tensor, so stop trying with the rest of the views.
154             }
155             subTensors.push_back(std::move(subTensor));
156         }
157 
158         if (useSubTensors)
159         {
160             unsigned int i = 0;
161             for (auto& subTensor : subTensors)
162             {
163                 m_OutputHandlers[i].SetData(std::move(subTensor));
164                 ++i;
165             }
166         }
167     }
168 
169     if (!useSubTensors)
170     {
171         for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
172         {
173             m_OutputHandlers[i].CreateTensorHandles(factory, isMemoryManaged);
174         }
175     }
176 }
177 
CreateTensorHandles(const TensorHandleFactoryRegistry & registry,const IWorkloadFactory & workloadFactory,const bool isMemoryManaged)178 void SplitterLayer::CreateTensorHandles(const TensorHandleFactoryRegistry& registry,
179                                         const IWorkloadFactory& workloadFactory,
180                                         const bool isMemoryManaged)
181 {
182     OutputSlot& slot = GetOutputSlot(0);
183     ITensorHandleFactory::FactoryId factoryId = slot.GetTensorHandleFactoryId();
184 
185     if (factoryId == ITensorHandleFactory::LegacyFactoryId)
186     {
187         CreateTensors(registry, workloadFactory, isMemoryManaged);
188     }
189     else
190     {
191         ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
192         ARMNN_ASSERT(handleFactory);
193         CreateTensors(registry, *handleFactory, isMemoryManaged);
194     }
195 }
196 
Clone(Graph & graph) const197 SplitterLayer* SplitterLayer::Clone(Graph& graph) const
198 {
199     return CloneBase<SplitterLayer>(graph, m_Param, GetName());
200 }
201 
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const202 std::vector<TensorShape> SplitterLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
203 {
204     IgnoreUnused(inputShapes);
205     ARMNN_ASSERT(inputShapes.size() ==  m_Param.GetNumViews());
206     std::vector<TensorShape> outShapes;
207     //Output shapes must match View shapes.
208     for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
209     {
210         const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
211         outShapes.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
212     }
213     return outShapes;
214 }
215 
ValidateTensorShapesFromInputs()216 void SplitterLayer::ValidateTensorShapesFromInputs()
217 {
218     std::for_each(BeginOutputSlots(), EndOutputSlots(), [&](OutputSlot& outputSlot)
219     {
220         VerifyShapeInferenceType(outputSlot.GetTensorInfo().GetShape(), m_ShapeInferenceMethod);
221     });
222 
223     std::vector<TensorShape> views;
224     for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
225     {
226         const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
227         views.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
228     }
229 
230     auto inferredShapes = InferOutputShapes(views);
231 
232     ARMNN_ASSERT(inferredShapes.size() == m_Param.GetNumViews());
233 
234     for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
235     {
236         ValidateAndCopyShape(GetOutputSlot(viewIdx).GetTensorInfo().GetShape(),
237                              inferredShapes[viewIdx],
238                              m_ShapeInferenceMethod,
239                              "SplitterLayer",
240                              viewIdx);
241     }
242 }
243 
ExecuteStrategy(IStrategy & strategy) const244 void SplitterLayer::ExecuteStrategy(IStrategy& strategy) const
245 {
246     strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
247 }
248 
249 } // namespace armnn
250