xref: /aosp_15_r20/external/armnn/include/armnnTestUtils/TensorHelpers.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnnTestUtils/PredicateResult.hpp>
8 
9 #include <armnn/Tensor.hpp>
10 #include <armnn/utility/Assert.hpp>
11 #include <armnnUtils/FloatingPointComparison.hpp>
12 
13 #include <armnnUtils/QuantizeHelper.hpp>
14 
15 #include <doctest/doctest.h>
16 
17 #include <array>
18 #include <cmath>
19 #include <random>
20 #include <vector>
21 
22 constexpr float g_FloatCloseToZeroTolerance = 1.0e-6f;
23 
24 template<typename T, bool isQuantized = true>
25 struct SelectiveComparer
26 {
CompareSelectiveComparer27     static bool Compare(T a, T b)
28     {
29         return (std::max(a, b) - std::min(a, b)) <= 1;
30     }
31 
32 };
33 
34 template<typename T>
35 struct SelectiveComparer<T, false>
36 {
CompareSelectiveComparer37     static bool Compare(T a, T b)
38     {
39         // If a or b is zero, percent_tolerance does an exact match, so compare to a small, constant tolerance instead.
40         if (a == 0.0f || b == 0.0f)
41         {
42             return std::abs(a - b) <= g_FloatCloseToZeroTolerance;
43         }
44 
45         if (std::isinf(a) && a == b)
46         {
47             return true;
48         }
49 
50         if (std::isnan(a) && std::isnan(b))
51         {
52             return true;
53         }
54 
55         // For unquantized floats we use a tolerance of 1%.
56         return armnnUtils::within_percentage_tolerance(a, b);
57     }
58 };
59 
60 template<typename T>
SelectiveCompare(T a,T b)61 bool SelectiveCompare(T a, T b)
62 {
63     return SelectiveComparer<T, armnn::IsQuantizedType<T>()>::Compare(a, b);
64 };
65 
66 template<typename T>
SelectiveCompareBoolean(T a,T b)67 bool SelectiveCompareBoolean(T a, T b)
68 {
69     return (((a == 0) && (b == 0)) || ((a != 0) && (b != 0)));
70 };
71 
72 template <typename T>
CompareTensors(const std::vector<T> & actualData,const std::vector<T> & expectedData,const armnn::TensorShape & actualShape,const armnn::TensorShape & expectedShape,bool compareBoolean=false,bool isDynamic=false)73 armnn::PredicateResult CompareTensors(const std::vector<T>& actualData,
74                                       const std::vector<T>& expectedData,
75                                       const armnn::TensorShape& actualShape,
76                                       const armnn::TensorShape& expectedShape,
77                                       bool compareBoolean = false,
78                                       bool isDynamic = false)
79 {
80     if (actualData.size() != expectedData.size())
81     {
82         armnn::PredicateResult res(false);
83         res.Message() << "Different data size ["
84                       << actualData.size()
85                       << "!="
86                       << expectedData.size()
87                       << "]";
88         return res;
89     }
90 
91     if (actualShape.GetNumDimensions() != expectedShape.GetNumDimensions())
92     {
93         armnn::PredicateResult res(false);
94         res.Message() << "Different number of dimensions ["
95                       << actualShape.GetNumDimensions()
96                       << "!="
97                       << expectedShape.GetNumDimensions()
98                       << "]";
99         return res;
100     }
101 
102     if (actualShape.GetNumElements() != expectedShape.GetNumElements())
103     {
104         armnn::PredicateResult res(false);
105         res.Message() << "Different number of elements ["
106                       << actualShape.GetNumElements()
107                       << "!="
108                       << expectedShape.GetNumElements()
109                       << "]";
110         return res;
111     }
112 
113     unsigned int numberOfDimensions = actualShape.GetNumDimensions();
114 
115     if (!isDynamic)
116     {
117         // Checks they are same shape.
118         for (unsigned int i = 0; i < numberOfDimensions; ++i)
119         {
120             if (actualShape[i] != expectedShape[i])
121             {
122                 armnn::PredicateResult res(false);
123                 res.Message() << "Different shapes ["
124                               << actualShape[i]
125                               << "!="
126                               << expectedShape[i]
127                               << "]";
128                 return res;
129             }
130         }
131     }
132 
133     // Fun iteration over n dimensions.
134     std::vector<unsigned int> indices;
135     for (unsigned int i = 0; i < numberOfDimensions; i++)
136     {
137         indices.emplace_back(0);
138     }
139 
140     std::stringstream errorString;
141     int numFailedElements = 0;
142     constexpr int maxReportedDifferences = 3;
143     unsigned int index = 0;
144 
145     // Compare data element by element.
146     while (true)
147     {
148         bool comparison;
149         // As true for uint8_t is non-zero (1-255) we must have a dedicated compare for Booleans.
150         if(compareBoolean)
151         {
152             comparison = SelectiveCompareBoolean(actualData[index], expectedData[index]);
153         }
154         else
155         {
156             comparison = SelectiveCompare(actualData[index], expectedData[index]);
157         }
158 
159         if (!comparison)
160         {
161             ++numFailedElements;
162 
163             if (numFailedElements <= maxReportedDifferences)
164             {
165                 if (numFailedElements >= 2)
166                 {
167                     errorString << ", ";
168                 }
169                 errorString << "[";
170                 for (unsigned int i = 0; i < numberOfDimensions; ++i)
171                 {
172                     errorString << indices[i];
173                     if (i != numberOfDimensions - 1)
174                     {
175                         errorString << ",";
176                     }
177                 }
178                 errorString << "]";
179 
180                 errorString << " (" << +actualData[index] << " != " << +expectedData[index] << ")";
181             }
182         }
183 
184         ++indices[numberOfDimensions - 1];
185         for (unsigned int i=numberOfDimensions-1; i>0; i--)
186         {
187             if (indices[i] == actualShape[i])
188             {
189                 indices[i] = 0;
190                 ++indices[i - 1];
191             }
192         }
193         if (indices[0] == actualShape[0])
194         {
195             break;
196         }
197 
198         index++;
199     }
200 
201     armnn::PredicateResult comparisonResult(true);
202     if (numFailedElements > 0)
203     {
204         comparisonResult.SetResult(false);
205         comparisonResult.Message() << numFailedElements << " different values at: ";
206         if (numFailedElements > maxReportedDifferences)
207         {
208             errorString << ", ... (and " << (numFailedElements - maxReportedDifferences) << " other differences)";
209         }
210         comparisonResult.Message() << errorString.str();
211     }
212 
213     return comparisonResult;
214 }
215 
216 template <typename T>
MakeRandomTensor(const armnn::TensorInfo & tensorInfo,unsigned int seed,float min=-10.0f,float max=10.0f)217 std::vector<T> MakeRandomTensor(const armnn::TensorInfo& tensorInfo,
218                                 unsigned int seed,
219                                 float        min = -10.0f,
220                                 float        max = 10.0f)
221 {
222     std::mt19937 gen(seed);
223     std::uniform_real_distribution<float> dist(min, max);
224 
225     std::vector<float> init(tensorInfo.GetNumElements());
226     for (unsigned int i = 0; i < init.size(); i++)
227     {
228         init[i] = dist(gen);
229     }
230 
231     const float   qScale  = tensorInfo.GetQuantizationScale();
232     const int32_t qOffset = tensorInfo.GetQuantizationOffset();
233 
234     return armnnUtils::QuantizedVector<T>(init, qScale, qOffset);
235 }
236