xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/serialization_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/serialization.h"
16 
17 #include <cstdint>
18 #include <string>
19 #include <vector>
20 
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/util.h"
24 
25 namespace tflite {
26 namespace delegates {
27 namespace {
28 
EmptyReportError(TfLiteContext * context,const char * format,...)29 void EmptyReportError(TfLiteContext* context, const char* format, ...) {}
30 
31 class SerializationTest : public ::testing::Test {
32  protected:
TearDown()33   void TearDown() override {
34     for (auto* owned_array : owned_arrays_) {
35       TfLiteIntArrayFree(owned_array);
36     }
37   }
38 
getSerializationDir()39   std::string getSerializationDir() {
40     auto from_env = ::testing::TempDir();
41     if (!from_env.empty()) {
42       return from_env;
43     }
44     return "";
45   }
46 
47   // Unique num_tensors creates unique context fingerprint for testing.
GenerateTfLiteContext(int num_tensors)48   TfLiteContext GenerateTfLiteContext(int num_tensors) {
49     owned_tensor_vecs_.emplace_back();
50     auto& tensors_vec = owned_tensor_vecs_.back();
51     for (int i = 0; i < num_tensors; ++i) {
52       tensors_vec.emplace_back();
53       auto& tensor = tensors_vec.back();
54       tensor.bytes = i + 1;
55     }
56 
57     TfLiteContext context;
58     context.tensors_size = num_tensors;
59     context.tensors = tensors_vec.data();
60     context.ReportError = EmptyReportError;
61     return context;
62   }
63 
GenerateTfLiteDelegateParams(int num_nodes,int num_input_tensors,int num_output_tensors)64   TfLiteDelegateParams GenerateTfLiteDelegateParams(int num_nodes,
65                                                     int num_input_tensors,
66                                                     int num_output_tensors) {
67     // Create a dummy execution plan.
68     auto* nodes_to_replace = TfLiteIntArrayCreate(num_nodes);
69     auto* input_tensors = TfLiteIntArrayCreate(num_input_tensors);
70     auto* output_tensors = TfLiteIntArrayCreate(num_output_tensors);
71     owned_arrays_.push_back(nodes_to_replace);
72     owned_arrays_.push_back(input_tensors);
73     owned_arrays_.push_back(output_tensors);
74     for (int i = 0; i < num_nodes; ++i) {
75       nodes_to_replace->data[i] = i;
76     }
77     for (int i = 0; i < num_input_tensors; ++i) {
78       input_tensors->data[i] = i + 2;
79     }
80     for (int i = 0; i < num_output_tensors; ++i) {
81       output_tensors->data[i] = i + 3;
82     }
83 
84     TfLiteDelegateParams params;
85     params.input_tensors = input_tensors;
86     params.output_tensors = output_tensors;
87     params.nodes_to_replace = nodes_to_replace;
88     return params;
89   }
90 
91   std::vector<TfLiteIntArray*> owned_arrays_;
92   std::vector<std::vector<TfLiteTensor>> owned_tensor_vecs_;
93 };
94 
TEST_F(SerializationTest,StrFingerprint)95 TEST_F(SerializationTest, StrFingerprint) {
96   std::vector<int> data_1 = {1, 2, 3, 4};
97   std::vector<int> data_1_equivalent = {1, 2, 3, 4};
98   std::vector<int> data_2 = {2, 4, 6, 8};
99 
100   auto fingerprint_1 =
101       StrFingerprint(data_1.data(), data_1.size() * sizeof(int));
102   auto fingerprint_1_equivalent = StrFingerprint(
103       data_1_equivalent.data(), data_1_equivalent.size() * sizeof(int));
104   auto fingerprint_2 =
105       StrFingerprint(data_2.data(), data_2.size() * sizeof(int));
106 
107   EXPECT_EQ(fingerprint_1, fingerprint_1_equivalent);
108   EXPECT_NE(fingerprint_1, fingerprint_2);
109 }
110 
TEST_F(SerializationTest,DelegateEntryFingerprint)111 TEST_F(SerializationTest, DelegateEntryFingerprint) {
112   const std::string model_token = "mobilenet";
113   const std::string dir = "/test/dir";
114   const std::string delegate1 = "gpu";
115   const std::string delegate2 = "nnapi";
116   TfLiteContext context1 = GenerateTfLiteContext(/*num_tensors*/ 20);
117   TfLiteContext context2 = GenerateTfLiteContext(/*num_tensors*/ 30);
118 
119   SerializationParams serialization_params = {model_token.c_str(), dir.c_str()};
120   Serialization serialization(serialization_params);
121 
122   // Different contexts yield different keys.
123   auto entry1 = serialization.GetEntryForDelegate(delegate1.c_str(), &context1);
124   auto entry2 = serialization.GetEntryForDelegate(delegate1.c_str(), &context2);
125   ASSERT_NE(entry1.GetFingerprint(), entry2.GetFingerprint());
126 
127   // Different custom_keys yield different keys.
128   auto entry3 = serialization.GetEntryForDelegate(delegate2.c_str(), &context1);
129   ASSERT_NE(entry1.GetFingerprint(), entry3.GetFingerprint());
130 
131   // Same fingerprint across serialization runs.
132   Serialization serialization2(serialization_params);
133   auto entry2_retry =
134       serialization2.GetEntryForDelegate(delegate1.c_str(), &context2);
135   ASSERT_EQ(entry2.GetFingerprint(), entry2_retry.GetFingerprint());
136 }
137 
TEST_F(SerializationTest,KernelEntryFingerprint)138 TEST_F(SerializationTest, KernelEntryFingerprint) {
139   const std::string model_token = "mobilenet";
140   const std::string dir = "/test/dir";
141   const std::string delegate = "gpu";
142   SerializationParams serialization_params = {model_token.c_str(), dir.c_str()};
143   Serialization serialization(serialization_params);
144 
145   TfLiteContext ref_context = GenerateTfLiteContext(/*num_tensors*/ 30);
146   TfLiteDelegateParams ref_partition = GenerateTfLiteDelegateParams(
147       /*num_nodes=*/3, /*num_input_tensors=*/4, /*num_output_tensors=*/2);
148   auto ref_entry = serialization.GetEntryForKernel(
149       delegate.c_str(), &ref_context, &ref_partition);
150 
151   // Different inputs to delegated partition => different fingerprint.
152   TfLiteDelegateParams diff_input_partition = GenerateTfLiteDelegateParams(
153       /*num_nodes=*/3, /*num_input_tensors=*/3, /*num_output_tensors=*/2);
154   ASSERT_NE(ref_entry.GetFingerprint(),
155             serialization
156                 .GetEntryForKernel(delegate.c_str(), &ref_context,
157                                    &diff_input_partition)
158                 .GetFingerprint());
159 
160   // Different outputs from delegated partition => different fingerprint.
161   TfLiteDelegateParams diff_output_partition = GenerateTfLiteDelegateParams(
162       /*num_nodes=*/3, /*num_input_tensors=*/4, /*num_output_tensors=*/3);
163   ASSERT_NE(ref_entry.GetFingerprint(),
164             serialization
165                 .GetEntryForKernel(delegate.c_str(), &ref_context,
166                                    &diff_output_partition)
167                 .GetFingerprint());
168 
169   // Different nodes from delegated partition => different fingerprint.
170   TfLiteDelegateParams diff_nodes_partition = GenerateTfLiteDelegateParams(
171       /*num_nodes=*/4, /*num_input_tensors=*/4, /*num_output_tensors=*/2);
172   ASSERT_NE(ref_entry.GetFingerprint(),
173             serialization
174                 .GetEntryForKernel(delegate.c_str(), &ref_context,
175                                    &diff_nodes_partition)
176                 .GetFingerprint());
177 
178   // Different contexts, same partition.
179   TfLiteContext other_context = GenerateTfLiteContext(/*num_tensors*/ 60);
180   ASSERT_NE(
181       ref_entry.GetFingerprint(),
182       serialization
183           .GetEntryForKernel(delegate.c_str(), &other_context, &ref_partition)
184           .GetFingerprint());
185 
186   // Same values across runs.
187   ASSERT_EQ(
188       ref_entry.GetFingerprint(),
189       serialization
190           .GetEntryForKernel(delegate.c_str(), &ref_context, &ref_partition)
191           .GetFingerprint());
192 
193   // Same value from a new Serialization instance.
194   Serialization serialization2(serialization_params);
195   ASSERT_EQ(
196       ref_entry.GetFingerprint(),
197       serialization
198           .GetEntryForKernel(delegate.c_str(), &ref_context, &ref_partition)
199           .GetFingerprint());
200 }
201 
TEST_F(SerializationTest,ModelTokenFingerprint)202 TEST_F(SerializationTest, ModelTokenFingerprint) {
203   std::string model_token1 = "model1";
204   std::string model_token2 = "model2";
205   const std::string dir = "/test/dir";
206   const std::string delegate = "gpu";
207   TfLiteContext context = GenerateTfLiteContext(/*num_tensors*/ 20);
208   TfLiteDelegateParams partition = GenerateTfLiteDelegateParams(
209       /*num_nodes=*/2, /*num_input_tensors=*/3, /*num_output_tensors=*/1);
210 
211   SerializationParams serialization_params1 = {model_token1.c_str(),
212                                                dir.c_str()};
213   Serialization serialization1(serialization_params1);
214   auto entry1 =
215       serialization1.GetEntryForKernel(delegate.c_str(), &context, &partition);
216   SerializationParams serialization_params2 = {model_token2.c_str(),
217                                                dir.c_str()};
218   Serialization serialization2(serialization_params2);
219   auto entry2 =
220       serialization2.GetEntryForKernel(delegate.c_str(), &context, &partition);
221 
222   // Same params, but different model tokens.
223   ASSERT_NE(entry1.GetFingerprint(), entry2.GetFingerprint());
224 
225   // Serialization Dir shouldn't matter for fingerprint values.
226   std::string serialization_dir2 = "/another/dir";
227   SerializationParams serialization_params3 = {model_token1.c_str(),
228                                                serialization_dir2.c_str()};
229   Serialization serialization3(serialization_params3);
230   auto entry3 =
231       serialization3.GetEntryForKernel(delegate.c_str(), &context, &partition);
232   ASSERT_EQ(entry1.GetFingerprint(), entry3.GetFingerprint());
233 }
234 
TEST_F(SerializationTest,SerializationData)235 TEST_F(SerializationTest, SerializationData) {
236   // Sample data to store in serialization.
237   float value1 = 456.24;
238   float value2 = 678.23;
239   std::string model_token = "model1";
240   std::string test_dir = getSerializationDir();
241   const std::string fake_dir = "/test/dir";
242 
243   // Dummy context.
244   TfLiteContext context = GenerateTfLiteContext(/*num_tensors*/ 30);
245   TfLiteDelegateParams partition = GenerateTfLiteDelegateParams(
246       /*num_nodes=*/2, /*num_input_tensors=*/3, /*num_output_tensors=*/1);
247 
248   SerializationParams serialization_params = {model_token.c_str(),
249                                               test_dir.c_str()};
250   Serialization serialization(serialization_params);
251 
252   {
253     std::string custom_str1 = "test1";
254 
255     // Set data.
256     auto entry1 =
257         serialization.GetEntryForKernel(custom_str1, &context, &partition);
258     ASSERT_EQ(entry1.SetData(&context, reinterpret_cast<const char*>(&value1),
259                              sizeof(value1)),
260               kTfLiteOk);
261 
262     // Same key instance should be able to read the data back.
263     std::string read_back1 = "this string should be cleared";
264     ASSERT_EQ(entry1.GetData(&context, &read_back1), kTfLiteOk);
265     auto* retrieved_data1 = reinterpret_cast<float*>(&(read_back1[0]));
266     ASSERT_FLOAT_EQ(*retrieved_data1, value1);
267 
268     // Equivalent key from same serialization should be able to read the same
269     // data back.
270     auto entry2 =
271         serialization.GetEntryForKernel(custom_str1, &context, &partition);
272     std::string read_back2;
273     ASSERT_EQ(entry2.GetData(&context, &read_back2), kTfLiteOk);
274     auto* retrieved_data2 = reinterpret_cast<float*>(&(read_back2[0]));
275     ASSERT_FLOAT_EQ(*retrieved_data2, value1);
276   }
277 
278   {
279     std::string custom_str2 = "test2";
280 
281     // Trying to read data without setting should result in a 'cache miss'.
282     auto entry3 =
283         serialization.GetEntryForKernel(custom_str2, &context, &partition);
284     std::string read_back3;
285     ASSERT_EQ(entry3.GetData(&context, &read_back3),
286               kTfLiteDelegateDataNotFound);
287     // Now insert data.
288     ASSERT_EQ(entry3.SetData(&context, reinterpret_cast<const char*>(&value2),
289                              sizeof(value2)),
290               kTfLiteOk);
291 
292     // Equivalent key from different serialization with same caching dir & model
293     // token should read back the data.
294     Serialization serialization2(serialization_params);
295     std::string read_back4;
296     auto entry4 =
297         serialization2.GetEntryForKernel(custom_str2, &context, &partition);
298     ASSERT_EQ(entry4.GetData(&context, &read_back4), kTfLiteOk);
299     auto* retrieved_data = reinterpret_cast<float*>(&(read_back4[0]));
300     ASSERT_FLOAT_EQ(*retrieved_data, value2);
301 
302     // Same key, but different dir shouldn't find data.
303     SerializationParams new_params = {model_token.c_str(), fake_dir.c_str()};
304     Serialization serialization3(new_params);
305     auto entry5 =
306         serialization3.GetEntryForKernel(custom_str2, &context, &partition);
307     std::string read_back5;
308     ASSERT_EQ(entry5.GetData(&context, &read_back5),
309               kTfLiteDelegateDataNotFound);
310   }
311 }
312 
TEST_F(SerializationTest,CachingDelegatedNodes)313 TEST_F(SerializationTest, CachingDelegatedNodes) {
314   std::string model_token = "model1";
315   std::string test_dir = getSerializationDir();
316   SerializationParams serialization_params = {model_token.c_str(),
317                                               test_dir.c_str()};
318   Serialization serialization(serialization_params);
319   TfLiteContext context = GenerateTfLiteContext(/*num_tensors*/ 30);
320   const std::string test_delegate_id = "dummy_delegate";
321 
322   std::vector<int> nodes_to_delegate = {2, 3, 4, 7};
323   TfLiteIntArray* nodes_to_delegate_array =
324       ConvertVectorToTfLiteIntArray(nodes_to_delegate);
325   std::vector<int> empty_nodes = {};
326   TfLiteIntArray* empty_nodes_array =
327       ConvertVectorToTfLiteIntArray(empty_nodes);
328 
329   {
330     ASSERT_EQ(SaveDelegatedNodes(&context, &serialization, test_delegate_id,
331                                  nodes_to_delegate_array),
332               kTfLiteOk);
333   }
334   {
335     TfLiteIntArray* read_back_array;
336     ASSERT_EQ(GetDelegatedNodes(&context, &serialization, "unknown_delegate",
337                                 &read_back_array),
338               kTfLiteDelegateDataNotFound);
339     ASSERT_EQ(GetDelegatedNodes(&context, &serialization, test_delegate_id,
340                                 &read_back_array),
341               kTfLiteOk);
342     ASSERT_EQ(TfLiteIntArrayEqual(nodes_to_delegate_array, read_back_array), 1);
343     TfLiteIntArrayFree(read_back_array);
344   }
345   {
346     ASSERT_EQ(SaveDelegatedNodes(&context, &serialization, test_delegate_id,
347                                  empty_nodes_array),
348               kTfLiteOk);
349     TfLiteIntArray* read_back_array;
350     ASSERT_EQ(GetDelegatedNodes(&context, &serialization, test_delegate_id,
351                                 &read_back_array),
352               kTfLiteOk);
353     ASSERT_EQ(read_back_array->size, 0);
354     TfLiteIntArrayFree(read_back_array);
355   }
356   {
357     // nullptr invalid.
358     ASSERT_EQ(
359         SaveDelegatedNodes(&context, &serialization, test_delegate_id, nullptr),
360         kTfLiteError);
361     ASSERT_EQ(
362         GetDelegatedNodes(&context, &serialization, test_delegate_id, nullptr),
363         kTfLiteError);
364   }
365 
366   TfLiteIntArrayFree(nodes_to_delegate_array);
367   TfLiteIntArrayFree(empty_nodes_array);
368 }
369 
370 }  // namespace
371 }  // namespace delegates
372 }  // namespace tflite
373