1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_slice_writer.h"
17
18 #include <algorithm>
19 #include <array>
20 #include <memory>
21 #include <vector>
22
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/framework/versions.pb.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/path.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/core/protobuf/error_codes.pb.h"
31 #include "tensorflow/core/public/version.h"
32 #include "tensorflow/core/util/saved_tensor_slice_util.h"
33 #include "tensorflow/core/util/tensor_slice_reader.h"
34
35 namespace tensorflow {
36
37 namespace checkpoint {
38
39 class TensorSliceWriteTestHelper {
40 public:
41 static void CheckEntries(const string& fname);
42 static void GetData(TensorSliceReader::Table* table, const string& name,
43 const TensorSlice& slice, SavedSlice* ss);
44 };
45
46 namespace {
47
48 // Testing that an array is what is expected
ExpectIdenticalFloatArrays(const float * expected,int size,const float * actual)49 void ExpectIdenticalFloatArrays(const float* expected, int size,
50 const float* actual) {
51 // TODO(yangke): copy some of the Dump* functions over
52 // LOG(INFO) << "Expected = " << DumpFloatArray(expected, size);
53 // LOG(INFO) << "Actual = " << DumpFloatArray(actual, size);
54 for (int i = 0; i < size; ++i) {
55 EXPECT_NEAR(expected[i], actual[i], 1e-6);
56 }
57 }
58
59 template <typename T, typename U>
ExpectIdenticalIntArrays(const T * expected,int size,const U * actual)60 void ExpectIdenticalIntArrays(const T* expected, int size, const U* actual) {
61 for (int i = 0; i < size; ++i) {
62 EXPECT_EQ(expected[i], static_cast<T>(actual[i]));
63 }
64 }
65
66 // Nifty routine to get the size of an array
67 template <typename T, unsigned SIZE>
ArraySize(const T (& v)[SIZE])68 inline size_t ArraySize(const T (&v)[SIZE]) {
69 return SIZE;
70 }
71
72 // A simple test on writing a few tensor slices
73 // TODO(yangke): refactor into smaller tests: will do as we add more stuff to
74 // the writer.
TEST(TensorSliceWriteTest,SimpleWrite)75 TEST(TensorSliceWriteTest, SimpleWrite) {
76 const string filename = io::JoinPath(testing::TmpDir(), "checkpoint");
77
78 TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder);
79
80 // Add some int32 tensor slices
81 {
82 TensorShape shape({5, 10});
83 TensorSlice slice = TensorSlice::ParseOrDie("-:0,1");
84 const int32 data[] = {0, 1, 2, 3, 4};
85 TF_CHECK_OK(writer.Add("test", shape, slice, data));
86 }
87
88 // Two slices share the same tensor name
89 {
90 TensorShape shape({5, 10});
91 TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
92 const int32 data[] = {10, 11, 12, 13, 14};
93 TF_CHECK_OK(writer.Add("test", shape, slice, data));
94 }
95
96 // Another slice from a different float tensor -- it has a different name and
97 // should be inserted in front of the previous tensor
98 {
99 TensorShape shape({3, 2});
100 TensorSlice slice = TensorSlice::ParseOrDie("-:-");
101 const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3};
102 TF_CHECK_OK(writer.Add("AA", shape, slice, data));
103 }
104
105 // A slice with int64 data
106 {
107 TensorShape shape({5, 10});
108 TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
109 const int64_t data[] = {10, 11, 12, 13, 14};
110 TF_CHECK_OK(writer.Add("int64", shape, slice, data));
111 }
112
113 // A slice with int16 data
114 {
115 TensorShape shape({5, 10});
116 TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
117 const int16 data[] = {10, 11, 12, 13, 14};
118 TF_CHECK_OK(writer.Add("int16", shape, slice, data));
119 }
120
121 TF_CHECK_OK(writer.Finish());
122
123 // Now we examine the checkpoint file manually.
124 TensorSliceWriteTestHelper::CheckEntries(filename);
125 }
126
127 } // namespace
128
GetData(TensorSliceReader::Table * table,const string & name,const TensorSlice & slice,SavedSlice * ss)129 void TensorSliceWriteTestHelper::GetData(TensorSliceReader::Table* table,
130 const string& name,
131 const TensorSlice& slice,
132 SavedSlice* ss) {
133 string key = EncodeTensorNameSlice(name, slice);
134 string value;
135 EXPECT_TRUE(table->Get(key, &value));
136 SavedTensorSlices sts;
137 EXPECT_TRUE(ParseProtoUnlimited(&sts, value));
138 EXPECT_FALSE(sts.has_meta());
139 *ss = sts.data();
140 EXPECT_EQ(name, ss->name());
141 TensorSlice slice2(ss->slice());
142 EXPECT_EQ(slice.DebugString(), slice2.DebugString());
143 }
144
CheckEntries(const string & fname)145 void TensorSliceWriteTestHelper::CheckEntries(const string& fname) {
146 TensorSliceReader::Table* tptr;
147 TF_CHECK_OK(OpenTableTensorSliceReader(fname, &tptr));
148 std::unique_ptr<TensorSliceReader::Table> table(tptr);
149 CHECK_NOTNULL(table.get());
150
151 // We expect a block of SavedTensorSlices
152 string value;
153 ASSERT_TRUE(table->Get(kSavedTensorSlicesKey, &value));
154 {
155 SavedTensorSlices sts;
156 EXPECT_TRUE(ParseProtoUnlimited(&sts, value));
157 // We also expect two entries for the tensors
158 EXPECT_TRUE(sts.has_meta());
159 EXPECT_EQ(4, sts.meta().tensor_size());
160 // We should have written nontrivial version information
161 EXPECT_LT(0, TF_CHECKPOINT_VERSION);
162 EXPECT_EQ(TF_CHECKPOINT_VERSION, sts.meta().versions().producer());
163 EXPECT_EQ(TF_CHECKPOINT_VERSION_MIN_CONSUMER,
164 sts.meta().versions().min_consumer());
165 // We don't expect any data in the first block.
166 EXPECT_FALSE(sts.has_data());
167 // The two tensors should be stored in the same order as they are first
168 // created.
169 {
170 // The two slices of the "test" tensor
171 const SavedSliceMeta& ssm = sts.meta().tensor(0);
172 EXPECT_EQ("test", ssm.name());
173 TensorShapeProto expected_shape_proto;
174 protobuf::TextFormat::ParseFromString(
175 "dim { size: 5 } "
176 "dim { size: 10 }",
177 &expected_shape_proto);
178 EXPECT_EQ(ssm.shape().ShortDebugString(),
179 expected_shape_proto.ShortDebugString());
180 EXPECT_EQ(DT_INT32, ssm.type());
181 EXPECT_EQ(2, ssm.slice_size());
182 TensorSlice s0(ssm.slice(0));
183 TensorSlice s1(ssm.slice(1));
184 EXPECT_EQ("-:0,1", s0.DebugString());
185 EXPECT_EQ("-:3,1", s1.DebugString());
186 }
187 {
188 // The "AA" tensor
189 const SavedSliceMeta& ssm = sts.meta().tensor(1);
190 EXPECT_EQ("AA", ssm.name());
191 TensorShapeProto expected_shape_proto;
192 protobuf::TextFormat::ParseFromString(
193 "dim { size: 3 } "
194 "dim { size: 2 }",
195 &expected_shape_proto);
196 EXPECT_EQ(ssm.shape().ShortDebugString(),
197 expected_shape_proto.ShortDebugString());
198 EXPECT_EQ(DT_FLOAT, ssm.type());
199 EXPECT_EQ(1, ssm.slice_size());
200 TensorSlice s0(ssm.slice(0));
201 EXPECT_EQ("-:-", s0.DebugString());
202 }
203 {
204 // The "int64" tensor
205 const SavedSliceMeta& ssm = sts.meta().tensor(2);
206 EXPECT_EQ("int64", ssm.name());
207 TensorShapeProto expected_shape_proto;
208 protobuf::TextFormat::ParseFromString(
209 "dim { size: 5 } "
210 "dim { size: 10 }",
211 &expected_shape_proto);
212 EXPECT_EQ(ssm.shape().ShortDebugString(),
213 expected_shape_proto.ShortDebugString());
214 EXPECT_EQ(DT_INT64, ssm.type());
215 EXPECT_EQ(1, ssm.slice_size());
216 TensorSlice s0(ssm.slice(0));
217 EXPECT_EQ("-:3,1", s0.DebugString());
218 }
219 {
220 // The "int16" tensor
221 const SavedSliceMeta& ssm = sts.meta().tensor(3);
222 EXPECT_EQ("int16", ssm.name());
223 TensorShapeProto expected_shape_proto;
224 protobuf::TextFormat::ParseFromString(
225 "dim { size: 5 } "
226 "dim { size: 10 }",
227 &expected_shape_proto);
228 EXPECT_EQ(ssm.shape().ShortDebugString(),
229 expected_shape_proto.ShortDebugString());
230 EXPECT_EQ(DT_INT16, ssm.type());
231 EXPECT_EQ(1, ssm.slice_size());
232 TensorSlice s0(ssm.slice(0));
233 EXPECT_EQ("-:3,1", s0.DebugString());
234 }
235 }
236
237 // We expect 5 blocks of tensor data
238 {
239 // Block 1: we expect it to be the full slice of the "AA" tensor
240 SavedSlice ss;
241 GetData(table.get(), "AA", TensorSlice(2), &ss);
242 const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3};
243 EXPECT_EQ(ArraySize(data), ss.data().float_val_size());
244 ExpectIdenticalFloatArrays(data, ArraySize(data),
245 ss.data().float_val().data());
246 }
247
248 {
249 // Block 2: we expect it to be the first slice of the "test" tensor
250 SavedSlice ss;
251 GetData(table.get(), "test", TensorSlice({{0, -1}, {0, 1}}), &ss);
252 const int32 data[] = {0, 1, 2, 3, 4};
253 EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
254 ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
255 }
256
257 {
258 // Block 3: we expect it to be the second slice of the "test" tensor
259 SavedSlice ss;
260 GetData(table.get(), "test", TensorSlice({{0, -1}, {3, 1}}), &ss);
261 const int32 data[] = {10, 11, 12, 13, 14};
262 EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
263 ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
264 }
265
266 {
267 // Block 4: we expect it to be the slice of the "int64" tensor
268 SavedSlice ss;
269 GetData(table.get(), "int64", TensorSlice({{0, -1}, {3, 1}}), &ss);
270 const int64_t data[] = {10, 11, 12, 13, 14};
271 EXPECT_EQ(ArraySize(data), ss.data().int64_val_size());
272 ExpectIdenticalIntArrays(data, ArraySize(data),
273 ss.data().int64_val().data());
274 }
275
276 {
277 // Block 5: we expect it to be the slice of the "int16" tensor
278 SavedSlice ss;
279 GetData(table.get(), "int16", TensorSlice({{0, -1}, {3, 1}}), &ss);
280 const int16 data[] = {10, 11, 12, 13, 14};
281 EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
282 ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
283 }
284 }
285
286 template <typename DT>
BytesPerElementHelper(DT value)287 size_t BytesPerElementHelper(DT value) {
288 SavedSlice ss;
289 std::array<DT, 1> lo_data;
290 std::fill(lo_data.begin(), lo_data.end(), value);
291 TF_EXPECT_OK(
292 TensorSliceWriter::SaveData(lo_data.data(), lo_data.size(), &ss));
293 size_t lo_byte_size = ss.ByteSizeLong();
294
295 std::array<DT, 1001> hi_data;
296 std::fill(hi_data.begin(), hi_data.end(), value);
297 TF_EXPECT_OK(
298 TensorSliceWriter::SaveData(hi_data.data(), hi_data.size(), &ss));
299 size_t hi_byte_size = ss.ByteSizeLong();
300
301 return (hi_byte_size - lo_byte_size) / (hi_data.size() - lo_data.size());
302 }
303
TEST(TensorSliceWriteTest,CheckpointSize)304 TEST(TensorSliceWriteTest, CheckpointSize) {
305 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL),
306 BytesPerElementHelper<bool>(false));
307 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL),
308 BytesPerElementHelper<bool>(true));
309 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_FLOAT),
310 BytesPerElementHelper<float>(-1.0));
311 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_DOUBLE),
312 BytesPerElementHelper<double>(-1.0));
313 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX64),
314 BytesPerElementHelper<complex64>(-1.0));
315 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX128),
316 BytesPerElementHelper<complex128>(-1.0));
317 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT32),
318 BytesPerElementHelper<int32>(-1));
319 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT64),
320 BytesPerElementHelper<int64_t>(-1));
321 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT16),
322 BytesPerElementHelper<uint16>(std::numeric_limits<uint16>::max()));
323 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT8),
324 BytesPerElementHelper<uint8>(std::numeric_limits<uint8>::max()));
325 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT8),
326 BytesPerElementHelper<int8>(-1));
327 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT16),
328 BytesPerElementHelper<int16>(-1));
329 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT8),
330 BytesPerElementHelper<qint8>(-1));
331 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QUINT8),
332 BytesPerElementHelper<quint8>(std::numeric_limits<uint8>::max()));
333 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT32),
334 BytesPerElementHelper<qint32>(-1));
335 EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_HALF),
336 BytesPerElementHelper<Eigen::half>(Eigen::half(-1.0)));
337 }
338
TEST(TensorSliceWriteTest,SizeErrors)339 TEST(TensorSliceWriteTest, SizeErrors) {
340 const string filename = io::JoinPath(testing::TmpDir(), "checkpoint");
341
342 TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder);
343
344 // Add a 300MB int8 tensor slice, which will fail because it expands to 3GB.
345 {
346 TensorShape shape({300, 1000000});
347 TensorSlice slice = TensorSlice::ParseOrDie("-:-");
348 const std::vector<int8> data(300000000, -1);
349 Status s = writer.Add("test1", shape, slice, data.data());
350 EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
351 EXPECT_TRUE(absl::StrContains(s.error_message(),
352 "Tensor slice is too large to serialize"));
353 }
354
355 // Add a large string tensor slice, which will fail.
356 {
357 TensorShape shape({256, 1024});
358 TensorSlice slice = TensorSlice::ParseOrDie("-:-");
359 const std::vector<tstring> data(256 * 1024, std::string(8192, 'f'));
360 Status s = writer.Add("test2", shape, slice, data.data());
361 EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
362 EXPECT_TRUE(absl::StrContains(s.error_message(),
363 "Tensor slice is too large to serialize"));
364 }
365 }
366
TEST(TensorSliceWriterTest,InvalidInput)367 TEST(TensorSliceWriterTest, InvalidInput) {
368 SavedSlice ss;
369 std::array<uint32_t, 1> data;
370 std::fill(data.begin(), data.end(), 1234);
371 Status s = TensorSliceWriter::SaveData(data.data(), data.size(), &ss);
372 EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
373 EXPECT_TRUE(absl::StrContains(
374 s.error_message(),
375 "Tensor slice serialization not implemented for dtype"));
376 }
377
378 } // namespace checkpoint
379
380 } // namespace tensorflow
381