xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/check_numerics_callback_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import re
17
18import numpy as np
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.debug.lib import check_numerics_callback
22from tensorflow.python.eager import backprop
23from tensorflow.python.eager import context
24from tensorflow.python.eager import def_function
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
31from tensorflow.python.ops import custom_gradient
32from tensorflow.python.ops import gen_nn_ops
33from tensorflow.python.ops import gradient_checker_v2
34from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import googletest
38from tensorflow.python.platform import test
39
40
41class LimitStringLengthTest(test_util.TensorFlowTestCase):
42
43  def testLimitStringLengthWithExplicitLimit(self):
44    self.assertEqual(
45        check_numerics_callback.limit_string_length("", max_len=2), "")
46    self.assertEqual(
47        check_numerics_callback.limit_string_length("e", max_len=2), "e")
48    self.assertEqual(
49        check_numerics_callback.limit_string_length("de", max_len=2), "de")
50    self.assertEqual(
51        check_numerics_callback.limit_string_length("abcde", max_len=2),
52        "...de")
53
54  def testLimitStringLengthWithNoLimit(self):
55    self.assertEqual(check_numerics_callback.limit_string_length(
56        "A" * 100 + "B", max_len=None), "A" * 100 + "B")
57    self.assertEqual(
58        check_numerics_callback.limit_string_length("", max_len=None), "")
59
60  def testLimitStringLengthWithDefaultLimit(self):
61    self.assertEqual(
62        check_numerics_callback.limit_string_length("A" * 50 + "B"),
63        "..." + "A" * 49 + "B")
64
65
66class CheckNumericsCallbackTest(test_util.TensorFlowTestCase):
67
68  def tearDown(self):
69    check_numerics_callback.disable_check_numerics()
70    super(CheckNumericsCallbackTest, self).tearDown()
71
72  def testCallingDisableCheckNumericsWithoutEnablingFirstIsTolerated(self):
73    check_numerics_callback.disable_check_numerics()
74
75  def testNoCatchEagerOpExecution(self):
76    """Test running multiple steps of eager execution without Inf/NaN."""
77    check_numerics_callback.enable_check_numerics()
78    x = constant_op.constant([2.0, 3.0])
79    y = constant_op.constant([1.0, 0.0])
80    self.assertAllClose((x + y) * (x - y), [3.0, 9.0])
81
82  @test_util.run_in_graph_and_eager_modes
83  def testDatasetMapHealthyResults(self):
84    check_numerics_callback.enable_check_numerics()
85
86    tensor = constant_op.constant(
87        [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0])
88
89    def map_fn(x):
90      return math_ops.log(math_ops.square(x) + 1)
91
92    dataset = dataset_ops.Dataset.from_tensor_slices(tensor).batch(2).map(
93        map_fn)
94
95    @def_function.function
96    def get_batches():
97      iterator = iter(dataset)
98      return [next(iterator), next(iterator)]
99
100    batches = self.evaluate(get_batches())
101    self.assertLen(batches, 2)
102    self.assertAllClose(batches[0], np.log([1.25, 2]))
103    self.assertAllClose(batches[1], np.log([3.25, 5]))
104
105  @test_util.run_in_graph_and_eager_modes
106  def testGraphModeUsesCorrectPathLengthAndStackHeightLimits(self):
107    check_numerics_callback.enable_check_numerics(
108        stack_height_limit=123, path_length_limit=1200)
109
110    @def_function.function
111    def add_fn(x, y):
112      return x + y
113
114    fake_get_check_numerics_error_message = test.mock.MagicMock(
115        return_value="dummy_message")
116    with test.mock.patch.object(check_numerics_callback,
117                                "get_check_numerics_error_message",
118                                fake_get_check_numerics_error_message):
119      x = constant_op.constant(2.0)
120      y = constant_op.constant(3.0)
121      self.assertAllClose(self.evaluate(add_fn(x, y)), 5.0)
122      (_, call_kwargs) = fake_get_check_numerics_error_message.call_args
123      self.assertEqual(call_kwargs["stack_height_limit"], 123)
124      self.assertEqual(call_kwargs["path_length_limit"], 1200)
125
126
127class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
128  """Test for cases in which enable_check_numerics() catches infs or nans."""
129
130  def tearDown(self):
131    check_numerics_callback.disable_check_numerics()
132    super(CheckNumericsCallbackUnhealthyTest, self).tearDown()
133
134  def _assertRaisesInvalidArgumentErrorAndGetMessage(self, func):
135    caught = None
136    try:
137      func()
138    except errors.InvalidArgumentError as error:
139      caught = error
140    self.assertTrue(caught, "Failed to catch expected InvalidArgumentError")
141    return caught.message
142
143  def testCatchEagerOpFloat32Inf(self):
144    """Test catching Infinity in eager op execution: float32."""
145    check_numerics_callback.enable_check_numerics()
146
147    x = constant_op.constant([2.0, 3.0])
148    y = constant_op.constant([1.0, 0.0])
149    message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
150        lambda: x / y)
151
152    # Check the content of the error message.
153    self.assertTrue(re.search(r"eagerly-executing op.*\"RealDiv\"", message))
154    self.assertTrue(re.search(r"dtype.*float32", message))
155    self.assertIn("shape: (2,)\n", message)
156    self.assertIn("# of +Inf elements: 1\n", message)
157    self.assertIn("0: %s" % x, message)
158    self.assertIn("1: %s" % y, message)
159
160  def testEnableCheckNumericsIsIdempotent(self):
161    """Two calls to enable_check_numerics() have same effect as one call."""
162    check_numerics_callback.enable_check_numerics()
163    check_numerics_callback.enable_check_numerics()
164
165    x = constant_op.constant([2.0, 3.0])
166    y = constant_op.constant([1.0, 0.0])
167    message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
168        lambda: x / y)
169
170    # Check the content of the error message.
171    self.assertTrue(re.search(r"eagerly-executing op.*\"RealDiv\"", message))
172    self.assertTrue(re.search(r"dtype.*float32", message))
173    self.assertIn("shape: (2,)\n", message)
174    self.assertIn("# of +Inf elements: 1\n", message)
175    self.assertIn("0: %s" % x, message)
176    self.assertIn("1: %s" % y, message)
177
178  def testCatchEagerOpFloat16NaN(self):
179    """Test catching Infinity in eager op execution: float16."""
180    check_numerics_callback.enable_check_numerics()
181    def log1p(x):
182      y = 1.0 + x
183      return math_ops.log(y)
184    x = constant_op.constant([[-1.0]], dtype=dtypes.float16)
185    message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
186        lambda: log1p(x))
187
188    # Check the content of the error message.
189    self.assertTrue(re.search(r"eagerly-executing op.*\"Log\"", message))
190    self.assertTrue(re.search(r"dtype.*float16", message))
191    self.assertIn("shape: (1, 1)\n", message)
192    self.assertIn("# of -Inf elements: 1\n", message)
193    self.assertTrue(re.search(r"Input tensor.*0\.", message))
194
195  @test_util.enable_eager_op_as_function
196  def testCatchEagerOpFloat16NaNWithEagerOpAsFunctionEnabled(self):
197    self.testCatchEagerOpFloat16NaN()
198
199  @test_util.run_in_graph_and_eager_modes
200  def testCatchFunctionOpInfFloat64(self):
201    """Test catching infinites generated in a FuncGraph."""
202
203    check_numerics_callback.enable_check_numerics()
204    @def_function.function
205    def divide_sum_with_diff(x, y):
206      w1 = x + y
207      w2 = x - y
208      u = w1 / w2
209      return u * 2.0
210    x = constant_op.constant(2.0, dtype=dtypes.float64)
211    y = constant_op.constant(2.0, dtype=dtypes.float64)
212    message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
213        lambda: self.evaluate(divide_sum_with_diff(x, y)))
214
215    # Check the content of the error message.
216    self.assertTrue(re.search(r"graph op.*\"RealDiv\"", message))
217    self.assertTrue(re.search(r"dtype.*float64", message))
218    self.assertIn("shape: ()\n", message)
219    self.assertIn("Input tensors (2):", message)
220    # Check that the correct input ops are printed.
221    self.assertTrue(re.search(r"0:.*Tensor.*add:0", message))
222    self.assertTrue(re.search(r"1:.*Tensor.*sub:0", message))
223    # Check that the correct line for op creation is printed.
224    self.assertTrue(re.search(r"Stack trace of op's creation", message))
225    self.assertIn("u = w1 / w2", message)
226
227  @test_util.run_in_graph_and_eager_modes
228  @test_util.disable_xla(
229      "TODO(b/141100809): XLA has no way to assert inside of a kernel.")
230  def testControlFlowGraphWithNaNBFloat16(self):
231    """Test catching bfloat16 NaNs in a control-flow-v2 FuncGraph."""
232    check_numerics_callback.enable_check_numerics()
233
234    @def_function.function
235    def my_conditional(x):
236      if math_ops.less(math_ops.reduce_sum(x), 0.0):
237        return math_ops.log(x)
238      else:
239        return math_ops.log(-x)
240
241    x = constant_op.constant([1.0, 2.0, 3.0], dtype=dtypes.bfloat16)
242    message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
243        lambda: self.evaluate(my_conditional(x)))
244    # Check the content of the error message.
245    self.assertTrue(re.search(r"graph op.*\"Log\"", message))
246    self.assertTrue(re.search(r"dtype.*bfloat16", message))
247    self.assertIn("shape: (3,)\n", message)
248    # Check that the correct input op is printed.
249    self.assertTrue(re.search(r"Input tensor.*Tensor.*Neg", message))
250    # Check that the correct line for op creation is printed.
251    self.assertTrue(re.search(r"Stack trace of op's creation", message))
252    self.assertIn("return math_ops.log(-x)", message)
253
254  @test_util.run_in_graph_and_eager_modes
255  @test_util.disable_xla(
256      "There is a small inconsistency in the step at which overflow happens: "
257      "128 (without XLA) and 127 (with XLA).")
258  @test_util.disable_tfrt("b/177261532: TFRT cannot detect overflow yet.")
259  def testOverflowInTfFunction(self):
260    """Test catching Infinity caused by overflow in a tf.function with while."""
261    check_numerics_callback.enable_check_numerics()
262
263    @def_function.function
264    def accumulation_function(counter, lim, accum):
265      while math_ops.less(counter, lim):
266        accum.assign(accum * 2.0)
267        counter.assign_add(1)
268
269    counter = variables.Variable(0, dtype=dtypes.int32)
270    # Repeated `* 2.0` overflows a float32 tensor in 128 steps. So the
271    # 1000-step limit is sufficient.
272    lim = constant_op.constant(1000, dtype=dtypes.int32)
273    accum = variables.Variable(1.0)
274
275    if not context.executing_eagerly():
276      self.evaluate([counter.initializer, accum.initializer])
277
278    message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
279        lambda: self.evaluate(accumulation_function(counter, lim, accum)))
280
281    self.assertAllClose(self.evaluate(counter), 128)
282    # Check the content of the error message.
283    # The overflow to +Infinity happens during the `* 2.0` operation.
284    self.assertTrue(re.search(r"graph op.*\"Mul\"", message))
285    self.assertTrue(re.search(r"dtype.*float32", message))
286    self.assertIn("shape: ()\n", message)
287    # Check that the correct input op is printed.
288    self.assertIn("Input tensors (2):", message)
289    # Check that the correct input ops are printed.
290    self.assertTrue(re.search(r"0:.*Tensor.*ReadVariableOp:0", message))
291    self.assertTrue(re.search(r"1:.*Tensor.*mul/y:0", message))
292    # Check that the correct line for op creation is printed.
293    self.assertTrue(re.search(r"Stack trace of op's creation", message))
294    self.assertIn("accum.assign(accum * 2.0)", message)
295
296  @test_util.run_in_graph_and_eager_modes
297  def testNanInConstIsCaptured(self):
298    check_numerics_callback.enable_check_numerics()
299    v = variables.Variable(3.0, dtype=dtypes.float32)
300    @def_function.function
301    def add_a_bad_constant(x):
302      c = constant_op.constant(np.nan)
303      return x + c
304    if not context.executing_eagerly():
305      self.evaluate(v.initializer)
306    message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
307        lambda: self.evaluate(add_a_bad_constant(v)))
308    self.assertTrue(re.search(r"graph op.*\"Const\"", message))
309    self.assertTrue(re.search(r"dtype:.*float32", message))
310    self.assertTrue(re.search(r"shape:.*\(\)", message))
311    self.assertTrue(re.search(r"Graph name:.*add_a_bad_constant", message))
312
313  @test_util.run_in_graph_and_eager_modes
314  def testCatchInfinityInDatasetMapFunction(self):
315    """Test that callback catches NaN in a tf.dataset map function."""
316    check_numerics_callback.enable_check_numerics()
317
318    def generate_nan(x):
319      """Intentionally generates NaNs by taking log of negative number."""
320      casted_x = math_ops.cast(x, dtypes.float32)
321      return math_ops.log([[-1.0, 1.0], [3.0, 5.0]]) + casted_x
322
323    dataset = dataset_ops.Dataset.range(10).map(generate_nan)
324    iterator = dataset_ops.make_one_shot_iterator(dataset)
325
326    message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
327        lambda: self.evaluate(iterator.get_next()))
328
329    # Check the content of the error message.
330    self.assertTrue(re.search(r"graph op.*\"Log\"", message))
331    self.assertTrue(re.search(r"dtype.*float32", message))
332    self.assertIn("shape: (2, 2)\n", message)
333    self.assertTrue(re.search(r"Input tensor.*Tensor.*Log/x:0", message))
334    self.assertIn(
335        "-> |   return math_ops.log([[-1.0, 1.0], [3.0, 5.0]]) + casted_x",
336        message)
337
338  @test_util.run_in_graph_and_eager_modes
339  def testCustomGradientWithNaNWithTfFunction(self):
340    """Test that callback catches NaN in a gradient function during backprop."""
341    check_numerics_callback.enable_check_numerics()
342
343    @custom_gradient.custom_gradient
344    def func_with_bad_grad(x):
345      output = math_ops.sin(x)
346      @def_function.function
347      def grad(dy):
348        # `dy` will come in as 1.0. Taking log of -1.0 leads to NaN.
349        return math_ops.log(-dy)
350      return output, grad
351
352    x = constant_op.constant(-2.0, dtype=dtypes.float16)
353    def f(x):
354      return func_with_bad_grad(x)
355
356    message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
357        lambda: gradient_checker_v2.compute_gradient(f, [x]))
358
359    # Check the content of the error message.
360    self.assertTrue(re.search(r"graph op.*\"Log\"", message))
361    self.assertTrue(re.search(r"dtype.*float16", message))
362    if context.executing_eagerly():
363      self.assertIn("shape: ()\n", message)
364    self.assertTrue(re.search(r"Input tensor.*Tensor.*Neg:0", message))
365    self.assertIn("-> |   return math_ops.log(-dy)", message)
366
367  @test_util.run_in_graph_and_eager_modes
368  def testNestedFunctionGradientCall(self):
369    """Catching inf in the inner nested tf.function during backprop."""
370    check_numerics_callback.enable_check_numerics()
371
372    x = constant_op.constant(1.0 - 1e-8, dtype=dtypes.float32)
373
374    @def_function.function
375    def asinp1(x):
376      # asin()'s gradient overflows at the value close to 1.0.
377      return math_ops.asin(x) + 1.0
378
379    @def_function.function
380    def loss(x):
381      return math_ops.square(asinp1(x))
382
383    with backprop.GradientTape() as tape:
384      tape.watch(x)
385      y = loss(x)
386      message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
387          lambda: self.evaluate(tape.gradient(y, x)))
388      # Check the content of the error message.
389      # Assume the op Reciprocal or Xdivy is used in the gradient function for
390      # asin().
391      self.assertTrue((re.search(r"graph op.*\"Reciprocal\"", message) or
392                       re.search(r"graph op.*\"Xdivy\"", message)))
393      self.assertTrue(re.search(r"dtype.*float32", message))
394
395  def testEagerModeUsesCorrectPathLengthAndStackHeightLimits(self):
396    check_numerics_callback.enable_check_numerics(
397        stack_height_limit=123, path_length_limit=1200)
398    fake_get_check_numerics_error_message = test.mock.MagicMock(
399        return_value="dummy_message")
400    with test.mock.patch.object(check_numerics_callback,
401                                "get_check_numerics_error_message",
402                                fake_get_check_numerics_error_message):
403      x = constant_op.constant(2.0)
404      y = constant_op.constant(0.0)
405      self._assertRaisesInvalidArgumentErrorAndGetMessage(
406          lambda: x / y)  # Expected to generate an inf.
407      (_, call_kwargs) = fake_get_check_numerics_error_message.call_args
408      self.assertEqual(call_kwargs["stack_height_limit"], 123)
409      self.assertEqual(call_kwargs["path_length_limit"], 1200)
410
411  @test_util.run_in_graph_and_eager_modes
412  def testExpectedNaNOpOutputs(self):
413    """Test calling operations with benign NaN output."""
414    check_numerics_callback.enable_check_numerics()
415
416    # Empty input tensor
417    x = constant_op.constant(1, dtype=dtypes.float32, shape=[0, 1, 1, 1])
418    scale = constant_op.constant([1], dtype=dtypes.float32)
419    offset = constant_op.constant([1], dtype=dtypes.float32)
420
421    # Calling fused_batch_norm with an empty input should output a NaN in the
422    # latter four outputs without triggering the check_numerics callback
423    batch_norm_res = gen_nn_ops._fused_batch_norm(
424        x=x, scale=scale, offset=offset, mean=[], variance=[])
425
426    _, batch_mean, batch_variance, _, _ = self.evaluate(batch_norm_res)
427
428    self.assertTrue(np.isnan(batch_mean.squeeze()))
429    self.assertTrue(np.isnan(batch_variance.squeeze()))
430
431
432if __name__ == "__main__":
433  ops.enable_eager_execution()
434  googletest.main()
435