xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/tflite/import_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/import.h"
16 
17 #include <initializer_list>
18 #include <string>
19 
20 #include "flatbuffers/flexbuffers.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/schema/schema_conversion_utils.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 #include "tensorflow/lite/version.h"
26 
27 namespace toco {
28 
29 namespace tflite {
30 namespace {
31 
32 using ::testing::ElementsAre;
33 
34 using flatbuffers::Offset;
35 using flatbuffers::Vector;
36 class ImportTest : public ::testing::Test {
37  protected:
38   template <typename T>
CreateDataVector(const std::vector<T> & data)39   Offset<Vector<unsigned char>> CreateDataVector(const std::vector<T>& data) {
40     return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
41                                  sizeof(T) * data.size());
42   }
43 
BuildBuffers()44   Offset<Vector<Offset<::tflite::Buffer>>> BuildBuffers() {
45     auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({}));
46     auto buf1 = ::tflite::CreateBuffer(
47         builder_, CreateDataVector<float>({1.0f, 2.0f, 3.0f, 4.0f}));
48     auto buf2 =
49         ::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f, 4.0f}));
50     return builder_.CreateVector(
51         std::vector<Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
52   }
53 
BuildTensors()54   Offset<Vector<Offset<::tflite::Tensor>>> BuildTensors() {
55     auto q = ::tflite::CreateQuantizationParameters(
56         builder_,
57         /*min=*/builder_.CreateVector<float>({0.1f}),
58         /*max=*/builder_.CreateVector<float>({0.2f}),
59         /*scale=*/builder_.CreateVector<float>({0.3f}),
60         /*zero_point=*/builder_.CreateVector<int64_t>({100LL}));
61     auto t1 =
62         ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 2, 2}),
63                                ::tflite::TensorType_FLOAT32, 1,
64                                builder_.CreateString("tensor_one"), q);
65     auto t2 =
66         ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({2, 1}),
67                                ::tflite::TensorType_FLOAT32, 0,
68                                builder_.CreateString("tensor_two"), q);
69     return builder_.CreateVector(
70         std::vector<Offset<::tflite::Tensor>>({t1, t2}));
71   }
72 
BuildOpCodes(std::initializer_list<::tflite::BuiltinOperator> op_codes)73   Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes(
74       std::initializer_list<::tflite::BuiltinOperator> op_codes) {
75     std::vector<Offset<::tflite::OperatorCode>> op_codes_vector;
76     for (auto op : op_codes) {
77       op_codes_vector.push_back(::tflite::CreateOperatorCode(builder_, op, 0));
78     }
79     return builder_.CreateVector(op_codes_vector);
80   }
81 
BuildOpCodes()82   Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes() {
83     return BuildOpCodes({::tflite::BuiltinOperator_MAX_POOL_2D,
84                          ::tflite::BuiltinOperator_CONV_2D});
85   }
86 
BuildOperators(std::initializer_list<int> inputs,std::initializer_list<int> outputs)87   Offset<Vector<Offset<::tflite::Operator>>> BuildOperators(
88       std::initializer_list<int> inputs, std::initializer_list<int> outputs) {
89     auto is = builder_.CreateVector<int>(inputs);
90     if (inputs.size() == 0) is = 0;
91     auto os = builder_.CreateVector<int>(outputs);
92     if (outputs.size() == 0) os = 0;
93     auto op = ::tflite::CreateOperator(
94         builder_, 0, is, os, ::tflite::BuiltinOptions_Conv2DOptions,
95         ::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_VALID, 1, 1,
96                                       ::tflite::ActivationFunctionType_NONE)
97             .Union(),
98         /*custom_options=*/0, ::tflite::CustomOptionsFormat_FLEXBUFFERS);
99 
100     return builder_.CreateVector(std::vector<Offset<::tflite::Operator>>({op}));
101   }
102 
BuildOperators()103   Offset<Vector<Offset<::tflite::Operator>>> BuildOperators() {
104     return BuildOperators({0}, {1});
105   }
106 
BuildSubGraphs(Offset<Vector<Offset<::tflite::Tensor>>> tensors,Offset<Vector<Offset<::tflite::Operator>>> operators,int num_sub_graphs=1)107   Offset<Vector<Offset<::tflite::SubGraph>>> BuildSubGraphs(
108       Offset<Vector<Offset<::tflite::Tensor>>> tensors,
109       Offset<Vector<Offset<::tflite::Operator>>> operators,
110       int num_sub_graphs = 1) {
111     std::vector<int32_t> inputs = {0};
112     std::vector<int32_t> outputs = {1};
113     std::vector<Offset<::tflite::SubGraph>> v;
114     for (int i = 0; i < num_sub_graphs; ++i) {
115       v.push_back(::tflite::CreateSubGraph(
116           builder_, tensors, builder_.CreateVector(inputs),
117           builder_.CreateVector(outputs), operators,
118           builder_.CreateString("subgraph")));
119     }
120     return builder_.CreateVector(v);
121   }
122 
123   // This is a very simplistic model. We are not interested in testing all the
124   // details here, since tf.mini's testing framework will be exercising all the
125   // conversions multiple times, and the conversion of operators is tested by
126   // separate unittests.
BuildTestModel()127   void BuildTestModel() {
128     auto buffers = BuildBuffers();
129     auto tensors = BuildTensors();
130     auto opcodes = BuildOpCodes();
131     auto operators = BuildOperators();
132     auto subgraphs = BuildSubGraphs(tensors, operators);
133     auto s = builder_.CreateString("");
134 
135     ::tflite::FinishModelBuffer(
136         builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
137                                         opcodes, subgraphs, s, buffers));
138 
139     input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
140   }
InputModelAsString()141   std::string InputModelAsString() {
142     return std::string(reinterpret_cast<char*>(builder_.GetBufferPointer()),
143                        builder_.GetSize());
144   }
145   flatbuffers::FlatBufferBuilder builder_;
146   const ::tflite::Model* input_model_ = nullptr;
147 };
148 
TEST_F(ImportTest,LoadTensorsTable)149 TEST_F(ImportTest, LoadTensorsTable) {
150   BuildTestModel();
151 
152   details::TensorsTable tensors;
153   details::LoadTensorsTable(*input_model_, &tensors);
154   EXPECT_THAT(tensors, ElementsAre("tensor_one", "tensor_two"));
155 }
156 
TEST_F(ImportTest,LoadOperatorsTable)157 TEST_F(ImportTest, LoadOperatorsTable) {
158   BuildTestModel();
159 
160   details::OperatorsTable operators;
161   details::LoadOperatorsTable(*input_model_, &operators);
162   EXPECT_THAT(operators, ElementsAre("MAX_POOL_2D", "CONV_2D"));
163 }
164 
TEST_F(ImportTest,Tensors)165 TEST_F(ImportTest, Tensors) {
166   BuildTestModel();
167 
168   auto model = Import(ModelFlags(), InputModelAsString());
169 
170   ASSERT_GT(model->HasArray("tensor_one"), 0);
171   Array& a1 = model->GetArray("tensor_one");
172   EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
173   EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
174               ElementsAre(1.0f, 2.0f, 3.0f, 4.0f));
175   ASSERT_TRUE(a1.has_shape());
176   EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 2));
177 
178   const auto& mm = a1.minmax;
179   ASSERT_TRUE(mm.get());
180   EXPECT_FLOAT_EQ(0.1, mm->min);
181   EXPECT_FLOAT_EQ(0.2, mm->max);
182 
183   const auto& q = a1.quantization_params;
184   ASSERT_TRUE(q.get());
185   EXPECT_FLOAT_EQ(0.3, q->scale);
186   EXPECT_EQ(100, q->zero_point);
187 }
188 
TEST_F(ImportTest,NoBuffers)189 TEST_F(ImportTest, NoBuffers) {
190   auto buffers = 0;
191   auto tensors = BuildTensors();
192   auto opcodes = BuildOpCodes();
193   auto operators = BuildOperators();
194   auto subgraphs = BuildSubGraphs(tensors, operators);
195   auto comment = builder_.CreateString("");
196   ::tflite::FinishModelBuffer(
197       builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
198                                       subgraphs, comment, buffers));
199   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
200                "Missing 'buffers' section.");
201 }
202 
TEST_F(ImportTest,NoInputs)203 TEST_F(ImportTest, NoInputs) {
204   auto buffers = BuildBuffers();
205   auto tensors = BuildTensors();
206   auto opcodes = BuildOpCodes();
207   auto operators = BuildOperators({}, {1});
208   auto subgraphs = BuildSubGraphs(tensors, operators);
209   auto comment = builder_.CreateString("");
210   ::tflite::FinishModelBuffer(
211       builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
212                                       subgraphs, comment, buffers));
213   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
214                "Missing 'inputs' for operator.");
215 }
216 
TEST_F(ImportTest,NoOutputs)217 TEST_F(ImportTest, NoOutputs) {
218   auto buffers = BuildBuffers();
219   auto tensors = BuildTensors();
220   auto opcodes = BuildOpCodes();
221   auto operators = BuildOperators({0}, {});
222   auto subgraphs = BuildSubGraphs(tensors, operators);
223   auto comment = builder_.CreateString("");
224   ::tflite::FinishModelBuffer(
225       builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
226                                       subgraphs, comment, buffers));
227   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
228                "Missing 'outputs' for operator.");
229 }
230 
TEST_F(ImportTest,InvalidOpCode)231 TEST_F(ImportTest, InvalidOpCode) {
232   auto buffers = BuildBuffers();
233   auto tensors = BuildTensors();
234   auto opcodes = BuildOpCodes({static_cast<::tflite::BuiltinOperator>(-1),
235                                ::tflite::BuiltinOperator_CONV_2D});
236   auto operators = BuildOperators();
237   auto subgraphs = BuildSubGraphs(tensors, operators);
238   auto comment = builder_.CreateString("");
239   ::tflite::FinishModelBuffer(
240       builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
241                                       subgraphs, comment, buffers));
242   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
243                "Operator id '-1' is out of range.");
244 }
245 
TEST_F(ImportTest,MultipleSubGraphs)246 TEST_F(ImportTest, MultipleSubGraphs) {
247   auto buffers = BuildBuffers();
248   auto tensors = BuildTensors();
249   auto opcodes = BuildOpCodes();
250   auto operators = BuildOperators();
251   auto subgraphs = BuildSubGraphs(tensors, operators, 2);
252   auto comment = builder_.CreateString("");
253   ::tflite::FinishModelBuffer(
254       builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
255                                       subgraphs, comment, buffers));
256 
257   input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
258 
259   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
260                "Number of subgraphs in tflite should be exactly 1.");
261 }
262 
263 // TODO(ahentz): still need tests for Operators and IOTensors.
264 
265 }  // namespace
266 }  // namespace tflite
267 
268 }  // namespace toco
269