1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include "InferenceTest.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include "YoloDatabase.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp> 12*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/FloatingPointComparison.hpp> 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker #include <algorithm> 15*89c4ff92SAndroid Build Coastguard Worker #include <array> 16*89c4ff92SAndroid Build Coastguard Worker #include <utility> 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker constexpr size_t YoloOutputSize = 1470; 19*89c4ff92SAndroid Build Coastguard Worker 20*89c4ff92SAndroid Build Coastguard Worker template <typename Model> 21*89c4ff92SAndroid Build Coastguard Worker class YoloTestCase : public InferenceModelTestCase<Model> 22*89c4ff92SAndroid Build Coastguard Worker { 23*89c4ff92SAndroid Build Coastguard Worker public: YoloTestCase(Model & model,unsigned int testCaseId,YoloTestCaseData & testCaseData)24*89c4ff92SAndroid Build Coastguard Worker YoloTestCase(Model& model, 25*89c4ff92SAndroid Build Coastguard Worker unsigned int testCaseId, 26*89c4ff92SAndroid Build Coastguard Worker YoloTestCaseData& testCaseData) 27*89c4ff92SAndroid Build Coastguard Worker : InferenceModelTestCase<Model>(model, testCaseId, { std::move(testCaseData.m_InputImage) }, { YoloOutputSize }) 28*89c4ff92SAndroid Build Coastguard Worker , m_TopObjectDetections(std::move(testCaseData.m_TopObjectDetections)) 29*89c4ff92SAndroid Build Coastguard Worker { 30*89c4ff92SAndroid Build Coastguard Worker } 31*89c4ff92SAndroid Build Coastguard Worker ProcessResult(const InferenceTestOptions & options)32*89c4ff92SAndroid Build Coastguard Worker virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) override 33*89c4ff92SAndroid Build Coastguard Worker { 34*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(options); 35*89c4ff92SAndroid Build Coastguard Worker 36*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& output = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]); 37*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(output.size() == YoloOutputSize); 38*89c4ff92SAndroid Build Coastguard Worker 39*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int gridSize = 7; 40*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int numClasses = 20; 41*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int numScales = 2; 42*89c4ff92SAndroid Build Coastguard Worker 43*89c4ff92SAndroid Build Coastguard Worker const float* outputPtr = output.data(); 44*89c4ff92SAndroid Build Coastguard Worker 45*89c4ff92SAndroid Build Coastguard Worker // Range 0-980. Class probabilities. 7x7x20 46*89c4ff92SAndroid Build Coastguard Worker vector<vector<vector<float>>> classProbabilities(gridSize, vector<vector<float>>(gridSize, 47*89c4ff92SAndroid Build Coastguard Worker vector<float>(numClasses))); 48*89c4ff92SAndroid Build Coastguard Worker for (unsigned int y = 0; y < gridSize; ++y) 49*89c4ff92SAndroid Build Coastguard Worker { 50*89c4ff92SAndroid Build Coastguard Worker for (unsigned int x = 0; x < gridSize; ++x) 51*89c4ff92SAndroid Build Coastguard Worker { 52*89c4ff92SAndroid Build Coastguard Worker for (unsigned int c = 0; c < numClasses; ++c) 53*89c4ff92SAndroid Build Coastguard Worker { 54*89c4ff92SAndroid Build Coastguard Worker classProbabilities[y][x][c] = *outputPtr++; 55*89c4ff92SAndroid Build Coastguard Worker } 56*89c4ff92SAndroid Build Coastguard Worker } 57*89c4ff92SAndroid Build Coastguard Worker } 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker // Range 980-1078. Scales. 7x7x2 60*89c4ff92SAndroid Build Coastguard Worker vector<vector<vector<float>>> scales(gridSize, vector<vector<float>>(gridSize, vector<float>(numScales))); 61*89c4ff92SAndroid Build Coastguard Worker for (unsigned int y = 0; y < gridSize; ++y) 62*89c4ff92SAndroid Build Coastguard Worker { 63*89c4ff92SAndroid Build Coastguard Worker for (unsigned int x = 0; x < gridSize; ++x) 64*89c4ff92SAndroid Build Coastguard Worker { 65*89c4ff92SAndroid Build Coastguard Worker for (unsigned int s = 0; s < numScales; ++s) 66*89c4ff92SAndroid Build Coastguard Worker { 67*89c4ff92SAndroid Build Coastguard Worker scales[y][x][s] = *outputPtr++; 68*89c4ff92SAndroid Build Coastguard Worker } 69*89c4ff92SAndroid Build Coastguard Worker } 70*89c4ff92SAndroid Build Coastguard Worker } 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker // Range 1078-1469. Bounding boxes. 7x7x2x4 73*89c4ff92SAndroid Build Coastguard Worker constexpr float imageWidthAsFloat = static_cast<float>(YoloImageWidth); 74*89c4ff92SAndroid Build Coastguard Worker constexpr float imageHeightAsFloat = static_cast<float>(YoloImageHeight); 75*89c4ff92SAndroid Build Coastguard Worker 76*89c4ff92SAndroid Build Coastguard Worker vector<vector<vector<vector<float>>>> boxes(gridSize, vector<vector<vector<float>>> 77*89c4ff92SAndroid Build Coastguard Worker (gridSize, vector<vector<float>>(numScales, vector<float>(4)))); 78*89c4ff92SAndroid Build Coastguard Worker for (unsigned int y = 0; y < gridSize; ++y) 79*89c4ff92SAndroid Build Coastguard Worker { 80*89c4ff92SAndroid Build Coastguard Worker for (unsigned int x = 0; x < gridSize; ++x) 81*89c4ff92SAndroid Build Coastguard Worker { 82*89c4ff92SAndroid Build Coastguard Worker for (unsigned int s = 0; s < numScales; ++s) 83*89c4ff92SAndroid Build Coastguard Worker { 84*89c4ff92SAndroid Build Coastguard Worker float bx = *outputPtr++; 85*89c4ff92SAndroid Build Coastguard Worker float by = *outputPtr++; 86*89c4ff92SAndroid Build Coastguard Worker float bw = *outputPtr++; 87*89c4ff92SAndroid Build Coastguard Worker float bh = *outputPtr++; 88*89c4ff92SAndroid Build Coastguard Worker 89*89c4ff92SAndroid Build Coastguard Worker boxes[y][x][s][0] = ((bx + static_cast<float>(x)) / 7.0f) * imageWidthAsFloat; 90*89c4ff92SAndroid Build Coastguard Worker boxes[y][x][s][1] = ((by + static_cast<float>(y)) / 7.0f) * imageHeightAsFloat; 91*89c4ff92SAndroid Build Coastguard Worker boxes[y][x][s][2] = bw * bw * static_cast<float>(imageWidthAsFloat); 92*89c4ff92SAndroid Build Coastguard Worker boxes[y][x][s][3] = bh * bh * static_cast<float>(imageHeightAsFloat); 93*89c4ff92SAndroid Build Coastguard Worker } 94*89c4ff92SAndroid Build Coastguard Worker } 95*89c4ff92SAndroid Build Coastguard Worker } 96*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(output.data() + YoloOutputSize == outputPtr); 97*89c4ff92SAndroid Build Coastguard Worker 98*89c4ff92SAndroid Build Coastguard Worker std::vector<YoloDetectedObject> detectedObjects; 99*89c4ff92SAndroid Build Coastguard Worker detectedObjects.reserve(gridSize * gridSize * numScales * numClasses); 100*89c4ff92SAndroid Build Coastguard Worker 101*89c4ff92SAndroid Build Coastguard Worker for (unsigned int y = 0; y < gridSize; ++y) 102*89c4ff92SAndroid Build Coastguard Worker { 103*89c4ff92SAndroid Build Coastguard Worker for (unsigned int x = 0; x < gridSize; ++x) 104*89c4ff92SAndroid Build Coastguard Worker { 105*89c4ff92SAndroid Build Coastguard Worker for (unsigned int s = 0; s < numScales; ++s) 106*89c4ff92SAndroid Build Coastguard Worker { 107*89c4ff92SAndroid Build Coastguard Worker for (unsigned int c = 0; c < numClasses; ++c) 108*89c4ff92SAndroid Build Coastguard Worker { 109*89c4ff92SAndroid Build Coastguard Worker // Resolved confidence: class probabilities * scales. 110*89c4ff92SAndroid Build Coastguard Worker const float confidence = classProbabilities[y][x][c] * scales[y][x][s]; 111*89c4ff92SAndroid Build Coastguard Worker 112*89c4ff92SAndroid Build Coastguard Worker // Resolves bounding box and stores. 113*89c4ff92SAndroid Build Coastguard Worker YoloBoundingBox box; 114*89c4ff92SAndroid Build Coastguard Worker box.m_X = boxes[y][x][s][0]; 115*89c4ff92SAndroid Build Coastguard Worker box.m_Y = boxes[y][x][s][1]; 116*89c4ff92SAndroid Build Coastguard Worker box.m_W = boxes[y][x][s][2]; 117*89c4ff92SAndroid Build Coastguard Worker box.m_H = boxes[y][x][s][3]; 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker detectedObjects.emplace_back(c, box, confidence); 120*89c4ff92SAndroid Build Coastguard Worker } 121*89c4ff92SAndroid Build Coastguard Worker } 122*89c4ff92SAndroid Build Coastguard Worker } 123*89c4ff92SAndroid Build Coastguard Worker } 124*89c4ff92SAndroid Build Coastguard Worker 125*89c4ff92SAndroid Build Coastguard Worker // Sorts detected objects by confidence. 126*89c4ff92SAndroid Build Coastguard Worker std::sort(detectedObjects.begin(), detectedObjects.end(), 127*89c4ff92SAndroid Build Coastguard Worker [](const YoloDetectedObject& a, const YoloDetectedObject& b) 128*89c4ff92SAndroid Build Coastguard Worker { 129*89c4ff92SAndroid Build Coastguard Worker // Sorts by largest confidence first, then by class. 130*89c4ff92SAndroid Build Coastguard Worker return a.m_Confidence > b.m_Confidence 131*89c4ff92SAndroid Build Coastguard Worker || (a.m_Confidence == b.m_Confidence && a.m_Class > b.m_Class); 132*89c4ff92SAndroid Build Coastguard Worker }); 133*89c4ff92SAndroid Build Coastguard Worker 134*89c4ff92SAndroid Build Coastguard Worker // Checks the top N detections. 135*89c4ff92SAndroid Build Coastguard Worker auto outputIt = detectedObjects.begin(); 136*89c4ff92SAndroid Build Coastguard Worker auto outputEnd = detectedObjects.end(); 137*89c4ff92SAndroid Build Coastguard Worker 138*89c4ff92SAndroid Build Coastguard Worker for (const YoloDetectedObject& expectedDetection : m_TopObjectDetections) 139*89c4ff92SAndroid Build Coastguard Worker { 140*89c4ff92SAndroid Build Coastguard Worker if (outputIt == outputEnd) 141*89c4ff92SAndroid Build Coastguard Worker { 142*89c4ff92SAndroid Build Coastguard Worker // Somehow expected more things to check than detections found by the model. 143*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Abort; 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker 146*89c4ff92SAndroid Build Coastguard Worker const YoloDetectedObject& detectedObject = *outputIt; 147*89c4ff92SAndroid Build Coastguard Worker if (detectedObject.m_Class != expectedDetection.m_Class) 148*89c4ff92SAndroid Build Coastguard Worker { 149*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Prediction for test case " << this->GetTestCaseId() << 150*89c4ff92SAndroid Build Coastguard Worker " is incorrect: Expected (" << expectedDetection.m_Class << ")" << 151*89c4ff92SAndroid Build Coastguard Worker " but predicted (" << detectedObject.m_Class << ")"; 152*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 153*89c4ff92SAndroid Build Coastguard Worker } 154*89c4ff92SAndroid Build Coastguard Worker 155*89c4ff92SAndroid Build Coastguard Worker if (!armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_X, expectedDetection.m_Box.m_X) || 156*89c4ff92SAndroid Build Coastguard Worker !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_Y, expectedDetection.m_Box.m_Y) || 157*89c4ff92SAndroid Build Coastguard Worker !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_W, expectedDetection.m_Box.m_W) || 158*89c4ff92SAndroid Build Coastguard Worker !armnnUtils::within_percentage_tolerance(detectedObject.m_Box.m_H, expectedDetection.m_Box.m_H) || 159*89c4ff92SAndroid Build Coastguard Worker !armnnUtils::within_percentage_tolerance(detectedObject.m_Confidence, expectedDetection.m_Confidence)) 160*89c4ff92SAndroid Build Coastguard Worker { 161*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Detected bounding box for test case " << this->GetTestCaseId() << 162*89c4ff92SAndroid Build Coastguard Worker " is incorrect"; 163*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 164*89c4ff92SAndroid Build Coastguard Worker } 165*89c4ff92SAndroid Build Coastguard Worker 166*89c4ff92SAndroid Build Coastguard Worker ++outputIt; 167*89c4ff92SAndroid Build Coastguard Worker } 168*89c4ff92SAndroid Build Coastguard Worker 169*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Ok; 170*89c4ff92SAndroid Build Coastguard Worker } 171*89c4ff92SAndroid Build Coastguard Worker 172*89c4ff92SAndroid Build Coastguard Worker private: 173*89c4ff92SAndroid Build Coastguard Worker std::vector<YoloDetectedObject> m_TopObjectDetections; 174*89c4ff92SAndroid Build Coastguard Worker }; 175*89c4ff92SAndroid Build Coastguard Worker 176*89c4ff92SAndroid Build Coastguard Worker template <typename Model> 177*89c4ff92SAndroid Build Coastguard Worker class YoloTestCaseProvider : public IInferenceTestCaseProvider 178*89c4ff92SAndroid Build Coastguard Worker { 179*89c4ff92SAndroid Build Coastguard Worker public: 180*89c4ff92SAndroid Build Coastguard Worker template <typename TConstructModelCallable> YoloTestCaseProvider(TConstructModelCallable constructModel)181*89c4ff92SAndroid Build Coastguard Worker explicit YoloTestCaseProvider(TConstructModelCallable constructModel) 182*89c4ff92SAndroid Build Coastguard Worker : m_ConstructModel(constructModel) 183*89c4ff92SAndroid Build Coastguard Worker { 184*89c4ff92SAndroid Build Coastguard Worker } 185*89c4ff92SAndroid Build Coastguard Worker AddCommandLineOptions(cxxopts::Options & options,std::vector<std::string> & required)186*89c4ff92SAndroid Build Coastguard Worker virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override 187*89c4ff92SAndroid Build Coastguard Worker { 188*89c4ff92SAndroid Build Coastguard Worker options 189*89c4ff92SAndroid Build Coastguard Worker .allow_unrecognised_options() 190*89c4ff92SAndroid Build Coastguard Worker .add_options() 191*89c4ff92SAndroid Build Coastguard Worker ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir)); 192*89c4ff92SAndroid Build Coastguard Worker 193*89c4ff92SAndroid Build Coastguard Worker Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required); 194*89c4ff92SAndroid Build Coastguard Worker } 195*89c4ff92SAndroid Build Coastguard Worker ProcessCommandLineOptions(const InferenceTestOptions & commonOptions)196*89c4ff92SAndroid Build Coastguard Worker virtual bool ProcessCommandLineOptions(const InferenceTestOptions& commonOptions) override 197*89c4ff92SAndroid Build Coastguard Worker { 198*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(m_DataDir)) 199*89c4ff92SAndroid Build Coastguard Worker { 200*89c4ff92SAndroid Build Coastguard Worker return false; 201*89c4ff92SAndroid Build Coastguard Worker } 202*89c4ff92SAndroid Build Coastguard Worker 203*89c4ff92SAndroid Build Coastguard Worker m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions); 204*89c4ff92SAndroid Build Coastguard Worker if (!m_Model) 205*89c4ff92SAndroid Build Coastguard Worker { 206*89c4ff92SAndroid Build Coastguard Worker return false; 207*89c4ff92SAndroid Build Coastguard Worker } 208*89c4ff92SAndroid Build Coastguard Worker 209*89c4ff92SAndroid Build Coastguard Worker m_Database = std::make_unique<YoloDatabase>(m_DataDir.c_str()); 210*89c4ff92SAndroid Build Coastguard Worker if (!m_Database) 211*89c4ff92SAndroid Build Coastguard Worker { 212*89c4ff92SAndroid Build Coastguard Worker return false; 213*89c4ff92SAndroid Build Coastguard Worker } 214*89c4ff92SAndroid Build Coastguard Worker 215*89c4ff92SAndroid Build Coastguard Worker return true; 216*89c4ff92SAndroid Build Coastguard Worker } 217*89c4ff92SAndroid Build Coastguard Worker GetTestCase(unsigned int testCaseId)218*89c4ff92SAndroid Build Coastguard Worker virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override 219*89c4ff92SAndroid Build Coastguard Worker { 220*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<YoloTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId); 221*89c4ff92SAndroid Build Coastguard Worker if (!testCaseData) 222*89c4ff92SAndroid Build Coastguard Worker { 223*89c4ff92SAndroid Build Coastguard Worker return nullptr; 224*89c4ff92SAndroid Build Coastguard Worker } 225*89c4ff92SAndroid Build Coastguard Worker 226*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<YoloTestCase<Model>>(*m_Model, testCaseId, *testCaseData); 227*89c4ff92SAndroid Build Coastguard Worker } 228*89c4ff92SAndroid Build Coastguard Worker 229*89c4ff92SAndroid Build Coastguard Worker private: 230*89c4ff92SAndroid Build Coastguard Worker typename Model::CommandLineOptions m_ModelCommandLineOptions; 231*89c4ff92SAndroid Build Coastguard Worker std::function<std::unique_ptr<Model>(const InferenceTestOptions&, 232*89c4ff92SAndroid Build Coastguard Worker typename Model::CommandLineOptions)> m_ConstructModel; 233*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Model> m_Model; 234*89c4ff92SAndroid Build Coastguard Worker 235*89c4ff92SAndroid Build Coastguard Worker std::string m_DataDir; 236*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<YoloDatabase> m_Database; 237*89c4ff92SAndroid Build Coastguard Worker }; 238