1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/evaluation/stages/object_detection_average_precision_stage.h"
16 
17 #include <stdint.h>
18 
19 #include <numeric>
20 
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
23 
24 namespace tflite {
25 namespace evaluation {
26 namespace {
27 
ConvertProtoToDetection(const ObjectDetectionResult::ObjectInstance & input,int image_id)28 image::Detection ConvertProtoToDetection(
29     const ObjectDetectionResult::ObjectInstance& input, int image_id) {
30   image::Detection detection;
31   detection.box.x.min = input.bounding_box().normalized_left();
32   detection.box.x.max = input.bounding_box().normalized_right();
33   detection.box.y.min = input.bounding_box().normalized_top();
34   detection.box.y.max = input.bounding_box().normalized_bottom();
35   detection.imgid = image_id;
36   detection.score = input.score();
37   return detection;
38 }
39 
40 }  // namespace
41 
Init()42 TfLiteStatus ObjectDetectionAveragePrecisionStage::Init() {
43   num_classes_ = config_.specification()
44                      .object_detection_average_precision_params()
45                      .num_classes();
46   if (num_classes_ <= 0) {
47     LOG(ERROR) << "num_classes cannot be <= 0";
48     return kTfLiteError;
49   }
50 
51   // Initialize per-class data structures.
52   for (int i = 0; i < num_classes_; ++i) {
53     ground_truth_object_vectors_.emplace_back();
54     predicted_object_vectors_.emplace_back();
55   }
56   return kTfLiteOk;
57 }
58 
Run()59 TfLiteStatus ObjectDetectionAveragePrecisionStage::Run() {
60   for (int i = 0; i < ground_truth_objects_.objects_size(); ++i) {
61     const int class_id = ground_truth_objects_.objects(i).class_id();
62     if (class_id >= num_classes_) {
63       LOG(ERROR) << "Encountered invalid class ID: " << class_id;
64       return kTfLiteError;
65     }
66 
67     ground_truth_object_vectors_[class_id].push_back(ConvertProtoToDetection(
68         ground_truth_objects_.objects(i), current_image_index_));
69   }
70 
71   for (int i = 0; i < predicted_objects_.objects_size(); ++i) {
72     const int class_id = predicted_objects_.objects(i).class_id();
73     if (class_id >= num_classes_) {
74       LOG(ERROR) << "Encountered invalid class ID: " << class_id;
75       return kTfLiteError;
76     }
77 
78     predicted_object_vectors_[class_id].push_back(ConvertProtoToDetection(
79         predicted_objects_.objects(i), current_image_index_));
80   }
81 
82   current_image_index_++;
83   return kTfLiteOk;
84 }
85 
LatestMetrics()86 EvaluationStageMetrics ObjectDetectionAveragePrecisionStage::LatestMetrics() {
87   EvaluationStageMetrics metrics;
88   if (current_image_index_ == 0) return metrics;
89 
90   metrics.set_num_runs(current_image_index_);
91   auto* ap_metrics = metrics.mutable_process_metrics()
92                          ->mutable_object_detection_average_precision_metrics();
93   auto& ap_params =
94       config_.specification().object_detection_average_precision_params();
95 
96   std::vector<float> iou_thresholds;
97   if (ap_params.iou_thresholds_size() == 0) {
98     // Default IoU thresholds as defined by COCO evaluation.
99     // Refer: http://cocodataset.org/#detection-eval
100     float threshold = 0.5;
101     for (int i = 0; i < 10; ++i) {
102       iou_thresholds.push_back(threshold + i * 0.05);
103     }
104   } else {
105     for (auto& threshold : ap_params.iou_thresholds()) {
106       iou_thresholds.push_back(threshold);
107     }
108   }
109 
110   image::AveragePrecision::Options opts;
111   opts.num_recall_points = ap_params.num_recall_points();
112 
113   float ap_sum = 0;
114   int num_total_aps = 0;
115   for (float threshold : iou_thresholds) {
116     float threshold_ap_sum = 0;
117     int num_counted_classes = 0;
118 
119     for (int i = 0; i < num_classes_; ++i) {
120       // Skip if this class wasn't encountered at all.
121       // TODO(b/133772912): Investigate the validity of this snippet when a
122       // subset of the classes is encountered in datasets.
123       if (ground_truth_object_vectors_[i].empty() &&
124           predicted_object_vectors_[i].empty())
125         continue;
126 
127       // Output is NaN if there are no ground truth objects.
128       // So we assume 0.
129       float ap_value = 0.0;
130       if (!ground_truth_object_vectors_[i].empty()) {
131         opts.iou_threshold = threshold;
132         ap_value = image::AveragePrecision(opts).FromBoxes(
133             ground_truth_object_vectors_[i], predicted_object_vectors_[i]);
134       }
135 
136       ap_sum += ap_value;
137       num_total_aps += 1;
138       threshold_ap_sum += ap_value;
139       num_counted_classes += 1;
140     }
141 
142     if (num_counted_classes == 0) continue;
143     auto* threshold_ap = ap_metrics->add_individual_average_precisions();
144     threshold_ap->set_average_precision(threshold_ap_sum / num_counted_classes);
145     threshold_ap->set_iou_threshold(threshold);
146   }
147 
148   if (num_total_aps == 0) return metrics;
149   ap_metrics->set_overall_mean_average_precision(ap_sum / num_total_aps);
150   return metrics;
151 }
152 
153 }  // namespace evaluation
154 }  // namespace tflite
155