1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "NonMaxSuppression.hpp"
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker namespace od
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker
GenerateRangeK(unsigned int k)12*89c4ff92SAndroid Build Coastguard Worker static std::vector<unsigned int> GenerateRangeK(unsigned int k)
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> range(k);
15*89c4ff92SAndroid Build Coastguard Worker std::iota(range.begin(), range.end(), 0);
16*89c4ff92SAndroid Build Coastguard Worker return range;
17*89c4ff92SAndroid Build Coastguard Worker }
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker /**
21*89c4ff92SAndroid Build Coastguard Worker * @brief Returns the intersection over union for two bounding boxes
22*89c4ff92SAndroid Build Coastguard Worker *
23*89c4ff92SAndroid Build Coastguard Worker * @param[in] First detect containing bounding box.
24*89c4ff92SAndroid Build Coastguard Worker * @param[in] Second detect containing bounding box.
25*89c4ff92SAndroid Build Coastguard Worker * @return Calculated intersection over union.
26*89c4ff92SAndroid Build Coastguard Worker *
27*89c4ff92SAndroid Build Coastguard Worker */
IntersectionOverUnion(DetectedObject & detect1,DetectedObject & detect2)28*89c4ff92SAndroid Build Coastguard Worker static double IntersectionOverUnion(DetectedObject& detect1, DetectedObject& detect2)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker uint32_t area1 = (detect1.GetBoundingBox().GetHeight() * detect1.GetBoundingBox().GetWidth());
31*89c4ff92SAndroid Build Coastguard Worker uint32_t area2 = (detect2.GetBoundingBox().GetHeight() * detect2.GetBoundingBox().GetWidth());
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker float yMinIntersection = std::max(detect1.GetBoundingBox().GetY(), detect2.GetBoundingBox().GetY());
34*89c4ff92SAndroid Build Coastguard Worker float xMinIntersection = std::max(detect1.GetBoundingBox().GetX(), detect2.GetBoundingBox().GetX());
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker float yMaxIntersection = std::min(detect1.GetBoundingBox().GetY() + detect1.GetBoundingBox().GetHeight(),
37*89c4ff92SAndroid Build Coastguard Worker detect2.GetBoundingBox().GetY() + detect2.GetBoundingBox().GetHeight());
38*89c4ff92SAndroid Build Coastguard Worker float xMaxIntersection = std::min(detect1.GetBoundingBox().GetX() + detect1.GetBoundingBox().GetWidth(),
39*89c4ff92SAndroid Build Coastguard Worker detect2.GetBoundingBox().GetX() + detect2.GetBoundingBox().GetWidth());
40*89c4ff92SAndroid Build Coastguard Worker
41*89c4ff92SAndroid Build Coastguard Worker double areaIntersection = std::max(yMaxIntersection - yMinIntersection, 0.0f) *
42*89c4ff92SAndroid Build Coastguard Worker std::max(xMaxIntersection - xMinIntersection, 0.0f);
43*89c4ff92SAndroid Build Coastguard Worker double areaUnion = area1 + area2 - areaIntersection;
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Worker return areaIntersection / areaUnion;
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker
NonMaxSuppression(DetectedObjects & inputDetections,float iouThresh)48*89c4ff92SAndroid Build Coastguard Worker std::vector<int> NonMaxSuppression(DetectedObjects& inputDetections, float iouThresh)
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker // Sort indicies of detections by highest score to lowest.
51*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> sortedIndicies = GenerateRangeK(inputDetections.size());
52*89c4ff92SAndroid Build Coastguard Worker std::sort(sortedIndicies.begin(), sortedIndicies.end(),
53*89c4ff92SAndroid Build Coastguard Worker [&inputDetections](int idx1, int idx2)
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker return inputDetections[idx1].GetScore() > inputDetections[idx2].GetScore();
56*89c4ff92SAndroid Build Coastguard Worker });
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> visited(inputDetections.size(), false);
59*89c4ff92SAndroid Build Coastguard Worker std::vector<int> outputIndiciesAfterNMS;
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker for (int i=0; i < inputDetections.size(); ++i)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker // Each new unvisited detect should be kept.
64*89c4ff92SAndroid Build Coastguard Worker if (!visited[sortedIndicies[i]])
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker outputIndiciesAfterNMS.emplace_back(sortedIndicies[i]);
67*89c4ff92SAndroid Build Coastguard Worker visited[sortedIndicies[i]] = true;
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker
70*89c4ff92SAndroid Build Coastguard Worker // Look for detections to suppress.
71*89c4ff92SAndroid Build Coastguard Worker for (int j=i+1; j<inputDetections.size(); ++j)
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker // Skip if already kept or suppressed.
74*89c4ff92SAndroid Build Coastguard Worker if (!visited[sortedIndicies[j]])
75*89c4ff92SAndroid Build Coastguard Worker {
76*89c4ff92SAndroid Build Coastguard Worker // Detects must have the same label to be suppressed.
77*89c4ff92SAndroid Build Coastguard Worker if (inputDetections[sortedIndicies[j]].GetLabel() == inputDetections[sortedIndicies[i]].GetLabel())
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker auto iou = IntersectionOverUnion(inputDetections[sortedIndicies[i]],
80*89c4ff92SAndroid Build Coastguard Worker inputDetections[sortedIndicies[j]]);
81*89c4ff92SAndroid Build Coastguard Worker if (iou > iouThresh)
82*89c4ff92SAndroid Build Coastguard Worker {
83*89c4ff92SAndroid Build Coastguard Worker visited[sortedIndicies[j]] = true;
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker }
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker }
89*89c4ff92SAndroid Build Coastguard Worker return outputIndiciesAfterNMS;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker } // namespace od
93