1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdlib>
16 #include <fstream>
17 #include <optional>
18 #include <string>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/types/optional.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/tools/command_line_flags.h"
25 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
26 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
27 #include "tensorflow/lite/tools/evaluation/stages/object_detection_stage.h"
28 #include "tensorflow/lite/tools/evaluation/tasks/task_executor.h"
29 #include "tensorflow/lite/tools/evaluation/utils.h"
30 #include "tensorflow/lite/tools/logging.h"
31
32 namespace tflite {
33 namespace evaluation {
34
35 constexpr char kModelFileFlag[] = "model_file";
36 constexpr char kGroundTruthImagesPathFlag[] = "ground_truth_images_path";
37 constexpr char kModelOutputLabelsFlag[] = "model_output_labels";
38 constexpr char kOutputFilePathFlag[] = "output_file_path";
39 constexpr char kGroundTruthProtoFileFlag[] = "ground_truth_proto";
40 constexpr char kInterpreterThreadsFlag[] = "num_interpreter_threads";
41 constexpr char kDebugModeFlag[] = "debug_mode";
42 constexpr char kDelegateFlag[] = "delegate";
43
GetNameFromPath(const std::string & str)44 std::string GetNameFromPath(const std::string& str) {
45 int pos = str.find_last_of("/\\");
46 if (pos == std::string::npos) return "";
47 return str.substr(pos + 1);
48 }
49
50 class CocoObjectDetection : public TaskExecutor {
51 public:
CocoObjectDetection()52 CocoObjectDetection() : debug_mode_(false), num_interpreter_threads_(1) {}
~CocoObjectDetection()53 ~CocoObjectDetection() override {}
54
55 protected:
56 std::vector<Flag> GetFlags() final;
57
58 // If the run is successful, the latest metrics will be returned.
59 std::optional<EvaluationStageMetrics> RunImpl() final;
60
61 private:
62 void OutputResult(const EvaluationStageMetrics& latest_metrics) const;
63 std::string model_file_path_;
64 std::string model_output_labels_path_;
65 std::string ground_truth_images_path_;
66 std::string ground_truth_proto_file_;
67 std::string output_file_path_;
68 bool debug_mode_;
69 std::string delegate_;
70 int num_interpreter_threads_;
71 };
72
GetFlags()73 std::vector<Flag> CocoObjectDetection::GetFlags() {
74 std::vector<tflite::Flag> flag_list = {
75 tflite::Flag::CreateFlag(kModelFileFlag, &model_file_path_,
76 "Path to test tflite model file."),
77 tflite::Flag::CreateFlag(
78 kModelOutputLabelsFlag, &model_output_labels_path_,
79 "Path to labels that correspond to output of model."
80 " E.g. in case of COCO-trained SSD model, this is the path to file "
81 "where each line contains a class detected by the model in correct "
82 "order, starting from background."),
83 tflite::Flag::CreateFlag(
84 kGroundTruthImagesPathFlag, &ground_truth_images_path_,
85 "Path to ground truth images. These will be evaluated in "
86 "alphabetical order of filenames"),
87 tflite::Flag::CreateFlag(kGroundTruthProtoFileFlag,
88 &ground_truth_proto_file_,
89 "Path to file containing "
90 "tflite::evaluation::ObjectDetectionGroundTruth "
91 "proto in binary serialized format. If left "
92 "empty, mAP numbers are not output."),
93 tflite::Flag::CreateFlag(
94 kOutputFilePathFlag, &output_file_path_,
95 "File to output to. Contains only metrics proto if debug_mode is "
96 "off, and per-image predictions also otherwise."),
97 tflite::Flag::CreateFlag(kDebugModeFlag, &debug_mode_,
98 "Whether to enable debug mode. Per-image "
99 "predictions are written to the output file "
100 "along with metrics."),
101 tflite::Flag::CreateFlag(
102 kInterpreterThreadsFlag, &num_interpreter_threads_,
103 "Number of interpreter threads to use for inference."),
104 tflite::Flag::CreateFlag(
105 kDelegateFlag, &delegate_,
106 "Delegate to use for inference, if available. "
107 "Must be one of {'nnapi', 'gpu', 'xnnpack', 'hexagon'}"),
108 };
109 return flag_list;
110 }
111
RunImpl()112 std::optional<EvaluationStageMetrics> CocoObjectDetection::RunImpl() {
113 // Process images in filename-sorted order.
114 std::vector<std::string> image_paths;
115 if (GetSortedFileNames(StripTrailingSlashes(ground_truth_images_path_),
116 &image_paths) != kTfLiteOk) {
117 return std::nullopt;
118 }
119
120 std::vector<std::string> model_labels;
121 if (!ReadFileLines(model_output_labels_path_, &model_labels)) {
122 TFLITE_LOG(ERROR) << "Could not read model output labels file";
123 return std::nullopt;
124 }
125
126 EvaluationStageConfig eval_config;
127 eval_config.set_name("object_detection");
128 auto* detection_params =
129 eval_config.mutable_specification()->mutable_object_detection_params();
130 auto* inference_params = detection_params->mutable_inference_params();
131 inference_params->set_model_file_path(model_file_path_);
132 inference_params->set_num_threads(num_interpreter_threads_);
133 inference_params->set_delegate(ParseStringToDelegateType(delegate_));
134
135 // Get ground truth data.
136 absl::flat_hash_map<std::string, ObjectDetectionResult> ground_truth_map;
137 if (!ground_truth_proto_file_.empty()) {
138 PopulateGroundTruth(ground_truth_proto_file_, &ground_truth_map);
139 }
140
141 ObjectDetectionStage eval(eval_config);
142
143 eval.SetAllLabels(model_labels);
144 if (eval.Init(&delegate_providers_) != kTfLiteOk) return std::nullopt;
145
146 const int step = image_paths.size() / 100;
147 for (int i = 0; i < image_paths.size(); ++i) {
148 if (step > 1 && i % step == 0) {
149 TFLITE_LOG(INFO) << "Finished: " << i / step << "%";
150 }
151
152 const std::string image_name = GetNameFromPath(image_paths[i]);
153 eval.SetInputs(image_paths[i], ground_truth_map[image_name]);
154 if (eval.Run() != kTfLiteOk) return std::nullopt;
155
156 if (debug_mode_) {
157 ObjectDetectionResult prediction = *eval.GetLatestPrediction();
158 TFLITE_LOG(INFO) << "Image: " << image_name << "\n";
159 for (int i = 0; i < prediction.objects_size(); ++i) {
160 const auto& object = prediction.objects(i);
161 TFLITE_LOG(INFO) << "Object [" << i << "]";
162 TFLITE_LOG(INFO) << " Score: " << object.score();
163 TFLITE_LOG(INFO) << " Class-ID: " << object.class_id();
164 TFLITE_LOG(INFO) << " Bounding Box:";
165 const auto& bounding_box = object.bounding_box();
166 TFLITE_LOG(INFO) << " Normalized Top: "
167 << bounding_box.normalized_top();
168 TFLITE_LOG(INFO) << " Normalized Bottom: "
169 << bounding_box.normalized_bottom();
170 TFLITE_LOG(INFO) << " Normalized Left: "
171 << bounding_box.normalized_left();
172 TFLITE_LOG(INFO) << " Normalized Right: "
173 << bounding_box.normalized_right();
174 }
175 TFLITE_LOG(INFO)
176 << "======================================================\n";
177 }
178 }
179
180 // Write metrics to file.
181 EvaluationStageMetrics latest_metrics = eval.LatestMetrics();
182 if (ground_truth_proto_file_.empty()) {
183 TFLITE_LOG(WARN) << "mAP metrics are meaningless w/o ground truth.";
184 latest_metrics.mutable_process_metrics()
185 ->mutable_object_detection_metrics()
186 ->clear_average_precision_metrics();
187 }
188
189 OutputResult(latest_metrics);
190 return std::make_optional(latest_metrics);
191 }
192
OutputResult(const EvaluationStageMetrics & latest_metrics) const193 void CocoObjectDetection::OutputResult(
194 const EvaluationStageMetrics& latest_metrics) const {
195 if (!output_file_path_.empty()) {
196 std::ofstream metrics_ofile;
197 metrics_ofile.open(output_file_path_, std::ios::out);
198 metrics_ofile << latest_metrics.SerializeAsString();
199 metrics_ofile.close();
200 }
201 TFLITE_LOG(INFO) << "Num evaluation runs: " << latest_metrics.num_runs();
202 const auto object_detection_metrics =
203 latest_metrics.process_metrics().object_detection_metrics();
204 const auto& preprocessing_latency =
205 object_detection_metrics.pre_processing_latency();
206 TFLITE_LOG(INFO) << "Preprocessing latency: avg="
207 << preprocessing_latency.avg_us() << "(us), std_dev="
208 << preprocessing_latency.std_deviation_us() << "(us)";
209 const auto& inference_latency = object_detection_metrics.inference_latency();
210 TFLITE_LOG(INFO) << "Inference latency: avg=" << inference_latency.avg_us()
211 << "(us), std_dev=" << inference_latency.std_deviation_us()
212 << "(us)";
213 const auto& precision_metrics =
214 object_detection_metrics.average_precision_metrics();
215 for (int i = 0; i < precision_metrics.individual_average_precisions_size();
216 ++i) {
217 const auto ap_metric = precision_metrics.individual_average_precisions(i);
218 TFLITE_LOG(INFO) << "Average Precision [IOU Threshold="
219 << ap_metric.iou_threshold()
220 << "]: " << ap_metric.average_precision();
221 }
222 TFLITE_LOG(INFO) << "Overall mAP: "
223 << precision_metrics.overall_mean_average_precision();
224 }
225
CreateTaskExecutor()226 std::unique_ptr<TaskExecutor> CreateTaskExecutor() {
227 return std::unique_ptr<TaskExecutor>(new CocoObjectDetection());
228 }
229
230 } // namespace evaluation
231 } // namespace tflite
232