xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/flex/delegate_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/delegate.h"
16 
17 #include <cstdint>
18 #include <memory>
19 #include <vector>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/delegates/flex/test_util.h"
24 #include "tensorflow/lite/shared_library.h"
25 
26 namespace tflite {
27 namespace flex {
28 namespace {
29 
30 using ::testing::ElementsAre;
31 
32 class DelegateTest : public testing::FlexModelTest {
33  public:
DelegateTest()34   DelegateTest() : delegate_(FlexDelegate::Create()) {
35     flex_delegate_ = static_cast<FlexDelegate*>(delegate_->data_);
36     interpreter_ = std::make_unique<Interpreter>(&error_reporter_);
37   }
38 
~DelegateTest()39   ~DelegateTest() override {
40     // The delegate needs to be destructed after the interpreter because the
41     // interpreter references data contained in the delegate.
42     interpreter_.reset();
43     delegate_.reset();
44   }
45 
ConfigureDelegate()46   void ConfigureDelegate() {
47     interpreter_->SetCancellationFunction(flex_delegate_,
48                                           FlexDelegate::HasCancelled);
49     ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
50               kTfLiteOk);
51   }
52 
Cancel()53   void Cancel() { flex_delegate_->Cancel(); }
54 
55  private:
56   std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> delegate_;
57   FlexDelegate* flex_delegate_;
58 };
59 
TEST_F(DelegateTest,FullGraph)60 TEST_F(DelegateTest, FullGraph) {
61   // Define the graph.
62   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
63 
64   AddTfOp(testing::kUnpack, {0}, {1, 2});
65   AddTfOp(testing::kUnpack, {3}, {4, 5});
66   AddTfOp(testing::kAdd, {1, 4}, {6});
67   AddTfOp(testing::kAdd, {2, 5}, {7});
68   AddTfOp(testing::kMul, {6, 7}, {8});
69 
70   // Apply the delegate.
71   ConfigureDelegate();
72 
73   // Define inputs.
74   SetShape(0, {2, 2, 1});
75   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
76   SetShape(3, {2, 2, 1});
77   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
78 
79   ASSERT_TRUE(Invoke());
80 
81   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
82   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
83   ASSERT_EQ(GetType(8), kTfLiteFloat32);
84 }
85 
TEST_F(DelegateTest,NonFloatTypeInference)86 TEST_F(DelegateTest, NonFloatTypeInference) {
87   AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
88 
89   AddTfOp(testing::kAdd, {0, 1}, {2});
90 
91   ConfigureDelegate();
92 
93   SetShape(0, {2, 2});
94   SetTypedValues<int>(0, {1, 2, 3, 4});
95   SetShape(1, {2, 2});
96   SetTypedValues<int>(1, {4, 3, 2, 1});
97 
98   ASSERT_TRUE(Invoke());
99 
100   ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
101   ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
102   ASSERT_EQ(GetType(2), kTfLiteInt32);
103 }
104 
TEST_F(DelegateTest,StringInference)105 TEST_F(DelegateTest, StringInference) {
106   AddTensors(3, {0, 1}, {2}, kTfLiteString, {2});
107 
108   AddTfOp(testing::kAdd, {0, 1}, {2});
109 
110   ConfigureDelegate();
111 
112   SetShape(0, {2, 2});
113   SetStringValues(0, {"1", "2", "3", "4"});
114   SetShape(1, {2, 2});
115   SetStringValues(1, {"4", "3", "2", "1"});
116 
117   ASSERT_TRUE(Invoke());
118 
119   ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
120   ASSERT_THAT(GetStringValues(2), ElementsAre("14", "23", "32", "41"));
121   ASSERT_EQ(GetType(2), kTfLiteString);
122 }
123 
TEST_F(DelegateTest,MixedGraph)124 TEST_F(DelegateTest, MixedGraph) {
125   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
126 
127   AddTfOp(testing::kUnpack, {0}, {1, 2});
128   AddTfOp(testing::kUnpack, {3}, {4, 5});
129   AddTfOp(testing::kAdd, {1, 4}, {6});
130   AddTfOp(testing::kAdd, {2, 5}, {7});
131   AddTfLiteMulOp({6, 7}, {8});
132 
133   ConfigureDelegate();
134 
135   SetShape(0, {2, 2, 1});
136   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
137   SetShape(3, {2, 2, 1});
138   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
139 
140   ASSERT_TRUE(Invoke());
141 
142   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
143   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
144 }
145 
TEST_F(DelegateTest,SplitGraph)146 TEST_F(DelegateTest, SplitGraph) {
147   AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
148 
149   AddTfOp(testing::kUnpack, {0}, {1, 2});
150   AddTfOp(testing::kAdd, {1, 2}, {3});
151   AddTfOp(testing::kUnpack, {3}, {4, 5});
152 
153   AddTfLiteMulOp({4, 5}, {6});
154 
155   AddTfOp(testing::kUnpack, {6}, {7, 8});
156   AddTfOp(testing::kAdd, {7, 8}, {9});
157 
158   ConfigureDelegate();
159 
160   SetShape(0, {2, 2, 2, 1});
161   SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
162 
163   ASSERT_TRUE(Invoke());
164 
165   ASSERT_THAT(GetShape(9), ElementsAre(1));
166   ASSERT_THAT(GetValues(9), ElementsAre(10.0f));
167 }
168 
TEST_F(DelegateTest,OnlyTFLite)169 TEST_F(DelegateTest, OnlyTFLite) {
170   // Only TFLite single op model.
171   AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
172   AddTfLiteMulOp({0, 1}, {2});
173 
174   ConfigureDelegate();
175 
176   SetShape(0, {2, 2, 1});
177   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
178   SetShape(1, {2, 2, 1});
179   SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
180 
181   ASSERT_TRUE(Invoke());
182 
183   ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
184   ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
185 }
186 
TEST_F(DelegateTest,MultipleInvokeCalls)187 TEST_F(DelegateTest, MultipleInvokeCalls) {
188   // Call Invoke() multiple times on the same model.
189   AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
190   AddTfLiteMulOp({0, 1}, {2});
191 
192   ConfigureDelegate();
193 
194   SetShape(0, {2, 2, 1});
195   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
196   SetShape(1, {2, 2, 1});
197   SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
198 
199   ASSERT_TRUE(Invoke());
200 
201   ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
202   ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
203 
204   SetShape(0, {2, 2, 1});
205   SetValues(1, {4.0f, 3.0f, 2.0f, 1.0f});
206   SetShape(1, {2, 2, 1});
207   SetValues(0, {4.4f, 3.3f, 2.2f, 1.1f});
208 
209   ASSERT_TRUE(Invoke());
210 
211   ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
212   ASSERT_THAT(GetValues(2), ElementsAre(17.6f, 9.9f, 4.4f, 1.1f));
213 }
214 
TEST_F(DelegateTest,MultipleInterpretersSameDelegate)215 TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
216   // Build a graph, configure the delegate and set inputs.
217   {
218     AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
219     AddTfOp(testing::kUnpack, {0}, {1, 2});
220     AddTfOp(testing::kUnpack, {3}, {4, 5});
221     AddTfOp(testing::kAdd, {1, 4}, {6});
222     AddTfOp(testing::kAdd, {2, 5}, {7});
223     AddTfOp(testing::kMul, {6, 7}, {8});
224     ConfigureDelegate();
225     SetShape(0, {2, 2, 1});
226     SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
227     SetShape(3, {2, 2, 1});
228     SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
229   }
230 
231   // Create a new interpreter, inject into the test framework and build
232   // a different graph using the *same* delegate.
233   std::unique_ptr<Interpreter> interpreter(new Interpreter(&error_reporter_));
234   interpreter_.swap(interpreter);
235   {
236     AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
237     AddTfOp(testing::kUnpack, {0}, {1, 2});
238     AddTfOp(testing::kAdd, {1, 2}, {3});
239     AddTfOp(testing::kUnpack, {3}, {4, 5});
240     AddTfLiteMulOp({4, 5}, {6});
241     AddTfOp(testing::kUnpack, {6}, {7, 8});
242     AddTfOp(testing::kAdd, {7, 8}, {9});
243     ConfigureDelegate();
244     SetShape(0, {2, 2, 2, 1});
245     SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
246   }
247 
248   // Swap back in the first interpreter and validate inference.
249   interpreter_.swap(interpreter);
250   {
251     ASSERT_TRUE(Invoke());
252     EXPECT_THAT(GetShape(8), ElementsAre(2, 1));
253     EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
254   }
255 
256   // Swap in the second interpreter and validate inference.
257   interpreter_.swap(interpreter);
258   {
259     ASSERT_TRUE(Invoke());
260     EXPECT_THAT(GetShape(9), ElementsAre(1));
261     EXPECT_THAT(GetValues(9), ElementsAre(10.0f));
262   }
263 }
264 
TEST_F(DelegateTest,SingleThreaded)265 TEST_F(DelegateTest, SingleThreaded) {
266   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
267   AddTfOp(testing::kUnpack, {0}, {1, 2});
268   AddTfOp(testing::kUnpack, {3}, {4, 5});
269   AddTfOp(testing::kAdd, {1, 4}, {6});
270   AddTfOp(testing::kAdd, {2, 5}, {7});
271   AddTfOp(testing::kMul, {6, 7}, {8});
272 
273   // Explicitly disable multi-threading before installing the delegate.
274   interpreter_->SetNumThreads(1);
275   ConfigureDelegate();
276 
277   SetShape(0, {2, 2, 1});
278   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
279   SetShape(3, {2, 2, 1});
280   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
281 
282   // Invocation should behave as expected.
283   ASSERT_TRUE(Invoke());
284 
285   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
286   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
287   ASSERT_EQ(GetType(8), kTfLiteFloat32);
288 }
289 
TEST_F(DelegateTest,MultiThreaded)290 TEST_F(DelegateTest, MultiThreaded) {
291   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
292   AddTfOp(testing::kUnpack, {0}, {1, 2});
293   AddTfOp(testing::kUnpack, {3}, {4, 5});
294   AddTfOp(testing::kAdd, {1, 4}, {6});
295   AddTfOp(testing::kAdd, {2, 5}, {7});
296   AddTfOp(testing::kMul, {6, 7}, {8});
297 
298   // Explicitly enable multi-threading before installing the delegate.
299   interpreter_->SetNumThreads(4);
300   ConfigureDelegate();
301 
302   SetShape(0, {2, 2, 1});
303   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
304   SetShape(3, {2, 2, 1});
305   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
306 
307   // Invocation should behave as expected.
308   ASSERT_TRUE(Invoke());
309 
310   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
311   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
312   ASSERT_EQ(GetType(8), kTfLiteFloat32);
313 }
314 
315 #if !defined(__ANDROID__)
TEST_F(DelegateTest,TF_AcquireFlexDelegate)316 TEST_F(DelegateTest, TF_AcquireFlexDelegate) {
317   auto TF_AcquireFlexDelegate =
318       reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
319           SharedLibrary::GetSymbol("TF_AcquireFlexDelegate"));
320   ASSERT_TRUE(TF_AcquireFlexDelegate);
321   auto delegate_ptr = TF_AcquireFlexDelegate();
322   ASSERT_TRUE(delegate_ptr != nullptr);
323 }
324 #endif  // !defined(__ANDROID__)
325 
TEST_F(DelegateTest,StaticOutput)326 TEST_F(DelegateTest, StaticOutput) {
327   // Define the graph with input, output shapes of [2].
328   AddTensors(7, {0, 1, 2, 3}, {6}, kTfLiteFloat32, {2});
329 
330   AddTfOp(testing::kAdd, {0, 2}, {4});
331   AddTfOp(testing::kAdd, {1, 3}, {5});
332   AddTfOp(testing::kMul, {4, 5}, {6});
333 
334   // Apply the delegate.
335   ConfigureDelegate();
336 
337   // Define inputs which matech with the original shapes.
338   SetShape(0, {2});
339   SetShape(1, {2});
340   SetShape(2, {2});
341   SetShape(3, {2});
342   SetValues(0, {1.1f, 2.2f});
343   SetValues(1, {3.3f, 4.4f});
344   SetValues(2, {1.1f, 2.2f});
345   SetValues(3, {3.3f, 4.4f});
346 
347   ASSERT_TRUE(Invoke());
348 
349   ASSERT_THAT(GetShape(6), ElementsAre(2));
350   ASSERT_THAT(GetValues(6), ElementsAre(14.52f, 38.72f));
351   ASSERT_EQ(GetType(6), kTfLiteFloat32);
352   // Since shapes are consistent, static output tensor is used.
353   ASSERT_FALSE(IsDynamicTensor(6));
354 }
355 
TEST_F(DelegateTest,StaticOutputRFFT)356 TEST_F(DelegateTest, StaticOutputRFFT) {
357   // Define the graph with input, output shapes of [3, 257].
358   AddTensors(4, {0, 1}, {3}, kTfLiteFloat32, {3, 257});
359   int32_t rfft_length[] = {512};
360   SetConstTensor(1, {1}, kTfLiteInt32,
361                  reinterpret_cast<const char*>(&rfft_length),
362                  sizeof(rfft_length));
363 
364   AddTfOp(testing::kRfft, {0, 1}, {2});
365   AddTfOp(testing::kImag, {2}, {3});
366 
367   // Apply the delegate.
368   ConfigureDelegate();
369 
370   // Define inputs.
371   SetShape(0, {3, 512});
372   SetValues(0, std::vector<float>(3 * 512, 1.0f));
373 
374   ASSERT_TRUE(Invoke());
375 
376   ASSERT_EQ(GetType(3), kTfLiteFloat32);
377   // Since shapes are consistent, static output tensor is used.
378   ASSERT_FALSE(IsDynamicTensor(3));
379 }
380 
TEST_F(DelegateTest,DynamicOutputAfterReshape)381 TEST_F(DelegateTest, DynamicOutputAfterReshape) {
382   // Define the graph.
383   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
384 
385   AddTfOp(testing::kUnpack, {0}, {1, 2});
386   AddTfOp(testing::kUnpack, {3}, {4, 5});
387   AddTfOp(testing::kAdd, {1, 4}, {6});
388   AddTfOp(testing::kAdd, {2, 5}, {7});
389   AddTfOp(testing::kMul, {6, 7}, {8});
390 
391   // Apply the delegate.
392   ConfigureDelegate();
393 
394   // Define inputs with reshape.
395   SetShape(0, {2, 2, 1});
396   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
397   SetShape(3, {2, 2, 1});
398   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
399 
400   ASSERT_TRUE(Invoke());
401 
402   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
403   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
404   ASSERT_EQ(GetType(8), kTfLiteFloat32);
405   // Since shapes are inconsistent, dynamic output tensor is used.
406   ASSERT_TRUE(IsDynamicTensor(8));
407 }
408 
TEST_F(DelegateTest,TestCancellation1)409 TEST_F(DelegateTest, TestCancellation1) {
410   AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
411 
412   AddTfOp(testing::kAdd, {0, 1}, {2});
413 
414   ConfigureDelegate();
415 
416   SetShape(0, {2, 2});
417   SetTypedValues<int>(0, {1, 2, 3, 4});
418   SetShape(1, {2, 2});
419   SetTypedValues<int>(1, {4, 3, 2, 1});
420 
421   ASSERT_TRUE(Invoke());
422 
423   ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
424   ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
425   ASSERT_EQ(GetType(2), kTfLiteInt32);
426 
427   Cancel();
428   // Op should be cancelled.
429   ASSERT_FALSE(Invoke());
430   // TODO(b/205345340): We shouldn't do raw string matching here. Instead we
431   // need to introduce fine-grained error codes to represent cancellation
432   // status.
433   EXPECT_EQ(error_reporter_.error_messages(),
434             "Client requested cancel during Invoke()");
435 }
436 
TEST_F(DelegateTest,TestCancellation2)437 TEST_F(DelegateTest, TestCancellation2) {
438   // Define the graph.
439   AddTensors(2, {0}, {1}, kTfLiteBool, {1});
440 
441   // We need an op that checks the CancellationManager status.
442   AddTfOp(testing::kLoopCond, {0}, {1});
443 
444   // Apply the delegate.
445   ConfigureDelegate();
446 
447   // Define inputs.
448   SetShape(0, {1});
449 
450   ASSERT_TRUE(Invoke());
451 
452   Cancel();
453   // Op should be cancelled.
454   ASSERT_FALSE(Invoke());
455   // TODO(b/205345340): We shouldn't do raw string matching here. Instead we
456   // need to introduce fine-grained error codes to represent cancellation
457   // status.
458   EXPECT_EQ(error_reporter_.error_messages(),
459             "Client requested cancel during Invoke()");
460 }
461 
TEST_F(DelegateTest,TestCancellationTwoThreads)462 TEST_F(DelegateTest, TestCancellationTwoThreads) {
463   AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
464 
465   AddTfOp(testing::kAdd, {0, 1}, {2});
466 
467   ConfigureDelegate();
468 
469   SetShape(0, {2, 2});
470   SetTypedValues<int>(0, {1, 2, 3, 4});
471   SetShape(1, {2, 2});
472   SetTypedValues<int>(1, {4, 3, 2, 1});
473 
474   std::thread invoke_thread([this]() {
475     bool result = true;
476     result = this->Invoke();
477     std::this_thread::sleep_for(std::chrono::milliseconds(1000));
478     result = this->Invoke();
479     ASSERT_FALSE(result);
480     // TODO(b/205345340): Check returned error code.
481   });
482 
483   std::thread cancel_thread([this]() { this->Cancel(); });
484 
485   invoke_thread.join();
486   cancel_thread.join();
487 }
488 
489 }  // namespace
490 }  // namespace flex
491 }  // namespace tflite
492