xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h"
16 
17 #include <fstream>
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
24 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
25 #include "tensorflow/lite/tools/evaluation/utils.h"
26 
27 namespace tflite {
28 namespace evaluation {
29 
Init(const DelegateProviders * delegate_providers)30 TfLiteStatus ObjectDetectionStage::Init(
31     const DelegateProviders* delegate_providers) {
32   // Ensure inference params are provided.
33   if (!config_.specification().has_object_detection_params()) {
34     LOG(ERROR) << "ObjectDetectionParams not provided";
35     return kTfLiteError;
36   }
37   auto& params = config_.specification().object_detection_params();
38   if (!params.has_inference_params()) {
39     LOG(ERROR) << "inference_params not provided";
40     return kTfLiteError;
41   }
42   if (all_labels_ == nullptr) {
43     LOG(ERROR) << "Detection output labels not provided";
44     return kTfLiteError;
45   }
46 
47   // TfliteInferenceStage.
48   EvaluationStageConfig tflite_inference_config;
49   tflite_inference_config.set_name("tflite_inference");
50   *tflite_inference_config.mutable_specification()
51        ->mutable_tflite_inference_params() = params.inference_params();
52   inference_stage_ =
53       std::make_unique<TfliteInferenceStage>(tflite_inference_config);
54   TF_LITE_ENSURE_STATUS(inference_stage_->Init(delegate_providers));
55 
56   // Validate model inputs.
57   const TfLiteModelInfo* model_info = inference_stage_->GetModelInfo();
58   if (model_info->inputs.size() != 1 || model_info->outputs.size() != 4) {
59     LOG(ERROR) << "Object detection model must have 1 input & 4 outputs";
60     return kTfLiteError;
61   }
62   TfLiteType input_type = model_info->inputs[0]->type;
63   auto* input_shape = model_info->inputs[0]->dims;
64   // Input should be of the shape {1, height, width, 3}
65   if (input_shape->size != 4 || input_shape->data[0] != 1 ||
66       input_shape->data[3] != 3) {
67     LOG(ERROR) << "Invalid input shape for model";
68     return kTfLiteError;
69   }
70 
71   // ImagePreprocessingStage
72   tflite::evaluation::ImagePreprocessingConfigBuilder builder(
73       "image_preprocessing", input_type);
74   builder.AddResizingStep(input_shape->data[2], input_shape->data[1], false);
75   builder.AddDefaultNormalizationStep();
76   preprocessing_stage_ =
77       std::make_unique<ImagePreprocessingStage>(builder.build());
78   TF_LITE_ENSURE_STATUS(preprocessing_stage_->Init());
79 
80   // ObjectDetectionAveragePrecisionStage
81   EvaluationStageConfig eval_config;
82   eval_config.set_name("average_precision");
83   *eval_config.mutable_specification()
84        ->mutable_object_detection_average_precision_params() =
85       params.ap_params();
86   eval_config.mutable_specification()
87       ->mutable_object_detection_average_precision_params()
88       ->set_num_classes(all_labels_->size());
89   eval_stage_ =
90       std::make_unique<ObjectDetectionAveragePrecisionStage>(eval_config);
91   TF_LITE_ENSURE_STATUS(eval_stage_->Init());
92 
93   return kTfLiteOk;
94 }
95 
Run()96 TfLiteStatus ObjectDetectionStage::Run() {
97   if (image_path_.empty()) {
98     LOG(ERROR) << "Input image not set";
99     return kTfLiteError;
100   }
101 
102   // Preprocessing.
103   preprocessing_stage_->SetImagePath(&image_path_);
104   TF_LITE_ENSURE_STATUS(preprocessing_stage_->Run());
105 
106   // Inference.
107   std::vector<void*> data_ptrs = {};
108   data_ptrs.push_back(preprocessing_stage_->GetPreprocessedImageData());
109   inference_stage_->SetInputs(data_ptrs);
110   TF_LITE_ENSURE_STATUS(inference_stage_->Run());
111 
112   // Convert model output to ObjectsSet.
113   predicted_objects_.Clear();
114   const int class_offset =
115       config_.specification().object_detection_params().class_offset();
116   const std::vector<void*>* outputs = inference_stage_->GetOutputs();
117   int num_detections = static_cast<int>(*static_cast<float*>(outputs->at(3)));
118   float* detected_label_boxes = static_cast<float*>(outputs->at(0));
119   float* detected_label_indices = static_cast<float*>(outputs->at(1));
120   float* detected_label_probabilities = static_cast<float*>(outputs->at(2));
121   for (int i = 0; i < num_detections; ++i) {
122     const int bounding_box_offset = i * 4;
123     auto* object = predicted_objects_.add_objects();
124     // Bounding box
125     auto* bbox = object->mutable_bounding_box();
126     bbox->set_normalized_top(detected_label_boxes[bounding_box_offset + 0]);
127     bbox->set_normalized_left(detected_label_boxes[bounding_box_offset + 1]);
128     bbox->set_normalized_bottom(detected_label_boxes[bounding_box_offset + 2]);
129     bbox->set_normalized_right(detected_label_boxes[bounding_box_offset + 3]);
130     // Class.
131     object->set_class_id(static_cast<int>(detected_label_indices[i]) +
132                          class_offset);
133     // Score
134     object->set_score(detected_label_probabilities[i]);
135   }
136 
137   // AP Evaluation.
138   eval_stage_->SetEvalInputs(predicted_objects_, *ground_truth_objects_);
139   TF_LITE_ENSURE_STATUS(eval_stage_->Run());
140 
141   return kTfLiteOk;
142 }
143 
LatestMetrics()144 EvaluationStageMetrics ObjectDetectionStage::LatestMetrics() {
145   EvaluationStageMetrics metrics;
146   auto* detection_metrics =
147       metrics.mutable_process_metrics()->mutable_object_detection_metrics();
148 
149   *detection_metrics->mutable_pre_processing_latency() =
150       preprocessing_stage_->LatestMetrics().process_metrics().total_latency();
151   EvaluationStageMetrics inference_metrics = inference_stage_->LatestMetrics();
152   *detection_metrics->mutable_inference_latency() =
153       inference_metrics.process_metrics().total_latency();
154   *detection_metrics->mutable_inference_metrics() =
155       inference_metrics.process_metrics().tflite_inference_metrics();
156   *detection_metrics->mutable_average_precision_metrics() =
157       eval_stage_->LatestMetrics()
158           .process_metrics()
159           .object_detection_average_precision_metrics();
160   metrics.set_num_runs(inference_metrics.num_runs());
161   return metrics;
162 }
163 
PopulateGroundTruth(const std::string & grouth_truth_proto_file,absl::flat_hash_map<std::string,ObjectDetectionResult> * ground_truth_mapping)164 TfLiteStatus PopulateGroundTruth(
165     const std::string& grouth_truth_proto_file,
166     absl::flat_hash_map<std::string, ObjectDetectionResult>*
167         ground_truth_mapping) {
168   if (ground_truth_mapping == nullptr) {
169     return kTfLiteError;
170   }
171   ground_truth_mapping->clear();
172 
173   // Read the ground truth dump.
174   std::ifstream t(grouth_truth_proto_file);
175   std::string proto_str((std::istreambuf_iterator<char>(t)),
176                         std::istreambuf_iterator<char>());
177   ObjectDetectionGroundTruth ground_truth_proto;
178   ground_truth_proto.ParseFromString(proto_str);
179 
180   for (const auto& image_ground_truth :
181        ground_truth_proto.detection_results()) {
182     (*ground_truth_mapping)[image_ground_truth.image_name()] =
183         image_ground_truth;
184   }
185 
186   return kTfLiteOk;
187 }
188 
189 }  // namespace evaluation
190 }  // namespace tflite
191