1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A tensor bundle is a set of immutable persistent files storing a set of named
17 // tensors. It is designed for checkpointing TensorFlow tensors.
18 //
19 // The paths of the managed files share a common prefix; e.g., with the prefix:
20 // /fs/model/train/ckpt-step/ckpt
21 //
22 // the bundle may contain a metadata file, and sharded data files:
23 // /fs/model/train/ckpt-step/
24 // ckpt.index
25 // ckpt.data-00000-of-00020
26 // ckpt.data-00001-of-00020
27 // ...
28 // ckpt.data-00019-of-00020
29 //
30 // The ".index" file is a string-string immutable table
31 // (tensorflow::table::Table). Each key is a name of a tensor and its value is
32 // a serialized BundleEntryProto. Each BundleEntryProto describes the metadata
33 // of a tensor: which of the "data" files contains the content of a tensor, the
34 // offset into that file, checksum, some auxiliary data, etc.
35 //
36 // A tensor bundle can be accessed randomly using a BundleReader. Usage:
37 //
38 // BundleReader reader(env, "/fs/model/train/ckpt-step/ckpt");
39 // reader.Lookup("name", &tensor);
40 //
41 // A tensor bundle can be built using BundleWriter. Each BundleWriter builds a
42 // single data file bundle. Multiple bundles can then be merged by
43 // MergeBundles() without reading and writing large chunk of data: it reads the
44 // metadata files and outputs a single merged metadata. Typical usage:
45 //
46 // worker 0:
47 // BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker0-step");
48 // writer.Add(...); // Adds the tensors on this worker.
49 // writer.Finish(); // Flushes.
50 // worker 1:
51 // BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker1-step");
52 // writer.Add(...);
53 // writer.Finish();
54 // worker 2:
55 // MergeBundles(env,
56 // {"/fs/model/train/ckpt-step/tmp/worker0-step",
57 // "/fs/model/train/ckpt-step/tmp/worker1-step"},
58 // "/fs/model/train/ckpt-step/ckpt" /* merged prefix */);
59 //
60
61 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
62 #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
63
64 #include <map>
65 #include <string>
66 #include <unordered_map>
67
68 #include "absl/algorithm/container.h"
69 #include "absl/container/flat_hash_map.h"
70 #include "absl/functional/function_ref.h"
71 #include "tensorflow/core/framework/tensor.h"
72 #include "tensorflow/core/framework/tensor_shape.h"
73 #include "tensorflow/core/framework/tensor_slice.h"
74 #include "tensorflow/core/lib/core/status.h"
75 #include "tensorflow/core/lib/gtl/array_slice.h"
76 #include "tensorflow/core/lib/io/cache.h"
77 #include "tensorflow/core/lib/io/inputbuffer.h"
78 #include "tensorflow/core/lib/io/table.h"
79 #include "tensorflow/core/platform/cord.h"
80 #include "tensorflow/core/platform/env.h"
81 #include "tensorflow/core/platform/file_system.h"
82 #include "tensorflow/core/platform/macros.h"
83 #include "tensorflow/core/platform/types.h"
84 #include "tensorflow/core/protobuf/tensor_bundle.pb.h"
85 #include "tensorflow/core/util/tensor_bundle/naming.h"
86 #include "tensorflow/core/util/tensor_slice_set.h"
87
88 namespace tensorflow {
89
90 class FileOutputBuffer;
91
92 // Versioning of the tensor bundle format.
93 // Follows the same rules as 3p/tf/core/public/version.h.
94 //
95 // History:
96 // 0. Any tensor bundles produced before this field was added.
97 // 1. Added this field (2016-09-14).
98 extern const int kTensorBundleMinProducer;
99 extern const int kTensorBundleMinConsumer;
100 extern const int kTensorBundleVersion;
101
102 // The empty string, hence always the first key in the metadata table. Its
103 // corresponding value is a BundleHeaderProto.
104 extern const char* const kHeaderEntryKey;
105
106 // Builds a string-string table of tensor names to BundleEntryProto (metadata).
107 //
108 // On construction, attempts to create a directory given by the dirname of
109 // "prefix", so "status()" must be checked before calling any member functions.
110 //
111 // All threads accessing the same BundleWriter must synchronize.
112 class BundleWriter {
113 public:
114 struct Options {
OptionsOptions115 Options() {}
116 // Alignment, in bytes, for tensor data.
117 // Must be >= 1. The default size of 1 densely packs tensors.
118 int data_alignment{1};
119 };
120 BundleWriter(Env* env, StringPiece prefix,
121 const Options& options = Options());
122
123 // Adds the tensor "val" under key "key".
124 // Across calls "key" must be unique but can be added in any order.
125 Status Add(StringPiece key, const Tensor& val);
126
127 // Partitioned variables support.
128 // A slice of a full tensor is stored in two entries in the metadata table:
129 //
130 // full_tensor_key -> BundleEntryProto, describing all stored slices
131 // of this full tensor. Does not append to the data
132 // file.
133 // encoded slice key -> BundleEntryProto, describing one particular slice.
134 // Appends values of this slice to the data file.
135 //
136 // Slices of a full tensor can be added in any order.
137 //
138 // If a full tensor has slices placed on N devices and N BundleWriter's are
139 // concurrently used, the caller must use MergeBundles() to ensure that a
140 // consistent entry for "full_tensor_key" is produced.
141 //
142 // Returns an error if the same slice is added the second time.
143 Status AddSlice(StringPiece full_tensor_key,
144 const TensorShape& full_tensor_shape,
145 const TensorSlice& slice_spec, const Tensor& slice_tensor);
146
147 // Finishes the writer and flushes.
148 Status Finish() TF_MUST_USE_RESULT;
149
status()150 Status status() const { return status_; }
151
152 private:
153 Env* const env_; // Not owned.
154 const Options options_;
155 const string prefix_;
156 string metadata_path_;
157 string data_path_;
158 bool use_temp_file_;
159 std::unique_ptr<FileOutputBuffer> out_;
160 int64_t size_; // Number of bytes written into out_.
161 std::map<string, BundleEntryProto> entries_;
162 Status status_;
163
164 TF_DISALLOW_COPY_AND_ASSIGN(BundleWriter);
165 };
166
167 // Merges a set of bundles (given their prefixes) into a single bundle with the
168 // given "merged_prefix". The merged metadata is guaranteed to be consistent.
169 //
170 // If there are N bundles in "prefixes", during the merge the data files will be
171 // renamed to contain a proper sharded file spec, with num_shards set to the sum
172 // of num_shards across the N input bundles.
173 //
174 // The caller should only rely on the metadata file of the merged bundle to
175 // query information about a tensor. In particular, this function does not
176 // guarantee not to re-order the input data files.
177 //
178 // Once merged, makes a best effort to delete the old metadata files.
179 // Returns OK iff all bundles are successfully merged.
180 //
181 // "allow_missing_files": If set to true, merges "prefixes" as long as
182 // at least one file exists. (Defaults to false.)
183 //
184 // Returns an InvalidArgumentError when "allow_missing_files" is set to true
185 // and all data files named in "prefixes" do not exist.
186 //
187 // Returns a NotFoundError when "allow_missing_files" is set to false and
188 // any data file named in "prefixes" does not exist.
189 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
190 StringPiece merged_prefix,
191 bool allow_missing_files = false);
192
193 // On construction, silently attempts to read the metadata associated with
194 // "prefix". If caller intends to call any function afterwards, "status()"
195 // must be checked.
196 // All threads accessing the same BundleReader must synchronize.
197 class BundleReader {
198 public:
199 BundleReader(Env* const env, StringPiece prefix);
200 ~BundleReader();
201
202 // Is ok() iff the reader construction is successful (completed the read of
203 // the metadata).
status()204 Status status() const { return status_; }
205
206 // Queries whether the bundle contains an entry keyed by "key". Calls Seek()
207 // internally, so this call invalidates the reader's current position.
208 // REQUIRES: status().ok()
209 bool Contains(StringPiece key);
210
211 // Sorts a `container` of tensors to read such that when `Seek(key)` is called
212 // on the elements of the sorted container, the underlying file access is
213 // sequential. Sorting can greatly improve overall read speed.
214 //
215 // `get_key` should be a functon that when passed an element in `container`,
216 // returns the `key` of the tensor.
217 //
218 // REQUIRES: status().ok()
219 template <class T>
220 Status SortForSequentialAccess(std::vector<T>& container,
221 absl::FunctionRef<string(const T&)> get_key);
222
223 // Looks up the dtype and the shape of the tensor keyed by "key".
224 // REQUIRES: status().ok()
225 Status LookupDtypeAndShape(StringPiece key, DataType* dtype,
226 TensorShape* shape) TF_MUST_USE_RESULT;
227
228 // Looks up the shape of the tensor keyed by "key".
229 // Clears "shape" if not found.
230 // REQUIRES: status().ok()
231 Status LookupTensorShape(StringPiece key,
232 TensorShape* shape) TF_MUST_USE_RESULT;
233
234 // Looks up the tensor keyed by "key". If "key" refers to a partitioned
235 // tensor, attempts to look up the full contents using all stored slices.
236 //
237 // Caller must make sure "val" has the same shape and dtype as the
238 // corresponding contents, so that its buffer can be filled without needing
239 // extra allocation. These can be queried via "LookupDtypeAndShape()".
240 //
241 // On error, "val" may contain nonsense data. Returns a NotFound error if
242 // tensor keyed by "key" does not exist in this bundle.
243 //
244 // Validates the stored crc32c checksum against the restored bytes.
245 // REQUIRES: status().ok()
246 Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT;
247
248 // Looks up the tensor pointed to by the internal iterator.
249 //
250 // On error, "val" may contain nonsense data.
251 //
252 // Validates the stored crc32c checksum against the restored bytes.
253 // REQUIRES: status().ok() && Valid()
254 Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT;
255
256 // Looks up the slices of the tensor keyed by "key". On OK, "slices"
257 // is non-empty if and only if the tensor is a partitioned tensor.
258 //
259 // Warning - there is no guaranteed ordering for the returned slices, so
260 // a slice with a larger start index in some dimension could come before
261 // another slice with a smaller start index in the same dimension.
262 // REQUIRES: status().ok()
263 Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices)
264 TF_MUST_USE_RESULT;
265
266 // Looks up a specific slice of a partitioned tensor.
267 // It is only required that the stored slices cover the requested slice,
268 // namely "slice_spec" is a subset of the union of the stored slices.
269 // REQUIRES: status().ok()
270 Status LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec,
271 Tensor* val) TF_MUST_USE_RESULT;
272
273 // Seeks to the first position in the bundle whose key is no less than "key".
274 // REQUIRES: status().ok()
Seek(StringPiece key)275 void Seek(StringPiece key) { return iter_->Seek(key); }
276 // Moves to the next position in the bundle.
277 // REQUIRES: status().ok()
Next()278 void Next() const { iter_->Next(); }
279 // Returns true iff the reader is positioned to a key/val pair.
280 // REQUIRES: status().ok()
Valid()281 bool Valid() const { return iter_->Valid(); }
282
283 // Returns the key at the current position.
284 // REQUIRES: status().ok() && Valid()
key()285 StringPiece key() const { return iter_->key(); }
286 // Returns the raw value at the current position.
287 // REQUIRES: status().ok() && Valid()
value()288 StringPiece value() const { return iter_->value(); }
289
290 string DebugString();
291
292 private:
293 // Seeks for "key" and reads the metadata proto.
294 // On non-OK return, clears "entry" for the caller.
295 // REQUIRES: status().ok()
296 Status GetBundleEntryProto(StringPiece key,
297 BundleEntryProto* entry) TF_MUST_USE_RESULT;
298
299 // Reads the tensor value described by the metadata proto "entry".
300 // Usage for "val" follows the comment of "Lookup()".
301 Status GetValue(const BundleEntryProto& entry,
302 Tensor* val) TF_MUST_USE_RESULT;
303
304 // Reads the slice described by "slice_spec". The corresponding full tensor
305 // has key "ful_tensor_key" and metadata proto "full_tensor_entry".
306 // REQUIRES: full_tensor_entry.slices_size() > 0
307 Status GetSliceValue(StringPiece full_tensor_key,
308 const BundleEntryProto& full_tensor_entry,
309 const TensorSlice& slice_spec,
310 Tensor* val) TF_MUST_USE_RESULT;
311
312 Env* env_; // Not owned.
313 const string prefix_;
314
315 Status status_;
316 RandomAccessFile* metadata_; // Owned.
317 table::Table* table_;
318 table::Cache* index_cache_;
319 table::Iterator* iter_;
320 // Owned the InputBuffer objects and their underlying RandomAccessFile's.
321 std::unordered_map<int32, io::InputBuffer*> data_;
322
323 // Maps each partitioned tensor's key to its stored slices (represented in a
324 // TensorSliceSet). Populated on-demand.
325 std::unordered_map<string, checkpoint::TensorSliceSet*> tensor_slices_;
326
327 // Expected number of data file shards in the bundle. Extracted by reading
328 // the header entry in the metadata table.
329 int num_shards_;
330
331 // Flag that this class sets to true when the endianness of the target bundle
332 // differs from that of the current system's processor architecture.
333 bool need_to_swap_bytes_;
334
335 friend class TensorBundleAlignmentTest; // For testing data alignment.
336
337 TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);
338 };
339
340 // A buffering wrapper for a WritableFile. Useful if the caller wishes to issue
341 // small writes to a file (e.g. writing out a list of small varints).
342 // External synchronization must be used in the presence of concurrent callers.
343 class FileOutputBuffer {
344 public:
345 FileOutputBuffer(WritableFile* file, size_t buffer_size);
346 ~FileOutputBuffer();
347
348 // Buffered append.
349 Status Append(StringPiece data);
350
351 // Returns the running crc32c checksum of all currently appended bytes.
crc32c()352 uint32 crc32c() { return crc32c_; }
353 // Clears the running crc32c checksum.
clear_crc32c()354 void clear_crc32c() { crc32c_ = 0; }
355
356 // Appends the buffered data, then closes the underlying file.
357 Status Close();
358
359 private:
360 // Appends the buffered data to the underlying file. Does NOT flush the file.
361 Status FlushBuffer(bool closing);
362
363 WritableFile* file_; // Owned.
364
365 // buffer_ptr_[0, position_) holds the buffered data not yet appended to the
366 // underlying file.
367 size_t position_;
368 const size_t buffer_size_;
369 char* buffer_ptr_;
370
371 // Checksum of all appended bytes since construction or last clear_crc32c().
372 uint32 crc32c_ = 0;
373 };
374
375 template <class T>
SortForSequentialAccess(std::vector<T> & container,absl::FunctionRef<string (const T &)> get_key)376 Status BundleReader::SortForSequentialAccess(
377 std::vector<T>& container, absl::FunctionRef<string(const T&)> get_key) {
378 struct FileOffset {
379 int32_t shard_id;
380 int64_t offset;
381 };
382 absl::flat_hash_map<string, FileOffset> file_offsets;
383 for (const T& element : container) {
384 BundleEntryProto entry;
385 TF_RETURN_IF_ERROR(GetBundleEntryProto(get_key(element), &entry));
386 file_offsets[get_key(element)] = {entry.shard_id(), entry.offset()};
387 }
388 absl::c_sort(container, [&get_key, &file_offsets](const T& a, const T& b) {
389 const FileOffset& file_offset_a = file_offsets[get_key(a)];
390 const FileOffset& file_offset_b = file_offsets[get_key(b)];
391 if (file_offset_a.shard_id == file_offset_b.shard_id) {
392 return file_offset_a.offset < file_offset_b.offset;
393 } else {
394 return file_offset_a.shard_id < file_offset_b.shard_id;
395 }
396 });
397 return OkStatus();
398 }
399
400 } // namespace tensorflow
401
402 #endif // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
403