xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/image/crop_and_resize_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/image/crop_and_resize_op.h"
21 
22 #include <functional>
23 #include <string>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/bounds_check.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_reference.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/determinism.h"
37 #include "tensorflow/core/util/work_sharder.h"
38 
39 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
40 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
41 #include "tensorflow/core/platform/stream_executor.h"
42 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
43 
44 #if GOOGLE_CUDA
45 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
46 using stream_executor::cuda::ScopedActivateExecutorContext;
47 #elif TENSORFLOW_USE_ROCM
48 #include "tensorflow/core/platform/rocm.h"
49 using stream_executor::rocm::ScopedActivateExecutorContext;
50 #endif
51 
52 namespace tensorflow {
53 namespace {
54 
55 typedef Eigen::ThreadPoolDevice CPUDevice;
56 typedef Eigen::GpuDevice GPUDevice;
57 using Callback = std::function<void()>;
58 
ParseAndCheckBoxSizes(const Tensor & boxes,const Tensor & box_index,int * num_boxes)59 static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
60                                            const Tensor& box_index,
61                                            int* num_boxes) {
62   if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
63     *num_boxes = 0;
64     return OkStatus();
65   }
66   // The shape of 'boxes' is [num_boxes, 4].
67   if (boxes.dims() != 2) {
68     return errors::InvalidArgument("boxes must be 2-D",
69                                    boxes.shape().DebugString());
70   }
71   *num_boxes = boxes.dim_size(0);
72   if (boxes.dim_size(1) != 4) {
73     return errors::InvalidArgument("boxes must have 4 columns");
74   }
75   // The shape of 'box_index' is [num_boxes].
76   if (box_index.dims() != 1) {
77     return errors::InvalidArgument("box_index must be 1-D",
78                                    box_index.shape().DebugString());
79   }
80   if (box_index.dim_size(0) != *num_boxes) {
81     return errors::InvalidArgument("box_index has incompatible shape");
82   }
83   return OkStatus();
84 }
85 
86 // Conditionally calls the compute callback if all values in box_index are in
87 // [0, batch_size) then calls done.
88 template <typename Device>
89 inline void RunIfBoxIndexIsValid(
90     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
91     int batch_size, const Callback& compute, const Callback& done);
92 
93 // Specialization of CheckValidBoxIndex for a CPUDevice.
94 template <>
RunIfBoxIndexIsValid(OpKernelContext * context,typename TTypes<int32,1>::ConstTensor box_index,int batch_size,const Callback & compute,const Callback & done)95 inline void RunIfBoxIndexIsValid<CPUDevice>(
96     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
97     int batch_size, const Callback& compute, const Callback& done) {
98   const int num_boxes = box_index.dimension(0);
99   for (int b = 0; b < num_boxes; ++b) {
100     OP_REQUIRES_ASYNC(
101         context, FastBoundsCheck(box_index(b), batch_size),
102         errors::OutOfRange("box_index has values outside [0, batch_size)"),
103         done);
104   }
105   if (compute) {
106     compute();
107   }
108   if (done) {
109     done();
110   }
111 }
112 
113 }  // namespace
114 
115 template <typename Device, typename T>
116 class CropAndResizeOp : public AsyncOpKernel {
117  public:
CropAndResizeOp(OpKernelConstruction * context)118   explicit CropAndResizeOp(OpKernelConstruction* context)
119       : AsyncOpKernel(context) {
120     OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
121     OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
122                 errors::InvalidArgument(
123                     "method must be 'bilinear' or 'nearest'", method_));
124     OP_REQUIRES_OK(context, context->GetAttr("extrapolation_value",
125                                              &extrapolation_value_));
126   }
127 
ComputeAsync(OpKernelContext * context,DoneCallback done)128   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
129     // The shape of 'image' is [batch_size, image_height, image_width,
130     // channels].
131     const Tensor& image = context->input(0);
132     // The shape of 'boxes' is [num_boxes, 4].
133     const Tensor& boxes = context->input(1);
134     // The shape of 'box_index' is [num_boxes].
135     const Tensor& box_index = context->input(2);
136     // The shape of 'crop_size' is [2].
137     const Tensor& crop_size = context->input(3);
138 
139     // Validate inputs dimensions.
140     OP_REQUIRES_ASYNC(context, image.dims() == 4,
141                       errors::InvalidArgument("input image must be 4-D",
142                                               image.shape().DebugString()),
143                       done);
144     const int batch_size = image.dim_size(0);
145     const int image_height = image.dim_size(1);
146     const int image_width = image.dim_size(2);
147     const int depth = image.dim_size(3);
148     OP_REQUIRES_ASYNC(
149         context, image_height > 0 && image_width > 0,
150         errors::InvalidArgument("image dimensions must be positive"), done);
151     int num_boxes = 0;
152     OP_REQUIRES_OK_ASYNC(
153         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
154 
155     OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
156                       errors::InvalidArgument("crop_size must be 1-D",
157                                               crop_size.shape().DebugString()),
158                       done);
159     OP_REQUIRES_ASYNC(
160         context, crop_size.dim_size(0) == 2,
161         errors::InvalidArgument("crop_size must have two elements",
162                                 crop_size.shape().DebugString()),
163         done);
164 
165     // Copy and validate crop sizes.
166     auto crop_size_vec = crop_size.vec<int32>();
167     const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
168     const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
169     OP_REQUIRES_ASYNC(
170         context, crop_height > 0 && crop_width > 0,
171         errors::InvalidArgument("crop dimensions must be positive"), done);
172 
173     TensorShape shape;
174     OP_REQUIRES_OK_ASYNC(context, shape.AddDimWithStatus(num_boxes), done);
175     OP_REQUIRES_OK_ASYNC(context, shape.AddDimWithStatus(crop_height), done);
176     OP_REQUIRES_OK_ASYNC(context, shape.AddDimWithStatus(crop_width), done);
177     OP_REQUIRES_OK_ASYNC(context, shape.AddDimWithStatus(depth), done);
178     // Allocate output tensor.
179     Tensor* output = nullptr;
180     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, shape, &output),
181                          done);
182 
183     auto compute_callback = [this, context, output]() {
184       const Tensor& image = context->input(0);
185       const Tensor& boxes = context->input(1);
186       const Tensor& box_index = context->input(2);
187       const bool status = functor::CropAndResize<Device, T>()(
188           context, image.tensor<T, 4>(), boxes.tensor<float, 2>(),
189           box_index.tensor<int32, 1>(), method_, extrapolation_value_,
190           output->tensor<float, 4>());
191 
192       if (!status) {
193         context->SetStatus(
194             errors::Internal("Failed to launch CropAndResizeKernel."));
195       }
196     };
197 
198     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
199                                  batch_size, std::move(compute_callback),
200                                  std::move(done));
201   }
202 
203  private:
204   float extrapolation_value_;
205   string method_;
206 };
207 
208 // Partial specialization of CropAndResize functor for a CPUDevice.
209 namespace functor {
210 template <typename T>
211 struct CropAndResize<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResize212   bool operator()(OpKernelContext* context,
213                   typename TTypes<T, 4>::ConstTensor image,
214                   typename TTypes<float, 2>::ConstTensor boxes,
215                   typename TTypes<int32, 1>::ConstTensor box_index,
216                   const string& method_name, float extrapolation_value,
217                   typename TTypes<float, 4>::Tensor crops) {
218     const int batch_size = image.dimension(0);
219     const int image_height = image.dimension(1);
220     const int image_width = image.dimension(2);
221 
222     const int num_boxes = crops.dimension(0);
223     const int crop_height = crops.dimension(1);
224     const int crop_width = crops.dimension(2);
225     const int depth = crops.dimension(3);
226 
227     // Since `functor::CropAndResize` operates on float, we first validate
228     // that we don't overflow (since overflow causes undefined behavior which
229     // could result in segfault in this scenario).
230     const Eigen::Tensor<bool, 0, Eigen::RowMajor> only_finite_elements =
231         boxes.isfinite().all();
232     if (!only_finite_elements()) {
233       context->SetStatus(errors::InvalidArgument(
234           "Boxes contains at least one element that is not finite"));
235       return false;
236     }
237 
238     // Sharding across boxes.
239     auto CropAndResizePerBox = [&](int64_t start_box, int64_t limit_box) {
240       for (int b = start_box; b < limit_box; ++b) {
241         const float y1 = boxes(b, 0);
242         const float x1 = boxes(b, 1);
243         const float y2 = boxes(b, 2);
244         const float x2 = boxes(b, 3);
245 
246         const int32_t b_in = box_index(b);
247         if (!FastBoundsCheck(b_in, batch_size)) {
248           continue;
249         }
250 
251         const float height_scale =
252             (crop_height > 1)
253                 ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
254                 : 0;
255         const float width_scale =
256             (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1)
257                              : 0;
258 
259         for (int y = 0; y < crop_height; ++y) {
260           const float in_y = (crop_height > 1)
261                                  ? y1 * (image_height - 1) + y * height_scale
262                                  : 0.5 * (y1 + y2) * (image_height - 1);
263           if (in_y < 0 || in_y > image_height - 1) {
264             for (int x = 0; x < crop_width; ++x) {
265               for (int d = 0; d < depth; ++d) {
266                 crops(b, y, x, d) = extrapolation_value;
267               }
268             }
269             continue;
270           }
271           if (method_name == "bilinear") {
272             const int top_y_index = floorf(in_y);
273             const int bottom_y_index = ceilf(in_y);
274             const float y_lerp = in_y - top_y_index;
275 
276             for (int x = 0; x < crop_width; ++x) {
277               const float in_x = (crop_width > 1)
278                                      ? x1 * (image_width - 1) + x * width_scale
279                                      : 0.5 * (x1 + x2) * (image_width - 1);
280               if (in_x < 0 || in_x > image_width - 1) {
281                 for (int d = 0; d < depth; ++d) {
282                   crops(b, y, x, d) = extrapolation_value;
283                 }
284                 continue;
285               }
286               const int left_x_index = floorf(in_x);
287               const int right_x_index = ceilf(in_x);
288               const float x_lerp = in_x - left_x_index;
289 
290               for (int d = 0; d < depth; ++d) {
291                 const float top_left(static_cast<float>(
292                     image(b_in, top_y_index, left_x_index, d)));
293                 const float top_right(static_cast<float>(
294                     image(b_in, top_y_index, right_x_index, d)));
295                 const float bottom_left(static_cast<float>(
296                     image(b_in, bottom_y_index, left_x_index, d)));
297                 const float bottom_right(static_cast<float>(
298                     image(b_in, bottom_y_index, right_x_index, d)));
299                 const float top = top_left + (top_right - top_left) * x_lerp;
300                 const float bottom =
301                     bottom_left + (bottom_right - bottom_left) * x_lerp;
302                 crops(b, y, x, d) = top + (bottom - top) * y_lerp;
303               }
304             }
305           } else {  // method == "nearest"
306             for (int x = 0; x < crop_width; ++x) {
307               const float in_x = (crop_width > 1)
308                                      ? x1 * (image_width - 1) + x * width_scale
309                                      : 0.5 * (x1 + x2) * (image_width - 1);
310               if (in_x < 0 || in_x > image_width - 1) {
311                 for (int d = 0; d < depth; ++d) {
312                   crops(b, y, x, d) = extrapolation_value;
313                 }
314                 continue;
315               }
316               const int closest_x_index = roundf(in_x);
317               const int closest_y_index = roundf(in_y);
318               for (int d = 0; d < depth; ++d) {
319                 crops(b, y, x, d) = static_cast<float>(
320                     image(b_in, closest_y_index, closest_x_index, d));
321               }
322             }
323           }
324         }
325       }
326     };
327 
328     // A rough estimation of the cost for each cropped box.
329     double cost_per_pixel =
330         depth * (Eigen::TensorOpCost::AddCost<float>() * 6 +
331                  Eigen::TensorOpCost::MulCost<float>() * 3 +
332                  Eigen::TensorOpCost::CastCost<T, float>() * 4) +
333         (Eigen::TensorOpCost::AddCost<float>() * 2 +
334          Eigen::TensorOpCost::AddCost<float>() * 3);
335     if (method_name == "nearest") {
336       cost_per_pixel = depth * Eigen::TensorOpCost::CastCost<T, float>() +
337                        Eigen::TensorOpCost::AddCost<float>() * 4 +
338                        Eigen::TensorOpCost::MulCost<float>() * 4;
339     }
340     const double cost_per_box = crop_height * crop_width * cost_per_pixel;
341 
342     const DeviceBase::CpuWorkerThreads& worker_threads =
343         *(context->device()->tensorflow_cpu_worker_threads());
344     Shard(worker_threads.num_threads, worker_threads.workers, num_boxes,
345           cost_per_box, CropAndResizePerBox);
346 
347     return true;
348   }
349 };
350 
351 }  // namespace functor
352 
353 template <typename Device, typename T>
354 class CropAndResizeGradImageOp : public AsyncOpKernel {
355  public:
CropAndResizeGradImageOp(OpKernelConstruction * context)356   explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
357       : AsyncOpKernel(context) {
358     OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
359     OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
360                 errors::InvalidArgument(
361                     "method must be 'bilinear' or 'nearest'", method_));
362   }
363 
ComputeAsync(OpKernelContext * context,DoneCallback done)364   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
365     // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
366     const Tensor& grads = context->input(0);
367     // The shape of 'boxes' is [num_boxes, 4].
368     const Tensor& boxes = context->input(1);
369     // The shape of 'box_index' is [num_boxes].
370     const Tensor& box_index = context->input(2);
371     // The shape of 'image_size' is [4].
372     const Tensor& image_size = context->input(3);
373 
374     // Validate input shapes.
375     OP_REQUIRES_ASYNC(context, grads.dims() == 4,
376                       errors::InvalidArgument("grads image must be 4-D",
377                                               grads.shape().DebugString()),
378                       done);
379     const int crop_height = grads.dim_size(1);
380     const int crop_width = grads.dim_size(2);
381     OP_REQUIRES_ASYNC(
382         context, crop_height > 0 && crop_width > 0,
383         errors::InvalidArgument("grads dimensions must be positive"), done);
384     int num_boxes = 0;
385     OP_REQUIRES_OK_ASYNC(
386         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
387     OP_REQUIRES_ASYNC(
388         context, grads.dim_size(0) == num_boxes,
389         errors::InvalidArgument("boxes and grads have incompatible shape"),
390         done);
391 
392     OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
393                       errors::InvalidArgument("image_size must be 1-D",
394                                               image_size.shape().DebugString()),
395                       done);
396     OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
397                       errors::InvalidArgument("image_size must have 4 elements",
398                                               image_size.shape().DebugString()),
399                       done);
400     auto image_size_vec = image_size.vec<int32>();
401     const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
402     const int image_height = internal::SubtleMustCopy(image_size_vec(1));
403     const int image_width = internal::SubtleMustCopy(image_size_vec(2));
404     const int depth = internal::SubtleMustCopy(image_size_vec(3));
405     OP_REQUIRES_ASYNC(
406         context, image_height > 0 && image_width > 0,
407         errors::InvalidArgument("image dimensions must be positive"), done);
408     OP_REQUIRES_ASYNC(
409         context, grads.dim_size(3) == depth,
410         errors::InvalidArgument("image_size and grads are incompatible"), done);
411 
412     if (std::is_same<Device, GPUDevice>::value) {
413       OP_REQUIRES_ASYNC(
414           context, !OpDeterminismRequired(),
415           errors::Unimplemented(
416               "Deterministic GPU implementation of CropAndResizeBackpropImage"
417               " not available."),
418           done);
419     }
420 
421     TensorShape shape;
422     OP_REQUIRES_OK_ASYNC(context, shape.AddDimWithStatus(batch_size), done);
423     OP_REQUIRES_OK_ASYNC(context, shape.AddDimWithStatus(image_height), done);
424     OP_REQUIRES_OK_ASYNC(context, shape.AddDimWithStatus(image_width), done);
425     OP_REQUIRES_OK_ASYNC(context, shape.AddDimWithStatus(depth), done);
426     // Allocate output tensor.
427     Tensor* output = nullptr;
428     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, shape, &output),
429                          done);
430 
431     auto compute_callback = [this, context, output]() {
432       const Tensor& grads = context->input(0);
433       const Tensor& boxes = context->input(1);
434       const Tensor& box_index = context->input(2);
435       const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
436           context, grads.tensor<float, 4>(), boxes.tensor<float, 2>(),
437           box_index.tensor<int32, 1>(), output->tensor<T, 4>(), method_);
438 
439       if (!status) {
440         context->SetStatus(errors::Internal(
441             "Failed to launch CropAndResizeBackpropImage kernel."));
442       }
443     };
444 
445     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
446                                  batch_size, std::move(compute_callback),
447                                  std::move(done));
448   }
449 
450  private:
451   string method_;
452 };
453 
454 // Partial specialization of CropAndResizeBackpropImage functor for a CPUDevice.
455 namespace functor {
456 template <typename T>
457 struct CropAndResizeBackpropImage<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResizeBackpropImage458   bool operator()(const OpKernelContext* context,
459                   typename TTypes<float, 4>::ConstTensor grads,
460                   typename TTypes<float, 2>::ConstTensor boxes,
461                   typename TTypes<int32, 1>::ConstTensor box_index,
462                   typename TTypes<T, 4>::Tensor grads_image,
463                   const string& method_name) {
464     const int batch_size = grads_image.dimension(0);
465     const int image_height = grads_image.dimension(1);
466     const int image_width = grads_image.dimension(2);
467 
468     const int num_boxes = grads.dimension(0);
469     const int crop_height = grads.dimension(1);
470     const int crop_width = grads.dimension(2);
471     const int depth = grads.dimension(3);
472 
473     grads_image.setZero();
474 
475     auto CropAndResizeBackImgPerBox = [&](int64_t start_box,
476                                           int64_t limit_box) {
477       for (int b = start_box; b < limit_box; ++b) {
478         const float y1 = boxes(b, 0);
479         const float x1 = boxes(b, 1);
480         const float y2 = boxes(b, 2);
481         const float x2 = boxes(b, 3);
482 
483         const int32_t b_in = box_index(b);
484         if (!FastBoundsCheck(b_in, batch_size)) {
485           continue;
486         }
487 
488         const float height_scale =
489             (crop_height > 1)
490                 ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
491                 : 0;
492         const float width_scale =
493             (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1)
494                              : 0;
495 
496         for (int y = 0; y < crop_height; ++y) {
497           const float in_y = (crop_height > 1)
498                                  ? y1 * (image_height - 1) + y * height_scale
499                                  : 0.5 * (y1 + y2) * (image_height - 1);
500           if (in_y < 0 || in_y > image_height - 1) {
501             continue;
502           }
503           const int top_y_index = floorf(in_y);
504           const int bottom_y_index = ceilf(in_y);
505           const float y_lerp = in_y - top_y_index;
506 
507           for (int x = 0; x < crop_width; ++x) {
508             const float in_x = (crop_width > 1)
509                                    ? x1 * (image_width - 1) + x * width_scale
510                                    : 0.5 * (x1 + x2) * (image_width - 1);
511             if (in_x < 0 || in_x > image_width - 1) {
512               continue;
513             }
514 
515             if (method_name == "bilinear") {
516               const int left_x_index = floorf(in_x);
517               const int right_x_index = ceilf(in_x);
518               const float x_lerp = in_x - left_x_index;
519 
520               for (int d = 0; d < depth; ++d) {
521                 const float dtop = (1 - y_lerp) * grads(b, y, x, d);
522                 grads_image(b_in, top_y_index, left_x_index, d) +=
523                     static_cast<T>((1 - x_lerp) * dtop);
524                 grads_image(b_in, top_y_index, right_x_index, d) +=
525                     static_cast<T>(x_lerp * dtop);
526                 const float dbottom = y_lerp * grads(b, y, x, d);
527                 grads_image(b_in, bottom_y_index, left_x_index, d) +=
528                     static_cast<T>((1 - x_lerp) * dbottom);
529                 grads_image(b_in, bottom_y_index, right_x_index, d) +=
530                     static_cast<T>(x_lerp * dbottom);
531               }
532             } else {  // method_name == "nearest"
533               for (int d = 0; d < depth; ++d) {
534                 int closest_x_index = roundf(in_x);
535                 int closest_y_index = roundf(in_y);
536                 grads_image(b_in, closest_y_index, closest_x_index, d) +=
537                     static_cast<T>(grads(b, y, x, d));
538               }
539             }
540           }
541         }
542       }
543     };
544 
545     // A rough estimation of the cost for each cropped box.
546     // Including calculation cost in the depth loop and pixel loop.
547     const double cost_per_pixel =
548         (method_name == "bilinear"
549              ? depth * (Eigen::TensorOpCost::AddCost<float>() * 7 +
550                         Eigen::TensorOpCost::MulCost<float>() * 6 +
551                         Eigen::TensorOpCost::CastCost<T, float>() * 4) +
552                    Eigen::TensorOpCost::AddCost<float>() * 4
553              : depth * (Eigen::TensorOpCost::AddCost<float>() +
554                         Eigen::TensorOpCost::CastCost<T, float>()) +
555                    Eigen::TensorOpCost::AddCost<float>() * 3);
556 
557     const double cost_per_box = crop_height * crop_width * cost_per_pixel;
558 
559     const DeviceBase::CpuWorkerThreads& worker_threads =
560         *(context->device()->tensorflow_cpu_worker_threads());
561 
562     // Sharding introduces nondeterminism when the gradients associated with
563     // more than two crops backprop into the same element in the source image.
564     int max_threads = OpDeterminismRequired() ? 1 : worker_threads.num_threads;
565 
566     Shard(max_threads, worker_threads.workers, num_boxes, cost_per_box,
567           CropAndResizeBackImgPerBox);
568 
569     return true;
570   }
571 };
572 
573 }  // namespace functor
574 
575 template <typename Device, typename T>
576 class CropAndResizeGradBoxesOp : public AsyncOpKernel {
577  public:
CropAndResizeGradBoxesOp(OpKernelConstruction * context)578   explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
579       : AsyncOpKernel(context) {
580     string method;
581     OP_REQUIRES_OK(context, context->GetAttr("method", &method));
582     OP_REQUIRES(context, method == "bilinear",
583                 errors::InvalidArgument("method must be 'bilinear'", method));
584   }
585 
ComputeAsync(OpKernelContext * context,DoneCallback done)586   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
587     // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
588     const Tensor& grads = context->input(0);
589     // The shape of 'boxes' is [num_boxes, 4].
590     const Tensor& boxes = context->input(2);
591     // The shape of 'box_index' is [num_boxes].
592     const Tensor& box_index = context->input(3);
593     // The shape of 'image' is [batch_size, image_height, image_width, depth].
594     const Tensor& image = context->input(1);
595 
596     // Validate input shapes.
597     OP_REQUIRES_ASYNC(context, grads.dims() == 4,
598                       errors::InvalidArgument("grads image must be 4-D",
599                                               grads.shape().DebugString()),
600                       done);
601     const int crop_height = grads.dim_size(1);
602     const int crop_width = grads.dim_size(2);
603     const int depth = grads.dim_size(3);
604     OP_REQUIRES_ASYNC(
605         context, crop_height > 0 && crop_width > 0,
606         errors::InvalidArgument("grads dimensions must be positive"), done);
607 
608     OP_REQUIRES_ASYNC(context, image.dims() == 4,
609                       errors::InvalidArgument("input image must be 4-D",
610                                               image.shape().DebugString()),
611                       done);
612     const int batch_size = image.dim_size(0);
613     const int image_height = image.dim_size(1);
614     const int image_width = image.dim_size(2);
615     OP_REQUIRES_ASYNC(
616         context, image_height > 0 && image_width > 0,
617         errors::InvalidArgument("image dimensions must be positive"), done);
618     OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
619                       errors::InvalidArgument("image, grads depth differ"),
620                       done);
621 
622     int num_boxes = 0;
623     OP_REQUIRES_OK_ASYNC(
624         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
625 
626     OP_REQUIRES_ASYNC(
627         context, grads.dim_size(0) == num_boxes,
628         errors::InvalidArgument("boxes and grads have incompatible shape"),
629         done);
630 
631     if (std::is_same<Device, GPUDevice>::value) {
632       OP_REQUIRES_ASYNC(
633           context, !OpDeterminismRequired(),
634           errors::Unimplemented(
635               "Deterministic GPU implementation of CropAndResizeBackpropBoxes"
636               " not available."),
637           done);
638     }
639 
640     // Allocate output tensor.
641     Tensor* output = nullptr;
642     OP_REQUIRES_OK_ASYNC(
643         context,
644         context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
645         done);
646 
647     auto compute_callback = [context, output]() {
648       const Tensor& grads = context->input(0);
649       const Tensor& image = context->input(1);
650       const Tensor& boxes = context->input(2);
651       const Tensor& box_index = context->input(3);
652       const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
653           context->eigen_device<Device>(), grads.tensor<float, 4>(),
654           image.tensor<T, 4>(), boxes.tensor<float, 2>(),
655           box_index.tensor<int32, 1>(), output->tensor<float, 2>());
656       if (!status) {
657         context->SetStatus(errors::Internal(
658             "Failed to launch CropAndResizeBackpropBoxes kernel."));
659       }
660     };
661 
662     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
663                                  batch_size, std::move(compute_callback),
664                                  std::move(done));
665   }
666 };
667 
668 // Partial specialization of CropAndResizeBackpropBoxes functor for a CPUDevice.
669 namespace functor {
670 template <typename T>
671 struct CropAndResizeBackpropBoxes<CPUDevice, T> {
operator ()tensorflow::functor::CropAndResizeBackpropBoxes672   bool operator()(const CPUDevice& d,
673                   typename TTypes<float, 4>::ConstTensor grads,
674                   typename TTypes<T, 4>::ConstTensor image,
675                   typename TTypes<float, 2>::ConstTensor boxes,
676                   typename TTypes<int32, 1>::ConstTensor box_index,
677                   typename TTypes<float, 2>::Tensor grads_boxes) {
678     const int batch_size = image.dimension(0);
679     const int image_height = image.dimension(1);
680     const int image_width = image.dimension(2);
681 
682     const int num_boxes = grads.dimension(0);
683     const int crop_height = grads.dimension(1);
684     const int crop_width = grads.dimension(2);
685     const int depth = grads.dimension(3);
686 
687     grads_boxes.setZero();
688 
689     for (int b = 0; b < num_boxes; ++b) {
690       const float y1 = boxes(b, 0);
691       const float x1 = boxes(b, 1);
692       const float y2 = boxes(b, 2);
693       const float x2 = boxes(b, 3);
694 
695       const int32_t b_in = box_index(b);
696       if (!FastBoundsCheck(b_in, batch_size)) {
697         continue;
698       }
699 
700       const float height_ratio =
701           (crop_height > 1)
702               ? static_cast<float>(image_height - 1) / (crop_height - 1)
703               : 0;
704       const float width_ratio =
705           (crop_width > 1)
706               ? static_cast<float>(image_width - 1) / (crop_width - 1)
707               : 0;
708 
709       const float height_scale =
710           (crop_height > 1) ? (y2 - y1) * height_ratio : 0;
711       const float width_scale = (crop_width > 1) ? (x2 - x1) * width_ratio : 0;
712 
713       for (int y = 0; y < crop_height; ++y) {
714         const float in_y = (crop_height > 1)
715                                ? y1 * (image_height - 1) + y * height_scale
716                                : 0.5 * (y1 + y2) * (image_height - 1);
717         if (in_y < 0 || in_y > image_height - 1) {
718           continue;
719         }
720         const int top_y_index = floorf(in_y);
721         const int bottom_y_index = ceilf(in_y);
722         const float y_lerp = in_y - top_y_index;
723 
724         for (int x = 0; x < crop_width; ++x) {
725           const float in_x = (crop_width > 1)
726                                  ? x1 * (image_width - 1) + x * width_scale
727                                  : 0.5 * (x1 + x2) * (image_width - 1);
728           if (in_x < 0 || in_x > image_width - 1) {
729             continue;
730           }
731           const int left_x_index = floorf(in_x);
732           const int right_x_index = ceilf(in_x);
733           const float x_lerp = in_x - left_x_index;
734 
735           for (int d = 0; d < depth; ++d) {
736             const float top_left(
737                 static_cast<float>(image(b_in, top_y_index, left_x_index, d)));
738             const float top_right(
739                 static_cast<float>(image(b_in, top_y_index, right_x_index, d)));
740             const float bottom_left(static_cast<float>(
741                 image(b_in, bottom_y_index, left_x_index, d)));
742             const float bottom_right(static_cast<float>(
743                 image(b_in, bottom_y_index, right_x_index, d)));
744             // Compute the image gradient.
745             float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) +
746                                  x_lerp * (bottom_right - top_right);
747             float image_grad_x = (1 - y_lerp) * (top_right - top_left) +
748                                  y_lerp * (bottom_right - bottom_left);
749             // Modulate the image gradient with the incoming gradient.
750             const float top_grad = grads(b, y, x, d);
751             image_grad_y *= top_grad;
752             image_grad_x *= top_grad;
753             // dy1, dy2
754             if (crop_height > 1) {
755               grads_boxes(b, 0) +=
756                   image_grad_y * (image_height - 1 - y * height_ratio);
757               grads_boxes(b, 2) += image_grad_y * (y * height_ratio);
758             } else {
759               grads_boxes(b, 0) += image_grad_y * 0.5 * (image_height - 1);
760               grads_boxes(b, 2) += image_grad_y * 0.5 * (image_height - 1);
761             }
762             // dx1, dx2
763             if (crop_width > 1) {
764               grads_boxes(b, 1) +=
765                   image_grad_x * (image_width - 1 - x * width_ratio);
766               grads_boxes(b, 3) += image_grad_x * (x * width_ratio);
767             } else {
768               grads_boxes(b, 1) += image_grad_x * 0.5 * (image_width - 1);
769               grads_boxes(b, 3) += image_grad_x * 0.5 * (image_width - 1);
770             }
771           }
772         }
773       }
774     }
775     return true;
776   }
777 };
778 
779 }  // namespace functor
780 
781 #define REGISTER_KERNEL(T)                                \
782   REGISTER_KERNEL_BUILDER(Name("CropAndResize")           \
783                               .Device(DEVICE_CPU)         \
784                               .TypeConstraint<T>("T")     \
785                               .HostMemory("crop_size"),   \
786                           CropAndResizeOp<CPUDevice, T>); \
787                                                           \
788   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes")  \
789                               .Device(DEVICE_CPU)         \
790                               .TypeConstraint<T>("T"),    \
791                           CropAndResizeGradBoxesOp<CPUDevice, T>);
792 
793 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
794 
795 #undef REGISTER_KERNEL
796 
797 #define REGISTER_KERNEL(T)                               \
798   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
799                               .Device(DEVICE_CPU)        \
800                               .TypeConstraint<T>("T")    \
801                               .HostMemory("image_size"), \
802                           CropAndResizeGradImageOp<CPUDevice, T>);
803 
804 TF_CALL_half(REGISTER_KERNEL);
805 TF_CALL_float(REGISTER_KERNEL);
806 TF_CALL_double(REGISTER_KERNEL);
807 
808 #undef REGISTER_KERNEL
809 
810 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
811 
812 // Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
813 namespace functor {
814 template <>
815 void CheckValidBoxIndexHelper<GPUDevice>::operator()(
816     const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
817     int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
818 extern template struct CheckValidBoxIndexHelper<GPUDevice>;
819 }  // namespace functor
820 
821 namespace {
822 
823 // Specialization of CheckValidBoxIndex for a GPUDevice.
824 template <>
RunIfBoxIndexIsValid(OpKernelContext * context,typename TTypes<int32,1>::ConstTensor box_index,int batch_size,const Callback & compute,const Callback & done)825 inline void RunIfBoxIndexIsValid<GPUDevice>(
826     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
827     int batch_size, const Callback& compute, const Callback& done) {
828   const int num_boxes = box_index.dimension(0);
829   if (num_boxes == 0) {
830     compute();
831     done();
832     return;
833   }
834 
835   Tensor isvalid_dev_tensor;
836   OP_REQUIRES_OK_ASYNC(
837       context,
838       context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
839                              &isvalid_dev_tensor),
840       done);
841   typename TTypes<bool, 0>::Tensor isvalid_dev =
842       isvalid_dev_tensor.tensor<bool, 0>();
843 
844   // Run the actual box check on the device.
845   functor::CheckValidBoxIndexHelper<GPUDevice>()(
846       context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
847 
848   // Copy the result back to the host.
849   auto* stream = context->op_device_context()->stream();
850   OP_REQUIRES_ASYNC(context, stream,
851                     errors::Internal("No GPU stream available."), done);
852   Tensor isvalid_host_tensor;
853   // Use pinned host memory on the host to avoid unnecessary
854   // synchronization.
855   AllocatorAttributes alloc_attr;
856   alloc_attr.set_on_host(true);
857   alloc_attr.set_gpu_compatible(true);
858   OP_REQUIRES_OK_ASYNC(
859       context,
860       context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
861                              &isvalid_host_tensor, alloc_attr),
862       done);
863   se::DeviceMemoryBase wrapped(isvalid_dev.data(), sizeof(bool));
864   const bool status =
865       stream
866           ->ThenMemcpy(
867               isvalid_host_tensor.scalar<bool>().data() /* destination */,
868               wrapped /* source */, sizeof(bool))
869           .ok();
870   OP_REQUIRES_ASYNC(
871       context, status,
872       errors::Internal("Failed to launch copy of isvalid from device to host."),
873       done);
874 
875   // We capture both temporary tensors to prevent them from being deallocated
876   // when ComputeAsync returns and before the closure runs.
877   TensorReference isvalid_dev_ref(isvalid_dev_tensor);
878   auto wrapped_callback = [context, isvalid_host_tensor, isvalid_dev_ref,
879                            compute, done]() {
880     auto stream = context->op_device_context()->stream();
881     ScopedActivateExecutorContext scoped_activation{stream->parent()};
882     const bool isvalid = isvalid_host_tensor.scalar<bool>()();
883     isvalid_dev_ref.Unref();
884     OP_REQUIRES_ASYNC(
885         context, isvalid,
886         errors::OutOfRange("box_index has values outside [0, batch_size)"),
887         done);
888     compute();
889     done();
890   };
891 
892   context->device()
893       ->tensorflow_accelerator_device_info()
894       ->event_mgr->ThenExecute(stream, wrapped_callback);
895 }
896 
897 }  // namespace
898 
899 #define REGISTER_KERNEL(T)                                         \
900   REGISTER_KERNEL_BUILDER(Name("CropAndResize")                    \
901                               .Device(DEVICE_GPU)                  \
902                               .TypeConstraint<T>("T")              \
903                               .HostMemory("crop_size"),            \
904                           CropAndResizeOp<GPUDevice, T>);          \
905                                                                    \
906   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage")           \
907                               .Device(DEVICE_GPU)                  \
908                               .TypeConstraint<T>("T")              \
909                               .HostMemory("image_size"),           \
910                           CropAndResizeGradImageOp<GPUDevice, T>); \
911                                                                    \
912   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes")           \
913                               .Device(DEVICE_GPU)                  \
914                               .TypeConstraint<T>("T"),             \
915                           CropAndResizeGradBoxesOp<GPUDevice, T>);
916 
917 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
918 
919 #undef REGISTER_KERNEL
920 
921 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
922 
923 }  // namespace tensorflow
924