1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/image_ops.cc
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/image/non_max_suppression_op.h"
21
22 #include <cmath>
23 #include <functional>
24 #include <queue>
25 #include <vector>
26
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/bounds_check.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36
37 namespace tensorflow {
38 namespace {
39
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41
CheckScoreSizes(OpKernelContext * context,int num_boxes,const Tensor & scores)42 static inline void CheckScoreSizes(OpKernelContext* context, int num_boxes,
43 const Tensor& scores) {
44 // The shape of 'scores' is [num_boxes]
45 OP_REQUIRES(context, scores.dims() == 1,
46 errors::InvalidArgument(
47 "scores must be 1-D", scores.shape().DebugString(),
48 " (Shape must be rank 1 but is rank ", scores.dims(), ")"));
49 OP_REQUIRES(
50 context, scores.dim_size(0) == num_boxes,
51 errors::InvalidArgument("scores has incompatible shape (Dimensions must "
52 "be equal, but are ",
53 num_boxes, " and ", scores.dim_size(0), ")"));
54 }
55
ParseAndCheckOverlapSizes(OpKernelContext * context,const Tensor & overlaps,int * num_boxes)56 static inline void ParseAndCheckOverlapSizes(OpKernelContext* context,
57 const Tensor& overlaps,
58 int* num_boxes) {
59 // the shape of 'overlaps' is [num_boxes, num_boxes]
60 OP_REQUIRES(context, overlaps.dims() == 2,
61 errors::InvalidArgument("overlaps must be 2-D",
62 overlaps.shape().DebugString()));
63
64 *num_boxes = overlaps.dim_size(0);
65 OP_REQUIRES(context, overlaps.dim_size(1) == *num_boxes,
66 errors::InvalidArgument("overlaps must be square",
67 overlaps.shape().DebugString()));
68 }
69
ParseAndCheckBoxSizes(OpKernelContext * context,const Tensor & boxes,int * num_boxes)70 static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
71 const Tensor& boxes, int* num_boxes) {
72 // The shape of 'boxes' is [num_boxes, 4]
73 OP_REQUIRES(context, boxes.dims() == 2,
74 errors::InvalidArgument(
75 "boxes must be 2-D", boxes.shape().DebugString(),
76 " (Shape must be rank 2 but is rank ", boxes.dims(), ")"));
77 *num_boxes = boxes.dim_size(0);
78 OP_REQUIRES(context, boxes.dim_size(1) == 4,
79 errors::InvalidArgument("boxes must have 4 columns (Dimension "
80 "must be 4 but is ",
81 boxes.dim_size(1), ")"));
82 }
83
CheckCombinedNMSScoreSizes(OpKernelContext * context,int num_boxes,const Tensor & scores)84 static inline void CheckCombinedNMSScoreSizes(OpKernelContext* context,
85 int num_boxes,
86 const Tensor& scores) {
87 // The shape of 'scores' is [batch_size, num_boxes, num_classes]
88 OP_REQUIRES(context, scores.dims() == 3,
89 errors::InvalidArgument("scores must be 3-D",
90 scores.shape().DebugString()));
91 OP_REQUIRES(context, scores.dim_size(1) == num_boxes,
92 errors::InvalidArgument("scores has incompatible shape"));
93 }
94
ParseAndCheckCombinedNMSBoxSizes(OpKernelContext * context,const Tensor & boxes,int * num_boxes,const int num_classes)95 static inline void ParseAndCheckCombinedNMSBoxSizes(OpKernelContext* context,
96 const Tensor& boxes,
97 int* num_boxes,
98 const int num_classes) {
99 // The shape of 'boxes' is [batch_size, num_boxes, q, 4]
100 OP_REQUIRES(context, boxes.dims() == 4,
101 errors::InvalidArgument("boxes must be 4-D",
102 boxes.shape().DebugString()));
103
104 bool box_check = boxes.dim_size(2) == 1 || boxes.dim_size(2) == num_classes;
105 OP_REQUIRES(context, box_check,
106 errors::InvalidArgument(
107 "third dimension of boxes must be either 1 or num classes"));
108 *num_boxes = boxes.dim_size(1);
109 OP_REQUIRES(context, boxes.dim_size(3) == 4,
110 errors::InvalidArgument("boxes must have 4 columns"));
111 }
112 // Return intersection-over-union overlap between boxes i and j
113 template <typename T>
IOU(typename TTypes<T,2>::ConstTensor boxes,int i,int j)114 static inline float IOU(typename TTypes<T, 2>::ConstTensor boxes, int i,
115 int j) {
116 const float ymin_i = Eigen::numext::mini<float>(
117 static_cast<float>(boxes(i, 0)), static_cast<float>(boxes(i, 2)));
118 const float xmin_i = Eigen::numext::mini<float>(
119 static_cast<float>(boxes(i, 1)), static_cast<float>(boxes(i, 3)));
120 const float ymax_i = Eigen::numext::maxi<float>(
121 static_cast<float>(boxes(i, 0)), static_cast<float>(boxes(i, 2)));
122 const float xmax_i = Eigen::numext::maxi<float>(
123 static_cast<float>(boxes(i, 1)), static_cast<float>(boxes(i, 3)));
124 const float ymin_j = Eigen::numext::mini<float>(
125 static_cast<float>(boxes(j, 0)), static_cast<float>(boxes(j, 2)));
126 const float xmin_j = Eigen::numext::mini<float>(
127 static_cast<float>(boxes(j, 1)), static_cast<float>(boxes(j, 3)));
128 const float ymax_j = Eigen::numext::maxi<float>(
129 static_cast<float>(boxes(j, 0)), static_cast<float>(boxes(j, 2)));
130 const float xmax_j = Eigen::numext::maxi<float>(
131 static_cast<float>(boxes(j, 1)), static_cast<float>(boxes(j, 3)));
132 const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
133 const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
134 if (area_i <= 0 || area_j <= 0) {
135 return 0.0;
136 }
137 const float intersection_ymin = Eigen::numext::maxi<float>(ymin_i, ymin_j);
138 const float intersection_xmin = Eigen::numext::maxi<float>(xmin_i, xmin_j);
139 const float intersection_ymax = Eigen::numext::mini<float>(ymax_i, ymax_j);
140 const float intersection_xmax = Eigen::numext::mini<float>(xmax_i, xmax_j);
141 const float intersection_area =
142 Eigen::numext::maxi<float>(intersection_ymax - intersection_ymin, 0.0) *
143 Eigen::numext::maxi<float>(intersection_xmax - intersection_xmin, 0.0);
144 return intersection_area / (area_i + area_j - intersection_area);
145 }
146
IOU(const float * boxes,int i,int j)147 static inline float IOU(const float* boxes, int i, int j) {
148 const float ymin_i = Eigen::numext::mini<float>(boxes[i], boxes[i + 2]);
149 const float xmin_i = Eigen::numext::mini<float>(boxes[i + 1], boxes[i + 3]);
150 const float ymax_i = Eigen::numext::maxi<float>(boxes[i], boxes[i + 2]);
151 const float xmax_i = Eigen::numext::maxi<float>(boxes[i + 1], boxes[i + 3]);
152 const float ymin_j = Eigen::numext::mini<float>(boxes[j], boxes[j + 2]);
153 const float xmin_j = Eigen::numext::mini<float>(boxes[j + 1], boxes[j + 3]);
154 const float ymax_j = Eigen::numext::maxi<float>(boxes[j], boxes[j + 2]);
155 const float xmax_j = Eigen::numext::maxi<float>(boxes[j + 1], boxes[j + 3]);
156 const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
157 const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
158 if (area_i <= 0 || area_j <= 0) {
159 return 0.0;
160 }
161 const float intersection_ymin = Eigen::numext::maxi<float>(ymin_i, ymin_j);
162 const float intersection_xmin = Eigen::numext::maxi<float>(xmin_i, xmin_j);
163 const float intersection_ymax = Eigen::numext::mini<float>(ymax_i, ymax_j);
164 const float intersection_xmax = Eigen::numext::mini<float>(xmax_i, xmax_j);
165 const float intersection_area =
166 Eigen::numext::maxi<float>(intersection_ymax - intersection_ymin, 0.0) *
167 Eigen::numext::maxi<float>(intersection_xmax - intersection_xmin, 0.0);
168 return intersection_area / (area_i + area_j - intersection_area);
169 }
170
171 template <typename T>
Overlap(typename TTypes<T,2>::ConstTensor overlaps,int i,int j)172 static inline T Overlap(typename TTypes<T, 2>::ConstTensor overlaps, int i,
173 int j) {
174 return overlaps(i, j);
175 }
176
177 template <typename T>
CreateIOUSimilarityFn(const Tensor & boxes)178 static inline std::function<float(int, int)> CreateIOUSimilarityFn(
179 const Tensor& boxes) {
180 typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>();
181 return std::bind(&IOU<T>, boxes_data, std::placeholders::_1,
182 std::placeholders::_2);
183 }
184
185 template <typename T>
CreateOverlapSimilarityFn(const Tensor & overlaps)186 static inline std::function<T(int, int)> CreateOverlapSimilarityFn(
187 const Tensor& overlaps) {
188 typename TTypes<T, 2>::ConstTensor overlaps_data =
189 overlaps.tensor<float, 2>();
190 return std::bind(&Overlap<T>, overlaps_data, std::placeholders::_1,
191 std::placeholders::_2);
192 }
193
194 template <typename T>
DoNonMaxSuppressionOp(OpKernelContext * context,const Tensor & scores,int num_boxes,const Tensor & max_output_size,const T similarity_threshold,const T score_threshold,const T soft_nms_sigma,const std::function<float (int,int)> & similarity_fn,bool return_scores_tensor=false,bool pad_to_max_output_size=false,int * ptr_num_valid_outputs=nullptr)195 void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
196 int num_boxes, const Tensor& max_output_size,
197 const T similarity_threshold,
198 const T score_threshold, const T soft_nms_sigma,
199 const std::function<float(int, int)>& similarity_fn,
200 bool return_scores_tensor = false,
201 bool pad_to_max_output_size = false,
202 int* ptr_num_valid_outputs = nullptr) {
203 const int output_size = max_output_size.scalar<int>()();
204 OP_REQUIRES(context, output_size >= 0,
205 errors::InvalidArgument("output size must be non-negative"));
206
207 std::vector<T> scores_data(num_boxes);
208 std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin());
209
210 // Data structure for a selection candidate in NMS.
211 struct Candidate {
212 int box_index;
213 T score;
214 int suppress_begin_index;
215 };
216
217 auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
218 return ((bs_i.score == bs_j.score) && (bs_i.box_index > bs_j.box_index)) ||
219 bs_i.score < bs_j.score;
220 };
221 std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
222 candidate_priority_queue(cmp);
223 for (int i = 0; i < scores_data.size(); ++i) {
224 if (scores_data[i] > score_threshold) {
225 candidate_priority_queue.emplace(Candidate({i, scores_data[i], 0}));
226 }
227 }
228
229 T scale = static_cast<T>(0.0);
230 bool is_soft_nms = soft_nms_sigma > static_cast<T>(0.0);
231 if (is_soft_nms) {
232 scale = static_cast<T>(-0.5) / soft_nms_sigma;
233 }
234
235 auto suppress_weight = [similarity_threshold, scale,
236 is_soft_nms](const T sim) {
237 const T weight = Eigen::numext::exp<T>(scale * sim * sim);
238 return is_soft_nms || sim <= similarity_threshold ? weight
239 : static_cast<T>(0.0);
240 };
241
242 std::vector<int> selected;
243 std::vector<T> selected_scores;
244 float similarity;
245 T original_score;
246 Candidate next_candidate;
247
248 while (selected.size() < output_size && !candidate_priority_queue.empty()) {
249 next_candidate = candidate_priority_queue.top();
250 original_score = next_candidate.score;
251 candidate_priority_queue.pop();
252
253 // Overlapping boxes are likely to have similar scores, therefore we
254 // iterate through the previously selected boxes backwards in order to
255 // see if `next_candidate` should be suppressed. We also enforce a property
256 // that a candidate can be suppressed by another candidate no more than
257 // once via `suppress_begin_index` which tracks which previously selected
258 // boxes have already been compared against next_candidate prior to a given
259 // iteration. These previous selected boxes are then skipped over in the
260 // following loop.
261 bool should_hard_suppress = false;
262 for (int j = static_cast<int>(selected.size()) - 1;
263 j >= next_candidate.suppress_begin_index; --j) {
264 similarity = similarity_fn(next_candidate.box_index, selected[j]);
265
266 next_candidate.score *= suppress_weight(static_cast<T>(similarity));
267
268 // First decide whether to perform hard suppression
269 if (!is_soft_nms && static_cast<T>(similarity) > similarity_threshold) {
270 should_hard_suppress = true;
271 break;
272 }
273
274 // If next_candidate survives hard suppression, apply soft suppression
275 if (next_candidate.score <= score_threshold) break;
276 }
277 // If `next_candidate.score` has not dropped below `score_threshold`
278 // by this point, then we know that we went through all of the previous
279 // selections and can safely update `suppress_begin_index` to
280 // `selected.size()`. If on the other hand `next_candidate.score`
281 // *has* dropped below the score threshold, then since `suppress_weight`
282 // always returns values in [0, 1], further suppression by items that were
283 // not covered in the above for loop would not have caused the algorithm
284 // to select this item. We thus do the same update to
285 // `suppress_begin_index`, but really, this element will not be added back
286 // into the priority queue in the following.
287 next_candidate.suppress_begin_index = selected.size();
288
289 if (!should_hard_suppress) {
290 if (next_candidate.score == original_score) {
291 // Suppression has not occurred, so select next_candidate
292 selected.push_back(next_candidate.box_index);
293 selected_scores.push_back(next_candidate.score);
294 continue;
295 }
296 if (next_candidate.score > score_threshold) {
297 // Soft suppression has occurred and current score is still greater than
298 // score_threshold; add next_candidate back onto priority queue.
299 candidate_priority_queue.push(next_candidate);
300 }
301 }
302 }
303
304 int num_valid_outputs = selected.size();
305 if (pad_to_max_output_size) {
306 selected.resize(output_size, 0);
307 selected_scores.resize(output_size, static_cast<T>(0));
308 }
309 if (ptr_num_valid_outputs) {
310 *ptr_num_valid_outputs = num_valid_outputs;
311 }
312
313 // Allocate output tensors
314 Tensor* output_indices = nullptr;
315 TensorShape output_shape({static_cast<int>(selected.size())});
316 OP_REQUIRES_OK(context,
317 context->allocate_output(0, output_shape, &output_indices));
318 TTypes<int, 1>::Tensor output_indices_data = output_indices->tensor<int, 1>();
319 std::copy_n(selected.begin(), selected.size(), output_indices_data.data());
320
321 if (return_scores_tensor) {
322 Tensor* output_scores = nullptr;
323 OP_REQUIRES_OK(context,
324 context->allocate_output(1, output_shape, &output_scores));
325 typename TTypes<T, 1>::Tensor output_scores_data =
326 output_scores->tensor<T, 1>();
327 std::copy_n(selected_scores.begin(), selected_scores.size(),
328 output_scores_data.data());
329 }
330 }
331
332 struct ResultCandidate {
333 int box_index;
334 float score;
335 int class_idx;
336 float box_coord[4];
337 };
338
DoNMSPerClass(int batch_idx,int class_idx,const float * boxes_data,const float * scores_data,int num_boxes,int q,int num_classes,const int size_per_class,const float score_threshold,const float iou_threshold,std::vector<ResultCandidate> & result_candidate_vec)339 void DoNMSPerClass(int batch_idx, int class_idx, const float* boxes_data,
340 const float* scores_data, int num_boxes, int q,
341 int num_classes, const int size_per_class,
342 const float score_threshold, const float iou_threshold,
343 std::vector<ResultCandidate>& result_candidate_vec) {
344 // Do NMS, get the candidate indices of form vector<int>
345 // Data structure for selection candidate in NMS.
346 struct Candidate {
347 int box_index;
348 float score;
349 };
350 auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
351 return bs_i.score < bs_j.score;
352 };
353 std::priority_queue<Candidate, std::vector<Candidate>, decltype(cmp)>
354 candidate_priority_queue(cmp);
355 float temp_score;
356 for (int i = 0; i < num_boxes; ++i) {
357 temp_score = scores_data[i * num_classes + class_idx];
358 if (temp_score > score_threshold) {
359 candidate_priority_queue.emplace(Candidate({i, temp_score}));
360 }
361 }
362
363 std::vector<int> selected;
364 Candidate next_candidate;
365
366 int candidate_box_data_idx, selected_box_data_idx, class_box_idx;
367 class_box_idx = (q > 1) ? class_idx : 0;
368
369 float iou;
370 while (selected.size() < size_per_class &&
371 !candidate_priority_queue.empty()) {
372 next_candidate = candidate_priority_queue.top();
373 candidate_priority_queue.pop();
374
375 candidate_box_data_idx = (next_candidate.box_index * q + class_box_idx) * 4;
376
377 // Overlapping boxes are likely to have similar scores,
378 // therefore we iterate through the previously selected boxes backwards
379 // in order to see if `next_candidate` should be suppressed.
380 bool should_select = true;
381 for (int j = selected.size() - 1; j >= 0; --j) {
382 selected_box_data_idx = (selected[j] * q + class_box_idx) * 4;
383 iou = IOU(boxes_data, candidate_box_data_idx, selected_box_data_idx);
384 if (iou > iou_threshold) {
385 should_select = false;
386 break;
387 }
388 }
389
390 if (should_select) {
391 // Add the selected box to the result candidate. Sorted by score
392 result_candidate_vec[selected.size() + size_per_class * class_idx] = {
393 next_candidate.box_index,
394 next_candidate.score,
395 class_idx,
396 {boxes_data[candidate_box_data_idx],
397 boxes_data[candidate_box_data_idx + 1],
398 boxes_data[candidate_box_data_idx + 2],
399 boxes_data[candidate_box_data_idx + 3]}};
400 selected.push_back(next_candidate.box_index);
401 }
402 }
403 }
404
SelectResultPerBatch(std::vector<float> & nmsed_boxes,std::vector<float> & nmsed_scores,std::vector<float> & nmsed_classes,std::vector<ResultCandidate> & result_candidate_vec,std::vector<int> & final_valid_detections,const int batch_idx,int total_size_per_batch,bool pad_per_class,int max_size_per_batch,bool clip_boxes,int per_batch_size)405 void SelectResultPerBatch(std::vector<float>& nmsed_boxes,
406 std::vector<float>& nmsed_scores,
407 std::vector<float>& nmsed_classes,
408 std::vector<ResultCandidate>& result_candidate_vec,
409 std::vector<int>& final_valid_detections,
410 const int batch_idx, int total_size_per_batch,
411 bool pad_per_class, int max_size_per_batch,
412 bool clip_boxes, int per_batch_size) {
413 auto rc_cmp = [](const ResultCandidate rc_i, const ResultCandidate rc_j) {
414 return rc_i.score > rc_j.score;
415 };
416 std::sort(result_candidate_vec.begin(), result_candidate_vec.end(), rc_cmp);
417
418 int max_detections = 0;
419 int result_candidate_size =
420 std::count_if(result_candidate_vec.begin(), result_candidate_vec.end(),
421 [](ResultCandidate rc) { return rc.box_index > -1; });
422 // If pad_per_class is false, we always pad to max_total_size
423 if (!pad_per_class) {
424 max_detections = std::min(result_candidate_size, total_size_per_batch);
425 } else {
426 max_detections = std::min(per_batch_size, result_candidate_size);
427 }
428
429 final_valid_detections[batch_idx] = max_detections;
430
431 int curr_total_size = max_detections;
432 int result_idx = 0;
433 // Pick the top max_detections values
434 while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
435 ResultCandidate next_candidate = result_candidate_vec[result_idx++];
436 // Add to final output vectors
437 if (clip_boxes) {
438 const float box_min = 0.0;
439 const float box_max = 1.0;
440 nmsed_boxes.push_back(
441 std::max(std::min(next_candidate.box_coord[0], box_max), box_min));
442 nmsed_boxes.push_back(
443 std::max(std::min(next_candidate.box_coord[1], box_max), box_min));
444 nmsed_boxes.push_back(
445 std::max(std::min(next_candidate.box_coord[2], box_max), box_min));
446 nmsed_boxes.push_back(
447 std::max(std::min(next_candidate.box_coord[3], box_max), box_min));
448 } else {
449 nmsed_boxes.push_back(next_candidate.box_coord[0]);
450 nmsed_boxes.push_back(next_candidate.box_coord[1]);
451 nmsed_boxes.push_back(next_candidate.box_coord[2]);
452 nmsed_boxes.push_back(next_candidate.box_coord[3]);
453 }
454 nmsed_scores.push_back(next_candidate.score);
455 nmsed_classes.push_back(next_candidate.class_idx);
456 curr_total_size--;
457 }
458
459 nmsed_boxes.resize(per_batch_size * 4, 0);
460 nmsed_scores.resize(per_batch_size, 0);
461 nmsed_classes.resize(per_batch_size, 0);
462 }
463
BatchedNonMaxSuppressionOp(OpKernelContext * context,const Tensor & inp_boxes,const Tensor & inp_scores,int num_boxes,const int max_size_per_class,const int total_size_per_batch,const float score_threshold,const float iou_threshold,bool pad_per_class=false,bool clip_boxes=true)464 void BatchedNonMaxSuppressionOp(
465 OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores,
466 int num_boxes, const int max_size_per_class, const int total_size_per_batch,
467 const float score_threshold, const float iou_threshold,
468 bool pad_per_class = false, bool clip_boxes = true) {
469 const int num_batches = inp_boxes.dim_size(0);
470 int num_classes = inp_scores.dim_size(2);
471 int q = inp_boxes.dim_size(2);
472
473 const float* scores_data =
474 const_cast<float*>(inp_scores.flat<float>().data());
475 const float* boxes_data = const_cast<float*>(inp_boxes.flat<float>().data());
476
477 int boxes_per_batch = num_boxes * q * 4;
478 int scores_per_batch = num_boxes * num_classes;
479 const int size_per_class = std::min(max_size_per_class, num_boxes);
480 std::vector<std::vector<ResultCandidate>> result_candidate_vec(
481 num_batches,
482 std::vector<ResultCandidate>(size_per_class * num_classes,
483 {-1, -1.0, -1, {0.0, 0.0, 0.0, 0.0}}));
484
485 // [num_batches, per_batch_size * 4]
486 std::vector<std::vector<float>> nmsed_boxes(num_batches);
487 // [num_batches, per_batch_size]
488 std::vector<std::vector<float>> nmsed_scores(num_batches);
489 // [num_batches, per_batch_size]
490 std::vector<std::vector<float>> nmsed_classes(num_batches);
491 // [num_batches]
492 std::vector<int> final_valid_detections(num_batches);
493
494 auto shard_nms = [&](int begin, int end) {
495 for (int idx = begin; idx < end; ++idx) {
496 int batch_idx = idx / num_classes;
497 int class_idx = idx % num_classes;
498 DoNMSPerClass(batch_idx, class_idx,
499 boxes_data + boxes_per_batch * batch_idx,
500 scores_data + scores_per_batch * batch_idx, num_boxes, q,
501 num_classes, size_per_class, score_threshold, iou_threshold,
502 result_candidate_vec[batch_idx]);
503 }
504 };
505
506 int length = num_batches * num_classes;
507 // Input data boxes_data, scores_data
508 int input_bytes = num_boxes * 10 * sizeof(float);
509 int output_bytes = num_boxes * 10 * sizeof(float);
510 int compute_cycles = Eigen::TensorOpCost::AddCost<int>() * num_boxes * 14 +
511 Eigen::TensorOpCost::MulCost<int>() * num_boxes * 9 +
512 Eigen::TensorOpCost::MulCost<float>() * num_boxes * 9 +
513 Eigen::TensorOpCost::AddCost<float>() * num_boxes * 8;
514 // The cost here is not the actual number of cycles, but rather a set of
515 // hand-tuned numbers that seem to work best.
516 const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
517 const CPUDevice& d = context->eigen_device<CPUDevice>();
518 d.parallelFor(length, cost, shard_nms);
519
520 int per_batch_size = total_size_per_batch;
521 if (pad_per_class) {
522 per_batch_size =
523 std::min(total_size_per_batch, max_size_per_class * num_classes);
524 }
525
526 Tensor* valid_detections_t = nullptr;
527 TensorShape valid_detections_shape({num_batches});
528 OP_REQUIRES_OK(context, context->allocate_output(3, valid_detections_shape,
529 &valid_detections_t));
530 auto valid_detections_flat = valid_detections_t->template flat<int>();
531
532 auto shard_result = [&](int begin, int end) {
533 for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
534 SelectResultPerBatch(
535 nmsed_boxes[batch_idx], nmsed_scores[batch_idx],
536 nmsed_classes[batch_idx], result_candidate_vec[batch_idx],
537 final_valid_detections, batch_idx, total_size_per_batch,
538 pad_per_class, max_size_per_class * num_classes, clip_boxes,
539 per_batch_size);
540 valid_detections_flat(batch_idx) = final_valid_detections[batch_idx];
541 }
542 };
543 length = num_batches;
544 // Input data boxes_data, scores_data
545 input_bytes =
546 num_boxes * 10 * sizeof(float) + per_batch_size * 6 * sizeof(float);
547 output_bytes =
548 num_boxes * 5 * sizeof(float) + per_batch_size * 6 * sizeof(float);
549 compute_cycles = Eigen::TensorOpCost::AddCost<int>() * num_boxes * 5 +
550 Eigen::TensorOpCost::AddCost<float>() * num_boxes * 5;
551 // The cost here is not the actual number of cycles, but rather a set of
552 // hand-tuned numbers that seem to work best.
553 const Eigen::TensorOpCost cost_result(input_bytes, output_bytes,
554 compute_cycles);
555 d.parallelFor(length, cost_result, shard_result);
556
557 Tensor* nmsed_boxes_t = nullptr;
558 TensorShape boxes_shape({num_batches, per_batch_size, 4});
559 OP_REQUIRES_OK(context,
560 context->allocate_output(0, boxes_shape, &nmsed_boxes_t));
561 auto nmsed_boxes_flat = nmsed_boxes_t->template flat<float>();
562
563 Tensor* nmsed_scores_t = nullptr;
564 TensorShape scores_shape({num_batches, per_batch_size});
565 OP_REQUIRES_OK(context,
566 context->allocate_output(1, scores_shape, &nmsed_scores_t));
567 auto nmsed_scores_flat = nmsed_scores_t->template flat<float>();
568
569 Tensor* nmsed_classes_t = nullptr;
570 OP_REQUIRES_OK(context,
571 context->allocate_output(2, scores_shape, &nmsed_classes_t));
572 auto nmsed_classes_flat = nmsed_classes_t->template flat<float>();
573
574 auto shard_copy_result = [&](int begin, int end) {
575 for (int idx = begin; idx < end; ++idx) {
576 int batch_idx = idx / per_batch_size;
577 int j = idx % per_batch_size;
578 nmsed_scores_flat(idx) = nmsed_scores[batch_idx][j];
579 nmsed_classes_flat(idx) = nmsed_classes[batch_idx][j];
580 for (int k = 0; k < 4; ++k) {
581 nmsed_boxes_flat(idx * 4 + k) = nmsed_boxes[batch_idx][j * 4 + k];
582 }
583 }
584 };
585 length = num_batches * per_batch_size;
586 // Input data boxes_data, scores_data
587 input_bytes = 6 * sizeof(float);
588 output_bytes = 6 * sizeof(float);
589 compute_cycles = Eigen::TensorOpCost::AddCost<int>() * 2 +
590 Eigen::TensorOpCost::MulCost<int>() * 2 +
591 Eigen::TensorOpCost::DivCost<float>() * 2;
592 const Eigen::TensorOpCost cost_copy_result(input_bytes, output_bytes,
593 compute_cycles);
594 d.parallelFor(length, cost_copy_result, shard_copy_result);
595 }
596
597 // Extract a scalar of type T from a tensor, with correct type checking.
598 // This is necessary because several of the kernels here assume
599 // T == T_threshold.
600 template <typename T>
GetScalar(const Tensor & tensor)601 T GetScalar(const Tensor& tensor) {
602 switch (tensor.dtype()) {
603 case DT_FLOAT:
604 return static_cast<T>(tensor.scalar<float>()());
605 case DT_DOUBLE:
606 return static_cast<T>(tensor.scalar<double>()());
607 case DT_BFLOAT16:
608 return static_cast<T>(tensor.scalar<Eigen::bfloat16>()());
609 case DT_HALF:
610 return static_cast<T>(tensor.scalar<Eigen::half>()());
611 default:
612 DCHECK(false) << "Unsupported type " << tensor.dtype();
613 break;
614 }
615 return static_cast<T>(0);
616 }
617
618 } // namespace
619
620 template <typename Device>
621 class NonMaxSuppressionOp : public OpKernel {
622 public:
NonMaxSuppressionOp(OpKernelConstruction * context)623 explicit NonMaxSuppressionOp(OpKernelConstruction* context)
624 : OpKernel(context) {
625 OP_REQUIRES_OK(context, context->GetAttr("iou_threshold", &iou_threshold_));
626 }
627
Compute(OpKernelContext * context)628 void Compute(OpKernelContext* context) override {
629 // boxes: [num_boxes, 4]
630 const Tensor& boxes = context->input(0);
631 // scores: [num_boxes]
632 const Tensor& scores = context->input(1);
633 // max_output_size: scalar
634 const Tensor& max_output_size = context->input(2);
635 OP_REQUIRES(
636 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
637 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
638 max_output_size.shape().DebugString()));
639
640 OP_REQUIRES(context, iou_threshold_ >= 0 && iou_threshold_ <= 1,
641 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
642 int num_boxes = 0;
643 ParseAndCheckBoxSizes(context, boxes, &num_boxes);
644 CheckScoreSizes(context, num_boxes, scores);
645 if (!context->status().ok()) {
646 return;
647 }
648 auto similarity_fn = CreateIOUSimilarityFn<float>(boxes);
649
650 const float score_threshold_val = std::numeric_limits<float>::lowest();
651 const float dummy_soft_nms_sigma = static_cast<float>(0.0);
652 DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
653 iou_threshold_, score_threshold_val,
654 dummy_soft_nms_sigma, similarity_fn);
655 }
656
657 private:
658 float iou_threshold_;
659 };
660
661 template <typename Device, typename T>
662 class NonMaxSuppressionV2Op : public OpKernel {
663 public:
NonMaxSuppressionV2Op(OpKernelConstruction * context)664 explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
665 : OpKernel(context) {}
666
Compute(OpKernelContext * context)667 void Compute(OpKernelContext* context) override {
668 // boxes: [num_boxes, 4]
669 const Tensor& boxes = context->input(0);
670 // scores: [num_boxes]
671 const Tensor& scores = context->input(1);
672 // max_output_size: scalar
673 const Tensor& max_output_size = context->input(2);
674 OP_REQUIRES(
675 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
676 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
677 max_output_size.shape().DebugString()));
678 // iou_threshold: scalar
679 const Tensor& iou_threshold = context->input(3);
680 OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
681 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
682 iou_threshold.shape().DebugString()));
683 const T iou_threshold_val = GetScalar<T>(iou_threshold);
684
685 OP_REQUIRES(context,
686 iou_threshold_val >= static_cast<T>(0.0) &&
687 iou_threshold_val <= static_cast<T>(1.0),
688 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
689 int num_boxes = 0;
690 ParseAndCheckBoxSizes(context, boxes, &num_boxes);
691 CheckScoreSizes(context, num_boxes, scores);
692 if (!context->status().ok()) {
693 return;
694 }
695 auto similarity_fn = CreateIOUSimilarityFn<T>(boxes);
696
697 const T score_threshold_val = std::numeric_limits<T>::lowest();
698 const T dummy_soft_nms_sigma = static_cast<T>(0.0);
699 DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
700 iou_threshold_val, score_threshold_val,
701 dummy_soft_nms_sigma, similarity_fn);
702 }
703 };
704
705 template <typename Device, typename T>
706 class NonMaxSuppressionV3Op : public OpKernel {
707 public:
NonMaxSuppressionV3Op(OpKernelConstruction * context)708 explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
709 : OpKernel(context) {}
710
Compute(OpKernelContext * context)711 void Compute(OpKernelContext* context) override {
712 // boxes: [num_boxes, 4]
713 const Tensor& boxes = context->input(0);
714 // scores: [num_boxes]
715 const Tensor& scores = context->input(1);
716 // max_output_size: scalar
717 const Tensor& max_output_size = context->input(2);
718 OP_REQUIRES(
719 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
720 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
721 max_output_size.shape().DebugString(),
722 " (Shape must be rank 0 but is ", "rank ",
723 max_output_size.dims(), ")"));
724 // iou_threshold: scalar
725 const Tensor& iou_threshold = context->input(3);
726 OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
727 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
728 iou_threshold.shape().DebugString(),
729 " (Shape must be rank 0 but is rank ",
730 iou_threshold.dims(), ")"));
731 const T iou_threshold_val = GetScalar<T>(iou_threshold);
732 OP_REQUIRES(context,
733 iou_threshold_val >= static_cast<T>(0.0) &&
734 iou_threshold_val <= static_cast<T>(1.0),
735 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
736 // score_threshold: scalar
737 const Tensor& score_threshold = context->input(4);
738 OP_REQUIRES(
739 context, TensorShapeUtils::IsScalar(score_threshold.shape()),
740 errors::InvalidArgument("score_threshold must be 0-D, got shape ",
741 score_threshold.shape().DebugString()));
742 const T score_threshold_val = GetScalar<T>(score_threshold);
743
744 int num_boxes = 0;
745 ParseAndCheckBoxSizes(context, boxes, &num_boxes);
746 CheckScoreSizes(context, num_boxes, scores);
747 if (!context->status().ok()) {
748 return;
749 }
750
751 auto similarity_fn = CreateIOUSimilarityFn<T>(boxes);
752
753 const T dummy_soft_nms_sigma = static_cast<T>(0.0);
754 DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
755 iou_threshold_val, score_threshold_val,
756 dummy_soft_nms_sigma, similarity_fn);
757 }
758 };
759
760 template <typename Device, typename T>
761 class NonMaxSuppressionV4Op : public OpKernel {
762 public:
NonMaxSuppressionV4Op(OpKernelConstruction * context)763 explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
764 : OpKernel(context) {
765 OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
766 &pad_to_max_output_size_));
767 }
768
Compute(OpKernelContext * context)769 void Compute(OpKernelContext* context) override {
770 // boxes: [num_boxes, 4]
771 const Tensor& boxes = context->input(0);
772 // scores: [num_boxes]
773 const Tensor& scores = context->input(1);
774 // max_output_size: scalar
775 const Tensor& max_output_size = context->input(2);
776 OP_REQUIRES(
777 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
778 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
779 max_output_size.shape().DebugString()));
780 // iou_threshold: scalar
781 const Tensor& iou_threshold = context->input(3);
782 OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
783 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
784 iou_threshold.shape().DebugString()));
785 const T iou_threshold_val = GetScalar<T>(iou_threshold);
786 OP_REQUIRES(context,
787 iou_threshold_val >= static_cast<T>(0.0) &&
788 iou_threshold_val <= static_cast<T>(1.0),
789 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
790 // score_threshold: scalar
791 const Tensor& score_threshold = context->input(4);
792 OP_REQUIRES(
793 context, TensorShapeUtils::IsScalar(score_threshold.shape()),
794 errors::InvalidArgument("score_threshold must be 0-D, got shape ",
795 score_threshold.shape().DebugString()));
796 const T score_threshold_val = GetScalar<T>(score_threshold);
797
798 int num_boxes = 0;
799 ParseAndCheckBoxSizes(context, boxes, &num_boxes);
800 CheckScoreSizes(context, num_boxes, scores);
801 if (!context->status().ok()) {
802 return;
803 }
804
805 auto similarity_fn = CreateIOUSimilarityFn<T>(boxes);
806 int num_valid_outputs;
807
808 bool return_scores_tensor_ = false;
809 const T dummy_soft_nms_sigma = static_cast<T>(0.0);
810 DoNonMaxSuppressionOp<T>(
811 context, scores, num_boxes, max_output_size, iou_threshold_val,
812 score_threshold_val, dummy_soft_nms_sigma, similarity_fn,
813 return_scores_tensor_, pad_to_max_output_size_, &num_valid_outputs);
814 if (!context->status().ok()) {
815 return;
816 }
817
818 // Allocate scalar output tensor for number of indices computed.
819 Tensor* num_outputs_t = nullptr;
820 OP_REQUIRES_OK(context, context->allocate_output(
821 1, tensorflow::TensorShape{}, &num_outputs_t));
822 num_outputs_t->scalar<int32>().setConstant(num_valid_outputs);
823 }
824
825 private:
826 bool pad_to_max_output_size_;
827 };
828
829 template <typename Device, typename T>
830 class NonMaxSuppressionV5Op : public OpKernel {
831 public:
NonMaxSuppressionV5Op(OpKernelConstruction * context)832 explicit NonMaxSuppressionV5Op(OpKernelConstruction* context)
833 : OpKernel(context) {
834 OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
835 &pad_to_max_output_size_));
836 }
837
Compute(OpKernelContext * context)838 void Compute(OpKernelContext* context) override {
839 // boxes: [num_boxes, 4]
840 const Tensor& boxes = context->input(0);
841 // scores: [num_boxes]
842 const Tensor& scores = context->input(1);
843 // max_output_size: scalar
844 const Tensor& max_output_size = context->input(2);
845 OP_REQUIRES(
846 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
847 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
848 max_output_size.shape().DebugString()));
849 // iou_threshold: scalar
850 const Tensor& iou_threshold = context->input(3);
851 OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
852 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
853 iou_threshold.shape().DebugString()));
854 const T iou_threshold_val = iou_threshold.scalar<T>()();
855 OP_REQUIRES(context,
856 iou_threshold_val >= static_cast<T>(0.0) &&
857 iou_threshold_val <= static_cast<T>(1.0),
858 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
859 // score_threshold: scalar
860 const Tensor& score_threshold = context->input(4);
861 OP_REQUIRES(
862 context, TensorShapeUtils::IsScalar(score_threshold.shape()),
863 errors::InvalidArgument("score_threshold must be 0-D, got shape ",
864 score_threshold.shape().DebugString()));
865 const T score_threshold_val = score_threshold.scalar<T>()();
866
867 // soft_nms_sigma: scalar
868 const Tensor& soft_nms_sigma = context->input(5);
869 OP_REQUIRES(
870 context, TensorShapeUtils::IsScalar(soft_nms_sigma.shape()),
871 errors::InvalidArgument("soft_nms_sigma must be 0-D, got shape ",
872 soft_nms_sigma.shape().DebugString()));
873 const T soft_nms_sigma_val = soft_nms_sigma.scalar<T>()();
874 OP_REQUIRES(context, soft_nms_sigma_val >= static_cast<T>(0.0),
875 errors::InvalidArgument("soft_nms_sigma_val must be >= 0"));
876
877 int num_boxes = 0;
878 ParseAndCheckBoxSizes(context, boxes, &num_boxes);
879 CheckScoreSizes(context, num_boxes, scores);
880 if (!context->status().ok()) {
881 return;
882 }
883
884 auto similarity_fn = CreateIOUSimilarityFn<T>(boxes);
885 int num_valid_outputs;
886
887 // For NonMaxSuppressionV5Op, we always return a second output holding
888 // corresponding scores, so `return_scores_tensor` should never be false.
889 const bool return_scores_tensor_ = true;
890 DoNonMaxSuppressionOp<T>(
891 context, scores, num_boxes, max_output_size, iou_threshold_val,
892 score_threshold_val, soft_nms_sigma_val, similarity_fn,
893 return_scores_tensor_, pad_to_max_output_size_, &num_valid_outputs);
894 if (!context->status().ok()) {
895 return;
896 }
897
898 // Allocate scalar output tensor for number of indices computed.
899 Tensor* num_outputs_t = nullptr;
900 OP_REQUIRES_OK(context, context->allocate_output(
901 2, tensorflow::TensorShape{}, &num_outputs_t));
902 num_outputs_t->scalar<int32>().setConstant(num_valid_outputs);
903 }
904
905 private:
906 bool pad_to_max_output_size_;
907 };
908
909 template <typename Device>
910 class NonMaxSuppressionWithOverlapsOp : public OpKernel {
911 public:
NonMaxSuppressionWithOverlapsOp(OpKernelConstruction * context)912 explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context)
913 : OpKernel(context) {}
914
Compute(OpKernelContext * context)915 void Compute(OpKernelContext* context) override {
916 // overlaps: [num_boxes, num_boxes]
917 const Tensor& overlaps = context->input(0);
918 // scores: [num_boxes]
919 const Tensor& scores = context->input(1);
920 // max_output_size: scalar
921 const Tensor& max_output_size = context->input(2);
922 OP_REQUIRES(
923 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
924 errors::InvalidArgument("max_output_size must be 0-D, got shape ",
925 max_output_size.shape().DebugString()));
926 // overlap_threshold: scalar
927 const Tensor& overlap_threshold = context->input(3);
928 OP_REQUIRES(
929 context, TensorShapeUtils::IsScalar(overlap_threshold.shape()),
930 errors::InvalidArgument("overlap_threshold must be 0-D, got shape ",
931 overlap_threshold.shape().DebugString()));
932 const float overlap_threshold_val = overlap_threshold.scalar<float>()();
933
934 // score_threshold: scalar
935 const Tensor& score_threshold = context->input(4);
936 OP_REQUIRES(
937 context, TensorShapeUtils::IsScalar(score_threshold.shape()),
938 errors::InvalidArgument("score_threshold must be 0-D, got shape ",
939 score_threshold.shape().DebugString()));
940 const float score_threshold_val = score_threshold.scalar<float>()();
941
942 int num_boxes = 0;
943 ParseAndCheckOverlapSizes(context, overlaps, &num_boxes);
944 CheckScoreSizes(context, num_boxes, scores);
945 if (!context->status().ok()) {
946 return;
947 }
948 auto similarity_fn = CreateOverlapSimilarityFn<float>(overlaps);
949
950 const float dummy_soft_nms_sigma = static_cast<float>(0.0);
951 DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
952 overlap_threshold_val, score_threshold_val,
953 dummy_soft_nms_sigma, similarity_fn);
954 }
955 };
956
957 template <typename Device>
958 class CombinedNonMaxSuppressionOp : public OpKernel {
959 public:
CombinedNonMaxSuppressionOp(OpKernelConstruction * context)960 explicit CombinedNonMaxSuppressionOp(OpKernelConstruction* context)
961 : OpKernel(context) {
962 OP_REQUIRES_OK(context, context->GetAttr("pad_per_class", &pad_per_class_));
963 OP_REQUIRES_OK(context, context->GetAttr("clip_boxes", &clip_boxes_));
964 }
965
Compute(OpKernelContext * context)966 void Compute(OpKernelContext* context) override {
967 // boxes: [batch_size, num_anchors, q, 4]
968 const Tensor& boxes = context->input(0);
969 // scores: [batch_size, num_anchors, num_classes]
970 const Tensor& scores = context->input(1);
971 OP_REQUIRES(
972 context, (boxes.dim_size(0) == scores.dim_size(0)),
973 errors::InvalidArgument("boxes and scores must have same batch size"));
974
975 // max_output_size: scalar
976 const Tensor& max_output_size = context->input(2);
977 OP_REQUIRES(
978 context, TensorShapeUtils::IsScalar(max_output_size.shape()),
979 errors::InvalidArgument("max_size_per_class must be 0-D, got shape ",
980 max_output_size.shape().DebugString()));
981 const int max_size_per_class = max_output_size.scalar<int>()();
982 OP_REQUIRES(context, max_size_per_class > 0,
983 errors::InvalidArgument("max_size_per_class must be positive"));
984 // max_total_size: scalar
985 const Tensor& max_total_size = context->input(3);
986 OP_REQUIRES(
987 context, TensorShapeUtils::IsScalar(max_total_size.shape()),
988 errors::InvalidArgument("max_total_size must be 0-D, got shape ",
989 max_total_size.shape().DebugString()));
990 const int max_total_size_per_batch = max_total_size.scalar<int>()();
991 OP_REQUIRES(context, max_total_size_per_batch > 0,
992 errors::InvalidArgument("max_total_size must be > 0"));
993 // Throw warning when `max_total_size` is too large as it may cause OOM.
994 if (max_total_size_per_batch > pow(10, 6)) {
995 LOG(WARNING) << "Detected a large value for `max_total_size`. This may "
996 << "cause OOM error. (max_total_size: "
997 << max_total_size.scalar<int>()() << ")";
998 }
999 // iou_threshold: scalar
1000 const Tensor& iou_threshold = context->input(4);
1001 OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
1002 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
1003 iou_threshold.shape().DebugString()));
1004 const float iou_threshold_val = iou_threshold.scalar<float>()();
1005
1006 // score_threshold: scalar
1007 const Tensor& score_threshold = context->input(5);
1008 OP_REQUIRES(
1009 context, TensorShapeUtils::IsScalar(score_threshold.shape()),
1010 errors::InvalidArgument("score_threshold must be 0-D, got shape ",
1011 score_threshold.shape().DebugString()));
1012 const float score_threshold_val = score_threshold.scalar<float>()();
1013
1014 OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
1015 errors::InvalidArgument("iou_threshold must be in [0, 1]"));
1016 int num_boxes = 0;
1017 const int num_classes = scores.dim_size(2);
1018 ParseAndCheckCombinedNMSBoxSizes(context, boxes, &num_boxes, num_classes);
1019 CheckCombinedNMSScoreSizes(context, num_boxes, scores);
1020
1021 if (!context->status().ok()) {
1022 return;
1023 }
1024 BatchedNonMaxSuppressionOp(context, boxes, scores, num_boxes,
1025 max_size_per_class, max_total_size_per_batch,
1026 score_threshold_val, iou_threshold_val,
1027 pad_per_class_, clip_boxes_);
1028 }
1029
1030 private:
1031 bool pad_per_class_;
1032 bool clip_boxes_;
1033 };
1034
1035 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
1036 NonMaxSuppressionOp<CPUDevice>);
1037
1038 REGISTER_KERNEL_BUILDER(
1039 Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU),
1040 NonMaxSuppressionV2Op<CPUDevice, float>);
1041 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
1042 .TypeConstraint<Eigen::half>("T")
1043 .Device(DEVICE_CPU),
1044 NonMaxSuppressionV2Op<CPUDevice, Eigen::half>);
1045
1046 REGISTER_KERNEL_BUILDER(
1047 Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU),
1048 NonMaxSuppressionV3Op<CPUDevice, float>);
1049 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
1050 .TypeConstraint<Eigen::half>("T")
1051 .Device(DEVICE_CPU),
1052 NonMaxSuppressionV3Op<CPUDevice, Eigen::half>);
1053
1054 REGISTER_KERNEL_BUILDER(
1055 Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU),
1056 NonMaxSuppressionV4Op<CPUDevice, float>);
1057 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
1058 .TypeConstraint<Eigen::half>("T")
1059 .Device(DEVICE_CPU),
1060 NonMaxSuppressionV4Op<CPUDevice, Eigen::half>);
1061
1062 REGISTER_KERNEL_BUILDER(
1063 Name("NonMaxSuppressionV5").TypeConstraint<float>("T").Device(DEVICE_CPU),
1064 NonMaxSuppressionV5Op<CPUDevice, float>);
1065 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV5")
1066 .TypeConstraint<Eigen::half>("T")
1067 .Device(DEVICE_CPU),
1068 NonMaxSuppressionV5Op<CPUDevice, Eigen::half>);
1069
1070 REGISTER_KERNEL_BUILDER(
1071 Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
1072 NonMaxSuppressionWithOverlapsOp<CPUDevice>);
1073
1074 REGISTER_KERNEL_BUILDER(Name("CombinedNonMaxSuppression").Device(DEVICE_CPU),
1075 CombinedNonMaxSuppressionOp<CPUDevice>);
1076
1077 } // namespace tensorflow
1078