xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/detection_postprocess.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <math.h>
16 #include <stddef.h>
17 #include <stdint.h>
18 
19 #include <algorithm>
20 #include <initializer_list>
21 #include <numeric>
22 #include <vector>
23 
24 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/kernels/internal/compatibility.h"
27 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
29 #include "tensorflow/lite/kernels/internal/tensor.h"
30 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 
33 namespace tflite {
34 namespace ops {
35 namespace custom {
36 namespace detection_postprocess {
37 
38 // Input tensors
39 constexpr int kInputTensorBoxEncodings = 0;
40 constexpr int kInputTensorClassPredictions = 1;
41 constexpr int kInputTensorAnchors = 2;
42 
43 // Output tensors
44 // When max_classes_per_detection > 1, detection boxes will be replicated by the
45 // number of detected classes of that box. Dummy data will be appended if the
46 // number of classes is smaller than max_classes_per_detection.
47 constexpr int kOutputTensorDetectionBoxes = 0;
48 constexpr int kOutputTensorDetectionClasses = 1;
49 constexpr int kOutputTensorDetectionScores = 2;
50 constexpr int kOutputTensorNumDetections = 3;
51 
52 constexpr int kNumCoordBox = 4;
53 constexpr int kBatchSize = 1;
54 
55 constexpr int kNumDetectionsPerClass = 100;
56 
57 // Object Detection model produces axis-aligned boxes in two formats:
58 // BoxCorner represents the lower left corner (xmin, ymin) and
59 // the upper right corner (xmax, ymax).
60 // CenterSize represents the center (xcenter, ycenter), height and width.
61 // BoxCornerEncoding and CenterSizeEncoding are related as follows:
62 // ycenter = y / y_scale * anchor.h + anchor.y;
63 // xcenter = x / x_scale * anchor.w + anchor.x;
64 // half_h = 0.5*exp(h/ h_scale)) * anchor.h;
65 // half_w = 0.5*exp(w / w_scale)) * anchor.w;
66 // ymin = ycenter - half_h
67 // ymax = ycenter + half_h
68 // xmin = xcenter - half_w
69 // xmax = xcenter + half_w
70 struct BoxCornerEncoding {
71   float ymin;
72   float xmin;
73   float ymax;
74   float xmax;
75 };
76 
77 struct CenterSizeEncoding {
78   float y;
79   float x;
80   float h;
81   float w;
82 };
83 // We make sure that the memory allocations are contiguous with static assert.
84 static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
85               "Size of BoxCornerEncoding is 4 float values");
86 static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
87               "Size of CenterSizeEncoding is 4 float values");
88 
89 struct OpData {
90   int max_detections;
91   int max_classes_per_detection;  // Fast Non-Max-Suppression
92   int detections_per_class;       // Regular Non-Max-Suppression
93   float non_max_suppression_score_threshold;
94   float intersection_over_union_threshold;
95   int num_classes;
96   bool use_regular_non_max_suppression;
97   CenterSizeEncoding scale_values;
98   // Indices of Temporary tensors
99   int decoded_boxes_index;
100   int scores_index;
101 };
102 
Init(TfLiteContext * context,const char * buffer,size_t length)103 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
104   auto* op_data = new OpData;
105   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
106   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
107   op_data->max_detections = m["max_detections"].AsInt32();
108   op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
109   if (m["detections_per_class"].IsNull())
110     op_data->detections_per_class = kNumDetectionsPerClass;
111   else
112     op_data->detections_per_class = m["detections_per_class"].AsInt32();
113   if (m["use_regular_nms"].IsNull())
114     op_data->use_regular_non_max_suppression = false;
115   else
116     op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool();
117 
118   op_data->non_max_suppression_score_threshold =
119       m["nms_score_threshold"].AsFloat();
120   op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
121   op_data->num_classes = m["num_classes"].AsInt32();
122   op_data->scale_values.y = m["y_scale"].AsFloat();
123   op_data->scale_values.x = m["x_scale"].AsFloat();
124   op_data->scale_values.h = m["h_scale"].AsFloat();
125   op_data->scale_values.w = m["w_scale"].AsFloat();
126   context->AddTensors(context, 1, &op_data->decoded_boxes_index);
127   context->AddTensors(context, 1, &op_data->scores_index);
128   return op_data;
129 }
130 
Free(TfLiteContext * context,void * buffer)131 void Free(TfLiteContext* context, void* buffer) {
132   delete static_cast<OpData*>(buffer);
133 }
134 
SetTensorSizes(TfLiteContext * context,TfLiteTensor * tensor,std::initializer_list<int> values)135 TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor,
136                             std::initializer_list<int> values) {
137   TfLiteIntArray* size = TfLiteIntArrayCreate(values.size());
138   int index = 0;
139   for (const auto& v : values) {
140     size->data[index] = v;
141     ++index;
142   }
143   return context->ResizeTensor(context, tensor, size);
144 }
145 
Prepare(TfLiteContext * context,TfLiteNode * node)146 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
147   auto* op_data = static_cast<OpData*>(node->user_data);
148   // Inputs: box_encodings, scores, anchors
149   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
150   const TfLiteTensor* input_box_encodings;
151   TF_LITE_ENSURE_OK(context,
152                     GetInputSafe(context, node, kInputTensorBoxEncodings,
153                                  &input_box_encodings));
154   const TfLiteTensor* input_class_predictions;
155   TF_LITE_ENSURE_OK(context,
156                     GetInputSafe(context, node, kInputTensorClassPredictions,
157                                  &input_class_predictions));
158   const TfLiteTensor* input_anchors;
159   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
160                                           &input_anchors));
161   TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
162   TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
163   TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
164   // number of detected boxes
165   const int num_detected_boxes =
166       op_data->max_detections * op_data->max_classes_per_detection;
167 
168   // Outputs: detection_boxes, detection_scores, detection_classes,
169   // num_detections
170   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
171   // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
172   TfLiteTensor* detection_boxes;
173   TF_LITE_ENSURE_OK(context,
174                     GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
175                                   &detection_boxes));
176   detection_boxes->type = kTfLiteFloat32;
177   SetTensorSizes(context, detection_boxes,
178                  {kBatchSize, num_detected_boxes, kNumCoordBox});
179 
180   // Output Tensor detection_classes: size is set to (1, num_detected_boxes)
181   TfLiteTensor* detection_classes;
182   TF_LITE_ENSURE_OK(context,
183                     GetOutputSafe(context, node, kOutputTensorDetectionClasses,
184                                   &detection_classes));
185   detection_classes->type = kTfLiteFloat32;
186   SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes});
187 
188   // Output Tensor detection_scores: size is set to (1, num_detected_boxes)
189   TfLiteTensor* detection_scores;
190   TF_LITE_ENSURE_OK(context,
191                     GetOutputSafe(context, node, kOutputTensorDetectionScores,
192                                   &detection_scores));
193   detection_scores->type = kTfLiteFloat32;
194   SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes});
195 
196   // Output Tensor num_detections: size is set to 1
197   TfLiteTensor* num_detections;
198   TF_LITE_ENSURE_OK(context,
199                     GetOutputSafe(context, node, kOutputTensorNumDetections,
200                                   &num_detections));
201   num_detections->type = kTfLiteFloat32;
202   SetTensorSizes(context, num_detections, {1});
203 
204   // Temporary tensors
205   TfLiteIntArrayFree(node->temporaries);
206   node->temporaries = TfLiteIntArrayCreate(2);
207   node->temporaries->data[0] = op_data->decoded_boxes_index;
208   node->temporaries->data[1] = op_data->scores_index;
209 
210   // decoded_boxes
211   TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index];
212   decoded_boxes->type = kTfLiteFloat32;
213   decoded_boxes->allocation_type = kTfLiteArenaRw;
214   SetTensorSizes(context, decoded_boxes,
215                  {input_box_encodings->dims->data[1], kNumCoordBox});
216 
217   // scores
218   TfLiteTensor* scores = &context->tensors[op_data->scores_index];
219   scores->type = kTfLiteFloat32;
220   scores->allocation_type = kTfLiteArenaRw;
221   SetTensorSizes(context, scores,
222                  {input_class_predictions->dims->data[1],
223                   input_class_predictions->dims->data[2]});
224 
225   return kTfLiteOk;
226 }
227 
228 class Dequantizer {
229  public:
Dequantizer(int zero_point,float scale)230   Dequantizer(int zero_point, float scale)
231       : zero_point_(zero_point), scale_(scale) {}
operator ()(uint8 x)232   float operator()(uint8 x) {
233     return (static_cast<float>(x) - zero_point_) * scale_;
234   }
235 
236  private:
237   int zero_point_;
238   float scale_;
239 };
240 
DequantizeBoxEncodings(const TfLiteTensor * input_box_encodings,int idx,float quant_zero_point,float quant_scale,int length_box_encoding,CenterSizeEncoding * box_centersize)241 void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
242                             float quant_zero_point, float quant_scale,
243                             int length_box_encoding,
244                             CenterSizeEncoding* box_centersize) {
245   const uint8* boxes =
246       GetTensorData<uint8>(input_box_encodings) + length_box_encoding * idx;
247   Dequantizer dequantize(quant_zero_point, quant_scale);
248   // See definition of the KeyPointBoxCoder at
249   // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py
250   // The first four elements are the box coordinates, which is the same as the
251   // FastRnnBoxCoder at
252   // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/faster_rcnn_box_coder.py
253   box_centersize->y = dequantize(boxes[0]);
254   box_centersize->x = dequantize(boxes[1]);
255   box_centersize->h = dequantize(boxes[2]);
256   box_centersize->w = dequantize(boxes[3]);
257 }
258 
259 template <class T>
ReInterpretTensor(const TfLiteTensor * tensor)260 T ReInterpretTensor(const TfLiteTensor* tensor) {
261   const float* tensor_base = GetTensorData<float>(tensor);
262   return reinterpret_cast<T>(tensor_base);
263 }
264 
265 template <class T>
ReInterpretTensor(TfLiteTensor * tensor)266 T ReInterpretTensor(TfLiteTensor* tensor) {
267   float* tensor_base = GetTensorData<float>(tensor);
268   return reinterpret_cast<T>(tensor_base);
269 }
270 
DecodeCenterSizeBoxes(TfLiteContext * context,TfLiteNode * node,OpData * op_data)271 TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
272                                    OpData* op_data) {
273   // Parse input tensor boxencodings
274   const TfLiteTensor* input_box_encodings;
275   TF_LITE_ENSURE_OK(context,
276                     GetInputSafe(context, node, kInputTensorBoxEncodings,
277                                  &input_box_encodings));
278   TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
279   const int num_boxes = input_box_encodings->dims->data[1];
280   TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox);
281   const TfLiteTensor* input_anchors;
282   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
283                                           &input_anchors));
284 
285   // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
286   CenterSizeEncoding box_centersize;
287   CenterSizeEncoding scale_values = op_data->scale_values;
288   CenterSizeEncoding anchor;
289   for (int idx = 0; idx < num_boxes; ++idx) {
290     switch (input_box_encodings->type) {
291         // Quantized
292       case kTfLiteUInt8:
293         DequantizeBoxEncodings(
294             input_box_encodings, idx,
295             static_cast<float>(input_box_encodings->params.zero_point),
296             static_cast<float>(input_box_encodings->params.scale),
297             input_box_encodings->dims->data[2], &box_centersize);
298         DequantizeBoxEncodings(
299             input_anchors, idx,
300             static_cast<float>(input_anchors->params.zero_point),
301             static_cast<float>(input_anchors->params.scale), kNumCoordBox,
302             &anchor);
303         break;
304         // Float
305       case kTfLiteFloat32: {
306         // Please see DequantizeBoxEncodings function for the support detail.
307         const int box_encoding_idx = idx * input_box_encodings->dims->data[2];
308         const float* boxes =
309             &(GetTensorData<float>(input_box_encodings)[box_encoding_idx]);
310         box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
311         TF_LITE_ENSURE_EQ(context, input_anchors->type, kTfLiteFloat32);
312         anchor =
313             ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
314         break;
315       }
316       default:
317         // Unsupported type.
318         return kTfLiteError;
319     }
320 
321     float ycenter = static_cast<float>(static_cast<double>(box_centersize.y) /
322                                            static_cast<double>(scale_values.y) *
323                                            static_cast<double>(anchor.h) +
324                                        static_cast<double>(anchor.y));
325 
326     float xcenter = static_cast<float>(static_cast<double>(box_centersize.x) /
327                                            static_cast<double>(scale_values.x) *
328                                            static_cast<double>(anchor.w) +
329                                        static_cast<double>(anchor.x));
330 
331     float half_h =
332         static_cast<float>(0.5 *
333                            (std::exp(static_cast<double>(box_centersize.h) /
334                                      static_cast<double>(scale_values.h))) *
335                            static_cast<double>(anchor.h));
336     float half_w =
337         static_cast<float>(0.5 *
338                            (std::exp(static_cast<double>(box_centersize.w) /
339                                      static_cast<double>(scale_values.w))) *
340                            static_cast<double>(anchor.w));
341 
342     TfLiteTensor* decoded_boxes =
343         &context->tensors[op_data->decoded_boxes_index];
344     TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
345     auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
346     box.ymin = ycenter - half_h;
347     box.xmin = xcenter - half_w;
348     box.ymax = ycenter + half_h;
349     box.xmax = xcenter + half_w;
350   }
351   return kTfLiteOk;
352 }
353 
DecreasingPartialArgSort(const float * values,int num_values,int num_to_sort,int * indices)354 void DecreasingPartialArgSort(const float* values, int num_values,
355                               int num_to_sort, int* indices) {
356   if (num_to_sort == 1) {
357     indices[0] = optimized_ops::ArgMaxVector(values, num_values);
358   } else {
359     std::iota(indices, indices + num_values, 0);
360     std::partial_sort(
361         indices, indices + num_to_sort, indices + num_values,
362         [&values](const int i, const int j) { return values[i] > values[j]; });
363   }
364 }
365 
DecreasingArgSort(const float * values,int num_values,int * indices)366 void DecreasingArgSort(const float* values, int num_values, int* indices) {
367   std::iota(indices, indices + num_values, 0);
368 
369   // We want here a stable sort, in order to get completely defined output.
370   // In this way TFL and TFLM can be bit-exact.
371   std::stable_sort(
372       indices, indices + num_values,
373       [&values](const int i, const int j) { return values[i] > values[j]; });
374 }
375 
SelectDetectionsAboveScoreThreshold(const std::vector<float> & values,const float threshold,std::vector<float> * keep_values,std::vector<int> * keep_indices)376 void SelectDetectionsAboveScoreThreshold(const std::vector<float>& values,
377                                          const float threshold,
378                                          std::vector<float>* keep_values,
379                                          std::vector<int>* keep_indices) {
380   for (int i = 0; i < values.size(); i++) {
381     if (values[i] >= threshold) {
382       keep_values->emplace_back(values[i]);
383       keep_indices->emplace_back(i);
384     }
385   }
386 }
387 
ValidateBoxes(const TfLiteTensor * decoded_boxes,const int num_boxes)388 bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) {
389   for (int i = 0; i < num_boxes; ++i) {
390     auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
391     // Note: `ComputeIntersectionOverUnion` properly handles degenerated boxes
392     // (xmin == xmax and/or ymin == ymax) as it just returns 0 in case the box
393     // area is <= 0.
394     if (box.ymin > box.ymax || box.xmin > box.xmax) {
395       return false;
396     }
397   }
398   return true;
399 }
400 
ComputeIntersectionOverUnion(const TfLiteTensor * decoded_boxes,const int i,const int j)401 float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes,
402                                    const int i, const int j) {
403   auto& box_i = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
404   auto& box_j = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[j];
405   const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
406   const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
407   if (area_i <= 0 || area_j <= 0) return 0.0;
408   const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
409   const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
410   const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
411   const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
412   const float intersection_area =
413       std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
414       std::max<float>(intersection_xmax - intersection_xmin, 0.0);
415   return intersection_area / (area_i + area_j - intersection_area);
416 }
417 
418 // NonMaxSuppressionSingleClass() prunes out the box locations with high overlap
419 // before selecting the highest scoring boxes (max_detections in number)
420 // It assumes all boxes are good in beginning and sorts based on the scores.
421 // If lower-scoring box has too much overlap with a higher-scoring box,
422 // we get rid of the lower-scoring box.
423 // Complexity is O(N^2) pairwise comparison between boxes
NonMaxSuppressionSingleClassHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const std::vector<float> & scores,int max_detections,std::vector<int> * selected)424 TfLiteStatus NonMaxSuppressionSingleClassHelper(
425     TfLiteContext* context, TfLiteNode* node, OpData* op_data,
426     const std::vector<float>& scores, int max_detections,
427     std::vector<int>* selected) {
428   const TfLiteTensor* input_box_encodings;
429   TF_LITE_ENSURE_OK(context,
430                     GetInputSafe(context, node, kInputTensorBoxEncodings,
431                                  &input_box_encodings));
432   const TfLiteTensor* decoded_boxes =
433       &context->tensors[op_data->decoded_boxes_index];
434   const int num_boxes = input_box_encodings->dims->data[1];
435   const float non_max_suppression_score_threshold =
436       op_data->non_max_suppression_score_threshold;
437   const float intersection_over_union_threshold =
438       op_data->intersection_over_union_threshold;
439   // Maximum detections should be positive.
440   TF_LITE_ENSURE(context, (max_detections >= 0));
441   // intersection_over_union_threshold should be positive
442   // and should be less than 1.
443   TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
444                               (intersection_over_union_threshold <= 1.0f));
445   // Validate boxes
446   TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
447   TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
448 
449   // threshold scores
450   std::vector<int> keep_indices;
451   // TODO(b/177068807): Remove the dynamic allocation and replace it
452   // with temporaries, esp for std::vector<float>
453   std::vector<float> keep_scores;
454   SelectDetectionsAboveScoreThreshold(
455       scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices);
456 
457   int num_scores_kept = keep_scores.size();
458   std::vector<int> sorted_indices;
459   sorted_indices.resize(num_scores_kept);
460   DecreasingArgSort(keep_scores.data(), num_scores_kept, sorted_indices.data());
461 
462   const int num_boxes_kept = num_scores_kept;
463   const int output_size = std::min(num_boxes_kept, max_detections);
464   selected->clear();
465   int num_active_candidate = num_boxes_kept;
466   std::vector<uint8_t> active_box_candidate(num_boxes_kept, 1);
467 
468   for (int i = 0; i < num_boxes_kept; ++i) {
469     if (num_active_candidate == 0 || selected->size() >= output_size) break;
470     if (active_box_candidate[i] == 1) {
471       selected->push_back(keep_indices[sorted_indices[i]]);
472       active_box_candidate[i] = 0;
473       num_active_candidate--;
474     } else {
475       continue;
476     }
477     for (int j = i + 1; j < num_boxes_kept; ++j) {
478       if (active_box_candidate[j] == 1) {
479         TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
480         float intersection_over_union = ComputeIntersectionOverUnion(
481             decoded_boxes, keep_indices[sorted_indices[i]],
482             keep_indices[sorted_indices[j]]);
483 
484         if (intersection_over_union > intersection_over_union_threshold) {
485           active_box_candidate[j] = 0;
486           num_active_candidate--;
487         }
488       }
489     }
490   }
491   return kTfLiteOk;
492 }
493 
494 struct BoxInfo {
495   int index;
496   float score;
497 };
498 
499 struct NMSTaskParam {
500   // Caller retains the ownership of `context`, `node`, `op_data` and `scores`.
501   // Caller should ensure their lifetime is longer than NMSTaskParam instance.
502   TfLiteContext* context;
503   TfLiteNode* node;
504   OpData* op_data;
505   const float* scores;
506 
507   int num_classes;
508   int num_boxes;
509   int label_offset;
510   int num_classes_with_background;
511   int num_detections_per_class;
512   int max_detections;
513   std::vector<int>& num_selected;
514 };
515 
InplaceMergeBoxInfo(std::vector<BoxInfo> & boxes,int mid_index,int end_index)516 void InplaceMergeBoxInfo(std::vector<BoxInfo>& boxes, int mid_index,
517                          int end_index) {
518   std::inplace_merge(
519       boxes.begin(), boxes.begin() + mid_index, boxes.begin() + end_index,
520       [](const BoxInfo& a, const BoxInfo& b) { return a.score >= b.score; });
521 }
522 
ComputeNMSResult(const NMSTaskParam & nms_task_param,int col_begin,int col_end,int & sorted_indices_size,std::vector<BoxInfo> & resulted_sorted_box_info)523 TfLiteStatus ComputeNMSResult(const NMSTaskParam& nms_task_param, int col_begin,
524                               int col_end, int& sorted_indices_size,
525                               std::vector<BoxInfo>& resulted_sorted_box_info) {
526   std::vector<float> class_scores(nms_task_param.num_boxes);
527   std::vector<int> selected;
528   selected.reserve(nms_task_param.num_detections_per_class);
529 
530   for (int col = col_begin; col <= col_end; ++col) {
531     const float* scores_base =
532         nms_task_param.scores + col + nms_task_param.label_offset;
533     for (int row = 0; row < nms_task_param.num_boxes; row++) {
534       // Get scores of boxes corresponding to all anchors for single class
535       class_scores[row] = *scores_base;
536       scores_base += nms_task_param.num_classes_with_background;
537     }
538 
539     // Perform non-maximal suppression on single class
540     selected.clear();
541     TF_LITE_ENSURE_OK(
542         nms_task_param.context,
543         NonMaxSuppressionSingleClassHelper(
544             nms_task_param.context, nms_task_param.node, nms_task_param.op_data,
545             class_scores, nms_task_param.num_detections_per_class, &selected));
546     if (selected.empty()) {
547       continue;
548     }
549 
550     for (int i = 0; i < selected.size(); ++i) {
551       resulted_sorted_box_info[sorted_indices_size + i].score =
552           class_scores[selected[i]];
553       resulted_sorted_box_info[sorted_indices_size + i].index =
554           (selected[i] * nms_task_param.num_classes_with_background + col +
555            nms_task_param.label_offset);
556     }
557 
558     // In-place merge the original boxes and new selected boxes which are both
559     // sorted by scores.
560     InplaceMergeBoxInfo(resulted_sorted_box_info, sorted_indices_size,
561                         sorted_indices_size + selected.size());
562 
563     sorted_indices_size =
564         std::min(sorted_indices_size + static_cast<int>(selected.size()),
565                  nms_task_param.max_detections);
566   }
567   return kTfLiteOk;
568 }
569 
570 struct NonMaxSuppressionWorkerTask : cpu_backend_threadpool::Task {
NonMaxSuppressionWorkerTasktflite::ops::custom::detection_postprocess::NonMaxSuppressionWorkerTask571   NonMaxSuppressionWorkerTask(NMSTaskParam& nms_task_param,
572                               std::atomic<int>& next_col, int col_begin)
573       : nms_task_param(nms_task_param),
574         next_col(next_col),
575         col_begin(col_begin),
576         sorted_indices_size(0) {}
Runtflite::ops::custom::detection_postprocess::NonMaxSuppressionWorkerTask577   void Run() override {
578     sorted_box_info.resize(nms_task_param.num_detections_per_class +
579                            nms_task_param.max_detections);
580     for (int col = col_begin; col < nms_task_param.num_classes;
581          col = (++next_col)) {
582       if (ComputeNMSResult(nms_task_param, col, col, sorted_indices_size,
583                            sorted_box_info) != kTfLiteOk) {
584         break;
585       }
586     }
587   }
588   NMSTaskParam& nms_task_param;
589   // A shared atomic variable across threads, representing the next col this
590   // task will work on after completing the work for 'col_begin'
591   std::atomic<int>& next_col;
592   const int col_begin;
593   int sorted_indices_size;
594   std::vector<BoxInfo> sorted_box_info;
595 };
596 
597 // This function implements a regular version of Non Maximal Suppression (NMS)
598 // for multiple classes where
599 // 1) we do NMS separately for each class across all anchors and
600 // 2) keep only the highest anchor scores across all classes
601 // 3) The worst runtime of the regular NMS is O(K*N^2)
602 // where N is the number of anchors and K the number of
603 // classes.
NonMaxSuppressionMultiClassRegularHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)604 TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
605                                                       TfLiteNode* node,
606                                                       OpData* op_data,
607                                                       const float* scores) {
608   const TfLiteTensor* input_box_encodings;
609   TF_LITE_ENSURE_OK(context,
610                     GetInputSafe(context, node, kInputTensorBoxEncodings,
611                                  &input_box_encodings));
612   const TfLiteTensor* input_class_predictions;
613   TF_LITE_ENSURE_OK(context,
614                     GetInputSafe(context, node, kInputTensorClassPredictions,
615                                  &input_class_predictions));
616   const TfLiteTensor* decoded_boxes =
617       &context->tensors[op_data->decoded_boxes_index];
618 
619   TfLiteTensor* detection_boxes;
620   TF_LITE_ENSURE_OK(context,
621                     GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
622                                   &detection_boxes));
623   TfLiteTensor* detection_classes;
624   TF_LITE_ENSURE_OK(context,
625                     GetOutputSafe(context, node, kOutputTensorDetectionClasses,
626                                   &detection_classes));
627   TfLiteTensor* detection_scores;
628   TF_LITE_ENSURE_OK(context,
629                     GetOutputSafe(context, node, kOutputTensorDetectionScores,
630                                   &detection_scores));
631   TfLiteTensor* num_detections;
632   TF_LITE_ENSURE_OK(context,
633                     GetOutputSafe(context, node, kOutputTensorNumDetections,
634                                   &num_detections));
635 
636   const int num_boxes = input_box_encodings->dims->data[1];
637   const int num_classes = op_data->num_classes;
638   const int num_detections_per_class =
639       std::min(op_data->detections_per_class, op_data->max_detections);
640   const int max_detections = op_data->max_detections;
641   const int num_classes_with_background =
642       input_class_predictions->dims->data[2];
643   // The row index offset is 1 if background class is included and 0 otherwise.
644   int label_offset = num_classes_with_background - num_classes;
645   TF_LITE_ENSURE(context, num_detections_per_class > 0);
646 
647   int sorted_indices_size = 0;
648   std::vector<BoxInfo> box_info_after_regular_non_max_suppression(
649       max_detections + num_detections_per_class);
650   std::vector<int> num_selected(num_classes);
651 
652   NMSTaskParam nms_task_param{context,
653                               node,
654                               op_data,
655                               scores,
656                               num_classes,
657                               num_boxes,
658                               label_offset,
659                               num_classes_with_background,
660                               num_detections_per_class,
661                               max_detections,
662                               num_selected};
663 
664   int num_threads =
665       CpuBackendContext::GetFromContext(context)->max_num_threads();
666   if (num_threads == 1) {
667     // For each class, perform non-max suppression.
668     TF_LITE_ENSURE_OK(
669         context, ComputeNMSResult(nms_task_param, /* col_begin= */ 0,
670                                   num_classes - 1, sorted_indices_size,
671                                   box_info_after_regular_non_max_suppression));
672   } else {
673     std::atomic<int> next_col(num_threads);
674     std::vector<NonMaxSuppressionWorkerTask> tasks;
675     tasks.reserve(num_threads);
676     for (int i = 0; i < num_threads; ++i) {
677       tasks.emplace_back(
678           NonMaxSuppressionWorkerTask(nms_task_param, next_col, i));
679     }
680     cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
681                                     CpuBackendContext::GetFromContext(context));
682 
683     // Merge results from tasks.
684     for (int j = 0; j < tasks.size(); ++j) {
685       if (tasks[j].sorted_indices_size == 0) {
686         continue;
687       }
688       memcpy(&box_info_after_regular_non_max_suppression[sorted_indices_size],
689              &tasks[j].sorted_box_info[0],
690              sizeof(BoxInfo) * tasks[j].sorted_indices_size);
691       InplaceMergeBoxInfo(box_info_after_regular_non_max_suppression,
692                           sorted_indices_size,
693                           sorted_indices_size + tasks[j].sorted_indices_size);
694       sorted_indices_size = std::min(
695           sorted_indices_size + tasks[j].sorted_indices_size, max_detections);
696     }
697   }
698 
699   // Allocate output tensors
700   for (int output_box_index = 0; output_box_index < max_detections;
701        output_box_index++) {
702     if (output_box_index < sorted_indices_size) {
703       const int anchor_index = floor(
704           box_info_after_regular_non_max_suppression[output_box_index].index /
705           num_classes_with_background);
706       const int class_index =
707           box_info_after_regular_non_max_suppression[output_box_index].index -
708           anchor_index * num_classes_with_background - label_offset;
709       const float selected_score =
710           box_info_after_regular_non_max_suppression[output_box_index].score;
711       // detection_boxes
712       TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
713       TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
714       ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
715           ReInterpretTensor<const BoxCornerEncoding*>(
716               decoded_boxes)[anchor_index];
717       // detection_classes
718       GetTensorData<float>(detection_classes)[output_box_index] = class_index;
719       // detection_scores
720       GetTensorData<float>(detection_scores)[output_box_index] = selected_score;
721     } else {
722       TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
723       ReInterpretTensor<BoxCornerEncoding*>(
724           detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
725       // detection_classes
726       GetTensorData<float>(detection_classes)[output_box_index] = 0.0f;
727       // detection_scores
728       GetTensorData<float>(detection_scores)[output_box_index] = 0.0f;
729     }
730   }
731   GetTensorData<float>(num_detections)[0] = sorted_indices_size;
732   box_info_after_regular_non_max_suppression.clear();
733   return kTfLiteOk;
734 }
735 
736 // This function implements a fast version of Non Maximal Suppression for
737 // multiple classes where
738 // 1) we keep the top-k scores for each anchor and
739 // 2) during NMS, each anchor only uses the highest class score for sorting.
740 // 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
741 // instead of O(KN^2) where N is the number of anchors and K the number of
742 // classes.
NonMaxSuppressionMultiClassFastHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)743 TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
744                                                    TfLiteNode* node,
745                                                    OpData* op_data,
746                                                    const float* scores) {
747   const TfLiteTensor* input_box_encodings;
748   TF_LITE_ENSURE_OK(context,
749                     GetInputSafe(context, node, kInputTensorBoxEncodings,
750                                  &input_box_encodings));
751   const TfLiteTensor* input_class_predictions;
752   TF_LITE_ENSURE_OK(context,
753                     GetInputSafe(context, node, kInputTensorClassPredictions,
754                                  &input_class_predictions));
755   const TfLiteTensor* decoded_boxes =
756       &context->tensors[op_data->decoded_boxes_index];
757 
758   TfLiteTensor* detection_boxes;
759   TF_LITE_ENSURE_OK(context,
760                     GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
761                                   &detection_boxes));
762   TfLiteTensor* detection_classes;
763   TF_LITE_ENSURE_OK(context,
764                     GetOutputSafe(context, node, kOutputTensorDetectionClasses,
765                                   &detection_classes));
766   TfLiteTensor* detection_scores;
767   TF_LITE_ENSURE_OK(context,
768                     GetOutputSafe(context, node, kOutputTensorDetectionScores,
769                                   &detection_scores));
770   TfLiteTensor* num_detections;
771   TF_LITE_ENSURE_OK(context,
772                     GetOutputSafe(context, node, kOutputTensorNumDetections,
773                                   &num_detections));
774 
775   const int num_boxes = input_box_encodings->dims->data[1];
776   const int num_classes = op_data->num_classes;
777   const int max_categories_per_anchor = op_data->max_classes_per_detection;
778   const int num_classes_with_background =
779       input_class_predictions->dims->data[2];
780   // The row index offset is 1 if background class is included and 0 otherwise.
781   int label_offset = num_classes_with_background - num_classes;
782   TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
783   const int num_categories_per_anchor =
784       std::min(max_categories_per_anchor, num_classes);
785   std::vector<float> max_scores;
786   max_scores.resize(num_boxes);
787   std::vector<int> sorted_class_indices;
788   sorted_class_indices.resize(num_boxes * num_categories_per_anchor);
789   for (int row = 0; row < num_boxes; row++) {
790     const float* box_scores =
791         scores + row * num_classes_with_background + label_offset;
792     int* class_indices =
793         sorted_class_indices.data() + row * num_categories_per_anchor;
794     DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
795                              class_indices);
796     max_scores[row] = box_scores[class_indices[0]];
797   }
798   // Perform non-maximal suppression on max scores
799   std::vector<int> selected;
800   TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
801       context, node, op_data, max_scores, op_data->max_detections, &selected));
802   // Allocate output tensors
803   int output_box_index = 0;
804   for (const auto& selected_index : selected) {
805     const float* box_scores =
806         scores + selected_index * num_classes_with_background + label_offset;
807     const int* class_indices = sorted_class_indices.data() +
808                                selected_index * num_categories_per_anchor;
809 
810     for (int col = 0; col < num_categories_per_anchor; ++col) {
811       int box_offset = max_categories_per_anchor * output_box_index + col;
812       // detection_boxes
813       TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
814       TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
815       ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
816           ReInterpretTensor<const BoxCornerEncoding*>(
817               decoded_boxes)[selected_index];
818       // detection_classes
819       GetTensorData<float>(detection_classes)[box_offset] = class_indices[col];
820       // detection_scores
821       GetTensorData<float>(detection_scores)[box_offset] =
822           box_scores[class_indices[col]];
823     }
824     output_box_index++;
825   }
826   GetTensorData<float>(num_detections)[0] = output_box_index;
827   return kTfLiteOk;
828 }
829 
DequantizeClassPredictions(const TfLiteTensor * input_class_predictions,const int num_boxes,const int num_classes_with_background,TfLiteTensor * scores)830 void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions,
831                                 const int num_boxes,
832                                 const int num_classes_with_background,
833                                 TfLiteTensor* scores) {
834   float quant_zero_point =
835       static_cast<float>(input_class_predictions->params.zero_point);
836   float quant_scale = static_cast<float>(input_class_predictions->params.scale);
837   tflite::DequantizationParams op_params;
838   op_params.zero_point = quant_zero_point;
839   op_params.scale = quant_scale;
840   const auto shape = RuntimeShape(1, num_boxes * num_classes_with_background);
841   optimized_ops::Dequantize(op_params, shape,
842                             GetTensorData<uint8>(input_class_predictions),
843                             shape, GetTensorData<float>(scores));
844 }
845 
NonMaxSuppressionMultiClass(TfLiteContext * context,TfLiteNode * node,OpData * op_data)846 TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
847                                          TfLiteNode* node, OpData* op_data) {
848   // Get the input tensors
849   const TfLiteTensor* input_box_encodings;
850   TF_LITE_ENSURE_OK(context,
851                     GetInputSafe(context, node, kInputTensorBoxEncodings,
852                                  &input_box_encodings));
853   const TfLiteTensor* input_class_predictions;
854   TF_LITE_ENSURE_OK(context,
855                     GetInputSafe(context, node, kInputTensorClassPredictions,
856                                  &input_class_predictions));
857   const int num_boxes = input_box_encodings->dims->data[1];
858   const int num_classes = op_data->num_classes;
859   TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
860                     kBatchSize);
861   TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
862   const int num_classes_with_background =
863       input_class_predictions->dims->data[2];
864 
865   TF_LITE_ENSURE(context, (num_classes_with_background - num_classes <= 1));
866   TF_LITE_ENSURE(context, (num_classes_with_background >= num_classes));
867 
868   const TfLiteTensor* scores;
869   switch (input_class_predictions->type) {
870     case kTfLiteUInt8: {
871       TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index];
872       DequantizeClassPredictions(input_class_predictions, num_boxes,
873                                  num_classes_with_background, temporary_scores);
874       scores = temporary_scores;
875     } break;
876     case kTfLiteFloat32:
877       scores = input_class_predictions;
878       break;
879     default:
880       // Unsupported type.
881       return kTfLiteError;
882   }
883   if (op_data->use_regular_non_max_suppression)
884     TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper(
885         context, node, op_data, GetTensorData<float>(scores)));
886   else
887     TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassFastHelper(
888         context, node, op_data, GetTensorData<float>(scores)));
889 
890   return kTfLiteOk;
891 }
892 
Eval(TfLiteContext * context,TfLiteNode * node)893 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
894   // TODO(b/177068051):  Generalize for any batch size.
895   TF_LITE_ENSURE(context, (kBatchSize == 1));
896   auto* op_data = static_cast<OpData*>(node->user_data);
897   // These two functions correspond to two blocks in the Object Detection model.
898   // In future, we would like to break the custom op in two blocks, which is
899   // currently not feasible because we would like to input quantized inputs
900   // and do all calculations in float. Mixed quantized/float calculations are
901   // currently not supported in TFLite.
902 
903   // This fills in temporary decoded_boxes
904   // by transforming input_box_encodings and input_anchors from
905   // CenterSizeEncodings to BoxCornerEncoding
906   TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data));
907   // This fills in the output tensors
908   // by choosing effective set of decoded boxes
909   // based on Non Maximal Suppression, i.e. selecting
910   // highest scoring non-overlapping boxes.
911   TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data));
912 
913   return kTfLiteOk;
914 }
915 }  // namespace detection_postprocess
916 
Register_DETECTION_POSTPROCESS()917 TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
918   static TfLiteRegistration r = {
919       detection_postprocess::Init, detection_postprocess::Free,
920       detection_postprocess::Prepare, detection_postprocess::Eval};
921   return &r;
922 }
923 
924 // Since the op is named "TFLite_Detection_PostProcess", the selective build
925 // tool will assume the register function is named
926 // "Register_TFLITE_DETECTION_POST_PROCESS".
Register_TFLITE_DETECTION_POST_PROCESS()927 TfLiteRegistration* Register_TFLITE_DETECTION_POST_PROCESS() {
928   return Register_DETECTION_POSTPROCESS();
929 }
930 
931 }  // namespace custom
932 }  // namespace ops
933 }  // namespace tflite
934