1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base class for testing reader datasets.""" 16 17import os 18 19from tensorflow.core.example import example_pb2 20from tensorflow.core.example import feature_pb2 21from tensorflow.python.data.experimental.ops import readers 22from tensorflow.python.data.kernel_tests import test_base 23from tensorflow.python.data.ops import readers as core_readers 24from tensorflow.python.framework import dtypes 25from tensorflow.python.lib.io import python_io 26from tensorflow.python.ops import parsing_ops 27from tensorflow.python.util import compat 28 29 30class FeaturesTestBase(test_base.DatasetTestBase): 31 """Base class for testing TFRecord-based features.""" 32 33 def setUp(self): 34 super(FeaturesTestBase, self).setUp() 35 self._num_files = 2 36 self._num_records = 7 37 self._filenames = self._createFiles() 38 39 def make_batch_feature(self, 40 filenames, 41 num_epochs, 42 batch_size, 43 label_key=None, 44 reader_num_threads=1, 45 parser_num_threads=1, 46 shuffle=False, 47 shuffle_seed=None, 48 drop_final_batch=False): 49 self.filenames = filenames 50 self.num_epochs = num_epochs 51 self.batch_size = batch_size 52 53 return readers.make_batched_features_dataset( 54 file_pattern=self.filenames, 55 batch_size=self.batch_size, 56 features={ 57 "file": parsing_ops.FixedLenFeature([], dtypes.int64), 58 "record": parsing_ops.FixedLenFeature([], dtypes.int64), 59 "keywords": parsing_ops.VarLenFeature(dtypes.string), 60 "label": parsing_ops.FixedLenFeature([], dtypes.string), 61 }, 62 label_key=label_key, 63 reader=core_readers.TFRecordDataset, 64 num_epochs=self.num_epochs, 65 shuffle=shuffle, 66 shuffle_seed=shuffle_seed, 67 reader_num_threads=reader_num_threads, 68 parser_num_threads=parser_num_threads, 69 drop_final_batch=drop_final_batch) 70 71 def _record(self, f, r, l): 72 example = example_pb2.Example( 73 features=feature_pb2.Features( 74 feature={ 75 "file": 76 feature_pb2.Feature( 77 int64_list=feature_pb2.Int64List(value=[f])), 78 "record": 79 feature_pb2.Feature( 80 int64_list=feature_pb2.Int64List(value=[r])), 81 "keywords": 82 feature_pb2.Feature( 83 bytes_list=feature_pb2.BytesList( 84 value=self._get_keywords(f, r))), 85 "label": 86 feature_pb2.Feature( 87 bytes_list=feature_pb2.BytesList( 88 value=[compat.as_bytes(l)])) 89 })) 90 return example.SerializeToString() 91 92 def _get_keywords(self, f, r): 93 num_keywords = 1 + (f + r) % 2 94 keywords = [] 95 for index in range(num_keywords): 96 keywords.append(compat.as_bytes("keyword%d" % index)) 97 return keywords 98 99 def _sum_keywords(self, num_files): 100 sum_keywords = 0 101 for i in range(num_files): 102 for j in range(self._num_records): 103 sum_keywords += 1 + (i + j) % 2 104 return sum_keywords 105 106 def _createFiles(self): 107 filenames = [] 108 for i in range(self._num_files): 109 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 110 filenames.append(fn) 111 writer = python_io.TFRecordWriter(fn) 112 for j in range(self._num_records): 113 writer.write(self._record(i, j, "fake-label")) 114 writer.close() 115 return filenames 116 117 def _run_actual_batch(self, outputs, label_key_provided=False): 118 if label_key_provided: 119 # outputs would be a tuple of (feature dict, label) 120 features, label = self.evaluate(outputs()) 121 else: 122 features = self.evaluate(outputs()) 123 label = features["label"] 124 file_out = features["file"] 125 keywords_indices = features["keywords"].indices 126 keywords_values = features["keywords"].values 127 keywords_dense_shape = features["keywords"].dense_shape 128 record = features["record"] 129 return ([ 130 file_out, keywords_indices, keywords_values, keywords_dense_shape, 131 record, label 132 ]) 133 134 def _next_actual_batch(self, label_key_provided=False): 135 return self._run_actual_batch(self.outputs, label_key_provided) 136 137 def _interleave(self, iterators, cycle_length): 138 pending_iterators = iterators 139 open_iterators = [] 140 num_open = 0 141 for i in range(cycle_length): 142 if pending_iterators: 143 open_iterators.append(pending_iterators.pop(0)) 144 num_open += 1 145 146 while num_open: 147 for i in range(min(cycle_length, len(open_iterators))): 148 if open_iterators[i] is None: 149 continue 150 try: 151 yield next(open_iterators[i]) 152 except StopIteration: 153 if pending_iterators: 154 open_iterators[i] = pending_iterators.pop(0) 155 else: 156 open_iterators[i] = None 157 num_open -= 1 158 159 def _next_expected_batch(self, 160 file_indices, 161 batch_size, 162 num_epochs, 163 cycle_length=1): 164 165 def _next_record(file_indices): 166 for j in file_indices: 167 for i in range(self._num_records): 168 yield j, i, compat.as_bytes("fake-label") 169 170 def _next_record_interleaved(file_indices, cycle_length): 171 return self._interleave([_next_record([i]) for i in file_indices], 172 cycle_length) 173 174 file_batch = [] 175 keywords_batch_indices = [] 176 keywords_batch_values = [] 177 keywords_batch_max_len = 0 178 record_batch = [] 179 batch_index = 0 180 label_batch = [] 181 for _ in range(num_epochs): 182 if cycle_length == 1: 183 next_records = _next_record(file_indices) 184 else: 185 next_records = _next_record_interleaved(file_indices, cycle_length) 186 for record in next_records: 187 f = record[0] 188 r = record[1] 189 label_batch.append(record[2]) 190 file_batch.append(f) 191 record_batch.append(r) 192 keywords = self._get_keywords(f, r) 193 keywords_batch_values.extend(keywords) 194 keywords_batch_indices.extend( 195 [[batch_index, i] for i in range(len(keywords))]) 196 batch_index += 1 197 keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) 198 if len(file_batch) == batch_size: 199 yield [ 200 file_batch, keywords_batch_indices, keywords_batch_values, 201 [batch_size, keywords_batch_max_len], record_batch, label_batch 202 ] 203 file_batch = [] 204 keywords_batch_indices = [] 205 keywords_batch_values = [] 206 keywords_batch_max_len = 0 207 record_batch = [] 208 batch_index = 0 209 label_batch = [] 210 if file_batch: 211 yield [ 212 file_batch, keywords_batch_indices, keywords_batch_values, 213 [len(file_batch), keywords_batch_max_len], record_batch, label_batch 214 ] 215 216 def _verify_records(self, 217 batch_size, 218 file_index=None, 219 num_epochs=1, 220 label_key_provided=False, 221 interleave_cycle_length=1): 222 if file_index is not None: 223 file_indices = [file_index] 224 else: 225 file_indices = range(self._num_files) 226 227 for expected_batch in self._next_expected_batch( 228 file_indices, 229 batch_size, 230 num_epochs, 231 cycle_length=interleave_cycle_length): 232 actual_batch = self._next_actual_batch( 233 label_key_provided=label_key_provided) 234 for i in range(len(expected_batch)): 235 self.assertAllEqual(expected_batch[i], actual_batch[i]) 236 237 238class TFRecordTestBase(test_base.DatasetTestBase): 239 """Base class for TFRecord-based tests.""" 240 241 def setUp(self): 242 super(TFRecordTestBase, self).setUp() 243 self._num_files = 2 244 self._num_records = 7 245 self._filenames = self._createFiles() 246 247 def _interleave(self, iterators, cycle_length): 248 pending_iterators = iterators 249 open_iterators = [] 250 num_open = 0 251 for i in range(cycle_length): 252 if pending_iterators: 253 open_iterators.append(pending_iterators.pop(0)) 254 num_open += 1 255 256 while num_open: 257 for i in range(min(cycle_length, len(open_iterators))): 258 if open_iterators[i] is None: 259 continue 260 try: 261 yield next(open_iterators[i]) 262 except StopIteration: 263 if pending_iterators: 264 open_iterators[i] = pending_iterators.pop(0) 265 else: 266 open_iterators[i] = None 267 num_open -= 1 268 269 def _next_expected_batch(self, file_indices, batch_size, num_epochs, 270 cycle_length, drop_final_batch, use_parser_fn): 271 272 def _next_record(file_indices): 273 for j in file_indices: 274 for i in range(self._num_records): 275 yield j, i 276 277 def _next_record_interleaved(file_indices, cycle_length): 278 return self._interleave([_next_record([i]) for i in file_indices], 279 cycle_length) 280 281 record_batch = [] 282 batch_index = 0 283 for _ in range(num_epochs): 284 if cycle_length == 1: 285 next_records = _next_record(file_indices) 286 else: 287 next_records = _next_record_interleaved(file_indices, cycle_length) 288 for f, r in next_records: 289 record = self._record(f, r) 290 if use_parser_fn: 291 record = record[1:] 292 record_batch.append(record) 293 batch_index += 1 294 if len(record_batch) == batch_size: 295 yield record_batch 296 record_batch = [] 297 batch_index = 0 298 if record_batch and not drop_final_batch: 299 yield record_batch 300 301 def _verify_records(self, outputs, batch_size, file_index, num_epochs, 302 interleave_cycle_length, drop_final_batch, use_parser_fn): 303 if file_index is not None: 304 if isinstance(file_index, list): 305 file_indices = file_index 306 else: 307 file_indices = [file_index] 308 else: 309 file_indices = range(self._num_files) 310 311 for expected_batch in self._next_expected_batch( 312 file_indices, batch_size, num_epochs, interleave_cycle_length, 313 drop_final_batch, use_parser_fn): 314 actual_batch = self.evaluate(outputs()) 315 self.assertAllEqual(expected_batch, actual_batch) 316 317 def _record(self, f, r): 318 return compat.as_bytes("Record %d of file %d" % (r, f)) 319 320 def _createFiles(self): 321 filenames = [] 322 for i in range(self._num_files): 323 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 324 filenames.append(fn) 325 writer = python_io.TFRecordWriter(fn) 326 for j in range(self._num_records): 327 writer.write(self._record(i, j)) 328 writer.close() 329 return filenames 330 331 def _writeFile(self, name, data): 332 filename = os.path.join(self.get_temp_dir(), name) 333 writer = python_io.TFRecordWriter(filename) 334 for d in data: 335 writer.write(compat.as_bytes(str(d))) 336 writer.close() 337 return filename 338 339