1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.experimental.TFRecordWriter`.""" 16import os 17 18from absl.testing import parameterized 19 20from tensorflow.python.data.experimental.ops import grouping 21from tensorflow.python.data.experimental.ops import writers 22from tensorflow.python.data.kernel_tests import test_base 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.data.ops import readers 25from tensorflow.python.eager import function 26from tensorflow.python.framework import combinations 27from tensorflow.python.framework import dtypes 28from tensorflow.python.lib.io import python_io 29from tensorflow.python.lib.io import tf_record 30from tensorflow.python.ops import string_ops 31from tensorflow.python.platform import test 32from tensorflow.python.util import compat 33 34 35class TFRecordWriterTest(test_base.DatasetTestBase, parameterized.TestCase): 36 37 def setUp(self): 38 super(TFRecordWriterTest, self).setUp() 39 self._num_records = 8 40 41 def writer_fn(self, filename, compression_type=""): 42 input_dataset = readers.TFRecordDataset([filename], compression_type) 43 return writers.TFRecordWriter(self._outputFilename(), 44 compression_type).write(input_dataset) 45 46 def _record(self, i): 47 return compat.as_bytes("Record %d" % (i)) 48 49 def _createFile(self, options=None): 50 filename = self._inputFilename() 51 writer = python_io.TFRecordWriter(filename, options) 52 for i in range(self._num_records): 53 writer.write(self._record(i)) 54 writer.close() 55 return filename 56 57 def _inputFilename(self): 58 return os.path.join(self.get_temp_dir(), "tf_record.in.txt") 59 60 def _outputFilename(self): 61 return os.path.join(self.get_temp_dir(), "tf_record.out.txt") 62 63 @combinations.generate(test_base.default_test_combinations()) 64 def testWrite(self): 65 self.evaluate(self.writer_fn(self._createFile())) 66 for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())): 67 self.assertAllEqual(self._record(i), r) 68 69 @combinations.generate(test_base.default_test_combinations()) 70 def testWriteZLIB(self): 71 options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB) 72 self.evaluate( 73 self.writer_fn(self._createFile(options), compression_type="ZLIB")) 74 for i, r in enumerate( 75 tf_record.tf_record_iterator(self._outputFilename(), options=options)): 76 self.assertAllEqual(self._record(i), r) 77 78 @combinations.generate(test_base.default_test_combinations()) 79 def testWriteGZIP(self): 80 options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP) 81 self.evaluate( 82 self.writer_fn(self._createFile(options), compression_type="GZIP")) 83 for i, r in enumerate( 84 tf_record.tf_record_iterator(self._outputFilename(), options=options)): 85 self.assertAllEqual(self._record(i), r) 86 87 @combinations.generate(test_base.default_test_combinations()) 88 def testFailDataset(self): 89 with self.assertRaises(TypeError): 90 writers.TFRecordWriter(self._outputFilename(), "").write("whoops") 91 92 @combinations.generate(test_base.default_test_combinations()) 93 def testFailDType(self): 94 input_dataset = dataset_ops.Dataset.from_tensors(10) 95 with self.assertRaises(TypeError): 96 writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset) 97 98 @combinations.generate(test_base.default_test_combinations()) 99 def testFailShape(self): 100 input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]]) 101 with self.assertRaises(TypeError): 102 writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset) 103 104 @combinations.generate(test_base.default_test_combinations()) 105 def testSideEffect(self): 106 def writer_fn(): 107 input_dataset = readers.TFRecordDataset(self._createFile()) 108 return writers.TFRecordWriter(self._outputFilename()).write(input_dataset) 109 110 @function.defun 111 def fn(): 112 _ = writer_fn() 113 return "hello" 114 115 self.assertEqual(self.evaluate(fn()), b"hello") 116 for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())): 117 self.assertAllEqual(self._record(i), r) 118 119 @combinations.generate(test_base.default_test_combinations()) 120 def testShard(self): 121 filename = self._createFile() 122 dataset = readers.TFRecordDataset([filename]) 123 124 def reduce_func(key, dataset): 125 shard_filename = string_ops.string_join( 126 [filename, string_ops.as_string(key)]) 127 writer = writers.TFRecordWriter(shard_filename) 128 writer.write(dataset.map(lambda _, x: x)) 129 return dataset_ops.Dataset.from_tensors(shard_filename) 130 131 dataset = dataset.enumerate() 132 dataset = dataset.apply( 133 grouping.group_by_window(lambda i, _: i % 2, reduce_func, 134 dtypes.int64.max)) 135 136 get_next = self.getNext(dataset) 137 for i in range(2): 138 shard_filename = (filename + str(i)).encode() 139 self.assertEqual(self.evaluate(get_next()), shard_filename) 140 for j, r in enumerate(tf_record.tf_record_iterator(shard_filename)): 141 self.assertAllEqual(self._record(i + 2*j), r) 142 143 144if __name__ == "__main__": 145 test.main() 146