1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf.training.evaluation.""" 16 17import os 18 19import numpy as np 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import random_seed 25from tensorflow.python.layers import layers 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import metrics as metrics_module 29from tensorflow.python.ops import state_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.ops.losses import losses 32from tensorflow.python.platform import test 33from tensorflow.python.training import basic_session_run_hooks 34from tensorflow.python.training import evaluation 35from tensorflow.python.training import gradient_descent 36from tensorflow.python.training import monitored_session 37from tensorflow.python.training import saver 38from tensorflow.python.training import training 39 40_USE_GLOBAL_STEP = 0 41 42 43def logistic_classifier(inputs): 44 return layers.dense(inputs, 1, activation=math_ops.sigmoid) 45 46 47def local_variable(init_value, name): 48 return variable_scope.get_variable( 49 name, 50 dtype=dtypes.float32, 51 initializer=init_value, 52 trainable=False, 53 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 54 55 56class EvaluateOnceTest(test.TestCase): 57 58 def setUp(self): 59 super(EvaluateOnceTest, self).setUp() 60 61 # Create an easy training set: 62 np.random.seed(0) 63 64 self._inputs = np.zeros((16, 4)) 65 self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 66 67 for i in range(16): 68 j = int(2 * self._labels[i] + np.random.randint(0, 2)) 69 self._inputs[i, j] = 1 70 71 def _train_model(self, checkpoint_dir, num_steps): 72 """Trains a simple classification model. 73 74 Note that the data has been configured such that after around 300 steps, 75 the model has memorized the dataset (e.g. we can expect %100 accuracy). 76 77 Args: 78 checkpoint_dir: The directory where the checkpoint is written to. 79 num_steps: The number of steps to train for. 80 """ 81 with ops.Graph().as_default(): 82 random_seed.set_random_seed(0) 83 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 84 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 85 86 tf_predictions = logistic_classifier(tf_inputs) 87 loss_op = losses.log_loss(labels=tf_labels, predictions=tf_predictions) 88 89 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 90 train_op = optimizer.minimize(loss_op, 91 training.get_or_create_global_step()) 92 93 with monitored_session.MonitoredTrainingSession( 94 checkpoint_dir=checkpoint_dir, 95 hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)]) as session: 96 loss = None 97 while not session.should_stop(): 98 _, loss = session.run([train_op, loss_op]) 99 100 if num_steps >= 300: 101 assert loss < .015 102 103 def testEvaluatePerfectModel(self): 104 checkpoint_dir = os.path.join(self.get_temp_dir(), 105 'evaluate_perfect_model_once') 106 107 # Train a Model to completion: 108 self._train_model(checkpoint_dir, num_steps=300) 109 110 # Run 111 inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 112 labels = constant_op.constant(self._labels, dtype=dtypes.float32) 113 logits = logistic_classifier(inputs) 114 predictions = math_ops.round(logits) 115 116 accuracy, update_op = metrics_module.accuracy(labels, predictions) 117 118 checkpoint_path = saver.latest_checkpoint(checkpoint_dir) 119 120 final_ops_values = evaluation._evaluate_once( 121 checkpoint_path=checkpoint_path, 122 eval_ops=update_op, 123 final_ops={'accuracy': (accuracy, update_op)}, 124 hooks=[ 125 evaluation._StopAfterNEvalsHook(1), 126 ]) 127 self.assertGreater(final_ops_values['accuracy'], .99) 128 129 def testEvaluateWithFiniteInputs(self): 130 checkpoint_dir = os.path.join(self.get_temp_dir(), 131 'evaluate_with_finite_inputs') 132 133 # Train a Model to completion: 134 self._train_model(checkpoint_dir, num_steps=300) 135 136 # Run evaluation. Inputs are fed through input producer for one epoch. 137 all_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 138 all_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 139 140 single_input, single_label = training.slice_input_producer( 141 [all_inputs, all_labels], num_epochs=1) 142 inputs, labels = training.batch([single_input, single_label], batch_size=6, 143 allow_smaller_final_batch=True) 144 145 logits = logistic_classifier(inputs) 146 predictions = math_ops.round(logits) 147 148 accuracy, update_op = metrics_module.accuracy(labels, predictions) 149 150 checkpoint_path = saver.latest_checkpoint(checkpoint_dir) 151 152 final_ops_values = evaluation._evaluate_once( 153 checkpoint_path=checkpoint_path, 154 eval_ops=update_op, 155 final_ops={ 156 'accuracy': (accuracy, update_op), 157 'eval_steps': evaluation._get_or_create_eval_step() 158 }, 159 hooks=[ 160 evaluation._StopAfterNEvalsHook(None), 161 ]) 162 self.assertTrue(final_ops_values['accuracy'] > .99) 163 # Runs evaluation for 4 iterations. First 2 evaluate full batch of 6 inputs 164 # each; the 3rd iter evaluates the remaining 4 inputs, and the last one 165 # triggers an error which stops evaluation. 166 self.assertEqual(final_ops_values['eval_steps'], 4) 167 168 def testEvalOpAndFinalOp(self): 169 checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops') 170 171 # Train a model for a single step to get a checkpoint. 172 self._train_model(checkpoint_dir, num_steps=1) 173 checkpoint_path = saver.latest_checkpoint(checkpoint_dir) 174 175 # Create the model so we have something to restore. 176 inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 177 logistic_classifier(inputs) 178 179 num_evals = 5 180 final_increment = 9.0 181 182 my_var = local_variable(0.0, name='MyVar') 183 eval_ops = state_ops.assign_add(my_var, 1.0) 184 final_ops = array_ops.identity(my_var) + final_increment 185 186 final_hooks = [evaluation._StopAfterNEvalsHook(num_evals),] 187 initial_hooks = list(final_hooks) 188 final_ops_values = evaluation._evaluate_once( 189 checkpoint_path=checkpoint_path, 190 eval_ops=eval_ops, 191 final_ops={'value': final_ops}, 192 hooks=final_hooks) 193 self.assertEqual(final_ops_values['value'], num_evals + final_increment) 194 self.assertEqual(initial_hooks, final_hooks) 195 196 def testMultiEvalStepIncrements(self): 197 checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops') 198 199 # Train a model for a single step to get a checkpoint. 200 self._train_model(checkpoint_dir, num_steps=1) 201 checkpoint_path = saver.latest_checkpoint(checkpoint_dir) 202 203 # Create the model so we have something to restore. 204 inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 205 logistic_classifier(inputs) 206 207 num_evals = 6 208 209 my_var = local_variable(0.0, name='MyVar') 210 # In eval ops, we also increase the eval step one more time. 211 eval_ops = [state_ops.assign_add(my_var, 1.0), 212 state_ops.assign_add( 213 evaluation._get_or_create_eval_step(), 1, use_locking=True)] 214 expect_eval_update_counts = num_evals // 2 215 216 final_ops = array_ops.identity(my_var) 217 218 final_ops_values = evaluation._evaluate_once( 219 checkpoint_path=checkpoint_path, 220 eval_ops=eval_ops, 221 final_ops={'value': final_ops}, 222 hooks=[evaluation._StopAfterNEvalsHook(num_evals),]) 223 self.assertEqual(final_ops_values['value'], expect_eval_update_counts) 224 225 def testOnlyFinalOp(self): 226 checkpoint_dir = os.path.join(self.get_temp_dir(), 'only_final_ops') 227 228 # Train a model for a single step to get a checkpoint. 229 self._train_model(checkpoint_dir, num_steps=1) 230 checkpoint_path = saver.latest_checkpoint(checkpoint_dir) 231 232 # Create the model so we have something to restore. 233 inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 234 logistic_classifier(inputs) 235 236 final_increment = 9.0 237 238 my_var = local_variable(0.0, name='MyVar') 239 final_ops = array_ops.identity(my_var) + final_increment 240 241 final_ops_values = evaluation._evaluate_once( 242 checkpoint_path=checkpoint_path, final_ops={'value': final_ops}) 243 self.assertEqual(final_ops_values['value'], final_increment) 244 245 246if __name__ == '__main__': 247 test.main() 248