xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/check_numerics_callback.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Eager-graph unified check numerics callback."""
16
17import collections
18import threading
19
20import numpy as np
21
22from tensorflow.core.protobuf import debug_event_pb2
23from tensorflow.python.debug.lib import op_callbacks_common
24from tensorflow.python.debug.lib import source_utils
25from tensorflow.python.eager import monitoring
26from tensorflow.python.framework import op_callbacks
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gen_debug_ops
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import compat
32from tensorflow.python.util.tf_export import tf_export
33
34
35# Many ops have benign NaN outputs, and running them with check_numerics
36# on will create unwanted errors
37# TODO(b/142497024): Replace this allowlist with function decorators in the ops
38IGNORE_OP_OUTPUTS = (
39    # For FusedBatchNorm, if the input tensor is empty then batch_mean and
40    # batch_variance will be NaN. reserve_space holds intermediate values
41    # derived from batch_mean and batch_variance used for gradient calculation
42    (b"FusedBatchNorm", 1),  # batch_mean
43    (b"FusedBatchNorm", 2),  # batch_variance
44    (b"FusedBatchNorm", 3),  # reserve_space_1
45    (b"FusedBatchNorm", 4),  # reserve_space_2
46
47    # Same as above
48    (b"FusedBatchNormV2", 1),  # batch_mean
49    (b"FusedBatchNormV2", 2),  # batch_variance
50    (b"FusedBatchNormV2", 3),  # reserve_space_1
51    (b"FusedBatchNormV2", 4),  # reserve_space_2
52
53    # Same as above, but reserve_space_3 holds additional intermediate values
54    (b"FusedBatchNormV3", 1),  # batch_mean
55    (b"FusedBatchNormV3", 2),  # batch_variance
56    (b"FusedBatchNormV3", 3),  # reserve_space_1
57    (b"FusedBatchNormV3", 4),  # reserve_space_2
58    (b"FusedBatchNormV3", 5),  # reserve_space_3
59)
60
61# Some frequently used ops are generally safe and we can skip them to reduce
62# overhead. NOTE: This list is compiled by observing operations called by
63# models in practice and is not a comprehensive list of safe operations.
64SAFE_OPS = (
65    b"Concat",
66    b"ConcatV2",
67    b"ExpandDims",
68    b"Fill",
69    b"Gather",
70    b"Maximum",
71    b"Minimum",
72    b"Reshape",
73    b"Slice",
74    b"Squeeze",
75    b"Stack",
76    b"StridedSlice",
77    b"StridedSliceGrad",
78    b"TensorListConcatV2",
79    b"TensorListGather",
80    b"TensorListGetItem",
81    b"TensorListPopBack",
82    b"TensorListStack",
83    b"Transpose",
84    b"Unpack",
85)
86
87_state = threading.local()
88
89_check_numerics_callback_create_counter = monitoring.Counter(
90    "/tensorflow/api/python/debugging/check_numerics_callback_create_counter",
91    "Counter for number of times the check_numerics op callback is created.")
92
93
94def limit_string_length(string, max_len=50):
95  """Limit the length of input string.
96
97  Args:
98    string: Input string.
99    max_len: (int or None) If int, the length limit. If None, no limit.
100
101  Returns:
102    Possibly length-limited string.
103  """
104  if max_len is None or len(string) <= max_len:
105    return string
106  else:
107    return "..." + string[len(string) - max_len:]
108
109
110# A dictionary that supports looking up the original input tensor names.
111_CHECK_NUMERICS_INPUT_LOOKUP = collections.defaultdict(dict)
112
113
114def _maybe_lookup_original_input_tensor(graph, tensor):
115  if (graph and
116      graph in _CHECK_NUMERICS_INPUT_LOOKUP and
117      tensor.name in _CHECK_NUMERICS_INPUT_LOOKUP[graph]):
118    return _CHECK_NUMERICS_INPUT_LOOKUP[graph][tensor.name]
119  else:
120    return tensor
121
122
123def get_check_numerics_error_message(slot,
124                                     num_outputs,
125                                     op_type,
126                                     tensor,
127                                     inputs,
128                                     graph=None,
129                                     traceback=None,
130                                     stack_height_limit=30,
131                                     path_length_limit=50):
132  """Create a meaningful and user-friendly error message about offending tensor.
133
134  The error message reveals the following info about the op that outputs
135  NaN/Infinity: dtype, shape (to the extent known at graph-construction time),
136  input tensors, stack trace for op creation (if is graph mode).
137
138  Args:
139    slot: (int) slot index of the tensor output.
140    num_outputs: (int) total number of outputs of the op.
141    op_type: (str) Type of the that generates `tensor`.
142    tensor: (Tensor) the offending tensor, i.e., the tensor that contains
143      Infinities or NaNs.
144    inputs: (array of Tensor) inputs to the op that generates `tensor`.
145    graph: (tf.Graph) the graph object that `tensor` belongs to. Available only
146      under graph mode.
147    traceback: (list of trace frames) the stack trace of the op's creation.
148      Available only under graph model.
149    stack_height_limit: (int or None) If int, limit to the height of the stack
150      trace printed in the error message. If None, no limit to the height.
151    path_length_limit: (int or None) Length limit for file paths included in the
152      formatted stack trace.
153
154  Returns:
155    (str) A formatted error message.
156  """
157  eager_vs_graph_qualifier = "graph" if graph else "eagerly-executing"
158  message = "\n"
159  message += (
160      "\n!!! Detected Infinity or NaN in output %d of "
161      "%s op \"%s\" (# of outputs: %d) !!!\n" %
162      (slot, eager_vs_graph_qualifier, op_type, num_outputs))
163
164  message += "  dtype: %s\n" % tensor.dtype
165  message += "  shape: %s\n" % (tensor.shape,)
166
167  if not graph:
168    # This is an eager tensor. We can get its numpy value and count
169    # NaNs and Infs.
170    is_inf = np.isinf(tensor)
171
172    num_neg_inf = np.sum(np.logical_and(np.less(tensor, 0.), is_inf))
173    num_pos_inf = np.sum(np.logical_and(np.greater(tensor, 0.), is_inf))
174    num_nan = np.sum(np.isnan(tensor))
175    if num_neg_inf > 0:
176      message += "  # of -Inf elements: %s\n" % num_neg_inf
177    if num_pos_inf > 0:
178      message += "  # of +Inf elements: %s\n" % num_pos_inf
179    if num_nan:
180      message += "  # of +NaN elements: %s\n" % num_nan
181
182  if len(inputs) > 1:
183    message += "\n  Input tensors (%d):\n" % len(inputs)
184    for slot, input_tensor in enumerate(inputs):
185      message += "         %d: %s\n" % (
186          slot, _maybe_lookup_original_input_tensor(graph, input_tensor))
187  elif len(inputs) == 1:
188    message += "\n  Input tensor: %s\n" % (
189        _maybe_lookup_original_input_tensor(graph, inputs[0]))
190  if graph and hasattr(graph, "name") and graph.name:
191    message += "  Graph name: \"%s\"\n" % graph.name
192
193  # Format the stack trace for the op's creation. We omit files that
194  # belong to tensorflow itself.
195  if graph and traceback:
196    message += (
197        "\n  Stack trace of op's creation (\"->\": inferred user code):\n")
198    if stack_height_limit is not None and len(traceback) > stack_height_limit:
199      num_omitted_frames = len(traceback) - stack_height_limit
200      message += "    + ... (Omitted %d frames)\n" % num_omitted_frames
201    for filepath, lineno, function_name, source_line in traceback[
202        -stack_height_limit:]:
203      user_code_indicator = "    "
204      if not source_utils.guess_is_tensorflow_py_library(filepath):
205        user_code_indicator = " -> "
206
207      message += "    + %s (L%d) %s\n" % (
208          limit_string_length(filepath, path_length_limit), lineno,
209          function_name)
210      if source_line is not None:
211        message += "%s|   %s\n" % (user_code_indicator, source_line)
212  message += "\n"
213  return message
214
215
216def _debug_summary(x):
217  return gen_debug_ops.debug_numeric_summary_v2(
218      x,
219      tensor_debug_mode=(
220          debug_event_pb2.TensorDebugMode.REDUCE_INF_NAN_THREE_SLOTS))
221
222
223class CheckNumericsCallback(object):
224  """Wrapper for the numerics-checking callback for thread locality."""
225
226  def __init__(self, stack_height_limit, path_length_limit):
227    self._stack_height_limit = stack_height_limit
228    self._path_length_limit = path_length_limit
229    # A dict mapping Placeholder tensors to their instrumenting debug tensors.
230    # Used only under V1 graph mode, where we can't rely on auto control
231    # dependency to execute the debug tensors and hence need to attach the debug
232    # tensors as control dependencies of the ops that consume the Placeholder.
233    self._placeholder_to_debug_tensor = dict()
234
235  def callback(self,
236               op_type,
237               inputs,
238               attrs,
239               outputs,
240               op_name=None,
241               graph=None):
242    """Eager-function unified callback for checking numerics."""
243    del attrs, op_name  # Unused
244    op_type_bytes = compat.as_bytes(op_type)
245    is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
246    if (op_type_bytes in op_callbacks_common.OP_CALLBACK_SKIP_OPS or
247        op_type_bytes in SAFE_OPS):
248      return None
249    if graph:
250      # Under graph mode. Insert check_numerics op.
251      instrumented_outputs = []
252      if is_v1_graph_mode:
253        for input_tensor in inputs:
254          if input_tensor in self._placeholder_to_debug_tensor and outputs:
255            outputs[0].op._add_control_input(  # pylint: disable=protected-access
256                self._placeholder_to_debug_tensor[input_tensor].op)
257      for slot, output in enumerate(outputs):
258        if (output.dtype.is_floating and
259            (op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
260          checked_output = array_ops.check_numerics_v2(
261              # TF v2 has automatic control dependencies added to stateful async
262              # ops, which allows us to run check_numerics asynchronously.
263              # In the above case we use debug_summary to reduce all output
264              # tensors asynchronously from the op being checked and then
265              # process the tensor summary with check_numerics.
266              output if is_v1_graph_mode else _debug_summary(output),
267              get_check_numerics_error_message(
268                  slot,
269                  len(outputs),
270                  op_type,
271                  output,
272                  inputs,
273                  graph=graph,
274                  traceback=output.op.traceback,
275                  stack_height_limit=self._stack_height_limit,
276                  path_length_limit=self._path_length_limit))
277          _CHECK_NUMERICS_INPUT_LOOKUP[graph][checked_output.name] = output
278          instrumented_outputs.append(self._get_output_tensor(
279              op_type_bytes, output, checked_output, is_v1_graph_mode))
280        else:
281          instrumented_outputs.append(output)
282      return instrumented_outputs
283    else:
284      if op_type_bytes == b"CheckNumericsV2":
285        # TODO(b/140334369): Remove this special casing logic once op_callback.
286        # automatically prevents infinite recursion in eager mode.
287        return None
288      # Under eager mode. Eagerly execute check_numerics op.
289      for slot, output in enumerate(outputs):
290        if (output.dtype.is_floating and
291            (op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
292          array_ops.check_numerics_v2(
293              output,
294              get_check_numerics_error_message(
295                  slot, len(outputs), op_type, output, inputs,
296                  stack_height_limit=self._stack_height_limit,
297                  path_length_limit=self._path_length_limit))
298
299  def _get_output_tensor(self,
300                         op_type,
301                         tensor,
302                         checked_tensor,
303                         is_v1_graph_mode):
304    """Determine what tensor to output from callback.
305
306    Args:
307      op_type: Type of the op that outputs the original symbolic tensor, as
308        `bytes`.
309      tensor: The original output symbolic tensor.
310      checked_tensor: The debugger-instrumented, numerics-checking tensor.
311      is_v1_graph_mode: Whether the debugged proggram is running under V1 graph
312        mode.
313
314    Returns:
315      A symbolic tensor to be returned by the dumping op_callback.
316    """
317    if is_v1_graph_mode:
318      # Placeholders need special treatment under V1 graph mode. The
319      # callback can't simply override the Placeholder tensor to the debug
320      # tensor, as that would cause the Placeholder op to lack a value.
321      # The debug tensor is remembered and will be attached as control
322      # inputs to ops that consumer the Placeholders later.
323      if op_type == b"Placeholder":
324        self._placeholder_to_debug_tensor[tensor] = checked_tensor
325        return tensor
326      else:
327        return checked_tensor
328    else:
329      # Under non-v1 graph mode, rely on auto control dependency to run the
330      # checked tensor.
331      return tensor
332
333
334@tf_export("debugging.enable_check_numerics")
335def enable_check_numerics(stack_height_limit=30,
336                          path_length_limit=50):
337  r"""Enable tensor numerics checking in an eager/graph unified fashion.
338
339  The numerics checking mechanism will cause any TensorFlow eager execution or
340  graph execution to error out as soon as an op's output tensor contains
341  infinity or NaN.
342
343  This method is idempotent. Calling it multiple times has the same effect
344  as calling it once.
345
346  This method takes effect only on the thread in which it is called.
347
348  When a op's float-type output tensor contains any Infinity or NaN, an
349  `tf.errors.InvalidArgumentError` will be thrown, with an error message that
350  reveals the following information:
351    - The type of the op that generated the tensor with bad numerics.
352    - Data type (dtype) of the tensor.
353    - Shape of the tensor (to the extent known at the time of eager execution
354      or graph construction).
355    - Name of the containing graph (if available).
356    - (Graph mode only): The stack trace of the intra-graph op's creation,
357      with a stack-height limit and a path-length limit for visual clarity.
358      The stack frames that belong to the user's code (as opposed to
359      tensorflow's internal code) are highlighted with a text arrow ("->").
360    - (Eager mode only): How many of the offending tensor's elements are
361      `Infinity` and `NaN`, respectively.
362
363  Once enabled, the check-numerics mechanism can be disabled by using
364  `tf.debugging.disable_check_numerics()`.
365
366  Example usage:
367
368  1. Catching infinity during the execution of a `tf.function` graph:
369
370     ```py
371     import tensorflow as tf
372
373     tf.debugging.enable_check_numerics()
374
375     @tf.function
376     def square_log_x_plus_1(x):
377       v = tf.math.log(x + 1)
378       return tf.math.square(v)
379
380     x = -1.0
381
382     # When the following line runs, a function graph will be compiled
383     # from the Python function `square_log_x_plus_1()`. Due to the
384     # `enable_check_numerics()` call above, the graph will contain
385     # numerics checking ops that will run during the function graph's
386     # execution. The function call generates an -infinity when the Log
387     # (logarithm) op operates on the output tensor of the Add op.
388     # The program errors out at this line, printing an error message.
389     y = square_log_x_plus_1(x)
390     z = -y
391    ```
392
393  2. Catching NaN during eager execution:
394
395     ```py
396     import numpy as np
397     import tensorflow as tf
398
399     tf.debugging.enable_check_numerics()
400
401     x = np.array([[0.0, -1.0], [4.0, 3.0]])
402
403     # The following line executes the Sqrt op eagerly. Due to the negative
404     # element in the input array, a NaN is generated. Due to the
405     # `enable_check_numerics()` call above, the program errors immediately
406     # at this line, printing an error message.
407     y = tf.math.sqrt(x)
408     z = tf.matmul(y, y)
409     ```
410
411  NOTE: If your code is running on TPUs, be sure to call
412  `tf.config.set_soft_device_placement(True)` before calling
413  `tf.debugging.enable_check_numerics()` as this API uses automatic outside
414  compilation on TPUs. For example:
415
416  ```py
417  tf.config.set_soft_device_placement(True)
418  tf.debugging.enable_check_numerics()
419
420  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
421  strategy = tf.distribute.TPUStrategy(resolver)
422  with strategy.scope():
423    # ...
424  ```
425
426  Args:
427    stack_height_limit: Limit to the height of the printed stack trace.
428      Applicable only to ops in `tf.function`s (graphs).
429    path_length_limit: Limit to the file path included in the printed stack
430      trace. Applicable only to ops in `tf.function`s (graphs).
431  """
432  if not hasattr(_state, "check_numerics_callback"):
433    _state.check_numerics_callback = CheckNumericsCallback(
434        stack_height_limit, path_length_limit)
435  op_callbacks.add_op_callback(_state.check_numerics_callback.callback)
436
437  logging.info(
438      "Enabled check-numerics callback in thread %s",
439      threading.current_thread().name)
440  _check_numerics_callback_create_counter.get_cell().increase_by(1)
441
442
443@tf_export("debugging.disable_check_numerics")
444def disable_check_numerics():
445  """Disable the eager/graph unified numerics checking mechanism.
446
447  This method can be used after a call to `tf.debugging.enable_check_numerics()`
448  to disable the numerics-checking mechanism that catches infinity and NaN
449  values output by ops executed eagerly or in tf.function-compiled graphs.
450
451  This method is idempotent. Calling it multiple times has the same effect
452  as calling it once.
453
454  This method takes effect only on the thread in which it is called.
455  """
456  if not hasattr(_state, "check_numerics_callback"):
457    return
458  try:
459    op_callbacks.remove_op_callback(_state.check_numerics_callback.callback)
460    delattr(_state, "check_numerics_callback")
461    logging.info(
462        "Disabled check-numerics callback in thread %s",
463        threading.current_thread().name)
464  except KeyError:
465    # Tolerate disabling the check numerics callback without
466    # enable_check_numerics() being called first.
467    pass
468