1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h"
16
17 #include <fstream>
18 #include <memory>
19 #include <string>
20
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
24 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
25 #include "tensorflow/lite/tools/evaluation/utils.h"
26
27 namespace tflite {
28 namespace evaluation {
29
Init(const DelegateProviders * delegate_providers)30 TfLiteStatus ObjectDetectionStage::Init(
31 const DelegateProviders* delegate_providers) {
32 // Ensure inference params are provided.
33 if (!config_.specification().has_object_detection_params()) {
34 LOG(ERROR) << "ObjectDetectionParams not provided";
35 return kTfLiteError;
36 }
37 auto& params = config_.specification().object_detection_params();
38 if (!params.has_inference_params()) {
39 LOG(ERROR) << "inference_params not provided";
40 return kTfLiteError;
41 }
42 if (all_labels_ == nullptr) {
43 LOG(ERROR) << "Detection output labels not provided";
44 return kTfLiteError;
45 }
46
47 // TfliteInferenceStage.
48 EvaluationStageConfig tflite_inference_config;
49 tflite_inference_config.set_name("tflite_inference");
50 *tflite_inference_config.mutable_specification()
51 ->mutable_tflite_inference_params() = params.inference_params();
52 inference_stage_ =
53 std::make_unique<TfliteInferenceStage>(tflite_inference_config);
54 TF_LITE_ENSURE_STATUS(inference_stage_->Init(delegate_providers));
55
56 // Validate model inputs.
57 const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();
58 if (model_info->inputs.size() != 1 || model_info->outputs.size() != 4) {
59 LOG(ERROR) << "Object detection model must have 1 input & 4 outputs";
60 return kTfLiteError;
61 }
62 TfLiteType input_type = model_info->inputs[0]->type;
63 auto* input_shape = model_info->inputs[0]->dims;
64 // Input should be of the shape {1, height, width, 3}
65 if (input_shape->size != 4 || input_shape->data[0] != 1 ||
66 input_shape->data[3] != 3) {
67 LOG(ERROR) << "Invalid input shape for model";
68 return kTfLiteError;
69 }
70
71 // ImagePreprocessingStage
72 tflite::evaluation::ImagePreprocessingConfigBuilder builder(
73 "image_preprocessing", input_type);
74 builder.AddResizingStep(input_shape->data[2], input_shape->data[1], false);
75 builder.AddDefaultNormalizationStep();
76 preprocessing_stage_ =
77 std::make_unique<ImagePreprocessingStage>(builder.build());
78 TF_LITE_ENSURE_STATUS(preprocessing_stage_->Init());
79
80 // ObjectDetectionAveragePrecisionStage
81 EvaluationStageConfig eval_config;
82 eval_config.set_name("average_precision");
83 *eval_config.mutable_specification()
84 ->mutable_object_detection_average_precision_params() =
85 params.ap_params();
86 eval_config.mutable_specification()
87 ->mutable_object_detection_average_precision_params()
88 ->set_num_classes(all_labels_->size());
89 eval_stage_ =
90 std::make_unique<ObjectDetectionAveragePrecisionStage>(eval_config);
91 TF_LITE_ENSURE_STATUS(eval_stage_->Init());
92
93 return kTfLiteOk;
94 }
95
Run()96 TfLiteStatus ObjectDetectionStage::Run() {
97 if (image_path_.empty()) {
98 LOG(ERROR) << "Input image not set";
99 return kTfLiteError;
100 }
101
102 // Preprocessing.
103 preprocessing_stage_->SetImagePath(&image_path_);
104 TF_LITE_ENSURE_STATUS(preprocessing_stage_->Run());
105
106 // Inference.
107 std::vector<void*> data_ptrs = {};
108 data_ptrs.push_back(preprocessing_stage_->GetPreprocessedImageData());
109 inference_stage_->SetInputs(data_ptrs);
110 TF_LITE_ENSURE_STATUS(inference_stage_->Run());
111
112 // Convert model output to ObjectsSet.
113 predicted_objects_.Clear();
114 const int class_offset =
115 config_.specification().object_detection_params().class_offset();
116 const std::vector<void*>* outputs = inference_stage_->GetOutputs();
117 int num_detections = static_cast<int>(*static_cast<float*>(outputs->at(3)));
118 float* detected_label_boxes = static_cast<float*>(outputs->at(0));
119 float* detected_label_indices = static_cast<float*>(outputs->at(1));
120 float* detected_label_probabilities = static_cast<float*>(outputs->at(2));
121 for (int i = 0; i < num_detections; ++i) {
122 const int bounding_box_offset = i * 4;
123 auto* object = predicted_objects_.add_objects();
124 // Bounding box
125 auto* bbox = object->mutable_bounding_box();
126 bbox->set_normalized_top(detected_label_boxes[bounding_box_offset + 0]);
127 bbox->set_normalized_left(detected_label_boxes[bounding_box_offset + 1]);
128 bbox->set_normalized_bottom(detected_label_boxes[bounding_box_offset + 2]);
129 bbox->set_normalized_right(detected_label_boxes[bounding_box_offset + 3]);
130 // Class.
131 object->set_class_id(static_cast<int>(detected_label_indices[i]) +
132 class_offset);
133 // Score
134 object->set_score(detected_label_probabilities[i]);
135 }
136
137 // AP Evaluation.
138 eval_stage_->SetEvalInputs(predicted_objects_, *ground_truth_objects_);
139 TF_LITE_ENSURE_STATUS(eval_stage_->Run());
140
141 return kTfLiteOk;
142 }
143
LatestMetrics()144 EvaluationStageMetrics ObjectDetectionStage::LatestMetrics() {
145 EvaluationStageMetrics metrics;
146 auto* detection_metrics =
147 metrics.mutable_process_metrics()->mutable_object_detection_metrics();
148
149 *detection_metrics->mutable_pre_processing_latency() =
150 preprocessing_stage_->LatestMetrics().process_metrics().total_latency();
151 EvaluationStageMetrics inference_metrics = inference_stage_->LatestMetrics();
152 *detection_metrics->mutable_inference_latency() =
153 inference_metrics.process_metrics().total_latency();
154 *detection_metrics->mutable_inference_metrics() =
155 inference_metrics.process_metrics().tflite_inference_metrics();
156 *detection_metrics->mutable_average_precision_metrics() =
157 eval_stage_->LatestMetrics()
158 .process_metrics()
159 .object_detection_average_precision_metrics();
160 metrics.set_num_runs(inference_metrics.num_runs());
161 return metrics;
162 }
163
PopulateGroundTruth(const std::string & grouth_truth_proto_file,absl::flat_hash_map<std::string,ObjectDetectionResult> * ground_truth_mapping)164 TfLiteStatus PopulateGroundTruth(
165 const std::string& grouth_truth_proto_file,
166 absl::flat_hash_map<std::string, ObjectDetectionResult>*
167 ground_truth_mapping) {
168 if (ground_truth_mapping == nullptr) {
169 return kTfLiteError;
170 }
171 ground_truth_mapping->clear();
172
173 // Read the ground truth dump.
174 std::ifstream t(grouth_truth_proto_file);
175 std::string proto_str((std::istreambuf_iterator<char>(t)),
176 std::istreambuf_iterator<char>());
177 ObjectDetectionGroundTruth ground_truth_proto;
178 ground_truth_proto.ParseFromString(proto_str);
179
180 for (const auto& image_ground_truth :
181 ground_truth_proto.detection_results()) {
182 (*ground_truth_mapping)[image_ground_truth.image_name()] =
183 image_ground_truth;
184 }
185
186 return kTfLiteOk;
187 }
188
189 } // namespace evaluation
190 } // namespace tflite
191