xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/SpaceToDepth.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "SpaceToDepth.hpp"
7 
8 #include <armnnUtils/DataLayoutIndexed.hpp>
9 
10 using namespace armnnUtils;
11 
12 namespace {
GetOffset(const armnn::TensorShape & shape,unsigned int c,unsigned int h,unsigned int w,unsigned int b,const DataLayoutIndexed & dataLayout)13     unsigned int GetOffset(const armnn::TensorShape& shape,
14         unsigned int c,
15         unsigned int h,
16         unsigned int w,
17         unsigned int b,
18         const DataLayoutIndexed& dataLayout)
19     {
20         if (dataLayout.GetDataLayout() == armnn::DataLayout::NHWC)
21         {
22             return ((b * shape[dataLayout.GetHeightIndex()] + h) * shape[dataLayout.GetWidthIndex()] + w) *
23                 shape[dataLayout.GetChannelsIndex()] + c;
24         }
25         else
26         {
27             return ((b * shape[dataLayout.GetChannelsIndex()] + c) * shape[dataLayout.GetHeightIndex()] + h) *
28                 shape[dataLayout.GetWidthIndex()] + w;
29         }
30     }
31 }
32 
33 namespace armnn
34 {
35 
SpaceToDepth(const TensorInfo & inputInfo,const TensorInfo & outputInfo,const SpaceToDepthDescriptor & params,Decoder<float> & inputData,Encoder<float> & outputData)36 void SpaceToDepth(const TensorInfo& inputInfo,
37                   const TensorInfo& outputInfo,
38                   const SpaceToDepthDescriptor& params,
39                   Decoder<float>& inputData,
40                   Encoder<float>& outputData)
41 {
42     DataLayoutIndexed dataLayout = params.m_DataLayout;
43 
44     const TensorShape& inputShape = inputInfo.GetShape();
45     const TensorShape& outputShape = outputInfo.GetShape();
46 
47     const unsigned int inputBatchSize = inputShape[0];
48     const unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()];
49 
50     const unsigned int outputHeight = outputShape[dataLayout.GetHeightIndex()];
51     const unsigned int outputWidth = outputShape[dataLayout.GetWidthIndex()];
52     const unsigned int outputChannels = outputShape[dataLayout.GetChannelsIndex()];
53 
54     const unsigned int blockSize = params.m_BlockSize;
55 
56     if (blockSize == 0)
57     {
58         throw InvalidArgumentException(
59             "Input shape must be divisible by block size in all spatial dimensions: Block size is"
60             " equal to zero");
61     }
62 
63     for (unsigned int outChannelIndex = 0; outChannelIndex < outputChannels; outChannelIndex++)
64     {
65         unsigned int inChannelIndex = outChannelIndex % inputChannels;
66 
67         unsigned int shiftW = (outChannelIndex / inputChannels) % blockSize;
68         unsigned int shiftH = (outChannelIndex / inputChannels) / blockSize;
69 
70         for (unsigned int outH = 0; outH < outputHeight; outH++)
71         {
72             for (unsigned int outW = 0; outW < outputWidth; outW++)
73             {
74                 for (unsigned int inBatchIndex = 0; inBatchIndex < inputBatchSize; inBatchIndex++)
75                 {
76                     unsigned int inOffset = GetOffset(inputShape,
77                         inChannelIndex,
78                         (outH * blockSize + shiftH),
79                         (outW * blockSize + shiftW),
80                         inBatchIndex,
81                         dataLayout);
82 
83                     unsigned int outOffset = GetOffset(outputShape,
84                         outChannelIndex,
85                         outH,
86                         outW,
87                         inBatchIndex,
88                         dataLayout);
89 
90                     outputData += outOffset;
91                     inputData += inOffset;
92                     outputData.Set(inputData.Get());
93                     inputData -= inOffset;
94                     outputData -= outOffset;
95                 }
96             }
97         }
98     }
99 }
100 
101 void SpaceToDepth(const TensorInfo& inputInfo,
102     const TensorInfo& outputInfo,
103     const SpaceToDepthDescriptor& params,
104     Decoder<float>& inputData,
105     Encoder<float>& outData);
106 
107 } //namespace armnn
108