1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstddef>
16 #include <memory>
17 #include <vector>
18
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
22 #include "tensorflow/lite/builtin_ops.h"
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/core/subgraph.h"
26 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/call_register.h"
27 #include "tensorflow/lite/interpreter.h"
28 #include "tensorflow/lite/interpreter_test_util.h"
29 #include "tensorflow/lite/kernels/builtin_op_kernels.h"
30 #include "tensorflow/lite/kernels/subgraph_test_util.h"
31 #include "tensorflow/lite/testing/util.h"
32
33 namespace tflite {
34
35 namespace {
36
37 class CallTest : public subgraph_test_util::ControlFlowOpTest {
38 public:
CallTest()39 CallTest() { interpreter_ = std::make_unique<Interpreter>(&error_reporter_); }
40 ~CallTest() override = default;
SetupTensor(Subgraph * subgraph,int tensor_index,TfLiteType type)41 void SetupTensor(Subgraph* subgraph, int tensor_index, TfLiteType type) {
42 ASSERT_EQ(subgraph->SetTensorParametersReadWrite(tensor_index, type, "", 0,
43 nullptr, {}, false),
44 kTfLiteOk);
45 }
BuildCallSubgraph(Subgraph * subgraph,std::vector<uint8_t> params_buffer,std::vector<int> inputs,std::vector<int> outputs,int expected_node_index,bool single_node_subgraph)46 void BuildCallSubgraph(Subgraph* subgraph, std::vector<uint8_t> params_buffer,
47 std::vector<int> inputs, std::vector<int> outputs,
48 int expected_node_index, bool single_node_subgraph) {
49 if (single_node_subgraph) {
50 int first_new_tensor_index;
51 ASSERT_EQ(subgraph->AddTensors(inputs.size() + outputs.size(),
52 &first_new_tensor_index),
53 kTfLiteOk);
54 ASSERT_EQ(first_new_tensor_index, 0);
55 ASSERT_EQ(subgraph->SetInputs(inputs), kTfLiteOk);
56 ASSERT_EQ(subgraph->SetOutputs(outputs), kTfLiteOk);
57 }
58 for (const int& idx : inputs) {
59 SetupTensor(subgraph, idx, kTfLiteInt32);
60 }
61 for (const int& idx : outputs) {
62 SetupTensor(subgraph, idx, kTfLiteInt32);
63 }
64 int node_index;
65 subgraph->AddNodeWithParameters(
66 inputs, outputs, {},
67 reinterpret_cast<const char*>(params_buffer.data()),
68 params_buffer.size(), nullptr, acceleration::ops::Register_CALL(),
69 &node_index);
70 ASSERT_EQ(node_index, expected_node_index);
71 }
BuildCallSubgraph(Subgraph * subgraph,int index,int loop_count,std::vector<int> inputs,std::vector<int> outputs,int expected_node_index=0,bool single_node_subgraph=true)72 void BuildCallSubgraph(Subgraph* subgraph, int index, int loop_count,
73 std::vector<int> inputs, std::vector<int> outputs,
74 int expected_node_index = 0,
75 bool single_node_subgraph = true) {
76 flexbuffers::Builder fbb;
77 fbb.Map([&] {
78 fbb.Int("subgraph_index", index);
79 fbb.Int("loop_count", loop_count);
80 });
81 fbb.Finish();
82 BuildCallSubgraph(subgraph, fbb.GetBuffer(), inputs, outputs,
83 expected_node_index, single_node_subgraph);
84 }
85
BuildGraphWithMultipleOutputs(Subgraph * subgraph)86 void BuildGraphWithMultipleOutputs(Subgraph* subgraph) {
87 const int kInput1 = 0;
88 const int kInput2 = 1;
89 const int kMulOutput = 2;
90 const int kAddOutput = 3;
91 const int kTensorCount = 4;
92 // kInput1(0) --> +---+
93 // |MUL| --> kOutput(2)
94 // kInput2(1) --> +---+
95 //
96 // kInput1(0) --> +---+
97 // |ADD| --> kOutput(3)
98 // kInput2(1) --> +---+
99
100 int first_new_tensor_index;
101 ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
102 kTfLiteOk);
103 ASSERT_EQ(first_new_tensor_index, 0);
104 ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
105 ASSERT_EQ(subgraph->SetOutputs({kMulOutput, kAddOutput}), kTfLiteOk);
106
107 SetupTensor(subgraph, kInput1, kTfLiteInt32);
108 SetupTensor(subgraph, kInput2, kTfLiteInt32);
109 SetupTensor(subgraph, kMulOutput, kTfLiteInt32);
110 SetupTensor(subgraph, kAddOutput, kTfLiteInt32);
111 TfLiteMulParams* params_mul =
112 reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
113 params_mul->activation = kTfLiteActNone;
114 int node_index;
115 subgraph->AddNodeWithParameters(
116 {kInput1, kInput2}, {kMulOutput}, {}, nullptr, 0, params_mul,
117 ::tflite::ops::builtin::Register_MUL(), &node_index);
118 TfLiteAddParams* params_add =
119 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
120 params_add->activation = kTfLiteActNone;
121 subgraph->AddNodeWithParameters(
122 {kInput1, kInput2}, {kAddOutput}, {}, nullptr, 0, params_add,
123 ::tflite::ops::builtin::Register_ADD(), &node_index);
124 }
BuildMultiNodeGraph(Subgraph * this_subgraph)125 void BuildMultiNodeGraph(Subgraph* this_subgraph) {
126 // kIn1(0)----------------
127 // |
128 // | +----+
129 // +---+ -------->| | +---+
130 // kIn2(1)--> |PAD|-->kOut1(4)--->|CALL|-->kOut2(5)-->|MUL|-->kOut3(6)
131 // kIn3(2)--> | | | | | |
132 // +---+ +----+ ---->| |
133 // | +---+
134 // |
135 // |
136 // kIn4(3)----------------------------------------
137 const int kInput1 = 0, kInput2 = 1, kInput3 = 2, kInput4 = 3;
138 const int kOutput1 = 4, kOutput2 = 5, kOutput3 = 6;
139 const int kTensorCount = 7;
140 int first_new_tensor_index;
141 ASSERT_EQ(this_subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
142 kTfLiteOk);
143 ASSERT_EQ(first_new_tensor_index, 0);
144 std::vector<int> inputs = {kInput1, kInput2, kInput3, kInput4};
145 std::vector<int> outputs = {kOutput3};
146 ASSERT_EQ(this_subgraph->SetInputs(inputs), kTfLiteOk);
147 ASSERT_EQ(this_subgraph->SetOutputs({kOutput3}), kTfLiteOk);
148 for (int idx = 0; idx < kTensorCount; ++idx) {
149 SetupTensor(this_subgraph, idx, kTfLiteInt32);
150 }
151 int expected_node_index = 0, node_index;
152 // Node 1: Pad op.
153 auto* pad_reg = ops::builtin::Register_PAD();
154 pad_reg->builtin_code = kTfLiteBuiltinPad;
155 this_subgraph->AddNodeWithParameters(
156 {kInput2, kInput3}, {kOutput1}, {}, nullptr, 0,
157 reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLitePadParams))),
158 pad_reg, &node_index);
159 ASSERT_EQ(node_index, expected_node_index++);
160 // Node 2: Call op, calls subgraph that contains Add op.
161 AddSubgraphs(1);
162 const int kLoopCount = 1;
163 const int kSubgraphIndex = 1;
164 builder_->BuildAddSubgraph(interpreter_->subgraph(1));
165 CallTest::BuildCallSubgraph(this_subgraph, kSubgraphIndex, kLoopCount,
166 {kInput1, kOutput1}, {kOutput2},
167 expected_node_index++, false);
168 // Node 3: Mul op.
169 TfLiteMulParams* mul_params =
170 reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
171 mul_params->activation = kTfLiteActNone;
172 auto* mul_reg = ops::builtin::Register_MUL();
173 mul_reg->builtin_code = kTfLiteBuiltinMul;
174 this_subgraph->AddNodeWithParameters({kInput4, kOutput2}, {kOutput3}, {},
175 nullptr, 0, mul_params, mul_reg,
176 &node_index);
177 ASSERT_EQ(node_index, expected_node_index++);
178 }
179 TestErrorReporter error_reporter_;
180 };
181
182 /** Tests the happy path for `call` op. **/
TEST_F(CallTest,SubgraphMultipleInputsSingleOutput)183 TEST_F(CallTest, SubgraphMultipleInputsSingleOutput) {
184 std::vector<std::vector<int>> test_shapes = {
185 {3, 2}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
186 // Will loop over and will be fed to the subgraph as {1,2}, {1,3}, {1,1,3},
187 // {1,3,1,2}.
188 for (size_t i = 0; i < test_shapes.size(); ++i) {
189 interpreter_ = std::make_unique<Interpreter>();
190 AddSubgraphs(1);
191 int loop_count = test_shapes[i][0];
192 builder_->BuildMulSubgraph(interpreter_->subgraph(1));
193 CallTest::BuildCallSubgraph(&interpreter_->primary_subgraph(), 1,
194 loop_count, {0, 1}, {2});
195 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], test_shapes[i]);
196 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], test_shapes[i]);
197 ASSERT_EQ(interpreter_->subgraph(1)->AllocateTensors(), kTfLiteOk);
198 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
199
200 subgraph_test_util::FillIntTensor(
201 interpreter_->tensor(interpreter_->inputs()[0]), {-1, 2, -3, 4, -5, 6});
202 subgraph_test_util::FillIntTensor(
203 interpreter_->tensor(interpreter_->inputs()[1]), {-1, 2, -3, 4, -5, 6});
204 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
205
206 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
207 subgraph_test_util::CheckIntTensor(output, test_shapes[i],
208 {1, 4, 9, 16, 25, 36});
209 }
210 }
211
TEST_F(CallTest,ShouldBeANoOpWhenLoopCountIsZero)212 TEST_F(CallTest, ShouldBeANoOpWhenLoopCountIsZero) {
213 AddSubgraphs(1);
214 int loop_count = 0;
215 builder_->BuildMulSubgraph(interpreter_->subgraph(1));
216 CallTest::BuildCallSubgraph(&interpreter_->primary_subgraph(), 1, loop_count,
217 {0, 1}, {2});
218 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {0, 3});
219 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {0, 3});
220 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
221 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
222 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
223 subgraph_test_util::CheckIntTensor(output, {0, 3}, {});
224 }
225
TEST_F(CallTest,SubgraphWithFixedInputShapes)226 TEST_F(CallTest, SubgraphWithFixedInputShapes) {
227 AddSubgraphs(1);
228 const int kLoopCount = 2;
229 const int kBatchSizeSubgraph = 1;
230 const int kFixedInputLen = 3;
231 const std::vector<int> kCallOpInputShape = {kLoopCount, kFixedInputLen};
232 const std::vector<int> kSubgraphInputShape = {kBatchSizeSubgraph,
233 kFixedInputLen};
234
235 Subgraph* subgraph = interpreter_->subgraph(1);
236 builder_->BuildMulSubgraph(subgraph);
237 CallTest::BuildCallSubgraph(&interpreter_->primary_subgraph(), 1, kLoopCount,
238 {0, 1}, {2});
239 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], kCallOpInputShape);
240 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], kCallOpInputShape);
241
242 subgraph->ResizeInputTensor(subgraph->inputs()[0], kSubgraphInputShape);
243 subgraph->ResizeInputTensor(subgraph->inputs()[1], kSubgraphInputShape);
244
245 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
246 subgraph_test_util::FillIntTensor(
247 interpreter_->tensor(interpreter_->inputs()[0]), {-1, 2, -3, 4, -5, 6});
248 subgraph_test_util::FillIntTensor(
249 interpreter_->tensor(interpreter_->inputs()[1]), {-1, 2, -3, 4, -5, 6});
250 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
251
252 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
253 subgraph_test_util::CheckIntTensor(output, kCallOpInputShape,
254 {1, 4, 9, 16, 25, 36});
255 }
256
TEST_F(CallTest,SubgraphWithMultipleInputsAndOutputs)257 TEST_F(CallTest, SubgraphWithMultipleInputsAndOutputs) {
258 std::vector<std::vector<int>> test_shapes = {
259 {3, 2, 1}, {1, 2, 3}, {2, 1, 3}, {2, 3, 1, 1}, {2, 3}};
260 for (size_t i = 0; i < test_shapes.size(); ++i) {
261 interpreter_ = std::make_unique<Interpreter>();
262 AddSubgraphs(1);
263 int loop_count = test_shapes[i][0];
264 CallTest::BuildGraphWithMultipleOutputs(interpreter_->subgraph(1));
265 CallTest::BuildCallSubgraph(&interpreter_->primary_subgraph(), 1,
266 loop_count, {0, 1}, {2, 3});
267 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], test_shapes[i]);
268 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], test_shapes[i]);
269 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
270
271 subgraph_test_util::FillIntTensor(
272 interpreter_->tensor(interpreter_->inputs()[0]), {-1, 2, -3, 4, -5, 6});
273 subgraph_test_util::FillIntTensor(
274 interpreter_->tensor(interpreter_->inputs()[1]), {-1, 2, -3, 4, -5, 6});
275 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
276
277 TfLiteTensor* output_mul = interpreter_->tensor(interpreter_->outputs()[0]);
278 subgraph_test_util::CheckIntTensor(output_mul, test_shapes[i],
279 {1, 4, 9, 16, 25, 36});
280 TfLiteTensor* output_add = interpreter_->tensor(interpreter_->outputs()[1]);
281 subgraph_test_util::CheckIntTensor(output_add, test_shapes[i],
282 {-2, 4, -6, 8, -10, 12});
283 }
284 }
285
TEST_F(CallTest,ShouldHandleInvalidParamsAndSetToDefault)286 TEST_F(CallTest, ShouldHandleInvalidParamsAndSetToDefault) {
287 flexbuffers::Builder fbb;
288 fbb.Vector([&]() {
289 fbb.String("hi");
290 fbb.String("hello");
291 });
292 fbb.Finish();
293 AddSubgraphs(1);
294
295 CallTest::BuildCallSubgraph(&interpreter_->primary_subgraph(),
296 fbb.GetBuffer(), {0}, {1}, 0, true);
297 const int kNodeIndex = 0;
298 const TfLiteNode* call_node = &interpreter_->primary_subgraph()
299 .nodes_and_registration()[kNodeIndex]
300 .first;
301 tflite::acceleration::ops::TfLiteCallParams* op_data =
302 reinterpret_cast<tflite::acceleration::ops::TfLiteCallParams*>(
303 call_node->user_data);
304
305 EXPECT_EQ(op_data->subgraph_index, 0);
306 EXPECT_EQ(op_data->loop_count, 0);
307 }
TEST_F(CallTest,MultiNodeGraph)308 TEST_F(CallTest, MultiNodeGraph) {
309 CallTest::BuildMultiNodeGraph(&interpreter_->primary_subgraph());
310 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1, 4, 4, 1});
311 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2, 2, 1});
312 interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {4, 2});
313 interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {1, 4, 4, 1});
314
315 ASSERT_EQ(interpreter_->subgraph(1)->AllocateTensors(), kTfLiteOk);
316 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
317
318 subgraph_test_util::FillIntTensor(
319 interpreter_->tensor(interpreter_->inputs()[0]), std::vector<int>(16, 1));
320 subgraph_test_util::FillIntTensor(
321 interpreter_->tensor(interpreter_->inputs()[1]), {1, 2, 3, 4});
322 subgraph_test_util::FillIntTensor(
323 interpreter_->tensor(interpreter_->inputs()[2]),
324 {0, 0, 1, 1, 1, 1, 0, 0});
325 subgraph_test_util::FillIntTensor(
326 interpreter_->tensor(interpreter_->inputs()[3]), std::vector<int>(16, 2));
327
328 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
329 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
330 subgraph_test_util::CheckIntTensor(
331 output, {1, 4, 4, 1}, {2, 2, 2, 2, 2, 4, 6, 2, 2, 8, 10, 2, 2, 2, 2, 2});
332 }
333
334 // Note: For the tests below the error messages returned by the error reporter
335 // are of the following format:
336 // "<filename>:<line number> <error message>. Node <number name> failed to
337 // prepare.\n"
338 // It's sufficient to test whether the string returned by error reporter
339 // contains the expected error message.
TEST_F(CallTest,ShouldFailWith0DInputs)340 TEST_F(CallTest, ShouldFailWith0DInputs) {
341 AddSubgraphs(1);
342 int loop_count = 5;
343 builder_->BuildMulSubgraph(interpreter_->subgraph(1));
344 interpreter_->subgraph(1)->ResizeInputTensor(0, {});
345 interpreter_->subgraph(1)->ResizeInputTensor(1, {});
346 CallTest::BuildCallSubgraph(&interpreter_->primary_subgraph(), 1, loop_count,
347 {0, 1}, {2});
348
349 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteError);
350
351 EXPECT_THAT(
352 error_reporter_.error_messages(),
353 testing::HasSubstr(
354 "Dimensions of all of call node's inputs should be non-zero."));
355 }
356
TEST_F(CallTest,ShouldFailWhenLoopCountDoesNotMatchBatchSize)357 TEST_F(CallTest, ShouldFailWhenLoopCountDoesNotMatchBatchSize) {
358 AddSubgraphs(1);
359 int loop_count = 7;
360 builder_->BuildMulSubgraph(interpreter_->subgraph(1));
361 CallTest::BuildCallSubgraph(&interpreter_->primary_subgraph(), 1, loop_count,
362 {0, 1}, {2});
363 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {5, 3});
364 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {5, 3});
365
366 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteError);
367 EXPECT_THAT(
368 error_reporter_.error_messages(),
369 testing::HasSubstr("node_input->dims->data[0] != loop_count (5 != 7)"));
370 }
371
TEST_F(CallTest,ShouldFailForSubgraphWithIncompatibleInputShapes)372 TEST_F(CallTest, ShouldFailForSubgraphWithIncompatibleInputShapes) {
373 AddSubgraphs(1);
374 const int kLoopCount = 5;
375 const int kBatchSizeSubgraph = 1;
376 std::vector<int> call_op_input = {kLoopCount, 3};
377 std::vector<int> subgraph_input = {kBatchSizeSubgraph, 7};
378 Subgraph* subgraph = interpreter_->subgraph(1);
379 builder_->BuildMulSubgraph(subgraph);
380 CallTest::BuildCallSubgraph(&interpreter_->primary_subgraph(), 1, kLoopCount,
381 {0, 1}, {2});
382 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], call_op_input);
383 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], call_op_input);
384 subgraph->ResizeInputTensor(subgraph->inputs()[0], subgraph_input);
385 subgraph->ResizeInputTensor(subgraph->inputs()[1], subgraph_input);
386
387 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteError);
388
389 EXPECT_THAT(
390 error_reporter_.error_messages(),
391 testing::HasSubstr("All dimensions except the batch size should match "
392 "for call node and the subgraph to invoke"));
393 }
394
TEST_F(CallTest,ShouldFailWhenSubgraphIndexMatchesInvokedSubgraph)395 TEST_F(CallTest, ShouldFailWhenSubgraphIndexMatchesInvokedSubgraph) {
396 const int kPrimarySubgraphIndex = 0;
397 CallTest::BuildCallSubgraph(&interpreter_->primary_subgraph(),
398 kPrimarySubgraphIndex, 1, {0}, {1});
399
400 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteError);
401
402 EXPECT_THAT(
403 error_reporter_.error_messages(),
404 testing::HasSubstr(
405 "Subgraph to invoke must be different from the invoking graph."));
406 }
407
TEST_F(CallTest,ShouldFailWithNegativeLoopCount)408 TEST_F(CallTest, ShouldFailWithNegativeLoopCount) {
409 AddSubgraphs(1);
410 CallTest::BuildCallSubgraph(&interpreter_->primary_subgraph(), 1, -1, {0},
411 {1});
412
413 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteError);
414
415 EXPECT_THAT(error_reporter_.error_messages(),
416 testing::HasSubstr("Loop count must be positive."));
417 }
418
419 } // namespace
420 } // namespace tflite
421