1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Iterator ops.""" 16 17from tensorflow.python.checkpoint import checkpoint_management 18from tensorflow.python.data.ops import iterator_ops 19from tensorflow.python.data.ops import options as options_lib 20from tensorflow.python.framework import ops 21from tensorflow.python.training import basic_session_run_hooks 22from tensorflow.python.training import saver as saver_lib 23from tensorflow.python.training import session_run_hook 24from tensorflow.python.util import deprecation 25from tensorflow.python.util.tf_export import tf_export 26 27 28def _convert_external_state_policy_to_enum(external_state_policy): 29 if isinstance(external_state_policy, options_lib.ExternalStatePolicy): 30 return external_state_policy 31 if external_state_policy == "warn": 32 return options_lib.ExternalStatePolicy.WARN 33 if external_state_policy == "ignore": 34 return options_lib.ExternalStatePolicy.IGNORE 35 if external_state_policy == "fail": 36 return options_lib.ExternalStatePolicy.FAIL 37 raise ValueError( 38 f"Invalid `ExternalStatePolicy.` Supported values include 'warn', " 39 f"'ignore', and 'fail.' Received {external_state_policy}." 40 ) 41 42 43@tf_export("data.experimental.make_saveable_from_iterator") 44@deprecation.deprecated( 45 None, "`make_saveable_from_iterator` is intended for use in TF1 with " 46 "`tf.compat.v1.Saver`. In TF2, use `tf.train.Checkpoint` instead.") 47def make_saveable_from_iterator(iterator, external_state_policy=None): 48 """Returns a SaveableObject for saving/restoring iterator state using Saver. 49 50 Args: 51 iterator: Iterator. 52 external_state_policy: A string that identifies how to handle input 53 pipelines that depend on external state. Possible values are 54 'ignore': The external state is silently ignored. 55 'warn': The external state is ignored, logging a warning. 56 'fail': The operation fails upon encountering external state. 57 By default we set it to 'fail'. 58 59 Returns: 60 A SaveableObject for saving/restoring iterator state using Saver. 61 62 Raises: 63 ValueError: If iterator does not support checkpointing. 64 ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or 65 'fail'. 66 67 For example: 68 69 ```python 70 with tf.Graph().as_default(): 71 ds = tf.data.Dataset.range(10) 72 iterator = ds.make_initializable_iterator() 73 # Build the iterator SaveableObject. 74 saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator) 75 # Add the SaveableObject to the SAVEABLE_OBJECTS collection so 76 # it can be automatically saved using Saver. 77 tf.compat.v1.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) 78 saver = tf.compat.v1.train.Saver() 79 80 while continue_training: 81 ... Perform training ... 82 if should_save_checkpoint: 83 saver.save() 84 ``` 85 86 Note: When restoring the iterator, the existing iterator state is completely 87 discarded. This means that any changes you may have made to the Dataset 88 graph will be discarded as well! This includes the new Dataset graph 89 that you may have built during validation. So, while running validation, 90 make sure to run the initializer for the validation input pipeline after 91 restoring the checkpoint. 92 93 Note: Not all iterators support checkpointing yet. Attempting to save the 94 state of an unsupported iterator will throw an error. 95 """ 96 if external_state_policy is None: 97 external_state_policy = "fail" 98 policy_enum = _convert_external_state_policy_to_enum(external_state_policy) 99 return iterator_ops._IteratorSaveable( # pylint: disable=protected-access 100 iterator._iterator_resource, # pylint: disable=protected-access 101 iterator._iterator_resource.name, # pylint: disable=protected-access 102 external_state_policy=policy_enum) 103 104 105@tf_export("data.experimental.CheckpointInputPipelineHook") 106class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): 107 """Checkpoints input pipeline state every N steps or seconds. 108 109 This hook saves the state of the iterators in the `Graph` so that when 110 training is resumed the input pipeline continues from where it left off. 111 This could potentially avoid overfitting in certain pipelines where the 112 number of training steps per eval are small compared to the dataset 113 size or if the training pipeline is pre-empted. 114 115 Differences from `CheckpointSaverHook`: 116 1. Saves only the input pipelines in the "iterators" collection and not the 117 global variables or other saveable objects. 118 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary. 119 120 Example of checkpointing the training pipeline: 121 122 ```python 123 est = tf.estimator.Estimator(model_fn) 124 while True: 125 est.train( 126 train_input_fn, 127 hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)], 128 steps=train_steps_per_eval) 129 # Note: We do not pass the hook here. 130 metrics = est.evaluate(eval_input_fn) 131 if should_stop_the_training(metrics): 132 break 133 ``` 134 135 This hook should be used if the input pipeline state needs to be saved 136 separate from the model checkpoint. Doing so may be useful for a few reasons: 137 1. The input pipeline checkpoint may be large, if there are large shuffle 138 or prefetch buffers for instance, and may bloat the checkpoint size. 139 2. If the input pipeline is shared between training and validation, restoring 140 the checkpoint during validation may override the validation input 141 pipeline. 142 143 For saving the input pipeline checkpoint alongside the model weights use 144 `tf.data.experimental.make_saveable_from_iterator` directly to create a 145 `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, 146 that you will need to be careful not to restore the training iterator during 147 eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS 148 collector when building the eval graph. 149 """ 150 151 def __init__(self, estimator, external_state_policy=None): 152 """Initializes a `CheckpointInputPipelineHook`. 153 154 If the input pipeline depends on external state (e.g. seeds for 155 RandomUniform) beyond the input pipeline, this hook would be unable to 156 serialize and deserialize that state. If its acceptable to ignore that state 157 change the external_state_policy argument to 'warn' or 'ignore'. For e.g. 158 159 ```python 160 est = tf.estimator.Estimator(model_fn) 161 while True: 162 est.train( 163 train_input_fn, 164 hooks=[tf.data.experimental.CheckpointInputPipelineHook( 165 est, external_state_policy='warn')], 166 steps=train_steps_per_eval) 167 # Note: We do not pass the hook here. 168 metrics = est.evaluate(eval_input_fn) 169 if should_stop_the_training(metrics): 170 break 171 ``` 172 173 Args: 174 estimator: Estimator. 175 external_state_policy: A string that identifies how to handle input 176 pipelines that depend on external state. Possible values are 177 'ignore': The external state is silently ignored. 178 'warn': The external state is ignored, logging a warning. 179 'fail': The operation fails upon encountering external state. 180 By default we set it to 'fail'. 181 182 Raises: 183 ValueError: One of `save_steps` or `save_secs` should be set. 184 ValueError: At most one of saver or scaffold should be set. 185 ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or 186 'fail'. 187 """ 188 if external_state_policy is None: 189 external_state_policy = "fail" 190 self._external_state_policy = _convert_external_state_policy_to_enum( 191 external_state_policy) 192 # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or 193 # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines. 194 # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is 195 # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix 196 # to be different to avoid conflicts with the model checkpoint. 197 198 # pylint: disable=protected-access 199 checkpoint_prefix = "input" 200 if estimator._config.num_worker_replicas > 1: 201 # Distributed setting. 202 suffix = "_{}_{}".format(estimator._config.task_type, 203 estimator._config.task_id) 204 checkpoint_prefix += suffix 205 # pylint: enable=protected-access 206 207 # We use a composition paradigm instead of inheriting from 208 # `CheckpointSaverHook` because `Estimator` does an `isinstance` check 209 # to check whether a `CheckpointSaverHook` is already present in the list 210 # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` 211 # would thwart this behavior. This hook checkpoints *only the iterators* 212 # and not the graph variables. 213 self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( 214 estimator.model_dir, 215 save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access 216 save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access 217 checkpoint_basename=checkpoint_prefix + ".ckpt") 218 219 # Name for the protocol buffer file that will contain the list of most 220 # recent checkpoints stored as a `CheckpointState` protocol buffer. 221 # This file, kept in the same directory as the checkpoint files, is 222 # automatically managed by the `Saver` to keep track of recent checkpoints. 223 # The default name used by the `Saver` for this file is "checkpoint". Here 224 # we use the name "checkpoint_<checkpoint_prefix>" so that in case the 225 # `checkpoint_dir` is the same as the model checkpoint directory, there are 226 # no conflicts during restore. 227 self._latest_filename = "checkpoint_" + checkpoint_prefix 228 229 def begin(self): 230 # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` 231 # collection if no `Saver` or `Scaffold` is provided. 232 # pylint: disable=protected-access 233 if (self._checkpoint_saver_hook._saver is None and 234 self._checkpoint_saver_hook._scaffold is None): 235 iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) 236 saveables = [ 237 iterator_ops._IteratorSaveable( 238 i, i.name, external_state_policy=self._external_state_policy) 239 for i in iterators 240 ] 241 self._checkpoint_saver_hook._saver = _CustomSaver( 242 saveables, self._latest_filename, sharded=True) 243 # pylint: enable=protected-access 244 self._checkpoint_saver_hook.begin() 245 246 def after_create_session(self, session, coord): 247 # If a new session was created, we set _first_run to True so that we can 248 # restore if needed. 249 self._first_run = True 250 251 def _restore_or_save_initial_ckpt(self, session): 252 # Ideally this should be run in after_create_session but is not for the 253 # following reason: 254 # Currently there is no way of enforcing an order of running the 255 # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` 256 # is run *after* this hook. That is troublesome because 257 # 1. If a checkpoint exists and this hook restores it, the initializer hook 258 # will override it. 259 # 2. If no checkpoint exists, this hook will try to save an uninitialized 260 # iterator which will result in an exception. 261 # 262 # As a temporary fix we enter the following implicit contract between this 263 # hook and the _DatasetInitializerHook. 264 # 1. The _DatasetInitializerHook initializes the iterator in the call to 265 # after_create_session. 266 # 2. This hook saves the iterator on the first call to `before_run()`, which 267 # is guaranteed to happen after `after_create_session()` of all hooks 268 # have been run. 269 270 # Check if there is an existing checkpoint. If so, restore from it. 271 # pylint: disable=protected-access 272 latest_checkpoint_path = checkpoint_management.latest_checkpoint( 273 self._checkpoint_saver_hook._checkpoint_dir, 274 latest_filename=self._latest_filename) 275 if latest_checkpoint_path: 276 self._checkpoint_saver_hook._get_saver().restore(session, 277 latest_checkpoint_path) 278 else: 279 # The checkpoint saved here is the state at step "global_step". 280 # Note: We do not save the GraphDef or MetaGraphDef here. 281 global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) 282 self._checkpoint_saver_hook._save(session, global_step) 283 self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) 284 # pylint: enable=protected-access 285 286 def before_run(self, run_context): 287 if self._first_run: 288 self._restore_or_save_initial_ckpt(run_context.session) 289 self._first_run = False 290 return self._checkpoint_saver_hook.before_run(run_context) 291 292 def after_run(self, run_context, run_values): 293 self._checkpoint_saver_hook.after_run(run_context, run_values) 294 295 def end(self, session): 296 self._checkpoint_saver_hook.end(session) 297 298 299class _CustomSaver(saver_lib.Saver): 300 """`Saver` with a different default `latest_filename`. 301 302 This is used in the `CheckpointInputPipelineHook` to avoid conflicts with 303 the model ckpt saved by the `CheckpointSaverHook`. 304 """ 305 306 def __init__(self, var_list, latest_filename, sharded=False): 307 super(_CustomSaver, self).__init__(var_list, sharded=sharded) 308 self._latest_filename = latest_filename 309 310 def save(self, 311 sess, 312 save_path, 313 global_step=None, 314 latest_filename=None, 315 meta_graph_suffix="meta", 316 write_meta_graph=True, 317 write_state=True, 318 strip_default_attrs=False): 319 return super(_CustomSaver, self).save( 320 sess, save_path, global_step, latest_filename or self._latest_filename, 321 meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) 322