xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/experimental/ops/iterator_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Iterator ops."""
16
17from tensorflow.python.checkpoint import checkpoint_management
18from tensorflow.python.data.ops import iterator_ops
19from tensorflow.python.data.ops import options as options_lib
20from tensorflow.python.framework import ops
21from tensorflow.python.training import basic_session_run_hooks
22from tensorflow.python.training import saver as saver_lib
23from tensorflow.python.training import session_run_hook
24from tensorflow.python.util import deprecation
25from tensorflow.python.util.tf_export import tf_export
26
27
28def _convert_external_state_policy_to_enum(external_state_policy):
29  if isinstance(external_state_policy, options_lib.ExternalStatePolicy):
30    return external_state_policy
31  if external_state_policy == "warn":
32    return options_lib.ExternalStatePolicy.WARN
33  if external_state_policy == "ignore":
34    return options_lib.ExternalStatePolicy.IGNORE
35  if external_state_policy == "fail":
36    return options_lib.ExternalStatePolicy.FAIL
37  raise ValueError(
38      f"Invalid `ExternalStatePolicy.` Supported values include 'warn', "
39      f"'ignore', and 'fail.' Received {external_state_policy}."
40  )
41
42
43@tf_export("data.experimental.make_saveable_from_iterator")
44@deprecation.deprecated(
45    None, "`make_saveable_from_iterator` is intended for use in TF1 with "
46    "`tf.compat.v1.Saver`. In TF2, use `tf.train.Checkpoint` instead.")
47def make_saveable_from_iterator(iterator, external_state_policy=None):
48  """Returns a SaveableObject for saving/restoring iterator state using Saver.
49
50  Args:
51    iterator: Iterator.
52    external_state_policy: A string that identifies how to handle input
53      pipelines that depend on external state. Possible values are
54      'ignore': The external state is silently ignored.
55      'warn': The external state is ignored, logging a warning.
56      'fail': The operation fails upon encountering external state.
57      By default we set it to 'fail'.
58
59  Returns:
60    A SaveableObject for saving/restoring iterator state using Saver.
61
62  Raises:
63    ValueError: If iterator does not support checkpointing.
64    ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or
65      'fail'.
66
67  For example:
68
69  ```python
70  with tf.Graph().as_default():
71    ds = tf.data.Dataset.range(10)
72    iterator = ds.make_initializable_iterator()
73    # Build the iterator SaveableObject.
74    saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator)
75    # Add the SaveableObject to the SAVEABLE_OBJECTS collection so
76    # it can be automatically saved using Saver.
77    tf.compat.v1.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
78    saver = tf.compat.v1.train.Saver()
79
80    while continue_training:
81      ... Perform training ...
82      if should_save_checkpoint:
83        saver.save()
84  ```
85
86  Note: When restoring the iterator, the existing iterator state is completely
87  discarded. This means that any changes you may have made to the Dataset
88  graph will be discarded as well! This includes the new Dataset graph
89  that you may have built during validation. So, while running validation,
90  make sure to run the initializer for the validation input pipeline after
91  restoring the checkpoint.
92
93  Note: Not all iterators support checkpointing yet. Attempting to save the
94  state of an unsupported iterator will throw an error.
95  """
96  if external_state_policy is None:
97    external_state_policy = "fail"
98  policy_enum = _convert_external_state_policy_to_enum(external_state_policy)
99  return iterator_ops._IteratorSaveable(  # pylint: disable=protected-access
100      iterator._iterator_resource,  # pylint: disable=protected-access
101      iterator._iterator_resource.name,  # pylint: disable=protected-access
102      external_state_policy=policy_enum)
103
104
105@tf_export("data.experimental.CheckpointInputPipelineHook")
106class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
107  """Checkpoints input pipeline state every N steps or seconds.
108
109  This hook saves the state of the iterators in the `Graph` so that when
110  training is resumed the input pipeline continues from where it left off.
111  This could potentially avoid overfitting in certain pipelines where the
112  number of training steps per eval are small compared to the dataset
113  size or if the training pipeline is pre-empted.
114
115  Differences from `CheckpointSaverHook`:
116  1. Saves only the input pipelines in the "iterators" collection and not the
117     global variables or other saveable objects.
118  2. Does not write the `GraphDef` and `MetaGraphDef` to the summary.
119
120  Example of checkpointing the training pipeline:
121
122  ```python
123  est = tf.estimator.Estimator(model_fn)
124  while True:
125    est.train(
126        train_input_fn,
127        hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)],
128        steps=train_steps_per_eval)
129    # Note: We do not pass the hook here.
130    metrics = est.evaluate(eval_input_fn)
131    if should_stop_the_training(metrics):
132      break
133  ```
134
135  This hook should be used if the input pipeline state needs to be saved
136  separate from the model checkpoint. Doing so may be useful for a few reasons:
137  1. The input pipeline checkpoint may be large, if there are large shuffle
138     or prefetch buffers for instance, and may bloat the checkpoint size.
139  2. If the input pipeline is shared between training and validation, restoring
140     the checkpoint during validation may override the validation input
141     pipeline.
142
143  For saving the input pipeline checkpoint alongside the model weights use
144  `tf.data.experimental.make_saveable_from_iterator` directly to create a
145  `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however,
146  that you will need to be careful not to restore the training iterator during
147  eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS
148  collector when building the eval graph.
149  """
150
151  def __init__(self, estimator, external_state_policy=None):
152    """Initializes a `CheckpointInputPipelineHook`.
153
154    If the input pipeline depends on external state (e.g. seeds for
155    RandomUniform) beyond the input pipeline, this hook would be unable to
156    serialize and deserialize that state. If its acceptable to ignore that state
157    change the external_state_policy argument to 'warn' or 'ignore'. For e.g.
158
159    ```python
160    est = tf.estimator.Estimator(model_fn)
161    while True:
162      est.train(
163          train_input_fn,
164          hooks=[tf.data.experimental.CheckpointInputPipelineHook(
165              est, external_state_policy='warn')],
166          steps=train_steps_per_eval)
167      # Note: We do not pass the hook here.
168      metrics = est.evaluate(eval_input_fn)
169      if should_stop_the_training(metrics):
170        break
171    ```
172
173    Args:
174      estimator: Estimator.
175      external_state_policy: A string that identifies how to handle input
176        pipelines that depend on external state. Possible values are
177        'ignore': The external state is silently ignored.
178        'warn': The external state is ignored, logging a warning.
179        'fail': The operation fails upon encountering external state.
180        By default we set it to 'fail'.
181
182    Raises:
183      ValueError: One of `save_steps` or `save_secs` should be set.
184      ValueError: At most one of saver or scaffold should be set.
185      ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or
186        'fail'.
187    """
188    if external_state_policy is None:
189      external_state_policy = "fail"
190    self._external_state_policy = _convert_external_state_policy_to_enum(
191        external_state_policy)
192    # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
193    # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
194    # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
195    # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
196    # to be different to avoid conflicts with the model checkpoint.
197
198    # pylint: disable=protected-access
199    checkpoint_prefix = "input"
200    if estimator._config.num_worker_replicas > 1:
201      # Distributed setting.
202      suffix = "_{}_{}".format(estimator._config.task_type,
203                               estimator._config.task_id)
204      checkpoint_prefix += suffix
205    # pylint: enable=protected-access
206
207    # We use a composition paradigm instead of inheriting from
208    # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
209    # to check whether a `CheckpointSaverHook` is already present in the list
210    # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
211    # would thwart this behavior. This hook checkpoints *only the iterators*
212    # and not the graph variables.
213    self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
214        estimator.model_dir,
215        save_secs=estimator._config.save_checkpoints_secs,  # pylint: disable=protected-access
216        save_steps=estimator._config.save_checkpoints_steps,  # pylint: disable=protected-access
217        checkpoint_basename=checkpoint_prefix + ".ckpt")
218
219    # Name for the protocol buffer file that will contain the list of most
220    # recent checkpoints stored as a `CheckpointState` protocol buffer.
221    # This file, kept in the same directory as the checkpoint files, is
222    # automatically managed by the `Saver` to keep track of recent checkpoints.
223    # The default name used by the `Saver` for this file is "checkpoint". Here
224    # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
225    # `checkpoint_dir` is the same as the model checkpoint directory, there are
226    # no conflicts during restore.
227    self._latest_filename = "checkpoint_" + checkpoint_prefix
228
229  def begin(self):
230    # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
231    # collection if no `Saver` or `Scaffold` is provided.
232    # pylint: disable=protected-access
233    if (self._checkpoint_saver_hook._saver is None and
234        self._checkpoint_saver_hook._scaffold is None):
235      iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
236      saveables = [
237          iterator_ops._IteratorSaveable(
238              i, i.name, external_state_policy=self._external_state_policy)
239          for i in iterators
240      ]
241      self._checkpoint_saver_hook._saver = _CustomSaver(
242          saveables, self._latest_filename, sharded=True)
243    # pylint: enable=protected-access
244    self._checkpoint_saver_hook.begin()
245
246  def after_create_session(self, session, coord):
247    # If a new session was created, we set _first_run to True so that we can
248    # restore if needed.
249    self._first_run = True
250
251  def _restore_or_save_initial_ckpt(self, session):
252    # Ideally this should be run in after_create_session but is not for the
253    # following reason:
254    # Currently there is no way of enforcing an order of running the
255    # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
256    # is run *after* this hook. That is troublesome because
257    # 1. If a checkpoint exists and this hook restores it, the initializer hook
258    #    will override it.
259    # 2. If no checkpoint exists, this hook will try to save an uninitialized
260    #    iterator which will result in an exception.
261    #
262    # As a temporary fix we enter the following implicit contract between this
263    # hook and the _DatasetInitializerHook.
264    # 1. The _DatasetInitializerHook initializes the iterator in the call to
265    #    after_create_session.
266    # 2. This hook saves the iterator on the first call to `before_run()`, which
267    #    is guaranteed to happen after `after_create_session()` of all hooks
268    #    have been run.
269
270    # Check if there is an existing checkpoint. If so, restore from it.
271    # pylint: disable=protected-access
272    latest_checkpoint_path = checkpoint_management.latest_checkpoint(
273        self._checkpoint_saver_hook._checkpoint_dir,
274        latest_filename=self._latest_filename)
275    if latest_checkpoint_path:
276      self._checkpoint_saver_hook._get_saver().restore(session,
277                                                       latest_checkpoint_path)
278    else:
279      # The checkpoint saved here is the state at step "global_step".
280      # Note: We do not save the GraphDef or MetaGraphDef here.
281      global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
282      self._checkpoint_saver_hook._save(session, global_step)
283      self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
284    # pylint: enable=protected-access
285
286  def before_run(self, run_context):
287    if self._first_run:
288      self._restore_or_save_initial_ckpt(run_context.session)
289      self._first_run = False
290    return self._checkpoint_saver_hook.before_run(run_context)
291
292  def after_run(self, run_context, run_values):
293    self._checkpoint_saver_hook.after_run(run_context, run_values)
294
295  def end(self, session):
296    self._checkpoint_saver_hook.end(session)
297
298
299class _CustomSaver(saver_lib.Saver):
300  """`Saver` with a different default `latest_filename`.
301
302  This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
303  the model ckpt saved by the `CheckpointSaverHook`.
304  """
305
306  def __init__(self, var_list, latest_filename, sharded=False):
307    super(_CustomSaver, self).__init__(var_list, sharded=sharded)
308    self._latest_filename = latest_filename
309
310  def save(self,
311           sess,
312           save_path,
313           global_step=None,
314           latest_filename=None,
315           meta_graph_suffix="meta",
316           write_meta_graph=True,
317           write_state=True,
318           strip_default_attrs=False):
319    return super(_CustomSaver, self).save(
320        sess, save_path, global_step, latest_filename or self._latest_filename,
321        meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
322