xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/GetInputsOutputs.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 using armnnTfLiteParser::TfLiteParserImpl;
9 using ModelPtr = TfLiteParserImpl::ModelPtr;
10 
11 TEST_SUITE("TensorflowLiteParser_GetInputsOutputs")
12 {
13 struct GetInputsOutputsMainFixture : public ParserFlatbuffersFixture
14 {
GetInputsOutputsMainFixtureGetInputsOutputsMainFixture15     explicit GetInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs)
16     {
17         m_JsonString = R"(
18         {
19             "version": 3,
20             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
21             "subgraphs": [
22             {
23                 "tensors": [
24                 {
25                     "shape": [ 1, 1, 1, 1 ] ,
26                     "type": "UINT8",
27                             "buffer": 0,
28                             "name": "OutputTensor",
29                             "quantization": {
30                                 "min": [ 0.0 ],
31                                 "max": [ 255.0 ],
32                                 "scale": [ 1.0 ],
33                                 "zero_point": [ 0 ]
34                             }
35                 },
36                 {
37                     "shape": [ 1, 2, 2, 1 ] ,
38                     "type": "UINT8",
39                             "buffer": 1,
40                             "name": "InputTensor",
41                             "quantization": {
42                                 "min": [ -1.2 ],
43                                 "max": [ 25.5 ],
44                                 "scale": [ 0.25 ],
45                                 "zero_point": [ 10 ]
46                             }
47                 }
48                 ],
49                 "inputs": [ 1 ],
50                 "outputs": [ 0 ],
51                 "operators": [ {
52                         "opcode_index": 0,
53                         "inputs":  )"
54                             + inputs
55                             + R"(,
56                         "outputs": )"
57                             + outputs
58                             + R"(,
59                         "builtin_options_type": "Pool2DOptions",
60                         "builtin_options":
61                         {
62                             "padding": "VALID",
63                             "stride_w": 2,
64                             "stride_h": 2,
65                             "filter_width": 2,
66                             "filter_height": 2,
67                             "fused_activation_function": "NONE"
68                         },
69                         "custom_options_format": "FLEXBUFFERS"
70                     } ]
71                 },
72                 {
73                     "tensors": [
74                         {
75                             "shape": [ 1, 3, 3, 1 ],
76                             "type": "UINT8",
77                             "buffer": 0,
78                             "name": "ConvInputTensor",
79                             "quantization": {
80                                 "scale": [ 1.0 ],
81                                 "zero_point": [ 0 ],
82                             }
83                         },
84                         {
85                             "shape": [ 1, 1, 1, 1 ],
86                             "type": "UINT8",
87                             "buffer": 1,
88                             "name": "ConvOutputTensor",
89                             "quantization": {
90                                 "min": [ 0.0 ],
91                                 "max": [ 511.0 ],
92                                 "scale": [ 2.0 ],
93                                 "zero_point": [ 0 ],
94                             }
95                         },
96                         {
97                             "shape": [ 1, 3, 3, 1 ],
98                             "type": "UINT8",
99                             "buffer": 2,
100                             "name": "filterTensor",
101                             "quantization": {
102                                 "min": [ 0.0 ],
103                                 "max": [ 255.0 ],
104                                 "scale": [ 1.0 ],
105                                 "zero_point": [ 0 ],
106                             }
107                         }
108                     ],
109                     "inputs": [ 0 ],
110                     "outputs": [ 1 ],
111                     "operators": [
112                         {
113                             "opcode_index": 0,
114                             "inputs": [ 0, 2 ],
115                             "outputs": [ 1 ],
116                             "builtin_options_type": "Conv2DOptions",
117                             "builtin_options": {
118                                 "padding": "VALID",
119                                 "stride_w": 1,
120                                 "stride_h": 1,
121                                 "fused_activation_function": "NONE"
122                             },
123                             "custom_options_format": "FLEXBUFFERS"
124                         }
125                     ],
126                 }
127             ],
128             "description": "Test Subgraph Inputs Outputs",
129             "buffers" : [
130                     { },
131                     { },
132                     { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
133                     { },
134                 ]
135         })";
136 
137         ReadStringToBinary();
138     }
139 
140 };
141 
142 struct GetEmptyInputsOutputsFixture : GetInputsOutputsMainFixture
143 {
GetEmptyInputsOutputsFixtureGetEmptyInputsOutputsFixture144     GetEmptyInputsOutputsFixture() : GetInputsOutputsMainFixture("[ ]", "[ ]") {}
145 };
146 
147 struct GetInputsOutputsFixture : GetInputsOutputsMainFixture
148 {
GetInputsOutputsFixtureGetInputsOutputsFixture149     GetInputsOutputsFixture() : GetInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {}
150 };
151 
152 TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyInputs")
153 {
154     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
155                                                                              m_GraphBinary.size());
156     TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetInputs(model, 0, 0);
157     CHECK_EQ(0, tensors.size());
158 }
159 
160 TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyOutputs")
161 {
162     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
163                                                                              m_GraphBinary.size());
164     TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetOutputs(model, 0, 0);
165     CHECK_EQ(0, tensors.size());
166 }
167 
168 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputs")
169 {
170     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
171                                                                              m_GraphBinary.size());
172     TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetInputs(model, 0, 0);
173     CHECK_EQ(1, tensors.size());
174     CheckTensors(tensors[0], 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1,
175                       "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 });
176 }
177 
178 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputs")
179 {
180     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
181                                                                              m_GraphBinary.size());
182     TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetOutputs(model, 0, 0);
183     CHECK_EQ(1, tensors.size());
184     CheckTensors(tensors[0], 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0,
185                       "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
186 }
187 
188 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputsMultipleInputs")
189 {
190     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
191                                                                              m_GraphBinary.size());
192     TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetInputs(model, 1, 0);
193     CHECK_EQ(2, tensors.size());
194     CheckTensors(tensors[0], 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0,
195                       "ConvInputTensor", { }, { }, { 1.0f }, { 0 });
196     CheckTensors(tensors[1], 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 2,
197                       "filterTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
198 }
199 
200 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputs2")
201 {
202     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
203                                                                              m_GraphBinary.size());
204     TfLiteParserImpl::TensorRawPtrVector tensors = TfLiteParserImpl::GetOutputs(model, 1, 0);
205     CHECK_EQ(1, tensors.size());
206     CheckTensors(tensors[0], 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1,
207                       "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 });
208 }
209 
210 TEST_CASE("GetInputsNullModel")
211 {
212     CHECK_THROWS_AS(TfLiteParserImpl::GetInputs(nullptr, 0, 0), armnn::ParseException);
213 }
214 
215 TEST_CASE("GetOutputsNullModel")
216 {
217     CHECK_THROWS_AS(TfLiteParserImpl::GetOutputs(nullptr, 0, 0), armnn::ParseException);
218 }
219 
220 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputsInvalidSubgraph")
221 {
222     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
223                                                                              m_GraphBinary.size());
224     CHECK_THROWS_AS(TfLiteParserImpl::GetInputs(model, 2, 0), armnn::ParseException);
225 }
226 
227 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputsInvalidSubgraph")
228 {
229     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
230                                                                              m_GraphBinary.size());
231     CHECK_THROWS_AS(TfLiteParserImpl::GetOutputs(model, 2, 0), armnn::ParseException);
232 }
233 
234 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetInputsInvalidOperator")
235 {
236     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
237                                                                              m_GraphBinary.size());
238     CHECK_THROWS_AS(TfLiteParserImpl::GetInputs(model, 0, 1), armnn::ParseException);
239 }
240 
241 TEST_CASE_FIXTURE(GetInputsOutputsFixture, "GetOutputsInvalidOperator")
242 {
243     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
244                                                                              m_GraphBinary.size());
245     CHECK_THROWS_AS(TfLiteParserImpl::GetOutputs(model, 0, 1), armnn::ParseException);
246 }
247 
248 }