xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Prelu.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <doctest/doctest.h>
7 #include "ParserFlatbuffersFixture.hpp"
8 
9 
10 TEST_SUITE("TensorflowLiteParser_Prelu")
11 {
12 struct PreluFixture : public ParserFlatbuffersFixture
13 {
PreluFixturePreluFixture14     explicit PreluFixture(const std::string& inputShape,
15                           const std::string& alphaShape,
16                           const std::string& outputShape,
17                           const std::string& inputIndex,
18                           const std::string& alphaData)
19     {
20         m_JsonString = R"(
21             {
22               "version": 3,
23               "operator_codes": [
24                 {
25                   "builtin_code": "PRELU",
26                   "version": 1
27                 }
28               ],
29               "subgraphs": [
30                 {
31                   "tensors": [
32                     {
33                       "shape": )" + inputShape + R"(,
34                       "type": "FLOAT32",
35                       "buffer": 1,
36                       "name": "input0",
37                       "quantization": {
38                         "details_type": "NONE",
39                         "quantized_dimension": 0
40                       },
41                       "is_variable": false
42                     },
43                     {
44                       "shape": )" + alphaShape + R"(,
45                       "type": "FLOAT32",
46                       "buffer": 2,
47                       "name": "input1",
48                       "quantization": {
49                         "details_type": "NONE",
50                         "quantized_dimension": 0
51                       },
52                       "is_variable": false
53                     },
54                     {
55                       "shape": )" + outputShape + R"(,
56                       "type": "FLOAT32",
57                       "buffer": 3,
58                       "name": "output",
59                       "quantization": {
60                         "details_type": "NONE",
61                         "quantized_dimension": 0
62                       },
63                       "is_variable": false
64                     }
65                   ],
66                   "inputs": )" + inputIndex + R"(,
67                   "outputs": [
68                     2
69                   ],
70                   "operators": [
71                     {
72                       "opcode_index": 0,
73                       "inputs": [
74                         0,
75                         1
76                       ],
77                       "outputs": [
78                         2
79                       ],
80                       "builtin_options_type": "NONE",
81                       "custom_options_format": "FLEXBUFFERS"
82                     }
83                   ],
84                   "name": "main"
85                 }
86               ],
87               "description": "MLIR Converted.",
88               "buffers": [
89                 {
90                 },
91                 {
92                 },
93                 { )" + alphaData + R"(
94                 },
95                 {
96                 }
97               ]
98             }
99         )";
100         Setup();
101     }
102 };
103 
104 struct PreluNetworkFixture : public ParserFlatbuffersFixture
105 {
PreluNetworkFixturePreluNetworkFixture106     explicit PreluNetworkFixture()
107     {
108         m_JsonString = R"(
109             {
110               "version": 3,
111               "operator_codes": [
112                 {
113                   "builtin_code": "PRELU",
114                   "version": 1
115                 },
116                 {
117                   "builtin_code": "MUL",
118                   "version": 1
119                 },
120                 {
121                   "builtin_code": "ADD",
122                   "version": 1
123                 }
124               ],
125               "subgraphs": [
126                 {
127                   "tensors": [
128                     {
129                       "shape": [
130                         1,
131                         2,
132                         3
133                       ],
134                       "type": "FLOAT32",
135                       "buffer": 6,
136                       "name": "output",
137                       "quantization": {
138                         "details_type": "NONE",
139                         "quantized_dimension": 0
140                       },
141                     },
142                     {
143                       "shape": [
144                         1,
145                         2,
146                         3
147                       ],
148                       "type": "FLOAT32",
149                       "buffer": 5,
150                       "name": "mul",
151                       "quantization": {
152                         "details_type": "NONE",
153                         "quantized_dimension": 0
154                       }
155                     },
156                     {
157                       "shape": [
158                         1,
159                         2,
160                         3
161                       ],
162                       "type": "FLOAT32",
163                       "buffer": 1,
164                       "name": "input0",
165                       "quantization": {
166                         "details_type": "NONE",
167                         "quantized_dimension": 0
168                       }
169                     },
170                     {
171                       "shape": [
172                         2,
173                         3
174                       ],
175                       "type": "FLOAT32",
176                       "buffer": 2,
177                       "name": "alpha",
178                       "quantization": {
179                         "details_type": "NONE",
180                         "quantized_dimension": 0
181                       }
182                     },
183                     {
184                       "shape": [
185                         1
186                       ],
187                       "type": "FLOAT32",
188                       "buffer": 3,
189                       "name": "const0",
190                       "quantization": {
191                         "details_type": "NONE",
192                         "quantized_dimension": 0
193                       }
194                     },
195                     {
196                       "shape": [
197                         1,
198                         2,
199                         3
200                       ],
201                       "type": "FLOAT32",
202                       "buffer": 4,
203                       "name": "prelumul",
204                       "quantization": {
205                         "details_type": "NONE",
206                         "quantized_dimension": 0
207                       }
208                     }
209                   ],
210                   "inputs": [
211                     2
212                   ],
213                   "outputs": [
214                     0
215                   ],
216                   "operators": [
217                     {
218                       "opcode_index": 0,
219                       "inputs": [
220                         2,
221                         3
222                       ],
223                       "outputs": [
224                         5
225                       ],
226                       "builtin_options_type": "NONE",
227                       "custom_options_format": "FLEXBUFFERS"
228                     },
229                     {
230                       "opcode_index": 1,
231                       "inputs": [
232                         5,
233                         4
234                       ],
235                       "outputs": [
236                         1
237                       ],
238                       "builtin_options_type": "MulOptions",
239                       "builtin_options": {
240                         "fused_activation_function": "NONE"
241                       },
242                       "custom_options_format": "FLEXBUFFERS"
243                     },
244                     {
245                       "opcode_index": 2,
246                       "inputs": [
247                         5,
248                         1
249                       ],
250                       "outputs": [
251                         0
252                       ],
253                       "builtin_options_type": "AddOptions",
254                       "builtin_options": {
255                         "fused_activation_function": "NONE"
256                       },
257                       "custom_options_format": "FLEXBUFFERS"
258                     }
259                   ],
260                   "name": "main"
261                 }
262               ],
263               "buffers": [
264                 {
265                 },
266                 {
267                 },
268                 {
269                   "data": [
270                     0,
271                     0,
272                     128,
273                     62,
274                     0,
275                     0,
276                     128,
277                     62,
278                     0,
279                     0,
280                     128,
281                     62,
282                     0,
283                     0,
284                     128,
285                     62,
286                     0,
287                     0,
288                     128,
289                     62,
290                     0,
291                     0,
292                     128,
293                     62
294                   ]
295                 },
296                 {
297                   "data": [
298                     0,
299                     0,
300                     160,
301                     64
302                   ]
303                 },
304                 {
305                 },
306                 {
307                 },
308                 {
309                 },
310                 {
311                 }
312               ],
313             }
314         )";
315         Setup();
316     }
317 };
318 
319 struct SimplePreluFixture : PreluFixture
320 {
SimplePreluFixtureSimplePreluFixture321     SimplePreluFixture() : PreluFixture("[ 2, 3 ]",
322                                         "[ 1 ]",
323                                         "[ 2, 3 ]",
324                                         "[ 0, 1 ]",
325                                         "") {}
326 };
327 
328 struct PreluConstAlphaFixture : PreluFixture
329 {
PreluConstAlphaFixturePreluConstAlphaFixture330     PreluConstAlphaFixture() : PreluFixture(
331         "[ 1, 2, 3 ]",
332         "[ 1, 2, 3 ]",
333         "[ 1, 2, 3 ]",
334         "[ 0 ]",
335         "\"data\": [ 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62 ]"){}
336 };
337 
338 struct PreluBroadcastAlphaFixture : PreluFixture
339 {
PreluBroadcastAlphaFixturePreluBroadcastAlphaFixture340     PreluBroadcastAlphaFixture() : PreluFixture(
341         "[ 1, 1, 2, 3 ]",
342         "[ 1, 3 ]",
343         "[ 1, 1, 2, 3 ]",
344         "[ 0 ]",
345         "\"data\": [ 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62 ]"){}
346 };
347 
348 struct PreluDynamicTensorFixture : PreluFixture
349 {
PreluDynamicTensorFixturePreluDynamicTensorFixture350     PreluDynamicTensorFixture() : PreluFixture("[ 2, 3 ]",
351                                                "[ 1, 1 ]",
352                                                "[]",
353                                                "[ 0 ]",
354                                                "\"data\": [ 0, 0, 128, 62 ]") {}
355 };
356 
357 TEST_CASE_FIXTURE(SimplePreluFixture, "SimplePrelu")
358 {
359   RunTest<2, armnn::DataType::Float32>(
360       0,
361       {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }},{"input1", { 0.25f }}},
362       {{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}});
363 }
364 
365 TEST_CASE_FIXTURE(PreluConstAlphaFixture, "PreluConstAlpha")
366 {
367   RunTest<3, armnn::DataType::Float32>(
368       0,
369       {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
370       {{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}});
371 }
372 
373 TEST_CASE_FIXTURE(PreluBroadcastAlphaFixture, "PreluBroadcastAlpha")
374 {
375   RunTest<4, armnn::DataType::Float32>(
376       0,
377       {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
378       {{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}});
379 }
380 
381 TEST_CASE_FIXTURE(PreluDynamicTensorFixture, "PreluDynamicTensor")
382 {
383   RunTest<2, armnn::DataType::Float32, armnn::DataType::Float32>(
384       0,
385       {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
386       {{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}},
387       true);
388 }
389 
390 TEST_CASE_FIXTURE(PreluNetworkFixture, "PreluNetwork")
391 {
392   RunTest<3, armnn::DataType::Float32>(
393       0,
394       {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
395       {{"output", { -21.f, 12.f, 0.f, 6.f, -7.5f, 84.f }}});
396 }
397 
398 }
399