xref: /aosp_15_r20/external/armnn/samples/ObjectDetection/src/YoloResultDecoder.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "YoloResultDecoder.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "NonMaxSuppression.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <cassert>
11*89c4ff92SAndroid Build Coastguard Worker #include <stdexcept>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker namespace od
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker 
Decode(const common::InferenceResults<float> & networkResults,const common::Size & outputFrameSize,const common::Size & resizedFrameSize,const std::vector<std::string> & labels)16*89c4ff92SAndroid Build Coastguard Worker DetectedObjects YoloResultDecoder::Decode(const common::InferenceResults<float>& networkResults,
17*89c4ff92SAndroid Build Coastguard Worker                                          const common::Size& outputFrameSize,
18*89c4ff92SAndroid Build Coastguard Worker                                          const common::Size& resizedFrameSize,
19*89c4ff92SAndroid Build Coastguard Worker                                          const std::vector<std::string>& labels)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker     // Yolo v3 network outputs 1 tensor
23*89c4ff92SAndroid Build Coastguard Worker     if (networkResults.size() != 1)
24*89c4ff92SAndroid Build Coastguard Worker     {
25*89c4ff92SAndroid Build Coastguard Worker         throw std::runtime_error("Number of outputs from Yolo model doesn't equal 1");
26*89c4ff92SAndroid Build Coastguard Worker     }
27*89c4ff92SAndroid Build Coastguard Worker     auto element_step = m_boxElements + m_confidenceElements + m_numClasses;
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker     float longEdgeInput = std::max(resizedFrameSize.m_Width, resizedFrameSize.m_Height);
30*89c4ff92SAndroid Build Coastguard Worker     float longEdgeOutput = std::max(outputFrameSize.m_Width, outputFrameSize.m_Height);
31*89c4ff92SAndroid Build Coastguard Worker     const float resizeFactor = longEdgeOutput/longEdgeInput;
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     DetectedObjects detectedObjects;
34*89c4ff92SAndroid Build Coastguard Worker     DetectedObjects resultsAfterNMS;
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker     for (const common::InferenceResult<float>& result : networkResults)
37*89c4ff92SAndroid Build Coastguard Worker     {
38*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < m_numBoxes; ++i)
39*89c4ff92SAndroid Build Coastguard Worker         {
40*89c4ff92SAndroid Build Coastguard Worker             const float* cur_box = &result[i * element_step];
41*89c4ff92SAndroid Build Coastguard Worker             // Objectness score
42*89c4ff92SAndroid Build Coastguard Worker             if (cur_box[4] > m_objectThreshold)
43*89c4ff92SAndroid Build Coastguard Worker             {
44*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int classIndex = 0; classIndex < m_numClasses; ++classIndex)
45*89c4ff92SAndroid Build Coastguard Worker                 {
46*89c4ff92SAndroid Build Coastguard Worker                     const float class_prob =  cur_box[4] * cur_box[5 + classIndex];
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker                     // class confidence
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker                     if (class_prob > m_ClsThreshold)
51*89c4ff92SAndroid Build Coastguard Worker                     {
52*89c4ff92SAndroid Build Coastguard Worker                         DetectedObject detectedObject;
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker                         detectedObject.SetScore(class_prob);
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker                         float topLeftX = cur_box[0] * resizeFactor;
57*89c4ff92SAndroid Build Coastguard Worker                         float topLeftY = cur_box[1] * resizeFactor;
58*89c4ff92SAndroid Build Coastguard Worker                         float botRightX = cur_box[2] * resizeFactor;
59*89c4ff92SAndroid Build Coastguard Worker                         float botRightY = cur_box[3] * resizeFactor;
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker                         assert(botRightX > topLeftX);
62*89c4ff92SAndroid Build Coastguard Worker                         assert(botRightY > topLeftY);
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker                         detectedObject.SetBoundingBox({static_cast<int>(topLeftX),
65*89c4ff92SAndroid Build Coastguard Worker                                                        static_cast<int>(topLeftY),
66*89c4ff92SAndroid Build Coastguard Worker                                                        static_cast<unsigned int>(botRightX-topLeftX),
67*89c4ff92SAndroid Build Coastguard Worker                                                        static_cast<unsigned int>(botRightY-topLeftY)});
68*89c4ff92SAndroid Build Coastguard Worker                         if(labels.size() > classIndex)
69*89c4ff92SAndroid Build Coastguard Worker                         {
70*89c4ff92SAndroid Build Coastguard Worker                             detectedObject.SetLabel(labels.at(classIndex));
71*89c4ff92SAndroid Build Coastguard Worker                         }
72*89c4ff92SAndroid Build Coastguard Worker                         else
73*89c4ff92SAndroid Build Coastguard Worker                         {
74*89c4ff92SAndroid Build Coastguard Worker                             detectedObject.SetLabel(std::to_string(classIndex));
75*89c4ff92SAndroid Build Coastguard Worker                         }
76*89c4ff92SAndroid Build Coastguard Worker                         detectedObject.SetId(classIndex);
77*89c4ff92SAndroid Build Coastguard Worker                         detectedObjects.emplace_back(detectedObject);
78*89c4ff92SAndroid Build Coastguard Worker                     }
79*89c4ff92SAndroid Build Coastguard Worker                 }
80*89c4ff92SAndroid Build Coastguard Worker             }
81*89c4ff92SAndroid Build Coastguard Worker         }
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker         std::vector<int> keepIndiciesAfterNMS = od::NonMaxSuppression(detectedObjects, m_NmsThreshold);
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker         for (const int ind: keepIndiciesAfterNMS)
86*89c4ff92SAndroid Build Coastguard Worker         {
87*89c4ff92SAndroid Build Coastguard Worker             resultsAfterNMS.emplace_back(detectedObjects[ind]);
88*89c4ff92SAndroid Build Coastguard Worker         }
89*89c4ff92SAndroid Build Coastguard Worker     }
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     return resultsAfterNMS;
92*89c4ff92SAndroid Build Coastguard Worker }
93*89c4ff92SAndroid Build Coastguard Worker 
YoloResultDecoder(float NMSThreshold,float ClsThreshold,float ObjectThreshold)94*89c4ff92SAndroid Build Coastguard Worker YoloResultDecoder::YoloResultDecoder(float NMSThreshold, float ClsThreshold, float ObjectThreshold)
95*89c4ff92SAndroid Build Coastguard Worker         : m_NmsThreshold(NMSThreshold), m_ClsThreshold(ClsThreshold), m_objectThreshold(ObjectThreshold) {}
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker }// namespace od
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker 
101