1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/lite/tools/evaluation/stages/object_detection_average_precision_stage.h" 16 17 #include <stdint.h> 18 19 #include <numeric> 20 21 #include "tensorflow/core/platform/logging.h" 22 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" 23 24 namespace tflite { 25 namespace evaluation { 26 namespace { 27 ConvertProtoToDetection(const ObjectDetectionResult::ObjectInstance & input,int image_id)28image::Detection ConvertProtoToDetection( 29 const ObjectDetectionResult::ObjectInstance& input, int image_id) { 30 image::Detection detection; 31 detection.box.x.min = input.bounding_box().normalized_left(); 32 detection.box.x.max = input.bounding_box().normalized_right(); 33 detection.box.y.min = input.bounding_box().normalized_top(); 34 detection.box.y.max = input.bounding_box().normalized_bottom(); 35 detection.imgid = image_id; 36 detection.score = input.score(); 37 return detection; 38 } 39 40 } // namespace 41 Init()42TfLiteStatus ObjectDetectionAveragePrecisionStage::Init() { 43 num_classes_ = config_.specification() 44 .object_detection_average_precision_params() 45 .num_classes(); 46 if (num_classes_ <= 0) { 47 LOG(ERROR) << "num_classes cannot be <= 0"; 48 return kTfLiteError; 49 } 50 51 // Initialize per-class data structures. 52 for (int i = 0; i < num_classes_; ++i) { 53 ground_truth_object_vectors_.emplace_back(); 54 predicted_object_vectors_.emplace_back(); 55 } 56 return kTfLiteOk; 57 } 58 Run()59TfLiteStatus ObjectDetectionAveragePrecisionStage::Run() { 60 for (int i = 0; i < ground_truth_objects_.objects_size(); ++i) { 61 const int class_id = ground_truth_objects_.objects(i).class_id(); 62 if (class_id >= num_classes_) { 63 LOG(ERROR) << "Encountered invalid class ID: " << class_id; 64 return kTfLiteError; 65 } 66 67 ground_truth_object_vectors_[class_id].push_back(ConvertProtoToDetection( 68 ground_truth_objects_.objects(i), current_image_index_)); 69 } 70 71 for (int i = 0; i < predicted_objects_.objects_size(); ++i) { 72 const int class_id = predicted_objects_.objects(i).class_id(); 73 if (class_id >= num_classes_) { 74 LOG(ERROR) << "Encountered invalid class ID: " << class_id; 75 return kTfLiteError; 76 } 77 78 predicted_object_vectors_[class_id].push_back(ConvertProtoToDetection( 79 predicted_objects_.objects(i), current_image_index_)); 80 } 81 82 current_image_index_++; 83 return kTfLiteOk; 84 } 85 LatestMetrics()86EvaluationStageMetrics ObjectDetectionAveragePrecisionStage::LatestMetrics() { 87 EvaluationStageMetrics metrics; 88 if (current_image_index_ == 0) return metrics; 89 90 metrics.set_num_runs(current_image_index_); 91 auto* ap_metrics = metrics.mutable_process_metrics() 92 ->mutable_object_detection_average_precision_metrics(); 93 auto& ap_params = 94 config_.specification().object_detection_average_precision_params(); 95 96 std::vector<float> iou_thresholds; 97 if (ap_params.iou_thresholds_size() == 0) { 98 // Default IoU thresholds as defined by COCO evaluation. 99 // Refer: http://cocodataset.org/#detection-eval 100 float threshold = 0.5; 101 for (int i = 0; i < 10; ++i) { 102 iou_thresholds.push_back(threshold + i * 0.05); 103 } 104 } else { 105 for (auto& threshold : ap_params.iou_thresholds()) { 106 iou_thresholds.push_back(threshold); 107 } 108 } 109 110 image::AveragePrecision::Options opts; 111 opts.num_recall_points = ap_params.num_recall_points(); 112 113 float ap_sum = 0; 114 int num_total_aps = 0; 115 for (float threshold : iou_thresholds) { 116 float threshold_ap_sum = 0; 117 int num_counted_classes = 0; 118 119 for (int i = 0; i < num_classes_; ++i) { 120 // Skip if this class wasn't encountered at all. 121 // TODO(b/133772912): Investigate the validity of this snippet when a 122 // subset of the classes is encountered in datasets. 123 if (ground_truth_object_vectors_[i].empty() && 124 predicted_object_vectors_[i].empty()) 125 continue; 126 127 // Output is NaN if there are no ground truth objects. 128 // So we assume 0. 129 float ap_value = 0.0; 130 if (!ground_truth_object_vectors_[i].empty()) { 131 opts.iou_threshold = threshold; 132 ap_value = image::AveragePrecision(opts).FromBoxes( 133 ground_truth_object_vectors_[i], predicted_object_vectors_[i]); 134 } 135 136 ap_sum += ap_value; 137 num_total_aps += 1; 138 threshold_ap_sum += ap_value; 139 num_counted_classes += 1; 140 } 141 142 if (num_counted_classes == 0) continue; 143 auto* threshold_ap = ap_metrics->add_individual_average_precisions(); 144 threshold_ap->set_average_precision(threshold_ap_sum / num_counted_classes); 145 threshold_ap->set_iou_threshold(threshold); 146 } 147 148 if (num_total_aps == 0) return metrics; 149 ap_metrics->set_overall_mean_average_precision(ap_sum / num_total_aps); 150 return metrics; 151 } 152 153 } // namespace evaluation 154 } // namespace tflite 155