xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/monitored_session.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# pylint: disable=g-bad-file-header
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""A wrapper of Session API which runs hooks."""
17
18import abc
19import os
20
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.python.checkpoint import checkpoint as trackable_util
23from tensorflow.python.checkpoint import graph_view
24from tensorflow.python.distribute import distribute_coordinator_context
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import lookup_ops
30from tensorflow.python.ops import resources
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.summary import summary
34from tensorflow.python.training import basic_session_run_hooks
35from tensorflow.python.training import coordinator
36from tensorflow.python.training import queue_runner
37from tensorflow.python.training import saver as training_saver
38from tensorflow.python.training import session_manager as sm
39from tensorflow.python.training import session_run_hook
40from tensorflow.python.util import function_utils
41from tensorflow.python.util.tf_export import tf_export
42
43# The list of exceptions that we should recover from. Exceptions not in this
44# list may terminate the job.
45_PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError)
46
47# Value that indicates no value was provided.
48USE_DEFAULT = object()
49
50
51@tf_export(v1=['train.Scaffold'])
52class Scaffold:
53  """Structure to create or gather pieces commonly needed to train a model.
54
55  When you build a model for training you usually need ops to initialize
56  variables, a `Saver` to checkpoint them, an op to collect summaries for
57  the visualizer, and so on.
58
59  Various libraries built on top of the core TensorFlow library take care of
60  creating some or all of these pieces and storing them in well known
61  collections in the graph.  The `Scaffold` class helps pick these pieces from
62  the graph collections, creating and adding them to the collections if needed.
63
64  If you call the scaffold constructor without any arguments, it will pick
65  pieces from the collections, creating default ones if needed when
66  `scaffold.finalize()` is called.  You can pass arguments to the constructor to
67  provide your own pieces.  Pieces that you pass to the constructor are not
68  added to the graph collections.
69
70  The following pieces are directly accessible as attributes of the `Scaffold`
71  object:
72
73  * `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the
74  variables.
75    Picked from and stored into the `SAVERS` collection in the graph by default.
76  * `init_op`: An op to run to initialize the variables.  Picked from and
77    stored into the `INIT_OP` collection in the graph by default.
78  * `ready_op`: An op to verify that the variables are initialized.  Picked
79    from and stored into the `READY_OP` collection in the graph by default.
80  * `ready_for_local_init_op`: An op to verify that global state has been
81    initialized and it is alright to run `local_init_op`.  Picked from and
82    stored into the `READY_FOR_LOCAL_INIT_OP` collection in the graph by
83    default. This is needed when the initialization of local variables depends
84    on the values of global variables.
85  * `local_init_op`: An op to initialize the local variables.  Picked
86    from and stored into the `LOCAL_INIT_OP` collection in the graph by default.
87  * `summary_op`: An op to run and merge the summaries in the graph.  Picked
88    from and stored into the `SUMMARY_OP` collection in the graph by default.
89
90  You can also pass the following additional pieces to the constructor:
91
92  * `init_feed_dict`: A session feed dictionary that should be used when
93     running the init op.
94  * `init_fn`: A callable to run after the init op to perform additional
95    initializations.  The callable will be called as
96    `init_fn(scaffold, session)`.
97
98  """
99
100  def __init__(self,
101               init_op=None,
102               init_feed_dict=None,
103               init_fn=None,
104               ready_op=None,
105               ready_for_local_init_op=None,
106               local_init_op=None,
107               summary_op=None,
108               saver=None,
109               copy_from_scaffold=None,
110               local_init_feed_dict=None):
111    """Create a scaffold.
112
113    Args:
114      init_op: Optional op for initializing variables.
115      init_feed_dict: Optional session feed dictionary to use when running the
116        init_op.
117      init_fn: Optional function to use to initialize the model after running
118        the init_op.  Will be called as `init_fn(scaffold, session)`.
119      ready_op: Optional op to verify that the variables are initialized.  Must
120        return an empty 1D string tensor when the variables are initialized, or
121        a non-empty 1D string tensor listing the names of the non-initialized
122        variables.
123      ready_for_local_init_op: Optional op to verify that the global variables
124        are initialized and `local_init_op` can be run. Must return an empty 1D
125        string tensor when the global variables are initialized, or a non-empty
126        1D string tensor listing the names of the non-initialized global
127        variables.
128      local_init_op: Optional op to initialize local variables.
129      summary_op: Optional op to gather all summaries.  Must return a scalar
130        string tensor containing a serialized `Summary` proto.
131      saver: Optional `tf.compat.v1.train.Saver` object to use to save and
132        restore variables.  May also be a `tf.train.Checkpoint` object, in which
133        case object-based checkpoints are saved. This will also load some
134        object-based checkpoints saved from elsewhere, but that loading may be
135        fragile since it uses fixed keys rather than performing a full
136        graph-based match. For example if a variable has two paths from the
137        `Checkpoint` object because two `Model` objects share the `Layer` object
138        that owns it, removing one `Model` may change the keys and break
139        checkpoint loading through this API, whereas a graph-based match would
140        match the variable through the other `Model`.
141      copy_from_scaffold: Optional scaffold object to copy fields from. Its
142        fields will be overwritten by the provided fields in this function.
143      local_init_feed_dict: Optional session feed dictionary to use when running
144        the local_init_op.
145    """
146    if copy_from_scaffold is not None:
147      if not isinstance(copy_from_scaffold, Scaffold):
148        raise TypeError('copy_from_scaffold is not a Scaffold instance.')
149      # We need _coalesce since Tensor is not converted to bool automatically,
150      # so the common idiom of (a or b) does not work.
151      coalesce = lambda a, b: a if a is not None else b
152      init_op = coalesce(init_op, copy_from_scaffold.init_op)
153      init_feed_dict = coalesce(init_feed_dict,
154                                copy_from_scaffold.init_feed_dict)
155      # Use the original init_fn provided by the user to init the new Scaffold.
156      init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn)  # pylint: disable=protected-access
157      ready_op = coalesce(ready_op, copy_from_scaffold.ready_op)
158      ready_for_local_init_op = coalesce(
159          ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
160      local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
161      local_init_feed_dict = coalesce(local_init_feed_dict,
162                                      copy_from_scaffold.local_init_feed_dict)
163      summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
164      saver = coalesce(saver, copy_from_scaffold.saver)
165
166    # NOTE(touts): modifying the init function to be passed the scaffold is a
167    # hack to make it easy to find the saver.  Is there a better way?
168    self._user_init_fn = init_fn
169    if init_fn:
170      self._init_fn = lambda sess: init_fn(self, sess)
171    else:
172      self._init_fn = None
173
174    self._init_op = init_op
175    self._init_feed_dict = init_feed_dict
176    self._ready_op = ready_op
177    self._ready_for_local_init_op = ready_for_local_init_op
178    self._local_init_op = local_init_op
179    self._local_init_feed_dict = local_init_feed_dict
180    self._summary_op = summary_op
181    self._saver = saver
182
183  def finalize(self):
184    """Creates operations if needed and finalizes the graph."""
185    if self._init_op is None:
186
187      def default_init_op():
188        return control_flow_ops.group(
189            variables.global_variables_initializer(),
190            resources.initialize_resources(resources.shared_resources()),
191            ops.get_collection('saved_model_initializers'))
192
193      self._init_op = Scaffold.get_or_default('init_op', ops.GraphKeys.INIT_OP,
194                                              default_init_op)
195    if self._ready_op is None:
196
197      def default_ready_op():
198        return array_ops.concat([
199            variables.report_uninitialized_variables(),
200            resources.report_uninitialized_resources()
201        ], 0)
202
203      self._ready_op = Scaffold.get_or_default('ready_op',
204                                               ops.GraphKeys.READY_OP,
205                                               default_ready_op)
206    if self._ready_for_local_init_op is None:
207
208      def default_ready_for_local_init_op():
209        return array_ops.concat([
210            variables.report_uninitialized_variables(
211                variables.global_variables()),
212            resources.report_uninitialized_resources(
213                resources.shared_resources())
214        ], 0)
215
216      self._ready_for_local_init_op = Scaffold.get_or_default(
217          'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
218          default_ready_for_local_init_op)
219    if self._local_init_op is None:
220      self._local_init_op = Scaffold.get_or_default(
221          'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
222          Scaffold.default_local_init_op)
223    if self._summary_op is None:
224      self._summary_op = Scaffold.get_or_default('summary_op',
225                                                 ops.GraphKeys.SUMMARY_OP,
226                                                 summary.merge_all)
227    # pylint: disable=g-long-lambda
228    if self._saver is None:
229      self._saver = training_saver._get_saver_or_default()  # pylint: disable=protected-access
230    # pylint: enable=g-long-lambda
231    if isinstance(self._saver, trackable_util.Checkpoint):
232      self._saver = training_saver.Saver(
233          var_list=graph_view.ObjectGraphView(
234              self._saver).frozen_saveable_objects(),
235          sharded=True)
236    else:
237      self._saver.build()
238
239    ops.get_default_graph().finalize()
240    logging.info('Graph was finalized.')
241    return self
242
243  @property
244  def init_fn(self):
245    return self._init_fn
246
247  @property
248  def init_op(self):
249    return self._init_op
250
251  @property
252  def ready_op(self):
253    return self._ready_op
254
255  @property
256  def ready_for_local_init_op(self):
257    return self._ready_for_local_init_op
258
259  @property
260  def local_init_op(self):
261    return self._local_init_op
262
263  @property
264  def local_init_feed_dict(self):
265    return self._local_init_feed_dict
266
267  @property
268  def summary_op(self):
269    return self._summary_op
270
271  @property
272  def saver(self):
273    return self._saver
274
275  @property
276  def init_feed_dict(self):
277    return self._init_feed_dict
278
279  @staticmethod
280  def get_or_default(arg_name, collection_key, default_constructor):
281    """Get from cache or create a default operation."""
282    elements = ops.get_collection(collection_key)
283    if elements:
284      if len(elements) > 1:
285        raise RuntimeError(
286            'More than one item in the collection "%s". '
287            'Please indicate which one to use by passing it to '
288            'the tf.Scaffold constructor as:  '
289            'tf.Scaffold(%s=item to use)', collection_key, arg_name)
290      return elements[0]
291    op = default_constructor()
292    if op is not None:
293      ops.add_to_collection(collection_key, op)
294    return op
295
296  @staticmethod
297  def default_local_init_op():
298    """Returns an op that groups the default local init ops.
299
300    This op is used during session initialization when a Scaffold is
301    initialized without specifying the local_init_op arg. It includes
302    `tf.compat.v1.local_variables_initializer`,
303    `tf.compat.v1.tables_initializer`, and also
304    initializes local session resources.
305
306    Returns:
307      The default Scaffold local init op.
308    """
309    return control_flow_ops.group(
310        variables.local_variables_initializer(),
311        lookup_ops.tables_initializer(),
312        resources.initialize_resources(resources.local_resources()))
313
314
315def _create_monitored_session_with_worker_context(
316    worker_context,  # pylint: disable=missing-docstring
317    scaffold,
318    checkpoint_dir=None,
319    hooks=None,
320    chief_only_hooks=None,
321    save_checkpoint_secs=None,
322    save_summaries_steps=None,
323    save_summaries_secs=None,
324    config=None,
325    stop_grace_period_secs=120,
326    log_step_count_steps=100,
327    max_wait_secs=7200,
328    save_checkpoint_steps=None,
329    summary_dir=None,
330    save_graph_def=True):
331  all_hooks = []
332  if hooks:
333    all_hooks.extend(hooks)
334  if chief_only_hooks and worker_context.is_chief:
335    all_hooks.extend(chief_only_hooks)
336
337  # We need to call save or summary ops on all workers since these ops may
338  # contain collective ops, only running save ops on some workers would make
339  # collective ops hang. Therefore on those workers that don't need to actually
340  # write checkpoints or summaries, we let them write to a temp directory.
341  # pylint: disable=protected-access
342  if type(
343      worker_context._strategy).__name__ in ('CollectiveAllReduceStrategy',
344                                             'CollectiveAllReduceStrategyV1',
345                                             'MultiWorkerMirroredStrategy'):
346    if worker_context.task_type:
347      tmpdir = 'tmp_%s_%d' % (worker_context.task_type, worker_context.task_id)
348    else:
349      tmpdir = 'tmp'
350
351    if save_checkpoint_secs:
352      logging.warning('Collective ops may deadlock with '
353                      '`save_checkpoints_secs` please use '
354                      '`save_checkpoint_steps` instead. Clearing '
355                      '`save_checkpoint_secs` and setting '
356                      '`save_checkpoint_steps` to 1000 now.')
357      save_checkpoint_secs = None
358      save_checkpoint_steps = 1000
359    if save_summaries_secs:
360      logging.warning('Collective ops may run out of sync with'
361                      '`save_summaries_secs`, please use '
362                      '`save_summaries_steps` instead.')
363  else:
364    tmpdir = None
365
366  summary_dir = summary_dir or checkpoint_dir
367  if summary_dir and log_step_count_steps and log_step_count_steps > 0:
368    if worker_context.should_save_summary:
369      all_hooks.append(
370          basic_session_run_hooks.StepCounterHook(
371              output_dir=summary_dir, every_n_steps=log_step_count_steps))
372    elif tmpdir:
373      all_hooks.append(
374          basic_session_run_hooks.StepCounterHook(
375              output_dir=os.path.join(summary_dir, tmpdir),
376              every_n_steps=log_step_count_steps))
377
378  if (((save_summaries_steps and save_summaries_steps > 0) or
379       (save_summaries_secs and save_summaries_secs > 0)) and summary_dir):
380    if worker_context.should_save_summary:
381      all_hooks.append(
382          basic_session_run_hooks.SummarySaverHook(
383              scaffold=scaffold,
384              save_steps=save_summaries_steps,
385              save_secs=save_summaries_secs,
386              output_dir=summary_dir))
387    elif tmpdir:
388      all_hooks.append(
389          basic_session_run_hooks.SummarySaverHook(
390              scaffold=scaffold,
391              save_steps=save_summaries_steps,
392              save_secs=save_summaries_secs,
393              output_dir=os.path.join(summary_dir, tmpdir)))
394
395    if (((save_checkpoint_secs and save_checkpoint_secs > 0) or
396         (save_checkpoint_steps and save_checkpoint_steps > 0)) and
397        checkpoint_dir):
398      if worker_context.should_checkpoint:
399        all_hooks.append(
400            basic_session_run_hooks.CheckpointSaverHook(
401                checkpoint_dir,
402                save_steps=save_checkpoint_steps,
403                save_secs=save_checkpoint_secs,
404                scaffold=scaffold,
405                save_graph_def=save_graph_def))
406      elif tmpdir:
407        all_hooks.append(
408            basic_session_run_hooks.CheckpointSaverHook(
409                os.path.join(checkpoint_dir, tmpdir),
410                save_steps=save_checkpoint_steps,
411                save_secs=save_checkpoint_secs,
412                scaffold=scaffold,
413                save_graph_def=save_graph_def))
414
415  logging.info('all_hooks %r', all_hooks)
416  session_creator = worker_context.session_creator(
417      scaffold,
418      config=config,
419      checkpoint_dir=checkpoint_dir,
420      max_wait_secs=max_wait_secs)
421  return MonitoredSession(
422      session_creator=session_creator,
423      hooks=all_hooks,
424      stop_grace_period_secs=stop_grace_period_secs)
425
426
427@tf_export(v1=['train.MonitoredTrainingSession'])
428def MonitoredTrainingSession(
429    master='',  # pylint: disable=invalid-name
430    is_chief=True,
431    checkpoint_dir=None,
432    scaffold=None,
433    hooks=None,
434    chief_only_hooks=None,
435    save_checkpoint_secs=USE_DEFAULT,
436    save_summaries_steps=USE_DEFAULT,
437    save_summaries_secs=USE_DEFAULT,
438    config=None,
439    stop_grace_period_secs=120,
440    log_step_count_steps=100,
441    max_wait_secs=7200,
442    save_checkpoint_steps=USE_DEFAULT,
443    summary_dir=None,
444    save_graph_def=True):
445  """Creates a `MonitoredSession` for training.
446
447  For a chief, this utility sets proper session initializer/restorer. It also
448  creates hooks related to checkpoint and summary saving. For workers, this
449  utility sets proper session creator which waits for the chief to
450  initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for
451  more
452  information.
453
454  @compatibility(TF2)
455  This API is not compatible with eager execution and `tf.function`. To migrate
456  to TF2, rewrite the code to be compatible with eager execution. Check the
457  [migration
458  guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
459  on replacing `Session.run` calls. In Keras, session hooks can be replaced by
460  Callbacks e.g. [logging hook notebook](
461  https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb)
462  For more details please read [Better
463  performance with tf.function](https://www.tensorflow.org/guide/function).
464  @end_compatibility
465
466  Args:
467    master: `String` the TensorFlow master to use.
468    is_chief: If `True`, it will take care of initialization and recovery the
469      underlying TensorFlow session. If `False`, it will wait on a chief to
470      initialize or recover the TensorFlow session.
471    checkpoint_dir: A string.  Optional path to a directory where to restore
472      variables.
473    scaffold: A `Scaffold` used for gathering or building supportive ops. If not
474      specified, a default one is created. It's used to finalize the graph.
475    hooks: Optional list of `SessionRunHook` objects.
476    chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
477      `is_chief==True`, ignore otherwise.
478    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
479      using a default checkpoint saver. If both `save_checkpoint_steps` and
480      `save_checkpoint_secs` are set to `None`, then the default checkpoint
481      saver isn't used. If both are provided, then only `save_checkpoint_secs`
482      is used. Default 600.
483    save_summaries_steps: The frequency, in number of global steps, that the
484      summaries are written to disk using a default summary saver. If both
485      `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
486      the default summary saver isn't used. Default 100.
487    save_summaries_secs: The frequency, in secs, that the summaries are written
488      to disk using a default summary saver.  If both `save_summaries_steps` and
489      `save_summaries_secs` are set to `None`, then the default summary saver
490      isn't used. Default not enabled.
491    config: an instance of `tf.compat.v1.ConfigProto` proto used to configure
492      the session. It's the `config` argument of constructor of
493      `tf.compat.v1.Session`.
494    stop_grace_period_secs: Number of seconds given to threads to stop after
495      `close()` has been called.
496    log_step_count_steps: The frequency, in number of global steps, that the
497      global step/sec is logged.
498    max_wait_secs: Maximum time workers should wait for the session to become
499      available. This should be kept relatively short to help detect incorrect
500      code, but sometimes may need to be increased if the chief takes a while to
501      start up.
502    save_checkpoint_steps: The frequency, in number of global steps, that a
503      checkpoint is saved using a default checkpoint saver. If both
504      `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then
505      the default checkpoint saver isn't used. If both are provided, then only
506      `save_checkpoint_secs` is used. Default not enabled.
507    summary_dir: A string.  Optional path to a directory where to save
508      summaries. If None, checkpoint_dir is used instead.
509    save_graph_def: Whether to save the GraphDef and MetaGraphDef to
510      `checkpoint_dir`. The GraphDef is saved after the session is created as
511      `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
512      `model.ckpt-*.meta`.
513
514  Returns:
515    A `MonitoredSession` object.
516  """
517  if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT:
518    save_summaries_steps = 100
519    save_summaries_secs = None
520  elif save_summaries_secs == USE_DEFAULT:
521    save_summaries_secs = None
522  elif save_summaries_steps == USE_DEFAULT:
523    save_summaries_steps = None
524
525  if (save_checkpoint_steps == USE_DEFAULT and
526      save_checkpoint_secs == USE_DEFAULT):
527    save_checkpoint_steps = None
528    save_checkpoint_secs = 600
529  elif save_checkpoint_secs == USE_DEFAULT:
530    save_checkpoint_secs = None
531  elif save_checkpoint_steps == USE_DEFAULT:
532    save_checkpoint_steps = None
533
534  scaffold = scaffold or Scaffold()
535  worker_context = distribute_coordinator_context.get_current_worker_context()
536
537  if worker_context:
538    return _create_monitored_session_with_worker_context(
539        worker_context,
540        scaffold,
541        checkpoint_dir=checkpoint_dir,
542        hooks=hooks,
543        chief_only_hooks=chief_only_hooks,
544        save_checkpoint_secs=save_checkpoint_secs,
545        save_summaries_steps=save_summaries_steps,
546        save_summaries_secs=save_summaries_secs,
547        config=config,
548        stop_grace_period_secs=stop_grace_period_secs,
549        log_step_count_steps=log_step_count_steps,
550        max_wait_secs=max_wait_secs,
551        save_checkpoint_steps=save_checkpoint_steps,
552        summary_dir=summary_dir,
553        save_graph_def=save_graph_def)
554
555  if not is_chief:
556    session_creator = WorkerSessionCreator(
557        scaffold=scaffold,
558        master=master,
559        config=config,
560        max_wait_secs=max_wait_secs)
561    return MonitoredSession(
562        session_creator=session_creator,
563        hooks=hooks or [],
564        stop_grace_period_secs=stop_grace_period_secs)
565
566  all_hooks = []
567  if chief_only_hooks:
568    all_hooks.extend(chief_only_hooks)
569  session_creator = ChiefSessionCreator(
570      scaffold=scaffold,
571      checkpoint_dir=checkpoint_dir,
572      master=master,
573      config=config)
574
575  summary_dir = summary_dir or checkpoint_dir
576  if summary_dir:
577    if log_step_count_steps and log_step_count_steps > 0:
578      all_hooks.append(
579          basic_session_run_hooks.StepCounterHook(
580              output_dir=summary_dir, every_n_steps=log_step_count_steps))
581
582    if (save_summaries_steps and
583        save_summaries_steps > 0) or (save_summaries_secs and
584                                      save_summaries_secs > 0):
585      all_hooks.append(
586          basic_session_run_hooks.SummarySaverHook(
587              scaffold=scaffold,
588              save_steps=save_summaries_steps,
589              save_secs=save_summaries_secs,
590              output_dir=summary_dir))
591
592  if checkpoint_dir:
593    if (save_checkpoint_secs and
594        save_checkpoint_secs > 0) or (save_checkpoint_steps and
595                                      save_checkpoint_steps > 0):
596      all_hooks.append(
597          basic_session_run_hooks.CheckpointSaverHook(
598              checkpoint_dir,
599              save_steps=save_checkpoint_steps,
600              save_secs=save_checkpoint_secs,
601              scaffold=scaffold,
602              save_graph_def=save_graph_def))
603
604  if hooks:
605    all_hooks.extend(hooks)
606  return MonitoredSession(
607      session_creator=session_creator,
608      hooks=all_hooks,
609      stop_grace_period_secs=stop_grace_period_secs)
610
611
612@tf_export(v1=['train.SessionCreator'])
613class SessionCreator(metaclass=abc.ABCMeta):
614  """A factory for tf.Session."""
615
616  @abc.abstractmethod
617  def create_session(self):
618    raise NotImplementedError(
619        'create_session is not implemented for {}.'.format(self))
620
621
622@tf_export(v1=['train.ChiefSessionCreator'])
623class ChiefSessionCreator(SessionCreator):
624  """Creates a tf.compat.v1.Session for a chief."""
625
626  def __init__(self,
627               scaffold=None,
628               master='',
629               config=None,
630               checkpoint_dir=None,
631               checkpoint_filename_with_path=None):
632    """Initializes a chief session creator.
633
634    Args:
635      scaffold: A `Scaffold` used for gathering or building supportive ops. If
636        not specified a default one is created. It's used to finalize the graph.
637      master: `String` representation of the TensorFlow master to use.
638      config: `ConfigProto` proto used to configure the session.
639      checkpoint_dir: A string.  Optional path to a directory where to restore
640        variables.
641      checkpoint_filename_with_path: Full file name path to the checkpoint file.
642    """
643    self._checkpoint_dir = checkpoint_dir
644    self._checkpoint_filename_with_path = checkpoint_filename_with_path
645    self._scaffold = scaffold or Scaffold()
646    self._session_manager = None
647    self._master = master
648    self._config = config
649
650  def _get_session_manager(self):
651    """Gets or creates a SessionManager."""
652    if self._session_manager:
653      return self._session_manager
654
655    self._session_manager = sm.SessionManager(
656        local_init_op=self._scaffold.local_init_op,
657        local_init_feed_dict=self._scaffold.local_init_feed_dict,
658        ready_op=self._scaffold.ready_op,
659        ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
660        graph=ops.get_default_graph())
661    return self._session_manager
662
663  def create_session(self):
664    self._scaffold.finalize()
665    return self._get_session_manager().prepare_session(
666        self._master,
667        saver=self._scaffold.saver,
668        checkpoint_dir=self._checkpoint_dir,
669        checkpoint_filename_with_path=self._checkpoint_filename_with_path,
670        config=self._config,
671        init_op=self._scaffold.init_op,
672        init_feed_dict=self._scaffold.init_feed_dict,
673        init_fn=self._scaffold.init_fn)
674
675
676@tf_export(v1=['train.WorkerSessionCreator'])
677class WorkerSessionCreator(SessionCreator):
678  """Creates a tf.compat.v1.Session for a worker."""
679
680  def __init__(self,
681               scaffold=None,
682               master='',
683               config=None,
684               max_wait_secs=30 * 60):
685    """Initializes a worker session creator.
686
687    Args:
688      scaffold: A `Scaffold` used for gathering or building supportive ops. If
689        not specified a default one is created. It's used to finalize the graph.
690      master: `String` representation of the TensorFlow master to use.
691      config: `ConfigProto` proto used to configure the session.
692      max_wait_secs: Maximum time to wait for the session to become available.
693    """
694    self._scaffold = scaffold or Scaffold()
695    self._session_manager = None
696    self._master = master
697    self._config = config
698    self._max_wait_secs = max_wait_secs
699
700  def _get_session_manager(self):
701    """Gets or creates a SessionManager."""
702    if self._session_manager:
703      return self._session_manager
704
705    self._session_manager = sm.SessionManager(
706        local_init_op=self._scaffold.local_init_op,
707        local_init_feed_dict=self._scaffold.local_init_feed_dict,
708        ready_op=self._scaffold.ready_op,
709        ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
710        graph=ops.get_default_graph())
711    return self._session_manager
712
713  def create_session(self):
714    self._scaffold.finalize()
715    return self._get_session_manager().wait_for_session(
716        self._master, config=self._config, max_wait_secs=self._max_wait_secs)
717
718
719class _MonitoredSession:
720  """See `MonitoredSession` or `SingularMonitoredSession`."""
721
722  def __init__(self,
723               session_creator,
724               hooks,
725               should_recover,
726               stop_grace_period_secs=120):
727    """Sets up a Monitored or Hooked Session.
728
729    Args:
730      session_creator: A factory object to create session. Typically a
731        `ChiefSessionCreator` or a `WorkerSessionCreator`.
732      hooks: An iterable of `SessionRunHook' objects.
733      should_recover: A bool. Indicates whether to recover from `AbortedError`
734        and `UnavailableError` or not.
735      stop_grace_period_secs: Number of seconds given to threads to stop after
736        `close()` has been called.
737    """
738    self._graph_was_finalized = ops.get_default_graph().finalized
739    self._hooks = hooks or []
740    for h in self._hooks:
741      h.begin()
742
743    worker_context = distribute_coordinator_context.get_current_worker_context()
744    if not session_creator and worker_context:
745      session_creator = worker_context.session_creator()
746
747    # Create the session.
748    self._coordinated_creator = self._CoordinatedSessionCreator(
749        session_creator=session_creator or ChiefSessionCreator(),
750        hooks=self._hooks,
751        stop_grace_period_secs=stop_grace_period_secs)
752    if should_recover:
753      self._sess = _RecoverableSession(self._coordinated_creator)
754    else:
755      self._sess = self._coordinated_creator.create_session()
756
757  @property
758  def graph(self):
759    """The graph that was launched in this session."""
760    if self._tf_sess() is None:
761      return None
762    return self._tf_sess().graph
763
764  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
765    """Run ops in the monitored session.
766
767    This method is completely compatible with the `tf.Session.run()` method.
768
769    Args:
770      fetches: Same as `tf.Session.run()`.
771      feed_dict: Same as `tf.Session.run()`.
772      options: Same as `tf.Session.run()`.
773      run_metadata: Same as `tf.Session.run()`.
774
775    Returns:
776      Same as `tf.Session.run()`.
777    """
778    return self._sess.run(
779        fetches,
780        feed_dict=feed_dict,
781        options=options,
782        run_metadata=run_metadata)
783
784  def run_step_fn(self, step_fn):
785    """Run ops using a step function.
786
787    Args:
788      step_fn: A function or a method with a single argument of type
789        `StepContext`.  The function may use methods of the argument to perform
790        computations with access to a raw session.  The returned value of the
791        `step_fn` will be returned from `run_step_fn`, unless a stop is
792        requested.  In that case, the next `should_stop` call will return True.
793        Example usage:
794            ```python
795            with tf.Graph().as_default():
796              c = tf.compat.v1.placeholder(dtypes.float32)
797              v = tf.add(c, 4.0)
798              w = tf.add(c, 0.5)
799              def step_fn(step_context):
800                a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
801                if a <= 4.5:
802                  step_context.request_stop()
803                  return step_context.run_with_hooks(fetches=w,
804                                                     feed_dict={c: 0.1})
805
806              with tf.MonitoredSession() as session:
807                while not session.should_stop():
808                  a = session.run_step_fn(step_fn)
809            ```
810            Hooks interact with the `run_with_hooks()` call inside the
811                 `step_fn` as they do with a `MonitoredSession.run` call.
812
813    Returns:
814      Returns the returned value of `step_fn`.
815
816    Raises:
817      StopIteration: if `step_fn` has called `request_stop()`.  It may be
818        caught by `with tf.MonitoredSession()` to close the session.
819      ValueError: if `step_fn` doesn't have a single argument called
820        `step_context`. It may also optionally have `self` for cases when it
821        belongs to an object.
822    """
823    step_fn_arguments = function_utils.fn_args(step_fn)
824    if step_fn_arguments != ('step_context',) and step_fn_arguments != (
825        'self',
826        'step_context',
827    ):
828      raise ValueError(
829          '`step_fn` may either have one `step_context` argument, or'
830          ' `self` and `step_context` arguments if it\'s an instance'
831          ' method. Got {} instead.'.format(step_fn_arguments))
832
833    # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
834    # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
835    # `_CoordinatedSession.run` downstream in either case. This allows
836    # `_PREEMPTION_ERRORS` to propage from within `step_fn` to
837    # `_RecoverableSession.run_step_fn`.
838    return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
839
840  class StepContext:
841    """Control flow instrument for the `step_fn` from `run_step_fn()`.
842
843       Users of `step_fn` may perform `run()` calls without running hooks
844       by accessing the `session`.  A `run()` call with hooks may be performed
845       using `run_with_hooks()`.  Computation flow can be interrupted using
846       `request_stop()`.
847    """
848
849    def __init__(self, session, run_with_hooks_fn):
850      """Initializes the `step_context` argument for a `step_fn` invocation.
851
852      Args:
853        session: An instance of `tf.compat.v1.Session`.
854        run_with_hooks_fn: A function for running fetches and hooks.
855      """
856      self._session = session
857      self._run_with_hooks_fn = run_with_hooks_fn
858
859    @property
860    def session(self):
861      return self._session
862
863    def run_with_hooks(self, *args, **kwargs):
864      """Same as `MonitoredSession.run`. Accepts the same arguments."""
865      return self._run_with_hooks_fn(*args, **kwargs)
866
867    def request_stop(self):
868      """Exit the training loop by causing `should_stop()` to return `True`.
869
870         Causes `step_fn` to exit by raising an exception.
871
872      Raises:
873        StopIteration
874      """
875      raise StopIteration('step_fn has requested the iterations to stop.')
876
877  def should_stop(self):
878    return self._sess is None or self._sess.should_stop()
879
880  def close(self):
881    self._close_internal()
882
883  def __enter__(self):
884    return self
885
886  def __exit__(self, exception_type, exception_value, traceback):
887    if exception_type in [errors.OutOfRangeError, StopIteration]:
888      exception_type = None
889    self._close_internal(exception_type)
890    # __exit__ should return True to suppress an exception.
891    return exception_type is None
892
893  class _CoordinatedSessionCreator(SessionCreator):
894    """Factory for _CoordinatedSession."""
895
896    def __init__(self, session_creator, hooks, stop_grace_period_secs):
897      self._session_creator = session_creator
898      self._hooks = hooks
899      self.coord = None
900      self.tf_sess = None
901      self._stop_grace_period_secs = stop_grace_period_secs
902
903    def create_session(self):
904      """Creates a coordinated session."""
905      # Keep the tf_sess for unit testing.
906      self.tf_sess = self._session_creator.create_session()
907      # We don't want coordinator to suppress any exception.
908      self.coord = coordinator.Coordinator(clean_stop_exception_types=[])
909      if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
910        queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord)
911      # Inform the hooks that a new session has been created.
912      for hook in self._hooks:
913        hook.after_create_session(self.tf_sess, self.coord)
914      return _CoordinatedSession(
915          _HookedSession(self.tf_sess, self._hooks), self.coord,
916          self._stop_grace_period_secs)
917
918  def _close_internal(self, exception_type=None):
919    try:
920      if not exception_type:
921        for h in self._hooks:
922          h.end(self._coordinated_creator.tf_sess)
923    finally:
924      try:
925        if self._sess is None:
926          raise RuntimeError('Session is already closed.')
927        self._sess.close()
928      finally:
929        self._sess = None
930        self._coordinated_creator.tf_sess = None
931        self._coordinated_creator.coord = None
932        if not self._graph_was_finalized:
933          ops.get_default_graph()._unsafe_unfinalize()  # pylint: disable=protected-access
934
935  def _is_closed(self):
936    """Return True if the monitored session is closed.
937
938    For tests only.
939
940    Returns:
941      A boolean.
942    """
943    return self._coordinated_creator.tf_sess is None
944
945  def _tf_sess(self):
946    """Return underlying tf.compat.v1.Session object.
947
948    Warning: accessing the returned object in user code is likely to cause races
949    or "flaky tests".
950
951    Returns:
952      A tf.compat.v1.Session object.
953    """
954    return self._coordinated_creator.tf_sess
955
956
957@tf_export(v1=['train.MonitoredSession'])
958class MonitoredSession(_MonitoredSession):
959  """Session-like object that handles initialization, recovery and hooks.
960
961  Example usage:
962
963  ```python
964  saver_hook = CheckpointSaverHook(...)
965  summary_hook = SummarySaverHook(...)
966  with MonitoredSession(session_creator=ChiefSessionCreator(...),
967                        hooks=[saver_hook, summary_hook]) as sess:
968    while not sess.should_stop():
969      sess.run(train_op)
970  ```
971
972  Initialization: At creation time the monitored session does following things
973  in given order:
974
975  * calls `hook.begin()` for each given hook
976  * finalizes the graph via `scaffold.finalize()`
977  * create session
978  * initializes the model via initialization ops provided by `Scaffold`
979  * restores variables if a checkpoint exists
980  * launches queue runners
981  * calls `hook.after_create_session()`
982
983  Run: When `run()` is called, the monitored session does following things:
984
985  * calls `hook.before_run()`
986  * calls TensorFlow `session.run()` with merged fetches and feed_dict
987  * calls `hook.after_run()`
988  * returns result of `session.run()` asked by user
989  * if `AbortedError` or `UnavailableError` occurs, it recovers or
990    reinitializes the session before executing the run() call again
991
992
993  Exit: At the `close()`, the monitored session does following things in order:
994
995  * calls `hook.end()`
996  * closes the queue runners and the session
997  * suppresses `OutOfRange` error which indicates that all inputs have been
998    processed if the monitored_session is used as a context
999
1000  How to set `tf.compat.v1.Session` arguments:
1001
1002  * In most cases you can set session arguments as follows:
1003
1004  ```python
1005  MonitoredSession(
1006    session_creator=ChiefSessionCreator(master=..., config=...))
1007  ```
1008
1009  * In distributed setting for a non-chief worker, you can use following:
1010
1011  ```python
1012  MonitoredSession(
1013    session_creator=WorkerSessionCreator(master=..., config=...))
1014  ```
1015
1016  See `MonitoredTrainingSession` for an example usage based on chief or worker.
1017
1018  Note: This is not a `tf.compat.v1.Session`. For example, it cannot do
1019  following:
1020
1021  * it cannot be set as default session.
1022  * it cannot be sent to saver.save.
1023  * it cannot be sent to tf.train.start_queue_runners.
1024
1025  @compatibility(TF2)
1026  This API is not compatible with eager execution and `tf.function`. To migrate
1027  to TF2, rewrite the code to be compatible with eager execution. Check the
1028  [migration
1029  guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
1030  on replacing `Session.run` calls. In Keras, session hooks can be replaced by
1031  Callbacks e.g. [logging hook notebook](
1032  https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb)
1033  For more details please read [Better
1034  performance with tf.function](https://www.tensorflow.org/guide/function).
1035  @end_compatibility
1036
1037  Args:
1038    session_creator: A factory object to create session. Typically a
1039      `ChiefSessionCreator` which is the default one.
1040    hooks: An iterable of `SessionRunHook' objects.
1041
1042  Returns:
1043    A MonitoredSession object.
1044  """
1045
1046  def __init__(self,
1047               session_creator=None,
1048               hooks=None,
1049               stop_grace_period_secs=120):
1050    super(MonitoredSession, self).__init__(
1051        session_creator,
1052        hooks,
1053        should_recover=True,
1054        stop_grace_period_secs=stop_grace_period_secs)
1055
1056
1057@tf_export(v1=['train.SingularMonitoredSession'])
1058class SingularMonitoredSession(_MonitoredSession):
1059  """Session-like object that handles initialization, restoring, and hooks.
1060
1061  Please note that this utility is not recommended for distributed settings.
1062  For distributed settings, please use `tf.compat.v1.train.MonitoredSession`.
1063  The
1064  differences between `MonitoredSession` and `SingularMonitoredSession` are:
1065
1066  * `MonitoredSession` handles `AbortedError` and `UnavailableError` for
1067    distributed settings, but `SingularMonitoredSession` does not.
1068  * `MonitoredSession` can be created in `chief` or `worker` modes.
1069    `SingularMonitoredSession` is always created as `chief`.
1070  * You can access the raw `tf.compat.v1.Session` object used by
1071    `SingularMonitoredSession`, whereas in MonitoredSession the raw session is
1072    private. This can be used:
1073      - To `run` without hooks.
1074      - To save and restore.
1075  * All other functionality is identical.
1076
1077  Example usage:
1078  ```python
1079  saver_hook = CheckpointSaverHook(...)
1080  summary_hook = SummarySaverHook(...)
1081  with SingularMonitoredSession(hooks=[saver_hook, summary_hook]) as sess:
1082    while not sess.should_stop():
1083      sess.run(train_op)
1084  ```
1085
1086  Initialization: At creation time the hooked session does following things
1087  in given order:
1088
1089  * calls `hook.begin()` for each given hook
1090  * finalizes the graph via `scaffold.finalize()`
1091  * create session
1092  * initializes the model via initialization ops provided by `Scaffold`
1093  * restores variables if a checkpoint exists
1094  * launches queue runners
1095
1096  Run: When `run()` is called, the hooked session does following things:
1097
1098  * calls `hook.before_run()`
1099  * calls TensorFlow `session.run()` with merged fetches and feed_dict
1100  * calls `hook.after_run()`
1101  * returns result of `session.run()` asked by user
1102
1103  Exit: At the `close()`, the hooked session does following things in order:
1104
1105  * calls `hook.end()`
1106  * closes the queue runners and the session
1107  * suppresses `OutOfRange` error which indicates that all inputs have been
1108    processed if the `SingularMonitoredSession` is used as a context.
1109
1110  @compatibility(TF2)
1111  This API is not compatible with eager execution and `tf.function`. To migrate
1112  to TF2, rewrite the code to be compatible with eager execution. Check the
1113  [migration
1114  guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
1115  on replacing `Session.run` calls. In Keras, session hooks can be replaced by
1116  Callbacks e.g. [logging hook notebook](
1117  https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb)
1118  For more details please read [Better
1119  performance with tf.function](https://www.tensorflow.org/guide/function).
1120  @end_compatibility
1121  """
1122
1123  def __init__(self,
1124               hooks=None,
1125               scaffold=None,
1126               master='',
1127               config=None,
1128               checkpoint_dir=None,
1129               stop_grace_period_secs=120,
1130               checkpoint_filename_with_path=None):
1131    """Creates a SingularMonitoredSession.
1132
1133    Args:
1134      hooks: An iterable of `SessionRunHook' objects.
1135      scaffold: A `Scaffold` used for gathering or building supportive ops. If
1136        not specified a default one is created. It's used to finalize the graph.
1137      master: `String` representation of the TensorFlow master to use.
1138      config: `ConfigProto` proto used to configure the session.
1139      checkpoint_dir: A string.  Optional path to a directory where to restore
1140        variables.
1141      stop_grace_period_secs: Number of seconds given to threads to stop after
1142        `close()` has been called.
1143      checkpoint_filename_with_path: A string. Optional path to a checkpoint
1144        file from which to restore variables.
1145    """
1146    session_creator = ChiefSessionCreator(
1147        scaffold=scaffold,
1148        master=master,
1149        config=config,
1150        checkpoint_dir=checkpoint_dir,
1151        checkpoint_filename_with_path=checkpoint_filename_with_path)
1152    super(SingularMonitoredSession, self).__init__(
1153        session_creator,
1154        hooks,
1155        should_recover=False,
1156        stop_grace_period_secs=stop_grace_period_secs)
1157
1158  def raw_session(self):
1159    """Returns underlying `TensorFlow.Session` object."""
1160    return self._tf_sess()
1161
1162
1163class _WrappedSession:
1164  """Wrapper around a `tf.compat.v1.Session`.
1165
1166  This wrapper is used as a base class for various session wrappers
1167  that provide additional functionality such as monitoring, coordination,
1168  and recovery.
1169
1170  In addition to the methods exported by `SessionInterface` the wrapper
1171  provides a method to check for stop and never raises exceptions from
1172  calls to `close()`.
1173  """
1174
1175  def __init__(self, sess):
1176    """Creates a `_WrappedSession`.
1177
1178    Args:
1179      sess: A `tf.compat.v1.Session` or `_WrappedSession` object.  The wrapped
1180        session.
1181    """
1182    self._sess = sess
1183    self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession)
1184
1185  @property
1186  def graph(self):
1187    return self._sess.graph
1188
1189  @property
1190  def sess_str(self):
1191    return self._sess.sess_str
1192
1193  def should_stop(self):
1194    """Return true if this session should not be used anymore.
1195
1196    Always return True if the session was closed.
1197
1198    Returns:
1199      True if the session should stop, False otherwise.
1200    """
1201    if self._check_stop():
1202      return True
1203    if self._sess:
1204      return self._wrapped_is_stoppable and self._sess.should_stop()
1205    return True
1206
1207  def _check_stop(self):
1208    """Hook for subclasses to provide their own stop condition.
1209
1210    Returns:
1211      True if the session should stop, False otherwise.
1212    """
1213    return False
1214
1215  def close(self):
1216    if self._sess:
1217      try:
1218        self._sess.close()
1219      except _PREEMPTION_ERRORS as e:
1220        logging.error(
1221            'An error occurred when attempting to close the '
1222            'session. This may be due to a preemption in a '
1223            'connected worker or parameter server. Error: %s', e)
1224      finally:
1225        self._sess = None
1226
1227  def run(self, *args, **kwargs):
1228    return self._sess.run(*args, **kwargs)
1229
1230  def run_step_fn(self, step_fn, raw_session, run_with_hooks):
1231    # `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`.
1232    # It is `None` when called from `_CoordinatedSession`. In that case
1233    # `self.run` is `_CoordinatedSession.run`.
1234    run_with_hooks = run_with_hooks or self.run
1235    return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks))
1236
1237
1238class _RecoverableSession(_WrappedSession):
1239  """A wrapped session that recreates a session upon certain kinds of errors.
1240
1241  The constructor is passed a SessionCreator object, not a session.
1242
1243  Calls to `run()` are delegated to the wrapped session.  If a call raises the
1244  exception `tf.errors.AbortedError` or `tf.errors.UnavailableError`, the
1245  wrapped session is closed, and a new one is created by calling the factory
1246  again.
1247  """
1248
1249  def __init__(self, sess_creator):
1250    """Create a new `_RecoverableSession`.
1251
1252    The value returned by calling `sess_creator.create_session()` will be the
1253    session wrapped by this recoverable session.
1254
1255    Args:
1256      sess_creator: A 'SessionCreator' to be wrapped by recoverable.
1257    """
1258    self._sess_creator = sess_creator
1259    _WrappedSession.__init__(self, self._create_session())
1260
1261  def _create_session(self):
1262    while True:
1263      try:
1264        return self._sess_creator.create_session()
1265      except _PREEMPTION_ERRORS as e:
1266        logging.info(
1267            'An error was raised while a session was being created. '
1268            'This may be due to a preemption of a connected worker '
1269            'or parameter server. A new session will be created. '
1270            'This error may also occur due to a gRPC failure caused '
1271            'by high memory or network bandwidth usage in the '
1272            'parameter servers. If this error occurs repeatedly, try '
1273            'increasing the number of parameter servers assigned to '
1274            'the job. Error: %s', e)
1275
1276  def _check_stop(self):
1277    try:
1278      if self._sess:
1279        return self._sess._check_stop()  # pylint: disable=protected-access
1280      else:
1281        return True
1282    except _PREEMPTION_ERRORS as e:
1283      logging.info(
1284          'An error was raised while considering whether the '
1285          'session is complete. This may be due to a preemption in '
1286          'a connected worker or parameter server. The current '
1287          'session will be closed and a new session will be '
1288          'created. This error may also occur due to a gRPC failure '
1289          'caused by high memory or network bandwidth usage in the '
1290          'parameter servers. If this error occurs repeatedly, try '
1291          'increasing the number of parameter servers assigned to '
1292          'the job. Error: %s', e)
1293      self.close()
1294      self._sess = self._create_session()
1295      # Since we have just recreated the session, the overall computation should
1296      # not stop:
1297      return False
1298    except Exception:  # pylint: disable=broad-except
1299      # `should_stop` should return True instead of raising an exception.
1300      return True
1301
1302  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
1303    while True:
1304      try:
1305        if not self._sess:
1306          self._sess = self._create_session()
1307        return self._sess.run(
1308            fetches,
1309            feed_dict=feed_dict,
1310            options=options,
1311            run_metadata=run_metadata)
1312      except _PREEMPTION_ERRORS as e:
1313        logging.info(
1314            'An error was raised. This may be due to a preemption in '
1315            'a connected worker or parameter server. The current '
1316            'session will be closed and a new session will be '
1317            'created. This error may also occur due to a gRPC failure '
1318            'caused by high memory or network bandwidth usage in the '
1319            'parameter servers. If this error occurs repeatedly, try '
1320            'increasing the number of parameter servers assigned to '
1321            'the job. Error: %s', e)
1322        self.close()
1323        self._sess = None
1324
1325  def run_step_fn(self, step_fn, raw_session, run_with_hooks):
1326    while True:
1327      try:
1328        if not self._sess:
1329          self._sess = self._create_session()
1330
1331        run_with_hooks = self._sess.run
1332        return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks)
1333      except _PREEMPTION_ERRORS as e:
1334        logging.info(
1335            'An error was raised. This may be due to a preemption in '
1336            'a connected worker or parameter server. The current '
1337            'session will be closed and a new session will be '
1338            'created. This error may also occur due to a gRPC failure '
1339            'caused by high memory or network bandwidth usage in the '
1340            'parameter servers. If this error occurs repeatedly, try '
1341            'increasing the number of parameter servers assigned to '
1342            'the job. Error: %s', e)
1343        self.close()
1344        self._sess = None
1345
1346
1347class _CoordinatedSession(_WrappedSession):
1348  """A wrapped session that works with a `tf.Coordinator`.
1349
1350  Calls to `run()` are delegated to the wrapped session.  If a call
1351  raises an exception, the exception is reported to the coordinator.
1352
1353  In addition, after each call to `run()` this session ask the coordinator if
1354  the session should stop.  In that case it will join all the threads
1355  registered with the coordinator before returning.
1356
1357  If the coordinator was requested to stop with an exception, that exception
1358  will be re-raised from the call to `run()`.
1359  """
1360
1361  def __init__(self, sess, coord, stop_grace_period_secs=120):
1362    """Create a new `_CoordinatedSession`.
1363
1364    Args:
1365      sess: A `tf.compat.v1.Session` object.  The wrapped session.
1366      coord: A `tf.train.Coordinator` object.
1367      stop_grace_period_secs: Number of seconds given to threads to stop after
1368        `close()` has been called.
1369    """
1370    _WrappedSession.__init__(self, sess)
1371    self._coord = coord
1372    self._stop_grace_period_secs = stop_grace_period_secs
1373
1374  def _check_stop(self):
1375    # If the coordinator was asked to stop due to an exception, then it needs
1376    # to be propagated to this stack.
1377    self._coord.raise_requested_exception()
1378    # At this point, no exceptions are recorded in the coordinator.
1379    return self._coord.should_stop()
1380
1381  def close(self):
1382    self._coord.request_stop()
1383    try:
1384      self._coord.join(
1385          stop_grace_period_secs=self._stop_grace_period_secs,
1386          ignore_live_threads=True)
1387    finally:
1388      try:
1389        _WrappedSession.close(self)
1390      except Exception:  # pylint: disable=broad-except
1391        # We intentionally suppress exceptions from the close() here since
1392        # useful exceptions are already reported by join().
1393        pass
1394
1395  def run(self, *args, **kwargs):
1396    try:
1397      return self._sess.run(*args, **kwargs)
1398    except _PREEMPTION_ERRORS:
1399      raise
1400    except Exception as original_exception:  # pylint: disable=broad-except
1401      # A non-preemption error could have been caused by a preemption error
1402      # in the coordinator. If this is the case, raise that exception instead,
1403      # since it's the root cause. Otherwise, stick to the `original_exception`.
1404      try:
1405        self._coord.raise_requested_exception()
1406      except _PREEMPTION_ERRORS:
1407        raise
1408      except Exception:  # pylint: disable=broad-except
1409        raise original_exception from None
1410      else:
1411        raise
1412
1413
1414class _HookedSession(_WrappedSession):
1415  """A _WrappedSession that calls hooks during calls to run().
1416
1417  The list of hooks to call is passed in the constructor.  Before each call
1418  to `run()` the session calls the `before_run()` method of the hooks, which
1419  can return additional ops or tensors to run.  These are added to the arguments
1420  of the call to `run()`.
1421
1422  When the `run()` call finishes, the session calls the `after_run()` methods of
1423  the hooks, passing the values returned by the `run()` call corresponding to
1424  the ops and tensors that each hook requested.
1425
1426  If any call to the hooks, requests stop via run_context the session will be
1427  marked as needing to stop and its `should_stop()` method will now return
1428  `True`.
1429  """
1430
1431  def __init__(self, sess, hooks):
1432    """Initializes a _HookedSession object.
1433
1434    Args:
1435      sess: A `tf.compat.v1.Session` or a `_WrappedSession` object.
1436      hooks: An iterable of `SessionRunHook' objects.
1437    """
1438
1439    _WrappedSession.__init__(self, sess)
1440    self._hooks = hooks
1441    self._should_stop = False
1442
1443  def _check_stop(self):
1444    """See base class."""
1445    return self._should_stop
1446
1447  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
1448    """See base class."""
1449    if self.should_stop():
1450      raise RuntimeError('Run called even after should_stop requested.')
1451
1452    actual_fetches = {'caller': fetches}
1453
1454    run_context = session_run_hook.SessionRunContext(
1455        original_args=session_run_hook.SessionRunArgs(fetches, feed_dict),
1456        session=self._sess)
1457
1458    options = options or config_pb2.RunOptions()
1459    feed_dict = self._call_hook_before_run(run_context, actual_fetches,
1460                                           feed_dict, options)
1461
1462    # Do session run.
1463    run_metadata = run_metadata or config_pb2.RunMetadata()
1464    outputs = _WrappedSession.run(
1465        self,
1466        fetches=actual_fetches,
1467        feed_dict=feed_dict,
1468        options=options,
1469        run_metadata=run_metadata)
1470
1471    for hook in self._hooks:
1472      hook.after_run(
1473          run_context,
1474          session_run_hook.SessionRunValues(
1475              results=outputs[hook] if hook in outputs else None,
1476              options=options,
1477              run_metadata=run_metadata))
1478    self._should_stop = self._should_stop or run_context.stop_requested
1479
1480    return outputs['caller']
1481
1482  def _call_hook_before_run(self, run_context, fetch_dict, user_feed_dict,
1483                            options):
1484    """Calls hooks.before_run and handles requests from hooks."""
1485    hook_feeds = {}
1486    for hook in self._hooks:
1487      request = hook.before_run(run_context)
1488      if request is not None:
1489        if request.fetches is not None:
1490          fetch_dict[hook] = request.fetches
1491        if request.feed_dict:
1492          self._raise_if_feeds_intersects(hook_feeds, request.feed_dict,
1493                                          'Same tensor is fed by two hooks.')
1494          hook_feeds.update(request.feed_dict)
1495        if request.options:
1496          self._merge_run_options(options, request.options)
1497
1498    if not hook_feeds:
1499      return user_feed_dict
1500
1501    if not user_feed_dict:
1502      return hook_feeds
1503
1504    self._raise_if_feeds_intersects(
1505        user_feed_dict, hook_feeds,
1506        'Same tensor is fed by a SessionRunHook and user.')
1507    hook_feeds.update(user_feed_dict)
1508    return hook_feeds
1509
1510  def _raise_if_feeds_intersects(self, feeds1, feeds2, message):
1511    intersection = set(feeds1.keys()) & set(feeds2.keys())
1512    if intersection:
1513      raise RuntimeError(message + ' Conflict(s): ' + str(list(intersection)))
1514
1515  def _merge_run_options(self, options, incoming_options):
1516    """Merge two instances of RunOptions into the first one.
1517
1518    During the merger, the numerical fields including trace_level,
1519    timeout_in_ms, inter_op_thread_pool are set to the larger one of the two.
1520    The boolean value is set to the logical OR of the two.
1521    debug_tensor_watch_opts of the original options is extended with that from
1522    the incoming one.
1523
1524    Args:
1525      options: The options to merge into.
1526      incoming_options: The options to be merged into the first argument.
1527    """
1528    options.trace_level = max(options.trace_level, incoming_options.trace_level)
1529    options.timeout_in_ms = max(options.timeout_in_ms,
1530                                incoming_options.timeout_in_ms)
1531    options.inter_op_thread_pool = max(options.inter_op_thread_pool,
1532                                       incoming_options.inter_op_thread_pool)
1533    options.output_partition_graphs = max(
1534        options.output_partition_graphs,
1535        incoming_options.output_partition_graphs)
1536    options.debug_options.debug_tensor_watch_opts.extend(
1537        incoming_options.debug_options.debug_tensor_watch_opts)
1538    options.debug_options.reset_disk_byte_usage = (
1539        options.debug_options.reset_disk_byte_usage or
1540        incoming_options.debug_options.reset_disk_byte_usage)
1541    options.report_tensor_allocations_upon_oom = (
1542        options.report_tensor_allocations_upon_oom or
1543        incoming_options.report_tensor_allocations_upon_oom)
1544