xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/one_device_strategy_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for class OneDeviceStrategy."""
16
17from tensorflow.python import tf2
18from tensorflow.python.data.ops import dataset_ops
19from tensorflow.python.distribute import combinations
20from tensorflow.python.distribute import distribute_lib
21from tensorflow.python.distribute import strategy_combinations
22from tensorflow.python.distribute import strategy_test_lib
23from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
24from tensorflow.python.eager import context
25from tensorflow.python.eager import test
26from tensorflow.python.framework import device as tf_device
27
28
29@combinations.generate(
30    combinations.combine(
31        distribution=[
32            strategy_combinations.one_device_strategy,
33            strategy_combinations.one_device_strategy_gpu
34        ],
35        mode=["eager", "graph"]))
36class OneDeviceStrategyTest(
37    strategy_test_lib.DistributionTestBase,
38    strategy_test_lib.OneDeviceDistributionTestBase):
39
40  def testMinimizeLoss(self, distribution):
41    if context.executing_eagerly():
42      self._test_minimize_loss_eager(distribution)
43    else:
44      self._test_minimize_loss_graph(distribution)
45
46  def testReplicaId(self, distribution):
47    self._test_replica_id(distribution)
48
49  def testCallAndMergeExceptions(self, distribution):
50    self._test_call_and_merge_exceptions(distribution)
51
52  def testReplicateDataset(self, distribution):
53    if tf2.enabled() and not context.executing_eagerly():
54      self.skipTest("Skipping test since we do not support graph mode in TF 2")
55    dataset_fn = lambda: dataset_ops.Dataset.range(10)
56    expected_values = [[i] for i in range(10)]
57    input_fn = self._input_fn_to_test_input_context(
58        dataset_fn,
59        expected_num_replicas_in_sync=1,
60        expected_num_input_pipelines=1,
61        expected_input_pipeline_id=0)
62    self._test_input_fn_iterable(distribution, input_fn, expected_values)
63
64  def testMakeInputFnIteratorWithDataset(self, distribution):
65    dataset_fn = lambda: dataset_ops.Dataset.range(10)
66    expected_values = [[i] for i in range(10)]
67    input_fn = self._input_fn_to_test_input_context(
68        dataset_fn,
69        expected_num_replicas_in_sync=1,
70        expected_num_input_pipelines=1,
71        expected_input_pipeline_id=0)
72    iterator = distribution.make_input_fn_iterator(input_fn)
73    self._test_input_fn_iterator(
74        iterator, distribution.extended.worker_devices, expected_values)
75
76  def testMakeInputFnIteratorWithCallable(self, distribution):
77    def fn():
78      dataset = dataset_ops.Dataset.range(10)
79      it = dataset_ops.make_one_shot_iterator(dataset)
80      return it.get_next
81    expected_values = [[i] for i in range(10)]
82    input_fn = self._input_fn_to_test_input_context(
83        fn,
84        expected_num_replicas_in_sync=1,
85        expected_num_input_pipelines=1,
86        expected_input_pipeline_id=0)
87    iterator = distribution.make_input_fn_iterator(input_fn)
88    self._test_input_fn_iterator(
89        iterator, distribution.extended.worker_devices, expected_values,
90        test_reinitialize=False, ignore_order=True)
91
92  def testNumpyDataset(self, distribution):
93    self._test_numpy_dataset(distribution)
94
95  def testRun(self, distribution):
96    self._test_run(distribution)
97
98  def testAllReduceSum(self, distribution):
99    self._test_all_reduce_sum(distribution)
100
101  def testAllReduceSumGradients(self, distribution):
102    self._test_all_reduce_sum_gradients(distribution)
103
104  def testAllReduceSumGradientTape(self, distribution):
105    self._test_all_reduce_sum_gradient_tape(distribution)
106
107  def testAllReduceMean(self, distribution):
108    self._test_all_reduce_mean(distribution)
109
110  def testAllReduceMeanGradients(self, distribution):
111    self._test_all_reduce_mean_gradients(distribution)
112
113  def testAllReduceMeanGradientTape(self, distribution):
114    self._test_all_reduce_mean_gradient_tape(distribution)
115
116  def testTrainableVariables(self, distribution):
117    self._test_trainable_variable(distribution)
118
119  def test_prefetch_to_device_dataset(self, distribution):
120    input_options = distribute_lib.InputOptions(
121        experimental_fetch_to_device=True)
122    dataset = dataset_ops.Dataset.range(100)
123    dataset = dataset.batch(distribution.num_replicas_in_sync)
124    dataset = distribution.experimental_distribute_dataset(
125        dataset, options=input_options)
126    if context.executing_eagerly():
127      item = next(iter(dataset))
128    else:
129      if isinstance(dataset, input_lib_v1.DistributedDatasetV1):
130        item = dataset.make_initializable_iterator().get_next()
131      else:
132        self.skipTest("unsupported test combination")
133    device_types = (
134        tf_device.DeviceSpec.from_string(item.device).device_type)
135    expected_device_types = (
136        tf_device.DeviceSpec.from_string(
137            distribution.extended.worker_devices[0]).device_type)
138    self.assertAllEqual(device_types, expected_device_types)
139
140  def test_prefetch_to_host_dataset(self, distribution):
141    input_options = distribute_lib.InputOptions(
142        experimental_fetch_to_device=False)
143    dataset = dataset_ops.Dataset.range(100)
144    dataset = dataset.batch(distribution.num_replicas_in_sync)
145    dataset = distribution.experimental_distribute_dataset(
146        dataset, options=input_options)
147    if context.executing_eagerly():
148      item = next(iter(dataset))
149    else:
150      if isinstance(dataset, input_lib_v1.DistributedDatasetV1):
151        item = dataset.make_initializable_iterator().get_next()
152      else:
153        self.skipTest("unsupported test combination")
154    self.assertAllEqual(
155        tf_device.DeviceSpec.from_string(item.device).device_type, "CPU")
156
157
158@combinations.generate(
159    combinations.combine(
160        distribution=[
161            strategy_combinations.one_device_strategy_on_worker_1,
162            strategy_combinations.one_device_strategy_gpu_on_worker_1
163        ],
164        mode=["eager", "graph"]))
165class OneDeviceStrategyOnRemoteWorkerTest(
166    strategy_test_lib.DistributionTestBase,
167    strategy_test_lib.OneDeviceDistributionTestBase):
168
169  def testDeviceAndInputDeviceAreColocated(self, distribution):
170    self._test_device_and_input_device_are_colocated(distribution)
171
172  def testDeviceAndInputDeviceAreColocatedWithFunction(self, distribution):
173    self._test_device_and_input_device_are_colocated_with_function(distribution)
174
175
176if __name__ == "__main__":
177  test.main()
178