xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/zero_batch_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test DistributionStrategy in the zero batch case."""
16
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.distribute import combinations
22from tensorflow.python.distribute import strategy_combinations
23from tensorflow.python.distribute import test_util
24from tensorflow.python.eager import backprop
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.layers import normalization
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.ops.losses import losses
32from tensorflow.python.platform import test
33from tensorflow.python.training import gradient_descent
34
35
36class NormalizationTest(test.TestCase, parameterized.TestCase):
37
38  @combinations.generate(
39      combinations.combine(
40          distribution=[
41              strategy_combinations.one_device_strategy,
42          ],
43          mode=["graph"],
44          fused=[True, False]))
45  def testBNWithZeroBatchInputGraph(self, distribution, fused):
46    distribution.extended.experimental_enable_get_next_as_optional = True
47    with distribution.scope(), self.cached_session() as sess:
48      bn_list = []
49      inputs = np.random.random((0, 4, 4, 3)) + 100
50      targets = np.random.random((0, 4, 4, 3))
51      inputs_placeholder = array_ops.placeholder(
52          dtype=dtypes.float32, shape=[None, 4, 4, 3])
53      targets_placeholder = array_ops.placeholder(
54          dtype=dtypes.float32, shape=[None, 4, 4, 3])
55
56      def step_fn(is_training, inputs, targets=None):
57        bn = normalization.BatchNormalization(
58            axis=3, epsilon=1e-3, momentum=0.9, fused=fused)
59        bn_list.append(bn)
60        outputs = bn.apply(inputs, training=is_training)
61        if not is_training:
62          return outputs
63
64        loss = losses.mean_squared_error(targets, outputs)
65        optimizer = gradient_descent.GradientDescentOptimizer(0.01)
66        train_op = optimizer.minimize(loss)
67        with ops.control_dependencies([train_op]):
68          return array_ops.identity(loss)
69
70      train_op = distribution.extended.call_for_each_replica(
71          step_fn, args=(True, inputs_placeholder, targets_placeholder))
72      predict_op = distribution.extended.call_for_each_replica(
73          step_fn, args=(False, inputs_placeholder))
74      bn = bn_list[0]
75
76      self.evaluate(variables.global_variables_initializer())
77
78      # Check for initial statistics and weights.
79      moving_mean, moving_var = self.evaluate(
80          [bn.moving_mean, bn.moving_variance])
81      self.assertAllEqual([0, 0, 0], moving_mean)
82      self.assertAllEqual([1, 1, 1], moving_var)
83
84      np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
85      self.assertAllEqual([1, 1, 1], np_gamma)
86      self.assertAllEqual([0, 0, 0], np_beta)
87
88      for _ in range(100):
89        np_output, _, _ = sess.run([train_op] + bn.updates, {
90            inputs_placeholder: inputs,
91            targets_placeholder: targets
92        })
93        self.assertEqual(0.0, np_output)
94
95      # Verify that the statistics and weights are not changed after training.
96      moving_mean, moving_var = self.evaluate(
97          [bn.moving_mean, bn.moving_variance])
98      self.assertAllEqual([0, 0, 0], moving_mean)
99      self.assertAllEqual([1, 1, 1], moving_var)
100
101      np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
102      self.assertAllEqual([1, 1, 1], np_gamma)
103      self.assertAllEqual([0, 0, 0], np_beta)
104
105      # Test inference.
106      np_output = sess.run(predict_op, {inputs_placeholder: inputs})
107      self.assertEqual([], np_output.tolist())
108
109  @combinations.generate(
110      combinations.combine(
111          distribution=[
112              strategy_combinations.one_device_strategy,
113          ],
114          mode=["eager"],
115          fused=[True, False]))
116  def testBNWithZeroBatchInput(self, distribution, fused):
117    distribution.extended.experimental_enable_get_next_as_optional = True
118    with distribution.scope():
119      inputs = np.random.random((0, 4, 4, 3)).astype(np.float32) + 100
120      targets = np.random.random((0, 4, 4, 3)).astype(np.float32)
121      bn = normalization.BatchNormalization(
122          axis=3, epsilon=1e-3, momentum=0.9, fused=fused)
123      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
124
125      @def_function.function
126      def train_step():
127        def step_fn(inputs, targets):
128          with backprop.GradientTape() as tape:
129            outputs = bn.apply(inputs, training=True)
130            loss = losses.mean_squared_error(targets, outputs)
131          grads = tape.gradient(loss, bn.variables)
132          optimizer.apply_gradients(zip(grads, bn.variables))
133          return loss
134
135        return distribution.run(step_fn, args=(inputs, targets))
136
137      for _ in range(100):
138        np_output = train_step().numpy()
139        self.assertEqual(0.0, np_output)
140
141      # Verify that the statistics and weights are not changed after training.
142      self.assertAllEqual([0, 0, 0], bn.moving_mean.numpy())
143      self.assertAllEqual([1, 1, 1], bn.moving_variance.numpy())
144      self.assertAllEqual([1, 1, 1], bn.gamma.numpy())
145      self.assertAllEqual([0, 0, 0], bn.beta.numpy())
146
147      @def_function.function
148      def test_step():
149        def step_fn(inputs):
150          outputs = bn.apply(inputs, training=False)
151          return outputs
152
153        return distribution.run(step_fn, args=(inputs,))
154
155      # Test inference.
156      self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32),
157                          test_step().numpy())
158
159  @combinations.generate(
160      combinations.combine(
161          distribution=[
162              strategy_combinations.one_device_strategy,
163          ],
164          mode=["eager"],
165          fused=[True, False]))
166  def testBNWithDynamicBatchInputEager(self, distribution, fused):
167    distribution.extended.experimental_enable_get_next_as_optional = True
168    with distribution.scope():
169      # Explicitly create dataset with drop_remainder=False.
170      # This would make batch size unknown.
171      inputs = np.random.random((11, 4, 4, 3)).astype(np.float32) + 100
172      targets = np.random.random((11, 4, 4, 3)).astype(np.float32)
173      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch(
174          10, drop_remainder=False).repeat()
175      dataset_iterator = iter(
176          distribution.experimental_distribute_dataset(dataset))
177
178      bn = normalization.BatchNormalization(
179          axis=-1, epsilon=1e-3, momentum=0.9, fused=fused)
180      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
181
182      @def_function.function
183      def train_step(iterator):
184
185        def step_fn(inputs):
186          features, targets = inputs
187          with backprop.GradientTape() as tape:
188            outputs = bn(features, training=True)
189            loss = losses.mean_squared_error(targets, outputs)
190
191          grads = tape.gradient(loss, bn.variables)
192          optimizer.apply_gradients(zip(grads, bn.variables))
193          return loss
194
195        return distribution.run(step_fn, args=(next(iterator),))
196
197      for _ in range(100):
198        train_step(dataset_iterator).numpy()
199
200      # Verify that the statistics and weights are updated.
201      self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.moving_mean.numpy())
202      self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.moving_variance.numpy())
203      self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.gamma.numpy())
204      self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.beta.numpy())
205
206
207if __name__ == "__main__":
208  test_util.main()
209