1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_ 17 #define TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_ 18 19 #include "tensorflow/core/framework/dataset.h" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/types.h" 23 #include "tensorflow/core/lib/io/compression.h" 24 #include "tensorflow/core/lib/io/inputstream_interface.h" 25 #include "tensorflow/core/lib/io/record_reader.h" 26 #include "tensorflow/core/lib/io/record_writer.h" 27 #include "tensorflow/core/platform/env.h" 28 #include "tensorflow/core/platform/file_system.h" 29 #include "tensorflow/core/platform/path.h" 30 #include "tensorflow/core/platform/status.h" 31 32 namespace tensorflow { 33 34 class GraphDef; 35 36 namespace data { 37 38 namespace experimental { 39 40 class SnapshotMetadataRecord; 41 class SnapshotTensorMetadata; 42 43 } // namespace experimental 44 45 namespace snapshot_util { 46 47 constexpr char kMetadataFilename[] = "snapshot.metadata"; 48 49 constexpr char kModeAuto[] = "auto"; 50 constexpr char kModeWrite[] = "write"; 51 constexpr char kModeRead[] = "read"; 52 constexpr char kModePassthrough[] = "passthrough"; 53 constexpr char kShardDirectorySuffix[] = ".shard"; 54 55 enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; 56 57 // Returns the name of the "hash" directory for the given base path and hash ID. 58 std::string HashDirectory(const std::string& path, uint64 hash); 59 60 // Returns the name of the "run" directory for the given base path and run ID. 61 std::string RunDirectory(const std::string& hash_directory, uint64 run_id); 62 std::string RunDirectory(const std::string& hash_directory, 63 const std::string& run_id); 64 65 // Returns the name of the "shard" directory for the given base path and shard 66 // ID. 67 std::string ShardDirectory(const std::string& run_directory, int64_t shard_id); 68 69 // Returns the checkpoint file name for the given directory and checkpoint ID. 70 std::string GetCheckpointFileName(const std::string& shard_directory, 71 const uint64 checkpoint_id); 72 73 // This is a interface class that exposes snapshot writing functionality. 74 class Writer { 75 public: 76 // Creates a new writer object. 77 static Status Create(Env* env, const std::string& filename, 78 const std::string& compression_type, int version, 79 const DataTypeVector& dtypes, 80 std::unique_ptr<Writer>* out_writer); 81 82 // Writes a vector of tensors to the snapshot writer file. 83 virtual Status WriteTensors(const std::vector<Tensor>& tensors) = 0; 84 85 // Flushes any in-memory buffers to disk. 86 virtual Status Sync() = 0; 87 88 // Closes and finalizes the snapshot file. All calls to any other method will 89 // be invalid after this call. 90 virtual Status Close() = 0; 91 ~Writer()92 virtual ~Writer() {} 93 94 protected: 95 virtual Status Initialize(tensorflow::Env* env) = 0; 96 }; 97 98 // Writes snapshots with the standard TFRecord file format. 99 class TFRecordWriter : public Writer { 100 public: 101 TFRecordWriter(const std::string& filename, 102 const std::string& compression_type); 103 104 Status WriteTensors(const std::vector<Tensor>& tensors) override; 105 106 Status Sync() override; 107 108 Status Close() override; 109 110 ~TFRecordWriter() override; 111 112 protected: 113 Status Initialize(tensorflow::Env* env) override; 114 115 private: 116 const std::string filename_; 117 const std::string compression_type_; 118 119 std::unique_ptr<WritableFile> dest_; 120 std::unique_ptr<io::RecordWriter> record_writer_; 121 }; 122 123 // Writes snapshot with a custom (legacy) file format. 124 class CustomWriter : public Writer { 125 public: 126 static constexpr const size_t kHeaderSize = sizeof(uint64); 127 128 static constexpr const char* const kClassName = "SnapshotWriter"; 129 static constexpr const char* const kWriteStringPiece = "WriteStringPiece"; 130 static constexpr const char* const kWriteCord = "WriteCord"; 131 static constexpr const char* const kSeparator = "::"; 132 133 CustomWriter(const std::string& filename, const std::string& compression_type, 134 const DataTypeVector& dtypes); 135 136 Status WriteTensors(const std::vector<Tensor>& tensors) override; 137 138 Status Sync() override; 139 140 Status Close() override; 141 142 ~CustomWriter() override; 143 144 protected: 145 Status Initialize(tensorflow::Env* env) override; 146 147 private: 148 Status WriteRecord(const StringPiece& data); 149 150 #if defined(TF_CORD_SUPPORT) 151 Status WriteRecord(const absl::Cord& data); 152 #endif // TF_CORD_SUPPORT 153 154 std::unique_ptr<WritableFile> dest_; 155 const std::string filename_; 156 const std::string compression_type_; 157 const DataTypeVector dtypes_; 158 // We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that 159 // in dest_ if we want compression. ZlibOutputBuffer doesn't own the original 160 // dest_ and so we need somewhere to store the original one. 161 std::unique_ptr<WritableFile> zlib_underlying_dest_; 162 std::vector<bool> simple_tensor_mask_; // true for simple, false for complex. 163 int num_simple_ = 0; 164 int num_complex_ = 0; 165 }; 166 167 // Interface class for reading snapshot files previous written with Writer. 168 class Reader { 169 public: 170 // Op kernel that creates an instance of `Reader::Dataset` needed to support 171 // serialization and deserialization of `Reader::Dataset`. 172 class DatasetOp : public DatasetOpKernel { 173 public: 174 explicit DatasetOp(OpKernelConstruction* ctx); 175 176 protected: 177 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; 178 179 private: 180 DataTypeVector output_types_; 181 std::vector<PartialTensorShape> output_shapes_; 182 std::string compression_; 183 int64_t version_; 184 }; 185 186 // Op kernel that creates an instance of `Reader::NestedDataset` needed to 187 // support serialization and deserialization of `Reader::NestedDataset`. 188 class NestedDatasetOp : public DatasetOpKernel { 189 public: 190 explicit NestedDatasetOp(OpKernelConstruction* ctx); 191 192 protected: 193 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; 194 195 private: 196 DataTypeVector output_types_; 197 std::vector<PartialTensorShape> output_shapes_; 198 }; 199 200 // Creates a new Reader object that reads data from `filename`. Note that 201 // the `version`, `compression_type`, and `dtypes` arguments passed into 202 // `Writer` and `Reader` must be the same for the reading to succeed. 203 static Status Create(Env* env, const std::string& filename, 204 const string& compression_type, int version, 205 const DataTypeVector& dtypes, 206 std::unique_ptr<Reader>* out_reader); 207 208 // Returns a nested dataset for a set of given snapshot file names. 209 // 210 // This function takes a vector of snapshot files, and returns a nested 211 // dataset. Each element within the nested dataset is itself a dataset, and 212 // contains all the elements written out to each individual snapshot file. 213 static Status MakeNestedDataset(Env* env, 214 const std::vector<std::string>& shard_dirs, 215 const string& compression_type, int version, 216 const DataTypeVector& dtypes, 217 const std::vector<PartialTensorShape>& shapes, 218 const int64_t start_index, 219 DatasetBase** output); 220 221 // Reads a vector of Tensors from the snapshot file. 222 virtual Status ReadTensors(std::vector<Tensor>* read_tensors) = 0; 223 224 // Skips `num_records`. Equivalent to calling `ReadTensors` `num_records` 225 // times then discarding the results. 226 virtual Status SkipRecords(int64_t num_records); 227 ~Reader()228 virtual ~Reader() {} 229 230 protected: 231 virtual Status Initialize(Env* env) = 0; 232 233 class Dataset; 234 class NestedDataset; 235 }; 236 237 // Reads snapshots previously written with `TFRecordWriter`. 238 class TFRecordReader : public Reader { 239 public: 240 TFRecordReader(const std::string& filename, const string& compression_type, 241 const DataTypeVector& dtypes); 242 243 Status ReadTensors(std::vector<Tensor>* read_tensors) override; 244 ~TFRecordReader()245 ~TFRecordReader() override {} 246 247 protected: 248 Status Initialize(Env* env) override; 249 250 private: 251 std::string filename_; 252 std::unique_ptr<RandomAccessFile> file_; 253 std::unique_ptr<io::RecordReader> record_reader_; 254 uint64 offset_; 255 256 const string compression_type_; 257 const DataTypeVector dtypes_; 258 }; 259 260 // Reads snapshots previously written with `CustomWriter`. 261 class CustomReader : public Reader { 262 public: 263 // The reader input buffer size is deliberately large because the input reader 264 // will throw an error if the compressed block length cannot fit in the input 265 // buffer. 266 static constexpr const int64_t kSnappyReaderInputBufferSizeBytes = 267 1 << 30; // 1 GiB 268 // TODO(b/148804377): Set this in a smarter fashion. 269 static constexpr const int64_t kSnappyReaderOutputBufferSizeBytes = 270 32 << 20; // 32 MiB 271 static constexpr const size_t kHeaderSize = sizeof(uint64); 272 273 static constexpr const char* const kClassName = "SnapshotReader"; 274 static constexpr const char* const kReadString = "ReadString"; 275 static constexpr const char* const kReadCord = "ReadCord"; 276 static constexpr const char* const kSeparator = "::"; 277 278 CustomReader(const std::string& filename, const string& compression_type, 279 const int version, const DataTypeVector& dtypes); 280 281 Status ReadTensors(std::vector<Tensor>* read_tensors) override; 282 ~CustomReader()283 ~CustomReader() override {} 284 285 protected: 286 Status Initialize(Env* env) override; 287 288 private: 289 Status ReadTensorsV0(std::vector<Tensor>* read_tensors); 290 291 Status SnappyUncompress( 292 const experimental::SnapshotTensorMetadata* metadata, 293 std::vector<Tensor>* simple_tensors, 294 std::vector<std::pair<std::unique_ptr<char[]>, size_t>>* 295 tensor_proto_strs); 296 297 Status ReadRecord(tstring* record); 298 299 #if defined(TF_CORD_SUPPORT) 300 Status ReadRecord(absl::Cord* record); 301 #endif 302 303 std::string filename_; 304 std::unique_ptr<RandomAccessFile> file_; 305 std::unique_ptr<io::InputStreamInterface> input_stream_; 306 const string compression_type_; 307 const int version_; 308 const DataTypeVector dtypes_; 309 int num_simple_ = 0; 310 int num_complex_ = 0; 311 std::vector<bool> simple_tensor_mask_; // true for simple, false for complex. 312 }; 313 314 // Writes snapshot metadata to the given directory. 315 Status WriteMetadataFile(Env* env, const string& dir, 316 const experimental::SnapshotMetadataRecord* metadata); 317 318 // Reads snapshot metadata from the given directory. 319 Status ReadMetadataFile(Env* env, const string& dir, 320 experimental::SnapshotMetadataRecord* metadata, 321 bool* file_exists); 322 323 // Writes a dataset graph to the given directory. 324 Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, 325 const GraphDef* graph); 326 327 Status DetermineOpState(const std::string& mode_string, bool file_exists, 328 const experimental::SnapshotMetadataRecord* metadata, 329 const uint64 pending_snapshot_expiry_seconds, 330 Mode* mode); 331 332 // Represents a dataset element or EOF. 333 struct ElementOrEOF { 334 std::vector<Tensor> value; 335 bool end_of_sequence = false; 336 }; 337 338 // AsyncWriter provides API for asynchronously writing dataset elements 339 // (each represented as a vector of tensors) to a file. 340 // 341 // The expected use of this API is: 342 // 343 // std::unique_ptr<AsyncWriter> writer = absl_make_unique<AsyncWriter>(...); 344 // 345 // while (data_available()) { 346 // std::vector<Tensor> data = read_data() 347 // writer->Write(data); 348 // } 349 // writer->SignalEOF(); 350 // writer = nullptr; // This will block until writes are flushed. 351 class AsyncWriter { 352 public: 353 explicit AsyncWriter(Env* env, int64_t file_index, 354 const std::string& shard_directory, uint64 checkpoint_id, 355 const std::string& compression, int64_t version, 356 const DataTypeVector& output_types, 357 std::function<void(Status)> done); 358 359 // Writes the given tensors. The method is non-blocking and returns without 360 // waiting for the element to be written. 361 void Write(const std::vector<Tensor>& tensors) TF_LOCKS_EXCLUDED(mu_); 362 363 // Signals the end of input. The method is non-blocking and returns without 364 // waiting for the writer to be closed. 365 void SignalEOF() TF_LOCKS_EXCLUDED(mu_); 366 367 private: 368 void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_); 369 bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 370 Status WriterThread(Env* env, const std::string& shard_directory, 371 uint64 checkpoint_id, const std::string& compression, 372 int64_t version, DataTypeVector output_types); 373 374 mutex mu_; 375 std::deque<ElementOrEOF> deque_ TF_GUARDED_BY(mu_); 376 377 // This has to be last. During destruction, we need to make sure that the 378 // Thread object is destroyed first as its destructor blocks on thread 379 // completion. If there are other member variables after this, they may get 380 // destroyed first before the thread finishes, potentially causing the 381 // thread to access invalid memory. 382 std::unique_ptr<Thread> thread_; 383 }; 384 385 } // namespace snapshot_util 386 } // namespace data 387 } // namespace tensorflow 388 389 #endif // TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_ 390