xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/snapshot_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/snapshot_utils.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <queue>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/memory/memory.h"
25 #include "tensorflow/core/common_runtime/dma_helper.h"
26 #include "tensorflow/core/data/name_utils.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/lib/io/buffered_inputstream.h"
31 #include "tensorflow/core/lib/io/random_inputstream.h"
32 #include "tensorflow/core/lib/io/record_writer.h"
33 #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
34 #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
35 #include "tensorflow/core/lib/io/zlib_compression_options.h"
36 #include "tensorflow/core/lib/io/zlib_inputstream.h"
37 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
38 #include "tensorflow/core/platform/coding.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/file_system.h"
41 #include "tensorflow/core/platform/path.h"
42 #include "tensorflow/core/platform/random.h"
43 #include "tensorflow/core/platform/strcat.h"
44 #include "tensorflow/core/platform/stringprintf.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/protobuf/snapshot.pb.h"
47 
48 namespace tensorflow {
49 namespace data {
50 namespace snapshot_util {
51 namespace {
52 
53 constexpr const char* const kOutputTypes = "output_types";
54 constexpr const char* const kOutputShapes = "output_shapes";
55 constexpr const char* const kCompression = "compression";
56 constexpr const char* const kVersion = "version";
57 constexpr const char* const kCurrentCheckpointID = "current_checkpoint_id";
58 constexpr const char* const kIndex = "index";
59 constexpr const char* const kStartIndex = "start_index";
60 
61 }  // namespace
62 
63 /* static */ constexpr const int64_t
64     CustomReader::kSnappyReaderInputBufferSizeBytes;
65 /* static */ constexpr const int64_t
66     CustomReader::kSnappyReaderOutputBufferSizeBytes;
67 
HashDirectory(const std::string & path,uint64 hash)68 std::string HashDirectory(const std::string& path, uint64 hash) {
69   return io::JoinPath(
70       path, strings::Printf("%llu", static_cast<unsigned long long>(hash)));
71 }
72 
RunDirectory(const std::string & hash_directory,uint64 run_id)73 std::string RunDirectory(const std::string& hash_directory, uint64 run_id) {
74   return RunDirectory(
75       hash_directory,
76       strings::Printf("%llu", static_cast<unsigned long long>(run_id)));
77 }
78 
RunDirectory(const std::string & hash_directory,const std::string & run_id)79 std::string RunDirectory(const std::string& hash_directory,
80                          const std::string& run_id) {
81   return io::JoinPath(hash_directory, run_id);
82 }
83 
ShardDirectory(const std::string & run_directory,int64_t shard_id)84 std::string ShardDirectory(const std::string& run_directory, int64_t shard_id) {
85   return io::JoinPath(
86       run_directory,
87       strings::Printf("%08llu%s", static_cast<unsigned long long>(shard_id),
88                       kShardDirectorySuffix));
89 }
GetCheckpointFileName(const std::string & shard_directory,uint64 checkpoint_id)90 std::string GetCheckpointFileName(const std::string& shard_directory,
91                                   uint64 checkpoint_id) {
92   return io::JoinPath(
93       shard_directory,
94       strings::Printf("%08llu.snapshot",
95                       static_cast<unsigned long long>(checkpoint_id)));
96 }
97 
Create(Env * env,const std::string & filename,const std::string & compression_type,int version,const DataTypeVector & dtypes,std::unique_ptr<Writer> * out_writer)98 Status Writer::Create(Env* env, const std::string& filename,
99                       const std::string& compression_type, int version,
100                       const DataTypeVector& dtypes,
101                       std::unique_ptr<Writer>* out_writer) {
102   switch (version) {
103     case 1:
104       *out_writer =
105           std::make_unique<CustomWriter>(filename, compression_type, dtypes);
106       break;
107     case 2:
108       *out_writer =
109           std::make_unique<TFRecordWriter>(filename, compression_type);
110       break;
111     default:
112       return errors::InvalidArgument("Snapshot writer version: ", version,
113                                      " is not supported.");
114   }
115 
116   return (*out_writer)->Initialize(env);
117 }
118 
TFRecordWriter(const std::string & filename,const std::string & compression_type)119 TFRecordWriter::TFRecordWriter(const std::string& filename,
120                                const std::string& compression_type)
121     : filename_(filename), compression_type_(compression_type) {}
122 
Initialize(tensorflow::Env * env)123 Status TFRecordWriter::Initialize(tensorflow::Env* env) {
124   TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
125 
126   record_writer_ = std::make_unique<io::RecordWriter>(
127       dest_.get(), io::RecordWriterOptions::CreateRecordWriterOptions(
128                        /*compression_type=*/compression_type_));
129   return OkStatus();
130 }
131 
WriteTensors(const std::vector<Tensor> & tensors)132 Status TFRecordWriter::WriteTensors(const std::vector<Tensor>& tensors) {
133   for (const auto& tensor : tensors) {
134     TensorProto proto;
135     tensor.AsProtoTensorContent(&proto);
136 #if defined(TF_CORD_SUPPORT)
137     // Creating raw pointer here because std::move() in a releases in OSS TF
138     // will result in a smart pointer being moved upon function creation, which
139     // will result in proto_buffer == nullptr when WriteRecord happens.
140     auto proto_buffer = new std::string();
141     proto.SerializeToString(proto_buffer);
142     absl::Cord proto_serialized = absl::MakeCordFromExternal(
143         *proto_buffer,
144         [proto_buffer](absl::string_view) { delete proto_buffer; });
145     TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
146 #else   // TF_CORD_SUPPORT
147     TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString()));
148 #endif  // TF_CORD_SUPPORT
149   }
150   return OkStatus();
151 }
152 
Sync()153 Status TFRecordWriter::Sync() {
154   TF_RETURN_IF_ERROR(record_writer_->Flush());
155   return dest_->Flush();
156 }
157 
Close()158 Status TFRecordWriter::Close() {
159   if (record_writer_ != nullptr) {
160     TF_RETURN_IF_ERROR(Sync());
161     TF_RETURN_IF_ERROR(record_writer_->Close());
162     TF_RETURN_IF_ERROR(dest_->Close());
163     record_writer_ = nullptr;
164     dest_ = nullptr;
165   }
166   return OkStatus();
167 }
168 
~TFRecordWriter()169 TFRecordWriter::~TFRecordWriter() {
170   Status s = Close();
171   if (!s.ok()) {
172     LOG(ERROR) << "Failed to close snapshot file " << filename_ << ": " << s;
173   }
174 }
175 
CustomWriter(const std::string & filename,const std::string & compression_type,const DataTypeVector & dtypes)176 CustomWriter::CustomWriter(const std::string& filename,
177                            const std::string& compression_type,
178                            const DataTypeVector& dtypes)
179     : filename_(filename),
180       compression_type_(compression_type),
181       dtypes_(dtypes) {}
182 
Initialize(tensorflow::Env * env)183 Status CustomWriter::Initialize(tensorflow::Env* env) {
184   TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
185 #if defined(IS_SLIM_BUILD)
186   if (compression_type_ != io::compression::kNone) {
187     LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
188                << "off compression.";
189   }
190 #else   // IS_SLIM_BUILD
191   if (compression_type_ == io::compression::kGzip) {
192     zlib_underlying_dest_.swap(dest_);
193     io::ZlibCompressionOptions zlib_options;
194     zlib_options = io::ZlibCompressionOptions::GZIP();
195 
196     io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
197         zlib_underlying_dest_.get(), zlib_options.input_buffer_size,
198         zlib_options.output_buffer_size, zlib_options);
199     TF_CHECK_OK(zlib_output_buffer->Init());
200     dest_.reset(zlib_output_buffer);
201   }
202 #endif  // IS_SLIM_BUILD
203   simple_tensor_mask_.reserve(dtypes_.size());
204   for (const auto& dtype : dtypes_) {
205     if (DataTypeCanUseMemcpy(dtype)) {
206       simple_tensor_mask_.push_back(true);
207       num_simple_++;
208     } else {
209       simple_tensor_mask_.push_back(false);
210       num_complex_++;
211     }
212   }
213 
214   return OkStatus();
215 }
216 
WriteTensors(const std::vector<Tensor> & tensors)217 Status CustomWriter::WriteTensors(const std::vector<Tensor>& tensors) {
218   if (compression_type_ != io::compression::kSnappy) {
219     experimental::SnapshotRecord record;
220     for (const auto& tensor : tensors) {
221       TensorProto* t = record.add_tensor();
222       tensor.AsProtoTensorContent(t);
223     }
224 #if defined(TF_CORD_SUPPORT)
225     auto record_buffer = new std::string();
226     record.SerializeToString(record_buffer);
227     absl::Cord record_serialized = absl::MakeCordFromExternal(
228         *record_buffer,
229         [record_buffer](absl::string_view) { delete record_buffer; });
230     return WriteRecord(record_serialized);
231 #else   // TF_CORD_SUPPORT
232     return WriteRecord(record.SerializeAsString());
233 #endif  // TF_CORD_SUPPORT
234   }
235 
236   std::vector<const TensorBuffer*> tensor_buffers;
237   tensor_buffers.reserve(num_simple_);
238   std::vector<TensorProto> tensor_protos;
239   tensor_protos.reserve(num_complex_);
240   experimental::SnapshotTensorMetadata metadata;
241   int64_t total_size = 0;
242   for (int i = 0, end = tensors.size(); i < end; ++i) {
243     const Tensor& tensor = tensors[i];
244     experimental::TensorMetadata* tensor_metadata =
245         metadata.add_tensor_metadata();
246     tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
247     int64_t size = 0;
248     if (simple_tensor_mask_[i]) {
249       auto tensor_buffer = DMAHelper::buffer(&tensor);
250       tensor_buffers.push_back(tensor_buffer);
251       size = tensor_buffer->size();
252     } else {
253       TensorProto proto;
254       tensor.AsProtoTensorContent(&proto);
255       size = proto.ByteSizeLong();
256       tensor_protos.push_back(std::move(proto));
257     }
258     tensor_metadata->set_tensor_size_bytes(size);
259     total_size += size;
260   }
261 
262   std::vector<char> uncompressed(total_size);
263   char* position = uncompressed.data();
264   int buffer_index = 0;
265   int proto_index = 0;
266   for (int i = 0, end = tensors.size(); i < end; ++i) {
267     const auto& tensor_metadata = metadata.tensor_metadata(i);
268     if (simple_tensor_mask_[i]) {
269       memcpy(position, tensor_buffers[buffer_index]->data(),
270              tensor_metadata.tensor_size_bytes());
271       buffer_index++;
272     } else {
273       tensor_protos[proto_index].SerializeToArray(
274           position, tensor_metadata.tensor_size_bytes());
275       proto_index++;
276     }
277     position += tensor_metadata.tensor_size_bytes();
278   }
279   DCHECK_EQ(position, uncompressed.data() + total_size);
280 
281   string output;
282   if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
283     return errors::Internal("Failed to compress using snappy.");
284   }
285 
286 #if defined(TF_CORD_SUPPORT)
287   auto metadata_buffer = new std::string();
288   metadata.SerializeToString(metadata_buffer);
289   absl::Cord metadata_serialized = absl::MakeCordFromExternal(
290       *metadata_buffer,
291       [metadata_buffer](absl::string_view) { delete metadata_buffer; });
292 #else
293   std::string metadata_serialized = metadata.SerializeAsString();
294 #endif  // TF_CORD_SUPPORT
295   TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
296   TF_RETURN_IF_ERROR(WriteRecord(output));
297   return OkStatus();
298 }
299 
Sync()300 Status CustomWriter::Sync() { return dest_->Sync(); }
301 
Close()302 Status CustomWriter::Close() {
303   if (dest_ != nullptr) {
304     TF_RETURN_IF_ERROR(dest_->Close());
305     dest_ = nullptr;
306   }
307   if (zlib_underlying_dest_ != nullptr) {
308     TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
309     zlib_underlying_dest_ = nullptr;
310   }
311   return OkStatus();
312 }
313 
~CustomWriter()314 CustomWriter::~CustomWriter() {
315   Status s = Close();
316   if (!s.ok()) {
317     LOG(ERROR) << "Could not finish writing file: " << s;
318   }
319 }
320 
WriteRecord(const StringPiece & data)321 Status CustomWriter::WriteRecord(const StringPiece& data) {
322   char header[kHeaderSize];
323   core::EncodeFixed64(header, data.size());
324   TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
325   return dest_->Append(data);
326 }
327 
328 #if defined(TF_CORD_SUPPORT)
WriteRecord(const absl::Cord & data)329 Status CustomWriter::WriteRecord(const absl::Cord& data) {
330   char header[kHeaderSize];
331   core::EncodeFixed64(header, data.size());
332   TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
333   return dest_->Append(data);
334 }
335 #endif  // TF_CORD_SUPPORT
336 
Create(Env * env,const std::string & filename,const string & compression_type,int version,const DataTypeVector & dtypes,std::unique_ptr<Reader> * out_reader)337 Status Reader::Create(Env* env, const std::string& filename,
338                       const string& compression_type, int version,
339                       const DataTypeVector& dtypes,
340                       std::unique_ptr<Reader>* out_reader) {
341   switch (version) {
342     // CustomReader is able to read a legacy snapshot file format (v0) though
343     // custom writer doesn't have the ability to write it any more since it is
344     // strictly worse than V1.
345     case 0:
346     case 1:
347       *out_reader = std::make_unique<CustomReader>(filename, compression_type,
348                                                    version, dtypes);
349       break;
350     case 2:
351       *out_reader =
352           std::make_unique<TFRecordReader>(filename, compression_type, dtypes);
353       break;
354     default:
355       return errors::InvalidArgument("Snapshot reader version: ", version,
356                                      " is not supported.");
357   }
358 
359   return (*out_reader)->Initialize(env);
360 }
361 
SkipRecords(int64_t num_records)362 Status Reader::SkipRecords(int64_t num_records) {
363   // TODO(frankchn): Optimize to not parse the entire Tensor and actually skip.
364   for (int i = 0; i < num_records; ++i) {
365     std::vector<Tensor> unused_tensors;
366     TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors));
367   }
368   return OkStatus();
369 }
370 
371 class Reader::Dataset : public DatasetBase {
372  public:
Dataset(DatasetContext && ctx,const std::string & shard_dir,const std::string & compression,const int64_t version,const DataTypeVector & dtypes,const std::vector<PartialTensorShape> & shapes,const int64_t start_index)373   Dataset(DatasetContext&& ctx, const std::string& shard_dir,
374           const std::string& compression, const int64_t version,
375           const DataTypeVector& dtypes,
376           const std::vector<PartialTensorShape>& shapes,
377           const int64_t start_index)
378       : DatasetBase(std::move(ctx)),
379         shard_dir_(shard_dir),
380         compression_(compression),
381         version_(version),
382         dtypes_(dtypes),
383         shapes_(shapes),
384         start_index_(start_index) {}
385 
output_dtypes() const386   const DataTypeVector& output_dtypes() const override { return dtypes_; }
387 
output_shapes() const388   const std::vector<PartialTensorShape>& output_shapes() const override {
389     return shapes_;
390   }
391 
DebugString() const392   std::string DebugString() const override { return "SnapshotDatasetReader"; }
393 
InputDatasets(std::vector<const DatasetBase * > * inputs) const394   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
395     return OkStatus();
396   }
397 
CheckExternalState() const398   Status CheckExternalState() const override { return OkStatus(); }
399 
400  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** node) const401   Status AsGraphDefInternal(SerializationContext* ctx,
402                             DatasetGraphDefBuilder* b,
403                             Node** node) const override {
404     Node* shard_dir = nullptr;
405     TF_RETURN_IF_ERROR(b->AddScalar(shard_dir_, &shard_dir));
406 
407     Node* start_index = nullptr;
408     TF_RETURN_IF_ERROR(b->AddScalar(start_index_, &start_index));
409 
410     AttrValue compression;
411     b->BuildAttrValue(compression_, &compression);
412 
413     AttrValue version;
414     b->BuildAttrValue(version_, &version);
415 
416     return b->AddDataset(
417         this,
418         /*inputs=*/
419         {std::make_pair(0, shard_dir), std::make_pair(1, start_index)},
420         /*list_inputs=*/{},
421         /*attrs=*/
422         {{kCompression, compression}, {kVersion, version}},
423         /*use_dataset_name=*/true, node);
424   }
425 
MakeIteratorInternal(const string & prefix) const426   std::unique_ptr<IteratorBase> MakeIteratorInternal(
427       const string& prefix) const override {
428     return std::make_unique<Iterator>(Iterator::Params{
429         this, name_utils::IteratorPrefix(node_name(), prefix)});
430   }
431 
432  private:
433   class Iterator : public DatasetIterator<Dataset> {
434    public:
Iterator(const Params & params)435     explicit Iterator(const Params& params)
436         : DatasetIterator<Dataset>(params),
437           start_index_(dataset()->start_index_) {}
438 
Initialize(IteratorContext * ctx)439     Status Initialize(IteratorContext* ctx) override {
440       // TODO(jsimsa): This only needs to happen when we are not restoring but
441       // parallel_interleave op implementation caches IteratorContext (and thus
442       // the is_restoring bit ends up being inaccurate).
443       TF_RETURN_IF_ERROR(Reader::Create(
444           ctx->env(), GetCurrentFilename(), dataset()->compression_,
445           dataset()->version_, dataset()->dtypes_, &reader_));
446       return AdvanceToStartIndex(ctx);
447     }
448 
449    protected:
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)450     Status GetNextInternal(IteratorContext* ctx,
451                            std::vector<Tensor>* out_tensors,
452                            bool* end_of_sequence) override {
453       *end_of_sequence = false;
454       Status s = reader_->ReadTensors(out_tensors);
455       if (!errors::IsOutOfRange(s)) {
456         start_index_++;
457         return s;
458       }
459       Status status = AdvanceToNextFile(ctx->env());
460       if (errors::IsNotFound(status)) {
461         *end_of_sequence = true;
462         return OkStatus();
463       }
464       return status;
465     }
466 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)467     Status SaveInternal(SerializationContext* ctx,
468                         IteratorStateWriter* writer) override {
469       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentCheckpointID),
470                                              current_checkpoint_id_));
471       TF_RETURN_IF_ERROR(
472           writer->WriteScalar(full_name(kStartIndex), start_index_));
473       return OkStatus();
474     }
475 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)476     Status RestoreInternal(IteratorContext* ctx,
477                            IteratorStateReader* reader) override {
478       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentCheckpointID),
479                                             &current_checkpoint_id_));
480       TF_RETURN_IF_ERROR(
481           reader->ReadScalar(full_name(kStartIndex), &start_index_));
482       TF_RETURN_IF_ERROR(ctx->env()->FileExists(GetCurrentFilename()));
483       TF_RETURN_IF_ERROR(Reader::Create(
484           ctx->env(), GetCurrentFilename(), dataset()->compression_,
485           dataset()->version_, dataset()->dtypes_, &reader_));
486       return AdvanceToStartIndex(ctx);
487     }
488 
489    private:
AdvanceToNextFile(Env * env)490     Status AdvanceToNextFile(Env* env) {
491       start_index_ = 0;
492       current_checkpoint_id_++;
493       TF_RETURN_IF_ERROR(env->FileExists(GetCurrentFilename()));
494       return Reader::Create(env, GetCurrentFilename(), dataset()->compression_,
495                             dataset()->version_, dataset()->dtypes_, &reader_);
496     }
497 
GetCurrentFilename()498     std::string GetCurrentFilename() {
499       return GetCheckpointFileName(dataset()->shard_dir_,
500                                    current_checkpoint_id_);
501     }
502 
503     // TODO(frankchn): Optimize this to not parse every single element.
AdvanceToStartIndex(IteratorContext * ctx)504     Status AdvanceToStartIndex(IteratorContext* ctx) {
505       for (int64_t i = 0; i < start_index_; ++i) {
506         std::vector<Tensor> unused;
507         TF_RETURN_IF_ERROR(reader_->ReadTensors(&unused));
508       }
509       return OkStatus();
510     }
511 
512     std::unique_ptr<Reader> reader_;
513 
514     // Stores the id current checkpoint file that we are in the process of
515     // reading (e.g. if the file is currently 00000001.snapshot, then this will
516     // be 1).
517     int64_t current_checkpoint_id_ = 0;
518     int64_t start_index_;
519   };
520 
521   const tstring shard_dir_;
522   const std::string compression_;
523   const int64_t version_;
524   const DataTypeVector dtypes_;
525   const std::vector<PartialTensorShape> shapes_;
526   const int64_t start_index_;
527 };
528 
DatasetOp(OpKernelConstruction * ctx)529 Reader::DatasetOp::DatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
530   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
531   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
532   OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
533   OP_REQUIRES_OK(ctx, ctx->GetAttr(kVersion, &version_));
534 }
535 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)536 void Reader::DatasetOp::MakeDataset(OpKernelContext* ctx,
537                                     DatasetBase** output) {
538   tstring shard_dir;
539   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "shard_dir", &shard_dir));
540 
541   int64_t start_index;
542   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "start_index", &start_index));
543 
544   *output =
545       new Reader::Dataset(DatasetContext(ctx), shard_dir, compression_,
546                           version_, output_types_, output_shapes_, start_index);
547 }
548 
549 class Reader::NestedDataset : public DatasetBase {
550  public:
NestedDataset(DatasetContext && ctx,std::vector<DatasetBase * > datasets)551   explicit NestedDataset(DatasetContext&& ctx,
552                          std::vector<DatasetBase*> datasets)
553       : DatasetBase(std::move(ctx)), datasets_(datasets) {
554     dtypes_.push_back(DT_VARIANT);
555     gtl::InlinedVector<int64_t, 1> element_dim_sizes;
556     element_dim_sizes.push_back(1);
557     partial_shapes_.emplace_back(element_dim_sizes);
558   }
559 
output_dtypes() const560   const DataTypeVector& output_dtypes() const override { return dtypes_; }
561 
output_shapes() const562   const std::vector<PartialTensorShape>& output_shapes() const override {
563     return partial_shapes_;
564   }
565 
DebugString() const566   std::string DebugString() const override {
567     return "SnapshotNestedDatasetReader";
568   }
569 
InputDatasets(std::vector<const DatasetBase * > * inputs) const570   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
571     inputs->clear();
572     return OkStatus();
573   }
574 
CheckExternalState() const575   Status CheckExternalState() const override { return OkStatus(); }
576 
577  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** node) const578   Status AsGraphDefInternal(SerializationContext* ctx,
579                             DatasetGraphDefBuilder* b,
580                             Node** node) const override {
581     std::vector<Node*> input_graph_nodes;
582     input_graph_nodes.reserve(datasets_.size());
583     for (const auto& dataset : datasets_) {
584       Node* input_node;
585       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, dataset, &input_node));
586       input_graph_nodes.emplace_back(input_node);
587     }
588     TF_RETURN_IF_ERROR(
589         b->AddDataset(this, /*inputs=*/{},
590                       /*list_inputs=*/{std::make_pair(0, input_graph_nodes)},
591                       /*attrs=*/{}, node));
592     return OkStatus();
593   }
594 
MakeIteratorInternal(const string & prefix) const595   std::unique_ptr<IteratorBase> MakeIteratorInternal(
596       const string& prefix) const override {
597     return std::make_unique<Iterator>(Iterator::Params{
598         this, name_utils::IteratorPrefix(node_name(), prefix)});
599   }
600 
601  private:
602   std::vector<DatasetBase*> datasets_;
603   DataTypeVector dtypes_;
604   std::vector<PartialTensorShape> partial_shapes_;
605 
606   class Iterator : public DatasetIterator<NestedDataset> {
607    public:
Iterator(const Params & params)608     explicit Iterator(const Params& params)
609         : DatasetIterator<NestedDataset>(params) {}
610 
611    protected:
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)612     Status GetNextInternal(IteratorContext* ctx,
613                            std::vector<Tensor>* out_tensors,
614                            bool* end_of_sequence) override {
615       const int64_t num_datasets = dataset()->datasets_.size();
616       *end_of_sequence = num_datasets == index_;
617       if (!*end_of_sequence) {
618         Tensor tensor(DT_VARIANT, TensorShape({}));
619 
620         TF_RETURN_IF_ERROR(
621             StoreDatasetInVariantTensor(dataset()->datasets_[index_], &tensor));
622         out_tensors->clear();
623         out_tensors->push_back(std::move(tensor));
624 
625         index_++;
626       }
627       return OkStatus();
628     }
629 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)630     Status SaveInternal(SerializationContext* ctx,
631                         IteratorStateWriter* writer) override {
632       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
633       return OkStatus();
634     }
635 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)636     Status RestoreInternal(IteratorContext* ctx,
637                            IteratorStateReader* reader) override {
638       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &index_));
639       return OkStatus();
640     }
641 
642    private:
643     int64_t index_ = 0;
644   };
645 };
646 
NestedDatasetOp(OpKernelConstruction * ctx)647 Reader::NestedDatasetOp::NestedDatasetOp(OpKernelConstruction* ctx)
648     : DatasetOpKernel(ctx) {
649   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
650   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
651 }
652 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)653 void Reader::NestedDatasetOp::MakeDataset(OpKernelContext* ctx,
654                                           DatasetBase** output) {
655   std::vector<DatasetBase*> inputs;
656   for (size_t i = 0; i < ctx->num_inputs(); ++i) {
657     DatasetBase* input;
658     OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
659     inputs.push_back(input);
660   }
661   *output = new Reader::NestedDataset(DatasetContext(ctx), inputs);
662   (*output)->Initialize(/*metadata=*/{});
663 }
664 
MakeNestedDataset(Env * env,const std::vector<std::string> & shard_dirs,const string & compression_type,int version,const DataTypeVector & dtypes,const std::vector<PartialTensorShape> & shapes,const int64_t start_index,DatasetBase ** output)665 Status Reader::MakeNestedDataset(Env* env,
666                                  const std::vector<std::string>& shard_dirs,
667                                  const string& compression_type, int version,
668                                  const DataTypeVector& dtypes,
669                                  const std::vector<PartialTensorShape>& shapes,
670                                  const int64_t start_index,
671                                  DatasetBase** output) {
672   std::vector<DatasetBase*> datasets;
673 
674   datasets.reserve(shard_dirs.size());
675   for (int64_t i = 0; i < shard_dirs.size(); ++i) {
676     // TODO(frankchn): The reading pattern could be controlled in a non-round
677     // robin fashion, so we cannot assume a round-robin manner when restoring.
678     int64_t dataset_start_index = start_index / shard_dirs.size();
679     if (start_index % shard_dirs.size() > datasets.size()) {
680       dataset_start_index++;
681     }
682 
683     datasets.push_back(
684         new Dataset(DatasetContext(DatasetContext::Params(
685                         {"SnapshotDatasetReader",
686                          strings::StrCat("SnapshotDatasetReader/_", i)})),
687                     shard_dirs.at(i), compression_type, version, dtypes, shapes,
688                     dataset_start_index));
689     datasets.back()->Initialize(/*metadata=*/{});
690   }
691 
692   // Rotate the vector such that the first dataset contains the next element
693   // to be produced, but not if there are no shards at all (then we just
694   // construct an empty dataset).
695   if (!shard_dirs.empty()) {
696     std::rotate(datasets.begin(),
697                 datasets.begin() + (start_index % shard_dirs.size()),
698                 datasets.end());
699   }
700 
701   *output = new NestedDataset(
702       DatasetContext(DatasetContext::Params(
703           {"SnapshotNestedDatasetReader", "SnapshotNestedDatasetReader"})),
704       datasets);
705   (*output)->Initialize(/*metadata=*/{});
706   return OkStatus();
707 }
708 
TFRecordReader(const std::string & filename,const string & compression_type,const DataTypeVector & dtypes)709 TFRecordReader::TFRecordReader(const std::string& filename,
710                                const string& compression_type,
711                                const DataTypeVector& dtypes)
712     : filename_(filename),
713       offset_(0),
714       compression_type_(compression_type),
715       dtypes_(dtypes) {}
716 
Initialize(Env * env)717 Status TFRecordReader::Initialize(Env* env) {
718   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
719 
720   record_reader_ = std::make_unique<io::RecordReader>(
721       file_.get(), io::RecordReaderOptions::CreateRecordReaderOptions(
722                        /*compression_type=*/compression_type_));
723   return OkStatus();
724 }
725 
ReadTensors(std::vector<Tensor> * read_tensors)726 Status TFRecordReader::ReadTensors(std::vector<Tensor>* read_tensors) {
727   read_tensors->reserve(dtypes_.size());
728   for (int i = 0; i < dtypes_.size(); ++i) {
729     tstring record;
730     TF_RETURN_IF_ERROR(record_reader_->ReadRecord(&offset_, &record));
731 
732     TensorProto proto;
733     proto.ParseFromArray(record.data(), record.size());
734 
735     Tensor tensor;
736     if (!tensor.FromProto(proto)) {
737       return errors::DataLoss("Unable to parse tensor from stored proto.");
738     }
739 
740     read_tensors->push_back(std::move(tensor));
741   }
742   return OkStatus();
743 }
744 
CustomReader(const std::string & filename,const string & compression_type,const int version,const DataTypeVector & dtypes)745 CustomReader::CustomReader(const std::string& filename,
746                            const string& compression_type, const int version,
747                            const DataTypeVector& dtypes)
748     : filename_(filename),
749       compression_type_(compression_type),
750       version_(version),
751       dtypes_(dtypes) {}
752 
Initialize(Env * env)753 Status CustomReader::Initialize(Env* env) {
754   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
755   input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get());
756 
757 #if defined(IS_SLIM_BUILD)
758   if (compression_type_ != io::compression::kNone) {
759     LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
760                << "off compression.";
761   }
762 #else   // IS_SLIM_BUILD
763   if (compression_type_ == io::compression::kGzip) {
764     io::ZlibCompressionOptions zlib_options;
765     zlib_options = io::ZlibCompressionOptions::GZIP();
766 
767     input_stream_ = std::make_unique<io::ZlibInputStream>(
768         input_stream_.release(), zlib_options.input_buffer_size,
769         zlib_options.output_buffer_size, zlib_options, true);
770   } else if (compression_type_ == io::compression::kSnappy) {
771     if (version_ == 0) {
772       input_stream_ = std::make_unique<io::SnappyInputBuffer>(
773           file_.get(), /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
774           /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
775     } else {
776       input_stream_ =
777           std::make_unique<io::BufferedInputStream>(file_.get(), 64 << 20);
778     }
779   }
780 #endif  // IS_SLIM_BUILD
781   simple_tensor_mask_.reserve(dtypes_.size());
782   for (const auto& dtype : dtypes_) {
783     if (DataTypeCanUseMemcpy(dtype)) {
784       simple_tensor_mask_.push_back(true);
785       num_simple_++;
786     } else {
787       simple_tensor_mask_.push_back(false);
788       num_complex_++;
789     }
790   }
791 
792   return OkStatus();
793 }
794 
ReadTensors(std::vector<Tensor> * read_tensors)795 Status CustomReader::ReadTensors(std::vector<Tensor>* read_tensors) {
796   profiler::TraceMe activity(
797       [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
798       profiler::TraceMeLevel::kInfo);
799   if (version_ == 0 || compression_type_ != io::compression::kSnappy) {
800     return ReadTensorsV0(read_tensors);
801   }
802   if (version_ != 1) {
803     return errors::InvalidArgument("Version: ", version_, " is not supported.");
804   }
805   if (compression_type_ != io::compression::kSnappy) {
806     return errors::InvalidArgument("Compression ", compression_type_,
807                                    " is not supported.");
808   }
809 
810   experimental::SnapshotTensorMetadata metadata;
811   tstring metadata_str;
812   TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
813   if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
814     return errors::DataLoss("Could not parse SnapshotTensorMetadata");
815   }
816   read_tensors->reserve(metadata.tensor_metadata_size());
817 
818   std::vector<Tensor> simple_tensors;
819   simple_tensors.reserve(num_simple_);
820   std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs;
821   tensor_proto_strs.reserve(num_complex_);
822   TF_RETURN_IF_ERROR(
823       SnappyUncompress(&metadata, &simple_tensors, &tensor_proto_strs));
824 
825   int simple_index = 0;
826   int complex_index = 0;
827   for (int i = 0, end = simple_tensor_mask_.size(); i < end; ++i) {
828     if (simple_tensor_mask_[i]) {
829       read_tensors->push_back(std::move(simple_tensors[simple_index]));
830       simple_index++;
831     } else {
832       auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first);
833       size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
834       TensorProto tp;
835       if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
836         return errors::Internal("Could not parse TensorProto");
837       }
838       Tensor t;
839       if (!t.FromProto(tp)) {
840         return errors::Internal("Could not parse Tensor");
841       }
842       read_tensors->push_back(std::move(t));
843       complex_index++;
844     }
845   }
846   return OkStatus();
847 }
848 
ReadTensorsV0(std::vector<Tensor> * read_tensors)849 Status CustomReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
850   experimental::SnapshotRecord record;
851 #if defined(PLATFORM_GOOGLE)
852   absl::Cord c;
853   TF_RETURN_IF_ERROR(ReadRecord(&c));
854   record.ParseFromCord(c);
855 #else   // PLATFORM_GOOGLE
856   tstring record_bytes;
857   TF_RETURN_IF_ERROR(ReadRecord(&record_bytes));
858   record.ParseFromArray(record_bytes.data(), record_bytes.size());
859 #endif  // PLATFORM_GOOGLE
860   read_tensors->reserve(record.tensor_size());
861   for (int i = 0; i < record.tensor_size(); ++i) {
862     read_tensors->emplace_back();
863     if (!read_tensors->back().FromProto(record.tensor(i))) {
864       return errors::DataLoss("Unable to parse tensor from proto.");
865     }
866   }
867   return OkStatus();
868 }
869 
SnappyUncompress(const experimental::SnapshotTensorMetadata * metadata,std::vector<Tensor> * simple_tensors,std::vector<std::pair<std::unique_ptr<char[]>,size_t>> * tensor_proto_strs)870 Status CustomReader::SnappyUncompress(
871     const experimental::SnapshotTensorMetadata* metadata,
872     std::vector<Tensor>* simple_tensors,
873     std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
874         tensor_proto_strs) {
875   tstring compressed;
876   TF_RETURN_IF_ERROR(ReadRecord(&compressed));
877   size_t size;
878   if (!port::Snappy_GetUncompressedLength(compressed.data(), compressed.size(),
879                                           &size)) {
880     return errors::Internal("Could not get snappy uncompressed length");
881   }
882 
883   int num_tensors = metadata->tensor_metadata_size();
884   std::vector<struct iovec> iov(num_tensors);
885   int index = 0;
886   int64_t total_size = 0;
887   for (int i = 0, end = simple_tensor_mask_.size(); i < end; ++i) {
888     const auto& tensor_metadata = metadata->tensor_metadata(i);
889     if (simple_tensor_mask_[i]) {
890       TensorShape shape(tensor_metadata.tensor_shape());
891       Tensor simple_tensor(dtypes_[i], shape);
892       TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor);
893       iov[index].iov_base = buffer->data();
894       iov[index].iov_len = buffer->size();
895       simple_tensors->push_back(std::move(simple_tensor));
896     } else {
897       auto tensor_proto_str =
898           std::make_unique<char[]>(tensor_metadata.tensor_size_bytes());
899       iov[index].iov_base = tensor_proto_str.get();
900       iov[index].iov_len = tensor_metadata.tensor_size_bytes();
901       tensor_proto_strs->push_back(std::make_pair(
902           std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes()));
903     }
904     total_size += iov[index].iov_len;
905     index++;
906   }
907   const int64_t size_int = size;
908   if (size_int != total_size) {
909     return errors::Internal("Uncompressed size mismatch. Snappy expects ", size,
910                             " whereas the tensor metadata suggests ",
911                             total_size);
912   }
913   if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(),
914                                       iov.data(), num_tensors)) {
915     return errors::Internal("Failed to perform snappy decompression.");
916   }
917   return OkStatus();
918 }
919 
ReadRecord(tstring * record)920 Status CustomReader::ReadRecord(tstring* record) {
921   tstring header;
922   TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
923   uint64 length = core::DecodeFixed64(header.data());
924   return input_stream_->ReadNBytes(length, record);
925 }
926 
927 #if defined(TF_CORD_SUPPORT)
ReadRecord(absl::Cord * record)928 Status CustomReader::ReadRecord(absl::Cord* record) {
929   tstring header;
930   TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
931   uint64 length = core::DecodeFixed64(header.data());
932   if (compression_type_ == io::compression::kNone) {
933     return input_stream_->ReadNBytes(length, record);
934   } else {
935     auto tmp_str = new tstring();
936     TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str));
937     absl::string_view tmp_str_view(*tmp_str);
938     record->Append(absl::MakeCordFromExternal(
939         tmp_str_view, [tmp_str](absl::string_view) { delete tmp_str; }));
940     return OkStatus();
941   }
942 }
943 #endif  // TF_CORD_SUPPORT
944 
WriteMetadataFile(Env * env,const string & dir,const experimental::SnapshotMetadataRecord * metadata)945 Status WriteMetadataFile(Env* env, const string& dir,
946                          const experimental::SnapshotMetadataRecord* metadata) {
947   string metadata_filename = io::JoinPath(dir, kMetadataFilename);
948   TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(dir));
949   std::string tmp_filename =
950       absl::StrCat(metadata_filename, "-tmp-", random::New64());
951   TF_RETURN_IF_ERROR(WriteBinaryProto(env, tmp_filename, *metadata));
952   return env->RenameFile(tmp_filename, metadata_filename);
953 }
954 
ReadMetadataFile(Env * env,const string & dir,experimental::SnapshotMetadataRecord * metadata,bool * file_exists)955 Status ReadMetadataFile(Env* env, const string& dir,
956                         experimental::SnapshotMetadataRecord* metadata,
957                         bool* file_exists) {
958   string metadata_filename = io::JoinPath(dir, kMetadataFilename);
959   Status s = env->FileExists(metadata_filename);
960   *file_exists = s.ok();
961 
962   if (*file_exists) {
963     return ReadBinaryProto(env, metadata_filename, metadata);
964   } else {
965     return OkStatus();
966   }
967 }
968 
DumpDatasetGraph(Env * env,const std::string & path,uint64 hash,const GraphDef * graph)969 Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
970                         const GraphDef* graph) {
971   std::string hash_hex =
972       strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
973   std::string graph_file =
974       io::JoinPath(path, absl::StrCat(hash_hex, "-graph.pbtxt"));
975 
976   LOG(INFO) << "Graph hash is " << hash_hex << ", writing to " << graph_file;
977   TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(path));
978   return WriteTextProto(env, graph_file, *graph);
979 }
980 
DetermineOpState(const std::string & mode_string,bool file_exists,const experimental::SnapshotMetadataRecord * metadata,const uint64 pending_snapshot_expiry_seconds,Mode * mode)981 Status DetermineOpState(const std::string& mode_string, bool file_exists,
982                         const experimental::SnapshotMetadataRecord* metadata,
983                         const uint64 pending_snapshot_expiry_seconds,
984                         Mode* mode) {
985   if (mode_string == kModeRead) {
986     // In read mode, we should expect a metadata file is written.
987     if (!file_exists) {
988       return errors::NotFound("Metadata file does not exist.");
989     }
990     LOG(INFO) << "Overriding mode to reader.";
991     *mode = READER;
992     return OkStatus();
993   }
994 
995   if (mode_string == kModeWrite) {
996     LOG(INFO) << "Overriding mode to writer.";
997     *mode = WRITER;
998     return OkStatus();
999   }
1000 
1001   if (mode_string == kModePassthrough) {
1002     LOG(INFO) << "Overriding mode to passthrough.";
1003     *mode = PASSTHROUGH;
1004     return OkStatus();
1005   }
1006 
1007   if (!file_exists) {
1008     *mode = WRITER;
1009     return OkStatus();
1010   }
1011 
1012   if (metadata->finalized()) {
1013     // File found, snapshot has been finalized.
1014     *mode = READER;
1015     return OkStatus();
1016   }
1017 
1018   int64_t expiration_timer = static_cast<int64_t>(EnvTime::NowMicros()) -
1019                              pending_snapshot_expiry_seconds * 1000000;
1020 
1021   if (metadata->creation_timestamp() >= expiration_timer) {
1022     // Someone else is already writing and time has not expired.
1023     *mode = PASSTHROUGH;
1024     return OkStatus();
1025   } else {
1026     // Time has expired, we write regardless.
1027     *mode = WRITER;
1028     return OkStatus();
1029   }
1030 }
1031 
AsyncWriter(Env * env,int64_t file_index,const std::string & shard_directory,uint64 checkpoint_id,const std::string & compression,int64_t version,const DataTypeVector & output_types,std::function<void (Status)> done)1032 AsyncWriter::AsyncWriter(Env* env, int64_t file_index,
1033                          const std::string& shard_directory,
1034                          uint64 checkpoint_id, const std::string& compression,
1035                          int64_t version, const DataTypeVector& output_types,
1036                          std::function<void(Status)> done) {
1037   thread_ = absl::WrapUnique(env->StartThread(
1038       ThreadOptions(), absl::StrCat("writer_thread_", file_index),
1039       [this, env, shard_directory, checkpoint_id, compression, version,
1040        &output_types, done = std::move(done)] {
1041         done(WriterThread(env, shard_directory, checkpoint_id, compression,
1042                           version, output_types));
1043       }));
1044 }
1045 
Write(const std::vector<Tensor> & tensors)1046 void AsyncWriter::Write(const std::vector<Tensor>& tensors) {
1047   mutex_lock l(mu_);
1048   ElementOrEOF element;
1049   element.value = tensors;
1050   deque_.push_back(std::move(element));
1051 }
1052 
SignalEOF()1053 void AsyncWriter::SignalEOF() {
1054   mutex_lock l(mu_);
1055   ElementOrEOF be;
1056   be.end_of_sequence = true;
1057   deque_.push_back(std::move(be));
1058 }
1059 
Consume(ElementOrEOF * be)1060 void AsyncWriter::Consume(ElementOrEOF* be) {
1061   mutex_lock l(mu_);
1062   mu_.Await(tensorflow::Condition(this, &AsyncWriter::ElementAvailable));
1063   *be = deque_.front();
1064   deque_.pop_front();
1065 }
1066 
ElementAvailable()1067 bool AsyncWriter::ElementAvailable() { return !deque_.empty(); }
1068 
WriterThread(Env * env,const std::string & shard_directory,uint64 checkpoint_id,const std::string & compression,int64_t version,DataTypeVector output_types)1069 Status AsyncWriter::WriterThread(Env* env, const std::string& shard_directory,
1070                                  uint64 checkpoint_id,
1071                                  const std::string& compression,
1072                                  int64_t version, DataTypeVector output_types) {
1073   std::unique_ptr<snapshot_util::Writer> writer;
1074   TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
1075 
1076   TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
1077       env, GetCheckpointFileName(shard_directory, checkpoint_id), compression,
1078       version, std::move(output_types), &writer));
1079 
1080   while (true) {
1081     ElementOrEOF be;
1082     Consume(&be);
1083 
1084     if (be.end_of_sequence) {
1085       TF_RETURN_IF_ERROR(writer->Close());
1086       break;
1087     }
1088 
1089     TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
1090   }
1091   return OkStatus();
1092 }
1093 
1094 namespace {
1095 
1096 REGISTER_KERNEL_BUILDER(Name("SnapshotDatasetReader").Device(DEVICE_CPU),
1097                         Reader::DatasetOp);
1098 REGISTER_KERNEL_BUILDER(Name("SnapshotNestedDatasetReader").Device(DEVICE_CPU),
1099                         Reader::NestedDatasetOp);
1100 
1101 }  // namespace
1102 }  // namespace snapshot_util
1103 }  // namespace data
1104 }  // namespace tensorflow
1105