xref: /aosp_15_r20/external/armnn/include/armnnTestUtils/TensorHelpers.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/PredicateResult.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/FloatingPointComparison.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <array>
18*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
19*89c4ff92SAndroid Build Coastguard Worker #include <random>
20*89c4ff92SAndroid Build Coastguard Worker #include <vector>
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker constexpr float g_FloatCloseToZeroTolerance = 1.0e-6f;
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker template<typename T, bool isQuantized = true>
25*89c4ff92SAndroid Build Coastguard Worker struct SelectiveComparer
26*89c4ff92SAndroid Build Coastguard Worker {
CompareSelectiveComparer27*89c4ff92SAndroid Build Coastguard Worker     static bool Compare(T a, T b)
28*89c4ff92SAndroid Build Coastguard Worker     {
29*89c4ff92SAndroid Build Coastguard Worker         return (std::max(a, b) - std::min(a, b)) <= 1;
30*89c4ff92SAndroid Build Coastguard Worker     }
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker };
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker template<typename T>
35*89c4ff92SAndroid Build Coastguard Worker struct SelectiveComparer<T, false>
36*89c4ff92SAndroid Build Coastguard Worker {
CompareSelectiveComparer37*89c4ff92SAndroid Build Coastguard Worker     static bool Compare(T a, T b)
38*89c4ff92SAndroid Build Coastguard Worker     {
39*89c4ff92SAndroid Build Coastguard Worker         // If a or b is zero, percent_tolerance does an exact match, so compare to a small, constant tolerance instead.
40*89c4ff92SAndroid Build Coastguard Worker         if (a == 0.0f || b == 0.0f)
41*89c4ff92SAndroid Build Coastguard Worker         {
42*89c4ff92SAndroid Build Coastguard Worker             return std::abs(a - b) <= g_FloatCloseToZeroTolerance;
43*89c4ff92SAndroid Build Coastguard Worker         }
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker         if (std::isinf(a) && a == b)
46*89c4ff92SAndroid Build Coastguard Worker         {
47*89c4ff92SAndroid Build Coastguard Worker             return true;
48*89c4ff92SAndroid Build Coastguard Worker         }
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker         if (std::isnan(a) && std::isnan(b))
51*89c4ff92SAndroid Build Coastguard Worker         {
52*89c4ff92SAndroid Build Coastguard Worker             return true;
53*89c4ff92SAndroid Build Coastguard Worker         }
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker         // For unquantized floats we use a tolerance of 1%.
56*89c4ff92SAndroid Build Coastguard Worker         return armnnUtils::within_percentage_tolerance(a, b);
57*89c4ff92SAndroid Build Coastguard Worker     }
58*89c4ff92SAndroid Build Coastguard Worker };
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker template<typename T>
SelectiveCompare(T a,T b)61*89c4ff92SAndroid Build Coastguard Worker bool SelectiveCompare(T a, T b)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker     return SelectiveComparer<T, armnn::IsQuantizedType<T>()>::Compare(a, b);
64*89c4ff92SAndroid Build Coastguard Worker };
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker template<typename T>
SelectiveCompareBoolean(T a,T b)67*89c4ff92SAndroid Build Coastguard Worker bool SelectiveCompareBoolean(T a, T b)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker     return (((a == 0) && (b == 0)) || ((a != 0) && (b != 0)));
70*89c4ff92SAndroid Build Coastguard Worker };
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker template <typename T>
CompareTensors(const std::vector<T> & actualData,const std::vector<T> & expectedData,const armnn::TensorShape & actualShape,const armnn::TensorShape & expectedShape,bool compareBoolean=false,bool isDynamic=false)73*89c4ff92SAndroid Build Coastguard Worker armnn::PredicateResult CompareTensors(const std::vector<T>& actualData,
74*89c4ff92SAndroid Build Coastguard Worker                                       const std::vector<T>& expectedData,
75*89c4ff92SAndroid Build Coastguard Worker                                       const armnn::TensorShape& actualShape,
76*89c4ff92SAndroid Build Coastguard Worker                                       const armnn::TensorShape& expectedShape,
77*89c4ff92SAndroid Build Coastguard Worker                                       bool compareBoolean = false,
78*89c4ff92SAndroid Build Coastguard Worker                                       bool isDynamic = false)
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker     if (actualData.size() != expectedData.size())
81*89c4ff92SAndroid Build Coastguard Worker     {
82*89c4ff92SAndroid Build Coastguard Worker         armnn::PredicateResult res(false);
83*89c4ff92SAndroid Build Coastguard Worker         res.Message() << "Different data size ["
84*89c4ff92SAndroid Build Coastguard Worker                       << actualData.size()
85*89c4ff92SAndroid Build Coastguard Worker                       << "!="
86*89c4ff92SAndroid Build Coastguard Worker                       << expectedData.size()
87*89c4ff92SAndroid Build Coastguard Worker                       << "]";
88*89c4ff92SAndroid Build Coastguard Worker         return res;
89*89c4ff92SAndroid Build Coastguard Worker     }
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     if (actualShape.GetNumDimensions() != expectedShape.GetNumDimensions())
92*89c4ff92SAndroid Build Coastguard Worker     {
93*89c4ff92SAndroid Build Coastguard Worker         armnn::PredicateResult res(false);
94*89c4ff92SAndroid Build Coastguard Worker         res.Message() << "Different number of dimensions ["
95*89c4ff92SAndroid Build Coastguard Worker                       << actualShape.GetNumDimensions()
96*89c4ff92SAndroid Build Coastguard Worker                       << "!="
97*89c4ff92SAndroid Build Coastguard Worker                       << expectedShape.GetNumDimensions()
98*89c4ff92SAndroid Build Coastguard Worker                       << "]";
99*89c4ff92SAndroid Build Coastguard Worker         return res;
100*89c4ff92SAndroid Build Coastguard Worker     }
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker     if (actualShape.GetNumElements() != expectedShape.GetNumElements())
103*89c4ff92SAndroid Build Coastguard Worker     {
104*89c4ff92SAndroid Build Coastguard Worker         armnn::PredicateResult res(false);
105*89c4ff92SAndroid Build Coastguard Worker         res.Message() << "Different number of elements ["
106*89c4ff92SAndroid Build Coastguard Worker                       << actualShape.GetNumElements()
107*89c4ff92SAndroid Build Coastguard Worker                       << "!="
108*89c4ff92SAndroid Build Coastguard Worker                       << expectedShape.GetNumElements()
109*89c4ff92SAndroid Build Coastguard Worker                       << "]";
110*89c4ff92SAndroid Build Coastguard Worker         return res;
111*89c4ff92SAndroid Build Coastguard Worker     }
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker     unsigned int numberOfDimensions = actualShape.GetNumDimensions();
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker     if (!isDynamic)
116*89c4ff92SAndroid Build Coastguard Worker     {
117*89c4ff92SAndroid Build Coastguard Worker         // Checks they are same shape.
118*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < numberOfDimensions; ++i)
119*89c4ff92SAndroid Build Coastguard Worker         {
120*89c4ff92SAndroid Build Coastguard Worker             if (actualShape[i] != expectedShape[i])
121*89c4ff92SAndroid Build Coastguard Worker             {
122*89c4ff92SAndroid Build Coastguard Worker                 armnn::PredicateResult res(false);
123*89c4ff92SAndroid Build Coastguard Worker                 res.Message() << "Different shapes ["
124*89c4ff92SAndroid Build Coastguard Worker                               << actualShape[i]
125*89c4ff92SAndroid Build Coastguard Worker                               << "!="
126*89c4ff92SAndroid Build Coastguard Worker                               << expectedShape[i]
127*89c4ff92SAndroid Build Coastguard Worker                               << "]";
128*89c4ff92SAndroid Build Coastguard Worker                 return res;
129*89c4ff92SAndroid Build Coastguard Worker             }
130*89c4ff92SAndroid Build Coastguard Worker         }
131*89c4ff92SAndroid Build Coastguard Worker     }
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker     // Fun iteration over n dimensions.
134*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> indices;
135*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < numberOfDimensions; i++)
136*89c4ff92SAndroid Build Coastguard Worker     {
137*89c4ff92SAndroid Build Coastguard Worker         indices.emplace_back(0);
138*89c4ff92SAndroid Build Coastguard Worker     }
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker     std::stringstream errorString;
141*89c4ff92SAndroid Build Coastguard Worker     int numFailedElements = 0;
142*89c4ff92SAndroid Build Coastguard Worker     constexpr int maxReportedDifferences = 3;
143*89c4ff92SAndroid Build Coastguard Worker     unsigned int index = 0;
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker     // Compare data element by element.
146*89c4ff92SAndroid Build Coastguard Worker     while (true)
147*89c4ff92SAndroid Build Coastguard Worker     {
148*89c4ff92SAndroid Build Coastguard Worker         bool comparison;
149*89c4ff92SAndroid Build Coastguard Worker         // As true for uint8_t is non-zero (1-255) we must have a dedicated compare for Booleans.
150*89c4ff92SAndroid Build Coastguard Worker         if(compareBoolean)
151*89c4ff92SAndroid Build Coastguard Worker         {
152*89c4ff92SAndroid Build Coastguard Worker             comparison = SelectiveCompareBoolean(actualData[index], expectedData[index]);
153*89c4ff92SAndroid Build Coastguard Worker         }
154*89c4ff92SAndroid Build Coastguard Worker         else
155*89c4ff92SAndroid Build Coastguard Worker         {
156*89c4ff92SAndroid Build Coastguard Worker             comparison = SelectiveCompare(actualData[index], expectedData[index]);
157*89c4ff92SAndroid Build Coastguard Worker         }
158*89c4ff92SAndroid Build Coastguard Worker 
159*89c4ff92SAndroid Build Coastguard Worker         if (!comparison)
160*89c4ff92SAndroid Build Coastguard Worker         {
161*89c4ff92SAndroid Build Coastguard Worker             ++numFailedElements;
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker             if (numFailedElements <= maxReportedDifferences)
164*89c4ff92SAndroid Build Coastguard Worker             {
165*89c4ff92SAndroid Build Coastguard Worker                 if (numFailedElements >= 2)
166*89c4ff92SAndroid Build Coastguard Worker                 {
167*89c4ff92SAndroid Build Coastguard Worker                     errorString << ", ";
168*89c4ff92SAndroid Build Coastguard Worker                 }
169*89c4ff92SAndroid Build Coastguard Worker                 errorString << "[";
170*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int i = 0; i < numberOfDimensions; ++i)
171*89c4ff92SAndroid Build Coastguard Worker                 {
172*89c4ff92SAndroid Build Coastguard Worker                     errorString << indices[i];
173*89c4ff92SAndroid Build Coastguard Worker                     if (i != numberOfDimensions - 1)
174*89c4ff92SAndroid Build Coastguard Worker                     {
175*89c4ff92SAndroid Build Coastguard Worker                         errorString << ",";
176*89c4ff92SAndroid Build Coastguard Worker                     }
177*89c4ff92SAndroid Build Coastguard Worker                 }
178*89c4ff92SAndroid Build Coastguard Worker                 errorString << "]";
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker                 errorString << " (" << +actualData[index] << " != " << +expectedData[index] << ")";
181*89c4ff92SAndroid Build Coastguard Worker             }
182*89c4ff92SAndroid Build Coastguard Worker         }
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker         ++indices[numberOfDimensions - 1];
185*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i=numberOfDimensions-1; i>0; i--)
186*89c4ff92SAndroid Build Coastguard Worker         {
187*89c4ff92SAndroid Build Coastguard Worker             if (indices[i] == actualShape[i])
188*89c4ff92SAndroid Build Coastguard Worker             {
189*89c4ff92SAndroid Build Coastguard Worker                 indices[i] = 0;
190*89c4ff92SAndroid Build Coastguard Worker                 ++indices[i - 1];
191*89c4ff92SAndroid Build Coastguard Worker             }
192*89c4ff92SAndroid Build Coastguard Worker         }
193*89c4ff92SAndroid Build Coastguard Worker         if (indices[0] == actualShape[0])
194*89c4ff92SAndroid Build Coastguard Worker         {
195*89c4ff92SAndroid Build Coastguard Worker             break;
196*89c4ff92SAndroid Build Coastguard Worker         }
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker         index++;
199*89c4ff92SAndroid Build Coastguard Worker     }
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker     armnn::PredicateResult comparisonResult(true);
202*89c4ff92SAndroid Build Coastguard Worker     if (numFailedElements > 0)
203*89c4ff92SAndroid Build Coastguard Worker     {
204*89c4ff92SAndroid Build Coastguard Worker         comparisonResult.SetResult(false);
205*89c4ff92SAndroid Build Coastguard Worker         comparisonResult.Message() << numFailedElements << " different values at: ";
206*89c4ff92SAndroid Build Coastguard Worker         if (numFailedElements > maxReportedDifferences)
207*89c4ff92SAndroid Build Coastguard Worker         {
208*89c4ff92SAndroid Build Coastguard Worker             errorString << ", ... (and " << (numFailedElements - maxReportedDifferences) << " other differences)";
209*89c4ff92SAndroid Build Coastguard Worker         }
210*89c4ff92SAndroid Build Coastguard Worker         comparisonResult.Message() << errorString.str();
211*89c4ff92SAndroid Build Coastguard Worker     }
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker     return comparisonResult;
214*89c4ff92SAndroid Build Coastguard Worker }
215*89c4ff92SAndroid Build Coastguard Worker 
216*89c4ff92SAndroid Build Coastguard Worker template <typename T>
MakeRandomTensor(const armnn::TensorInfo & tensorInfo,unsigned int seed,float min=-10.0f,float max=10.0f)217*89c4ff92SAndroid Build Coastguard Worker std::vector<T> MakeRandomTensor(const armnn::TensorInfo& tensorInfo,
218*89c4ff92SAndroid Build Coastguard Worker                                 unsigned int seed,
219*89c4ff92SAndroid Build Coastguard Worker                                 float        min = -10.0f,
220*89c4ff92SAndroid Build Coastguard Worker                                 float        max = 10.0f)
221*89c4ff92SAndroid Build Coastguard Worker {
222*89c4ff92SAndroid Build Coastguard Worker     std::mt19937 gen(seed);
223*89c4ff92SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> dist(min, max);
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> init(tensorInfo.GetNumElements());
226*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < init.size(); i++)
227*89c4ff92SAndroid Build Coastguard Worker     {
228*89c4ff92SAndroid Build Coastguard Worker         init[i] = dist(gen);
229*89c4ff92SAndroid Build Coastguard Worker     }
230*89c4ff92SAndroid Build Coastguard Worker 
231*89c4ff92SAndroid Build Coastguard Worker     const float   qScale  = tensorInfo.GetQuantizationScale();
232*89c4ff92SAndroid Build Coastguard Worker     const int32_t qOffset = tensorInfo.GetQuantizationOffset();
233*89c4ff92SAndroid Build Coastguard Worker 
234*89c4ff92SAndroid Build Coastguard Worker     return armnnUtils::QuantizedVector<T>(init, qScale, qOffset);
235*89c4ff92SAndroid Build Coastguard Worker }
236