xref: /aosp_15_r20/external/armnn/src/armnnUtils/ModelAccuracyChecker.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <algorithm>
9 #include <armnn/Types.hpp>
10 #include <armnn/utility/Assert.hpp>
11 #include <mapbox/variant.hpp>
12 #include <cstddef>
13 #include <functional>
14 #include <iostream>
15 #include <map>
16 #include <string>
17 #include <vector>
18 
19 namespace armnnUtils
20 {
21 
22 using namespace armnn;
23 
24 // Category names associated with a label
25 using LabelCategoryNames = std::vector<std::string>;
26 
27 /** Split a string into tokens by a delimiter
28  *
29  * @param[in] originalString    Original string to be split
30  * @param[in] delimiter         Delimiter used to split \p originalString
31  * @param[in] includeEmptyToekn If true, include empty tokens in the result
32  * @return A vector of tokens split from \p originalString by \delimiter
33  */
34 std::vector<std::string>
35     SplitBy(const std::string& originalString, const std::string& delimiter = " ", bool includeEmptyToken = false);
36 
37 /** Remove any preceding and trailing character specified in the characterSet.
38  *
39  * @param[in] originalString    Original string to be stripped
40  * @param[in] characterSet      Set of characters to be stripped from \p originalString
41  * @return A string stripped of all characters specified in \p characterSet from \p originalString
42  */
43 std::string Strip(const std::string& originalString, const std::string& characterSet = " ");
44 
45 class ModelAccuracyChecker
46 {
47 public:
48     /** Constructor for a model top k accuracy checker
49      *
50      * @param[in] validationLabelSet Mapping from names of images to be validated, to category names of their
51                                      corresponding ground-truth labels.
52      * @param[in] modelOutputLabels  Mapping from output nodes to the category names of their corresponding labels
53                                      Note that an output node can have multiple category names.
54      */
55     ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabelSet,
56                          const std::vector<LabelCategoryNames>& modelOutputLabels);
57 
58     /** Get Top K accuracy
59      *
60      * @param[in] k The number of top predictions to use for validating the ground-truth label. For example, if \p k is
61                     3, then a prediction is considered correct as long as the ground-truth appears in the top 3
62                     predictions.
63      * @return  The accuracy, according to the top \p k th predictions.
64      */
65     float GetAccuracy(unsigned int k);
66 
67     /** Record the prediction result of an image
68      *
69      * @param[in] imageName     Name of the image.
70      * @param[in] outputTensor  Output tensor of the network running \p imageName.
71      */
72     template <typename TContainer>
AddImageResult(const std::string & imageName,std::vector<TContainer> outputTensor)73     void AddImageResult(const std::string& imageName, std::vector<TContainer> outputTensor)
74     {
75         // Increment the total number of images processed
76         ++m_ImagesProcessed;
77 
78         std::map<int, float> confidenceMap;
79         auto& output = outputTensor[0];
80 
81         // Create a map of all predictions
82         mapbox::util::apply_visitor([&confidenceMap](auto && value)
83                              {
84                                  int index = 0;
85                                  for (const auto & o : value)
86                                  {
87                                      if (o > 0)
88                                      {
89                                          confidenceMap.insert(std::pair<int, float>(index, static_cast<float>(o)));
90                                      }
91                                      ++index;
92                                  }
93                              },
94                              output);
95 
96         // Create a comparator for sorting the map in order of highest probability
97         typedef std::function<bool(std::pair<int, float>, std::pair<int, float>)> Comparator;
98 
99         Comparator compFunctor =
100             [](std::pair<int, float> element1, std::pair<int, float> element2)
101             {
102                 return element1.second > element2.second;
103             };
104 
105         // Do the sorting and store in an ordered set
106         std::set<std::pair<int, float>, Comparator> setOfPredictions(
107             confidenceMap.begin(), confidenceMap.end(), compFunctor);
108 
109         const std::string correctLabel = m_GroundTruthLabelSet.at(imageName);
110 
111         unsigned int index = 1;
112         for (std::pair<int, float> element : setOfPredictions)
113         {
114             if (index >= m_TopK.size())
115             {
116                 break;
117             }
118             // Check if the ground truth label value is included in the topi prediction.
119             // Note that a prediction can have multiple prediction labels.
120             const LabelCategoryNames predictionLabels = m_ModelOutputLabels[static_cast<size_t>(element.first)];
121             if (std::find(predictionLabels.begin(), predictionLabels.end(), correctLabel) != predictionLabels.end())
122             {
123                 ++m_TopK[index];
124                 break;
125             }
126             ++index;
127         }
128     }
129 
130 private:
131     const std::map<std::string, std::string> m_GroundTruthLabelSet;
132     const std::vector<LabelCategoryNames> m_ModelOutputLabels;
133     std::vector<unsigned int> m_TopK = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
134     unsigned int m_ImagesProcessed   = 0;
135 };
136 } //namespace armnnUtils
137 
138