1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <algorithm> 9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <mapbox/variant.hpp> 12*89c4ff92SAndroid Build Coastguard Worker #include <cstddef> 13*89c4ff92SAndroid Build Coastguard Worker #include <functional> 14*89c4ff92SAndroid Build Coastguard Worker #include <iostream> 15*89c4ff92SAndroid Build Coastguard Worker #include <map> 16*89c4ff92SAndroid Build Coastguard Worker #include <string> 17*89c4ff92SAndroid Build Coastguard Worker #include <vector> 18*89c4ff92SAndroid Build Coastguard Worker 19*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils 20*89c4ff92SAndroid Build Coastguard Worker { 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker // Category names associated with a label 25*89c4ff92SAndroid Build Coastguard Worker using LabelCategoryNames = std::vector<std::string>; 26*89c4ff92SAndroid Build Coastguard Worker 27*89c4ff92SAndroid Build Coastguard Worker /** Split a string into tokens by a delimiter 28*89c4ff92SAndroid Build Coastguard Worker * 29*89c4ff92SAndroid Build Coastguard Worker * @param[in] originalString Original string to be split 30*89c4ff92SAndroid Build Coastguard Worker * @param[in] delimiter Delimiter used to split \p originalString 31*89c4ff92SAndroid Build Coastguard Worker * @param[in] includeEmptyToekn If true, include empty tokens in the result 32*89c4ff92SAndroid Build Coastguard Worker * @return A vector of tokens split from \p originalString by \delimiter 33*89c4ff92SAndroid Build Coastguard Worker */ 34*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> 35*89c4ff92SAndroid Build Coastguard Worker SplitBy(const std::string& originalString, const std::string& delimiter = " ", bool includeEmptyToken = false); 36*89c4ff92SAndroid Build Coastguard Worker 37*89c4ff92SAndroid Build Coastguard Worker /** Remove any preceding and trailing character specified in the characterSet. 38*89c4ff92SAndroid Build Coastguard Worker * 39*89c4ff92SAndroid Build Coastguard Worker * @param[in] originalString Original string to be stripped 40*89c4ff92SAndroid Build Coastguard Worker * @param[in] characterSet Set of characters to be stripped from \p originalString 41*89c4ff92SAndroid Build Coastguard Worker * @return A string stripped of all characters specified in \p characterSet from \p originalString 42*89c4ff92SAndroid Build Coastguard Worker */ 43*89c4ff92SAndroid Build Coastguard Worker std::string Strip(const std::string& originalString, const std::string& characterSet = " "); 44*89c4ff92SAndroid Build Coastguard Worker 45*89c4ff92SAndroid Build Coastguard Worker class ModelAccuracyChecker 46*89c4ff92SAndroid Build Coastguard Worker { 47*89c4ff92SAndroid Build Coastguard Worker public: 48*89c4ff92SAndroid Build Coastguard Worker /** Constructor for a model top k accuracy checker 49*89c4ff92SAndroid Build Coastguard Worker * 50*89c4ff92SAndroid Build Coastguard Worker * @param[in] validationLabelSet Mapping from names of images to be validated, to category names of their 51*89c4ff92SAndroid Build Coastguard Worker corresponding ground-truth labels. 52*89c4ff92SAndroid Build Coastguard Worker * @param[in] modelOutputLabels Mapping from output nodes to the category names of their corresponding labels 53*89c4ff92SAndroid Build Coastguard Worker Note that an output node can have multiple category names. 54*89c4ff92SAndroid Build Coastguard Worker */ 55*89c4ff92SAndroid Build Coastguard Worker ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabelSet, 56*89c4ff92SAndroid Build Coastguard Worker const std::vector<LabelCategoryNames>& modelOutputLabels); 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker /** Get Top K accuracy 59*89c4ff92SAndroid Build Coastguard Worker * 60*89c4ff92SAndroid Build Coastguard Worker * @param[in] k The number of top predictions to use for validating the ground-truth label. For example, if \p k is 61*89c4ff92SAndroid Build Coastguard Worker 3, then a prediction is considered correct as long as the ground-truth appears in the top 3 62*89c4ff92SAndroid Build Coastguard Worker predictions. 63*89c4ff92SAndroid Build Coastguard Worker * @return The accuracy, according to the top \p k th predictions. 64*89c4ff92SAndroid Build Coastguard Worker */ 65*89c4ff92SAndroid Build Coastguard Worker float GetAccuracy(unsigned int k); 66*89c4ff92SAndroid Build Coastguard Worker 67*89c4ff92SAndroid Build Coastguard Worker /** Record the prediction result of an image 68*89c4ff92SAndroid Build Coastguard Worker * 69*89c4ff92SAndroid Build Coastguard Worker * @param[in] imageName Name of the image. 70*89c4ff92SAndroid Build Coastguard Worker * @param[in] outputTensor Output tensor of the network running \p imageName. 71*89c4ff92SAndroid Build Coastguard Worker */ 72*89c4ff92SAndroid Build Coastguard Worker template <typename TContainer> AddImageResult(const std::string & imageName,std::vector<TContainer> outputTensor)73*89c4ff92SAndroid Build Coastguard Worker void AddImageResult(const std::string& imageName, std::vector<TContainer> outputTensor) 74*89c4ff92SAndroid Build Coastguard Worker { 75*89c4ff92SAndroid Build Coastguard Worker // Increment the total number of images processed 76*89c4ff92SAndroid Build Coastguard Worker ++m_ImagesProcessed; 77*89c4ff92SAndroid Build Coastguard Worker 78*89c4ff92SAndroid Build Coastguard Worker std::map<int, float> confidenceMap; 79*89c4ff92SAndroid Build Coastguard Worker auto& output = outputTensor[0]; 80*89c4ff92SAndroid Build Coastguard Worker 81*89c4ff92SAndroid Build Coastguard Worker // Create a map of all predictions 82*89c4ff92SAndroid Build Coastguard Worker mapbox::util::apply_visitor([&confidenceMap](auto && value) 83*89c4ff92SAndroid Build Coastguard Worker { 84*89c4ff92SAndroid Build Coastguard Worker int index = 0; 85*89c4ff92SAndroid Build Coastguard Worker for (const auto & o : value) 86*89c4ff92SAndroid Build Coastguard Worker { 87*89c4ff92SAndroid Build Coastguard Worker if (o > 0) 88*89c4ff92SAndroid Build Coastguard Worker { 89*89c4ff92SAndroid Build Coastguard Worker confidenceMap.insert(std::pair<int, float>(index, static_cast<float>(o))); 90*89c4ff92SAndroid Build Coastguard Worker } 91*89c4ff92SAndroid Build Coastguard Worker ++index; 92*89c4ff92SAndroid Build Coastguard Worker } 93*89c4ff92SAndroid Build Coastguard Worker }, 94*89c4ff92SAndroid Build Coastguard Worker output); 95*89c4ff92SAndroid Build Coastguard Worker 96*89c4ff92SAndroid Build Coastguard Worker // Create a comparator for sorting the map in order of highest probability 97*89c4ff92SAndroid Build Coastguard Worker typedef std::function<bool(std::pair<int, float>, std::pair<int, float>)> Comparator; 98*89c4ff92SAndroid Build Coastguard Worker 99*89c4ff92SAndroid Build Coastguard Worker Comparator compFunctor = 100*89c4ff92SAndroid Build Coastguard Worker [](std::pair<int, float> element1, std::pair<int, float> element2) 101*89c4ff92SAndroid Build Coastguard Worker { 102*89c4ff92SAndroid Build Coastguard Worker return element1.second > element2.second; 103*89c4ff92SAndroid Build Coastguard Worker }; 104*89c4ff92SAndroid Build Coastguard Worker 105*89c4ff92SAndroid Build Coastguard Worker // Do the sorting and store in an ordered set 106*89c4ff92SAndroid Build Coastguard Worker std::set<std::pair<int, float>, Comparator> setOfPredictions( 107*89c4ff92SAndroid Build Coastguard Worker confidenceMap.begin(), confidenceMap.end(), compFunctor); 108*89c4ff92SAndroid Build Coastguard Worker 109*89c4ff92SAndroid Build Coastguard Worker const std::string correctLabel = m_GroundTruthLabelSet.at(imageName); 110*89c4ff92SAndroid Build Coastguard Worker 111*89c4ff92SAndroid Build Coastguard Worker unsigned int index = 1; 112*89c4ff92SAndroid Build Coastguard Worker for (std::pair<int, float> element : setOfPredictions) 113*89c4ff92SAndroid Build Coastguard Worker { 114*89c4ff92SAndroid Build Coastguard Worker if (index >= m_TopK.size()) 115*89c4ff92SAndroid Build Coastguard Worker { 116*89c4ff92SAndroid Build Coastguard Worker break; 117*89c4ff92SAndroid Build Coastguard Worker } 118*89c4ff92SAndroid Build Coastguard Worker // Check if the ground truth label value is included in the topi prediction. 119*89c4ff92SAndroid Build Coastguard Worker // Note that a prediction can have multiple prediction labels. 120*89c4ff92SAndroid Build Coastguard Worker const LabelCategoryNames predictionLabels = m_ModelOutputLabels[static_cast<size_t>(element.first)]; 121*89c4ff92SAndroid Build Coastguard Worker if (std::find(predictionLabels.begin(), predictionLabels.end(), correctLabel) != predictionLabels.end()) 122*89c4ff92SAndroid Build Coastguard Worker { 123*89c4ff92SAndroid Build Coastguard Worker ++m_TopK[index]; 124*89c4ff92SAndroid Build Coastguard Worker break; 125*89c4ff92SAndroid Build Coastguard Worker } 126*89c4ff92SAndroid Build Coastguard Worker ++index; 127*89c4ff92SAndroid Build Coastguard Worker } 128*89c4ff92SAndroid Build Coastguard Worker } 129*89c4ff92SAndroid Build Coastguard Worker 130*89c4ff92SAndroid Build Coastguard Worker private: 131*89c4ff92SAndroid Build Coastguard Worker const std::map<std::string, std::string> m_GroundTruthLabelSet; 132*89c4ff92SAndroid Build Coastguard Worker const std::vector<LabelCategoryNames> m_ModelOutputLabels; 133*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> m_TopK = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; 134*89c4ff92SAndroid Build Coastguard Worker unsigned int m_ImagesProcessed = 0; 135*89c4ff92SAndroid Build Coastguard Worker }; 136*89c4ff92SAndroid Build Coastguard Worker } //namespace armnnUtils 137*89c4ff92SAndroid Build Coastguard Worker 138