1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include "InferenceTest.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include "MobileNetSsdDatabase.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp> 12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp> 13*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/FloatingPointComparison.hpp> 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker #include <vector> 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker namespace 18*89c4ff92SAndroid Build Coastguard Worker { 19*89c4ff92SAndroid Build Coastguard Worker 20*89c4ff92SAndroid Build Coastguard Worker template<typename Model> 21*89c4ff92SAndroid Build Coastguard Worker class MobileNetSsdTestCase : public InferenceModelTestCase<Model> 22*89c4ff92SAndroid Build Coastguard Worker { 23*89c4ff92SAndroid Build Coastguard Worker public: MobileNetSsdTestCase(Model & model,unsigned int testCaseId,const MobileNetSsdTestCaseData & testCaseData)24*89c4ff92SAndroid Build Coastguard Worker MobileNetSsdTestCase(Model& model, 25*89c4ff92SAndroid Build Coastguard Worker unsigned int testCaseId, 26*89c4ff92SAndroid Build Coastguard Worker const MobileNetSsdTestCaseData& testCaseData) 27*89c4ff92SAndroid Build Coastguard Worker : InferenceModelTestCase<Model>(model, 28*89c4ff92SAndroid Build Coastguard Worker testCaseId, 29*89c4ff92SAndroid Build Coastguard Worker { std::move(testCaseData.m_InputData) }, 30*89c4ff92SAndroid Build Coastguard Worker { k_OutputSize1, k_OutputSize2, k_OutputSize3, k_OutputSize4 }) 31*89c4ff92SAndroid Build Coastguard Worker , m_DetectedObjects(testCaseData.m_ExpectedDetectedObject) 32*89c4ff92SAndroid Build Coastguard Worker {} 33*89c4ff92SAndroid Build Coastguard Worker ProcessResult(const InferenceTestOptions & options)34*89c4ff92SAndroid Build Coastguard Worker TestCaseResult ProcessResult(const InferenceTestOptions& options) override 35*89c4ff92SAndroid Build Coastguard Worker { 36*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(options); 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker // bounding boxes 39*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& output1 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]); 40*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(output1.size() == k_OutputSize1); 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker // classes 43*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& output2 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[1]); 44*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(output2.size() == k_OutputSize2); 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker // scores 47*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& output3 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[2]); 48*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(output3.size() == k_OutputSize3); 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker // valid detections 51*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& output4 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[3]); 52*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(output4.size() == k_OutputSize4); 53*89c4ff92SAndroid Build Coastguard Worker 54*89c4ff92SAndroid Build Coastguard Worker const size_t numDetections = armnn::numeric_cast<size_t>(output4[0]); 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker // Check if number of valid detections matches expectations 57*89c4ff92SAndroid Build Coastguard Worker const size_t expectedNumDetections = m_DetectedObjects.size(); 58*89c4ff92SAndroid Build Coastguard Worker if (numDetections != expectedNumDetections) 59*89c4ff92SAndroid Build Coastguard Worker { 60*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Number of detections is incorrect: Expected (" << 61*89c4ff92SAndroid Build Coastguard Worker expectedNumDetections << ")" << " but got (" << numDetections << ")"; 62*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 63*89c4ff92SAndroid Build Coastguard Worker } 64*89c4ff92SAndroid Build Coastguard Worker 65*89c4ff92SAndroid Build Coastguard Worker // Extract detected objects from output data 66*89c4ff92SAndroid Build Coastguard Worker std::vector<DetectedObject> detectedObjects; 67*89c4ff92SAndroid Build Coastguard Worker const float* outputData = output1.data(); 68*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < numDetections; i++) 69*89c4ff92SAndroid Build Coastguard Worker { 70*89c4ff92SAndroid Build Coastguard Worker // NOTE: Order of coordinates in output data is yMin, xMin, yMax, xMax 71*89c4ff92SAndroid Build Coastguard Worker float yMin = *outputData++; 72*89c4ff92SAndroid Build Coastguard Worker float xMin = *outputData++; 73*89c4ff92SAndroid Build Coastguard Worker float yMax = *outputData++; 74*89c4ff92SAndroid Build Coastguard Worker float xMax = *outputData++; 75*89c4ff92SAndroid Build Coastguard Worker 76*89c4ff92SAndroid Build Coastguard Worker DetectedObject detectedObject( 77*89c4ff92SAndroid Build Coastguard Worker output2.at(i), 78*89c4ff92SAndroid Build Coastguard Worker BoundingBox(xMin, yMin, xMax, yMax), 79*89c4ff92SAndroid Build Coastguard Worker output3.at(i)); 80*89c4ff92SAndroid Build Coastguard Worker 81*89c4ff92SAndroid Build Coastguard Worker detectedObjects.push_back(detectedObject); 82*89c4ff92SAndroid Build Coastguard Worker } 83*89c4ff92SAndroid Build Coastguard Worker 84*89c4ff92SAndroid Build Coastguard Worker std::sort(detectedObjects.begin(), detectedObjects.end()); 85*89c4ff92SAndroid Build Coastguard Worker std::sort(m_DetectedObjects.begin(), m_DetectedObjects.end()); 86*89c4ff92SAndroid Build Coastguard Worker 87*89c4ff92SAndroid Build Coastguard Worker // Compare detected objects with expected results 88*89c4ff92SAndroid Build Coastguard Worker std::vector<DetectedObject>::const_iterator it = detectedObjects.begin(); 89*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numDetections; i++) 90*89c4ff92SAndroid Build Coastguard Worker { 91*89c4ff92SAndroid Build Coastguard Worker if (it == detectedObjects.end()) 92*89c4ff92SAndroid Build Coastguard Worker { 93*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "No more detected objects found! Index out of bounds: " << i; 94*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Abort; 95*89c4ff92SAndroid Build Coastguard Worker } 96*89c4ff92SAndroid Build Coastguard Worker 97*89c4ff92SAndroid Build Coastguard Worker const DetectedObject& detectedObject = *it; 98*89c4ff92SAndroid Build Coastguard Worker const DetectedObject& expectedObject = m_DetectedObjects[i]; 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Worker if (detectedObject.m_Class != expectedObject.m_Class) 101*89c4ff92SAndroid Build Coastguard Worker { 102*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Prediction for test case " << this->GetTestCaseId() << 103*89c4ff92SAndroid Build Coastguard Worker " is incorrect: Expected (" << expectedObject.m_Class << ")" << 104*89c4ff92SAndroid Build Coastguard Worker " but predicted (" << detectedObject.m_Class << ")"; 105*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 106*89c4ff92SAndroid Build Coastguard Worker } 107*89c4ff92SAndroid Build Coastguard Worker 108*89c4ff92SAndroid Build Coastguard Worker if(!armnnUtils::within_percentage_tolerance(detectedObject.m_Confidence, expectedObject.m_Confidence)) 109*89c4ff92SAndroid Build Coastguard Worker { 110*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Confidence of prediction for test case " << this->GetTestCaseId() << 111*89c4ff92SAndroid Build Coastguard Worker " is incorrect: Expected (" << expectedObject.m_Confidence << ") +- 1.0 pc" << 112*89c4ff92SAndroid Build Coastguard Worker " but predicted (" << detectedObject.m_Confidence << ")"; 113*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 114*89c4ff92SAndroid Build Coastguard Worker } 115*89c4ff92SAndroid Build Coastguard Worker 116*89c4ff92SAndroid Build Coastguard Worker if (!armnnUtils::within_percentage_tolerance(detectedObject.m_BoundingBox.m_XMin, 117*89c4ff92SAndroid Build Coastguard Worker expectedObject.m_BoundingBox.m_XMin) || 118*89c4ff92SAndroid Build Coastguard Worker !armnnUtils::within_percentage_tolerance(detectedObject.m_BoundingBox.m_YMin, 119*89c4ff92SAndroid Build Coastguard Worker expectedObject.m_BoundingBox.m_YMin) || 120*89c4ff92SAndroid Build Coastguard Worker !armnnUtils::within_percentage_tolerance(detectedObject.m_BoundingBox.m_XMax, 121*89c4ff92SAndroid Build Coastguard Worker expectedObject.m_BoundingBox.m_XMax) || 122*89c4ff92SAndroid Build Coastguard Worker !armnnUtils::within_percentage_tolerance(detectedObject.m_BoundingBox.m_YMax, 123*89c4ff92SAndroid Build Coastguard Worker expectedObject.m_BoundingBox.m_YMax)) 124*89c4ff92SAndroid Build Coastguard Worker { 125*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Detected bounding box for test case " << this->GetTestCaseId() << 126*89c4ff92SAndroid Build Coastguard Worker " is incorrect"; 127*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 128*89c4ff92SAndroid Build Coastguard Worker } 129*89c4ff92SAndroid Build Coastguard Worker 130*89c4ff92SAndroid Build Coastguard Worker ++it; 131*89c4ff92SAndroid Build Coastguard Worker } 132*89c4ff92SAndroid Build Coastguard Worker 133*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Ok; 134*89c4ff92SAndroid Build Coastguard Worker } 135*89c4ff92SAndroid Build Coastguard Worker 136*89c4ff92SAndroid Build Coastguard Worker private: 137*89c4ff92SAndroid Build Coastguard Worker static constexpr unsigned int k_Shape = 10u; 138*89c4ff92SAndroid Build Coastguard Worker 139*89c4ff92SAndroid Build Coastguard Worker static constexpr unsigned int k_OutputSize1 = k_Shape * 4u; 140*89c4ff92SAndroid Build Coastguard Worker static constexpr unsigned int k_OutputSize2 = k_Shape; 141*89c4ff92SAndroid Build Coastguard Worker static constexpr unsigned int k_OutputSize3 = k_Shape; 142*89c4ff92SAndroid Build Coastguard Worker static constexpr unsigned int k_OutputSize4 = 1u; 143*89c4ff92SAndroid Build Coastguard Worker 144*89c4ff92SAndroid Build Coastguard Worker std::vector<DetectedObject> m_DetectedObjects; 145*89c4ff92SAndroid Build Coastguard Worker }; 146*89c4ff92SAndroid Build Coastguard Worker 147*89c4ff92SAndroid Build Coastguard Worker template <typename Model> 148*89c4ff92SAndroid Build Coastguard Worker class MobileNetSsdTestCaseProvider : public IInferenceTestCaseProvider 149*89c4ff92SAndroid Build Coastguard Worker { 150*89c4ff92SAndroid Build Coastguard Worker public: 151*89c4ff92SAndroid Build Coastguard Worker template <typename TConstructModelCallable> MobileNetSsdTestCaseProvider(TConstructModelCallable constructModel)152*89c4ff92SAndroid Build Coastguard Worker explicit MobileNetSsdTestCaseProvider(TConstructModelCallable constructModel) 153*89c4ff92SAndroid Build Coastguard Worker : m_ConstructModel(constructModel) 154*89c4ff92SAndroid Build Coastguard Worker {} 155*89c4ff92SAndroid Build Coastguard Worker AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)156*89c4ff92SAndroid Build Coastguard Worker virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override 157*89c4ff92SAndroid Build Coastguard Worker { 158*89c4ff92SAndroid Build Coastguard Worker options 159*89c4ff92SAndroid Build Coastguard Worker .allow_unrecognised_options() 160*89c4ff92SAndroid Build Coastguard Worker .add_options() 161*89c4ff92SAndroid Build Coastguard Worker ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir)); 162*89c4ff92SAndroid Build Coastguard Worker 163*89c4ff92SAndroid Build Coastguard Worker required.emplace_back("data-dir"); 164*89c4ff92SAndroid Build Coastguard Worker 165*89c4ff92SAndroid Build Coastguard Worker Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required); 166*89c4ff92SAndroid Build Coastguard Worker } 167*89c4ff92SAndroid Build Coastguard Worker ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)168*89c4ff92SAndroid Build Coastguard Worker virtual bool ProcessCommandLineOptions(const InferenceTestOptions& commonOptions) override 169*89c4ff92SAndroid Build Coastguard Worker { 170*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(m_DataDir)) 171*89c4ff92SAndroid Build Coastguard Worker { 172*89c4ff92SAndroid Build Coastguard Worker return false; 173*89c4ff92SAndroid Build Coastguard Worker } 174*89c4ff92SAndroid Build Coastguard Worker 175*89c4ff92SAndroid Build Coastguard Worker m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions); 176*89c4ff92SAndroid Build Coastguard Worker if (!m_Model) 177*89c4ff92SAndroid Build Coastguard Worker { 178*89c4ff92SAndroid Build Coastguard Worker return false; 179*89c4ff92SAndroid Build Coastguard Worker } 180*89c4ff92SAndroid Build Coastguard Worker std::pair<float, int32_t> qParams = m_Model->GetInputQuantizationParams(); 181*89c4ff92SAndroid Build Coastguard Worker m_Database = std::make_unique<MobileNetSsdDatabase>(m_DataDir.c_str(), qParams.first, qParams.second); 182*89c4ff92SAndroid Build Coastguard Worker if (!m_Database) 183*89c4ff92SAndroid Build Coastguard Worker { 184*89c4ff92SAndroid Build Coastguard Worker return false; 185*89c4ff92SAndroid Build Coastguard Worker } 186*89c4ff92SAndroid Build Coastguard Worker 187*89c4ff92SAndroid Build Coastguard Worker return true; 188*89c4ff92SAndroid Build Coastguard Worker } 189*89c4ff92SAndroid Build Coastguard Worker GetTestCase(unsigned int testCaseId)190*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override 191*89c4ff92SAndroid Build Coastguard Worker { 192*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<MobileNetSsdTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId); 193*89c4ff92SAndroid Build Coastguard Worker if (!testCaseData) 194*89c4ff92SAndroid Build Coastguard Worker { 195*89c4ff92SAndroid Build Coastguard Worker return nullptr; 196*89c4ff92SAndroid Build Coastguard Worker } 197*89c4ff92SAndroid Build Coastguard Worker 198*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<MobileNetSsdTestCase<Model>>(*m_Model, testCaseId, *testCaseData); 199*89c4ff92SAndroid Build Coastguard Worker } 200*89c4ff92SAndroid Build Coastguard Worker 201*89c4ff92SAndroid Build Coastguard Worker private: 202*89c4ff92SAndroid Build Coastguard Worker typename Model::CommandLineOptions m_ModelCommandLineOptions; 203*89c4ff92SAndroid Build Coastguard Worker std::function<std::unique_ptr<Model>(const InferenceTestOptions &, 204*89c4ff92SAndroid Build Coastguard Worker typename Model::CommandLineOptions)> m_ConstructModel; 205*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Model> m_Model; 206*89c4ff92SAndroid Build Coastguard Worker 207*89c4ff92SAndroid Build Coastguard Worker std::string m_DataDir; 208*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<MobileNetSsdDatabase> m_Database; 209*89c4ff92SAndroid Build Coastguard Worker }; 210*89c4ff92SAndroid Build Coastguard Worker 211*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace