1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15"""Hook for asynchronous checkpointing. 16 17This hook dispatches checkpoint writing operations in a separate thread to 18allow execution to continue on the main thread. 19""" 20 21import os 22import threading 23import time 24from typing import Any, List, Optional, Text 25 26from tensorflow.core.util import event_pb2 27from tensorflow.python.client import session as session_lib 28from tensorflow.python.framework import meta_graph 29from tensorflow.python.framework import ops 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.training import basic_session_run_hooks 32from tensorflow.python.training import monitored_session 33from tensorflow.python.training import saver as saver_lib 34from tensorflow.python.training import session_run_hook 35from tensorflow.python.training import training_util 36from tensorflow.python.training.summary_io import SummaryWriterCache 37 38 39class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): 40 """Saves checkpoints every N steps or seconds.""" 41 42 def __init__(self, 43 checkpoint_dir: Text, 44 save_secs: Optional[int] = None, 45 save_steps: Optional[int] = None, 46 saver: Optional[saver_lib.Saver] = None, 47 checkpoint_basename: Text = "model.ckpt", 48 scaffold: Optional[monitored_session.Scaffold] = None, 49 listeners: Optional[List[ 50 basic_session_run_hooks.CheckpointSaverListener]] = None): 51 """Initializes a `CheckpointSaverHook`. 52 53 Args: 54 checkpoint_dir: `str`, base directory for the checkpoint files. 55 save_secs: `int`, save every N secs. 56 save_steps: `int`, save every N steps. 57 saver: `Saver` object, used for saving. 58 checkpoint_basename: `str`, base name for the checkpoint files. 59 scaffold: `Scaffold`, use to get saver object. 60 listeners: List of `CheckpointSaverListener` subclass instances. Used for 61 callbacks that run immediately before or after this hook saves the 62 checkpoint. 63 64 Raises: 65 ValueError: One of `save_steps` or `save_secs` should be set. 66 ValueError: At most one of `saver` or `scaffold` should be set. 67 """ 68 save_path = os.path.join(checkpoint_dir, checkpoint_basename) 69 logging.info("Create AsyncCheckpointSaverHook saving to path\n%s", 70 save_path) 71 if listeners: 72 logging.info(" with %d listener(s).", len(listeners)) 73 if saver is not None and scaffold is not None: 74 raise ValueError("You cannot provide both saver and scaffold.") 75 self._saver = saver 76 self._save_thread = None 77 self._write_graph_thread = None 78 self._checkpoint_dir = checkpoint_dir 79 self._save_path = save_path 80 self._scaffold = scaffold 81 self._timer = basic_session_run_hooks.SecondOrStepTimer( 82 every_secs=save_secs, every_steps=save_steps) 83 self._listeners = listeners or [] 84 self._steps_per_run = 1 85 self._summary_writer = None 86 self._global_step_tensor = None 87 88 self._last_checkpoint_step = None 89 90 def _set_steps_per_run(self, steps_per_run): 91 self._steps_per_run = steps_per_run 92 93 def begin(self): 94 self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) 95 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 96 if self._global_step_tensor is None: 97 raise RuntimeError( 98 "Global step should be created to use CheckpointSaverHook.") 99 for l in self._listeners: 100 l.begin() 101 102 def after_create_session(self, session: session_lib.Session, coord: Any): 103 global_step = session.run(self._global_step_tensor) 104 105 # We do write graph and saver_def at the first call of before_run. 106 # We cannot do this in begin, since we let other hooks to change graph and 107 # add variables in begin. Graph is finalized after all begin calls. 108 def _write_graph_fn(self): 109 training_util.write_graph( 110 ops.get_default_graph().as_graph_def(add_shapes=True), 111 self._checkpoint_dir, "graph.pbtxt") 112 self._write_graph_thread = threading.Thread(target=_write_graph_fn, 113 args=[self]) 114 self._write_graph_thread.start() 115 116 saver_def = self._get_saver().saver_def if self._get_saver() else None 117 graph = ops.get_default_graph() 118 meta_graph_def = meta_graph.create_meta_graph_def( 119 graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) 120 self._summary_writer.add_graph(graph) 121 self._summary_writer.add_meta_graph(meta_graph_def) 122 # The checkpoint saved here is the state at step "global_step". 123 self._save(session, global_step) 124 self._timer.update_last_triggered_step(global_step) 125 126 def before_run(self, run_context: Any): # pylint: disable=unused-argument 127 return session_run_hook.SessionRunArgs(self._global_step_tensor) 128 129 def after_run(self, run_context: session_run_hook.SessionRunContext, 130 run_values: Any): 131 global_step = run_context.session.run(self._global_step_tensor) 132 if self._timer.should_trigger_for_step(global_step): 133 self._timer.update_last_triggered_step(global_step) 134 logging.info("Triggering checkpoint. %s", global_step) 135 if self._save(run_context.session, global_step): 136 run_context.request_stop() 137 138 def end(self, session: session_lib.Session): 139 if self._save_thread: 140 logging.info("Waiting for any pending checkpoints to finish.") 141 self._save_thread.join() 142 if self._write_graph_thread: 143 logging.info("Waiting for any pending write_graph to finish.") 144 self._write_graph_thread.join() 145 146 last_step = session.run(self._global_step_tensor) 147 148 if self._last_checkpoint_step != last_step: 149 self._save(session, last_step, asynchronous=False) 150 151 for l in self._listeners: 152 l.end(session, last_step) 153 154 def _save(self, session, step, asynchronous=True): 155 """Saves the latest checkpoint, returns should_stop.""" 156 157 def _save_fn(): 158 """Run the saver process.""" 159 logging.info("Saving checkpoints for %d into %s.", step, self._save_path) 160 161 start_time = time.time() 162 for l in self._listeners: 163 l.before_save(session, step) 164 165 self._get_saver().save(session, self._save_path, global_step=step) 166 self._summary_writer.add_session_log( 167 event_pb2.SessionLog( 168 status=event_pb2.SessionLog.CHECKPOINT, 169 checkpoint_path=self._save_path), step) 170 171 for l in self._listeners: 172 l.after_save(session, step) 173 174 end_time = time.time() 175 logging.info("Checkpoint actual writing time: (%.3f sec)", 176 end_time - start_time) 177 logging.info("Checkpoint finished for %d into %s.", step, self._save_path) 178 179 if not asynchronous: 180 self._last_checkpoint_step = step 181 _save_fn() 182 return 183 184 if self._save_thread is not None: 185 self._save_thread.join(timeout=0.1) 186 if self._save_thread.is_alive(): 187 logging.info("Saver thread still in progress, skipping checkpoint.") 188 return 189 190 self._last_checkpoint_step = step 191 self._save_thread = threading.Thread(target=_save_fn) 192 self._save_thread.start() 193 194 def _get_saver(self): 195 if self._saver is not None: 196 return self._saver 197 elif self._scaffold is not None: 198 return self._scaffold.saver 199 200 # Get saver from the SAVERS collection if present. 201 collection_key = ops.GraphKeys.SAVERS 202 savers = ops.get_collection(collection_key) 203 if not savers: 204 raise RuntimeError( 205 "No items in collection {}. Please add a saver to the collection " 206 "or provide a saver or scaffold.".format(collection_key)) 207 elif len(savers) > 1: 208 raise RuntimeError( 209 "More than one item in collection {}. " 210 "Please indicate which one to use by passing it to the constructor." 211 .format(collection_key)) 212 213 self._saver = savers[0] 214 return savers[0] 215