1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include "OnnxParserTestUtils.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Reshape") 11*89c4ff92SAndroid Build Coastguard Worker { 12*89c4ff92SAndroid Build Coastguard Worker struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 13*89c4ff92SAndroid Build Coastguard Worker { ReshapeMainFixtureReshapeMainFixture14*89c4ff92SAndroid Build Coastguard Worker ReshapeMainFixture(const std::string& dataType) 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 17*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 18*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK" 19*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1" 20*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk" 21*89c4ff92SAndroid Build Coastguard Worker model_version: 1 22*89c4ff92SAndroid Build Coastguard Worker graph { 23*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph" 24*89c4ff92SAndroid Build Coastguard Worker input { 25*89c4ff92SAndroid Build Coastguard Worker name: "Input" 26*89c4ff92SAndroid Build Coastguard Worker type { 27*89c4ff92SAndroid Build Coastguard Worker tensor_type { 28*89c4ff92SAndroid Build Coastguard Worker elem_type: )" + dataType + R"( 29*89c4ff92SAndroid Build Coastguard Worker shape { 30*89c4ff92SAndroid Build Coastguard Worker dim { 31*89c4ff92SAndroid Build Coastguard Worker dim_value: 4 32*89c4ff92SAndroid Build Coastguard Worker } 33*89c4ff92SAndroid Build Coastguard Worker } 34*89c4ff92SAndroid Build Coastguard Worker } 35*89c4ff92SAndroid Build Coastguard Worker } 36*89c4ff92SAndroid Build Coastguard Worker } 37*89c4ff92SAndroid Build Coastguard Worker input { 38*89c4ff92SAndroid Build Coastguard Worker name: "Shape" 39*89c4ff92SAndroid Build Coastguard Worker type { 40*89c4ff92SAndroid Build Coastguard Worker tensor_type { 41*89c4ff92SAndroid Build Coastguard Worker elem_type: 7 42*89c4ff92SAndroid Build Coastguard Worker shape { 43*89c4ff92SAndroid Build Coastguard Worker dim { 44*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 45*89c4ff92SAndroid Build Coastguard Worker } 46*89c4ff92SAndroid Build Coastguard Worker } 47*89c4ff92SAndroid Build Coastguard Worker } 48*89c4ff92SAndroid Build Coastguard Worker } 49*89c4ff92SAndroid Build Coastguard Worker } 50*89c4ff92SAndroid Build Coastguard Worker node { 51*89c4ff92SAndroid Build Coastguard Worker input: "Input" 52*89c4ff92SAndroid Build Coastguard Worker input: "Shape" 53*89c4ff92SAndroid Build Coastguard Worker output: "Output" 54*89c4ff92SAndroid Build Coastguard Worker name: "reshape" 55*89c4ff92SAndroid Build Coastguard Worker op_type: "Reshape" 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker } 58*89c4ff92SAndroid Build Coastguard Worker initializer { 59*89c4ff92SAndroid Build Coastguard Worker dims: 2 60*89c4ff92SAndroid Build Coastguard Worker data_type: 7 61*89c4ff92SAndroid Build Coastguard Worker int64_data: 2 62*89c4ff92SAndroid Build Coastguard Worker int64_data: 2 63*89c4ff92SAndroid Build Coastguard Worker name: "Shape" 64*89c4ff92SAndroid Build Coastguard Worker } 65*89c4ff92SAndroid Build Coastguard Worker output { 66*89c4ff92SAndroid Build Coastguard Worker name: "Output" 67*89c4ff92SAndroid Build Coastguard Worker type { 68*89c4ff92SAndroid Build Coastguard Worker tensor_type { 69*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 70*89c4ff92SAndroid Build Coastguard Worker shape { 71*89c4ff92SAndroid Build Coastguard Worker dim { 72*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 73*89c4ff92SAndroid Build Coastguard Worker } 74*89c4ff92SAndroid Build Coastguard Worker dim { 75*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 76*89c4ff92SAndroid Build Coastguard Worker } 77*89c4ff92SAndroid Build Coastguard Worker } 78*89c4ff92SAndroid Build Coastguard Worker } 79*89c4ff92SAndroid Build Coastguard Worker } 80*89c4ff92SAndroid Build Coastguard Worker } 81*89c4ff92SAndroid Build Coastguard Worker } 82*89c4ff92SAndroid Build Coastguard Worker opset_import { 83*89c4ff92SAndroid Build Coastguard Worker version: 7 84*89c4ff92SAndroid Build Coastguard Worker })"; 85*89c4ff92SAndroid Build Coastguard Worker } 86*89c4ff92SAndroid Build Coastguard Worker }; 87*89c4ff92SAndroid Build Coastguard Worker 88*89c4ff92SAndroid Build Coastguard Worker struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 89*89c4ff92SAndroid Build Coastguard Worker { ReshapeRank4FixtureReshapeRank4Fixture90*89c4ff92SAndroid Build Coastguard Worker ReshapeRank4Fixture(const std::string& dataType) 91*89c4ff92SAndroid Build Coastguard Worker { 92*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 93*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 94*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK" 95*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1" 96*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk" 97*89c4ff92SAndroid Build Coastguard Worker model_version: 1 98*89c4ff92SAndroid Build Coastguard Worker graph { 99*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph" 100*89c4ff92SAndroid Build Coastguard Worker input { 101*89c4ff92SAndroid Build Coastguard Worker name: "Input" 102*89c4ff92SAndroid Build Coastguard Worker type { 103*89c4ff92SAndroid Build Coastguard Worker tensor_type { 104*89c4ff92SAndroid Build Coastguard Worker elem_type: )" + dataType + R"( 105*89c4ff92SAndroid Build Coastguard Worker shape { 106*89c4ff92SAndroid Build Coastguard Worker dim { 107*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 108*89c4ff92SAndroid Build Coastguard Worker } 109*89c4ff92SAndroid Build Coastguard Worker dim { 110*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 111*89c4ff92SAndroid Build Coastguard Worker } 112*89c4ff92SAndroid Build Coastguard Worker dim { 113*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 114*89c4ff92SAndroid Build Coastguard Worker } 115*89c4ff92SAndroid Build Coastguard Worker dim { 116*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 117*89c4ff92SAndroid Build Coastguard Worker } 118*89c4ff92SAndroid Build Coastguard Worker } 119*89c4ff92SAndroid Build Coastguard Worker } 120*89c4ff92SAndroid Build Coastguard Worker } 121*89c4ff92SAndroid Build Coastguard Worker } 122*89c4ff92SAndroid Build Coastguard Worker input { 123*89c4ff92SAndroid Build Coastguard Worker name: "Shape" 124*89c4ff92SAndroid Build Coastguard Worker type { 125*89c4ff92SAndroid Build Coastguard Worker tensor_type { 126*89c4ff92SAndroid Build Coastguard Worker elem_type: 7 127*89c4ff92SAndroid Build Coastguard Worker shape { 128*89c4ff92SAndroid Build Coastguard Worker dim { 129*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 130*89c4ff92SAndroid Build Coastguard Worker } 131*89c4ff92SAndroid Build Coastguard Worker } 132*89c4ff92SAndroid Build Coastguard Worker } 133*89c4ff92SAndroid Build Coastguard Worker } 134*89c4ff92SAndroid Build Coastguard Worker } 135*89c4ff92SAndroid Build Coastguard Worker node { 136*89c4ff92SAndroid Build Coastguard Worker input: "Input" 137*89c4ff92SAndroid Build Coastguard Worker input: "Shape" 138*89c4ff92SAndroid Build Coastguard Worker output: "Output" 139*89c4ff92SAndroid Build Coastguard Worker name: "reshape" 140*89c4ff92SAndroid Build Coastguard Worker op_type: "Reshape" 141*89c4ff92SAndroid Build Coastguard Worker 142*89c4ff92SAndroid Build Coastguard Worker } 143*89c4ff92SAndroid Build Coastguard Worker initializer { 144*89c4ff92SAndroid Build Coastguard Worker dims: 2 145*89c4ff92SAndroid Build Coastguard Worker data_type: 7 146*89c4ff92SAndroid Build Coastguard Worker int64_data: 2 147*89c4ff92SAndroid Build Coastguard Worker int64_data: 2 148*89c4ff92SAndroid Build Coastguard Worker name: "Shape" 149*89c4ff92SAndroid Build Coastguard Worker } 150*89c4ff92SAndroid Build Coastguard Worker output { 151*89c4ff92SAndroid Build Coastguard Worker name: "Output" 152*89c4ff92SAndroid Build Coastguard Worker type { 153*89c4ff92SAndroid Build Coastguard Worker tensor_type { 154*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 155*89c4ff92SAndroid Build Coastguard Worker shape { 156*89c4ff92SAndroid Build Coastguard Worker dim { 157*89c4ff92SAndroid Build Coastguard Worker dim_value: 6 158*89c4ff92SAndroid Build Coastguard Worker } 159*89c4ff92SAndroid Build Coastguard Worker dim { 160*89c4ff92SAndroid Build Coastguard Worker dim_value: 6 161*89c4ff92SAndroid Build Coastguard Worker } 162*89c4ff92SAndroid Build Coastguard Worker } 163*89c4ff92SAndroid Build Coastguard Worker } 164*89c4ff92SAndroid Build Coastguard Worker } 165*89c4ff92SAndroid Build Coastguard Worker } 166*89c4ff92SAndroid Build Coastguard Worker } 167*89c4ff92SAndroid Build Coastguard Worker opset_import { 168*89c4ff92SAndroid Build Coastguard Worker version: 7 169*89c4ff92SAndroid Build Coastguard Worker })"; 170*89c4ff92SAndroid Build Coastguard Worker } 171*89c4ff92SAndroid Build Coastguard Worker }; 172*89c4ff92SAndroid Build Coastguard Worker 173*89c4ff92SAndroid Build Coastguard Worker struct ReshapeValidFixture : ReshapeMainFixture 174*89c4ff92SAndroid Build Coastguard Worker { ReshapeValidFixtureReshapeValidFixture175*89c4ff92SAndroid Build Coastguard Worker ReshapeValidFixture() : ReshapeMainFixture("1") { 176*89c4ff92SAndroid Build Coastguard Worker Setup(); 177*89c4ff92SAndroid Build Coastguard Worker } 178*89c4ff92SAndroid Build Coastguard Worker }; 179*89c4ff92SAndroid Build Coastguard Worker 180*89c4ff92SAndroid Build Coastguard Worker struct ReshapeValidRank4Fixture : ReshapeRank4Fixture 181*89c4ff92SAndroid Build Coastguard Worker { ReshapeValidRank4FixtureReshapeValidRank4Fixture182*89c4ff92SAndroid Build Coastguard Worker ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") { 183*89c4ff92SAndroid Build Coastguard Worker Setup(); 184*89c4ff92SAndroid Build Coastguard Worker } 185*89c4ff92SAndroid Build Coastguard Worker }; 186*89c4ff92SAndroid Build Coastguard Worker 187*89c4ff92SAndroid Build Coastguard Worker struct ReshapeInvalidFixture : ReshapeMainFixture 188*89c4ff92SAndroid Build Coastguard Worker { ReshapeInvalidFixtureReshapeInvalidFixture189*89c4ff92SAndroid Build Coastguard Worker ReshapeInvalidFixture() : ReshapeMainFixture("10") { } 190*89c4ff92SAndroid Build Coastguard Worker }; 191*89c4ff92SAndroid Build Coastguard Worker 192*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeValidFixture, "ValidReshapeTest") 193*89c4ff92SAndroid Build Coastguard Worker { 194*89c4ff92SAndroid Build Coastguard Worker RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}}); 195*89c4ff92SAndroid Build Coastguard Worker } 196*89c4ff92SAndroid Build Coastguard Worker 197*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeValidRank4Fixture, "ValidRank4ReshapeTest") 198*89c4ff92SAndroid Build Coastguard Worker { 199*89c4ff92SAndroid Build Coastguard Worker RunTest<2>( 200*89c4ff92SAndroid Build Coastguard Worker {{"Input", 201*89c4ff92SAndroid Build Coastguard Worker {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 202*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 203*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}}, 204*89c4ff92SAndroid Build Coastguard Worker {{"Output", 205*89c4ff92SAndroid Build Coastguard Worker {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 206*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 207*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}}); 208*89c4ff92SAndroid Build Coastguard Worker } 209*89c4ff92SAndroid Build Coastguard Worker 210*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape") 211*89c4ff92SAndroid Build Coastguard Worker { 212*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(Setup(), armnn::ParseException); 213*89c4ff92SAndroid Build Coastguard Worker } 214*89c4ff92SAndroid Build Coastguard Worker 215*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNegativeReshapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 216*89c4ff92SAndroid Build Coastguard Worker { ReshapeNegativeReshapeFixtureReshapeNegativeReshapeFixture217*89c4ff92SAndroid Build Coastguard Worker ReshapeNegativeReshapeFixture(const std::vector<int>& inputShape, 218*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& shapeInputShape, 219*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& outputShape, 220*89c4ff92SAndroid Build Coastguard Worker const std::string& shape) 221*89c4ff92SAndroid Build Coastguard Worker { 222*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 223*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 224*89c4ff92SAndroid Build Coastguard Worker producer_name: "onnx-example" 225*89c4ff92SAndroid Build Coastguard Worker graph { 226*89c4ff92SAndroid Build Coastguard Worker name: "ReshapeGrapn" 227*89c4ff92SAndroid Build Coastguard Worker input { 228*89c4ff92SAndroid Build Coastguard Worker name: "Input" 229*89c4ff92SAndroid Build Coastguard Worker type { 230*89c4ff92SAndroid Build Coastguard Worker tensor_type { 231*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 232*89c4ff92SAndroid Build Coastguard Worker shape { 233*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( 234*89c4ff92SAndroid Build Coastguard Worker } 235*89c4ff92SAndroid Build Coastguard Worker } 236*89c4ff92SAndroid Build Coastguard Worker } 237*89c4ff92SAndroid Build Coastguard Worker } 238*89c4ff92SAndroid Build Coastguard Worker input { 239*89c4ff92SAndroid Build Coastguard Worker name: "Shape" 240*89c4ff92SAndroid Build Coastguard Worker type { 241*89c4ff92SAndroid Build Coastguard Worker tensor_type { 242*89c4ff92SAndroid Build Coastguard Worker elem_type: 7 243*89c4ff92SAndroid Build Coastguard Worker shape { 244*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"( 245*89c4ff92SAndroid Build Coastguard Worker } 246*89c4ff92SAndroid Build Coastguard Worker } 247*89c4ff92SAndroid Build Coastguard Worker } 248*89c4ff92SAndroid Build Coastguard Worker } 249*89c4ff92SAndroid Build Coastguard Worker node { 250*89c4ff92SAndroid Build Coastguard Worker input: "Input" 251*89c4ff92SAndroid Build Coastguard Worker input: "Shape" 252*89c4ff92SAndroid Build Coastguard Worker output: "Output" 253*89c4ff92SAndroid Build Coastguard Worker name: "reshape" 254*89c4ff92SAndroid Build Coastguard Worker op_type: "Reshape" 255*89c4ff92SAndroid Build Coastguard Worker } 256*89c4ff92SAndroid Build Coastguard Worker initializer { 257*89c4ff92SAndroid Build Coastguard Worker dims: 2 258*89c4ff92SAndroid Build Coastguard Worker data_type: 7 259*89c4ff92SAndroid Build Coastguard Worker )" + shape + R"( 260*89c4ff92SAndroid Build Coastguard Worker name: "Shape" 261*89c4ff92SAndroid Build Coastguard Worker } 262*89c4ff92SAndroid Build Coastguard Worker output { 263*89c4ff92SAndroid Build Coastguard Worker name: "Output" 264*89c4ff92SAndroid Build Coastguard Worker type { 265*89c4ff92SAndroid Build Coastguard Worker tensor_type { 266*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 267*89c4ff92SAndroid Build Coastguard Worker shape { 268*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 269*89c4ff92SAndroid Build Coastguard Worker } 270*89c4ff92SAndroid Build Coastguard Worker } 271*89c4ff92SAndroid Build Coastguard Worker } 272*89c4ff92SAndroid Build Coastguard Worker } 273*89c4ff92SAndroid Build Coastguard Worker } 274*89c4ff92SAndroid Build Coastguard Worker opset_import { 275*89c4ff92SAndroid Build Coastguard Worker version: 7 276*89c4ff92SAndroid Build Coastguard Worker })"; 277*89c4ff92SAndroid Build Coastguard Worker } 278*89c4ff92SAndroid Build Coastguard Worker }; 279*89c4ff92SAndroid Build Coastguard Worker 280*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNegativeReshape1DFixture : ReshapeNegativeReshapeFixture 281*89c4ff92SAndroid Build Coastguard Worker { ReshapeNegativeReshape1DFixtureReshapeNegativeReshape1DFixture282*89c4ff92SAndroid Build Coastguard Worker ReshapeNegativeReshape1DFixture() : ReshapeNegativeReshapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }, "int64_data: -1") 283*89c4ff92SAndroid Build Coastguard Worker { 284*89c4ff92SAndroid Build Coastguard Worker Setup(); 285*89c4ff92SAndroid Build Coastguard Worker } 286*89c4ff92SAndroid Build Coastguard Worker }; 287*89c4ff92SAndroid Build Coastguard Worker 288*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNegativeReshape2DFixture : ReshapeNegativeReshapeFixture 289*89c4ff92SAndroid Build Coastguard Worker { ReshapeNegativeReshape2DFixtureReshapeNegativeReshape2DFixture290*89c4ff92SAndroid Build Coastguard Worker ReshapeNegativeReshape2DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 }, 291*89c4ff92SAndroid Build Coastguard Worker { 2 }, 292*89c4ff92SAndroid Build Coastguard Worker { 2, 6 }, 293*89c4ff92SAndroid Build Coastguard Worker "int64_data: -1 int64_data: 6") 294*89c4ff92SAndroid Build Coastguard Worker { 295*89c4ff92SAndroid Build Coastguard Worker Setup(); 296*89c4ff92SAndroid Build Coastguard Worker } 297*89c4ff92SAndroid Build Coastguard Worker }; 298*89c4ff92SAndroid Build Coastguard Worker 299*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNegativeReshape3DFixture : ReshapeNegativeReshapeFixture 300*89c4ff92SAndroid Build Coastguard Worker { ReshapeNegativeReshape3DFixtureReshapeNegativeReshape3DFixture301*89c4ff92SAndroid Build Coastguard Worker ReshapeNegativeReshape3DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 }, 302*89c4ff92SAndroid Build Coastguard Worker { 3 }, 303*89c4ff92SAndroid Build Coastguard Worker { 3, 1, 4 }, 304*89c4ff92SAndroid Build Coastguard Worker "int64_data: 3 int64_data: -1 int64_data: 4") 305*89c4ff92SAndroid Build Coastguard Worker { 306*89c4ff92SAndroid Build Coastguard Worker Setup(); 307*89c4ff92SAndroid Build Coastguard Worker } 308*89c4ff92SAndroid Build Coastguard Worker }; 309*89c4ff92SAndroid Build Coastguard Worker 310*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNegativeReshape4DFixture : ReshapeNegativeReshapeFixture 311*89c4ff92SAndroid Build Coastguard Worker { ReshapeNegativeReshape4DFixtureReshapeNegativeReshape4DFixture312*89c4ff92SAndroid Build Coastguard Worker ReshapeNegativeReshape4DFixture() : ReshapeNegativeReshapeFixture( 313*89c4ff92SAndroid Build Coastguard Worker { 2, 3, 1, 2 }, 314*89c4ff92SAndroid Build Coastguard Worker { 4 }, 315*89c4ff92SAndroid Build Coastguard Worker { 3, 1, 2, 2 }, 316*89c4ff92SAndroid Build Coastguard Worker "int64_data: 3 int64_data: 1 int64_data: 2 int64_data: -1") 317*89c4ff92SAndroid Build Coastguard Worker { 318*89c4ff92SAndroid Build Coastguard Worker Setup(); 319*89c4ff92SAndroid Build Coastguard Worker } 320*89c4ff92SAndroid Build Coastguard Worker }; 321*89c4ff92SAndroid Build Coastguard Worker 322*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNegativeReshape1DFixture, "ReshapeNegativeReshape1DTest") 323*89c4ff92SAndroid Build Coastguard Worker { 324*89c4ff92SAndroid Build Coastguard Worker RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, 325*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}); 326*89c4ff92SAndroid Build Coastguard Worker } 327*89c4ff92SAndroid Build Coastguard Worker 328*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNegativeReshape2DFixture, "ReshapeNegativeReshape2DTest") 329*89c4ff92SAndroid Build Coastguard Worker { 330*89c4ff92SAndroid Build Coastguard Worker RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 331*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, 332*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 333*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); 334*89c4ff92SAndroid Build Coastguard Worker } 335*89c4ff92SAndroid Build Coastguard Worker 336*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNegativeReshape3DFixture, "ReshapeNegativeReshape3DTest") 337*89c4ff92SAndroid Build Coastguard Worker { 338*89c4ff92SAndroid Build Coastguard Worker RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 339*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, 340*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 341*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); 342*89c4ff92SAndroid Build Coastguard Worker } 343*89c4ff92SAndroid Build Coastguard Worker 344*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNegativeReshape4DFixture, "ReshapeNegativeReshape4DTest") 345*89c4ff92SAndroid Build Coastguard Worker { 346*89c4ff92SAndroid Build Coastguard Worker RunTest<4, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 347*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, 348*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 349*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); 350*89c4ff92SAndroid Build Coastguard Worker } 351*89c4ff92SAndroid Build Coastguard Worker 352*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNonConstShapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 353*89c4ff92SAndroid Build Coastguard Worker { ReshapeNonConstShapeFixtureReshapeNonConstShapeFixture354*89c4ff92SAndroid Build Coastguard Worker ReshapeNonConstShapeFixture(const std::vector<int>& inputShape, 355*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& shapeInputShape, 356*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& outputShape) 357*89c4ff92SAndroid Build Coastguard Worker { 358*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 359*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 360*89c4ff92SAndroid Build Coastguard Worker producer_name: "onnx-example" 361*89c4ff92SAndroid Build Coastguard Worker graph { 362*89c4ff92SAndroid Build Coastguard Worker name: "ReshapeGrapn" 363*89c4ff92SAndroid Build Coastguard Worker input { 364*89c4ff92SAndroid Build Coastguard Worker name: "Input" 365*89c4ff92SAndroid Build Coastguard Worker type { 366*89c4ff92SAndroid Build Coastguard Worker tensor_type { 367*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 368*89c4ff92SAndroid Build Coastguard Worker shape { 369*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( 370*89c4ff92SAndroid Build Coastguard Worker } 371*89c4ff92SAndroid Build Coastguard Worker } 372*89c4ff92SAndroid Build Coastguard Worker } 373*89c4ff92SAndroid Build Coastguard Worker } 374*89c4ff92SAndroid Build Coastguard Worker input { 375*89c4ff92SAndroid Build Coastguard Worker name: "Shape" 376*89c4ff92SAndroid Build Coastguard Worker type { 377*89c4ff92SAndroid Build Coastguard Worker tensor_type { 378*89c4ff92SAndroid Build Coastguard Worker elem_type: 7 379*89c4ff92SAndroid Build Coastguard Worker shape { 380*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"( 381*89c4ff92SAndroid Build Coastguard Worker } 382*89c4ff92SAndroid Build Coastguard Worker } 383*89c4ff92SAndroid Build Coastguard Worker } 384*89c4ff92SAndroid Build Coastguard Worker } 385*89c4ff92SAndroid Build Coastguard Worker node { 386*89c4ff92SAndroid Build Coastguard Worker input: "Input" 387*89c4ff92SAndroid Build Coastguard Worker input: "Shape" 388*89c4ff92SAndroid Build Coastguard Worker output: "Output" 389*89c4ff92SAndroid Build Coastguard Worker name: "reshape" 390*89c4ff92SAndroid Build Coastguard Worker op_type: "Reshape" 391*89c4ff92SAndroid Build Coastguard Worker } 392*89c4ff92SAndroid Build Coastguard Worker output { 393*89c4ff92SAndroid Build Coastguard Worker name: "Output" 394*89c4ff92SAndroid Build Coastguard Worker type { 395*89c4ff92SAndroid Build Coastguard Worker tensor_type { 396*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 397*89c4ff92SAndroid Build Coastguard Worker shape { 398*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 399*89c4ff92SAndroid Build Coastguard Worker } 400*89c4ff92SAndroid Build Coastguard Worker } 401*89c4ff92SAndroid Build Coastguard Worker } 402*89c4ff92SAndroid Build Coastguard Worker } 403*89c4ff92SAndroid Build Coastguard Worker } 404*89c4ff92SAndroid Build Coastguard Worker opset_import { 405*89c4ff92SAndroid Build Coastguard Worker version: 7 406*89c4ff92SAndroid Build Coastguard Worker })"; 407*89c4ff92SAndroid Build Coastguard Worker } 408*89c4ff92SAndroid Build Coastguard Worker }; 409*89c4ff92SAndroid Build Coastguard Worker 410*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNonConst1DShapeFixture : ReshapeNonConstShapeFixture 411*89c4ff92SAndroid Build Coastguard Worker { ReshapeNonConst1DShapeFixtureReshapeNonConst1DShapeFixture412*89c4ff92SAndroid Build Coastguard Worker ReshapeNonConst1DShapeFixture() : ReshapeNonConstShapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }) 413*89c4ff92SAndroid Build Coastguard Worker { 414*89c4ff92SAndroid Build Coastguard Worker Setup(); 415*89c4ff92SAndroid Build Coastguard Worker } 416*89c4ff92SAndroid Build Coastguard Worker }; 417*89c4ff92SAndroid Build Coastguard Worker 418*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNonConst2DShapeFixture : ReshapeNonConstShapeFixture 419*89c4ff92SAndroid Build Coastguard Worker { ReshapeNonConst2DShapeFixtureReshapeNonConst2DShapeFixture420*89c4ff92SAndroid Build Coastguard Worker ReshapeNonConst2DShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 2 }, { 2, 12 }) 421*89c4ff92SAndroid Build Coastguard Worker { 422*89c4ff92SAndroid Build Coastguard Worker Setup(); 423*89c4ff92SAndroid Build Coastguard Worker } 424*89c4ff92SAndroid Build Coastguard Worker }; 425*89c4ff92SAndroid Build Coastguard Worker 426*89c4ff92SAndroid Build Coastguard Worker struct ReshapeInvalidNonConstShapeFixture : ReshapeNonConstShapeFixture 427*89c4ff92SAndroid Build Coastguard Worker { ReshapeInvalidNonConstShapeFixtureReshapeInvalidNonConstShapeFixture428*89c4ff92SAndroid Build Coastguard Worker ReshapeInvalidNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 3 }, { 2, 3, 4 }) 429*89c4ff92SAndroid Build Coastguard Worker { 430*89c4ff92SAndroid Build Coastguard Worker } 431*89c4ff92SAndroid Build Coastguard Worker }; 432*89c4ff92SAndroid Build Coastguard Worker 433*89c4ff92SAndroid Build Coastguard Worker struct ReshapeInvalidDimNonConstShapeFixture : ReshapeNonConstShapeFixture 434*89c4ff92SAndroid Build Coastguard Worker { ReshapeInvalidDimNonConstShapeFixtureReshapeInvalidDimNonConstShapeFixture435*89c4ff92SAndroid Build Coastguard Worker ReshapeInvalidDimNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 1, 2 }, { 2, 3, 4 }) 436*89c4ff92SAndroid Build Coastguard Worker { 437*89c4ff92SAndroid Build Coastguard Worker } 438*89c4ff92SAndroid Build Coastguard Worker }; 439*89c4ff92SAndroid Build Coastguard Worker 440*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNonConst1DShapeFixture, "ReshapeNonConst1DShapeTest") 441*89c4ff92SAndroid Build Coastguard Worker { 442*89c4ff92SAndroid Build Coastguard Worker RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, 443*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}); 444*89c4ff92SAndroid Build Coastguard Worker } 445*89c4ff92SAndroid Build Coastguard Worker 446*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNonConst2DShapeFixture, "ReshapeNonConst2DShapeTest") 447*89c4ff92SAndroid Build Coastguard Worker { 448*89c4ff92SAndroid Build Coastguard Worker RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 449*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 450*89c4ff92SAndroid Build Coastguard Worker 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 451*89c4ff92SAndroid Build Coastguard Worker 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}}, 452*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 453*89c4ff92SAndroid Build Coastguard Worker 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 454*89c4ff92SAndroid Build Coastguard Worker 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 455*89c4ff92SAndroid Build Coastguard Worker 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}}); 456*89c4ff92SAndroid Build Coastguard Worker } 457*89c4ff92SAndroid Build Coastguard Worker 458*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeInvalidNonConstShapeFixture, "ReshapeInvalidNonConstShapeTest") 459*89c4ff92SAndroid Build Coastguard Worker { 460*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(Setup(), armnn::ParseException); 461*89c4ff92SAndroid Build Coastguard Worker } 462*89c4ff92SAndroid Build Coastguard Worker 463*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeInvalidDimNonConstShapeFixture, "ReshapeInvalidDimNonConstShapeTest") 464*89c4ff92SAndroid Build Coastguard Worker { 465*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(Setup(), armnn::ParseException); 466*89c4ff92SAndroid Build Coastguard Worker } 467*89c4ff92SAndroid Build Coastguard Worker 468*89c4ff92SAndroid Build Coastguard Worker } 469