1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_OBJECT_DETECTION_STAGE_H_ 16 #define TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_OBJECT_DETECTION_STAGE_H_ 17 18 #include <memory> 19 #include <string> 20 #include <vector> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" 24 #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" 25 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" 26 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" 27 #include "tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h" 28 #include "tensorflow/lite/tools/evaluation/stages/object_detection_average_precision_stage.h" 29 #include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h" 30 31 namespace tflite { 32 namespace evaluation { 33 34 // An EvaluationStage to encapsulate the complete Object Detection task. 35 // Assumes that the object detection model's signature (number of 36 // inputs/outputs, ordering of outputs & what they denote) is same as the 37 // MobileNet SSD model: 38 // https://www.tensorflow.org/lite/examples/object_detection/overview#output_signature. 39 // Input size/type & number of detections could be different. 40 // 41 // This class will be extended to support other types of detection models, if 42 // required in the future. 43 class ObjectDetectionStage : public EvaluationStage { 44 public: ObjectDetectionStage(const EvaluationStageConfig & config)45 explicit ObjectDetectionStage(const EvaluationStageConfig& config) 46 : EvaluationStage(config) {} 47 Init()48 TfLiteStatus Init() override { return Init(nullptr); } 49 TfLiteStatus Init(const DelegateProviders* delegate_providers); 50 51 TfLiteStatus Run() override; 52 53 EvaluationStageMetrics LatestMetrics() override; 54 55 // Call before Init(). all_labels should contain all possible object labels 56 // that can be detected by the model, in the correct order. all_labels should 57 // outlive the call to Init(). SetAllLabels(const std::vector<std::string> & all_labels)58 void SetAllLabels(const std::vector<std::string>& all_labels) { 59 all_labels_ = &all_labels; 60 } 61 62 // Call before Run(). 63 // ground_truth_objects instance should outlive the call to Run(). SetInputs(const std::string & image_path,const ObjectDetectionResult & ground_truth_objects)64 void SetInputs(const std::string& image_path, 65 const ObjectDetectionResult& ground_truth_objects) { 66 image_path_ = image_path; 67 ground_truth_objects_ = &ground_truth_objects; 68 } 69 70 // Provides a pointer to the underlying TfLiteInferenceStage. 71 // Returns non-null value only if this stage has been initialized. GetInferenceStage()72 TfliteInferenceStage* const GetInferenceStage() { 73 return inference_stage_.get(); 74 } 75 76 // Returns a const pointer to the latest inference output. GetLatestPrediction()77 const ObjectDetectionResult* GetLatestPrediction() { 78 return &predicted_objects_; 79 } 80 81 private: 82 const std::vector<std::string>* all_labels_ = nullptr; 83 std::unique_ptr<ImagePreprocessingStage> preprocessing_stage_; 84 std::unique_ptr<TfliteInferenceStage> inference_stage_; 85 std::unique_ptr<ObjectDetectionAveragePrecisionStage> eval_stage_; 86 std::string image_path_; 87 88 // Obtained from SetInputs(...). 89 const ObjectDetectionResult* ground_truth_objects_; 90 // Reflects the outputs generated from the latest call to Run(). 91 ObjectDetectionResult predicted_objects_; 92 }; 93 94 // Reads a tflite::evaluation::ObjectDetectionGroundTruth instance from a 95 // textproto file and populates a mapping of image name to 96 // ObjectDetectionResult. 97 // File with ObjectDetectionGroundTruth can be generated using the 98 // preprocess_coco_minival.py script in evaluation/tasks/coco_object_detection. 99 // Useful for wrappers/scripts that use ObjectDetectionStage. 100 TfLiteStatus PopulateGroundTruth( 101 const std::string& grouth_truth_proto_file, 102 absl::flat_hash_map<std::string, ObjectDetectionResult>* 103 ground_truth_mapping); 104 105 } // namespace evaluation 106 } // namespace tflite 107 108 #endif // TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_OBJECT_DETECTION_STAGE_H_ 109