1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Checkpoint Manager and other utilities for managing checkpoints.""" 18import collections 19import os.path 20import re 21import time 22 23from google.protobuf import text_format 24 25from tensorflow.core.protobuf import saver_pb2 26from tensorflow.python.eager import context 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.lib.io import file_io 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.training import training_util 33from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState 34from tensorflow.python.util import compat 35from tensorflow.python.util import deprecation 36from tensorflow.python.util.tf_export import tf_export 37 38 39def _evaluate(tensor): 40 """Returns the numpy value of a tensor.""" 41 if context.executing_eagerly(): 42 return tensor.numpy() 43 return ops.get_default_session().run(tensor) 44 45 46def _GetCheckpointFilename(save_dir, latest_filename): 47 """Returns a filename for storing the CheckpointState. 48 49 Args: 50 save_dir: The directory for saving and restoring checkpoints. 51 latest_filename: Name of the file in 'save_dir' that is used 52 to store the CheckpointState. 53 54 Returns: 55 The path of the file that contains the CheckpointState proto. 56 """ 57 if latest_filename is None: 58 latest_filename = "checkpoint" 59 return os.path.join(save_dir, latest_filename) 60 61 62@tf_export(v1=["train.generate_checkpoint_state_proto"]) 63def generate_checkpoint_state_proto(save_dir, 64 model_checkpoint_path, 65 all_model_checkpoint_paths=None, 66 all_model_checkpoint_timestamps=None, 67 last_preserved_timestamp=None): 68 """Generates a checkpoint state proto. 69 70 Args: 71 save_dir: Directory where the model was saved. 72 model_checkpoint_path: The checkpoint file. 73 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 74 checkpoints, sorted from oldest to newest. If this is a non-empty list, 75 the last element must be equal to model_checkpoint_path. These paths 76 are also saved in the CheckpointState proto. 77 all_model_checkpoint_timestamps: A list of floats, indicating the number of 78 seconds since the Epoch when each checkpoint was generated. 79 last_preserved_timestamp: A float, indicating the number of seconds since 80 the Epoch when the last preserved checkpoint was written, e.g. due to a 81 `keep_checkpoint_every_n_hours` parameter (see 82 `tf.train.CheckpointManager` for an implementation). 83 Returns: 84 CheckpointState proto with model_checkpoint_path and 85 all_model_checkpoint_paths updated to either absolute paths or 86 relative paths to the current save_dir. 87 88 Raises: 89 ValueError: If `all_model_checkpoint_timestamps` was provided but its length 90 does not match `all_model_checkpoint_paths`. 91 """ 92 if all_model_checkpoint_paths is None: 93 all_model_checkpoint_paths = [] 94 95 if (not all_model_checkpoint_paths or 96 all_model_checkpoint_paths[-1] != model_checkpoint_path): 97 logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", 98 model_checkpoint_path) 99 all_model_checkpoint_paths.append(model_checkpoint_path) 100 101 if (all_model_checkpoint_timestamps 102 and (len(all_model_checkpoint_timestamps) 103 != len(all_model_checkpoint_paths))): 104 raise ValueError( 105 ("Checkpoint timestamps, if provided, must match checkpoint paths (got " 106 "paths %s and timestamps %s)") 107 % (all_model_checkpoint_paths, all_model_checkpoint_timestamps)) 108 109 # Relative paths need to be rewritten to be relative to the "save_dir" 110 # if model_checkpoint_path already contains "save_dir". 111 if not os.path.isabs(save_dir): 112 if not os.path.isabs(model_checkpoint_path): 113 model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) 114 for i, p in enumerate(all_model_checkpoint_paths): 115 if not os.path.isabs(p): 116 all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) 117 118 coord_checkpoint_proto = CheckpointState( 119 model_checkpoint_path=model_checkpoint_path, 120 all_model_checkpoint_paths=all_model_checkpoint_paths, 121 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 122 last_preserved_timestamp=last_preserved_timestamp) 123 124 return coord_checkpoint_proto 125 126 127@deprecation.deprecated( 128 date=None, 129 instructions=("Use `tf.train.CheckpointManager` to manage checkpoints " 130 "rather than manually editing the Checkpoint proto.")) 131@tf_export(v1=["train.update_checkpoint_state"]) 132def update_checkpoint_state(save_dir, 133 model_checkpoint_path, 134 all_model_checkpoint_paths=None, 135 latest_filename=None, 136 all_model_checkpoint_timestamps=None, 137 last_preserved_timestamp=None): 138 """Updates the content of the 'checkpoint' file. 139 140 This updates the checkpoint file containing a CheckpointState 141 proto. 142 143 Args: 144 save_dir: Directory where the model was saved. 145 model_checkpoint_path: The checkpoint file. 146 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 147 checkpoints, sorted from oldest to newest. If this is a non-empty list, 148 the last element must be equal to model_checkpoint_path. These paths 149 are also saved in the CheckpointState proto. 150 latest_filename: Optional name of the checkpoint file. Default to 151 'checkpoint'. 152 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 153 seconds since the Epoch) indicating when the checkpoints in 154 `all_model_checkpoint_paths` were created. 155 last_preserved_timestamp: A float, indicating the number of seconds since 156 the Epoch when the last preserved checkpoint was written, e.g. due to a 157 `keep_checkpoint_every_n_hours` parameter (see 158 `tf.train.CheckpointManager` for an implementation). 159 Raises: 160 RuntimeError: If any of the model checkpoint paths conflict with the file 161 containing CheckpointSate. 162 """ 163 update_checkpoint_state_internal( 164 save_dir=save_dir, 165 model_checkpoint_path=model_checkpoint_path, 166 all_model_checkpoint_paths=all_model_checkpoint_paths, 167 latest_filename=latest_filename, 168 save_relative_paths=False, 169 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 170 last_preserved_timestamp=last_preserved_timestamp) 171 172 173@tf_export("__internal__.train.update_checkpoint_state", v1=[]) 174def update_checkpoint_state_internal(save_dir, 175 model_checkpoint_path, 176 all_model_checkpoint_paths=None, 177 latest_filename=None, 178 save_relative_paths=False, 179 all_model_checkpoint_timestamps=None, 180 last_preserved_timestamp=None): 181 """Updates the content of the 'checkpoint' file. 182 183 This updates the checkpoint file containing a CheckpointState 184 proto. 185 186 Args: 187 save_dir: Directory where the model was saved. 188 model_checkpoint_path: The checkpoint file. 189 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 190 checkpoints, sorted from oldest to newest. If this is a non-empty list, 191 the last element must be equal to model_checkpoint_path. These paths 192 are also saved in the CheckpointState proto. 193 latest_filename: Optional name of the checkpoint file. Default to 194 'checkpoint'. 195 save_relative_paths: If `True`, will write relative paths to the checkpoint 196 state file. 197 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 198 seconds since the Epoch) indicating when the checkpoints in 199 `all_model_checkpoint_paths` were created. 200 last_preserved_timestamp: A float, indicating the number of seconds since 201 the Epoch when the last preserved checkpoint was written, e.g. due to a 202 `keep_checkpoint_every_n_hours` parameter (see 203 `tf.train.CheckpointManager` for an implementation). 204 205 Raises: 206 RuntimeError: If any of the model checkpoint paths conflict with the file 207 containing CheckpointSate. 208 """ 209 # Writes the "checkpoint" file for the coordinator for later restoration. 210 coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) 211 if save_relative_paths: 212 if os.path.isabs(model_checkpoint_path): 213 rel_model_checkpoint_path = os.path.relpath( 214 model_checkpoint_path, save_dir) 215 else: 216 rel_model_checkpoint_path = model_checkpoint_path 217 rel_all_model_checkpoint_paths = [] 218 for p in all_model_checkpoint_paths: 219 if os.path.isabs(p): 220 rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) 221 else: 222 rel_all_model_checkpoint_paths.append(p) 223 ckpt = generate_checkpoint_state_proto( 224 save_dir, 225 rel_model_checkpoint_path, 226 all_model_checkpoint_paths=rel_all_model_checkpoint_paths, 227 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 228 last_preserved_timestamp=last_preserved_timestamp) 229 else: 230 ckpt = generate_checkpoint_state_proto( 231 save_dir, 232 model_checkpoint_path, 233 all_model_checkpoint_paths=all_model_checkpoint_paths, 234 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 235 last_preserved_timestamp=last_preserved_timestamp) 236 237 if coord_checkpoint_filename == ckpt.model_checkpoint_path: 238 raise RuntimeError("Save path '%s' conflicts with path used for " 239 "checkpoint state. Please use a different save path." % 240 model_checkpoint_path) 241 242 # Preventing potential read/write race condition by *atomically* writing to a 243 # file. 244 file_io.atomic_write_string_to_file(coord_checkpoint_filename, 245 text_format.MessageToString(ckpt)) 246 247 248@tf_export("train.get_checkpoint_state") 249def get_checkpoint_state(checkpoint_dir, latest_filename=None): 250 """Returns CheckpointState proto from the "checkpoint" file. 251 252 If the "checkpoint" file contains a valid CheckpointState 253 proto, returns it. 254 255 Args: 256 checkpoint_dir: The directory of checkpoints. 257 latest_filename: Optional name of the checkpoint file. Default to 258 'checkpoint'. 259 260 Returns: 261 A CheckpointState if the state was available, None 262 otherwise. 263 264 Raises: 265 ValueError: if the checkpoint read doesn't have model_checkpoint_path set. 266 """ 267 if isinstance(checkpoint_dir, os.PathLike): 268 checkpoint_dir = os.fspath(checkpoint_dir) 269 ckpt = None 270 coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, 271 latest_filename) 272 f = None 273 try: 274 # Check that the file exists before opening it to avoid 275 # many lines of errors from colossus in the logs. 276 if file_io.file_exists(coord_checkpoint_filename): 277 file_content = file_io.read_file_to_string( 278 coord_checkpoint_filename) 279 ckpt = CheckpointState() 280 text_format.Merge(file_content, ckpt) 281 if not ckpt.model_checkpoint_path: 282 raise ValueError("Invalid checkpoint state loaded from " 283 + checkpoint_dir) 284 # For relative model_checkpoint_path and all_model_checkpoint_paths, 285 # prepend checkpoint_dir. 286 if not os.path.isabs(ckpt.model_checkpoint_path): 287 ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, 288 ckpt.model_checkpoint_path) 289 for i, p in enumerate(ckpt.all_model_checkpoint_paths): 290 if not os.path.isabs(p): 291 ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) 292 except errors.OpError as e: 293 # It's ok if the file cannot be read 294 logging.warning("%s: %s", type(e).__name__, e) 295 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 296 return None 297 except text_format.ParseError as e: 298 logging.warning("%s: %s", type(e).__name__, e) 299 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 300 return None 301 finally: 302 if f: 303 f.close() 304 return ckpt 305 306 307def _prefix_to_checkpoint_path(prefix, format_version): 308 """Returns the pathname of a checkpoint file, given the checkpoint prefix. 309 310 For V1 checkpoint, simply returns the prefix itself (the data file). For V2, 311 returns the pathname to the index file. 312 313 Args: 314 prefix: a string, the prefix of a checkpoint. 315 format_version: the checkpoint format version that corresponds to the 316 prefix. 317 Returns: 318 The pathname of a checkpoint file, taking into account the checkpoint 319 format version. 320 """ 321 if format_version == saver_pb2.SaverDef.V2: 322 return prefix + ".index" # The index file identifies a checkpoint. 323 return prefix # Just the data file. 324 325 326@tf_export("train.latest_checkpoint") 327def latest_checkpoint(checkpoint_dir, latest_filename=None): 328 """Finds the filename of latest saved checkpoint file. 329 330 Gets the checkpoint state given the provided checkpoint_dir and looks for a 331 corresponding TensorFlow 2 (preferred) or TensorFlow 1.x checkpoint path. 332 The latest_filename argument is only applicable if you are saving checkpoint 333 using `v1.train.Saver.save` 334 335 336 See the [Training Checkpoints 337 Guide](https://www.tensorflow.org/guide/checkpoint) for more details and 338 examples.` 339 340 Args: 341 checkpoint_dir: Directory where the variables were saved. 342 latest_filename: Optional name for the protocol buffer file that 343 contains the list of most recent checkpoint filenames. 344 See the corresponding argument to `v1.train.Saver.save`. 345 346 Returns: 347 The full path to the latest checkpoint or `None` if no checkpoint was found. 348 """ 349 # Pick the latest checkpoint based on checkpoint state. 350 ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) 351 if ckpt and ckpt.model_checkpoint_path: 352 # Look for either a V2 path or a V1 path, with priority for V2. 353 v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 354 saver_pb2.SaverDef.V2) 355 v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 356 saver_pb2.SaverDef.V1) 357 if file_io.get_matching_files(v2_path) or file_io.get_matching_files( 358 v1_path): 359 return ckpt.model_checkpoint_path 360 else: 361 logging.error("Couldn't match files for checkpoint %s", 362 ckpt.model_checkpoint_path) 363 return None 364 365 366def checkpoint_exists_internal(checkpoint_prefix): 367 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 368 369 This is an internal function to check if a checkpoint exists, 370 since it takes into account the naming difference between V1 and V2 formats. 371 372 Args: 373 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 374 priority. Typically the result of `Saver.save()` or that of 375 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 376 V1/V2. 377 Returns: 378 A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. 379 """ 380 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 381 saver_pb2.SaverDef.V2) 382 if file_io.get_matching_files(pathname): 383 return True 384 elif file_io.get_matching_files(checkpoint_prefix): 385 return True 386 else: 387 return False 388 389 390@deprecation.deprecated( 391 date=None, 392 instructions="Use standard file APIs to check for files with this prefix.") 393@tf_export(v1=["train.checkpoint_exists"]) 394def checkpoint_exists(checkpoint_prefix): 395 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 396 397 This is the recommended way to check if a checkpoint exists, since it takes 398 into account the naming difference between V1 and V2 formats. 399 400 Args: 401 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 402 priority. Typically the result of `Saver.save()` or that of 403 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 404 V1/V2. 405 406 Returns: 407 A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. 408 """ 409 return checkpoint_exists_internal(checkpoint_prefix) 410 411 412@deprecation.deprecated( 413 date=None, 414 instructions="Use standard file utilities to get mtimes.") 415@tf_export(v1=["train.get_checkpoint_mtimes"]) 416def get_checkpoint_mtimes(checkpoint_prefixes): 417 """Returns the mtimes (modification timestamps) of the checkpoints. 418 419 Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files 420 exist, collect their mtime. Both V2 and V1 checkpoints are considered, in 421 that priority. 422 423 This is the recommended way to get the mtimes, since it takes into account 424 the naming difference between V1 and V2 formats. 425 426 Note: If not all checkpoints exist, the length of the returned mtimes list 427 will be smaller than the length of `checkpoint_prefixes` list, so mapping 428 checkpoints to corresponding mtimes will not be possible. 429 430 Args: 431 checkpoint_prefixes: a list of checkpoint paths, typically the results of 432 `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of 433 sharded/non-sharded or V1/V2. 434 Returns: 435 A list of mtimes (in microseconds) of the found checkpoints. 436 """ 437 mtimes = [] 438 439 def match_maybe_append(pathname): 440 fnames = file_io.get_matching_files(pathname) 441 if fnames: 442 mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) 443 return True 444 return False 445 446 for checkpoint_prefix in checkpoint_prefixes: 447 # Tries V2's metadata file first. 448 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 449 saver_pb2.SaverDef.V2) 450 if match_maybe_append(pathname): 451 continue 452 # Otherwise, tries V1, where the prefix is the complete pathname. 453 match_maybe_append(checkpoint_prefix) 454 455 return mtimes 456 457 458@deprecation.deprecated( 459 date=None, 460 instructions="Use standard file APIs to delete files with this prefix.") 461@tf_export(v1=["train.remove_checkpoint"]) 462def remove_checkpoint(checkpoint_prefix, 463 checkpoint_format_version=saver_pb2.SaverDef.V2, 464 meta_graph_suffix="meta"): 465 """Removes a checkpoint given by `checkpoint_prefix`. 466 467 Args: 468 checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result 469 of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of 470 sharded/non-sharded or V1/V2. 471 checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to 472 `SaverDef.V2`. 473 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 474 """ 475 _delete_file_if_exists( 476 meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) 477 if checkpoint_format_version == saver_pb2.SaverDef.V2: 478 # V2 has a metadata file and some data files. 479 _delete_file_if_exists(checkpoint_prefix + ".index") 480 _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") 481 else: 482 # V1, Legacy. Exact match on the data file. 483 _delete_file_if_exists(checkpoint_prefix) 484 485 486def _delete_file_if_exists(filespec): 487 """Deletes files matching `filespec`.""" 488 for pathname in file_io.get_matching_files(filespec): 489 try: 490 file_io.delete_file(pathname) 491 except errors.NotFoundError: 492 logging.warning( 493 "Hit NotFoundError when deleting '%s', possibly because another " 494 "process/thread is also deleting/moving the same file", pathname) 495 496 497def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): 498 """Returns the meta graph filename. 499 500 Args: 501 checkpoint_filename: Name of the checkpoint file. 502 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 503 504 Returns: 505 MetaGraph file name. 506 """ 507 # If the checkpoint_filename is sharded, the checkpoint_filename could 508 # be of format model.ckpt-step#-?????-of-shard#. For example, 509 # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. 510 basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) 511 suffixed_filename = ".".join([basename, meta_graph_suffix]) 512 return suffixed_filename 513 514 515# TODO(allenl): Allow tf.keras.Model instances in the constructor directly? 516@tf_export("train.CheckpointManager") 517class CheckpointManager(object): 518 """Manages multiple checkpoints by keeping some and deleting unneeded ones. 519 520 Example usage: 521 522 ```python 523 import tensorflow as tf 524 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 525 manager = tf.train.CheckpointManager( 526 checkpoint, directory="/tmp/model", max_to_keep=5) 527 status = checkpoint.restore(manager.latest_checkpoint) 528 while True: 529 # train 530 manager.save() 531 ``` 532 533 `CheckpointManager` preserves its own state across instantiations (see the 534 `__init__` documentation for details). Only one should be active in a 535 particular directory at a time. 536 """ 537 538 def __init__(self, 539 checkpoint, 540 directory, 541 max_to_keep, 542 keep_checkpoint_every_n_hours=None, 543 checkpoint_name="ckpt", 544 step_counter=None, 545 checkpoint_interval=None, 546 init_fn=None): 547 """Configure a `CheckpointManager` for use in `directory`. 548 549 If a `CheckpointManager` was previously used in `directory`, its 550 state will be restored. This includes the list of managed checkpoints and 551 the timestamp bookkeeping necessary to support 552 `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager` 553 will be the same as the previous `CheckpointManager`, including cleaning up 554 existing checkpoints if appropriate. 555 556 Checkpoints are only considered for deletion just after a new checkpoint has 557 been added. At that point, `max_to_keep` checkpoints will remain in an 558 "active set". Once a checkpoint is preserved by 559 `keep_checkpoint_every_n_hours` it will not be deleted by this 560 `CheckpointManager` or any future `CheckpointManager` instantiated in 561 `directory` (regardless of the new setting of 562 `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the 563 active set may be deleted by this `CheckpointManager` or a future 564 `CheckpointManager` instantiated in `directory` (subject to its 565 `max_to_keep` and `keep_checkpoint_every_n_hours` settings). 566 567 `CheckpointManager` can be also used for initializing the model if 568 there is no checkpoints for restoring in `directory`. An example usage is: 569 570 >>> import tempfile 571 572 >>> tmp_dir = tempfile.mkdtemp() 573 >>> checkpoint = tf.train.Checkpoint() 574 >>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init')) 575 576 >>> def init_fn(): 577 ... # Partially restore the checkpoint from `init_path`. 578 ... checkpoint.restore(init_path) 579 580 >>> manager = tf.train.CheckpointManager( 581 ... checkpoint, 582 ... directory=os.path.join(tmp_dir, 'ckpt'), 583 ... max_to_keep=None, 584 ... init_fn=init_fn) 585 >>> # `restore_or_initialize` will call `init_fn` if there is no existing 586 >>> # checkpoint in `directory`. 587 >>> manager.restore_or_initialize() 588 589 Args: 590 checkpoint: The `tf.train.Checkpoint` instance to save and manage 591 checkpoints for. 592 directory: The path to a directory in which to write checkpoints. A 593 special file named "checkpoint" is also written to this directory (in a 594 human-readable text format) which contains the state of the 595 `CheckpointManager`. 596 max_to_keep: An integer, the number of checkpoints to keep. Unless 597 preserved by `keep_checkpoint_every_n_hours`, checkpoints will be 598 deleted from the active set, oldest first, until only `max_to_keep` 599 checkpoints remain. If `None`, no checkpoints are deleted and everything 600 stays in the active set. Note that `max_to_keep=None` will keep all 601 checkpoint paths in memory and in the checkpoint state protocol buffer 602 on disk. 603 keep_checkpoint_every_n_hours: Upon removal from the active set, a 604 checkpoint will be preserved if it has been at least 605 `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The 606 default setting of `None` does not preserve any checkpoints in this way. 607 checkpoint_name: Custom name for the checkpoint file. 608 step_counter: A `tf.Variable` instance for checking the current step 609 counter value, in case users want to save checkpoints every N steps. 610 checkpoint_interval: An integer, indicates the minimum step interval 611 between two checkpoints. 612 init_fn: Callable. A function to do customized intialization if no 613 checkpoints are in the directory. 614 615 Raises: 616 ValueError: If `max_to_keep` is not a positive integer. 617 """ 618 self._checkpoint = checkpoint 619 self._save_counter_assign = None 620 if max_to_keep is not None and max_to_keep <= 0: 621 raise ValueError( 622 ("Expected a positive integer or `None` for `max_to_keep`, " 623 "got %d.") 624 % (max_to_keep,)) 625 self._max_to_keep = max_to_keep 626 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 627 if isinstance(directory, os.PathLike): 628 directory = os.fspath(directory) 629 self._directory = directory 630 self._checkpoint_prefix = os.path.join(directory, checkpoint_name) 631 self._init_fn = init_fn 632 633 if checkpoint_interval is not None: 634 if step_counter is None: 635 raise ValueError("`step_counter` should be passed if " 636 "`checkpoint_interval` is not None.") 637 self._last_checkpoint_step = None 638 self._step_counter = step_counter 639 self._checkpoint_interval = checkpoint_interval 640 641 recovered_state = get_checkpoint_state(directory) 642 current_clock = time.time() 643 self._maybe_delete = collections.OrderedDict() 644 if recovered_state is None: 645 self._latest_checkpoint = None 646 # Set the clock back slightly to avoid race conditions when quickly 647 # re-creating a CheckpointManager. 648 self._last_preserved_timestamp = current_clock - 1. 649 else: 650 self._latest_checkpoint = recovered_state.model_checkpoint_path 651 self._last_preserved_timestamp = recovered_state.last_preserved_timestamp 652 if current_clock < self._last_preserved_timestamp: 653 # Time seems to have reversed itself. In addition to this warning, we'll 654 # min() saved checkpoint timestamps with the current time to ensure that 655 # old checkpoints don't get deleted accidentally. 656 logging.warning( 657 ("time.time() returned a value %f seconds behind the last " 658 "preserved checkpoint timestamp.") 659 % (self._last_preserved_timestamp - current_clock,)) 660 self._last_preserved_timestamp = current_clock 661 all_timestamps = recovered_state.all_model_checkpoint_timestamps 662 all_paths = recovered_state.all_model_checkpoint_paths 663 del recovered_state # Uses modified values from now on 664 if not all_timestamps: 665 all_timestamps = [self._last_preserved_timestamp] * len(all_paths) 666 667 for filename, timestamp in zip(all_paths, all_timestamps): 668 timestamp = min(timestamp, current_clock) 669 if timestamp > self._last_preserved_timestamp: 670 self._maybe_delete[filename] = timestamp 671 672 @property 673 def directory(self): 674 return self._directory 675 676 @property 677 def checkpoint_interval(self): 678 return self._checkpoint_interval 679 680 @property 681 def latest_checkpoint(self): 682 """The prefix of the most recent checkpoint in `directory`. 683 684 Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is 685 the constructor argument to `CheckpointManager`. 686 687 Suitable for passing to `tf.train.Checkpoint.restore` to resume training. 688 689 Returns: 690 The checkpoint prefix. If there are no checkpoints, returns `None`. 691 """ 692 return self._latest_checkpoint 693 694 @property 695 def checkpoints(self): 696 """A list of managed checkpoints. 697 698 Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not 699 show up in this list (to avoid ever-growing filename lists). 700 701 Returns: 702 A list of filenames, sorted from oldest to newest. 703 """ 704 return list(self._maybe_delete.keys()) 705 706 def _sweep(self): 707 """Deletes or preserves managed checkpoints.""" 708 if not self._max_to_keep: 709 # Does not update self._last_preserved_timestamp, since everything is kept 710 # in the active set. 711 return 712 while len(self._maybe_delete) > self._max_to_keep: 713 filename, timestamp = self._maybe_delete.popitem(last=False) 714 # Even if we're keeping this checkpoint due to 715 # keep_checkpoint_every_n_hours, we won't reference it to avoid 716 # infinitely-growing CheckpointState protos. 717 if (self._keep_checkpoint_every_n_hours 718 and (timestamp - self._keep_checkpoint_every_n_hours * 3600. 719 >= self._last_preserved_timestamp)): 720 self._last_preserved_timestamp = timestamp 721 continue 722 _delete_file_if_exists(filename + ".index") 723 _delete_file_if_exists(filename + ".data-?????-of-?????") 724 725 def _record_state(self): 726 """Saves the `CheckpointManager`'s state in `directory`.""" 727 filenames, timestamps = zip(*self._maybe_delete.items()) 728 update_checkpoint_state_internal( 729 self._directory, 730 model_checkpoint_path=self.latest_checkpoint, 731 all_model_checkpoint_paths=filenames, 732 all_model_checkpoint_timestamps=timestamps, 733 last_preserved_timestamp=self._last_preserved_timestamp, 734 save_relative_paths=True) 735 736 @property 737 def _prefix(self): 738 """A common prefix for all checkpoints saved with this manager. 739 740 For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`, 741 `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be 742 numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each 743 checkpoint has several associated files 744 (e.g. `"/tmp/tf-model/ckpt-2.index"`). 745 746 Returns: 747 A string prefix. 748 """ 749 return self._checkpoint_prefix 750 751 @property 752 def checkpoint(self): 753 """Returns the `tf.train.Checkpoint` object.""" 754 return self._checkpoint 755 756 def save(self, checkpoint_number=None, check_interval=True, options=None): 757 """Creates a new checkpoint and manages it. 758 759 Args: 760 checkpoint_number: An optional integer, or an integer-dtype `Variable` or 761 `Tensor`, used to number the checkpoint. If `None` (default), 762 checkpoints are numbered using `checkpoint.save_counter`. Even if 763 `checkpoint_number` is provided, `save_counter` is still incremented. A 764 user-provided `checkpoint_number` is not incremented even if it is a 765 `Variable`. 766 check_interval: An optional boolean. The argument is only effective when 767 `checkpoint_interval` is passed into the manager. If `True`, the manager 768 will only save the checkpoint if the interval between checkpoints is 769 larger than `checkpoint_interval`. Otherwise it will always save the 770 checkpoint unless a checkpoint has already been saved for the current 771 step. 772 options: Optional `tf.train.CheckpointOptions` object. This argument only 773 works with TF2 checkpoint objects. For example, options = 774 tf.saved_model.SaveOptions(experimental_io_device='/job:localhost') 775 776 Returns: 777 The path to the new checkpoint. It is also recorded in the `checkpoints` 778 and `latest_checkpoint` properties. `None` if no checkpoint is saved. 779 """ 780 if self._checkpoint_interval is not None: 781 current_step = _evaluate(self._step_counter) 782 if self._last_checkpoint_step is not None: 783 if current_step == self._last_checkpoint_step: 784 return None 785 if check_interval and current_step < ( 786 self._last_checkpoint_step + self._checkpoint_interval): 787 return None 788 self._last_checkpoint_step = current_step 789 790 # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge 791 # slightly with a custom numbering option. 792 if context.executing_eagerly(): 793 save_counter = self._checkpoint.save_counter 794 save_counter.assign_add(1) 795 session = None 796 else: 797 session = ops.get_default_session() 798 799 def _initializing_creator(next_creator, **kwargs): 800 """Initialize the save counter if it has been newly created.""" 801 v = next_creator(**kwargs) 802 session.run(v.initializer) 803 return v 804 805 with variable_scope.variable_creator_scope(_initializing_creator): 806 save_counter = self._checkpoint.save_counter 807 if self._save_counter_assign is None: 808 self._save_counter_assign = save_counter.assign_add(1, read_value=False) 809 session.run(self._save_counter_assign) 810 if checkpoint_number is None: 811 checkpoint_number = save_counter 812 if not isinstance(checkpoint_number, compat.integral_types): 813 checkpoint_number = training_util.global_step( 814 sess=session, global_step_tensor=checkpoint_number) 815 prefix = "%s-%d" % (self._prefix, checkpoint_number) 816 817 def _record_and_sweep_state(save_path): 818 timestamp = time.time() 819 # If this is an overwritten checkpoint we were previously tracking, delete 820 # and reinsert it to make sure it goes to the end of the queue. 821 if save_path in self._maybe_delete: 822 del self._maybe_delete[save_path] 823 self._maybe_delete[save_path] = timestamp 824 self._latest_checkpoint = save_path 825 # Before deleting anything we update the Checkpoint proto with the new 826 # checkpoint. We'll go back and correct it after cleaning up old files, 827 # but a preemption while deleting will be more likely to see the new 828 # checkpoint this way. 829 self._record_state() 830 self._sweep() 831 # Write out the Checkpoint proto a second time, now without the deleted 832 # checkpoints. 833 self._record_state() 834 835 if options is None: 836 save_path = self._checkpoint._write( # pylint: disable=protected-access 837 prefix, write_done_callback=_record_and_sweep_state) 838 else: 839 save_path = self._checkpoint._write( # pylint: disable=protected-access 840 prefix, options=options, write_done_callback=_record_and_sweep_state) 841 842 return save_path 843 844 def restore_or_initialize(self): 845 """Restore items in `checkpoint` from the latest checkpoint file. 846 847 This method will first try to restore from the most recent checkpoint in 848 `directory`. If no checkpoints exist in `directory`, and `init_fn` is 849 specified, this method will call `init_fn` to do customized 850 initialization. This can be used to support initialization from pretrained 851 models. 852 853 Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return 854 a load status object that users can run assertions on 855 (e.g. assert_consumed()). Thus to run assertions, users should directly use 856 `tf.train.Checkpoint.restore()` method. 857 858 Returns: 859 The restored checkpoint path if the lastest checkpoint is found and 860 restored. Otherwise None. 861 """ 862 # TODO(chienchunh): When AsyncCheckpoint is used, we may need to force to 863 # sync until any ongoing async save is done. Otherwise, if this is the first 864 # checkpoint and _latest_checkpoint has not been updated due to async write, 865 # this would resort to init_fn instead of restoring from the checkpoin file. 866 # This should be fixed once AsyncCheckpoint is integrated with the public 867 # API so that we can rely on CheckpointOptions to tell whether we should 868 # sync for AsyncCheckpoint. 869 if self._latest_checkpoint is not None: 870 self._checkpoint.restore(self._latest_checkpoint) 871 if self._checkpoint_interval is not None: 872 self._last_checkpoint_step = _evaluate(self._step_counter) 873 return self._latest_checkpoint 874 875 if self._init_fn is not None: 876 self._init_fn() 877 return None 878