1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_UTILS_IMAGE_METRICS_H_ 16 #define TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_UTILS_IMAGE_METRICS_H_ 17 18 #include <stdint.h> 19 20 #include <vector> 21 22 namespace tflite { 23 namespace evaluation { 24 namespace image { 25 26 struct Box2D { 27 struct Interval { 28 float min = 0; 29 float max = 0; IntervalBox2D::Interval30 Interval(float x, float y) { 31 min = x; 32 max = y; 33 } IntervalBox2D::Interval34 Interval() {} 35 }; 36 37 Interval x; 38 Interval y; 39 static float Length(const Interval& a); 40 static float Intersection(const Interval& a, const Interval& b); 41 float Area() const; 42 float Intersection(const Box2D& other) const; 43 float Union(const Box2D& other) const; 44 // Intersection of this box and the given box normalized over the union of 45 // this box and the given box. 46 float IoU(const Box2D& other) const; 47 // Intersection of this box and the given box normalized over the area of 48 // this box. 49 float Overlap(const Box2D& other) const; 50 }; 51 52 // If the value is: 53 // - kDontIgnore: The object is included in this evaluation. 54 // - kIgnoreOneMatch: the first matched prediction bbox will be ignored. This 55 // is useful when this groundtruth object is not intended to be evaluated. 56 // - kIgnoreAllMatches: all matched prediction bbox will be ignored. Typically 57 // it is used to mark an area that has not been labeled. 58 enum IgnoreType { 59 kDontIgnore = 0, 60 kIgnoreOneMatch = 1, 61 kIgnoreAllMatches = 2, 62 }; 63 64 struct Detection { 65 public: 66 bool difficult = false; 67 int64_t imgid = 0; 68 float score = 0; 69 Box2D box; 70 IgnoreType ignore = IgnoreType::kDontIgnore; 71 DetectionDetection72 Detection() {} DetectionDetection73 Detection(bool d, int64_t id, float s, Box2D b) 74 : difficult(d), imgid(id), score(s), box(b) {} DetectionDetection75 Detection(bool d, int64_t id, float s, Box2D b, IgnoreType i) 76 : difficult(d), imgid(id), score(s), box(b), ignore(i) {} 77 }; 78 79 // Precision and recall. 80 struct PR { 81 float p = 0; 82 float r = 0; PRPR83 PR(const float p_, const float r_) : p(p_), r(r_) {} 84 }; 85 86 class AveragePrecision { 87 public: 88 // iou_threshold: A predicted box matches a ground truth box if and only if 89 // IoU between these two are larger than this iou_threshold. Default: 0.5. 90 // num_recall_points: AP is computed as the average of maximum precision at (1 91 // + num_recall_points) recall levels. E.g., if num_recall_points is 10, 92 // recall levels are 0., 0.1, 0.2, ..., 0.9, 1.0. 93 // Default: 100. If num_recall_points < 0, AveragePrecision of 0 is returned. 94 struct Options { 95 float iou_threshold = 0.5; 96 int num_recall_points = 100; 97 }; AveragePrecision()98 AveragePrecision() : AveragePrecision(Options()) {} AveragePrecision(const Options & opts)99 explicit AveragePrecision(const Options& opts) : opts_(opts) {} 100 101 // Given a sequence of precision-recall points ordered by the recall in 102 // non-increasing order, returns the average of maximum precisions at 103 // different recall values (0.0, 0.1, 0.2, ..., 0.9, 1.0). 104 // The p-r pairs at these fixed recall points will be written to pr_out, if 105 // it is not null_ptr. 106 float FromPRCurve(const std::vector<PR>& pr, 107 std::vector<PR>* pr_out = nullptr); 108 109 // An axis aligned bounding box for an image with id 'imageid'. Score 110 // indicates its confidence. 111 // 112 // 'difficult' is a special bit specific to Pascal VOC dataset and tasks using 113 // the data. If 'difficult' is true, by convention, the box is often ignored 114 // during the AP calculation. I.e., if a predicted box matches a 'difficult' 115 // ground box, this predicted box is ignored as if the model does not make 116 // such a prediction. 117 118 // Given the set of ground truth boxes and a set of predicted boxes, returns 119 // the average of the maximum precisions at different recall values. 120 float FromBoxes(const std::vector<Detection>& groundtruth, 121 const std::vector<Detection>& prediction, 122 std::vector<PR>* pr_out = nullptr); 123 124 private: 125 Options opts_; 126 }; 127 128 } // namespace image 129 } // namespace evaluation 130 } // namespace tflite 131 132 #endif // TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_UTILS_IMAGE_METRICS_H_ 133