xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/image/non_max_suppression_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/image_ops.cc
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/image/non_max_suppression_op.h"
21 
22 #include <cmath>
23 #include <functional>
24 #include <queue>
25 #include <vector>
26 
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/bounds_check.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace tensorflow {
38 namespace {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 
CheckScoreSizes(OpKernelContext * context,int num_boxes,const Tensor & scores)42 static inline void CheckScoreSizes(OpKernelContext* context, int num_boxes,
43                                    const Tensor& scores) {
44   // The shape of 'scores' is [num_boxes]
45   OP_REQUIRES(context, scores.dims() == 1,
46               errors::InvalidArgument(
47                   "scores must be 1-D", scores.shape().DebugString(),
48                   " (Shape must be rank 1 but is rank ", scores.dims(), ")"));
49   OP_REQUIRES(
50       context, scores.dim_size(0) == num_boxes,
51       errors::InvalidArgument("scores has incompatible shape (Dimensions must "
52                               "be equal, but are ",
53                               num_boxes, " and ", scores.dim_size(0), ")"));
54 }
55 
ParseAndCheckOverlapSizes(OpKernelContext * context,const Tensor & overlaps,int * num_boxes)56 static inline void ParseAndCheckOverlapSizes(OpKernelContext* context,
57                                              const Tensor& overlaps,
58                                              int* num_boxes) {
59   // the shape of 'overlaps' is [num_boxes, num_boxes]
60   OP_REQUIRES(context, overlaps.dims() == 2,
61               errors::InvalidArgument("overlaps must be 2-D",
62                                       overlaps.shape().DebugString()));
63 
64   *num_boxes = overlaps.dim_size(0);
65   OP_REQUIRES(context, overlaps.dim_size(1) == *num_boxes,
66               errors::InvalidArgument("overlaps must be square",
67                                       overlaps.shape().DebugString()));
68 }
69 
ParseAndCheckBoxSizes(OpKernelContext * context,const Tensor & boxes,int * num_boxes)70 static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
71                                          const Tensor& boxes, int* num_boxes) {
72   // The shape of 'boxes' is [num_boxes, 4]
73   OP_REQUIRES(context, boxes.dims() == 2,
74               errors::InvalidArgument(
75                   "boxes must be 2-D", boxes.shape().DebugString(),
76                   " (Shape must be rank 2 but is rank ", boxes.dims(), ")"));
77   *num_boxes = boxes.dim_size(0);
78   OP_REQUIRES(context, boxes.dim_size(1) == 4,
79               errors::InvalidArgument("boxes must have 4 columns (Dimension "
80                                       "must be 4 but is ",
81                                       boxes.dim_size(1), ")"));
82 }
83 
CheckCombinedNMSScoreSizes(OpKernelContext * context,int num_boxes,const Tensor & scores)84 static inline void CheckCombinedNMSScoreSizes(OpKernelContext* context,
85                                               int num_boxes,
86                                               const Tensor& scores) {
87   // The shape of 'scores' is [batch_size, num_boxes, num_classes]
88   OP_REQUIRES(context, scores.dims() == 3,
89               errors::InvalidArgument("scores must be 3-D",
90                                       scores.shape().DebugString()));
91   OP_REQUIRES(context, scores.dim_size(1) == num_boxes,
92               errors::InvalidArgument("scores has incompatible shape"));
93 }
94 
ParseAndCheckCombinedNMSBoxSizes(OpKernelContext * context,const Tensor & boxes,int * num_boxes,const int num_classes)95 static inline void ParseAndCheckCombinedNMSBoxSizes(OpKernelContext* context,
96                                                     const Tensor& boxes,
97                                                     int* num_boxes,
98                                                     const int num_classes) {
99   // The shape of 'boxes' is [batch_size, num_boxes, q, 4]
100   OP_REQUIRES(context, boxes.dims() == 4,
101               errors::InvalidArgument("boxes must be 4-D",
102                                       boxes.shape().DebugString()));
103 
104   bool box_check = boxes.dim_size(2) == 1 || boxes.dim_size(2) == num_classes;
105   OP_REQUIRES(context, box_check,
106               errors::InvalidArgument(
107                   "third dimension of boxes must be either 1 or num classes"));
108   *num_boxes = boxes.dim_size(1);
109   OP_REQUIRES(context, boxes.dim_size(3) == 4,
110               errors::InvalidArgument("boxes must have 4 columns"));
111 }
112 // Return intersection-over-union overlap between boxes i and j
113 template <typename T>
IOU(typename TTypes<T,2>::ConstTensor boxes,int i,int j)114 static inline float IOU(typename TTypes<T, 2>::ConstTensor boxes, int i,
115                         int j) {
116   const float ymin_i = Eigen::numext::mini<float>(
117       static_cast<float>(boxes(i, 0)), static_cast<float>(boxes(i, 2)));
118   const float xmin_i = Eigen::numext::mini<float>(
119       static_cast<float>(boxes(i, 1)), static_cast<float>(boxes(i, 3)));
120   const float ymax_i = Eigen::numext::maxi<float>(
121       static_cast<float>(boxes(i, 0)), static_cast<float>(boxes(i, 2)));
122   const float xmax_i = Eigen::numext::maxi<float>(
123       static_cast<float>(boxes(i, 1)), static_cast<float>(boxes(i, 3)));
124   const float ymin_j = Eigen::numext::mini<float>(
125       static_cast<float>(boxes(j, 0)), static_cast<float>(boxes(j, 2)));
126   const float xmin_j = Eigen::numext::mini<float>(
127       static_cast<float>(boxes(j, 1)), static_cast<float>(boxes(j, 3)));
128   const float ymax_j = Eigen::numext::maxi<float>(
129       static_cast<float>(boxes(j, 0)), static_cast<float>(boxes(j, 2)));
130   const float xmax_j = Eigen::numext::maxi<float>(
131       static_cast<float>(boxes(j, 1)), static_cast<float>(boxes(j, 3)));
132   const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
133   const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
134   if (area_i <= 0 || area_j <= 0) {
135     return 0.0;
136   }
137   const float intersection_ymin = Eigen::numext::maxi<float>(ymin_i, ymin_j);
138   const float intersection_xmin = Eigen::numext::maxi<float>(xmin_i, xmin_j);
139   const float intersection_ymax = Eigen::numext::mini<float>(ymax_i, ymax_j);
140   const float intersection_xmax = Eigen::numext::mini<float>(xmax_i, xmax_j);
141   const float intersection_area =
142       Eigen::numext::maxi<float>(intersection_ymax - intersection_ymin, 0.0) *
143       Eigen::numext::maxi<float>(intersection_xmax - intersection_xmin, 0.0);
144   return intersection_area / (area_i + area_j - intersection_area);
145 }
146 
IOU(const float * boxes,int i,int j)147 static inline float IOU(const float* boxes, int i, int j) {
148   const float ymin_i = Eigen::numext::mini<float>(boxes[i], boxes[i + 2]);
149   const float xmin_i = Eigen::numext::mini<float>(boxes[i + 1], boxes[i + 3]);
150   const float ymax_i = Eigen::numext::maxi<float>(boxes[i], boxes[i + 2]);
151   const float xmax_i = Eigen::numext::maxi<float>(boxes[i + 1], boxes[i + 3]);
152   const float ymin_j = Eigen::numext::mini<float>(boxes[j], boxes[j + 2]);
153   const float xmin_j = Eigen::numext::mini<float>(boxes[j + 1], boxes[j + 3]);
154   const float ymax_j = Eigen::numext::maxi<float>(boxes[j], boxes[j + 2]);
155   const float xmax_j = Eigen::numext::maxi<float>(boxes[j + 1], boxes[j + 3]);
156   const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
157   const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
158   if (area_i <= 0 || area_j <= 0) {
159     return 0.0;
160   }
161   const float intersection_ymin = Eigen::numext::maxi<float>(ymin_i, ymin_j);
162   const float intersection_xmin = Eigen::numext::maxi<float>(xmin_i, xmin_j);
163   const float intersection_ymax = Eigen::numext::mini<float>(ymax_i, ymax_j);
164   const float intersection_xmax = Eigen::numext::mini<float>(xmax_i, xmax_j);
165   const float intersection_area =
166       Eigen::numext::maxi<float>(intersection_ymax - intersection_ymin, 0.0) *
167       Eigen::numext::maxi<float>(intersection_xmax - intersection_xmin, 0.0);
168   return intersection_area / (area_i + area_j - intersection_area);
169 }
170 
171 template <typename T>
Overlap(typename TTypes<T,2>::ConstTensor overlaps,int i,int j)172 static inline T Overlap(typename TTypes<T, 2>::ConstTensor overlaps, int i,
173                         int j) {
174   return overlaps(i, j);
175 }
176 
177 template <typename T>
CreateIOUSimilarityFn(const Tensor & boxes)178 static inline std::function<float(int, int)> CreateIOUSimilarityFn(
179     const Tensor& boxes) {
180   typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>();
181   return std::bind(&IOU<T>, boxes_data, std::placeholders::_1,
182                    std::placeholders::_2);
183 }
184 
185 template <typename T>
CreateOverlapSimilarityFn(const Tensor & overlaps)186 static inline std::function<T(int, int)> CreateOverlapSimilarityFn(
187     const Tensor& overlaps) {
188   typename TTypes<T, 2>::ConstTensor overlaps_data =
189       overlaps.tensor<float, 2>();
190   return std::bind(&Overlap<T>, overlaps_data, std::placeholders::_1,
191                    std::placeholders::_2);
192 }
193 
194 template <typename T>
DoNonMaxSuppressionOp(OpKernelContext * context,const Tensor & scores,int num_boxes,const Tensor & max_output_size,const T similarity_threshold,const T score_threshold,const T soft_nms_sigma,const std::function<float (int,int)> & similarity_fn,bool return_scores_tensor=false,bool pad_to_max_output_size=false,int * ptr_num_valid_outputs=nullptr)195 void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
196                            int num_boxes, const Tensor& max_output_size,
197                            const T similarity_threshold,
198                            const T score_threshold, const T soft_nms_sigma,
199                            const std::function<float(int, int)>& similarity_fn,
200                            bool return_scores_tensor = false,
201                            bool pad_to_max_output_size = false,
202                            int* ptr_num_valid_outputs = nullptr) {
203   const int output_size = max_output_size.scalar<int>()();
204   OP_REQUIRES(context, output_size >= 0,
205               errors::InvalidArgument("output size must be non-negative"));
206 
207   std::vector<T> scores_data(num_boxes);
208   std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin());
209 
210   // Data structure for a selection candidate in NMS.
211   struct Candidate {
212     int box_index;
213     T score;
214     int suppress_begin_index;
215   };
216 
217   auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
218     return ((bs_i.score == bs_j.score) && (bs_i.box_index > bs_j.box_index)) ||
219            bs_i.score < bs_j.score;
220   };
221   std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
222       candidate_priority_queue(cmp);
223   for (int i = 0; i < scores_data.size(); ++i) {
224     if (scores_data[i] > score_threshold) {
225       candidate_priority_queue.emplace(Candidate({i, scores_data[i], 0}));
226     }
227   }
228 
229   T scale = static_cast<T>(0.0);
230   bool is_soft_nms = soft_nms_sigma > static_cast<T>(0.0);
231   if (is_soft_nms) {
232     scale = static_cast<T>(-0.5) / soft_nms_sigma;
233   }
234 
235   auto suppress_weight = [similarity_threshold, scale,
236                           is_soft_nms](const T sim) {
237     const T weight = Eigen::numext::exp<T>(scale * sim * sim);
238     return is_soft_nms || sim <= similarity_threshold ? weight
239                                                       : static_cast<T>(0.0);
240   };
241 
242   std::vector<int> selected;
243   std::vector<T> selected_scores;
244   float similarity;
245   T original_score;
246   Candidate next_candidate;
247 
248   while (selected.size() < output_size && !candidate_priority_queue.empty()) {
249     next_candidate = candidate_priority_queue.top();
250     original_score = next_candidate.score;
251     candidate_priority_queue.pop();
252 
253     // Overlapping boxes are likely to have similar scores, therefore we
254     // iterate through the previously selected boxes backwards in order to
255     // see if `next_candidate` should be suppressed. We also enforce a property
256     // that a candidate can be suppressed by another candidate no more than
257     // once via `suppress_begin_index` which tracks which previously selected
258     // boxes have already been compared against next_candidate prior to a given
259     // iteration.  These previous selected boxes are then skipped over in the
260     // following loop.
261     bool should_hard_suppress = false;
262     for (int j = static_cast<int>(selected.size()) - 1;
263          j >= next_candidate.suppress_begin_index; --j) {
264       similarity = similarity_fn(next_candidate.box_index, selected[j]);
265 
266       next_candidate.score *= suppress_weight(static_cast<T>(similarity));
267 
268       // First decide whether to perform hard suppression
269       if (!is_soft_nms && static_cast<T>(similarity) > similarity_threshold) {
270         should_hard_suppress = true;
271         break;
272       }
273 
274       // If next_candidate survives hard suppression, apply soft suppression
275       if (next_candidate.score <= score_threshold) break;
276     }
277     // If `next_candidate.score` has not dropped below `score_threshold`
278     // by this point, then we know that we went through all of the previous
279     // selections and can safely update `suppress_begin_index` to
280     // `selected.size()`. If on the other hand `next_candidate.score`
281     // *has* dropped below the score threshold, then since `suppress_weight`
282     // always returns values in [0, 1], further suppression by items that were
283     // not covered in the above for loop would not have caused the algorithm
284     // to select this item. We thus do the same update to
285     // `suppress_begin_index`, but really, this element will not be added back
286     // into the priority queue in the following.
287     next_candidate.suppress_begin_index = selected.size();
288 
289     if (!should_hard_suppress) {
290       if (next_candidate.score == original_score) {
291         // Suppression has not occurred, so select next_candidate
292         selected.push_back(next_candidate.box_index);
293         selected_scores.push_back(next_candidate.score);
294         continue;
295       }
296       if (next_candidate.score > score_threshold) {
297         // Soft suppression has occurred and current score is still greater than
298         // score_threshold; add next_candidate back onto priority queue.
299         candidate_priority_queue.push(next_candidate);
300       }
301     }
302   }
303 
304   int num_valid_outputs = selected.size();
305   if (pad_to_max_output_size) {
306     selected.resize(output_size, 0);
307     selected_scores.resize(output_size, static_cast<T>(0));
308   }
309   if (ptr_num_valid_outputs) {
310     *ptr_num_valid_outputs = num_valid_outputs;
311   }
312 
313   // Allocate output tensors
314   Tensor* output_indices = nullptr;
315   TensorShape output_shape({static_cast<int>(selected.size())});
316   OP_REQUIRES_OK(context,
317                  context->allocate_output(0, output_shape, &output_indices));
318   TTypes<int, 1>::Tensor output_indices_data = output_indices->tensor<int, 1>();
319   std::copy_n(selected.begin(), selected.size(), output_indices_data.data());
320 
321   if (return_scores_tensor) {
322     Tensor* output_scores = nullptr;
323     OP_REQUIRES_OK(context,
324                    context->allocate_output(1, output_shape, &output_scores));
325     typename TTypes<T, 1>::Tensor output_scores_data =
326         output_scores->tensor<T, 1>();
327     std::copy_n(selected_scores.begin(), selected_scores.size(),
328                 output_scores_data.data());
329   }
330 }
331 
332 struct ResultCandidate {
333   int box_index;
334   float score;
335   int class_idx;
336   float box_coord[4];
337 };
338 
DoNMSPerClass(int batch_idx,int class_idx,const float * boxes_data,const float * scores_data,int num_boxes,int q,int num_classes,const int size_per_class,const float score_threshold,const float iou_threshold,std::vector<ResultCandidate> & result_candidate_vec)339 void DoNMSPerClass(int batch_idx, int class_idx, const float* boxes_data,
340                    const float* scores_data, int num_boxes, int q,
341                    int num_classes, const int size_per_class,
342                    const float score_threshold, const float iou_threshold,
343                    std::vector<ResultCandidate>& result_candidate_vec) {
344   // Do NMS, get the candidate indices of form vector<int>
345   // Data structure for selection candidate in NMS.
346   struct Candidate {
347     int box_index;
348     float score;
349   };
350   auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
351     return bs_i.score < bs_j.score;
352   };
353   std::priority_queue<Candidate, std::vector<Candidate>, decltype(cmp)>
354       candidate_priority_queue(cmp);
355   float temp_score;
356   for (int i = 0; i < num_boxes; ++i) {
357     temp_score = scores_data[i * num_classes + class_idx];
358     if (temp_score > score_threshold) {
359       candidate_priority_queue.emplace(Candidate({i, temp_score}));
360     }
361   }
362 
363   std::vector<int> selected;
364   Candidate next_candidate;
365 
366   int candidate_box_data_idx, selected_box_data_idx, class_box_idx;
367   class_box_idx = (q > 1) ? class_idx : 0;
368 
369   float iou;
370   while (selected.size() < size_per_class &&
371          !candidate_priority_queue.empty()) {
372     next_candidate = candidate_priority_queue.top();
373     candidate_priority_queue.pop();
374 
375     candidate_box_data_idx = (next_candidate.box_index * q + class_box_idx) * 4;
376 
377     // Overlapping boxes are likely to have similar scores,
378     // therefore we iterate through the previously selected boxes backwards
379     // in order to see if `next_candidate` should be suppressed.
380     bool should_select = true;
381     for (int j = selected.size() - 1; j >= 0; --j) {
382       selected_box_data_idx = (selected[j] * q + class_box_idx) * 4;
383       iou = IOU(boxes_data, candidate_box_data_idx, selected_box_data_idx);
384       if (iou > iou_threshold) {
385         should_select = false;
386         break;
387       }
388     }
389 
390     if (should_select) {
391       // Add the selected box to the result candidate. Sorted by score
392       result_candidate_vec[selected.size() + size_per_class * class_idx] = {
393           next_candidate.box_index,
394           next_candidate.score,
395           class_idx,
396           {boxes_data[candidate_box_data_idx],
397            boxes_data[candidate_box_data_idx + 1],
398            boxes_data[candidate_box_data_idx + 2],
399            boxes_data[candidate_box_data_idx + 3]}};
400       selected.push_back(next_candidate.box_index);
401     }
402   }
403 }
404 
SelectResultPerBatch(std::vector<float> & nmsed_boxes,std::vector<float> & nmsed_scores,std::vector<float> & nmsed_classes,std::vector<ResultCandidate> & result_candidate_vec,std::vector<int> & final_valid_detections,const int batch_idx,int total_size_per_batch,bool pad_per_class,int max_size_per_batch,bool clip_boxes,int per_batch_size)405 void SelectResultPerBatch(std::vector<float>& nmsed_boxes,
406                           std::vector<float>& nmsed_scores,
407                           std::vector<float>& nmsed_classes,
408                           std::vector<ResultCandidate>& result_candidate_vec,
409                           std::vector<int>& final_valid_detections,
410                           const int batch_idx, int total_size_per_batch,
411                           bool pad_per_class, int max_size_per_batch,
412                           bool clip_boxes, int per_batch_size) {
413   auto rc_cmp = [](const ResultCandidate rc_i, const ResultCandidate rc_j) {
414     return rc_i.score > rc_j.score;
415   };
416   std::sort(result_candidate_vec.begin(), result_candidate_vec.end(), rc_cmp);
417 
418   int max_detections = 0;
419   int result_candidate_size =
420       std::count_if(result_candidate_vec.begin(), result_candidate_vec.end(),
421                     [](ResultCandidate rc) { return rc.box_index > -1; });
422   // If pad_per_class is false, we always pad to max_total_size
423   if (!pad_per_class) {
424     max_detections = std::min(result_candidate_size, total_size_per_batch);
425   } else {
426     max_detections = std::min(per_batch_size, result_candidate_size);
427   }
428 
429   final_valid_detections[batch_idx] = max_detections;
430 
431   int curr_total_size = max_detections;
432   int result_idx = 0;
433   // Pick the top max_detections values
434   while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
435     ResultCandidate next_candidate = result_candidate_vec[result_idx++];
436     // Add to final output vectors
437     if (clip_boxes) {
438       const float box_min = 0.0;
439       const float box_max = 1.0;
440       nmsed_boxes.push_back(
441           std::max(std::min(next_candidate.box_coord[0], box_max), box_min));
442       nmsed_boxes.push_back(
443           std::max(std::min(next_candidate.box_coord[1], box_max), box_min));
444       nmsed_boxes.push_back(
445           std::max(std::min(next_candidate.box_coord[2], box_max), box_min));
446       nmsed_boxes.push_back(
447           std::max(std::min(next_candidate.box_coord[3], box_max), box_min));
448     } else {
449       nmsed_boxes.push_back(next_candidate.box_coord[0]);
450       nmsed_boxes.push_back(next_candidate.box_coord[1]);
451       nmsed_boxes.push_back(next_candidate.box_coord[2]);
452       nmsed_boxes.push_back(next_candidate.box_coord[3]);
453     }
454     nmsed_scores.push_back(next_candidate.score);
455     nmsed_classes.push_back(next_candidate.class_idx);
456     curr_total_size--;
457   }
458 
459   nmsed_boxes.resize(per_batch_size * 4, 0);
460   nmsed_scores.resize(per_batch_size, 0);
461   nmsed_classes.resize(per_batch_size, 0);
462 }
463 
BatchedNonMaxSuppressionOp(OpKernelContext * context,const Tensor & inp_boxes,const Tensor & inp_scores,int num_boxes,const int max_size_per_class,const int total_size_per_batch,const float score_threshold,const float iou_threshold,bool pad_per_class=false,bool clip_boxes=true)464 void BatchedNonMaxSuppressionOp(
465     OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores,
466     int num_boxes, const int max_size_per_class, const int total_size_per_batch,
467     const float score_threshold, const float iou_threshold,
468     bool pad_per_class = false, bool clip_boxes = true) {
469   const int num_batches = inp_boxes.dim_size(0);
470   int num_classes = inp_scores.dim_size(2);
471   int q = inp_boxes.dim_size(2);
472 
473   const float* scores_data =
474       const_cast<float*>(inp_scores.flat<float>().data());
475   const float* boxes_data = const_cast<float*>(inp_boxes.flat<float>().data());
476 
477   int boxes_per_batch = num_boxes * q * 4;
478   int scores_per_batch = num_boxes * num_classes;
479   const int size_per_class = std::min(max_size_per_class, num_boxes);
480   std::vector<std::vector<ResultCandidate>> result_candidate_vec(
481       num_batches,
482       std::vector<ResultCandidate>(size_per_class * num_classes,
483                                    {-1, -1.0, -1, {0.0, 0.0, 0.0, 0.0}}));
484 
485   // [num_batches, per_batch_size * 4]
486   std::vector<std::vector<float>> nmsed_boxes(num_batches);
487   // [num_batches, per_batch_size]
488   std::vector<std::vector<float>> nmsed_scores(num_batches);
489   // [num_batches, per_batch_size]
490   std::vector<std::vector<float>> nmsed_classes(num_batches);
491   // [num_batches]
492   std::vector<int> final_valid_detections(num_batches);
493 
494   auto shard_nms = [&](int begin, int end) {
495     for (int idx = begin; idx < end; ++idx) {
496       int batch_idx = idx / num_classes;
497       int class_idx = idx % num_classes;
498       DoNMSPerClass(batch_idx, class_idx,
499                     boxes_data + boxes_per_batch * batch_idx,
500                     scores_data + scores_per_batch * batch_idx, num_boxes, q,
501                     num_classes, size_per_class, score_threshold, iou_threshold,
502                     result_candidate_vec[batch_idx]);
503     }
504   };
505 
506   int length = num_batches * num_classes;
507   // Input data boxes_data, scores_data
508   int input_bytes = num_boxes * 10 * sizeof(float);
509   int output_bytes = num_boxes * 10 * sizeof(float);
510   int compute_cycles = Eigen::TensorOpCost::AddCost<int>() * num_boxes * 14 +
511                        Eigen::TensorOpCost::MulCost<int>() * num_boxes * 9 +
512                        Eigen::TensorOpCost::MulCost<float>() * num_boxes * 9 +
513                        Eigen::TensorOpCost::AddCost<float>() * num_boxes * 8;
514   // The cost here is not the actual number of cycles, but rather a set of
515   // hand-tuned numbers that seem to work best.
516   const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
517   const CPUDevice& d = context->eigen_device<CPUDevice>();
518   d.parallelFor(length, cost, shard_nms);
519 
520   int per_batch_size = total_size_per_batch;
521   if (pad_per_class) {
522     per_batch_size =
523         std::min(total_size_per_batch, max_size_per_class * num_classes);
524   }
525 
526   Tensor* valid_detections_t = nullptr;
527   TensorShape valid_detections_shape({num_batches});
528   OP_REQUIRES_OK(context, context->allocate_output(3, valid_detections_shape,
529                                                    &valid_detections_t));
530   auto valid_detections_flat = valid_detections_t->template flat<int>();
531 
532   auto shard_result = [&](int begin, int end) {
533     for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
534       SelectResultPerBatch(
535           nmsed_boxes[batch_idx], nmsed_scores[batch_idx],
536           nmsed_classes[batch_idx], result_candidate_vec[batch_idx],
537           final_valid_detections, batch_idx, total_size_per_batch,
538           pad_per_class, max_size_per_class * num_classes, clip_boxes,
539           per_batch_size);
540       valid_detections_flat(batch_idx) = final_valid_detections[batch_idx];
541     }
542   };
543   length = num_batches;
544   // Input data boxes_data, scores_data
545   input_bytes =
546       num_boxes * 10 * sizeof(float) + per_batch_size * 6 * sizeof(float);
547   output_bytes =
548       num_boxes * 5 * sizeof(float) + per_batch_size * 6 * sizeof(float);
549   compute_cycles = Eigen::TensorOpCost::AddCost<int>() * num_boxes * 5 +
550                    Eigen::TensorOpCost::AddCost<float>() * num_boxes * 5;
551   // The cost here is not the actual number of cycles, but rather a set of
552   // hand-tuned numbers that seem to work best.
553   const Eigen::TensorOpCost cost_result(input_bytes, output_bytes,
554                                         compute_cycles);
555   d.parallelFor(length, cost_result, shard_result);
556 
557   Tensor* nmsed_boxes_t = nullptr;
558   TensorShape boxes_shape({num_batches, per_batch_size, 4});
559   OP_REQUIRES_OK(context,
560                  context->allocate_output(0, boxes_shape, &nmsed_boxes_t));
561   auto nmsed_boxes_flat = nmsed_boxes_t->template flat<float>();
562 
563   Tensor* nmsed_scores_t = nullptr;
564   TensorShape scores_shape({num_batches, per_batch_size});
565   OP_REQUIRES_OK(context,
566                  context->allocate_output(1, scores_shape, &nmsed_scores_t));
567   auto nmsed_scores_flat = nmsed_scores_t->template flat<float>();
568 
569   Tensor* nmsed_classes_t = nullptr;
570   OP_REQUIRES_OK(context,
571                  context->allocate_output(2, scores_shape, &nmsed_classes_t));
572   auto nmsed_classes_flat = nmsed_classes_t->template flat<float>();
573 
574   auto shard_copy_result = [&](int begin, int end) {
575     for (int idx = begin; idx < end; ++idx) {
576       int batch_idx = idx / per_batch_size;
577       int j = idx % per_batch_size;
578       nmsed_scores_flat(idx) = nmsed_scores[batch_idx][j];
579       nmsed_classes_flat(idx) = nmsed_classes[batch_idx][j];
580       for (int k = 0; k < 4; ++k) {
581         nmsed_boxes_flat(idx * 4 + k) = nmsed_boxes[batch_idx][j * 4 + k];
582       }
583     }
584   };
585   length = num_batches * per_batch_size;
586   // Input data boxes_data, scores_data
587   input_bytes = 6 * sizeof(float);
588   output_bytes = 6 * sizeof(float);
589   compute_cycles = Eigen::TensorOpCost::AddCost<int>() * 2 +
590                    Eigen::TensorOpCost::MulCost<int>() * 2 +
591                    Eigen::TensorOpCost::DivCost<float>() * 2;
592   const Eigen::TensorOpCost cost_copy_result(input_bytes, output_bytes,
593                                              compute_cycles);
594   d.parallelFor(length, cost_copy_result, shard_copy_result);
595 }
596 
597 // Extract a scalar of type T from a tensor, with correct type checking.
598 // This is necessary because several of the kernels here assume
599 // T == T_threshold.
600 template <typename T>
GetScalar(const Tensor & tensor)601 T GetScalar(const Tensor& tensor) {
602   switch (tensor.dtype()) {
603     case DT_FLOAT:
604       return static_cast<T>(tensor.scalar<float>()());
605     case DT_DOUBLE:
606       return static_cast<T>(tensor.scalar<double>()());
607     case DT_BFLOAT16:
608       return static_cast<T>(tensor.scalar<Eigen::bfloat16>()());
609     case DT_HALF:
610       return static_cast<T>(tensor.scalar<Eigen::half>()());
611     default:
612       DCHECK(false) << "Unsupported type " << tensor.dtype();
613       break;
614   }
615   return static_cast<T>(0);
616 }
617 
618 }  // namespace
619 
620 template <typename Device>
621 class NonMaxSuppressionOp : public OpKernel {
622  public:
NonMaxSuppressionOp(OpKernelConstruction * context)623   explicit NonMaxSuppressionOp(OpKernelConstruction* context)
624       : OpKernel(context) {
625     OP_REQUIRES_OK(context, context->GetAttr("iou_threshold", &iou_threshold_));
626   }
627 
Compute(OpKernelContext * context)628   void Compute(OpKernelContext* context) override {
629     // boxes: [num_boxes, 4]
630     const Tensor& boxes = context->input(0);
631     // scores: [num_boxes]
632     const Tensor& scores = context->input(1);
633     // max_output_size: scalar
634     const Tensor& max_output_size = context->input(2);
635     OP_REQUIRES(
636         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
637         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
638                                 max_output_size.shape().DebugString()));
639 
640     OP_REQUIRES(context, iou_threshold_ >= 0 && iou_threshold_ <= 1,
641                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
642     int num_boxes = 0;
643     ParseAndCheckBoxSizes(context, boxes, &num_boxes);
644     CheckScoreSizes(context, num_boxes, scores);
645     if (!context->status().ok()) {
646       return;
647     }
648     auto similarity_fn = CreateIOUSimilarityFn<float>(boxes);
649 
650     const float score_threshold_val = std::numeric_limits<float>::lowest();
651     const float dummy_soft_nms_sigma = static_cast<float>(0.0);
652     DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
653                                  iou_threshold_, score_threshold_val,
654                                  dummy_soft_nms_sigma, similarity_fn);
655   }
656 
657  private:
658   float iou_threshold_;
659 };
660 
661 template <typename Device, typename T>
662 class NonMaxSuppressionV2Op : public OpKernel {
663  public:
NonMaxSuppressionV2Op(OpKernelConstruction * context)664   explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
665       : OpKernel(context) {}
666 
Compute(OpKernelContext * context)667   void Compute(OpKernelContext* context) override {
668     // boxes: [num_boxes, 4]
669     const Tensor& boxes = context->input(0);
670     // scores: [num_boxes]
671     const Tensor& scores = context->input(1);
672     // max_output_size: scalar
673     const Tensor& max_output_size = context->input(2);
674     OP_REQUIRES(
675         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
676         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
677                                 max_output_size.shape().DebugString()));
678     // iou_threshold: scalar
679     const Tensor& iou_threshold = context->input(3);
680     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
681                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
682                                         iou_threshold.shape().DebugString()));
683     const T iou_threshold_val = GetScalar<T>(iou_threshold);
684 
685     OP_REQUIRES(context,
686                 iou_threshold_val >= static_cast<T>(0.0) &&
687                     iou_threshold_val <= static_cast<T>(1.0),
688                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
689     int num_boxes = 0;
690     ParseAndCheckBoxSizes(context, boxes, &num_boxes);
691     CheckScoreSizes(context, num_boxes, scores);
692     if (!context->status().ok()) {
693       return;
694     }
695     auto similarity_fn = CreateIOUSimilarityFn<T>(boxes);
696 
697     const T score_threshold_val = std::numeric_limits<T>::lowest();
698     const T dummy_soft_nms_sigma = static_cast<T>(0.0);
699     DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
700                              iou_threshold_val, score_threshold_val,
701                              dummy_soft_nms_sigma, similarity_fn);
702   }
703 };
704 
705 template <typename Device, typename T>
706 class NonMaxSuppressionV3Op : public OpKernel {
707  public:
NonMaxSuppressionV3Op(OpKernelConstruction * context)708   explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
709       : OpKernel(context) {}
710 
Compute(OpKernelContext * context)711   void Compute(OpKernelContext* context) override {
712     // boxes: [num_boxes, 4]
713     const Tensor& boxes = context->input(0);
714     // scores: [num_boxes]
715     const Tensor& scores = context->input(1);
716     // max_output_size: scalar
717     const Tensor& max_output_size = context->input(2);
718     OP_REQUIRES(
719         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
720         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
721                                 max_output_size.shape().DebugString(),
722                                 " (Shape must be rank 0 but is ", "rank ",
723                                 max_output_size.dims(), ")"));
724     // iou_threshold: scalar
725     const Tensor& iou_threshold = context->input(3);
726     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
727                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
728                                         iou_threshold.shape().DebugString(),
729                                         " (Shape must be rank 0 but is rank ",
730                                         iou_threshold.dims(), ")"));
731     const T iou_threshold_val = GetScalar<T>(iou_threshold);
732     OP_REQUIRES(context,
733                 iou_threshold_val >= static_cast<T>(0.0) &&
734                     iou_threshold_val <= static_cast<T>(1.0),
735                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
736     // score_threshold: scalar
737     const Tensor& score_threshold = context->input(4);
738     OP_REQUIRES(
739         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
740         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
741                                 score_threshold.shape().DebugString()));
742     const T score_threshold_val = GetScalar<T>(score_threshold);
743 
744     int num_boxes = 0;
745     ParseAndCheckBoxSizes(context, boxes, &num_boxes);
746     CheckScoreSizes(context, num_boxes, scores);
747     if (!context->status().ok()) {
748       return;
749     }
750 
751     auto similarity_fn = CreateIOUSimilarityFn<T>(boxes);
752 
753     const T dummy_soft_nms_sigma = static_cast<T>(0.0);
754     DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
755                              iou_threshold_val, score_threshold_val,
756                              dummy_soft_nms_sigma, similarity_fn);
757   }
758 };
759 
760 template <typename Device, typename T>
761 class NonMaxSuppressionV4Op : public OpKernel {
762  public:
NonMaxSuppressionV4Op(OpKernelConstruction * context)763   explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
764       : OpKernel(context) {
765     OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
766                                              &pad_to_max_output_size_));
767   }
768 
Compute(OpKernelContext * context)769   void Compute(OpKernelContext* context) override {
770     // boxes: [num_boxes, 4]
771     const Tensor& boxes = context->input(0);
772     // scores: [num_boxes]
773     const Tensor& scores = context->input(1);
774     // max_output_size: scalar
775     const Tensor& max_output_size = context->input(2);
776     OP_REQUIRES(
777         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
778         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
779                                 max_output_size.shape().DebugString()));
780     // iou_threshold: scalar
781     const Tensor& iou_threshold = context->input(3);
782     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
783                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
784                                         iou_threshold.shape().DebugString()));
785     const T iou_threshold_val = GetScalar<T>(iou_threshold);
786     OP_REQUIRES(context,
787                 iou_threshold_val >= static_cast<T>(0.0) &&
788                     iou_threshold_val <= static_cast<T>(1.0),
789                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
790     // score_threshold: scalar
791     const Tensor& score_threshold = context->input(4);
792     OP_REQUIRES(
793         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
794         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
795                                 score_threshold.shape().DebugString()));
796     const T score_threshold_val = GetScalar<T>(score_threshold);
797 
798     int num_boxes = 0;
799     ParseAndCheckBoxSizes(context, boxes, &num_boxes);
800     CheckScoreSizes(context, num_boxes, scores);
801     if (!context->status().ok()) {
802       return;
803     }
804 
805     auto similarity_fn = CreateIOUSimilarityFn<T>(boxes);
806     int num_valid_outputs;
807 
808     bool return_scores_tensor_ = false;
809     const T dummy_soft_nms_sigma = static_cast<T>(0.0);
810     DoNonMaxSuppressionOp<T>(
811         context, scores, num_boxes, max_output_size, iou_threshold_val,
812         score_threshold_val, dummy_soft_nms_sigma, similarity_fn,
813         return_scores_tensor_, pad_to_max_output_size_, &num_valid_outputs);
814     if (!context->status().ok()) {
815       return;
816     }
817 
818     // Allocate scalar output tensor for number of indices computed.
819     Tensor* num_outputs_t = nullptr;
820     OP_REQUIRES_OK(context, context->allocate_output(
821                                 1, tensorflow::TensorShape{}, &num_outputs_t));
822     num_outputs_t->scalar<int32>().setConstant(num_valid_outputs);
823   }
824 
825  private:
826   bool pad_to_max_output_size_;
827 };
828 
829 template <typename Device, typename T>
830 class NonMaxSuppressionV5Op : public OpKernel {
831  public:
NonMaxSuppressionV5Op(OpKernelConstruction * context)832   explicit NonMaxSuppressionV5Op(OpKernelConstruction* context)
833       : OpKernel(context) {
834     OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
835                                              &pad_to_max_output_size_));
836   }
837 
Compute(OpKernelContext * context)838   void Compute(OpKernelContext* context) override {
839     // boxes: [num_boxes, 4]
840     const Tensor& boxes = context->input(0);
841     // scores: [num_boxes]
842     const Tensor& scores = context->input(1);
843     // max_output_size: scalar
844     const Tensor& max_output_size = context->input(2);
845     OP_REQUIRES(
846         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
847         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
848                                 max_output_size.shape().DebugString()));
849     // iou_threshold: scalar
850     const Tensor& iou_threshold = context->input(3);
851     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
852                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
853                                         iou_threshold.shape().DebugString()));
854     const T iou_threshold_val = iou_threshold.scalar<T>()();
855     OP_REQUIRES(context,
856                 iou_threshold_val >= static_cast<T>(0.0) &&
857                     iou_threshold_val <= static_cast<T>(1.0),
858                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
859     // score_threshold: scalar
860     const Tensor& score_threshold = context->input(4);
861     OP_REQUIRES(
862         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
863         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
864                                 score_threshold.shape().DebugString()));
865     const T score_threshold_val = score_threshold.scalar<T>()();
866 
867     // soft_nms_sigma: scalar
868     const Tensor& soft_nms_sigma = context->input(5);
869     OP_REQUIRES(
870         context, TensorShapeUtils::IsScalar(soft_nms_sigma.shape()),
871         errors::InvalidArgument("soft_nms_sigma must be 0-D, got shape ",
872                                 soft_nms_sigma.shape().DebugString()));
873     const T soft_nms_sigma_val = soft_nms_sigma.scalar<T>()();
874     OP_REQUIRES(context, soft_nms_sigma_val >= static_cast<T>(0.0),
875                 errors::InvalidArgument("soft_nms_sigma_val must be >= 0"));
876 
877     int num_boxes = 0;
878     ParseAndCheckBoxSizes(context, boxes, &num_boxes);
879     CheckScoreSizes(context, num_boxes, scores);
880     if (!context->status().ok()) {
881       return;
882     }
883 
884     auto similarity_fn = CreateIOUSimilarityFn<T>(boxes);
885     int num_valid_outputs;
886 
887     // For NonMaxSuppressionV5Op, we always return a second output holding
888     // corresponding scores, so `return_scores_tensor` should never be false.
889     const bool return_scores_tensor_ = true;
890     DoNonMaxSuppressionOp<T>(
891         context, scores, num_boxes, max_output_size, iou_threshold_val,
892         score_threshold_val, soft_nms_sigma_val, similarity_fn,
893         return_scores_tensor_, pad_to_max_output_size_, &num_valid_outputs);
894     if (!context->status().ok()) {
895       return;
896     }
897 
898     // Allocate scalar output tensor for number of indices computed.
899     Tensor* num_outputs_t = nullptr;
900     OP_REQUIRES_OK(context, context->allocate_output(
901                                 2, tensorflow::TensorShape{}, &num_outputs_t));
902     num_outputs_t->scalar<int32>().setConstant(num_valid_outputs);
903   }
904 
905  private:
906   bool pad_to_max_output_size_;
907 };
908 
909 template <typename Device>
910 class NonMaxSuppressionWithOverlapsOp : public OpKernel {
911  public:
NonMaxSuppressionWithOverlapsOp(OpKernelConstruction * context)912   explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context)
913       : OpKernel(context) {}
914 
Compute(OpKernelContext * context)915   void Compute(OpKernelContext* context) override {
916     // overlaps: [num_boxes, num_boxes]
917     const Tensor& overlaps = context->input(0);
918     // scores: [num_boxes]
919     const Tensor& scores = context->input(1);
920     // max_output_size: scalar
921     const Tensor& max_output_size = context->input(2);
922     OP_REQUIRES(
923         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
924         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
925                                 max_output_size.shape().DebugString()));
926     // overlap_threshold: scalar
927     const Tensor& overlap_threshold = context->input(3);
928     OP_REQUIRES(
929         context, TensorShapeUtils::IsScalar(overlap_threshold.shape()),
930         errors::InvalidArgument("overlap_threshold must be 0-D, got shape ",
931                                 overlap_threshold.shape().DebugString()));
932     const float overlap_threshold_val = overlap_threshold.scalar<float>()();
933 
934     // score_threshold: scalar
935     const Tensor& score_threshold = context->input(4);
936     OP_REQUIRES(
937         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
938         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
939                                 score_threshold.shape().DebugString()));
940     const float score_threshold_val = score_threshold.scalar<float>()();
941 
942     int num_boxes = 0;
943     ParseAndCheckOverlapSizes(context, overlaps, &num_boxes);
944     CheckScoreSizes(context, num_boxes, scores);
945     if (!context->status().ok()) {
946       return;
947     }
948     auto similarity_fn = CreateOverlapSimilarityFn<float>(overlaps);
949 
950     const float dummy_soft_nms_sigma = static_cast<float>(0.0);
951     DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
952                                  overlap_threshold_val, score_threshold_val,
953                                  dummy_soft_nms_sigma, similarity_fn);
954   }
955 };
956 
957 template <typename Device>
958 class CombinedNonMaxSuppressionOp : public OpKernel {
959  public:
CombinedNonMaxSuppressionOp(OpKernelConstruction * context)960   explicit CombinedNonMaxSuppressionOp(OpKernelConstruction* context)
961       : OpKernel(context) {
962     OP_REQUIRES_OK(context, context->GetAttr("pad_per_class", &pad_per_class_));
963     OP_REQUIRES_OK(context, context->GetAttr("clip_boxes", &clip_boxes_));
964   }
965 
Compute(OpKernelContext * context)966   void Compute(OpKernelContext* context) override {
967     // boxes: [batch_size, num_anchors, q, 4]
968     const Tensor& boxes = context->input(0);
969     // scores: [batch_size, num_anchors, num_classes]
970     const Tensor& scores = context->input(1);
971     OP_REQUIRES(
972         context, (boxes.dim_size(0) == scores.dim_size(0)),
973         errors::InvalidArgument("boxes and scores must have same batch size"));
974 
975     // max_output_size: scalar
976     const Tensor& max_output_size = context->input(2);
977     OP_REQUIRES(
978         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
979         errors::InvalidArgument("max_size_per_class must be 0-D, got shape ",
980                                 max_output_size.shape().DebugString()));
981     const int max_size_per_class = max_output_size.scalar<int>()();
982     OP_REQUIRES(context, max_size_per_class > 0,
983                 errors::InvalidArgument("max_size_per_class must be positive"));
984     // max_total_size: scalar
985     const Tensor& max_total_size = context->input(3);
986     OP_REQUIRES(
987         context, TensorShapeUtils::IsScalar(max_total_size.shape()),
988         errors::InvalidArgument("max_total_size must be 0-D, got shape ",
989                                 max_total_size.shape().DebugString()));
990     const int max_total_size_per_batch = max_total_size.scalar<int>()();
991     OP_REQUIRES(context, max_total_size_per_batch > 0,
992                 errors::InvalidArgument("max_total_size must be > 0"));
993     // Throw warning when `max_total_size` is too large as it may cause OOM.
994     if (max_total_size_per_batch > pow(10, 6)) {
995       LOG(WARNING) << "Detected a large value for `max_total_size`. This may "
996                    << "cause OOM error. (max_total_size: "
997                    << max_total_size.scalar<int>()() << ")";
998     }
999     // iou_threshold: scalar
1000     const Tensor& iou_threshold = context->input(4);
1001     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
1002                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
1003                                         iou_threshold.shape().DebugString()));
1004     const float iou_threshold_val = iou_threshold.scalar<float>()();
1005 
1006     // score_threshold: scalar
1007     const Tensor& score_threshold = context->input(5);
1008     OP_REQUIRES(
1009         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
1010         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
1011                                 score_threshold.shape().DebugString()));
1012     const float score_threshold_val = score_threshold.scalar<float>()();
1013 
1014     OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
1015                 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
1016     int num_boxes = 0;
1017     const int num_classes = scores.dim_size(2);
1018     ParseAndCheckCombinedNMSBoxSizes(context, boxes, &num_boxes, num_classes);
1019     CheckCombinedNMSScoreSizes(context, num_boxes, scores);
1020 
1021     if (!context->status().ok()) {
1022       return;
1023     }
1024     BatchedNonMaxSuppressionOp(context, boxes, scores, num_boxes,
1025                                max_size_per_class, max_total_size_per_batch,
1026                                score_threshold_val, iou_threshold_val,
1027                                pad_per_class_, clip_boxes_);
1028   }
1029 
1030  private:
1031   bool pad_per_class_;
1032   bool clip_boxes_;
1033 };
1034 
1035 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
1036                         NonMaxSuppressionOp<CPUDevice>);
1037 
1038 REGISTER_KERNEL_BUILDER(
1039     Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU),
1040     NonMaxSuppressionV2Op<CPUDevice, float>);
1041 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
1042                             .TypeConstraint<Eigen::half>("T")
1043                             .Device(DEVICE_CPU),
1044                         NonMaxSuppressionV2Op<CPUDevice, Eigen::half>);
1045 
1046 REGISTER_KERNEL_BUILDER(
1047     Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU),
1048     NonMaxSuppressionV3Op<CPUDevice, float>);
1049 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
1050                             .TypeConstraint<Eigen::half>("T")
1051                             .Device(DEVICE_CPU),
1052                         NonMaxSuppressionV3Op<CPUDevice, Eigen::half>);
1053 
1054 REGISTER_KERNEL_BUILDER(
1055     Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU),
1056     NonMaxSuppressionV4Op<CPUDevice, float>);
1057 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
1058                             .TypeConstraint<Eigen::half>("T")
1059                             .Device(DEVICE_CPU),
1060                         NonMaxSuppressionV4Op<CPUDevice, Eigen::half>);
1061 
1062 REGISTER_KERNEL_BUILDER(
1063     Name("NonMaxSuppressionV5").TypeConstraint<float>("T").Device(DEVICE_CPU),
1064     NonMaxSuppressionV5Op<CPUDevice, float>);
1065 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV5")
1066                             .TypeConstraint<Eigen::half>("T")
1067                             .Device(DEVICE_CPU),
1068                         NonMaxSuppressionV5Op<CPUDevice, Eigen::half>);
1069 
1070 REGISTER_KERNEL_BUILDER(
1071     Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
1072     NonMaxSuppressionWithOverlapsOp<CPUDevice>);
1073 
1074 REGISTER_KERNEL_BUILDER(Name("CombinedNonMaxSuppression").Device(DEVICE_CPU),
1075                         CombinedNonMaxSuppressionOp<CPUDevice>);
1076 
1077 }  // namespace tensorflow
1078