1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base test class for checkpointing datasets.""" 16 17import os 18 19import numpy as np 20from tensorflow.python.checkpoint import checkpoint as tracking_util 21from tensorflow.python.checkpoint import checkpoint_management 22from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.data.ops import options as options_lib 25from tensorflow.python.eager import context 26from tensorflow.python.framework import combinations 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.ops import lookup_ops 32from tensorflow.python.ops import variables 33from tensorflow.python.ops.ragged import ragged_tensor_value 34from tensorflow.python.platform import gfile 35from tensorflow.python.platform import test 36from tensorflow.python.training import saver as saver_lib 37from tensorflow.python.util import nest 38 39 40def remove_variants(get_next_op): 41 # TODO(b/72408568): Remove this once session.run can get variant tensors. 42 """Remove variants from a nest structure, so sess.run will execute.""" 43 44 def _remove_variant(x): 45 if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: 46 return () 47 else: 48 return x 49 50 return nest.map_structure(_remove_variant, get_next_op) 51 52 53def default_test_combinations(): 54 """Returns the default test combinations for testing checkpointing.""" 55 56 def disable_optimizations(ds_fn): 57 options = options_lib.Options() 58 options.experimental_optimization.apply_default_optimizations = False 59 60 def ds_fn_no_opt(): 61 return ds_fn().with_options(options) 62 63 return ds_fn_no_opt 64 65 def verify_unused_iterator(obj, ds_fn, num_outputs, sparse_tensors=False): 66 obj.verify_unused_iterator( 67 ds_fn=disable_optimizations(ds_fn=ds_fn), 68 num_outputs=num_outputs, 69 sparse_tensors=sparse_tensors) 70 71 verify_unused_iterator_combination = combinations.combine( 72 verify_fn=combinations.NamedObject("verify_unused_iterator", 73 verify_unused_iterator)) 74 75 def verify_fully_used_iterator(obj, ds_fn, num_outputs, sparse_tensors=False): 76 obj.verify_fully_used_iterator( 77 ds_fn=disable_optimizations(ds_fn=ds_fn), 78 num_outputs=num_outputs, 79 sparse_tensors=sparse_tensors) 80 81 verify_fully_used_iterator_combination = combinations.combine( 82 verify_fn=combinations.NamedObject("verify_fully_used_iterator", 83 verify_fully_used_iterator)) 84 85 def verify_exhausted_iterator(obj, ds_fn, num_outputs, sparse_tensors=False): 86 obj.verify_exhausted_iterator( 87 ds_fn=disable_optimizations(ds_fn=ds_fn), 88 num_outputs=num_outputs, 89 sparse_tensors=sparse_tensors) 90 91 verify_exhausted_iterator_combination = combinations.combine( 92 verify_fn=combinations.NamedObject("verify_exhausted_iterator", 93 verify_exhausted_iterator)) 94 95 def verify_multiple_breaks(obj, ds_fn, num_outputs, sparse_tensors=False): 96 obj.verify_multiple_breaks( 97 ds_fn=disable_optimizations(ds_fn=ds_fn), 98 num_outputs=num_outputs, 99 sparse_tensors=sparse_tensors) 100 101 verify_multiple_breaks_combination = combinations.combine( 102 verify_fn=combinations.NamedObject("verify_multiple_breaks", 103 verify_multiple_breaks)) 104 105 def verify_reset_restored_iterator(obj, 106 ds_fn, 107 num_outputs, 108 sparse_tensors=False): 109 obj.verify_reset_restored_iterator( 110 ds_fn=disable_optimizations(ds_fn=ds_fn), 111 num_outputs=num_outputs, 112 sparse_tensors=sparse_tensors) 113 114 verify_reset_restored_iterator_combination = combinations.combine( 115 verify_fn=combinations.NamedObject("verify_reset_restored_iterator", 116 verify_reset_restored_iterator)) 117 118 return (verify_unused_iterator_combination + 119 verify_fully_used_iterator_combination + 120 verify_exhausted_iterator_combination + 121 verify_multiple_breaks_combination + 122 verify_reset_restored_iterator_combination) 123 124 125# TODO(b/72657739): Remove sparse_tensor argument, which is to test the 126# (deprecated) saveable `SparseTensorSliceDataset`, once the API 127# `from_sparse_tensor_slices()` and related tests are deleted. 128class CheckpointTestBase(test.TestCase): 129 """Base test class for checkpointing datasets.""" 130 131 def tearDown(self): 132 self._delete_ckpt() 133 super(CheckpointTestBase, self).tearDown() 134 135 def verify_unused_iterator(self, 136 ds_fn, 137 num_outputs, 138 sparse_tensors=False, 139 verify_exhausted=True): 140 """Verifies that saving and restoring an unused iterator works. 141 142 Args: 143 ds_fn: 0-argument function that returns a Dataset. 144 num_outputs: Total number of outputs expected from this Dataset. 145 sparse_tensors: Whether dataset is built from SparseTensor(s). 146 verify_exhausted: Whether to verify that the iterator has been exhausted 147 after producing `num_outputs` elements. 148 149 Raises: 150 AssertionError if any test fails. 151 """ 152 self.verify_run_with_breaks( 153 ds_fn, [0], 154 num_outputs, 155 sparse_tensors=sparse_tensors, 156 verify_exhausted=verify_exhausted) 157 158 def verify_fully_used_iterator(self, 159 ds_fn, 160 num_outputs, 161 sparse_tensors=False): 162 """Verifies that saving and restoring a fully used iterator works. 163 164 Note that this only checks saving and restoring an iterator from which 165 `num_outputs` items have been produced but does not check for an 166 exhausted iterator, i.e., one from which an OutOfRange error has been 167 returned. 168 169 Args: 170 ds_fn: 0-argument function that returns a Dataset. 171 num_outputs: Total number of outputs expected from this Dataset. 172 sparse_tensors: Whether dataset is built from SparseTensor(s). 173 174 Raises: 175 AssertionError if test fails. 176 """ 177 self.verify_run_with_breaks( 178 ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors) 179 180 def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False): 181 """Verifies that saving and restoring an exhausted iterator works. 182 183 An exhausted iterator is one which has returned an OutOfRange error. 184 185 Args: 186 ds_fn: 0-argument function that returns a Dataset. 187 num_outputs: Total number of outputs expected from this Dataset. 188 sparse_tensors: Whether dataset is built from SparseTensor(s). 189 190 Raises: 191 AssertionError if any test fails. 192 """ 193 self.gen_outputs( 194 ds_fn, [], 195 num_outputs, 196 verify_exhausted=True, 197 sparse_tensors=sparse_tensors) 198 actual = self.gen_outputs( 199 ds_fn, [], 200 0, 201 ckpt_saved=True, 202 verify_exhausted=True, 203 sparse_tensors=sparse_tensors) 204 self.assertLen(actual, 0) 205 206 def verify_multiple_breaks(self, 207 ds_fn, 208 num_outputs, 209 num_breaks=10, 210 sparse_tensors=False, 211 verify_exhausted=True): 212 """Attempts to save/restore at multiple break points. 213 214 Args: 215 ds_fn: 0-argument function that returns a Dataset. 216 num_outputs: Total number of outputs expected from this Dataset. 217 num_breaks: The number of break points. These are uniformly spread in [0, 218 num_outputs] both inclusive. 219 sparse_tensors: Whether dataset is built from SparseTensor(s). 220 verify_exhausted: Whether to verify that the iterator has been exhausted 221 after producing `num_outputs` elements. 222 223 Raises: 224 AssertionError if any test fails. 225 """ 226 self.verify_run_with_breaks( 227 ds_fn, 228 self.gen_break_points(num_outputs, num_breaks), 229 num_outputs, 230 sparse_tensors=sparse_tensors, 231 verify_exhausted=verify_exhausted) 232 233 def verify_reset_restored_iterator(self, 234 ds_fn, 235 num_outputs, 236 break_point=None, 237 sparse_tensors=False, 238 verify_exhausted=True): 239 """Attempts to re-initialize a restored iterator. 240 241 This is useful when restoring a training checkpoint during validation. 242 243 Args: 244 ds_fn: 0-argument function that returns a Dataset. 245 num_outputs: Total number of outputs expected from this Dataset. 246 break_point: Break point. Optional. Defaults to num_outputs/2. 247 sparse_tensors: Whether dataset is built from SparseTensor(s). 248 verify_exhausted: Whether to verify that the iterator has been exhausted 249 after producing `num_outputs` elements. 250 251 Raises: 252 AssertionError if any test fails. 253 """ 254 if context.executing_eagerly(): 255 self.skipTest("Eager mode iteration do not support re-initialization.") 256 257 break_point = num_outputs // 2 if not break_point else break_point 258 259 # Collect ground truth containing all outputs. 260 expected = self.gen_outputs( 261 ds_fn, [], 262 num_outputs, 263 sparse_tensors=sparse_tensors, 264 verify_exhausted=verify_exhausted) 265 266 # Skip some items and save checkpoint. 267 self.gen_outputs( 268 ds_fn, [], 269 break_point, 270 sparse_tensors=sparse_tensors, 271 verify_exhausted=False) 272 273 actual = [] 274 # Restore from checkpoint and then run init_op. 275 with ops.Graph().as_default() as g: 276 saver = self._import_meta_graph() 277 init_op, get_next_op = self._get_iterator_ops_from_collection( 278 ds_fn, sparse_tensors=sparse_tensors) 279 get_next_op = remove_variants(get_next_op) 280 with self.session(graph=g) as sess: 281 self._initialize(init_op, sess) 282 self._restore(saver, sess) 283 self._initialize(init_op, sess) 284 for _ in range(num_outputs): 285 actual.append(sess.run(get_next_op)) 286 if verify_exhausted: 287 with self.assertRaises(errors.OutOfRangeError): 288 sess.run(get_next_op) 289 self.match(expected, actual) 290 291 def verify_error_on_save(self, 292 ds_fn, 293 num_outputs, 294 error, 295 break_point=None, 296 sparse_tensors=False): 297 """Attempts to save a non-saveable iterator. 298 299 Args: 300 ds_fn: 0-argument function that returns a Dataset. 301 num_outputs: Total number of outputs expected from this Dataset. 302 error: Declared error when trying to save iterator. 303 break_point: Break point. Optional. Defaults to num_outputs/2. 304 sparse_tensors: Whether dataset is built from SparseTensor(s). 305 306 Raises: 307 AssertionError if any test fails. 308 """ 309 break_point = num_outputs // 2 if not break_point else break_point 310 if context.executing_eagerly(): 311 iterator = iter(ds_fn()) 312 ckpt = tracking_util.Checkpoint(iterator=iterator) 313 for _ in range(break_point): 314 next(iterator) 315 with self.assertRaises(error): 316 ckpt.save(self._ckpt_path()) 317 else: 318 with ops.Graph().as_default() as g: 319 init_op, get_next_op, saver = self._build_graph( 320 ds_fn, sparse_tensors=sparse_tensors) 321 get_next_op = remove_variants(get_next_op) 322 with self.session(graph=g) as sess: 323 self._initialize(init_op, sess) 324 for _ in range(break_point): 325 sess.run(get_next_op) 326 with self.assertRaises(error): 327 self._save(sess, saver) 328 329 def verify_run_with_breaks(self, 330 ds_fn, 331 break_points, 332 num_outputs, 333 sparse_tensors=False, 334 verify_exhausted=True): 335 """Verifies that ds_fn() produces the same outputs with and without breaks. 336 337 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it 338 *without* stopping at break points. 339 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it 340 with stopping at break points. 341 342 Deep matches outputs from 1 and 2. 343 344 Args: 345 ds_fn: 0-argument function that returns a Dataset. 346 break_points: A list of integers. For each `break_point` in 347 `break_points`, we produce outputs till `break_point` number of items 348 have been produced and then checkpoint the state. The current graph and 349 session are destroyed and a new graph and session are used to produce 350 outputs till next checkpoint or till `num_outputs` elements have been 351 produced. `break_point` must be <= `num_outputs`. 352 num_outputs: Total number of outputs expected from this Dataset. 353 sparse_tensors: Whether dataset is built from SparseTensor(s). 354 verify_exhausted: Whether to verify that the iterator has been exhausted 355 after producing `num_outputs` elements. 356 357 Raises: 358 AssertionError if any test fails. 359 """ 360 expected = self.gen_outputs( 361 ds_fn, [], 362 num_outputs, 363 sparse_tensors=sparse_tensors, 364 verify_exhausted=verify_exhausted) 365 366 actual = self.gen_outputs( 367 ds_fn, 368 break_points, 369 num_outputs, 370 sparse_tensors=sparse_tensors, 371 verify_exhausted=verify_exhausted) 372 373 self.match(expected, actual) 374 375 def gen_outputs(self, 376 ds_fn, 377 break_points, 378 num_outputs, 379 ckpt_saved=False, 380 sparse_tensors=False, 381 verify_exhausted=True, 382 save_checkpoint_at_end=True): 383 """Generates elements from input dataset while stopping at break points. 384 385 Produces `num_outputs` outputs and saves the state of the iterator in the 386 Saver checkpoint. 387 388 Args: 389 ds_fn: 0-argument function that returns the dataset. 390 break_points: A list of integers. For each `break_point` in 391 `break_points`, we produce outputs till `break_point` number of items 392 have been produced and then checkpoint the state. The current graph and 393 session are destroyed and a new graph and session are used to produce 394 outputs till next checkpoint or till `num_outputs` elements have been 395 produced. `break_point` must be <= `num_outputs`. 396 num_outputs: The total number of outputs to produce from the iterator. 397 ckpt_saved: Whether a checkpoint already exists. 398 sparse_tensors: Whether dataset is built from SparseTensor(s). 399 verify_exhausted: Whether to verify that the iterator has been exhausted 400 after producing `num_outputs` elements. 401 save_checkpoint_at_end: Whether to save a checkpoint after producing all 402 outputs. If False, checkpoints are saved each break point but not at the 403 end. Note that checkpoints overwrite each other so there is always only 404 a single checkpoint available. Defaults to True. 405 406 Returns: 407 A list of `num_outputs` items. 408 """ 409 outputs = [] 410 411 if context.executing_eagerly(): 412 for i in range(len(break_points) + 1): 413 iterator = iter(ds_fn()) 414 ckpt = tracking_util.Checkpoint(iterator=iterator) 415 if ckpt_saved: 416 ckpt_path = self._latest_ckpt() 417 ckpt.restore(ckpt_path) 418 start = break_points[i - 1] if i > 0 else 0 419 end = break_points[i] if i < len(break_points) else num_outputs 420 num_iters = end - start 421 for _ in range(num_iters): 422 outputs.append(self.evaluate(next(iterator))) 423 if i == len(break_points) and verify_exhausted: 424 with self.assertRaises(StopIteration): 425 next(iterator) 426 if save_checkpoint_at_end or i < len(break_points): 427 ckpt_path = ckpt.save(self._ckpt_path()) 428 ckpt_saved = True 429 else: 430 def get_ops(): 431 if ckpt_saved: 432 saver = self._import_meta_graph() 433 init_op, get_next_op = self._get_iterator_ops_from_collection( 434 ds_fn, sparse_tensors=sparse_tensors) 435 else: 436 init_op, get_next_op, saver = self._build_graph( 437 ds_fn, sparse_tensors=sparse_tensors) 438 return init_op, get_next_op, saver 439 440 for i in range(len(break_points) + 1): 441 with ops.Graph().as_default() as g: 442 init_op, get_next_op, saver = get_ops() 443 get_next_op = remove_variants(get_next_op) 444 with self.session(graph=g) as sess: 445 if ckpt_saved: 446 self._initialize(init_op, sess) 447 self._restore(saver, sess) 448 else: 449 self._initialize(init_op, sess) 450 start = break_points[i - 1] if i > 0 else 0 451 end = break_points[i] if i < len(break_points) else num_outputs 452 num_iters = end - start 453 for _ in range(num_iters): 454 outputs.append(sess.run(get_next_op)) 455 if i == len(break_points) and verify_exhausted: 456 with self.assertRaises(errors.OutOfRangeError): 457 sess.run(get_next_op) 458 if save_checkpoint_at_end or i < len(break_points): 459 self._save(sess, saver) 460 ckpt_saved = True 461 462 return outputs 463 464 def match(self, expected, actual): 465 """Matches nested structures. 466 467 Recursively matches shape and values of `expected` and `actual`. 468 Handles scalars, numpy arrays and other python sequence containers 469 e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue. 470 471 Args: 472 expected: Nested structure 1. 473 actual: Nested structure 2. 474 475 Raises: 476 AssertionError if matching fails. 477 """ 478 if isinstance(expected, np.ndarray): 479 expected = expected.tolist() 480 if isinstance(actual, np.ndarray): 481 actual = actual.tolist() 482 self.assertEqual(type(expected), type(actual)) 483 484 if nest.is_nested(expected): 485 self.assertEqual(len(expected), len(actual)) 486 if isinstance(expected, dict): 487 for key1, key2 in zip(sorted(expected), sorted(actual)): 488 self.assertEqual(key1, key2) 489 self.match(expected[key1], actual[key2]) 490 else: 491 for item1, item2 in zip(expected, actual): 492 self.match(item1, item2) 493 elif isinstance(expected, sparse_tensor.SparseTensorValue): 494 self.match((expected.indices, expected.values, expected.dense_shape), 495 (actual.indices, actual.values, actual.dense_shape)) 496 elif isinstance(expected, ragged_tensor_value.RaggedTensorValue): 497 self.match((expected.values, expected.row_splits), 498 (actual.values, actual.row_splits)) 499 else: 500 self.assertEqual(expected, actual) 501 502 def does_not_match(self, expected, actual): 503 with self.assertRaises(AssertionError): 504 self.match(expected, actual) 505 506 def gen_break_points(self, num_outputs, num_samples=10): 507 """Generates `num_samples` breaks points in [0, num_outputs].""" 508 return np.linspace(0, num_outputs, num_samples, dtype=int) 509 510 def _build_graph(self, ds_fn, sparse_tensors=False): 511 dataset = ds_fn() 512 iterator = dataset_ops.make_initializable_iterator(dataset) 513 external_state_policy = dataset.options().experimental_external_state_policy 514 saveable = contrib_iterator_ops.make_saveable_from_iterator( 515 iterator, external_state_policy=external_state_policy) 516 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) 517 init_op = iterator.initializer 518 if sparse_tensors: 519 get_next = sparse_tensor.SparseTensor(*iterator.get_next()) 520 else: 521 get_next = iterator.get_next() 522 self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, 523 sparse_tensors) 524 saver = saver_lib.Saver(allow_empty=True) 525 return init_op, get_next, saver 526 527 def _add_iterator_ops_to_collection(self, 528 init_op, 529 get_next, 530 ds_fn, 531 sparse_tensors=False): 532 ops.add_to_collection("iterator_ops", init_op) 533 # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections 534 # do not support tuples we flatten the tensors and restore the shape in 535 # `_get_iterator_ops_from_collection`. 536 if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. 537 ops.add_to_collection("iterator_ops", get_next.indices) 538 ops.add_to_collection("iterator_ops", get_next.values) 539 ops.add_to_collection("iterator_ops", get_next.dense_shape) 540 return 541 542 get_next_list = nest.flatten(get_next) 543 for i, output_class in enumerate( 544 nest.flatten(self._get_output_classes(ds_fn))): 545 if output_class is sparse_tensor.SparseTensor: 546 ops.add_to_collection("iterator_ops", get_next_list[i].indices) 547 ops.add_to_collection("iterator_ops", get_next_list[i].values) 548 ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape) 549 else: 550 ops.add_to_collection("iterator_ops", get_next_list[i]) 551 552 def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): 553 all_ops = ops.get_collection("iterator_ops") 554 if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. 555 init_op, indices, values, dense_shape = all_ops 556 return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) 557 get_next_list = [] 558 i = 1 559 for output_class in nest.flatten(self._get_output_classes(ds_fn)): 560 if output_class is sparse_tensor.SparseTensor: 561 indices, values, dense_shape = all_ops[i:i + 3] 562 i += 3 563 get_next_list.append( 564 sparse_tensor.SparseTensor(indices, values, dense_shape)) 565 else: 566 get_next_list.append(all_ops[i]) 567 i += 1 568 return all_ops[0], nest.pack_sequence_as( 569 self._get_output_types(ds_fn), get_next_list) 570 571 def _get_output_types(self, ds_fn): 572 assert not context.executing_eagerly() 573 with ops.Graph().as_default(): 574 return dataset_ops.get_legacy_output_types(ds_fn()) 575 576 def _get_output_shapes(self, ds_fn): 577 assert not context.executing_eagerly() 578 with ops.Graph().as_default(): 579 return dataset_ops.get_legacy_output_shapes(ds_fn()) 580 581 def _get_output_classes(self, ds_fn): 582 assert not context.executing_eagerly() 583 with ops.Graph().as_default(): 584 return dataset_ops.get_legacy_output_classes(ds_fn()) 585 586 def _ckpt_path(self): 587 return os.path.join(self.get_temp_dir(), "iterator") 588 589 def _latest_ckpt(self): 590 return checkpoint_management.latest_checkpoint(self.get_temp_dir()) 591 592 def _save(self, sess, saver): 593 saver.save(sess, self._ckpt_path()) 594 595 def _restore(self, saver, sess): 596 sess.run(lookup_ops.tables_initializer()) 597 saver.restore(sess, self._latest_ckpt()) 598 599 def _initialize(self, init_op, sess): 600 sess.run(variables.global_variables_initializer()) 601 sess.run(lookup_ops.tables_initializer()) 602 sess.run(init_op) 603 604 def _import_meta_graph(self): 605 meta_file_path = self._ckpt_path() + ".meta" 606 return saver_lib.import_meta_graph(meta_file_path) 607 608 def _delete_ckpt(self): 609 # Remove all checkpoint files. 610 prefix = self._ckpt_path() 611 pattern = prefix + "*" 612 files = gfile.Glob(pattern) 613 map(gfile.Remove, files) 614