xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/LoadModel.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 #include <armnnUtils/Filesystem.hpp>
9 
10 using armnnTfLiteParser::TfLiteParserImpl;
11 using ModelPtr = TfLiteParserImpl::ModelPtr;
12 using SubgraphPtr = TfLiteParserImpl::SubgraphPtr;
13 using OperatorPtr = TfLiteParserImpl::OperatorPtr;
14 
15 TEST_SUITE("TensorflowLiteParser_LoadModel")
16 {
17 struct LoadModelFixture : public ParserFlatbuffersFixture
18 {
LoadModelFixtureLoadModelFixture19     explicit LoadModelFixture()
20     {
21         m_JsonString = R"(
22         {
23             "version": 3,
24             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
25             "subgraphs": [
26             {
27                 "tensors": [
28                 {
29                     "shape": [ 1, 1, 1, 1 ] ,
30                     "type": "UINT8",
31                             "buffer": 0,
32                             "name": "OutputTensor",
33                             "quantization": {
34                                 "min": [ 0.0 ],
35                                 "max": [ 255.0 ],
36                                 "scale": [ 1.0 ],
37                                 "zero_point": [ 0 ]
38                             }
39                 },
40                 {
41                     "shape": [ 1, 2, 2, 1 ] ,
42                     "type": "UINT8",
43                             "buffer": 1,
44                             "name": "InputTensor",
45                             "quantization": {
46                                 "min": [ 0.0 ],
47                                 "max": [ 255.0 ],
48                                 "scale": [ 1.0 ],
49                                 "zero_point": [ 0 ]
50                             }
51                 }
52                 ],
53                 "inputs": [ 1 ],
54                 "outputs": [ 0 ],
55                 "operators": [ {
56                         "opcode_index": 0,
57                         "inputs": [ 1 ],
58                         "outputs": [ 0 ],
59                         "builtin_options_type": "Pool2DOptions",
60                         "builtin_options":
61                         {
62                             "padding": "VALID",
63                             "stride_w": 2,
64                             "stride_h": 2,
65                             "filter_width": 2,
66                             "filter_height": 2,
67                             "fused_activation_function": "NONE"
68                         },
69                         "custom_options_format": "FLEXBUFFERS"
70                     } ]
71                 },
72                 {
73                     "tensors": [
74                         {
75                             "shape": [ 1, 3, 3, 1 ],
76                             "type": "UINT8",
77                             "buffer": 0,
78                             "name": "ConvInputTensor",
79                             "quantization": {
80                                 "scale": [ 1.0 ],
81                                 "zero_point": [ 0 ],
82                             }
83                         },
84                         {
85                             "shape": [ 1, 1, 1, 1 ],
86                             "type": "UINT8",
87                             "buffer": 1,
88                             "name": "ConvOutputTensor",
89                             "quantization": {
90                                 "min": [ 0.0 ],
91                                 "max": [ 511.0 ],
92                                 "scale": [ 2.0 ],
93                                 "zero_point": [ 0 ],
94                             }
95                         },
96                         {
97                             "shape": [ 1, 3, 3, 1 ],
98                             "type": "UINT8",
99                             "buffer": 2,
100                             "name": "filterTensor",
101                             "quantization": {
102                                 "min": [ 0.0 ],
103                                 "max": [ 255.0 ],
104                                 "scale": [ 1.0 ],
105                                 "zero_point": [ 0 ],
106                             }
107                         }
108                     ],
109                     "inputs": [ 0 ],
110                     "outputs": [ 1 ],
111                     "operators": [
112                         {
113                             "opcode_index": 1,
114                             "inputs": [ 0, 2 ],
115                             "outputs": [ 1 ],
116                             "builtin_options_type": "Conv2DOptions",
117                             "builtin_options": {
118                                 "padding": "VALID",
119                                 "stride_w": 1,
120                                 "stride_h": 1,
121                                 "fused_activation_function": "NONE"
122                             },
123                             "custom_options_format": "FLEXBUFFERS"
124                         }
125                     ],
126                 }
127             ],
128             "description": "Test loading a model",
129             "buffers" : [ {}, {} ]
130         })";
131 
132         ReadStringToBinary();
133     }
134 
CheckModelLoadModelFixture135     void CheckModel(const ModelPtr& model, uint32_t version, size_t opcodeSize,
136                     const std::vector<tflite::BuiltinOperator>& opcodes,
137                     size_t subgraphs, const std::string desc, size_t buffers)
138     {
139         CHECK(model);
140         CHECK_EQ(version, model->version);
141         CHECK_EQ(opcodeSize, model->operator_codes.size());
142         CheckBuiltinOperators(opcodes, model->operator_codes);
143         CHECK_EQ(subgraphs, model->subgraphs.size());
144         CHECK_EQ(desc, model->description);
145         CHECK_EQ(buffers, model->buffers.size());
146     }
147 
CheckBuiltinOperatorsLoadModelFixture148     void CheckBuiltinOperators(const std::vector<tflite::BuiltinOperator>& expectedOperators,
149                                const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& result)
150     {
151         CHECK_EQ(expectedOperators.size(), result.size());
152         for (size_t i = 0; i < expectedOperators.size(); i++)
153         {
154             CHECK_EQ(expectedOperators[i], result[i]->builtin_code);
155         }
156     }
157 
CheckSubgraphLoadModelFixture158     void CheckSubgraph(const SubgraphPtr& subgraph, size_t tensors, const std::vector<int32_t>& inputs,
159                        const std::vector<int32_t>& outputs, size_t operators, const std::string& name)
160     {
161         CHECK(subgraph);
162         CHECK_EQ(tensors, subgraph->tensors.size());
163         CHECK(std::equal(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end()));
164         CHECK(std::equal(outputs.begin(), outputs.end(),
165                                       subgraph->outputs.begin(), subgraph->outputs.end()));
166         CHECK_EQ(operators, subgraph->operators.size());
167         CHECK_EQ(name, subgraph->name);
168     }
169 
CheckOperatorLoadModelFixture170     void CheckOperator(const OperatorPtr& operatorPtr, uint32_t opcode,  const std::vector<int32_t>& inputs,
171                        const std::vector<int32_t>& outputs, tflite::BuiltinOptions optionType,
172                        tflite::CustomOptionsFormat custom_options_format)
173     {
174         CHECK(operatorPtr);
175         CHECK_EQ(opcode, operatorPtr->opcode_index);
176         CHECK(std::equal(inputs.begin(), inputs.end(),
177                                       operatorPtr->inputs.begin(), operatorPtr->inputs.end()));
178         CHECK(std::equal(outputs.begin(), outputs.end(),
179                                       operatorPtr->outputs.begin(), operatorPtr->outputs.end()));
180         CHECK_EQ(optionType, operatorPtr->builtin_options.type);
181         CHECK_EQ(custom_options_format, operatorPtr->custom_options_format);
182     }
183 };
184 
185 TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromBinary")
186 {
187     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
188                                                                              m_GraphBinary.size());
189     CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
190                2, "Test loading a model", 2);
191     CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
192     CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
193     CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
194                   tflite::CustomOptionsFormat_FLEXBUFFERS);
195     CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
196                   tflite::CustomOptionsFormat_FLEXBUFFERS);
197 }
198 
199 TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromFile")
200 {
201     using namespace fs;
202     fs::path fname = armnnUtils::Filesystem::NamedTempFile("Armnn-tfLite-LoadModelFromFile-TempFile.csv");
203     bool saved = flatbuffers::SaveFile(fname.c_str(),
204                                        reinterpret_cast<char *>(m_GraphBinary.data()),
205                                        m_GraphBinary.size(), true);
206     CHECK_MESSAGE(saved, "Cannot save test file");
207 
208     TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromFile(fname.c_str());
209     CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
210                2, "Test loading a model", 2);
211     CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
212     CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
213     CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
214                   tflite::CustomOptionsFormat_FLEXBUFFERS);
215     CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
216                   tflite::CustomOptionsFormat_FLEXBUFFERS);
217     remove(fname);
218 }
219 
220 TEST_CASE("LoadNullBinary")
221 {
222     CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException);
223 }
224 
225 TEST_CASE("LoadInvalidBinary")
226 {
227     std::string testData = "invalid data";
228     CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData),
229                                                         testData.length()), armnn::ParseException);
230 }
231 
232 TEST_CASE("LoadFileNotFound")
233 {
234     CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException);
235 }
236 
237 TEST_CASE("LoadNullPtrFile")
238 {
239     CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile(nullptr), armnn::InvalidArgumentException);
240 }
241 
242 }
243