xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/async_checkpoint.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Hook for asynchronous checkpointing.
16
17This hook dispatches checkpoint writing operations in a separate thread to
18allow execution to continue on the main thread.
19"""
20
21import os
22import threading
23import time
24from typing import Any, List, Optional, Text
25
26from tensorflow.core.util import event_pb2
27from tensorflow.python.client import session as session_lib
28from tensorflow.python.framework import meta_graph
29from tensorflow.python.framework import ops
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.training import basic_session_run_hooks
32from tensorflow.python.training import monitored_session
33from tensorflow.python.training import saver as saver_lib
34from tensorflow.python.training import session_run_hook
35from tensorflow.python.training import training_util
36from tensorflow.python.training.summary_io import SummaryWriterCache
37
38
39class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
40  """Saves checkpoints every N steps or seconds."""
41
42  def __init__(self,
43               checkpoint_dir: Text,
44               save_secs: Optional[int] = None,
45               save_steps: Optional[int] = None,
46               saver: Optional[saver_lib.Saver] = None,
47               checkpoint_basename: Text = "model.ckpt",
48               scaffold: Optional[monitored_session.Scaffold] = None,
49               listeners: Optional[List[
50                   basic_session_run_hooks.CheckpointSaverListener]] = None):
51    """Initializes a `CheckpointSaverHook`.
52
53    Args:
54      checkpoint_dir: `str`, base directory for the checkpoint files.
55      save_secs: `int`, save every N secs.
56      save_steps: `int`, save every N steps.
57      saver: `Saver` object, used for saving.
58      checkpoint_basename: `str`, base name for the checkpoint files.
59      scaffold: `Scaffold`, use to get saver object.
60      listeners: List of `CheckpointSaverListener` subclass instances. Used for
61        callbacks that run immediately before or after this hook saves the
62        checkpoint.
63
64    Raises:
65      ValueError: One of `save_steps` or `save_secs` should be set.
66      ValueError: At most one of `saver` or `scaffold` should be set.
67    """
68    save_path = os.path.join(checkpoint_dir, checkpoint_basename)
69    logging.info("Create AsyncCheckpointSaverHook saving to path\n%s",
70                 save_path)
71    if listeners:
72      logging.info(" with %d listener(s).", len(listeners))
73    if saver is not None and scaffold is not None:
74      raise ValueError("You cannot provide both saver and scaffold.")
75    self._saver = saver
76    self._save_thread = None
77    self._write_graph_thread = None
78    self._checkpoint_dir = checkpoint_dir
79    self._save_path = save_path
80    self._scaffold = scaffold
81    self._timer = basic_session_run_hooks.SecondOrStepTimer(
82        every_secs=save_secs, every_steps=save_steps)
83    self._listeners = listeners or []
84    self._steps_per_run = 1
85    self._summary_writer = None
86    self._global_step_tensor = None
87
88    self._last_checkpoint_step = None
89
90  def _set_steps_per_run(self, steps_per_run):
91    self._steps_per_run = steps_per_run
92
93  def begin(self):
94    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
95    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
96    if self._global_step_tensor is None:
97      raise RuntimeError(
98          "Global step should be created to use CheckpointSaverHook.")
99    for l in self._listeners:
100      l.begin()
101
102  def after_create_session(self, session: session_lib.Session, coord: Any):
103    global_step = session.run(self._global_step_tensor)
104
105    # We do write graph and saver_def at the first call of before_run.
106    # We cannot do this in begin, since we let other hooks to change graph and
107    # add variables in begin. Graph is finalized after all begin calls.
108    def _write_graph_fn(self):
109      training_util.write_graph(
110          ops.get_default_graph().as_graph_def(add_shapes=True),
111          self._checkpoint_dir, "graph.pbtxt")
112    self._write_graph_thread = threading.Thread(target=_write_graph_fn,
113                                                args=[self])
114    self._write_graph_thread.start()
115
116    saver_def = self._get_saver().saver_def if self._get_saver() else None
117    graph = ops.get_default_graph()
118    meta_graph_def = meta_graph.create_meta_graph_def(
119        graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
120    self._summary_writer.add_graph(graph)
121    self._summary_writer.add_meta_graph(meta_graph_def)
122    # The checkpoint saved here is the state at step "global_step".
123    self._save(session, global_step)
124    self._timer.update_last_triggered_step(global_step)
125
126  def before_run(self, run_context: Any):  # pylint: disable=unused-argument
127    return session_run_hook.SessionRunArgs(self._global_step_tensor)
128
129  def after_run(self, run_context: session_run_hook.SessionRunContext,
130                run_values: Any):
131    global_step = run_context.session.run(self._global_step_tensor)
132    if self._timer.should_trigger_for_step(global_step):
133      self._timer.update_last_triggered_step(global_step)
134      logging.info("Triggering checkpoint. %s", global_step)
135      if self._save(run_context.session, global_step):
136        run_context.request_stop()
137
138  def end(self, session: session_lib.Session):
139    if self._save_thread:
140      logging.info("Waiting for any pending checkpoints to finish.")
141      self._save_thread.join()
142    if self._write_graph_thread:
143      logging.info("Waiting for any pending write_graph to finish.")
144      self._write_graph_thread.join()
145
146    last_step = session.run(self._global_step_tensor)
147
148    if self._last_checkpoint_step != last_step:
149      self._save(session, last_step, asynchronous=False)
150
151    for l in self._listeners:
152      l.end(session, last_step)
153
154  def _save(self, session, step, asynchronous=True):
155    """Saves the latest checkpoint, returns should_stop."""
156
157    def _save_fn():
158      """Run the saver process."""
159      logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
160
161      start_time = time.time()
162      for l in self._listeners:
163        l.before_save(session, step)
164
165      self._get_saver().save(session, self._save_path, global_step=step)
166      self._summary_writer.add_session_log(
167          event_pb2.SessionLog(
168              status=event_pb2.SessionLog.CHECKPOINT,
169              checkpoint_path=self._save_path), step)
170
171      for l in self._listeners:
172        l.after_save(session, step)
173
174      end_time = time.time()
175      logging.info("Checkpoint actual writing time: (%.3f sec)",
176                   end_time - start_time)
177      logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
178
179    if not asynchronous:
180      self._last_checkpoint_step = step
181      _save_fn()
182      return
183
184    if self._save_thread is not None:
185      self._save_thread.join(timeout=0.1)
186      if self._save_thread.is_alive():
187        logging.info("Saver thread still in progress, skipping checkpoint.")
188        return
189
190    self._last_checkpoint_step = step
191    self._save_thread = threading.Thread(target=_save_fn)
192    self._save_thread.start()
193
194  def _get_saver(self):
195    if self._saver is not None:
196      return self._saver
197    elif self._scaffold is not None:
198      return self._scaffold.saver
199
200    # Get saver from the SAVERS collection if present.
201    collection_key = ops.GraphKeys.SAVERS
202    savers = ops.get_collection(collection_key)
203    if not savers:
204      raise RuntimeError(
205          "No items in collection {}. Please add a saver to the collection "
206          "or provide a saver or scaffold.".format(collection_key))
207    elif len(savers) > 1:
208      raise RuntimeError(
209          "More than one item in collection {}. "
210          "Please indicate which one to use by passing it to the constructor."
211          .format(collection_key))
212
213    self._saver = savers[0]
214    return savers[0]
215