1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import re 17 18import numpy as np 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.debug.lib import check_numerics_callback 22from tensorflow.python.eager import backprop 23from tensorflow.python.eager import context 24from tensorflow.python.eager import def_function 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import array_grad # pylint: disable=unused-import 31from tensorflow.python.ops import custom_gradient 32from tensorflow.python.ops import gen_nn_ops 33from tensorflow.python.ops import gradient_checker_v2 34from tensorflow.python.ops import math_grad # pylint: disable=unused-import 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import googletest 38from tensorflow.python.platform import test 39 40 41class LimitStringLengthTest(test_util.TensorFlowTestCase): 42 43 def testLimitStringLengthWithExplicitLimit(self): 44 self.assertEqual( 45 check_numerics_callback.limit_string_length("", max_len=2), "") 46 self.assertEqual( 47 check_numerics_callback.limit_string_length("e", max_len=2), "e") 48 self.assertEqual( 49 check_numerics_callback.limit_string_length("de", max_len=2), "de") 50 self.assertEqual( 51 check_numerics_callback.limit_string_length("abcde", max_len=2), 52 "...de") 53 54 def testLimitStringLengthWithNoLimit(self): 55 self.assertEqual(check_numerics_callback.limit_string_length( 56 "A" * 100 + "B", max_len=None), "A" * 100 + "B") 57 self.assertEqual( 58 check_numerics_callback.limit_string_length("", max_len=None), "") 59 60 def testLimitStringLengthWithDefaultLimit(self): 61 self.assertEqual( 62 check_numerics_callback.limit_string_length("A" * 50 + "B"), 63 "..." + "A" * 49 + "B") 64 65 66class CheckNumericsCallbackTest(test_util.TensorFlowTestCase): 67 68 def tearDown(self): 69 check_numerics_callback.disable_check_numerics() 70 super(CheckNumericsCallbackTest, self).tearDown() 71 72 def testCallingDisableCheckNumericsWithoutEnablingFirstIsTolerated(self): 73 check_numerics_callback.disable_check_numerics() 74 75 def testNoCatchEagerOpExecution(self): 76 """Test running multiple steps of eager execution without Inf/NaN.""" 77 check_numerics_callback.enable_check_numerics() 78 x = constant_op.constant([2.0, 3.0]) 79 y = constant_op.constant([1.0, 0.0]) 80 self.assertAllClose((x + y) * (x - y), [3.0, 9.0]) 81 82 @test_util.run_in_graph_and_eager_modes 83 def testDatasetMapHealthyResults(self): 84 check_numerics_callback.enable_check_numerics() 85 86 tensor = constant_op.constant( 87 [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]) 88 89 def map_fn(x): 90 return math_ops.log(math_ops.square(x) + 1) 91 92 dataset = dataset_ops.Dataset.from_tensor_slices(tensor).batch(2).map( 93 map_fn) 94 95 @def_function.function 96 def get_batches(): 97 iterator = iter(dataset) 98 return [next(iterator), next(iterator)] 99 100 batches = self.evaluate(get_batches()) 101 self.assertLen(batches, 2) 102 self.assertAllClose(batches[0], np.log([1.25, 2])) 103 self.assertAllClose(batches[1], np.log([3.25, 5])) 104 105 @test_util.run_in_graph_and_eager_modes 106 def testGraphModeUsesCorrectPathLengthAndStackHeightLimits(self): 107 check_numerics_callback.enable_check_numerics( 108 stack_height_limit=123, path_length_limit=1200) 109 110 @def_function.function 111 def add_fn(x, y): 112 return x + y 113 114 fake_get_check_numerics_error_message = test.mock.MagicMock( 115 return_value="dummy_message") 116 with test.mock.patch.object(check_numerics_callback, 117 "get_check_numerics_error_message", 118 fake_get_check_numerics_error_message): 119 x = constant_op.constant(2.0) 120 y = constant_op.constant(3.0) 121 self.assertAllClose(self.evaluate(add_fn(x, y)), 5.0) 122 (_, call_kwargs) = fake_get_check_numerics_error_message.call_args 123 self.assertEqual(call_kwargs["stack_height_limit"], 123) 124 self.assertEqual(call_kwargs["path_length_limit"], 1200) 125 126 127class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase): 128 """Test for cases in which enable_check_numerics() catches infs or nans.""" 129 130 def tearDown(self): 131 check_numerics_callback.disable_check_numerics() 132 super(CheckNumericsCallbackUnhealthyTest, self).tearDown() 133 134 def _assertRaisesInvalidArgumentErrorAndGetMessage(self, func): 135 caught = None 136 try: 137 func() 138 except errors.InvalidArgumentError as error: 139 caught = error 140 self.assertTrue(caught, "Failed to catch expected InvalidArgumentError") 141 return caught.message 142 143 def testCatchEagerOpFloat32Inf(self): 144 """Test catching Infinity in eager op execution: float32.""" 145 check_numerics_callback.enable_check_numerics() 146 147 x = constant_op.constant([2.0, 3.0]) 148 y = constant_op.constant([1.0, 0.0]) 149 message = self._assertRaisesInvalidArgumentErrorAndGetMessage( 150 lambda: x / y) 151 152 # Check the content of the error message. 153 self.assertTrue(re.search(r"eagerly-executing op.*\"RealDiv\"", message)) 154 self.assertTrue(re.search(r"dtype.*float32", message)) 155 self.assertIn("shape: (2,)\n", message) 156 self.assertIn("# of +Inf elements: 1\n", message) 157 self.assertIn("0: %s" % x, message) 158 self.assertIn("1: %s" % y, message) 159 160 def testEnableCheckNumericsIsIdempotent(self): 161 """Two calls to enable_check_numerics() have same effect as one call.""" 162 check_numerics_callback.enable_check_numerics() 163 check_numerics_callback.enable_check_numerics() 164 165 x = constant_op.constant([2.0, 3.0]) 166 y = constant_op.constant([1.0, 0.0]) 167 message = self._assertRaisesInvalidArgumentErrorAndGetMessage( 168 lambda: x / y) 169 170 # Check the content of the error message. 171 self.assertTrue(re.search(r"eagerly-executing op.*\"RealDiv\"", message)) 172 self.assertTrue(re.search(r"dtype.*float32", message)) 173 self.assertIn("shape: (2,)\n", message) 174 self.assertIn("# of +Inf elements: 1\n", message) 175 self.assertIn("0: %s" % x, message) 176 self.assertIn("1: %s" % y, message) 177 178 def testCatchEagerOpFloat16NaN(self): 179 """Test catching Infinity in eager op execution: float16.""" 180 check_numerics_callback.enable_check_numerics() 181 def log1p(x): 182 y = 1.0 + x 183 return math_ops.log(y) 184 x = constant_op.constant([[-1.0]], dtype=dtypes.float16) 185 message = self._assertRaisesInvalidArgumentErrorAndGetMessage( 186 lambda: log1p(x)) 187 188 # Check the content of the error message. 189 self.assertTrue(re.search(r"eagerly-executing op.*\"Log\"", message)) 190 self.assertTrue(re.search(r"dtype.*float16", message)) 191 self.assertIn("shape: (1, 1)\n", message) 192 self.assertIn("# of -Inf elements: 1\n", message) 193 self.assertTrue(re.search(r"Input tensor.*0\.", message)) 194 195 @test_util.enable_eager_op_as_function 196 def testCatchEagerOpFloat16NaNWithEagerOpAsFunctionEnabled(self): 197 self.testCatchEagerOpFloat16NaN() 198 199 @test_util.run_in_graph_and_eager_modes 200 def testCatchFunctionOpInfFloat64(self): 201 """Test catching infinites generated in a FuncGraph.""" 202 203 check_numerics_callback.enable_check_numerics() 204 @def_function.function 205 def divide_sum_with_diff(x, y): 206 w1 = x + y 207 w2 = x - y 208 u = w1 / w2 209 return u * 2.0 210 x = constant_op.constant(2.0, dtype=dtypes.float64) 211 y = constant_op.constant(2.0, dtype=dtypes.float64) 212 message = self._assertRaisesInvalidArgumentErrorAndGetMessage( 213 lambda: self.evaluate(divide_sum_with_diff(x, y))) 214 215 # Check the content of the error message. 216 self.assertTrue(re.search(r"graph op.*\"RealDiv\"", message)) 217 self.assertTrue(re.search(r"dtype.*float64", message)) 218 self.assertIn("shape: ()\n", message) 219 self.assertIn("Input tensors (2):", message) 220 # Check that the correct input ops are printed. 221 self.assertTrue(re.search(r"0:.*Tensor.*add:0", message)) 222 self.assertTrue(re.search(r"1:.*Tensor.*sub:0", message)) 223 # Check that the correct line for op creation is printed. 224 self.assertTrue(re.search(r"Stack trace of op's creation", message)) 225 self.assertIn("u = w1 / w2", message) 226 227 @test_util.run_in_graph_and_eager_modes 228 @test_util.disable_xla( 229 "TODO(b/141100809): XLA has no way to assert inside of a kernel.") 230 def testControlFlowGraphWithNaNBFloat16(self): 231 """Test catching bfloat16 NaNs in a control-flow-v2 FuncGraph.""" 232 check_numerics_callback.enable_check_numerics() 233 234 @def_function.function 235 def my_conditional(x): 236 if math_ops.less(math_ops.reduce_sum(x), 0.0): 237 return math_ops.log(x) 238 else: 239 return math_ops.log(-x) 240 241 x = constant_op.constant([1.0, 2.0, 3.0], dtype=dtypes.bfloat16) 242 message = self._assertRaisesInvalidArgumentErrorAndGetMessage( 243 lambda: self.evaluate(my_conditional(x))) 244 # Check the content of the error message. 245 self.assertTrue(re.search(r"graph op.*\"Log\"", message)) 246 self.assertTrue(re.search(r"dtype.*bfloat16", message)) 247 self.assertIn("shape: (3,)\n", message) 248 # Check that the correct input op is printed. 249 self.assertTrue(re.search(r"Input tensor.*Tensor.*Neg", message)) 250 # Check that the correct line for op creation is printed. 251 self.assertTrue(re.search(r"Stack trace of op's creation", message)) 252 self.assertIn("return math_ops.log(-x)", message) 253 254 @test_util.run_in_graph_and_eager_modes 255 @test_util.disable_xla( 256 "There is a small inconsistency in the step at which overflow happens: " 257 "128 (without XLA) and 127 (with XLA).") 258 @test_util.disable_tfrt("b/177261532: TFRT cannot detect overflow yet.") 259 def testOverflowInTfFunction(self): 260 """Test catching Infinity caused by overflow in a tf.function with while.""" 261 check_numerics_callback.enable_check_numerics() 262 263 @def_function.function 264 def accumulation_function(counter, lim, accum): 265 while math_ops.less(counter, lim): 266 accum.assign(accum * 2.0) 267 counter.assign_add(1) 268 269 counter = variables.Variable(0, dtype=dtypes.int32) 270 # Repeated `* 2.0` overflows a float32 tensor in 128 steps. So the 271 # 1000-step limit is sufficient. 272 lim = constant_op.constant(1000, dtype=dtypes.int32) 273 accum = variables.Variable(1.0) 274 275 if not context.executing_eagerly(): 276 self.evaluate([counter.initializer, accum.initializer]) 277 278 message = self._assertRaisesInvalidArgumentErrorAndGetMessage( 279 lambda: self.evaluate(accumulation_function(counter, lim, accum))) 280 281 self.assertAllClose(self.evaluate(counter), 128) 282 # Check the content of the error message. 283 # The overflow to +Infinity happens during the `* 2.0` operation. 284 self.assertTrue(re.search(r"graph op.*\"Mul\"", message)) 285 self.assertTrue(re.search(r"dtype.*float32", message)) 286 self.assertIn("shape: ()\n", message) 287 # Check that the correct input op is printed. 288 self.assertIn("Input tensors (2):", message) 289 # Check that the correct input ops are printed. 290 self.assertTrue(re.search(r"0:.*Tensor.*ReadVariableOp:0", message)) 291 self.assertTrue(re.search(r"1:.*Tensor.*mul/y:0", message)) 292 # Check that the correct line for op creation is printed. 293 self.assertTrue(re.search(r"Stack trace of op's creation", message)) 294 self.assertIn("accum.assign(accum * 2.0)", message) 295 296 @test_util.run_in_graph_and_eager_modes 297 def testNanInConstIsCaptured(self): 298 check_numerics_callback.enable_check_numerics() 299 v = variables.Variable(3.0, dtype=dtypes.float32) 300 @def_function.function 301 def add_a_bad_constant(x): 302 c = constant_op.constant(np.nan) 303 return x + c 304 if not context.executing_eagerly(): 305 self.evaluate(v.initializer) 306 message = self._assertRaisesInvalidArgumentErrorAndGetMessage( 307 lambda: self.evaluate(add_a_bad_constant(v))) 308 self.assertTrue(re.search(r"graph op.*\"Const\"", message)) 309 self.assertTrue(re.search(r"dtype:.*float32", message)) 310 self.assertTrue(re.search(r"shape:.*\(\)", message)) 311 self.assertTrue(re.search(r"Graph name:.*add_a_bad_constant", message)) 312 313 @test_util.run_in_graph_and_eager_modes 314 def testCatchInfinityInDatasetMapFunction(self): 315 """Test that callback catches NaN in a tf.dataset map function.""" 316 check_numerics_callback.enable_check_numerics() 317 318 def generate_nan(x): 319 """Intentionally generates NaNs by taking log of negative number.""" 320 casted_x = math_ops.cast(x, dtypes.float32) 321 return math_ops.log([[-1.0, 1.0], [3.0, 5.0]]) + casted_x 322 323 dataset = dataset_ops.Dataset.range(10).map(generate_nan) 324 iterator = dataset_ops.make_one_shot_iterator(dataset) 325 326 message = self._assertRaisesInvalidArgumentErrorAndGetMessage( 327 lambda: self.evaluate(iterator.get_next())) 328 329 # Check the content of the error message. 330 self.assertTrue(re.search(r"graph op.*\"Log\"", message)) 331 self.assertTrue(re.search(r"dtype.*float32", message)) 332 self.assertIn("shape: (2, 2)\n", message) 333 self.assertTrue(re.search(r"Input tensor.*Tensor.*Log/x:0", message)) 334 self.assertIn( 335 "-> | return math_ops.log([[-1.0, 1.0], [3.0, 5.0]]) + casted_x", 336 message) 337 338 @test_util.run_in_graph_and_eager_modes 339 def testCustomGradientWithNaNWithTfFunction(self): 340 """Test that callback catches NaN in a gradient function during backprop.""" 341 check_numerics_callback.enable_check_numerics() 342 343 @custom_gradient.custom_gradient 344 def func_with_bad_grad(x): 345 output = math_ops.sin(x) 346 @def_function.function 347 def grad(dy): 348 # `dy` will come in as 1.0. Taking log of -1.0 leads to NaN. 349 return math_ops.log(-dy) 350 return output, grad 351 352 x = constant_op.constant(-2.0, dtype=dtypes.float16) 353 def f(x): 354 return func_with_bad_grad(x) 355 356 message = self._assertRaisesInvalidArgumentErrorAndGetMessage( 357 lambda: gradient_checker_v2.compute_gradient(f, [x])) 358 359 # Check the content of the error message. 360 self.assertTrue(re.search(r"graph op.*\"Log\"", message)) 361 self.assertTrue(re.search(r"dtype.*float16", message)) 362 if context.executing_eagerly(): 363 self.assertIn("shape: ()\n", message) 364 self.assertTrue(re.search(r"Input tensor.*Tensor.*Neg:0", message)) 365 self.assertIn("-> | return math_ops.log(-dy)", message) 366 367 @test_util.run_in_graph_and_eager_modes 368 def testNestedFunctionGradientCall(self): 369 """Catching inf in the inner nested tf.function during backprop.""" 370 check_numerics_callback.enable_check_numerics() 371 372 x = constant_op.constant(1.0 - 1e-8, dtype=dtypes.float32) 373 374 @def_function.function 375 def asinp1(x): 376 # asin()'s gradient overflows at the value close to 1.0. 377 return math_ops.asin(x) + 1.0 378 379 @def_function.function 380 def loss(x): 381 return math_ops.square(asinp1(x)) 382 383 with backprop.GradientTape() as tape: 384 tape.watch(x) 385 y = loss(x) 386 message = self._assertRaisesInvalidArgumentErrorAndGetMessage( 387 lambda: self.evaluate(tape.gradient(y, x))) 388 # Check the content of the error message. 389 # Assume the op Reciprocal or Xdivy is used in the gradient function for 390 # asin(). 391 self.assertTrue((re.search(r"graph op.*\"Reciprocal\"", message) or 392 re.search(r"graph op.*\"Xdivy\"", message))) 393 self.assertTrue(re.search(r"dtype.*float32", message)) 394 395 def testEagerModeUsesCorrectPathLengthAndStackHeightLimits(self): 396 check_numerics_callback.enable_check_numerics( 397 stack_height_limit=123, path_length_limit=1200) 398 fake_get_check_numerics_error_message = test.mock.MagicMock( 399 return_value="dummy_message") 400 with test.mock.patch.object(check_numerics_callback, 401 "get_check_numerics_error_message", 402 fake_get_check_numerics_error_message): 403 x = constant_op.constant(2.0) 404 y = constant_op.constant(0.0) 405 self._assertRaisesInvalidArgumentErrorAndGetMessage( 406 lambda: x / y) # Expected to generate an inf. 407 (_, call_kwargs) = fake_get_check_numerics_error_message.call_args 408 self.assertEqual(call_kwargs["stack_height_limit"], 123) 409 self.assertEqual(call_kwargs["path_length_limit"], 1200) 410 411 @test_util.run_in_graph_and_eager_modes 412 def testExpectedNaNOpOutputs(self): 413 """Test calling operations with benign NaN output.""" 414 check_numerics_callback.enable_check_numerics() 415 416 # Empty input tensor 417 x = constant_op.constant(1, dtype=dtypes.float32, shape=[0, 1, 1, 1]) 418 scale = constant_op.constant([1], dtype=dtypes.float32) 419 offset = constant_op.constant([1], dtype=dtypes.float32) 420 421 # Calling fused_batch_norm with an empty input should output a NaN in the 422 # latter four outputs without triggering the check_numerics callback 423 batch_norm_res = gen_nn_ops._fused_batch_norm( 424 x=x, scale=scale, offset=offset, mean=[], variance=[]) 425 426 _, batch_mean, batch_variance, _, _ = self.evaluate(batch_norm_res) 427 428 self.assertTrue(np.isnan(batch_mean.squeeze())) 429 self.assertTrue(np.isnan(batch_variance.squeeze())) 430 431 432if __name__ == "__main__": 433 ops.enable_eager_execution() 434 googletest.main() 435