1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/delegate.h"
16
17 #include <cstdint>
18 #include <memory>
19 #include <vector>
20
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/delegates/flex/test_util.h"
24 #include "tensorflow/lite/shared_library.h"
25
26 namespace tflite {
27 namespace flex {
28 namespace {
29
30 using ::testing::ElementsAre;
31
32 class DelegateTest : public testing::FlexModelTest {
33 public:
DelegateTest()34 DelegateTest() : delegate_(FlexDelegate::Create()) {
35 flex_delegate_ = static_cast<FlexDelegate*>(delegate_->data_);
36 interpreter_ = std::make_unique<Interpreter>(&error_reporter_);
37 }
38
~DelegateTest()39 ~DelegateTest() override {
40 // The delegate needs to be destructed after the interpreter because the
41 // interpreter references data contained in the delegate.
42 interpreter_.reset();
43 delegate_.reset();
44 }
45
ConfigureDelegate()46 void ConfigureDelegate() {
47 interpreter_->SetCancellationFunction(flex_delegate_,
48 FlexDelegate::HasCancelled);
49 ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
50 kTfLiteOk);
51 }
52
Cancel()53 void Cancel() { flex_delegate_->Cancel(); }
54
55 private:
56 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> delegate_;
57 FlexDelegate* flex_delegate_;
58 };
59
TEST_F(DelegateTest,FullGraph)60 TEST_F(DelegateTest, FullGraph) {
61 // Define the graph.
62 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
63
64 AddTfOp(testing::kUnpack, {0}, {1, 2});
65 AddTfOp(testing::kUnpack, {3}, {4, 5});
66 AddTfOp(testing::kAdd, {1, 4}, {6});
67 AddTfOp(testing::kAdd, {2, 5}, {7});
68 AddTfOp(testing::kMul, {6, 7}, {8});
69
70 // Apply the delegate.
71 ConfigureDelegate();
72
73 // Define inputs.
74 SetShape(0, {2, 2, 1});
75 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
76 SetShape(3, {2, 2, 1});
77 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
78
79 ASSERT_TRUE(Invoke());
80
81 ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
82 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
83 ASSERT_EQ(GetType(8), kTfLiteFloat32);
84 }
85
TEST_F(DelegateTest,NonFloatTypeInference)86 TEST_F(DelegateTest, NonFloatTypeInference) {
87 AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
88
89 AddTfOp(testing::kAdd, {0, 1}, {2});
90
91 ConfigureDelegate();
92
93 SetShape(0, {2, 2});
94 SetTypedValues<int>(0, {1, 2, 3, 4});
95 SetShape(1, {2, 2});
96 SetTypedValues<int>(1, {4, 3, 2, 1});
97
98 ASSERT_TRUE(Invoke());
99
100 ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
101 ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
102 ASSERT_EQ(GetType(2), kTfLiteInt32);
103 }
104
TEST_F(DelegateTest,StringInference)105 TEST_F(DelegateTest, StringInference) {
106 AddTensors(3, {0, 1}, {2}, kTfLiteString, {2});
107
108 AddTfOp(testing::kAdd, {0, 1}, {2});
109
110 ConfigureDelegate();
111
112 SetShape(0, {2, 2});
113 SetStringValues(0, {"1", "2", "3", "4"});
114 SetShape(1, {2, 2});
115 SetStringValues(1, {"4", "3", "2", "1"});
116
117 ASSERT_TRUE(Invoke());
118
119 ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
120 ASSERT_THAT(GetStringValues(2), ElementsAre("14", "23", "32", "41"));
121 ASSERT_EQ(GetType(2), kTfLiteString);
122 }
123
TEST_F(DelegateTest,MixedGraph)124 TEST_F(DelegateTest, MixedGraph) {
125 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
126
127 AddTfOp(testing::kUnpack, {0}, {1, 2});
128 AddTfOp(testing::kUnpack, {3}, {4, 5});
129 AddTfOp(testing::kAdd, {1, 4}, {6});
130 AddTfOp(testing::kAdd, {2, 5}, {7});
131 AddTfLiteMulOp({6, 7}, {8});
132
133 ConfigureDelegate();
134
135 SetShape(0, {2, 2, 1});
136 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
137 SetShape(3, {2, 2, 1});
138 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
139
140 ASSERT_TRUE(Invoke());
141
142 ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
143 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
144 }
145
TEST_F(DelegateTest,SplitGraph)146 TEST_F(DelegateTest, SplitGraph) {
147 AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
148
149 AddTfOp(testing::kUnpack, {0}, {1, 2});
150 AddTfOp(testing::kAdd, {1, 2}, {3});
151 AddTfOp(testing::kUnpack, {3}, {4, 5});
152
153 AddTfLiteMulOp({4, 5}, {6});
154
155 AddTfOp(testing::kUnpack, {6}, {7, 8});
156 AddTfOp(testing::kAdd, {7, 8}, {9});
157
158 ConfigureDelegate();
159
160 SetShape(0, {2, 2, 2, 1});
161 SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
162
163 ASSERT_TRUE(Invoke());
164
165 ASSERT_THAT(GetShape(9), ElementsAre(1));
166 ASSERT_THAT(GetValues(9), ElementsAre(10.0f));
167 }
168
TEST_F(DelegateTest,OnlyTFLite)169 TEST_F(DelegateTest, OnlyTFLite) {
170 // Only TFLite single op model.
171 AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
172 AddTfLiteMulOp({0, 1}, {2});
173
174 ConfigureDelegate();
175
176 SetShape(0, {2, 2, 1});
177 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
178 SetShape(1, {2, 2, 1});
179 SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
180
181 ASSERT_TRUE(Invoke());
182
183 ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
184 ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
185 }
186
TEST_F(DelegateTest,MultipleInvokeCalls)187 TEST_F(DelegateTest, MultipleInvokeCalls) {
188 // Call Invoke() multiple times on the same model.
189 AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
190 AddTfLiteMulOp({0, 1}, {2});
191
192 ConfigureDelegate();
193
194 SetShape(0, {2, 2, 1});
195 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
196 SetShape(1, {2, 2, 1});
197 SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
198
199 ASSERT_TRUE(Invoke());
200
201 ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
202 ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
203
204 SetShape(0, {2, 2, 1});
205 SetValues(1, {4.0f, 3.0f, 2.0f, 1.0f});
206 SetShape(1, {2, 2, 1});
207 SetValues(0, {4.4f, 3.3f, 2.2f, 1.1f});
208
209 ASSERT_TRUE(Invoke());
210
211 ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
212 ASSERT_THAT(GetValues(2), ElementsAre(17.6f, 9.9f, 4.4f, 1.1f));
213 }
214
TEST_F(DelegateTest,MultipleInterpretersSameDelegate)215 TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
216 // Build a graph, configure the delegate and set inputs.
217 {
218 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
219 AddTfOp(testing::kUnpack, {0}, {1, 2});
220 AddTfOp(testing::kUnpack, {3}, {4, 5});
221 AddTfOp(testing::kAdd, {1, 4}, {6});
222 AddTfOp(testing::kAdd, {2, 5}, {7});
223 AddTfOp(testing::kMul, {6, 7}, {8});
224 ConfigureDelegate();
225 SetShape(0, {2, 2, 1});
226 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
227 SetShape(3, {2, 2, 1});
228 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
229 }
230
231 // Create a new interpreter, inject into the test framework and build
232 // a different graph using the *same* delegate.
233 std::unique_ptr<Interpreter> interpreter(new Interpreter(&error_reporter_));
234 interpreter_.swap(interpreter);
235 {
236 AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
237 AddTfOp(testing::kUnpack, {0}, {1, 2});
238 AddTfOp(testing::kAdd, {1, 2}, {3});
239 AddTfOp(testing::kUnpack, {3}, {4, 5});
240 AddTfLiteMulOp({4, 5}, {6});
241 AddTfOp(testing::kUnpack, {6}, {7, 8});
242 AddTfOp(testing::kAdd, {7, 8}, {9});
243 ConfigureDelegate();
244 SetShape(0, {2, 2, 2, 1});
245 SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
246 }
247
248 // Swap back in the first interpreter and validate inference.
249 interpreter_.swap(interpreter);
250 {
251 ASSERT_TRUE(Invoke());
252 EXPECT_THAT(GetShape(8), ElementsAre(2, 1));
253 EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
254 }
255
256 // Swap in the second interpreter and validate inference.
257 interpreter_.swap(interpreter);
258 {
259 ASSERT_TRUE(Invoke());
260 EXPECT_THAT(GetShape(9), ElementsAre(1));
261 EXPECT_THAT(GetValues(9), ElementsAre(10.0f));
262 }
263 }
264
TEST_F(DelegateTest,SingleThreaded)265 TEST_F(DelegateTest, SingleThreaded) {
266 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
267 AddTfOp(testing::kUnpack, {0}, {1, 2});
268 AddTfOp(testing::kUnpack, {3}, {4, 5});
269 AddTfOp(testing::kAdd, {1, 4}, {6});
270 AddTfOp(testing::kAdd, {2, 5}, {7});
271 AddTfOp(testing::kMul, {6, 7}, {8});
272
273 // Explicitly disable multi-threading before installing the delegate.
274 interpreter_->SetNumThreads(1);
275 ConfigureDelegate();
276
277 SetShape(0, {2, 2, 1});
278 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
279 SetShape(3, {2, 2, 1});
280 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
281
282 // Invocation should behave as expected.
283 ASSERT_TRUE(Invoke());
284
285 ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
286 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
287 ASSERT_EQ(GetType(8), kTfLiteFloat32);
288 }
289
TEST_F(DelegateTest,MultiThreaded)290 TEST_F(DelegateTest, MultiThreaded) {
291 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
292 AddTfOp(testing::kUnpack, {0}, {1, 2});
293 AddTfOp(testing::kUnpack, {3}, {4, 5});
294 AddTfOp(testing::kAdd, {1, 4}, {6});
295 AddTfOp(testing::kAdd, {2, 5}, {7});
296 AddTfOp(testing::kMul, {6, 7}, {8});
297
298 // Explicitly enable multi-threading before installing the delegate.
299 interpreter_->SetNumThreads(4);
300 ConfigureDelegate();
301
302 SetShape(0, {2, 2, 1});
303 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
304 SetShape(3, {2, 2, 1});
305 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
306
307 // Invocation should behave as expected.
308 ASSERT_TRUE(Invoke());
309
310 ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
311 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
312 ASSERT_EQ(GetType(8), kTfLiteFloat32);
313 }
314
315 #if !defined(__ANDROID__)
TEST_F(DelegateTest,TF_AcquireFlexDelegate)316 TEST_F(DelegateTest, TF_AcquireFlexDelegate) {
317 auto TF_AcquireFlexDelegate =
318 reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
319 SharedLibrary::GetSymbol("TF_AcquireFlexDelegate"));
320 ASSERT_TRUE(TF_AcquireFlexDelegate);
321 auto delegate_ptr = TF_AcquireFlexDelegate();
322 ASSERT_TRUE(delegate_ptr != nullptr);
323 }
324 #endif // !defined(__ANDROID__)
325
TEST_F(DelegateTest,StaticOutput)326 TEST_F(DelegateTest, StaticOutput) {
327 // Define the graph with input, output shapes of [2].
328 AddTensors(7, {0, 1, 2, 3}, {6}, kTfLiteFloat32, {2});
329
330 AddTfOp(testing::kAdd, {0, 2}, {4});
331 AddTfOp(testing::kAdd, {1, 3}, {5});
332 AddTfOp(testing::kMul, {4, 5}, {6});
333
334 // Apply the delegate.
335 ConfigureDelegate();
336
337 // Define inputs which matech with the original shapes.
338 SetShape(0, {2});
339 SetShape(1, {2});
340 SetShape(2, {2});
341 SetShape(3, {2});
342 SetValues(0, {1.1f, 2.2f});
343 SetValues(1, {3.3f, 4.4f});
344 SetValues(2, {1.1f, 2.2f});
345 SetValues(3, {3.3f, 4.4f});
346
347 ASSERT_TRUE(Invoke());
348
349 ASSERT_THAT(GetShape(6), ElementsAre(2));
350 ASSERT_THAT(GetValues(6), ElementsAre(14.52f, 38.72f));
351 ASSERT_EQ(GetType(6), kTfLiteFloat32);
352 // Since shapes are consistent, static output tensor is used.
353 ASSERT_FALSE(IsDynamicTensor(6));
354 }
355
TEST_F(DelegateTest,StaticOutputRFFT)356 TEST_F(DelegateTest, StaticOutputRFFT) {
357 // Define the graph with input, output shapes of [3, 257].
358 AddTensors(4, {0, 1}, {3}, kTfLiteFloat32, {3, 257});
359 int32_t rfft_length[] = {512};
360 SetConstTensor(1, {1}, kTfLiteInt32,
361 reinterpret_cast<const char*>(&rfft_length),
362 sizeof(rfft_length));
363
364 AddTfOp(testing::kRfft, {0, 1}, {2});
365 AddTfOp(testing::kImag, {2}, {3});
366
367 // Apply the delegate.
368 ConfigureDelegate();
369
370 // Define inputs.
371 SetShape(0, {3, 512});
372 SetValues(0, std::vector<float>(3 * 512, 1.0f));
373
374 ASSERT_TRUE(Invoke());
375
376 ASSERT_EQ(GetType(3), kTfLiteFloat32);
377 // Since shapes are consistent, static output tensor is used.
378 ASSERT_FALSE(IsDynamicTensor(3));
379 }
380
TEST_F(DelegateTest,DynamicOutputAfterReshape)381 TEST_F(DelegateTest, DynamicOutputAfterReshape) {
382 // Define the graph.
383 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
384
385 AddTfOp(testing::kUnpack, {0}, {1, 2});
386 AddTfOp(testing::kUnpack, {3}, {4, 5});
387 AddTfOp(testing::kAdd, {1, 4}, {6});
388 AddTfOp(testing::kAdd, {2, 5}, {7});
389 AddTfOp(testing::kMul, {6, 7}, {8});
390
391 // Apply the delegate.
392 ConfigureDelegate();
393
394 // Define inputs with reshape.
395 SetShape(0, {2, 2, 1});
396 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
397 SetShape(3, {2, 2, 1});
398 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
399
400 ASSERT_TRUE(Invoke());
401
402 ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
403 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
404 ASSERT_EQ(GetType(8), kTfLiteFloat32);
405 // Since shapes are inconsistent, dynamic output tensor is used.
406 ASSERT_TRUE(IsDynamicTensor(8));
407 }
408
TEST_F(DelegateTest,TestCancellation1)409 TEST_F(DelegateTest, TestCancellation1) {
410 AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
411
412 AddTfOp(testing::kAdd, {0, 1}, {2});
413
414 ConfigureDelegate();
415
416 SetShape(0, {2, 2});
417 SetTypedValues<int>(0, {1, 2, 3, 4});
418 SetShape(1, {2, 2});
419 SetTypedValues<int>(1, {4, 3, 2, 1});
420
421 ASSERT_TRUE(Invoke());
422
423 ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
424 ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
425 ASSERT_EQ(GetType(2), kTfLiteInt32);
426
427 Cancel();
428 // Op should be cancelled.
429 ASSERT_FALSE(Invoke());
430 // TODO(b/205345340): We shouldn't do raw string matching here. Instead we
431 // need to introduce fine-grained error codes to represent cancellation
432 // status.
433 EXPECT_EQ(error_reporter_.error_messages(),
434 "Client requested cancel during Invoke()");
435 }
436
TEST_F(DelegateTest,TestCancellation2)437 TEST_F(DelegateTest, TestCancellation2) {
438 // Define the graph.
439 AddTensors(2, {0}, {1}, kTfLiteBool, {1});
440
441 // We need an op that checks the CancellationManager status.
442 AddTfOp(testing::kLoopCond, {0}, {1});
443
444 // Apply the delegate.
445 ConfigureDelegate();
446
447 // Define inputs.
448 SetShape(0, {1});
449
450 ASSERT_TRUE(Invoke());
451
452 Cancel();
453 // Op should be cancelled.
454 ASSERT_FALSE(Invoke());
455 // TODO(b/205345340): We shouldn't do raw string matching here. Instead we
456 // need to introduce fine-grained error codes to represent cancellation
457 // status.
458 EXPECT_EQ(error_reporter_.error_messages(),
459 "Client requested cancel during Invoke()");
460 }
461
TEST_F(DelegateTest,TestCancellationTwoThreads)462 TEST_F(DelegateTest, TestCancellationTwoThreads) {
463 AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
464
465 AddTfOp(testing::kAdd, {0, 1}, {2});
466
467 ConfigureDelegate();
468
469 SetShape(0, {2, 2});
470 SetTypedValues<int>(0, {1, 2, 3, 4});
471 SetShape(1, {2, 2});
472 SetTypedValues<int>(1, {4, 3, 2, 1});
473
474 std::thread invoke_thread([this]() {
475 bool result = true;
476 result = this->Invoke();
477 std::this_thread::sleep_for(std::chrono::milliseconds(1000));
478 result = this->Invoke();
479 ASSERT_FALSE(result);
480 // TODO(b/205345340): Check returned error code.
481 });
482
483 std::thread cancel_thread([this]() { this->Cancel(); });
484
485 invoke_thread.join();
486 cancel_thread.join();
487 }
488
489 } // namespace
490 } // namespace flex
491 } // namespace tflite
492