1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "SplitterLayer.hpp"
6
7 #include "LayerCloneBase.hpp"
8
9 #include <armnn/TypesUtils.hpp>
10 #include <armnn/backends/WorkloadData.hpp>
11 #include <armnn/backends/WorkloadFactory.hpp>
12
13 namespace armnn
14 {
15
SplitterLayer(const ViewsDescriptor & param,const char * name)16 SplitterLayer::SplitterLayer(const ViewsDescriptor& param, const char* name)
17 : LayerWithParameters(1, param.GetNumViews(), LayerType::Splitter, param, name)
18 {
19 }
20
CreateWorkload(const IWorkloadFactory & factory) const21 std::unique_ptr<IWorkload> SplitterLayer::CreateWorkload(const IWorkloadFactory& factory) const
22 {
23 SplitterQueueDescriptor descriptor;
24
25 // Copies the window origins to the descriptor.
26 for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
27 {
28 descriptor.m_ViewOrigins.emplace_back(
29 std::vector<unsigned int>(m_Param.GetViewOrigin(i), m_Param.GetViewOrigin(i) + m_Param.GetNumDimensions()));
30 }
31
32 SetAdditionalInfo(descriptor);
33
34 return factory.CreateWorkload(LayerType::Splitter, descriptor, PrepInfoAndDesc(descriptor));
35 }
36
37 template<typename FactoryType>
CreateTensors(const TensorHandleFactoryRegistry & registry,const FactoryType & factory,bool isMemoryManaged)38 void SplitterLayer::CreateTensors(const TensorHandleFactoryRegistry& registry,
39 const FactoryType& factory,
40 bool isMemoryManaged)
41 {
42 //If sub tensors are supported than all the "splitter" need to do is to
43 //set the outputs to be appropriate sub tensors of the input.
44 bool useSubTensors = factory.SupportsSubTensors();
45
46 if (useSubTensors)
47 {
48 // Get outputHandler of previous layer
49 const OutputHandler& outputHandler = GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();
50 const OutputSlot* slot = GetInputSlots()[0].GetConnectedOutputSlot();
51
52 const TensorInfo& parentInfo = outputHandler.GetTensorInfo();
53
54 ITensorHandle* inputData = outputHandler.GetData();
55
56 std::vector<std::unique_ptr<ITensorHandle>> subTensors;
57
58 // check if split is along the x or y (2 innermost dimensions)
59 auto numberOfDimensions = m_Param.GetNumDimensions();
60
61 // Compute split axis within class as aclCommon function causes header issues when included
62 auto ComputeSplitAxis = [&](const armnn::SplitterDescriptor& desc, const TensorShape& input)
63 {
64 unsigned int numSplit = desc.GetNumViews();
65 unsigned int numDimensions = desc.GetNumDimensions();
66 std::set<unsigned int> splitAxis;
67
68 for (unsigned int i = 0; i < numSplit; ++i)
69 {
70 for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
71 {
72 if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
73 {
74 splitAxis.insert(dimIdx);
75 }
76 }
77 }
78 return splitAxis;
79 };
80
81 std::set<unsigned int> axis = ComputeSplitAxis(m_Param, parentInfo.GetShape());
82 std::set<unsigned int>::iterator axisIt = axis.begin();
83
84 bool isOnXorY = m_Param.GetNumDimensions() >= 3 &&
85 ((*axisIt == numberOfDimensions - 1) ||
86 (*axisIt == numberOfDimensions - 2));
87
88 //Creates the outputs as subtensors of the input.
89 for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
90 {
91 const TensorInfo& info = m_OutputHandlers[i].GetTensorInfo();
92
93 OutputSlot& outSlot = GetOutputSlot(i);
94 ITensorHandleFactory::FactoryId factoryId = outSlot.GetTensorHandleFactoryId();
95
96 const unsigned int numOutputSlots = GetNumOutputSlots();
97
98 // if split along x or y (2 innermost dimensions) and the next layers do not require padding
99 bool canUseSubTensorOnXorY = true;
100 bool isTensorHandleFactory = std::is_same<armnn::ITensorHandleFactory, FactoryType>::value;
101 if (isTensorHandleFactory)
102 {
103 for (unsigned int it = 0; it < numOutputSlots; ++it)
104 {
105 InputSlot* inputSlot = GetOutputSlot(it).GetConnection(0);
106 ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
107 std::vector<Capability> capabilities =
108 handleFactory->GetCapabilities(&(inputSlot->GetOwningLayer()),
109 this,
110 CapabilityClass::PaddingRequired);
111 if (isOnXorY)
112 {
113 canUseSubTensorOnXorY = false;
114 if (capabilities.empty())
115 {
116 canUseSubTensorOnXorY = true;
117 }
118 }
119
120 if (!canUseSubTensorOnXorY)
121 {
122 break;
123 }
124 }
125 }
126
127 auto CreateSubTensor = [&]()
128 {
129 // Make sure:
130 // 1) quantization parameters are in the same space
131 // 2) the same TensorHandleFactory is used for input and split layer output
132 // 3) the output does not go to a Constant layer or input layer
133 // 4) if split along x or y (2 innermost dimensions) and the next layers do not require padding
134 if (parentInfo.IsTypeSpaceMatch(info) && //(1)
135 factoryId == slot->GetTensorHandleFactoryId() && //(2)
136 GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Constant && //(3)
137 GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Input && //(3)
138 canUseSubTensorOnXorY) //(4)
139 {
140 ARMNN_NO_DEPRECATE_WARN_BEGIN
141 return factory.CreateSubTensorHandle(*inputData,
142 info.GetShape(),
143 this->m_Param.GetViewOrigin(i));
144 ARMNN_NO_DEPRECATE_WARN_END
145 }
146 return std::unique_ptr<ITensorHandle>();
147 };
148
149 auto subTensor = CreateSubTensor();
150 if (!subTensor)
151 {
152 useSubTensors = false;
153 break; //Failed to create a valid sub-tensor, so stop trying with the rest of the views.
154 }
155 subTensors.push_back(std::move(subTensor));
156 }
157
158 if (useSubTensors)
159 {
160 unsigned int i = 0;
161 for (auto& subTensor : subTensors)
162 {
163 m_OutputHandlers[i].SetData(std::move(subTensor));
164 ++i;
165 }
166 }
167 }
168
169 if (!useSubTensors)
170 {
171 for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
172 {
173 m_OutputHandlers[i].CreateTensorHandles(factory, isMemoryManaged);
174 }
175 }
176 }
177
CreateTensorHandles(const TensorHandleFactoryRegistry & registry,const IWorkloadFactory & workloadFactory,const bool isMemoryManaged)178 void SplitterLayer::CreateTensorHandles(const TensorHandleFactoryRegistry& registry,
179 const IWorkloadFactory& workloadFactory,
180 const bool isMemoryManaged)
181 {
182 OutputSlot& slot = GetOutputSlot(0);
183 ITensorHandleFactory::FactoryId factoryId = slot.GetTensorHandleFactoryId();
184
185 if (factoryId == ITensorHandleFactory::LegacyFactoryId)
186 {
187 CreateTensors(registry, workloadFactory, isMemoryManaged);
188 }
189 else
190 {
191 ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
192 ARMNN_ASSERT(handleFactory);
193 CreateTensors(registry, *handleFactory, isMemoryManaged);
194 }
195 }
196
Clone(Graph & graph) const197 SplitterLayer* SplitterLayer::Clone(Graph& graph) const
198 {
199 return CloneBase<SplitterLayer>(graph, m_Param, GetName());
200 }
201
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const202 std::vector<TensorShape> SplitterLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
203 {
204 IgnoreUnused(inputShapes);
205 ARMNN_ASSERT(inputShapes.size() == m_Param.GetNumViews());
206 std::vector<TensorShape> outShapes;
207 //Output shapes must match View shapes.
208 for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
209 {
210 const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
211 outShapes.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
212 }
213 return outShapes;
214 }
215
ValidateTensorShapesFromInputs()216 void SplitterLayer::ValidateTensorShapesFromInputs()
217 {
218 std::for_each(BeginOutputSlots(), EndOutputSlots(), [&](OutputSlot& outputSlot)
219 {
220 VerifyShapeInferenceType(outputSlot.GetTensorInfo().GetShape(), m_ShapeInferenceMethod);
221 });
222
223 std::vector<TensorShape> views;
224 for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
225 {
226 const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
227 views.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
228 }
229
230 auto inferredShapes = InferOutputShapes(views);
231
232 ARMNN_ASSERT(inferredShapes.size() == m_Param.GetNumViews());
233
234 for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
235 {
236 ValidateAndCopyShape(GetOutputSlot(viewIdx).GetTensorInfo().GetShape(),
237 inferredShapes[viewIdx],
238 m_ShapeInferenceMethod,
239 "SplitterLayer",
240 viewIdx);
241 }
242 }
243
ExecuteStrategy(IStrategy & strategy) const244 void SplitterLayer::ExecuteStrategy(IStrategy& strategy) const
245 {
246 strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
247 }
248
249 } // namespace armnn
250