1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for all_reduce.""" 16 17import time 18 19import numpy as np 20 21from tensorflow.core.framework import types_pb2 22from tensorflow.python.distribute.v1 import all_reduce as ar 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import state_ops 30from tensorflow.python.platform import test 31from tensorflow.python.platform import tf_logging 32 33 34class AllReduceTest(test_util.TensorFlowTestCase): 35 36 @test_util.run_deprecated_v1 37 def testFlattenTensorsShapesDefined(self): 38 x = array_ops.placeholder(types_pb2.DT_FLOAT, [None]) 39 with self.assertRaisesRegex(ValueError, "must have statically known shape"): 40 ar._flatten_tensors([x, x]) 41 42 def testRingPermutations(self): 43 # 0 devices 44 pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, []) 45 self.assertEqual(pred_by_c_d, []) 46 self.assertEqual(rank_by_c_d, []) 47 # 1 worker, 1 subchunk cases 48 pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0]) 49 self.assertEqual(pred_by_c_d, [[0]]) 50 self.assertEqual(rank_by_c_d, [[0]]) 51 pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0, 1, 2]) 52 self.assertEqual(pred_by_c_d, [[2, 0, 1]]) 53 self.assertEqual(rank_by_c_d, [[0, 1, 2]]) 54 # multiple workers, 1 subchunk cases 55 pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [0, 1, 2]) 56 self.assertEqual(pred_by_c_d, [[5, 0, 1, 2, 3, 4]]) 57 self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5]]) 58 pred_by_c_d, rank_by_c_d = ar._ring_permutations(3, 1, [0, 1, 2]) 59 self.assertEqual(pred_by_c_d, [[8, 0, 1, 2, 3, 4, 5, 6, 7]]) 60 self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7, 8]]) 61 pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 1, [2, 1, 0]) 62 self.assertEqual(pred_by_c_d, [[1, 2, 3, 4, 5, 0]]) 63 self.assertEqual(rank_by_c_d, [[2, 1, 0, 5, 4, 3]]) 64 # 1 worker, multiple subchunk cases 65 pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3]) 66 self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2]]) 67 self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [2, 3, 0, 1]]) 68 pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 4, [0, 1, 2, 3]) 69 self.assertEqual(pred_by_c_d, [[3, 0, 1, 2], [3, 0, 1, 2], 70 [3, 0, 1, 2], [3, 0, 1, 2]]) 71 self.assertEqual(rank_by_c_d, [[0, 1, 2, 3], [3, 0, 1, 2], 72 [2, 3, 0, 1], [1, 2, 3, 0]]) 73 # multiple worker, multiple subchunk cases 74 pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 1, 2, 3]) 75 self.assertEqual(pred_by_c_d, [[7, 0, 1, 2, 3, 4, 5, 6], 76 [3, 0, 5, 2, 7, 4, 1, 6]]) 77 self.assertEqual(rank_by_c_d, [[0, 1, 2, 3, 4, 5, 6, 7], 78 [2, 3, 0, 1, 6, 7, 4, 5]]) 79 pred_by_c_d, rank_by_c_d = ar._ring_permutations(2, 2, [0, 3, 2, 1]) 80 self.assertEqual(pred_by_c_d, [[5, 2, 3, 0, 1, 6, 7, 4], 81 [1, 2, 7, 0, 5, 6, 3, 4]]) 82 self.assertEqual(rank_by_c_d, [[0, 3, 2, 1, 4, 7, 6, 5], 83 [2, 1, 0, 3, 6, 5, 4, 7]]) 84 85 def _buildInput(self, num_workers, num_gpus): 86 t8 = constant_op.constant( 87 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], 88 types_pb2.DT_FLOAT) 89 input_tensors = [] 90 device_names = [] 91 for w in range(0, num_workers): 92 for d in range(0, num_gpus): 93 dn = "/replica:0/task:%d/device:GPU:%d" % (w, d % num_gpus) 94 device_names.append(dn) 95 with ops.device(dn): 96 input_tensors.append(array_ops.identity(t8)) 97 return input_tensors, device_names 98 99 @test_util.run_deprecated_v1 100 def testBuildRingGatherPassStructure(self): 101 # 1 worker, 1 device 102 input_tensors, device_names = self._buildInput(1, 1) 103 pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 1, [0]) 104 output_tensors = ar._build_ring_gather(input_tensors, device_names, 1, 105 pred_by_c_d, rank_by_c_d, 106 math_ops.add) 107 self.assertEqual(output_tensors, input_tensors) 108 # 1 worker, 4 devices, 2 subchunks 109 input_tensors, device_names = self._buildInput(1, 4) 110 pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 2, [0, 1, 2, 3]) 111 output_tensors, pad_len = ar._build_ring_gather( 112 input_tensors, device_names, 2, pred_by_c_d, rank_by_c_d, math_ops.add) 113 self.assertEqual(0, pad_len) 114 # same number outputs as inputs 115 self.assertEqual(len(output_tensors), len(input_tensors)) 116 num_chunks = 2 * len(input_tensors) 117 tlen = tensor_shape.dimension_value(input_tensors[0].shape[0]) 118 for otl in output_tensors: 119 self.assertEqual(len(otl), num_chunks) 120 for ot in otl: 121 self.assertEqual(ot.shape, [tlen//num_chunks]) 122 123 def _buildInitialVars(self, shape, dev_list): 124 values = [] 125 num_devices = len(dev_list) 126 dim = np.prod(shape, dtype=int) if shape else 1 127 for d in range(0, num_devices): 128 with ops.device(dev_list[d]): 129 npt = np.zeros(shape).astype(np.float32) 130 alias = np.frombuffer(npt.data, dtype=np.float32) 131 for i in range(0, dim): 132 alias[i] = i + 0.01 * d 133 var = state_ops.variable_op(shape, types_pb2.DT_FLOAT) 134 state_ops.init_variable(var, npt).op.run() 135 values.append(var) 136 return values 137 138 # pylint: disable=g-long-lambda 139 140 def _buildRing(self, num_workers, num_gpus, subdiv): 141 gpu_perm = range(0, num_gpus) 142 return lambda x, un_op: ar.build_ring_all_reduce( 143 x, num_workers, subdiv, gpu_perm, math_ops.add, un_op) 144 145 def _testAllReduce(self, num_workers, num_gpus, shape, build_f): 146 # Use local CPU as device for all inputs. 147 num_devices = num_workers * num_gpus 148 dev_list = ["/replica:0/task:0/device:CPU:0" 149 for _ in range(num_devices)] 150 with self.cached_session(): 151 input_tensors = self._buildInitialVars(shape, dev_list) 152 un_op = lambda x: math_ops.div( 153 x, constant_op.constant(num_devices, dtype=types_pb2.DT_FLOAT)) 154 simple_sum = math_ops.add_n(input_tensors) 155 simple_sum.op.run() 156 output_tensors = build_f(input_tensors, un_op) 157 sum_reduced = math_ops.add_n(output_tensors) 158 sum_reduced.op.run() 159 self.assertAllClose(sum_reduced, self.evaluate(simple_sum)) 160 161 def _testRingAllReduce(self, num_workers, num_gpus, shape, subdiv): 162 start_time = time.time() 163 build_f = self._buildRing(num_workers, num_gpus, subdiv) 164 self._testAllReduce(num_workers, num_gpus, shape, build_f) 165 elapsed = time.time() - start_time 166 tf_logging.info("RingAllReduce num_workers=%d num_gpus=%d shape=%s " 167 "subdiv=%d elapsed=%f" % 168 (num_workers, num_gpus, shape, subdiv, elapsed)) 169 170 @test_util.run_deprecated_v1 171 def testRingAllReduce(self): 172 self._testRingAllReduce(1, 2, [], 1) 173 self._testRingAllReduce(1, 2, [8], 1) 174 self._testRingAllReduce(1, 2, [4, 4], 1) 175 self._testRingAllReduce(6, 1, [8], 1) 176 self._testRingAllReduce(1, 8, [32], 1) 177 self._testRingAllReduce(1, 8, [120], 1) 178 self._testRingAllReduce(2, 8, [7, 13], 1) 179 self._testRingAllReduce(2, 8, [8, 8], 2) 180 self._testRingAllReduce(2, 8, [8, 8], 4) 181 # TODO(tucker): The following test is surprisingly slow. 182 # Diagnose and fix before re-enabling. 183 # self._testRingAllReduce(4, 8, [8, 8, 2], 4) 184 185 def _buildShuffle(self, num_workers, num_gpus, num_shards): 186 # Use local CPU for all shuffle shards 187 gather_devices = ["/replica:0/task:0/device:CPU:0" 188 for _ in range(num_shards)] 189 return lambda x, un_op: ar.build_shuffle_all_reduce( 190 x, gather_devices, math_ops.add_n, un_op) 191 192 def _testShuffleAllReduce(self, num_workers, num_gpus, shape, num_shards): 193 start_time = time.time() 194 build_f = self._buildShuffle(num_workers, num_gpus, num_shards) 195 self._testAllReduce(num_workers, num_gpus, shape, build_f) 196 elapsed = time.time() - start_time 197 tf_logging.info("ShuffleAllReduce num_workers=%d num_gpus=%d shape=%s " 198 "elapsed=%f" % (num_workers, num_gpus, shape, elapsed)) 199 200 @test_util.run_deprecated_v1 201 def testShuffleAllReduce(self): 202 self._testShuffleAllReduce(1, 2, [], 1) 203 self._testShuffleAllReduce(1, 2, [8], 1) 204 self._testShuffleAllReduce(1, 2, [4, 4], 1) 205 self._testShuffleAllReduce(1, 8, [32], 1) 206 self._testShuffleAllReduce(1, 8, [120], 1) 207 self._testShuffleAllReduce(2, 8, [7, 13], 3) 208 self._testShuffleAllReduce(2, 8, [8, 8], 2) 209 self._testShuffleAllReduce(2, 8, [8, 8], 4) 210 self._testShuffleAllReduce(4, 8, [8, 8, 2], 4) 211 212 def _buildRecursiveHD(self, num_workers, num_gpus): 213 return lambda x, un_op: ar.build_recursive_hd_all_reduce( 214 x, math_ops.add, un_op) 215 216 # pylint: enable=g-long-lambda 217 218 def _testRecursiveHDAllReduce(self, num_workers, num_gpus, shape): 219 start_time = time.time() 220 build_f = self._buildRecursiveHD(num_workers, num_gpus) 221 self._testAllReduce(num_workers, num_gpus, shape, build_f) 222 elapsed = time.time() - start_time 223 tf_logging.info("RecursiveHDAllReduce num_workers=%d num_gpus=%d " 224 "shape=%s elapsed=%f" % 225 (num_workers, num_gpus, shape, elapsed)) 226 227 @test_util.run_deprecated_v1 228 def testRecursiveHDAllReduce(self): 229 self._testRecursiveHDAllReduce(1, 2, [8]) 230 self._testRecursiveHDAllReduce(1, 2, [4, 4]) 231 self._testRecursiveHDAllReduce(1, 8, [32]) 232 self._testRecursiveHDAllReduce(1, 8, [120]) 233 self._testRecursiveHDAllReduce(2, 8, [8, 8]) 234 self._testRecursiveHDAllReduce(4, 8, [8, 8, 2]) 235 236 237if __name__ == "__main__": 238 test.main() 239