xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/GetSubgraphInputsOutputs.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 using armnnTfLiteParser::TfLiteParserImpl;
9 using ModelPtr = TfLiteParserImpl::ModelPtr;
10 using TensorRawPtr = TfLiteParserImpl::TensorRawPtr;
11 
12 TEST_SUITE("TensorflowLiteParser_GetSubgraphInputsOutputs")
13 {
14 struct GetSubgraphInputsOutputsMainFixture : public ParserFlatbuffersFixture
15 {
GetSubgraphInputsOutputsMainFixtureGetSubgraphInputsOutputsMainFixture16     explicit GetSubgraphInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs)
17     {
18         m_JsonString = R"(
19         {
20             "version": 3,
21             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
22             "subgraphs": [
23             {
24                 "tensors": [
25                 {
26                     "shape": [ 1, 1, 1, 1 ] ,
27                     "type": "UINT8",
28                             "buffer": 0,
29                             "name": "OutputTensor",
30                             "quantization": {
31                                 "min": [ 0.0 ],
32                                 "max": [ 255.0 ],
33                                 "scale": [ 1.0 ],
34                                 "zero_point": [ 0 ]
35                             }
36                 },
37                 {
38                     "shape": [ 1, 2, 2, 1 ] ,
39                     "type": "UINT8",
40                             "buffer": 1,
41                             "name": "InputTensor",
42                             "quantization": {
43                                 "min": [ -1.2 ],
44                                 "max": [ 25.5 ],
45                                 "scale": [ 0.25 ],
46                                 "zero_point": [ 10 ]
47                             }
48                 }
49                 ],
50                 "inputs": )"
51                             + inputs
52                             + R"(,
53                 "outputs": )"
54                             + outputs
55                             + R"(,
56                 "operators": [ {
57                         "opcode_index": 0,
58                         "inputs": [ 1 ],
59                         "outputs": [ 0 ],
60                         "builtin_options_type": "Pool2DOptions",
61                         "builtin_options":
62                         {
63                             "padding": "VALID",
64                             "stride_w": 2,
65                             "stride_h": 2,
66                             "filter_width": 2,
67                             "filter_height": 2,
68                             "fused_activation_function": "NONE"
69                         },
70                         "custom_options_format": "FLEXBUFFERS"
71                     } ]
72                 },
73                 {
74                     "tensors": [
75                         {
76                             "shape": [ 1, 3, 3, 1 ],
77                             "type": "UINT8",
78                             "buffer": 0,
79                             "name": "ConvInputTensor",
80                             "quantization": {
81                                 "scale": [ 1.0 ],
82                                 "zero_point": [ 0 ],
83                             }
84                         },
85                         {
86                             "shape": [ 1, 1, 1, 1 ],
87                             "type": "UINT8",
88                             "buffer": 1,
89                             "name": "ConvOutputTensor",
90                             "quantization": {
91                                 "min": [ 0.0 ],
92                                 "max": [ 511.0 ],
93                                 "scale": [ 2.0 ],
94                                 "zero_point": [ 0 ],
95                             }
96                         },
97                         {
98                             "shape": [ 1, 3, 3, 1 ],
99                             "type": "UINT8",
100                             "buffer": 2,
101                             "name": "filterTensor",
102                             "quantization": {
103                                 "min": [ 0.0 ],
104                                 "max": [ 255.0 ],
105                                 "scale": [ 1.0 ],
106                                 "zero_point": [ 0 ],
107                             }
108                         }
109                     ],
110                     "inputs": [ 0 ],
111                     "outputs": [ 1 ],
112                     "operators": [
113                         {
114                             "opcode_index": 0,
115                             "inputs": [ 0, 2 ],
116                             "outputs": [ 1 ],
117                             "builtin_options_type": "Conv2DOptions",
118                             "builtin_options": {
119                                 "padding": "VALID",
120                                 "stride_w": 1,
121                                 "stride_h": 1,
122                                 "fused_activation_function": "NONE"
123                             },
124                             "custom_options_format": "FLEXBUFFERS"
125                         }
126                     ],
127                 }
128             ],
129             "description": "Test Subgraph Inputs Outputs",
130             "buffers" : [
131                     { },
132                     { },
133                     { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
134                     { },
135                 ]
136         })";
137 
138         ReadStringToBinary();
139     }
140 
141 };
142 
143 struct GetEmptySubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
144 {
GetEmptySubgraphInputsOutputsFixtureGetEmptySubgraphInputsOutputsFixture145     GetEmptySubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ ]", "[ ]") {}
146 };
147 
148 struct GetSubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
149 {
GetSubgraphInputsOutputsFixtureGetSubgraphInputsOutputsFixture150     GetSubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {}
151 };
152 
153 TEST_CASE_FIXTURE(GetEmptySubgraphInputsOutputsFixture, "GetEmptySubgraphInputs")
154 {
155     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
156                                                                              m_GraphBinary.size());
157     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0);
158     CHECK_EQ(0, subgraphTensors.size());
159 }
160 
161 TEST_CASE_FIXTURE(GetEmptySubgraphInputsOutputsFixture, "GetEmptySubgraphOutputs")
162 {
163     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
164                                                                              m_GraphBinary.size());
165     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0);
166     CHECK_EQ(0, subgraphTensors.size());
167 }
168 
169 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputs")
170 {
171     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
172                                                                              m_GraphBinary.size());
173     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0);
174     CHECK_EQ(1, subgraphTensors.size());
175     CHECK_EQ(1, subgraphTensors[0].first);
176     CheckTensors(subgraphTensors[0].second, 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1,
177                       "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 });
178 }
179 
180 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputsSimpleQuantized")
181 {
182     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
183                                                                              m_GraphBinary.size());
184     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0);
185     CHECK_EQ(1, subgraphTensors.size());
186     CHECK_EQ(0, subgraphTensors[0].first);
187     CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0,
188                       "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
189 }
190 
191 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputsEmptyMinMax")
192 {
193     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
194                                                                              m_GraphBinary.size());
195     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 1);
196     CHECK_EQ(1, subgraphTensors.size());
197     CHECK_EQ(0, subgraphTensors[0].first);
198     CheckTensors(subgraphTensors[0].second, 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0,
199                       "ConvInputTensor", { }, { }, { 1.0f }, { 0 });
200 }
201 
202 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputs")
203 {
204     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
205                                                                              m_GraphBinary.size());
206     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 1);
207     CHECK_EQ(1, subgraphTensors.size());
208     CHECK_EQ(1, subgraphTensors[0].first);
209     CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1,
210                       "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 });
211 }
212 
213 TEST_CASE("GetSubgraphInputsNullModel")
214 {
215     CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphInputs(nullptr, 0), armnn::ParseException);
216 }
217 
218 TEST_CASE("GetSubgraphOutputsNullModel")
219 {
220     CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphOutputs(nullptr, 0), armnn::ParseException);
221 }
222 
223 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputsInvalidSubgraph")
224 {
225     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
226                                                                              m_GraphBinary.size());
227     CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphInputs(model, 2), armnn::ParseException);
228 }
229 
230 TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputsInvalidSubgraph")
231 {
232     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
233                                                                              m_GraphBinary.size());
234     CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphOutputs(model, 2), armnn::ParseException);
235 }
236 
237 }