xref: /aosp_15_r20/external/armnn/src/armnnUtils/ModelAccuracyChecker.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <mapbox/variant.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <cstddef>
13*89c4ff92SAndroid Build Coastguard Worker #include <functional>
14*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
15*89c4ff92SAndroid Build Coastguard Worker #include <map>
16*89c4ff92SAndroid Build Coastguard Worker #include <string>
17*89c4ff92SAndroid Build Coastguard Worker #include <vector>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker // Category names associated with a label
25*89c4ff92SAndroid Build Coastguard Worker using LabelCategoryNames = std::vector<std::string>;
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker /** Split a string into tokens by a delimiter
28*89c4ff92SAndroid Build Coastguard Worker  *
29*89c4ff92SAndroid Build Coastguard Worker  * @param[in] originalString    Original string to be split
30*89c4ff92SAndroid Build Coastguard Worker  * @param[in] delimiter         Delimiter used to split \p originalString
31*89c4ff92SAndroid Build Coastguard Worker  * @param[in] includeEmptyToekn If true, include empty tokens in the result
32*89c4ff92SAndroid Build Coastguard Worker  * @return A vector of tokens split from \p originalString by \delimiter
33*89c4ff92SAndroid Build Coastguard Worker  */
34*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string>
35*89c4ff92SAndroid Build Coastguard Worker     SplitBy(const std::string& originalString, const std::string& delimiter = " ", bool includeEmptyToken = false);
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker /** Remove any preceding and trailing character specified in the characterSet.
38*89c4ff92SAndroid Build Coastguard Worker  *
39*89c4ff92SAndroid Build Coastguard Worker  * @param[in] originalString    Original string to be stripped
40*89c4ff92SAndroid Build Coastguard Worker  * @param[in] characterSet      Set of characters to be stripped from \p originalString
41*89c4ff92SAndroid Build Coastguard Worker  * @return A string stripped of all characters specified in \p characterSet from \p originalString
42*89c4ff92SAndroid Build Coastguard Worker  */
43*89c4ff92SAndroid Build Coastguard Worker std::string Strip(const std::string& originalString, const std::string& characterSet = " ");
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker class ModelAccuracyChecker
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker public:
48*89c4ff92SAndroid Build Coastguard Worker     /** Constructor for a model top k accuracy checker
49*89c4ff92SAndroid Build Coastguard Worker      *
50*89c4ff92SAndroid Build Coastguard Worker      * @param[in] validationLabelSet Mapping from names of images to be validated, to category names of their
51*89c4ff92SAndroid Build Coastguard Worker                                      corresponding ground-truth labels.
52*89c4ff92SAndroid Build Coastguard Worker      * @param[in] modelOutputLabels  Mapping from output nodes to the category names of their corresponding labels
53*89c4ff92SAndroid Build Coastguard Worker                                      Note that an output node can have multiple category names.
54*89c4ff92SAndroid Build Coastguard Worker      */
55*89c4ff92SAndroid Build Coastguard Worker     ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabelSet,
56*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<LabelCategoryNames>& modelOutputLabels);
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     /** Get Top K accuracy
59*89c4ff92SAndroid Build Coastguard Worker      *
60*89c4ff92SAndroid Build Coastguard Worker      * @param[in] k The number of top predictions to use for validating the ground-truth label. For example, if \p k is
61*89c4ff92SAndroid Build Coastguard Worker                     3, then a prediction is considered correct as long as the ground-truth appears in the top 3
62*89c4ff92SAndroid Build Coastguard Worker                     predictions.
63*89c4ff92SAndroid Build Coastguard Worker      * @return  The accuracy, according to the top \p k th predictions.
64*89c4ff92SAndroid Build Coastguard Worker      */
65*89c4ff92SAndroid Build Coastguard Worker     float GetAccuracy(unsigned int k);
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     /** Record the prediction result of an image
68*89c4ff92SAndroid Build Coastguard Worker      *
69*89c4ff92SAndroid Build Coastguard Worker      * @param[in] imageName     Name of the image.
70*89c4ff92SAndroid Build Coastguard Worker      * @param[in] outputTensor  Output tensor of the network running \p imageName.
71*89c4ff92SAndroid Build Coastguard Worker      */
72*89c4ff92SAndroid Build Coastguard Worker     template <typename TContainer>
AddImageResult(const std::string & imageName,std::vector<TContainer> outputTensor)73*89c4ff92SAndroid Build Coastguard Worker     void AddImageResult(const std::string& imageName, std::vector<TContainer> outputTensor)
74*89c4ff92SAndroid Build Coastguard Worker     {
75*89c4ff92SAndroid Build Coastguard Worker         // Increment the total number of images processed
76*89c4ff92SAndroid Build Coastguard Worker         ++m_ImagesProcessed;
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker         std::map<int, float> confidenceMap;
79*89c4ff92SAndroid Build Coastguard Worker         auto& output = outputTensor[0];
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker         // Create a map of all predictions
82*89c4ff92SAndroid Build Coastguard Worker         mapbox::util::apply_visitor([&confidenceMap](auto && value)
83*89c4ff92SAndroid Build Coastguard Worker                              {
84*89c4ff92SAndroid Build Coastguard Worker                                  int index = 0;
85*89c4ff92SAndroid Build Coastguard Worker                                  for (const auto & o : value)
86*89c4ff92SAndroid Build Coastguard Worker                                  {
87*89c4ff92SAndroid Build Coastguard Worker                                      if (o > 0)
88*89c4ff92SAndroid Build Coastguard Worker                                      {
89*89c4ff92SAndroid Build Coastguard Worker                                          confidenceMap.insert(std::pair<int, float>(index, static_cast<float>(o)));
90*89c4ff92SAndroid Build Coastguard Worker                                      }
91*89c4ff92SAndroid Build Coastguard Worker                                      ++index;
92*89c4ff92SAndroid Build Coastguard Worker                                  }
93*89c4ff92SAndroid Build Coastguard Worker                              },
94*89c4ff92SAndroid Build Coastguard Worker                              output);
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker         // Create a comparator for sorting the map in order of highest probability
97*89c4ff92SAndroid Build Coastguard Worker         typedef std::function<bool(std::pair<int, float>, std::pair<int, float>)> Comparator;
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker         Comparator compFunctor =
100*89c4ff92SAndroid Build Coastguard Worker             [](std::pair<int, float> element1, std::pair<int, float> element2)
101*89c4ff92SAndroid Build Coastguard Worker             {
102*89c4ff92SAndroid Build Coastguard Worker                 return element1.second > element2.second;
103*89c4ff92SAndroid Build Coastguard Worker             };
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker         // Do the sorting and store in an ordered set
106*89c4ff92SAndroid Build Coastguard Worker         std::set<std::pair<int, float>, Comparator> setOfPredictions(
107*89c4ff92SAndroid Build Coastguard Worker             confidenceMap.begin(), confidenceMap.end(), compFunctor);
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker         const std::string correctLabel = m_GroundTruthLabelSet.at(imageName);
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker         unsigned int index = 1;
112*89c4ff92SAndroid Build Coastguard Worker         for (std::pair<int, float> element : setOfPredictions)
113*89c4ff92SAndroid Build Coastguard Worker         {
114*89c4ff92SAndroid Build Coastguard Worker             if (index >= m_TopK.size())
115*89c4ff92SAndroid Build Coastguard Worker             {
116*89c4ff92SAndroid Build Coastguard Worker                 break;
117*89c4ff92SAndroid Build Coastguard Worker             }
118*89c4ff92SAndroid Build Coastguard Worker             // Check if the ground truth label value is included in the topi prediction.
119*89c4ff92SAndroid Build Coastguard Worker             // Note that a prediction can have multiple prediction labels.
120*89c4ff92SAndroid Build Coastguard Worker             const LabelCategoryNames predictionLabels = m_ModelOutputLabels[static_cast<size_t>(element.first)];
121*89c4ff92SAndroid Build Coastguard Worker             if (std::find(predictionLabels.begin(), predictionLabels.end(), correctLabel) != predictionLabels.end())
122*89c4ff92SAndroid Build Coastguard Worker             {
123*89c4ff92SAndroid Build Coastguard Worker                 ++m_TopK[index];
124*89c4ff92SAndroid Build Coastguard Worker                 break;
125*89c4ff92SAndroid Build Coastguard Worker             }
126*89c4ff92SAndroid Build Coastguard Worker             ++index;
127*89c4ff92SAndroid Build Coastguard Worker         }
128*89c4ff92SAndroid Build Coastguard Worker     }
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker private:
131*89c4ff92SAndroid Build Coastguard Worker     const std::map<std::string, std::string> m_GroundTruthLabelSet;
132*89c4ff92SAndroid Build Coastguard Worker     const std::vector<LabelCategoryNames> m_ModelOutputLabels;
133*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> m_TopK = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
134*89c4ff92SAndroid Build Coastguard Worker     unsigned int m_ImagesProcessed   = 0;
135*89c4ff92SAndroid Build Coastguard Worker };
136*89c4ff92SAndroid Build Coastguard Worker } //namespace armnnUtils
137*89c4ff92SAndroid Build Coastguard Worker 
138