xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/evaluation_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.training.evaluation."""
16
17import os
18
19import numpy as np
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import random_seed
25from tensorflow.python.layers import layers
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import metrics as metrics_module
29from tensorflow.python.ops import state_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.ops.losses import losses
32from tensorflow.python.platform import test
33from tensorflow.python.training import basic_session_run_hooks
34from tensorflow.python.training import evaluation
35from tensorflow.python.training import gradient_descent
36from tensorflow.python.training import monitored_session
37from tensorflow.python.training import saver
38from tensorflow.python.training import training
39
40_USE_GLOBAL_STEP = 0
41
42
43def logistic_classifier(inputs):
44  return layers.dense(inputs, 1, activation=math_ops.sigmoid)
45
46
47def local_variable(init_value, name):
48  return variable_scope.get_variable(
49      name,
50      dtype=dtypes.float32,
51      initializer=init_value,
52      trainable=False,
53      collections=[ops.GraphKeys.LOCAL_VARIABLES])
54
55
56class EvaluateOnceTest(test.TestCase):
57
58  def setUp(self):
59    super(EvaluateOnceTest, self).setUp()
60
61    # Create an easy training set:
62    np.random.seed(0)
63
64    self._inputs = np.zeros((16, 4))
65    self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
66
67    for i in range(16):
68      j = int(2 * self._labels[i] + np.random.randint(0, 2))
69      self._inputs[i, j] = 1
70
71  def _train_model(self, checkpoint_dir, num_steps):
72    """Trains a simple classification model.
73
74    Note that the data has been configured such that after around 300 steps,
75    the model has memorized the dataset (e.g. we can expect %100 accuracy).
76
77    Args:
78      checkpoint_dir: The directory where the checkpoint is written to.
79      num_steps: The number of steps to train for.
80    """
81    with ops.Graph().as_default():
82      random_seed.set_random_seed(0)
83      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
84      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
85
86      tf_predictions = logistic_classifier(tf_inputs)
87      loss_op = losses.log_loss(labels=tf_labels, predictions=tf_predictions)
88
89      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
90      train_op = optimizer.minimize(loss_op,
91                                    training.get_or_create_global_step())
92
93      with monitored_session.MonitoredTrainingSession(
94          checkpoint_dir=checkpoint_dir,
95          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)]) as session:
96        loss = None
97        while not session.should_stop():
98          _, loss = session.run([train_op, loss_op])
99
100        if num_steps >= 300:
101          assert loss < .015
102
103  def testEvaluatePerfectModel(self):
104    checkpoint_dir = os.path.join(self.get_temp_dir(),
105                                  'evaluate_perfect_model_once')
106
107    # Train a Model to completion:
108    self._train_model(checkpoint_dir, num_steps=300)
109
110    # Run
111    inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
112    labels = constant_op.constant(self._labels, dtype=dtypes.float32)
113    logits = logistic_classifier(inputs)
114    predictions = math_ops.round(logits)
115
116    accuracy, update_op = metrics_module.accuracy(labels, predictions)
117
118    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
119
120    final_ops_values = evaluation._evaluate_once(
121        checkpoint_path=checkpoint_path,
122        eval_ops=update_op,
123        final_ops={'accuracy': (accuracy, update_op)},
124        hooks=[
125            evaluation._StopAfterNEvalsHook(1),
126        ])
127    self.assertGreater(final_ops_values['accuracy'], .99)
128
129  def testEvaluateWithFiniteInputs(self):
130    checkpoint_dir = os.path.join(self.get_temp_dir(),
131                                  'evaluate_with_finite_inputs')
132
133    # Train a Model to completion:
134    self._train_model(checkpoint_dir, num_steps=300)
135
136    # Run evaluation. Inputs are fed through input producer for one epoch.
137    all_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
138    all_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
139
140    single_input, single_label = training.slice_input_producer(
141        [all_inputs, all_labels], num_epochs=1)
142    inputs, labels = training.batch([single_input, single_label], batch_size=6,
143                                    allow_smaller_final_batch=True)
144
145    logits = logistic_classifier(inputs)
146    predictions = math_ops.round(logits)
147
148    accuracy, update_op = metrics_module.accuracy(labels, predictions)
149
150    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
151
152    final_ops_values = evaluation._evaluate_once(
153        checkpoint_path=checkpoint_path,
154        eval_ops=update_op,
155        final_ops={
156            'accuracy': (accuracy, update_op),
157            'eval_steps': evaluation._get_or_create_eval_step()
158        },
159        hooks=[
160            evaluation._StopAfterNEvalsHook(None),
161        ])
162    self.assertTrue(final_ops_values['accuracy'] > .99)
163    # Runs evaluation for 4 iterations. First 2 evaluate full batch of 6 inputs
164    # each; the 3rd iter evaluates the remaining 4 inputs, and the last one
165    # triggers an error which stops evaluation.
166    self.assertEqual(final_ops_values['eval_steps'], 4)
167
168  def testEvalOpAndFinalOp(self):
169    checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')
170
171    # Train a model for a single step to get a checkpoint.
172    self._train_model(checkpoint_dir, num_steps=1)
173    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
174
175    # Create the model so we have something to restore.
176    inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
177    logistic_classifier(inputs)
178
179    num_evals = 5
180    final_increment = 9.0
181
182    my_var = local_variable(0.0, name='MyVar')
183    eval_ops = state_ops.assign_add(my_var, 1.0)
184    final_ops = array_ops.identity(my_var) + final_increment
185
186    final_hooks = [evaluation._StopAfterNEvalsHook(num_evals),]
187    initial_hooks = list(final_hooks)
188    final_ops_values = evaluation._evaluate_once(
189        checkpoint_path=checkpoint_path,
190        eval_ops=eval_ops,
191        final_ops={'value': final_ops},
192        hooks=final_hooks)
193    self.assertEqual(final_ops_values['value'], num_evals + final_increment)
194    self.assertEqual(initial_hooks, final_hooks)
195
196  def testMultiEvalStepIncrements(self):
197    checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')
198
199    # Train a model for a single step to get a checkpoint.
200    self._train_model(checkpoint_dir, num_steps=1)
201    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
202
203    # Create the model so we have something to restore.
204    inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
205    logistic_classifier(inputs)
206
207    num_evals = 6
208
209    my_var = local_variable(0.0, name='MyVar')
210    # In eval ops, we also increase the eval step one more time.
211    eval_ops = [state_ops.assign_add(my_var, 1.0),
212                state_ops.assign_add(
213                    evaluation._get_or_create_eval_step(), 1, use_locking=True)]
214    expect_eval_update_counts = num_evals // 2
215
216    final_ops = array_ops.identity(my_var)
217
218    final_ops_values = evaluation._evaluate_once(
219        checkpoint_path=checkpoint_path,
220        eval_ops=eval_ops,
221        final_ops={'value': final_ops},
222        hooks=[evaluation._StopAfterNEvalsHook(num_evals),])
223    self.assertEqual(final_ops_values['value'], expect_eval_update_counts)
224
225  def testOnlyFinalOp(self):
226    checkpoint_dir = os.path.join(self.get_temp_dir(), 'only_final_ops')
227
228    # Train a model for a single step to get a checkpoint.
229    self._train_model(checkpoint_dir, num_steps=1)
230    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
231
232    # Create the model so we have something to restore.
233    inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
234    logistic_classifier(inputs)
235
236    final_increment = 9.0
237
238    my_var = local_variable(0.0, name='MyVar')
239    final_ops = array_ops.identity(my_var) + final_increment
240
241    final_ops_values = evaluation._evaluate_once(
242        checkpoint_path=checkpoint_path, final_ops={'value': final_ops})
243    self.assertEqual(final_ops_values['value'], final_increment)
244
245
246if __name__ == '__main__':
247  test.main()
248