xref: /aosp_15_r20/external/armnn/tests/MobileNetSsdDatabase.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "InferenceTestImage.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "ObjectDetectionCommon.hpp"
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <array>
17*89c4ff92SAndroid Build Coastguard Worker #include <memory>
18*89c4ff92SAndroid Build Coastguard Worker #include <string>
19*89c4ff92SAndroid Build Coastguard Worker #include <vector>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker namespace
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker struct MobileNetSsdTestCaseData
25*89c4ff92SAndroid Build Coastguard Worker {
MobileNetSsdTestCaseData__anon3c500c880111::MobileNetSsdTestCaseData26*89c4ff92SAndroid Build Coastguard Worker     MobileNetSsdTestCaseData(
27*89c4ff92SAndroid Build Coastguard Worker         const std::vector<uint8_t>& inputData,
28*89c4ff92SAndroid Build Coastguard Worker         const std::vector<DetectedObject>& expectedDetectedObject,
29*89c4ff92SAndroid Build Coastguard Worker         const std::vector<std::vector<float>>& expectedOutput)
30*89c4ff92SAndroid Build Coastguard Worker         : m_InputData(inputData)
31*89c4ff92SAndroid Build Coastguard Worker         , m_ExpectedDetectedObject(expectedDetectedObject)
32*89c4ff92SAndroid Build Coastguard Worker         , m_ExpectedOutput(expectedOutput)
33*89c4ff92SAndroid Build Coastguard Worker     {}
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t>            m_InputData;
36*89c4ff92SAndroid Build Coastguard Worker     std::vector<DetectedObject>     m_ExpectedDetectedObject;
37*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<float>> m_ExpectedOutput;
38*89c4ff92SAndroid Build Coastguard Worker };
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker class MobileNetSsdDatabase
41*89c4ff92SAndroid Build Coastguard Worker {
42*89c4ff92SAndroid Build Coastguard Worker public:
43*89c4ff92SAndroid Build Coastguard Worker     explicit MobileNetSsdDatabase(const std::string& imageDir, float scale, int offset);
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<MobileNetSsdTestCaseData> GetTestCaseData(unsigned int testCaseId);
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker private:
48*89c4ff92SAndroid Build Coastguard Worker     std::string m_ImageDir;
49*89c4ff92SAndroid Build Coastguard Worker     float m_Scale;
50*89c4ff92SAndroid Build Coastguard Worker     int m_Offset;
51*89c4ff92SAndroid Build Coastguard Worker };
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int k_MobileNetSsdImageWidth  = 300u;
54*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int k_MobileNetSsdImageHeight = k_MobileNetSsdImageWidth;
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker // Test cases
57*89c4ff92SAndroid Build Coastguard Worker const std::array<ObjectDetectionInput, 1> g_PerTestCaseInput =
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker     ObjectDetectionInput
60*89c4ff92SAndroid Build Coastguard Worker     {
61*89c4ff92SAndroid Build Coastguard Worker         "Cat.jpg",
62*89c4ff92SAndroid Build Coastguard Worker         {
63*89c4ff92SAndroid Build Coastguard Worker           DetectedObject(16.0f, BoundingBox(0.216785252f, 0.079726994f, 0.927124202f, 0.939067304f), 0.79296875f)
64*89c4ff92SAndroid Build Coastguard Worker         }
65*89c4ff92SAndroid Build Coastguard Worker     }
66*89c4ff92SAndroid Build Coastguard Worker };
67*89c4ff92SAndroid Build Coastguard Worker 
MobileNetSsdDatabase(const std::string & imageDir,float scale,int offset)68*89c4ff92SAndroid Build Coastguard Worker MobileNetSsdDatabase::MobileNetSsdDatabase(const std::string& imageDir, float scale, int offset)
69*89c4ff92SAndroid Build Coastguard Worker     : m_ImageDir(imageDir)
70*89c4ff92SAndroid Build Coastguard Worker     , m_Scale(scale)
71*89c4ff92SAndroid Build Coastguard Worker     , m_Offset(offset)
72*89c4ff92SAndroid Build Coastguard Worker {}
73*89c4ff92SAndroid Build Coastguard Worker 
GetTestCaseData(unsigned int testCaseId)74*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<MobileNetSsdTestCaseData> MobileNetSsdDatabase::GetTestCaseData(unsigned int testCaseId)
75*89c4ff92SAndroid Build Coastguard Worker {
76*89c4ff92SAndroid Build Coastguard Worker     const unsigned int safeTestCaseId =
77*89c4ff92SAndroid Build Coastguard Worker         testCaseId % armnn::numeric_cast<unsigned int>(g_PerTestCaseInput.size());
78*89c4ff92SAndroid Build Coastguard Worker     const ObjectDetectionInput& testCaseInput = g_PerTestCaseInput[safeTestCaseId];
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     // Load test case input
81*89c4ff92SAndroid Build Coastguard Worker     const std::string imagePath = m_ImageDir + testCaseInput.first;
82*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> imageData;
83*89c4ff92SAndroid Build Coastguard Worker     try
84*89c4ff92SAndroid Build Coastguard Worker     {
85*89c4ff92SAndroid Build Coastguard Worker         InferenceTestImage image(imagePath.c_str());
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker         // Resize image (if needed)
88*89c4ff92SAndroid Build Coastguard Worker         const unsigned int width  = image.GetWidth();
89*89c4ff92SAndroid Build Coastguard Worker         const unsigned int height = image.GetHeight();
90*89c4ff92SAndroid Build Coastguard Worker         if (width != k_MobileNetSsdImageWidth || height != k_MobileNetSsdImageHeight)
91*89c4ff92SAndroid Build Coastguard Worker         {
92*89c4ff92SAndroid Build Coastguard Worker             image.Resize(k_MobileNetSsdImageWidth, k_MobileNetSsdImageHeight, CHECK_LOCATION());
93*89c4ff92SAndroid Build Coastguard Worker         }
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker         // Get image data as a vector of floats
96*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> floatImageData = GetImageDataAsNormalizedFloats(ImageChannelLayout::Rgb, image);
97*89c4ff92SAndroid Build Coastguard Worker         imageData = armnnUtils::QuantizedVector<uint8_t>(floatImageData, m_Scale, m_Offset);
98*89c4ff92SAndroid Build Coastguard Worker     }
99*89c4ff92SAndroid Build Coastguard Worker     catch (const InferenceTestImageException& e)
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
102*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
103*89c4ff92SAndroid Build Coastguard Worker     }
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> numDetections = { static_cast<float>(testCaseInput.second.size()) };
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> detectionBoxes;
108*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> detectionClasses;
109*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> detectionScores;
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker     for (DetectedObject expectedObject : testCaseInput.second)
112*89c4ff92SAndroid Build Coastguard Worker     {
113*89c4ff92SAndroid Build Coastguard Worker             detectionBoxes.push_back(expectedObject.m_BoundingBox.m_YMin);
114*89c4ff92SAndroid Build Coastguard Worker             detectionBoxes.push_back(expectedObject.m_BoundingBox.m_XMin);
115*89c4ff92SAndroid Build Coastguard Worker             detectionBoxes.push_back(expectedObject.m_BoundingBox.m_YMax);
116*89c4ff92SAndroid Build Coastguard Worker             detectionBoxes.push_back(expectedObject.m_BoundingBox.m_XMax);
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker             detectionClasses.push_back(expectedObject.m_Class);
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker             detectionScores.push_back(expectedObject.m_Confidence);
121*89c4ff92SAndroid Build Coastguard Worker     }
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     // Prepare test case expected output
124*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<float>> expectedOutputs;
125*89c4ff92SAndroid Build Coastguard Worker     expectedOutputs.reserve(4);
126*89c4ff92SAndroid Build Coastguard Worker     expectedOutputs.push_back(detectionBoxes);
127*89c4ff92SAndroid Build Coastguard Worker     expectedOutputs.push_back(detectionClasses);
128*89c4ff92SAndroid Build Coastguard Worker     expectedOutputs.push_back(detectionScores);
129*89c4ff92SAndroid Build Coastguard Worker     expectedOutputs.push_back(numDetections);
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<MobileNetSsdTestCaseData>(imageData, testCaseInput.second, expectedOutputs);
132*89c4ff92SAndroid Build Coastguard Worker }
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
135