1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "TestUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
9*89c4ff92SAndroid Build Coastguard Worker {
10*89c4ff92SAndroid Build Coastguard Worker
CompareData(bool tensor1[],bool tensor2[],size_t tensorSize)11*89c4ff92SAndroid Build Coastguard Worker void CompareData(bool tensor1[], bool tensor2[], size_t tensorSize)
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker auto compareBool = [](auto a, auto b) {return (((a == 0) && (b == 0)) || ((a != 0) && (b != 0)));};
14*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker CHECK(compareBool(tensor1[i], tensor2[i]));
17*89c4ff92SAndroid Build Coastguard Worker }
18*89c4ff92SAndroid Build Coastguard Worker }
19*89c4ff92SAndroid Build Coastguard Worker
CompareData(std::vector<bool> & tensor1,std::vector<bool> & tensor2,size_t tensorSize)20*89c4ff92SAndroid Build Coastguard Worker void CompareData(std::vector<bool>& tensor1, std::vector<bool>& tensor2, size_t tensorSize)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker auto compareBool = [](auto a, auto b) {return (((a == 0) && (b == 0)) || ((a != 0) && (b != 0)));};
23*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker CHECK(compareBool(tensor1[i], tensor2[i]));
26*89c4ff92SAndroid Build Coastguard Worker }
27*89c4ff92SAndroid Build Coastguard Worker }
28*89c4ff92SAndroid Build Coastguard Worker
CompareData(float tensor1[],float tensor2[],size_t tensorSize)29*89c4ff92SAndroid Build Coastguard Worker void CompareData(float tensor1[], float tensor2[], size_t tensorSize)
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker CHECK(tensor1[i] == doctest::Approx( tensor2[i] ));
34*89c4ff92SAndroid Build Coastguard Worker }
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker
CompareData(float tensor1[],float tensor2[],size_t tensorSize,float percentTolerance)37*89c4ff92SAndroid Build Coastguard Worker void CompareData(float tensor1[], float tensor2[], size_t tensorSize, float percentTolerance)
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker CHECK(std::max(tensor1[i], tensor2[i]) - std::min(tensor1[i], tensor2[i]) <=
42*89c4ff92SAndroid Build Coastguard Worker std::abs(tensor1[i]*percentTolerance/100));
43*89c4ff92SAndroid Build Coastguard Worker }
44*89c4ff92SAndroid Build Coastguard Worker }
45*89c4ff92SAndroid Build Coastguard Worker
CompareData(uint8_t tensor1[],uint8_t tensor2[],size_t tensorSize)46*89c4ff92SAndroid Build Coastguard Worker void CompareData(uint8_t tensor1[], uint8_t tensor2[], size_t tensorSize)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker uint8_t tolerance = 1;
49*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker CHECK(std::max(tensor1[i], tensor2[i]) - std::min(tensor1[i], tensor2[i]) <= tolerance);
52*89c4ff92SAndroid Build Coastguard Worker }
53*89c4ff92SAndroid Build Coastguard Worker }
54*89c4ff92SAndroid Build Coastguard Worker
CompareData(int16_t tensor1[],int16_t tensor2[],size_t tensorSize)55*89c4ff92SAndroid Build Coastguard Worker void CompareData(int16_t tensor1[], int16_t tensor2[], size_t tensorSize)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker int16_t tolerance = 1;
58*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker CHECK(std::max(tensor1[i], tensor2[i]) - std::min(tensor1[i], tensor2[i]) <= tolerance);
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker
CompareData(int32_t tensor1[],int32_t tensor2[],size_t tensorSize)64*89c4ff92SAndroid Build Coastguard Worker void CompareData(int32_t tensor1[], int32_t tensor2[], size_t tensorSize)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker int32_t tolerance = 1;
67*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker CHECK(std::max(tensor1[i], tensor2[i]) - std::min(tensor1[i], tensor2[i]) <= tolerance);
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker }
72*89c4ff92SAndroid Build Coastguard Worker
CompareData(int8_t tensor1[],int8_t tensor2[],size_t tensorSize)73*89c4ff92SAndroid Build Coastguard Worker void CompareData(int8_t tensor1[], int8_t tensor2[], size_t tensorSize)
74*89c4ff92SAndroid Build Coastguard Worker {
75*89c4ff92SAndroid Build Coastguard Worker int8_t tolerance = 1;
76*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker CHECK(std::max(tensor1[i], tensor2[i]) - std::min(tensor1[i], tensor2[i]) <= tolerance);
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker
CompareData(Half tensor1[],Half tensor2[],size_t tensorSize)82*89c4ff92SAndroid Build Coastguard Worker void CompareData(Half tensor1[], Half tensor2[], size_t tensorSize)
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker CHECK(tensor1[i] == doctest::Approx( tensor2[i] ));
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker }
89*89c4ff92SAndroid Build Coastguard Worker
CompareData(TfLiteFloat16 tensor1[],TfLiteFloat16 tensor2[],size_t tensorSize)90*89c4ff92SAndroid Build Coastguard Worker void CompareData(TfLiteFloat16 tensor1[], TfLiteFloat16 tensor2[], size_t tensorSize)
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker uint16_t tolerance = 1;
93*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
94*89c4ff92SAndroid Build Coastguard Worker {
95*89c4ff92SAndroid Build Coastguard Worker uint16_t tensor1Data = tensor1[i].data;
96*89c4ff92SAndroid Build Coastguard Worker uint16_t tensor2Data = tensor2[i].data;
97*89c4ff92SAndroid Build Coastguard Worker CHECK(std::max(tensor1Data, tensor2Data) - std::min(tensor1Data, tensor2Data) <= tolerance);
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker
CompareData(TfLiteFloat16 tensor1[],Half tensor2[],size_t tensorSize)101*89c4ff92SAndroid Build Coastguard Worker void CompareData(TfLiteFloat16 tensor1[], Half tensor2[], size_t tensorSize) {
102*89c4ff92SAndroid Build Coastguard Worker uint16_t tolerance = 1;
103*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < tensorSize; i++)
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker uint16_t tensor1Data = tensor1[i].data;
106*89c4ff92SAndroid Build Coastguard Worker uint16_t tensor2Data = half_float::detail::float2half<std::round_indeterminate, float>(tensor2[i]);
107*89c4ff92SAndroid Build Coastguard Worker CHECK(std::max(tensor1Data, tensor2Data) - std::min(tensor1Data, tensor2Data) <= tolerance);
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker }
110*89c4ff92SAndroid Build Coastguard Worker
CompareOutputShape(const std::vector<int32_t> & tfLiteDelegateShape,const std::vector<int32_t> & armnnDelegateShape,const std::vector<int32_t> & expectedOutputShape)111*89c4ff92SAndroid Build Coastguard Worker void CompareOutputShape(const std::vector<int32_t>& tfLiteDelegateShape,
112*89c4ff92SAndroid Build Coastguard Worker const std::vector<int32_t>& armnnDelegateShape,
113*89c4ff92SAndroid Build Coastguard Worker const std::vector<int32_t>& expectedOutputShape)
114*89c4ff92SAndroid Build Coastguard Worker {
115*89c4ff92SAndroid Build Coastguard Worker CHECK(expectedOutputShape.size() == tfLiteDelegateShape.size());
116*89c4ff92SAndroid Build Coastguard Worker CHECK(expectedOutputShape.size() == armnnDelegateShape.size());
117*89c4ff92SAndroid Build Coastguard Worker
118*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < expectedOutputShape.size(); i++)
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker CHECK(expectedOutputShape[i] == armnnDelegateShape[i]);
121*89c4ff92SAndroid Build Coastguard Worker CHECK(tfLiteDelegateShape[i] == expectedOutputShape[i]);
122*89c4ff92SAndroid Build Coastguard Worker CHECK(tfLiteDelegateShape[i] == armnnDelegateShape[i]);
123*89c4ff92SAndroid Build Coastguard Worker }
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate