1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "StackLayer.hpp"
6 #include "LayerCloneBase.hpp"
7
8 #include <armnn/TypesUtils.hpp>
9 #include <armnn/backends/WorkloadData.hpp>
10 #include <armnn/backends/WorkloadFactory.hpp>
11
12 #include <queue>
13
14 namespace armnn
15 {
16
StackLayer(const StackDescriptor & param,const char * name)17 StackLayer::StackLayer(const StackDescriptor& param, const char* name)
18 : LayerWithParameters(param.m_NumInputs, 1, LayerType::Stack, param, name)
19 {
20 }
21
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> StackLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24 StackQueueDescriptor descriptor;
25 SetAdditionalInfo(descriptor);
26
27 return factory.CreateWorkload(LayerType::Stack, descriptor, PrepInfoAndDesc(descriptor));
28 }
29
Clone(Graph & graph) const30 StackLayer* StackLayer::Clone(Graph& graph) const
31 {
32 return CloneBase<StackLayer>(graph, m_Param, GetName());
33 }
34
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const35 std::vector<TensorShape> StackLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
36 {
37 IgnoreUnused(inputShapes);
38
39 const TensorShape& inputShape = m_Param.m_InputShape;
40 const unsigned int inputNumDimensions = inputShape.GetNumDimensions();
41 const unsigned int axis = m_Param.m_Axis;
42
43 ARMNN_ASSERT(axis <= inputNumDimensions);
44
45 std::vector<unsigned int> dimensionSizes(inputNumDimensions + 1, 0);
46 for (unsigned int i = 0; i < axis; ++i)
47 {
48 dimensionSizes[i] = inputShape[i];
49 }
50
51 dimensionSizes[axis] = m_Param.m_NumInputs;
52
53 for (unsigned int i = axis + 1; i < inputNumDimensions + 1; ++i)
54 {
55 dimensionSizes[i] = inputShape[i-1];
56 }
57
58 TensorShape targetShape = TensorShape(inputNumDimensions + 1, dimensionSizes.data());
59
60 return std::vector<TensorShape>({ targetShape });
61 }
62
ValidateTensorShapesFromInputs()63 void StackLayer::ValidateTensorShapesFromInputs()
64 {
65 // Validates Stack layer.
66 ConditionalThrowIfNotEqual<LayerValidationException>(
67 "StackLayer: Num Input Slots must match Num Inputs.",
68 m_Param.m_NumInputs,
69 GetNumInputSlots());
70
71 VerifyLayerConnections(m_Param.m_NumInputs, CHECK_LOCATION());
72
73 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
74
75 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
76
77 // Constructs and validates input shapes
78 std::vector<TensorShape> inputShapes;
79 for (unsigned int i = 0; i < GetNumInputSlots(); ++i)
80 {
81 TensorShape inputShape = GetInputSlot(i).GetConnection()->GetTensorInfo().GetShape();
82 if (inputShape != m_Param.m_InputShape)
83 {
84 throw LayerValidationException("StackLayer: TensorShape set on InputSlot[" +
85 std::to_string(i) +
86 "] does not match defined input shape");
87 }
88 inputShapes.push_back(inputShape);
89 }
90
91 auto inferredShapes = InferOutputShapes(inputShapes);
92
93 ARMNN_ASSERT(inferredShapes.size() == 1);
94
95 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "StackLayer");
96 }
97
ExecuteStrategy(IStrategy & strategy) const98 void StackLayer::ExecuteStrategy(IStrategy& strategy) const
99 {
100 strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
101 }
102
103 } // namespace armnn armnn
104