xref: /aosp_15_r20/external/armnn/samples/ObjectDetection/src/NonMaxSuppression.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "NonMaxSuppression.hpp"
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker namespace od
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker 
GenerateRangeK(unsigned int k)12*89c4ff92SAndroid Build Coastguard Worker static std::vector<unsigned int> GenerateRangeK(unsigned int k)
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> range(k);
15*89c4ff92SAndroid Build Coastguard Worker     std::iota(range.begin(), range.end(), 0);
16*89c4ff92SAndroid Build Coastguard Worker     return range;
17*89c4ff92SAndroid Build Coastguard Worker }
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker /**
21*89c4ff92SAndroid Build Coastguard Worker * @brief Returns the intersection over union for two bounding boxes
22*89c4ff92SAndroid Build Coastguard Worker *
23*89c4ff92SAndroid Build Coastguard Worker * @param[in]  First detect containing bounding box.
24*89c4ff92SAndroid Build Coastguard Worker * @param[in]  Second detect containing bounding box.
25*89c4ff92SAndroid Build Coastguard Worker * @return     Calculated intersection over union.
26*89c4ff92SAndroid Build Coastguard Worker *
27*89c4ff92SAndroid Build Coastguard Worker */
IntersectionOverUnion(DetectedObject & detect1,DetectedObject & detect2)28*89c4ff92SAndroid Build Coastguard Worker static double IntersectionOverUnion(DetectedObject& detect1, DetectedObject& detect2)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker     uint32_t area1 = (detect1.GetBoundingBox().GetHeight() * detect1.GetBoundingBox().GetWidth());
31*89c4ff92SAndroid Build Coastguard Worker     uint32_t area2 = (detect2.GetBoundingBox().GetHeight() * detect2.GetBoundingBox().GetWidth());
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     float yMinIntersection = std::max(detect1.GetBoundingBox().GetY(), detect2.GetBoundingBox().GetY());
34*89c4ff92SAndroid Build Coastguard Worker     float xMinIntersection = std::max(detect1.GetBoundingBox().GetX(), detect2.GetBoundingBox().GetX());
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker     float yMaxIntersection = std::min(detect1.GetBoundingBox().GetY() + detect1.GetBoundingBox().GetHeight(),
37*89c4ff92SAndroid Build Coastguard Worker                                       detect2.GetBoundingBox().GetY() + detect2.GetBoundingBox().GetHeight());
38*89c4ff92SAndroid Build Coastguard Worker     float xMaxIntersection = std::min(detect1.GetBoundingBox().GetX() + detect1.GetBoundingBox().GetWidth(),
39*89c4ff92SAndroid Build Coastguard Worker                                       detect2.GetBoundingBox().GetX() + detect2.GetBoundingBox().GetWidth());
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     double areaIntersection = std::max(yMaxIntersection - yMinIntersection, 0.0f) *
42*89c4ff92SAndroid Build Coastguard Worker                               std::max(xMaxIntersection - xMinIntersection, 0.0f);
43*89c4ff92SAndroid Build Coastguard Worker     double areaUnion = area1 + area2 - areaIntersection;
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     return areaIntersection / areaUnion;
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker 
NonMaxSuppression(DetectedObjects & inputDetections,float iouThresh)48*89c4ff92SAndroid Build Coastguard Worker std::vector<int> NonMaxSuppression(DetectedObjects& inputDetections, float iouThresh)
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker     // Sort indicies of detections by highest score to lowest.
51*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> sortedIndicies = GenerateRangeK(inputDetections.size());
52*89c4ff92SAndroid Build Coastguard Worker     std::sort(sortedIndicies.begin(), sortedIndicies.end(),
53*89c4ff92SAndroid Build Coastguard Worker         [&inputDetections](int idx1, int idx2)
54*89c4ff92SAndroid Build Coastguard Worker         {
55*89c4ff92SAndroid Build Coastguard Worker             return inputDetections[idx1].GetScore() > inputDetections[idx2].GetScore();
56*89c4ff92SAndroid Build Coastguard Worker         });
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     std::vector<bool> visited(inputDetections.size(), false);
59*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> outputIndiciesAfterNMS;
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     for (int i=0; i < inputDetections.size(); ++i)
62*89c4ff92SAndroid Build Coastguard Worker     {
63*89c4ff92SAndroid Build Coastguard Worker         // Each new unvisited detect should be kept.
64*89c4ff92SAndroid Build Coastguard Worker         if (!visited[sortedIndicies[i]])
65*89c4ff92SAndroid Build Coastguard Worker         {
66*89c4ff92SAndroid Build Coastguard Worker             outputIndiciesAfterNMS.emplace_back(sortedIndicies[i]);
67*89c4ff92SAndroid Build Coastguard Worker             visited[sortedIndicies[i]] = true;
68*89c4ff92SAndroid Build Coastguard Worker         }
69*89c4ff92SAndroid Build Coastguard Worker 
70*89c4ff92SAndroid Build Coastguard Worker         // Look for detections to suppress.
71*89c4ff92SAndroid Build Coastguard Worker         for (int j=i+1; j<inputDetections.size(); ++j)
72*89c4ff92SAndroid Build Coastguard Worker         {
73*89c4ff92SAndroid Build Coastguard Worker             // Skip if already kept or suppressed.
74*89c4ff92SAndroid Build Coastguard Worker             if (!visited[sortedIndicies[j]])
75*89c4ff92SAndroid Build Coastguard Worker             {
76*89c4ff92SAndroid Build Coastguard Worker                 // Detects must have the same label to be suppressed.
77*89c4ff92SAndroid Build Coastguard Worker                 if (inputDetections[sortedIndicies[j]].GetLabel() == inputDetections[sortedIndicies[i]].GetLabel())
78*89c4ff92SAndroid Build Coastguard Worker                 {
79*89c4ff92SAndroid Build Coastguard Worker                     auto iou = IntersectionOverUnion(inputDetections[sortedIndicies[i]],
80*89c4ff92SAndroid Build Coastguard Worker                                                     inputDetections[sortedIndicies[j]]);
81*89c4ff92SAndroid Build Coastguard Worker                     if (iou > iouThresh)
82*89c4ff92SAndroid Build Coastguard Worker                     {
83*89c4ff92SAndroid Build Coastguard Worker                         visited[sortedIndicies[j]] = true;
84*89c4ff92SAndroid Build Coastguard Worker                     }
85*89c4ff92SAndroid Build Coastguard Worker                 }
86*89c4ff92SAndroid Build Coastguard Worker             }
87*89c4ff92SAndroid Build Coastguard Worker         }
88*89c4ff92SAndroid Build Coastguard Worker     }
89*89c4ff92SAndroid Build Coastguard Worker     return outputIndiciesAfterNMS;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker } // namespace od
93