xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/checkpoint_management.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Checkpoint Manager and other utilities for managing checkpoints."""
18import collections
19import os.path
20import re
21import time
22
23from google.protobuf import text_format
24
25from tensorflow.core.protobuf import saver_pb2
26from tensorflow.python.eager import context
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.lib.io import file_io
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.training import training_util
33from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
34from tensorflow.python.util import compat
35from tensorflow.python.util import deprecation
36from tensorflow.python.util.tf_export import tf_export
37
38
39def _evaluate(tensor):
40  """Returns the numpy value of a tensor."""
41  if context.executing_eagerly():
42    return tensor.numpy()
43  return ops.get_default_session().run(tensor)
44
45
46def _GetCheckpointFilename(save_dir, latest_filename):
47  """Returns a filename for storing the CheckpointState.
48
49  Args:
50    save_dir: The directory for saving and restoring checkpoints.
51    latest_filename: Name of the file in 'save_dir' that is used
52      to store the CheckpointState.
53
54  Returns:
55    The path of the file that contains the CheckpointState proto.
56  """
57  if latest_filename is None:
58    latest_filename = "checkpoint"
59  return os.path.join(save_dir, latest_filename)
60
61
62@tf_export(v1=["train.generate_checkpoint_state_proto"])
63def generate_checkpoint_state_proto(save_dir,
64                                    model_checkpoint_path,
65                                    all_model_checkpoint_paths=None,
66                                    all_model_checkpoint_timestamps=None,
67                                    last_preserved_timestamp=None):
68  """Generates a checkpoint state proto.
69
70  Args:
71    save_dir: Directory where the model was saved.
72    model_checkpoint_path: The checkpoint file.
73    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
74      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
75      the last element must be equal to model_checkpoint_path.  These paths
76      are also saved in the CheckpointState proto.
77    all_model_checkpoint_timestamps: A list of floats, indicating the number of
78      seconds since the Epoch when each checkpoint was generated.
79    last_preserved_timestamp: A float, indicating the number of seconds since
80      the Epoch when the last preserved checkpoint was written, e.g. due to a
81      `keep_checkpoint_every_n_hours` parameter (see
82      `tf.train.CheckpointManager` for an implementation).
83  Returns:
84    CheckpointState proto with model_checkpoint_path and
85    all_model_checkpoint_paths updated to either absolute paths or
86    relative paths to the current save_dir.
87
88  Raises:
89    ValueError: If `all_model_checkpoint_timestamps` was provided but its length
90      does not match `all_model_checkpoint_paths`.
91  """
92  if all_model_checkpoint_paths is None:
93    all_model_checkpoint_paths = []
94
95  if (not all_model_checkpoint_paths or
96      all_model_checkpoint_paths[-1] != model_checkpoint_path):
97    logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
98                 model_checkpoint_path)
99    all_model_checkpoint_paths.append(model_checkpoint_path)
100
101  if (all_model_checkpoint_timestamps
102      and (len(all_model_checkpoint_timestamps)
103           != len(all_model_checkpoint_paths))):
104    raise ValueError(
105        ("Checkpoint timestamps, if provided, must match checkpoint paths (got "
106         "paths %s and timestamps %s)")
107        % (all_model_checkpoint_paths, all_model_checkpoint_timestamps))
108
109  # Relative paths need to be rewritten to be relative to the "save_dir"
110  # if model_checkpoint_path already contains "save_dir".
111  if not os.path.isabs(save_dir):
112    if not os.path.isabs(model_checkpoint_path):
113      model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
114    for i, p in enumerate(all_model_checkpoint_paths):
115      if not os.path.isabs(p):
116        all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
117
118  coord_checkpoint_proto = CheckpointState(
119      model_checkpoint_path=model_checkpoint_path,
120      all_model_checkpoint_paths=all_model_checkpoint_paths,
121      all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
122      last_preserved_timestamp=last_preserved_timestamp)
123
124  return coord_checkpoint_proto
125
126
127@deprecation.deprecated(
128    date=None,
129    instructions=("Use `tf.train.CheckpointManager` to manage checkpoints "
130                  "rather than manually editing the Checkpoint proto."))
131@tf_export(v1=["train.update_checkpoint_state"])
132def update_checkpoint_state(save_dir,
133                            model_checkpoint_path,
134                            all_model_checkpoint_paths=None,
135                            latest_filename=None,
136                            all_model_checkpoint_timestamps=None,
137                            last_preserved_timestamp=None):
138  """Updates the content of the 'checkpoint' file.
139
140  This updates the checkpoint file containing a CheckpointState
141  proto.
142
143  Args:
144    save_dir: Directory where the model was saved.
145    model_checkpoint_path: The checkpoint file.
146    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
147      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
148      the last element must be equal to model_checkpoint_path.  These paths
149      are also saved in the CheckpointState proto.
150    latest_filename: Optional name of the checkpoint file.  Default to
151      'checkpoint'.
152    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
153      seconds since the Epoch) indicating when the checkpoints in
154      `all_model_checkpoint_paths` were created.
155    last_preserved_timestamp: A float, indicating the number of seconds since
156      the Epoch when the last preserved checkpoint was written, e.g. due to a
157      `keep_checkpoint_every_n_hours` parameter (see
158      `tf.train.CheckpointManager` for an implementation).
159  Raises:
160    RuntimeError: If any of the model checkpoint paths conflict with the file
161      containing CheckpointSate.
162  """
163  update_checkpoint_state_internal(
164      save_dir=save_dir,
165      model_checkpoint_path=model_checkpoint_path,
166      all_model_checkpoint_paths=all_model_checkpoint_paths,
167      latest_filename=latest_filename,
168      save_relative_paths=False,
169      all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
170      last_preserved_timestamp=last_preserved_timestamp)
171
172
173@tf_export("__internal__.train.update_checkpoint_state", v1=[])
174def update_checkpoint_state_internal(save_dir,
175                                     model_checkpoint_path,
176                                     all_model_checkpoint_paths=None,
177                                     latest_filename=None,
178                                     save_relative_paths=False,
179                                     all_model_checkpoint_timestamps=None,
180                                     last_preserved_timestamp=None):
181  """Updates the content of the 'checkpoint' file.
182
183  This updates the checkpoint file containing a CheckpointState
184  proto.
185
186  Args:
187    save_dir: Directory where the model was saved.
188    model_checkpoint_path: The checkpoint file.
189    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
190      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
191      the last element must be equal to model_checkpoint_path.  These paths
192      are also saved in the CheckpointState proto.
193    latest_filename: Optional name of the checkpoint file.  Default to
194      'checkpoint'.
195    save_relative_paths: If `True`, will write relative paths to the checkpoint
196      state file.
197    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
198      seconds since the Epoch) indicating when the checkpoints in
199      `all_model_checkpoint_paths` were created.
200    last_preserved_timestamp: A float, indicating the number of seconds since
201      the Epoch when the last preserved checkpoint was written, e.g. due to a
202      `keep_checkpoint_every_n_hours` parameter (see
203      `tf.train.CheckpointManager` for an implementation).
204
205  Raises:
206    RuntimeError: If any of the model checkpoint paths conflict with the file
207      containing CheckpointSate.
208  """
209  # Writes the "checkpoint" file for the coordinator for later restoration.
210  coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
211  if save_relative_paths:
212    if os.path.isabs(model_checkpoint_path):
213      rel_model_checkpoint_path = os.path.relpath(
214          model_checkpoint_path, save_dir)
215    else:
216      rel_model_checkpoint_path = model_checkpoint_path
217    rel_all_model_checkpoint_paths = []
218    for p in all_model_checkpoint_paths:
219      if os.path.isabs(p):
220        rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
221      else:
222        rel_all_model_checkpoint_paths.append(p)
223    ckpt = generate_checkpoint_state_proto(
224        save_dir,
225        rel_model_checkpoint_path,
226        all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
227        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
228        last_preserved_timestamp=last_preserved_timestamp)
229  else:
230    ckpt = generate_checkpoint_state_proto(
231        save_dir,
232        model_checkpoint_path,
233        all_model_checkpoint_paths=all_model_checkpoint_paths,
234        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
235        last_preserved_timestamp=last_preserved_timestamp)
236
237  if coord_checkpoint_filename == ckpt.model_checkpoint_path:
238    raise RuntimeError("Save path '%s' conflicts with path used for "
239                       "checkpoint state.  Please use a different save path." %
240                       model_checkpoint_path)
241
242  # Preventing potential read/write race condition by *atomically* writing to a
243  # file.
244  file_io.atomic_write_string_to_file(coord_checkpoint_filename,
245                                      text_format.MessageToString(ckpt))
246
247
248@tf_export("train.get_checkpoint_state")
249def get_checkpoint_state(checkpoint_dir, latest_filename=None):
250  """Returns CheckpointState proto from the "checkpoint" file.
251
252  If the "checkpoint" file contains a valid CheckpointState
253  proto, returns it.
254
255  Args:
256    checkpoint_dir: The directory of checkpoints.
257    latest_filename: Optional name of the checkpoint file.  Default to
258      'checkpoint'.
259
260  Returns:
261    A CheckpointState if the state was available, None
262    otherwise.
263
264  Raises:
265    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
266  """
267  if isinstance(checkpoint_dir, os.PathLike):
268    checkpoint_dir = os.fspath(checkpoint_dir)
269  ckpt = None
270  coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
271                                                     latest_filename)
272  f = None
273  try:
274    # Check that the file exists before opening it to avoid
275    # many lines of errors from colossus in the logs.
276    if file_io.file_exists(coord_checkpoint_filename):
277      file_content = file_io.read_file_to_string(
278          coord_checkpoint_filename)
279      ckpt = CheckpointState()
280      text_format.Merge(file_content, ckpt)
281      if not ckpt.model_checkpoint_path:
282        raise ValueError("Invalid checkpoint state loaded from "
283                         + checkpoint_dir)
284      # For relative model_checkpoint_path and all_model_checkpoint_paths,
285      # prepend checkpoint_dir.
286      if not os.path.isabs(ckpt.model_checkpoint_path):
287        ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
288                                                  ckpt.model_checkpoint_path)
289      for i, p in enumerate(ckpt.all_model_checkpoint_paths):
290        if not os.path.isabs(p):
291          ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
292  except errors.OpError as e:
293    # It's ok if the file cannot be read
294    logging.warning("%s: %s", type(e).__name__, e)
295    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
296    return None
297  except text_format.ParseError as e:
298    logging.warning("%s: %s", type(e).__name__, e)
299    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
300    return None
301  finally:
302    if f:
303      f.close()
304  return ckpt
305
306
307def _prefix_to_checkpoint_path(prefix, format_version):
308  """Returns the pathname of a checkpoint file, given the checkpoint prefix.
309
310  For V1 checkpoint, simply returns the prefix itself (the data file).  For V2,
311  returns the pathname to the index file.
312
313  Args:
314    prefix: a string, the prefix of a checkpoint.
315    format_version: the checkpoint format version that corresponds to the
316      prefix.
317  Returns:
318    The pathname of a checkpoint file, taking into account the checkpoint
319      format version.
320  """
321  if format_version == saver_pb2.SaverDef.V2:
322    return prefix + ".index"  # The index file identifies a checkpoint.
323  return prefix  # Just the data file.
324
325
326@tf_export("train.latest_checkpoint")
327def latest_checkpoint(checkpoint_dir, latest_filename=None):
328  """Finds the filename of latest saved checkpoint file.
329
330  Gets the checkpoint state given the provided checkpoint_dir and looks for a
331  corresponding TensorFlow 2 (preferred) or TensorFlow 1.x checkpoint path.
332  The latest_filename argument is only applicable if you are saving checkpoint
333  using `v1.train.Saver.save`
334
335
336  See the [Training Checkpoints
337  Guide](https://www.tensorflow.org/guide/checkpoint) for more details and
338  examples.`
339
340  Args:
341    checkpoint_dir: Directory where the variables were saved.
342    latest_filename: Optional name for the protocol buffer file that
343      contains the list of most recent checkpoint filenames.
344      See the corresponding argument to `v1.train.Saver.save`.
345
346  Returns:
347    The full path to the latest checkpoint or `None` if no checkpoint was found.
348  """
349  # Pick the latest checkpoint based on checkpoint state.
350  ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
351  if ckpt and ckpt.model_checkpoint_path:
352    # Look for either a V2 path or a V1 path, with priority for V2.
353    v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
354                                         saver_pb2.SaverDef.V2)
355    v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
356                                         saver_pb2.SaverDef.V1)
357    if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
358        v1_path):
359      return ckpt.model_checkpoint_path
360    else:
361      logging.error("Couldn't match files for checkpoint %s",
362                    ckpt.model_checkpoint_path)
363  return None
364
365
366def checkpoint_exists_internal(checkpoint_prefix):
367  """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
368
369  This is an internal function to check if a checkpoint exists,
370  since it takes into account the naming difference between V1 and V2 formats.
371
372  Args:
373    checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
374      priority.  Typically the result of `Saver.save()` or that of
375      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
376      V1/V2.
377  Returns:
378    A bool, true if a checkpoint referred to by `checkpoint_prefix` exists.
379  """
380  pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
381                                        saver_pb2.SaverDef.V2)
382  if file_io.get_matching_files(pathname):
383    return True
384  elif file_io.get_matching_files(checkpoint_prefix):
385    return True
386  else:
387    return False
388
389
390@deprecation.deprecated(
391    date=None,
392    instructions="Use standard file APIs to check for files with this prefix.")
393@tf_export(v1=["train.checkpoint_exists"])
394def checkpoint_exists(checkpoint_prefix):
395  """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
396
397  This is the recommended way to check if a checkpoint exists, since it takes
398  into account the naming difference between V1 and V2 formats.
399
400  Args:
401    checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
402      priority.  Typically the result of `Saver.save()` or that of
403      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
404      V1/V2.
405
406  Returns:
407    A bool, true if a checkpoint referred to by `checkpoint_prefix` exists.
408  """
409  return checkpoint_exists_internal(checkpoint_prefix)
410
411
412@deprecation.deprecated(
413    date=None,
414    instructions="Use standard file utilities to get mtimes.")
415@tf_export(v1=["train.get_checkpoint_mtimes"])
416def get_checkpoint_mtimes(checkpoint_prefixes):
417  """Returns the mtimes (modification timestamps) of the checkpoints.
418
419  Globs for the checkpoints pointed to by `checkpoint_prefixes`.  If the files
420  exist, collect their mtime.  Both V2 and V1 checkpoints are considered, in
421  that priority.
422
423  This is the recommended way to get the mtimes, since it takes into account
424  the naming difference between V1 and V2 formats.
425
426  Note: If not all checkpoints exist, the length of the returned mtimes list
427  will be smaller than the length of `checkpoint_prefixes` list, so mapping
428  checkpoints to corresponding mtimes will not be possible.
429
430  Args:
431    checkpoint_prefixes: a list of checkpoint paths, typically the results of
432      `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
433      sharded/non-sharded or V1/V2.
434  Returns:
435    A list of mtimes (in microseconds) of the found checkpoints.
436  """
437  mtimes = []
438
439  def match_maybe_append(pathname):
440    fnames = file_io.get_matching_files(pathname)
441    if fnames:
442      mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
443      return True
444    return False
445
446  for checkpoint_prefix in checkpoint_prefixes:
447    # Tries V2's metadata file first.
448    pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
449                                          saver_pb2.SaverDef.V2)
450    if match_maybe_append(pathname):
451      continue
452    # Otherwise, tries V1, where the prefix is the complete pathname.
453    match_maybe_append(checkpoint_prefix)
454
455  return mtimes
456
457
458@deprecation.deprecated(
459    date=None,
460    instructions="Use standard file APIs to delete files with this prefix.")
461@tf_export(v1=["train.remove_checkpoint"])
462def remove_checkpoint(checkpoint_prefix,
463                      checkpoint_format_version=saver_pb2.SaverDef.V2,
464                      meta_graph_suffix="meta"):
465  """Removes a checkpoint given by `checkpoint_prefix`.
466
467  Args:
468    checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
469      of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
470      sharded/non-sharded or V1/V2.
471    checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
472      `SaverDef.V2`.
473    meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
474  """
475  _delete_file_if_exists(
476      meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
477  if checkpoint_format_version == saver_pb2.SaverDef.V2:
478    # V2 has a metadata file and some data files.
479    _delete_file_if_exists(checkpoint_prefix + ".index")
480    _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
481  else:
482    # V1, Legacy.  Exact match on the data file.
483    _delete_file_if_exists(checkpoint_prefix)
484
485
486def _delete_file_if_exists(filespec):
487  """Deletes files matching `filespec`."""
488  for pathname in file_io.get_matching_files(filespec):
489    try:
490      file_io.delete_file(pathname)
491    except errors.NotFoundError:
492      logging.warning(
493          "Hit NotFoundError when deleting '%s', possibly because another "
494          "process/thread is also deleting/moving the same file", pathname)
495
496
497def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
498  """Returns the meta graph filename.
499
500  Args:
501    checkpoint_filename: Name of the checkpoint file.
502    meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
503
504  Returns:
505    MetaGraph file name.
506  """
507  # If the checkpoint_filename is sharded, the checkpoint_filename could
508  # be of format model.ckpt-step#-?????-of-shard#. For example,
509  # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
510  basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
511  suffixed_filename = ".".join([basename, meta_graph_suffix])
512  return suffixed_filename
513
514
515# TODO(allenl): Allow tf.keras.Model instances in the constructor directly?
516@tf_export("train.CheckpointManager")
517class CheckpointManager(object):
518  """Manages multiple checkpoints by keeping some and deleting unneeded ones.
519
520  Example usage:
521
522  ```python
523  import tensorflow as tf
524  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
525  manager = tf.train.CheckpointManager(
526      checkpoint, directory="/tmp/model", max_to_keep=5)
527  status = checkpoint.restore(manager.latest_checkpoint)
528  while True:
529    # train
530    manager.save()
531  ```
532
533  `CheckpointManager` preserves its own state across instantiations (see the
534  `__init__` documentation for details). Only one should be active in a
535  particular directory at a time.
536  """
537
538  def __init__(self,
539               checkpoint,
540               directory,
541               max_to_keep,
542               keep_checkpoint_every_n_hours=None,
543               checkpoint_name="ckpt",
544               step_counter=None,
545               checkpoint_interval=None,
546               init_fn=None):
547    """Configure a `CheckpointManager` for use in `directory`.
548
549    If a `CheckpointManager` was previously used in `directory`, its
550    state will be restored. This includes the list of managed checkpoints and
551    the timestamp bookkeeping necessary to support
552    `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager`
553    will be the same as the previous `CheckpointManager`, including cleaning up
554    existing checkpoints if appropriate.
555
556    Checkpoints are only considered for deletion just after a new checkpoint has
557    been added. At that point, `max_to_keep` checkpoints will remain in an
558    "active set". Once a checkpoint is preserved by
559    `keep_checkpoint_every_n_hours` it will not be deleted by this
560    `CheckpointManager` or any future `CheckpointManager` instantiated in
561    `directory` (regardless of the new setting of
562    `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the
563    active set may be deleted by this `CheckpointManager` or a future
564    `CheckpointManager` instantiated in `directory` (subject to its
565    `max_to_keep` and `keep_checkpoint_every_n_hours` settings).
566
567    `CheckpointManager` can be also used for initializing the model if
568    there is no checkpoints for restoring in `directory`. An example usage is:
569
570    >>> import tempfile
571
572    >>> tmp_dir = tempfile.mkdtemp()
573    >>> checkpoint = tf.train.Checkpoint()
574    >>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init'))
575
576    >>> def init_fn():
577    ...   # Partially restore the checkpoint from `init_path`.
578    ...   checkpoint.restore(init_path)
579
580    >>> manager = tf.train.CheckpointManager(
581    ...     checkpoint,
582    ...     directory=os.path.join(tmp_dir, 'ckpt'),
583    ...     max_to_keep=None,
584    ...     init_fn=init_fn)
585    >>> # `restore_or_initialize` will call `init_fn` if there is no existing
586    >>> # checkpoint in `directory`.
587    >>> manager.restore_or_initialize()
588
589    Args:
590      checkpoint: The `tf.train.Checkpoint` instance to save and manage
591        checkpoints for.
592      directory: The path to a directory in which to write checkpoints. A
593        special file named "checkpoint" is also written to this directory (in a
594        human-readable text format) which contains the state of the
595        `CheckpointManager`.
596      max_to_keep: An integer, the number of checkpoints to keep. Unless
597        preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
598        deleted from the active set, oldest first, until only `max_to_keep`
599        checkpoints remain. If `None`, no checkpoints are deleted and everything
600        stays in the active set. Note that `max_to_keep=None` will keep all
601        checkpoint paths in memory and in the checkpoint state protocol buffer
602        on disk.
603      keep_checkpoint_every_n_hours: Upon removal from the active set, a
604        checkpoint will be preserved if it has been at least
605        `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
606        default setting of `None` does not preserve any checkpoints in this way.
607      checkpoint_name: Custom name for the checkpoint file.
608      step_counter: A `tf.Variable` instance for checking the current step
609        counter value, in case users want to save checkpoints every N steps.
610      checkpoint_interval: An integer, indicates the minimum step interval
611        between two checkpoints.
612      init_fn: Callable. A function to do customized intialization if no
613        checkpoints are in the directory.
614
615    Raises:
616      ValueError: If `max_to_keep` is not a positive integer.
617    """
618    self._checkpoint = checkpoint
619    self._save_counter_assign = None
620    if max_to_keep is not None and max_to_keep <= 0:
621      raise ValueError(
622          ("Expected a positive integer or `None` for `max_to_keep`, "
623           "got %d.")
624          % (max_to_keep,))
625    self._max_to_keep = max_to_keep
626    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
627    if isinstance(directory, os.PathLike):
628      directory = os.fspath(directory)
629    self._directory = directory
630    self._checkpoint_prefix = os.path.join(directory, checkpoint_name)
631    self._init_fn = init_fn
632
633    if checkpoint_interval is not None:
634      if step_counter is None:
635        raise ValueError("`step_counter` should be passed if "
636                         "`checkpoint_interval` is not None.")
637      self._last_checkpoint_step = None
638      self._step_counter = step_counter
639    self._checkpoint_interval = checkpoint_interval
640
641    recovered_state = get_checkpoint_state(directory)
642    current_clock = time.time()
643    self._maybe_delete = collections.OrderedDict()
644    if recovered_state is None:
645      self._latest_checkpoint = None
646      # Set the clock back slightly to avoid race conditions when quickly
647      # re-creating a CheckpointManager.
648      self._last_preserved_timestamp = current_clock - 1.
649    else:
650      self._latest_checkpoint = recovered_state.model_checkpoint_path
651      self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
652      if current_clock < self._last_preserved_timestamp:
653        # Time seems to have reversed itself. In addition to this warning, we'll
654        # min() saved checkpoint timestamps with the current time to ensure that
655        # old checkpoints don't get deleted accidentally.
656        logging.warning(
657            ("time.time() returned a value %f seconds behind the last "
658             "preserved checkpoint timestamp.")
659            % (self._last_preserved_timestamp - current_clock,))
660        self._last_preserved_timestamp = current_clock
661      all_timestamps = recovered_state.all_model_checkpoint_timestamps
662      all_paths = recovered_state.all_model_checkpoint_paths
663      del recovered_state  # Uses modified values from now on
664      if not all_timestamps:
665        all_timestamps = [self._last_preserved_timestamp] * len(all_paths)
666
667      for filename, timestamp in zip(all_paths, all_timestamps):
668        timestamp = min(timestamp, current_clock)
669        if timestamp > self._last_preserved_timestamp:
670          self._maybe_delete[filename] = timestamp
671
672  @property
673  def directory(self):
674    return self._directory
675
676  @property
677  def checkpoint_interval(self):
678    return self._checkpoint_interval
679
680  @property
681  def latest_checkpoint(self):
682    """The prefix of the most recent checkpoint in `directory`.
683
684    Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is
685    the constructor argument to `CheckpointManager`.
686
687    Suitable for passing to `tf.train.Checkpoint.restore` to resume training.
688
689    Returns:
690      The checkpoint prefix. If there are no checkpoints, returns `None`.
691    """
692    return self._latest_checkpoint
693
694  @property
695  def checkpoints(self):
696    """A list of managed checkpoints.
697
698    Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not
699    show up in this list (to avoid ever-growing filename lists).
700
701    Returns:
702      A list of filenames, sorted from oldest to newest.
703    """
704    return list(self._maybe_delete.keys())
705
706  def _sweep(self):
707    """Deletes or preserves managed checkpoints."""
708    if not self._max_to_keep:
709      # Does not update self._last_preserved_timestamp, since everything is kept
710      # in the active set.
711      return
712    while len(self._maybe_delete) > self._max_to_keep:
713      filename, timestamp = self._maybe_delete.popitem(last=False)
714      # Even if we're keeping this checkpoint due to
715      # keep_checkpoint_every_n_hours, we won't reference it to avoid
716      # infinitely-growing CheckpointState protos.
717      if (self._keep_checkpoint_every_n_hours
718          and (timestamp - self._keep_checkpoint_every_n_hours * 3600.
719               >= self._last_preserved_timestamp)):
720        self._last_preserved_timestamp = timestamp
721        continue
722      _delete_file_if_exists(filename + ".index")
723      _delete_file_if_exists(filename + ".data-?????-of-?????")
724
725  def _record_state(self):
726    """Saves the `CheckpointManager`'s state in `directory`."""
727    filenames, timestamps = zip(*self._maybe_delete.items())
728    update_checkpoint_state_internal(
729        self._directory,
730        model_checkpoint_path=self.latest_checkpoint,
731        all_model_checkpoint_paths=filenames,
732        all_model_checkpoint_timestamps=timestamps,
733        last_preserved_timestamp=self._last_preserved_timestamp,
734        save_relative_paths=True)
735
736  @property
737  def _prefix(self):
738    """A common prefix for all checkpoints saved with this manager.
739
740    For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`,
741    `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be
742    numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each
743    checkpoint has several associated files
744    (e.g. `"/tmp/tf-model/ckpt-2.index"`).
745
746    Returns:
747      A string prefix.
748    """
749    return self._checkpoint_prefix
750
751  @property
752  def checkpoint(self):
753    """Returns the `tf.train.Checkpoint` object."""
754    return self._checkpoint
755
756  def save(self, checkpoint_number=None, check_interval=True, options=None):
757    """Creates a new checkpoint and manages it.
758
759    Args:
760      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
761        `Tensor`, used to number the checkpoint. If `None` (default),
762        checkpoints are numbered using `checkpoint.save_counter`. Even if
763        `checkpoint_number` is provided, `save_counter` is still incremented. A
764        user-provided `checkpoint_number` is not incremented even if it is a
765        `Variable`.
766      check_interval: An optional boolean. The argument is only effective when
767        `checkpoint_interval` is passed into the manager. If `True`, the manager
768        will only save the checkpoint if the interval between checkpoints is
769        larger than `checkpoint_interval`. Otherwise it will always save the
770        checkpoint unless a checkpoint has already been saved for the current
771        step.
772      options: Optional `tf.train.CheckpointOptions` object. This argument only
773        works with TF2 checkpoint objects. For example, options =
774        tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
775
776    Returns:
777      The path to the new checkpoint. It is also recorded in the `checkpoints`
778      and `latest_checkpoint` properties. `None` if no checkpoint is saved.
779    """
780    if self._checkpoint_interval is not None:
781      current_step = _evaluate(self._step_counter)
782      if self._last_checkpoint_step is not None:
783        if current_step == self._last_checkpoint_step:
784          return None
785        if check_interval and current_step < (
786            self._last_checkpoint_step + self._checkpoint_interval):
787          return None
788      self._last_checkpoint_step = current_step
789
790    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
791    # slightly with a custom numbering option.
792    if context.executing_eagerly():
793      save_counter = self._checkpoint.save_counter
794      save_counter.assign_add(1)
795      session = None
796    else:
797      session = ops.get_default_session()
798
799      def _initializing_creator(next_creator, **kwargs):
800        """Initialize the save counter if it has been newly created."""
801        v = next_creator(**kwargs)
802        session.run(v.initializer)
803        return v
804
805      with variable_scope.variable_creator_scope(_initializing_creator):
806        save_counter = self._checkpoint.save_counter
807      if self._save_counter_assign is None:
808        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
809      session.run(self._save_counter_assign)
810    if checkpoint_number is None:
811      checkpoint_number = save_counter
812    if not isinstance(checkpoint_number, compat.integral_types):
813      checkpoint_number = training_util.global_step(
814          sess=session, global_step_tensor=checkpoint_number)
815    prefix = "%s-%d" % (self._prefix, checkpoint_number)
816
817    def _record_and_sweep_state(save_path):
818      timestamp = time.time()
819      # If this is an overwritten checkpoint we were previously tracking, delete
820      # and reinsert it to make sure it goes to the end of the queue.
821      if save_path in self._maybe_delete:
822        del self._maybe_delete[save_path]
823      self._maybe_delete[save_path] = timestamp
824      self._latest_checkpoint = save_path
825      # Before deleting anything we update the Checkpoint proto with the new
826      # checkpoint. We'll go back and correct it after cleaning up old files,
827      # but a preemption while deleting will be more likely to see the new
828      # checkpoint this way.
829      self._record_state()
830      self._sweep()
831      # Write out the Checkpoint proto a second time, now without the deleted
832      # checkpoints.
833      self._record_state()
834
835    if options is None:
836      save_path = self._checkpoint._write(  # pylint: disable=protected-access
837          prefix, write_done_callback=_record_and_sweep_state)
838    else:
839      save_path = self._checkpoint._write(  # pylint: disable=protected-access
840          prefix, options=options, write_done_callback=_record_and_sweep_state)
841
842    return save_path
843
844  def restore_or_initialize(self):
845    """Restore items in `checkpoint` from the latest checkpoint file.
846
847    This method will first try to restore from the most recent checkpoint in
848    `directory`. If no checkpoints exist in `directory`, and `init_fn` is
849    specified, this method will call `init_fn` to do customized
850    initialization. This can be used to support initialization from pretrained
851    models.
852
853    Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return
854    a load status object that users can run assertions on
855    (e.g. assert_consumed()). Thus to run assertions, users should directly use
856    `tf.train.Checkpoint.restore()` method.
857
858    Returns:
859      The restored checkpoint path if the lastest checkpoint is found and
860      restored. Otherwise None.
861    """
862    # TODO(chienchunh): When AsyncCheckpoint is used, we may need to force to
863    # sync until any ongoing async save is done. Otherwise, if this is the first
864    # checkpoint and _latest_checkpoint has not been updated due to async write,
865    # this would resort to init_fn instead of restoring from the checkpoin file.
866    # This should be fixed once AsyncCheckpoint is integrated with the public
867    # API so that we can rely on CheckpointOptions to tell whether we should
868    # sync for AsyncCheckpoint.
869    if self._latest_checkpoint is not None:
870      self._checkpoint.restore(self._latest_checkpoint)
871      if self._checkpoint_interval is not None:
872        self._last_checkpoint_step = _evaluate(self._step_counter)
873      return self._latest_checkpoint
874
875    if self._init_fn is not None:
876      self._init_fn()
877    return None
878