1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/serialization.h"
16
17 #if defined(_WIN32)
18 #include <fstream>
19 #include <iostream>
20 #else
21 #include <errno.h>
22 #include <fcntl.h>
23 #include <sys/file.h>
24 #include <unistd.h>
25
26 #include <cstring>
27 #endif // defined(_WIN32)
28
29 #include <time.h>
30
31 #include <algorithm>
32 #include <cstdint>
33 #include <memory>
34 #include <string>
35 #include <vector>
36
37 #include "tensorflow/lite/c/common.h"
38 #include "tensorflow/lite/minimal_logging.h"
39 #include "utils/hash/farmhash.h"
40
41 namespace tflite {
42 namespace delegates {
43 namespace {
44
45 static const char kDelegatedNodesSuffix[] = "_dnodes";
46
47 // Farmhash Fingerprint
CombineFingerprints(uint64_t l,uint64_t h)48 inline uint64_t CombineFingerprints(uint64_t l, uint64_t h) {
49 // Murmur-inspired hashing.
50 const uint64_t kMul = 0x9ddfea08eb382d69ULL;
51 uint64_t a = (l ^ h) * kMul;
52 a ^= (a >> 47);
53 uint64_t b = (h ^ a) * kMul;
54 b ^= (b >> 44);
55 b *= kMul;
56 b ^= (b >> 41);
57 b *= kMul;
58 return b;
59 }
60
JoinPath(const std::string & path1,const std::string & path2)61 inline std::string JoinPath(const std::string& path1,
62 const std::string& path2) {
63 return (path1.back() == '/') ? (path1 + path2) : (path1 + "/" + path2);
64 }
65
GetFilePath(const std::string & cache_dir,const std::string & model_token,const uint64_t fingerprint)66 inline std::string GetFilePath(const std::string& cache_dir,
67 const std::string& model_token,
68 const uint64_t fingerprint) {
69 auto file_name = (model_token + "_" + std::to_string(fingerprint) + ".bin");
70 return JoinPath(cache_dir, file_name);
71 }
72
73 } // namespace
74
StrFingerprint(const void * data,const size_t num_bytes)75 std::string StrFingerprint(const void* data, const size_t num_bytes) {
76 return std::to_string(
77 ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(reinterpret_cast<const char*>(data), num_bytes));
78 }
79
SerializationEntry(const std::string & cache_dir,const std::string & model_token,const uint64_t fingerprint)80 SerializationEntry::SerializationEntry(const std::string& cache_dir,
81 const std::string& model_token,
82 const uint64_t fingerprint)
83 : cache_dir_(cache_dir),
84 model_token_(model_token),
85 fingerprint_(fingerprint) {}
86
SetData(TfLiteContext * context,const char * data,const size_t size) const87 TfLiteStatus SerializationEntry::SetData(TfLiteContext* context,
88 const char* data,
89 const size_t size) const {
90 auto filepath = GetFilePath(cache_dir_, model_token_, fingerprint_);
91 // Temporary file to write data to.
92 const std::string temp_filepath =
93 JoinPath(cache_dir_, (model_token_ + std::to_string(fingerprint_) +
94 std::to_string(time(nullptr))));
95
96 #if defined(_WIN32)
97 std::ofstream out_file(temp_filepath.c_str());
98 if (!out_file) {
99 TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Could not create file: %s",
100 temp_filepath.c_str());
101 return kTfLiteDelegateDataWriteError;
102 }
103 out_file.write(data, size);
104 out_file.flush();
105 out_file.close();
106 // rename is an atomic operation in most systems.
107 if (rename(temp_filepath.c_str(), filepath.c_str()) < 0) {
108 TF_LITE_KERNEL_LOG(context, "Failed to rename to %s", filepath.c_str());
109 return kTfLiteDelegateDataWriteError;
110 }
111 #else // !defined(_WIN32)
112 // This method only works on unix/POSIX systems.
113 const int fd = open(temp_filepath.c_str(),
114 O_WRONLY | O_APPEND | O_CREAT | O_CLOEXEC, 0600);
115 if (fd < 0) {
116 TF_LITE_KERNEL_LOG(context, "Failed to open for writing: %s",
117 temp_filepath.c_str());
118 return kTfLiteDelegateDataWriteError;
119 }
120 // Loop until all bytes written.
121 ssize_t len = 0;
122 const char* buf = data;
123 do {
124 ssize_t ret = write(fd, buf, size);
125 if (ret <= 0) {
126 close(fd);
127 TF_LITE_KERNEL_LOG(context, "Failed to write data to: %s, error: %s",
128 temp_filepath.c_str(), std::strerror(errno));
129 return kTfLiteDelegateDataWriteError;
130 }
131
132 len += ret;
133 buf += ret;
134 } while (len < static_cast<ssize_t>(size));
135 // Use fsync to ensure data is on disk before renaming temp file.
136 if (fsync(fd) < 0) {
137 close(fd);
138 TF_LITE_KERNEL_LOG(context, "Could not fsync: %s, error: %s",
139 temp_filepath.c_str(), std::strerror(errno));
140 return kTfLiteDelegateDataWriteError;
141 }
142 if (close(fd) < 0) {
143 TF_LITE_KERNEL_LOG(context, "Could not close fd: %s, error: %s",
144 temp_filepath.c_str(), std::strerror(errno));
145 return kTfLiteDelegateDataWriteError;
146 }
147 if (rename(temp_filepath.c_str(), filepath.c_str()) < 0) {
148 TF_LITE_KERNEL_LOG(context, "Failed to rename to %s, error: %s",
149 filepath.c_str(), std::strerror(errno));
150 return kTfLiteDelegateDataWriteError;
151 }
152 #endif // defined(_WIN32)
153
154 TFLITE_LOG(TFLITE_LOG_INFO, "Wrote serialized data for model %s (%d B) to %s",
155 model_token_.c_str(), size, filepath.c_str());
156
157 return kTfLiteOk;
158 }
159
GetData(TfLiteContext * context,std::string * data) const160 TfLiteStatus SerializationEntry::GetData(TfLiteContext* context,
161 std::string* data) const {
162 if (!data) return kTfLiteError;
163 auto filepath = GetFilePath(cache_dir_, model_token_, fingerprint_);
164
165 #if defined(_WIN32)
166 std::ifstream cache_stream(filepath,
167 std::ios_base::in | std::ios_base::binary);
168 if (cache_stream.good()) {
169 cache_stream.seekg(0, cache_stream.end);
170 int cache_size = cache_stream.tellg();
171 cache_stream.seekg(0, cache_stream.beg);
172
173 data->resize(cache_size);
174 cache_stream.read(&(*data)[0], cache_size);
175 cache_stream.close();
176 }
177 #else // !defined(_WIN32)
178 // This method only works on unix/POSIX systems, but is more optimized & has
179 // lower size overhead for Android binaries.
180 data->clear();
181 // O_CLOEXEC is needed for correctness, as another thread may call
182 // popen() and the callee inherit the lock if it's not O_CLOEXEC.
183 int fd = open(filepath.c_str(), O_RDONLY | O_CLOEXEC, 0600);
184 if (fd < 0) {
185 TF_LITE_KERNEL_LOG(context, "File %s couldn't be opened for reading: %s",
186 filepath.c_str(), std::strerror(errno));
187 return kTfLiteDelegateDataNotFound;
188 }
189 int lock_status = flock(fd, LOCK_EX);
190 if (lock_status < 0) {
191 close(fd);
192 TF_LITE_KERNEL_LOG(context, "Could not flock %s: %s", filepath.c_str(),
193 std::strerror(errno));
194 return kTfLiteDelegateDataReadError;
195 }
196 char buffer[512];
197 while (true) {
198 int bytes_read = read(fd, buffer, 512);
199 if (bytes_read == 0) {
200 // EOF
201 close(fd);
202 return kTfLiteOk;
203 } else if (bytes_read < 0) {
204 close(fd);
205 TF_LITE_KERNEL_LOG(context, "Error reading %s: %s", filepath.c_str(),
206 std::strerror(errno));
207 return kTfLiteDelegateDataReadError;
208 } else {
209 data->append(buffer, bytes_read);
210 }
211 }
212 #endif // defined(_WIN32)
213
214 TFLITE_LOG_PROD(TFLITE_LOG_INFO,
215 "Found serialized data for model %s (%d B) at %s",
216 model_token_.c_str(), data->size(), filepath.c_str());
217
218 if (!data->empty()) {
219 TFLITE_LOG(TFLITE_LOG_INFO, "Data found at %s: %d bytes", filepath.c_str(),
220 data->size());
221 return kTfLiteOk;
222 } else {
223 TF_LITE_KERNEL_LOG(context, "No serialized data found: %s",
224 filepath.c_str());
225 return kTfLiteDelegateDataNotFound;
226 }
227 }
228
GetEntryImpl(const std::string & custom_key,TfLiteContext * context,const TfLiteDelegateParams * delegate_params)229 SerializationEntry Serialization::GetEntryImpl(
230 const std::string& custom_key, TfLiteContext* context,
231 const TfLiteDelegateParams* delegate_params) {
232 // First incorporate model_token.
233 // We use Fingerprint64 instead of std::hash, since the latter isn't
234 // guaranteed to be stable across runs. See b/172237993.
235 uint64_t fingerprint =
236 ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(model_token_.c_str(), model_token_.size());
237
238 // Incorporate custom_key.
239 const uint64_t custom_str_fingerprint =
240 ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(custom_key.c_str(), custom_key.size());
241 fingerprint = CombineFingerprints(fingerprint, custom_str_fingerprint);
242
243 // Incorporate context details, if provided.
244 // A quick heuristic involving graph tensors to 'fingerprint' a
245 // tflite::Subgraph. We don't consider the execution plan, since it could be
246 // in flux if the delegate uses this method during
247 // ReplaceNodeSubsetsWithDelegateKernels (eg in kernel Init).
248 if (context) {
249 std::vector<int32_t> context_data;
250 // Number of tensors can be large.
251 const int tensors_to_consider = std::min<int>(context->tensors_size, 100);
252 context_data.reserve(1 + tensors_to_consider);
253 context_data.push_back(context->tensors_size);
254 for (int i = 0; i < tensors_to_consider; ++i) {
255 context_data.push_back(context->tensors[i].bytes);
256 }
257 const uint64_t context_fingerprint =
258 ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(reinterpret_cast<char*>(context_data.data()),
259 context_data.size() * sizeof(int32_t));
260 fingerprint = CombineFingerprints(fingerprint, context_fingerprint);
261 }
262
263 // Incorporate delegated partition details, if provided.
264 // A quick heuristic that considers the nodes & I/O tensor sizes to
265 // fingerprint TfLiteDelegateParams.
266 if (delegate_params) {
267 std::vector<int32_t> partition_data;
268 auto* nodes = delegate_params->nodes_to_replace;
269 auto* input_tensors = delegate_params->input_tensors;
270 auto* output_tensors = delegate_params->output_tensors;
271 partition_data.reserve(nodes->size + input_tensors->size +
272 output_tensors->size);
273 partition_data.insert(partition_data.end(), nodes->data,
274 nodes->data + nodes->size);
275 for (int i = 0; i < input_tensors->size; ++i) {
276 auto& tensor = context->tensors[input_tensors->data[i]];
277 partition_data.push_back(tensor.bytes);
278 }
279 for (int i = 0; i < output_tensors->size; ++i) {
280 auto& tensor = context->tensors[output_tensors->data[i]];
281 partition_data.push_back(tensor.bytes);
282 }
283 const uint64_t partition_fingerprint =
284 ::NAMESPACE_FOR_HASH_FUNCTIONS::Fingerprint64(reinterpret_cast<char*>(partition_data.data()),
285 partition_data.size() * sizeof(int32_t));
286 fingerprint = CombineFingerprints(fingerprint, partition_fingerprint);
287 }
288
289 // Get a fingerprint-specific lock that is passed to the SerializationKey, to
290 // ensure noone else gets access to an equivalent SerializationKey.
291 return SerializationEntry(cache_dir_, model_token_, fingerprint);
292 }
293
SaveDelegatedNodes(TfLiteContext * context,Serialization * serialization,const std::string & delegate_id,const TfLiteIntArray * node_ids)294 TfLiteStatus SaveDelegatedNodes(TfLiteContext* context,
295 Serialization* serialization,
296 const std::string& delegate_id,
297 const TfLiteIntArray* node_ids) {
298 if (!node_ids) return kTfLiteError;
299 std::string cache_key = delegate_id + kDelegatedNodesSuffix;
300 auto entry = serialization->GetEntryForDelegate(cache_key, context);
301 return entry.SetData(context, reinterpret_cast<const char*>(node_ids),
302 (1 + node_ids->size) * sizeof(int));
303 }
304
GetDelegatedNodes(TfLiteContext * context,Serialization * serialization,const std::string & delegate_id,TfLiteIntArray ** node_ids)305 TfLiteStatus GetDelegatedNodes(TfLiteContext* context,
306 Serialization* serialization,
307 const std::string& delegate_id,
308 TfLiteIntArray** node_ids) {
309 if (!node_ids) return kTfLiteError;
310 std::string cache_key = delegate_id + kDelegatedNodesSuffix;
311 auto entry = serialization->GetEntryForDelegate(cache_key, context);
312
313 std::string read_buffer;
314 TF_LITE_ENSURE_STATUS(entry.GetData(context, &read_buffer));
315 if (read_buffer.empty()) return kTfLiteOk;
316 *node_ids = TfLiteIntArrayCopy(
317 reinterpret_cast<const TfLiteIntArray*>(read_buffer.data()));
318 return kTfLiteOk;
319 }
320
321 } // namespace delegates
322 } // namespace tflite
323