xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Addition.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include  "ParserPrototxtFixture.hpp"
8 
9 TEST_SUITE("OnnxParser_Addition")
10 {
11 struct AddMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12 {
AddMainFixtureAddMainFixture13     AddMainFixture(const std::string& dataType)
14     {
15         m_Prototext = R"(
16                    ir_version: 3
17                    producer_name:  "CNTK"
18                    producer_version:  "2.5.1"
19                    domain:  "ai.cntk"
20                    model_version: 1
21                    graph {
22                      name:  "CNTKGraph"
23                      input {
24                         name: "Input0"
25                         type {
26                           tensor_type {
27                             elem_type: )" + dataType + R"(
28                             shape {
29                               dim {
30                                 dim_value: 1
31                               }
32                               dim {
33                                 dim_value: 1
34                               }
35                               dim {
36                                 dim_value: 2
37                               }
38                               dim {
39                                 dim_value: 2
40                               }
41                             }
42                           }
43                         }
44                       }
45                       input {
46                          name: "Input1"
47                          type {
48                            tensor_type {
49                              elem_type: )" + dataType + R"(
50                              shape {
51                                dim {
52                                  dim_value: 1
53                                }
54                                dim {
55                                  dim_value: 1
56                                }
57                                dim {
58                                  dim_value: 2
59                                }
60                                dim {
61                                  dim_value: 2
62                                }
63                              }
64                            }
65                          }
66                        }
67                        node {
68                             input: "Input0"
69                             input: "Input1"
70                             output: "Output"
71                             name: "addition"
72                             op_type: "Add"
73                             doc_string: ""
74                             domain: ""
75                           }
76                           output {
77                               name: "Output"
78                               type {
79                                  tensor_type {
80                                    elem_type: 1
81                                    shape {
82                                        dim {
83                                            dim_value: 1
84                                        }
85                                        dim {
86                                            dim_value: 1
87                                        }
88                                        dim {
89                                            dim_value: 2
90                                        }
91                                        dim {
92                                            dim_value: 2
93                                        }
94                                    }
95                                 }
96                             }
97                         }
98                     }
99                    opset_import {
100                       version: 7
101                     })";
102     }
103 };
104 
105 struct AddValidFixture : AddMainFixture
106 {
AddValidFixtureAddValidFixture107     AddValidFixture() : AddMainFixture("1") {
108         Setup();
109     }
110 };
111 
112 struct AddInvalidFixture : AddMainFixture
113 {
AddInvalidFixtureAddInvalidFixture114     AddInvalidFixture() : AddMainFixture("6") { }
115 };
116 
117 struct AddValidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
118 {
AddValidBroadcastFixtureAddValidBroadcastFixture119     AddValidBroadcastFixture() {
120 
121         m_Prototext = R"(
122                    ir_version: 3
123                    producer_name:  "CNTK"
124                    producer_version:  "2.5.1"
125                    domain:  "ai.cntk"
126                    model_version: 1
127                    graph {
128                      name:  "CNTKGraph"
129                      input {
130                         name: "Input0"
131                         type {
132                           tensor_type {
133                             elem_type: 1
134                             shape {
135                               dim {
136                                 dim_value: 1
137                               }
138                               dim {
139                                 dim_value: 1
140                               }
141                               dim {
142                                 dim_value: 1
143                               }
144                               dim {
145                                 dim_value: 4
146                               }
147                             }
148                           }
149                         }
150                       }
151                       input {
152                          name: "Input1"
153                          type {
154                            tensor_type {
155                              elem_type: 1
156                              shape {
157                                  dim {
158                                    dim_value: 4
159                                  }
160                              }
161                            }
162                          }
163                        }
164                        node {
165                             input: "Input0"
166                             input: "Input1"
167                             output: "Output"
168                             name: "addition"
169                             op_type: "Add"
170                             doc_string: ""
171                             domain: ""
172                           }
173                           output {
174                               name: "Output"
175                               type {
176                                  tensor_type {
177                                    elem_type: 1
178                                    shape {
179                                        dim {
180                                            dim_value: 1
181                                        }
182                                        dim {
183                                            dim_value: 1
184                                        }
185                                        dim {
186                                            dim_value: 1
187                                        }
188                                        dim {
189                                            dim_value: 4
190                                        }
191                                    }
192                                 }
193                             }
194                         }
195                     }
196                    opset_import {
197                       version: 7
198                     })";
199         Setup();
200     }
201 };
202 
203 struct AddInvalidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
204 {
AddInvalidBroadcastFixtureAddInvalidBroadcastFixture205     AddInvalidBroadcastFixture() {
206 
207         m_Prototext = R"(
208                    ir_version: 3
209                    producer_name:  "CNTK"
210                    producer_version:  "2.5.1"
211                    domain:  "ai.cntk"
212                    model_version: 1
213                    graph {
214                      name:  "CNTKGraph"
215                      input {
216                         name: "Input0"
217                         type {
218                           tensor_type {
219                             elem_type: 1
220                             shape {
221                               dim {
222                                 dim_value: 1
223                               }
224                               dim {
225                                 dim_value: 1
226                               }
227                               dim {
228                                 dim_value: 1
229                               }
230                               dim {
231                                 dim_value: 3
232                               }
233                             }
234                           }
235                         }
236                       }
237                       input {
238                          name: "Input1"
239                          type {
240                            tensor_type {
241                              elem_type: 1
242                              shape {
243                                  dim {
244                                    dim_value: 4
245                                  }
246                              }
247                            }
248                          }
249                        }
250                        node {
251                             input: "Input0"
252                             input: "Input1"
253                             output: "Output"
254                             name: "addition"
255                             op_type: "Add"
256                             doc_string: ""
257                             domain: ""
258                           }
259                           output {
260                               name: "Output"
261                               type {
262                                  tensor_type {
263                                    elem_type: 1
264                                    shape {
265                                        dim {
266                                            dim_value: 1
267                                        }
268                                        dim {
269                                            dim_value: 1
270                                        }
271                                        dim {
272                                            dim_value: 1
273                                        }
274                                        dim {
275                                            dim_value: 4
276                                        }
277                                    }
278                                 }
279                             }
280                         }
281                     }
282                    opset_import {
283                       version: 7
284                     })";
285     }
286 };
287 
288 struct AddScalarFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
289 {
AddScalarFixtureAddScalarFixture290     AddScalarFixture(const std::string& dataType)
291     {
292         m_Prototext = R"(
293                    ir_version: 3
294                    producer_name:  "CNTK"
295                    producer_version:  "2.5.1"
296                    domain:  "ai.cntk"
297                    model_version: 1
298                    graph {
299                      name:  "CNTKGraph"
300                      input {
301                         name: "Input0"
302                         type {
303                           tensor_type {
304                             elem_type: )" + dataType + R"(
305                             shape {
306                               dim {
307                                 dim_value: 1
308                               }
309                               dim {
310                                 dim_value: 1
311                               }
312                               dim {
313                                 dim_value: 2
314                               }
315                               dim {
316                                 dim_value: 2
317                               }
318                             }
319                           }
320                         }
321                       }
322                       input {
323                          name: "Input1"
324                          type {
325                            tensor_type {
326                              elem_type: )" + dataType + R"(
327                              shape {
328                                dim {
329                                  dim_value: 1
330                                }
331                              }
332                            }
333                          }
334                        }
335                        node {
336                             input: "Input0"
337                             input: "Input1"
338                             output: "Output"
339                             name: "addition"
340                             op_type: "Add"
341                             doc_string: ""
342                             domain: ""
343                           }
344                           output {
345                               name: "Output"
346                               type {
347                                  tensor_type {
348                                    elem_type: 1
349                                    shape {
350                                        dim {
351                                            dim_value: 1
352                                        }
353                                        dim {
354                                            dim_value: 1
355                                        }
356                                        dim {
357                                            dim_value: 2
358                                        }
359                                        dim {
360                                            dim_value: 2
361                                        }
362                                    }
363                                 }
364                             }
365                         }
366                     }
367                    opset_import {
368                       version: 7
369                     })";
370     }
371 };
372 
373 struct AddValidScalarFixture : AddScalarFixture
374 {
AddValidScalarFixtureAddValidScalarFixture375     AddValidScalarFixture() : AddScalarFixture("1") {
376         Setup();
377     }
378 };
379 
380 struct AddInvalidScalarFixture : AddScalarFixture
381 {
AddInvalidScalarFixtureAddInvalidScalarFixture382     AddInvalidScalarFixture() : AddScalarFixture("6") { }
383 };
384 
385 TEST_CASE_FIXTURE(AddValidFixture, "ValidAddTest")
386 {
387     RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
388                 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
389 }
390 
391 TEST_CASE_FIXTURE(AddInvalidFixture, "IncorrectDataTypeAdd")
392 {
393    CHECK_THROWS_AS(Setup(), armnn::ParseException);
394 }
395 
396 TEST_CASE_FIXTURE(AddInvalidBroadcastFixture, "InvalidBroadcastAdd")
397 {
398    CHECK_THROWS_AS(Setup(), armnn::ParseException);
399 }
400 
401 TEST_CASE_FIXTURE(AddValidBroadcastFixture, "ValidBroadcastAdd")
402 {
403     RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
404                 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
405 }
406 
407 TEST_CASE_FIXTURE(AddValidScalarFixture, "ValidAddScalarTest")
408 {
409     RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
410                 {"Input1", {-8.0f}}}, {{"Output", {-7.0, -6.0, -11.0, -12.0}}});
411 }
412 
413 TEST_CASE_FIXTURE(AddInvalidScalarFixture, "IncorrectDataTypeAddScalar")
414 {
415     CHECK_THROWS_AS(Setup(), armnn::ParseException);
416 }
417 
418 }