1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_slice_writer.h"
17
18 #include <utility>
19
20 #include "tensorflow/core/framework/versions.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/io/table_builder.h"
23 #include "tensorflow/core/lib/random/random.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/env.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/saved_tensor_slice_util.h"
29
30 namespace tensorflow {
31
32 namespace checkpoint {
33
34 namespace {
35
36 class TableBuilder : public TensorSliceWriter::Builder {
37 public:
TableBuilder(const string & name,WritableFile * f)38 TableBuilder(const string& name, WritableFile* f) : name_(name), file_(f) {
39 table::Options option;
40 option.compression = table::kNoCompression;
41 builder_.reset(new table::TableBuilder(option, f));
42 }
Add(StringPiece key,StringPiece val)43 void Add(StringPiece key, StringPiece val) override {
44 builder_->Add(key, val);
45 }
Finish(int64_t * file_size)46 Status Finish(int64_t* file_size) override {
47 *file_size = -1;
48 Status s = builder_->Finish();
49 if (s.ok()) {
50 s = file_->Close();
51 if (s.ok()) {
52 *file_size = builder_->FileSize();
53 }
54 }
55 if (!s.ok()) {
56 s = errors::Internal("Error writing (tmp) checkpoint file: ", name_, ": ",
57 s.error_message());
58 }
59 builder_.reset();
60 file_.reset();
61 return s;
62 }
63
64 private:
65 string name_;
66 std::unique_ptr<WritableFile> file_;
67 std::unique_ptr<table::TableBuilder> builder_;
68 };
69 } // anonymous namespace
70
CreateTableTensorSliceBuilder(const string & name,TensorSliceWriter::Builder ** builder)71 Status CreateTableTensorSliceBuilder(const string& name,
72 TensorSliceWriter::Builder** builder) {
73 *builder = nullptr;
74 std::unique_ptr<WritableFile> f;
75 Status s = Env::Default()->NewWritableFile(name, &f);
76 if (s.ok()) {
77 *builder = new TableBuilder(name, f.release());
78 return OkStatus();
79 } else {
80 return s;
81 }
82 }
83
TensorSliceWriter(const string & filename,CreateBuilderFunction create_builder)84 TensorSliceWriter::TensorSliceWriter(const string& filename,
85 CreateBuilderFunction create_builder)
86 : filename_(filename),
87 create_builder_(std::move(create_builder)),
88 slices_(0) {
89 Env* env = Env::Default();
90 Status status = env->CanCreateTempFile(filename_, &use_temp_file_);
91 if (!status.ok()) {
92 LOG(ERROR) << "Failed to get CanCreateTempFile attribute: " << filename_;
93 use_temp_file_ = true;
94 }
95 data_filename_ = filename_;
96 if (use_temp_file_) {
97 data_filename_ = strings::StrCat(filename_, ".tempstate", random::New64());
98 }
99 VersionDef* versions = sts_.mutable_meta()->mutable_versions();
100 versions->set_producer(TF_CHECKPOINT_VERSION);
101 versions->set_min_consumer(TF_CHECKPOINT_VERSION_MIN_CONSUMER);
102 }
103
Finish()104 Status TensorSliceWriter::Finish() {
105 Builder* b;
106 Status s = create_builder_(data_filename_, &b);
107 if (!s.ok()) {
108 delete b;
109 return s;
110 }
111 std::unique_ptr<Builder> builder(b);
112
113 // We save the saved tensor slice metadata as the first element.
114 string meta;
115 sts_.AppendToString(&meta);
116 builder->Add(kSavedTensorSlicesKey, meta);
117
118 // Go through all the data and add them
119 for (const auto& x : data_) {
120 builder->Add(x.first, x.second);
121 }
122
123 int64_t file_size;
124 s = builder->Finish(&file_size);
125 // If use temp file, we need to rename the file to the proper name.
126 if (use_temp_file_) {
127 if (s.ok()) {
128 s = Env::Default()->RenameFile(data_filename_, filename_);
129 if (s.ok()) {
130 VLOG(1) << "Written " << slices_ << " slices for "
131 << sts_.meta().tensor_size() << " tensors (" << file_size
132 << " bytes) to " << filename_;
133 } else {
134 LOG(ERROR) << "Failed to rename file " << data_filename_ << " to "
135 << filename_;
136 }
137 } else {
138 Env::Default()->DeleteFile(data_filename_).IgnoreError();
139 }
140 }
141 return s;
142 }
143
144 /* static */
MaxBytesPerElement(DataType dt)145 size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
146 size_t max_bytes_per_element =
147 TensorSliceWriter::MaxBytesPerElementOrZero(dt);
148 if (max_bytes_per_element == 0) {
149 LOG(FATAL) << "MaxBytesPerElement not implemented for dtype: " << dt;
150 }
151 return max_bytes_per_element;
152 }
153
154 /* static */
MaxBytesPerElementOrZero(DataType dt)155 size_t TensorSliceWriter::MaxBytesPerElementOrZero(DataType dt) {
156 switch (dt) {
157 case DT_FLOAT:
158 return 4;
159 case DT_DOUBLE:
160 return 8;
161 case DT_INT32:
162 return 10;
163 case DT_UINT8:
164 return 2;
165 case DT_INT16:
166 return 10;
167 case DT_INT8:
168 return 10;
169 case DT_COMPLEX64:
170 return 8;
171 case DT_INT64:
172 return 10;
173 case DT_BOOL:
174 return 1;
175 case DT_QINT8:
176 return 10;
177 case DT_QUINT8:
178 return 2;
179 case DT_QINT32:
180 return 10;
181 case DT_QINT16:
182 return 10;
183 case DT_QUINT16:
184 return 3;
185 case DT_UINT16:
186 return 3;
187 case DT_COMPLEX128:
188 return 16;
189 case DT_HALF:
190 return 3;
191 case DT_INVALID:
192 case DT_STRING:
193 case DT_BFLOAT16:
194 default:
195 return 0;
196 }
197 }
198
199 template <>
SaveData(const tstring * data,int64_t num_elements,SavedSlice * ss)200 Status TensorSliceWriter::SaveData(const tstring* data, int64_t num_elements,
201 SavedSlice* ss) {
202 size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
203 (num_elements * MaxBytesPerElement(DT_INT32));
204 for (int64_t i = 0; i < num_elements; ++i) {
205 size_bound += data[i].size();
206 }
207 if (size_bound > kMaxMessageBytes) {
208 return errors::InvalidArgument(
209 "Tensor slice is too large to serialize (conservative estimate: ",
210 size_bound, " bytes)");
211 }
212 Fill(data, num_elements, ss->mutable_data());
213 DCHECK_GE(ss->ByteSize(), 0);
214 DCHECK_LE(ss->ByteSize(), size_bound);
215 return OkStatus();
216 }
217
218 } // namespace checkpoint
219
220 } // namespace tensorflow
221