xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Tests for tpu_function helpers."""
17
18from tensorflow.python.eager import def_function
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import importer
22from tensorflow.python.framework import ops
23from tensorflow.python.layers import convolutional
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import control_flow_util
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import special_math_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.platform import test
32from tensorflow.python.tpu import tpu
33from tensorflow.python.tpu import tpu_feed
34from tensorflow.python.tpu import training_loop
35from tensorflow.python.tpu.ops import tpu_ops
36
37
38class TPUContextTest(test.TestCase):
39
40  def testIsInContext(self):
41    """Test that control_flow_util can check that we're in a TPU context."""
42    with ops.Graph().as_default():
43      z1 = array_ops.identity(1)
44      pivot = control_flow_ops.no_op()
45      context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
46      context.Enter()
47      z2 = array_ops.identity(1)
48      context.Exit()
49      self.assertFalse(control_flow_util.IsInXLAContext(z1.op))
50      self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
51
52  def testHandlesNameCollision(self):
53    """Test AddValue handles name collisions for ops from different graphs."""
54    with ops.Graph().as_default():
55      z = array_ops.zeros([2, 3], name="a")
56      assert z.name == "a:0", "Expected: a:0, Found: %s" % z.name
57
58      @def_function.function
59      def f():
60        pivot = control_flow_ops.no_op()
61        context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
62        context.Enter()
63        array_ops.identity(z)  # Capture z.
64        z1 = array_ops.zeros([3, 2], name="a")
65        assert z1.name == "a:0", "Expected: a:0, Found: %s" % z1.name
66        z2 = array_ops.zeros([3, 2], name="a")
67        # Prior to fixing b/166794533 this would fail with a shape mismatch
68        # because context.AddValue would have cached `z` by its name which
69        # collides with z1's name.
70        result = z1 + z2
71        context.Exit()
72        return result
73
74      f.get_concrete_function()
75
76
77class TPULayerRewriteTest(test.TestCase):
78
79  def testUsingInfeedQueueWithRegularizer(self):
80    """Test that Layer regularizers can reference data created in loops."""
81
82    with ops.Graph().as_default():
83
84      def make_regularizer(scale):
85        def regularizer(inputs):
86          return scale * math_ops.reduce_sum(math_ops.square(inputs))
87        return regularizer
88
89      def training_step(inputs, scale):
90        outputs = convolutional.conv2d(
91            inputs,
92            filters=16,
93            kernel_size=(3, 3),
94            data_format="channels_first",
95            kernel_regularizer=make_regularizer(scale))
96        loss = math_ops.reduce_mean(math_ops.square(outputs))
97        return loss.op
98
99      inputs = array_ops.zeros(shape=(128, 32, 32, 16))
100      scale = array_ops.ones(shape=())
101      infeed = tpu_feed.InfeedQueue(
102          tuple_types=[dtypes.float32, dtypes.float32],
103          tuple_shapes=[inputs.shape, scale.shape])
104
105      def loop():
106        return training_loop.repeat(5, training_step, infeed_queue=infeed)
107
108      # This should not throw an error.
109      tpu.rewrite(loop)
110
111
112class TPUGraphPruneTest(test.TestCase):
113
114  def test_prune_unconnected_ops(self):
115    with ops.Graph().as_default():
116      a = array_ops.placeholder(dtype=dtypes.float32, name="a")
117      b = array_ops.placeholder(dtype=dtypes.float32, name="b")
118      constant_op.constant(1.0, name="constant")
119      x = variable_scope.get_variable(
120          name="x",
121          dtype=dtypes.float32,
122          shape=[],
123          use_resource=True,
124          initializer=init_ops.constant_initializer(2.0))
125      y = variable_scope.get_variable(
126          name="y",
127          dtype=dtypes.float32,
128          shape=[],
129          use_resource=True,
130          initializer=init_ops.constant_initializer(3.0))
131      math_ops.add(a, b)
132      math_ops.add(x, y)
133      graph_def = ops.get_default_graph().as_graph_def()
134
135      for node in graph_def.node:
136        # Attach a TPU_REPLICATE_ATTR to each node.
137        node.attr[tpu._TPU_REPLICATE_ATTR].s = b"0"
138        # Rewire placeholder "a" and variable "y" leaving them unconnected.
139        for (input_index, node_input) in enumerate(node.input):
140          if node_input == "b":
141            node.input[input_index] = "constant"
142          if node_input == "y":
143            node.input[input_index] = "x"
144
145    with ops.Graph().as_default() as graph:
146      # Reimport the graph and prune unconnected ops.
147      importer.import_graph_def(graph_def)
148      tpu.prune_unconnected_ops_from_xla(ops.get_default_graph())
149
150      # Verify that ops "a" and "x" still have TPU_REPLICATE_ATTR.
151      a = graph.get_operation_by_name("import/a").get_attr(
152          tpu._TPU_REPLICATE_ATTR)
153      self.assertEqual(b"0", a)
154      x = graph.get_operation_by_name("import/x").get_attr(
155          tpu._TPU_REPLICATE_ATTR)
156      self.assertEqual(b"0", x)
157      # Verify that ops "b" and "y" have TPU_REPLICATE_ATTR removed.
158      with self.assertRaisesRegex(
159          ValueError,
160          "Operation \'import/b\' has no attr named \'_tpu_replicate\'"):
161        graph.get_operation_by_name("import/b").get_attr(
162            tpu._TPU_REPLICATE_ATTR)
163      with self.assertRaisesRegex(
164          ValueError,
165          "Operation \'import/y\' has no attr named \'_tpu_replicate\'"):
166        graph.get_operation_by_name("import/y").get_attr(
167            tpu._TPU_REPLICATE_ATTR)
168
169
170class TPUOpsTest(test.TestCase):
171
172  def test_all_to_all_zero_split_count(self):
173    with self.assertRaisesRegex(
174        ValueError, "split_count 0 must at least be one"):
175      tpu_ops.all_to_all(
176          x=[0.0, 0.1652, 0.6543],
177          group_assignment=[1, -1],
178          concat_dimension=0,
179          split_dimension=0,
180          split_count=0)
181
182  def test_all_to_all_group_assignment_wrong_shape(self):
183    with self.assertRaisesRegex(
184        ValueError, "group_assignment must have rank 2"):
185      tpu_ops.all_to_all(
186          x=[0.0, 0.1652, 0.6543],
187          group_assignment=[1, -1],
188          concat_dimension=0,
189          split_dimension=0,
190          split_count=2)
191
192  def test_all_to_all_split_count_not_equal_to_group_assignment_shape(self):
193    with self.assertRaisesRegex(
194        ValueError, "split_count 1 must equal the size of the second dimension "
195        "of group_assignment 2"):
196      tpu_ops.all_to_all(
197          x=[0.0, 0.1652, 0.6543],
198          group_assignment=[[0, 1], [2, 3]],
199          concat_dimension=0,
200          split_dimension=0,
201          split_count=1)
202
203  def test_all_to_all_split_count_not_divide_input_shape(self):
204    with self.assertRaisesRegex(
205        ValueError, "input dimension 3 not divisible by split_count 2"):
206      tpu_ops.all_to_all(
207          x=[[0.0], [0.1652], [0.6543]],
208          group_assignment=[[0, 1], [2, 3]],
209          concat_dimension=1,
210          split_dimension=0,
211          split_count=2)
212
213
214def do_einsum():
215  a = array_ops.placeholder(dtype=dtypes.float32, name="a", shape=[2, 3, 4])
216  b = array_ops.placeholder(dtype=dtypes.float32, name="b", shape=[2, 4, 5])
217  return special_math_ops.einsum("abc,acd->abd", a, b)
218
219
220def find_einsum(g):
221  graph_def = g.as_graph_def()
222  for node in graph_def.node:
223    if node.op == "Einsum":
224      return True
225  return False
226
227
228def find_xla_einsum(g):
229  graph_def = g.as_graph_def()
230  for node in graph_def.node:
231    if node.op == "XlaEinsum":
232      return True
233  return False
234
235
236class TPUXlaEinsumTest(test.TestCase):
237
238  def test_tpu_rewrite_uses_xla_einsum(self):
239    with ops.Graph().as_default() as g:
240      tpu.rewrite(do_einsum)
241      self.assertTrue(find_einsum(g) or find_xla_einsum(g))
242
243  def test_default_does_not_use_xla_einsum(self):
244    with ops.Graph().as_default() as g:
245      do_einsum()
246      self.assertFalse(find_xla_einsum(g))
247
248
249if __name__ == "__main__":
250  test.main()
251