1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for continuous runs using cross-worker collective ops.""" 16 17import json 18import os 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python.distribute import collective_all_reduce_strategy 24from tensorflow.python.distribute import combinations 25from tensorflow.python.distribute import multi_process_runner 26from tensorflow.python.distribute import multi_worker_test_base as test_base 27from tensorflow.python.distribute import reduce_util 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.eager import test 31from tensorflow.python.framework import config 32from tensorflow.python.framework import errors_impl 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import test_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import variable_scope 37 38 39# TODO(b/151232436): This test doesn't work with check health enabled because it 40# relies on barrier around creating strategies. Check health performs 41# communications inside strategy constructor, which makes the barrier 42# ineffective. 43CollectiveAllReduceExtended = ( 44 collective_all_reduce_strategy.CollectiveAllReduceExtended) 45CollectiveAllReduceExtended._enable_check_health = False 46 47 48NUM_WORKERS = 5 49 50 51# TODO(b/143286947): expand the test to cover fault tolerance and elasticity 52class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): 53 54 def setUp(self): 55 self._gpus = config.list_physical_devices('GPU') 56 self._local_device = '/device:GPU:0' if self._gpus else '/device:CPU:0' 57 super(MultiWorkerContinuousRunTest, self).setUp() 58 59 def _maybe_setup_gpus(self): 60 if self._gpus: 61 # Set virtual GPU with memory limit of 64MB so that multiple worker 62 # processes can share the physical GPU 63 config.set_logical_device_configuration( 64 self._gpus[0], [context.LogicalDeviceConfiguration(64)]) 65 66 @combinations.generate(combinations.combine(mode=['eager'])) 67 def testAllReduceContinuousRun(self, mode): 68 tensor_shape = [2, 2] 69 70 def worker_step_fn(worker_id): 71 strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() 72 # Make sure the processeses are in sync after updating the cluster 73 multi_process_runner.get_barrier().wait() 74 75 @def_function.function 76 def run_reduce(): 77 with ops.device(self._local_device): 78 t_in = array_ops.ones(tensor_shape) * worker_id 79 return strategy.reduce(reduce_util.ReduceOp.MEAN, t_in, axis=None) 80 81 t_out = run_reduce() 82 # Element values from the workers are 83 # 0, 1, ..., (NUM_WORKERS - 1) 84 expected_mean = (NUM_WORKERS - 1) / 2 85 expected_out = np.ones(tensor_shape) * expected_mean 86 self.assertAllClose(t_out, expected_out) 87 88 def worker_fn(): 89 self._maybe_setup_gpus() 90 tf_config = json.loads(os.environ['TF_CONFIG']) 91 worker_id = tf_config['task']['index'] 92 for _ in range(20): 93 worker_step_fn(worker_id) 94 95 with test_util.skip_if_error(self, errors_impl.UnavailableError): 96 multi_process_runner.run( 97 worker_fn, 98 cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS)) 99 100 @combinations.generate(combinations.combine(mode=['eager'])) 101 def testVariableInitializationWithChangingShape(self, mode): 102 103 def worker_step_fn(worker_id, num_dims): 104 strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() 105 # Make sure the processeses are in sync after updating the cluster 106 multi_process_runner.get_barrier().wait() 107 tensor_shape = [2] * num_dims 108 109 def variable_fn(): 110 with ops.device(self._local_device): 111 # The initial value will be broadcasted from worker 0 to others. 112 initial_value = (array_ops.ones(tensor_shape) if worker_id == 0 else 113 array_ops.zeros(tensor_shape)) 114 var = variable_scope.get_variable(name='x', initializer=initial_value) 115 return array_ops.identity(var) 116 117 t_out = strategy.extended.call_for_each_replica(variable_fn) 118 expected_out = np.ones(tensor_shape) 119 self.assertAllClose(t_out, expected_out) 120 121 def worker_fn(): 122 self._maybe_setup_gpus() 123 tf_config = json.loads(os.environ['TF_CONFIG']) 124 worker_id = tf_config['task']['index'] 125 for i in range(20): 126 worker_step_fn(worker_id, num_dims=(i + 1)) 127 128 with test_util.skip_if_error(self, errors_impl.UnavailableError): 129 multi_process_runner.run( 130 worker_fn, 131 cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS)) 132 133 134if __name__ == '__main__': 135 multi_process_runner.test_main() 136