1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <algorithm> 9 #include <armnn/Types.hpp> 10 #include <armnn/utility/Assert.hpp> 11 #include <mapbox/variant.hpp> 12 #include <cstddef> 13 #include <functional> 14 #include <iostream> 15 #include <map> 16 #include <string> 17 #include <vector> 18 19 namespace armnnUtils 20 { 21 22 using namespace armnn; 23 24 // Category names associated with a label 25 using LabelCategoryNames = std::vector<std::string>; 26 27 /** Split a string into tokens by a delimiter 28 * 29 * @param[in] originalString Original string to be split 30 * @param[in] delimiter Delimiter used to split \p originalString 31 * @param[in] includeEmptyToekn If true, include empty tokens in the result 32 * @return A vector of tokens split from \p originalString by \delimiter 33 */ 34 std::vector<std::string> 35 SplitBy(const std::string& originalString, const std::string& delimiter = " ", bool includeEmptyToken = false); 36 37 /** Remove any preceding and trailing character specified in the characterSet. 38 * 39 * @param[in] originalString Original string to be stripped 40 * @param[in] characterSet Set of characters to be stripped from \p originalString 41 * @return A string stripped of all characters specified in \p characterSet from \p originalString 42 */ 43 std::string Strip(const std::string& originalString, const std::string& characterSet = " "); 44 45 class ModelAccuracyChecker 46 { 47 public: 48 /** Constructor for a model top k accuracy checker 49 * 50 * @param[in] validationLabelSet Mapping from names of images to be validated, to category names of their 51 corresponding ground-truth labels. 52 * @param[in] modelOutputLabels Mapping from output nodes to the category names of their corresponding labels 53 Note that an output node can have multiple category names. 54 */ 55 ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabelSet, 56 const std::vector<LabelCategoryNames>& modelOutputLabels); 57 58 /** Get Top K accuracy 59 * 60 * @param[in] k The number of top predictions to use for validating the ground-truth label. For example, if \p k is 61 3, then a prediction is considered correct as long as the ground-truth appears in the top 3 62 predictions. 63 * @return The accuracy, according to the top \p k th predictions. 64 */ 65 float GetAccuracy(unsigned int k); 66 67 /** Record the prediction result of an image 68 * 69 * @param[in] imageName Name of the image. 70 * @param[in] outputTensor Output tensor of the network running \p imageName. 71 */ 72 template <typename TContainer> AddImageResult(const std::string & imageName,std::vector<TContainer> outputTensor)73 void AddImageResult(const std::string& imageName, std::vector<TContainer> outputTensor) 74 { 75 // Increment the total number of images processed 76 ++m_ImagesProcessed; 77 78 std::map<int, float> confidenceMap; 79 auto& output = outputTensor[0]; 80 81 // Create a map of all predictions 82 mapbox::util::apply_visitor([&confidenceMap](auto && value) 83 { 84 int index = 0; 85 for (const auto & o : value) 86 { 87 if (o > 0) 88 { 89 confidenceMap.insert(std::pair<int, float>(index, static_cast<float>(o))); 90 } 91 ++index; 92 } 93 }, 94 output); 95 96 // Create a comparator for sorting the map in order of highest probability 97 typedef std::function<bool(std::pair<int, float>, std::pair<int, float>)> Comparator; 98 99 Comparator compFunctor = 100 [](std::pair<int, float> element1, std::pair<int, float> element2) 101 { 102 return element1.second > element2.second; 103 }; 104 105 // Do the sorting and store in an ordered set 106 std::set<std::pair<int, float>, Comparator> setOfPredictions( 107 confidenceMap.begin(), confidenceMap.end(), compFunctor); 108 109 const std::string correctLabel = m_GroundTruthLabelSet.at(imageName); 110 111 unsigned int index = 1; 112 for (std::pair<int, float> element : setOfPredictions) 113 { 114 if (index >= m_TopK.size()) 115 { 116 break; 117 } 118 // Check if the ground truth label value is included in the topi prediction. 119 // Note that a prediction can have multiple prediction labels. 120 const LabelCategoryNames predictionLabels = m_ModelOutputLabels[static_cast<size_t>(element.first)]; 121 if (std::find(predictionLabels.begin(), predictionLabels.end(), correctLabel) != predictionLabels.end()) 122 { 123 ++m_TopK[index]; 124 break; 125 } 126 ++index; 127 } 128 } 129 130 private: 131 const std::map<std::string, std::string> m_GroundTruthLabelSet; 132 const std::vector<LabelCategoryNames> m_ModelOutputLabels; 133 std::vector<unsigned int> m_TopK = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; 134 unsigned int m_ImagesProcessed = 0; 135 }; 136 } //namespace armnnUtils 137 138