xref: /aosp_15_r20/external/armnn/src/armnnTestUtils/UnitTests.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "armnnTestUtils/TensorHelpers.hpp"
8 
9 #include <armnn/Logging.hpp>
10 #include <armnn/Utils.hpp>
11 #include <reference/RefWorkloadFactory.hpp>
12 #include <reference/test/RefWorkloadFactoryHelper.hpp>
13 
14 #include <backendsCommon/test/WorkloadFactoryHelper.hpp>
15 
16 #include <armnnTestUtils/LayerTestResult.hpp>
17 #include <armnnTestUtils/TensorCopyUtils.hpp>
18 #include <armnnTestUtils/WorkloadTestUtils.hpp>
19 
20 #include <doctest/doctest.h>
21 
ConfigureLoggingTest()22 inline void ConfigureLoggingTest()
23 {
24     // Configures logging for both the ARMNN library and this test program.
25     armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
26 }
27 
28 // The following macros require the caller to have defined FactoryType, with one of the following using statements:
29 //
30 //      using FactoryType = armnn::RefWorkloadFactory;
31 //      using FactoryType = armnn::ClWorkloadFactory;
32 //      using FactoryType = armnn::NeonWorkloadFactory;
33 
34 /// Executes CHECK_MESSAGE on CompareTensors() return value so that the predicate_result message is reported.
35 /// If the test reports itself as not supported then the tensors are not compared.
36 /// Additionally this checks that the supportedness reported by the test matches the name of the test.
37 /// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name.
38 /// This is useful because it clarifies that the feature being tested is not actually supported
39 /// (a passed test with the name of a feature would imply that feature was supported).
40 /// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED.
41 /// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED.
42 template <typename T, std::size_t n>
CompareTestResultIfSupported(const std::string & testName,const LayerTestResult<T,n> & testResult)43 void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult<T, n>& testResult)
44 {
45     bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
46     CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.m_Supported,
47                   "The test name does not match the supportedness it is reporting");
48     if (testResult.m_Supported)
49     {
50         auto result = CompareTensors(testResult.m_ActualData,
51                                      testResult.m_ExpectedData,
52                                      testResult.m_ActualShape,
53                                      testResult.m_ExpectedShape,
54                                      testResult.m_CompareBoolean);
55        CHECK_MESSAGE(result.m_Result, result.m_Message.str());
56     }
57 }
58 
59 template <typename T, std::size_t n>
CompareTestResultIfSupported(const std::string & testName,const std::vector<LayerTestResult<T,n>> & testResult)60 void CompareTestResultIfSupported(const std::string& testName, const std::vector<LayerTestResult<T, n>>& testResult)
61 {
62     bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
63     for (unsigned int i = 0; i < testResult.size(); ++i)
64     {
65         CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].m_Supported,
66                       "The test name does not match the supportedness it is reporting");
67         if (testResult[i].m_Supported)
68         {
69             auto result = CompareTensors(testResult[i].m_ActualData,
70                                          testResult[i].m_ExpectedData,
71                                          testResult[i].m_ActualShape,
72                                          testResult[i].m_ExpectedShape);
73             CHECK_MESSAGE(result.m_Result, result.m_Message.str());
74         }
75     }
76 }
77 
78 template<typename FactoryType, typename TFuncPtr, typename... Args>
RunTestFunction(const char * testName,TFuncPtr testFunction,Args...args)79 void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
80 {
81     std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
82     armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
83 
84     auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
85     FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
86 
87     auto testResult = (*testFunction)(workloadFactory, memoryManager, args...);
88     CompareTestResultIfSupported(testName, testResult);
89 
90     armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
91 }
92 
93 
94 template<typename FactoryType, typename TFuncPtr, typename... Args>
RunTestFunctionUsingTensorHandleFactory(const char * testName,TFuncPtr testFunction,Args...args)95 void RunTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
96 {
97     std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
98     armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
99 
100     auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
101     FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
102 
103     auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
104 
105     auto testResult = (*testFunction)(workloadFactory, memoryManager, tensorHandleFactory, args...);
106     CompareTestResultIfSupported(testName, testResult);
107 
108     armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
109 }
110 
111 #define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \
112     TEST_CASE(#TestName) \
113     { \
114         TestFunction(); \
115     }
116 
117 #define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
118     TEST_CASE(#TestName) \
119     { \
120         RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
121     }
122 
123 #define ARMNN_AUTO_TEST_FIXTURE(TestName, Fixture, TestFunction, ...) \
124     TEST_CASE_FIXTURE(Fixture, #TestName) \
125     { \
126         RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
127     }
128 
129 #define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
130     TEST_CASE(#TestName) \
131     { \
132         RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
133     }
134 
135 #define ARMNN_AUTO_TEST_FIXTURE_WITH_THF(TestName, Fixture, TestFunction, ...) \
136     TEST_CASE_FIXTURE(Fixture, #TestName) \
137     { \
138         RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
139     }
140 
141 template<typename FactoryType, typename TFuncPtr, typename... Args>
CompareRefTestFunction(const char * testName,TFuncPtr testFunction,Args...args)142 void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
143 {
144     auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
145     FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
146 
147     armnn::RefWorkloadFactory refWorkloadFactory;
148 
149     auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...);
150     CompareTestResultIfSupported(testName, testResult);
151 }
152 
153 template<typename FactoryType, typename TFuncPtr, typename... Args>
CompareRefTestFunctionUsingTensorHandleFactory(const char * testName,TFuncPtr testFunction,Args...args)154 void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
155 {
156     auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
157     FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
158     auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
159 
160     armnn::RefWorkloadFactory refWorkloadFactory;
161     auto refMemoryManager = WorkloadFactoryHelper<armnn::RefWorkloadFactory>::GetMemoryManager();
162     auto refTensorHandleFactory = RefWorkloadFactoryHelper::GetTensorHandleFactory(refMemoryManager);
163 
164     auto testResult = (*testFunction)(
165         workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, args...);
166     CompareTestResultIfSupported(testName, testResult);
167 }
168 
169 #define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
170     TEST_CASE(#TestName) \
171     { \
172         CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
173     }
174 
175 #define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
176     TEST_CASE(#TestName) \
177     { \
178         CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
179     }
180 
181 #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
182     TEST_CASE_FIXTURE(Fixture, #TestName) \
183     { \
184         CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
185     }
186 
187 #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE_WITH_THF(TestName, Fixture, TestFunction, ...) \
188     TEST_CASE_FIXTURE(Fixture, #TestName) \
189     { \
190         CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
191     }
192