xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeComparison.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <armnnUtils/QuantizeHelper.hpp>
10 #include <ResolveType.hpp>
11 
12 #include <string>
13 
14 TEST_SUITE("Deserializer_Comparison")
15 {
16 #define DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
17 struct Simple##operation##dataType##Fixture : public SimpleComparisonFixture \
18 { \
19     Simple##operation##dataType##Fixture() \
20         : SimpleComparisonFixture(#dataType, #operation) {} \
21 };
22 
23 #define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType) \
24 DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
25 TEST_CASE_FIXTURE(Simple##operation##dataType##Fixture, #operation#dataType) \
26 { \
27     using T = armnn::ResolveType<armnn::DataType::dataType>; \
28     constexpr float   qScale  = 1.f; \
29     constexpr int32_t qOffset = 0; \
30     RunTest<4, armnn::DataType::dataType, armnn::DataType::Boolean>( \
31         0, \
32         {{ "InputLayer0", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData0, qScale, qOffset)  }, \
33          { "InputLayer1", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData1, qScale, qOffset)  }}, \
34         {{ "OutputLayer", s_TestData.m_Output##operation }}); \
35 }
36 
37 struct ComparisonFixture : public ParserFlatbuffersSerializeFixture
38 {
ComparisonFixtureComparisonFixture39     explicit ComparisonFixture(const std::string& inputShape0,
40                                const std::string& inputShape1,
41                                const std::string& outputShape,
42                                const std::string& inputDataType,
43                                const std::string& comparisonOperation)
44     {
45         m_JsonString = R"(
46             {
47                 inputIds: [0, 1],
48                 outputIds: [3],
49                 layers: [
50                     {
51                         layer_type: "InputLayer",
52                         layer: {
53                             base: {
54                                 layerBindingId: 0,
55                                 base: {
56                                     index: 0,
57                                     layerName: "InputLayer0",
58                                     layerType: "Input",
59                                     inputSlots: [{
60                                         index: 0,
61                                         connection: { sourceLayerIndex:0, outputSlotIndex:0 },
62                                     }],
63                                     outputSlots: [{
64                                         index: 0,
65                                         tensorInfo: {
66                                             dimensions: )" + inputShape0 + R"(,
67                                             dataType: )" + inputDataType + R"(
68                                         },
69                                     }],
70                                 },
71                             }
72                         },
73                     },
74                     {
75                         layer_type: "InputLayer",
76                         layer: {
77                             base: {
78                                 layerBindingId: 1,
79                                 base: {
80                                       index:1,
81                                       layerName: "InputLayer1",
82                                       layerType: "Input",
83                                       inputSlots: [{
84                                           index: 0,
85                                           connection: { sourceLayerIndex:0, outputSlotIndex:0 },
86                                       }],
87                                       outputSlots: [{
88                                           index: 0,
89                                           tensorInfo: {
90                                               dimensions: )" + inputShape1 + R"(,
91                                               dataType: )" + inputDataType + R"(
92                                           },
93                                       }],
94                                 },
95                             }
96                         },
97                     },
98                     {
99                         layer_type: "ComparisonLayer",
100                         layer: {
101                             base: {
102                                  index:2,
103                                  layerName: "ComparisonLayer",
104                                  layerType: "Comparison",
105                                  inputSlots: [{
106                                      index: 0,
107                                      connection: { sourceLayerIndex:0, outputSlotIndex:0 },
108                                  },
109                                  {
110                                      index: 1,
111                                      connection: { sourceLayerIndex:1, outputSlotIndex:0 },
112                                  }],
113                                  outputSlots: [{
114                                      index: 0,
115                                      tensorInfo: {
116                                          dimensions: )" + outputShape + R"(,
117                                          dataType: Boolean
118                                      },
119                                  }],
120                             },
121                             descriptor: {
122                                 operation: )" + comparisonOperation + R"(
123                             }
124                         },
125                     },
126                     {
127                         layer_type: "OutputLayer",
128                         layer: {
129                             base:{
130                                 layerBindingId: 0,
131                                 base: {
132                                     index: 3,
133                                     layerName: "OutputLayer",
134                                     layerType: "Output",
135                                     inputSlots: [{
136                                         index: 0,
137                                         connection: { sourceLayerIndex:2, outputSlotIndex:0 },
138                                     }],
139                                     outputSlots: [{
140                                         index: 0,
141                                         tensorInfo: {
142                                             dimensions: )" + outputShape + R"(,
143                                             dataType: Boolean
144                                         },
145                                     }],
146                                 }
147                             }
148                         },
149                     }
150                 ]
151             }
152         )";
153         Setup();
154     }
155 };
156 
157 struct SimpleComparisonTestData
158 {
SimpleComparisonTestDataSimpleComparisonTestData159     SimpleComparisonTestData()
160     {
161         m_InputData0 =
162         {
163             1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
164             3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
165         };
166 
167         m_InputData1 =
168         {
169             1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f,
170             5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f
171         };
172 
173         m_OutputEqual =
174         {
175             1, 1, 1, 1, 0, 0, 0, 0,
176             0, 0, 0, 0, 1, 1, 1, 1
177         };
178 
179         m_OutputGreater =
180         {
181             0, 0, 0, 0, 1, 1, 1, 1,
182             0, 0, 0, 0, 0, 0, 0, 0
183         };
184 
185         m_OutputGreaterOrEqual =
186         {
187             1, 1, 1, 1, 1, 1, 1, 1,
188             0, 0, 0, 0, 1, 1, 1, 1
189         };
190 
191         m_OutputLess =
192         {
193             0, 0, 0, 0, 0, 0, 0, 0,
194             1, 1, 1, 1, 0, 0, 0, 0
195         };
196 
197         m_OutputLessOrEqual =
198         {
199             1, 1, 1, 1, 0, 0, 0, 0,
200             1, 1, 1, 1, 1, 1, 1, 1
201         };
202 
203         m_OutputNotEqual =
204         {
205             0, 0, 0, 0, 1, 1, 1, 1,
206             1, 1, 1, 1, 0, 0, 0, 0
207         };
208     }
209 
210     std::vector<float> m_InputData0;
211     std::vector<float> m_InputData1;
212 
213     std::vector<uint8_t> m_OutputEqual;
214     std::vector<uint8_t> m_OutputGreater;
215     std::vector<uint8_t> m_OutputGreaterOrEqual;
216     std::vector<uint8_t> m_OutputLess;
217     std::vector<uint8_t> m_OutputLessOrEqual;
218     std::vector<uint8_t> m_OutputNotEqual;
219 };
220 
221 struct SimpleComparisonFixture : public ComparisonFixture
222 {
SimpleComparisonFixtureSimpleComparisonFixture223     SimpleComparisonFixture(const std::string& inputDataType,
224                             const std::string& comparisonOperation)
225         : ComparisonFixture("[ 2, 2, 2, 2 ]", // inputShape0
226                             "[ 2, 2, 2, 2 ]", // inputShape1
227                             "[ 2, 2, 2, 2 ]", // outputShape,
228                             inputDataType,
229                             comparisonOperation) {}
230 
231     static SimpleComparisonTestData s_TestData;
232 };
233 
234 SimpleComparisonTestData SimpleComparisonFixture::s_TestData;
235 
236 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal,          Float32)
237 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater,        Float32)
238 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, Float32)
239 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less,           Float32)
240 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual,    Float32)
241 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual,       Float32)
242 
243 
244 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal,          QAsymmU8)
245 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater,        QAsymmU8)
246 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QAsymmU8)
247 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less,           QAsymmU8)
248 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual,    QAsymmU8)
249 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual,       QAsymmU8)
250 
251 }
252