1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 9 TEST_SUITE("OnnxParser_Addition") 10 { 11 struct AddMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 12 { AddMainFixtureAddMainFixture13 AddMainFixture(const std::string& dataType) 14 { 15 m_Prototext = R"( 16 ir_version: 3 17 producer_name: "CNTK" 18 producer_version: "2.5.1" 19 domain: "ai.cntk" 20 model_version: 1 21 graph { 22 name: "CNTKGraph" 23 input { 24 name: "Input0" 25 type { 26 tensor_type { 27 elem_type: )" + dataType + R"( 28 shape { 29 dim { 30 dim_value: 1 31 } 32 dim { 33 dim_value: 1 34 } 35 dim { 36 dim_value: 2 37 } 38 dim { 39 dim_value: 2 40 } 41 } 42 } 43 } 44 } 45 input { 46 name: "Input1" 47 type { 48 tensor_type { 49 elem_type: )" + dataType + R"( 50 shape { 51 dim { 52 dim_value: 1 53 } 54 dim { 55 dim_value: 1 56 } 57 dim { 58 dim_value: 2 59 } 60 dim { 61 dim_value: 2 62 } 63 } 64 } 65 } 66 } 67 node { 68 input: "Input0" 69 input: "Input1" 70 output: "Output" 71 name: "addition" 72 op_type: "Add" 73 doc_string: "" 74 domain: "" 75 } 76 output { 77 name: "Output" 78 type { 79 tensor_type { 80 elem_type: 1 81 shape { 82 dim { 83 dim_value: 1 84 } 85 dim { 86 dim_value: 1 87 } 88 dim { 89 dim_value: 2 90 } 91 dim { 92 dim_value: 2 93 } 94 } 95 } 96 } 97 } 98 } 99 opset_import { 100 version: 7 101 })"; 102 } 103 }; 104 105 struct AddValidFixture : AddMainFixture 106 { AddValidFixtureAddValidFixture107 AddValidFixture() : AddMainFixture("1") { 108 Setup(); 109 } 110 }; 111 112 struct AddInvalidFixture : AddMainFixture 113 { AddInvalidFixtureAddInvalidFixture114 AddInvalidFixture() : AddMainFixture("6") { } 115 }; 116 117 struct AddValidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 118 { AddValidBroadcastFixtureAddValidBroadcastFixture119 AddValidBroadcastFixture() { 120 121 m_Prototext = R"( 122 ir_version: 3 123 producer_name: "CNTK" 124 producer_version: "2.5.1" 125 domain: "ai.cntk" 126 model_version: 1 127 graph { 128 name: "CNTKGraph" 129 input { 130 name: "Input0" 131 type { 132 tensor_type { 133 elem_type: 1 134 shape { 135 dim { 136 dim_value: 1 137 } 138 dim { 139 dim_value: 1 140 } 141 dim { 142 dim_value: 1 143 } 144 dim { 145 dim_value: 4 146 } 147 } 148 } 149 } 150 } 151 input { 152 name: "Input1" 153 type { 154 tensor_type { 155 elem_type: 1 156 shape { 157 dim { 158 dim_value: 4 159 } 160 } 161 } 162 } 163 } 164 node { 165 input: "Input0" 166 input: "Input1" 167 output: "Output" 168 name: "addition" 169 op_type: "Add" 170 doc_string: "" 171 domain: "" 172 } 173 output { 174 name: "Output" 175 type { 176 tensor_type { 177 elem_type: 1 178 shape { 179 dim { 180 dim_value: 1 181 } 182 dim { 183 dim_value: 1 184 } 185 dim { 186 dim_value: 1 187 } 188 dim { 189 dim_value: 4 190 } 191 } 192 } 193 } 194 } 195 } 196 opset_import { 197 version: 7 198 })"; 199 Setup(); 200 } 201 }; 202 203 struct AddInvalidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 204 { AddInvalidBroadcastFixtureAddInvalidBroadcastFixture205 AddInvalidBroadcastFixture() { 206 207 m_Prototext = R"( 208 ir_version: 3 209 producer_name: "CNTK" 210 producer_version: "2.5.1" 211 domain: "ai.cntk" 212 model_version: 1 213 graph { 214 name: "CNTKGraph" 215 input { 216 name: "Input0" 217 type { 218 tensor_type { 219 elem_type: 1 220 shape { 221 dim { 222 dim_value: 1 223 } 224 dim { 225 dim_value: 1 226 } 227 dim { 228 dim_value: 1 229 } 230 dim { 231 dim_value: 3 232 } 233 } 234 } 235 } 236 } 237 input { 238 name: "Input1" 239 type { 240 tensor_type { 241 elem_type: 1 242 shape { 243 dim { 244 dim_value: 4 245 } 246 } 247 } 248 } 249 } 250 node { 251 input: "Input0" 252 input: "Input1" 253 output: "Output" 254 name: "addition" 255 op_type: "Add" 256 doc_string: "" 257 domain: "" 258 } 259 output { 260 name: "Output" 261 type { 262 tensor_type { 263 elem_type: 1 264 shape { 265 dim { 266 dim_value: 1 267 } 268 dim { 269 dim_value: 1 270 } 271 dim { 272 dim_value: 1 273 } 274 dim { 275 dim_value: 4 276 } 277 } 278 } 279 } 280 } 281 } 282 opset_import { 283 version: 7 284 })"; 285 } 286 }; 287 288 struct AddScalarFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 289 { AddScalarFixtureAddScalarFixture290 AddScalarFixture(const std::string& dataType) 291 { 292 m_Prototext = R"( 293 ir_version: 3 294 producer_name: "CNTK" 295 producer_version: "2.5.1" 296 domain: "ai.cntk" 297 model_version: 1 298 graph { 299 name: "CNTKGraph" 300 input { 301 name: "Input0" 302 type { 303 tensor_type { 304 elem_type: )" + dataType + R"( 305 shape { 306 dim { 307 dim_value: 1 308 } 309 dim { 310 dim_value: 1 311 } 312 dim { 313 dim_value: 2 314 } 315 dim { 316 dim_value: 2 317 } 318 } 319 } 320 } 321 } 322 input { 323 name: "Input1" 324 type { 325 tensor_type { 326 elem_type: )" + dataType + R"( 327 shape { 328 dim { 329 dim_value: 1 330 } 331 } 332 } 333 } 334 } 335 node { 336 input: "Input0" 337 input: "Input1" 338 output: "Output" 339 name: "addition" 340 op_type: "Add" 341 doc_string: "" 342 domain: "" 343 } 344 output { 345 name: "Output" 346 type { 347 tensor_type { 348 elem_type: 1 349 shape { 350 dim { 351 dim_value: 1 352 } 353 dim { 354 dim_value: 1 355 } 356 dim { 357 dim_value: 2 358 } 359 dim { 360 dim_value: 2 361 } 362 } 363 } 364 } 365 } 366 } 367 opset_import { 368 version: 7 369 })"; 370 } 371 }; 372 373 struct AddValidScalarFixture : AddScalarFixture 374 { AddValidScalarFixtureAddValidScalarFixture375 AddValidScalarFixture() : AddScalarFixture("1") { 376 Setup(); 377 } 378 }; 379 380 struct AddInvalidScalarFixture : AddScalarFixture 381 { AddInvalidScalarFixtureAddInvalidScalarFixture382 AddInvalidScalarFixture() : AddScalarFixture("6") { } 383 }; 384 385 TEST_CASE_FIXTURE(AddValidFixture, "ValidAddTest") 386 { 387 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}}, 388 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}}); 389 } 390 391 TEST_CASE_FIXTURE(AddInvalidFixture, "IncorrectDataTypeAdd") 392 { 393 CHECK_THROWS_AS(Setup(), armnn::ParseException); 394 } 395 396 TEST_CASE_FIXTURE(AddInvalidBroadcastFixture, "InvalidBroadcastAdd") 397 { 398 CHECK_THROWS_AS(Setup(), armnn::ParseException); 399 } 400 401 TEST_CASE_FIXTURE(AddValidBroadcastFixture, "ValidBroadcastAdd") 402 { 403 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}}, 404 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}}); 405 } 406 407 TEST_CASE_FIXTURE(AddValidScalarFixture, "ValidAddScalarTest") 408 { 409 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}}, 410 {"Input1", {-8.0f}}}, {{"Output", {-7.0, -6.0, -11.0, -12.0}}}); 411 } 412 413 TEST_CASE_FIXTURE(AddInvalidScalarFixture, "IncorrectDataTypeAddScalar") 414 { 415 CHECK_THROWS_AS(Setup(), armnn::ParseException); 416 } 417 418 }