xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Shape.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include "ParserPrototxtFixture.hpp"
8 #include "OnnxParserTestUtils.hpp"
9 
10 TEST_SUITE("OnnxParser_Shape")
11 {
12 
13 struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14 {
ShapeMainFixtureShapeMainFixture15     ShapeMainFixture(const std::string& inputType,
16                      const std::string& outputType,
17                      const std::string& outputDim,
18                      const std::vector<int>& inputShape)
19     {
20         m_Prototext = R"(
21                     ir_version: 8
22                     producer_name: "onnx-example"
23                     graph {
24                       node {
25                         input: "Input"
26                         output: "Output"
27                         op_type: "Shape"
28                       }
29                       name: "shape-model"
30                       input {
31                         name: "Input"
32                         type {
33                           tensor_type {
34                             elem_type: )" + inputType + R"(
35                             shape {
36                               )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
37                             }
38                           }
39                         }
40                       }
41                       output {
42                         name: "Output"
43                         type {
44                           tensor_type {
45                             elem_type: )" + outputType + R"(
46                             shape {
47                               dim {
48                                 dim_value: )" + outputDim + R"(
49                               }
50                             }
51                           }
52                         }
53                       }
54                     }
55                     opset_import {
56                       version: 10
57                     })";
58     }
59 };
60 
61 struct ShapeFloatFixture : ShapeMainFixture
62 {
ShapeFloatFixtureShapeFloatFixture63     ShapeFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 })
64     {
65         Setup();
66     }
67 };
68 
69 struct ShapeIntFixture : ShapeMainFixture
70 {
ShapeIntFixtureShapeIntFixture71     ShapeIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 })
72     {
73         Setup();
74     }
75 };
76 
77 struct Shape3DFixture : ShapeMainFixture
78 {
Shape3DFixtureShape3DFixture79     Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 })
80     {
81         Setup();
82     }
83 };
84 
85 struct Shape2DFixture : ShapeMainFixture
86 {
Shape2DFixtureShape2DFixture87     Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 })
88     {
89         Setup();
90     }
91 };
92 
93 struct Shape1DFixture : ShapeMainFixture
94 {
Shape1DFixtureShape1DFixture95     Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 })
96     {
97         Setup();
98     }
99 };
100 
101 TEST_CASE_FIXTURE(ShapeFloatFixture, "FloatValidShapeTest")
102 {
103     RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f,
104                                  4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
105                                  0.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, {{"Output", { 1, 3, 1, 5 }}});
106 }
107 
108 TEST_CASE_FIXTURE(ShapeIntFixture, "IntValidShapeTest")
109 {
110     RunTest<1, int>({{"Input", { 0, 1, 2, 3, 4,
111                                  4, 3, 2, 1, 0,
112                                  0, 1, 2, 3, 4 }}}, {{"Output", { 1, 3, 1, 5 }}});
113 }
114 
115 TEST_CASE_FIXTURE(Shape3DFixture, "Shape3DTest")
116 {
117     RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
118                                  5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
119                                  0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 3, 2, 3 }}});
120 }
121 
122 TEST_CASE_FIXTURE(Shape2DFixture, "Shape2DTest")
123 {
124     RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}});
125 }
126 
127 TEST_CASE_FIXTURE(Shape1DFixture, "Shape1DTest")
128 {
129     RunTest<1, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}});
130 }
131 
132 }
133