1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for cross_device_utils.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python.distribute import combinations 20from tensorflow.python.distribute import cross_device_utils 21from tensorflow.python.distribute import device_util 22from tensorflow.python.eager import test 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import indexed_slices 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import math_ops 30 31 32class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): 33 34 def _assert_values_equal(self, left, right): 35 self.assertAllEqual( 36 self.evaluate(ops.convert_to_tensor(left)), 37 self.evaluate(ops.convert_to_tensor(right))) 38 39 @test_util.run_in_graph_and_eager_modes 40 def testAggregateTensors(self): 41 t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) 42 t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) 43 total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) 44 result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1]) 45 self._assert_values_equal(total, result) 46 47 @test_util.run_in_graph_and_eager_modes 48 def testAggregateIndexedSlices(self): 49 t0 = math_ops._as_indexed_slices( 50 constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) 51 t1 = math_ops._as_indexed_slices( 52 constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) 53 total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) 54 result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1]) 55 self.assertIsInstance(result, indexed_slices.IndexedSlices) 56 self._assert_values_equal(total, result) 57 58 @test_util.run_in_graph_and_eager_modes 59 def testDivideTensor(self): 60 t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) 61 n = 2 62 expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) 63 result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n) 64 self._assert_values_equal(expected, result) 65 66 @test_util.run_in_graph_and_eager_modes 67 def testDivideIndexedSlices(self): 68 t = math_ops._as_indexed_slices( 69 constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) 70 n = 2 71 expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) 72 result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n) 73 self.assertIsInstance(result, indexed_slices.IndexedSlices) 74 self._assert_values_equal(expected, result) 75 76 @test_util.run_in_graph_and_eager_modes 77 def testIsIndexedSlices(self): 78 t = math_ops._as_indexed_slices( 79 constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) 80 self.assertTrue(cross_device_utils.is_indexed_slices(t)) 81 82 @combinations.generate(combinations.combine( 83 mode=["graph", "eager"], 84 required_gpus=1)) 85 def testCopyTensor(self): 86 with ops.device("/cpu:0"): 87 t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) 88 destination = "/gpu:0" 89 result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( 90 t, destination) 91 92 self._assert_values_equal(t, result) 93 self.assertEqual(device_util.resolve(destination), 94 device_util.resolve(result.device)) 95 96 @combinations.generate(combinations.combine( 97 mode=["graph", "eager"], 98 required_gpus=1)) 99 def testCopyIndexedSlices(self): 100 with ops.device("/cpu:0"): 101 t = math_ops._as_indexed_slices( 102 constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) 103 destination = "/gpu:0" 104 result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( 105 t, destination) 106 107 self.assertIsInstance(result, indexed_slices.IndexedSlices) 108 self._assert_values_equal(t, result) 109 self.assertEqual( 110 device_util.resolve(destination), device_util.resolve(result.device)) 111 112 @combinations.generate( 113 combinations.combine(mode=["graph", "eager"], required_gpus=1)) 114 def testCopyIndexedSlicesNoDenseShape(self): 115 with ops.device("/cpu:0"): 116 t = indexed_slices.IndexedSlices( 117 indices=array_ops.identity([0]), values=array_ops.identity([1.])) 118 destination = "/gpu:0" 119 result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( 120 t, destination) 121 122 self.assertIsInstance(result, indexed_slices.IndexedSlices) 123 self.assertAllEqual(t.indices, result.indices) 124 self.assertAllEqual(t.values, result.values) 125 self.assertEqual( 126 device_util.resolve(destination), device_util.resolve(result.device)) 127 128 129class GroupBySizeTest(test.TestCase): 130 131 def testPreferLargerPack(self): 132 # Each packs except the last one should be equal or larger than 133 # bytes_per_pack. 134 values = [ 135 # size = 2 * 4 * 4 * 4 = 128 136 array_ops.ones([2, 4, 4], dtype=dtypes.float32), 137 # size = 8 * 4 = 32 138 array_ops.ones([8], dtype=dtypes.int32), 139 # size = 10 * 10 * 8 = 800 140 array_ops.ones([10, 10], dtype=dtypes.int64), 141 # size = 1 * 4 = 4 142 array_ops.ones([1], dtype=dtypes.int32), 143 ] 144 packs = cross_device_utils.group_by_size(values, bytes_per_pack=200) 145 self.assertLen(packs, 2) 146 self.assertLen(packs[0], 3) 147 self.assertEqual(packs[0][0].shape, [2, 4, 4]) 148 self.assertEqual(packs[0][1].shape, [8]) 149 self.assertEqual(packs[0][2].shape, [10, 10]) 150 self.assertLen(packs[1], 1) 151 self.assertEqual(packs[1][0].shape, [1]) 152 153 def testZeroBytesPerPack(self): 154 values = [ 155 array_ops.ones([1], dtype=dtypes.float32), 156 array_ops.ones([2], dtype=dtypes.float32), 157 ] 158 packs = cross_device_utils.group_by_size(values, bytes_per_pack=0) 159 self.assertLen(packs, 1) 160 self.assertLen(packs[0], 2) 161 self.assertEqual(packs[0][0].shape, [1]) 162 self.assertEqual(packs[0][1].shape, [2]) 163 164 def testUnknownShape(self): 165 def create_placeholder(shape, dtype): 166 with ops.Graph().as_default(): 167 return array_ops.placeholder(dtype=dtype, shape=shape) 168 169 values = [ 170 array_ops.ones([10, 10], dtype=dtypes.float32), 171 create_placeholder([None, 10], dtype=dtypes.float32), 172 ] 173 packs = cross_device_utils.group_by_size(values, bytes_per_pack=1) 174 self.assertLen(packs, 1) 175 self.assertEqual(packs[0], values) 176 177 178if __name__ == "__main__": 179 test.main() 180