xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/image/resize_bilinear_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 #define EIGEN_USE_THREADS
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #define EIGEN_USE_GPU
21 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 
23 #include "tensorflow/core/kernels/image/resize_bilinear_op.h"
24 
25 #ifdef __SSE4_1__
26 #include <xmmintrin.h>
27 #endif
28 
29 #include <memory>
30 
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/kernels/cast_op.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/util/image_resizer_state.h"
41 
42 namespace tensorflow {
43 
44 typedef Eigen::ThreadPoolDevice CPUDevice;
45 typedef Eigen::GpuDevice GPUDevice;
46 
47 template <typename Device, typename T>
48 class ResizeBilinearOp : public OpKernel {
49  public:
ResizeBilinearOp(OpKernelConstruction * context)50   explicit ResizeBilinearOp(OpKernelConstruction* context) : OpKernel(context) {
51     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
52     OP_REQUIRES_OK(
53         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
54   }
55 
Compute(OpKernelContext * context)56   void Compute(OpKernelContext* context) override {
57     ImageResizerState st(align_corners_, half_pixel_centers_);
58     st.ValidateAndCreateOutput(context);
59 
60     if (!context->status().ok()) return;
61 
62     // Return if the output is empty.
63     if (st.output->NumElements() == 0) return;
64 
65     typename TTypes<T, 4>::ConstTensor image_data(
66         context->input(0).tensor<T, 4>());
67     TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
68 
69     functor::ResizeBilinear<Device, T>()(
70         context->eigen_device<Device>(), image_data, st.height_scale,
71         st.width_scale, half_pixel_centers_, output_data);
72   }
73 
74  private:
75   bool align_corners_;
76   bool half_pixel_centers_;
77 };
78 
79 namespace {
80 // Compute the interpolation indices only once.
81 struct CachedInterpolation {
82   int64_t lower;  // Lower source index used in the interpolation
83   int64_t upper;  // Upper source index used in the interpolation
84   // 1-D linear interpolation scale (see:
85   // https://en.wikipedia.org/wiki/Bilinear_interpolation)
86   float lerp;
87 };
88 
89 template <typename Scaler>
compute_interpolation_weights(const Scaler scaler,const int64_t out_size,const int64_t in_size,const float scale,CachedInterpolation * interpolation)90 inline void compute_interpolation_weights(const Scaler scaler,
91                                           const int64_t out_size,
92                                           const int64_t in_size,
93                                           const float scale,
94                                           CachedInterpolation* interpolation) {
95   interpolation[out_size].lower = 0;
96   interpolation[out_size].upper = 0;
97   for (int64_t i = out_size - 1; i >= 0; --i) {
98     const float in = scaler(i, scale);
99     const float in_f = std::floor(in);
100     interpolation[i].lower =
101         std::max(static_cast<int64_t>(in_f), static_cast<int64_t>(0));
102     interpolation[i].upper =
103         std::min(static_cast<int64_t>(std::ceil(in)), in_size - 1);
104     interpolation[i].lerp = in - in_f;
105   }
106 }
107 
108 /**
109  * Computes the bilinear interpolation from the appropriate 4 float points
110  * and the linear interpolation weights.
111  */
compute_lerp(const float top_left,const float top_right,const float bottom_left,const float bottom_right,const float x_lerp,const float y_lerp)112 inline float compute_lerp(const float top_left, const float top_right,
113                           const float bottom_left, const float bottom_right,
114                           const float x_lerp, const float y_lerp) {
115   const float top = top_left + (top_right - top_left) * x_lerp;
116   const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
117   return top + (bottom - top) * y_lerp;
118 }
119 
120 #ifdef __SSE4_1__
121 /* Vector version of the above */
compute_lerp_v(const __m128 top_left,const __m128 top_right,const __m128 bottom_left,const __m128 bottom_right,const __m128 x_lerp,const __m128 y_lerp)122 inline __m128 compute_lerp_v(const __m128 top_left, const __m128 top_right,
123                              const __m128 bottom_left,
124                              const __m128 bottom_right, const __m128 x_lerp,
125                              const __m128 y_lerp) {
126   const __m128 top =
127       _mm_add_ps(top_left, _mm_mul_ps(_mm_sub_ps(top_right, top_left), x_lerp));
128   const __m128 bottom = _mm_add_ps(
129       bottom_left, _mm_mul_ps(_mm_sub_ps(bottom_right, bottom_left), x_lerp));
130   return _mm_add_ps(top, _mm_mul_ps(_mm_sub_ps(bottom, top), y_lerp));
131 }
132 #endif
133 
134 template <typename T>
ResizeLineChannels(const T * const ys_input_lower_ptr,const T * const ys_input_upper_ptr,const CachedInterpolation * const xs,const float ys_lerp,const int64_t out_width,float * out_y,const int channels)135 void ResizeLineChannels(const T* const ys_input_lower_ptr,
136                         const T* const ys_input_upper_ptr,
137                         const CachedInterpolation* const xs,
138                         const float ys_lerp, const int64_t out_width,
139                         float* out_y, const int channels) {
140   for (int64_t x = 0; x < out_width; ++x) {
141     const int64_t xs_lower = xs[x].lower;
142     const int64_t xs_upper = xs[x].upper;
143     const float xs_lerp = xs[x].lerp;
144 
145     for (int c = 0; c < channels; ++c) {
146       const float top_left(ys_input_lower_ptr[xs_lower + c]);
147       const float top_right(ys_input_lower_ptr[xs_upper + c]);
148       const float bottom_left(ys_input_upper_ptr[xs_lower + c]);
149       const float bottom_right(ys_input_upper_ptr[xs_upper + c]);
150 
151       out_y[x * channels + c] = compute_lerp(top_left, top_right, bottom_left,
152                                              bottom_right, xs_lerp, ys_lerp);
153     }
154   }
155 }
156 
157 #ifdef __SSE4_1__
158 
159 // Load 3 floats from the given buffer, which must be of size at least 4.
160 template <typename T>
load_3xfloat_v(T * values)161 inline __m128 load_3xfloat_v(T* values) {
162   return _mm_set_ps(0.0f, static_cast<float>(values[2]),
163                     static_cast<float>(values[1]),
164                     static_cast<float>(values[0]));
165 }
166 
167 // Specialize cases that can be done more efficiently.
168 template <>
load_3xfloat_v(float * values)169 inline __m128 load_3xfloat_v(float* values) {
170   return _mm_loadu_ps(values);
171 }
172 
173 template <typename T>
ResizeLine3ChannelsVector(const T * const ys_input_lower_ptr,const T * const ys_input_upper_ptr,const CachedInterpolation * const xs,const float ys_lerp,const int64_t out_width,float * out_y)174 void ResizeLine3ChannelsVector(const T* const ys_input_lower_ptr,
175                                const T* const ys_input_upper_ptr,
176                                const CachedInterpolation* const xs,
177                                const float ys_lerp, const int64_t out_width,
178                                float* out_y) {
179   const __m128 ys_lerp_v = _mm_set1_ps(ys_lerp);
180   // All pixels but the last one can overflow, vectorize the inside of the
181   // row.
182   int64_t x = 0;
183   for (x = 0; x < out_width - 1; ++x) {
184     const int64_t xs_lower = xs[x].lower;
185     const int64_t xs_upper = xs[x].upper;
186     const __m128 xs_lerp_v = _mm_set1_ps(xs[x].lerp);
187 
188     const __m128 top_left_v = load_3xfloat_v(ys_input_lower_ptr + xs_lower);
189     const __m128 top_right_v = load_3xfloat_v(ys_input_lower_ptr + xs_upper);
190     const __m128 bottom_left_v = load_3xfloat_v(ys_input_upper_ptr + xs_lower);
191     const __m128 bottom_right_v = load_3xfloat_v(ys_input_upper_ptr + xs_upper);
192 
193     _mm_storeu_ps(out_y + x * 3,
194                   compute_lerp_v(top_left_v, top_right_v, bottom_left_v,
195                                  bottom_right_v, xs_lerp_v, ys_lerp_v));
196   }
197   // The last pixel of each row must be done in a non-vectorized way
198   // because we cannot overflow.
199   ResizeLineChannels(ys_input_lower_ptr, ys_input_upper_ptr, xs + out_width - 1,
200                      ys_lerp, 1, out_y + (out_width - 1) * 3, 3);
201 }
202 #endif
203 
204 template <typename T>
205 void resize_image(
206     typename TTypes<T, 4>::ConstTensor images, const int batch_size,
207     const int64_t in_height, const int64_t in_width, const int64_t out_height,
208     const int64_t out_width, const int channels,
209     const std::vector<CachedInterpolation>& xs,
210     const std::vector<CachedInterpolation>& ys,
211     typename TTypes<float, 4>::Tensor output) TF_ATTRIBUTE_NOINLINE;
212 template <typename T>
resize_image(typename TTypes<T,4>::ConstTensor images,const int batch_size,const int64_t in_height,const int64_t in_width,const int64_t out_height,const int64_t out_width,const int channels,const std::vector<CachedInterpolation> & xs_vec,const std::vector<CachedInterpolation> & ys,typename TTypes<float,4>::Tensor output)213 void resize_image(typename TTypes<T, 4>::ConstTensor images,
214                   const int batch_size, const int64_t in_height,
215                   const int64_t in_width, const int64_t out_height,
216                   const int64_t out_width, const int channels,
217                   const std::vector<CachedInterpolation>& xs_vec,
218                   const std::vector<CachedInterpolation>& ys,
219                   typename TTypes<float, 4>::Tensor output) {
220   const int64_t in_row_size = in_width * channels;
221   const int64_t in_batch_num_values = in_height * in_row_size;
222   const int64_t out_row_size = out_width * channels;
223 
224   const T* input_b_ptr = images.data();
225   const CachedInterpolation* xs = xs_vec.data();
226 
227   if (channels == 3) {
228     float* output_y_ptr = output.data();
229     for (int b = 0; b < batch_size; ++b) {
230       for (int64_t y = 0; y < out_height; ++y) {
231         const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size;
232         const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size;
233 #ifdef __SSE4_1__
234         ResizeLine3ChannelsVector(ys_input_lower_ptr, ys_input_upper_ptr, xs,
235                                   ys[y].lerp, out_width, output_y_ptr);
236 #else
237         ResizeLineChannels(ys_input_lower_ptr, ys_input_upper_ptr, xs,
238                            ys[y].lerp, out_width, output_y_ptr, 3);
239 #endif
240         output_y_ptr += out_row_size;
241       }
242       input_b_ptr += in_batch_num_values;
243     }
244   } else {
245     float* output_y_ptr = output.data();
246     for (int b = 0; b < batch_size; ++b) {
247       for (int64_t y = 0; y < out_height; ++y) {
248         const T* ys_input_lower_ptr = input_b_ptr + ys[y].lower * in_row_size;
249         const T* ys_input_upper_ptr = input_b_ptr + ys[y].upper * in_row_size;
250 
251         ResizeLineChannels(ys_input_lower_ptr, ys_input_upper_ptr, xs,
252                            ys[y].lerp, out_width, output_y_ptr, channels);
253 
254         output_y_ptr += out_row_size;
255       }
256       input_b_ptr += in_batch_num_values;
257     }
258   }
259 }
260 
261 // Casts from float16 to T.
262 template <typename Device, typename T>
263 struct CastFloatTo {
operator ()tensorflow::__anon739419410111::CastFloatTo264   void operator()(const Device& d, typename TTypes<float>::ConstFlat input,
265                   typename TTypes<T>::Flat output) {
266     output.device(d) = input.template cast<T>();
267   }
268 };
269 
270 template <typename T>
271 struct CastFloatTo<GPUDevice, T> {
operator ()tensorflow::__anon739419410111::CastFloatTo272   void operator()(const GPUDevice& d, typename TTypes<float>::ConstFlat input,
273                   typename TTypes<T>::Flat output) {
274     // Use existing cast functor instead of directly casting Eigen tensor, as
275     // otherwise we need to instantiate the cast function in a .cu.cc file
276     functor::CastFunctor<GPUDevice, T, float> cast;
277     cast(d, output, input);
278   }
279 };
280 
281 }  // namespace
282 
283 // Partial specialization of ResizeBilinear functor for a CPUDevice.
284 namespace functor {
285 template <typename T>
286 struct ResizeBilinear<CPUDevice, T> {
operator ()tensorflow::functor::ResizeBilinear287   void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor images,
288                   const float height_scale, const float width_scale,
289                   bool half_pixel_centers,
290                   typename TTypes<float, 4>::Tensor output) {
291     const int batch_size = images.dimension(0);
292     const int64_t in_height = images.dimension(1);
293     const int64_t in_width = images.dimension(2);
294     const int channels = images.dimension(3);
295 
296     const int64_t out_height = output.dimension(1);
297     const int64_t out_width = output.dimension(2);
298 
299     // Handle no-op resizes efficiently.
300     if (out_height == in_height && out_width == in_width) {
301       output = images.template cast<float>();
302       return;
303     }
304 
305     std::vector<CachedInterpolation> ys(out_height + 1);
306     std::vector<CachedInterpolation> xs(out_width + 1);
307 
308     // Compute the cached interpolation weights on the x and y dimensions.
309     if (half_pixel_centers) {
310       compute_interpolation_weights(HalfPixelScaler(), out_height, in_height,
311                                     height_scale, ys.data());
312       compute_interpolation_weights(HalfPixelScaler(), out_width, in_width,
313                                     width_scale, xs.data());
314 
315     } else {
316       compute_interpolation_weights(LegacyScaler(), out_height, in_height,
317                                     height_scale, ys.data());
318       compute_interpolation_weights(LegacyScaler(), out_width, in_width,
319                                     width_scale, xs.data());
320     }
321     // Scale x interpolation weights to avoid a multiplication during iteration.
322     for (int i = 0; i < xs.size(); ++i) {
323       xs[i].lower *= channels;
324       xs[i].upper *= channels;
325     }
326 
327     resize_image<T>(images, batch_size, in_height, in_width, out_height,
328                     out_width, channels, xs, ys, output);
329   }
330 };
331 }  // namespace functor
332 
333 template <typename Device, typename T>
334 class ResizeBilinearOpGrad : public OpKernel {
335  public:
ResizeBilinearOpGrad(OpKernelConstruction * context)336   explicit ResizeBilinearOpGrad(OpKernelConstruction* context)
337       : OpKernel(context) {
338     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
339     OP_REQUIRES_OK(
340         context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
341   }
342 
Compute(OpKernelContext * context)343   void Compute(OpKernelContext* context) override {
344     // Validate input.
345     ImageResizerGradientState st(align_corners_, half_pixel_centers_);
346     st.ValidateAndCreateOutput(context);
347 
348     if (!context->status().ok()) return;
349 
350     // First argument is gradient with respect to resized image.
351     TTypes<float, 4>::ConstTensor input_grad =
352         context->input(0).tensor<float, 4>();
353 
354     if (!std::is_same<T, Eigen::half>::value &&
355         !std::is_same<T, Eigen::bfloat16>::value) {
356       typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
357       functor::ResizeBilinearGrad<Device, T>()(
358           context->eigen_device<Device>(), input_grad, st.height_scale,
359           st.width_scale, half_pixel_centers_, output_grad);
360     } else {
361       // Accumulate output to float instead of half/bfloat16 tensor, since float
362       // accumulation is more numerically stable and GPU half implementation is
363       // slow.
364       // TODO(b/165759037): Create optimized and numerically stable half and
365       // bfloat16 implementation
366       Tensor output_grad;
367       OP_REQUIRES_OK(context, context->allocate_temp(
368                                   DT_FLOAT, st.output->shape(), &output_grad));
369       functor::ResizeBilinearGrad<Device, float>()(
370           context->eigen_device<Device>(), input_grad, st.height_scale,
371           st.width_scale, half_pixel_centers_, output_grad.tensor<float, 4>());
372       const Tensor& output_grad_const = output_grad;
373       CastFloatTo<Device, T>{}(context->template eigen_device<Device>(),
374                                output_grad_const.template flat<float>(),
375                                st.output->template flat<T>());
376     }
377   }
378 
379  private:
380   bool align_corners_;
381   bool half_pixel_centers_;
382 };
383 
384 // Partial specialization of ResizeBilinearGrad functor for a CPUDevice.
385 namespace functor {
386 
387 template <typename T>
388 struct ResizeBilinearGrad<CPUDevice, T> {
389   template <typename Scaler>
ResizeGradCoretensorflow::functor::ResizeBilinearGrad390   void ResizeGradCore(const Scaler& scaler,
391                       typename TTypes<float, 4>::ConstTensor input_grad,
392                       const float height_scale, const float width_scale,
393                       typename TTypes<T, 4>::Tensor output_grad) {
394     const Eigen::Index batch = output_grad.dimension(0);
395     const Eigen::Index original_height = output_grad.dimension(1);
396     const Eigen::Index original_width = output_grad.dimension(2);
397     const Eigen::Index channels = output_grad.dimension(3);
398 
399     const Eigen::Index resized_height = input_grad.dimension(1);
400     const Eigen::Index resized_width = input_grad.dimension(2);
401 
402     output_grad.setZero();
403 
404     // Each resized output pixel was computed as a weighted average of four
405     // input pixels. Here we find the four input pixel locations that
406     // contributed to each output pixel and propagate the gradient at the output
407     // pixel location to each of those four input pixel locations in the same
408     // proportions that they originally contributed to the output pixel.
409     // Here is the forward-propagation pseudo-code, for reference:
410     // resized(b, y, x, c) = top_left     * (1 - y) * (1 - x)
411     //                     + top_right    * (1 - y) *      x
412     //                     + bottom_left  *      y  * (1 - x)
413     //                     + bottom_right *      y  *      x
414     for (Eigen::Index b = 0; b < batch; ++b) {
415       for (Eigen::Index y = 0; y < resized_height; ++y) {
416         const float in_y = scaler(y, height_scale);
417         const Eigen::Index top_y_index =
418             std::max(static_cast<Eigen::Index>(floorf(in_y)),
419                      static_cast<Eigen::Index>(0));
420         const Eigen::Index bottom_y_index = std::min(
421             static_cast<Eigen::Index>(ceilf(in_y)), original_height - 1);
422         const float y_lerp = in_y - floorf(in_y);
423         const float inverse_y_lerp = (1.0f - y_lerp);
424         for (Eigen::Index x = 0; x < resized_width; ++x) {
425           const float in_x = scaler(x, width_scale);
426           const Eigen::Index left_x_index =
427               std::max(static_cast<Eigen::Index>(floorf(in_x)),
428                        static_cast<Eigen::Index>(0));
429           const Eigen::Index right_x_index = std::min(
430               static_cast<Eigen::Index>(ceilf(in_x)), original_width - 1);
431           const float x_lerp = in_x - floorf(in_x);
432           const float inverse_x_lerp = (1.0f - x_lerp);
433           // TODO(b/158287314): Look into vectorizing this.
434           for (Eigen::Index c = 0; c < channels; ++c) {
435             output_grad(b, top_y_index, left_x_index, c) +=
436                 T(input_grad(b, y, x, c) * inverse_y_lerp * inverse_x_lerp);
437             output_grad(b, top_y_index, right_x_index, c) +=
438                 T(input_grad(b, y, x, c) * inverse_y_lerp * x_lerp);
439             output_grad(b, bottom_y_index, left_x_index, c) +=
440                 T(input_grad(b, y, x, c) * y_lerp * inverse_x_lerp);
441             output_grad(b, bottom_y_index, right_x_index, c) +=
442                 T(input_grad(b, y, x, c) * y_lerp * x_lerp);
443           }
444         }
445       }
446     }
447   }
operator ()tensorflow::functor::ResizeBilinearGrad448   void operator()(const CPUDevice& d,
449                   typename TTypes<float, 4>::ConstTensor input_grad,
450                   const float height_scale, const float width_scale,
451                   const bool half_pixel_centers,
452                   typename TTypes<T, 4>::Tensor output_grad) {
453     if (half_pixel_centers) {
454       return ResizeGradCore(HalfPixelScaler(), input_grad, height_scale,
455                             width_scale, output_grad);
456     } else {
457       return ResizeGradCore(LegacyScaler(), input_grad, height_scale,
458                             width_scale, output_grad);
459     }
460   }
461 };
462 
463 }  // namespace functor
464 
465 #define REGISTER_KERNEL(T)                            \
466   REGISTER_KERNEL_BUILDER(Name("ResizeBilinear")      \
467                               .Device(DEVICE_CPU)     \
468                               .TypeConstraint<T>("T") \
469                               .HostMemory("size"),    \
470                           ResizeBilinearOp<CPUDevice, T>);
471 
472 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
473 
474 #undef REGISTER_KERNEL
475 
476 #define REGISTER_GRAD_KERNEL(T)                                             \
477   REGISTER_KERNEL_BUILDER(                                                  \
478       Name("ResizeBilinearGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
479       ResizeBilinearOpGrad<CPUDevice, T>);
480 
481 TF_CALL_half(REGISTER_GRAD_KERNEL);
482 TF_CALL_float(REGISTER_GRAD_KERNEL);
483 TF_CALL_double(REGISTER_GRAD_KERNEL);
484 TF_CALL_bfloat16(REGISTER_GRAD_KERNEL);
485 
486 #undef REGISTER_GRAD_KERNEL
487 
488 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
489 
490 #define REGISTER_KERNEL(T)                            \
491   REGISTER_KERNEL_BUILDER(Name("ResizeBilinear")      \
492                               .Device(DEVICE_GPU)     \
493                               .TypeConstraint<T>("T") \
494                               .HostMemory("size"),    \
495                           ResizeBilinearOp<GPUDevice, T>);
496 
497 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
498 
499 #undef REGISTER_KERNEL
500 
501 #define REGISTER_GRAD_KERNEL(T)                                             \
502   REGISTER_KERNEL_BUILDER(                                                  \
503       Name("ResizeBilinearGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
504       ResizeBilinearOpGrad<GPUDevice, T>);
505 
506 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GRAD_KERNEL);
507 
508 #undef REGISTER_GRAD_KERNEL
509 
510 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
511 
512 }  // namespace tensorflow
513