xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/basic_session_run_hooks.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Some common SessionRunHook classes.
16
17Note that the symbols that are exported to v1 tf.train namespace are also
18exported to v2 in tf.estimator namespace. See
19https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py
20"""
21
22import os
23import time
24
25import numpy as np
26
27from tensorflow.core.framework.summary_pb2 import Summary
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.util.event_pb2 import SessionLog
30from tensorflow.python.client import timeline
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors
33from tensorflow.python.framework import meta_graph
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import init_ops
36from tensorflow.python.ops import variable_scope
37from tensorflow.python.platform import gfile
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.training import session_run_hook
40from tensorflow.python.training import training_util
41from tensorflow.python.training.session_run_hook import SessionRunArgs
42from tensorflow.python.training.summary_io import SummaryWriterCache
43from tensorflow.python.util.tf_export import tf_export
44
45_HOOKS = "hooks"
46_STEPS_PER_RUN_VAR = "steps_per_run"
47
48
49class _HookTimer:
50  """Base timer for determining when Hooks should trigger.
51
52  Should not be instantiated directly.
53  """
54
55  def __init__(self):
56    pass
57
58  def reset(self):
59    """Resets the timer."""
60    pass
61
62  def should_trigger_for_step(self, step):
63    """Return true if the timer should trigger for the specified step."""
64    raise NotImplementedError
65
66  def update_last_triggered_step(self, step):
67    """Update the last triggered time and step number.
68
69    Args:
70      step: The current step.
71
72    Returns:
73      A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number
74      of seconds between the current trigger and the last one (a float), and
75      `elapsed_steps` is the number of steps between the current trigger and
76      the last one. Both values will be set to `None` on the first trigger.
77    """
78    raise NotImplementedError
79
80  def last_triggered_step(self):
81    """Returns the last triggered time step or None if never triggered."""
82    raise NotImplementedError
83
84
85@tf_export(v1=["train.SecondOrStepTimer"])
86class SecondOrStepTimer(_HookTimer):
87  """Timer that triggers at most once every N seconds or once every N steps.
88
89  This symbol is also exported to v2 in tf.estimator namespace. See
90  https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py
91  """
92
93  def __init__(self, every_secs=None, every_steps=None):
94    self.reset()
95    self._every_secs = every_secs
96    self._every_steps = every_steps
97
98    if self._every_secs is None and self._every_steps is None:
99      raise ValueError("Either every_secs or every_steps should be provided.")
100    if (self._every_secs is not None) and (self._every_steps is not None):
101      raise ValueError("Can not provide both every_secs and every_steps.")
102
103    super(SecondOrStepTimer, self).__init__()
104
105  def reset(self):
106    self._last_triggered_step = None
107    self._last_triggered_time = None
108
109  def should_trigger_for_step(self, step):
110    """Return true if the timer should trigger for the specified step.
111
112    Args:
113      step: Training step to trigger on.
114
115    Returns:
116      True if the difference between the current time and the time of the last
117      trigger exceeds `every_secs`, or if the difference between the current
118      step and the last triggered step exceeds `every_steps`. False otherwise.
119    """
120    if self._last_triggered_step is None:
121      return True
122
123    if self._last_triggered_step == step:
124      return False
125
126    if self._every_secs is not None:
127      if time.time() >= self._last_triggered_time + self._every_secs:
128        return True
129
130    if self._every_steps is not None:
131      if step >= self._last_triggered_step + self._every_steps:
132        return True
133
134    return False
135
136  def update_last_triggered_step(self, step):
137    current_time = time.time()
138    if self._last_triggered_time is None:
139      elapsed_secs = None
140      elapsed_steps = None
141    else:
142      elapsed_secs = current_time - self._last_triggered_time
143      elapsed_steps = step - self._last_triggered_step
144
145    self._last_triggered_time = current_time
146    self._last_triggered_step = step
147    return (elapsed_secs, elapsed_steps)
148
149  def last_triggered_step(self):
150    return self._last_triggered_step
151
152
153class NeverTriggerTimer(_HookTimer):
154  """Timer that never triggers."""
155
156  def should_trigger_for_step(self, step):
157    _ = step
158    return False
159
160  def update_last_triggered_step(self, step):
161    _ = step
162    return (None, None)
163
164  def last_triggered_step(self):
165    return None
166
167
168@tf_export(v1=["train.LoggingTensorHook"])
169class LoggingTensorHook(session_run_hook.SessionRunHook):
170  """Prints the given tensors every N local steps, every N seconds, or at end.
171
172  The tensors will be printed to the log, with `INFO` severity. If you are not
173  seeing the logs, you might want to add the following line after your imports:
174
175  ```python
176    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
177  ```
178
179  Note that if `at_end` is True, `tensors` should not include any tensor
180  whose evaluation produces a side effect such as consuming additional inputs.
181
182  @compatibility(TF2)
183  Please check this [notebook][notebook] on how to migrate the API to TF2.
184
185  [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb
186
187  @end_compatibility
188
189  """
190
191  def __init__(self,
192               tensors,
193               every_n_iter=None,
194               every_n_secs=None,
195               at_end=False,
196               formatter=None):
197    """Initializes a `LoggingTensorHook`.
198
199    Args:
200      tensors: `dict` that maps string-valued tags to tensors/tensor names, or
201        `iterable` of tensors/tensor names.
202      every_n_iter: `int`, print the values of `tensors` once every N local
203        steps taken on the current worker.
204      every_n_secs: `int` or `float`, print the values of `tensors` once every N
205        seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
206        provided.
207      at_end: `bool` specifying whether to print the values of `tensors` at the
208        end of the run.
209      formatter: function, takes dict of `tag`->`Tensor` and returns a string.
210        If `None` uses default printing all tensors.
211
212    Raises:
213      ValueError: if `every_n_iter` is non-positive.
214    """
215    only_log_at_end = (
216        at_end and (every_n_iter is None) and (every_n_secs is None))
217    if (not only_log_at_end and
218        (every_n_iter is None) == (every_n_secs is None)):
219      raise ValueError(
220          "either at_end and/or exactly one of every_n_iter and every_n_secs "
221          "must be provided.")
222    if every_n_iter is not None and every_n_iter <= 0:
223      raise ValueError("invalid every_n_iter=%s." % every_n_iter)
224    if not isinstance(tensors, dict):
225      self._tag_order = tensors
226      tensors = {item: item for item in tensors}
227    else:
228      self._tag_order = sorted(tensors.keys())
229    self._tensors = tensors
230    self._formatter = formatter
231    self._timer = (
232        NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer(
233            every_secs=every_n_secs, every_steps=every_n_iter))
234    self._log_at_end = at_end
235
236  def begin(self):
237    self._timer.reset()
238    self._iter_count = 0
239    # Convert names to tensors if given
240    self._current_tensors = {
241        tag: _as_graph_element(tensor)
242        for (tag, tensor) in self._tensors.items()
243    }
244
245  def before_run(self, run_context):  # pylint: disable=unused-argument
246    self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
247    if self._should_trigger:
248      return SessionRunArgs(self._current_tensors)
249    else:
250      return None
251
252  def _log_tensors(self, tensor_values):
253    original = np.get_printoptions()
254    np.set_printoptions(suppress=True)
255    elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count)
256    if self._formatter:
257      logging.info(self._formatter(tensor_values))
258    else:
259      stats = []
260      for tag in self._tag_order:
261        stats.append("%s = %s" % (tag, tensor_values[tag]))
262      if elapsed_secs is not None:
263        logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs)
264      else:
265        logging.info("%s", ", ".join(stats))
266    np.set_printoptions(**original)
267
268  def after_run(self, run_context, run_values):
269    _ = run_context
270    if self._should_trigger:
271      self._log_tensors(run_values.results)
272
273    self._iter_count += 1
274
275  def end(self, session):
276    if self._log_at_end:
277      values = session.run(self._current_tensors)
278      self._log_tensors(values)
279
280
281def get_or_create_steps_per_run_variable():
282  """Gets or creates the steps_per_run variable.
283
284  In Estimator, the user provided computation, the model_fn, is wrapped
285  inside a tf.while_loop for peak performance. The iterations of the loop are
286  specified by this variable, which adjusts its value on the CPU after each
287  device program execution and before the next execution.
288
289  The purpose of using a variable, rather than a constant, is to allow
290  Estimator adapt the device training iterations according to the final steps
291  specified by users. For example, if the user sets the steps_per_run as
292  4 and steps as 10 in Estimator.train(), the steps_per_run
293  variable will have the following value before each training run.
294
295      - 1-st execution: steps_per_run = 4
296      - 2-nd execution: steps_per_run = 4
297      - 3-rd execution: steps_per_run = 2
298
299  As model_fn increases the global step once per train_op invocation, the global
300  step is 10 after all executions, matching the steps=10 inputs passed in by
301  users.
302
303  Returns:
304    A TF non-trainable resource variable.
305
306  Raises:
307    RuntimeError: If multi steps_per_run variables were found.
308  """
309  graph = ops.get_default_graph()
310  collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR)
311  steps_per_run_vars = graph.get_collection(collection_name)
312  if len(steps_per_run_vars) == 1:
313    return steps_per_run_vars[0]
314  elif len(steps_per_run_vars) > 1:
315    raise RuntimeError("Multiple steps_per_run_var in collection.")
316
317  with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE):
318    return variable_scope.get_variable(
319        _STEPS_PER_RUN_VAR,
320        initializer=init_ops.ones_initializer(),
321        shape=[],
322        dtype=dtypes.int32,
323        trainable=False,
324        collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES],
325        use_resource=True)
326
327
328class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
329  """Hook that requests stop at a specified step."""
330
331  def __init__(self, num_steps=None, last_step=None, steps_per_run=1):
332    """Initializes a `MultiStepStopAtStepHook`.
333
334    This hook requests stop after either a number of steps have been
335    executed or a last step has been reached. Only one of the two options can be
336    specified.
337
338    if `num_steps` is specified, it indicates the number of steps to execute
339    after `begin()` is called. If instead `last_step` is specified, it
340    indicates the last step we want to execute, as passed to the `after_run()`
341    call.
342
343    In Estimator, the user provided computation, the model_fn, is wrapped
344    inside a tf.while_loop for peak performance. The steps_per_run variable
345    determines the number of iterations of the loop before returning to the CPU.
346
347    Args:
348      num_steps: Number of steps to execute.
349      last_step: Step after which to stop.
350      steps_per_run: Number of steps executed per run call.
351
352    Raises:
353      ValueError: If one of the arguments is invalid.
354    """
355    if num_steps is None and last_step is None:
356      raise ValueError("One of num_steps or last_step must be specified.")
357    if num_steps is not None and last_step is not None:
358      raise ValueError("Only one of num_steps or last_step can be specified.")
359    if steps_per_run is None or steps_per_run < 1:
360      raise ValueError("steps_per_run should be greater than 0")
361    self._num_steps = num_steps
362    self._last_step = last_step
363    self._steps_per_run_initial_value = steps_per_run
364
365  def begin(self):
366    self._global_step_tensor = training_util.get_global_step()
367    if self._global_step_tensor is None:
368      raise RuntimeError("Global step should be created to use StopAtStepHook.")
369    self._steps_per_run_variable = get_or_create_steps_per_run_variable()
370
371  def _update_steps_per_run_variable(self, global_step, session):
372    steps = min(self._last_step - global_step,
373                self._steps_per_run_initial_value)
374    self._steps_per_run_variable.load(steps, session=session)
375
376  def after_create_session(self, session, coord):
377    global_step = session.run(self._global_step_tensor)
378    if self._last_step is None:
379      self._last_step = global_step + self._num_steps
380    self._update_steps_per_run_variable(global_step, session)
381
382  def after_run(self, run_context, run_values):
383    # Global step cannot be retrieved via SessionRunArgs and before_run due to
384    # race condition in hook execution.
385    global_step = run_context.session.run(self._global_step_tensor)
386    if global_step >= self._last_step:
387      run_context.request_stop()
388    else:
389      self._update_steps_per_run_variable(global_step, run_context.session)
390
391
392@tf_export(v1=["train.StopAtStepHook"])
393class StopAtStepHook(session_run_hook.SessionRunHook):
394  """Hook that requests stop at a specified step.
395
396  @compatibility(TF2)
397  Please check this [notebook][notebook] on how to migrate the API to TF2.
398
399  [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb
400
401  @end_compatibility
402  """
403
404  def __init__(self, num_steps=None, last_step=None):
405    """Initializes a `StopAtStepHook`.
406
407    This hook requests stop after either a number of steps have been
408    executed or a last step has been reached. Only one of the two options can be
409    specified.
410
411    if `num_steps` is specified, it indicates the number of steps to execute
412    after `begin()` is called. If instead `last_step` is specified, it
413    indicates the last step we want to execute, as passed to the `after_run()`
414    call.
415
416    Args:
417      num_steps: Number of steps to execute.
418      last_step: Step after which to stop.
419
420    Raises:
421      ValueError: If one of the arguments is invalid.
422    """
423    if num_steps is None and last_step is None:
424      raise ValueError("One of num_steps or last_step must be specified.")
425    if num_steps is not None and last_step is not None:
426      raise ValueError("Only one of num_steps or last_step can be specified.")
427    self._num_steps = num_steps
428    self._last_step = last_step
429
430  def begin(self):
431    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
432    if self._global_step_tensor is None:
433      raise RuntimeError("Global step should be created to use StopAtStepHook.")
434
435  def after_create_session(self, session, coord):
436    if self._last_step is None:
437      global_step = session.run(self._global_step_tensor)
438      self._last_step = global_step + self._num_steps
439
440  def before_run(self, run_context):  # pylint: disable=unused-argument
441    return SessionRunArgs(self._global_step_tensor)
442
443  def after_run(self, run_context, run_values):
444    global_step = run_values.results + 1
445    if global_step >= self._last_step:
446      # Check latest global step to ensure that the targeted last step is
447      # reached. global_step read tensor is the value of global step
448      # before running the operation. We're not sure whether current session.run
449      # incremented the global_step or not. Here we're checking it.
450
451      step = run_context.session.run(self._global_step_tensor)
452      if step >= self._last_step:
453        run_context.request_stop()
454
455
456@tf_export(v1=["train.CheckpointSaverListener"])
457class CheckpointSaverListener:
458  """Interface for listeners that take action before or after checkpoint save.
459
460  `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is
461  triggered, and provides callbacks at the following points:
462   - before using the session
463   - before each call to `Saver.save()`
464   - after each call to `Saver.save()`
465   - at the end of session
466
467  To use a listener, implement a class and pass the listener to a
468  `CheckpointSaverHook`, as in this example:
469
470  ```python
471  class ExampleCheckpointSaverListener(CheckpointSaverListener):
472    def begin(self):
473      # You can add ops to the graph here.
474      print('Starting the session.')
475      self.your_tensor = ...
476
477    def before_save(self, session, global_step_value):
478      print('About to write a checkpoint')
479
480    def after_save(self, session, global_step_value):
481      print('Done writing checkpoint.')
482      if decided_to_stop_training():
483        return True
484
485    def end(self, session, global_step_value):
486      print('Done with the session.')
487
488  ...
489  listener = ExampleCheckpointSaverListener()
490  saver_hook = tf.estimator.CheckpointSaverHook(
491      checkpoint_dir, listeners=[listener])
492  with
493  tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
494    ...
495  ```
496
497  A `CheckpointSaverListener` may simply take some action after every
498  checkpoint save. It is also possible for the listener to use its own schedule
499  to act less frequently, e.g. based on global_step_value. In this case,
500  implementors should implement the `end()` method to handle actions related to
501  the last checkpoint save. But the listener should not act twice if
502  `after_save()` already handled this last checkpoint save.
503
504  A `CheckpointSaverListener` can request training to be stopped, by returning
505  True in `after_save`. Please note that, in replicated distributed training
506  setting, only `chief` should use this behavior. Otherwise each worker will do
507  their own evaluation, which may be wasteful of resources.
508  """
509
510  def begin(self):
511    pass
512
513  def before_save(self, session, global_step_value):
514    pass
515
516  def after_save(self, session, global_step_value):
517    pass
518
519  def end(self, session, global_step_value):
520    pass
521
522
523@tf_export(v1=["train.CheckpointSaverHook"])
524class CheckpointSaverHook(session_run_hook.SessionRunHook):
525  """Saves checkpoints every N steps or seconds."""
526
527  def __init__(self,
528               checkpoint_dir,
529               save_secs=None,
530               save_steps=None,
531               saver=None,
532               checkpoint_basename="model.ckpt",
533               scaffold=None,
534               listeners=None,
535               save_graph_def=True):
536    """Initializes a `CheckpointSaverHook`.
537
538    Args:
539      checkpoint_dir: `str`, base directory for the checkpoint files.
540      save_secs: `int`, save every N secs.
541      save_steps: `int`, save every N steps.
542      saver: `Saver` object, used for saving.
543      checkpoint_basename: `str`, base name for the checkpoint files.
544      scaffold: `Scaffold`, use to get saver object.
545      listeners: List of `CheckpointSaverListener` subclass instances. Used for
546        callbacks that run immediately before or after this hook saves the
547        checkpoint.
548      save_graph_def: Whether to save the GraphDef and MetaGraphDef to
549        `checkpoint_dir`. The GraphDef is saved after the session is created as
550        `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
551        `model.ckpt-*.meta`.
552
553    Raises:
554      ValueError: One of `save_steps` or `save_secs` should be set.
555      ValueError: At most one of `saver` or `scaffold` should be set.
556    """
557    logging.info("Create CheckpointSaverHook.")
558    if saver is not None and scaffold is not None:
559      raise ValueError("You cannot provide both saver and scaffold.")
560    self._saver = saver
561    self._checkpoint_dir = checkpoint_dir
562    self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
563    self._scaffold = scaffold
564    self._timer = SecondOrStepTimer(
565        every_secs=save_secs, every_steps=save_steps)
566    self._listeners = listeners or []
567    # Set sufficiently high default that it never skips checking the actual
568    # global step counter -- unless the user overrides it with the right value
569    # for the steps_per_run.
570    self._steps_per_run = 1000000
571    self._save_graph_def = save_graph_def
572
573  def _set_steps_per_run(self, steps_per_run):
574    self._steps_per_run = steps_per_run
575
576  def begin(self):
577    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
578    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
579    if self._global_step_tensor is None:
580      raise RuntimeError(
581          "Global step should be created to use CheckpointSaverHook.")
582    for l in self._listeners:
583      l.begin()
584
585  def after_create_session(self, session, coord):
586    global_step = session.run(self._global_step_tensor)
587    if self._save_graph_def:
588      # We do write graph and saver_def at the first call of before_run.
589      # We cannot do this in begin, since we let other hooks to change graph and
590      # add variables in begin. Graph is finalized after all begin calls.
591      training_util.write_graph(
592          ops.get_default_graph().as_graph_def(add_shapes=True),
593          self._checkpoint_dir, "graph.pbtxt")
594    saver_def = self._get_saver().saver_def if self._get_saver() else None
595    graph = ops.get_default_graph()
596    meta_graph_def = meta_graph.create_meta_graph_def(
597        graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
598    self._summary_writer.add_graph(graph)
599    self._summary_writer.add_meta_graph(meta_graph_def)
600    # The checkpoint saved here is the state at step "global_step".
601    self._save(session, global_step)
602    self._timer.update_last_triggered_step(global_step)
603
604  def before_run(self, run_context):  # pylint: disable=unused-argument
605    return SessionRunArgs(self._global_step_tensor)
606
607  def after_run(self, run_context, run_values):
608    stale_global_step = run_values.results
609    if self._timer.should_trigger_for_step(stale_global_step +
610                                           self._steps_per_run):
611      # get the real value after train op.
612      global_step = run_context.session.run(self._global_step_tensor)
613      if self._timer.should_trigger_for_step(global_step):
614        self._timer.update_last_triggered_step(global_step)
615        if self._save(run_context.session, global_step):
616          run_context.request_stop()
617
618  def end(self, session):
619    last_step = session.run(self._global_step_tensor)
620    if last_step != self._timer.last_triggered_step():
621      self._save(session, last_step)
622    for l in self._listeners:
623      l.end(session, last_step)
624
625  def _save(self, session, step):
626    """Saves the latest checkpoint, returns should_stop."""
627    logging.info("Calling checkpoint listeners before saving checkpoint %d...",
628                 step)
629    for l in self._listeners:
630      l.before_save(session, step)
631
632    logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
633    self._get_saver().save(session, self._save_path, global_step=step,
634                           write_meta_graph=self._save_graph_def)
635    self._summary_writer.add_session_log(
636        SessionLog(
637            status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
638        step)
639    logging.info("Calling checkpoint listeners after saving checkpoint %d...",
640                 step)
641    should_stop = False
642    for l in self._listeners:
643      if l.after_save(session, step):
644        logging.info(
645            "A CheckpointSaverListener requested that training be stopped. "
646            "listener: {}".format(l))
647        should_stop = True
648    return should_stop
649
650  def _get_saver(self):
651    if self._saver is not None:
652      return self._saver
653    elif self._scaffold is not None:
654      return self._scaffold.saver
655
656    # Get saver from the SAVERS collection if present.
657    collection_key = ops.GraphKeys.SAVERS
658    savers = ops.get_collection(collection_key)
659    if not savers:
660      raise RuntimeError(
661          "No items in collection {}. Please add a saver to the collection "
662          "or provide a saver or scaffold.".format(collection_key))
663    elif len(savers) > 1:
664      raise RuntimeError(
665          "More than one item in collection {}. "
666          "Please indicate which one to use by passing it to the constructor."
667          .format(collection_key))
668
669    self._saver = savers[0]
670    return savers[0]
671
672
673@tf_export(v1=["train.StepCounterHook"])
674class StepCounterHook(session_run_hook.SessionRunHook):
675  """Hook that counts steps per second."""
676
677  def __init__(self,
678               every_n_steps=100,
679               every_n_secs=None,
680               output_dir=None,
681               summary_writer=None):
682
683    if (every_n_steps is None) == (every_n_secs is None):
684      raise ValueError(
685          "exactly one of every_n_steps and every_n_secs should be provided.")
686    self._timer = SecondOrStepTimer(
687        every_steps=every_n_steps, every_secs=every_n_secs)
688
689    self._summary_writer = summary_writer
690    self._output_dir = output_dir
691    self._last_global_step = None
692    self._steps_per_run = 1
693
694  def _set_steps_per_run(self, steps_per_run):
695    self._steps_per_run = steps_per_run
696
697  def begin(self):
698    if self._summary_writer is None and self._output_dir:
699      self._summary_writer = SummaryWriterCache.get(self._output_dir)
700    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
701    if self._global_step_tensor is None:
702      raise RuntimeError(
703          "Global step should be created to use StepCounterHook.")
704    self._summary_tag = training_util.get_global_step().op.name + "/sec"
705
706  def before_run(self, run_context):  # pylint: disable=unused-argument
707    return SessionRunArgs(self._global_step_tensor)
708
709  def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
710    steps_per_sec = elapsed_steps / elapsed_time
711    if self._summary_writer is not None:
712      summary = Summary(value=[
713          Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
714      ])
715      self._summary_writer.add_summary(summary, global_step)
716    logging.info("%s: %g", self._summary_tag, steps_per_sec)
717
718  def after_run(self, run_context, run_values):
719    _ = run_context
720
721    stale_global_step = run_values.results
722    if self._timer.should_trigger_for_step(stale_global_step +
723                                           self._steps_per_run):
724      # get the real value after train op.
725      global_step = run_context.session.run(self._global_step_tensor)
726      if self._timer.should_trigger_for_step(global_step):
727        elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
728            global_step)
729        if elapsed_time is not None:
730          self._log_and_record(elapsed_steps, elapsed_time, global_step)
731
732    # Check whether the global step has been increased. Here, we do not use the
733    # timer.last_triggered_step as the timer might record a different global
734    # step value such that the comparison could be unreliable. For simplicity,
735    # we just compare the stale_global_step with previously recorded version.
736    if stale_global_step == self._last_global_step:
737      # Here, we give a warning in the first 5 times if we have observed that
738      # the global step has not been increased. For some Optimizers, the global
739      # step is not increased each time by design. For example,
740      # SyncReplicaOptimizer doesn't increase the global step in worker's main
741      # train step.
742      logging.log_first_n(
743          logging.WARN,
744          "It seems that global step (tf.train.get_global_step) has not "
745          "been increased. Current value (could be stable): %s vs previous "
746          "value: %s. You could increase the global step by passing "
747          "tf.train.get_global_step() to Optimizer.apply_gradients or "
748          "Optimizer.minimize.", 5, stale_global_step, self._last_global_step)
749
750    self._last_global_step = stale_global_step
751
752
753@tf_export(v1=["train.NanLossDuringTrainingError"])
754class NanLossDuringTrainingError(RuntimeError):
755
756  def __str__(self):
757    return "NaN loss during training."
758
759
760@tf_export(v1=["train.NanTensorHook"])
761class NanTensorHook(session_run_hook.SessionRunHook):
762  """Monitors the loss tensor and stops training if loss is NaN.
763
764  Can either fail with exception or just stop training.
765  """
766
767  def __init__(self, loss_tensor, fail_on_nan_loss=True):
768    """Initializes a `NanTensorHook`.
769
770    Args:
771      loss_tensor: `Tensor`, the loss tensor.
772      fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
773    """
774    self._loss_tensor = loss_tensor
775    self._fail_on_nan_loss = fail_on_nan_loss
776
777  def before_run(self, run_context):  # pylint: disable=unused-argument
778    return SessionRunArgs(self._loss_tensor)
779
780  def after_run(self, run_context, run_values):
781    if np.isnan(run_values.results):
782      failure_message = "Model diverged with loss = NaN."
783      if self._fail_on_nan_loss:
784        logging.error(failure_message)
785        raise NanLossDuringTrainingError
786      else:
787        logging.warning(failure_message)
788        # We don't raise an error but we request stop without an exception.
789        run_context.request_stop()
790
791
792@tf_export(v1=["train.SummarySaverHook"])
793class SummarySaverHook(session_run_hook.SessionRunHook):
794  """Saves summaries every N steps."""
795
796  def __init__(self,
797               save_steps=None,
798               save_secs=None,
799               output_dir=None,
800               summary_writer=None,
801               scaffold=None,
802               summary_op=None):
803    """Initializes a `SummarySaverHook`.
804
805    Args:
806      save_steps: `int`, save summaries every N steps. Exactly one of
807        `save_secs` and `save_steps` should be set.
808      save_secs: `int`, save summaries every N seconds.
809      output_dir: `string`, the directory to save the summaries to. Only used if
810        no `summary_writer` is supplied.
811      summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
812        one will be created accordingly.
813      scaffold: `Scaffold` to get summary_op if it's not provided.
814      summary_op: `Tensor` of type `string` containing the serialized `Summary`
815        protocol buffer or a list of `Tensor`. They are most likely an output by
816        TF summary methods like `tf.compat.v1.summary.scalar` or
817        `tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if
818        more than one, they must be passed in as a list.
819
820    Raises:
821      ValueError: Exactly one of scaffold or summary_op should be set.
822    """
823    if ((scaffold is None and summary_op is None) or
824        (scaffold is not None and summary_op is not None)):
825      raise ValueError(
826          "Exactly one of scaffold or summary_op must be provided.")
827    self._summary_op = summary_op
828    self._summary_writer = summary_writer
829    self._output_dir = output_dir
830    self._scaffold = scaffold
831    self._timer = SecondOrStepTimer(
832        every_secs=save_secs, every_steps=save_steps)
833    # TODO(mdan): Throw an error if output_dir and summary_writer are None.
834
835  def begin(self):
836    if self._summary_writer is None and self._output_dir:
837      self._summary_writer = SummaryWriterCache.get(self._output_dir)
838    self._next_step = None
839    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
840    if self._global_step_tensor is None:
841      raise RuntimeError(
842          "Global step should be created to use SummarySaverHook.")
843
844  def before_run(self, run_context):  # pylint: disable=unused-argument
845    self._request_summary = (
846        self._next_step is None or
847        self._timer.should_trigger_for_step(self._next_step))
848    requests = {"global_step": self._global_step_tensor}
849    if self._request_summary:
850      if self._get_summary_op() is not None:
851        requests["summary"] = self._get_summary_op()
852
853    return SessionRunArgs(requests)
854
855  def after_run(self, run_context, run_values):
856    _ = run_context
857    if not self._summary_writer:
858      return
859
860    stale_global_step = run_values.results["global_step"]
861    global_step = stale_global_step + 1
862    if self._next_step is None or self._request_summary:
863      global_step = run_context.session.run(self._global_step_tensor)
864
865    if self._next_step is None:
866      self._summary_writer.add_session_log(
867          SessionLog(status=SessionLog.START), global_step)
868
869    if self._request_summary:
870      self._timer.update_last_triggered_step(global_step)
871      if "summary" in run_values.results:
872        for summary in run_values.results["summary"]:
873          self._summary_writer.add_summary(summary, global_step)
874
875    self._next_step = global_step + 1
876
877  def end(self, session=None):
878    if self._summary_writer:
879      self._summary_writer.flush()
880
881  def _get_summary_op(self):
882    """Fetches the summary op either from self._summary_op or self._scaffold.
883
884    Returns:
885      Returns a list of summary `Tensor`.
886    """
887    summary_op = None
888    if self._summary_op is not None:
889      summary_op = self._summary_op
890    elif self._scaffold.summary_op is not None:
891      summary_op = self._scaffold.summary_op
892
893    if summary_op is None:
894      return None
895
896    if not isinstance(summary_op, list):
897      return [summary_op]
898    return summary_op
899
900
901@tf_export(v1=["train.GlobalStepWaiterHook"])
902class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
903  """Delays execution until global step reaches `wait_until_step`.
904
905  This hook delays execution until global step reaches to `wait_until_step`. It
906  is used to gradually start workers in distributed settings. One example usage
907  would be setting `wait_until_step=int(K*log(task_id+1))` assuming that
908  task_id=0 is the chief.
909  """
910
911  def __init__(self, wait_until_step):
912    """Initializes a `GlobalStepWaiterHook`.
913
914    Args:
915      wait_until_step: an `int` shows until which global step should we wait.
916    """
917    self._wait_until_step = wait_until_step
918
919  def begin(self):
920    self._worker_is_started = False
921    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
922    if self._global_step_tensor is None:
923      raise RuntimeError(
924          "Global step should be created to use _GlobalStepWaiterHook.")
925
926  def before_run(self, run_context):
927    if self._worker_is_started:
928      return None
929
930    if self._wait_until_step <= 0:
931      self._worker_is_started = True
932      return None
933
934    logging.info("Waiting for global step %d before starting training.",
935                 self._wait_until_step)
936    last_logged_step = 0
937    while True:
938      current_step = run_context.session.run(self._global_step_tensor)
939      if current_step >= self._wait_until_step:
940        self._worker_is_started = True
941        return None
942      if current_step - last_logged_step > 1000:
943        logging.info(
944            "Waiting for global step %d before starting training. "
945            "Current step is %d.", self._wait_until_step, current_step)
946        last_logged_step = current_step
947      time.sleep(0.5)
948
949
950@tf_export(v1=["train.FinalOpsHook"])
951class FinalOpsHook(session_run_hook.SessionRunHook):
952  """A hook which evaluates `Tensors` at the end of a session."""
953
954  def __init__(self, final_ops, final_ops_feed_dict=None):
955    """Initializes `FinalOpHook` with ops to run at the end of the session.
956
957    Args:
958      final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
959        to `Tensors`.
960      final_ops_feed_dict: A feed dictionary to use when running
961        `final_ops_dict`.
962    """
963    self._final_ops = final_ops
964    self._final_ops_feed_dict = final_ops_feed_dict
965    self._final_ops_values = None
966
967  @property
968  def final_ops_values(self):
969    return self._final_ops_values
970
971  def end(self, session):
972    if self._final_ops is not None:
973      try:
974        self._final_ops_values = session.run(
975            self._final_ops, feed_dict=self._final_ops_feed_dict)
976      except (errors.OutOfRangeError, StopIteration) as e:
977        logging.warning(
978            "An OutOfRangeError or StopIteration exception is raised by the "
979            "code in FinalOpsHook. This typically means the Ops running by the "
980            "FinalOpsHook have a dependency back to some input source, which "
981            "should not happen. For example, for metrics in "
982            "tf.estimator.Estimator, all metrics functions return two Ops: "
983            "`value_op` and  `update_op`. Estimator.evaluate calls the "
984            "`update_op` for each batch of the data in input source and, once "
985            "it is exhausted, it call the `value_op` to get the metric values. "
986            "The `value_op` here should have dependency back to variables "
987            "reading only, rather than reading another batch from input. "
988            "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers "
989            "another data reading, which ends OutOfRangeError/StopIteration. "
990            "Please fix that.")
991        raise e
992
993
994@tf_export(v1=["train.FeedFnHook"])
995class FeedFnHook(session_run_hook.SessionRunHook):
996  """Runs `feed_fn` and sets the `feed_dict` accordingly."""
997
998  def __init__(self, feed_fn):
999    """Initializes a `FeedFnHook`.
1000
1001    Args:
1002      feed_fn: function that takes no arguments and returns `dict` of `Tensor`
1003        to feed.
1004    """
1005    self.feed_fn = feed_fn
1006
1007  def before_run(self, run_context):  # pylint: disable=unused-argument
1008    return session_run_hook.SessionRunArgs(
1009        fetches=None, feed_dict=self.feed_fn())
1010
1011
1012@tf_export(v1=["train.ProfilerHook"])
1013class ProfilerHook(session_run_hook.SessionRunHook):
1014  """Captures CPU/GPU profiling information every N steps or seconds.
1015
1016  This produces files called "timeline-<step>.json", which are in Chrome
1017  Trace format.
1018
1019  For more information see:
1020  https://github.com/catapult-project/catapult/blob/master/tracing/README.md
1021  """
1022
1023  def __init__(self,
1024               save_steps=None,
1025               save_secs=None,
1026               output_dir="",
1027               show_dataflow=True,
1028               show_memory=False):
1029    """Initializes a hook that takes periodic profiling snapshots.
1030
1031    `options.run_metadata` argument of `tf.Session.Run` is used to collect
1032    metadata about execution. This hook sets the metadata and dumps it in Chrome
1033    Trace format.
1034
1035
1036    Args:
1037      save_steps: `int`, save profile traces every N steps. Exactly one of
1038        `save_secs` and `save_steps` should be set.
1039      save_secs: `int` or `float`, save profile traces every N seconds.
1040      output_dir: `string`, the directory to save the profile traces to.
1041        Defaults to the current directory.
1042      show_dataflow: `bool`, if True, add flow events to the trace connecting
1043        producers and consumers of tensors.
1044      show_memory: `bool`, if True, add object snapshot events to the trace
1045        showing the sizes and lifetimes of tensors.
1046    """
1047    self._output_file = os.path.join(output_dir, "timeline-{}.json")
1048    self._file_writer = SummaryWriterCache.get(output_dir)
1049    self._show_dataflow = show_dataflow
1050    self._show_memory = show_memory
1051    self._timer = SecondOrStepTimer(
1052        every_secs=save_secs, every_steps=save_steps)
1053
1054  def begin(self):
1055    self._next_step = None
1056    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
1057    if self._global_step_tensor is None:
1058      raise RuntimeError("Global step should be created to use ProfilerHook.")
1059
1060  def before_run(self, run_context):
1061    self._request_summary = (
1062        self._next_step is not None and
1063        self._timer.should_trigger_for_step(self._next_step))
1064    requests = {"global_step": self._global_step_tensor}
1065    opts = (
1066        config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
1067        if self._request_summary else None)
1068
1069    return SessionRunArgs(requests, options=opts)
1070
1071  def after_run(self, run_context, run_values):
1072    stale_global_step = run_values.results["global_step"]
1073    if self._next_step is None:
1074      # Update the timer so that it does not activate until N steps or seconds
1075      # have passed.
1076      self._timer.update_last_triggered_step(stale_global_step)
1077    global_step = stale_global_step + 1
1078    if self._request_summary:
1079      global_step = run_context.session.run(self._global_step_tensor)
1080      self._timer.update_last_triggered_step(global_step)
1081      self._save(global_step, self._output_file.format(global_step),
1082                 run_values.run_metadata.step_stats)
1083      self._file_writer.add_run_metadata(run_values.run_metadata,
1084                                         "step_%d" % global_step)
1085
1086    self._next_step = global_step + 1
1087
1088  def _save(self, step, save_path, step_stats):
1089    logging.info("Saving timeline for %d into '%s'.", step, save_path)
1090    with gfile.Open(save_path, "w") as f:
1091      trace = timeline.Timeline(step_stats)
1092      f.write(
1093          trace.generate_chrome_trace_format(
1094              show_dataflow=self._show_dataflow, show_memory=self._show_memory))
1095
1096
1097def _as_graph_element(obj):
1098  """Retrieves Graph element."""
1099  graph = ops.get_default_graph()
1100  if not isinstance(obj, str):
1101    if not hasattr(obj, "graph") or obj.graph != graph:
1102      raise ValueError("Passed %s should have graph attribute that is equal "
1103                       "to current graph %s." % (obj, graph))
1104    return obj
1105  if ":" in obj:
1106    element = graph.as_graph_element(obj)
1107  else:
1108    element = graph.as_graph_element(obj + ":0")
1109    # Check that there is no :1 (e.g. it's single output).
1110    try:
1111      graph.as_graph_element(obj + ":1")
1112    except (KeyError, ValueError):
1113      pass
1114    else:
1115      raise ValueError("Name %s is ambiguous, "
1116                       "as this `Operation` has multiple outputs "
1117                       "(at least 2)." % obj)
1118  return element
1119