xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/serialization.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/serialization.h"
16 
17 #if defined(_WIN32)
18 #include <fstream>
19 #include <iostream>
20 #else
21 #include <errno.h>
22 #include <fcntl.h>
23 #include <sys/file.h>
24 #include <unistd.h>
25 
26 #include <cstring>
27 #endif  // defined(_WIN32)
28 
29 #include <time.h>
30 
31 #include <algorithm>
32 #include <cstdint>
33 #include <memory>
34 #include <string>
35 #include <vector>
36 
37 #include "tensorflow/lite/c/common.h"
38 #include "tensorflow/lite/minimal_logging.h"
39 #include "utils/hash/farmhash.h"
40 
41 namespace tflite {
42 namespace delegates {
43 namespace {
44 
45 static const char kDelegatedNodesSuffix[] = "_dnodes";
46 
47 // Farmhash Fingerprint
CombineFingerprints(uint64_t l,uint64_t h)48 inline uint64_t CombineFingerprints(uint64_t l, uint64_t h) {
49   // Murmur-inspired hashing.
50   const uint64_t kMul = 0x9ddfea08eb382d69ULL;
51   uint64_t a = (l ^ h) * kMul;
52   a ^= (a >> 47);
53   uint64_t b = (h ^ a) * kMul;
54   b ^= (b >> 44);
55   b *= kMul;
56   b ^= (b >> 41);
57   b *= kMul;
58   return b;
59 }
60 
JoinPath(const std::string & path1,const std::string & path2)61 inline std::string JoinPath(const std::string& path1,
62                             const std::string& path2) {
63   return (path1.back() == '/') ? (path1 + path2) : (path1 + "/" + path2);
64 }
65 
GetFilePath(const std::string & cache_dir,const std::string & model_token,const uint64_t fingerprint)66 inline std::string GetFilePath(const std::string& cache_dir,
67                                const std::string& model_token,
68                                const uint64_t fingerprint) {
69   auto file_name = (model_token + "_" + std::to_string(fingerprint) + ".bin");
70   return JoinPath(cache_dir, file_name);
71 }
72 
73 }  // namespace
74 
StrFingerprint(const void * data,const size_t num_bytes)75 std::string StrFingerprint(const void* data, const size_t num_bytes) {
76   return std::to_string(
77       ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(reinterpret_cast<const char*>(data), num_bytes));
78 }
79 
SerializationEntry(const std::string & cache_dir,const std::string & model_token,const uint64_t fingerprint)80 SerializationEntry::SerializationEntry(const std::string& cache_dir,
81                                        const std::string& model_token,
82                                        const uint64_t fingerprint)
83     : cache_dir_(cache_dir),
84       model_token_(model_token),
85       fingerprint_(fingerprint) {}
86 
SetData(TfLiteContext * context,const char * data,const size_t size) const87 TfLiteStatus SerializationEntry::SetData(TfLiteContext* context,
88                                          const char* data,
89                                          const size_t size) const {
90   auto filepath = GetFilePath(cache_dir_, model_token_, fingerprint_);
91   // Temporary file to write data to.
92   const std::string temp_filepath =
93       JoinPath(cache_dir_, (model_token_ + std::to_string(fingerprint_) +
94                             std::to_string(time(nullptr))));
95 
96 #if defined(_WIN32)
97   std::ofstream out_file(temp_filepath.c_str());
98   if (!out_file) {
99     TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Could not create file: %s",
100                     temp_filepath.c_str());
101     return kTfLiteDelegateDataWriteError;
102   }
103   out_file.write(data, size);
104   out_file.flush();
105   out_file.close();
106   // rename is an atomic operation in most systems.
107   if (rename(temp_filepath.c_str(), filepath.c_str()) < 0) {
108     TF_LITE_KERNEL_LOG(context, "Failed to rename to %s", filepath.c_str());
109     return kTfLiteDelegateDataWriteError;
110   }
111 #else   // !defined(_WIN32)
112   // This method only works on unix/POSIX systems.
113   const int fd = open(temp_filepath.c_str(),
114                       O_WRONLY | O_APPEND | O_CREAT | O_CLOEXEC, 0600);
115   if (fd < 0) {
116     TF_LITE_KERNEL_LOG(context, "Failed to open for writing: %s",
117                        temp_filepath.c_str());
118     return kTfLiteDelegateDataWriteError;
119   }
120   // Loop until all bytes written.
121   ssize_t len = 0;
122   const char* buf = data;
123   do {
124     ssize_t ret = write(fd, buf, size);
125     if (ret <= 0) {
126       close(fd);
127       TF_LITE_KERNEL_LOG(context, "Failed to write data to: %s, error: %s",
128                          temp_filepath.c_str(), std::strerror(errno));
129       return kTfLiteDelegateDataWriteError;
130     }
131 
132     len += ret;
133     buf += ret;
134   } while (len < static_cast<ssize_t>(size));
135   // Use fsync to ensure data is on disk before renaming temp file.
136   if (fsync(fd) < 0) {
137     close(fd);
138     TF_LITE_KERNEL_LOG(context, "Could not fsync: %s, error: %s",
139                        temp_filepath.c_str(), std::strerror(errno));
140     return kTfLiteDelegateDataWriteError;
141   }
142   if (close(fd) < 0) {
143     TF_LITE_KERNEL_LOG(context, "Could not close fd: %s, error: %s",
144                        temp_filepath.c_str(), std::strerror(errno));
145     return kTfLiteDelegateDataWriteError;
146   }
147   if (rename(temp_filepath.c_str(), filepath.c_str()) < 0) {
148     TF_LITE_KERNEL_LOG(context, "Failed to rename to %s, error: %s",
149                        filepath.c_str(), std::strerror(errno));
150     return kTfLiteDelegateDataWriteError;
151   }
152 #endif  // defined(_WIN32)
153 
154   TFLITE_LOG(TFLITE_LOG_INFO, "Wrote serialized data for model %s (%d B) to %s",
155              model_token_.c_str(), size, filepath.c_str());
156 
157   return kTfLiteOk;
158 }
159 
GetData(TfLiteContext * context,std::string * data) const160 TfLiteStatus SerializationEntry::GetData(TfLiteContext* context,
161                                          std::string* data) const {
162   if (!data) return kTfLiteError;
163   auto filepath = GetFilePath(cache_dir_, model_token_, fingerprint_);
164 
165 #if defined(_WIN32)
166   std::ifstream cache_stream(filepath,
167                              std::ios_base::in | std::ios_base::binary);
168   if (cache_stream.good()) {
169     cache_stream.seekg(0, cache_stream.end);
170     int cache_size = cache_stream.tellg();
171     cache_stream.seekg(0, cache_stream.beg);
172 
173     data->resize(cache_size);
174     cache_stream.read(&(*data)[0], cache_size);
175     cache_stream.close();
176   }
177 #else   // !defined(_WIN32)
178   // This method only works on unix/POSIX systems, but is more optimized & has
179   // lower size overhead for Android binaries.
180   data->clear();
181   // O_CLOEXEC is needed for correctness, as another thread may call
182   // popen() and the callee inherit the lock if it's not O_CLOEXEC.
183   int fd = open(filepath.c_str(), O_RDONLY | O_CLOEXEC, 0600);
184   if (fd < 0) {
185     TF_LITE_KERNEL_LOG(context, "File %s couldn't be opened for reading: %s",
186                        filepath.c_str(), std::strerror(errno));
187     return kTfLiteDelegateDataNotFound;
188   }
189   int lock_status = flock(fd, LOCK_EX);
190   if (lock_status < 0) {
191     close(fd);
192     TF_LITE_KERNEL_LOG(context, "Could not flock %s: %s", filepath.c_str(),
193                        std::strerror(errno));
194     return kTfLiteDelegateDataReadError;
195   }
196   char buffer[512];
197   while (true) {
198     int bytes_read = read(fd, buffer, 512);
199     if (bytes_read == 0) {
200       // EOF
201       close(fd);
202       return kTfLiteOk;
203     } else if (bytes_read < 0) {
204       close(fd);
205       TF_LITE_KERNEL_LOG(context, "Error reading %s: %s", filepath.c_str(),
206                          std::strerror(errno));
207       return kTfLiteDelegateDataReadError;
208     } else {
209       data->append(buffer, bytes_read);
210     }
211   }
212 #endif  // defined(_WIN32)
213 
214   TFLITE_LOG_PROD(TFLITE_LOG_INFO,
215                   "Found serialized data for model %s (%d B) at %s",
216                   model_token_.c_str(), data->size(), filepath.c_str());
217 
218   if (!data->empty()) {
219     TFLITE_LOG(TFLITE_LOG_INFO, "Data found at %s: %d bytes", filepath.c_str(),
220                data->size());
221     return kTfLiteOk;
222   } else {
223     TF_LITE_KERNEL_LOG(context, "No serialized data found: %s",
224                        filepath.c_str());
225     return kTfLiteDelegateDataNotFound;
226   }
227 }
228 
GetEntryImpl(const std::string & custom_key,TfLiteContext * context,const TfLiteDelegateParams * delegate_params)229 SerializationEntry Serialization::GetEntryImpl(
230     const std::string& custom_key, TfLiteContext* context,
231     const TfLiteDelegateParams* delegate_params) {
232   // First incorporate model_token.
233   // We use Fingerprint64 instead of std::hash, since the latter isn't
234   // guaranteed to be stable across runs. See b/172237993.
235   uint64_t fingerprint =
236       ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(model_token_.c_str(), model_token_.size());
237 
238   // Incorporate custom_key.
239   const uint64_t custom_str_fingerprint =
240       ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(custom_key.c_str(), custom_key.size());
241   fingerprint = CombineFingerprints(fingerprint, custom_str_fingerprint);
242 
243   // Incorporate context details, if provided.
244   // A quick heuristic involving graph tensors to 'fingerprint' a
245   // tflite::Subgraph. We don't consider the execution plan, since it could be
246   // in flux if the delegate uses this method during
247   // ReplaceNodeSubsetsWithDelegateKernels (eg in kernel Init).
248   if (context) {
249     std::vector<int32_t> context_data;
250     // Number of tensors can be large.
251     const int tensors_to_consider = std::min<int>(context->tensors_size, 100);
252     context_data.reserve(1 + tensors_to_consider);
253     context_data.push_back(context->tensors_size);
254     for (int i = 0; i < tensors_to_consider; ++i) {
255       context_data.push_back(context->tensors[i].bytes);
256     }
257     const uint64_t context_fingerprint =
258         ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(reinterpret_cast<char*>(context_data.data()),
259                                 context_data.size() * sizeof(int32_t));
260     fingerprint = CombineFingerprints(fingerprint, context_fingerprint);
261   }
262 
263   // Incorporate delegated partition details, if provided.
264   // A quick heuristic that considers the nodes & I/O tensor sizes to
265   // fingerprint TfLiteDelegateParams.
266   if (delegate_params) {
267     std::vector<int32_t> partition_data;
268     auto* nodes = delegate_params->nodes_to_replace;
269     auto* input_tensors = delegate_params->input_tensors;
270     auto* output_tensors = delegate_params->output_tensors;
271     partition_data.reserve(nodes->size + input_tensors->size +
272                            output_tensors->size);
273     partition_data.insert(partition_data.end(), nodes->data,
274                           nodes->data + nodes->size);
275     for (int i = 0; i < input_tensors->size; ++i) {
276       auto& tensor = context->tensors[input_tensors->data[i]];
277       partition_data.push_back(tensor.bytes);
278     }
279     for (int i = 0; i < output_tensors->size; ++i) {
280       auto& tensor = context->tensors[output_tensors->data[i]];
281       partition_data.push_back(tensor.bytes);
282     }
283     const uint64_t partition_fingerprint =
284         ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(reinterpret_cast<char*>(partition_data.data()),
285                                 partition_data.size() * sizeof(int32_t));
286     fingerprint = CombineFingerprints(fingerprint, partition_fingerprint);
287   }
288 
289   // Get a fingerprint-specific lock that is passed to the SerializationKey, to
290   // ensure noone else gets access to an equivalent SerializationKey.
291   return SerializationEntry(cache_dir_, model_token_, fingerprint);
292 }
293 
SaveDelegatedNodes(TfLiteContext * context,Serialization * serialization,const std::string & delegate_id,const TfLiteIntArray * node_ids)294 TfLiteStatus SaveDelegatedNodes(TfLiteContext* context,
295                                 Serialization* serialization,
296                                 const std::string& delegate_id,
297                                 const TfLiteIntArray* node_ids) {
298   if (!node_ids) return kTfLiteError;
299   std::string cache_key = delegate_id + kDelegatedNodesSuffix;
300   auto entry = serialization->GetEntryForDelegate(cache_key, context);
301   return entry.SetData(context, reinterpret_cast<const char*>(node_ids),
302                        (1 + node_ids->size) * sizeof(int));
303 }
304 
GetDelegatedNodes(TfLiteContext * context,Serialization * serialization,const std::string & delegate_id,TfLiteIntArray ** node_ids)305 TfLiteStatus GetDelegatedNodes(TfLiteContext* context,
306                                Serialization* serialization,
307                                const std::string& delegate_id,
308                                TfLiteIntArray** node_ids) {
309   if (!node_ids) return kTfLiteError;
310   std::string cache_key = delegate_id + kDelegatedNodesSuffix;
311   auto entry = serialization->GetEntryForDelegate(cache_key, context);
312 
313   std::string read_buffer;
314   TF_LITE_ENSURE_STATUS(entry.GetData(context, &read_buffer));
315   if (read_buffer.empty()) return kTfLiteOk;
316   *node_ids = TfLiteIntArrayCopy(
317       reinterpret_cast<const TfLiteIntArray*>(read_buffer.data()));
318   return kTfLiteOk;
319 }
320 
321 }  // namespace delegates
322 }  // namespace tflite
323