xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/quantized_resize_bilinear_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements a quantized version of the resize bilinear op.
17 
18 #define EIGEN_USE_THREADS
19 
20 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
21 #define USE_NEON
22 #define QUANTIZED_RESIZE_BILINEAR_USE_NEON
23 #include <arm_neon.h>
24 #endif
25 
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/kernels/quantization_utils.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/util/image_resizer_state.h"
31 
32 namespace tensorflow {
33 
34 static constexpr bool USE_REFERENCE = false;
35 
36 namespace {
37 // Compute the interpolation indices only once.
38 template <typename T_SCALE>
39 struct InterpolationCache {
40   std::vector<int64_t> lower;  // Lower source index used in the interpolation
41   std::vector<int64_t> upper;  // Upper source index used in the interpolation
42   // 1-D linear interpolation scale (see:
43   // https://en.wikipedia.org/wiki/Bilinear_interpolation)
44   std::vector<float> lerp;
45   std::vector<T_SCALE> ilerp;
46 };
47 
48 template <typename T_SCALE, typename Scaler>
ComputeInterpolationWeights(const int64_t out_size,const int64_t in_size,const float scale,const int resolution,InterpolationCache<T_SCALE> * interpolation)49 inline void ComputeInterpolationWeights(
50     const int64_t out_size, const int64_t in_size, const float scale,
51     const int resolution, InterpolationCache<T_SCALE>* interpolation) {
52   const Scaler scaler;
53   interpolation->lower.resize(out_size + 1);
54   interpolation->upper.resize(out_size + 1);
55   interpolation->lerp.resize(out_size + 1);
56   interpolation->ilerp.resize(out_size + 1);
57 
58   interpolation->lower[out_size] = 0;
59   interpolation->upper[out_size] = 0;
60   for (int64_t i = out_size - 1; i >= 0; --i) {
61     const float in = scaler(i, scale);
62     const float in_f = std::floor(in);
63     interpolation->lower[i] =
64         std::max(static_cast<int64_t>(in_f), static_cast<int64_t>(0));
65     interpolation->upper[i] =
66         std::min(static_cast<int64_t>(std::ceil(in)), in_size - 1);
67     interpolation->lower[i] =
68         std::min(interpolation->lower[i], interpolation->upper[i]);
69     interpolation->lerp[i] = in - in_f;
70     interpolation->ilerp[i] =
71         static_cast<T_SCALE>((in - in_f) * (1 << resolution));
72   }
73 }
74 
75 template <typename T_SCALE>
BuildLerpCache(const int64_t out_size,const int64_t in_size,const float scale,const int index_step,const int resolution,const bool half_pixel_centers)76 inline InterpolationCache<T_SCALE> BuildLerpCache(
77     const int64_t out_size, const int64_t in_size, const float scale,
78     const int index_step, const int resolution, const bool half_pixel_centers) {
79   InterpolationCache<T_SCALE> cache;
80   // Compute the cached interpolation weights on the x and y dimensions.
81   if (half_pixel_centers) {
82     ComputeInterpolationWeights<T_SCALE, HalfPixelScaler>(
83         out_size, in_size, scale, resolution, &cache);
84   } else {
85     ComputeInterpolationWeights<T_SCALE, LegacyScaler>(out_size, in_size, scale,
86                                                        resolution, &cache);
87   }
88   CHECK(index_step > 0);
89   if (index_step > 1) {
90     for (int i = 0; i < cache.lower.size(); ++i) {
91       cache.lower[i] *= index_step;
92       cache.upper[i] *= index_step;
93     }
94   }
95   return cache;
96 }
97 
98 /**
99  * Computes the bilinear interpolation from the appropriate 4 float points
100  * and the linear interpolation weights.
101  */
102 template <typename T>
ComputeLerpReference(const T in_top_left,const T in_top_right,const T in_bottom_left,const T in_bottom_right,const float x_lerp,const float y_lerp,const float min,const float max)103 inline T ComputeLerpReference(const T in_top_left, const T in_top_right,
104                               const T in_bottom_left, const T in_bottom_right,
105                               const float x_lerp, const float y_lerp,
106                               const float min, const float max) {
107   const float top_left = QuantizedToFloat<T>(in_top_left, min, max);
108   const float top_right = QuantizedToFloat<T>(in_top_right, min, max);
109   const float bottom_left = QuantizedToFloat<T>(in_bottom_left, min, max);
110   const float bottom_right = QuantizedToFloat<T>(in_bottom_right, min, max);
111   const float top = top_left + (top_right - top_left) * x_lerp;
112   const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
113   const float out = top + (bottom - top) * y_lerp;
114   return FloatToQuantized<T>(out, min, max);
115 }
116 
117 template <typename T, typename T_SCALE, typename T_CALC>
MulOffset(T a,T b,T_SCALE c)118 inline T_CALC MulOffset(T a, T b, T_SCALE c) {
119   return (static_cast<T_CALC>(a) - static_cast<T_CALC>(b)) *
120          static_cast<T_CALC>(c);
121 }
122 
123 template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC>
ComputeLerp(const T top_left,const T top_right,const T bottom_left,const T bottom_right,const T_SCALE x_lerp,const T_SCALE y_lerp)124 inline T ComputeLerp(const T top_left, const T top_right, const T bottom_left,
125                      const T bottom_right, const T_SCALE x_lerp,
126                      const T_SCALE y_lerp) {
127   constexpr T_CALC RESOLUTION_MULT = (1 << RESOLUTION);
128   const T_CALC top = static_cast<T_CALC>(top_left) * RESOLUTION_MULT +
129                      MulOffset<T, T_SCALE, T_CALC>(top_right, top_left, x_lerp);
130   const T_CALC bottom =
131       static_cast<T_CALC>(bottom_left) * RESOLUTION_MULT +
132       MulOffset<T, T_SCALE, T_CALC>(bottom_right, bottom_left, x_lerp);
133   const T_CALC out = top + (bottom - top) / RESOLUTION_MULT * y_lerp;
134   return static_cast<T>(
135       static_cast<int32>((out + RESOLUTION_MULT / 2) / RESOLUTION_MULT));
136 }
137 
138 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
ToUint8x8(const quint8 * v0,const quint8 * v1,const quint8 * v2,const quint8 * v3,const quint8 * v4,const quint8 * v5,const quint8 * v6,const quint8 * v7)139 inline uint8x8_t ToUint8x8(const quint8* v0, const quint8* v1, const quint8* v2,
140                            const quint8* v3, const quint8* v4, const quint8* v5,
141                            const quint8* v6, const quint8* v7) {
142   static const uint8x8_t ZERO_8x8 = vmov_n_u8(0);
143   uint8x8_t ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v0), ZERO_8x8, 0);
144   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v1), ret, 1);
145   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v2), ret, 2);
146   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v3), ret, 3);
147   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v4), ret, 4);
148   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v5), ret, 5);
149   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v6), ret, 6);
150   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v7), ret, 7);
151   return ret;
152 }
153 
ToInt16x8(const int16 * v0,const int16 * v1,const int16 * v2,const int16 * v3,const int16 * v4,const int16 * v5,const int16 * v6,const int16 * v7)154 inline int16x8_t ToInt16x8(const int16* v0, const int16* v1, const int16* v2,
155                            const int16* v3, const int16* v4, const int16* v5,
156                            const int16* v6, const int16* v7) {
157   static const int16x8_t ZERO_16x8 = vmovq_n_s16(0);
158   int16x8_t ret = vld1q_lane_s16(v0, ZERO_16x8, 0);
159   ret = vld1q_lane_s16(v1, ret, 1);
160   ret = vld1q_lane_s16(v2, ret, 2);
161   ret = vld1q_lane_s16(v3, ret, 3);
162   ret = vld1q_lane_s16(v4, ret, 4);
163   ret = vld1q_lane_s16(v5, ret, 5);
164   ret = vld1q_lane_s16(v6, ret, 6);
165   ret = vld1q_lane_s16(v7, ret, 7);
166   return ret;
167 }
168 
ToInt32x2(const qint32 * v0,const qint32 * v1)169 inline int32x2_t ToInt32x2(const qint32* v0, const qint32* v1) {
170   static const int32x2_t ZERO_32x2 = vmov_n_s32(0);
171   const int32x2_t ret0 =
172       vld1_lane_s32(reinterpret_cast<const int32*>(v0), ZERO_32x2, 0);
173   const int32x2_t ret1 =
174       vld1_lane_s32(reinterpret_cast<const int32*>(v1), ret0, 1);
175   return ret1;
176 }
177 
178 template <int RESOLUTION, bool X_LERP_SAME>
ComputeLerpx2(const qint32 * top_left0,const qint32 * top_right0,const qint32 * bottom_left0,const qint32 * bottom_right0,const qint32 * top_left1,const qint32 * top_right1,const qint32 * bottom_left1,const qint32 * bottom_right1,const int32 * x_lerp,const int32x2_t y_lerpsx)179 inline int32x2_t ComputeLerpx2(
180     const qint32* top_left0, const qint32* top_right0,
181     const qint32* bottom_left0, const qint32* bottom_right0,
182     const qint32* top_left1, const qint32* top_right1,
183     const qint32* bottom_left1, const qint32* bottom_right1,
184     const int32* x_lerp, const int32x2_t y_lerpsx) {
185   const int32x2_t x_lerpsx =
186       X_LERP_SAME ? vld1_dup_s32(reinterpret_cast<const int32*>(x_lerp))
187                   : vld1_s32(reinterpret_cast<const int32*>(x_lerp));
188 
189   const int32x2_t top_leftsx = ToInt32x2(top_left0, top_left1);
190   const int32x2_t top_rightsx = ToInt32x2(top_right0, top_right1);
191   const int32x2_t bottom_leftsx = ToInt32x2(bottom_left0, bottom_left1);
192   const int32x2_t bottom_rightsx = ToInt32x2(bottom_right0, bottom_right1);
193 
194   const int32x2_t retval =
195       ComputeLerp32x2<RESOLUTION>(top_leftsx, top_rightsx, bottom_leftsx,
196                                   bottom_rightsx, x_lerpsx, y_lerpsx);
197   return retval;
198 }
199 
200 template <int RESOLUTION>
ComputeLerpx8(const quint8 * tl0,const quint8 * tr0,const quint8 * bl0,const quint8 * br0,const int16 * xlp0,const quint8 * tl1,const quint8 * tr1,const quint8 * bl1,const quint8 * br1,const int16 * xlp1,const quint8 * tl2,const quint8 * tr2,const quint8 * bl2,const quint8 * br2,const int16 * xlp2,const quint8 * tl3,const quint8 * tr3,const quint8 * bl3,const quint8 * br3,const int16 * xlp3,const quint8 * tl4,const quint8 * tr4,const quint8 * bl4,const quint8 * br4,const int16 * xlp4,const quint8 * tl5,const quint8 * tr5,const quint8 * bl5,const quint8 * br5,const int16 * xlp5,const quint8 * tl6,const quint8 * tr6,const quint8 * bl6,const quint8 * br6,const int16 * xlp6,const quint8 * tl7,const quint8 * tr7,const quint8 * bl7,const quint8 * br7,const int16 * xlp7,const int16x8_t ys_lerpsx)201 inline uint8x8_t ComputeLerpx8(
202     const quint8* tl0, const quint8* tr0, const quint8* bl0, const quint8* br0,
203     const int16* xlp0, const quint8* tl1, const quint8* tr1, const quint8* bl1,
204     const quint8* br1, const int16* xlp1, const quint8* tl2, const quint8* tr2,
205     const quint8* bl2, const quint8* br2, const int16* xlp2, const quint8* tl3,
206     const quint8* tr3, const quint8* bl3, const quint8* br3, const int16* xlp3,
207     const quint8* tl4, const quint8* tr4, const quint8* bl4, const quint8* br4,
208     const int16* xlp4, const quint8* tl5, const quint8* tr5, const quint8* bl5,
209     const quint8* br5, const int16* xlp5, const quint8* tl6, const quint8* tr6,
210     const quint8* bl6, const quint8* br6, const int16* xlp6, const quint8* tl7,
211     const quint8* tr7, const quint8* bl7, const quint8* br7, const int16* xlp7,
212     const int16x8_t ys_lerpsx) {
213   const uint8x8_t tl8x8 = ToUint8x8(tl0, tl1, tl2, tl3, tl4, tl5, tl6, tl7);
214   const uint8x8_t tr8x8 = ToUint8x8(tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7);
215   const uint8x8_t bl8x8 = ToUint8x8(bl0, bl1, bl2, bl3, bl4, bl5, bl6, bl7);
216   const uint8x8_t br8x8 = ToUint8x8(br0, br1, br2, br3, br4, br5, br6, br7);
217   const int16x8_t xs_lerpsx =
218       ToInt16x8(xlp0, xlp1, xlp2, xlp3, xlp4, xlp5, xlp6, xlp7);
219   return ComputeLerp8x8<RESOLUTION>(tl8x8, tr8x8, bl8x8, br8x8, xs_lerpsx,
220                                     ys_lerpsx);
221 }
222 
223 // Expand address at compile time to improve performance
224 template <int RESOLUTION, int ID0, int CH0, int ID1, int CH1, int ID2, int CH2,
225           int ID3, int CH3, int ID4, int CH4, int ID5, int CH5, int ID6,
226           int CH6, int ID7, int CH7>
ComputeLerpx8Tmpl(const quint8 * const yl,const quint8 * yu,const int64 * xl,const int64 * xu,const int16 * xlp,const int16x8_t ys_lerpsx)227 inline uint8x8_t ComputeLerpx8Tmpl(const quint8* const yl, const quint8* yu,
228                                    const int64* xl, const int64* xu,
229                                    const int16* xlp,
230                                    const int16x8_t ys_lerpsx) {
231   return ComputeLerpx8<RESOLUTION>(
232       yl + xl[ID0] + CH0, yl + xu[ID0] + CH0, yu + xl[ID0] + CH0,
233       yu + xu[ID0] + CH0, xlp + ID0, yl + xl[ID1] + CH1, yl + xu[ID1] + CH1,
234       yu + xl[ID1] + CH1, yu + xu[ID1] + CH1, xlp + ID1, yl + xl[ID2] + CH2,
235       yl + xu[ID2] + CH2, yu + xl[ID2] + CH2, yu + xu[ID2] + CH2, xlp + ID2,
236       yl + xl[ID3] + CH3, yl + xu[ID3] + CH3, yu + xl[ID3] + CH3,
237       yu + xu[ID3] + CH3, xlp + ID3, yl + xl[ID4] + CH4, yl + xu[ID4] + CH4,
238       yu + xl[ID4] + CH4, yu + xu[ID4] + CH4, xlp + ID4, yl + xl[ID5] + CH5,
239       yl + xu[ID5] + CH5, yu + xl[ID5] + CH5, yu + xu[ID5] + CH5, xlp + ID5,
240       yl + xl[ID6] + CH6, yl + xu[ID6] + CH6, yu + xl[ID6] + CH6,
241       yu + xu[ID6] + CH6, xlp + ID6, yl + xl[ID7] + CH7, yl + xu[ID7] + CH7,
242       yu + xl[ID7] + CH7, yu + xu[ID7] + CH7, xlp + ID7, ys_lerpsx);
243 }
244 
245 #endif
246 
247 template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC>
OutputLerpForChannels(const InterpolationCache<T_SCALE> & xs,const int64_t x,const T_SCALE ys_ilerp,const int channels,const float min,const float max,const T * ys_input_lower_ptr,const T * ys_input_upper_ptr,T * output_y_ptr)248 inline void OutputLerpForChannels(const InterpolationCache<T_SCALE>& xs,
249                                   const int64_t x, const T_SCALE ys_ilerp,
250                                   const int channels, const float min,
251                                   const float max, const T* ys_input_lower_ptr,
252                                   const T* ys_input_upper_ptr,
253                                   T* output_y_ptr) {
254   const int64_t xs_lower = xs.lower[x];
255   const int64_t xs_upper = xs.upper[x];
256   const T_SCALE xs_ilerp = xs.ilerp[x];
257   for (int c = 0; c < channels; ++c) {
258     const T top_left = ys_input_lower_ptr[xs_lower + c];
259     const T top_right = ys_input_lower_ptr[xs_upper + c];
260     const T bottom_left = ys_input_upper_ptr[xs_lower + c];
261     const T bottom_right = ys_input_upper_ptr[xs_upper + c];
262     const T val = ComputeLerp<RESOLUTION, T, T_SCALE, T_CALC>(
263         top_left, top_right, bottom_left, bottom_right, xs_ilerp, ys_ilerp);
264     output_y_ptr[x * channels + c] = val;
265   }
266 }
267 
268 template <int RES>
OutputLerp8x8x1(const InterpolationCache<int16> & xs,const int64_t x_start,const int16_t ys_ilerp,const float min,const float max,const quint8 * const ys_input_lower_ptr,const quint8 * const ys_input_upper_ptr,quint8 * output_y_ptr)269 inline void OutputLerp8x8x1(const InterpolationCache<int16>& xs,
270                             const int64_t x_start, const int16_t ys_ilerp,
271                             const float min, const float max,
272                             const quint8* const ys_input_lower_ptr,
273                             const quint8* const ys_input_upper_ptr,
274                             quint8* output_y_ptr) {
275 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
276   const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp);
277 
278   const uint8x8_t x0x7 =
279       ComputeLerpx8Tmpl<RES, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0>(
280           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
281           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
282 
283   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start), x0x7);
284 
285 #else
286   for (int x = x_start; x < x_start + 8; ++x) {
287     OutputLerpForChannels<RES, quint8, int16, int16>(
288         xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
289         output_y_ptr);
290   }
291 #endif
292 }
293 
294 template <int RES>
OutputLerp8x8x3(const InterpolationCache<int16> & xs,const int64_t x_start,const int16_t ys_ilerp,const float min,const float max,const quint8 * const ys_input_lower_ptr,const quint8 * const ys_input_upper_ptr,quint8 * output_y_ptr)295 inline void OutputLerp8x8x3(const InterpolationCache<int16>& xs,
296                             const int64_t x_start, const int16_t ys_ilerp,
297                             const float min, const float max,
298                             const quint8* const ys_input_lower_ptr,
299                             const quint8* const ys_input_upper_ptr,
300                             quint8* output_y_ptr) {
301 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
302   const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp);
303 
304   const uint8x8_t x0c0x2c1 =
305       ComputeLerpx8Tmpl<RES, 0, 0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 2, 2, 0, 2, 1>(
306           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
307           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
308 
309   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3), x0c0x2c1);
310 
311   const uint8x8_t x2c2x5c0 =
312       ComputeLerpx8Tmpl<RES, 2, 2, 3, 0, 3, 1, 3, 2, 4, 0, 4, 1, 4, 2, 5, 0>(
313           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
314           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
315 
316   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 8), x2c2x5c0);
317 
318   const uint8x8_t x5c1x7c2 =
319       ComputeLerpx8Tmpl<RES, 5, 1, 5, 2, 6, 0, 6, 1, 6, 2, 7, 0, 7, 1, 7, 2>(
320           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
321           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
322 
323   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 16),
324           x5c1x7c2);
325 
326 #else
327   for (int x = x_start; x < x_start + 8; ++x) {
328     OutputLerpForChannels<RES, quint8, int16, int16>(
329         xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
330         output_y_ptr);
331   }
332 #endif
333 }
334 
335 template <int RESOLUTION>
OutputLerp32x4x1(const InterpolationCache<int32> & xs,const int64_t x_start,const int32_t ys_ilerp,const float min,const float max,const qint32 * const ys_input_lower_ptr,const qint32 * const ys_input_upper_ptr,qint32 * output_y_ptr)336 inline void OutputLerp32x4x1(const InterpolationCache<int32>& xs,
337                              const int64_t x_start, const int32_t ys_ilerp,
338                              const float min, const float max,
339                              const qint32* const ys_input_lower_ptr,
340                              const qint32* const ys_input_upper_ptr,
341                              qint32* output_y_ptr) {
342 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
343   const int64 xs_lower0 = xs.lower[x_start];
344   const int64 xs_upper0 = xs.upper[x_start];
345   const int32* const xs_ilerp0 = &xs.ilerp[x_start];
346   const int64 xs_lower1 = xs.lower[x_start + 1];
347   const int64 xs_upper1 = xs.upper[x_start + 1];
348   const int64 xs_lower2 = xs.lower[x_start + 2];
349   const int64 xs_upper2 = xs.upper[x_start + 2];
350   const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2];
351   const int64 xs_lower3 = xs.lower[x_start + 3];
352   const int64 xs_upper3 = xs.upper[x_start + 3];
353 
354   const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp);
355 
356   const int32x2_t x0x1 = ComputeLerpx2<RESOLUTION, false>(
357       ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0,
358       ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0,
359       ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1,
360       ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0,
361       y_lerpsx);
362 
363   const int32x2_t x1x2 = ComputeLerpx2<RESOLUTION, false>(
364       ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2,
365       ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2,
366       ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3,
367       ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2,
368       y_lerpsx);
369 
370   const int32x4_t x0x1x2x3 = vcombine_s32(x0x1, x1x2);
371 
372   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start), x0x1x2x3);
373 
374 #else
375   for (int x = x_start; x < x_start + 4; ++x) {
376     OutputLerpForChannels<RESOLUTION, qint32, int32, int64_t>(
377         xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
378         output_y_ptr);
379   }
380 #endif
381 }
382 
383 template <int RESOLUTION>
OutputLerp32x4x3(const InterpolationCache<int32> & xs,const int64_t x_start,const int32_t ys_ilerp,const float min,const float max,const qint32 * const ys_input_lower_ptr,const qint32 * const ys_input_upper_ptr,qint32 * output_y_ptr)384 inline void OutputLerp32x4x3(const InterpolationCache<int32>& xs,
385                              const int64_t x_start, const int32_t ys_ilerp,
386                              const float min, const float max,
387                              const qint32* const ys_input_lower_ptr,
388                              const qint32* const ys_input_upper_ptr,
389                              qint32* output_y_ptr) {
390 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
391   const int64 xs_lower0 = xs.lower[x_start];
392   const int64 xs_upper0 = xs.upper[x_start];
393   const int32* const xs_ilerp0 = &xs.ilerp[x_start];
394   const int64 xs_lower1 = xs.lower[x_start + 1];
395   const int64 xs_upper1 = xs.upper[x_start + 1];
396   const int32* const xs_ilerp1 = &xs.ilerp[x_start + 1];
397   const int64 xs_lower2 = xs.lower[x_start + 2];
398   const int64 xs_upper2 = xs.upper[x_start + 2];
399   const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2];
400   const int64 xs_lower3 = xs.lower[x_start + 3];
401   const int64 xs_upper3 = xs.upper[x_start + 3];
402   const int32* const xs_ilerp3 = &xs.ilerp[x_start + 3];
403 
404   const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp);
405 
406   const int32x2_t x0c0x0c1 = ComputeLerpx2<RESOLUTION, true>(
407       ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0,
408       ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0,
409       ys_input_lower_ptr + xs_lower0 + 1, ys_input_lower_ptr + xs_upper0 + 1,
410       ys_input_upper_ptr + xs_lower0 + 1, ys_input_upper_ptr + xs_upper0 + 1,
411       xs_ilerp0, y_lerpsx);
412 
413   const int32x2_t x0c2x1c0 = ComputeLerpx2<RESOLUTION, false>(
414       ys_input_lower_ptr + xs_lower0 + 2, ys_input_lower_ptr + xs_upper0 + 2,
415       ys_input_upper_ptr + xs_lower0 + 2, ys_input_upper_ptr + xs_upper0 + 2,
416       ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1,
417       ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0,
418       y_lerpsx);
419 
420   const int32x2_t x1c1x1c2 = ComputeLerpx2<RESOLUTION, true>(
421       ys_input_lower_ptr + xs_lower1 + 1, ys_input_lower_ptr + xs_upper1 + 1,
422       ys_input_upper_ptr + xs_lower1 + 1, ys_input_upper_ptr + xs_upper1 + 1,
423       ys_input_lower_ptr + xs_lower1 + 2, ys_input_lower_ptr + xs_upper1 + 2,
424       ys_input_upper_ptr + xs_lower1 + 2, ys_input_upper_ptr + xs_upper1 + 2,
425       xs_ilerp1, y_lerpsx);
426 
427   const int32x2_t x2c0x2c1 = ComputeLerpx2<RESOLUTION, true>(
428       ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2,
429       ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2,
430       ys_input_lower_ptr + xs_lower2 + 1, ys_input_lower_ptr + xs_upper2 + 1,
431       ys_input_upper_ptr + xs_lower2 + 1, ys_input_upper_ptr + xs_upper2 + 1,
432       xs_ilerp2, y_lerpsx);
433 
434   const int32x2_t x2c2x3c0 = ComputeLerpx2<RESOLUTION, false>(
435       ys_input_lower_ptr + xs_lower2 + 2, ys_input_lower_ptr + xs_upper2 + 2,
436       ys_input_upper_ptr + xs_lower2 + 2, ys_input_upper_ptr + xs_upper2 + 2,
437       ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3,
438       ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2,
439       y_lerpsx);
440 
441   const int32x2_t x3c1x3c2 = ComputeLerpx2<RESOLUTION, true>(
442       ys_input_lower_ptr + xs_lower3 + 1, ys_input_lower_ptr + xs_upper3 + 1,
443       ys_input_upper_ptr + xs_lower3 + 1, ys_input_upper_ptr + xs_upper3 + 1,
444       ys_input_lower_ptr + xs_lower3 + 2, ys_input_lower_ptr + xs_upper3 + 2,
445       ys_input_upper_ptr + xs_lower3 + 2, ys_input_upper_ptr + xs_upper3 + 2,
446       xs_ilerp3, y_lerpsx);
447 
448   const int32x4_t x0c0x0c1x0c2x1c0 = vcombine_s32(x0c0x0c1, x0c2x1c0);
449   const int32x4_t x1c1x1c2x2c0x2c1 = vcombine_s32(x1c1x1c2, x2c0x2c1);
450   const int32x4_t x2c2x3c0x3c1x3c2 = vcombine_s32(x2c2x3c0, x3c1x3c2);
451 
452   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3),
453             x0c0x0c1x0c2x1c0);
454   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 4),
455             x1c1x1c2x2c0x2c1);
456   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 8),
457             x2c2x3c0x3c1x3c2);
458 
459 #else
460   for (int x = x_start; x < x_start + 4; ++x) {
461     OutputLerpForChannels<RESOLUTION, qint32, int32, int64_t>(
462         xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
463         output_y_ptr);
464   }
465 #endif
466 }
467 
468 template <typename T>
ResizeImageReference(typename TTypes<T,4>::ConstTensor images,const int batch_size,const int64_t in_height,const int64_t in_width,const int64_t out_height,const int64_t out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<T,4>::Tensor * output)469 void ResizeImageReference(typename TTypes<T, 4>::ConstTensor images,
470                           const int batch_size, const int64_t in_height,
471                           const int64_t in_width, const int64_t out_height,
472                           const int64_t out_width, const int channels,
473                           const float height_scale, const float width_scale,
474                           const float in_min, const float in_max,
475                           const bool half_pixel_centers,
476                           typename TTypes<T, 4>::Tensor* output) {
477   CHECK_NOTNULL(output);
478 
479   const InterpolationCache<float> xs = BuildLerpCache<float>(
480       out_width, in_width, width_scale, channels, 0, half_pixel_centers);
481   const InterpolationCache<float> ys = BuildLerpCache<float>(
482       out_height, in_height, height_scale, 1, 0, half_pixel_centers);
483 
484   const int64_t in_row_size = in_width * channels;
485   const int64_t in_batch_num_values = in_height * in_row_size;
486   const int64_t out_row_size = out_width * channels;
487 
488   const T* input_b_ptr = images.data();
489 
490   T* output_y_ptr = output->data();
491   for (int b = 0; b < batch_size; ++b) {
492     for (int64_t y = 0; y < out_height; ++y) {
493       const T* ys_input_lower_ptr = input_b_ptr + ys.lower[y] * in_row_size;
494       const T* ys_input_upper_ptr = input_b_ptr + ys.upper[y] * in_row_size;
495       const float ys_lerp = ys.lerp[y];
496       for (int64_t x = 0; x < out_width; ++x) {
497         const int64_t xs_lower = xs.lower[x];
498         const int64_t xs_upper = xs.upper[x];
499         const float xs_lerp = xs.lerp[x];
500         for (int c = 0; c < channels; ++c) {
501           const T top_left = ys_input_lower_ptr[xs_lower + c];
502           const T top_right = ys_input_lower_ptr[xs_upper + c];
503           const T bottom_left = ys_input_upper_ptr[xs_lower + c];
504           const T bottom_right = ys_input_upper_ptr[xs_upper + c];
505           const T val = ComputeLerpReference<T>(
506               top_left, top_right, bottom_left, bottom_right, xs_lerp, ys_lerp,
507               in_min, in_max);
508           output_y_ptr[x * channels + c] = val;
509         }
510       }
511       output_y_ptr += out_row_size;
512     }
513     input_b_ptr += in_batch_num_values;
514   }
515 }
516 
517 template <typename T>
ResizeImage(typename TTypes<T,4>::ConstTensor images,const int batch_size,const int64_t in_height,const int64_t in_width,const int64_t out_height,const int64_t out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<T,4>::Tensor * output)518 void ResizeImage(typename TTypes<T, 4>::ConstTensor images,
519                  const int batch_size, const int64_t in_height,
520                  const int64_t in_width, const int64_t out_height,
521                  const int64_t out_width, const int channels,
522                  const float height_scale, const float width_scale,
523                  const float in_min, const float in_max,
524                  const bool half_pixel_centers,
525                  typename TTypes<T, 4>::Tensor* output) {
526   ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height,
527                           out_width, channels, height_scale, width_scale,
528                           in_min, in_max, half_pixel_centers, output);
529 }
530 
531 template <>
ResizeImage(typename TTypes<qint32,4>::ConstTensor images,const int batch_size,const int64_t in_height,const int64_t in_width,const int64_t out_height,const int64_t out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<qint32,4>::Tensor * output)532 void ResizeImage<qint32>(typename TTypes<qint32, 4>::ConstTensor images,
533                          const int batch_size, const int64_t in_height,
534                          const int64_t in_width, const int64_t out_height,
535                          const int64_t out_width, const int channels,
536                          const float height_scale, const float width_scale,
537                          const float in_min, const float in_max,
538                          const bool half_pixel_centers,
539                          typename TTypes<qint32, 4>::Tensor* output) {
540   // 30 is maximum resolution for signed int.
541   constexpr int RESOLUTION = 30;
542   constexpr int SIMD_STEP = 4;
543 
544   CHECK_NOTNULL(output);
545 
546   const InterpolationCache<int32> xs =
547       BuildLerpCache<int32>(out_width, in_width, width_scale, channels,
548                             RESOLUTION, half_pixel_centers);
549   const InterpolationCache<int32> ys = BuildLerpCache<int32>(
550       out_height, in_height, height_scale, 1, RESOLUTION, half_pixel_centers);
551 
552   const int64_t in_row_size = in_width * channels;
553   const int64_t in_batch_num_values = in_height * in_row_size;
554   const int64_t out_row_size = out_width * channels;
555 
556   const qint32* input_b_ptr = images.data();
557 
558   qint32* output_y_ptr = output->data();
559 
560   for (int b = 0; b < batch_size; ++b) {
561     for (int64_t y = 0; y < out_height; ++y) {
562       const qint32* ys_input_lower_ptr =
563           input_b_ptr + ys.lower[y] * in_row_size;
564       const qint32* ys_input_upper_ptr =
565           input_b_ptr + ys.upper[y] * in_row_size;
566       const int32_t ys_ilerp = ys.ilerp[y];
567       // Optimized for channels == 1 or channels == 3 as this
568       // is typical channels.
569       int64_t x = 0;
570       if (channels == 1) {
571         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
572           OutputLerp32x4x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
573                                        ys_input_lower_ptr, ys_input_upper_ptr,
574                                        output_y_ptr);
575         }
576       } else if (channels == 3) {
577         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
578           OutputLerp32x4x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
579                                        ys_input_lower_ptr, ys_input_upper_ptr,
580                                        output_y_ptr);
581         }
582       }
583       for (; x < out_width; ++x) {
584         OutputLerpForChannels<RESOLUTION, qint32, int32, int64_t>(
585             xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr,
586             ys_input_upper_ptr, output_y_ptr);
587       }
588       output_y_ptr += out_row_size;
589     }
590     input_b_ptr += in_batch_num_values;
591   }
592 }
593 
594 template <>
ResizeImage(typename TTypes<quint8,4>::ConstTensor images,const int batch_size,const int64_t in_height,const int64_t in_width,const int64_t out_height,const int64_t out_width,const int channels,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<quint8,4>::Tensor * output)595 void ResizeImage<quint8>(typename TTypes<quint8, 4>::ConstTensor images,
596                          const int batch_size, const int64_t in_height,
597                          const int64_t in_width, const int64_t out_height,
598                          const int64_t out_width, const int channels,
599                          const float height_scale, const float width_scale,
600                          const float in_min, const float in_max,
601                          const bool half_pixel_centers,
602                          typename TTypes<quint8, 4>::Tensor* output) {
603   // 7 is maximum resolution for unsigned byte.
604   constexpr int RESOLUTION = 7;
605   constexpr int SIMD_STEP = 8;
606 
607   CHECK_NOTNULL(output);
608 
609   const InterpolationCache<int16> xs =
610       BuildLerpCache<int16>(out_width, in_width, width_scale, channels,
611                             RESOLUTION, half_pixel_centers);
612   const InterpolationCache<int16> ys = BuildLerpCache<int16>(
613       out_height, in_height, height_scale, 1, RESOLUTION, half_pixel_centers);
614 
615   const int64_t in_row_size = in_width * channels;
616   const int64_t in_batch_num_values = in_height * in_row_size;
617   const int64_t out_row_size = out_width * channels;
618 
619   const quint8* input_b_ptr = images.data();
620 
621   quint8* output_y_ptr = output->data();
622 
623   for (int b = 0; b < batch_size; ++b) {
624     for (int64_t y = 0; y < out_height; ++y) {
625       const quint8* ys_input_lower_ptr =
626           input_b_ptr + ys.lower[y] * in_row_size;
627       const quint8* ys_input_upper_ptr =
628           input_b_ptr + ys.upper[y] * in_row_size;
629       const int32_t ys_ilerp = ys.ilerp[y];
630       // Optimized for channels == 1 or channels == 3 as this
631       // is typical channels.
632       // TODO(satok): Support more generic NEON optimized implementation
633       // for different channels.
634       int64_t x = 0;
635       if (channels == 1) {
636         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
637           OutputLerp8x8x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
638                                       ys_input_lower_ptr, ys_input_upper_ptr,
639                                       output_y_ptr);
640         }
641       } else if (channels == 3) {
642         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
643           OutputLerp8x8x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
644                                       ys_input_lower_ptr, ys_input_upper_ptr,
645                                       output_y_ptr);
646         }
647       }
648       for (; x < out_width; ++x) {
649         OutputLerpForChannels<RESOLUTION, quint8, int16, int16>(
650             xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr,
651             ys_input_upper_ptr, output_y_ptr);
652       }
653       output_y_ptr += out_row_size;
654     }
655     input_b_ptr += in_batch_num_values;
656   }
657 }
658 
659 template <typename T>
ResizeBilinear(const typename TTypes<T,4>::ConstTensor & images,const float height_scale,const float width_scale,const float in_min,const float in_max,const bool half_pixel_centers,typename TTypes<T,4>::Tensor * output)660 void ResizeBilinear(const typename TTypes<T, 4>::ConstTensor& images,
661                     const float height_scale, const float width_scale,
662                     const float in_min, const float in_max,
663                     const bool half_pixel_centers,
664                     typename TTypes<T, 4>::Tensor* output) {
665   CHECK_NOTNULL(output);
666 
667   const int batch_size = images.dimension(0);
668   const int64_t in_height = images.dimension(1);
669   const int64_t in_width = images.dimension(2);
670   const int channels = images.dimension(3);
671 
672   const int64_t out_height = output->dimension(1);
673   const int64_t out_width = output->dimension(2);
674 
675   // Handle no-op resizes efficiently.
676   if (out_height == in_height && out_width == in_width) {
677     *output = images.template cast<T>();
678     return;
679   }
680 
681   if (USE_REFERENCE) {
682     ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height,
683                             out_width, channels, height_scale, width_scale,
684                             in_min, in_max, half_pixel_centers, output);
685   } else {
686     ResizeImage<T>(images, batch_size, in_height, in_width, out_height,
687                    out_width, channels, height_scale, width_scale, in_min,
688                    in_max, half_pixel_centers, output);
689   }
690 }
691 
692 }  // namespace
693 
694 template <class T>
695 class QuantizedResizeBilinearOp : public OpKernel {
696  public:
QuantizedResizeBilinearOp(OpKernelConstruction * context)697   explicit QuantizedResizeBilinearOp(OpKernelConstruction* context)
698       : OpKernel(context) {
699     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
700     OP_REQUIRES_OK(
701         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
702   }
703 
Compute(OpKernelContext * context)704   void Compute(OpKernelContext* context) override {
705     const auto& in_min_tensor = context->input(2);
706     OP_REQUIRES(context, TensorShapeUtils::IsScalar(in_min_tensor.shape()),
707                 errors::InvalidArgument("min must be a scalar"));
708     const float in_min = in_min_tensor.flat<float>()(0);
709     const auto& in_max_tensor = context->input(3);
710     OP_REQUIRES(context, TensorShapeUtils::IsScalar(in_max_tensor.shape()),
711                 errors::InvalidArgument("max must be a scalar"));
712     const float in_max = in_max_tensor.flat<float>()(0);
713 
714     ImageResizerState st(align_corners_, false);
715     st.ValidateAndCreateOutput(context);
716 
717     if (!context->status().ok()) return;
718 
719     // Return if the output is empty.
720     if (st.output->NumElements() == 0) return;
721 
722     typename TTypes<T, 4>::ConstTensor image_data(
723         context->input(0).tensor<T, 4>());
724     typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>());
725 
726     ResizeBilinear<T>(image_data, st.height_scale, st.width_scale, in_min,
727                       in_max, half_pixel_centers_, &output_data);
728     Tensor* out_min = nullptr;
729     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &out_min));
730     out_min->flat<float>()(0) = in_min;
731 
732     Tensor* out_max = nullptr;
733     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &out_max));
734     out_max->flat<float>()(0) = in_max;
735   }
736 
737  private:
738   bool align_corners_;
739   bool half_pixel_centers_;
740 
741   TF_DISALLOW_COPY_AND_ASSIGN(QuantizedResizeBilinearOp<T>);
742 };
743 
744 #define REGISTER_CPU_KERNEL(type)                         \
745   REGISTER_KERNEL_BUILDER(Name("QuantizedResizeBilinear") \
746                               .Device(DEVICE_CPU)         \
747                               .HostMemory("size")         \
748                               .TypeConstraint<type>("T"), \
749                           QuantizedResizeBilinearOp<type>)
750 
751 REGISTER_CPU_KERNEL(::tensorflow::quint8);
752 REGISTER_CPU_KERNEL(::tensorflow::qint32);
753 REGISTER_CPU_KERNEL(float);
754 
755 }  // namespace tensorflow
756