1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "YoloResultDecoder.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include "NonMaxSuppression.hpp"
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <cassert>
11*89c4ff92SAndroid Build Coastguard Worker #include <stdexcept>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker namespace od
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker
Decode(const common::InferenceResults<float> & networkResults,const common::Size & outputFrameSize,const common::Size & resizedFrameSize,const std::vector<std::string> & labels)16*89c4ff92SAndroid Build Coastguard Worker DetectedObjects YoloResultDecoder::Decode(const common::InferenceResults<float>& networkResults,
17*89c4ff92SAndroid Build Coastguard Worker const common::Size& outputFrameSize,
18*89c4ff92SAndroid Build Coastguard Worker const common::Size& resizedFrameSize,
19*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::string>& labels)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker // Yolo v3 network outputs 1 tensor
23*89c4ff92SAndroid Build Coastguard Worker if (networkResults.size() != 1)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker throw std::runtime_error("Number of outputs from Yolo model doesn't equal 1");
26*89c4ff92SAndroid Build Coastguard Worker }
27*89c4ff92SAndroid Build Coastguard Worker auto element_step = m_boxElements + m_confidenceElements + m_numClasses;
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker float longEdgeInput = std::max(resizedFrameSize.m_Width, resizedFrameSize.m_Height);
30*89c4ff92SAndroid Build Coastguard Worker float longEdgeOutput = std::max(outputFrameSize.m_Width, outputFrameSize.m_Height);
31*89c4ff92SAndroid Build Coastguard Worker const float resizeFactor = longEdgeOutput/longEdgeInput;
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker DetectedObjects detectedObjects;
34*89c4ff92SAndroid Build Coastguard Worker DetectedObjects resultsAfterNMS;
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker for (const common::InferenceResult<float>& result : networkResults)
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < m_numBoxes; ++i)
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker const float* cur_box = &result[i * element_step];
41*89c4ff92SAndroid Build Coastguard Worker // Objectness score
42*89c4ff92SAndroid Build Coastguard Worker if (cur_box[4] > m_objectThreshold)
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker for (unsigned int classIndex = 0; classIndex < m_numClasses; ++classIndex)
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker const float class_prob = cur_box[4] * cur_box[5 + classIndex];
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker // class confidence
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker if (class_prob > m_ClsThreshold)
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker DetectedObject detectedObject;
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker detectedObject.SetScore(class_prob);
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker float topLeftX = cur_box[0] * resizeFactor;
57*89c4ff92SAndroid Build Coastguard Worker float topLeftY = cur_box[1] * resizeFactor;
58*89c4ff92SAndroid Build Coastguard Worker float botRightX = cur_box[2] * resizeFactor;
59*89c4ff92SAndroid Build Coastguard Worker float botRightY = cur_box[3] * resizeFactor;
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker assert(botRightX > topLeftX);
62*89c4ff92SAndroid Build Coastguard Worker assert(botRightY > topLeftY);
63*89c4ff92SAndroid Build Coastguard Worker
64*89c4ff92SAndroid Build Coastguard Worker detectedObject.SetBoundingBox({static_cast<int>(topLeftX),
65*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(topLeftY),
66*89c4ff92SAndroid Build Coastguard Worker static_cast<unsigned int>(botRightX-topLeftX),
67*89c4ff92SAndroid Build Coastguard Worker static_cast<unsigned int>(botRightY-topLeftY)});
68*89c4ff92SAndroid Build Coastguard Worker if(labels.size() > classIndex)
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker detectedObject.SetLabel(labels.at(classIndex));
71*89c4ff92SAndroid Build Coastguard Worker }
72*89c4ff92SAndroid Build Coastguard Worker else
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker detectedObject.SetLabel(std::to_string(classIndex));
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker detectedObject.SetId(classIndex);
77*89c4ff92SAndroid Build Coastguard Worker detectedObjects.emplace_back(detectedObject);
78*89c4ff92SAndroid Build Coastguard Worker }
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker
83*89c4ff92SAndroid Build Coastguard Worker std::vector<int> keepIndiciesAfterNMS = od::NonMaxSuppression(detectedObjects, m_NmsThreshold);
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker for (const int ind: keepIndiciesAfterNMS)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker resultsAfterNMS.emplace_back(detectedObjects[ind]);
88*89c4ff92SAndroid Build Coastguard Worker }
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker return resultsAfterNMS;
92*89c4ff92SAndroid Build Coastguard Worker }
93*89c4ff92SAndroid Build Coastguard Worker
YoloResultDecoder(float NMSThreshold,float ClsThreshold,float ObjectThreshold)94*89c4ff92SAndroid Build Coastguard Worker YoloResultDecoder::YoloResultDecoder(float NMSThreshold, float ClsThreshold, float ObjectThreshold)
95*89c4ff92SAndroid Build Coastguard Worker : m_NmsThreshold(NMSThreshold), m_ClsThreshold(ClsThreshold), m_objectThreshold(ObjectThreshold) {}
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker }// namespace od
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker
101