1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "armnnTestUtils/TensorHelpers.hpp"
8
9 #include <armnn/Logging.hpp>
10 #include <armnn/Utils.hpp>
11 #include <reference/RefWorkloadFactory.hpp>
12 #include <reference/test/RefWorkloadFactoryHelper.hpp>
13
14 #include <backendsCommon/test/WorkloadFactoryHelper.hpp>
15
16 #include <armnnTestUtils/LayerTestResult.hpp>
17 #include <armnnTestUtils/TensorCopyUtils.hpp>
18 #include <armnnTestUtils/WorkloadTestUtils.hpp>
19
20 #include <doctest/doctest.h>
21
ConfigureLoggingTest()22 inline void ConfigureLoggingTest()
23 {
24 // Configures logging for both the ARMNN library and this test program.
25 armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
26 }
27
28 // The following macros require the caller to have defined FactoryType, with one of the following using statements:
29 //
30 // using FactoryType = armnn::RefWorkloadFactory;
31 // using FactoryType = armnn::ClWorkloadFactory;
32 // using FactoryType = armnn::NeonWorkloadFactory;
33
34 /// Executes CHECK_MESSAGE on CompareTensors() return value so that the predicate_result message is reported.
35 /// If the test reports itself as not supported then the tensors are not compared.
36 /// Additionally this checks that the supportedness reported by the test matches the name of the test.
37 /// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name.
38 /// This is useful because it clarifies that the feature being tested is not actually supported
39 /// (a passed test with the name of a feature would imply that feature was supported).
40 /// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED.
41 /// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED.
42 template <typename T, std::size_t n>
CompareTestResultIfSupported(const std::string & testName,const LayerTestResult<T,n> & testResult)43 void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult<T, n>& testResult)
44 {
45 bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
46 CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.m_Supported,
47 "The test name does not match the supportedness it is reporting");
48 if (testResult.m_Supported)
49 {
50 auto result = CompareTensors(testResult.m_ActualData,
51 testResult.m_ExpectedData,
52 testResult.m_ActualShape,
53 testResult.m_ExpectedShape,
54 testResult.m_CompareBoolean);
55 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
56 }
57 }
58
59 template <typename T, std::size_t n>
CompareTestResultIfSupported(const std::string & testName,const std::vector<LayerTestResult<T,n>> & testResult)60 void CompareTestResultIfSupported(const std::string& testName, const std::vector<LayerTestResult<T, n>>& testResult)
61 {
62 bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
63 for (unsigned int i = 0; i < testResult.size(); ++i)
64 {
65 CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].m_Supported,
66 "The test name does not match the supportedness it is reporting");
67 if (testResult[i].m_Supported)
68 {
69 auto result = CompareTensors(testResult[i].m_ActualData,
70 testResult[i].m_ExpectedData,
71 testResult[i].m_ActualShape,
72 testResult[i].m_ExpectedShape);
73 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
74 }
75 }
76 }
77
78 template<typename FactoryType, typename TFuncPtr, typename... Args>
RunTestFunction(const char * testName,TFuncPtr testFunction,Args...args)79 void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
80 {
81 std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
82 armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
83
84 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
85 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
86
87 auto testResult = (*testFunction)(workloadFactory, memoryManager, args...);
88 CompareTestResultIfSupported(testName, testResult);
89
90 armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
91 }
92
93
94 template<typename FactoryType, typename TFuncPtr, typename... Args>
RunTestFunctionUsingTensorHandleFactory(const char * testName,TFuncPtr testFunction,Args...args)95 void RunTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
96 {
97 std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
98 armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
99
100 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
101 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
102
103 auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
104
105 auto testResult = (*testFunction)(workloadFactory, memoryManager, tensorHandleFactory, args...);
106 CompareTestResultIfSupported(testName, testResult);
107
108 armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
109 }
110
111 #define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \
112 TEST_CASE(#TestName) \
113 { \
114 TestFunction(); \
115 }
116
117 #define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
118 TEST_CASE(#TestName) \
119 { \
120 RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
121 }
122
123 #define ARMNN_AUTO_TEST_FIXTURE(TestName, Fixture, TestFunction, ...) \
124 TEST_CASE_FIXTURE(Fixture, #TestName) \
125 { \
126 RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
127 }
128
129 #define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
130 TEST_CASE(#TestName) \
131 { \
132 RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
133 }
134
135 #define ARMNN_AUTO_TEST_FIXTURE_WITH_THF(TestName, Fixture, TestFunction, ...) \
136 TEST_CASE_FIXTURE(Fixture, #TestName) \
137 { \
138 RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
139 }
140
141 template<typename FactoryType, typename TFuncPtr, typename... Args>
CompareRefTestFunction(const char * testName,TFuncPtr testFunction,Args...args)142 void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
143 {
144 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
145 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
146
147 armnn::RefWorkloadFactory refWorkloadFactory;
148
149 auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...);
150 CompareTestResultIfSupported(testName, testResult);
151 }
152
153 template<typename FactoryType, typename TFuncPtr, typename... Args>
CompareRefTestFunctionUsingTensorHandleFactory(const char * testName,TFuncPtr testFunction,Args...args)154 void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
155 {
156 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
157 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
158 auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
159
160 armnn::RefWorkloadFactory refWorkloadFactory;
161 auto refMemoryManager = WorkloadFactoryHelper<armnn::RefWorkloadFactory>::GetMemoryManager();
162 auto refTensorHandleFactory = RefWorkloadFactoryHelper::GetTensorHandleFactory(refMemoryManager);
163
164 auto testResult = (*testFunction)(
165 workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, args...);
166 CompareTestResultIfSupported(testName, testResult);
167 }
168
169 #define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
170 TEST_CASE(#TestName) \
171 { \
172 CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
173 }
174
175 #define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
176 TEST_CASE(#TestName) \
177 { \
178 CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
179 }
180
181 #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
182 TEST_CASE_FIXTURE(Fixture, #TestName) \
183 { \
184 CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
185 }
186
187 #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE_WITH_THF(TestName, Fixture, TestFunction, ...) \
188 TEST_CASE_FIXTURE(Fixture, #TestName) \
189 { \
190 CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
191 }
192