1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.experimental.TFRecordWriter`."""
16import os
17
18from absl.testing import parameterized
19
20from tensorflow.python.data.experimental.ops import grouping
21from tensorflow.python.data.experimental.ops import writers
22from tensorflow.python.data.kernel_tests import test_base
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.ops import readers
25from tensorflow.python.eager import function
26from tensorflow.python.framework import combinations
27from tensorflow.python.framework import dtypes
28from tensorflow.python.lib.io import python_io
29from tensorflow.python.lib.io import tf_record
30from tensorflow.python.ops import string_ops
31from tensorflow.python.platform import test
32from tensorflow.python.util import compat
33
34
35class TFRecordWriterTest(test_base.DatasetTestBase, parameterized.TestCase):
36
37  def setUp(self):
38    super(TFRecordWriterTest, self).setUp()
39    self._num_records = 8
40
41  def writer_fn(self, filename, compression_type=""):
42    input_dataset = readers.TFRecordDataset([filename], compression_type)
43    return writers.TFRecordWriter(self._outputFilename(),
44                                  compression_type).write(input_dataset)
45
46  def _record(self, i):
47    return compat.as_bytes("Record %d" % (i))
48
49  def _createFile(self, options=None):
50    filename = self._inputFilename()
51    writer = python_io.TFRecordWriter(filename, options)
52    for i in range(self._num_records):
53      writer.write(self._record(i))
54    writer.close()
55    return filename
56
57  def _inputFilename(self):
58    return os.path.join(self.get_temp_dir(), "tf_record.in.txt")
59
60  def _outputFilename(self):
61    return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
62
63  @combinations.generate(test_base.default_test_combinations())
64  def testWrite(self):
65    self.evaluate(self.writer_fn(self._createFile()))
66    for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
67      self.assertAllEqual(self._record(i), r)
68
69  @combinations.generate(test_base.default_test_combinations())
70  def testWriteZLIB(self):
71    options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
72    self.evaluate(
73        self.writer_fn(self._createFile(options), compression_type="ZLIB"))
74    for i, r in enumerate(
75        tf_record.tf_record_iterator(self._outputFilename(), options=options)):
76      self.assertAllEqual(self._record(i), r)
77
78  @combinations.generate(test_base.default_test_combinations())
79  def testWriteGZIP(self):
80    options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
81    self.evaluate(
82        self.writer_fn(self._createFile(options), compression_type="GZIP"))
83    for i, r in enumerate(
84        tf_record.tf_record_iterator(self._outputFilename(), options=options)):
85      self.assertAllEqual(self._record(i), r)
86
87  @combinations.generate(test_base.default_test_combinations())
88  def testFailDataset(self):
89    with self.assertRaises(TypeError):
90      writers.TFRecordWriter(self._outputFilename(), "").write("whoops")
91
92  @combinations.generate(test_base.default_test_combinations())
93  def testFailDType(self):
94    input_dataset = dataset_ops.Dataset.from_tensors(10)
95    with self.assertRaises(TypeError):
96      writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)
97
98  @combinations.generate(test_base.default_test_combinations())
99  def testFailShape(self):
100    input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]])
101    with self.assertRaises(TypeError):
102      writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)
103
104  @combinations.generate(test_base.default_test_combinations())
105  def testSideEffect(self):
106    def writer_fn():
107      input_dataset = readers.TFRecordDataset(self._createFile())
108      return writers.TFRecordWriter(self._outputFilename()).write(input_dataset)
109
110    @function.defun
111    def fn():
112      _ = writer_fn()
113      return "hello"
114
115    self.assertEqual(self.evaluate(fn()), b"hello")
116    for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
117      self.assertAllEqual(self._record(i), r)
118
119  @combinations.generate(test_base.default_test_combinations())
120  def testShard(self):
121    filename = self._createFile()
122    dataset = readers.TFRecordDataset([filename])
123
124    def reduce_func(key, dataset):
125      shard_filename = string_ops.string_join(
126          [filename, string_ops.as_string(key)])
127      writer = writers.TFRecordWriter(shard_filename)
128      writer.write(dataset.map(lambda _, x: x))
129      return dataset_ops.Dataset.from_tensors(shard_filename)
130
131    dataset = dataset.enumerate()
132    dataset = dataset.apply(
133        grouping.group_by_window(lambda i, _: i % 2, reduce_func,
134                                 dtypes.int64.max))
135
136    get_next = self.getNext(dataset)
137    for i in range(2):
138      shard_filename = (filename + str(i)).encode()
139      self.assertEqual(self.evaluate(get_next()), shard_filename)
140      for j, r in enumerate(tf_record.tf_record_iterator(shard_filename)):
141        self.assertAllEqual(self._record(i + 2*j), r)
142
143
144if __name__ == "__main__":
145  test.main()
146