1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/evaluation/stages/utils/image_metrics.h"
16
17 #include <algorithm>
18 #include <cmath>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/platform/logging.h"
22
23 namespace tflite {
24 namespace evaluation {
25 namespace image {
26
Length(const Box2D::Interval & a)27 float Box2D::Length(const Box2D::Interval& a) {
28 return std::max(0.f, a.max - a.min);
29 }
30
Intersection(const Box2D::Interval & a,const Box2D::Interval & b)31 float Box2D::Intersection(const Box2D::Interval& a, const Box2D::Interval& b) {
32 return Length(Interval{std::max(a.min, b.min), std::min(a.max, b.max)});
33 }
34
Area() const35 float Box2D::Area() const { return Length(x) * Length(y); }
36
Intersection(const Box2D & other) const37 float Box2D::Intersection(const Box2D& other) const {
38 return Intersection(x, other.x) * Intersection(y, other.y);
39 }
40
Union(const Box2D & other) const41 float Box2D::Union(const Box2D& other) const {
42 return Area() + other.Area() - Intersection(other);
43 }
44
IoU(const Box2D & other) const45 float Box2D::IoU(const Box2D& other) const {
46 const float total = Union(other);
47 if (total > 0) {
48 return Intersection(other) / total;
49 } else {
50 return 0.0;
51 }
52 }
53
Overlap(const Box2D & other) const54 float Box2D::Overlap(const Box2D& other) const {
55 const float intersection = Intersection(other);
56 return intersection > 0 ? intersection / Area() : 0.0;
57 }
58
FromPRCurve(const std::vector<PR> & pr,std::vector<PR> * pr_out)59 float AveragePrecision::FromPRCurve(const std::vector<PR>& pr,
60 std::vector<PR>* pr_out) {
61 // Because pr[...] are ordered by recall, iterate backward to compute max
62 // precision. p(r) = max_{r' >= r} p(r') for r in 0.0, 0.1, 0.2, ..., 0.9,
63 // 1.0. Then, take the average of (num_recal_points) quantities.
64 float p = 0;
65 float sum = 0;
66 int r_level = opts_.num_recall_points;
67 for (int i = pr.size() - 1; i >= 0; --i) {
68 const PR& item = pr[i];
69 if (i > 0) {
70 if (item.r < pr[i - 1].r) {
71 LOG(ERROR) << "recall points are not in order: " << pr[i - 1].r << ", "
72 << item.r;
73 return 0;
74 }
75 }
76
77 // Because r takes values opts_.num_recall_points, opts_.num_recall_points -
78 // 1, ..., 0, the following condition is checking whether item.r crosses r /
79 // opts_.num_recall_points. I.e., 1.0, 0.90, ..., 0.01, 0.0. We don't use
80 // float to represent r because 0.01 is not representable precisely.
81 while (item.r * opts_.num_recall_points < r_level) {
82 const float recall =
83 static_cast<float>(r_level) / opts_.num_recall_points;
84 if (r_level < 0) {
85 LOG(ERROR) << "Number of recall points should be > 0";
86 return 0;
87 }
88 sum += p;
89 r_level -= 1;
90 if (pr_out != nullptr) {
91 pr_out->emplace_back(p, recall);
92 }
93 }
94 p = std::max(p, item.p);
95 }
96 for (; r_level >= 0; --r_level) {
97 const float recall = static_cast<float>(r_level) / opts_.num_recall_points;
98 sum += p;
99 if (pr_out != nullptr) {
100 pr_out->emplace_back(p, recall);
101 }
102 }
103 return sum / (1 + opts_.num_recall_points);
104 }
105
FromBoxes(const std::vector<Detection> & groundtruth,const std::vector<Detection> & prediction,std::vector<PR> * pr_out)106 float AveragePrecision::FromBoxes(const std::vector<Detection>& groundtruth,
107 const std::vector<Detection>& prediction,
108 std::vector<PR>* pr_out) {
109 // Index ground truth boxes based on imageid.
110 absl::flat_hash_map<int64_t, std::list<Detection>> gt;
111 int num_gt = 0;
112 for (auto& box : groundtruth) {
113 gt[box.imgid].push_back(box);
114 if (!box.difficult && box.ignore == kDontIgnore) {
115 ++num_gt;
116 }
117 }
118
119 if (num_gt == 0) {
120 return NAN;
121 }
122
123 // Sort all predicted boxes by their scores in a non-ascending order.
124 std::vector<Detection> pd = prediction;
125 std::sort(pd.begin(), pd.end(), [](const Detection& a, const Detection& b) {
126 return a.score > b.score;
127 });
128
129 // Computes p-r for every prediction.
130 std::vector<PR> pr;
131 int correct = 0;
132 int num_pd = 0;
133 for (int i = 0; i < pd.size(); ++i) {
134 const Detection& b = pd[i];
135 auto* g = >[b.imgid];
136 auto best = g->end();
137 float best_iou = -INFINITY;
138 for (auto it = g->begin(); it != g->end(); ++it) {
139 const auto iou = b.box.IoU(it->box);
140 if (iou > best_iou) {
141 best = it;
142 best_iou = iou;
143 }
144 }
145 if ((best != g->end()) && (best_iou >= opts_.iou_threshold)) {
146 if (best->difficult) {
147 continue;
148 }
149 switch (best->ignore) {
150 case kDontIgnore: {
151 ++correct;
152 ++num_pd;
153 g->erase(best);
154 pr.push_back({static_cast<float>(correct) / num_pd,
155 static_cast<float>(correct) / num_gt});
156 break;
157 }
158 case kIgnoreOneMatch: {
159 g->erase(best);
160 break;
161 }
162 case kIgnoreAllMatches: {
163 break;
164 }
165 }
166 } else {
167 ++num_pd;
168 pr.push_back({static_cast<float>(correct) / num_pd,
169 static_cast<float>(correct) / num_gt});
170 }
171 }
172 return FromPRCurve(pr, pr_out);
173 }
174
175 } // namespace image
176 } // namespace evaluation
177 } // namespace tflite
178