xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/run_eval.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdlib>
16 #include <fstream>
17 #include <optional>
18 #include <string>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/types/optional.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/tools/command_line_flags.h"
25 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
26 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
27 #include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h"
28 #include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
29 #include "tensorflow/lite/tools/evaluation/utils.h"
30 #include "tensorflow/lite/tools/logging.h"
31 
32 namespace tflite {
33 namespace evaluation {
34 
35 constexpr char kModelFileFlag[] = "model_file";
36 constexpr char kGroundTruthImagesPathFlag[] = "ground_truth_images_path";
37 constexpr char kModelOutputLabelsFlag[] = "model_output_labels";
38 constexpr char kOutputFilePathFlag[] = "output_file_path";
39 constexpr char kGroundTruthProtoFileFlag[] = "ground_truth_proto";
40 constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads";
41 constexpr char kDebugModeFlag[] = "debug_mode";
42 constexpr char kDelegateFlag[] = "delegate";
43 
GetNameFromPath(const std::string & str)44 std::string GetNameFromPath(const std::string& str) {
45   int pos = str.find_last_of("/\\");
46   if (pos == std::string::npos) return "";
47   return str.substr(pos + 1);
48 }
49 
50 class CocoObjectDetection : public TaskExecutor {
51  public:
CocoObjectDetection()52   CocoObjectDetection() : debug_mode_(false), num_interpreter_threads_(1) {}
~CocoObjectDetection()53   ~CocoObjectDetection() override {}
54 
55  protected:
56   std::vector<Flag> GetFlags() final;
57 
58   // If the run is successful, the latest metrics will be returned.
59   std::optional<EvaluationStageMetrics> RunImpl() final;
60 
61  private:
62   void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
63   std::string model_file_path_;
64   std::string model_output_labels_path_;
65   std::string ground_truth_images_path_;
66   std::string ground_truth_proto_file_;
67   std::string output_file_path_;
68   bool debug_mode_;
69   std::string delegate_;
70   int num_interpreter_threads_;
71 };
72 
GetFlags()73 std::vector<Flag> CocoObjectDetection::GetFlags() {
74   std::vector<tflite::Flag> flag_list = {
75       tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
76                                "Path to test tflite model file."),
77       tflite::Flag::CreateFlag(
78           kModelOutputLabelsFlag, &model_output_labels_path_,
79           "Path to labels that correspond to output of model."
80           " E.g. in case of COCO-trained SSD model, this is the path to file "
81           "where each line contains a class detected by the model in correct "
82           "order, starting from background."),
83       tflite::Flag::CreateFlag(
84           kGroundTruthImagesPathFlag, &ground_truth_images_path_,
85           "Path to ground truth images. These will be evaluated in "
86           "alphabetical order of filenames"),
87       tflite::Flag::CreateFlag(kGroundTruthProtoFileFlag,
88                                &ground_truth_proto_file_,
89                                "Path to file containing "
90                                "tflite::evaluation::ObjectDetectionGroundTruth "
91                                "proto in binary serialized format. If left "
92                                "empty, mAP numbers are not output."),
93       tflite::Flag::CreateFlag(
94           kOutputFilePathFlag, &output_file_path_,
95           "File to output to. Contains only metrics proto if debug_mode is "
96           "off, and per-image predictions also otherwise."),
97       tflite::Flag::CreateFlag(kDebugModeFlag, &debug_mode_,
98                                "Whether to enable debug mode. Per-image "
99                                "predictions are written to the output file "
100                                "along with metrics."),
101       tflite::Flag::CreateFlag(
102           kInterpreterThreadsFlag, &num_interpreter_threads_,
103           "Number of interpreter threads to use for inference."),
104       tflite::Flag::CreateFlag(
105           kDelegateFlag, &delegate_,
106           "Delegate to use for inference, if available. "
107           "Must be one of {'nnapi', 'gpu', 'xnnpack', 'hexagon'}"),
108   };
109   return flag_list;
110 }
111 
RunImpl()112 std::optional<EvaluationStageMetrics> CocoObjectDetection::RunImpl() {
113   // Process images in filename-sorted order.
114   std::vector<std::string> image_paths;
115   if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
116                          &image_paths) != kTfLiteOk) {
117     return std::nullopt;
118   }
119 
120   std::vector<std::string> model_labels;
121   if (!ReadFileLines(model_output_labels_path_, &model_labels)) {
122     TFLITE_LOG(ERROR) << "Could not read model output labels file";
123     return std::nullopt;
124   }
125 
126   EvaluationStageConfig eval_config;
127   eval_config.set_name("object_detection");
128   auto* detection_params =
129       eval_config.mutable_specification()->mutable_object_detection_params();
130   auto* inference_params = detection_params->mutable_inference_params();
131   inference_params->set_model_file_path(model_file_path_);
132   inference_params->set_num_threads(num_interpreter_threads_);
133   inference_params->set_delegate(ParseStringToDelegateType(delegate_));
134 
135   // Get ground truth data.
136   absl::flat_hash_map<std::string, ObjectDetectionResult> ground_truth_map;
137   if (!ground_truth_proto_file_.empty()) {
138     PopulateGroundTruth(ground_truth_proto_file_, &ground_truth_map);
139   }
140 
141   ObjectDetectionStage eval(eval_config);
142 
143   eval.SetAllLabels(model_labels);
144   if (eval.Init(&delegate_providers_) != kTfLiteOk) return std::nullopt;
145 
146   const int step = image_paths.size() / 100;
147   for (int i = 0; i < image_paths.size(); ++i) {
148     if (step > 1 && i % step == 0) {
149       TFLITE_LOG(INFO) << "Finished: " << i / step << "%";
150     }
151 
152     const std::string image_name = GetNameFromPath(image_paths[i]);
153     eval.SetInputs(image_paths[i], ground_truth_map[image_name]);
154     if (eval.Run() != kTfLiteOk) return std::nullopt;
155 
156     if (debug_mode_) {
157       ObjectDetectionResult prediction = *eval.GetLatestPrediction();
158       TFLITE_LOG(INFO) << "Image: " << image_name << "\n";
159       for (int i = 0; i < prediction.objects_size(); ++i) {
160         const auto& object = prediction.objects(i);
161         TFLITE_LOG(INFO) << "Object [" << i << "]";
162         TFLITE_LOG(INFO) << "  Score: " << object.score();
163         TFLITE_LOG(INFO) << "  Class-ID: " << object.class_id();
164         TFLITE_LOG(INFO) << "  Bounding Box:";
165         const auto& bounding_box = object.bounding_box();
166         TFLITE_LOG(INFO) << "    Normalized Top: "
167                          << bounding_box.normalized_top();
168         TFLITE_LOG(INFO) << "    Normalized Bottom: "
169                          << bounding_box.normalized_bottom();
170         TFLITE_LOG(INFO) << "    Normalized Left: "
171                          << bounding_box.normalized_left();
172         TFLITE_LOG(INFO) << "    Normalized Right: "
173                          << bounding_box.normalized_right();
174       }
175       TFLITE_LOG(INFO)
176           << "======================================================\n";
177     }
178   }
179 
180   // Write metrics to file.
181   EvaluationStageMetrics latest_metrics = eval.LatestMetrics();
182   if (ground_truth_proto_file_.empty()) {
183     TFLITE_LOG(WARN) << "mAP metrics are meaningless w/o ground truth.";
184     latest_metrics.mutable_process_metrics()
185         ->mutable_object_detection_metrics()
186         ->clear_average_precision_metrics();
187   }
188 
189   OutputResult(latest_metrics);
190   return std::make_optional(latest_metrics);
191 }
192 
OutputResult(const EvaluationStageMetrics & latest_metrics) const193 void CocoObjectDetection::OutputResult(
194     const EvaluationStageMetrics& latest_metrics) const {
195   if (!output_file_path_.empty()) {
196     std::ofstream metrics_ofile;
197     metrics_ofile.open(output_file_path_, std::ios::out);
198     metrics_ofile << latest_metrics.SerializeAsString();
199     metrics_ofile.close();
200   }
201   TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs();
202   const auto object_detection_metrics =
203       latest_metrics.process_metrics().object_detection_metrics();
204   const auto& preprocessing_latency =
205       object_detection_metrics.pre_processing_latency();
206   TFLITE_LOG(INFO) << "Preprocessing latency: avg="
207                    << preprocessing_latency.avg_us() << "(us), std_dev="
208                    << preprocessing_latency.std_deviation_us() << "(us)";
209   const auto& inference_latency = object_detection_metrics.inference_latency();
210   TFLITE_LOG(INFO) << "Inference latency: avg=" << inference_latency.avg_us()
211                    << "(us), std_dev=" << inference_latency.std_deviation_us()
212                    << "(us)";
213   const auto& precision_metrics =
214       object_detection_metrics.average_precision_metrics();
215   for (int i = 0; i < precision_metrics.individual_average_precisions_size();
216        ++i) {
217     const auto ap_metric = precision_metrics.individual_average_precisions(i);
218     TFLITE_LOG(INFO) << "Average Precision [IOU Threshold="
219                      << ap_metric.iou_threshold()
220                      << "]: " << ap_metric.average_precision();
221   }
222   TFLITE_LOG(INFO) << "Overall mAP: "
223                    << precision_metrics.overall_mean_average_precision();
224 }
225 
CreateTaskExecutor()226 std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
227   return std::unique_ptr<TaskExecutor>(new CocoObjectDetection());
228 }
229 
230 }  // namespace evaluation
231 }  // namespace tflite
232