xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/v1/all_reduce_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for all_reduce."""
16
17import time
18
19import numpy as np
20
21from tensorflow.core.framework import types_pb2
22from tensorflow.python.distribute.v1 import all_reduce as ar
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import state_ops
30from tensorflow.python.platform import test
31from tensorflow.python.platform import tf_logging
32
33
34class AllReduceTest(test_util.TensorFlowTestCase):
35
36  @test_util.run_deprecated_v1
37  def testFlattenTensorsShapesDefined(self):
38    x = array_ops.placeholder(types_pb2.DT_FLOAT, [None])
39    with self.assertRaisesRegex(ValueError, "must have statically known shape"):
40      ar._flatten_tensors([x, x])
41
42  def testRingPermutations(self):
43    # 0 devices
44    pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, [])
45    self.assertEqual(pred_by_c_d, [])
46    self.assertEqual(rank_by_c_d, [])
47    # 1 worker, 1 subchunk cases
48    pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0])
49    self.assertEqual(pred_by_c_d, [[0]])
50    self.assertEqual(rank_by_c_d, [[0]])
51    pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0, 1, 2])
52    self.assertEqual(pred_by_c_d, [[2, 0, 1]])
53    self.assertEqual(rank_by_c_d, [[0, 1, 2]])
54    # multiple workers, 1 subchunk cases
55    pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [0, 1, 2])
56    self.assertEqual(pred_by_c_d, [[5, 0, 1, 2, 3, 4]])
57    self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5]])
58    pred_by_c_d, rank_by_c_d = ar._ring_permutations(3, 1, [0, 1, 2])
59    self.assertEqual(pred_by_c_d, [[8, 0, 1, 2, 3, 4, 5, 6, 7]])
60    self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7, 8]])
61    pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [2, 1, 0])
62    self.assertEqual(pred_by_c_d, [[1, 2, 3, 4, 5, 0]])
63    self.assertEqual(rank_by_c_d, [[2, 1, 0, 5, 4, 3]])
64    # 1 worker, multiple subchunk cases
65    pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3])
66    self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2]])
67    self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [2, 3, 0, 1]])
68    pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 4, [0, 1, 2, 3])
69    self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2],
70                                   [3, 0, 1, 2], [3, 0, 1, 2]])
71    self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [3, 0, 1, 2],
72                                   [2, 3, 0, 1], [1, 2, 3, 0]])
73    # multiple worker, multiple subchunk cases
74    pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 1, 2, 3])
75    self.assertEqual(pred_by_c_d, [[7, 0, 1, 2, 3, 4, 5, 6],
76                                   [3, 0, 5, 2, 7, 4, 1, 6]])
77    self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7],
78                                   [2, 3, 0, 1, 6, 7, 4, 5]])
79    pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 3, 2, 1])
80    self.assertEqual(pred_by_c_d, [[5, 2, 3, 0, 1, 6, 7, 4],
81                                   [1, 2, 7, 0, 5, 6, 3, 4]])
82    self.assertEqual(rank_by_c_d, [[0, 3, 2, 1, 4, 7, 6, 5],
83                                   [2, 1, 0, 3, 6, 5, 4, 7]])
84
85  def _buildInput(self, num_workers, num_gpus):
86    t8 = constant_op.constant(
87        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
88        types_pb2.DT_FLOAT)
89    input_tensors = []
90    device_names = []
91    for w in range(0, num_workers):
92      for d in range(0, num_gpus):
93        dn = "/replica:0/task:%d/device:GPU:%d" % (w, d % num_gpus)
94        device_names.append(dn)
95        with ops.device(dn):
96          input_tensors.append(array_ops.identity(t8))
97    return input_tensors, device_names
98
99  @test_util.run_deprecated_v1
100  def testBuildRingGatherPassStructure(self):
101    # 1 worker, 1 device
102    input_tensors, device_names = self._buildInput(1, 1)
103    pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0])
104    output_tensors = ar._build_ring_gather(input_tensors, device_names, 1,
105                                           pred_by_c_d, rank_by_c_d,
106                                           math_ops.add)
107    self.assertEqual(output_tensors, input_tensors)
108    # 1 worker, 4 devices, 2 subchunks
109    input_tensors, device_names = self._buildInput(1, 4)
110    pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3])
111    output_tensors, pad_len = ar._build_ring_gather(
112        input_tensors, device_names, 2, pred_by_c_d, rank_by_c_d, math_ops.add)
113    self.assertEqual(0, pad_len)
114    # same number outputs as inputs
115    self.assertEqual(len(output_tensors), len(input_tensors))
116    num_chunks = 2 * len(input_tensors)
117    tlen = tensor_shape.dimension_value(input_tensors[0].shape[0])
118    for otl in output_tensors:
119      self.assertEqual(len(otl), num_chunks)
120      for ot in otl:
121        self.assertEqual(ot.shape, [tlen//num_chunks])
122
123  def _buildInitialVars(self, shape, dev_list):
124    values = []
125    num_devices = len(dev_list)
126    dim = np.prod(shape, dtype=int) if shape else 1
127    for d in range(0, num_devices):
128      with ops.device(dev_list[d]):
129        npt = np.zeros(shape).astype(np.float32)
130        alias = np.frombuffer(npt.data, dtype=np.float32)
131        for i in range(0, dim):
132          alias[i] = i + 0.01 * d
133        var = state_ops.variable_op(shape, types_pb2.DT_FLOAT)
134        state_ops.init_variable(var, npt).op.run()
135        values.append(var)
136    return values
137
138  # pylint: disable=g-long-lambda
139
140  def _buildRing(self, num_workers, num_gpus, subdiv):
141    gpu_perm = range(0, num_gpus)
142    return lambda x, un_op: ar.build_ring_all_reduce(
143        x, num_workers, subdiv, gpu_perm, math_ops.add, un_op)
144
145  def _testAllReduce(self, num_workers, num_gpus, shape, build_f):
146    # Use local CPU as device for all inputs.
147    num_devices = num_workers * num_gpus
148    dev_list = ["/replica:0/task:0/device:CPU:0"
149                for _ in range(num_devices)]
150    with self.cached_session():
151      input_tensors = self._buildInitialVars(shape, dev_list)
152      un_op = lambda x: math_ops.div(
153          x, constant_op.constant(num_devices, dtype=types_pb2.DT_FLOAT))
154      simple_sum = math_ops.add_n(input_tensors)
155      simple_sum.op.run()
156      output_tensors = build_f(input_tensors, un_op)
157      sum_reduced = math_ops.add_n(output_tensors)
158      sum_reduced.op.run()
159      self.assertAllClose(sum_reduced, self.evaluate(simple_sum))
160
161  def _testRingAllReduce(self, num_workers, num_gpus, shape, subdiv):
162    start_time = time.time()
163    build_f = self._buildRing(num_workers, num_gpus, subdiv)
164    self._testAllReduce(num_workers, num_gpus, shape, build_f)
165    elapsed = time.time() - start_time
166    tf_logging.info("RingAllReduce num_workers=%d num_gpus=%d shape=%s "
167                    "subdiv=%d elapsed=%f" %
168                    (num_workers, num_gpus, shape, subdiv, elapsed))
169
170  @test_util.run_deprecated_v1
171  def testRingAllReduce(self):
172    self._testRingAllReduce(1, 2, [], 1)
173    self._testRingAllReduce(1, 2, [8], 1)
174    self._testRingAllReduce(1, 2, [4, 4], 1)
175    self._testRingAllReduce(6, 1, [8], 1)
176    self._testRingAllReduce(1, 8, [32], 1)
177    self._testRingAllReduce(1, 8, [120], 1)
178    self._testRingAllReduce(2, 8, [7, 13], 1)
179    self._testRingAllReduce(2, 8, [8, 8], 2)
180    self._testRingAllReduce(2, 8, [8, 8], 4)
181    # TODO(tucker): The following test is surprisingly slow.
182    # Diagnose and fix before re-enabling.
183    # self._testRingAllReduce(4, 8, [8, 8, 2], 4)
184
185  def _buildShuffle(self, num_workers, num_gpus, num_shards):
186    # Use local CPU for all shuffle shards
187    gather_devices = ["/replica:0/task:0/device:CPU:0"
188                      for _ in range(num_shards)]
189    return lambda x, un_op: ar.build_shuffle_all_reduce(
190        x, gather_devices, math_ops.add_n, un_op)
191
192  def _testShuffleAllReduce(self, num_workers, num_gpus, shape, num_shards):
193    start_time = time.time()
194    build_f = self._buildShuffle(num_workers, num_gpus, num_shards)
195    self._testAllReduce(num_workers, num_gpus, shape, build_f)
196    elapsed = time.time() - start_time
197    tf_logging.info("ShuffleAllReduce num_workers=%d num_gpus=%d shape=%s "
198                    "elapsed=%f" % (num_workers, num_gpus, shape, elapsed))
199
200  @test_util.run_deprecated_v1
201  def testShuffleAllReduce(self):
202    self._testShuffleAllReduce(1, 2, [], 1)
203    self._testShuffleAllReduce(1, 2, [8], 1)
204    self._testShuffleAllReduce(1, 2, [4, 4], 1)
205    self._testShuffleAllReduce(1, 8, [32], 1)
206    self._testShuffleAllReduce(1, 8, [120], 1)
207    self._testShuffleAllReduce(2, 8, [7, 13], 3)
208    self._testShuffleAllReduce(2, 8, [8, 8], 2)
209    self._testShuffleAllReduce(2, 8, [8, 8], 4)
210    self._testShuffleAllReduce(4, 8, [8, 8, 2], 4)
211
212  def _buildRecursiveHD(self, num_workers, num_gpus):
213    return lambda x, un_op: ar.build_recursive_hd_all_reduce(
214        x, math_ops.add, un_op)
215
216  # pylint: enable=g-long-lambda
217
218  def _testRecursiveHDAllReduce(self, num_workers, num_gpus, shape):
219    start_time = time.time()
220    build_f = self._buildRecursiveHD(num_workers, num_gpus)
221    self._testAllReduce(num_workers, num_gpus, shape, build_f)
222    elapsed = time.time() - start_time
223    tf_logging.info("RecursiveHDAllReduce num_workers=%d num_gpus=%d "
224                    "shape=%s elapsed=%f" %
225                    (num_workers, num_gpus, shape, elapsed))
226
227  @test_util.run_deprecated_v1
228  def testRecursiveHDAllReduce(self):
229    self._testRecursiveHDAllReduce(1, 2, [8])
230    self._testRecursiveHDAllReduce(1, 2, [4, 4])
231    self._testRecursiveHDAllReduce(1, 8, [32])
232    self._testRecursiveHDAllReduce(1, 8, [120])
233    self._testRecursiveHDAllReduce(2, 8, [8, 8])
234    self._testRecursiveHDAllReduce(4, 8, [8, 8, 2])
235
236
237if __name__ == "__main__":
238  test.main()
239