1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for a little bit of strategy_combinations.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python import tf2 20from tensorflow.python.distribute import central_storage_strategy 21from tensorflow.python.distribute import collective_all_reduce_strategy 22from tensorflow.python.distribute import combinations 23from tensorflow.python.distribute import mirrored_strategy 24from tensorflow.python.distribute import one_device_strategy 25from tensorflow.python.distribute import parameter_server_strategy_v2 26from tensorflow.python.distribute import reduce_util 27from tensorflow.python.distribute import strategy_combinations 28from tensorflow.python.distribute import test_util 29from tensorflow.python.distribute import tpu_strategy 30from tensorflow.python.eager import def_function 31from tensorflow.python.framework import constant_op 32from tensorflow.python.ops import array_ops 33from tensorflow.python.platform import test 34 35 36class StrategyCombinationsTest(test.TestCase, parameterized.TestCase): 37 38 @combinations.generate( 39 combinations.combine( 40 strategy=strategy_combinations.two_replica_strategies, 41 mode=["graph", "eager"])) 42 def testTwoReplicaStrategy(self, strategy): 43 with strategy.scope(): 44 45 @def_function.function 46 def one(): 47 return array_ops.identity(1.) 48 49 one_per_replica = strategy.run(one) 50 num_replicas = strategy.reduce( 51 reduce_util.ReduceOp.SUM, one_per_replica, axis=None) 52 self.assertEqual(self.evaluate(num_replicas), 2.) 53 54 @combinations.generate( 55 combinations.combine( 56 strategy=strategy_combinations.four_replica_strategies, 57 mode=["graph", "eager"])) 58 def testFourReplicaStrategy(self, strategy): 59 with strategy.scope(): 60 61 @def_function.function 62 def one(): 63 return array_ops.identity(1.) 64 65 one_per_replica = strategy.run(one) 66 num_replicas = strategy.reduce( 67 reduce_util.ReduceOp.SUM, one_per_replica, axis=None) 68 self.assertEqual(self.evaluate(num_replicas), 4.) 69 70 @combinations.generate( 71 combinations.combine( 72 distribution=[ 73 strategy_combinations.mirrored_strategy_with_cpu_1_and_2 74 ], 75 mode=["graph", "eager"])) 76 def testMirrored2CPUs(self, distribution): 77 with distribution.scope(): 78 one_per_replica = distribution.run(lambda: constant_op.constant(1)) 79 num_replicas = distribution.reduce( 80 reduce_util.ReduceOp.SUM, one_per_replica, axis=None) 81 self.assertEqual(2, self.evaluate(num_replicas)) 82 83 84class V1StrategyTest(test.TestCase, parameterized.TestCase): 85 86 def setUp(self): 87 super().setUp() 88 tf2.disable() 89 90 @combinations.generate( 91 combinations.combine(strategy=[ 92 strategy_combinations.one_device_strategy, 93 strategy_combinations.one_device_strategy_gpu, 94 strategy_combinations.one_device_strategy_gpu_on_worker_1, 95 strategy_combinations.one_device_strategy_on_worker_1 96 ])) 97 def testOneDevice(self, strategy): 98 self.assertIsInstance(strategy, one_device_strategy.OneDeviceStrategyV1) 99 100 @combinations.generate( 101 combinations.combine(strategy=[ 102 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 103 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 104 strategy_combinations.mirrored_strategy_with_one_cpu, 105 strategy_combinations.mirrored_strategy_with_one_gpu, 106 strategy_combinations.mirrored_strategy_with_two_gpus, 107 ])) 108 def testMirrored(self, strategy): 109 self.assertIsInstance(strategy, mirrored_strategy.MirroredStrategyV1) 110 111 @combinations.generate( 112 combinations.combine(strategy=[ 113 strategy_combinations.multi_worker_mirrored_2x1_cpu, 114 strategy_combinations.multi_worker_mirrored_2x1_gpu, 115 strategy_combinations.multi_worker_mirrored_2x2_gpu, 116 strategy_combinations.multi_worker_mirrored_4x1_cpu, 117 ])) 118 def testMultiWorkerMirrored(self, strategy): 119 # MultiWorkerMirroredStrategy combinations only supports V2. 120 self.assertIsInstance( 121 strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy) 122 123 @combinations.generate( 124 combinations.combine(strategy=[ 125 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 126 strategy_combinations.central_storage_strategy_with_two_gpus, 127 ])) 128 def testCentralStorage(self, strategy): 129 self.assertIsInstance(strategy, 130 central_storage_strategy.CentralStorageStrategyV1) 131 132 @combinations.generate( 133 combinations.combine(strategy=strategy_combinations.tpu_strategies)) 134 def testTPU(self, strategy): 135 self.assertIsInstance(strategy, tpu_strategy.TPUStrategyV1) 136 137 138class V2StrategyTest(test.TestCase, parameterized.TestCase): 139 140 def setUp(self): 141 super().setUp() 142 tf2.enable() 143 144 @combinations.generate( 145 combinations.combine(strategy=[ 146 strategy_combinations.one_device_strategy, 147 strategy_combinations.one_device_strategy_gpu, 148 strategy_combinations.one_device_strategy_gpu_on_worker_1, 149 strategy_combinations.one_device_strategy_on_worker_1 150 ])) 151 def testOneDevice(self, strategy): 152 self.assertIsInstance(strategy, one_device_strategy.OneDeviceStrategy) 153 154 @combinations.generate( 155 combinations.combine(strategy=[ 156 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 157 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 158 strategy_combinations.mirrored_strategy_with_one_cpu, 159 strategy_combinations.mirrored_strategy_with_one_gpu, 160 strategy_combinations.mirrored_strategy_with_two_gpus, 161 ])) 162 def testMirrored(self, strategy): 163 self.assertIsInstance(strategy, mirrored_strategy.MirroredStrategy) 164 165 @combinations.generate( 166 combinations.combine(strategy=[ 167 strategy_combinations.multi_worker_mirrored_2x1_cpu, 168 strategy_combinations.multi_worker_mirrored_2x1_gpu, 169 strategy_combinations.multi_worker_mirrored_2x2_gpu, 170 strategy_combinations.multi_worker_mirrored_4x1_cpu, 171 ])) 172 def testMultiWorkerMirrored(self, strategy): 173 self.assertIsInstance( 174 strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy) 175 176 @combinations.generate( 177 combinations.combine(strategy=[ 178 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 179 strategy_combinations.central_storage_strategy_with_two_gpus, 180 ])) 181 def testCentralStorage(self, strategy): 182 self.assertIsInstance(strategy, 183 central_storage_strategy.CentralStorageStrategy) 184 185 @combinations.generate( 186 combinations.combine(strategy=strategy_combinations.tpu_strategies)) 187 def testTPU(self, strategy): 188 self.assertIsInstance( 189 strategy, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)) 190 191 @combinations.generate( 192 combinations.combine(strategy=[ 193 strategy_combinations.parameter_server_strategy_3worker_2ps_cpu, 194 strategy_combinations.parameter_server_strategy_1worker_2ps_cpu, 195 strategy_combinations.parameter_server_strategy_3worker_2ps_1gpu, 196 strategy_combinations.parameter_server_strategy_1worker_2ps_1gpu, 197 ])) 198 def testParameterServer(self, strategy): 199 self.assertIsInstance( 200 strategy, parameter_server_strategy_v2.ParameterServerStrategyV2) 201 202 @combinations.generate( 203 combinations.combine(strategy=[ 204 strategy_combinations.parameter_server_strategy_3worker_2ps_cpu, 205 strategy_combinations.parameter_server_strategy_1worker_2ps_cpu, 206 strategy_combinations.parameter_server_strategy_3worker_2ps_1gpu, 207 strategy_combinations.parameter_server_strategy_1worker_2ps_1gpu, 208 ])) 209 def testParameterServer(self, strategy): 210 self.assertIsInstance( 211 strategy, parameter_server_strategy_v2.ParameterServerStrategyV2) 212 213 214if __name__ == "__main__": 215 test_util.main() 216