1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <math.h>
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include <algorithm>
20 #include <initializer_list>
21 #include <numeric>
22 #include <vector>
23
24 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/kernels/internal/compatibility.h"
27 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
29 #include "tensorflow/lite/kernels/internal/tensor.h"
30 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32
33 namespace tflite {
34 namespace ops {
35 namespace custom {
36 namespace detection_postprocess {
37
38 // Input tensors
39 constexpr int kInputTensorBoxEncodings = 0;
40 constexpr int kInputTensorClassPredictions = 1;
41 constexpr int kInputTensorAnchors = 2;
42
43 // Output tensors
44 // When max_classes_per_detection > 1, detection boxes will be replicated by the
45 // number of detected classes of that box. Dummy data will be appended if the
46 // number of classes is smaller than max_classes_per_detection.
47 constexpr int kOutputTensorDetectionBoxes = 0;
48 constexpr int kOutputTensorDetectionClasses = 1;
49 constexpr int kOutputTensorDetectionScores = 2;
50 constexpr int kOutputTensorNumDetections = 3;
51
52 constexpr int kNumCoordBox = 4;
53 constexpr int kBatchSize = 1;
54
55 constexpr int kNumDetectionsPerClass = 100;
56
57 // Object Detection model produces axis-aligned boxes in two formats:
58 // BoxCorner represents the lower left corner (xmin, ymin) and
59 // the upper right corner (xmax, ymax).
60 // CenterSize represents the center (xcenter, ycenter), height and width.
61 // BoxCornerEncoding and CenterSizeEncoding are related as follows:
62 // ycenter = y / y_scale * anchor.h + anchor.y;
63 // xcenter = x / x_scale * anchor.w + anchor.x;
64 // half_h = 0.5*exp(h/ h_scale)) * anchor.h;
65 // half_w = 0.5*exp(w / w_scale)) * anchor.w;
66 // ymin = ycenter - half_h
67 // ymax = ycenter + half_h
68 // xmin = xcenter - half_w
69 // xmax = xcenter + half_w
70 struct BoxCornerEncoding {
71 float ymin;
72 float xmin;
73 float ymax;
74 float xmax;
75 };
76
77 struct CenterSizeEncoding {
78 float y;
79 float x;
80 float h;
81 float w;
82 };
83 // We make sure that the memory allocations are contiguous with static assert.
84 static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
85 "Size of BoxCornerEncoding is 4 float values");
86 static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
87 "Size of CenterSizeEncoding is 4 float values");
88
89 struct OpData {
90 int max_detections;
91 int max_classes_per_detection; // Fast Non-Max-Suppression
92 int detections_per_class; // Regular Non-Max-Suppression
93 float non_max_suppression_score_threshold;
94 float intersection_over_union_threshold;
95 int num_classes;
96 bool use_regular_non_max_suppression;
97 CenterSizeEncoding scale_values;
98 // Indices of Temporary tensors
99 int decoded_boxes_index;
100 int scores_index;
101 };
102
Init(TfLiteContext * context,const char * buffer,size_t length)103 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
104 auto* op_data = new OpData;
105 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
106 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
107 op_data->max_detections = m["max_detections"].AsInt32();
108 op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
109 if (m["detections_per_class"].IsNull())
110 op_data->detections_per_class = kNumDetectionsPerClass;
111 else
112 op_data->detections_per_class = m["detections_per_class"].AsInt32();
113 if (m["use_regular_nms"].IsNull())
114 op_data->use_regular_non_max_suppression = false;
115 else
116 op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool();
117
118 op_data->non_max_suppression_score_threshold =
119 m["nms_score_threshold"].AsFloat();
120 op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
121 op_data->num_classes = m["num_classes"].AsInt32();
122 op_data->scale_values.y = m["y_scale"].AsFloat();
123 op_data->scale_values.x = m["x_scale"].AsFloat();
124 op_data->scale_values.h = m["h_scale"].AsFloat();
125 op_data->scale_values.w = m["w_scale"].AsFloat();
126 context->AddTensors(context, 1, &op_data->decoded_boxes_index);
127 context->AddTensors(context, 1, &op_data->scores_index);
128 return op_data;
129 }
130
Free(TfLiteContext * context,void * buffer)131 void Free(TfLiteContext* context, void* buffer) {
132 delete static_cast<OpData*>(buffer);
133 }
134
SetTensorSizes(TfLiteContext * context,TfLiteTensor * tensor,std::initializer_list<int> values)135 TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor,
136 std::initializer_list<int> values) {
137 TfLiteIntArray* size = TfLiteIntArrayCreate(values.size());
138 int index = 0;
139 for (const auto& v : values) {
140 size->data[index] = v;
141 ++index;
142 }
143 return context->ResizeTensor(context, tensor, size);
144 }
145
Prepare(TfLiteContext * context,TfLiteNode * node)146 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
147 auto* op_data = static_cast<OpData*>(node->user_data);
148 // Inputs: box_encodings, scores, anchors
149 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
150 const TfLiteTensor* input_box_encodings;
151 TF_LITE_ENSURE_OK(context,
152 GetInputSafe(context, node, kInputTensorBoxEncodings,
153 &input_box_encodings));
154 const TfLiteTensor* input_class_predictions;
155 TF_LITE_ENSURE_OK(context,
156 GetInputSafe(context, node, kInputTensorClassPredictions,
157 &input_class_predictions));
158 const TfLiteTensor* input_anchors;
159 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
160 &input_anchors));
161 TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
162 TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
163 TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
164 // number of detected boxes
165 const int num_detected_boxes =
166 op_data->max_detections * op_data->max_classes_per_detection;
167
168 // Outputs: detection_boxes, detection_scores, detection_classes,
169 // num_detections
170 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
171 // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
172 TfLiteTensor* detection_boxes;
173 TF_LITE_ENSURE_OK(context,
174 GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
175 &detection_boxes));
176 detection_boxes->type = kTfLiteFloat32;
177 SetTensorSizes(context, detection_boxes,
178 {kBatchSize, num_detected_boxes, kNumCoordBox});
179
180 // Output Tensor detection_classes: size is set to (1, num_detected_boxes)
181 TfLiteTensor* detection_classes;
182 TF_LITE_ENSURE_OK(context,
183 GetOutputSafe(context, node, kOutputTensorDetectionClasses,
184 &detection_classes));
185 detection_classes->type = kTfLiteFloat32;
186 SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes});
187
188 // Output Tensor detection_scores: size is set to (1, num_detected_boxes)
189 TfLiteTensor* detection_scores;
190 TF_LITE_ENSURE_OK(context,
191 GetOutputSafe(context, node, kOutputTensorDetectionScores,
192 &detection_scores));
193 detection_scores->type = kTfLiteFloat32;
194 SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes});
195
196 // Output Tensor num_detections: size is set to 1
197 TfLiteTensor* num_detections;
198 TF_LITE_ENSURE_OK(context,
199 GetOutputSafe(context, node, kOutputTensorNumDetections,
200 &num_detections));
201 num_detections->type = kTfLiteFloat32;
202 SetTensorSizes(context, num_detections, {1});
203
204 // Temporary tensors
205 TfLiteIntArrayFree(node->temporaries);
206 node->temporaries = TfLiteIntArrayCreate(2);
207 node->temporaries->data[0] = op_data->decoded_boxes_index;
208 node->temporaries->data[1] = op_data->scores_index;
209
210 // decoded_boxes
211 TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index];
212 decoded_boxes->type = kTfLiteFloat32;
213 decoded_boxes->allocation_type = kTfLiteArenaRw;
214 SetTensorSizes(context, decoded_boxes,
215 {input_box_encodings->dims->data[1], kNumCoordBox});
216
217 // scores
218 TfLiteTensor* scores = &context->tensors[op_data->scores_index];
219 scores->type = kTfLiteFloat32;
220 scores->allocation_type = kTfLiteArenaRw;
221 SetTensorSizes(context, scores,
222 {input_class_predictions->dims->data[1],
223 input_class_predictions->dims->data[2]});
224
225 return kTfLiteOk;
226 }
227
228 class Dequantizer {
229 public:
Dequantizer(int zero_point,float scale)230 Dequantizer(int zero_point, float scale)
231 : zero_point_(zero_point), scale_(scale) {}
operator ()(uint8 x)232 float operator()(uint8 x) {
233 return (static_cast<float>(x) - zero_point_) * scale_;
234 }
235
236 private:
237 int zero_point_;
238 float scale_;
239 };
240
DequantizeBoxEncodings(const TfLiteTensor * input_box_encodings,int idx,float quant_zero_point,float quant_scale,int length_box_encoding,CenterSizeEncoding * box_centersize)241 void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
242 float quant_zero_point, float quant_scale,
243 int length_box_encoding,
244 CenterSizeEncoding* box_centersize) {
245 const uint8* boxes =
246 GetTensorData<uint8>(input_box_encodings) + length_box_encoding * idx;
247 Dequantizer dequantize(quant_zero_point, quant_scale);
248 // See definition of the KeyPointBoxCoder at
249 // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py
250 // The first four elements are the box coordinates, which is the same as the
251 // FastRnnBoxCoder at
252 // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/faster_rcnn_box_coder.py
253 box_centersize->y = dequantize(boxes[0]);
254 box_centersize->x = dequantize(boxes[1]);
255 box_centersize->h = dequantize(boxes[2]);
256 box_centersize->w = dequantize(boxes[3]);
257 }
258
259 template <class T>
ReInterpretTensor(const TfLiteTensor * tensor)260 T ReInterpretTensor(const TfLiteTensor* tensor) {
261 const float* tensor_base = GetTensorData<float>(tensor);
262 return reinterpret_cast<T>(tensor_base);
263 }
264
265 template <class T>
ReInterpretTensor(TfLiteTensor * tensor)266 T ReInterpretTensor(TfLiteTensor* tensor) {
267 float* tensor_base = GetTensorData<float>(tensor);
268 return reinterpret_cast<T>(tensor_base);
269 }
270
DecodeCenterSizeBoxes(TfLiteContext * context,TfLiteNode * node,OpData * op_data)271 TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
272 OpData* op_data) {
273 // Parse input tensor boxencodings
274 const TfLiteTensor* input_box_encodings;
275 TF_LITE_ENSURE_OK(context,
276 GetInputSafe(context, node, kInputTensorBoxEncodings,
277 &input_box_encodings));
278 TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
279 const int num_boxes = input_box_encodings->dims->data[1];
280 TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox);
281 const TfLiteTensor* input_anchors;
282 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
283 &input_anchors));
284
285 // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
286 CenterSizeEncoding box_centersize;
287 CenterSizeEncoding scale_values = op_data->scale_values;
288 CenterSizeEncoding anchor;
289 for (int idx = 0; idx < num_boxes; ++idx) {
290 switch (input_box_encodings->type) {
291 // Quantized
292 case kTfLiteUInt8:
293 DequantizeBoxEncodings(
294 input_box_encodings, idx,
295 static_cast<float>(input_box_encodings->params.zero_point),
296 static_cast<float>(input_box_encodings->params.scale),
297 input_box_encodings->dims->data[2], &box_centersize);
298 DequantizeBoxEncodings(
299 input_anchors, idx,
300 static_cast<float>(input_anchors->params.zero_point),
301 static_cast<float>(input_anchors->params.scale), kNumCoordBox,
302 &anchor);
303 break;
304 // Float
305 case kTfLiteFloat32: {
306 // Please see DequantizeBoxEncodings function for the support detail.
307 const int box_encoding_idx = idx * input_box_encodings->dims->data[2];
308 const float* boxes =
309 &(GetTensorData<float>(input_box_encodings)[box_encoding_idx]);
310 box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
311 TF_LITE_ENSURE_EQ(context, input_anchors->type, kTfLiteFloat32);
312 anchor =
313 ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
314 break;
315 }
316 default:
317 // Unsupported type.
318 return kTfLiteError;
319 }
320
321 float ycenter = static_cast<float>(static_cast<double>(box_centersize.y) /
322 static_cast<double>(scale_values.y) *
323 static_cast<double>(anchor.h) +
324 static_cast<double>(anchor.y));
325
326 float xcenter = static_cast<float>(static_cast<double>(box_centersize.x) /
327 static_cast<double>(scale_values.x) *
328 static_cast<double>(anchor.w) +
329 static_cast<double>(anchor.x));
330
331 float half_h =
332 static_cast<float>(0.5 *
333 (std::exp(static_cast<double>(box_centersize.h) /
334 static_cast<double>(scale_values.h))) *
335 static_cast<double>(anchor.h));
336 float half_w =
337 static_cast<float>(0.5 *
338 (std::exp(static_cast<double>(box_centersize.w) /
339 static_cast<double>(scale_values.w))) *
340 static_cast<double>(anchor.w));
341
342 TfLiteTensor* decoded_boxes =
343 &context->tensors[op_data->decoded_boxes_index];
344 TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
345 auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
346 box.ymin = ycenter - half_h;
347 box.xmin = xcenter - half_w;
348 box.ymax = ycenter + half_h;
349 box.xmax = xcenter + half_w;
350 }
351 return kTfLiteOk;
352 }
353
DecreasingPartialArgSort(const float * values,int num_values,int num_to_sort,int * indices)354 void DecreasingPartialArgSort(const float* values, int num_values,
355 int num_to_sort, int* indices) {
356 if (num_to_sort == 1) {
357 indices[0] = optimized_ops::ArgMaxVector(values, num_values);
358 } else {
359 std::iota(indices, indices + num_values, 0);
360 std::partial_sort(
361 indices, indices + num_to_sort, indices + num_values,
362 [&values](const int i, const int j) { return values[i] > values[j]; });
363 }
364 }
365
DecreasingArgSort(const float * values,int num_values,int * indices)366 void DecreasingArgSort(const float* values, int num_values, int* indices) {
367 std::iota(indices, indices + num_values, 0);
368
369 // We want here a stable sort, in order to get completely defined output.
370 // In this way TFL and TFLM can be bit-exact.
371 std::stable_sort(
372 indices, indices + num_values,
373 [&values](const int i, const int j) { return values[i] > values[j]; });
374 }
375
SelectDetectionsAboveScoreThreshold(const std::vector<float> & values,const float threshold,std::vector<float> * keep_values,std::vector<int> * keep_indices)376 void SelectDetectionsAboveScoreThreshold(const std::vector<float>& values,
377 const float threshold,
378 std::vector<float>* keep_values,
379 std::vector<int>* keep_indices) {
380 for (int i = 0; i < values.size(); i++) {
381 if (values[i] >= threshold) {
382 keep_values->emplace_back(values[i]);
383 keep_indices->emplace_back(i);
384 }
385 }
386 }
387
ValidateBoxes(const TfLiteTensor * decoded_boxes,const int num_boxes)388 bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) {
389 for (int i = 0; i < num_boxes; ++i) {
390 auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
391 // Note: `ComputeIntersectionOverUnion` properly handles degenerated boxes
392 // (xmin == xmax and/or ymin == ymax) as it just returns 0 in case the box
393 // area is <= 0.
394 if (box.ymin > box.ymax || box.xmin > box.xmax) {
395 return false;
396 }
397 }
398 return true;
399 }
400
ComputeIntersectionOverUnion(const TfLiteTensor * decoded_boxes,const int i,const int j)401 float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes,
402 const int i, const int j) {
403 auto& box_i = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
404 auto& box_j = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[j];
405 const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
406 const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
407 if (area_i <= 0 || area_j <= 0) return 0.0;
408 const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
409 const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
410 const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
411 const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
412 const float intersection_area =
413 std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
414 std::max<float>(intersection_xmax - intersection_xmin, 0.0);
415 return intersection_area / (area_i + area_j - intersection_area);
416 }
417
418 // NonMaxSuppressionSingleClass() prunes out the box locations with high overlap
419 // before selecting the highest scoring boxes (max_detections in number)
420 // It assumes all boxes are good in beginning and sorts based on the scores.
421 // If lower-scoring box has too much overlap with a higher-scoring box,
422 // we get rid of the lower-scoring box.
423 // Complexity is O(N^2) pairwise comparison between boxes
NonMaxSuppressionSingleClassHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const std::vector<float> & scores,int max_detections,std::vector<int> * selected)424 TfLiteStatus NonMaxSuppressionSingleClassHelper(
425 TfLiteContext* context, TfLiteNode* node, OpData* op_data,
426 const std::vector<float>& scores, int max_detections,
427 std::vector<int>* selected) {
428 const TfLiteTensor* input_box_encodings;
429 TF_LITE_ENSURE_OK(context,
430 GetInputSafe(context, node, kInputTensorBoxEncodings,
431 &input_box_encodings));
432 const TfLiteTensor* decoded_boxes =
433 &context->tensors[op_data->decoded_boxes_index];
434 const int num_boxes = input_box_encodings->dims->data[1];
435 const float non_max_suppression_score_threshold =
436 op_data->non_max_suppression_score_threshold;
437 const float intersection_over_union_threshold =
438 op_data->intersection_over_union_threshold;
439 // Maximum detections should be positive.
440 TF_LITE_ENSURE(context, (max_detections >= 0));
441 // intersection_over_union_threshold should be positive
442 // and should be less than 1.
443 TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
444 (intersection_over_union_threshold <= 1.0f));
445 // Validate boxes
446 TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
447 TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
448
449 // threshold scores
450 std::vector<int> keep_indices;
451 // TODO(b/177068807): Remove the dynamic allocation and replace it
452 // with temporaries, esp for std::vector<float>
453 std::vector<float> keep_scores;
454 SelectDetectionsAboveScoreThreshold(
455 scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices);
456
457 int num_scores_kept = keep_scores.size();
458 std::vector<int> sorted_indices;
459 sorted_indices.resize(num_scores_kept);
460 DecreasingArgSort(keep_scores.data(), num_scores_kept, sorted_indices.data());
461
462 const int num_boxes_kept = num_scores_kept;
463 const int output_size = std::min(num_boxes_kept, max_detections);
464 selected->clear();
465 int num_active_candidate = num_boxes_kept;
466 std::vector<uint8_t> active_box_candidate(num_boxes_kept, 1);
467
468 for (int i = 0; i < num_boxes_kept; ++i) {
469 if (num_active_candidate == 0 || selected->size() >= output_size) break;
470 if (active_box_candidate[i] == 1) {
471 selected->push_back(keep_indices[sorted_indices[i]]);
472 active_box_candidate[i] = 0;
473 num_active_candidate--;
474 } else {
475 continue;
476 }
477 for (int j = i + 1; j < num_boxes_kept; ++j) {
478 if (active_box_candidate[j] == 1) {
479 TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
480 float intersection_over_union = ComputeIntersectionOverUnion(
481 decoded_boxes, keep_indices[sorted_indices[i]],
482 keep_indices[sorted_indices[j]]);
483
484 if (intersection_over_union > intersection_over_union_threshold) {
485 active_box_candidate[j] = 0;
486 num_active_candidate--;
487 }
488 }
489 }
490 }
491 return kTfLiteOk;
492 }
493
494 struct BoxInfo {
495 int index;
496 float score;
497 };
498
499 struct NMSTaskParam {
500 // Caller retains the ownership of `context`, `node`, `op_data` and `scores`.
501 // Caller should ensure their lifetime is longer than NMSTaskParam instance.
502 TfLiteContext* context;
503 TfLiteNode* node;
504 OpData* op_data;
505 const float* scores;
506
507 int num_classes;
508 int num_boxes;
509 int label_offset;
510 int num_classes_with_background;
511 int num_detections_per_class;
512 int max_detections;
513 std::vector<int>& num_selected;
514 };
515
InplaceMergeBoxInfo(std::vector<BoxInfo> & boxes,int mid_index,int end_index)516 void InplaceMergeBoxInfo(std::vector<BoxInfo>& boxes, int mid_index,
517 int end_index) {
518 std::inplace_merge(
519 boxes.begin(), boxes.begin() + mid_index, boxes.begin() + end_index,
520 [](const BoxInfo& a, const BoxInfo& b) { return a.score >= b.score; });
521 }
522
ComputeNMSResult(const NMSTaskParam & nms_task_param,int col_begin,int col_end,int & sorted_indices_size,std::vector<BoxInfo> & resulted_sorted_box_info)523 TfLiteStatus ComputeNMSResult(const NMSTaskParam& nms_task_param, int col_begin,
524 int col_end, int& sorted_indices_size,
525 std::vector<BoxInfo>& resulted_sorted_box_info) {
526 std::vector<float> class_scores(nms_task_param.num_boxes);
527 std::vector<int> selected;
528 selected.reserve(nms_task_param.num_detections_per_class);
529
530 for (int col = col_begin; col <= col_end; ++col) {
531 const float* scores_base =
532 nms_task_param.scores + col + nms_task_param.label_offset;
533 for (int row = 0; row < nms_task_param.num_boxes; row++) {
534 // Get scores of boxes corresponding to all anchors for single class
535 class_scores[row] = *scores_base;
536 scores_base += nms_task_param.num_classes_with_background;
537 }
538
539 // Perform non-maximal suppression on single class
540 selected.clear();
541 TF_LITE_ENSURE_OK(
542 nms_task_param.context,
543 NonMaxSuppressionSingleClassHelper(
544 nms_task_param.context, nms_task_param.node, nms_task_param.op_data,
545 class_scores, nms_task_param.num_detections_per_class, &selected));
546 if (selected.empty()) {
547 continue;
548 }
549
550 for (int i = 0; i < selected.size(); ++i) {
551 resulted_sorted_box_info[sorted_indices_size + i].score =
552 class_scores[selected[i]];
553 resulted_sorted_box_info[sorted_indices_size + i].index =
554 (selected[i] * nms_task_param.num_classes_with_background + col +
555 nms_task_param.label_offset);
556 }
557
558 // In-place merge the original boxes and new selected boxes which are both
559 // sorted by scores.
560 InplaceMergeBoxInfo(resulted_sorted_box_info, sorted_indices_size,
561 sorted_indices_size + selected.size());
562
563 sorted_indices_size =
564 std::min(sorted_indices_size + static_cast<int>(selected.size()),
565 nms_task_param.max_detections);
566 }
567 return kTfLiteOk;
568 }
569
570 struct NonMaxSuppressionWorkerTask : cpu_backend_threadpool::Task {
NonMaxSuppressionWorkerTasktflite::ops::custom::detection_postprocess::NonMaxSuppressionWorkerTask571 NonMaxSuppressionWorkerTask(NMSTaskParam& nms_task_param,
572 std::atomic<int>& next_col, int col_begin)
573 : nms_task_param(nms_task_param),
574 next_col(next_col),
575 col_begin(col_begin),
576 sorted_indices_size(0) {}
Runtflite::ops::custom::detection_postprocess::NonMaxSuppressionWorkerTask577 void Run() override {
578 sorted_box_info.resize(nms_task_param.num_detections_per_class +
579 nms_task_param.max_detections);
580 for (int col = col_begin; col < nms_task_param.num_classes;
581 col = (++next_col)) {
582 if (ComputeNMSResult(nms_task_param, col, col, sorted_indices_size,
583 sorted_box_info) != kTfLiteOk) {
584 break;
585 }
586 }
587 }
588 NMSTaskParam& nms_task_param;
589 // A shared atomic variable across threads, representing the next col this
590 // task will work on after completing the work for 'col_begin'
591 std::atomic<int>& next_col;
592 const int col_begin;
593 int sorted_indices_size;
594 std::vector<BoxInfo> sorted_box_info;
595 };
596
597 // This function implements a regular version of Non Maximal Suppression (NMS)
598 // for multiple classes where
599 // 1) we do NMS separately for each class across all anchors and
600 // 2) keep only the highest anchor scores across all classes
601 // 3) The worst runtime of the regular NMS is O(K*N^2)
602 // where N is the number of anchors and K the number of
603 // classes.
NonMaxSuppressionMultiClassRegularHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)604 TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
605 TfLiteNode* node,
606 OpData* op_data,
607 const float* scores) {
608 const TfLiteTensor* input_box_encodings;
609 TF_LITE_ENSURE_OK(context,
610 GetInputSafe(context, node, kInputTensorBoxEncodings,
611 &input_box_encodings));
612 const TfLiteTensor* input_class_predictions;
613 TF_LITE_ENSURE_OK(context,
614 GetInputSafe(context, node, kInputTensorClassPredictions,
615 &input_class_predictions));
616 const TfLiteTensor* decoded_boxes =
617 &context->tensors[op_data->decoded_boxes_index];
618
619 TfLiteTensor* detection_boxes;
620 TF_LITE_ENSURE_OK(context,
621 GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
622 &detection_boxes));
623 TfLiteTensor* detection_classes;
624 TF_LITE_ENSURE_OK(context,
625 GetOutputSafe(context, node, kOutputTensorDetectionClasses,
626 &detection_classes));
627 TfLiteTensor* detection_scores;
628 TF_LITE_ENSURE_OK(context,
629 GetOutputSafe(context, node, kOutputTensorDetectionScores,
630 &detection_scores));
631 TfLiteTensor* num_detections;
632 TF_LITE_ENSURE_OK(context,
633 GetOutputSafe(context, node, kOutputTensorNumDetections,
634 &num_detections));
635
636 const int num_boxes = input_box_encodings->dims->data[1];
637 const int num_classes = op_data->num_classes;
638 const int num_detections_per_class =
639 std::min(op_data->detections_per_class, op_data->max_detections);
640 const int max_detections = op_data->max_detections;
641 const int num_classes_with_background =
642 input_class_predictions->dims->data[2];
643 // The row index offset is 1 if background class is included and 0 otherwise.
644 int label_offset = num_classes_with_background - num_classes;
645 TF_LITE_ENSURE(context, num_detections_per_class > 0);
646
647 int sorted_indices_size = 0;
648 std::vector<BoxInfo> box_info_after_regular_non_max_suppression(
649 max_detections + num_detections_per_class);
650 std::vector<int> num_selected(num_classes);
651
652 NMSTaskParam nms_task_param{context,
653 node,
654 op_data,
655 scores,
656 num_classes,
657 num_boxes,
658 label_offset,
659 num_classes_with_background,
660 num_detections_per_class,
661 max_detections,
662 num_selected};
663
664 int num_threads =
665 CpuBackendContext::GetFromContext(context)->max_num_threads();
666 if (num_threads == 1) {
667 // For each class, perform non-max suppression.
668 TF_LITE_ENSURE_OK(
669 context, ComputeNMSResult(nms_task_param, /* col_begin= */ 0,
670 num_classes - 1, sorted_indices_size,
671 box_info_after_regular_non_max_suppression));
672 } else {
673 std::atomic<int> next_col(num_threads);
674 std::vector<NonMaxSuppressionWorkerTask> tasks;
675 tasks.reserve(num_threads);
676 for (int i = 0; i < num_threads; ++i) {
677 tasks.emplace_back(
678 NonMaxSuppressionWorkerTask(nms_task_param, next_col, i));
679 }
680 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
681 CpuBackendContext::GetFromContext(context));
682
683 // Merge results from tasks.
684 for (int j = 0; j < tasks.size(); ++j) {
685 if (tasks[j].sorted_indices_size == 0) {
686 continue;
687 }
688 memcpy(&box_info_after_regular_non_max_suppression[sorted_indices_size],
689 &tasks[j].sorted_box_info[0],
690 sizeof(BoxInfo) * tasks[j].sorted_indices_size);
691 InplaceMergeBoxInfo(box_info_after_regular_non_max_suppression,
692 sorted_indices_size,
693 sorted_indices_size + tasks[j].sorted_indices_size);
694 sorted_indices_size = std::min(
695 sorted_indices_size + tasks[j].sorted_indices_size, max_detections);
696 }
697 }
698
699 // Allocate output tensors
700 for (int output_box_index = 0; output_box_index < max_detections;
701 output_box_index++) {
702 if (output_box_index < sorted_indices_size) {
703 const int anchor_index = floor(
704 box_info_after_regular_non_max_suppression[output_box_index].index /
705 num_classes_with_background);
706 const int class_index =
707 box_info_after_regular_non_max_suppression[output_box_index].index -
708 anchor_index * num_classes_with_background - label_offset;
709 const float selected_score =
710 box_info_after_regular_non_max_suppression[output_box_index].score;
711 // detection_boxes
712 TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
713 TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
714 ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
715 ReInterpretTensor<const BoxCornerEncoding*>(
716 decoded_boxes)[anchor_index];
717 // detection_classes
718 GetTensorData<float>(detection_classes)[output_box_index] = class_index;
719 // detection_scores
720 GetTensorData<float>(detection_scores)[output_box_index] = selected_score;
721 } else {
722 TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
723 ReInterpretTensor<BoxCornerEncoding*>(
724 detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
725 // detection_classes
726 GetTensorData<float>(detection_classes)[output_box_index] = 0.0f;
727 // detection_scores
728 GetTensorData<float>(detection_scores)[output_box_index] = 0.0f;
729 }
730 }
731 GetTensorData<float>(num_detections)[0] = sorted_indices_size;
732 box_info_after_regular_non_max_suppression.clear();
733 return kTfLiteOk;
734 }
735
736 // This function implements a fast version of Non Maximal Suppression for
737 // multiple classes where
738 // 1) we keep the top-k scores for each anchor and
739 // 2) during NMS, each anchor only uses the highest class score for sorting.
740 // 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
741 // instead of O(KN^2) where N is the number of anchors and K the number of
742 // classes.
NonMaxSuppressionMultiClassFastHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)743 TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
744 TfLiteNode* node,
745 OpData* op_data,
746 const float* scores) {
747 const TfLiteTensor* input_box_encodings;
748 TF_LITE_ENSURE_OK(context,
749 GetInputSafe(context, node, kInputTensorBoxEncodings,
750 &input_box_encodings));
751 const TfLiteTensor* input_class_predictions;
752 TF_LITE_ENSURE_OK(context,
753 GetInputSafe(context, node, kInputTensorClassPredictions,
754 &input_class_predictions));
755 const TfLiteTensor* decoded_boxes =
756 &context->tensors[op_data->decoded_boxes_index];
757
758 TfLiteTensor* detection_boxes;
759 TF_LITE_ENSURE_OK(context,
760 GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
761 &detection_boxes));
762 TfLiteTensor* detection_classes;
763 TF_LITE_ENSURE_OK(context,
764 GetOutputSafe(context, node, kOutputTensorDetectionClasses,
765 &detection_classes));
766 TfLiteTensor* detection_scores;
767 TF_LITE_ENSURE_OK(context,
768 GetOutputSafe(context, node, kOutputTensorDetectionScores,
769 &detection_scores));
770 TfLiteTensor* num_detections;
771 TF_LITE_ENSURE_OK(context,
772 GetOutputSafe(context, node, kOutputTensorNumDetections,
773 &num_detections));
774
775 const int num_boxes = input_box_encodings->dims->data[1];
776 const int num_classes = op_data->num_classes;
777 const int max_categories_per_anchor = op_data->max_classes_per_detection;
778 const int num_classes_with_background =
779 input_class_predictions->dims->data[2];
780 // The row index offset is 1 if background class is included and 0 otherwise.
781 int label_offset = num_classes_with_background - num_classes;
782 TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
783 const int num_categories_per_anchor =
784 std::min(max_categories_per_anchor, num_classes);
785 std::vector<float> max_scores;
786 max_scores.resize(num_boxes);
787 std::vector<int> sorted_class_indices;
788 sorted_class_indices.resize(num_boxes * num_categories_per_anchor);
789 for (int row = 0; row < num_boxes; row++) {
790 const float* box_scores =
791 scores + row * num_classes_with_background + label_offset;
792 int* class_indices =
793 sorted_class_indices.data() + row * num_categories_per_anchor;
794 DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
795 class_indices);
796 max_scores[row] = box_scores[class_indices[0]];
797 }
798 // Perform non-maximal suppression on max scores
799 std::vector<int> selected;
800 TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
801 context, node, op_data, max_scores, op_data->max_detections, &selected));
802 // Allocate output tensors
803 int output_box_index = 0;
804 for (const auto& selected_index : selected) {
805 const float* box_scores =
806 scores + selected_index * num_classes_with_background + label_offset;
807 const int* class_indices = sorted_class_indices.data() +
808 selected_index * num_categories_per_anchor;
809
810 for (int col = 0; col < num_categories_per_anchor; ++col) {
811 int box_offset = max_categories_per_anchor * output_box_index + col;
812 // detection_boxes
813 TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
814 TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
815 ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
816 ReInterpretTensor<const BoxCornerEncoding*>(
817 decoded_boxes)[selected_index];
818 // detection_classes
819 GetTensorData<float>(detection_classes)[box_offset] = class_indices[col];
820 // detection_scores
821 GetTensorData<float>(detection_scores)[box_offset] =
822 box_scores[class_indices[col]];
823 }
824 output_box_index++;
825 }
826 GetTensorData<float>(num_detections)[0] = output_box_index;
827 return kTfLiteOk;
828 }
829
DequantizeClassPredictions(const TfLiteTensor * input_class_predictions,const int num_boxes,const int num_classes_with_background,TfLiteTensor * scores)830 void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions,
831 const int num_boxes,
832 const int num_classes_with_background,
833 TfLiteTensor* scores) {
834 float quant_zero_point =
835 static_cast<float>(input_class_predictions->params.zero_point);
836 float quant_scale = static_cast<float>(input_class_predictions->params.scale);
837 tflite::DequantizationParams op_params;
838 op_params.zero_point = quant_zero_point;
839 op_params.scale = quant_scale;
840 const auto shape = RuntimeShape(1, num_boxes * num_classes_with_background);
841 optimized_ops::Dequantize(op_params, shape,
842 GetTensorData<uint8>(input_class_predictions),
843 shape, GetTensorData<float>(scores));
844 }
845
NonMaxSuppressionMultiClass(TfLiteContext * context,TfLiteNode * node,OpData * op_data)846 TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
847 TfLiteNode* node, OpData* op_data) {
848 // Get the input tensors
849 const TfLiteTensor* input_box_encodings;
850 TF_LITE_ENSURE_OK(context,
851 GetInputSafe(context, node, kInputTensorBoxEncodings,
852 &input_box_encodings));
853 const TfLiteTensor* input_class_predictions;
854 TF_LITE_ENSURE_OK(context,
855 GetInputSafe(context, node, kInputTensorClassPredictions,
856 &input_class_predictions));
857 const int num_boxes = input_box_encodings->dims->data[1];
858 const int num_classes = op_data->num_classes;
859 TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
860 kBatchSize);
861 TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
862 const int num_classes_with_background =
863 input_class_predictions->dims->data[2];
864
865 TF_LITE_ENSURE(context, (num_classes_with_background - num_classes <= 1));
866 TF_LITE_ENSURE(context, (num_classes_with_background >= num_classes));
867
868 const TfLiteTensor* scores;
869 switch (input_class_predictions->type) {
870 case kTfLiteUInt8: {
871 TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index];
872 DequantizeClassPredictions(input_class_predictions, num_boxes,
873 num_classes_with_background, temporary_scores);
874 scores = temporary_scores;
875 } break;
876 case kTfLiteFloat32:
877 scores = input_class_predictions;
878 break;
879 default:
880 // Unsupported type.
881 return kTfLiteError;
882 }
883 if (op_data->use_regular_non_max_suppression)
884 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper(
885 context, node, op_data, GetTensorData<float>(scores)));
886 else
887 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassFastHelper(
888 context, node, op_data, GetTensorData<float>(scores)));
889
890 return kTfLiteOk;
891 }
892
Eval(TfLiteContext * context,TfLiteNode * node)893 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
894 // TODO(b/177068051): Generalize for any batch size.
895 TF_LITE_ENSURE(context, (kBatchSize == 1));
896 auto* op_data = static_cast<OpData*>(node->user_data);
897 // These two functions correspond to two blocks in the Object Detection model.
898 // In future, we would like to break the custom op in two blocks, which is
899 // currently not feasible because we would like to input quantized inputs
900 // and do all calculations in float. Mixed quantized/float calculations are
901 // currently not supported in TFLite.
902
903 // This fills in temporary decoded_boxes
904 // by transforming input_box_encodings and input_anchors from
905 // CenterSizeEncodings to BoxCornerEncoding
906 TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data));
907 // This fills in the output tensors
908 // by choosing effective set of decoded boxes
909 // based on Non Maximal Suppression, i.e. selecting
910 // highest scoring non-overlapping boxes.
911 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data));
912
913 return kTfLiteOk;
914 }
915 } // namespace detection_postprocess
916
Register_DETECTION_POSTPROCESS()917 TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
918 static TfLiteRegistration r = {
919 detection_postprocess::Init, detection_postprocess::Free,
920 detection_postprocess::Prepare, detection_postprocess::Eval};
921 return &r;
922 }
923
924 // Since the op is named "TFLite_Detection_PostProcess", the selective build
925 // tool will assume the register function is named
926 // "Register_TFLITE_DETECTION_POST_PROCESS".
Register_TFLITE_DETECTION_POST_PROCESS()927 TfLiteRegistration* Register_TFLITE_DETECTION_POST_PROCESS() {
928 return Register_DETECTION_POSTPROCESS();
929 }
930
931 } // namespace custom
932 } // namespace ops
933 } // namespace tflite
934