xref: /aosp_15_r20/external/armnn/tests/InferenceTest.inl (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "InferenceTest.hpp"
6
7#include <armnn/Utils.hpp>
8#include <armnn/utility/Assert.hpp>
9#include <armnn/utility/NumericCast.hpp>
10#include <armnnUtils/TContainer.hpp>
11
12#include "CxxoptsUtils.hpp"
13
14#include <cxxopts/cxxopts.hpp>
15#include <fmt/format.h>
16
17#include <fstream>
18#include <iostream>
19#include <iomanip>
20#include <array>
21#include <chrono>
22
23using namespace std;
24using namespace std::chrono;
25using namespace armnn::test;
26
27namespace armnn
28{
29namespace test
30{
31
32template <typename TTestCaseDatabase, typename TModel>
33ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
34    int& numInferencesRef,
35    int& numCorrectInferencesRef,
36    const std::vector<unsigned int>& validationPredictions,
37    std::vector<unsigned int>* validationPredictionsOut,
38    TModel& model,
39    unsigned int testCaseId,
40    unsigned int label,
41    std::vector<typename TModel::DataType> modelInput)
42    : InferenceModelTestCase<TModel>(
43            model, testCaseId, std::vector<armnnUtils::TContainer>{ modelInput }, { model.GetOutputSize() })
44    , m_Label(label)
45    , m_QuantizationParams(model.GetQuantizationParams())
46    , m_NumInferencesRef(numInferencesRef)
47    , m_NumCorrectInferencesRef(numCorrectInferencesRef)
48    , m_ValidationPredictions(validationPredictions)
49    , m_ValidationPredictionsOut(validationPredictionsOut)
50{
51}
52
53struct ClassifierResultProcessor
54{
55    using ResultMap = std::map<float,int>;
56
57    ClassifierResultProcessor(float scale, int offset)
58        : m_Scale(scale)
59        , m_Offset(offset)
60    {}
61
62    void operator()(const std::vector<float>& values)
63    {
64        SortPredictions(values, [](float value)
65                                {
66                                    return value;
67                                });
68    }
69
70    void operator()(const std::vector<int8_t>& values)
71    {
72        SortPredictions(values, [](int8_t value)
73        {
74            return value;
75        });
76    }
77
78    void operator()(const std::vector<uint8_t>& values)
79    {
80        auto& scale = m_Scale;
81        auto& offset = m_Offset;
82        SortPredictions(values, [&scale, &offset](uint8_t value)
83                                {
84                                    return armnn::Dequantize(value, scale, offset);
85                                });
86    }
87
88    void operator()(const std::vector<int>& values)
89    {
90        IgnoreUnused(values);
91        ARMNN_ASSERT_MSG(false, "Non-float predictions output not supported.");
92    }
93
94    ResultMap& GetResultMap() { return m_ResultMap; }
95
96private:
97    template<typename Container, typename Delegate>
98    void SortPredictions(const Container& c, Delegate delegate)
99    {
100        int index = 0;
101        for (const auto& value : c)
102        {
103            int classification = index++;
104            // Take the first class with each probability
105            // This avoids strange results when looping over batched results produced
106            // with identical test data.
107            ResultMap::iterator lb = m_ResultMap.lower_bound(value);
108
109            if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
110            {
111                // If the key is not already in the map, insert it.
112                m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
113            }
114        }
115    }
116
117    ResultMap m_ResultMap;
118
119    float m_Scale=0.0f;
120    int m_Offset=0;
121};
122
123template <typename TTestCaseDatabase, typename TModel>
124TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
125{
126    auto& output = this->GetOutputs()[0];
127    const auto testCaseId = this->GetTestCaseId();
128
129    ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
130    mapbox::util::apply_visitor(resultProcessor, output);
131
132    ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId;
133    auto it = resultProcessor.GetResultMap().rbegin();
134    for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
135    {
136        ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
137          " with value: " << (it->first);
138        ++it;
139    }
140
141    unsigned int prediction = 0;
142    mapbox::util::apply_visitor([&](auto&& value)
143                         {
144                             prediction = armnn::numeric_cast<unsigned int>(
145                                     std::distance(value.begin(), std::max_element(value.begin(), value.end())));
146                         },
147                         output);
148
149    // If we're just running the defaultTestCaseIds, each one must be classified correctly.
150    if (params.m_IterationCount == 0 && prediction != m_Label)
151    {
152        ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
153            " is incorrect (should be " << m_Label << ")";
154        return TestCaseResult::Failed;
155    }
156
157    // If a validation file was provided as input, it checks that the prediction matches.
158    if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
159    {
160        ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
161            " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
162        return TestCaseResult::Failed;
163    }
164
165    // If a validation file was requested as output, it stores the predictions.
166    if (m_ValidationPredictionsOut)
167    {
168        m_ValidationPredictionsOut->push_back(prediction);
169    }
170
171    // Updates accuracy stats.
172    m_NumInferencesRef++;
173    if (prediction == m_Label)
174    {
175        m_NumCorrectInferencesRef++;
176    }
177
178    return TestCaseResult::Ok;
179}
180
181template <typename TDatabase, typename InferenceModel>
182template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
183ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
184    TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
185    : m_ConstructModel(constructModel)
186    , m_ConstructDatabase(constructDatabase)
187    , m_NumInferences(0)
188    , m_NumCorrectInferences(0)
189{
190}
191
192template <typename TDatabase, typename InferenceModel>
193void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
194    cxxopts::Options& options, std::vector<std::string>& required)
195{
196    options
197        .allow_unrecognised_options()
198        .add_options()
199            ("validation-file-in",
200             "Reads expected predictions from the given file and confirms they match the actual predictions.",
201             cxxopts::value<std::string>(m_ValidationFileIn)->default_value(""))
202            ("validation-file-out", "Predictions are saved to the given file for later use via --validation-file-in.",
203             cxxopts::value<std::string>(m_ValidationFileOut)->default_value(""))
204            ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
205
206    required.emplace_back("data-dir"); //add to required arguments to check
207
208    InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
209}
210
211template <typename TDatabase, typename InferenceModel>
212bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions(
213        const InferenceTestOptions& commonOptions)
214{
215    if (!ValidateDirectory(m_DataDir))
216    {
217        return false;
218    }
219
220    ReadPredictions();
221
222    m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
223    if (!m_Model)
224    {
225        return false;
226    }
227
228    m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
229    if (!m_Database)
230    {
231        return false;
232    }
233
234    return true;
235}
236
237template <typename TDatabase, typename InferenceModel>
238std::unique_ptr<IInferenceTestCase>
239ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
240{
241    std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
242    if (testCaseData == nullptr)
243    {
244        return nullptr;
245    }
246
247    return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
248        m_NumInferences,
249        m_NumCorrectInferences,
250        m_ValidationPredictions,
251        m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
252        *m_Model,
253        testCaseId,
254        testCaseData->m_Label,
255        std::move(testCaseData->m_InputImage));
256}
257
258template <typename TDatabase, typename InferenceModel>
259bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
260{
261    const double accuracy = armnn::numeric_cast<double>(m_NumCorrectInferences) /
262        armnn::numeric_cast<double>(m_NumInferences);
263    ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
264
265    // If a validation file was requested as output, the predictions are saved to it.
266    if (!m_ValidationFileOut.empty())
267    {
268        std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
269        if (validationFileOut.good())
270        {
271            for (const unsigned int prediction : m_ValidationPredictionsOut)
272            {
273                validationFileOut << prediction << std::endl;
274            }
275        }
276        else
277        {
278            ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
279            return false;
280        }
281    }
282
283    return true;
284}
285
286template <typename TDatabase, typename InferenceModel>
287void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
288{
289    // Reads the expected predictions from the input validation file (if provided).
290    if (!m_ValidationFileIn.empty())
291    {
292        std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
293        if (validationFileIn.good())
294        {
295            while (!validationFileIn.eof())
296            {
297                unsigned int i;
298                validationFileIn >> i;
299                m_ValidationPredictions.emplace_back(i);
300            }
301        }
302        else
303        {
304            throw armnn::Exception(fmt::format("Failed to open input validation file: {}"
305                , m_ValidationFileIn));
306        }
307    }
308}
309
310template<typename TConstructTestCaseProvider>
311int InferenceTestMain(int argc,
312    char* argv[],
313    const std::vector<unsigned int>& defaultTestCaseIds,
314    TConstructTestCaseProvider constructTestCaseProvider)
315{
316    // Configures logging for both the ARMNN library and this test program.
317#ifdef NDEBUG
318    armnn::LogSeverity level = armnn::LogSeverity::Info;
319#else
320    armnn::LogSeverity level = armnn::LogSeverity::Debug;
321#endif
322    armnn::ConfigureLogging(true, true, level);
323
324    try
325    {
326        std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
327        if (!testCaseProvider)
328        {
329            return 1;
330        }
331
332        InferenceTestOptions inferenceTestOptions;
333        if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
334        {
335            return 1;
336        }
337
338        const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
339        return success ? 0 : 1;
340    }
341    catch (armnn::Exception const& e)
342    {
343        ARMNN_LOG(fatal) << "Armnn Error: " << e.what();
344        return 1;
345    }
346}
347
348//
349// This function allows us to create a classifier inference test based on:
350//  - a model file name
351//  - which can be a binary or a text file for protobuf formats
352//  - an input tensor name
353//  - an output tensor name
354//  - a set of test case ids
355//  - a callback method which creates an object that can return images
356//    called 'Database' in these tests
357//  - and an input tensor shape
358//
359template<typename TDatabase,
360         typename TParser,
361         typename TConstructDatabaseCallable>
362int ClassifierInferenceTestMain(int argc,
363                                char* argv[],
364                                const char* modelFilename,
365                                bool isModelBinary,
366                                const char* inputBindingName,
367                                const char* outputBindingName,
368                                const std::vector<unsigned int>& defaultTestCaseIds,
369                                TConstructDatabaseCallable constructDatabase,
370                                const armnn::TensorShape* inputTensorShape)
371
372{
373    ARMNN_ASSERT(modelFilename);
374    ARMNN_ASSERT(inputBindingName);
375    ARMNN_ASSERT(outputBindingName);
376
377    return InferenceTestMain(argc, argv, defaultTestCaseIds,
378        [=]
379        ()
380        {
381            using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
382            using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
383
384            return make_unique<TestCaseProvider>(constructDatabase,
385                [&]
386                (const InferenceTestOptions &commonOptions,
387                 typename InferenceModel::CommandLineOptions modelOptions)
388                {
389                    if (!ValidateDirectory(modelOptions.m_ModelDir))
390                    {
391                        return std::unique_ptr<InferenceModel>();
392                    }
393
394                    typename InferenceModel::Params modelParams;
395                    modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
396                    modelParams.m_InputBindings  = { inputBindingName };
397                    modelParams.m_OutputBindings = { outputBindingName };
398
399                    if (inputTensorShape)
400                    {
401                        modelParams.m_InputShapes.push_back(*inputTensorShape);
402                    }
403
404                    modelParams.m_IsModelBinary = isModelBinary;
405                    modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
406                    modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
407                    modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
408
409                    return std::make_unique<InferenceModel>(modelParams,
410                                                            commonOptions.m_EnableProfiling,
411                                                            commonOptions.m_DynamicBackendsPath);
412            });
413        });
414}
415
416} // namespace test
417} // namespace armnn
418