1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/import.h"
16
17 #include <initializer_list>
18 #include <string>
19
20 #include "flatbuffers/flexbuffers.h"
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/schema/schema_conversion_utils.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 #include "tensorflow/lite/version.h"
26
27 namespace toco {
28
29 namespace tflite {
30 namespace {
31
32 using ::testing::ElementsAre;
33
34 using flatbuffers::Offset;
35 using flatbuffers::Vector;
36 class ImportTest : public ::testing::Test {
37 protected:
38 template <typename T>
CreateDataVector(const std::vector<T> & data)39 Offset<Vector<unsigned char>> CreateDataVector(const std::vector<T>& data) {
40 return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
41 sizeof(T) * data.size());
42 }
43
BuildBuffers()44 Offset<Vector<Offset<::tflite::Buffer>>> BuildBuffers() {
45 auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({}));
46 auto buf1 = ::tflite::CreateBuffer(
47 builder_, CreateDataVector<float>({1.0f, 2.0f, 3.0f, 4.0f}));
48 auto buf2 =
49 ::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f, 4.0f}));
50 return builder_.CreateVector(
51 std::vector<Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
52 }
53
BuildTensors()54 Offset<Vector<Offset<::tflite::Tensor>>> BuildTensors() {
55 auto q = ::tflite::CreateQuantizationParameters(
56 builder_,
57 /*min=*/builder_.CreateVector<float>({0.1f}),
58 /*max=*/builder_.CreateVector<float>({0.2f}),
59 /*scale=*/builder_.CreateVector<float>({0.3f}),
60 /*zero_point=*/builder_.CreateVector<int64_t>({100LL}));
61 auto t1 =
62 ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 2, 2}),
63 ::tflite::TensorType_FLOAT32, 1,
64 builder_.CreateString("tensor_one"), q);
65 auto t2 =
66 ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({2, 1}),
67 ::tflite::TensorType_FLOAT32, 0,
68 builder_.CreateString("tensor_two"), q);
69 return builder_.CreateVector(
70 std::vector<Offset<::tflite::Tensor>>({t1, t2}));
71 }
72
BuildOpCodes(std::initializer_list<::tflite::BuiltinOperator> op_codes)73 Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes(
74 std::initializer_list<::tflite::BuiltinOperator> op_codes) {
75 std::vector<Offset<::tflite::OperatorCode>> op_codes_vector;
76 for (auto op : op_codes) {
77 op_codes_vector.push_back(::tflite::CreateOperatorCode(builder_, op, 0));
78 }
79 return builder_.CreateVector(op_codes_vector);
80 }
81
BuildOpCodes()82 Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes() {
83 return BuildOpCodes({::tflite::BuiltinOperator_MAX_POOL_2D,
84 ::tflite::BuiltinOperator_CONV_2D});
85 }
86
BuildOperators(std::initializer_list<int> inputs,std::initializer_list<int> outputs)87 Offset<Vector<Offset<::tflite::Operator>>> BuildOperators(
88 std::initializer_list<int> inputs, std::initializer_list<int> outputs) {
89 auto is = builder_.CreateVector<int>(inputs);
90 if (inputs.size() == 0) is = 0;
91 auto os = builder_.CreateVector<int>(outputs);
92 if (outputs.size() == 0) os = 0;
93 auto op = ::tflite::CreateOperator(
94 builder_, 0, is, os, ::tflite::BuiltinOptions_Conv2DOptions,
95 ::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_VALID, 1, 1,
96 ::tflite::ActivationFunctionType_NONE)
97 .Union(),
98 /*custom_options=*/0, ::tflite::CustomOptionsFormat_FLEXBUFFERS);
99
100 return builder_.CreateVector(std::vector<Offset<::tflite::Operator>>({op}));
101 }
102
BuildOperators()103 Offset<Vector<Offset<::tflite::Operator>>> BuildOperators() {
104 return BuildOperators({0}, {1});
105 }
106
BuildSubGraphs(Offset<Vector<Offset<::tflite::Tensor>>> tensors,Offset<Vector<Offset<::tflite::Operator>>> operators,int num_sub_graphs=1)107 Offset<Vector<Offset<::tflite::SubGraph>>> BuildSubGraphs(
108 Offset<Vector<Offset<::tflite::Tensor>>> tensors,
109 Offset<Vector<Offset<::tflite::Operator>>> operators,
110 int num_sub_graphs = 1) {
111 std::vector<int32_t> inputs = {0};
112 std::vector<int32_t> outputs = {1};
113 std::vector<Offset<::tflite::SubGraph>> v;
114 for (int i = 0; i < num_sub_graphs; ++i) {
115 v.push_back(::tflite::CreateSubGraph(
116 builder_, tensors, builder_.CreateVector(inputs),
117 builder_.CreateVector(outputs), operators,
118 builder_.CreateString("subgraph")));
119 }
120 return builder_.CreateVector(v);
121 }
122
123 // This is a very simplistic model. We are not interested in testing all the
124 // details here, since tf.mini's testing framework will be exercising all the
125 // conversions multiple times, and the conversion of operators is tested by
126 // separate unittests.
BuildTestModel()127 void BuildTestModel() {
128 auto buffers = BuildBuffers();
129 auto tensors = BuildTensors();
130 auto opcodes = BuildOpCodes();
131 auto operators = BuildOperators();
132 auto subgraphs = BuildSubGraphs(tensors, operators);
133 auto s = builder_.CreateString("");
134
135 ::tflite::FinishModelBuffer(
136 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
137 opcodes, subgraphs, s, buffers));
138
139 input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
140 }
InputModelAsString()141 std::string InputModelAsString() {
142 return std::string(reinterpret_cast<char*>(builder_.GetBufferPointer()),
143 builder_.GetSize());
144 }
145 flatbuffers::FlatBufferBuilder builder_;
146 const ::tflite::Model* input_model_ = nullptr;
147 };
148
TEST_F(ImportTest,LoadTensorsTable)149 TEST_F(ImportTest, LoadTensorsTable) {
150 BuildTestModel();
151
152 details::TensorsTable tensors;
153 details::LoadTensorsTable(*input_model_, &tensors);
154 EXPECT_THAT(tensors, ElementsAre("tensor_one", "tensor_two"));
155 }
156
TEST_F(ImportTest,LoadOperatorsTable)157 TEST_F(ImportTest, LoadOperatorsTable) {
158 BuildTestModel();
159
160 details::OperatorsTable operators;
161 details::LoadOperatorsTable(*input_model_, &operators);
162 EXPECT_THAT(operators, ElementsAre("MAX_POOL_2D", "CONV_2D"));
163 }
164
TEST_F(ImportTest,Tensors)165 TEST_F(ImportTest, Tensors) {
166 BuildTestModel();
167
168 auto model = Import(ModelFlags(), InputModelAsString());
169
170 ASSERT_GT(model->HasArray("tensor_one"), 0);
171 Array& a1 = model->GetArray("tensor_one");
172 EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
173 EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
174 ElementsAre(1.0f, 2.0f, 3.0f, 4.0f));
175 ASSERT_TRUE(a1.has_shape());
176 EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 2));
177
178 const auto& mm = a1.minmax;
179 ASSERT_TRUE(mm.get());
180 EXPECT_FLOAT_EQ(0.1, mm->min);
181 EXPECT_FLOAT_EQ(0.2, mm->max);
182
183 const auto& q = a1.quantization_params;
184 ASSERT_TRUE(q.get());
185 EXPECT_FLOAT_EQ(0.3, q->scale);
186 EXPECT_EQ(100, q->zero_point);
187 }
188
TEST_F(ImportTest,NoBuffers)189 TEST_F(ImportTest, NoBuffers) {
190 auto buffers = 0;
191 auto tensors = BuildTensors();
192 auto opcodes = BuildOpCodes();
193 auto operators = BuildOperators();
194 auto subgraphs = BuildSubGraphs(tensors, operators);
195 auto comment = builder_.CreateString("");
196 ::tflite::FinishModelBuffer(
197 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
198 subgraphs, comment, buffers));
199 EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
200 "Missing 'buffers' section.");
201 }
202
TEST_F(ImportTest,NoInputs)203 TEST_F(ImportTest, NoInputs) {
204 auto buffers = BuildBuffers();
205 auto tensors = BuildTensors();
206 auto opcodes = BuildOpCodes();
207 auto operators = BuildOperators({}, {1});
208 auto subgraphs = BuildSubGraphs(tensors, operators);
209 auto comment = builder_.CreateString("");
210 ::tflite::FinishModelBuffer(
211 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
212 subgraphs, comment, buffers));
213 EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
214 "Missing 'inputs' for operator.");
215 }
216
TEST_F(ImportTest,NoOutputs)217 TEST_F(ImportTest, NoOutputs) {
218 auto buffers = BuildBuffers();
219 auto tensors = BuildTensors();
220 auto opcodes = BuildOpCodes();
221 auto operators = BuildOperators({0}, {});
222 auto subgraphs = BuildSubGraphs(tensors, operators);
223 auto comment = builder_.CreateString("");
224 ::tflite::FinishModelBuffer(
225 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
226 subgraphs, comment, buffers));
227 EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
228 "Missing 'outputs' for operator.");
229 }
230
TEST_F(ImportTest,InvalidOpCode)231 TEST_F(ImportTest, InvalidOpCode) {
232 auto buffers = BuildBuffers();
233 auto tensors = BuildTensors();
234 auto opcodes = BuildOpCodes({static_cast<::tflite::BuiltinOperator>(-1),
235 ::tflite::BuiltinOperator_CONV_2D});
236 auto operators = BuildOperators();
237 auto subgraphs = BuildSubGraphs(tensors, operators);
238 auto comment = builder_.CreateString("");
239 ::tflite::FinishModelBuffer(
240 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
241 subgraphs, comment, buffers));
242 EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
243 "Operator id '-1' is out of range.");
244 }
245
TEST_F(ImportTest,MultipleSubGraphs)246 TEST_F(ImportTest, MultipleSubGraphs) {
247 auto buffers = BuildBuffers();
248 auto tensors = BuildTensors();
249 auto opcodes = BuildOpCodes();
250 auto operators = BuildOperators();
251 auto subgraphs = BuildSubGraphs(tensors, operators, 2);
252 auto comment = builder_.CreateString("");
253 ::tflite::FinishModelBuffer(
254 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
255 subgraphs, comment, buffers));
256
257 input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
258
259 EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
260 "Number of subgraphs in tflite should be exactly 1.");
261 }
262
263 // TODO(ahentz): still need tests for Operators and IOTensors.
264
265 } // namespace
266 } // namespace tflite
267
268 } // namespace toco
269