xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/evaluation/stages/utils/image_metrics.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/evaluation/stages/utils/image_metrics.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/platform/logging.h"
22 
23 namespace tflite {
24 namespace evaluation {
25 namespace image {
26 
Length(const Box2D::Interval & a)27 float Box2D::Length(const Box2D::Interval& a) {
28   return std::max(0.f, a.max - a.min);
29 }
30 
Intersection(const Box2D::Interval & a,const Box2D::Interval & b)31 float Box2D::Intersection(const Box2D::Interval& a, const Box2D::Interval& b) {
32   return Length(Interval{std::max(a.min, b.min), std::min(a.max, b.max)});
33 }
34 
Area() const35 float Box2D::Area() const { return Length(x) * Length(y); }
36 
Intersection(const Box2D & other) const37 float Box2D::Intersection(const Box2D& other) const {
38   return Intersection(x, other.x) * Intersection(y, other.y);
39 }
40 
Union(const Box2D & other) const41 float Box2D::Union(const Box2D& other) const {
42   return Area() + other.Area() - Intersection(other);
43 }
44 
IoU(const Box2D & other) const45 float Box2D::IoU(const Box2D& other) const {
46   const float total = Union(other);
47   if (total > 0) {
48     return Intersection(other) / total;
49   } else {
50     return 0.0;
51   }
52 }
53 
Overlap(const Box2D & other) const54 float Box2D::Overlap(const Box2D& other) const {
55   const float intersection = Intersection(other);
56   return intersection > 0 ? intersection / Area() : 0.0;
57 }
58 
FromPRCurve(const std::vector<PR> & pr,std::vector<PR> * pr_out)59 float AveragePrecision::FromPRCurve(const std::vector<PR>& pr,
60                                     std::vector<PR>* pr_out) {
61   // Because pr[...] are ordered by recall, iterate backward to compute max
62   // precision. p(r) = max_{r' >= r} p(r') for r in 0.0, 0.1, 0.2, ..., 0.9,
63   // 1.0. Then, take the average of (num_recal_points) quantities.
64   float p = 0;
65   float sum = 0;
66   int r_level = opts_.num_recall_points;
67   for (int i = pr.size() - 1; i >= 0; --i) {
68     const PR& item = pr[i];
69     if (i > 0) {
70       if (item.r < pr[i - 1].r) {
71         LOG(ERROR) << "recall points are not in order: " << pr[i - 1].r << ", "
72                    << item.r;
73         return 0;
74       }
75     }
76 
77     // Because r takes values opts_.num_recall_points, opts_.num_recall_points -
78     // 1, ..., 0, the following condition is checking whether item.r crosses r /
79     // opts_.num_recall_points. I.e., 1.0, 0.90, ..., 0.01, 0.0.  We don't use
80     // float to represent r because 0.01 is not representable precisely.
81     while (item.r * opts_.num_recall_points < r_level) {
82       const float recall =
83           static_cast<float>(r_level) / opts_.num_recall_points;
84       if (r_level < 0) {
85         LOG(ERROR) << "Number of recall points should be > 0";
86         return 0;
87       }
88       sum += p;
89       r_level -= 1;
90       if (pr_out != nullptr) {
91         pr_out->emplace_back(p, recall);
92       }
93     }
94     p = std::max(p, item.p);
95   }
96   for (; r_level >= 0; --r_level) {
97     const float recall = static_cast<float>(r_level) / opts_.num_recall_points;
98     sum += p;
99     if (pr_out != nullptr) {
100       pr_out->emplace_back(p, recall);
101     }
102   }
103   return sum / (1 + opts_.num_recall_points);
104 }
105 
FromBoxes(const std::vector<Detection> & groundtruth,const std::vector<Detection> & prediction,std::vector<PR> * pr_out)106 float AveragePrecision::FromBoxes(const std::vector<Detection>& groundtruth,
107                                   const std::vector<Detection>& prediction,
108                                   std::vector<PR>* pr_out) {
109   // Index ground truth boxes based on imageid.
110   absl::flat_hash_map<int64_t, std::list<Detection>> gt;
111   int num_gt = 0;
112   for (auto& box : groundtruth) {
113     gt[box.imgid].push_back(box);
114     if (!box.difficult && box.ignore == kDontIgnore) {
115       ++num_gt;
116     }
117   }
118 
119   if (num_gt == 0) {
120     return NAN;
121   }
122 
123   // Sort all predicted boxes by their scores in a non-ascending order.
124   std::vector<Detection> pd = prediction;
125   std::sort(pd.begin(), pd.end(), [](const Detection& a, const Detection& b) {
126     return a.score > b.score;
127   });
128 
129   // Computes p-r for every prediction.
130   std::vector<PR> pr;
131   int correct = 0;
132   int num_pd = 0;
133   for (int i = 0; i < pd.size(); ++i) {
134     const Detection& b = pd[i];
135     auto* g = &gt[b.imgid];
136     auto best = g->end();
137     float best_iou = -INFINITY;
138     for (auto it = g->begin(); it != g->end(); ++it) {
139       const auto iou = b.box.IoU(it->box);
140       if (iou > best_iou) {
141         best = it;
142         best_iou = iou;
143       }
144     }
145     if ((best != g->end()) && (best_iou >= opts_.iou_threshold)) {
146       if (best->difficult) {
147         continue;
148       }
149       switch (best->ignore) {
150         case kDontIgnore: {
151           ++correct;
152           ++num_pd;
153           g->erase(best);
154           pr.push_back({static_cast<float>(correct) / num_pd,
155                         static_cast<float>(correct) / num_gt});
156           break;
157         }
158         case kIgnoreOneMatch: {
159           g->erase(best);
160           break;
161         }
162         case kIgnoreAllMatches: {
163           break;
164         }
165       }
166     } else {
167       ++num_pd;
168       pr.push_back({static_cast<float>(correct) / num_pd,
169                     static_cast<float>(correct) / num_gt});
170     }
171   }
172   return FromPRCurve(pr, pr_out);
173 }
174 
175 }  // namespace image
176 }  // namespace evaluation
177 }  // namespace tflite
178