xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Flatten.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include  "ParserPrototxtFixture.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Flatter")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct FlattenMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12*89c4ff92SAndroid Build Coastguard Worker {
FlattenMainFixtureFlattenMainFixture13*89c4ff92SAndroid Build Coastguard Worker     FlattenMainFixture(const std::string& dataType)
14*89c4ff92SAndroid Build Coastguard Worker     {
15*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
16*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
17*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
18*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
19*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
20*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
21*89c4ff92SAndroid Build Coastguard Worker                    graph {
22*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
23*89c4ff92SAndroid Build Coastguard Worker                      input {
24*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
25*89c4ff92SAndroid Build Coastguard Worker                         type {
26*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
27*89c4ff92SAndroid Build Coastguard Worker                             elem_type: )" + dataType + R"(
28*89c4ff92SAndroid Build Coastguard Worker                             shape {
29*89c4ff92SAndroid Build Coastguard Worker                               dim {
30*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
31*89c4ff92SAndroid Build Coastguard Worker                               }
32*89c4ff92SAndroid Build Coastguard Worker                               dim {
33*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
34*89c4ff92SAndroid Build Coastguard Worker                               }
35*89c4ff92SAndroid Build Coastguard Worker                               dim {
36*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
37*89c4ff92SAndroid Build Coastguard Worker                               }
38*89c4ff92SAndroid Build Coastguard Worker                               dim {
39*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
40*89c4ff92SAndroid Build Coastguard Worker                               }
41*89c4ff92SAndroid Build Coastguard Worker                             }
42*89c4ff92SAndroid Build Coastguard Worker                           }
43*89c4ff92SAndroid Build Coastguard Worker                         }
44*89c4ff92SAndroid Build Coastguard Worker                       }
45*89c4ff92SAndroid Build Coastguard Worker                      node {
46*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
47*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
48*89c4ff92SAndroid Build Coastguard Worker                          name: "flatten"
49*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Flatten"
50*89c4ff92SAndroid Build Coastguard Worker                          attribute {
51*89c4ff92SAndroid Build Coastguard Worker                            name: "axis"
52*89c4ff92SAndroid Build Coastguard Worker                            i: 2
53*89c4ff92SAndroid Build Coastguard Worker                            type: INT
54*89c4ff92SAndroid Build Coastguard Worker                          }
55*89c4ff92SAndroid Build Coastguard Worker                       }
56*89c4ff92SAndroid Build Coastguard Worker                       output {
57*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
58*89c4ff92SAndroid Build Coastguard Worker                           type {
59*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
60*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
61*89c4ff92SAndroid Build Coastguard Worker                                shape {
62*89c4ff92SAndroid Build Coastguard Worker                                    dim {
63*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 4
64*89c4ff92SAndroid Build Coastguard Worker                                    }
65*89c4ff92SAndroid Build Coastguard Worker                                    dim {
66*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 9
67*89c4ff92SAndroid Build Coastguard Worker                                    }
68*89c4ff92SAndroid Build Coastguard Worker                                }
69*89c4ff92SAndroid Build Coastguard Worker                             }
70*89c4ff92SAndroid Build Coastguard Worker                           }
71*89c4ff92SAndroid Build Coastguard Worker                        }
72*89c4ff92SAndroid Build Coastguard Worker                     }
73*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
74*89c4ff92SAndroid Build Coastguard Worker                       version: 7
75*89c4ff92SAndroid Build Coastguard Worker                     })";
76*89c4ff92SAndroid Build Coastguard Worker     }
77*89c4ff92SAndroid Build Coastguard Worker };
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker struct FlattenDefaultAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
80*89c4ff92SAndroid Build Coastguard Worker {
FlattenDefaultAxisFixtureFlattenDefaultAxisFixture81*89c4ff92SAndroid Build Coastguard Worker     FlattenDefaultAxisFixture(const std::string& dataType)
82*89c4ff92SAndroid Build Coastguard Worker     {
83*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
84*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
85*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
86*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
87*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
88*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
89*89c4ff92SAndroid Build Coastguard Worker                    graph {
90*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
91*89c4ff92SAndroid Build Coastguard Worker                      input {
92*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
93*89c4ff92SAndroid Build Coastguard Worker                         type {
94*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
95*89c4ff92SAndroid Build Coastguard Worker                             elem_type: )" + dataType + R"(
96*89c4ff92SAndroid Build Coastguard Worker                             shape {
97*89c4ff92SAndroid Build Coastguard Worker                               dim {
98*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
99*89c4ff92SAndroid Build Coastguard Worker                               }
100*89c4ff92SAndroid Build Coastguard Worker                               dim {
101*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
102*89c4ff92SAndroid Build Coastguard Worker                               }
103*89c4ff92SAndroid Build Coastguard Worker                               dim {
104*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
105*89c4ff92SAndroid Build Coastguard Worker                               }
106*89c4ff92SAndroid Build Coastguard Worker                               dim {
107*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
108*89c4ff92SAndroid Build Coastguard Worker                               }
109*89c4ff92SAndroid Build Coastguard Worker                             }
110*89c4ff92SAndroid Build Coastguard Worker                           }
111*89c4ff92SAndroid Build Coastguard Worker                         }
112*89c4ff92SAndroid Build Coastguard Worker                       }
113*89c4ff92SAndroid Build Coastguard Worker                      node {
114*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
115*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
116*89c4ff92SAndroid Build Coastguard Worker                          name: "flatten"
117*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Flatten"
118*89c4ff92SAndroid Build Coastguard Worker                       }
119*89c4ff92SAndroid Build Coastguard Worker                       output {
120*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
121*89c4ff92SAndroid Build Coastguard Worker                           type {
122*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
123*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
124*89c4ff92SAndroid Build Coastguard Worker                                shape {
125*89c4ff92SAndroid Build Coastguard Worker                                    dim {
126*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 2
127*89c4ff92SAndroid Build Coastguard Worker                                    }
128*89c4ff92SAndroid Build Coastguard Worker                                    dim {
129*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 18
130*89c4ff92SAndroid Build Coastguard Worker                                    }
131*89c4ff92SAndroid Build Coastguard Worker                                }
132*89c4ff92SAndroid Build Coastguard Worker                             }
133*89c4ff92SAndroid Build Coastguard Worker                           }
134*89c4ff92SAndroid Build Coastguard Worker                        }
135*89c4ff92SAndroid Build Coastguard Worker                     }
136*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
137*89c4ff92SAndroid Build Coastguard Worker                       version: 7
138*89c4ff92SAndroid Build Coastguard Worker                     })";
139*89c4ff92SAndroid Build Coastguard Worker     }
140*89c4ff92SAndroid Build Coastguard Worker };
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker struct FlattenAxisZeroFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
143*89c4ff92SAndroid Build Coastguard Worker {
FlattenAxisZeroFixtureFlattenAxisZeroFixture144*89c4ff92SAndroid Build Coastguard Worker     FlattenAxisZeroFixture(const std::string& dataType)
145*89c4ff92SAndroid Build Coastguard Worker     {
146*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
147*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
148*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
149*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
150*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
151*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
152*89c4ff92SAndroid Build Coastguard Worker                    graph {
153*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
154*89c4ff92SAndroid Build Coastguard Worker                      input {
155*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
156*89c4ff92SAndroid Build Coastguard Worker                         type {
157*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
158*89c4ff92SAndroid Build Coastguard Worker                             elem_type: )" + dataType + R"(
159*89c4ff92SAndroid Build Coastguard Worker                             shape {
160*89c4ff92SAndroid Build Coastguard Worker                               dim {
161*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
162*89c4ff92SAndroid Build Coastguard Worker                               }
163*89c4ff92SAndroid Build Coastguard Worker                               dim {
164*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
165*89c4ff92SAndroid Build Coastguard Worker                               }
166*89c4ff92SAndroid Build Coastguard Worker                               dim {
167*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
168*89c4ff92SAndroid Build Coastguard Worker                               }
169*89c4ff92SAndroid Build Coastguard Worker                               dim {
170*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
171*89c4ff92SAndroid Build Coastguard Worker                               }
172*89c4ff92SAndroid Build Coastguard Worker                             }
173*89c4ff92SAndroid Build Coastguard Worker                           }
174*89c4ff92SAndroid Build Coastguard Worker                         }
175*89c4ff92SAndroid Build Coastguard Worker                       }
176*89c4ff92SAndroid Build Coastguard Worker                      node {
177*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
178*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
179*89c4ff92SAndroid Build Coastguard Worker                          name: "flatten"
180*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Flatten"
181*89c4ff92SAndroid Build Coastguard Worker                          attribute {
182*89c4ff92SAndroid Build Coastguard Worker                            name: "axis"
183*89c4ff92SAndroid Build Coastguard Worker                            i: 0
184*89c4ff92SAndroid Build Coastguard Worker                            type: INT
185*89c4ff92SAndroid Build Coastguard Worker                          }
186*89c4ff92SAndroid Build Coastguard Worker                       }
187*89c4ff92SAndroid Build Coastguard Worker                       output {
188*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
189*89c4ff92SAndroid Build Coastguard Worker                           type {
190*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
191*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
192*89c4ff92SAndroid Build Coastguard Worker                                shape {
193*89c4ff92SAndroid Build Coastguard Worker                                    dim {
194*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
195*89c4ff92SAndroid Build Coastguard Worker                                    }
196*89c4ff92SAndroid Build Coastguard Worker                                    dim {
197*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 36
198*89c4ff92SAndroid Build Coastguard Worker                                    }
199*89c4ff92SAndroid Build Coastguard Worker                                }
200*89c4ff92SAndroid Build Coastguard Worker                             }
201*89c4ff92SAndroid Build Coastguard Worker                           }
202*89c4ff92SAndroid Build Coastguard Worker                        }
203*89c4ff92SAndroid Build Coastguard Worker                     }
204*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
205*89c4ff92SAndroid Build Coastguard Worker                       version: 7
206*89c4ff92SAndroid Build Coastguard Worker                     })";
207*89c4ff92SAndroid Build Coastguard Worker     }
208*89c4ff92SAndroid Build Coastguard Worker };
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker struct FlattenNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
211*89c4ff92SAndroid Build Coastguard Worker {
FlattenNegativeAxisFixtureFlattenNegativeAxisFixture212*89c4ff92SAndroid Build Coastguard Worker     FlattenNegativeAxisFixture(const std::string& dataType)
213*89c4ff92SAndroid Build Coastguard Worker     {
214*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
215*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
216*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
217*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
218*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
219*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
220*89c4ff92SAndroid Build Coastguard Worker                    graph {
221*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
222*89c4ff92SAndroid Build Coastguard Worker                      input {
223*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
224*89c4ff92SAndroid Build Coastguard Worker                         type {
225*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
226*89c4ff92SAndroid Build Coastguard Worker                             elem_type: )" + dataType + R"(
227*89c4ff92SAndroid Build Coastguard Worker                             shape {
228*89c4ff92SAndroid Build Coastguard Worker                               dim {
229*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
230*89c4ff92SAndroid Build Coastguard Worker                               }
231*89c4ff92SAndroid Build Coastguard Worker                               dim {
232*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
233*89c4ff92SAndroid Build Coastguard Worker                               }
234*89c4ff92SAndroid Build Coastguard Worker                               dim {
235*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
236*89c4ff92SAndroid Build Coastguard Worker                               }
237*89c4ff92SAndroid Build Coastguard Worker                               dim {
238*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
239*89c4ff92SAndroid Build Coastguard Worker                               }
240*89c4ff92SAndroid Build Coastguard Worker                             }
241*89c4ff92SAndroid Build Coastguard Worker                           }
242*89c4ff92SAndroid Build Coastguard Worker                         }
243*89c4ff92SAndroid Build Coastguard Worker                       }
244*89c4ff92SAndroid Build Coastguard Worker                      node {
245*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
246*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
247*89c4ff92SAndroid Build Coastguard Worker                          name: "flatten"
248*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Flatten"
249*89c4ff92SAndroid Build Coastguard Worker                          attribute {
250*89c4ff92SAndroid Build Coastguard Worker                            name: "axis"
251*89c4ff92SAndroid Build Coastguard Worker                            i: -1
252*89c4ff92SAndroid Build Coastguard Worker                            type: INT
253*89c4ff92SAndroid Build Coastguard Worker                          }
254*89c4ff92SAndroid Build Coastguard Worker                       }
255*89c4ff92SAndroid Build Coastguard Worker                       output {
256*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
257*89c4ff92SAndroid Build Coastguard Worker                           type {
258*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
259*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
260*89c4ff92SAndroid Build Coastguard Worker                                shape {
261*89c4ff92SAndroid Build Coastguard Worker                                    dim {
262*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 12
263*89c4ff92SAndroid Build Coastguard Worker                                    }
264*89c4ff92SAndroid Build Coastguard Worker                                    dim {
265*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 3
266*89c4ff92SAndroid Build Coastguard Worker                                    }
267*89c4ff92SAndroid Build Coastguard Worker                                }
268*89c4ff92SAndroid Build Coastguard Worker                             }
269*89c4ff92SAndroid Build Coastguard Worker                           }
270*89c4ff92SAndroid Build Coastguard Worker                        }
271*89c4ff92SAndroid Build Coastguard Worker                     }
272*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
273*89c4ff92SAndroid Build Coastguard Worker                       version: 7
274*89c4ff92SAndroid Build Coastguard Worker                     })";
275*89c4ff92SAndroid Build Coastguard Worker     }
276*89c4ff92SAndroid Build Coastguard Worker };
277*89c4ff92SAndroid Build Coastguard Worker 
278*89c4ff92SAndroid Build Coastguard Worker struct FlattenInvalidNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
279*89c4ff92SAndroid Build Coastguard Worker {
FlattenInvalidNegativeAxisFixtureFlattenInvalidNegativeAxisFixture280*89c4ff92SAndroid Build Coastguard Worker     FlattenInvalidNegativeAxisFixture(const std::string& dataType)
281*89c4ff92SAndroid Build Coastguard Worker     {
282*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
283*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
284*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
285*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
286*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
287*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
288*89c4ff92SAndroid Build Coastguard Worker                    graph {
289*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
290*89c4ff92SAndroid Build Coastguard Worker                      input {
291*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
292*89c4ff92SAndroid Build Coastguard Worker                         type {
293*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
294*89c4ff92SAndroid Build Coastguard Worker                             elem_type: )" + dataType + R"(
295*89c4ff92SAndroid Build Coastguard Worker                             shape {
296*89c4ff92SAndroid Build Coastguard Worker                               dim {
297*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
298*89c4ff92SAndroid Build Coastguard Worker                               }
299*89c4ff92SAndroid Build Coastguard Worker                               dim {
300*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
301*89c4ff92SAndroid Build Coastguard Worker                               }
302*89c4ff92SAndroid Build Coastguard Worker                               dim {
303*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
304*89c4ff92SAndroid Build Coastguard Worker                               }
305*89c4ff92SAndroid Build Coastguard Worker                               dim {
306*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
307*89c4ff92SAndroid Build Coastguard Worker                               }
308*89c4ff92SAndroid Build Coastguard Worker                             }
309*89c4ff92SAndroid Build Coastguard Worker                           }
310*89c4ff92SAndroid Build Coastguard Worker                         }
311*89c4ff92SAndroid Build Coastguard Worker                       }
312*89c4ff92SAndroid Build Coastguard Worker                      node {
313*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
314*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
315*89c4ff92SAndroid Build Coastguard Worker                          name: "flatten"
316*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Flatten"
317*89c4ff92SAndroid Build Coastguard Worker                          attribute {
318*89c4ff92SAndroid Build Coastguard Worker                            name: "axis"
319*89c4ff92SAndroid Build Coastguard Worker                            i: -5
320*89c4ff92SAndroid Build Coastguard Worker                            type: INT
321*89c4ff92SAndroid Build Coastguard Worker                          }
322*89c4ff92SAndroid Build Coastguard Worker                       }
323*89c4ff92SAndroid Build Coastguard Worker                       output {
324*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
325*89c4ff92SAndroid Build Coastguard Worker                           type {
326*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
327*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
328*89c4ff92SAndroid Build Coastguard Worker                                shape {
329*89c4ff92SAndroid Build Coastguard Worker                                    dim {
330*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 12
331*89c4ff92SAndroid Build Coastguard Worker                                    }
332*89c4ff92SAndroid Build Coastguard Worker                                    dim {
333*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 3
334*89c4ff92SAndroid Build Coastguard Worker                                    }
335*89c4ff92SAndroid Build Coastguard Worker                                }
336*89c4ff92SAndroid Build Coastguard Worker                             }
337*89c4ff92SAndroid Build Coastguard Worker                           }
338*89c4ff92SAndroid Build Coastguard Worker                        }
339*89c4ff92SAndroid Build Coastguard Worker                     }
340*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
341*89c4ff92SAndroid Build Coastguard Worker                       version: 7
342*89c4ff92SAndroid Build Coastguard Worker                     })";
343*89c4ff92SAndroid Build Coastguard Worker     }
344*89c4ff92SAndroid Build Coastguard Worker };
345*89c4ff92SAndroid Build Coastguard Worker 
346*89c4ff92SAndroid Build Coastguard Worker struct FlattenValidFixture : FlattenMainFixture
347*89c4ff92SAndroid Build Coastguard Worker {
FlattenValidFixtureFlattenValidFixture348*89c4ff92SAndroid Build Coastguard Worker     FlattenValidFixture() : FlattenMainFixture("1") {
349*89c4ff92SAndroid Build Coastguard Worker         Setup();
350*89c4ff92SAndroid Build Coastguard Worker     }
351*89c4ff92SAndroid Build Coastguard Worker };
352*89c4ff92SAndroid Build Coastguard Worker 
353*89c4ff92SAndroid Build Coastguard Worker struct FlattenDefaultValidFixture : FlattenDefaultAxisFixture
354*89c4ff92SAndroid Build Coastguard Worker {
FlattenDefaultValidFixtureFlattenDefaultValidFixture355*89c4ff92SAndroid Build Coastguard Worker     FlattenDefaultValidFixture() : FlattenDefaultAxisFixture("1") {
356*89c4ff92SAndroid Build Coastguard Worker         Setup();
357*89c4ff92SAndroid Build Coastguard Worker     }
358*89c4ff92SAndroid Build Coastguard Worker };
359*89c4ff92SAndroid Build Coastguard Worker 
360*89c4ff92SAndroid Build Coastguard Worker struct FlattenAxisZeroValidFixture : FlattenAxisZeroFixture
361*89c4ff92SAndroid Build Coastguard Worker {
FlattenAxisZeroValidFixtureFlattenAxisZeroValidFixture362*89c4ff92SAndroid Build Coastguard Worker     FlattenAxisZeroValidFixture() : FlattenAxisZeroFixture("1") {
363*89c4ff92SAndroid Build Coastguard Worker         Setup();
364*89c4ff92SAndroid Build Coastguard Worker     }
365*89c4ff92SAndroid Build Coastguard Worker };
366*89c4ff92SAndroid Build Coastguard Worker 
367*89c4ff92SAndroid Build Coastguard Worker struct FlattenNegativeAxisValidFixture : FlattenNegativeAxisFixture
368*89c4ff92SAndroid Build Coastguard Worker {
FlattenNegativeAxisValidFixtureFlattenNegativeAxisValidFixture369*89c4ff92SAndroid Build Coastguard Worker     FlattenNegativeAxisValidFixture() : FlattenNegativeAxisFixture("1") {
370*89c4ff92SAndroid Build Coastguard Worker         Setup();
371*89c4ff92SAndroid Build Coastguard Worker     }
372*89c4ff92SAndroid Build Coastguard Worker };
373*89c4ff92SAndroid Build Coastguard Worker 
374*89c4ff92SAndroid Build Coastguard Worker struct FlattenInvalidFixture : FlattenMainFixture
375*89c4ff92SAndroid Build Coastguard Worker {
FlattenInvalidFixtureFlattenInvalidFixture376*89c4ff92SAndroid Build Coastguard Worker     FlattenInvalidFixture() : FlattenMainFixture("10") { }
377*89c4ff92SAndroid Build Coastguard Worker };
378*89c4ff92SAndroid Build Coastguard Worker 
379*89c4ff92SAndroid Build Coastguard Worker struct FlattenInvalidAxisFixture : FlattenInvalidNegativeAxisFixture
380*89c4ff92SAndroid Build Coastguard Worker {
FlattenInvalidAxisFixtureFlattenInvalidAxisFixture381*89c4ff92SAndroid Build Coastguard Worker     FlattenInvalidAxisFixture() : FlattenInvalidNegativeAxisFixture("1") { }
382*89c4ff92SAndroid Build Coastguard Worker };
383*89c4ff92SAndroid Build Coastguard Worker 
384*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FlattenValidFixture, "ValidFlattenTest")
385*89c4ff92SAndroid Build Coastguard Worker {
386*89c4ff92SAndroid Build Coastguard Worker     RunTest<2>({{"Input",
387*89c4ff92SAndroid Build Coastguard Worker                           { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
388*89c4ff92SAndroid Build Coastguard Worker                             1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
389*89c4ff92SAndroid Build Coastguard Worker                             1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
390*89c4ff92SAndroid Build Coastguard Worker                 {{"Output",
391*89c4ff92SAndroid Build Coastguard Worker                           { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
392*89c4ff92SAndroid Build Coastguard Worker                             1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
393*89c4ff92SAndroid Build Coastguard Worker                             1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
394*89c4ff92SAndroid Build Coastguard Worker }
395*89c4ff92SAndroid Build Coastguard Worker 
396*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FlattenDefaultValidFixture, "ValidFlattenDefaultTest")
397*89c4ff92SAndroid Build Coastguard Worker {
398*89c4ff92SAndroid Build Coastguard Worker     RunTest<2>({{"Input",
399*89c4ff92SAndroid Build Coastguard Worker                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
400*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
401*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
402*89c4ff92SAndroid Build Coastguard Worker                {{"Output",
403*89c4ff92SAndroid Build Coastguard Worker                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
404*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
405*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
406*89c4ff92SAndroid Build Coastguard Worker }
407*89c4ff92SAndroid Build Coastguard Worker 
408*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FlattenAxisZeroValidFixture, "ValidFlattenAxisZeroTest")
409*89c4ff92SAndroid Build Coastguard Worker {
410*89c4ff92SAndroid Build Coastguard Worker     RunTest<2>({{"Input",
411*89c4ff92SAndroid Build Coastguard Worker                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
412*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
413*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
414*89c4ff92SAndroid Build Coastguard Worker                {{"Output",
415*89c4ff92SAndroid Build Coastguard Worker                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
416*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
417*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
418*89c4ff92SAndroid Build Coastguard Worker }
419*89c4ff92SAndroid Build Coastguard Worker 
420*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FlattenNegativeAxisValidFixture, "ValidFlattenNegativeAxisTest")
421*89c4ff92SAndroid Build Coastguard Worker {
422*89c4ff92SAndroid Build Coastguard Worker     RunTest<2>({{"Input",
423*89c4ff92SAndroid Build Coastguard Worker                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
424*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
425*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
426*89c4ff92SAndroid Build Coastguard Worker                {{"Output",
427*89c4ff92SAndroid Build Coastguard Worker                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
428*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
429*89c4ff92SAndroid Build Coastguard Worker                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
430*89c4ff92SAndroid Build Coastguard Worker }
431*89c4ff92SAndroid Build Coastguard Worker 
432*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FlattenInvalidFixture, "IncorrectDataTypeFlatten")
433*89c4ff92SAndroid Build Coastguard Worker {
434*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(Setup(), armnn::ParseException);
435*89c4ff92SAndroid Build Coastguard Worker }
436*89c4ff92SAndroid Build Coastguard Worker 
437*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FlattenInvalidAxisFixture, "IncorrectAxisFlatten")
438*89c4ff92SAndroid Build Coastguard Worker {
439*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(Setup(), armnn::ParseException);
440*89c4ff92SAndroid Build Coastguard Worker }
441*89c4ff92SAndroid Build Coastguard Worker 
442*89c4ff92SAndroid Build Coastguard Worker }
443