xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Reshape.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include  "ParserPrototxtFixture.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "OnnxParserTestUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Reshape")
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13*89c4ff92SAndroid Build Coastguard Worker {
ReshapeMainFixtureReshapeMainFixture14*89c4ff92SAndroid Build Coastguard Worker     ReshapeMainFixture(const std::string& dataType)
15*89c4ff92SAndroid Build Coastguard Worker     {
16*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
17*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
18*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
19*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
20*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
21*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
22*89c4ff92SAndroid Build Coastguard Worker                    graph {
23*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
24*89c4ff92SAndroid Build Coastguard Worker                      input {
25*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
26*89c4ff92SAndroid Build Coastguard Worker                         type {
27*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
28*89c4ff92SAndroid Build Coastguard Worker                             elem_type: )" + dataType + R"(
29*89c4ff92SAndroid Build Coastguard Worker                             shape {
30*89c4ff92SAndroid Build Coastguard Worker                               dim {
31*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 4
32*89c4ff92SAndroid Build Coastguard Worker                               }
33*89c4ff92SAndroid Build Coastguard Worker                             }
34*89c4ff92SAndroid Build Coastguard Worker                           }
35*89c4ff92SAndroid Build Coastguard Worker                         }
36*89c4ff92SAndroid Build Coastguard Worker                       }
37*89c4ff92SAndroid Build Coastguard Worker                       input {
38*89c4ff92SAndroid Build Coastguard Worker                          name: "Shape"
39*89c4ff92SAndroid Build Coastguard Worker                          type {
40*89c4ff92SAndroid Build Coastguard Worker                            tensor_type {
41*89c4ff92SAndroid Build Coastguard Worker                              elem_type: 7
42*89c4ff92SAndroid Build Coastguard Worker                              shape {
43*89c4ff92SAndroid Build Coastguard Worker                                dim {
44*89c4ff92SAndroid Build Coastguard Worker                                  dim_value: 2
45*89c4ff92SAndroid Build Coastguard Worker                                }
46*89c4ff92SAndroid Build Coastguard Worker                              }
47*89c4ff92SAndroid Build Coastguard Worker                            }
48*89c4ff92SAndroid Build Coastguard Worker                          }
49*89c4ff92SAndroid Build Coastguard Worker                        }
50*89c4ff92SAndroid Build Coastguard Worker                      node {
51*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
52*89c4ff92SAndroid Build Coastguard Worker                          input: "Shape"
53*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
54*89c4ff92SAndroid Build Coastguard Worker                          name: "reshape"
55*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Reshape"
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker                       }
58*89c4ff92SAndroid Build Coastguard Worker                       initializer {
59*89c4ff92SAndroid Build Coastguard Worker                         dims: 2
60*89c4ff92SAndroid Build Coastguard Worker                         data_type: 7
61*89c4ff92SAndroid Build Coastguard Worker                         int64_data: 2
62*89c4ff92SAndroid Build Coastguard Worker                         int64_data: 2
63*89c4ff92SAndroid Build Coastguard Worker                         name: "Shape"
64*89c4ff92SAndroid Build Coastguard Worker                      }
65*89c4ff92SAndroid Build Coastguard Worker                       output {
66*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
67*89c4ff92SAndroid Build Coastguard Worker                           type {
68*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
69*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
70*89c4ff92SAndroid Build Coastguard Worker                                shape {
71*89c4ff92SAndroid Build Coastguard Worker                                    dim {
72*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 2
73*89c4ff92SAndroid Build Coastguard Worker                                    }
74*89c4ff92SAndroid Build Coastguard Worker                                    dim {
75*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 2
76*89c4ff92SAndroid Build Coastguard Worker                                    }
77*89c4ff92SAndroid Build Coastguard Worker                                }
78*89c4ff92SAndroid Build Coastguard Worker                             }
79*89c4ff92SAndroid Build Coastguard Worker                           }
80*89c4ff92SAndroid Build Coastguard Worker                        }
81*89c4ff92SAndroid Build Coastguard Worker                     }
82*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
83*89c4ff92SAndroid Build Coastguard Worker                       version: 7
84*89c4ff92SAndroid Build Coastguard Worker                     })";
85*89c4ff92SAndroid Build Coastguard Worker     }
86*89c4ff92SAndroid Build Coastguard Worker };
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
89*89c4ff92SAndroid Build Coastguard Worker {
ReshapeRank4FixtureReshapeRank4Fixture90*89c4ff92SAndroid Build Coastguard Worker     ReshapeRank4Fixture(const std::string& dataType)
91*89c4ff92SAndroid Build Coastguard Worker     {
92*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
93*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
94*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
95*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
96*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
97*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
98*89c4ff92SAndroid Build Coastguard Worker                    graph {
99*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
100*89c4ff92SAndroid Build Coastguard Worker                      input {
101*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
102*89c4ff92SAndroid Build Coastguard Worker                         type {
103*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
104*89c4ff92SAndroid Build Coastguard Worker                             elem_type: )" + dataType + R"(
105*89c4ff92SAndroid Build Coastguard Worker                             shape {
106*89c4ff92SAndroid Build Coastguard Worker                               dim {
107*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
108*89c4ff92SAndroid Build Coastguard Worker                               }
109*89c4ff92SAndroid Build Coastguard Worker                               dim {
110*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
111*89c4ff92SAndroid Build Coastguard Worker                               }
112*89c4ff92SAndroid Build Coastguard Worker                               dim {
113*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
114*89c4ff92SAndroid Build Coastguard Worker                               }
115*89c4ff92SAndroid Build Coastguard Worker                               dim {
116*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
117*89c4ff92SAndroid Build Coastguard Worker                               }
118*89c4ff92SAndroid Build Coastguard Worker                             }
119*89c4ff92SAndroid Build Coastguard Worker                           }
120*89c4ff92SAndroid Build Coastguard Worker                         }
121*89c4ff92SAndroid Build Coastguard Worker                       }
122*89c4ff92SAndroid Build Coastguard Worker                       input {
123*89c4ff92SAndroid Build Coastguard Worker                          name: "Shape"
124*89c4ff92SAndroid Build Coastguard Worker                          type {
125*89c4ff92SAndroid Build Coastguard Worker                            tensor_type {
126*89c4ff92SAndroid Build Coastguard Worker                              elem_type: 7
127*89c4ff92SAndroid Build Coastguard Worker                              shape {
128*89c4ff92SAndroid Build Coastguard Worker                                dim {
129*89c4ff92SAndroid Build Coastguard Worker                                  dim_value: 2
130*89c4ff92SAndroid Build Coastguard Worker                                }
131*89c4ff92SAndroid Build Coastguard Worker                              }
132*89c4ff92SAndroid Build Coastguard Worker                            }
133*89c4ff92SAndroid Build Coastguard Worker                          }
134*89c4ff92SAndroid Build Coastguard Worker                        }
135*89c4ff92SAndroid Build Coastguard Worker                      node {
136*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
137*89c4ff92SAndroid Build Coastguard Worker                          input: "Shape"
138*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
139*89c4ff92SAndroid Build Coastguard Worker                          name: "reshape"
140*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Reshape"
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker                       }
143*89c4ff92SAndroid Build Coastguard Worker                       initializer {
144*89c4ff92SAndroid Build Coastguard Worker                         dims: 2
145*89c4ff92SAndroid Build Coastguard Worker                         data_type: 7
146*89c4ff92SAndroid Build Coastguard Worker                         int64_data: 2
147*89c4ff92SAndroid Build Coastguard Worker                         int64_data: 2
148*89c4ff92SAndroid Build Coastguard Worker                         name: "Shape"
149*89c4ff92SAndroid Build Coastguard Worker                      }
150*89c4ff92SAndroid Build Coastguard Worker                       output {
151*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
152*89c4ff92SAndroid Build Coastguard Worker                           type {
153*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
154*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
155*89c4ff92SAndroid Build Coastguard Worker                                shape {
156*89c4ff92SAndroid Build Coastguard Worker                                    dim {
157*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 6
158*89c4ff92SAndroid Build Coastguard Worker                                    }
159*89c4ff92SAndroid Build Coastguard Worker                                    dim {
160*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 6
161*89c4ff92SAndroid Build Coastguard Worker                                    }
162*89c4ff92SAndroid Build Coastguard Worker                                }
163*89c4ff92SAndroid Build Coastguard Worker                             }
164*89c4ff92SAndroid Build Coastguard Worker                           }
165*89c4ff92SAndroid Build Coastguard Worker                        }
166*89c4ff92SAndroid Build Coastguard Worker                     }
167*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
168*89c4ff92SAndroid Build Coastguard Worker                       version: 7
169*89c4ff92SAndroid Build Coastguard Worker                     })";
170*89c4ff92SAndroid Build Coastguard Worker     }
171*89c4ff92SAndroid Build Coastguard Worker };
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker struct ReshapeValidFixture : ReshapeMainFixture
174*89c4ff92SAndroid Build Coastguard Worker {
ReshapeValidFixtureReshapeValidFixture175*89c4ff92SAndroid Build Coastguard Worker     ReshapeValidFixture() : ReshapeMainFixture("1") {
176*89c4ff92SAndroid Build Coastguard Worker         Setup();
177*89c4ff92SAndroid Build Coastguard Worker     }
178*89c4ff92SAndroid Build Coastguard Worker };
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
181*89c4ff92SAndroid Build Coastguard Worker {
ReshapeValidRank4FixtureReshapeValidRank4Fixture182*89c4ff92SAndroid Build Coastguard Worker     ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
183*89c4ff92SAndroid Build Coastguard Worker         Setup();
184*89c4ff92SAndroid Build Coastguard Worker     }
185*89c4ff92SAndroid Build Coastguard Worker };
186*89c4ff92SAndroid Build Coastguard Worker 
187*89c4ff92SAndroid Build Coastguard Worker struct ReshapeInvalidFixture : ReshapeMainFixture
188*89c4ff92SAndroid Build Coastguard Worker {
ReshapeInvalidFixtureReshapeInvalidFixture189*89c4ff92SAndroid Build Coastguard Worker     ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
190*89c4ff92SAndroid Build Coastguard Worker };
191*89c4ff92SAndroid Build Coastguard Worker 
192*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeValidFixture, "ValidReshapeTest")
193*89c4ff92SAndroid Build Coastguard Worker {
194*89c4ff92SAndroid Build Coastguard Worker     RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
195*89c4ff92SAndroid Build Coastguard Worker }
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeValidRank4Fixture, "ValidRank4ReshapeTest")
198*89c4ff92SAndroid Build Coastguard Worker {
199*89c4ff92SAndroid Build Coastguard Worker     RunTest<2>(
200*89c4ff92SAndroid Build Coastguard Worker         {{"Input",
201*89c4ff92SAndroid Build Coastguard Worker                    {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
202*89c4ff92SAndroid Build Coastguard Worker                     1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
203*89c4ff92SAndroid Build Coastguard Worker                     1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
204*89c4ff92SAndroid Build Coastguard Worker         {{"Output",
205*89c4ff92SAndroid Build Coastguard Worker                     {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
206*89c4ff92SAndroid Build Coastguard Worker                      1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
207*89c4ff92SAndroid Build Coastguard Worker                      1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
208*89c4ff92SAndroid Build Coastguard Worker }
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape")
211*89c4ff92SAndroid Build Coastguard Worker {
212*89c4ff92SAndroid Build Coastguard Worker    CHECK_THROWS_AS(Setup(), armnn::ParseException);
213*89c4ff92SAndroid Build Coastguard Worker }
214*89c4ff92SAndroid Build Coastguard Worker 
215*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNegativeReshapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
216*89c4ff92SAndroid Build Coastguard Worker {
ReshapeNegativeReshapeFixtureReshapeNegativeReshapeFixture217*89c4ff92SAndroid Build Coastguard Worker     ReshapeNegativeReshapeFixture(const std::vector<int>& inputShape,
218*89c4ff92SAndroid Build Coastguard Worker                                   const std::vector<int>& shapeInputShape,
219*89c4ff92SAndroid Build Coastguard Worker                                   const std::vector<int>& outputShape,
220*89c4ff92SAndroid Build Coastguard Worker                                   const std::string& shape)
221*89c4ff92SAndroid Build Coastguard Worker         {
222*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
223*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
224*89c4ff92SAndroid Build Coastguard Worker                    producer_name: "onnx-example"
225*89c4ff92SAndroid Build Coastguard Worker                    graph {
226*89c4ff92SAndroid Build Coastguard Worker                      name:  "ReshapeGrapn"
227*89c4ff92SAndroid Build Coastguard Worker                      input {
228*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
229*89c4ff92SAndroid Build Coastguard Worker                         type {
230*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
231*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
232*89c4ff92SAndroid Build Coastguard Worker                             shape {
233*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
234*89c4ff92SAndroid Build Coastguard Worker                             }
235*89c4ff92SAndroid Build Coastguard Worker                           }
236*89c4ff92SAndroid Build Coastguard Worker                         }
237*89c4ff92SAndroid Build Coastguard Worker                       }
238*89c4ff92SAndroid Build Coastguard Worker                       input {
239*89c4ff92SAndroid Build Coastguard Worker                          name: "Shape"
240*89c4ff92SAndroid Build Coastguard Worker                          type {
241*89c4ff92SAndroid Build Coastguard Worker                            tensor_type {
242*89c4ff92SAndroid Build Coastguard Worker                              elem_type: 7
243*89c4ff92SAndroid Build Coastguard Worker                              shape {
244*89c4ff92SAndroid Build Coastguard Worker                                )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
245*89c4ff92SAndroid Build Coastguard Worker                              }
246*89c4ff92SAndroid Build Coastguard Worker                            }
247*89c4ff92SAndroid Build Coastguard Worker                          }
248*89c4ff92SAndroid Build Coastguard Worker                        }
249*89c4ff92SAndroid Build Coastguard Worker                      node {
250*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
251*89c4ff92SAndroid Build Coastguard Worker                          input: "Shape"
252*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
253*89c4ff92SAndroid Build Coastguard Worker                          name: "reshape"
254*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Reshape"
255*89c4ff92SAndroid Build Coastguard Worker                       }
256*89c4ff92SAndroid Build Coastguard Worker                       initializer {
257*89c4ff92SAndroid Build Coastguard Worker                         dims: 2
258*89c4ff92SAndroid Build Coastguard Worker                         data_type: 7
259*89c4ff92SAndroid Build Coastguard Worker                         )" + shape + R"(
260*89c4ff92SAndroid Build Coastguard Worker                         name: "Shape"
261*89c4ff92SAndroid Build Coastguard Worker                      }
262*89c4ff92SAndroid Build Coastguard Worker                       output {
263*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
264*89c4ff92SAndroid Build Coastguard Worker                           type {
265*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
266*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
267*89c4ff92SAndroid Build Coastguard Worker                                shape {
268*89c4ff92SAndroid Build Coastguard Worker                                  )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
269*89c4ff92SAndroid Build Coastguard Worker                                }
270*89c4ff92SAndroid Build Coastguard Worker                             }
271*89c4ff92SAndroid Build Coastguard Worker                           }
272*89c4ff92SAndroid Build Coastguard Worker                        }
273*89c4ff92SAndroid Build Coastguard Worker                     }
274*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
275*89c4ff92SAndroid Build Coastguard Worker                       version: 7
276*89c4ff92SAndroid Build Coastguard Worker                    })";
277*89c4ff92SAndroid Build Coastguard Worker     }
278*89c4ff92SAndroid Build Coastguard Worker };
279*89c4ff92SAndroid Build Coastguard Worker 
280*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNegativeReshape1DFixture : ReshapeNegativeReshapeFixture
281*89c4ff92SAndroid Build Coastguard Worker {
ReshapeNegativeReshape1DFixtureReshapeNegativeReshape1DFixture282*89c4ff92SAndroid Build Coastguard Worker     ReshapeNegativeReshape1DFixture() : ReshapeNegativeReshapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }, "int64_data: -1")
283*89c4ff92SAndroid Build Coastguard Worker     {
284*89c4ff92SAndroid Build Coastguard Worker         Setup();
285*89c4ff92SAndroid Build Coastguard Worker     }
286*89c4ff92SAndroid Build Coastguard Worker };
287*89c4ff92SAndroid Build Coastguard Worker 
288*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNegativeReshape2DFixture : ReshapeNegativeReshapeFixture
289*89c4ff92SAndroid Build Coastguard Worker {
ReshapeNegativeReshape2DFixtureReshapeNegativeReshape2DFixture290*89c4ff92SAndroid Build Coastguard Worker     ReshapeNegativeReshape2DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
291*89c4ff92SAndroid Build Coastguard Worker                                                                       { 2 },
292*89c4ff92SAndroid Build Coastguard Worker                                                                       { 2, 6 },
293*89c4ff92SAndroid Build Coastguard Worker                                                                       "int64_data: -1  int64_data: 6")
294*89c4ff92SAndroid Build Coastguard Worker     {
295*89c4ff92SAndroid Build Coastguard Worker         Setup();
296*89c4ff92SAndroid Build Coastguard Worker     }
297*89c4ff92SAndroid Build Coastguard Worker };
298*89c4ff92SAndroid Build Coastguard Worker 
299*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNegativeReshape3DFixture : ReshapeNegativeReshapeFixture
300*89c4ff92SAndroid Build Coastguard Worker {
ReshapeNegativeReshape3DFixtureReshapeNegativeReshape3DFixture301*89c4ff92SAndroid Build Coastguard Worker     ReshapeNegativeReshape3DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
302*89c4ff92SAndroid Build Coastguard Worker                                                                       { 3 },
303*89c4ff92SAndroid Build Coastguard Worker                                                                       { 3, 1, 4 },
304*89c4ff92SAndroid Build Coastguard Worker                                                                       "int64_data: 3  int64_data: -1  int64_data: 4")
305*89c4ff92SAndroid Build Coastguard Worker     {
306*89c4ff92SAndroid Build Coastguard Worker         Setup();
307*89c4ff92SAndroid Build Coastguard Worker     }
308*89c4ff92SAndroid Build Coastguard Worker };
309*89c4ff92SAndroid Build Coastguard Worker 
310*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNegativeReshape4DFixture : ReshapeNegativeReshapeFixture
311*89c4ff92SAndroid Build Coastguard Worker {
ReshapeNegativeReshape4DFixtureReshapeNegativeReshape4DFixture312*89c4ff92SAndroid Build Coastguard Worker     ReshapeNegativeReshape4DFixture() : ReshapeNegativeReshapeFixture(
313*89c4ff92SAndroid Build Coastguard Worker         { 2, 3, 1, 2 },
314*89c4ff92SAndroid Build Coastguard Worker         { 4 },
315*89c4ff92SAndroid Build Coastguard Worker         { 3, 1, 2, 2 },
316*89c4ff92SAndroid Build Coastguard Worker         "int64_data: 3  int64_data: 1  int64_data: 2  int64_data: -1")
317*89c4ff92SAndroid Build Coastguard Worker     {
318*89c4ff92SAndroid Build Coastguard Worker         Setup();
319*89c4ff92SAndroid Build Coastguard Worker     }
320*89c4ff92SAndroid Build Coastguard Worker };
321*89c4ff92SAndroid Build Coastguard Worker 
322*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNegativeReshape1DFixture, "ReshapeNegativeReshape1DTest")
323*89c4ff92SAndroid Build Coastguard Worker {
324*89c4ff92SAndroid Build Coastguard Worker     RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
325*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
326*89c4ff92SAndroid Build Coastguard Worker }
327*89c4ff92SAndroid Build Coastguard Worker 
328*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNegativeReshape2DFixture, "ReshapeNegativeReshape2DTest")
329*89c4ff92SAndroid Build Coastguard Worker {
330*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
331*89c4ff92SAndroid Build Coastguard Worker                                    7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
332*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
333*89c4ff92SAndroid Build Coastguard Worker                                     7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
334*89c4ff92SAndroid Build Coastguard Worker }
335*89c4ff92SAndroid Build Coastguard Worker 
336*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNegativeReshape3DFixture, "ReshapeNegativeReshape3DTest")
337*89c4ff92SAndroid Build Coastguard Worker {
338*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
339*89c4ff92SAndroid Build Coastguard Worker                                    7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
340*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
341*89c4ff92SAndroid Build Coastguard Worker                                     7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
342*89c4ff92SAndroid Build Coastguard Worker }
343*89c4ff92SAndroid Build Coastguard Worker 
344*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNegativeReshape4DFixture, "ReshapeNegativeReshape4DTest")
345*89c4ff92SAndroid Build Coastguard Worker {
346*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
347*89c4ff92SAndroid Build Coastguard Worker                                    7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
348*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
349*89c4ff92SAndroid Build Coastguard Worker                                     7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
350*89c4ff92SAndroid Build Coastguard Worker }
351*89c4ff92SAndroid Build Coastguard Worker 
352*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNonConstShapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
353*89c4ff92SAndroid Build Coastguard Worker {
ReshapeNonConstShapeFixtureReshapeNonConstShapeFixture354*89c4ff92SAndroid Build Coastguard Worker     ReshapeNonConstShapeFixture(const std::vector<int>& inputShape,
355*89c4ff92SAndroid Build Coastguard Worker                                 const std::vector<int>& shapeInputShape,
356*89c4ff92SAndroid Build Coastguard Worker                                 const std::vector<int>& outputShape)
357*89c4ff92SAndroid Build Coastguard Worker     {
358*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
359*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
360*89c4ff92SAndroid Build Coastguard Worker                    producer_name: "onnx-example"
361*89c4ff92SAndroid Build Coastguard Worker                    graph {
362*89c4ff92SAndroid Build Coastguard Worker                      name:  "ReshapeGrapn"
363*89c4ff92SAndroid Build Coastguard Worker                      input {
364*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
365*89c4ff92SAndroid Build Coastguard Worker                         type {
366*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
367*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
368*89c4ff92SAndroid Build Coastguard Worker                             shape {
369*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
370*89c4ff92SAndroid Build Coastguard Worker                             }
371*89c4ff92SAndroid Build Coastguard Worker                           }
372*89c4ff92SAndroid Build Coastguard Worker                         }
373*89c4ff92SAndroid Build Coastguard Worker                       }
374*89c4ff92SAndroid Build Coastguard Worker                       input {
375*89c4ff92SAndroid Build Coastguard Worker                          name: "Shape"
376*89c4ff92SAndroid Build Coastguard Worker                          type {
377*89c4ff92SAndroid Build Coastguard Worker                            tensor_type {
378*89c4ff92SAndroid Build Coastguard Worker                              elem_type: 7
379*89c4ff92SAndroid Build Coastguard Worker                              shape {
380*89c4ff92SAndroid Build Coastguard Worker                                )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
381*89c4ff92SAndroid Build Coastguard Worker                              }
382*89c4ff92SAndroid Build Coastguard Worker                            }
383*89c4ff92SAndroid Build Coastguard Worker                          }
384*89c4ff92SAndroid Build Coastguard Worker                        }
385*89c4ff92SAndroid Build Coastguard Worker                      node {
386*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
387*89c4ff92SAndroid Build Coastguard Worker                          input: "Shape"
388*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
389*89c4ff92SAndroid Build Coastguard Worker                          name: "reshape"
390*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Reshape"
391*89c4ff92SAndroid Build Coastguard Worker                       }
392*89c4ff92SAndroid Build Coastguard Worker                       output {
393*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
394*89c4ff92SAndroid Build Coastguard Worker                           type {
395*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
396*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
397*89c4ff92SAndroid Build Coastguard Worker                                shape {
398*89c4ff92SAndroid Build Coastguard Worker                                  )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
399*89c4ff92SAndroid Build Coastguard Worker                                }
400*89c4ff92SAndroid Build Coastguard Worker                             }
401*89c4ff92SAndroid Build Coastguard Worker                           }
402*89c4ff92SAndroid Build Coastguard Worker                        }
403*89c4ff92SAndroid Build Coastguard Worker                     }
404*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
405*89c4ff92SAndroid Build Coastguard Worker                       version: 7
406*89c4ff92SAndroid Build Coastguard Worker                    })";
407*89c4ff92SAndroid Build Coastguard Worker     }
408*89c4ff92SAndroid Build Coastguard Worker };
409*89c4ff92SAndroid Build Coastguard Worker 
410*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNonConst1DShapeFixture : ReshapeNonConstShapeFixture
411*89c4ff92SAndroid Build Coastguard Worker {
ReshapeNonConst1DShapeFixtureReshapeNonConst1DShapeFixture412*89c4ff92SAndroid Build Coastguard Worker     ReshapeNonConst1DShapeFixture() : ReshapeNonConstShapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 })
413*89c4ff92SAndroid Build Coastguard Worker     {
414*89c4ff92SAndroid Build Coastguard Worker         Setup();
415*89c4ff92SAndroid Build Coastguard Worker     }
416*89c4ff92SAndroid Build Coastguard Worker };
417*89c4ff92SAndroid Build Coastguard Worker 
418*89c4ff92SAndroid Build Coastguard Worker struct ReshapeNonConst2DShapeFixture : ReshapeNonConstShapeFixture
419*89c4ff92SAndroid Build Coastguard Worker {
ReshapeNonConst2DShapeFixtureReshapeNonConst2DShapeFixture420*89c4ff92SAndroid Build Coastguard Worker     ReshapeNonConst2DShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 2 }, { 2, 12 })
421*89c4ff92SAndroid Build Coastguard Worker     {
422*89c4ff92SAndroid Build Coastguard Worker         Setup();
423*89c4ff92SAndroid Build Coastguard Worker     }
424*89c4ff92SAndroid Build Coastguard Worker };
425*89c4ff92SAndroid Build Coastguard Worker 
426*89c4ff92SAndroid Build Coastguard Worker struct ReshapeInvalidNonConstShapeFixture : ReshapeNonConstShapeFixture
427*89c4ff92SAndroid Build Coastguard Worker {
ReshapeInvalidNonConstShapeFixtureReshapeInvalidNonConstShapeFixture428*89c4ff92SAndroid Build Coastguard Worker     ReshapeInvalidNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 3 }, { 2, 3, 4 })
429*89c4ff92SAndroid Build Coastguard Worker     {
430*89c4ff92SAndroid Build Coastguard Worker     }
431*89c4ff92SAndroid Build Coastguard Worker };
432*89c4ff92SAndroid Build Coastguard Worker 
433*89c4ff92SAndroid Build Coastguard Worker struct ReshapeInvalidDimNonConstShapeFixture : ReshapeNonConstShapeFixture
434*89c4ff92SAndroid Build Coastguard Worker {
ReshapeInvalidDimNonConstShapeFixtureReshapeInvalidDimNonConstShapeFixture435*89c4ff92SAndroid Build Coastguard Worker     ReshapeInvalidDimNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 1, 2 }, { 2, 3, 4 })
436*89c4ff92SAndroid Build Coastguard Worker     {
437*89c4ff92SAndroid Build Coastguard Worker     }
438*89c4ff92SAndroid Build Coastguard Worker };
439*89c4ff92SAndroid Build Coastguard Worker 
440*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNonConst1DShapeFixture, "ReshapeNonConst1DShapeTest")
441*89c4ff92SAndroid Build Coastguard Worker {
442*89c4ff92SAndroid Build Coastguard Worker     RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
443*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
444*89c4ff92SAndroid Build Coastguard Worker }
445*89c4ff92SAndroid Build Coastguard Worker 
446*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeNonConst2DShapeFixture, "ReshapeNonConst2DShapeTest")
447*89c4ff92SAndroid Build Coastguard Worker {
448*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
449*89c4ff92SAndroid Build Coastguard Worker                                    7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
450*89c4ff92SAndroid Build Coastguard Worker                                    13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
451*89c4ff92SAndroid Build Coastguard Worker                                    19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}},
452*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
453*89c4ff92SAndroid Build Coastguard Worker                                     7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
454*89c4ff92SAndroid Build Coastguard Worker                                     13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
455*89c4ff92SAndroid Build Coastguard Worker                                     19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}});
456*89c4ff92SAndroid Build Coastguard Worker }
457*89c4ff92SAndroid Build Coastguard Worker 
458*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeInvalidNonConstShapeFixture, "ReshapeInvalidNonConstShapeTest")
459*89c4ff92SAndroid Build Coastguard Worker {
460*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(Setup(), armnn::ParseException);
461*89c4ff92SAndroid Build Coastguard Worker }
462*89c4ff92SAndroid Build Coastguard Worker 
463*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeInvalidDimNonConstShapeFixture, "ReshapeInvalidDimNonConstShapeTest")
464*89c4ff92SAndroid Build Coastguard Worker {
465*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(Setup(), armnn::ParseException);
466*89c4ff92SAndroid Build Coastguard Worker }
467*89c4ff92SAndroid Build Coastguard Worker 
468*89c4ff92SAndroid Build Coastguard Worker }
469