xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/multi_worker_continuous_run_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for continuous runs using cross-worker collective ops."""
16
17import json
18import os
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python.distribute import collective_all_reduce_strategy
24from tensorflow.python.distribute import combinations
25from tensorflow.python.distribute import multi_process_runner
26from tensorflow.python.distribute import multi_worker_test_base as test_base
27from tensorflow.python.distribute import reduce_util
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.eager import test
31from tensorflow.python.framework import config
32from tensorflow.python.framework import errors_impl
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import test_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import variable_scope
37
38
39# TODO(b/151232436): This test doesn't work with check health enabled because it
40# relies on barrier around creating strategies. Check health performs
41# communications inside strategy constructor, which makes the barrier
42# ineffective.
43CollectiveAllReduceExtended = (
44    collective_all_reduce_strategy.CollectiveAllReduceExtended)
45CollectiveAllReduceExtended._enable_check_health = False
46
47
48NUM_WORKERS = 5
49
50
51# TODO(b/143286947): expand the test to cover fault tolerance and elasticity
52class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
53
54  def setUp(self):
55    self._gpus = config.list_physical_devices('GPU')
56    self._local_device = '/device:GPU:0' if self._gpus else '/device:CPU:0'
57    super(MultiWorkerContinuousRunTest, self).setUp()
58
59  def _maybe_setup_gpus(self):
60    if self._gpus:
61      # Set virtual GPU with memory limit of 64MB so that multiple worker
62      # processes can share the physical GPU
63      config.set_logical_device_configuration(
64          self._gpus[0], [context.LogicalDeviceConfiguration(64)])
65
66  @combinations.generate(combinations.combine(mode=['eager']))
67  def testAllReduceContinuousRun(self, mode):
68    tensor_shape = [2, 2]
69
70    def worker_step_fn(worker_id):
71      strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
72      # Make sure the processeses are in sync after updating the cluster
73      multi_process_runner.get_barrier().wait()
74
75      @def_function.function
76      def run_reduce():
77        with ops.device(self._local_device):
78          t_in = array_ops.ones(tensor_shape) * worker_id
79          return strategy.reduce(reduce_util.ReduceOp.MEAN, t_in, axis=None)
80
81      t_out = run_reduce()
82      # Element values from the workers are
83      #     0, 1, ..., (NUM_WORKERS - 1)
84      expected_mean = (NUM_WORKERS - 1) / 2
85      expected_out = np.ones(tensor_shape) * expected_mean
86      self.assertAllClose(t_out, expected_out)
87
88    def worker_fn():
89      self._maybe_setup_gpus()
90      tf_config = json.loads(os.environ['TF_CONFIG'])
91      worker_id = tf_config['task']['index']
92      for _ in range(20):
93        worker_step_fn(worker_id)
94
95    with test_util.skip_if_error(self, errors_impl.UnavailableError):
96      multi_process_runner.run(
97          worker_fn,
98          cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS))
99
100  @combinations.generate(combinations.combine(mode=['eager']))
101  def testVariableInitializationWithChangingShape(self, mode):
102
103    def worker_step_fn(worker_id, num_dims):
104      strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
105      # Make sure the processeses are in sync after updating the cluster
106      multi_process_runner.get_barrier().wait()
107      tensor_shape = [2] * num_dims
108
109      def variable_fn():
110        with ops.device(self._local_device):
111          # The initial value will be broadcasted from worker 0 to others.
112          initial_value = (array_ops.ones(tensor_shape) if worker_id == 0 else
113                           array_ops.zeros(tensor_shape))
114          var = variable_scope.get_variable(name='x', initializer=initial_value)
115          return array_ops.identity(var)
116
117      t_out = strategy.extended.call_for_each_replica(variable_fn)
118      expected_out = np.ones(tensor_shape)
119      self.assertAllClose(t_out, expected_out)
120
121    def worker_fn():
122      self._maybe_setup_gpus()
123      tf_config = json.loads(os.environ['TF_CONFIG'])
124      worker_id = tf_config['task']['index']
125      for i in range(20):
126        worker_step_fn(worker_id, num_dims=(i + 1))
127
128    with test_util.skip_if_error(self, errors_impl.UnavailableError):
129      multi_process_runner.run(
130          worker_fn,
131          cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS))
132
133
134if __name__ == '__main__':
135  multi_process_runner.test_main()
136