xref: /aosp_15_r20/external/armnn/tests/YoloInferenceTest.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "InferenceTest.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "YoloDatabase.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/FloatingPointComparison.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
15*89c4ff92SAndroid Build Coastguard Worker #include <array>
16*89c4ff92SAndroid Build Coastguard Worker #include <utility>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker constexpr size_t YoloOutputSize = 1470;
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker template <typename Model>
21*89c4ff92SAndroid Build Coastguard Worker class YoloTestCase : public InferenceModelTestCase<Model>
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker public:
YoloTestCase(Model & model,unsigned int testCaseId,YoloTestCaseData & testCaseData)24*89c4ff92SAndroid Build Coastguard Worker     YoloTestCase(Model& model,
25*89c4ff92SAndroid Build Coastguard Worker         unsigned int testCaseId,
26*89c4ff92SAndroid Build Coastguard Worker         YoloTestCaseData& testCaseData)
27*89c4ff92SAndroid Build Coastguard Worker      : InferenceModelTestCase<Model>(model, testCaseId, { std::move(testCaseData.m_InputImage) }, { YoloOutputSize })
28*89c4ff92SAndroid Build Coastguard Worker      , m_TopObjectDetections(std::move(testCaseData.m_TopObjectDetections))
29*89c4ff92SAndroid Build Coastguard Worker     {
30*89c4ff92SAndroid Build Coastguard Worker     }
31*89c4ff92SAndroid Build Coastguard Worker 
ProcessResult(const InferenceTestOptions & options)32*89c4ff92SAndroid Build Coastguard Worker     virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) override
33*89c4ff92SAndroid Build Coastguard Worker     {
34*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(options);
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker         const std::vector<float>& output = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]);
37*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(output.size() == YoloOutputSize);
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker         constexpr unsigned int gridSize = 7;
40*89c4ff92SAndroid Build Coastguard Worker         constexpr unsigned int numClasses = 20;
41*89c4ff92SAndroid Build Coastguard Worker         constexpr unsigned int numScales = 2;
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker         const float* outputPtr =  output.data();
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker         // Range 0-980. Class probabilities. 7x7x20
46*89c4ff92SAndroid Build Coastguard Worker         vector<vector<vector<float>>> classProbabilities(gridSize, vector<vector<float>>(gridSize,
47*89c4ff92SAndroid Build Coastguard Worker                                                          vector<float>(numClasses)));
48*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int y = 0; y < gridSize; ++y)
49*89c4ff92SAndroid Build Coastguard Worker         {
50*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int x = 0; x < gridSize; ++x)
51*89c4ff92SAndroid Build Coastguard Worker             {
52*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int c = 0; c < numClasses; ++c)
53*89c4ff92SAndroid Build Coastguard Worker                 {
54*89c4ff92SAndroid Build Coastguard Worker                     classProbabilities[y][x][c] = *outputPtr++;
55*89c4ff92SAndroid Build Coastguard Worker                 }
56*89c4ff92SAndroid Build Coastguard Worker             }
57*89c4ff92SAndroid Build Coastguard Worker         }
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker         // Range 980-1078. Scales. 7x7x2
60*89c4ff92SAndroid Build Coastguard Worker         vector<vector<vector<float>>> scales(gridSize, vector<vector<float>>(gridSize, vector<float>(numScales)));
61*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int y = 0; y < gridSize; ++y)
62*89c4ff92SAndroid Build Coastguard Worker         {
63*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int x = 0; x < gridSize; ++x)
64*89c4ff92SAndroid Build Coastguard Worker             {
65*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int s = 0; s < numScales; ++s)
66*89c4ff92SAndroid Build Coastguard Worker                 {
67*89c4ff92SAndroid Build Coastguard Worker                     scales[y][x][s] = *outputPtr++;
68*89c4ff92SAndroid Build Coastguard Worker                 }
69*89c4ff92SAndroid Build Coastguard Worker             }
70*89c4ff92SAndroid Build Coastguard Worker         }
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker         // Range 1078-1469. Bounding boxes. 7x7x2x4
73*89c4ff92SAndroid Build Coastguard Worker         constexpr float imageWidthAsFloat = static_cast<float>(YoloImageWidth);
74*89c4ff92SAndroid Build Coastguard Worker         constexpr float imageHeightAsFloat = static_cast<float>(YoloImageHeight);
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker         vector<vector<vector<vector<float>>>> boxes(gridSize, vector<vector<vector<float>>>
77*89c4ff92SAndroid Build Coastguard Worker             (gridSize, vector<vector<float>>(numScales, vector<float>(4))));
78*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int y = 0; y < gridSize; ++y)
79*89c4ff92SAndroid Build Coastguard Worker         {
80*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int x = 0; x < gridSize; ++x)
81*89c4ff92SAndroid Build Coastguard Worker             {
82*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int s = 0; s < numScales; ++s)
83*89c4ff92SAndroid Build Coastguard Worker                 {
84*89c4ff92SAndroid Build Coastguard Worker                     float bx = *outputPtr++;
85*89c4ff92SAndroid Build Coastguard Worker                     float by = *outputPtr++;
86*89c4ff92SAndroid Build Coastguard Worker                     float bw = *outputPtr++;
87*89c4ff92SAndroid Build Coastguard Worker                     float bh = *outputPtr++;
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker                     boxes[y][x][s][0] = ((bx + static_cast<float>(x)) / 7.0f) * imageWidthAsFloat;
90*89c4ff92SAndroid Build Coastguard Worker                     boxes[y][x][s][1] = ((by + static_cast<float>(y)) / 7.0f) * imageHeightAsFloat;
91*89c4ff92SAndroid Build Coastguard Worker                     boxes[y][x][s][2] = bw * bw * static_cast<float>(imageWidthAsFloat);
92*89c4ff92SAndroid Build Coastguard Worker                     boxes[y][x][s][3] = bh * bh * static_cast<float>(imageHeightAsFloat);
93*89c4ff92SAndroid Build Coastguard Worker                 }
94*89c4ff92SAndroid Build Coastguard Worker             }
95*89c4ff92SAndroid Build Coastguard Worker         }
96*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(output.data() + YoloOutputSize == outputPtr);
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker         std::vector<YoloDetectedObject> detectedObjects;
99*89c4ff92SAndroid Build Coastguard Worker         detectedObjects.reserve(gridSize * gridSize * numScales * numClasses);
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int y = 0; y < gridSize; ++y)
102*89c4ff92SAndroid Build Coastguard Worker         {
103*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int x = 0; x < gridSize; ++x)
104*89c4ff92SAndroid Build Coastguard Worker             {
105*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int s = 0; s < numScales; ++s)
106*89c4ff92SAndroid Build Coastguard Worker                 {
107*89c4ff92SAndroid Build Coastguard Worker                     for (unsigned int c = 0; c < numClasses; ++c)
108*89c4ff92SAndroid Build Coastguard Worker                     {
109*89c4ff92SAndroid Build Coastguard Worker                         // Resolved confidence: class probabilities * scales.
110*89c4ff92SAndroid Build Coastguard Worker                         const float confidence = classProbabilities[y][x][c] * scales[y][x][s];
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker                         // Resolves bounding box and stores.
113*89c4ff92SAndroid Build Coastguard Worker                         YoloBoundingBox box;
114*89c4ff92SAndroid Build Coastguard Worker                         box.m_X = boxes[y][x][s][0];
115*89c4ff92SAndroid Build Coastguard Worker                         box.m_Y = boxes[y][x][s][1];
116*89c4ff92SAndroid Build Coastguard Worker                         box.m_W = boxes[y][x][s][2];
117*89c4ff92SAndroid Build Coastguard Worker                         box.m_H = boxes[y][x][s][3];
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker                         detectedObjects.emplace_back(c, box, confidence);
120*89c4ff92SAndroid Build Coastguard Worker                     }
121*89c4ff92SAndroid Build Coastguard Worker                 }
122*89c4ff92SAndroid Build Coastguard Worker             }
123*89c4ff92SAndroid Build Coastguard Worker         }
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker         // Sorts detected objects by confidence.
126*89c4ff92SAndroid Build Coastguard Worker         std::sort(detectedObjects.begin(), detectedObjects.end(),
127*89c4ff92SAndroid Build Coastguard Worker             [](const YoloDetectedObject& a, const YoloDetectedObject& b)
128*89c4ff92SAndroid Build Coastguard Worker             {
129*89c4ff92SAndroid Build Coastguard Worker                 // Sorts by largest confidence first, then by class.
130*89c4ff92SAndroid Build Coastguard Worker                 return a.m_Confidence > b.m_Confidence
131*89c4ff92SAndroid Build Coastguard Worker                     || (a.m_Confidence == b.m_Confidence && a.m_Class > b.m_Class);
132*89c4ff92SAndroid Build Coastguard Worker             });
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker         // Checks the top N detections.
135*89c4ff92SAndroid Build Coastguard Worker         auto outputIt  = detectedObjects.begin();
136*89c4ff92SAndroid Build Coastguard Worker         auto outputEnd = detectedObjects.end();
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker         for (const YoloDetectedObject& expectedDetection : m_TopObjectDetections)
139*89c4ff92SAndroid Build Coastguard Worker         {
140*89c4ff92SAndroid Build Coastguard Worker             if (outputIt == outputEnd)
141*89c4ff92SAndroid Build Coastguard Worker             {
142*89c4ff92SAndroid Build Coastguard Worker                 // Somehow expected more things to check than detections found by the model.
143*89c4ff92SAndroid Build Coastguard Worker                 return TestCaseResult::Abort;
144*89c4ff92SAndroid Build Coastguard Worker             }
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker             const YoloDetectedObject& detectedObject = *outputIt;
147*89c4ff92SAndroid Build Coastguard Worker             if (detectedObject.m_Class != expectedDetection.m_Class)
148*89c4ff92SAndroid Build Coastguard Worker             {
149*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(error) << "Prediction for test case " << this->GetTestCaseId() <<
150*89c4ff92SAndroid Build Coastguard Worker                     " is incorrect: Expected (" << expectedDetection.m_Class << ")" <<
151*89c4ff92SAndroid Build Coastguard Worker                     " but predicted (" << detectedObject.m_Class << ")";
152*89c4ff92SAndroid Build Coastguard Worker                 return TestCaseResult::Failed;
153*89c4ff92SAndroid Build Coastguard Worker             }
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker             if (!armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_X, expectedDetection.m_Box.m_X) ||
156*89c4ff92SAndroid Build Coastguard Worker                 !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_Y, expectedDetection.m_Box.m_Y) ||
157*89c4ff92SAndroid Build Coastguard Worker                 !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_W, expectedDetection.m_Box.m_W) ||
158*89c4ff92SAndroid Build Coastguard Worker                 !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_H, expectedDetection.m_Box.m_H) ||
159*89c4ff92SAndroid Build Coastguard Worker                 !armnnUtils::within_percentage_tolerance(detectedObject.m_Confidence, expectedDetection.m_Confidence))
160*89c4ff92SAndroid Build Coastguard Worker             {
161*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(error) << "Detected bounding box for test case " << this->GetTestCaseId() <<
162*89c4ff92SAndroid Build Coastguard Worker                     " is incorrect";
163*89c4ff92SAndroid Build Coastguard Worker                 return TestCaseResult::Failed;
164*89c4ff92SAndroid Build Coastguard Worker             }
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker             ++outputIt;
167*89c4ff92SAndroid Build Coastguard Worker         }
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker         return TestCaseResult::Ok;
170*89c4ff92SAndroid Build Coastguard Worker     }
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker private:
173*89c4ff92SAndroid Build Coastguard Worker     std::vector<YoloDetectedObject> m_TopObjectDetections;
174*89c4ff92SAndroid Build Coastguard Worker };
175*89c4ff92SAndroid Build Coastguard Worker 
176*89c4ff92SAndroid Build Coastguard Worker template <typename Model>
177*89c4ff92SAndroid Build Coastguard Worker class YoloTestCaseProvider : public IInferenceTestCaseProvider
178*89c4ff92SAndroid Build Coastguard Worker {
179*89c4ff92SAndroid Build Coastguard Worker public:
180*89c4ff92SAndroid Build Coastguard Worker     template <typename TConstructModelCallable>
YoloTestCaseProvider(TConstructModelCallable constructModel)181*89c4ff92SAndroid Build Coastguard Worker     explicit YoloTestCaseProvider(TConstructModelCallable constructModel)
182*89c4ff92SAndroid Build Coastguard Worker         : m_ConstructModel(constructModel)
183*89c4ff92SAndroid Build Coastguard Worker     {
184*89c4ff92SAndroid Build Coastguard Worker     }
185*89c4ff92SAndroid Build Coastguard Worker 
AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)186*89c4ff92SAndroid Build Coastguard Worker     virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override
187*89c4ff92SAndroid Build Coastguard Worker     {
188*89c4ff92SAndroid Build Coastguard Worker         options
189*89c4ff92SAndroid Build Coastguard Worker             .allow_unrecognised_options()
190*89c4ff92SAndroid Build Coastguard Worker             .add_options()
191*89c4ff92SAndroid Build Coastguard Worker                 ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
192*89c4ff92SAndroid Build Coastguard Worker 
193*89c4ff92SAndroid Build Coastguard Worker         Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
194*89c4ff92SAndroid Build Coastguard Worker     }
195*89c4ff92SAndroid Build Coastguard Worker 
ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)196*89c4ff92SAndroid Build Coastguard Worker     virtual bool ProcessCommandLineOptions(const InferenceTestOptions& commonOptions) override
197*89c4ff92SAndroid Build Coastguard Worker     {
198*89c4ff92SAndroid Build Coastguard Worker         if (!ValidateDirectory(m_DataDir))
199*89c4ff92SAndroid Build Coastguard Worker         {
200*89c4ff92SAndroid Build Coastguard Worker             return false;
201*89c4ff92SAndroid Build Coastguard Worker         }
202*89c4ff92SAndroid Build Coastguard Worker 
203*89c4ff92SAndroid Build Coastguard Worker         m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
204*89c4ff92SAndroid Build Coastguard Worker         if (!m_Model)
205*89c4ff92SAndroid Build Coastguard Worker         {
206*89c4ff92SAndroid Build Coastguard Worker             return false;
207*89c4ff92SAndroid Build Coastguard Worker         }
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker         m_Database = std::make_unique<YoloDatabase>(m_DataDir.c_str());
210*89c4ff92SAndroid Build Coastguard Worker         if (!m_Database)
211*89c4ff92SAndroid Build Coastguard Worker         {
212*89c4ff92SAndroid Build Coastguard Worker             return false;
213*89c4ff92SAndroid Build Coastguard Worker         }
214*89c4ff92SAndroid Build Coastguard Worker 
215*89c4ff92SAndroid Build Coastguard Worker         return true;
216*89c4ff92SAndroid Build Coastguard Worker     }
217*89c4ff92SAndroid Build Coastguard Worker 
GetTestCase(unsigned int testCaseId)218*89c4ff92SAndroid Build Coastguard Worker     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
219*89c4ff92SAndroid Build Coastguard Worker     {
220*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<YoloTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
221*89c4ff92SAndroid Build Coastguard Worker         if (!testCaseData)
222*89c4ff92SAndroid Build Coastguard Worker         {
223*89c4ff92SAndroid Build Coastguard Worker             return nullptr;
224*89c4ff92SAndroid Build Coastguard Worker         }
225*89c4ff92SAndroid Build Coastguard Worker 
226*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<YoloTestCase<Model>>(*m_Model, testCaseId, *testCaseData);
227*89c4ff92SAndroid Build Coastguard Worker     }
228*89c4ff92SAndroid Build Coastguard Worker 
229*89c4ff92SAndroid Build Coastguard Worker private:
230*89c4ff92SAndroid Build Coastguard Worker     typename Model::CommandLineOptions m_ModelCommandLineOptions;
231*89c4ff92SAndroid Build Coastguard Worker     std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
232*89c4ff92SAndroid Build Coastguard Worker                                          typename Model::CommandLineOptions)> m_ConstructModel;
233*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<Model> m_Model;
234*89c4ff92SAndroid Build Coastguard Worker 
235*89c4ff92SAndroid Build Coastguard Worker     std::string m_DataDir;
236*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<YoloDatabase> m_Database;
237*89c4ff92SAndroid Build Coastguard Worker };
238