1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test DistributionStrategy in the zero batch case.""" 16 17from absl.testing import parameterized 18import numpy as np 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.distribute import combinations 22from tensorflow.python.distribute import strategy_combinations 23from tensorflow.python.distribute import test_util 24from tensorflow.python.eager import backprop 25from tensorflow.python.eager import def_function 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.layers import normalization 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.ops.losses import losses 32from tensorflow.python.platform import test 33from tensorflow.python.training import gradient_descent 34 35 36class NormalizationTest(test.TestCase, parameterized.TestCase): 37 38 @combinations.generate( 39 combinations.combine( 40 distribution=[ 41 strategy_combinations.one_device_strategy, 42 ], 43 mode=["graph"], 44 fused=[True, False])) 45 def testBNWithZeroBatchInputGraph(self, distribution, fused): 46 distribution.extended.experimental_enable_get_next_as_optional = True 47 with distribution.scope(), self.cached_session() as sess: 48 bn_list = [] 49 inputs = np.random.random((0, 4, 4, 3)) + 100 50 targets = np.random.random((0, 4, 4, 3)) 51 inputs_placeholder = array_ops.placeholder( 52 dtype=dtypes.float32, shape=[None, 4, 4, 3]) 53 targets_placeholder = array_ops.placeholder( 54 dtype=dtypes.float32, shape=[None, 4, 4, 3]) 55 56 def step_fn(is_training, inputs, targets=None): 57 bn = normalization.BatchNormalization( 58 axis=3, epsilon=1e-3, momentum=0.9, fused=fused) 59 bn_list.append(bn) 60 outputs = bn.apply(inputs, training=is_training) 61 if not is_training: 62 return outputs 63 64 loss = losses.mean_squared_error(targets, outputs) 65 optimizer = gradient_descent.GradientDescentOptimizer(0.01) 66 train_op = optimizer.minimize(loss) 67 with ops.control_dependencies([train_op]): 68 return array_ops.identity(loss) 69 70 train_op = distribution.extended.call_for_each_replica( 71 step_fn, args=(True, inputs_placeholder, targets_placeholder)) 72 predict_op = distribution.extended.call_for_each_replica( 73 step_fn, args=(False, inputs_placeholder)) 74 bn = bn_list[0] 75 76 self.evaluate(variables.global_variables_initializer()) 77 78 # Check for initial statistics and weights. 79 moving_mean, moving_var = self.evaluate( 80 [bn.moving_mean, bn.moving_variance]) 81 self.assertAllEqual([0, 0, 0], moving_mean) 82 self.assertAllEqual([1, 1, 1], moving_var) 83 84 np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta]) 85 self.assertAllEqual([1, 1, 1], np_gamma) 86 self.assertAllEqual([0, 0, 0], np_beta) 87 88 for _ in range(100): 89 np_output, _, _ = sess.run([train_op] + bn.updates, { 90 inputs_placeholder: inputs, 91 targets_placeholder: targets 92 }) 93 self.assertEqual(0.0, np_output) 94 95 # Verify that the statistics and weights are not changed after training. 96 moving_mean, moving_var = self.evaluate( 97 [bn.moving_mean, bn.moving_variance]) 98 self.assertAllEqual([0, 0, 0], moving_mean) 99 self.assertAllEqual([1, 1, 1], moving_var) 100 101 np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta]) 102 self.assertAllEqual([1, 1, 1], np_gamma) 103 self.assertAllEqual([0, 0, 0], np_beta) 104 105 # Test inference. 106 np_output = sess.run(predict_op, {inputs_placeholder: inputs}) 107 self.assertEqual([], np_output.tolist()) 108 109 @combinations.generate( 110 combinations.combine( 111 distribution=[ 112 strategy_combinations.one_device_strategy, 113 ], 114 mode=["eager"], 115 fused=[True, False])) 116 def testBNWithZeroBatchInput(self, distribution, fused): 117 distribution.extended.experimental_enable_get_next_as_optional = True 118 with distribution.scope(): 119 inputs = np.random.random((0, 4, 4, 3)).astype(np.float32) + 100 120 targets = np.random.random((0, 4, 4, 3)).astype(np.float32) 121 bn = normalization.BatchNormalization( 122 axis=3, epsilon=1e-3, momentum=0.9, fused=fused) 123 optimizer = gradient_descent.GradientDescentOptimizer(0.01) 124 125 @def_function.function 126 def train_step(): 127 def step_fn(inputs, targets): 128 with backprop.GradientTape() as tape: 129 outputs = bn.apply(inputs, training=True) 130 loss = losses.mean_squared_error(targets, outputs) 131 grads = tape.gradient(loss, bn.variables) 132 optimizer.apply_gradients(zip(grads, bn.variables)) 133 return loss 134 135 return distribution.run(step_fn, args=(inputs, targets)) 136 137 for _ in range(100): 138 np_output = train_step().numpy() 139 self.assertEqual(0.0, np_output) 140 141 # Verify that the statistics and weights are not changed after training. 142 self.assertAllEqual([0, 0, 0], bn.moving_mean.numpy()) 143 self.assertAllEqual([1, 1, 1], bn.moving_variance.numpy()) 144 self.assertAllEqual([1, 1, 1], bn.gamma.numpy()) 145 self.assertAllEqual([0, 0, 0], bn.beta.numpy()) 146 147 @def_function.function 148 def test_step(): 149 def step_fn(inputs): 150 outputs = bn.apply(inputs, training=False) 151 return outputs 152 153 return distribution.run(step_fn, args=(inputs,)) 154 155 # Test inference. 156 self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32), 157 test_step().numpy()) 158 159 @combinations.generate( 160 combinations.combine( 161 distribution=[ 162 strategy_combinations.one_device_strategy, 163 ], 164 mode=["eager"], 165 fused=[True, False])) 166 def testBNWithDynamicBatchInputEager(self, distribution, fused): 167 distribution.extended.experimental_enable_get_next_as_optional = True 168 with distribution.scope(): 169 # Explicitly create dataset with drop_remainder=False. 170 # This would make batch size unknown. 171 inputs = np.random.random((11, 4, 4, 3)).astype(np.float32) + 100 172 targets = np.random.random((11, 4, 4, 3)).astype(np.float32) 173 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch( 174 10, drop_remainder=False).repeat() 175 dataset_iterator = iter( 176 distribution.experimental_distribute_dataset(dataset)) 177 178 bn = normalization.BatchNormalization( 179 axis=-1, epsilon=1e-3, momentum=0.9, fused=fused) 180 optimizer = gradient_descent.GradientDescentOptimizer(0.01) 181 182 @def_function.function 183 def train_step(iterator): 184 185 def step_fn(inputs): 186 features, targets = inputs 187 with backprop.GradientTape() as tape: 188 outputs = bn(features, training=True) 189 loss = losses.mean_squared_error(targets, outputs) 190 191 grads = tape.gradient(loss, bn.variables) 192 optimizer.apply_gradients(zip(grads, bn.variables)) 193 return loss 194 195 return distribution.run(step_fn, args=(next(iterator),)) 196 197 for _ in range(100): 198 train_step(dataset_iterator).numpy() 199 200 # Verify that the statistics and weights are updated. 201 self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.moving_mean.numpy()) 202 self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.moving_variance.numpy()) 203 self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.gamma.numpy()) 204 self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.beta.numpy()) 205 206 207if __name__ == "__main__": 208 test_util.main() 209