xref: /aosp_15_r20/external/armnn/tests/MobileNetSsdInferenceTest.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "InferenceTest.hpp"
8 #include "MobileNetSsdDatabase.hpp"
9 
10 #include <armnn/utility/Assert.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12 #include <armnn/utility/NumericCast.hpp>
13 #include <armnnUtils/FloatingPointComparison.hpp>
14 
15 #include <vector>
16 
17 namespace
18 {
19 
20 template<typename Model>
21 class MobileNetSsdTestCase : public InferenceModelTestCase<Model>
22 {
23 public:
MobileNetSsdTestCase(Model & model,unsigned int testCaseId,const MobileNetSsdTestCaseData & testCaseData)24     MobileNetSsdTestCase(Model& model,
25                          unsigned int testCaseId,
26                          const MobileNetSsdTestCaseData& testCaseData)
27         : InferenceModelTestCase<Model>(model,
28                                         testCaseId,
29                                         { std::move(testCaseData.m_InputData) },
30                                         { k_OutputSize1, k_OutputSize2, k_OutputSize3, k_OutputSize4 })
31         , m_DetectedObjects(testCaseData.m_ExpectedDetectedObject)
32     {}
33 
ProcessResult(const InferenceTestOptions & options)34     TestCaseResult ProcessResult(const InferenceTestOptions& options) override
35     {
36         armnn::IgnoreUnused(options);
37 
38         // bounding boxes
39         const std::vector<float>& output1 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]);
40         ARMNN_ASSERT(output1.size() == k_OutputSize1);
41 
42         // classes
43         const std::vector<float>& output2 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[1]);
44         ARMNN_ASSERT(output2.size() == k_OutputSize2);
45 
46         // scores
47         const std::vector<float>& output3 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[2]);
48         ARMNN_ASSERT(output3.size() == k_OutputSize3);
49 
50         // valid detections
51         const std::vector<float>& output4 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[3]);
52         ARMNN_ASSERT(output4.size() == k_OutputSize4);
53 
54         const size_t numDetections = armnn::numeric_cast<size_t>(output4[0]);
55 
56         // Check if number of valid detections matches expectations
57         const size_t expectedNumDetections = m_DetectedObjects.size();
58         if (numDetections != expectedNumDetections)
59         {
60             ARMNN_LOG(error) << "Number of detections is incorrect: Expected (" <<
61                 expectedNumDetections << ")" << " but got (" << numDetections << ")";
62             return TestCaseResult::Failed;
63         }
64 
65         // Extract detected objects from output data
66         std::vector<DetectedObject> detectedObjects;
67         const float* outputData = output1.data();
68         for (unsigned int i = 0u; i < numDetections; i++)
69         {
70             // NOTE: Order of coordinates in output data is yMin, xMin, yMax, xMax
71             float yMin = *outputData++;
72             float xMin = *outputData++;
73             float yMax = *outputData++;
74             float xMax = *outputData++;
75 
76             DetectedObject detectedObject(
77                 output2.at(i),
78                 BoundingBox(xMin, yMin, xMax, yMax),
79                 output3.at(i));
80 
81             detectedObjects.push_back(detectedObject);
82         }
83 
84         std::sort(detectedObjects.begin(), detectedObjects.end());
85         std::sort(m_DetectedObjects.begin(), m_DetectedObjects.end());
86 
87         // Compare detected objects with expected results
88         std::vector<DetectedObject>::const_iterator it = detectedObjects.begin();
89         for (unsigned int i = 0; i < numDetections; i++)
90         {
91             if (it == detectedObjects.end())
92             {
93                 ARMNN_LOG(error) << "No more detected objects found! Index out of bounds: " << i;
94                 return TestCaseResult::Abort;
95             }
96 
97             const DetectedObject& detectedObject = *it;
98             const DetectedObject& expectedObject = m_DetectedObjects[i];
99 
100             if (detectedObject.m_Class != expectedObject.m_Class)
101             {
102                 ARMNN_LOG(error) << "Prediction for test case " << this->GetTestCaseId() <<
103                     " is incorrect: Expected (" << expectedObject.m_Class << ")" <<
104                     " but predicted (" << detectedObject.m_Class << ")";
105                 return TestCaseResult::Failed;
106             }
107 
108             if(!armnnUtils::within_percentage_tolerance(detectedObject.m_Confidence, expectedObject.m_Confidence))
109             {
110                 ARMNN_LOG(error) << "Confidence of prediction for test case " << this->GetTestCaseId() <<
111                     " is incorrect: Expected (" << expectedObject.m_Confidence << ")  +- 1.0 pc" <<
112                     " but predicted (" << detectedObject.m_Confidence << ")";
113                 return TestCaseResult::Failed;
114             }
115 
116             if (!armnnUtils::within_percentage_tolerance(detectedObject.m_BoundingBox.m_XMin,
117                                                          expectedObject.m_BoundingBox.m_XMin) ||
118                 !armnnUtils::within_percentage_tolerance(detectedObject.m_BoundingBox.m_YMin,
119                                                          expectedObject.m_BoundingBox.m_YMin) ||
120                 !armnnUtils::within_percentage_tolerance(detectedObject.m_BoundingBox.m_XMax,
121                                                          expectedObject.m_BoundingBox.m_XMax) ||
122                 !armnnUtils::within_percentage_tolerance(detectedObject.m_BoundingBox.m_YMax,
123                                                          expectedObject.m_BoundingBox.m_YMax))
124             {
125                 ARMNN_LOG(error) << "Detected bounding box for test case " << this->GetTestCaseId() <<
126                     " is incorrect";
127                 return TestCaseResult::Failed;
128             }
129 
130             ++it;
131         }
132 
133         return TestCaseResult::Ok;
134     }
135 
136 private:
137     static constexpr unsigned int k_Shape       = 10u;
138 
139     static constexpr unsigned int k_OutputSize1 = k_Shape * 4u;
140     static constexpr unsigned int k_OutputSize2 = k_Shape;
141     static constexpr unsigned int k_OutputSize3 = k_Shape;
142     static constexpr unsigned int k_OutputSize4 = 1u;
143 
144     std::vector<DetectedObject>                 m_DetectedObjects;
145 };
146 
147 template <typename Model>
148 class MobileNetSsdTestCaseProvider : public IInferenceTestCaseProvider
149 {
150 public:
151     template <typename TConstructModelCallable>
MobileNetSsdTestCaseProvider(TConstructModelCallable constructModel)152     explicit MobileNetSsdTestCaseProvider(TConstructModelCallable constructModel)
153         : m_ConstructModel(constructModel)
154     {}
155 
AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)156     virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override
157     {
158         options
159             .allow_unrecognised_options()
160             .add_options()
161                 ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
162 
163         required.emplace_back("data-dir");
164 
165         Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
166     }
167 
ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)168     virtual bool ProcessCommandLineOptions(const InferenceTestOptions& commonOptions) override
169     {
170         if (!ValidateDirectory(m_DataDir))
171         {
172             return false;
173         }
174 
175         m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
176         if (!m_Model)
177         {
178             return false;
179         }
180         std::pair<float, int32_t> qParams = m_Model->GetInputQuantizationParams();
181         m_Database = std::make_unique<MobileNetSsdDatabase>(m_DataDir.c_str(), qParams.first, qParams.second);
182         if (!m_Database)
183         {
184             return false;
185         }
186 
187         return true;
188     }
189 
GetTestCase(unsigned int testCaseId)190     std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
191     {
192         std::unique_ptr<MobileNetSsdTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
193         if (!testCaseData)
194         {
195             return nullptr;
196         }
197 
198         return std::make_unique<MobileNetSsdTestCase<Model>>(*m_Model, testCaseId, *testCaseData);
199     }
200 
201 private:
202     typename Model::CommandLineOptions m_ModelCommandLineOptions;
203     std::function<std::unique_ptr<Model>(const InferenceTestOptions &,
204                                          typename Model::CommandLineOptions)> m_ConstructModel;
205     std::unique_ptr<Model> m_Model;
206 
207     std::string m_DataDir;
208     std::unique_ptr<MobileNetSsdDatabase> m_Database;
209 };
210 
211 } // anonymous namespace