xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/BatchToSpaceNd.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchToSpaceNd.hpp"
7 
8 #include "RefWorkloadUtils.hpp"
9 
10 #include <armnn/Types.hpp>
11 
12 #include <armnn/utility/Assert.hpp>
13 
14 using namespace armnnUtils;
15 
16 namespace armnn
17 {
18 
Offset(const TensorShape & shape,unsigned int batch,unsigned int height,unsigned int width,unsigned int channels,const DataLayoutIndexed & dataLayout)19 inline unsigned int Offset(const TensorShape& shape, unsigned int batch, unsigned int height, unsigned int width,
20         unsigned int channels, const DataLayoutIndexed& dataLayout)
21 {
22     if (dataLayout.GetDataLayout() == DataLayout::NHWC)
23     {
24         return ((batch * shape[dataLayout.GetHeightIndex()] + height) * shape[dataLayout.GetWidthIndex()] + width) *
25                shape[dataLayout.GetChannelsIndex()] + channels;
26     }
27     else
28     {
29         return ((batch * shape[dataLayout.GetChannelsIndex()] + channels) *
30                shape[dataLayout.GetHeightIndex()] + height) *
31                shape[dataLayout.GetWidthIndex()] + width;
32     }
33 }
34 
BatchToSpaceNd(const DataLayoutIndexed & dataLayout,const TensorInfo & inputTensorInfo,const TensorInfo & outputTensorInfo,const std::vector<unsigned int> & blockShape,const std::vector<std::pair<unsigned int,unsigned int>> & cropsData,Decoder<float> & inputDecoder,Encoder<float> & outputEncoder)35 void BatchToSpaceNd(const DataLayoutIndexed& dataLayout,
36                     const TensorInfo& inputTensorInfo,
37                     const TensorInfo& outputTensorInfo,
38                     const std::vector<unsigned int>& blockShape,
39                     const std::vector<std::pair<unsigned int, unsigned int>>& cropsData,
40                     Decoder<float>& inputDecoder,
41                     Encoder<float>& outputEncoder)
42 {
43     TensorShape inputShape = inputTensorInfo.GetShape();
44 
45     ARMNN_ASSERT_MSG(inputShape.GetNumDimensions() == 4, "Expected Input with 4 Dimensions");
46 
47     TensorShape outputShape = outputTensorInfo.GetShape();
48 
49     ARMNN_ASSERT_MSG(outputShape.GetNumDimensions() == 4, "Expected Output with 4 Dimensions");
50 
51     const unsigned int inputBatchSize = inputShape[0];
52     const unsigned int channels = inputShape[dataLayout.GetChannelsIndex()];
53 
54     const unsigned int outputBatchSize = outputShape[0];
55     const unsigned int outputHeight = outputShape[dataLayout.GetHeightIndex()];
56     const unsigned int outputWidth = outputShape[dataLayout.GetWidthIndex()];
57 
58     ARMNN_ASSERT_MSG(blockShape.size() > 0, "BlockShape must contain 1 or more entries");
59 
60     const unsigned int blockShapeHeight = blockShape[0];
61     const unsigned int blockShapeWidth = blockShape[1];
62 
63     ARMNN_ASSERT_MSG(cropsData.size() > 0, "Crops must contain 1 or more entries");
64 
65     const unsigned int cropsTop = cropsData[0].first;
66     const unsigned int cropsLeft = cropsData[1].first;
67 
68     for (unsigned int inBatch = 0; inBatch < inputBatchSize; ++inBatch)
69     {
70         const unsigned int outBatch = inBatch % outputBatchSize;
71         const unsigned int spatialOffset = inBatch / outputBatchSize;
72 
73         for (unsigned int inH = 0; inH < inputTensorInfo.GetShape()[dataLayout.GetHeightIndex()]; ++inH) {
74             const unsigned int outH = inH * blockShapeHeight + spatialOffset / blockShapeWidth - cropsTop;
75 
76             if (outH >= outputHeight)
77             {
78                 continue;
79             }
80 
81             for (unsigned int inW = 0; inW < inputTensorInfo.GetShape()[dataLayout.GetWidthIndex()]; ++inW) {
82                 const unsigned int outW = inW * blockShapeWidth + spatialOffset % blockShapeWidth - cropsLeft;
83 
84                 if (outW >= outputWidth)
85                 {
86                     continue;
87                 }
88 
89                 for (unsigned int c = 0; c < channels; c++)
90                 {
91                     unsigned int outOffset = Offset(outputShape, outBatch, outH, outW, c, dataLayout);
92                     unsigned int inOffset = Offset(inputShape, inBatch, inH, inW, c, dataLayout);
93 
94                     outputEncoder[outOffset];
95                     inputDecoder[inOffset];
96                     outputEncoder.Set(inputDecoder.Get());
97                 }
98             }
99         }
100     }
101 }
102 
103 } //namespace armnn
104