xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/cross_device_utils_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for cross_device_utils."""
16
17from absl.testing import parameterized
18
19from tensorflow.python.distribute import combinations
20from tensorflow.python.distribute import cross_device_utils
21from tensorflow.python.distribute import device_util
22from tensorflow.python.eager import test
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import indexed_slices
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30
31
32class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
33
34  def _assert_values_equal(self, left, right):
35    self.assertAllEqual(
36        self.evaluate(ops.convert_to_tensor(left)),
37        self.evaluate(ops.convert_to_tensor(right)))
38
39  @test_util.run_in_graph_and_eager_modes
40  def testAggregateTensors(self):
41    t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
42    t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]])
43    total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
44    result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1])
45    self._assert_values_equal(total, result)
46
47  @test_util.run_in_graph_and_eager_modes
48  def testAggregateIndexedSlices(self):
49    t0 = math_ops._as_indexed_slices(
50        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
51    t1 = math_ops._as_indexed_slices(
52        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
53    total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
54    result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1])
55    self.assertIsInstance(result, indexed_slices.IndexedSlices)
56    self._assert_values_equal(total, result)
57
58  @test_util.run_in_graph_and_eager_modes
59  def testDivideTensor(self):
60    t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
61    n = 2
62    expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
63    result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n)
64    self._assert_values_equal(expected, result)
65
66  @test_util.run_in_graph_and_eager_modes
67  def testDivideIndexedSlices(self):
68    t = math_ops._as_indexed_slices(
69        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
70    n = 2
71    expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
72    result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n)
73    self.assertIsInstance(result, indexed_slices.IndexedSlices)
74    self._assert_values_equal(expected, result)
75
76  @test_util.run_in_graph_and_eager_modes
77  def testIsIndexedSlices(self):
78    t = math_ops._as_indexed_slices(
79        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
80    self.assertTrue(cross_device_utils.is_indexed_slices(t))
81
82  @combinations.generate(combinations.combine(
83      mode=["graph", "eager"],
84      required_gpus=1))
85  def testCopyTensor(self):
86    with ops.device("/cpu:0"):
87      t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
88    destination = "/gpu:0"
89    result = cross_device_utils.copy_tensor_or_indexed_slices_to_device(
90        t, destination)
91
92    self._assert_values_equal(t, result)
93    self.assertEqual(device_util.resolve(destination),
94                     device_util.resolve(result.device))
95
96  @combinations.generate(combinations.combine(
97      mode=["graph", "eager"],
98      required_gpus=1))
99  def testCopyIndexedSlices(self):
100    with ops.device("/cpu:0"):
101      t = math_ops._as_indexed_slices(
102          constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
103    destination = "/gpu:0"
104    result = cross_device_utils.copy_tensor_or_indexed_slices_to_device(
105        t, destination)
106
107    self.assertIsInstance(result, indexed_slices.IndexedSlices)
108    self._assert_values_equal(t, result)
109    self.assertEqual(
110        device_util.resolve(destination), device_util.resolve(result.device))
111
112  @combinations.generate(
113      combinations.combine(mode=["graph", "eager"], required_gpus=1))
114  def testCopyIndexedSlicesNoDenseShape(self):
115    with ops.device("/cpu:0"):
116      t = indexed_slices.IndexedSlices(
117          indices=array_ops.identity([0]), values=array_ops.identity([1.]))
118    destination = "/gpu:0"
119    result = cross_device_utils.copy_tensor_or_indexed_slices_to_device(
120        t, destination)
121
122    self.assertIsInstance(result, indexed_slices.IndexedSlices)
123    self.assertAllEqual(t.indices, result.indices)
124    self.assertAllEqual(t.values, result.values)
125    self.assertEqual(
126        device_util.resolve(destination), device_util.resolve(result.device))
127
128
129class GroupBySizeTest(test.TestCase):
130
131  def testPreferLargerPack(self):
132    # Each packs except the last one should be equal or larger than
133    # bytes_per_pack.
134    values = [
135        # size = 2 * 4 * 4 * 4 = 128
136        array_ops.ones([2, 4, 4], dtype=dtypes.float32),
137        # size = 8 * 4 = 32
138        array_ops.ones([8], dtype=dtypes.int32),
139        # size = 10 * 10 * 8 = 800
140        array_ops.ones([10, 10], dtype=dtypes.int64),
141        # size = 1 * 4 = 4
142        array_ops.ones([1], dtype=dtypes.int32),
143    ]
144    packs = cross_device_utils.group_by_size(values, bytes_per_pack=200)
145    self.assertLen(packs, 2)
146    self.assertLen(packs[0], 3)
147    self.assertEqual(packs[0][0].shape, [2, 4, 4])
148    self.assertEqual(packs[0][1].shape, [8])
149    self.assertEqual(packs[0][2].shape, [10, 10])
150    self.assertLen(packs[1], 1)
151    self.assertEqual(packs[1][0].shape, [1])
152
153  def testZeroBytesPerPack(self):
154    values = [
155        array_ops.ones([1], dtype=dtypes.float32),
156        array_ops.ones([2], dtype=dtypes.float32),
157    ]
158    packs = cross_device_utils.group_by_size(values, bytes_per_pack=0)
159    self.assertLen(packs, 1)
160    self.assertLen(packs[0], 2)
161    self.assertEqual(packs[0][0].shape, [1])
162    self.assertEqual(packs[0][1].shape, [2])
163
164  def testUnknownShape(self):
165    def create_placeholder(shape, dtype):
166      with ops.Graph().as_default():
167        return array_ops.placeholder(dtype=dtype, shape=shape)
168
169    values = [
170        array_ops.ones([10, 10], dtype=dtypes.float32),
171        create_placeholder([None, 10], dtype=dtypes.float32),
172    ]
173    packs = cross_device_utils.group_by_size(values, bytes_per_pack=1)
174    self.assertLen(packs, 1)
175    self.assertEqual(packs[0], values)
176
177
178if __name__ == "__main__":
179  test.main()
180