xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/async_checkpoint_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Test async checkpointing."""
16
17import os
18
19import numpy as np
20
21from tensorflow.python.compat import v2_compat
22from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.lib.io import file_io
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import metrics as metrics_lib
28from tensorflow.python.ops import variable_scope
29from tensorflow.python.ops.losses import losses
30from tensorflow.python.platform import flags
31from tensorflow.python.platform import test
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.tpu import async_checkpoint
34from tensorflow.python.tpu import tpu_config
35from tensorflow.python.tpu import tpu_estimator
36from tensorflow.python.tpu import tpu_optimizer
37from tensorflow.python.training import basic_session_run_hooks
38from tensorflow.python.training import training
39from tensorflow_estimator.python.estimator import estimator as estimator_lib
40from tensorflow_estimator.python.estimator import model_fn as model_fn_lib
41
42FLAGS = flags.FLAGS
43flags.DEFINE_string('tpu', '', 'TPU to use in this test.')
44flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
45flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
46flags.DEFINE_string(
47    'model_dir',
48    os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'),
49    'GCS path to store model and checkpoints.')
50
51
52def input_fn(params):
53  """Return a dataset of source and target sequences for training."""
54  return (constant_op.constant(
55      np.random.randn(params['batch_size'], 1000), dtype=dtypes.float32),
56          constant_op.constant(
57              np.random.randint(0, 10, params['batch_size']),
58              dtype=dtypes.int32))
59
60
61def model_fn(features, labels, mode, params):
62  del params  # unused
63  with variable_scope.variable_scope('m', reuse=variable_scope.AUTO_REUSE):
64    w = variable_scope.get_variable('W', shape=[1000, 10])
65  logits = math_ops.matmul(features, w)
66  loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
67
68  if mode == model_fn_lib.ModeKeys.TRAIN:
69    optimizer = training.RMSPropOptimizer(learning_rate=0.01)
70    optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
71    train_op = optimizer.minimize(loss, training.get_global_step())
72    return tpu_estimator.TPUEstimatorSpec(
73        mode=model_fn_lib.ModeKeys.TRAIN,
74        loss=loss,
75        train_op=train_op,
76    )
77  elif mode == model_fn_lib.ModeKeys.EVAL:
78
79    def metric_fn(labels, logits):
80      labels = math_ops.cast(labels, dtypes.int64)
81      logging.info('LABELS %s %s', labels, logits)
82      return {
83          'recall@1': metrics_lib.recall_at_k(labels, logits, 1),
84          'recall@5': metrics_lib.recall_at_k(labels, logits, 5),
85      }
86
87    loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
88    eval_metrics = (metric_fn, [labels, logits])
89    return tpu_estimator.TPUEstimatorSpec(
90        mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metrics=eval_metrics)
91
92
93class AsyncCheckpointingTest(test.TestCase):
94
95  def testAsyncCheckpointHookEnabled(self):
96    resolver = tpu_cluster_resolver.TPUClusterResolver(
97        tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
98
99    checkpoint_interval = 5
100    config = tpu_config.RunConfig(
101        master=resolver.master(),
102        model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
103        save_checkpoints_steps=1000,
104        keep_checkpoint_max=11,  # off by one
105        tpu_config=tpu_config.TPUConfig(
106            iterations_per_loop=checkpoint_interval,))
107
108    estimator = tpu_estimator.TPUEstimator(
109        use_tpu=True,
110        model_fn=model_fn,
111        config=config,
112        train_batch_size=32,
113        eval_batch_size=32,
114        predict_batch_size=1,
115        params={},
116    )
117
118    i = 10
119    mock_listener = test.mock.create_autospec(
120        basic_session_run_hooks.CheckpointSaverListener)
121    estimator.train(
122        input_fn=input_fn,
123        max_steps=i * 10,
124        hooks=[
125            async_checkpoint.AsyncCheckpointSaverHook(
126                FLAGS.model_dir,
127                save_steps=checkpoint_interval,
128                listeners=[mock_listener])
129        ])
130
131    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
132        FLAGS.model_dir)  # pylint: disable=protected-access
133
134    # TODO(power) -- identify a better way to count the number of checkpoints.
135    checkpoints = file_io.get_matching_files(
136        FLAGS.model_dir + '/model.ckpt*.meta')
137    checkpoint_count = len(checkpoints)
138    logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
139    self.assertLessEqual(checkpoint_count, 10)
140    self.assertEqual(current_step, i * 10)
141    mock_listener.before_save.assert_called()
142    mock_listener.after_save.assert_called()
143
144  def testAsyncCheckpointHookWithoutListeners(self):
145    resolver = tpu_cluster_resolver.TPUClusterResolver(
146        tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
147
148    checkpoint_interval = 5
149    keep_checkpoint_max = 10
150    config = tpu_config.RunConfig(
151        master=resolver.master(),
152        model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
153        save_checkpoints_steps=1000,
154        keep_checkpoint_max=keep_checkpoint_max+1,  # off by one
155        tpu_config=tpu_config.TPUConfig(
156            iterations_per_loop=checkpoint_interval,))
157
158    estimator = tpu_estimator.TPUEstimator(
159        use_tpu=True,
160        model_fn=model_fn,
161        config=config,
162        train_batch_size=32,
163        eval_batch_size=32,
164        predict_batch_size=1,
165        params={},
166    )
167
168    max_steps = 100
169    estimator.train(
170        input_fn=input_fn,
171        max_steps=max_steps,
172        hooks=[
173            async_checkpoint.AsyncCheckpointSaverHook(
174                FLAGS.model_dir,
175                save_steps=checkpoint_interval)
176        ])
177
178    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
179        FLAGS.model_dir)  # pylint: disable=protected-access
180
181    # TODO(power) -- identify a better way to count the number of checkpoints.
182    checkpoints = file_io.get_matching_files(
183        FLAGS.model_dir + '/model.ckpt*.meta')
184    checkpoint_count = len(checkpoints)
185    logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
186    self.assertLessEqual(checkpoint_count, keep_checkpoint_max)
187    self.assertEqual(current_step, max_steps)
188
189
190if __name__ == '__main__':
191  v2_compat.disable_v2_behavior()
192  test.main()
193