1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Some common SessionRunHook classes. 16 17Note that the symbols that are exported to v1 tf.train namespace are also 18exported to v2 in tf.estimator namespace. See 19https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py 20""" 21 22import os 23import time 24 25import numpy as np 26 27from tensorflow.core.framework.summary_pb2 import Summary 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.core.util.event_pb2 import SessionLog 30from tensorflow.python.client import timeline 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import errors 33from tensorflow.python.framework import meta_graph 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import init_ops 36from tensorflow.python.ops import variable_scope 37from tensorflow.python.platform import gfile 38from tensorflow.python.platform import tf_logging as logging 39from tensorflow.python.training import session_run_hook 40from tensorflow.python.training import training_util 41from tensorflow.python.training.session_run_hook import SessionRunArgs 42from tensorflow.python.training.summary_io import SummaryWriterCache 43from tensorflow.python.util.tf_export import tf_export 44 45_HOOKS = "hooks" 46_STEPS_PER_RUN_VAR = "steps_per_run" 47 48 49class _HookTimer: 50 """Base timer for determining when Hooks should trigger. 51 52 Should not be instantiated directly. 53 """ 54 55 def __init__(self): 56 pass 57 58 def reset(self): 59 """Resets the timer.""" 60 pass 61 62 def should_trigger_for_step(self, step): 63 """Return true if the timer should trigger for the specified step.""" 64 raise NotImplementedError 65 66 def update_last_triggered_step(self, step): 67 """Update the last triggered time and step number. 68 69 Args: 70 step: The current step. 71 72 Returns: 73 A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number 74 of seconds between the current trigger and the last one (a float), and 75 `elapsed_steps` is the number of steps between the current trigger and 76 the last one. Both values will be set to `None` on the first trigger. 77 """ 78 raise NotImplementedError 79 80 def last_triggered_step(self): 81 """Returns the last triggered time step or None if never triggered.""" 82 raise NotImplementedError 83 84 85@tf_export(v1=["train.SecondOrStepTimer"]) 86class SecondOrStepTimer(_HookTimer): 87 """Timer that triggers at most once every N seconds or once every N steps. 88 89 This symbol is also exported to v2 in tf.estimator namespace. See 90 https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py 91 """ 92 93 def __init__(self, every_secs=None, every_steps=None): 94 self.reset() 95 self._every_secs = every_secs 96 self._every_steps = every_steps 97 98 if self._every_secs is None and self._every_steps is None: 99 raise ValueError("Either every_secs or every_steps should be provided.") 100 if (self._every_secs is not None) and (self._every_steps is not None): 101 raise ValueError("Can not provide both every_secs and every_steps.") 102 103 super(SecondOrStepTimer, self).__init__() 104 105 def reset(self): 106 self._last_triggered_step = None 107 self._last_triggered_time = None 108 109 def should_trigger_for_step(self, step): 110 """Return true if the timer should trigger for the specified step. 111 112 Args: 113 step: Training step to trigger on. 114 115 Returns: 116 True if the difference between the current time and the time of the last 117 trigger exceeds `every_secs`, or if the difference between the current 118 step and the last triggered step exceeds `every_steps`. False otherwise. 119 """ 120 if self._last_triggered_step is None: 121 return True 122 123 if self._last_triggered_step == step: 124 return False 125 126 if self._every_secs is not None: 127 if time.time() >= self._last_triggered_time + self._every_secs: 128 return True 129 130 if self._every_steps is not None: 131 if step >= self._last_triggered_step + self._every_steps: 132 return True 133 134 return False 135 136 def update_last_triggered_step(self, step): 137 current_time = time.time() 138 if self._last_triggered_time is None: 139 elapsed_secs = None 140 elapsed_steps = None 141 else: 142 elapsed_secs = current_time - self._last_triggered_time 143 elapsed_steps = step - self._last_triggered_step 144 145 self._last_triggered_time = current_time 146 self._last_triggered_step = step 147 return (elapsed_secs, elapsed_steps) 148 149 def last_triggered_step(self): 150 return self._last_triggered_step 151 152 153class NeverTriggerTimer(_HookTimer): 154 """Timer that never triggers.""" 155 156 def should_trigger_for_step(self, step): 157 _ = step 158 return False 159 160 def update_last_triggered_step(self, step): 161 _ = step 162 return (None, None) 163 164 def last_triggered_step(self): 165 return None 166 167 168@tf_export(v1=["train.LoggingTensorHook"]) 169class LoggingTensorHook(session_run_hook.SessionRunHook): 170 """Prints the given tensors every N local steps, every N seconds, or at end. 171 172 The tensors will be printed to the log, with `INFO` severity. If you are not 173 seeing the logs, you might want to add the following line after your imports: 174 175 ```python 176 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 177 ``` 178 179 Note that if `at_end` is True, `tensors` should not include any tensor 180 whose evaluation produces a side effect such as consuming additional inputs. 181 182 @compatibility(TF2) 183 Please check this [notebook][notebook] on how to migrate the API to TF2. 184 185 [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb 186 187 @end_compatibility 188 189 """ 190 191 def __init__(self, 192 tensors, 193 every_n_iter=None, 194 every_n_secs=None, 195 at_end=False, 196 formatter=None): 197 """Initializes a `LoggingTensorHook`. 198 199 Args: 200 tensors: `dict` that maps string-valued tags to tensors/tensor names, or 201 `iterable` of tensors/tensor names. 202 every_n_iter: `int`, print the values of `tensors` once every N local 203 steps taken on the current worker. 204 every_n_secs: `int` or `float`, print the values of `tensors` once every N 205 seconds. Exactly one of `every_n_iter` and `every_n_secs` should be 206 provided. 207 at_end: `bool` specifying whether to print the values of `tensors` at the 208 end of the run. 209 formatter: function, takes dict of `tag`->`Tensor` and returns a string. 210 If `None` uses default printing all tensors. 211 212 Raises: 213 ValueError: if `every_n_iter` is non-positive. 214 """ 215 only_log_at_end = ( 216 at_end and (every_n_iter is None) and (every_n_secs is None)) 217 if (not only_log_at_end and 218 (every_n_iter is None) == (every_n_secs is None)): 219 raise ValueError( 220 "either at_end and/or exactly one of every_n_iter and every_n_secs " 221 "must be provided.") 222 if every_n_iter is not None and every_n_iter <= 0: 223 raise ValueError("invalid every_n_iter=%s." % every_n_iter) 224 if not isinstance(tensors, dict): 225 self._tag_order = tensors 226 tensors = {item: item for item in tensors} 227 else: 228 self._tag_order = sorted(tensors.keys()) 229 self._tensors = tensors 230 self._formatter = formatter 231 self._timer = ( 232 NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer( 233 every_secs=every_n_secs, every_steps=every_n_iter)) 234 self._log_at_end = at_end 235 236 def begin(self): 237 self._timer.reset() 238 self._iter_count = 0 239 # Convert names to tensors if given 240 self._current_tensors = { 241 tag: _as_graph_element(tensor) 242 for (tag, tensor) in self._tensors.items() 243 } 244 245 def before_run(self, run_context): # pylint: disable=unused-argument 246 self._should_trigger = self._timer.should_trigger_for_step(self._iter_count) 247 if self._should_trigger: 248 return SessionRunArgs(self._current_tensors) 249 else: 250 return None 251 252 def _log_tensors(self, tensor_values): 253 original = np.get_printoptions() 254 np.set_printoptions(suppress=True) 255 elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count) 256 if self._formatter: 257 logging.info(self._formatter(tensor_values)) 258 else: 259 stats = [] 260 for tag in self._tag_order: 261 stats.append("%s = %s" % (tag, tensor_values[tag])) 262 if elapsed_secs is not None: 263 logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs) 264 else: 265 logging.info("%s", ", ".join(stats)) 266 np.set_printoptions(**original) 267 268 def after_run(self, run_context, run_values): 269 _ = run_context 270 if self._should_trigger: 271 self._log_tensors(run_values.results) 272 273 self._iter_count += 1 274 275 def end(self, session): 276 if self._log_at_end: 277 values = session.run(self._current_tensors) 278 self._log_tensors(values) 279 280 281def get_or_create_steps_per_run_variable(): 282 """Gets or creates the steps_per_run variable. 283 284 In Estimator, the user provided computation, the model_fn, is wrapped 285 inside a tf.while_loop for peak performance. The iterations of the loop are 286 specified by this variable, which adjusts its value on the CPU after each 287 device program execution and before the next execution. 288 289 The purpose of using a variable, rather than a constant, is to allow 290 Estimator adapt the device training iterations according to the final steps 291 specified by users. For example, if the user sets the steps_per_run as 292 4 and steps as 10 in Estimator.train(), the steps_per_run 293 variable will have the following value before each training run. 294 295 - 1-st execution: steps_per_run = 4 296 - 2-nd execution: steps_per_run = 4 297 - 3-rd execution: steps_per_run = 2 298 299 As model_fn increases the global step once per train_op invocation, the global 300 step is 10 after all executions, matching the steps=10 inputs passed in by 301 users. 302 303 Returns: 304 A TF non-trainable resource variable. 305 306 Raises: 307 RuntimeError: If multi steps_per_run variables were found. 308 """ 309 graph = ops.get_default_graph() 310 collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR) 311 steps_per_run_vars = graph.get_collection(collection_name) 312 if len(steps_per_run_vars) == 1: 313 return steps_per_run_vars[0] 314 elif len(steps_per_run_vars) > 1: 315 raise RuntimeError("Multiple steps_per_run_var in collection.") 316 317 with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE): 318 return variable_scope.get_variable( 319 _STEPS_PER_RUN_VAR, 320 initializer=init_ops.ones_initializer(), 321 shape=[], 322 dtype=dtypes.int32, 323 trainable=False, 324 collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], 325 use_resource=True) 326 327 328class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook): 329 """Hook that requests stop at a specified step.""" 330 331 def __init__(self, num_steps=None, last_step=None, steps_per_run=1): 332 """Initializes a `MultiStepStopAtStepHook`. 333 334 This hook requests stop after either a number of steps have been 335 executed or a last step has been reached. Only one of the two options can be 336 specified. 337 338 if `num_steps` is specified, it indicates the number of steps to execute 339 after `begin()` is called. If instead `last_step` is specified, it 340 indicates the last step we want to execute, as passed to the `after_run()` 341 call. 342 343 In Estimator, the user provided computation, the model_fn, is wrapped 344 inside a tf.while_loop for peak performance. The steps_per_run variable 345 determines the number of iterations of the loop before returning to the CPU. 346 347 Args: 348 num_steps: Number of steps to execute. 349 last_step: Step after which to stop. 350 steps_per_run: Number of steps executed per run call. 351 352 Raises: 353 ValueError: If one of the arguments is invalid. 354 """ 355 if num_steps is None and last_step is None: 356 raise ValueError("One of num_steps or last_step must be specified.") 357 if num_steps is not None and last_step is not None: 358 raise ValueError("Only one of num_steps or last_step can be specified.") 359 if steps_per_run is None or steps_per_run < 1: 360 raise ValueError("steps_per_run should be greater than 0") 361 self._num_steps = num_steps 362 self._last_step = last_step 363 self._steps_per_run_initial_value = steps_per_run 364 365 def begin(self): 366 self._global_step_tensor = training_util.get_global_step() 367 if self._global_step_tensor is None: 368 raise RuntimeError("Global step should be created to use StopAtStepHook.") 369 self._steps_per_run_variable = get_or_create_steps_per_run_variable() 370 371 def _update_steps_per_run_variable(self, global_step, session): 372 steps = min(self._last_step - global_step, 373 self._steps_per_run_initial_value) 374 self._steps_per_run_variable.load(steps, session=session) 375 376 def after_create_session(self, session, coord): 377 global_step = session.run(self._global_step_tensor) 378 if self._last_step is None: 379 self._last_step = global_step + self._num_steps 380 self._update_steps_per_run_variable(global_step, session) 381 382 def after_run(self, run_context, run_values): 383 # Global step cannot be retrieved via SessionRunArgs and before_run due to 384 # race condition in hook execution. 385 global_step = run_context.session.run(self._global_step_tensor) 386 if global_step >= self._last_step: 387 run_context.request_stop() 388 else: 389 self._update_steps_per_run_variable(global_step, run_context.session) 390 391 392@tf_export(v1=["train.StopAtStepHook"]) 393class StopAtStepHook(session_run_hook.SessionRunHook): 394 """Hook that requests stop at a specified step. 395 396 @compatibility(TF2) 397 Please check this [notebook][notebook] on how to migrate the API to TF2. 398 399 [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb 400 401 @end_compatibility 402 """ 403 404 def __init__(self, num_steps=None, last_step=None): 405 """Initializes a `StopAtStepHook`. 406 407 This hook requests stop after either a number of steps have been 408 executed or a last step has been reached. Only one of the two options can be 409 specified. 410 411 if `num_steps` is specified, it indicates the number of steps to execute 412 after `begin()` is called. If instead `last_step` is specified, it 413 indicates the last step we want to execute, as passed to the `after_run()` 414 call. 415 416 Args: 417 num_steps: Number of steps to execute. 418 last_step: Step after which to stop. 419 420 Raises: 421 ValueError: If one of the arguments is invalid. 422 """ 423 if num_steps is None and last_step is None: 424 raise ValueError("One of num_steps or last_step must be specified.") 425 if num_steps is not None and last_step is not None: 426 raise ValueError("Only one of num_steps or last_step can be specified.") 427 self._num_steps = num_steps 428 self._last_step = last_step 429 430 def begin(self): 431 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 432 if self._global_step_tensor is None: 433 raise RuntimeError("Global step should be created to use StopAtStepHook.") 434 435 def after_create_session(self, session, coord): 436 if self._last_step is None: 437 global_step = session.run(self._global_step_tensor) 438 self._last_step = global_step + self._num_steps 439 440 def before_run(self, run_context): # pylint: disable=unused-argument 441 return SessionRunArgs(self._global_step_tensor) 442 443 def after_run(self, run_context, run_values): 444 global_step = run_values.results + 1 445 if global_step >= self._last_step: 446 # Check latest global step to ensure that the targeted last step is 447 # reached. global_step read tensor is the value of global step 448 # before running the operation. We're not sure whether current session.run 449 # incremented the global_step or not. Here we're checking it. 450 451 step = run_context.session.run(self._global_step_tensor) 452 if step >= self._last_step: 453 run_context.request_stop() 454 455 456@tf_export(v1=["train.CheckpointSaverListener"]) 457class CheckpointSaverListener: 458 """Interface for listeners that take action before or after checkpoint save. 459 460 `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is 461 triggered, and provides callbacks at the following points: 462 - before using the session 463 - before each call to `Saver.save()` 464 - after each call to `Saver.save()` 465 - at the end of session 466 467 To use a listener, implement a class and pass the listener to a 468 `CheckpointSaverHook`, as in this example: 469 470 ```python 471 class ExampleCheckpointSaverListener(CheckpointSaverListener): 472 def begin(self): 473 # You can add ops to the graph here. 474 print('Starting the session.') 475 self.your_tensor = ... 476 477 def before_save(self, session, global_step_value): 478 print('About to write a checkpoint') 479 480 def after_save(self, session, global_step_value): 481 print('Done writing checkpoint.') 482 if decided_to_stop_training(): 483 return True 484 485 def end(self, session, global_step_value): 486 print('Done with the session.') 487 488 ... 489 listener = ExampleCheckpointSaverListener() 490 saver_hook = tf.estimator.CheckpointSaverHook( 491 checkpoint_dir, listeners=[listener]) 492 with 493 tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): 494 ... 495 ``` 496 497 A `CheckpointSaverListener` may simply take some action after every 498 checkpoint save. It is also possible for the listener to use its own schedule 499 to act less frequently, e.g. based on global_step_value. In this case, 500 implementors should implement the `end()` method to handle actions related to 501 the last checkpoint save. But the listener should not act twice if 502 `after_save()` already handled this last checkpoint save. 503 504 A `CheckpointSaverListener` can request training to be stopped, by returning 505 True in `after_save`. Please note that, in replicated distributed training 506 setting, only `chief` should use this behavior. Otherwise each worker will do 507 their own evaluation, which may be wasteful of resources. 508 """ 509 510 def begin(self): 511 pass 512 513 def before_save(self, session, global_step_value): 514 pass 515 516 def after_save(self, session, global_step_value): 517 pass 518 519 def end(self, session, global_step_value): 520 pass 521 522 523@tf_export(v1=["train.CheckpointSaverHook"]) 524class CheckpointSaverHook(session_run_hook.SessionRunHook): 525 """Saves checkpoints every N steps or seconds.""" 526 527 def __init__(self, 528 checkpoint_dir, 529 save_secs=None, 530 save_steps=None, 531 saver=None, 532 checkpoint_basename="model.ckpt", 533 scaffold=None, 534 listeners=None, 535 save_graph_def=True): 536 """Initializes a `CheckpointSaverHook`. 537 538 Args: 539 checkpoint_dir: `str`, base directory for the checkpoint files. 540 save_secs: `int`, save every N secs. 541 save_steps: `int`, save every N steps. 542 saver: `Saver` object, used for saving. 543 checkpoint_basename: `str`, base name for the checkpoint files. 544 scaffold: `Scaffold`, use to get saver object. 545 listeners: List of `CheckpointSaverListener` subclass instances. Used for 546 callbacks that run immediately before or after this hook saves the 547 checkpoint. 548 save_graph_def: Whether to save the GraphDef and MetaGraphDef to 549 `checkpoint_dir`. The GraphDef is saved after the session is created as 550 `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as 551 `model.ckpt-*.meta`. 552 553 Raises: 554 ValueError: One of `save_steps` or `save_secs` should be set. 555 ValueError: At most one of `saver` or `scaffold` should be set. 556 """ 557 logging.info("Create CheckpointSaverHook.") 558 if saver is not None and scaffold is not None: 559 raise ValueError("You cannot provide both saver and scaffold.") 560 self._saver = saver 561 self._checkpoint_dir = checkpoint_dir 562 self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) 563 self._scaffold = scaffold 564 self._timer = SecondOrStepTimer( 565 every_secs=save_secs, every_steps=save_steps) 566 self._listeners = listeners or [] 567 # Set sufficiently high default that it never skips checking the actual 568 # global step counter -- unless the user overrides it with the right value 569 # for the steps_per_run. 570 self._steps_per_run = 1000000 571 self._save_graph_def = save_graph_def 572 573 def _set_steps_per_run(self, steps_per_run): 574 self._steps_per_run = steps_per_run 575 576 def begin(self): 577 self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) 578 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 579 if self._global_step_tensor is None: 580 raise RuntimeError( 581 "Global step should be created to use CheckpointSaverHook.") 582 for l in self._listeners: 583 l.begin() 584 585 def after_create_session(self, session, coord): 586 global_step = session.run(self._global_step_tensor) 587 if self._save_graph_def: 588 # We do write graph and saver_def at the first call of before_run. 589 # We cannot do this in begin, since we let other hooks to change graph and 590 # add variables in begin. Graph is finalized after all begin calls. 591 training_util.write_graph( 592 ops.get_default_graph().as_graph_def(add_shapes=True), 593 self._checkpoint_dir, "graph.pbtxt") 594 saver_def = self._get_saver().saver_def if self._get_saver() else None 595 graph = ops.get_default_graph() 596 meta_graph_def = meta_graph.create_meta_graph_def( 597 graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) 598 self._summary_writer.add_graph(graph) 599 self._summary_writer.add_meta_graph(meta_graph_def) 600 # The checkpoint saved here is the state at step "global_step". 601 self._save(session, global_step) 602 self._timer.update_last_triggered_step(global_step) 603 604 def before_run(self, run_context): # pylint: disable=unused-argument 605 return SessionRunArgs(self._global_step_tensor) 606 607 def after_run(self, run_context, run_values): 608 stale_global_step = run_values.results 609 if self._timer.should_trigger_for_step(stale_global_step + 610 self._steps_per_run): 611 # get the real value after train op. 612 global_step = run_context.session.run(self._global_step_tensor) 613 if self._timer.should_trigger_for_step(global_step): 614 self._timer.update_last_triggered_step(global_step) 615 if self._save(run_context.session, global_step): 616 run_context.request_stop() 617 618 def end(self, session): 619 last_step = session.run(self._global_step_tensor) 620 if last_step != self._timer.last_triggered_step(): 621 self._save(session, last_step) 622 for l in self._listeners: 623 l.end(session, last_step) 624 625 def _save(self, session, step): 626 """Saves the latest checkpoint, returns should_stop.""" 627 logging.info("Calling checkpoint listeners before saving checkpoint %d...", 628 step) 629 for l in self._listeners: 630 l.before_save(session, step) 631 632 logging.info("Saving checkpoints for %d into %s.", step, self._save_path) 633 self._get_saver().save(session, self._save_path, global_step=step, 634 write_meta_graph=self._save_graph_def) 635 self._summary_writer.add_session_log( 636 SessionLog( 637 status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), 638 step) 639 logging.info("Calling checkpoint listeners after saving checkpoint %d...", 640 step) 641 should_stop = False 642 for l in self._listeners: 643 if l.after_save(session, step): 644 logging.info( 645 "A CheckpointSaverListener requested that training be stopped. " 646 "listener: {}".format(l)) 647 should_stop = True 648 return should_stop 649 650 def _get_saver(self): 651 if self._saver is not None: 652 return self._saver 653 elif self._scaffold is not None: 654 return self._scaffold.saver 655 656 # Get saver from the SAVERS collection if present. 657 collection_key = ops.GraphKeys.SAVERS 658 savers = ops.get_collection(collection_key) 659 if not savers: 660 raise RuntimeError( 661 "No items in collection {}. Please add a saver to the collection " 662 "or provide a saver or scaffold.".format(collection_key)) 663 elif len(savers) > 1: 664 raise RuntimeError( 665 "More than one item in collection {}. " 666 "Please indicate which one to use by passing it to the constructor." 667 .format(collection_key)) 668 669 self._saver = savers[0] 670 return savers[0] 671 672 673@tf_export(v1=["train.StepCounterHook"]) 674class StepCounterHook(session_run_hook.SessionRunHook): 675 """Hook that counts steps per second.""" 676 677 def __init__(self, 678 every_n_steps=100, 679 every_n_secs=None, 680 output_dir=None, 681 summary_writer=None): 682 683 if (every_n_steps is None) == (every_n_secs is None): 684 raise ValueError( 685 "exactly one of every_n_steps and every_n_secs should be provided.") 686 self._timer = SecondOrStepTimer( 687 every_steps=every_n_steps, every_secs=every_n_secs) 688 689 self._summary_writer = summary_writer 690 self._output_dir = output_dir 691 self._last_global_step = None 692 self._steps_per_run = 1 693 694 def _set_steps_per_run(self, steps_per_run): 695 self._steps_per_run = steps_per_run 696 697 def begin(self): 698 if self._summary_writer is None and self._output_dir: 699 self._summary_writer = SummaryWriterCache.get(self._output_dir) 700 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 701 if self._global_step_tensor is None: 702 raise RuntimeError( 703 "Global step should be created to use StepCounterHook.") 704 self._summary_tag = training_util.get_global_step().op.name + "/sec" 705 706 def before_run(self, run_context): # pylint: disable=unused-argument 707 return SessionRunArgs(self._global_step_tensor) 708 709 def _log_and_record(self, elapsed_steps, elapsed_time, global_step): 710 steps_per_sec = elapsed_steps / elapsed_time 711 if self._summary_writer is not None: 712 summary = Summary(value=[ 713 Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) 714 ]) 715 self._summary_writer.add_summary(summary, global_step) 716 logging.info("%s: %g", self._summary_tag, steps_per_sec) 717 718 def after_run(self, run_context, run_values): 719 _ = run_context 720 721 stale_global_step = run_values.results 722 if self._timer.should_trigger_for_step(stale_global_step + 723 self._steps_per_run): 724 # get the real value after train op. 725 global_step = run_context.session.run(self._global_step_tensor) 726 if self._timer.should_trigger_for_step(global_step): 727 elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( 728 global_step) 729 if elapsed_time is not None: 730 self._log_and_record(elapsed_steps, elapsed_time, global_step) 731 732 # Check whether the global step has been increased. Here, we do not use the 733 # timer.last_triggered_step as the timer might record a different global 734 # step value such that the comparison could be unreliable. For simplicity, 735 # we just compare the stale_global_step with previously recorded version. 736 if stale_global_step == self._last_global_step: 737 # Here, we give a warning in the first 5 times if we have observed that 738 # the global step has not been increased. For some Optimizers, the global 739 # step is not increased each time by design. For example, 740 # SyncReplicaOptimizer doesn't increase the global step in worker's main 741 # train step. 742 logging.log_first_n( 743 logging.WARN, 744 "It seems that global step (tf.train.get_global_step) has not " 745 "been increased. Current value (could be stable): %s vs previous " 746 "value: %s. You could increase the global step by passing " 747 "tf.train.get_global_step() to Optimizer.apply_gradients or " 748 "Optimizer.minimize.", 5, stale_global_step, self._last_global_step) 749 750 self._last_global_step = stale_global_step 751 752 753@tf_export(v1=["train.NanLossDuringTrainingError"]) 754class NanLossDuringTrainingError(RuntimeError): 755 756 def __str__(self): 757 return "NaN loss during training." 758 759 760@tf_export(v1=["train.NanTensorHook"]) 761class NanTensorHook(session_run_hook.SessionRunHook): 762 """Monitors the loss tensor and stops training if loss is NaN. 763 764 Can either fail with exception or just stop training. 765 """ 766 767 def __init__(self, loss_tensor, fail_on_nan_loss=True): 768 """Initializes a `NanTensorHook`. 769 770 Args: 771 loss_tensor: `Tensor`, the loss tensor. 772 fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN. 773 """ 774 self._loss_tensor = loss_tensor 775 self._fail_on_nan_loss = fail_on_nan_loss 776 777 def before_run(self, run_context): # pylint: disable=unused-argument 778 return SessionRunArgs(self._loss_tensor) 779 780 def after_run(self, run_context, run_values): 781 if np.isnan(run_values.results): 782 failure_message = "Model diverged with loss = NaN." 783 if self._fail_on_nan_loss: 784 logging.error(failure_message) 785 raise NanLossDuringTrainingError 786 else: 787 logging.warning(failure_message) 788 # We don't raise an error but we request stop without an exception. 789 run_context.request_stop() 790 791 792@tf_export(v1=["train.SummarySaverHook"]) 793class SummarySaverHook(session_run_hook.SessionRunHook): 794 """Saves summaries every N steps.""" 795 796 def __init__(self, 797 save_steps=None, 798 save_secs=None, 799 output_dir=None, 800 summary_writer=None, 801 scaffold=None, 802 summary_op=None): 803 """Initializes a `SummarySaverHook`. 804 805 Args: 806 save_steps: `int`, save summaries every N steps. Exactly one of 807 `save_secs` and `save_steps` should be set. 808 save_secs: `int`, save summaries every N seconds. 809 output_dir: `string`, the directory to save the summaries to. Only used if 810 no `summary_writer` is supplied. 811 summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed, 812 one will be created accordingly. 813 scaffold: `Scaffold` to get summary_op if it's not provided. 814 summary_op: `Tensor` of type `string` containing the serialized `Summary` 815 protocol buffer or a list of `Tensor`. They are most likely an output by 816 TF summary methods like `tf.compat.v1.summary.scalar` or 817 `tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if 818 more than one, they must be passed in as a list. 819 820 Raises: 821 ValueError: Exactly one of scaffold or summary_op should be set. 822 """ 823 if ((scaffold is None and summary_op is None) or 824 (scaffold is not None and summary_op is not None)): 825 raise ValueError( 826 "Exactly one of scaffold or summary_op must be provided.") 827 self._summary_op = summary_op 828 self._summary_writer = summary_writer 829 self._output_dir = output_dir 830 self._scaffold = scaffold 831 self._timer = SecondOrStepTimer( 832 every_secs=save_secs, every_steps=save_steps) 833 # TODO(mdan): Throw an error if output_dir and summary_writer are None. 834 835 def begin(self): 836 if self._summary_writer is None and self._output_dir: 837 self._summary_writer = SummaryWriterCache.get(self._output_dir) 838 self._next_step = None 839 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 840 if self._global_step_tensor is None: 841 raise RuntimeError( 842 "Global step should be created to use SummarySaverHook.") 843 844 def before_run(self, run_context): # pylint: disable=unused-argument 845 self._request_summary = ( 846 self._next_step is None or 847 self._timer.should_trigger_for_step(self._next_step)) 848 requests = {"global_step": self._global_step_tensor} 849 if self._request_summary: 850 if self._get_summary_op() is not None: 851 requests["summary"] = self._get_summary_op() 852 853 return SessionRunArgs(requests) 854 855 def after_run(self, run_context, run_values): 856 _ = run_context 857 if not self._summary_writer: 858 return 859 860 stale_global_step = run_values.results["global_step"] 861 global_step = stale_global_step + 1 862 if self._next_step is None or self._request_summary: 863 global_step = run_context.session.run(self._global_step_tensor) 864 865 if self._next_step is None: 866 self._summary_writer.add_session_log( 867 SessionLog(status=SessionLog.START), global_step) 868 869 if self._request_summary: 870 self._timer.update_last_triggered_step(global_step) 871 if "summary" in run_values.results: 872 for summary in run_values.results["summary"]: 873 self._summary_writer.add_summary(summary, global_step) 874 875 self._next_step = global_step + 1 876 877 def end(self, session=None): 878 if self._summary_writer: 879 self._summary_writer.flush() 880 881 def _get_summary_op(self): 882 """Fetches the summary op either from self._summary_op or self._scaffold. 883 884 Returns: 885 Returns a list of summary `Tensor`. 886 """ 887 summary_op = None 888 if self._summary_op is not None: 889 summary_op = self._summary_op 890 elif self._scaffold.summary_op is not None: 891 summary_op = self._scaffold.summary_op 892 893 if summary_op is None: 894 return None 895 896 if not isinstance(summary_op, list): 897 return [summary_op] 898 return summary_op 899 900 901@tf_export(v1=["train.GlobalStepWaiterHook"]) 902class GlobalStepWaiterHook(session_run_hook.SessionRunHook): 903 """Delays execution until global step reaches `wait_until_step`. 904 905 This hook delays execution until global step reaches to `wait_until_step`. It 906 is used to gradually start workers in distributed settings. One example usage 907 would be setting `wait_until_step=int(K*log(task_id+1))` assuming that 908 task_id=0 is the chief. 909 """ 910 911 def __init__(self, wait_until_step): 912 """Initializes a `GlobalStepWaiterHook`. 913 914 Args: 915 wait_until_step: an `int` shows until which global step should we wait. 916 """ 917 self._wait_until_step = wait_until_step 918 919 def begin(self): 920 self._worker_is_started = False 921 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 922 if self._global_step_tensor is None: 923 raise RuntimeError( 924 "Global step should be created to use _GlobalStepWaiterHook.") 925 926 def before_run(self, run_context): 927 if self._worker_is_started: 928 return None 929 930 if self._wait_until_step <= 0: 931 self._worker_is_started = True 932 return None 933 934 logging.info("Waiting for global step %d before starting training.", 935 self._wait_until_step) 936 last_logged_step = 0 937 while True: 938 current_step = run_context.session.run(self._global_step_tensor) 939 if current_step >= self._wait_until_step: 940 self._worker_is_started = True 941 return None 942 if current_step - last_logged_step > 1000: 943 logging.info( 944 "Waiting for global step %d before starting training. " 945 "Current step is %d.", self._wait_until_step, current_step) 946 last_logged_step = current_step 947 time.sleep(0.5) 948 949 950@tf_export(v1=["train.FinalOpsHook"]) 951class FinalOpsHook(session_run_hook.SessionRunHook): 952 """A hook which evaluates `Tensors` at the end of a session.""" 953 954 def __init__(self, final_ops, final_ops_feed_dict=None): 955 """Initializes `FinalOpHook` with ops to run at the end of the session. 956 957 Args: 958 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 959 to `Tensors`. 960 final_ops_feed_dict: A feed dictionary to use when running 961 `final_ops_dict`. 962 """ 963 self._final_ops = final_ops 964 self._final_ops_feed_dict = final_ops_feed_dict 965 self._final_ops_values = None 966 967 @property 968 def final_ops_values(self): 969 return self._final_ops_values 970 971 def end(self, session): 972 if self._final_ops is not None: 973 try: 974 self._final_ops_values = session.run( 975 self._final_ops, feed_dict=self._final_ops_feed_dict) 976 except (errors.OutOfRangeError, StopIteration) as e: 977 logging.warning( 978 "An OutOfRangeError or StopIteration exception is raised by the " 979 "code in FinalOpsHook. This typically means the Ops running by the " 980 "FinalOpsHook have a dependency back to some input source, which " 981 "should not happen. For example, for metrics in " 982 "tf.estimator.Estimator, all metrics functions return two Ops: " 983 "`value_op` and `update_op`. Estimator.evaluate calls the " 984 "`update_op` for each batch of the data in input source and, once " 985 "it is exhausted, it call the `value_op` to get the metric values. " 986 "The `value_op` here should have dependency back to variables " 987 "reading only, rather than reading another batch from input. " 988 "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers " 989 "another data reading, which ends OutOfRangeError/StopIteration. " 990 "Please fix that.") 991 raise e 992 993 994@tf_export(v1=["train.FeedFnHook"]) 995class FeedFnHook(session_run_hook.SessionRunHook): 996 """Runs `feed_fn` and sets the `feed_dict` accordingly.""" 997 998 def __init__(self, feed_fn): 999 """Initializes a `FeedFnHook`. 1000 1001 Args: 1002 feed_fn: function that takes no arguments and returns `dict` of `Tensor` 1003 to feed. 1004 """ 1005 self.feed_fn = feed_fn 1006 1007 def before_run(self, run_context): # pylint: disable=unused-argument 1008 return session_run_hook.SessionRunArgs( 1009 fetches=None, feed_dict=self.feed_fn()) 1010 1011 1012@tf_export(v1=["train.ProfilerHook"]) 1013class ProfilerHook(session_run_hook.SessionRunHook): 1014 """Captures CPU/GPU profiling information every N steps or seconds. 1015 1016 This produces files called "timeline-<step>.json", which are in Chrome 1017 Trace format. 1018 1019 For more information see: 1020 https://github.com/catapult-project/catapult/blob/master/tracing/README.md 1021 """ 1022 1023 def __init__(self, 1024 save_steps=None, 1025 save_secs=None, 1026 output_dir="", 1027 show_dataflow=True, 1028 show_memory=False): 1029 """Initializes a hook that takes periodic profiling snapshots. 1030 1031 `options.run_metadata` argument of `tf.Session.Run` is used to collect 1032 metadata about execution. This hook sets the metadata and dumps it in Chrome 1033 Trace format. 1034 1035 1036 Args: 1037 save_steps: `int`, save profile traces every N steps. Exactly one of 1038 `save_secs` and `save_steps` should be set. 1039 save_secs: `int` or `float`, save profile traces every N seconds. 1040 output_dir: `string`, the directory to save the profile traces to. 1041 Defaults to the current directory. 1042 show_dataflow: `bool`, if True, add flow events to the trace connecting 1043 producers and consumers of tensors. 1044 show_memory: `bool`, if True, add object snapshot events to the trace 1045 showing the sizes and lifetimes of tensors. 1046 """ 1047 self._output_file = os.path.join(output_dir, "timeline-{}.json") 1048 self._file_writer = SummaryWriterCache.get(output_dir) 1049 self._show_dataflow = show_dataflow 1050 self._show_memory = show_memory 1051 self._timer = SecondOrStepTimer( 1052 every_secs=save_secs, every_steps=save_steps) 1053 1054 def begin(self): 1055 self._next_step = None 1056 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 1057 if self._global_step_tensor is None: 1058 raise RuntimeError("Global step should be created to use ProfilerHook.") 1059 1060 def before_run(self, run_context): 1061 self._request_summary = ( 1062 self._next_step is not None and 1063 self._timer.should_trigger_for_step(self._next_step)) 1064 requests = {"global_step": self._global_step_tensor} 1065 opts = ( 1066 config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 1067 if self._request_summary else None) 1068 1069 return SessionRunArgs(requests, options=opts) 1070 1071 def after_run(self, run_context, run_values): 1072 stale_global_step = run_values.results["global_step"] 1073 if self._next_step is None: 1074 # Update the timer so that it does not activate until N steps or seconds 1075 # have passed. 1076 self._timer.update_last_triggered_step(stale_global_step) 1077 global_step = stale_global_step + 1 1078 if self._request_summary: 1079 global_step = run_context.session.run(self._global_step_tensor) 1080 self._timer.update_last_triggered_step(global_step) 1081 self._save(global_step, self._output_file.format(global_step), 1082 run_values.run_metadata.step_stats) 1083 self._file_writer.add_run_metadata(run_values.run_metadata, 1084 "step_%d" % global_step) 1085 1086 self._next_step = global_step + 1 1087 1088 def _save(self, step, save_path, step_stats): 1089 logging.info("Saving timeline for %d into '%s'.", step, save_path) 1090 with gfile.Open(save_path, "w") as f: 1091 trace = timeline.Timeline(step_stats) 1092 f.write( 1093 trace.generate_chrome_trace_format( 1094 show_dataflow=self._show_dataflow, show_memory=self._show_memory)) 1095 1096 1097def _as_graph_element(obj): 1098 """Retrieves Graph element.""" 1099 graph = ops.get_default_graph() 1100 if not isinstance(obj, str): 1101 if not hasattr(obj, "graph") or obj.graph != graph: 1102 raise ValueError("Passed %s should have graph attribute that is equal " 1103 "to current graph %s." % (obj, graph)) 1104 return obj 1105 if ":" in obj: 1106 element = graph.as_graph_element(obj) 1107 else: 1108 element = graph.as_graph_element(obj + ":0") 1109 # Check that there is no :1 (e.g. it's single output). 1110 try: 1111 graph.as_graph_element(obj + ":1") 1112 except (KeyError, ValueError): 1113 pass 1114 else: 1115 raise ValueError("Name %s is ambiguous, " 1116 "as this `Operation` has multiple outputs " 1117 "(at least 2)." % obj) 1118 return element 1119