xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/strategy_combinations_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for a little bit of strategy_combinations."""
16
17from absl.testing import parameterized
18
19from tensorflow.python import tf2
20from tensorflow.python.distribute import central_storage_strategy
21from tensorflow.python.distribute import collective_all_reduce_strategy
22from tensorflow.python.distribute import combinations
23from tensorflow.python.distribute import mirrored_strategy
24from tensorflow.python.distribute import one_device_strategy
25from tensorflow.python.distribute import parameter_server_strategy_v2
26from tensorflow.python.distribute import reduce_util
27from tensorflow.python.distribute import strategy_combinations
28from tensorflow.python.distribute import test_util
29from tensorflow.python.distribute import tpu_strategy
30from tensorflow.python.eager import def_function
31from tensorflow.python.framework import constant_op
32from tensorflow.python.ops import array_ops
33from tensorflow.python.platform import test
34
35
36class StrategyCombinationsTest(test.TestCase, parameterized.TestCase):
37
38  @combinations.generate(
39      combinations.combine(
40          strategy=strategy_combinations.two_replica_strategies,
41          mode=["graph", "eager"]))
42  def testTwoReplicaStrategy(self, strategy):
43    with strategy.scope():
44
45      @def_function.function
46      def one():
47        return array_ops.identity(1.)
48
49      one_per_replica = strategy.run(one)
50      num_replicas = strategy.reduce(
51          reduce_util.ReduceOp.SUM, one_per_replica, axis=None)
52      self.assertEqual(self.evaluate(num_replicas), 2.)
53
54  @combinations.generate(
55      combinations.combine(
56          strategy=strategy_combinations.four_replica_strategies,
57          mode=["graph", "eager"]))
58  def testFourReplicaStrategy(self, strategy):
59    with strategy.scope():
60
61      @def_function.function
62      def one():
63        return array_ops.identity(1.)
64
65      one_per_replica = strategy.run(one)
66      num_replicas = strategy.reduce(
67          reduce_util.ReduceOp.SUM, one_per_replica, axis=None)
68      self.assertEqual(self.evaluate(num_replicas), 4.)
69
70  @combinations.generate(
71      combinations.combine(
72          distribution=[
73              strategy_combinations.mirrored_strategy_with_cpu_1_and_2
74          ],
75          mode=["graph", "eager"]))
76  def testMirrored2CPUs(self, distribution):
77    with distribution.scope():
78      one_per_replica = distribution.run(lambda: constant_op.constant(1))
79      num_replicas = distribution.reduce(
80          reduce_util.ReduceOp.SUM, one_per_replica, axis=None)
81      self.assertEqual(2, self.evaluate(num_replicas))
82
83
84class V1StrategyTest(test.TestCase, parameterized.TestCase):
85
86  def setUp(self):
87    super().setUp()
88    tf2.disable()
89
90  @combinations.generate(
91      combinations.combine(strategy=[
92          strategy_combinations.one_device_strategy,
93          strategy_combinations.one_device_strategy_gpu,
94          strategy_combinations.one_device_strategy_gpu_on_worker_1,
95          strategy_combinations.one_device_strategy_on_worker_1
96      ]))
97  def testOneDevice(self, strategy):
98    self.assertIsInstance(strategy, one_device_strategy.OneDeviceStrategyV1)
99
100  @combinations.generate(
101      combinations.combine(strategy=[
102          strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
103          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
104          strategy_combinations.mirrored_strategy_with_one_cpu,
105          strategy_combinations.mirrored_strategy_with_one_gpu,
106          strategy_combinations.mirrored_strategy_with_two_gpus,
107      ]))
108  def testMirrored(self, strategy):
109    self.assertIsInstance(strategy, mirrored_strategy.MirroredStrategyV1)
110
111  @combinations.generate(
112      combinations.combine(strategy=[
113          strategy_combinations.multi_worker_mirrored_2x1_cpu,
114          strategy_combinations.multi_worker_mirrored_2x1_gpu,
115          strategy_combinations.multi_worker_mirrored_2x2_gpu,
116          strategy_combinations.multi_worker_mirrored_4x1_cpu,
117      ]))
118  def testMultiWorkerMirrored(self, strategy):
119    # MultiWorkerMirroredStrategy combinations only supports V2.
120    self.assertIsInstance(
121        strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy)
122
123  @combinations.generate(
124      combinations.combine(strategy=[
125          strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
126          strategy_combinations.central_storage_strategy_with_two_gpus,
127      ]))
128  def testCentralStorage(self, strategy):
129    self.assertIsInstance(strategy,
130                          central_storage_strategy.CentralStorageStrategyV1)
131
132  @combinations.generate(
133      combinations.combine(strategy=strategy_combinations.tpu_strategies))
134  def testTPU(self, strategy):
135    self.assertIsInstance(strategy, tpu_strategy.TPUStrategyV1)
136
137
138class V2StrategyTest(test.TestCase, parameterized.TestCase):
139
140  def setUp(self):
141    super().setUp()
142    tf2.enable()
143
144  @combinations.generate(
145      combinations.combine(strategy=[
146          strategy_combinations.one_device_strategy,
147          strategy_combinations.one_device_strategy_gpu,
148          strategy_combinations.one_device_strategy_gpu_on_worker_1,
149          strategy_combinations.one_device_strategy_on_worker_1
150      ]))
151  def testOneDevice(self, strategy):
152    self.assertIsInstance(strategy, one_device_strategy.OneDeviceStrategy)
153
154  @combinations.generate(
155      combinations.combine(strategy=[
156          strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
157          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
158          strategy_combinations.mirrored_strategy_with_one_cpu,
159          strategy_combinations.mirrored_strategy_with_one_gpu,
160          strategy_combinations.mirrored_strategy_with_two_gpus,
161      ]))
162  def testMirrored(self, strategy):
163    self.assertIsInstance(strategy, mirrored_strategy.MirroredStrategy)
164
165  @combinations.generate(
166      combinations.combine(strategy=[
167          strategy_combinations.multi_worker_mirrored_2x1_cpu,
168          strategy_combinations.multi_worker_mirrored_2x1_gpu,
169          strategy_combinations.multi_worker_mirrored_2x2_gpu,
170          strategy_combinations.multi_worker_mirrored_4x1_cpu,
171      ]))
172  def testMultiWorkerMirrored(self, strategy):
173    self.assertIsInstance(
174        strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy)
175
176  @combinations.generate(
177      combinations.combine(strategy=[
178          strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
179          strategy_combinations.central_storage_strategy_with_two_gpus,
180      ]))
181  def testCentralStorage(self, strategy):
182    self.assertIsInstance(strategy,
183                          central_storage_strategy.CentralStorageStrategy)
184
185  @combinations.generate(
186      combinations.combine(strategy=strategy_combinations.tpu_strategies))
187  def testTPU(self, strategy):
188    self.assertIsInstance(
189        strategy, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2))
190
191  @combinations.generate(
192      combinations.combine(strategy=[
193          strategy_combinations.parameter_server_strategy_3worker_2ps_cpu,
194          strategy_combinations.parameter_server_strategy_1worker_2ps_cpu,
195          strategy_combinations.parameter_server_strategy_3worker_2ps_1gpu,
196          strategy_combinations.parameter_server_strategy_1worker_2ps_1gpu,
197      ]))
198  def testParameterServer(self, strategy):
199    self.assertIsInstance(
200        strategy, parameter_server_strategy_v2.ParameterServerStrategyV2)
201
202  @combinations.generate(
203      combinations.combine(strategy=[
204          strategy_combinations.parameter_server_strategy_3worker_2ps_cpu,
205          strategy_combinations.parameter_server_strategy_1worker_2ps_cpu,
206          strategy_combinations.parameter_server_strategy_3worker_2ps_1gpu,
207          strategy_combinations.parameter_server_strategy_1worker_2ps_1gpu,
208      ]))
209  def testParameterServer(self, strategy):
210    self.assertIsInstance(
211        strategy, parameter_server_strategy_v2.ParameterServerStrategyV2)
212
213
214if __name__ == "__main__":
215  test_util.main()
216