xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/kernel_tests/tf_record_test_base.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for testing reader datasets."""
16
17import os
18
19from tensorflow.core.example import example_pb2
20from tensorflow.core.example import feature_pb2
21from tensorflow.python.data.experimental.ops import readers
22from tensorflow.python.data.kernel_tests import test_base
23from tensorflow.python.data.ops import readers as core_readers
24from tensorflow.python.framework import dtypes
25from tensorflow.python.lib.io import python_io
26from tensorflow.python.ops import parsing_ops
27from tensorflow.python.util import compat
28
29
30class FeaturesTestBase(test_base.DatasetTestBase):
31  """Base class for testing TFRecord-based features."""
32
33  def setUp(self):
34    super(FeaturesTestBase, self).setUp()
35    self._num_files = 2
36    self._num_records = 7
37    self._filenames = self._createFiles()
38
39  def make_batch_feature(self,
40                         filenames,
41                         num_epochs,
42                         batch_size,
43                         label_key=None,
44                         reader_num_threads=1,
45                         parser_num_threads=1,
46                         shuffle=False,
47                         shuffle_seed=None,
48                         drop_final_batch=False):
49    self.filenames = filenames
50    self.num_epochs = num_epochs
51    self.batch_size = batch_size
52
53    return readers.make_batched_features_dataset(
54        file_pattern=self.filenames,
55        batch_size=self.batch_size,
56        features={
57            "file": parsing_ops.FixedLenFeature([], dtypes.int64),
58            "record": parsing_ops.FixedLenFeature([], dtypes.int64),
59            "keywords": parsing_ops.VarLenFeature(dtypes.string),
60            "label": parsing_ops.FixedLenFeature([], dtypes.string),
61        },
62        label_key=label_key,
63        reader=core_readers.TFRecordDataset,
64        num_epochs=self.num_epochs,
65        shuffle=shuffle,
66        shuffle_seed=shuffle_seed,
67        reader_num_threads=reader_num_threads,
68        parser_num_threads=parser_num_threads,
69        drop_final_batch=drop_final_batch)
70
71  def _record(self, f, r, l):
72    example = example_pb2.Example(
73        features=feature_pb2.Features(
74            feature={
75                "file":
76                    feature_pb2.Feature(
77                        int64_list=feature_pb2.Int64List(value=[f])),
78                "record":
79                    feature_pb2.Feature(
80                        int64_list=feature_pb2.Int64List(value=[r])),
81                "keywords":
82                    feature_pb2.Feature(
83                        bytes_list=feature_pb2.BytesList(
84                            value=self._get_keywords(f, r))),
85                "label":
86                    feature_pb2.Feature(
87                        bytes_list=feature_pb2.BytesList(
88                            value=[compat.as_bytes(l)]))
89            }))
90    return example.SerializeToString()
91
92  def _get_keywords(self, f, r):
93    num_keywords = 1 + (f + r) % 2
94    keywords = []
95    for index in range(num_keywords):
96      keywords.append(compat.as_bytes("keyword%d" % index))
97    return keywords
98
99  def _sum_keywords(self, num_files):
100    sum_keywords = 0
101    for i in range(num_files):
102      for j in range(self._num_records):
103        sum_keywords += 1 + (i + j) % 2
104    return sum_keywords
105
106  def _createFiles(self):
107    filenames = []
108    for i in range(self._num_files):
109      fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
110      filenames.append(fn)
111      writer = python_io.TFRecordWriter(fn)
112      for j in range(self._num_records):
113        writer.write(self._record(i, j, "fake-label"))
114      writer.close()
115    return filenames
116
117  def _run_actual_batch(self, outputs, label_key_provided=False):
118    if label_key_provided:
119      # outputs would be a tuple of (feature dict, label)
120      features, label = self.evaluate(outputs())
121    else:
122      features = self.evaluate(outputs())
123      label = features["label"]
124    file_out = features["file"]
125    keywords_indices = features["keywords"].indices
126    keywords_values = features["keywords"].values
127    keywords_dense_shape = features["keywords"].dense_shape
128    record = features["record"]
129    return ([
130        file_out, keywords_indices, keywords_values, keywords_dense_shape,
131        record, label
132    ])
133
134  def _next_actual_batch(self, label_key_provided=False):
135    return self._run_actual_batch(self.outputs, label_key_provided)
136
137  def _interleave(self, iterators, cycle_length):
138    pending_iterators = iterators
139    open_iterators = []
140    num_open = 0
141    for i in range(cycle_length):
142      if pending_iterators:
143        open_iterators.append(pending_iterators.pop(0))
144        num_open += 1
145
146    while num_open:
147      for i in range(min(cycle_length, len(open_iterators))):
148        if open_iterators[i] is None:
149          continue
150        try:
151          yield next(open_iterators[i])
152        except StopIteration:
153          if pending_iterators:
154            open_iterators[i] = pending_iterators.pop(0)
155          else:
156            open_iterators[i] = None
157            num_open -= 1
158
159  def _next_expected_batch(self,
160                           file_indices,
161                           batch_size,
162                           num_epochs,
163                           cycle_length=1):
164
165    def _next_record(file_indices):
166      for j in file_indices:
167        for i in range(self._num_records):
168          yield j, i, compat.as_bytes("fake-label")
169
170    def _next_record_interleaved(file_indices, cycle_length):
171      return self._interleave([_next_record([i]) for i in file_indices],
172                              cycle_length)
173
174    file_batch = []
175    keywords_batch_indices = []
176    keywords_batch_values = []
177    keywords_batch_max_len = 0
178    record_batch = []
179    batch_index = 0
180    label_batch = []
181    for _ in range(num_epochs):
182      if cycle_length == 1:
183        next_records = _next_record(file_indices)
184      else:
185        next_records = _next_record_interleaved(file_indices, cycle_length)
186      for record in next_records:
187        f = record[0]
188        r = record[1]
189        label_batch.append(record[2])
190        file_batch.append(f)
191        record_batch.append(r)
192        keywords = self._get_keywords(f, r)
193        keywords_batch_values.extend(keywords)
194        keywords_batch_indices.extend(
195            [[batch_index, i] for i in range(len(keywords))])
196        batch_index += 1
197        keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
198        if len(file_batch) == batch_size:
199          yield [
200              file_batch, keywords_batch_indices, keywords_batch_values,
201              [batch_size, keywords_batch_max_len], record_batch, label_batch
202          ]
203          file_batch = []
204          keywords_batch_indices = []
205          keywords_batch_values = []
206          keywords_batch_max_len = 0
207          record_batch = []
208          batch_index = 0
209          label_batch = []
210    if file_batch:
211      yield [
212          file_batch, keywords_batch_indices, keywords_batch_values,
213          [len(file_batch), keywords_batch_max_len], record_batch, label_batch
214      ]
215
216  def _verify_records(self,
217                      batch_size,
218                      file_index=None,
219                      num_epochs=1,
220                      label_key_provided=False,
221                      interleave_cycle_length=1):
222    if file_index is not None:
223      file_indices = [file_index]
224    else:
225      file_indices = range(self._num_files)
226
227    for expected_batch in self._next_expected_batch(
228        file_indices,
229        batch_size,
230        num_epochs,
231        cycle_length=interleave_cycle_length):
232      actual_batch = self._next_actual_batch(
233          label_key_provided=label_key_provided)
234      for i in range(len(expected_batch)):
235        self.assertAllEqual(expected_batch[i], actual_batch[i])
236
237
238class TFRecordTestBase(test_base.DatasetTestBase):
239  """Base class for TFRecord-based tests."""
240
241  def setUp(self):
242    super(TFRecordTestBase, self).setUp()
243    self._num_files = 2
244    self._num_records = 7
245    self._filenames = self._createFiles()
246
247  def _interleave(self, iterators, cycle_length):
248    pending_iterators = iterators
249    open_iterators = []
250    num_open = 0
251    for i in range(cycle_length):
252      if pending_iterators:
253        open_iterators.append(pending_iterators.pop(0))
254        num_open += 1
255
256    while num_open:
257      for i in range(min(cycle_length, len(open_iterators))):
258        if open_iterators[i] is None:
259          continue
260        try:
261          yield next(open_iterators[i])
262        except StopIteration:
263          if pending_iterators:
264            open_iterators[i] = pending_iterators.pop(0)
265          else:
266            open_iterators[i] = None
267            num_open -= 1
268
269  def _next_expected_batch(self, file_indices, batch_size, num_epochs,
270                           cycle_length, drop_final_batch, use_parser_fn):
271
272    def _next_record(file_indices):
273      for j in file_indices:
274        for i in range(self._num_records):
275          yield j, i
276
277    def _next_record_interleaved(file_indices, cycle_length):
278      return self._interleave([_next_record([i]) for i in file_indices],
279                              cycle_length)
280
281    record_batch = []
282    batch_index = 0
283    for _ in range(num_epochs):
284      if cycle_length == 1:
285        next_records = _next_record(file_indices)
286      else:
287        next_records = _next_record_interleaved(file_indices, cycle_length)
288      for f, r in next_records:
289        record = self._record(f, r)
290        if use_parser_fn:
291          record = record[1:]
292        record_batch.append(record)
293        batch_index += 1
294        if len(record_batch) == batch_size:
295          yield record_batch
296          record_batch = []
297          batch_index = 0
298    if record_batch and not drop_final_batch:
299      yield record_batch
300
301  def _verify_records(self, outputs, batch_size, file_index, num_epochs,
302                      interleave_cycle_length, drop_final_batch, use_parser_fn):
303    if file_index is not None:
304      if isinstance(file_index, list):
305        file_indices = file_index
306      else:
307        file_indices = [file_index]
308    else:
309      file_indices = range(self._num_files)
310
311    for expected_batch in self._next_expected_batch(
312        file_indices, batch_size, num_epochs, interleave_cycle_length,
313        drop_final_batch, use_parser_fn):
314      actual_batch = self.evaluate(outputs())
315      self.assertAllEqual(expected_batch, actual_batch)
316
317  def _record(self, f, r):
318    return compat.as_bytes("Record %d of file %d" % (r, f))
319
320  def _createFiles(self):
321    filenames = []
322    for i in range(self._num_files):
323      fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
324      filenames.append(fn)
325      writer = python_io.TFRecordWriter(fn)
326      for j in range(self._num_records):
327        writer.write(self._record(i, j))
328      writer.close()
329    return filenames
330
331  def _writeFile(self, name, data):
332    filename = os.path.join(self.get_temp_dir(), name)
333    writer = python_io.TFRecordWriter(filename)
334    for d in data:
335      writer.write(compat.as_bytes(str(d)))
336    writer.close()
337    return filename
338
339