xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/evaluation/stages/utils/image_metrics.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_UTILS_IMAGE_METRICS_H_
16 #define TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_UTILS_IMAGE_METRICS_H_
17 
18 #include <stdint.h>
19 
20 #include <vector>
21 
22 namespace tflite {
23 namespace evaluation {
24 namespace image {
25 
26 struct Box2D {
27   struct Interval {
28     float min = 0;
29     float max = 0;
IntervalBox2D::Interval30     Interval(float x, float y) {
31       min = x;
32       max = y;
33     }
IntervalBox2D::Interval34     Interval() {}
35   };
36 
37   Interval x;
38   Interval y;
39   static float Length(const Interval& a);
40   static float Intersection(const Interval& a, const Interval& b);
41   float Area() const;
42   float Intersection(const Box2D& other) const;
43   float Union(const Box2D& other) const;
44   // Intersection of this box and the given box normalized over the union of
45   // this box and the given box.
46   float IoU(const Box2D& other) const;
47   // Intersection of this box and the given box normalized over the area of
48   // this box.
49   float Overlap(const Box2D& other) const;
50 };
51 
52 // If the value is:
53 //   - kDontIgnore: The object is included in this evaluation.
54 //   - kIgnoreOneMatch: the first matched prediction bbox will be ignored. This
55 //      is useful when this groundtruth object is not intended to be evaluated.
56 //   - kIgnoreAllMatches: all matched prediction bbox will be ignored. Typically
57 //      it is used to mark an area that has not been labeled.
58 enum IgnoreType {
59   kDontIgnore = 0,
60   kIgnoreOneMatch = 1,
61   kIgnoreAllMatches = 2,
62 };
63 
64 struct Detection {
65  public:
66   bool difficult = false;
67   int64_t imgid = 0;
68   float score = 0;
69   Box2D box;
70   IgnoreType ignore = IgnoreType::kDontIgnore;
71 
DetectionDetection72   Detection() {}
DetectionDetection73   Detection(bool d, int64_t id, float s, Box2D b)
74       : difficult(d), imgid(id), score(s), box(b) {}
DetectionDetection75   Detection(bool d, int64_t id, float s, Box2D b, IgnoreType i)
76       : difficult(d), imgid(id), score(s), box(b), ignore(i) {}
77 };
78 
79 // Precision and recall.
80 struct PR {
81   float p = 0;
82   float r = 0;
PRPR83   PR(const float p_, const float r_) : p(p_), r(r_) {}
84 };
85 
86 class AveragePrecision {
87  public:
88   // iou_threshold: A predicted box matches a ground truth box if and only if
89   //   IoU between these two are larger than this iou_threshold. Default: 0.5.
90   // num_recall_points: AP is computed as the average of maximum precision at (1
91   //   + num_recall_points) recall levels. E.g., if num_recall_points is 10,
92   //   recall levels are 0., 0.1, 0.2, ..., 0.9, 1.0.
93   // Default: 100. If num_recall_points < 0, AveragePrecision of 0 is returned.
94   struct Options {
95     float iou_threshold = 0.5;
96     int num_recall_points = 100;
97   };
AveragePrecision()98   AveragePrecision() : AveragePrecision(Options()) {}
AveragePrecision(const Options & opts)99   explicit AveragePrecision(const Options& opts) : opts_(opts) {}
100 
101   // Given a sequence of precision-recall points ordered by the recall in
102   // non-increasing order, returns the average of maximum precisions at
103   // different recall values (0.0, 0.1, 0.2, ..., 0.9, 1.0).
104   // The p-r pairs at these fixed recall points will be written to pr_out, if
105   // it is not null_ptr.
106   float FromPRCurve(const std::vector<PR>& pr,
107                     std::vector<PR>* pr_out = nullptr);
108 
109   // An axis aligned bounding box for an image with id 'imageid'.  Score
110   // indicates its confidence.
111   //
112   // 'difficult' is a special bit specific to Pascal VOC dataset and tasks using
113   // the data. If 'difficult' is true, by convention, the box is often ignored
114   // during the AP calculation. I.e., if a predicted box matches a 'difficult'
115   // ground box, this predicted box is ignored as if the model does not make
116   // such a prediction.
117 
118   // Given the set of ground truth boxes and a set of predicted boxes, returns
119   // the average of the maximum precisions at different recall values.
120   float FromBoxes(const std::vector<Detection>& groundtruth,
121                   const std::vector<Detection>& prediction,
122                   std::vector<PR>* pr_out = nullptr);
123 
124  private:
125   Options opts_;
126 };
127 
128 }  // namespace image
129 }  // namespace evaluation
130 }  // namespace tflite
131 
132 #endif  // TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_UTILS_IMAGE_METRICS_H_
133