xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Resize.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Resize.hpp"
7 
8 #include "TensorBufferArrayView.hpp"
9 
10 #include <armnn/utility/NumericCast.hpp>
11 
12 #include <cmath>
13 #include <algorithm>
14 
15 using namespace armnnUtils;
16 
17 namespace armnn
18 {
19 
20 namespace
21 {
22 
Lerp(float a,float b,float w)23 inline float Lerp(float a, float b, float w)
24 {
25     return w * b + (1.f - w) * a;
26 }
27 
EuclideanDistance(float Xa,float Ya,const unsigned int Xb,const unsigned int Yb)28 inline double EuclideanDistance(float Xa, float Ya, const unsigned int Xb, const unsigned int Yb)
29 {
30     return std::sqrt(pow(Xa - armnn::numeric_cast<float>(Xb), 2) + pow(Ya - armnn::numeric_cast<float>(Yb), 2));
31 }
32 
CalculateResizeScale(const unsigned int & InputSize,const unsigned int & OutputSize,const bool & AlignCorners)33 inline float CalculateResizeScale(const unsigned int& InputSize,
34                                   const unsigned int& OutputSize,
35                                   const bool& AlignCorners)
36 {
37     return (AlignCorners && OutputSize > 1)
38             ?  armnn::numeric_cast<float>(InputSize - 1) / armnn::numeric_cast<float>(OutputSize - 1)
39             :  armnn::numeric_cast<float>(InputSize) / armnn::numeric_cast<float>(OutputSize);
40 }
41 
PixelScaler(const unsigned int & Pixel,const float & Scale,const bool & HalfPixelCenters,armnn::ResizeMethod & resizeMethod)42 inline float PixelScaler(const unsigned int& Pixel,
43                          const float& Scale,
44                          const bool& HalfPixelCenters,
45                          armnn::ResizeMethod& resizeMethod)
46 {
47     // For Half Pixel Centers the Top Left texel is assumed to be at 0.5,0.5
48     if (HalfPixelCenters && resizeMethod == armnn::ResizeMethod::Bilinear)
49     {
50         return (static_cast<float>(Pixel) + 0.5f) * Scale - 0.5f;
51     }
52     // Nearest Neighbour doesn't need to have 0.5f trimmed off as it will floor the values later
53     else if (HalfPixelCenters && resizeMethod == armnn::ResizeMethod::NearestNeighbor)
54     {
55         return (static_cast<float>(Pixel) + 0.5f) * Scale;
56     }
57     else
58     {
59         return static_cast<float>(Pixel) * Scale;
60     }
61 }
62 
63 }// anonymous namespace
64 
Resize(Decoder<float> & in,const TensorInfo & inputInfo,Encoder<float> & out,const TensorInfo & outputInfo,DataLayoutIndexed dataLayout,armnn::ResizeMethod resizeMethod,bool alignCorners,bool halfPixelCenters)65 void Resize(Decoder<float>&   in,
66             const TensorInfo& inputInfo,
67             Encoder<float>&   out,
68             const TensorInfo& outputInfo,
69             DataLayoutIndexed dataLayout,
70             armnn::ResizeMethod resizeMethod,
71             bool alignCorners,
72             bool halfPixelCenters)
73 {
74     // alignCorners and halfPixelCenters cannot both be true
75     ARMNN_ASSERT(!(alignCorners && halfPixelCenters));
76 
77     // We follow the definition of TensorFlow and AndroidNN: the top-left corner of a texel in the output
78     // image is projected into the input image to figure out the interpolants and weights. Note that this
79     // will yield different results than if projecting the centre of output texels.
80 
81     const unsigned int batchSize = inputInfo.GetShape()[0];
82     const unsigned int channelCount = inputInfo.GetShape()[dataLayout.GetChannelsIndex()];
83 
84     const unsigned int inputHeight = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
85     const unsigned int inputWidth = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
86     const unsigned int outputHeight = outputInfo.GetShape()[dataLayout.GetHeightIndex()];
87     const unsigned int outputWidth = outputInfo.GetShape()[dataLayout.GetWidthIndex()];
88 
89     // How much to scale pixel coordinates in the output image, to get the corresponding pixel coordinates
90     // in the input image.
91     const float scaleY = CalculateResizeScale(inputHeight, outputHeight, alignCorners);
92     const float scaleX = CalculateResizeScale(inputWidth, outputWidth, alignCorners);
93 
94     TensorShape inputShape =  inputInfo.GetShape();
95     TensorShape outputShape =  outputInfo.GetShape();
96 
97     for (unsigned int n = 0; n < batchSize; ++n)
98     {
99         for (unsigned int c = 0; c < channelCount; ++c)
100         {
101             for (unsigned int y = 0; y < outputHeight; ++y)
102             {
103                 // Corresponding real-valued height coordinate in input image.
104                 float iy = PixelScaler(y, scaleY, halfPixelCenters, resizeMethod);
105 
106                 // Discrete height coordinate of top-left texel (in the 2x2 texel area used for interpolation).
107                 const float fiy = (resizeMethod == armnn::ResizeMethod::NearestNeighbor && alignCorners) ?
108                                   roundf(iy) : floorf(iy);
109                 // Pixel scaling a value with Half Pixel Centers can be negative, if so set to 0
110                 const unsigned int y0 = static_cast<unsigned int>(std::max(fiy, 0.0f));
111 
112                 // Interpolation weight (range [0,1]).
113                 const float yw = iy - fiy;
114 
115                 for (unsigned int x = 0; x < outputWidth; ++x)
116                 {
117                     // Real-valued and discrete width coordinates in input image.
118                     float ix = PixelScaler(x, scaleX, halfPixelCenters, resizeMethod);
119 
120                     // Nearest Neighbour uses rounding to align to corners
121                     const float fix = resizeMethod == armnn::ResizeMethod::NearestNeighbor && alignCorners ?
122                                       roundf(ix) : floorf(ix);
123                     // Pixel scaling a value with Half Pixel Centers can be negative, if so set to 0
124                     const unsigned int x0 = static_cast<unsigned int>(std::max(fix, 0.0f));
125 
126                     // Interpolation weight (range [0,1]).
127                     const float xw = ix - fix;
128 
129                     unsigned int x1;
130                     unsigned int y1;
131                     // Half Pixel Centers uses the scaling to compute a weighted parameter for nearby pixels
132                     if (halfPixelCenters)
133                     {
134                         x1 = std::min(static_cast<unsigned int>(std::ceil(ix)), inputWidth - 1u);
135                         y1 = std::min(static_cast<unsigned int>(std::ceil(iy)), inputHeight - 1u);
136                     }
137                     // Discrete width/height coordinates of texels below and to the right of (x0, y0).
138                     else
139                     {
140                         x1 = std::min(x0 + 1, inputWidth - 1u);
141                         y1 = std::min(y0 + 1, inputHeight - 1u);
142                     }
143 
144                     float interpolatedValue;
145                     switch (resizeMethod)
146                     {
147                         case armnn::ResizeMethod::Bilinear:
148                         {
149                             in[dataLayout.GetIndex(inputShape, n, c, y0, x0)];
150                             float input1 = in.Get();
151                             in[dataLayout.GetIndex(inputShape, n, c, y0, x1)];
152                             float input2 = in.Get();
153                             in[dataLayout.GetIndex(inputShape, n, c, y1, x0)];
154                             float input3 = in.Get();
155                             in[dataLayout.GetIndex(inputShape, n, c, y1, x1)];
156                             float input4 = in.Get();
157 
158                             const float ly0 = Lerp(input1, input2, xw); // lerp along row y0.
159                             const float ly1 = Lerp(input3, input4, xw); // lerp along row y1.
160                             interpolatedValue = Lerp(ly0, ly1, yw);
161                             break;
162                         }
163                         case armnn::ResizeMethod::NearestNeighbor:
164                         {
165                             // calculate euclidean distance to the 4 neighbours
166                             auto distance00 = EuclideanDistance(fix, fiy, x0, y0);
167                             auto distance01 = EuclideanDistance(fix, fiy, x0, y1);
168                             auto distance10 = EuclideanDistance(fix, fiy, x1, y0);
169                             auto distance11 = EuclideanDistance(fix, fiy, x1, y1);
170 
171                             auto minimum = std::min( { distance00, distance01, distance10, distance11 } );
172 
173                             unsigned int xNearest = 0;
174                             unsigned int yNearest = 0;
175 
176                             if (minimum == distance00)
177                             {
178                                xNearest = x0;
179                                yNearest = y0;
180                             }
181                             else if (minimum == distance01)
182                             {
183                                 xNearest = x0;
184                                 yNearest = y1;
185                             }
186                             else if (minimum == distance10)
187                             {
188                                 xNearest = x1;
189                                 yNearest = y0;
190                             }
191                             else if (minimum == distance11)
192                             {
193                                 xNearest = x1;
194                                 yNearest = y1;
195                             }
196                             else
197                             {
198                                 throw armnn::InvalidArgumentException("Resize Nearest Neighbor failure");
199                             }
200 
201                             in[dataLayout.GetIndex(inputShape, n, c, yNearest, xNearest)];
202                             interpolatedValue = in.Get();
203                             break;
204                         }
205                         default:
206                             throw armnn::InvalidArgumentException("Unknown resize method: " +
207                                                                   std::to_string(static_cast<int>(resizeMethod)));
208                     }
209                     out[dataLayout.GetIndex(outputShape, n, c, y, x)];
210                     out.Set(interpolatedValue);
211                 }
212             }
213         }
214     }
215 }
216 
217 } //namespace armnn
218