xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/snapshot_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_
17 #define TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_
18 
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/lib/io/compression.h"
24 #include "tensorflow/core/lib/io/inputstream_interface.h"
25 #include "tensorflow/core/lib/io/record_reader.h"
26 #include "tensorflow/core/lib/io/record_writer.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/file_system.h"
29 #include "tensorflow/core/platform/path.h"
30 #include "tensorflow/core/platform/status.h"
31 
32 namespace tensorflow {
33 
34 class GraphDef;
35 
36 namespace data {
37 
38 namespace experimental {
39 
40 class SnapshotMetadataRecord;
41 class SnapshotTensorMetadata;
42 
43 }  // namespace experimental
44 
45 namespace snapshot_util {
46 
47 constexpr char kMetadataFilename[] = "snapshot.metadata";
48 
49 constexpr char kModeAuto[] = "auto";
50 constexpr char kModeWrite[] = "write";
51 constexpr char kModeRead[] = "read";
52 constexpr char kModePassthrough[] = "passthrough";
53 constexpr char kShardDirectorySuffix[] = ".shard";
54 
55 enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
56 
57 // Returns the name of the "hash" directory for the given base path and hash ID.
58 std::string HashDirectory(const std::string& path, uint64 hash);
59 
60 // Returns the name of the "run" directory for the given base path and run ID.
61 std::string RunDirectory(const std::string& hash_directory, uint64 run_id);
62 std::string RunDirectory(const std::string& hash_directory,
63                          const std::string& run_id);
64 
65 // Returns the name of the "shard" directory for the given base path and shard
66 // ID.
67 std::string ShardDirectory(const std::string& run_directory, int64_t shard_id);
68 
69 // Returns the checkpoint file name for the given directory and checkpoint ID.
70 std::string GetCheckpointFileName(const std::string& shard_directory,
71                                   const uint64 checkpoint_id);
72 
73 // This is a interface class that exposes snapshot writing functionality.
74 class Writer {
75  public:
76   // Creates a new writer object.
77   static Status Create(Env* env, const std::string& filename,
78                        const std::string& compression_type, int version,
79                        const DataTypeVector& dtypes,
80                        std::unique_ptr<Writer>* out_writer);
81 
82   // Writes a vector of tensors to the snapshot writer file.
83   virtual Status WriteTensors(const std::vector<Tensor>& tensors) = 0;
84 
85   // Flushes any in-memory buffers to disk.
86   virtual Status Sync() = 0;
87 
88   // Closes and finalizes the snapshot file. All calls to any other method will
89   // be invalid after this call.
90   virtual Status Close() = 0;
91 
~Writer()92   virtual ~Writer() {}
93 
94  protected:
95   virtual Status Initialize(tensorflow::Env* env) = 0;
96 };
97 
98 // Writes snapshots with the standard TFRecord file format.
99 class TFRecordWriter : public Writer {
100  public:
101   TFRecordWriter(const std::string& filename,
102                  const std::string& compression_type);
103 
104   Status WriteTensors(const std::vector<Tensor>& tensors) override;
105 
106   Status Sync() override;
107 
108   Status Close() override;
109 
110   ~TFRecordWriter() override;
111 
112  protected:
113   Status Initialize(tensorflow::Env* env) override;
114 
115  private:
116   const std::string filename_;
117   const std::string compression_type_;
118 
119   std::unique_ptr<WritableFile> dest_;
120   std::unique_ptr<io::RecordWriter> record_writer_;
121 };
122 
123 // Writes snapshot with a custom (legacy) file format.
124 class CustomWriter : public Writer {
125  public:
126   static constexpr const size_t kHeaderSize = sizeof(uint64);
127 
128   static constexpr const char* const kClassName = "SnapshotWriter";
129   static constexpr const char* const kWriteStringPiece = "WriteStringPiece";
130   static constexpr const char* const kWriteCord = "WriteCord";
131   static constexpr const char* const kSeparator = "::";
132 
133   CustomWriter(const std::string& filename, const std::string& compression_type,
134                const DataTypeVector& dtypes);
135 
136   Status WriteTensors(const std::vector<Tensor>& tensors) override;
137 
138   Status Sync() override;
139 
140   Status Close() override;
141 
142   ~CustomWriter() override;
143 
144  protected:
145   Status Initialize(tensorflow::Env* env) override;
146 
147  private:
148   Status WriteRecord(const StringPiece& data);
149 
150 #if defined(TF_CORD_SUPPORT)
151   Status WriteRecord(const absl::Cord& data);
152 #endif  // TF_CORD_SUPPORT
153 
154   std::unique_ptr<WritableFile> dest_;
155   const std::string filename_;
156   const std::string compression_type_;
157   const DataTypeVector dtypes_;
158   // We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that
159   // in dest_ if we want compression. ZlibOutputBuffer doesn't own the original
160   // dest_ and so we need somewhere to store the original one.
161   std::unique_ptr<WritableFile> zlib_underlying_dest_;
162   std::vector<bool> simple_tensor_mask_;  // true for simple, false for complex.
163   int num_simple_ = 0;
164   int num_complex_ = 0;
165 };
166 
167 // Interface class for reading snapshot files previous written with Writer.
168 class Reader {
169  public:
170   // Op kernel that creates an instance of `Reader::Dataset` needed to support
171   // serialization and deserialization of `Reader::Dataset`.
172   class DatasetOp : public DatasetOpKernel {
173    public:
174     explicit DatasetOp(OpKernelConstruction* ctx);
175 
176    protected:
177     void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
178 
179    private:
180     DataTypeVector output_types_;
181     std::vector<PartialTensorShape> output_shapes_;
182     std::string compression_;
183     int64_t version_;
184   };
185 
186   // Op kernel that creates an instance of `Reader::NestedDataset` needed to
187   // support serialization and deserialization of `Reader::NestedDataset`.
188   class NestedDatasetOp : public DatasetOpKernel {
189    public:
190     explicit NestedDatasetOp(OpKernelConstruction* ctx);
191 
192    protected:
193     void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
194 
195    private:
196     DataTypeVector output_types_;
197     std::vector<PartialTensorShape> output_shapes_;
198   };
199 
200   // Creates a new Reader object that reads data from `filename`. Note that
201   // the `version`, `compression_type`, and `dtypes` arguments passed into
202   // `Writer` and `Reader` must be the same for the reading to succeed.
203   static Status Create(Env* env, const std::string& filename,
204                        const string& compression_type, int version,
205                        const DataTypeVector& dtypes,
206                        std::unique_ptr<Reader>* out_reader);
207 
208   // Returns a nested dataset for a set of given snapshot file names.
209   //
210   // This function takes a vector of snapshot files, and returns a nested
211   // dataset. Each element within the nested dataset is itself a dataset, and
212   // contains all the elements written out to each individual snapshot file.
213   static Status MakeNestedDataset(Env* env,
214                                   const std::vector<std::string>& shard_dirs,
215                                   const string& compression_type, int version,
216                                   const DataTypeVector& dtypes,
217                                   const std::vector<PartialTensorShape>& shapes,
218                                   const int64_t start_index,
219                                   DatasetBase** output);
220 
221   // Reads a vector of Tensors from the snapshot file.
222   virtual Status ReadTensors(std::vector<Tensor>* read_tensors) = 0;
223 
224   // Skips `num_records`. Equivalent to calling `ReadTensors` `num_records`
225   // times then discarding the results.
226   virtual Status SkipRecords(int64_t num_records);
227 
~Reader()228   virtual ~Reader() {}
229 
230  protected:
231   virtual Status Initialize(Env* env) = 0;
232 
233   class Dataset;
234   class NestedDataset;
235 };
236 
237 // Reads snapshots previously written with `TFRecordWriter`.
238 class TFRecordReader : public Reader {
239  public:
240   TFRecordReader(const std::string& filename, const string& compression_type,
241                  const DataTypeVector& dtypes);
242 
243   Status ReadTensors(std::vector<Tensor>* read_tensors) override;
244 
~TFRecordReader()245   ~TFRecordReader() override {}
246 
247  protected:
248   Status Initialize(Env* env) override;
249 
250  private:
251   std::string filename_;
252   std::unique_ptr<RandomAccessFile> file_;
253   std::unique_ptr<io::RecordReader> record_reader_;
254   uint64 offset_;
255 
256   const string compression_type_;
257   const DataTypeVector dtypes_;
258 };
259 
260 // Reads snapshots previously written with `CustomWriter`.
261 class CustomReader : public Reader {
262  public:
263   // The reader input buffer size is deliberately large because the input reader
264   // will throw an error if the compressed block length cannot fit in the input
265   // buffer.
266   static constexpr const int64_t kSnappyReaderInputBufferSizeBytes =
267       1 << 30;  // 1 GiB
268   // TODO(b/148804377): Set this in a smarter fashion.
269   static constexpr const int64_t kSnappyReaderOutputBufferSizeBytes =
270       32 << 20;  // 32 MiB
271   static constexpr const size_t kHeaderSize = sizeof(uint64);
272 
273   static constexpr const char* const kClassName = "SnapshotReader";
274   static constexpr const char* const kReadString = "ReadString";
275   static constexpr const char* const kReadCord = "ReadCord";
276   static constexpr const char* const kSeparator = "::";
277 
278   CustomReader(const std::string& filename, const string& compression_type,
279                const int version, const DataTypeVector& dtypes);
280 
281   Status ReadTensors(std::vector<Tensor>* read_tensors) override;
282 
~CustomReader()283   ~CustomReader() override {}
284 
285  protected:
286   Status Initialize(Env* env) override;
287 
288  private:
289   Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
290 
291   Status SnappyUncompress(
292       const experimental::SnapshotTensorMetadata* metadata,
293       std::vector<Tensor>* simple_tensors,
294       std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
295           tensor_proto_strs);
296 
297   Status ReadRecord(tstring* record);
298 
299 #if defined(TF_CORD_SUPPORT)
300   Status ReadRecord(absl::Cord* record);
301 #endif
302 
303   std::string filename_;
304   std::unique_ptr<RandomAccessFile> file_;
305   std::unique_ptr<io::InputStreamInterface> input_stream_;
306   const string compression_type_;
307   const int version_;
308   const DataTypeVector dtypes_;
309   int num_simple_ = 0;
310   int num_complex_ = 0;
311   std::vector<bool> simple_tensor_mask_;  // true for simple, false for complex.
312 };
313 
314 // Writes snapshot metadata to the given directory.
315 Status WriteMetadataFile(Env* env, const string& dir,
316                          const experimental::SnapshotMetadataRecord* metadata);
317 
318 // Reads snapshot metadata from the given directory.
319 Status ReadMetadataFile(Env* env, const string& dir,
320                         experimental::SnapshotMetadataRecord* metadata,
321                         bool* file_exists);
322 
323 // Writes a dataset graph to the given directory.
324 Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
325                         const GraphDef* graph);
326 
327 Status DetermineOpState(const std::string& mode_string, bool file_exists,
328                         const experimental::SnapshotMetadataRecord* metadata,
329                         const uint64 pending_snapshot_expiry_seconds,
330                         Mode* mode);
331 
332 // Represents a dataset element or EOF.
333 struct ElementOrEOF {
334   std::vector<Tensor> value;
335   bool end_of_sequence = false;
336 };
337 
338 // AsyncWriter provides API for asynchronously writing dataset elements
339 // (each represented as a vector of tensors) to a file.
340 //
341 // The expected use of this API is:
342 //
343 // std::unique_ptr<AsyncWriter> writer = absl_make_unique<AsyncWriter>(...);
344 //
345 // while (data_available()) {
346 //   std::vector<Tensor> data = read_data()
347 //   writer->Write(data);
348 // }
349 // writer->SignalEOF();
350 // writer = nullptr;  // This will block until writes are flushed.
351 class AsyncWriter {
352  public:
353   explicit AsyncWriter(Env* env, int64_t file_index,
354                        const std::string& shard_directory, uint64 checkpoint_id,
355                        const std::string& compression, int64_t version,
356                        const DataTypeVector& output_types,
357                        std::function<void(Status)> done);
358 
359   // Writes the given tensors. The method is non-blocking and returns without
360   // waiting for the element to be written.
361   void Write(const std::vector<Tensor>& tensors) TF_LOCKS_EXCLUDED(mu_);
362 
363   // Signals the end of input. The method is non-blocking and returns without
364   // waiting for the writer to be closed.
365   void SignalEOF() TF_LOCKS_EXCLUDED(mu_);
366 
367  private:
368   void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_);
369   bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
370   Status WriterThread(Env* env, const std::string& shard_directory,
371                       uint64 checkpoint_id, const std::string& compression,
372                       int64_t version, DataTypeVector output_types);
373 
374   mutex mu_;
375   std::deque<ElementOrEOF> deque_ TF_GUARDED_BY(mu_);
376 
377   // This has to be last. During destruction, we need to make sure that the
378   // Thread object is destroyed first as its destructor blocks on thread
379   // completion. If there are other member variables after this, they may get
380   // destroyed first before the thread finishes, potentially causing the
381   // thread to access invalid memory.
382   std::unique_ptr<Thread> thread_;
383 };
384 
385 }  // namespace snapshot_util
386 }  // namespace data
387 }  // namespace tensorflow
388 
389 #endif  // TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_
390