1# pylint: disable=g-bad-file-header 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""A wrapper of Session API which runs hooks.""" 17 18import abc 19import os 20 21from tensorflow.core.protobuf import config_pb2 22from tensorflow.python.checkpoint import checkpoint as trackable_util 23from tensorflow.python.checkpoint import graph_view 24from tensorflow.python.distribute import distribute_coordinator_context 25from tensorflow.python.framework import errors 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import lookup_ops 30from tensorflow.python.ops import resources 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.summary import summary 34from tensorflow.python.training import basic_session_run_hooks 35from tensorflow.python.training import coordinator 36from tensorflow.python.training import queue_runner 37from tensorflow.python.training import saver as training_saver 38from tensorflow.python.training import session_manager as sm 39from tensorflow.python.training import session_run_hook 40from tensorflow.python.util import function_utils 41from tensorflow.python.util.tf_export import tf_export 42 43# The list of exceptions that we should recover from. Exceptions not in this 44# list may terminate the job. 45_PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError) 46 47# Value that indicates no value was provided. 48USE_DEFAULT = object() 49 50 51@tf_export(v1=['train.Scaffold']) 52class Scaffold: 53 """Structure to create or gather pieces commonly needed to train a model. 54 55 When you build a model for training you usually need ops to initialize 56 variables, a `Saver` to checkpoint them, an op to collect summaries for 57 the visualizer, and so on. 58 59 Various libraries built on top of the core TensorFlow library take care of 60 creating some or all of these pieces and storing them in well known 61 collections in the graph. The `Scaffold` class helps pick these pieces from 62 the graph collections, creating and adding them to the collections if needed. 63 64 If you call the scaffold constructor without any arguments, it will pick 65 pieces from the collections, creating default ones if needed when 66 `scaffold.finalize()` is called. You can pass arguments to the constructor to 67 provide your own pieces. Pieces that you pass to the constructor are not 68 added to the graph collections. 69 70 The following pieces are directly accessible as attributes of the `Scaffold` 71 object: 72 73 * `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the 74 variables. 75 Picked from and stored into the `SAVERS` collection in the graph by default. 76 * `init_op`: An op to run to initialize the variables. Picked from and 77 stored into the `INIT_OP` collection in the graph by default. 78 * `ready_op`: An op to verify that the variables are initialized. Picked 79 from and stored into the `READY_OP` collection in the graph by default. 80 * `ready_for_local_init_op`: An op to verify that global state has been 81 initialized and it is alright to run `local_init_op`. Picked from and 82 stored into the `READY_FOR_LOCAL_INIT_OP` collection in the graph by 83 default. This is needed when the initialization of local variables depends 84 on the values of global variables. 85 * `local_init_op`: An op to initialize the local variables. Picked 86 from and stored into the `LOCAL_INIT_OP` collection in the graph by default. 87 * `summary_op`: An op to run and merge the summaries in the graph. Picked 88 from and stored into the `SUMMARY_OP` collection in the graph by default. 89 90 You can also pass the following additional pieces to the constructor: 91 92 * `init_feed_dict`: A session feed dictionary that should be used when 93 running the init op. 94 * `init_fn`: A callable to run after the init op to perform additional 95 initializations. The callable will be called as 96 `init_fn(scaffold, session)`. 97 98 """ 99 100 def __init__(self, 101 init_op=None, 102 init_feed_dict=None, 103 init_fn=None, 104 ready_op=None, 105 ready_for_local_init_op=None, 106 local_init_op=None, 107 summary_op=None, 108 saver=None, 109 copy_from_scaffold=None, 110 local_init_feed_dict=None): 111 """Create a scaffold. 112 113 Args: 114 init_op: Optional op for initializing variables. 115 init_feed_dict: Optional session feed dictionary to use when running the 116 init_op. 117 init_fn: Optional function to use to initialize the model after running 118 the init_op. Will be called as `init_fn(scaffold, session)`. 119 ready_op: Optional op to verify that the variables are initialized. Must 120 return an empty 1D string tensor when the variables are initialized, or 121 a non-empty 1D string tensor listing the names of the non-initialized 122 variables. 123 ready_for_local_init_op: Optional op to verify that the global variables 124 are initialized and `local_init_op` can be run. Must return an empty 1D 125 string tensor when the global variables are initialized, or a non-empty 126 1D string tensor listing the names of the non-initialized global 127 variables. 128 local_init_op: Optional op to initialize local variables. 129 summary_op: Optional op to gather all summaries. Must return a scalar 130 string tensor containing a serialized `Summary` proto. 131 saver: Optional `tf.compat.v1.train.Saver` object to use to save and 132 restore variables. May also be a `tf.train.Checkpoint` object, in which 133 case object-based checkpoints are saved. This will also load some 134 object-based checkpoints saved from elsewhere, but that loading may be 135 fragile since it uses fixed keys rather than performing a full 136 graph-based match. For example if a variable has two paths from the 137 `Checkpoint` object because two `Model` objects share the `Layer` object 138 that owns it, removing one `Model` may change the keys and break 139 checkpoint loading through this API, whereas a graph-based match would 140 match the variable through the other `Model`. 141 copy_from_scaffold: Optional scaffold object to copy fields from. Its 142 fields will be overwritten by the provided fields in this function. 143 local_init_feed_dict: Optional session feed dictionary to use when running 144 the local_init_op. 145 """ 146 if copy_from_scaffold is not None: 147 if not isinstance(copy_from_scaffold, Scaffold): 148 raise TypeError('copy_from_scaffold is not a Scaffold instance.') 149 # We need _coalesce since Tensor is not converted to bool automatically, 150 # so the common idiom of (a or b) does not work. 151 coalesce = lambda a, b: a if a is not None else b 152 init_op = coalesce(init_op, copy_from_scaffold.init_op) 153 init_feed_dict = coalesce(init_feed_dict, 154 copy_from_scaffold.init_feed_dict) 155 # Use the original init_fn provided by the user to init the new Scaffold. 156 init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn) # pylint: disable=protected-access 157 ready_op = coalesce(ready_op, copy_from_scaffold.ready_op) 158 ready_for_local_init_op = coalesce( 159 ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op) 160 local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op) 161 local_init_feed_dict = coalesce(local_init_feed_dict, 162 copy_from_scaffold.local_init_feed_dict) 163 summary_op = coalesce(summary_op, copy_from_scaffold.summary_op) 164 saver = coalesce(saver, copy_from_scaffold.saver) 165 166 # NOTE(touts): modifying the init function to be passed the scaffold is a 167 # hack to make it easy to find the saver. Is there a better way? 168 self._user_init_fn = init_fn 169 if init_fn: 170 self._init_fn = lambda sess: init_fn(self, sess) 171 else: 172 self._init_fn = None 173 174 self._init_op = init_op 175 self._init_feed_dict = init_feed_dict 176 self._ready_op = ready_op 177 self._ready_for_local_init_op = ready_for_local_init_op 178 self._local_init_op = local_init_op 179 self._local_init_feed_dict = local_init_feed_dict 180 self._summary_op = summary_op 181 self._saver = saver 182 183 def finalize(self): 184 """Creates operations if needed and finalizes the graph.""" 185 if self._init_op is None: 186 187 def default_init_op(): 188 return control_flow_ops.group( 189 variables.global_variables_initializer(), 190 resources.initialize_resources(resources.shared_resources()), 191 ops.get_collection('saved_model_initializers')) 192 193 self._init_op = Scaffold.get_or_default('init_op', ops.GraphKeys.INIT_OP, 194 default_init_op) 195 if self._ready_op is None: 196 197 def default_ready_op(): 198 return array_ops.concat([ 199 variables.report_uninitialized_variables(), 200 resources.report_uninitialized_resources() 201 ], 0) 202 203 self._ready_op = Scaffold.get_or_default('ready_op', 204 ops.GraphKeys.READY_OP, 205 default_ready_op) 206 if self._ready_for_local_init_op is None: 207 208 def default_ready_for_local_init_op(): 209 return array_ops.concat([ 210 variables.report_uninitialized_variables( 211 variables.global_variables()), 212 resources.report_uninitialized_resources( 213 resources.shared_resources()) 214 ], 0) 215 216 self._ready_for_local_init_op = Scaffold.get_or_default( 217 'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP, 218 default_ready_for_local_init_op) 219 if self._local_init_op is None: 220 self._local_init_op = Scaffold.get_or_default( 221 'local_init_op', ops.GraphKeys.LOCAL_INIT_OP, 222 Scaffold.default_local_init_op) 223 if self._summary_op is None: 224 self._summary_op = Scaffold.get_or_default('summary_op', 225 ops.GraphKeys.SUMMARY_OP, 226 summary.merge_all) 227 # pylint: disable=g-long-lambda 228 if self._saver is None: 229 self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access 230 # pylint: enable=g-long-lambda 231 if isinstance(self._saver, trackable_util.Checkpoint): 232 self._saver = training_saver.Saver( 233 var_list=graph_view.ObjectGraphView( 234 self._saver).frozen_saveable_objects(), 235 sharded=True) 236 else: 237 self._saver.build() 238 239 ops.get_default_graph().finalize() 240 logging.info('Graph was finalized.') 241 return self 242 243 @property 244 def init_fn(self): 245 return self._init_fn 246 247 @property 248 def init_op(self): 249 return self._init_op 250 251 @property 252 def ready_op(self): 253 return self._ready_op 254 255 @property 256 def ready_for_local_init_op(self): 257 return self._ready_for_local_init_op 258 259 @property 260 def local_init_op(self): 261 return self._local_init_op 262 263 @property 264 def local_init_feed_dict(self): 265 return self._local_init_feed_dict 266 267 @property 268 def summary_op(self): 269 return self._summary_op 270 271 @property 272 def saver(self): 273 return self._saver 274 275 @property 276 def init_feed_dict(self): 277 return self._init_feed_dict 278 279 @staticmethod 280 def get_or_default(arg_name, collection_key, default_constructor): 281 """Get from cache or create a default operation.""" 282 elements = ops.get_collection(collection_key) 283 if elements: 284 if len(elements) > 1: 285 raise RuntimeError( 286 'More than one item in the collection "%s". ' 287 'Please indicate which one to use by passing it to ' 288 'the tf.Scaffold constructor as: ' 289 'tf.Scaffold(%s=item to use)', collection_key, arg_name) 290 return elements[0] 291 op = default_constructor() 292 if op is not None: 293 ops.add_to_collection(collection_key, op) 294 return op 295 296 @staticmethod 297 def default_local_init_op(): 298 """Returns an op that groups the default local init ops. 299 300 This op is used during session initialization when a Scaffold is 301 initialized without specifying the local_init_op arg. It includes 302 `tf.compat.v1.local_variables_initializer`, 303 `tf.compat.v1.tables_initializer`, and also 304 initializes local session resources. 305 306 Returns: 307 The default Scaffold local init op. 308 """ 309 return control_flow_ops.group( 310 variables.local_variables_initializer(), 311 lookup_ops.tables_initializer(), 312 resources.initialize_resources(resources.local_resources())) 313 314 315def _create_monitored_session_with_worker_context( 316 worker_context, # pylint: disable=missing-docstring 317 scaffold, 318 checkpoint_dir=None, 319 hooks=None, 320 chief_only_hooks=None, 321 save_checkpoint_secs=None, 322 save_summaries_steps=None, 323 save_summaries_secs=None, 324 config=None, 325 stop_grace_period_secs=120, 326 log_step_count_steps=100, 327 max_wait_secs=7200, 328 save_checkpoint_steps=None, 329 summary_dir=None, 330 save_graph_def=True): 331 all_hooks = [] 332 if hooks: 333 all_hooks.extend(hooks) 334 if chief_only_hooks and worker_context.is_chief: 335 all_hooks.extend(chief_only_hooks) 336 337 # We need to call save or summary ops on all workers since these ops may 338 # contain collective ops, only running save ops on some workers would make 339 # collective ops hang. Therefore on those workers that don't need to actually 340 # write checkpoints or summaries, we let them write to a temp directory. 341 # pylint: disable=protected-access 342 if type( 343 worker_context._strategy).__name__ in ('CollectiveAllReduceStrategy', 344 'CollectiveAllReduceStrategyV1', 345 'MultiWorkerMirroredStrategy'): 346 if worker_context.task_type: 347 tmpdir = 'tmp_%s_%d' % (worker_context.task_type, worker_context.task_id) 348 else: 349 tmpdir = 'tmp' 350 351 if save_checkpoint_secs: 352 logging.warning('Collective ops may deadlock with ' 353 '`save_checkpoints_secs` please use ' 354 '`save_checkpoint_steps` instead. Clearing ' 355 '`save_checkpoint_secs` and setting ' 356 '`save_checkpoint_steps` to 1000 now.') 357 save_checkpoint_secs = None 358 save_checkpoint_steps = 1000 359 if save_summaries_secs: 360 logging.warning('Collective ops may run out of sync with' 361 '`save_summaries_secs`, please use ' 362 '`save_summaries_steps` instead.') 363 else: 364 tmpdir = None 365 366 summary_dir = summary_dir or checkpoint_dir 367 if summary_dir and log_step_count_steps and log_step_count_steps > 0: 368 if worker_context.should_save_summary: 369 all_hooks.append( 370 basic_session_run_hooks.StepCounterHook( 371 output_dir=summary_dir, every_n_steps=log_step_count_steps)) 372 elif tmpdir: 373 all_hooks.append( 374 basic_session_run_hooks.StepCounterHook( 375 output_dir=os.path.join(summary_dir, tmpdir), 376 every_n_steps=log_step_count_steps)) 377 378 if (((save_summaries_steps and save_summaries_steps > 0) or 379 (save_summaries_secs and save_summaries_secs > 0)) and summary_dir): 380 if worker_context.should_save_summary: 381 all_hooks.append( 382 basic_session_run_hooks.SummarySaverHook( 383 scaffold=scaffold, 384 save_steps=save_summaries_steps, 385 save_secs=save_summaries_secs, 386 output_dir=summary_dir)) 387 elif tmpdir: 388 all_hooks.append( 389 basic_session_run_hooks.SummarySaverHook( 390 scaffold=scaffold, 391 save_steps=save_summaries_steps, 392 save_secs=save_summaries_secs, 393 output_dir=os.path.join(summary_dir, tmpdir))) 394 395 if (((save_checkpoint_secs and save_checkpoint_secs > 0) or 396 (save_checkpoint_steps and save_checkpoint_steps > 0)) and 397 checkpoint_dir): 398 if worker_context.should_checkpoint: 399 all_hooks.append( 400 basic_session_run_hooks.CheckpointSaverHook( 401 checkpoint_dir, 402 save_steps=save_checkpoint_steps, 403 save_secs=save_checkpoint_secs, 404 scaffold=scaffold, 405 save_graph_def=save_graph_def)) 406 elif tmpdir: 407 all_hooks.append( 408 basic_session_run_hooks.CheckpointSaverHook( 409 os.path.join(checkpoint_dir, tmpdir), 410 save_steps=save_checkpoint_steps, 411 save_secs=save_checkpoint_secs, 412 scaffold=scaffold, 413 save_graph_def=save_graph_def)) 414 415 logging.info('all_hooks %r', all_hooks) 416 session_creator = worker_context.session_creator( 417 scaffold, 418 config=config, 419 checkpoint_dir=checkpoint_dir, 420 max_wait_secs=max_wait_secs) 421 return MonitoredSession( 422 session_creator=session_creator, 423 hooks=all_hooks, 424 stop_grace_period_secs=stop_grace_period_secs) 425 426 427@tf_export(v1=['train.MonitoredTrainingSession']) 428def MonitoredTrainingSession( 429 master='', # pylint: disable=invalid-name 430 is_chief=True, 431 checkpoint_dir=None, 432 scaffold=None, 433 hooks=None, 434 chief_only_hooks=None, 435 save_checkpoint_secs=USE_DEFAULT, 436 save_summaries_steps=USE_DEFAULT, 437 save_summaries_secs=USE_DEFAULT, 438 config=None, 439 stop_grace_period_secs=120, 440 log_step_count_steps=100, 441 max_wait_secs=7200, 442 save_checkpoint_steps=USE_DEFAULT, 443 summary_dir=None, 444 save_graph_def=True): 445 """Creates a `MonitoredSession` for training. 446 447 For a chief, this utility sets proper session initializer/restorer. It also 448 creates hooks related to checkpoint and summary saving. For workers, this 449 utility sets proper session creator which waits for the chief to 450 initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for 451 more 452 information. 453 454 @compatibility(TF2) 455 This API is not compatible with eager execution and `tf.function`. To migrate 456 to TF2, rewrite the code to be compatible with eager execution. Check the 457 [migration 458 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls) 459 on replacing `Session.run` calls. In Keras, session hooks can be replaced by 460 Callbacks e.g. [logging hook notebook]( 461 https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb) 462 For more details please read [Better 463 performance with tf.function](https://www.tensorflow.org/guide/function). 464 @end_compatibility 465 466 Args: 467 master: `String` the TensorFlow master to use. 468 is_chief: If `True`, it will take care of initialization and recovery the 469 underlying TensorFlow session. If `False`, it will wait on a chief to 470 initialize or recover the TensorFlow session. 471 checkpoint_dir: A string. Optional path to a directory where to restore 472 variables. 473 scaffold: A `Scaffold` used for gathering or building supportive ops. If not 474 specified, a default one is created. It's used to finalize the graph. 475 hooks: Optional list of `SessionRunHook` objects. 476 chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if 477 `is_chief==True`, ignore otherwise. 478 save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved 479 using a default checkpoint saver. If both `save_checkpoint_steps` and 480 `save_checkpoint_secs` are set to `None`, then the default checkpoint 481 saver isn't used. If both are provided, then only `save_checkpoint_secs` 482 is used. Default 600. 483 save_summaries_steps: The frequency, in number of global steps, that the 484 summaries are written to disk using a default summary saver. If both 485 `save_summaries_steps` and `save_summaries_secs` are set to `None`, then 486 the default summary saver isn't used. Default 100. 487 save_summaries_secs: The frequency, in secs, that the summaries are written 488 to disk using a default summary saver. If both `save_summaries_steps` and 489 `save_summaries_secs` are set to `None`, then the default summary saver 490 isn't used. Default not enabled. 491 config: an instance of `tf.compat.v1.ConfigProto` proto used to configure 492 the session. It's the `config` argument of constructor of 493 `tf.compat.v1.Session`. 494 stop_grace_period_secs: Number of seconds given to threads to stop after 495 `close()` has been called. 496 log_step_count_steps: The frequency, in number of global steps, that the 497 global step/sec is logged. 498 max_wait_secs: Maximum time workers should wait for the session to become 499 available. This should be kept relatively short to help detect incorrect 500 code, but sometimes may need to be increased if the chief takes a while to 501 start up. 502 save_checkpoint_steps: The frequency, in number of global steps, that a 503 checkpoint is saved using a default checkpoint saver. If both 504 `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then 505 the default checkpoint saver isn't used. If both are provided, then only 506 `save_checkpoint_secs` is used. Default not enabled. 507 summary_dir: A string. Optional path to a directory where to save 508 summaries. If None, checkpoint_dir is used instead. 509 save_graph_def: Whether to save the GraphDef and MetaGraphDef to 510 `checkpoint_dir`. The GraphDef is saved after the session is created as 511 `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as 512 `model.ckpt-*.meta`. 513 514 Returns: 515 A `MonitoredSession` object. 516 """ 517 if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT: 518 save_summaries_steps = 100 519 save_summaries_secs = None 520 elif save_summaries_secs == USE_DEFAULT: 521 save_summaries_secs = None 522 elif save_summaries_steps == USE_DEFAULT: 523 save_summaries_steps = None 524 525 if (save_checkpoint_steps == USE_DEFAULT and 526 save_checkpoint_secs == USE_DEFAULT): 527 save_checkpoint_steps = None 528 save_checkpoint_secs = 600 529 elif save_checkpoint_secs == USE_DEFAULT: 530 save_checkpoint_secs = None 531 elif save_checkpoint_steps == USE_DEFAULT: 532 save_checkpoint_steps = None 533 534 scaffold = scaffold or Scaffold() 535 worker_context = distribute_coordinator_context.get_current_worker_context() 536 537 if worker_context: 538 return _create_monitored_session_with_worker_context( 539 worker_context, 540 scaffold, 541 checkpoint_dir=checkpoint_dir, 542 hooks=hooks, 543 chief_only_hooks=chief_only_hooks, 544 save_checkpoint_secs=save_checkpoint_secs, 545 save_summaries_steps=save_summaries_steps, 546 save_summaries_secs=save_summaries_secs, 547 config=config, 548 stop_grace_period_secs=stop_grace_period_secs, 549 log_step_count_steps=log_step_count_steps, 550 max_wait_secs=max_wait_secs, 551 save_checkpoint_steps=save_checkpoint_steps, 552 summary_dir=summary_dir, 553 save_graph_def=save_graph_def) 554 555 if not is_chief: 556 session_creator = WorkerSessionCreator( 557 scaffold=scaffold, 558 master=master, 559 config=config, 560 max_wait_secs=max_wait_secs) 561 return MonitoredSession( 562 session_creator=session_creator, 563 hooks=hooks or [], 564 stop_grace_period_secs=stop_grace_period_secs) 565 566 all_hooks = [] 567 if chief_only_hooks: 568 all_hooks.extend(chief_only_hooks) 569 session_creator = ChiefSessionCreator( 570 scaffold=scaffold, 571 checkpoint_dir=checkpoint_dir, 572 master=master, 573 config=config) 574 575 summary_dir = summary_dir or checkpoint_dir 576 if summary_dir: 577 if log_step_count_steps and log_step_count_steps > 0: 578 all_hooks.append( 579 basic_session_run_hooks.StepCounterHook( 580 output_dir=summary_dir, every_n_steps=log_step_count_steps)) 581 582 if (save_summaries_steps and 583 save_summaries_steps > 0) or (save_summaries_secs and 584 save_summaries_secs > 0): 585 all_hooks.append( 586 basic_session_run_hooks.SummarySaverHook( 587 scaffold=scaffold, 588 save_steps=save_summaries_steps, 589 save_secs=save_summaries_secs, 590 output_dir=summary_dir)) 591 592 if checkpoint_dir: 593 if (save_checkpoint_secs and 594 save_checkpoint_secs > 0) or (save_checkpoint_steps and 595 save_checkpoint_steps > 0): 596 all_hooks.append( 597 basic_session_run_hooks.CheckpointSaverHook( 598 checkpoint_dir, 599 save_steps=save_checkpoint_steps, 600 save_secs=save_checkpoint_secs, 601 scaffold=scaffold, 602 save_graph_def=save_graph_def)) 603 604 if hooks: 605 all_hooks.extend(hooks) 606 return MonitoredSession( 607 session_creator=session_creator, 608 hooks=all_hooks, 609 stop_grace_period_secs=stop_grace_period_secs) 610 611 612@tf_export(v1=['train.SessionCreator']) 613class SessionCreator(metaclass=abc.ABCMeta): 614 """A factory for tf.Session.""" 615 616 @abc.abstractmethod 617 def create_session(self): 618 raise NotImplementedError( 619 'create_session is not implemented for {}.'.format(self)) 620 621 622@tf_export(v1=['train.ChiefSessionCreator']) 623class ChiefSessionCreator(SessionCreator): 624 """Creates a tf.compat.v1.Session for a chief.""" 625 626 def __init__(self, 627 scaffold=None, 628 master='', 629 config=None, 630 checkpoint_dir=None, 631 checkpoint_filename_with_path=None): 632 """Initializes a chief session creator. 633 634 Args: 635 scaffold: A `Scaffold` used for gathering or building supportive ops. If 636 not specified a default one is created. It's used to finalize the graph. 637 master: `String` representation of the TensorFlow master to use. 638 config: `ConfigProto` proto used to configure the session. 639 checkpoint_dir: A string. Optional path to a directory where to restore 640 variables. 641 checkpoint_filename_with_path: Full file name path to the checkpoint file. 642 """ 643 self._checkpoint_dir = checkpoint_dir 644 self._checkpoint_filename_with_path = checkpoint_filename_with_path 645 self._scaffold = scaffold or Scaffold() 646 self._session_manager = None 647 self._master = master 648 self._config = config 649 650 def _get_session_manager(self): 651 """Gets or creates a SessionManager.""" 652 if self._session_manager: 653 return self._session_manager 654 655 self._session_manager = sm.SessionManager( 656 local_init_op=self._scaffold.local_init_op, 657 local_init_feed_dict=self._scaffold.local_init_feed_dict, 658 ready_op=self._scaffold.ready_op, 659 ready_for_local_init_op=self._scaffold.ready_for_local_init_op, 660 graph=ops.get_default_graph()) 661 return self._session_manager 662 663 def create_session(self): 664 self._scaffold.finalize() 665 return self._get_session_manager().prepare_session( 666 self._master, 667 saver=self._scaffold.saver, 668 checkpoint_dir=self._checkpoint_dir, 669 checkpoint_filename_with_path=self._checkpoint_filename_with_path, 670 config=self._config, 671 init_op=self._scaffold.init_op, 672 init_feed_dict=self._scaffold.init_feed_dict, 673 init_fn=self._scaffold.init_fn) 674 675 676@tf_export(v1=['train.WorkerSessionCreator']) 677class WorkerSessionCreator(SessionCreator): 678 """Creates a tf.compat.v1.Session for a worker.""" 679 680 def __init__(self, 681 scaffold=None, 682 master='', 683 config=None, 684 max_wait_secs=30 * 60): 685 """Initializes a worker session creator. 686 687 Args: 688 scaffold: A `Scaffold` used for gathering or building supportive ops. If 689 not specified a default one is created. It's used to finalize the graph. 690 master: `String` representation of the TensorFlow master to use. 691 config: `ConfigProto` proto used to configure the session. 692 max_wait_secs: Maximum time to wait for the session to become available. 693 """ 694 self._scaffold = scaffold or Scaffold() 695 self._session_manager = None 696 self._master = master 697 self._config = config 698 self._max_wait_secs = max_wait_secs 699 700 def _get_session_manager(self): 701 """Gets or creates a SessionManager.""" 702 if self._session_manager: 703 return self._session_manager 704 705 self._session_manager = sm.SessionManager( 706 local_init_op=self._scaffold.local_init_op, 707 local_init_feed_dict=self._scaffold.local_init_feed_dict, 708 ready_op=self._scaffold.ready_op, 709 ready_for_local_init_op=self._scaffold.ready_for_local_init_op, 710 graph=ops.get_default_graph()) 711 return self._session_manager 712 713 def create_session(self): 714 self._scaffold.finalize() 715 return self._get_session_manager().wait_for_session( 716 self._master, config=self._config, max_wait_secs=self._max_wait_secs) 717 718 719class _MonitoredSession: 720 """See `MonitoredSession` or `SingularMonitoredSession`.""" 721 722 def __init__(self, 723 session_creator, 724 hooks, 725 should_recover, 726 stop_grace_period_secs=120): 727 """Sets up a Monitored or Hooked Session. 728 729 Args: 730 session_creator: A factory object to create session. Typically a 731 `ChiefSessionCreator` or a `WorkerSessionCreator`. 732 hooks: An iterable of `SessionRunHook' objects. 733 should_recover: A bool. Indicates whether to recover from `AbortedError` 734 and `UnavailableError` or not. 735 stop_grace_period_secs: Number of seconds given to threads to stop after 736 `close()` has been called. 737 """ 738 self._graph_was_finalized = ops.get_default_graph().finalized 739 self._hooks = hooks or [] 740 for h in self._hooks: 741 h.begin() 742 743 worker_context = distribute_coordinator_context.get_current_worker_context() 744 if not session_creator and worker_context: 745 session_creator = worker_context.session_creator() 746 747 # Create the session. 748 self._coordinated_creator = self._CoordinatedSessionCreator( 749 session_creator=session_creator or ChiefSessionCreator(), 750 hooks=self._hooks, 751 stop_grace_period_secs=stop_grace_period_secs) 752 if should_recover: 753 self._sess = _RecoverableSession(self._coordinated_creator) 754 else: 755 self._sess = self._coordinated_creator.create_session() 756 757 @property 758 def graph(self): 759 """The graph that was launched in this session.""" 760 if self._tf_sess() is None: 761 return None 762 return self._tf_sess().graph 763 764 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 765 """Run ops in the monitored session. 766 767 This method is completely compatible with the `tf.Session.run()` method. 768 769 Args: 770 fetches: Same as `tf.Session.run()`. 771 feed_dict: Same as `tf.Session.run()`. 772 options: Same as `tf.Session.run()`. 773 run_metadata: Same as `tf.Session.run()`. 774 775 Returns: 776 Same as `tf.Session.run()`. 777 """ 778 return self._sess.run( 779 fetches, 780 feed_dict=feed_dict, 781 options=options, 782 run_metadata=run_metadata) 783 784 def run_step_fn(self, step_fn): 785 """Run ops using a step function. 786 787 Args: 788 step_fn: A function or a method with a single argument of type 789 `StepContext`. The function may use methods of the argument to perform 790 computations with access to a raw session. The returned value of the 791 `step_fn` will be returned from `run_step_fn`, unless a stop is 792 requested. In that case, the next `should_stop` call will return True. 793 Example usage: 794 ```python 795 with tf.Graph().as_default(): 796 c = tf.compat.v1.placeholder(dtypes.float32) 797 v = tf.add(c, 4.0) 798 w = tf.add(c, 0.5) 799 def step_fn(step_context): 800 a = step_context.session.run(fetches=v, feed_dict={c: 0.5}) 801 if a <= 4.5: 802 step_context.request_stop() 803 return step_context.run_with_hooks(fetches=w, 804 feed_dict={c: 0.1}) 805 806 with tf.MonitoredSession() as session: 807 while not session.should_stop(): 808 a = session.run_step_fn(step_fn) 809 ``` 810 Hooks interact with the `run_with_hooks()` call inside the 811 `step_fn` as they do with a `MonitoredSession.run` call. 812 813 Returns: 814 Returns the returned value of `step_fn`. 815 816 Raises: 817 StopIteration: if `step_fn` has called `request_stop()`. It may be 818 caught by `with tf.MonitoredSession()` to close the session. 819 ValueError: if `step_fn` doesn't have a single argument called 820 `step_context`. It may also optionally have `self` for cases when it 821 belongs to an object. 822 """ 823 step_fn_arguments = function_utils.fn_args(step_fn) 824 if step_fn_arguments != ('step_context',) and step_fn_arguments != ( 825 'self', 826 'step_context', 827 ): 828 raise ValueError( 829 '`step_fn` may either have one `step_context` argument, or' 830 ' `self` and `step_context` arguments if it\'s an instance' 831 ' method. Got {} instead.'.format(step_fn_arguments)) 832 833 # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`. 834 # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be 835 # `_CoordinatedSession.run` downstream in either case. This allows 836 # `_PREEMPTION_ERRORS` to propage from within `step_fn` to 837 # `_RecoverableSession.run_step_fn`. 838 return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None) 839 840 class StepContext: 841 """Control flow instrument for the `step_fn` from `run_step_fn()`. 842 843 Users of `step_fn` may perform `run()` calls without running hooks 844 by accessing the `session`. A `run()` call with hooks may be performed 845 using `run_with_hooks()`. Computation flow can be interrupted using 846 `request_stop()`. 847 """ 848 849 def __init__(self, session, run_with_hooks_fn): 850 """Initializes the `step_context` argument for a `step_fn` invocation. 851 852 Args: 853 session: An instance of `tf.compat.v1.Session`. 854 run_with_hooks_fn: A function for running fetches and hooks. 855 """ 856 self._session = session 857 self._run_with_hooks_fn = run_with_hooks_fn 858 859 @property 860 def session(self): 861 return self._session 862 863 def run_with_hooks(self, *args, **kwargs): 864 """Same as `MonitoredSession.run`. Accepts the same arguments.""" 865 return self._run_with_hooks_fn(*args, **kwargs) 866 867 def request_stop(self): 868 """Exit the training loop by causing `should_stop()` to return `True`. 869 870 Causes `step_fn` to exit by raising an exception. 871 872 Raises: 873 StopIteration 874 """ 875 raise StopIteration('step_fn has requested the iterations to stop.') 876 877 def should_stop(self): 878 return self._sess is None or self._sess.should_stop() 879 880 def close(self): 881 self._close_internal() 882 883 def __enter__(self): 884 return self 885 886 def __exit__(self, exception_type, exception_value, traceback): 887 if exception_type in [errors.OutOfRangeError, StopIteration]: 888 exception_type = None 889 self._close_internal(exception_type) 890 # __exit__ should return True to suppress an exception. 891 return exception_type is None 892 893 class _CoordinatedSessionCreator(SessionCreator): 894 """Factory for _CoordinatedSession.""" 895 896 def __init__(self, session_creator, hooks, stop_grace_period_secs): 897 self._session_creator = session_creator 898 self._hooks = hooks 899 self.coord = None 900 self.tf_sess = None 901 self._stop_grace_period_secs = stop_grace_period_secs 902 903 def create_session(self): 904 """Creates a coordinated session.""" 905 # Keep the tf_sess for unit testing. 906 self.tf_sess = self._session_creator.create_session() 907 # We don't want coordinator to suppress any exception. 908 self.coord = coordinator.Coordinator(clean_stop_exception_types=[]) 909 if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS): 910 queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord) 911 # Inform the hooks that a new session has been created. 912 for hook in self._hooks: 913 hook.after_create_session(self.tf_sess, self.coord) 914 return _CoordinatedSession( 915 _HookedSession(self.tf_sess, self._hooks), self.coord, 916 self._stop_grace_period_secs) 917 918 def _close_internal(self, exception_type=None): 919 try: 920 if not exception_type: 921 for h in self._hooks: 922 h.end(self._coordinated_creator.tf_sess) 923 finally: 924 try: 925 if self._sess is None: 926 raise RuntimeError('Session is already closed.') 927 self._sess.close() 928 finally: 929 self._sess = None 930 self._coordinated_creator.tf_sess = None 931 self._coordinated_creator.coord = None 932 if not self._graph_was_finalized: 933 ops.get_default_graph()._unsafe_unfinalize() # pylint: disable=protected-access 934 935 def _is_closed(self): 936 """Return True if the monitored session is closed. 937 938 For tests only. 939 940 Returns: 941 A boolean. 942 """ 943 return self._coordinated_creator.tf_sess is None 944 945 def _tf_sess(self): 946 """Return underlying tf.compat.v1.Session object. 947 948 Warning: accessing the returned object in user code is likely to cause races 949 or "flaky tests". 950 951 Returns: 952 A tf.compat.v1.Session object. 953 """ 954 return self._coordinated_creator.tf_sess 955 956 957@tf_export(v1=['train.MonitoredSession']) 958class MonitoredSession(_MonitoredSession): 959 """Session-like object that handles initialization, recovery and hooks. 960 961 Example usage: 962 963 ```python 964 saver_hook = CheckpointSaverHook(...) 965 summary_hook = SummarySaverHook(...) 966 with MonitoredSession(session_creator=ChiefSessionCreator(...), 967 hooks=[saver_hook, summary_hook]) as sess: 968 while not sess.should_stop(): 969 sess.run(train_op) 970 ``` 971 972 Initialization: At creation time the monitored session does following things 973 in given order: 974 975 * calls `hook.begin()` for each given hook 976 * finalizes the graph via `scaffold.finalize()` 977 * create session 978 * initializes the model via initialization ops provided by `Scaffold` 979 * restores variables if a checkpoint exists 980 * launches queue runners 981 * calls `hook.after_create_session()` 982 983 Run: When `run()` is called, the monitored session does following things: 984 985 * calls `hook.before_run()` 986 * calls TensorFlow `session.run()` with merged fetches and feed_dict 987 * calls `hook.after_run()` 988 * returns result of `session.run()` asked by user 989 * if `AbortedError` or `UnavailableError` occurs, it recovers or 990 reinitializes the session before executing the run() call again 991 992 993 Exit: At the `close()`, the monitored session does following things in order: 994 995 * calls `hook.end()` 996 * closes the queue runners and the session 997 * suppresses `OutOfRange` error which indicates that all inputs have been 998 processed if the monitored_session is used as a context 999 1000 How to set `tf.compat.v1.Session` arguments: 1001 1002 * In most cases you can set session arguments as follows: 1003 1004 ```python 1005 MonitoredSession( 1006 session_creator=ChiefSessionCreator(master=..., config=...)) 1007 ``` 1008 1009 * In distributed setting for a non-chief worker, you can use following: 1010 1011 ```python 1012 MonitoredSession( 1013 session_creator=WorkerSessionCreator(master=..., config=...)) 1014 ``` 1015 1016 See `MonitoredTrainingSession` for an example usage based on chief or worker. 1017 1018 Note: This is not a `tf.compat.v1.Session`. For example, it cannot do 1019 following: 1020 1021 * it cannot be set as default session. 1022 * it cannot be sent to saver.save. 1023 * it cannot be sent to tf.train.start_queue_runners. 1024 1025 @compatibility(TF2) 1026 This API is not compatible with eager execution and `tf.function`. To migrate 1027 to TF2, rewrite the code to be compatible with eager execution. Check the 1028 [migration 1029 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls) 1030 on replacing `Session.run` calls. In Keras, session hooks can be replaced by 1031 Callbacks e.g. [logging hook notebook]( 1032 https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb) 1033 For more details please read [Better 1034 performance with tf.function](https://www.tensorflow.org/guide/function). 1035 @end_compatibility 1036 1037 Args: 1038 session_creator: A factory object to create session. Typically a 1039 `ChiefSessionCreator` which is the default one. 1040 hooks: An iterable of `SessionRunHook' objects. 1041 1042 Returns: 1043 A MonitoredSession object. 1044 """ 1045 1046 def __init__(self, 1047 session_creator=None, 1048 hooks=None, 1049 stop_grace_period_secs=120): 1050 super(MonitoredSession, self).__init__( 1051 session_creator, 1052 hooks, 1053 should_recover=True, 1054 stop_grace_period_secs=stop_grace_period_secs) 1055 1056 1057@tf_export(v1=['train.SingularMonitoredSession']) 1058class SingularMonitoredSession(_MonitoredSession): 1059 """Session-like object that handles initialization, restoring, and hooks. 1060 1061 Please note that this utility is not recommended for distributed settings. 1062 For distributed settings, please use `tf.compat.v1.train.MonitoredSession`. 1063 The 1064 differences between `MonitoredSession` and `SingularMonitoredSession` are: 1065 1066 * `MonitoredSession` handles `AbortedError` and `UnavailableError` for 1067 distributed settings, but `SingularMonitoredSession` does not. 1068 * `MonitoredSession` can be created in `chief` or `worker` modes. 1069 `SingularMonitoredSession` is always created as `chief`. 1070 * You can access the raw `tf.compat.v1.Session` object used by 1071 `SingularMonitoredSession`, whereas in MonitoredSession the raw session is 1072 private. This can be used: 1073 - To `run` without hooks. 1074 - To save and restore. 1075 * All other functionality is identical. 1076 1077 Example usage: 1078 ```python 1079 saver_hook = CheckpointSaverHook(...) 1080 summary_hook = SummarySaverHook(...) 1081 with SingularMonitoredSession(hooks=[saver_hook, summary_hook]) as sess: 1082 while not sess.should_stop(): 1083 sess.run(train_op) 1084 ``` 1085 1086 Initialization: At creation time the hooked session does following things 1087 in given order: 1088 1089 * calls `hook.begin()` for each given hook 1090 * finalizes the graph via `scaffold.finalize()` 1091 * create session 1092 * initializes the model via initialization ops provided by `Scaffold` 1093 * restores variables if a checkpoint exists 1094 * launches queue runners 1095 1096 Run: When `run()` is called, the hooked session does following things: 1097 1098 * calls `hook.before_run()` 1099 * calls TensorFlow `session.run()` with merged fetches and feed_dict 1100 * calls `hook.after_run()` 1101 * returns result of `session.run()` asked by user 1102 1103 Exit: At the `close()`, the hooked session does following things in order: 1104 1105 * calls `hook.end()` 1106 * closes the queue runners and the session 1107 * suppresses `OutOfRange` error which indicates that all inputs have been 1108 processed if the `SingularMonitoredSession` is used as a context. 1109 1110 @compatibility(TF2) 1111 This API is not compatible with eager execution and `tf.function`. To migrate 1112 to TF2, rewrite the code to be compatible with eager execution. Check the 1113 [migration 1114 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls) 1115 on replacing `Session.run` calls. In Keras, session hooks can be replaced by 1116 Callbacks e.g. [logging hook notebook]( 1117 https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb) 1118 For more details please read [Better 1119 performance with tf.function](https://www.tensorflow.org/guide/function). 1120 @end_compatibility 1121 """ 1122 1123 def __init__(self, 1124 hooks=None, 1125 scaffold=None, 1126 master='', 1127 config=None, 1128 checkpoint_dir=None, 1129 stop_grace_period_secs=120, 1130 checkpoint_filename_with_path=None): 1131 """Creates a SingularMonitoredSession. 1132 1133 Args: 1134 hooks: An iterable of `SessionRunHook' objects. 1135 scaffold: A `Scaffold` used for gathering or building supportive ops. If 1136 not specified a default one is created. It's used to finalize the graph. 1137 master: `String` representation of the TensorFlow master to use. 1138 config: `ConfigProto` proto used to configure the session. 1139 checkpoint_dir: A string. Optional path to a directory where to restore 1140 variables. 1141 stop_grace_period_secs: Number of seconds given to threads to stop after 1142 `close()` has been called. 1143 checkpoint_filename_with_path: A string. Optional path to a checkpoint 1144 file from which to restore variables. 1145 """ 1146 session_creator = ChiefSessionCreator( 1147 scaffold=scaffold, 1148 master=master, 1149 config=config, 1150 checkpoint_dir=checkpoint_dir, 1151 checkpoint_filename_with_path=checkpoint_filename_with_path) 1152 super(SingularMonitoredSession, self).__init__( 1153 session_creator, 1154 hooks, 1155 should_recover=False, 1156 stop_grace_period_secs=stop_grace_period_secs) 1157 1158 def raw_session(self): 1159 """Returns underlying `TensorFlow.Session` object.""" 1160 return self._tf_sess() 1161 1162 1163class _WrappedSession: 1164 """Wrapper around a `tf.compat.v1.Session`. 1165 1166 This wrapper is used as a base class for various session wrappers 1167 that provide additional functionality such as monitoring, coordination, 1168 and recovery. 1169 1170 In addition to the methods exported by `SessionInterface` the wrapper 1171 provides a method to check for stop and never raises exceptions from 1172 calls to `close()`. 1173 """ 1174 1175 def __init__(self, sess): 1176 """Creates a `_WrappedSession`. 1177 1178 Args: 1179 sess: A `tf.compat.v1.Session` or `_WrappedSession` object. The wrapped 1180 session. 1181 """ 1182 self._sess = sess 1183 self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession) 1184 1185 @property 1186 def graph(self): 1187 return self._sess.graph 1188 1189 @property 1190 def sess_str(self): 1191 return self._sess.sess_str 1192 1193 def should_stop(self): 1194 """Return true if this session should not be used anymore. 1195 1196 Always return True if the session was closed. 1197 1198 Returns: 1199 True if the session should stop, False otherwise. 1200 """ 1201 if self._check_stop(): 1202 return True 1203 if self._sess: 1204 return self._wrapped_is_stoppable and self._sess.should_stop() 1205 return True 1206 1207 def _check_stop(self): 1208 """Hook for subclasses to provide their own stop condition. 1209 1210 Returns: 1211 True if the session should stop, False otherwise. 1212 """ 1213 return False 1214 1215 def close(self): 1216 if self._sess: 1217 try: 1218 self._sess.close() 1219 except _PREEMPTION_ERRORS as e: 1220 logging.error( 1221 'An error occurred when attempting to close the ' 1222 'session. This may be due to a preemption in a ' 1223 'connected worker or parameter server. Error: %s', e) 1224 finally: 1225 self._sess = None 1226 1227 def run(self, *args, **kwargs): 1228 return self._sess.run(*args, **kwargs) 1229 1230 def run_step_fn(self, step_fn, raw_session, run_with_hooks): 1231 # `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`. 1232 # It is `None` when called from `_CoordinatedSession`. In that case 1233 # `self.run` is `_CoordinatedSession.run`. 1234 run_with_hooks = run_with_hooks or self.run 1235 return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks)) 1236 1237 1238class _RecoverableSession(_WrappedSession): 1239 """A wrapped session that recreates a session upon certain kinds of errors. 1240 1241 The constructor is passed a SessionCreator object, not a session. 1242 1243 Calls to `run()` are delegated to the wrapped session. If a call raises the 1244 exception `tf.errors.AbortedError` or `tf.errors.UnavailableError`, the 1245 wrapped session is closed, and a new one is created by calling the factory 1246 again. 1247 """ 1248 1249 def __init__(self, sess_creator): 1250 """Create a new `_RecoverableSession`. 1251 1252 The value returned by calling `sess_creator.create_session()` will be the 1253 session wrapped by this recoverable session. 1254 1255 Args: 1256 sess_creator: A 'SessionCreator' to be wrapped by recoverable. 1257 """ 1258 self._sess_creator = sess_creator 1259 _WrappedSession.__init__(self, self._create_session()) 1260 1261 def _create_session(self): 1262 while True: 1263 try: 1264 return self._sess_creator.create_session() 1265 except _PREEMPTION_ERRORS as e: 1266 logging.info( 1267 'An error was raised while a session was being created. ' 1268 'This may be due to a preemption of a connected worker ' 1269 'or parameter server. A new session will be created. ' 1270 'This error may also occur due to a gRPC failure caused ' 1271 'by high memory or network bandwidth usage in the ' 1272 'parameter servers. If this error occurs repeatedly, try ' 1273 'increasing the number of parameter servers assigned to ' 1274 'the job. Error: %s', e) 1275 1276 def _check_stop(self): 1277 try: 1278 if self._sess: 1279 return self._sess._check_stop() # pylint: disable=protected-access 1280 else: 1281 return True 1282 except _PREEMPTION_ERRORS as e: 1283 logging.info( 1284 'An error was raised while considering whether the ' 1285 'session is complete. This may be due to a preemption in ' 1286 'a connected worker or parameter server. The current ' 1287 'session will be closed and a new session will be ' 1288 'created. This error may also occur due to a gRPC failure ' 1289 'caused by high memory or network bandwidth usage in the ' 1290 'parameter servers. If this error occurs repeatedly, try ' 1291 'increasing the number of parameter servers assigned to ' 1292 'the job. Error: %s', e) 1293 self.close() 1294 self._sess = self._create_session() 1295 # Since we have just recreated the session, the overall computation should 1296 # not stop: 1297 return False 1298 except Exception: # pylint: disable=broad-except 1299 # `should_stop` should return True instead of raising an exception. 1300 return True 1301 1302 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 1303 while True: 1304 try: 1305 if not self._sess: 1306 self._sess = self._create_session() 1307 return self._sess.run( 1308 fetches, 1309 feed_dict=feed_dict, 1310 options=options, 1311 run_metadata=run_metadata) 1312 except _PREEMPTION_ERRORS as e: 1313 logging.info( 1314 'An error was raised. This may be due to a preemption in ' 1315 'a connected worker or parameter server. The current ' 1316 'session will be closed and a new session will be ' 1317 'created. This error may also occur due to a gRPC failure ' 1318 'caused by high memory or network bandwidth usage in the ' 1319 'parameter servers. If this error occurs repeatedly, try ' 1320 'increasing the number of parameter servers assigned to ' 1321 'the job. Error: %s', e) 1322 self.close() 1323 self._sess = None 1324 1325 def run_step_fn(self, step_fn, raw_session, run_with_hooks): 1326 while True: 1327 try: 1328 if not self._sess: 1329 self._sess = self._create_session() 1330 1331 run_with_hooks = self._sess.run 1332 return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks) 1333 except _PREEMPTION_ERRORS as e: 1334 logging.info( 1335 'An error was raised. This may be due to a preemption in ' 1336 'a connected worker or parameter server. The current ' 1337 'session will be closed and a new session will be ' 1338 'created. This error may also occur due to a gRPC failure ' 1339 'caused by high memory or network bandwidth usage in the ' 1340 'parameter servers. If this error occurs repeatedly, try ' 1341 'increasing the number of parameter servers assigned to ' 1342 'the job. Error: %s', e) 1343 self.close() 1344 self._sess = None 1345 1346 1347class _CoordinatedSession(_WrappedSession): 1348 """A wrapped session that works with a `tf.Coordinator`. 1349 1350 Calls to `run()` are delegated to the wrapped session. If a call 1351 raises an exception, the exception is reported to the coordinator. 1352 1353 In addition, after each call to `run()` this session ask the coordinator if 1354 the session should stop. In that case it will join all the threads 1355 registered with the coordinator before returning. 1356 1357 If the coordinator was requested to stop with an exception, that exception 1358 will be re-raised from the call to `run()`. 1359 """ 1360 1361 def __init__(self, sess, coord, stop_grace_period_secs=120): 1362 """Create a new `_CoordinatedSession`. 1363 1364 Args: 1365 sess: A `tf.compat.v1.Session` object. The wrapped session. 1366 coord: A `tf.train.Coordinator` object. 1367 stop_grace_period_secs: Number of seconds given to threads to stop after 1368 `close()` has been called. 1369 """ 1370 _WrappedSession.__init__(self, sess) 1371 self._coord = coord 1372 self._stop_grace_period_secs = stop_grace_period_secs 1373 1374 def _check_stop(self): 1375 # If the coordinator was asked to stop due to an exception, then it needs 1376 # to be propagated to this stack. 1377 self._coord.raise_requested_exception() 1378 # At this point, no exceptions are recorded in the coordinator. 1379 return self._coord.should_stop() 1380 1381 def close(self): 1382 self._coord.request_stop() 1383 try: 1384 self._coord.join( 1385 stop_grace_period_secs=self._stop_grace_period_secs, 1386 ignore_live_threads=True) 1387 finally: 1388 try: 1389 _WrappedSession.close(self) 1390 except Exception: # pylint: disable=broad-except 1391 # We intentionally suppress exceptions from the close() here since 1392 # useful exceptions are already reported by join(). 1393 pass 1394 1395 def run(self, *args, **kwargs): 1396 try: 1397 return self._sess.run(*args, **kwargs) 1398 except _PREEMPTION_ERRORS: 1399 raise 1400 except Exception as original_exception: # pylint: disable=broad-except 1401 # A non-preemption error could have been caused by a preemption error 1402 # in the coordinator. If this is the case, raise that exception instead, 1403 # since it's the root cause. Otherwise, stick to the `original_exception`. 1404 try: 1405 self._coord.raise_requested_exception() 1406 except _PREEMPTION_ERRORS: 1407 raise 1408 except Exception: # pylint: disable=broad-except 1409 raise original_exception from None 1410 else: 1411 raise 1412 1413 1414class _HookedSession(_WrappedSession): 1415 """A _WrappedSession that calls hooks during calls to run(). 1416 1417 The list of hooks to call is passed in the constructor. Before each call 1418 to `run()` the session calls the `before_run()` method of the hooks, which 1419 can return additional ops or tensors to run. These are added to the arguments 1420 of the call to `run()`. 1421 1422 When the `run()` call finishes, the session calls the `after_run()` methods of 1423 the hooks, passing the values returned by the `run()` call corresponding to 1424 the ops and tensors that each hook requested. 1425 1426 If any call to the hooks, requests stop via run_context the session will be 1427 marked as needing to stop and its `should_stop()` method will now return 1428 `True`. 1429 """ 1430 1431 def __init__(self, sess, hooks): 1432 """Initializes a _HookedSession object. 1433 1434 Args: 1435 sess: A `tf.compat.v1.Session` or a `_WrappedSession` object. 1436 hooks: An iterable of `SessionRunHook' objects. 1437 """ 1438 1439 _WrappedSession.__init__(self, sess) 1440 self._hooks = hooks 1441 self._should_stop = False 1442 1443 def _check_stop(self): 1444 """See base class.""" 1445 return self._should_stop 1446 1447 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 1448 """See base class.""" 1449 if self.should_stop(): 1450 raise RuntimeError('Run called even after should_stop requested.') 1451 1452 actual_fetches = {'caller': fetches} 1453 1454 run_context = session_run_hook.SessionRunContext( 1455 original_args=session_run_hook.SessionRunArgs(fetches, feed_dict), 1456 session=self._sess) 1457 1458 options = options or config_pb2.RunOptions() 1459 feed_dict = self._call_hook_before_run(run_context, actual_fetches, 1460 feed_dict, options) 1461 1462 # Do session run. 1463 run_metadata = run_metadata or config_pb2.RunMetadata() 1464 outputs = _WrappedSession.run( 1465 self, 1466 fetches=actual_fetches, 1467 feed_dict=feed_dict, 1468 options=options, 1469 run_metadata=run_metadata) 1470 1471 for hook in self._hooks: 1472 hook.after_run( 1473 run_context, 1474 session_run_hook.SessionRunValues( 1475 results=outputs[hook] if hook in outputs else None, 1476 options=options, 1477 run_metadata=run_metadata)) 1478 self._should_stop = self._should_stop or run_context.stop_requested 1479 1480 return outputs['caller'] 1481 1482 def _call_hook_before_run(self, run_context, fetch_dict, user_feed_dict, 1483 options): 1484 """Calls hooks.before_run and handles requests from hooks.""" 1485 hook_feeds = {} 1486 for hook in self._hooks: 1487 request = hook.before_run(run_context) 1488 if request is not None: 1489 if request.fetches is not None: 1490 fetch_dict[hook] = request.fetches 1491 if request.feed_dict: 1492 self._raise_if_feeds_intersects(hook_feeds, request.feed_dict, 1493 'Same tensor is fed by two hooks.') 1494 hook_feeds.update(request.feed_dict) 1495 if request.options: 1496 self._merge_run_options(options, request.options) 1497 1498 if not hook_feeds: 1499 return user_feed_dict 1500 1501 if not user_feed_dict: 1502 return hook_feeds 1503 1504 self._raise_if_feeds_intersects( 1505 user_feed_dict, hook_feeds, 1506 'Same tensor is fed by a SessionRunHook and user.') 1507 hook_feeds.update(user_feed_dict) 1508 return hook_feeds 1509 1510 def _raise_if_feeds_intersects(self, feeds1, feeds2, message): 1511 intersection = set(feeds1.keys()) & set(feeds2.keys()) 1512 if intersection: 1513 raise RuntimeError(message + ' Conflict(s): ' + str(list(intersection))) 1514 1515 def _merge_run_options(self, options, incoming_options): 1516 """Merge two instances of RunOptions into the first one. 1517 1518 During the merger, the numerical fields including trace_level, 1519 timeout_in_ms, inter_op_thread_pool are set to the larger one of the two. 1520 The boolean value is set to the logical OR of the two. 1521 debug_tensor_watch_opts of the original options is extended with that from 1522 the incoming one. 1523 1524 Args: 1525 options: The options to merge into. 1526 incoming_options: The options to be merged into the first argument. 1527 """ 1528 options.trace_level = max(options.trace_level, incoming_options.trace_level) 1529 options.timeout_in_ms = max(options.timeout_in_ms, 1530 incoming_options.timeout_in_ms) 1531 options.inter_op_thread_pool = max(options.inter_op_thread_pool, 1532 incoming_options.inter_op_thread_pool) 1533 options.output_partition_graphs = max( 1534 options.output_partition_graphs, 1535 incoming_options.output_partition_graphs) 1536 options.debug_options.debug_tensor_watch_opts.extend( 1537 incoming_options.debug_options.debug_tensor_watch_opts) 1538 options.debug_options.reset_disk_byte_usage = ( 1539 options.debug_options.reset_disk_byte_usage or 1540 incoming_options.debug_options.reset_disk_byte_usage) 1541 options.report_tensor_allocations_upon_oom = ( 1542 options.report_tensor_allocations_upon_oom or 1543 incoming_options.report_tensor_allocations_upon_oom) 1544