xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/LoadModel.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker using armnnTfLiteParser::TfLiteParserImpl;
11*89c4ff92SAndroid Build Coastguard Worker using ModelPtr = TfLiteParserImpl::ModelPtr;
12*89c4ff92SAndroid Build Coastguard Worker using SubgraphPtr = TfLiteParserImpl::SubgraphPtr;
13*89c4ff92SAndroid Build Coastguard Worker using OperatorPtr = TfLiteParserImpl::OperatorPtr;
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_LoadModel")
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker struct LoadModelFixture : public ParserFlatbuffersFixture
18*89c4ff92SAndroid Build Coastguard Worker {
LoadModelFixtureLoadModelFixture19*89c4ff92SAndroid Build Coastguard Worker     explicit LoadModelFixture()
20*89c4ff92SAndroid Build Coastguard Worker     {
21*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
22*89c4ff92SAndroid Build Coastguard Worker         {
23*89c4ff92SAndroid Build Coastguard Worker             "version": 3,
24*89c4ff92SAndroid Build Coastguard Worker             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
25*89c4ff92SAndroid Build Coastguard Worker             "subgraphs": [
26*89c4ff92SAndroid Build Coastguard Worker             {
27*89c4ff92SAndroid Build Coastguard Worker                 "tensors": [
28*89c4ff92SAndroid Build Coastguard Worker                 {
29*89c4ff92SAndroid Build Coastguard Worker                     "shape": [ 1, 1, 1, 1 ] ,
30*89c4ff92SAndroid Build Coastguard Worker                     "type": "UINT8",
31*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
32*89c4ff92SAndroid Build Coastguard Worker                             "name": "OutputTensor",
33*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
34*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
35*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
36*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
37*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ]
38*89c4ff92SAndroid Build Coastguard Worker                             }
39*89c4ff92SAndroid Build Coastguard Worker                 },
40*89c4ff92SAndroid Build Coastguard Worker                 {
41*89c4ff92SAndroid Build Coastguard Worker                     "shape": [ 1, 2, 2, 1 ] ,
42*89c4ff92SAndroid Build Coastguard Worker                     "type": "UINT8",
43*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
44*89c4ff92SAndroid Build Coastguard Worker                             "name": "InputTensor",
45*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
46*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
47*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
48*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
49*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ]
50*89c4ff92SAndroid Build Coastguard Worker                             }
51*89c4ff92SAndroid Build Coastguard Worker                 }
52*89c4ff92SAndroid Build Coastguard Worker                 ],
53*89c4ff92SAndroid Build Coastguard Worker                 "inputs": [ 1 ],
54*89c4ff92SAndroid Build Coastguard Worker                 "outputs": [ 0 ],
55*89c4ff92SAndroid Build Coastguard Worker                 "operators": [ {
56*89c4ff92SAndroid Build Coastguard Worker                         "opcode_index": 0,
57*89c4ff92SAndroid Build Coastguard Worker                         "inputs": [ 1 ],
58*89c4ff92SAndroid Build Coastguard Worker                         "outputs": [ 0 ],
59*89c4ff92SAndroid Build Coastguard Worker                         "builtin_options_type": "Pool2DOptions",
60*89c4ff92SAndroid Build Coastguard Worker                         "builtin_options":
61*89c4ff92SAndroid Build Coastguard Worker                         {
62*89c4ff92SAndroid Build Coastguard Worker                             "padding": "VALID",
63*89c4ff92SAndroid Build Coastguard Worker                             "stride_w": 2,
64*89c4ff92SAndroid Build Coastguard Worker                             "stride_h": 2,
65*89c4ff92SAndroid Build Coastguard Worker                             "filter_width": 2,
66*89c4ff92SAndroid Build Coastguard Worker                             "filter_height": 2,
67*89c4ff92SAndroid Build Coastguard Worker                             "fused_activation_function": "NONE"
68*89c4ff92SAndroid Build Coastguard Worker                         },
69*89c4ff92SAndroid Build Coastguard Worker                         "custom_options_format": "FLEXBUFFERS"
70*89c4ff92SAndroid Build Coastguard Worker                     } ]
71*89c4ff92SAndroid Build Coastguard Worker                 },
72*89c4ff92SAndroid Build Coastguard Worker                 {
73*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [
74*89c4ff92SAndroid Build Coastguard Worker                         {
75*89c4ff92SAndroid Build Coastguard Worker                             "shape": [ 1, 3, 3, 1 ],
76*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
77*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
78*89c4ff92SAndroid Build Coastguard Worker                             "name": "ConvInputTensor",
79*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
80*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
81*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
82*89c4ff92SAndroid Build Coastguard Worker                             }
83*89c4ff92SAndroid Build Coastguard Worker                         },
84*89c4ff92SAndroid Build Coastguard Worker                         {
85*89c4ff92SAndroid Build Coastguard Worker                             "shape": [ 1, 1, 1, 1 ],
86*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
87*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
88*89c4ff92SAndroid Build Coastguard Worker                             "name": "ConvOutputTensor",
89*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
90*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
91*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 511.0 ],
92*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 2.0 ],
93*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
94*89c4ff92SAndroid Build Coastguard Worker                             }
95*89c4ff92SAndroid Build Coastguard Worker                         },
96*89c4ff92SAndroid Build Coastguard Worker                         {
97*89c4ff92SAndroid Build Coastguard Worker                             "shape": [ 1, 3, 3, 1 ],
98*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
99*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 2,
100*89c4ff92SAndroid Build Coastguard Worker                             "name": "filterTensor",
101*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
102*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
103*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
104*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
105*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
106*89c4ff92SAndroid Build Coastguard Worker                             }
107*89c4ff92SAndroid Build Coastguard Worker                         }
108*89c4ff92SAndroid Build Coastguard Worker                     ],
109*89c4ff92SAndroid Build Coastguard Worker                     "inputs": [ 0 ],
110*89c4ff92SAndroid Build Coastguard Worker                     "outputs": [ 1 ],
111*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
112*89c4ff92SAndroid Build Coastguard Worker                         {
113*89c4ff92SAndroid Build Coastguard Worker                             "opcode_index": 1,
114*89c4ff92SAndroid Build Coastguard Worker                             "inputs": [ 0, 2 ],
115*89c4ff92SAndroid Build Coastguard Worker                             "outputs": [ 1 ],
116*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options_type": "Conv2DOptions",
117*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options": {
118*89c4ff92SAndroid Build Coastguard Worker                                 "padding": "VALID",
119*89c4ff92SAndroid Build Coastguard Worker                                 "stride_w": 1,
120*89c4ff92SAndroid Build Coastguard Worker                                 "stride_h": 1,
121*89c4ff92SAndroid Build Coastguard Worker                                 "fused_activation_function": "NONE"
122*89c4ff92SAndroid Build Coastguard Worker                             },
123*89c4ff92SAndroid Build Coastguard Worker                             "custom_options_format": "FLEXBUFFERS"
124*89c4ff92SAndroid Build Coastguard Worker                         }
125*89c4ff92SAndroid Build Coastguard Worker                     ],
126*89c4ff92SAndroid Build Coastguard Worker                 }
127*89c4ff92SAndroid Build Coastguard Worker             ],
128*89c4ff92SAndroid Build Coastguard Worker             "description": "Test loading a model",
129*89c4ff92SAndroid Build Coastguard Worker             "buffers" : [ {}, {} ]
130*89c4ff92SAndroid Build Coastguard Worker         })";
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker         ReadStringToBinary();
133*89c4ff92SAndroid Build Coastguard Worker     }
134*89c4ff92SAndroid Build Coastguard Worker 
CheckModelLoadModelFixture135*89c4ff92SAndroid Build Coastguard Worker     void CheckModel(const ModelPtr& model, uint32_t version, size_t opcodeSize,
136*89c4ff92SAndroid Build Coastguard Worker                     const std::vector<tflite::BuiltinOperator>& opcodes,
137*89c4ff92SAndroid Build Coastguard Worker                     size_t subgraphs, const std::string desc, size_t buffers)
138*89c4ff92SAndroid Build Coastguard Worker     {
139*89c4ff92SAndroid Build Coastguard Worker         CHECK(model);
140*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(version, model->version);
141*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(opcodeSize, model->operator_codes.size());
142*89c4ff92SAndroid Build Coastguard Worker         CheckBuiltinOperators(opcodes, model->operator_codes);
143*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(subgraphs, model->subgraphs.size());
144*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(desc, model->description);
145*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(buffers, model->buffers.size());
146*89c4ff92SAndroid Build Coastguard Worker     }
147*89c4ff92SAndroid Build Coastguard Worker 
CheckBuiltinOperatorsLoadModelFixture148*89c4ff92SAndroid Build Coastguard Worker     void CheckBuiltinOperators(const std::vector<tflite::BuiltinOperator>& expectedOperators,
149*89c4ff92SAndroid Build Coastguard Worker                                const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& result)
150*89c4ff92SAndroid Build Coastguard Worker     {
151*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(expectedOperators.size(), result.size());
152*89c4ff92SAndroid Build Coastguard Worker         for (size_t i = 0; i < expectedOperators.size(); i++)
153*89c4ff92SAndroid Build Coastguard Worker         {
154*89c4ff92SAndroid Build Coastguard Worker             CHECK_EQ(expectedOperators[i], result[i]->builtin_code);
155*89c4ff92SAndroid Build Coastguard Worker         }
156*89c4ff92SAndroid Build Coastguard Worker     }
157*89c4ff92SAndroid Build Coastguard Worker 
CheckSubgraphLoadModelFixture158*89c4ff92SAndroid Build Coastguard Worker     void CheckSubgraph(const SubgraphPtr& subgraph, size_t tensors, const std::vector<int32_t>& inputs,
159*89c4ff92SAndroid Build Coastguard Worker                        const std::vector<int32_t>& outputs, size_t operators, const std::string& name)
160*89c4ff92SAndroid Build Coastguard Worker     {
161*89c4ff92SAndroid Build Coastguard Worker         CHECK(subgraph);
162*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(tensors, subgraph->tensors.size());
163*89c4ff92SAndroid Build Coastguard Worker         CHECK(std::equal(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end()));
164*89c4ff92SAndroid Build Coastguard Worker         CHECK(std::equal(outputs.begin(), outputs.end(),
165*89c4ff92SAndroid Build Coastguard Worker                                       subgraph->outputs.begin(), subgraph->outputs.end()));
166*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(operators, subgraph->operators.size());
167*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(name, subgraph->name);
168*89c4ff92SAndroid Build Coastguard Worker     }
169*89c4ff92SAndroid Build Coastguard Worker 
CheckOperatorLoadModelFixture170*89c4ff92SAndroid Build Coastguard Worker     void CheckOperator(const OperatorPtr& operatorPtr, uint32_t opcode,  const std::vector<int32_t>& inputs,
171*89c4ff92SAndroid Build Coastguard Worker                        const std::vector<int32_t>& outputs, tflite::BuiltinOptions optionType,
172*89c4ff92SAndroid Build Coastguard Worker                        tflite::CustomOptionsFormat custom_options_format)
173*89c4ff92SAndroid Build Coastguard Worker     {
174*89c4ff92SAndroid Build Coastguard Worker         CHECK(operatorPtr);
175*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(opcode, operatorPtr->opcode_index);
176*89c4ff92SAndroid Build Coastguard Worker         CHECK(std::equal(inputs.begin(), inputs.end(),
177*89c4ff92SAndroid Build Coastguard Worker                                       operatorPtr->inputs.begin(), operatorPtr->inputs.end()));
178*89c4ff92SAndroid Build Coastguard Worker         CHECK(std::equal(outputs.begin(), outputs.end(),
179*89c4ff92SAndroid Build Coastguard Worker                                       operatorPtr->outputs.begin(), operatorPtr->outputs.end()));
180*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(optionType, operatorPtr->builtin_options.type);
181*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(custom_options_format, operatorPtr->custom_options_format);
182*89c4ff92SAndroid Build Coastguard Worker     }
183*89c4ff92SAndroid Build Coastguard Worker };
184*89c4ff92SAndroid Build Coastguard Worker 
185*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromBinary")
186*89c4ff92SAndroid Build Coastguard Worker {
187*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
188*89c4ff92SAndroid Build Coastguard Worker                                                                              m_GraphBinary.size());
189*89c4ff92SAndroid Build Coastguard Worker     CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
190*89c4ff92SAndroid Build Coastguard Worker                2, "Test loading a model", 2);
191*89c4ff92SAndroid Build Coastguard Worker     CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
192*89c4ff92SAndroid Build Coastguard Worker     CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
193*89c4ff92SAndroid Build Coastguard Worker     CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
194*89c4ff92SAndroid Build Coastguard Worker                   tflite::CustomOptionsFormat_FLEXBUFFERS);
195*89c4ff92SAndroid Build Coastguard Worker     CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
196*89c4ff92SAndroid Build Coastguard Worker                   tflite::CustomOptionsFormat_FLEXBUFFERS);
197*89c4ff92SAndroid Build Coastguard Worker }
198*89c4ff92SAndroid Build Coastguard Worker 
199*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromFile")
200*89c4ff92SAndroid Build Coastguard Worker {
201*89c4ff92SAndroid Build Coastguard Worker     using namespace fs;
202*89c4ff92SAndroid Build Coastguard Worker     fs::path fname = armnnUtils::Filesystem::NamedTempFile("Armnn-tfLite-LoadModelFromFile-TempFile.csv");
203*89c4ff92SAndroid Build Coastguard Worker     bool saved = flatbuffers::SaveFile(fname.c_str(),
204*89c4ff92SAndroid Build Coastguard Worker                                        reinterpret_cast<char *>(m_GraphBinary.data()),
205*89c4ff92SAndroid Build Coastguard Worker                                        m_GraphBinary.size(), true);
206*89c4ff92SAndroid Build Coastguard Worker     CHECK_MESSAGE(saved, "Cannot save test file");
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromFile(fname.c_str());
209*89c4ff92SAndroid Build Coastguard Worker     CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
210*89c4ff92SAndroid Build Coastguard Worker                2, "Test loading a model", 2);
211*89c4ff92SAndroid Build Coastguard Worker     CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
212*89c4ff92SAndroid Build Coastguard Worker     CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
213*89c4ff92SAndroid Build Coastguard Worker     CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
214*89c4ff92SAndroid Build Coastguard Worker                   tflite::CustomOptionsFormat_FLEXBUFFERS);
215*89c4ff92SAndroid Build Coastguard Worker     CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
216*89c4ff92SAndroid Build Coastguard Worker                   tflite::CustomOptionsFormat_FLEXBUFFERS);
217*89c4ff92SAndroid Build Coastguard Worker     remove(fname);
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker 
220*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LoadNullBinary")
221*89c4ff92SAndroid Build Coastguard Worker {
222*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException);
223*89c4ff92SAndroid Build Coastguard Worker }
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LoadInvalidBinary")
226*89c4ff92SAndroid Build Coastguard Worker {
227*89c4ff92SAndroid Build Coastguard Worker     std::string testData = "invalid data";
228*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData),
229*89c4ff92SAndroid Build Coastguard Worker                                                         testData.length()), armnn::ParseException);
230*89c4ff92SAndroid Build Coastguard Worker }
231*89c4ff92SAndroid Build Coastguard Worker 
232*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LoadFileNotFound")
233*89c4ff92SAndroid Build Coastguard Worker {
234*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException);
235*89c4ff92SAndroid Build Coastguard Worker }
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LoadNullPtrFile")
238*89c4ff92SAndroid Build Coastguard Worker {
239*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile(nullptr), armnn::InvalidArgumentException);
240*89c4ff92SAndroid Build Coastguard Worker }
241*89c4ff92SAndroid Build Coastguard Worker 
242*89c4ff92SAndroid Build Coastguard Worker }
243