1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Test async checkpointing.""" 16 17import os 18 19import numpy as np 20 21from tensorflow.python.compat import v2_compat 22from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.lib.io import file_io 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import metrics as metrics_lib 28from tensorflow.python.ops import variable_scope 29from tensorflow.python.ops.losses import losses 30from tensorflow.python.platform import flags 31from tensorflow.python.platform import test 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.tpu import async_checkpoint 34from tensorflow.python.tpu import tpu_config 35from tensorflow.python.tpu import tpu_estimator 36from tensorflow.python.tpu import tpu_optimizer 37from tensorflow.python.training import basic_session_run_hooks 38from tensorflow.python.training import training 39from tensorflow_estimator.python.estimator import estimator as estimator_lib 40from tensorflow_estimator.python.estimator import model_fn as model_fn_lib 41 42FLAGS = flags.FLAGS 43flags.DEFINE_string('tpu', '', 'TPU to use in this test.') 44flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.') 45flags.DEFINE_string('project', None, 'Name of GCP project with TPU.') 46flags.DEFINE_string( 47 'model_dir', 48 os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'), 49 'GCS path to store model and checkpoints.') 50 51 52def input_fn(params): 53 """Return a dataset of source and target sequences for training.""" 54 return (constant_op.constant( 55 np.random.randn(params['batch_size'], 1000), dtype=dtypes.float32), 56 constant_op.constant( 57 np.random.randint(0, 10, params['batch_size']), 58 dtype=dtypes.int32)) 59 60 61def model_fn(features, labels, mode, params): 62 del params # unused 63 with variable_scope.variable_scope('m', reuse=variable_scope.AUTO_REUSE): 64 w = variable_scope.get_variable('W', shape=[1000, 10]) 65 logits = math_ops.matmul(features, w) 66 loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 67 68 if mode == model_fn_lib.ModeKeys.TRAIN: 69 optimizer = training.RMSPropOptimizer(learning_rate=0.01) 70 optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) 71 train_op = optimizer.minimize(loss, training.get_global_step()) 72 return tpu_estimator.TPUEstimatorSpec( 73 mode=model_fn_lib.ModeKeys.TRAIN, 74 loss=loss, 75 train_op=train_op, 76 ) 77 elif mode == model_fn_lib.ModeKeys.EVAL: 78 79 def metric_fn(labels, logits): 80 labels = math_ops.cast(labels, dtypes.int64) 81 logging.info('LABELS %s %s', labels, logits) 82 return { 83 'recall@1': metrics_lib.recall_at_k(labels, logits, 1), 84 'recall@5': metrics_lib.recall_at_k(labels, logits, 5), 85 } 86 87 loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 88 eval_metrics = (metric_fn, [labels, logits]) 89 return tpu_estimator.TPUEstimatorSpec( 90 mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metrics=eval_metrics) 91 92 93class AsyncCheckpointingTest(test.TestCase): 94 95 def testAsyncCheckpointHookEnabled(self): 96 resolver = tpu_cluster_resolver.TPUClusterResolver( 97 tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) 98 99 checkpoint_interval = 5 100 config = tpu_config.RunConfig( 101 master=resolver.master(), 102 model_dir=os.path.join(FLAGS.model_dir, 'runconfig'), 103 save_checkpoints_steps=1000, 104 keep_checkpoint_max=11, # off by one 105 tpu_config=tpu_config.TPUConfig( 106 iterations_per_loop=checkpoint_interval,)) 107 108 estimator = tpu_estimator.TPUEstimator( 109 use_tpu=True, 110 model_fn=model_fn, 111 config=config, 112 train_batch_size=32, 113 eval_batch_size=32, 114 predict_batch_size=1, 115 params={}, 116 ) 117 118 i = 10 119 mock_listener = test.mock.create_autospec( 120 basic_session_run_hooks.CheckpointSaverListener) 121 estimator.train( 122 input_fn=input_fn, 123 max_steps=i * 10, 124 hooks=[ 125 async_checkpoint.AsyncCheckpointSaverHook( 126 FLAGS.model_dir, 127 save_steps=checkpoint_interval, 128 listeners=[mock_listener]) 129 ]) 130 131 current_step = estimator_lib._load_global_step_from_checkpoint_dir( 132 FLAGS.model_dir) # pylint: disable=protected-access 133 134 # TODO(power) -- identify a better way to count the number of checkpoints. 135 checkpoints = file_io.get_matching_files( 136 FLAGS.model_dir + '/model.ckpt*.meta') 137 checkpoint_count = len(checkpoints) 138 logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints) 139 self.assertLessEqual(checkpoint_count, 10) 140 self.assertEqual(current_step, i * 10) 141 mock_listener.before_save.assert_called() 142 mock_listener.after_save.assert_called() 143 144 def testAsyncCheckpointHookWithoutListeners(self): 145 resolver = tpu_cluster_resolver.TPUClusterResolver( 146 tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) 147 148 checkpoint_interval = 5 149 keep_checkpoint_max = 10 150 config = tpu_config.RunConfig( 151 master=resolver.master(), 152 model_dir=os.path.join(FLAGS.model_dir, 'runconfig'), 153 save_checkpoints_steps=1000, 154 keep_checkpoint_max=keep_checkpoint_max+1, # off by one 155 tpu_config=tpu_config.TPUConfig( 156 iterations_per_loop=checkpoint_interval,)) 157 158 estimator = tpu_estimator.TPUEstimator( 159 use_tpu=True, 160 model_fn=model_fn, 161 config=config, 162 train_batch_size=32, 163 eval_batch_size=32, 164 predict_batch_size=1, 165 params={}, 166 ) 167 168 max_steps = 100 169 estimator.train( 170 input_fn=input_fn, 171 max_steps=max_steps, 172 hooks=[ 173 async_checkpoint.AsyncCheckpointSaverHook( 174 FLAGS.model_dir, 175 save_steps=checkpoint_interval) 176 ]) 177 178 current_step = estimator_lib._load_global_step_from_checkpoint_dir( 179 FLAGS.model_dir) # pylint: disable=protected-access 180 181 # TODO(power) -- identify a better way to count the number of checkpoints. 182 checkpoints = file_io.get_matching_files( 183 FLAGS.model_dir + '/model.ckpt*.meta') 184 checkpoint_count = len(checkpoints) 185 logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints) 186 self.assertLessEqual(checkpoint_count, keep_checkpoint_max) 187 self.assertEqual(current_step, max_steps) 188 189 190if __name__ == '__main__': 191 v2_compat.disable_v2_behavior() 192 test.main() 193