xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/GetSubgraphInputsOutputs.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker using armnnTfLiteParser::TfLiteParserImpl;
9*89c4ff92SAndroid Build Coastguard Worker using ModelPtr = TfLiteParserImpl::ModelPtr;
10*89c4ff92SAndroid Build Coastguard Worker using TensorRawPtr = TfLiteParserImpl::TensorRawPtr;
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_GetSubgraphInputsOutputs")
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker struct GetSubgraphInputsOutputsMainFixture : public ParserFlatbuffersFixture
15*89c4ff92SAndroid Build Coastguard Worker {
GetSubgraphInputsOutputsMainFixtureGetSubgraphInputsOutputsMainFixture16*89c4ff92SAndroid Build Coastguard Worker     explicit GetSubgraphInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs)
17*89c4ff92SAndroid Build Coastguard Worker     {
18*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
19*89c4ff92SAndroid Build Coastguard Worker         {
20*89c4ff92SAndroid Build Coastguard Worker             "version": 3,
21*89c4ff92SAndroid Build Coastguard Worker             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
22*89c4ff92SAndroid Build Coastguard Worker             "subgraphs": [
23*89c4ff92SAndroid Build Coastguard Worker             {
24*89c4ff92SAndroid Build Coastguard Worker                 "tensors": [
25*89c4ff92SAndroid Build Coastguard Worker                 {
26*89c4ff92SAndroid Build Coastguard Worker                     "shape": [ 1, 1, 1, 1 ] ,
27*89c4ff92SAndroid Build Coastguard Worker                     "type": "UINT8",
28*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
29*89c4ff92SAndroid Build Coastguard Worker                             "name": "OutputTensor",
30*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
31*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
32*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
33*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
34*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ]
35*89c4ff92SAndroid Build Coastguard Worker                             }
36*89c4ff92SAndroid Build Coastguard Worker                 },
37*89c4ff92SAndroid Build Coastguard Worker                 {
38*89c4ff92SAndroid Build Coastguard Worker                     "shape": [ 1, 2, 2, 1 ] ,
39*89c4ff92SAndroid Build Coastguard Worker                     "type": "UINT8",
40*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
41*89c4ff92SAndroid Build Coastguard Worker                             "name": "InputTensor",
42*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
43*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ -1.2 ],
44*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 25.5 ],
45*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 0.25 ],
46*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 10 ]
47*89c4ff92SAndroid Build Coastguard Worker                             }
48*89c4ff92SAndroid Build Coastguard Worker                 }
49*89c4ff92SAndroid Build Coastguard Worker                 ],
50*89c4ff92SAndroid Build Coastguard Worker                 "inputs": )"
51*89c4ff92SAndroid Build Coastguard Worker                             + inputs
52*89c4ff92SAndroid Build Coastguard Worker                             + R"(,
53*89c4ff92SAndroid Build Coastguard Worker                 "outputs": )"
54*89c4ff92SAndroid Build Coastguard Worker                             + outputs
55*89c4ff92SAndroid Build Coastguard Worker                             + R"(,
56*89c4ff92SAndroid Build Coastguard Worker                 "operators": [ {
57*89c4ff92SAndroid Build Coastguard Worker                         "opcode_index": 0,
58*89c4ff92SAndroid Build Coastguard Worker                         "inputs": [ 1 ],
59*89c4ff92SAndroid Build Coastguard Worker                         "outputs": [ 0 ],
60*89c4ff92SAndroid Build Coastguard Worker                         "builtin_options_type": "Pool2DOptions",
61*89c4ff92SAndroid Build Coastguard Worker                         "builtin_options":
62*89c4ff92SAndroid Build Coastguard Worker                         {
63*89c4ff92SAndroid Build Coastguard Worker                             "padding": "VALID",
64*89c4ff92SAndroid Build Coastguard Worker                             "stride_w": 2,
65*89c4ff92SAndroid Build Coastguard Worker                             "stride_h": 2,
66*89c4ff92SAndroid Build Coastguard Worker                             "filter_width": 2,
67*89c4ff92SAndroid Build Coastguard Worker                             "filter_height": 2,
68*89c4ff92SAndroid Build Coastguard Worker                             "fused_activation_function": "NONE"
69*89c4ff92SAndroid Build Coastguard Worker                         },
70*89c4ff92SAndroid Build Coastguard Worker                         "custom_options_format": "FLEXBUFFERS"
71*89c4ff92SAndroid Build Coastguard Worker                     } ]
72*89c4ff92SAndroid Build Coastguard Worker                 },
73*89c4ff92SAndroid Build Coastguard Worker                 {
74*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [
75*89c4ff92SAndroid Build Coastguard Worker                         {
76*89c4ff92SAndroid Build Coastguard Worker                             "shape": [ 1, 3, 3, 1 ],
77*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
78*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
79*89c4ff92SAndroid Build Coastguard Worker                             "name": "ConvInputTensor",
80*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
81*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
82*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
83*89c4ff92SAndroid Build Coastguard Worker                             }
84*89c4ff92SAndroid Build Coastguard Worker                         },
85*89c4ff92SAndroid Build Coastguard Worker                         {
86*89c4ff92SAndroid Build Coastguard Worker                             "shape": [ 1, 1, 1, 1 ],
87*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
88*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
89*89c4ff92SAndroid Build Coastguard Worker                             "name": "ConvOutputTensor",
90*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
91*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
92*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 511.0 ],
93*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 2.0 ],
94*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
95*89c4ff92SAndroid Build Coastguard Worker                             }
96*89c4ff92SAndroid Build Coastguard Worker                         },
97*89c4ff92SAndroid Build Coastguard Worker                         {
98*89c4ff92SAndroid Build Coastguard Worker                             "shape": [ 1, 3, 3, 1 ],
99*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
100*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 2,
101*89c4ff92SAndroid Build Coastguard Worker                             "name": "filterTensor",
102*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
103*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
104*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
105*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
106*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
107*89c4ff92SAndroid Build Coastguard Worker                             }
108*89c4ff92SAndroid Build Coastguard Worker                         }
109*89c4ff92SAndroid Build Coastguard Worker                     ],
110*89c4ff92SAndroid Build Coastguard Worker                     "inputs": [ 0 ],
111*89c4ff92SAndroid Build Coastguard Worker                     "outputs": [ 1 ],
112*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
113*89c4ff92SAndroid Build Coastguard Worker                         {
114*89c4ff92SAndroid Build Coastguard Worker                             "opcode_index": 0,
115*89c4ff92SAndroid Build Coastguard Worker                             "inputs": [ 0, 2 ],
116*89c4ff92SAndroid Build Coastguard Worker                             "outputs": [ 1 ],
117*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options_type": "Conv2DOptions",
118*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options": {
119*89c4ff92SAndroid Build Coastguard Worker                                 "padding": "VALID",
120*89c4ff92SAndroid Build Coastguard Worker                                 "stride_w": 1,
121*89c4ff92SAndroid Build Coastguard Worker                                 "stride_h": 1,
122*89c4ff92SAndroid Build Coastguard Worker                                 "fused_activation_function": "NONE"
123*89c4ff92SAndroid Build Coastguard Worker                             },
124*89c4ff92SAndroid Build Coastguard Worker                             "custom_options_format": "FLEXBUFFERS"
125*89c4ff92SAndroid Build Coastguard Worker                         }
126*89c4ff92SAndroid Build Coastguard Worker                     ],
127*89c4ff92SAndroid Build Coastguard Worker                 }
128*89c4ff92SAndroid Build Coastguard Worker             ],
129*89c4ff92SAndroid Build Coastguard Worker             "description": "Test Subgraph Inputs Outputs",
130*89c4ff92SAndroid Build Coastguard Worker             "buffers" : [
131*89c4ff92SAndroid Build Coastguard Worker                     { },
132*89c4ff92SAndroid Build Coastguard Worker                     { },
133*89c4ff92SAndroid Build Coastguard Worker                     { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
134*89c4ff92SAndroid Build Coastguard Worker                     { },
135*89c4ff92SAndroid Build Coastguard Worker                 ]
136*89c4ff92SAndroid Build Coastguard Worker         })";
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker         ReadStringToBinary();
139*89c4ff92SAndroid Build Coastguard Worker     }
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker };
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker struct GetEmptySubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
144*89c4ff92SAndroid Build Coastguard Worker {
GetEmptySubgraphInputsOutputsFixtureGetEmptySubgraphInputsOutputsFixture145*89c4ff92SAndroid Build Coastguard Worker     GetEmptySubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ ]", "[ ]") {}
146*89c4ff92SAndroid Build Coastguard Worker };
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker struct GetSubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
149*89c4ff92SAndroid Build Coastguard Worker {
GetSubgraphInputsOutputsFixtureGetSubgraphInputsOutputsFixture150*89c4ff92SAndroid Build Coastguard Worker     GetSubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {}
151*89c4ff92SAndroid Build Coastguard Worker };
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetEmptySubgraphInputsOutputsFixture, "GetEmptySubgraphInputs")
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
156*89c4ff92SAndroid Build Coastguard Worker                                                                              m_GraphBinary.size());
157*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0);
158*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(0, subgraphTensors.size());
159*89c4ff92SAndroid Build Coastguard Worker }
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetEmptySubgraphInputsOutputsFixture, "GetEmptySubgraphOutputs")
162*89c4ff92SAndroid Build Coastguard Worker {
163*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
164*89c4ff92SAndroid Build Coastguard Worker                                                                              m_GraphBinary.size());
165*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0);
166*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(0, subgraphTensors.size());
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputs")
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
172*89c4ff92SAndroid Build Coastguard Worker                                                                              m_GraphBinary.size());
173*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0);
174*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(1, subgraphTensors.size());
175*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(1, subgraphTensors[0].first);
176*89c4ff92SAndroid Build Coastguard Worker     CheckTensors(subgraphTensors[0].second, 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1,
177*89c4ff92SAndroid Build Coastguard Worker                       "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 });
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputsSimpleQuantized")
181*89c4ff92SAndroid Build Coastguard Worker {
182*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
183*89c4ff92SAndroid Build Coastguard Worker                                                                              m_GraphBinary.size());
184*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0);
185*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(1, subgraphTensors.size());
186*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(0, subgraphTensors[0].first);
187*89c4ff92SAndroid Build Coastguard Worker     CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0,
188*89c4ff92SAndroid Build Coastguard Worker                       "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
189*89c4ff92SAndroid Build Coastguard Worker }
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputsEmptyMinMax")
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
194*89c4ff92SAndroid Build Coastguard Worker                                                                              m_GraphBinary.size());
195*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 1);
196*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(1, subgraphTensors.size());
197*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(0, subgraphTensors[0].first);
198*89c4ff92SAndroid Build Coastguard Worker     CheckTensors(subgraphTensors[0].second, 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0,
199*89c4ff92SAndroid Build Coastguard Worker                       "ConvInputTensor", { }, { }, { 1.0f }, { 0 });
200*89c4ff92SAndroid Build Coastguard Worker }
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputs")
203*89c4ff92SAndroid Build Coastguard Worker {
204*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
205*89c4ff92SAndroid Build Coastguard Worker                                                                              m_GraphBinary.size());
206*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 1);
207*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(1, subgraphTensors.size());
208*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(1, subgraphTensors[0].first);
209*89c4ff92SAndroid Build Coastguard Worker     CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1,
210*89c4ff92SAndroid Build Coastguard Worker                       "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 });
211*89c4ff92SAndroid Build Coastguard Worker }
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GetSubgraphInputsNullModel")
214*89c4ff92SAndroid Build Coastguard Worker {
215*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphInputs(nullptr, 0), armnn::ParseException);
216*89c4ff92SAndroid Build Coastguard Worker }
217*89c4ff92SAndroid Build Coastguard Worker 
218*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GetSubgraphOutputsNullModel")
219*89c4ff92SAndroid Build Coastguard Worker {
220*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphOutputs(nullptr, 0), armnn::ParseException);
221*89c4ff92SAndroid Build Coastguard Worker }
222*89c4ff92SAndroid Build Coastguard Worker 
223*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphInputsInvalidSubgraph")
224*89c4ff92SAndroid Build Coastguard Worker {
225*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
226*89c4ff92SAndroid Build Coastguard Worker                                                                              m_GraphBinary.size());
227*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphInputs(model, 2), armnn::ParseException);
228*89c4ff92SAndroid Build Coastguard Worker }
229*89c4ff92SAndroid Build Coastguard Worker 
230*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetSubgraphInputsOutputsFixture, "GetSubgraphOutputsInvalidSubgraph")
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
233*89c4ff92SAndroid Build Coastguard Worker                                                                              m_GraphBinary.size());
234*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(TfLiteParserImpl::GetSubgraphOutputs(model, 2), armnn::ParseException);
235*89c4ff92SAndroid Build Coastguard Worker }
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker }