xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/kernel_tests/checkpoint_test_base.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base test class for checkpointing datasets."""
16
17import os
18
19import numpy as np
20from tensorflow.python.checkpoint import checkpoint as tracking_util
21from tensorflow.python.checkpoint import checkpoint_management
22from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.ops import options as options_lib
25from tensorflow.python.eager import context
26from tensorflow.python.framework import combinations
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.ops import lookup_ops
32from tensorflow.python.ops import variables
33from tensorflow.python.ops.ragged import ragged_tensor_value
34from tensorflow.python.platform import gfile
35from tensorflow.python.platform import test
36from tensorflow.python.training import saver as saver_lib
37from tensorflow.python.util import nest
38
39
40def remove_variants(get_next_op):
41  # TODO(b/72408568): Remove this once session.run can get variant tensors.
42  """Remove variants from a nest structure, so sess.run will execute."""
43
44  def _remove_variant(x):
45    if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
46      return ()
47    else:
48      return x
49
50  return nest.map_structure(_remove_variant, get_next_op)
51
52
53def default_test_combinations():
54  """Returns the default test combinations for testing checkpointing."""
55
56  def disable_optimizations(ds_fn):
57    options = options_lib.Options()
58    options.experimental_optimization.apply_default_optimizations = False
59
60    def ds_fn_no_opt():
61      return ds_fn().with_options(options)
62
63    return ds_fn_no_opt
64
65  def verify_unused_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
66    obj.verify_unused_iterator(
67        ds_fn=disable_optimizations(ds_fn=ds_fn),
68        num_outputs=num_outputs,
69        sparse_tensors=sparse_tensors)
70
71  verify_unused_iterator_combination = combinations.combine(
72      verify_fn=combinations.NamedObject("verify_unused_iterator",
73                                         verify_unused_iterator))
74
75  def verify_fully_used_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
76    obj.verify_fully_used_iterator(
77        ds_fn=disable_optimizations(ds_fn=ds_fn),
78        num_outputs=num_outputs,
79        sparse_tensors=sparse_tensors)
80
81  verify_fully_used_iterator_combination = combinations.combine(
82      verify_fn=combinations.NamedObject("verify_fully_used_iterator",
83                                         verify_fully_used_iterator))
84
85  def verify_exhausted_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
86    obj.verify_exhausted_iterator(
87        ds_fn=disable_optimizations(ds_fn=ds_fn),
88        num_outputs=num_outputs,
89        sparse_tensors=sparse_tensors)
90
91  verify_exhausted_iterator_combination = combinations.combine(
92      verify_fn=combinations.NamedObject("verify_exhausted_iterator",
93                                         verify_exhausted_iterator))
94
95  def verify_multiple_breaks(obj, ds_fn, num_outputs, sparse_tensors=False):
96    obj.verify_multiple_breaks(
97        ds_fn=disable_optimizations(ds_fn=ds_fn),
98        num_outputs=num_outputs,
99        sparse_tensors=sparse_tensors)
100
101  verify_multiple_breaks_combination = combinations.combine(
102      verify_fn=combinations.NamedObject("verify_multiple_breaks",
103                                         verify_multiple_breaks))
104
105  def verify_reset_restored_iterator(obj,
106                                     ds_fn,
107                                     num_outputs,
108                                     sparse_tensors=False):
109    obj.verify_reset_restored_iterator(
110        ds_fn=disable_optimizations(ds_fn=ds_fn),
111        num_outputs=num_outputs,
112        sparse_tensors=sparse_tensors)
113
114  verify_reset_restored_iterator_combination = combinations.combine(
115      verify_fn=combinations.NamedObject("verify_reset_restored_iterator",
116                                         verify_reset_restored_iterator))
117
118  return (verify_unused_iterator_combination +
119          verify_fully_used_iterator_combination +
120          verify_exhausted_iterator_combination +
121          verify_multiple_breaks_combination +
122          verify_reset_restored_iterator_combination)
123
124
125# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
126# (deprecated) saveable `SparseTensorSliceDataset`, once the API
127# `from_sparse_tensor_slices()` and related tests are deleted.
128class CheckpointTestBase(test.TestCase):
129  """Base test class for checkpointing datasets."""
130
131  def tearDown(self):
132    self._delete_ckpt()
133    super(CheckpointTestBase, self).tearDown()
134
135  def verify_unused_iterator(self,
136                             ds_fn,
137                             num_outputs,
138                             sparse_tensors=False,
139                             verify_exhausted=True):
140    """Verifies that saving and restoring an unused iterator works.
141
142    Args:
143      ds_fn: 0-argument function that returns a Dataset.
144      num_outputs: Total number of outputs expected from this Dataset.
145      sparse_tensors: Whether dataset is built from SparseTensor(s).
146      verify_exhausted: Whether to verify that the iterator has been exhausted
147        after producing `num_outputs` elements.
148
149    Raises:
150      AssertionError if any test fails.
151    """
152    self.verify_run_with_breaks(
153        ds_fn, [0],
154        num_outputs,
155        sparse_tensors=sparse_tensors,
156        verify_exhausted=verify_exhausted)
157
158  def verify_fully_used_iterator(self,
159                                 ds_fn,
160                                 num_outputs,
161                                 sparse_tensors=False):
162    """Verifies that saving and restoring a fully used iterator works.
163
164    Note that this only checks saving and restoring an iterator from which
165    `num_outputs` items have been produced but does not check for an
166    exhausted iterator, i.e., one from which an OutOfRange error has been
167    returned.
168
169    Args:
170      ds_fn: 0-argument function that returns a Dataset.
171      num_outputs: Total number of outputs expected from this Dataset.
172      sparse_tensors: Whether dataset is built from SparseTensor(s).
173
174    Raises:
175      AssertionError if test fails.
176    """
177    self.verify_run_with_breaks(
178        ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
179
180  def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
181    """Verifies that saving and restoring an exhausted iterator works.
182
183    An exhausted iterator is one which has returned an OutOfRange error.
184
185    Args:
186      ds_fn: 0-argument function that returns a Dataset.
187      num_outputs: Total number of outputs expected from this Dataset.
188      sparse_tensors: Whether dataset is built from SparseTensor(s).
189
190    Raises:
191      AssertionError if any test fails.
192    """
193    self.gen_outputs(
194        ds_fn, [],
195        num_outputs,
196        verify_exhausted=True,
197        sparse_tensors=sparse_tensors)
198    actual = self.gen_outputs(
199        ds_fn, [],
200        0,
201        ckpt_saved=True,
202        verify_exhausted=True,
203        sparse_tensors=sparse_tensors)
204    self.assertLen(actual, 0)
205
206  def verify_multiple_breaks(self,
207                             ds_fn,
208                             num_outputs,
209                             num_breaks=10,
210                             sparse_tensors=False,
211                             verify_exhausted=True):
212    """Attempts to save/restore at multiple break points.
213
214    Args:
215      ds_fn: 0-argument function that returns a Dataset.
216      num_outputs: Total number of outputs expected from this Dataset.
217      num_breaks: The number of break points. These are uniformly spread in [0,
218        num_outputs] both inclusive.
219      sparse_tensors: Whether dataset is built from SparseTensor(s).
220      verify_exhausted: Whether to verify that the iterator has been exhausted
221        after producing `num_outputs` elements.
222
223    Raises:
224      AssertionError if any test fails.
225    """
226    self.verify_run_with_breaks(
227        ds_fn,
228        self.gen_break_points(num_outputs, num_breaks),
229        num_outputs,
230        sparse_tensors=sparse_tensors,
231        verify_exhausted=verify_exhausted)
232
233  def verify_reset_restored_iterator(self,
234                                     ds_fn,
235                                     num_outputs,
236                                     break_point=None,
237                                     sparse_tensors=False,
238                                     verify_exhausted=True):
239    """Attempts to re-initialize a restored iterator.
240
241    This is useful when restoring a training checkpoint during validation.
242
243    Args:
244      ds_fn: 0-argument function that returns a Dataset.
245      num_outputs: Total number of outputs expected from this Dataset.
246      break_point: Break point. Optional. Defaults to num_outputs/2.
247      sparse_tensors: Whether dataset is built from SparseTensor(s).
248      verify_exhausted: Whether to verify that the iterator has been exhausted
249        after producing `num_outputs` elements.
250
251    Raises:
252      AssertionError if any test fails.
253    """
254    if context.executing_eagerly():
255      self.skipTest("Eager mode iteration do not support re-initialization.")
256
257    break_point = num_outputs // 2 if not break_point else break_point
258
259    # Collect ground truth containing all outputs.
260    expected = self.gen_outputs(
261        ds_fn, [],
262        num_outputs,
263        sparse_tensors=sparse_tensors,
264        verify_exhausted=verify_exhausted)
265
266    # Skip some items and save checkpoint.
267    self.gen_outputs(
268        ds_fn, [],
269        break_point,
270        sparse_tensors=sparse_tensors,
271        verify_exhausted=False)
272
273    actual = []
274    # Restore from checkpoint and then run init_op.
275    with ops.Graph().as_default() as g:
276      saver = self._import_meta_graph()
277      init_op, get_next_op = self._get_iterator_ops_from_collection(
278          ds_fn, sparse_tensors=sparse_tensors)
279      get_next_op = remove_variants(get_next_op)
280      with self.session(graph=g) as sess:
281        self._initialize(init_op, sess)
282        self._restore(saver, sess)
283        self._initialize(init_op, sess)
284        for _ in range(num_outputs):
285          actual.append(sess.run(get_next_op))
286        if verify_exhausted:
287          with self.assertRaises(errors.OutOfRangeError):
288            sess.run(get_next_op)
289    self.match(expected, actual)
290
291  def verify_error_on_save(self,
292                           ds_fn,
293                           num_outputs,
294                           error,
295                           break_point=None,
296                           sparse_tensors=False):
297    """Attempts to save a non-saveable iterator.
298
299    Args:
300      ds_fn: 0-argument function that returns a Dataset.
301      num_outputs: Total number of outputs expected from this Dataset.
302      error: Declared error when trying to save iterator.
303      break_point: Break point. Optional. Defaults to num_outputs/2.
304      sparse_tensors: Whether dataset is built from SparseTensor(s).
305
306    Raises:
307      AssertionError if any test fails.
308    """
309    break_point = num_outputs // 2 if not break_point else break_point
310    if context.executing_eagerly():
311      iterator = iter(ds_fn())
312      ckpt = tracking_util.Checkpoint(iterator=iterator)
313      for _ in range(break_point):
314        next(iterator)
315      with self.assertRaises(error):
316        ckpt.save(self._ckpt_path())
317    else:
318      with ops.Graph().as_default() as g:
319        init_op, get_next_op, saver = self._build_graph(
320            ds_fn, sparse_tensors=sparse_tensors)
321        get_next_op = remove_variants(get_next_op)
322        with self.session(graph=g) as sess:
323          self._initialize(init_op, sess)
324          for _ in range(break_point):
325            sess.run(get_next_op)
326          with self.assertRaises(error):
327            self._save(sess, saver)
328
329  def verify_run_with_breaks(self,
330                             ds_fn,
331                             break_points,
332                             num_outputs,
333                             sparse_tensors=False,
334                             verify_exhausted=True):
335    """Verifies that ds_fn() produces the same outputs with and without breaks.
336
337    1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
338       *without* stopping at break points.
339    2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
340       with stopping at break points.
341
342    Deep matches outputs from 1 and 2.
343
344    Args:
345      ds_fn: 0-argument function that returns a Dataset.
346      break_points: A list of integers. For each `break_point` in
347        `break_points`, we produce outputs till `break_point` number of items
348        have been produced and then checkpoint the state. The current graph and
349        session are destroyed and a new graph and session are used to produce
350        outputs till next checkpoint or till `num_outputs` elements have been
351        produced. `break_point` must be <= `num_outputs`.
352      num_outputs: Total number of outputs expected from this Dataset.
353      sparse_tensors: Whether dataset is built from SparseTensor(s).
354      verify_exhausted: Whether to verify that the iterator has been exhausted
355        after producing `num_outputs` elements.
356
357    Raises:
358      AssertionError if any test fails.
359    """
360    expected = self.gen_outputs(
361        ds_fn, [],
362        num_outputs,
363        sparse_tensors=sparse_tensors,
364        verify_exhausted=verify_exhausted)
365
366    actual = self.gen_outputs(
367        ds_fn,
368        break_points,
369        num_outputs,
370        sparse_tensors=sparse_tensors,
371        verify_exhausted=verify_exhausted)
372
373    self.match(expected, actual)
374
375  def gen_outputs(self,
376                  ds_fn,
377                  break_points,
378                  num_outputs,
379                  ckpt_saved=False,
380                  sparse_tensors=False,
381                  verify_exhausted=True,
382                  save_checkpoint_at_end=True):
383    """Generates elements from input dataset while stopping at break points.
384
385    Produces `num_outputs` outputs and saves the state of the iterator in the
386    Saver checkpoint.
387
388    Args:
389      ds_fn: 0-argument function that returns the dataset.
390      break_points: A list of integers. For each `break_point` in
391        `break_points`, we produce outputs till `break_point` number of items
392        have been produced and then checkpoint the state. The current graph and
393        session are destroyed and a new graph and session are used to produce
394        outputs till next checkpoint or till `num_outputs` elements have been
395        produced. `break_point` must be <= `num_outputs`.
396      num_outputs: The total number of outputs to produce from the iterator.
397      ckpt_saved: Whether a checkpoint already exists.
398      sparse_tensors:  Whether dataset is built from SparseTensor(s).
399      verify_exhausted: Whether to verify that the iterator has been exhausted
400        after producing `num_outputs` elements.
401      save_checkpoint_at_end: Whether to save a checkpoint after producing all
402        outputs. If False, checkpoints are saved each break point but not at the
403        end. Note that checkpoints overwrite each other so there is always only
404        a single checkpoint available. Defaults to True.
405
406    Returns:
407      A list of `num_outputs` items.
408    """
409    outputs = []
410
411    if context.executing_eagerly():
412      for i in range(len(break_points) + 1):
413        iterator = iter(ds_fn())
414        ckpt = tracking_util.Checkpoint(iterator=iterator)
415        if ckpt_saved:
416          ckpt_path = self._latest_ckpt()
417          ckpt.restore(ckpt_path)
418        start = break_points[i - 1] if i > 0 else 0
419        end = break_points[i] if i < len(break_points) else num_outputs
420        num_iters = end - start
421        for _ in range(num_iters):
422          outputs.append(self.evaluate(next(iterator)))
423        if i == len(break_points) and verify_exhausted:
424          with self.assertRaises(StopIteration):
425            next(iterator)
426        if save_checkpoint_at_end or i < len(break_points):
427          ckpt_path = ckpt.save(self._ckpt_path())
428          ckpt_saved = True
429    else:
430      def get_ops():
431        if ckpt_saved:
432          saver = self._import_meta_graph()
433          init_op, get_next_op = self._get_iterator_ops_from_collection(
434              ds_fn, sparse_tensors=sparse_tensors)
435        else:
436          init_op, get_next_op, saver = self._build_graph(
437              ds_fn, sparse_tensors=sparse_tensors)
438        return init_op, get_next_op, saver
439
440      for i in range(len(break_points) + 1):
441        with ops.Graph().as_default() as g:
442          init_op, get_next_op, saver = get_ops()
443          get_next_op = remove_variants(get_next_op)
444          with self.session(graph=g) as sess:
445            if ckpt_saved:
446              self._initialize(init_op, sess)
447              self._restore(saver, sess)
448            else:
449              self._initialize(init_op, sess)
450            start = break_points[i - 1] if i > 0 else 0
451            end = break_points[i] if i < len(break_points) else num_outputs
452            num_iters = end - start
453            for _ in range(num_iters):
454              outputs.append(sess.run(get_next_op))
455            if i == len(break_points) and verify_exhausted:
456              with self.assertRaises(errors.OutOfRangeError):
457                sess.run(get_next_op)
458            if save_checkpoint_at_end or i < len(break_points):
459              self._save(sess, saver)
460              ckpt_saved = True
461
462    return outputs
463
464  def match(self, expected, actual):
465    """Matches nested structures.
466
467    Recursively matches shape and values of `expected` and `actual`.
468    Handles scalars, numpy arrays and other python sequence containers
469    e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue.
470
471    Args:
472      expected: Nested structure 1.
473      actual: Nested structure 2.
474
475    Raises:
476      AssertionError if matching fails.
477    """
478    if isinstance(expected, np.ndarray):
479      expected = expected.tolist()
480    if isinstance(actual, np.ndarray):
481      actual = actual.tolist()
482    self.assertEqual(type(expected), type(actual))
483
484    if nest.is_nested(expected):
485      self.assertEqual(len(expected), len(actual))
486      if isinstance(expected, dict):
487        for key1, key2 in zip(sorted(expected), sorted(actual)):
488          self.assertEqual(key1, key2)
489          self.match(expected[key1], actual[key2])
490      else:
491        for item1, item2 in zip(expected, actual):
492          self.match(item1, item2)
493    elif isinstance(expected, sparse_tensor.SparseTensorValue):
494      self.match((expected.indices, expected.values, expected.dense_shape),
495                 (actual.indices, actual.values, actual.dense_shape))
496    elif isinstance(expected, ragged_tensor_value.RaggedTensorValue):
497      self.match((expected.values, expected.row_splits),
498                 (actual.values, actual.row_splits))
499    else:
500      self.assertEqual(expected, actual)
501
502  def does_not_match(self, expected, actual):
503    with self.assertRaises(AssertionError):
504      self.match(expected, actual)
505
506  def gen_break_points(self, num_outputs, num_samples=10):
507    """Generates `num_samples` breaks points in [0, num_outputs]."""
508    return np.linspace(0, num_outputs, num_samples, dtype=int)
509
510  def _build_graph(self, ds_fn, sparse_tensors=False):
511    dataset = ds_fn()
512    iterator = dataset_ops.make_initializable_iterator(dataset)
513    external_state_policy = dataset.options().experimental_external_state_policy
514    saveable = contrib_iterator_ops.make_saveable_from_iterator(
515        iterator, external_state_policy=external_state_policy)
516    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
517    init_op = iterator.initializer
518    if sparse_tensors:
519      get_next = sparse_tensor.SparseTensor(*iterator.get_next())
520    else:
521      get_next = iterator.get_next()
522    self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
523                                         sparse_tensors)
524    saver = saver_lib.Saver(allow_empty=True)
525    return init_op, get_next, saver
526
527  def _add_iterator_ops_to_collection(self,
528                                      init_op,
529                                      get_next,
530                                      ds_fn,
531                                      sparse_tensors=False):
532    ops.add_to_collection("iterator_ops", init_op)
533    # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
534    # do not support tuples we flatten the tensors and restore the shape in
535    # `_get_iterator_ops_from_collection`.
536    if sparse_tensors:  # specific for deprecated `from_sparse_tensor_slices`.
537      ops.add_to_collection("iterator_ops", get_next.indices)
538      ops.add_to_collection("iterator_ops", get_next.values)
539      ops.add_to_collection("iterator_ops", get_next.dense_shape)
540      return
541
542    get_next_list = nest.flatten(get_next)
543    for i, output_class in enumerate(
544        nest.flatten(self._get_output_classes(ds_fn))):
545      if output_class is sparse_tensor.SparseTensor:
546        ops.add_to_collection("iterator_ops", get_next_list[i].indices)
547        ops.add_to_collection("iterator_ops", get_next_list[i].values)
548        ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
549      else:
550        ops.add_to_collection("iterator_ops", get_next_list[i])
551
552  def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
553    all_ops = ops.get_collection("iterator_ops")
554    if sparse_tensors:  # specific for deprecated `from_sparse_tensor_slices`.
555      init_op, indices, values, dense_shape = all_ops
556      return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
557    get_next_list = []
558    i = 1
559    for output_class in nest.flatten(self._get_output_classes(ds_fn)):
560      if output_class is sparse_tensor.SparseTensor:
561        indices, values, dense_shape = all_ops[i:i + 3]
562        i += 3
563        get_next_list.append(
564            sparse_tensor.SparseTensor(indices, values, dense_shape))
565      else:
566        get_next_list.append(all_ops[i])
567        i += 1
568    return all_ops[0], nest.pack_sequence_as(
569        self._get_output_types(ds_fn), get_next_list)
570
571  def _get_output_types(self, ds_fn):
572    assert not context.executing_eagerly()
573    with ops.Graph().as_default():
574      return dataset_ops.get_legacy_output_types(ds_fn())
575
576  def _get_output_shapes(self, ds_fn):
577    assert not context.executing_eagerly()
578    with ops.Graph().as_default():
579      return dataset_ops.get_legacy_output_shapes(ds_fn())
580
581  def _get_output_classes(self, ds_fn):
582    assert not context.executing_eagerly()
583    with ops.Graph().as_default():
584      return dataset_ops.get_legacy_output_classes(ds_fn())
585
586  def _ckpt_path(self):
587    return os.path.join(self.get_temp_dir(), "iterator")
588
589  def _latest_ckpt(self):
590    return checkpoint_management.latest_checkpoint(self.get_temp_dir())
591
592  def _save(self, sess, saver):
593    saver.save(sess, self._ckpt_path())
594
595  def _restore(self, saver, sess):
596    sess.run(lookup_ops.tables_initializer())
597    saver.restore(sess, self._latest_ckpt())
598
599  def _initialize(self, init_op, sess):
600    sess.run(variables.global_variables_initializer())
601    sess.run(lookup_ops.tables_initializer())
602    sess.run(init_op)
603
604  def _import_meta_graph(self):
605    meta_file_path = self._ckpt_path() + ".meta"
606    return saver_lib.import_meta_graph(meta_file_path)
607
608  def _delete_ckpt(self):
609    # Remove all checkpoint files.
610    prefix = self._ckpt_path()
611    pattern = prefix + "*"
612    files = gfile.Glob(pattern)
613    map(gfile.Remove, files)
614