1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Tests for tpu_function helpers.""" 17 18from tensorflow.python.eager import def_function 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import importer 22from tensorflow.python.framework import ops 23from tensorflow.python.layers import convolutional 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import control_flow_util 27from tensorflow.python.ops import init_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import special_math_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.platform import test 32from tensorflow.python.tpu import tpu 33from tensorflow.python.tpu import tpu_feed 34from tensorflow.python.tpu import training_loop 35from tensorflow.python.tpu.ops import tpu_ops 36 37 38class TPUContextTest(test.TestCase): 39 40 def testIsInContext(self): 41 """Test that control_flow_util can check that we're in a TPU context.""" 42 with ops.Graph().as_default(): 43 z1 = array_ops.identity(1) 44 pivot = control_flow_ops.no_op() 45 context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) 46 context.Enter() 47 z2 = array_ops.identity(1) 48 context.Exit() 49 self.assertFalse(control_flow_util.IsInXLAContext(z1.op)) 50 self.assertTrue(control_flow_util.IsInXLAContext(z2.op)) 51 52 def testHandlesNameCollision(self): 53 """Test AddValue handles name collisions for ops from different graphs.""" 54 with ops.Graph().as_default(): 55 z = array_ops.zeros([2, 3], name="a") 56 assert z.name == "a:0", "Expected: a:0, Found: %s" % z.name 57 58 @def_function.function 59 def f(): 60 pivot = control_flow_ops.no_op() 61 context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) 62 context.Enter() 63 array_ops.identity(z) # Capture z. 64 z1 = array_ops.zeros([3, 2], name="a") 65 assert z1.name == "a:0", "Expected: a:0, Found: %s" % z1.name 66 z2 = array_ops.zeros([3, 2], name="a") 67 # Prior to fixing b/166794533 this would fail with a shape mismatch 68 # because context.AddValue would have cached `z` by its name which 69 # collides with z1's name. 70 result = z1 + z2 71 context.Exit() 72 return result 73 74 f.get_concrete_function() 75 76 77class TPULayerRewriteTest(test.TestCase): 78 79 def testUsingInfeedQueueWithRegularizer(self): 80 """Test that Layer regularizers can reference data created in loops.""" 81 82 with ops.Graph().as_default(): 83 84 def make_regularizer(scale): 85 def regularizer(inputs): 86 return scale * math_ops.reduce_sum(math_ops.square(inputs)) 87 return regularizer 88 89 def training_step(inputs, scale): 90 outputs = convolutional.conv2d( 91 inputs, 92 filters=16, 93 kernel_size=(3, 3), 94 data_format="channels_first", 95 kernel_regularizer=make_regularizer(scale)) 96 loss = math_ops.reduce_mean(math_ops.square(outputs)) 97 return loss.op 98 99 inputs = array_ops.zeros(shape=(128, 32, 32, 16)) 100 scale = array_ops.ones(shape=()) 101 infeed = tpu_feed.InfeedQueue( 102 tuple_types=[dtypes.float32, dtypes.float32], 103 tuple_shapes=[inputs.shape, scale.shape]) 104 105 def loop(): 106 return training_loop.repeat(5, training_step, infeed_queue=infeed) 107 108 # This should not throw an error. 109 tpu.rewrite(loop) 110 111 112class TPUGraphPruneTest(test.TestCase): 113 114 def test_prune_unconnected_ops(self): 115 with ops.Graph().as_default(): 116 a = array_ops.placeholder(dtype=dtypes.float32, name="a") 117 b = array_ops.placeholder(dtype=dtypes.float32, name="b") 118 constant_op.constant(1.0, name="constant") 119 x = variable_scope.get_variable( 120 name="x", 121 dtype=dtypes.float32, 122 shape=[], 123 use_resource=True, 124 initializer=init_ops.constant_initializer(2.0)) 125 y = variable_scope.get_variable( 126 name="y", 127 dtype=dtypes.float32, 128 shape=[], 129 use_resource=True, 130 initializer=init_ops.constant_initializer(3.0)) 131 math_ops.add(a, b) 132 math_ops.add(x, y) 133 graph_def = ops.get_default_graph().as_graph_def() 134 135 for node in graph_def.node: 136 # Attach a TPU_REPLICATE_ATTR to each node. 137 node.attr[tpu._TPU_REPLICATE_ATTR].s = b"0" 138 # Rewire placeholder "a" and variable "y" leaving them unconnected. 139 for (input_index, node_input) in enumerate(node.input): 140 if node_input == "b": 141 node.input[input_index] = "constant" 142 if node_input == "y": 143 node.input[input_index] = "x" 144 145 with ops.Graph().as_default() as graph: 146 # Reimport the graph and prune unconnected ops. 147 importer.import_graph_def(graph_def) 148 tpu.prune_unconnected_ops_from_xla(ops.get_default_graph()) 149 150 # Verify that ops "a" and "x" still have TPU_REPLICATE_ATTR. 151 a = graph.get_operation_by_name("import/a").get_attr( 152 tpu._TPU_REPLICATE_ATTR) 153 self.assertEqual(b"0", a) 154 x = graph.get_operation_by_name("import/x").get_attr( 155 tpu._TPU_REPLICATE_ATTR) 156 self.assertEqual(b"0", x) 157 # Verify that ops "b" and "y" have TPU_REPLICATE_ATTR removed. 158 with self.assertRaisesRegex( 159 ValueError, 160 "Operation \'import/b\' has no attr named \'_tpu_replicate\'"): 161 graph.get_operation_by_name("import/b").get_attr( 162 tpu._TPU_REPLICATE_ATTR) 163 with self.assertRaisesRegex( 164 ValueError, 165 "Operation \'import/y\' has no attr named \'_tpu_replicate\'"): 166 graph.get_operation_by_name("import/y").get_attr( 167 tpu._TPU_REPLICATE_ATTR) 168 169 170class TPUOpsTest(test.TestCase): 171 172 def test_all_to_all_zero_split_count(self): 173 with self.assertRaisesRegex( 174 ValueError, "split_count 0 must at least be one"): 175 tpu_ops.all_to_all( 176 x=[0.0, 0.1652, 0.6543], 177 group_assignment=[1, -1], 178 concat_dimension=0, 179 split_dimension=0, 180 split_count=0) 181 182 def test_all_to_all_group_assignment_wrong_shape(self): 183 with self.assertRaisesRegex( 184 ValueError, "group_assignment must have rank 2"): 185 tpu_ops.all_to_all( 186 x=[0.0, 0.1652, 0.6543], 187 group_assignment=[1, -1], 188 concat_dimension=0, 189 split_dimension=0, 190 split_count=2) 191 192 def test_all_to_all_split_count_not_equal_to_group_assignment_shape(self): 193 with self.assertRaisesRegex( 194 ValueError, "split_count 1 must equal the size of the second dimension " 195 "of group_assignment 2"): 196 tpu_ops.all_to_all( 197 x=[0.0, 0.1652, 0.6543], 198 group_assignment=[[0, 1], [2, 3]], 199 concat_dimension=0, 200 split_dimension=0, 201 split_count=1) 202 203 def test_all_to_all_split_count_not_divide_input_shape(self): 204 with self.assertRaisesRegex( 205 ValueError, "input dimension 3 not divisible by split_count 2"): 206 tpu_ops.all_to_all( 207 x=[[0.0], [0.1652], [0.6543]], 208 group_assignment=[[0, 1], [2, 3]], 209 concat_dimension=1, 210 split_dimension=0, 211 split_count=2) 212 213 214def do_einsum(): 215 a = array_ops.placeholder(dtype=dtypes.float32, name="a", shape=[2, 3, 4]) 216 b = array_ops.placeholder(dtype=dtypes.float32, name="b", shape=[2, 4, 5]) 217 return special_math_ops.einsum("abc,acd->abd", a, b) 218 219 220def find_einsum(g): 221 graph_def = g.as_graph_def() 222 for node in graph_def.node: 223 if node.op == "Einsum": 224 return True 225 return False 226 227 228def find_xla_einsum(g): 229 graph_def = g.as_graph_def() 230 for node in graph_def.node: 231 if node.op == "XlaEinsum": 232 return True 233 return False 234 235 236class TPUXlaEinsumTest(test.TestCase): 237 238 def test_tpu_rewrite_uses_xla_einsum(self): 239 with ops.Graph().as_default() as g: 240 tpu.rewrite(do_einsum) 241 self.assertTrue(find_einsum(g) or find_xla_einsum(g)) 242 243 def test_default_does_not_use_xla_einsum(self): 244 with ops.Graph().as_default() as g: 245 do_einsum() 246 self.assertFalse(find_xla_einsum(g)) 247 248 249if __name__ == "__main__": 250 test.main() 251