xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/tensor_slice_writer_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_slice_writer.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <memory>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/framework/versions.pb.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/path.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/core/protobuf/error_codes.pb.h"
31 #include "tensorflow/core/public/version.h"
32 #include "tensorflow/core/util/saved_tensor_slice_util.h"
33 #include "tensorflow/core/util/tensor_slice_reader.h"
34 
35 namespace tensorflow {
36 
37 namespace checkpoint {
38 
39 class TensorSliceWriteTestHelper {
40  public:
41   static void CheckEntries(const string& fname);
42   static void GetData(TensorSliceReader::Table* table, const string& name,
43                       const TensorSlice& slice, SavedSlice* ss);
44 };
45 
46 namespace {
47 
48 // Testing that an array is what is expected
ExpectIdenticalFloatArrays(const float * expected,int size,const float * actual)49 void ExpectIdenticalFloatArrays(const float* expected, int size,
50                                 const float* actual) {
51   // TODO(yangke): copy some of the Dump* functions over
52   //  LOG(INFO) << "Expected = " << DumpFloatArray(expected, size);
53   //  LOG(INFO) << "Actual   = " << DumpFloatArray(actual, size);
54   for (int i = 0; i < size; ++i) {
55     EXPECT_NEAR(expected[i], actual[i], 1e-6);
56   }
57 }
58 
59 template <typename T, typename U>
ExpectIdenticalIntArrays(const T * expected,int size,const U * actual)60 void ExpectIdenticalIntArrays(const T* expected, int size, const U* actual) {
61   for (int i = 0; i < size; ++i) {
62     EXPECT_EQ(expected[i], static_cast<T>(actual[i]));
63   }
64 }
65 
66 // Nifty routine to get the size of an array
67 template <typename T, unsigned SIZE>
ArraySize(const T (& v)[SIZE])68 inline size_t ArraySize(const T (&v)[SIZE]) {
69   return SIZE;
70 }
71 
72 // A simple test on writing a few tensor slices
73 // TODO(yangke): refactor into smaller tests: will do as we add more stuff to
74 // the writer.
TEST(TensorSliceWriteTest,SimpleWrite)75 TEST(TensorSliceWriteTest, SimpleWrite) {
76   const string filename = io::JoinPath(testing::TmpDir(), "checkpoint");
77 
78   TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder);
79 
80   // Add some int32 tensor slices
81   {
82     TensorShape shape({5, 10});
83     TensorSlice slice = TensorSlice::ParseOrDie("-:0,1");
84     const int32 data[] = {0, 1, 2, 3, 4};
85     TF_CHECK_OK(writer.Add("test", shape, slice, data));
86   }
87 
88   // Two slices share the same tensor name
89   {
90     TensorShape shape({5, 10});
91     TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
92     const int32 data[] = {10, 11, 12, 13, 14};
93     TF_CHECK_OK(writer.Add("test", shape, slice, data));
94   }
95 
96   // Another slice from a different float tensor -- it has a different name and
97   // should be inserted in front of the previous tensor
98   {
99     TensorShape shape({3, 2});
100     TensorSlice slice = TensorSlice::ParseOrDie("-:-");
101     const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3};
102     TF_CHECK_OK(writer.Add("AA", shape, slice, data));
103   }
104 
105   // A slice with int64 data
106   {
107     TensorShape shape({5, 10});
108     TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
109     const int64_t data[] = {10, 11, 12, 13, 14};
110     TF_CHECK_OK(writer.Add("int64", shape, slice, data));
111   }
112 
113   // A slice with int16 data
114   {
115     TensorShape shape({5, 10});
116     TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
117     const int16 data[] = {10, 11, 12, 13, 14};
118     TF_CHECK_OK(writer.Add("int16", shape, slice, data));
119   }
120 
121   TF_CHECK_OK(writer.Finish());
122 
123   // Now we examine the checkpoint file manually.
124   TensorSliceWriteTestHelper::CheckEntries(filename);
125 }
126 
127 }  // namespace
128 
GetData(TensorSliceReader::Table * table,const string & name,const TensorSlice & slice,SavedSlice * ss)129 void TensorSliceWriteTestHelper::GetData(TensorSliceReader::Table* table,
130                                          const string& name,
131                                          const TensorSlice& slice,
132                                          SavedSlice* ss) {
133   string key = EncodeTensorNameSlice(name, slice);
134   string value;
135   EXPECT_TRUE(table->Get(key, &value));
136   SavedTensorSlices sts;
137   EXPECT_TRUE(ParseProtoUnlimited(&sts, value));
138   EXPECT_FALSE(sts.has_meta());
139   *ss = sts.data();
140   EXPECT_EQ(name, ss->name());
141   TensorSlice slice2(ss->slice());
142   EXPECT_EQ(slice.DebugString(), slice2.DebugString());
143 }
144 
CheckEntries(const string & fname)145 void TensorSliceWriteTestHelper::CheckEntries(const string& fname) {
146   TensorSliceReader::Table* tptr;
147   TF_CHECK_OK(OpenTableTensorSliceReader(fname, &tptr));
148   std::unique_ptr<TensorSliceReader::Table> table(tptr);
149   CHECK_NOTNULL(table.get());
150 
151   // We expect a block of SavedTensorSlices
152   string value;
153   ASSERT_TRUE(table->Get(kSavedTensorSlicesKey, &value));
154   {
155     SavedTensorSlices sts;
156     EXPECT_TRUE(ParseProtoUnlimited(&sts, value));
157     // We also expect two entries for the tensors
158     EXPECT_TRUE(sts.has_meta());
159     EXPECT_EQ(4, sts.meta().tensor_size());
160     // We should have written nontrivial version information
161     EXPECT_LT(0, TF_CHECKPOINT_VERSION);
162     EXPECT_EQ(TF_CHECKPOINT_VERSION, sts.meta().versions().producer());
163     EXPECT_EQ(TF_CHECKPOINT_VERSION_MIN_CONSUMER,
164               sts.meta().versions().min_consumer());
165     // We don't expect any data in the first block.
166     EXPECT_FALSE(sts.has_data());
167     // The two tensors should be stored in the same order as they are first
168     // created.
169     {
170       // The two slices of the "test" tensor
171       const SavedSliceMeta& ssm = sts.meta().tensor(0);
172       EXPECT_EQ("test", ssm.name());
173       TensorShapeProto expected_shape_proto;
174       protobuf::TextFormat::ParseFromString(
175           "dim { size: 5 } "
176           "dim { size: 10 }",
177           &expected_shape_proto);
178       EXPECT_EQ(ssm.shape().ShortDebugString(),
179                 expected_shape_proto.ShortDebugString());
180       EXPECT_EQ(DT_INT32, ssm.type());
181       EXPECT_EQ(2, ssm.slice_size());
182       TensorSlice s0(ssm.slice(0));
183       TensorSlice s1(ssm.slice(1));
184       EXPECT_EQ("-:0,1", s0.DebugString());
185       EXPECT_EQ("-:3,1", s1.DebugString());
186     }
187     {
188       // The "AA" tensor
189       const SavedSliceMeta& ssm = sts.meta().tensor(1);
190       EXPECT_EQ("AA", ssm.name());
191       TensorShapeProto expected_shape_proto;
192       protobuf::TextFormat::ParseFromString(
193           "dim { size: 3 } "
194           "dim { size: 2 }",
195           &expected_shape_proto);
196       EXPECT_EQ(ssm.shape().ShortDebugString(),
197                 expected_shape_proto.ShortDebugString());
198       EXPECT_EQ(DT_FLOAT, ssm.type());
199       EXPECT_EQ(1, ssm.slice_size());
200       TensorSlice s0(ssm.slice(0));
201       EXPECT_EQ("-:-", s0.DebugString());
202     }
203     {
204       // The "int64" tensor
205       const SavedSliceMeta& ssm = sts.meta().tensor(2);
206       EXPECT_EQ("int64", ssm.name());
207       TensorShapeProto expected_shape_proto;
208       protobuf::TextFormat::ParseFromString(
209           "dim { size: 5 } "
210           "dim { size: 10 }",
211           &expected_shape_proto);
212       EXPECT_EQ(ssm.shape().ShortDebugString(),
213                 expected_shape_proto.ShortDebugString());
214       EXPECT_EQ(DT_INT64, ssm.type());
215       EXPECT_EQ(1, ssm.slice_size());
216       TensorSlice s0(ssm.slice(0));
217       EXPECT_EQ("-:3,1", s0.DebugString());
218     }
219     {
220       // The "int16" tensor
221       const SavedSliceMeta& ssm = sts.meta().tensor(3);
222       EXPECT_EQ("int16", ssm.name());
223       TensorShapeProto expected_shape_proto;
224       protobuf::TextFormat::ParseFromString(
225           "dim { size: 5 } "
226           "dim { size: 10 }",
227           &expected_shape_proto);
228       EXPECT_EQ(ssm.shape().ShortDebugString(),
229                 expected_shape_proto.ShortDebugString());
230       EXPECT_EQ(DT_INT16, ssm.type());
231       EXPECT_EQ(1, ssm.slice_size());
232       TensorSlice s0(ssm.slice(0));
233       EXPECT_EQ("-:3,1", s0.DebugString());
234     }
235   }
236 
237   // We expect 5 blocks of tensor data
238   {
239     // Block 1: we expect it to be the full slice of the "AA" tensor
240     SavedSlice ss;
241     GetData(table.get(), "AA", TensorSlice(2), &ss);
242     const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3};
243     EXPECT_EQ(ArraySize(data), ss.data().float_val_size());
244     ExpectIdenticalFloatArrays(data, ArraySize(data),
245                                ss.data().float_val().data());
246   }
247 
248   {
249     // Block 2: we expect it to be the first slice of the "test" tensor
250     SavedSlice ss;
251     GetData(table.get(), "test", TensorSlice({{0, -1}, {0, 1}}), &ss);
252     const int32 data[] = {0, 1, 2, 3, 4};
253     EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
254     ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
255   }
256 
257   {
258     // Block 3: we expect it to be the second slice of the "test" tensor
259     SavedSlice ss;
260     GetData(table.get(), "test", TensorSlice({{0, -1}, {3, 1}}), &ss);
261     const int32 data[] = {10, 11, 12, 13, 14};
262     EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
263     ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
264   }
265 
266   {
267     // Block 4: we expect it to be the slice of the "int64" tensor
268     SavedSlice ss;
269     GetData(table.get(), "int64", TensorSlice({{0, -1}, {3, 1}}), &ss);
270     const int64_t data[] = {10, 11, 12, 13, 14};
271     EXPECT_EQ(ArraySize(data), ss.data().int64_val_size());
272     ExpectIdenticalIntArrays(data, ArraySize(data),
273                              ss.data().int64_val().data());
274   }
275 
276   {
277     // Block 5: we expect it to be the slice of the "int16" tensor
278     SavedSlice ss;
279     GetData(table.get(), "int16", TensorSlice({{0, -1}, {3, 1}}), &ss);
280     const int16 data[] = {10, 11, 12, 13, 14};
281     EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
282     ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
283   }
284 }
285 
286 template <typename DT>
BytesPerElementHelper(DT value)287 size_t BytesPerElementHelper(DT value) {
288   SavedSlice ss;
289   std::array<DT, 1> lo_data;
290   std::fill(lo_data.begin(), lo_data.end(), value);
291   TF_EXPECT_OK(
292       TensorSliceWriter::SaveData(lo_data.data(), lo_data.size(), &ss));
293   size_t lo_byte_size = ss.ByteSizeLong();
294 
295   std::array<DT, 1001> hi_data;
296   std::fill(hi_data.begin(), hi_data.end(), value);
297   TF_EXPECT_OK(
298       TensorSliceWriter::SaveData(hi_data.data(), hi_data.size(), &ss));
299   size_t hi_byte_size = ss.ByteSizeLong();
300 
301   return (hi_byte_size - lo_byte_size) / (hi_data.size() - lo_data.size());
302 }
303 
TEST(TensorSliceWriteTest,CheckpointSize)304 TEST(TensorSliceWriteTest, CheckpointSize) {
305   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL),
306             BytesPerElementHelper<bool>(false));
307   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL),
308             BytesPerElementHelper<bool>(true));
309   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_FLOAT),
310             BytesPerElementHelper<float>(-1.0));
311   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_DOUBLE),
312             BytesPerElementHelper<double>(-1.0));
313   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX64),
314             BytesPerElementHelper<complex64>(-1.0));
315   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX128),
316             BytesPerElementHelper<complex128>(-1.0));
317   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT32),
318             BytesPerElementHelper<int32>(-1));
319   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT64),
320             BytesPerElementHelper<int64_t>(-1));
321   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT16),
322             BytesPerElementHelper<uint16>(std::numeric_limits<uint16>::max()));
323   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT8),
324             BytesPerElementHelper<uint8>(std::numeric_limits<uint8>::max()));
325   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT8),
326             BytesPerElementHelper<int8>(-1));
327   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT16),
328             BytesPerElementHelper<int16>(-1));
329   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT8),
330             BytesPerElementHelper<qint8>(-1));
331   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QUINT8),
332             BytesPerElementHelper<quint8>(std::numeric_limits<uint8>::max()));
333   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT32),
334             BytesPerElementHelper<qint32>(-1));
335   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_HALF),
336             BytesPerElementHelper<Eigen::half>(Eigen::half(-1.0)));
337 }
338 
TEST(TensorSliceWriteTest,SizeErrors)339 TEST(TensorSliceWriteTest, SizeErrors) {
340   const string filename = io::JoinPath(testing::TmpDir(), "checkpoint");
341 
342   TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder);
343 
344   // Add a 300MB int8 tensor slice, which will fail because it expands to 3GB.
345   {
346     TensorShape shape({300, 1000000});
347     TensorSlice slice = TensorSlice::ParseOrDie("-:-");
348     const std::vector<int8> data(300000000, -1);
349     Status s = writer.Add("test1", shape, slice, data.data());
350     EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
351     EXPECT_TRUE(absl::StrContains(s.error_message(),
352                                   "Tensor slice is too large to serialize"));
353   }
354 
355   // Add a large string tensor slice, which will fail.
356   {
357     TensorShape shape({256, 1024});
358     TensorSlice slice = TensorSlice::ParseOrDie("-:-");
359     const std::vector<tstring> data(256 * 1024, std::string(8192, 'f'));
360     Status s = writer.Add("test2", shape, slice, data.data());
361     EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
362     EXPECT_TRUE(absl::StrContains(s.error_message(),
363                                   "Tensor slice is too large to serialize"));
364   }
365 }
366 
TEST(TensorSliceWriterTest,InvalidInput)367 TEST(TensorSliceWriterTest, InvalidInput) {
368   SavedSlice ss;
369   std::array<uint32_t, 1> data;
370   std::fill(data.begin(), data.end(), 1234);
371   Status s = TensorSliceWriter::SaveData(data.data(), data.size(), &ss);
372   EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
373   EXPECT_TRUE(absl::StrContains(
374       s.error_message(),
375       "Tensor slice serialization not implemented for dtype"));
376 }
377 
378 }  // namespace checkpoint
379 
380 }  // namespace tensorflow
381