xref: /aosp_15_r20/external/armnn/src/armnn/layers/StackLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "StackLayer.hpp"
6 #include "LayerCloneBase.hpp"
7 
8 #include <armnn/TypesUtils.hpp>
9 #include <armnn/backends/WorkloadData.hpp>
10 #include <armnn/backends/WorkloadFactory.hpp>
11 
12 #include <queue>
13 
14 namespace armnn
15 {
16 
StackLayer(const StackDescriptor & param,const char * name)17 StackLayer::StackLayer(const StackDescriptor& param, const char* name)
18     : LayerWithParameters(param.m_NumInputs, 1, LayerType::Stack, param, name)
19 {
20 }
21 
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> StackLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24     StackQueueDescriptor descriptor;
25     SetAdditionalInfo(descriptor);
26 
27     return factory.CreateWorkload(LayerType::Stack, descriptor, PrepInfoAndDesc(descriptor));
28 }
29 
Clone(Graph & graph) const30 StackLayer* StackLayer::Clone(Graph& graph) const
31 {
32     return CloneBase<StackLayer>(graph, m_Param, GetName());
33 }
34 
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const35 std::vector<TensorShape> StackLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
36 {
37     IgnoreUnused(inputShapes);
38 
39     const TensorShape& inputShape = m_Param.m_InputShape;
40     const unsigned int inputNumDimensions = inputShape.GetNumDimensions();
41     const unsigned int axis = m_Param.m_Axis;
42 
43     ARMNN_ASSERT(axis <= inputNumDimensions);
44 
45     std::vector<unsigned int> dimensionSizes(inputNumDimensions + 1, 0);
46     for (unsigned int i = 0; i < axis; ++i)
47     {
48         dimensionSizes[i] = inputShape[i];
49     }
50 
51     dimensionSizes[axis] = m_Param.m_NumInputs;
52 
53     for (unsigned int i = axis + 1; i < inputNumDimensions + 1; ++i)
54     {
55         dimensionSizes[i] = inputShape[i-1];
56     }
57 
58     TensorShape targetShape = TensorShape(inputNumDimensions + 1, dimensionSizes.data());
59 
60     return std::vector<TensorShape>({ targetShape });
61 }
62 
ValidateTensorShapesFromInputs()63 void StackLayer::ValidateTensorShapesFromInputs()
64 {
65     // Validates Stack layer.
66     ConditionalThrowIfNotEqual<LayerValidationException>(
67         "StackLayer: Num Input Slots must match Num Inputs.",
68         m_Param.m_NumInputs,
69         GetNumInputSlots());
70 
71     VerifyLayerConnections(m_Param.m_NumInputs, CHECK_LOCATION());
72 
73     const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
74 
75     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
76 
77     // Constructs and validates input shapes
78     std::vector<TensorShape> inputShapes;
79     for (unsigned int i = 0; i < GetNumInputSlots(); ++i)
80     {
81         TensorShape inputShape = GetInputSlot(i).GetConnection()->GetTensorInfo().GetShape();
82         if (inputShape != m_Param.m_InputShape)
83         {
84             throw LayerValidationException("StackLayer: TensorShape set on InputSlot[" +
85                                            std::to_string(i) +
86                                            "] does not match defined input shape");
87         }
88         inputShapes.push_back(inputShape);
89     }
90 
91     auto inferredShapes = InferOutputShapes(inputShapes);
92 
93     ARMNN_ASSERT(inferredShapes.size() == 1);
94 
95     ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "StackLayer");
96 }
97 
ExecuteStrategy(IStrategy & strategy) const98 void StackLayer::ExecuteStrategy(IStrategy& strategy) const
99 {
100     strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
101 }
102 
103 } // namespace armnn armnn
104