1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for class OneDeviceStrategy.""" 16 17from tensorflow.python import tf2 18from tensorflow.python.data.ops import dataset_ops 19from tensorflow.python.distribute import combinations 20from tensorflow.python.distribute import distribute_lib 21from tensorflow.python.distribute import strategy_combinations 22from tensorflow.python.distribute import strategy_test_lib 23from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 24from tensorflow.python.eager import context 25from tensorflow.python.eager import test 26from tensorflow.python.framework import device as tf_device 27 28 29@combinations.generate( 30 combinations.combine( 31 distribution=[ 32 strategy_combinations.one_device_strategy, 33 strategy_combinations.one_device_strategy_gpu 34 ], 35 mode=["eager", "graph"])) 36class OneDeviceStrategyTest( 37 strategy_test_lib.DistributionTestBase, 38 strategy_test_lib.OneDeviceDistributionTestBase): 39 40 def testMinimizeLoss(self, distribution): 41 if context.executing_eagerly(): 42 self._test_minimize_loss_eager(distribution) 43 else: 44 self._test_minimize_loss_graph(distribution) 45 46 def testReplicaId(self, distribution): 47 self._test_replica_id(distribution) 48 49 def testCallAndMergeExceptions(self, distribution): 50 self._test_call_and_merge_exceptions(distribution) 51 52 def testReplicateDataset(self, distribution): 53 if tf2.enabled() and not context.executing_eagerly(): 54 self.skipTest("Skipping test since we do not support graph mode in TF 2") 55 dataset_fn = lambda: dataset_ops.Dataset.range(10) 56 expected_values = [[i] for i in range(10)] 57 input_fn = self._input_fn_to_test_input_context( 58 dataset_fn, 59 expected_num_replicas_in_sync=1, 60 expected_num_input_pipelines=1, 61 expected_input_pipeline_id=0) 62 self._test_input_fn_iterable(distribution, input_fn, expected_values) 63 64 def testMakeInputFnIteratorWithDataset(self, distribution): 65 dataset_fn = lambda: dataset_ops.Dataset.range(10) 66 expected_values = [[i] for i in range(10)] 67 input_fn = self._input_fn_to_test_input_context( 68 dataset_fn, 69 expected_num_replicas_in_sync=1, 70 expected_num_input_pipelines=1, 71 expected_input_pipeline_id=0) 72 iterator = distribution.make_input_fn_iterator(input_fn) 73 self._test_input_fn_iterator( 74 iterator, distribution.extended.worker_devices, expected_values) 75 76 def testMakeInputFnIteratorWithCallable(self, distribution): 77 def fn(): 78 dataset = dataset_ops.Dataset.range(10) 79 it = dataset_ops.make_one_shot_iterator(dataset) 80 return it.get_next 81 expected_values = [[i] for i in range(10)] 82 input_fn = self._input_fn_to_test_input_context( 83 fn, 84 expected_num_replicas_in_sync=1, 85 expected_num_input_pipelines=1, 86 expected_input_pipeline_id=0) 87 iterator = distribution.make_input_fn_iterator(input_fn) 88 self._test_input_fn_iterator( 89 iterator, distribution.extended.worker_devices, expected_values, 90 test_reinitialize=False, ignore_order=True) 91 92 def testNumpyDataset(self, distribution): 93 self._test_numpy_dataset(distribution) 94 95 def testRun(self, distribution): 96 self._test_run(distribution) 97 98 def testAllReduceSum(self, distribution): 99 self._test_all_reduce_sum(distribution) 100 101 def testAllReduceSumGradients(self, distribution): 102 self._test_all_reduce_sum_gradients(distribution) 103 104 def testAllReduceSumGradientTape(self, distribution): 105 self._test_all_reduce_sum_gradient_tape(distribution) 106 107 def testAllReduceMean(self, distribution): 108 self._test_all_reduce_mean(distribution) 109 110 def testAllReduceMeanGradients(self, distribution): 111 self._test_all_reduce_mean_gradients(distribution) 112 113 def testAllReduceMeanGradientTape(self, distribution): 114 self._test_all_reduce_mean_gradient_tape(distribution) 115 116 def testTrainableVariables(self, distribution): 117 self._test_trainable_variable(distribution) 118 119 def test_prefetch_to_device_dataset(self, distribution): 120 input_options = distribute_lib.InputOptions( 121 experimental_fetch_to_device=True) 122 dataset = dataset_ops.Dataset.range(100) 123 dataset = dataset.batch(distribution.num_replicas_in_sync) 124 dataset = distribution.experimental_distribute_dataset( 125 dataset, options=input_options) 126 if context.executing_eagerly(): 127 item = next(iter(dataset)) 128 else: 129 if isinstance(dataset, input_lib_v1.DistributedDatasetV1): 130 item = dataset.make_initializable_iterator().get_next() 131 else: 132 self.skipTest("unsupported test combination") 133 device_types = ( 134 tf_device.DeviceSpec.from_string(item.device).device_type) 135 expected_device_types = ( 136 tf_device.DeviceSpec.from_string( 137 distribution.extended.worker_devices[0]).device_type) 138 self.assertAllEqual(device_types, expected_device_types) 139 140 def test_prefetch_to_host_dataset(self, distribution): 141 input_options = distribute_lib.InputOptions( 142 experimental_fetch_to_device=False) 143 dataset = dataset_ops.Dataset.range(100) 144 dataset = dataset.batch(distribution.num_replicas_in_sync) 145 dataset = distribution.experimental_distribute_dataset( 146 dataset, options=input_options) 147 if context.executing_eagerly(): 148 item = next(iter(dataset)) 149 else: 150 if isinstance(dataset, input_lib_v1.DistributedDatasetV1): 151 item = dataset.make_initializable_iterator().get_next() 152 else: 153 self.skipTest("unsupported test combination") 154 self.assertAllEqual( 155 tf_device.DeviceSpec.from_string(item.device).device_type, "CPU") 156 157 158@combinations.generate( 159 combinations.combine( 160 distribution=[ 161 strategy_combinations.one_device_strategy_on_worker_1, 162 strategy_combinations.one_device_strategy_gpu_on_worker_1 163 ], 164 mode=["eager", "graph"])) 165class OneDeviceStrategyOnRemoteWorkerTest( 166 strategy_test_lib.DistributionTestBase, 167 strategy_test_lib.OneDeviceDistributionTestBase): 168 169 def testDeviceAndInputDeviceAreColocated(self, distribution): 170 self._test_device_and_input_device_are_colocated(distribution) 171 172 def testDeviceAndInputDeviceAreColocatedWithFunction(self, distribution): 173 self._test_device_and_input_device_are_colocated_with_function(distribution) 174 175 176if __name__ == "__main__": 177 test.main() 178