xref: /aosp_15_r20/external/armnn/src/armnn/layers/BatchToSpaceNdLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchToSpaceNdLayer.hpp"
7 #include "LayerCloneBase.hpp"
8 #include "LayerWithParameters.hpp"
9 #include "BatchToSpaceNdLayer.hpp"
10 
11 #include <armnn/TypesUtils.hpp>
12 
13 #include <armnnUtils/DataLayoutIndexed.hpp>
14 
15 #include <armnn/backends/TensorHandle.hpp>
16 #include <armnn/backends/WorkloadData.hpp>
17 #include <armnn/backends/WorkloadFactory.hpp>
18 
19 #include <numeric>
20 
21 using namespace armnnUtils;
22 
23 namespace armnn
24 {
25 
BatchToSpaceNdLayer(const armnn::BatchToSpaceNdDescriptor & param,const char * name)26 BatchToSpaceNdLayer::BatchToSpaceNdLayer(const armnn::BatchToSpaceNdDescriptor& param, const char* name)
27     : LayerWithParameters(1, 1, LayerType::BatchToSpaceNd, param, name)
28 {
29 }
30 
CreateWorkload(const IWorkloadFactory & factory) const31 std::unique_ptr<IWorkload> BatchToSpaceNdLayer::CreateWorkload(const IWorkloadFactory& factory) const
32 {
33     BatchToSpaceNdQueueDescriptor descriptor;
34     SetAdditionalInfo(descriptor);
35 
36     return factory.CreateWorkload(LayerType::BatchToSpaceNd, descriptor, PrepInfoAndDesc(descriptor));
37 }
38 
Clone(Graph & graph) const39 BatchToSpaceNdLayer* BatchToSpaceNdLayer::Clone(Graph& graph) const
40 {
41     auto layer = CloneBase<BatchToSpaceNdLayer>(graph, m_Param, GetName());
42     return std::move(layer);
43 }
44 
ValidateTensorShapesFromInputs()45 void BatchToSpaceNdLayer::ValidateTensorShapesFromInputs()
46 {
47     VerifyLayerConnections(1, CHECK_LOCATION());
48 
49     const TensorShape &outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
50 
51     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
52 
53     auto inferredShapes = InferOutputShapes({GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape()});
54 
55     ARMNN_ASSERT(inferredShapes.size() == 1);
56 
57     ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "BatchToSpaceNdLayer");
58 }
59 
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const60 std::vector<TensorShape> BatchToSpaceNdLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
61 {
62     ARMNN_ASSERT(inputShapes.size() == 1);
63 
64     const TensorShape& inputShape = inputShapes[0];
65     TensorShape outputShape(inputShape);
66 
67     unsigned int accumulatedBlockShape = std::accumulate(m_Param.m_BlockShape.begin(),
68                                                          m_Param.m_BlockShape.end(),
69                                                          1U,
70                                                          std::multiplies<>());
71 
72     ARMNN_ASSERT(inputShape[0] % accumulatedBlockShape == 0);
73 
74     outputShape[0] = inputShape[0] / accumulatedBlockShape;
75 
76     DataLayoutIndexed dimensionIndices = m_Param.m_DataLayout;
77     unsigned int heightIndex = dimensionIndices.GetHeightIndex();
78     unsigned int widthIndex = dimensionIndices.GetWidthIndex();
79 
80     unsigned int heightCrop = m_Param.m_Crops[0].first + m_Param.m_Crops[0].second;
81     unsigned int widthCrop = m_Param.m_Crops[1].first + m_Param.m_Crops[1].second;
82 
83     unsigned int outputHeight = inputShape[heightIndex] * m_Param.m_BlockShape[0];
84     unsigned int outputWidth = inputShape[widthIndex] * m_Param.m_BlockShape[1];
85 
86     ARMNN_ASSERT_MSG(heightCrop <= outputHeight,
87         "BatchToSpaceLayer: Overall height crop should be less than or equal to the uncropped output height.");
88 
89     ARMNN_ASSERT_MSG(widthCrop <= outputWidth,
90         "BatchToSpaceLayer: Overall width crop should be less than or equal to the uncropped output width.");
91 
92     outputShape[heightIndex] = outputHeight - heightCrop;
93     outputShape[widthIndex] = outputWidth - widthCrop;
94 
95     return std::vector<TensorShape>({ outputShape });
96 }
97 
ExecuteStrategy(IStrategy & strategy) const98 void BatchToSpaceNdLayer::ExecuteStrategy(IStrategy& strategy) const
99 {
100     strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
101 }
102 
103 } // namespace armnn
104