1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 #include "OnnxParserTestUtils.hpp" 9 10 TEST_SUITE("OnnxParser_Shape") 11 { 12 13 struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 14 { ShapeMainFixtureShapeMainFixture15 ShapeMainFixture(const std::string& inputType, 16 const std::string& outputType, 17 const std::string& outputDim, 18 const std::vector<int>& inputShape) 19 { 20 m_Prototext = R"( 21 ir_version: 8 22 producer_name: "onnx-example" 23 graph { 24 node { 25 input: "Input" 26 output: "Output" 27 op_type: "Shape" 28 } 29 name: "shape-model" 30 input { 31 name: "Input" 32 type { 33 tensor_type { 34 elem_type: )" + inputType + R"( 35 shape { 36 )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( 37 } 38 } 39 } 40 } 41 output { 42 name: "Output" 43 type { 44 tensor_type { 45 elem_type: )" + outputType + R"( 46 shape { 47 dim { 48 dim_value: )" + outputDim + R"( 49 } 50 } 51 } 52 } 53 } 54 } 55 opset_import { 56 version: 10 57 })"; 58 } 59 }; 60 61 struct ShapeFloatFixture : ShapeMainFixture 62 { ShapeFloatFixtureShapeFloatFixture63 ShapeFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 }) 64 { 65 Setup(); 66 } 67 }; 68 69 struct ShapeIntFixture : ShapeMainFixture 70 { ShapeIntFixtureShapeIntFixture71 ShapeIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 }) 72 { 73 Setup(); 74 } 75 }; 76 77 struct Shape3DFixture : ShapeMainFixture 78 { Shape3DFixtureShape3DFixture79 Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 }) 80 { 81 Setup(); 82 } 83 }; 84 85 struct Shape2DFixture : ShapeMainFixture 86 { Shape2DFixtureShape2DFixture87 Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 }) 88 { 89 Setup(); 90 } 91 }; 92 93 struct Shape1DFixture : ShapeMainFixture 94 { Shape1DFixtureShape1DFixture95 Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 }) 96 { 97 Setup(); 98 } 99 }; 100 101 TEST_CASE_FIXTURE(ShapeFloatFixture, "FloatValidShapeTest") 102 { 103 RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 104 4.0f, 3.0f, 2.0f, 1.0f, 0.0f, 105 0.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, {{"Output", { 1, 3, 1, 5 }}}); 106 } 107 108 TEST_CASE_FIXTURE(ShapeIntFixture, "IntValidShapeTest") 109 { 110 RunTest<1, int>({{"Input", { 0, 1, 2, 3, 4, 111 4, 3, 2, 1, 0, 112 0, 1, 2, 3, 4 }}}, {{"Output", { 1, 3, 1, 5 }}}); 113 } 114 115 TEST_CASE_FIXTURE(Shape3DFixture, "Shape3DTest") 116 { 117 RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 118 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f, 119 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 3, 2, 3 }}}); 120 } 121 122 TEST_CASE_FIXTURE(Shape2DFixture, "Shape2DTest") 123 { 124 RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}}); 125 } 126 127 TEST_CASE_FIXTURE(Shape1DFixture, "Shape1DTest") 128 { 129 RunTest<1, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}}); 130 } 131 132 } 133