xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Pad.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Pad.hpp"
7 
8 #include "BaseIterator.hpp"
9 #include "Decoders.hpp"
10 #include "Encoders.hpp"
11 
12 #include <armnnUtils/TensorUtils.hpp>
13 
14 #include <cmath>
15 #include <cstddef>
16 #include <functional>
17 #include <limits>
18 #include <cassert>
19 
20 namespace
21 {
22 
FillOutputWithPadValue(armnn::Encoder<float> & outputData,const float padValue,const unsigned int numOutputElements)23 void FillOutputWithPadValue(armnn::Encoder<float>& outputData,
24                             const float padValue,
25                             const unsigned int numOutputElements)
26 {
27     for (unsigned int i = 0; i < numOutputElements; ++i)
28     {
29         outputData[i];
30         outputData.Set(padValue);
31     }
32 }
33 
34 } // anonymous namespace
35 
36 namespace armnn
37 {
38 
Pad(const TensorInfo & inputInfo,const TensorInfo & outputInfo,const ITensorHandle * inputHandle,ITensorHandle * outputHandle,const PadQueueDescriptor & data)39 void Pad(const TensorInfo& inputInfo,
40          const TensorInfo& outputInfo,
41          const ITensorHandle* inputHandle,
42          ITensorHandle* outputHandle,
43          const PadQueueDescriptor& data)
44 {
45     auto padList  = data.m_Parameters.m_PadList;
46     auto padValue = data.m_Parameters.m_PadValue;
47 
48     unsigned int numOutputElements = outputInfo.GetNumElements();
49 
50     TensorShape outputShape = outputInfo.GetShape();
51     TensorShape inputShape  = inputInfo.GetShape();
52 
53     unsigned int numInputDimensions = inputShape.GetNumDimensions();
54 
55 #ifndef NDEBUG
56 
57     unsigned int numOutputDimensions = outputShape.GetNumDimensions();
58     assert(numInputDimensions == numOutputDimensions);
59 
60 #endif
61 
62     unsigned int inputBatches  = 0;
63     unsigned int inputChannels = 0;
64     unsigned int inputHeight   = 0;
65     unsigned int inputWidth    = 0;
66 
67     unsigned int outputChannels = 0;
68     unsigned int outputHeight   = 0;
69     unsigned int outputWidth    = 0;
70 
71     auto inputData = MakeDecoder<float>(inputInfo, inputHandle->Map());
72     auto outData   = MakeEncoder<float>(outputInfo, outputHandle->Map());
73 
74     // Fill the output tensor with Pad value first
75     if (outputInfo.IsQuantized())
76     {
77         // For Quantized types Pad Value should not be quantized with scale and offset of the tensor info
78         auto temporaryInfo = TensorInfo(outputInfo.GetShape(), outputInfo.GetDataType(), 1.0f, 0);
79         auto outputData = MakeEncoder<float>(temporaryInfo, outputHandle->Map());
80         FillOutputWithPadValue(*outputData, padValue, numOutputElements);
81     }
82     else
83     {
84         FillOutputWithPadValue(*outData, padValue, numOutputElements);
85     }
86 
87     Decoder<float>& input  = *inputData;
88     Encoder<float>& output = *outData;
89 
90     switch(numInputDimensions) {
91 
92         case 1:
93             inputWidth = inputShape[0];
94             for (unsigned int w = 0; w < inputWidth ; w++)
95             {
96                 input[w];
97                 auto inputValue = input.Get();
98                 auto outputIndex = w + std::get<0>(padList[0]);
99                 output[outputIndex];
100                 output.Set(inputValue);
101             }
102 
103             break;
104         case 2  :
105             inputHeight = inputShape[0];
106             inputWidth  = inputShape[1];
107             outputWidth = outputShape[1];
108 
109             for (unsigned int h = 0; h < inputHeight; h++)
110             {
111                 for (unsigned int w = 0; w < inputWidth ; w++)
112                 {
113                     input[h * inputWidth + w];
114                     auto inputValue  = input.Get();
115                     auto outputIndex = (h + std::get<0>(padList[0])) * outputWidth + (w + std::get<0>(padList[1]));
116                     output[outputIndex];
117                     output.Set(inputValue);
118                 }
119             }
120 
121             break;
122         case 3  :
123             inputChannels = inputShape[0];
124             inputHeight   = inputShape[1];
125             inputWidth    = inputShape[2];
126             outputHeight  = outputShape[1];
127             outputWidth   = outputShape[2];
128 
129             for (unsigned int c = 0; c < inputChannels; c++)
130             {
131                 for (unsigned int h = 0; h < inputHeight; h++)
132                 {
133                     for (unsigned int w = 0; w < inputWidth ; w++)
134                     {
135                         input[c * inputHeight * inputWidth + h * inputWidth + w];
136                         auto inputValue  = input.Get();
137                         auto outputIndex = (c + std::get<0>(padList[0])) * outputHeight * outputWidth
138                                            + (h + std::get<0>(padList[1])) * outputWidth
139                                            + (w + std::get<0>(padList[2]));
140                         output[outputIndex];
141                         output.Set(inputValue);
142                     }
143                 }
144             }
145 
146             break;
147         case 4  :
148             inputBatches   = inputShape[0];
149             inputChannels  = inputShape[1];
150             inputHeight    = inputShape[2];
151             inputWidth     = inputShape[3];
152             outputChannels = outputShape[1];
153             outputHeight   = outputShape[2];
154             outputWidth    = outputShape[3];
155 
156             for (unsigned int b = 0; b < inputBatches; b++)
157             {
158                 for (unsigned int c = 0; c < inputChannels; c++)
159                 {
160                     for (unsigned int h = 0; h < inputHeight; h++)
161                     {
162                         for (unsigned int w = 0; w < inputWidth ; w++)
163                         {
164                             input[b * inputChannels * inputHeight * inputWidth
165                                       + c * inputHeight * inputWidth
166                                       + h * inputWidth
167                                       + w];
168                             auto inputValue  = input.Get();
169                             auto outputIndex = (b + std::get<0>(padList[0]))
170                                                * outputChannels * outputHeight * outputWidth
171                                                + (c + std::get<0>(padList[1])) * outputHeight * outputWidth
172                                                + (h + std::get<0>(padList[2])) * outputWidth
173                                                + (w + std::get<0>(padList[3]));
174                             output[outputIndex];
175                             output.Set(inputValue);
176                         }
177                     }
178                 }
179             }
180 
181             break;
182         default :
183             break;
184     }
185 }
186 
187 } //namespace armnn