xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/dumping_callback.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Dumping op callbacks: Enables dump-based features in tfdbg v2."""
16
17import atexit
18import os
19import re
20import socket
21import threading
22import uuid
23
24
25from tensorflow.core.framework import tensor_pb2
26from tensorflow.core.protobuf import debug_event_pb2
27from tensorflow.core.protobuf import graph_debug_info_pb2
28from tensorflow.python.debug.lib import debug_events_writer
29from tensorflow.python.debug.lib import op_callbacks_common
30from tensorflow.python.debug.lib import source_utils
31from tensorflow.python.eager import function as function_lib
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import op_callbacks
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import gen_debug_ops
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.util import compat
41from tensorflow.python.util import tf_stack
42from tensorflow.python.util.tf_export import tf_export
43
44_state = threading.local()
45DEFAULT_TENSOR_DEBUG_MODE = "NO_TENSOR"
46
47# pylint:disable=protected-access
48_FUNCTION_PREFIXES = (
49    compat.as_bytes(function_lib._FORWARD_PREFIX),
50    compat.as_bytes(function_lib._BACKWARD_PREFIX),
51    compat.as_bytes(function_lib._INFERENCE_PREFIX))
52# pylint:enable=protected-access
53
54
55def is_op_type_function(op_type):
56  return compat.as_bytes(op_type).startswith(_FUNCTION_PREFIXES)
57
58
59@ops.RegisterGradient("DebugIdentityV2")
60def _debug_identity_v2_grad(op, dy):
61  """Gradient function for the DebugIdentityV2 op."""
62  del op  # Unused
63  return dy
64
65
66def _get_tfdbg_run_id():
67  return str(uuid.uuid4())[:8]
68
69
70def _get_id():
71  """Get a short unique ID."""
72  return str(uuid.uuid4())
73
74
75def _concrete_tensor_to_proto(tensor):
76  return tensor_util.make_tensor_proto(tensor.numpy())
77
78
79class _DumpingCallback(object):
80  """An object holding the states surrounding the dumping callback."""
81
82  def __init__(self,
83               dump_root,
84               tensor_debug_mode,
85               circular_buffer_size,
86               op_regex,
87               tensor_dtypes):
88    self._dump_root = dump_root
89    self._tfdbg_run_id = _get_tfdbg_run_id()
90    self._tensor_debug_mode = tensor_debug_mode
91    self._circular_buffer_size = circular_buffer_size
92    self._op_regex = op_regex
93    self._tensor_dtypes = tensor_dtypes
94
95    self._hostname = socket.gethostname()
96    # A list of source-file paths.
97    self._source_file_paths = []
98    # A map from stack frame (FileLineCol) to unique ID.
99    self._stack_frame_to_id = dict()
100    # Mapping op context to unique ID.
101    self._context_to_id = dict()
102    self._function_to_graph_id = dict()
103    self._op_type_to_context_id = dict()
104    # Keeps track of counter for symbolic tensors output by in-graph ops.
105    # It is used to make unique names for debugger-generated tensors.
106    self._symbolic_tensor_counter = 0
107    # A map from the names of debugger-generated Identity and DebugIdentityV2
108    # tensors to the names of the original insrumented graph tensors. This is
109    # applicable to v1 graph mode only.
110    self._tensor_aliases = dict()
111    self._source_file_paths_lock = threading.Lock()
112    self._stack_frame_to_id_lock = threading.Lock()
113    self._context_lock = threading.Lock()
114    self._symbolic_tensor_counter_lock = threading.Lock()
115    # A dict mapping Placeholder tensors to their instrumenting debug tensors.
116    # Used only under V1 graph mode, where we can't rely on auto control
117    # dependency to execute the debug tensors and hence need to attach the debug
118    # tensors as control dependencies of the ops that consume the Placeholder.
119    self._placeholder_to_debug_tensor = dict()
120    self._writer = None
121
122  def function_callback(self, function, name, graph, inputs, outputs):
123    """A callback to be called on creation of Functions.
124
125    Used to establish a join between function name and graph (context) ID.
126
127    Args:
128      function: The just-created Function.
129      name: Name of the function.
130      graph: FuncGraph, the graph containing the operations in the function.
131      inputs: the tensors in the graph to be used as inputs to the function
132      outputs: the tensors in the graph which will be outputs from the function
133    """
134    del name, inputs, outputs
135
136    graph_id = self._get_context_id(graph)
137    with self._context_lock:
138      # NOTE(cais): We currently store the function (_EagerDefinedFunction)
139      # as keys of this dict, because weakrefs to them sometimes become
140      # unreferenceable by the time the op callback is called. This approach
141      # may cause memory leaks due to the holding of the functions. If that's
142      # the case, calling `tf.debugging.disable_dump_debug_info()` should
143      # cause GC of this object and this dict.
144      self._function_to_graph_id[function] = graph_id
145
146  @property
147  def dump_root(self):
148    return self._dump_root
149
150  @dump_root.setter
151  def dump_root(self, dump_root):
152    if self._dump_root != dump_root:
153      self._dump_root = dump_root
154      self._writer = None
155
156  @property
157  def tfdbg_run_id(self):
158    return self._tfdbg_run_id
159
160  @property
161  def tensor_debug_mode(self):
162    return self._tensor_debug_mode
163
164  @property
165  def circular_buffer_size(self):
166    return self._circular_buffer_size
167
168  def get_writer(self):
169    """Get the debug events writer for the currently configured dump root."""
170    if not self._writer:
171      self._writer = debug_events_writer.DebugEventsWriter(
172          self._dump_root,
173          self._tfdbg_run_id,
174          circular_buffer_size=self._circular_buffer_size)
175    return self._writer
176
177  def _get_context_id(self, context):
178    """Get a unique ID for an op-construction context (e.g., a graph).
179
180    If the graph has been encountered before, reuse the same unique ID.
181    When encountering a new context (graph), this methods writes a DebugEvent
182    proto with the debugged_graph field to the proper DebugEvent file.
183
184    Args:
185      context: A context to get the unique ID for. Must be hashable. E.g., a
186        Graph object.
187
188    Returns:
189      A unique ID for the context.
190    """
191    # Use the double-checked lock pattern to optimize the common case.
192    if context in self._context_to_id:  # 1st check, without lock.
193      return self._context_to_id[context]
194    graph_is_new = False
195    with self._context_lock:
196      if context not in self._context_to_id:  # 2nd check, with lock.
197        graph_is_new = True
198        context_id = _get_id()
199        self._context_to_id[context] = context_id
200    if graph_is_new:
201      self.get_writer().WriteDebuggedGraph(debug_event_pb2.DebuggedGraph(
202          graph_id=context_id,
203          graph_name=getattr(context, "name", None),
204          outer_context_id=self._get_outer_context_id(context)))
205    return self._context_to_id[context]
206
207  def _get_outer_context_id(self, graph):
208    """Get the ID of the immediate outer context of the input graph.
209
210    Args:
211      graph: The graph (context) in question.
212
213    Returns:
214      If an outer context exists, the immediate outer context name as a string.
215      If such as outer context does not exist (i.e., `graph` is itself
216      outermost), `None`.
217    """
218    if hasattr(graph, "outer_graph") and graph.outer_graph:
219      return self._get_context_id(graph.outer_graph)
220    else:
221      return None
222
223  def _write_source_file_content(self, file_path):
224    """Send the content of a source file via debug-events writer.
225
226    Args:
227      file_path: Path to the source file.
228
229    Returns:
230      An int index for the file.
231    """
232    if file_path in self._source_file_paths:
233      return self._source_file_paths.index(file_path)
234    with self._source_file_paths_lock:
235      if file_path not in self._source_file_paths:
236        lines = None
237        if source_utils.is_extension_uncompiled_python_source(file_path):
238          try:
239            lines, _ = source_utils.load_source(file_path)
240          except IOError as e:
241            logging.warn(
242                "Failed to read source code from path: %s. Reason: %s",
243                file_path, e)
244        writer = self.get_writer()
245        writer.WriteSourceFile(debug_event_pb2.SourceFile(
246            file_path=file_path, host_name=self._hostname, lines=lines))
247        self._source_file_paths.append(file_path)
248      return self._source_file_paths.index(file_path)
249
250  def _process_stack_frames(self):
251    """Process stack frames.
252
253    Send the content of source-files, on a best-effort basis.
254
255    Returns:
256      A list of stack frame IDs.
257    """
258    stack_frames = tf_stack.extract_stack()
259    stack_frame_ids = []
260    writer = None
261    for file_path, lineno, func, _ in stack_frames:
262      abs_path = os.path.abspath(file_path)
263      if (abs_path, lineno, func) in self._stack_frame_to_id:
264        stack_frame_ids.append(
265            self._stack_frame_to_id[(abs_path, lineno, func)])
266        continue
267      with self._stack_frame_to_id_lock:
268        if (abs_path, lineno, func) not in self._stack_frame_to_id:
269          stack_frame_id = _get_id()
270          self._stack_frame_to_id[(abs_path, lineno, func)] = stack_frame_id
271          file_index = self._write_source_file_content(abs_path)
272          file_line_col = graph_debug_info_pb2.GraphDebugInfo.FileLineCol(
273              file_index=file_index, line=lineno, func=func)
274          stack_frame_with_id = debug_event_pb2.StackFrameWithId(
275              id=stack_frame_id, file_line_col=file_line_col)
276          writer = self.get_writer()
277          writer.WriteStackFrameWithId(stack_frame_with_id)
278        stack_frame_ids.append(
279            self._stack_frame_to_id[(abs_path, lineno, func)])
280
281    code_location = debug_event_pb2.CodeLocation(
282        host_name=self._hostname, stack_frame_ids=stack_frame_ids)
283    return code_location
284
285  def _process_v1_graph_mode_tensor(self,
286                                    op_type,
287                                    tensor,
288                                    debug_tensor,
289                                    tensor_debug_mode):
290    """For V1 graph mode, determine what tensor to output from callback.
291
292    Args:
293      op_type: Type of the op that outputs the original symbolic tensor.
294      tensor: The original output symbolic tensor.
295      debug_tensor: The debugger-instrumented tensor.
296      tensor_debug_mode: Debug mode used, a tfdbg TensorDebugMode enum.
297
298    Returns:
299      A symbolic tensor to be returned by the dumping op_callback.
300    """
301    # Placeholders need special treatment under V1 graph mode. The
302    # callback can't simply override the Placeholder tensor to a debug tensor,
303    # as that would cause the Placeholder op to lack a value.
304    if op_type in ("Placeholder", "PlaceholderWithDefault"):
305      self._placeholder_to_debug_tensor[tensor] = debug_tensor
306      return tensor
307    else:
308      # TODO(cais): Evaluate performance optimization options. For the
309      # `NO_TENSOR` debug mode, an alternative is to add `debug_tensor` as a
310      # control dependency of `tensor.op` without an additional identity op.
311      if (tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR and
312          op_type != "Const"):
313        # NOTE(b/153716279): Under v1 graph mode, overriding the output tensor
314        # of Const ops can lead to downstream errors related to shapes. We opt
315        # to use an identity op to avoid this issue at the cost of slightly
316        # larger graph size.
317        self._tensor_aliases[debug_tensor.name] = tensor.name
318        return debug_tensor
319      else:
320        with self._symbolic_tensor_counter_lock:
321          identity_name = "tfdbg_identity_%d" % self._symbolic_tensor_counter
322        identity = array_ops.identity(tensor, name=identity_name)
323        identity.op._add_control_input(  # pylint: disable=protected-access
324            debug_tensor.op)
325        self._tensor_aliases[identity.name] = tensor.name
326        return identity
327
328  def _instrument_symbolic_tensors(self,
329                                   tensors,
330                                   op_type,
331                                   op_name,
332                                   tfdbg_context_id,
333                                   tensor_ids):
334    """Add debugging instrumentation for symbolic (i.e., non-eager) tensors.
335
336    The detailed fashion in which the tensors are instrumented is determined
337    by the tensor_debug_mode configured for the currently enabled dumping
338    callback.
339
340    Args:
341      tensors: A tuple of Tensors to instrument. It is assumed that their
342        ordering corresponds to the ordering of output tensors of an original
343        op. Output slot indices (0-based) will be generated based on the
344        ordering.
345      op_type: Type name of the op that emits the Tensors (e.g., "MatMul").
346      op_name: Name of the op that emits the Tensors (e.g., "dense_1/MatMul").
347      tfdbg_context_id: A unique ID for the context that the op belongs to
348        (e.g., a graph).
349      tensor_ids: A list of unique ID numbers for the tensors, for tfdbg's
350        internal use.
351
352    Returns:
353      Non-eager Tensors that override the `tensors` as the output of the op
354      that originally generated `tensors`. In some cases (e.g., non-V1 graph
355      mode), this may be `None`, as the instrumentation can simply rely on
356      automatic control dependencies (see `auto_control_deps.py`) instead of
357      tensor overriding.
358    """
359    tensor_debug_mode = self._tensor_debug_mode
360    debug_urls = ["file://%s" % self._dump_root]
361    is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
362    instrumented_tensors = [] if is_v1_graph_mode else None
363    for output_slot, tensor in enumerate(tensors):
364      with self._symbolic_tensor_counter_lock:
365        debug_identity_name = ("DebugIdentityV2_%d" %
366                               self._symbolic_tensor_counter)
367      debug_identity_op_kwargs = {
368          "tfdbg_context_id": tfdbg_context_id,
369          "op_name": op_name,
370          "output_slot": output_slot,
371          "tensor_debug_mode": self._tensor_debug_mode,
372          "debug_urls": debug_urls,
373          "name": debug_identity_name,
374          "circular_buffer_size": self._circular_buffer_size,
375          "tfdbg_run_id": self._tfdbg_run_id,
376      }
377      if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
378        if (not self._should_dump_tensor(op_type, tensor.dtype) or
379            not tensor.dtype.is_numpy_compatible):
380          if is_v1_graph_mode:
381            instrumented_tensors.append(tensor)
382          continue
383        if is_v1_graph_mode and not tensor.dtype.is_numpy_compatible:
384          # Avoid instrumenting Placeholder under is_v1_graph_mode. Doing that
385          # would cause runtime complaint about Placeholders not being fed.
386          instrumented_tensors.append(tensor)
387          continue
388        # Except in V1 graph mode + control flow, debug_identity_v2 triggers
389        # auto control dependency because it's a stateful op.
390        debug_tensor = gen_debug_ops.debug_identity_v2(
391            # Use an empty (shape=[0]) float32 tensor for the NO_TENSOR mode
392            # as a low-overhead placeholder, since no actual tensor value is
393            # traced.
394            constant_op.constant([], dtype=dtypes.float32),
395            **debug_identity_op_kwargs)
396        if is_v1_graph_mode:
397          instrumented_tensors.append(self._process_v1_graph_mode_tensor(
398              op_type, tensor, debug_tensor, tensor_debug_mode))
399      elif tensor_debug_mode in (debug_event_pb2.TensorDebugMode.CURT_HEALTH,
400                                 debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
401                                 debug_event_pb2.TensorDebugMode.FULL_HEALTH,
402                                 debug_event_pb2.TensorDebugMode.SHAPE):
403        dtype = tensor.dtype
404        dtype_is_dumpable = (
405            tensor_debug_mode in (
406                debug_event_pb2.TensorDebugMode.CURT_HEALTH,
407                debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
408                debug_event_pb2.TensorDebugMode.FULL_HEALTH) and
409            dtype.is_floating or
410            tensor_debug_mode == debug_event_pb2.TensorDebugMode.SHAPE and
411            (dtype.is_floating or dtype.is_integer or dtype.is_bool))
412        if (not self._should_dump_tensor(op_type, tensor.dtype) or
413            not dtype_is_dumpable):
414          if is_v1_graph_mode:
415            instrumented_tensors.append(tensor)
416          continue
417        debug_tensor = gen_debug_ops.debug_identity_v2(
418            gen_debug_ops.debug_numeric_summary_v2(
419                tensor,
420                tensor_id=tensor_ids[output_slot],
421                tensor_debug_mode=self._tensor_debug_mode,
422                output_dtype=dtypes.float64), **debug_identity_op_kwargs)
423        if is_v1_graph_mode:
424          instrumented_tensors.append(self._process_v1_graph_mode_tensor(
425              op_type, tensor, debug_tensor, tensor_debug_mode))
426      elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
427        if (not self._should_dump_tensor(op_type, tensor.dtype) or
428            not tensor.dtype.is_numpy_compatible):
429          # Instrumenting DT_VARIANT and DT_RESOURCE type tensors under
430          # V1 graph mode is known to have issues. TODO(cais): Investigate.
431          if is_v1_graph_mode:
432            instrumented_tensors.append(tensor)
433          continue
434        debug_tensor = gen_debug_ops.debug_identity_v2(
435            tensor, **debug_identity_op_kwargs)
436        if is_v1_graph_mode:
437          instrumented_tensors.append(self._process_v1_graph_mode_tensor(
438              op_type, tensor, debug_tensor, tensor_debug_mode))
439      else:
440        raise NotImplementedError(
441            "Symbolic tensor instrumentation is not implemented for debug mode "
442            "%s" % self._tensor_debug_mode)
443    return instrumented_tensors
444
445  def _dump_eager_tensors(self,
446                          tensors,
447                          op_type,
448                          input_tensor_ids,
449                          output_tensor_device_ids,
450                          graph_id=None):
451    """Dump the value of eager tensors.
452
453    The destination of the dumping is determined by the dump_root of the
454    currently enabled dumping callback. The tensors may be transformed prior to
455    dumping (e.g., reduced as summary statistics such as minimum, maximum and
456    arithmetic  mean). The details of this transformation (if any) depends on
457    the tensor_debug_mode of the currently enabled dumping callback.
458
459    Args:
460      tensors: The EagerTensors whose values are to be dumped, with or without
461        value transform.
462      op_type: Type of the op that generates the tensors, as a string.
463      input_tensor_ids: IDs of the input EagerTensors to the op.
464      output_tensor_device_ids: Debugged-generated IDs for the devices on which
465        the output tensors are allocated, as a `list` of `int`s. Must match
466        `tensors` in length.
467      graph_id: ID of the executed graph, applicable only to eager execution of
468        a FuncGraph.
469
470    Returns:
471      A tfdbg Execution protocol buffer.
472    """
473    tensor_debug_mode = self._tensor_debug_mode
474    output_tensor_ids = [
475        t._id for t in tensors]  # pylint:disable=protected-access
476    assert len(tensors) == len(output_tensor_device_ids)
477    if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
478      return debug_event_pb2.Execution(
479          op_type=op_type,
480          graph_id=graph_id,
481          num_outputs=len(tensors),
482          input_tensor_ids=input_tensor_ids,
483          output_tensor_ids=output_tensor_ids,
484          output_tensor_device_ids=output_tensor_device_ids,
485          tensor_debug_mode=tensor_debug_mode,
486          code_location=self._process_stack_frames())
487    elif tensor_debug_mode in (debug_event_pb2.TensorDebugMode.CURT_HEALTH,
488                               debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
489                               debug_event_pb2.TensorDebugMode.FULL_HEALTH,
490                               debug_event_pb2.TensorDebugMode.SHAPE,
491                               debug_event_pb2.TensorDebugMode.FULL_TENSOR):
492      execution_proto = debug_event_pb2.Execution(
493          op_type=op_type,
494          num_outputs=len(tensors),
495          graph_id=graph_id,
496          input_tensor_ids=input_tensor_ids,
497          output_tensor_ids=output_tensor_ids,
498          output_tensor_device_ids=output_tensor_device_ids,
499          tensor_debug_mode=tensor_debug_mode,
500          code_location=self._process_stack_frames())
501      for tensor in tensors:
502        if (self._should_dump_tensor(op_type, tensor.dtype) and
503            tensor.dtype.is_numpy_compatible):
504          if tensor_debug_mode in (
505              debug_event_pb2.TensorDebugMode.CURT_HEALTH,
506              debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
507              debug_event_pb2.TensorDebugMode.FULL_HEALTH):
508            if tensor.dtype.is_floating:
509              tensor_proto = _concrete_tensor_to_proto(
510                  gen_debug_ops.debug_numeric_summary_v2(
511                      tensor,
512                      tensor_debug_mode=tensor_debug_mode,
513                      output_dtype=dtypes.float64))
514            else:
515              # A placeholder for non-floating-type output tensors.
516              tensor_proto = tensor_pb2.TensorProto()
517          elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.SHAPE:
518            if (tensor.dtype.is_floating or tensor.dtype.is_integer or
519                tensor.dtype.is_bool):
520              tensor_proto = _concrete_tensor_to_proto(
521                  gen_debug_ops.debug_numeric_summary_v2(
522                      tensor,
523                      tensor_debug_mode=tensor_debug_mode,
524                      output_dtype=dtypes.float64))
525            else:
526              # A placeholder for non-floating-type output tensors.
527              tensor_proto = tensor_pb2.TensorProto()
528          elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
529            tensor_proto = _concrete_tensor_to_proto(tensor)
530          if tensor_proto:
531            execution_proto.tensor_protos.append(tensor_proto)
532      return execution_proto
533    else:
534      raise NotImplementedError(
535          "Tensor instrumentation is not implemented for debug mode %s yet " %
536          self._tensor_debug_mode)
537
538  def callback(self,
539               op_type,
540               inputs,
541               attrs,
542               outputs,
543               op_name=None,
544               graph=None):
545    """Op callback for tracing (dumping) a TF program's execution."""
546    del attrs  # Unused
547
548    writer = self.get_writer()
549    if graph:
550      is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
551      context_id = self._get_context_id(graph)  # Innermost context ID.
552      output_tensor_ids = self._get_symbolic_tensor_ids(len(outputs))
553      if op_type in ("Const", "Placeholder", "PlaceholderWithDefault"):
554        # In some cases, the op name of a Const or Placeholder op in a graph
555        # can be duplicate (e.g., `None` or "resource").
556        # When this happens, we use the output tensor name to infer
557        # the non-duplicated tensor name.
558        op_name = outputs[0].name.split(":")[0]
559      if is_v1_graph_mode:
560        for input_tensor in inputs:
561          if input_tensor in self._placeholder_to_debug_tensor and outputs:
562            outputs[0].op._add_control_input(  # pylint: disable=protected-access
563                self._placeholder_to_debug_tensor[input_tensor].op)
564      graph_op_creation = debug_event_pb2.GraphOpCreation(
565          op_type=op_type,
566          op_name=op_name,
567          graph_name=graph.name if hasattr(graph, "name") else None,
568          graph_id=context_id,
569          input_names=[
570              self._lookup_tensor_name(input_tensor) for input_tensor in inputs
571          ],
572          num_outputs=len(outputs),
573          output_tensor_ids=output_tensor_ids,
574          code_location=self._process_stack_frames())
575      writer.WriteGraphOpCreation(graph_op_creation)
576      if outputs and compat.as_bytes(
577          op_type) not in op_callbacks_common.OP_CALLBACK_SKIP_OPS:
578        return self._instrument_symbolic_tensors(
579            outputs, op_type, op_name, context_id, output_tensor_ids)
580    else:
581      op_type_bytes = compat.as_bytes(op_type)
582      if op_type_bytes == b"DebugNumericSummaryV2":
583        # TODO(b/140334369): Remove this special casing logic once op_callback.
584        # automatically prevents infinite recursion in eager mode.
585        return None
586      if op_type_bytes in op_callbacks_common.OP_CALLBACK_SKIP_OPS:
587        return None
588      context_id = self._func_graph_id_from_func_name(op_type)
589      input_ids = [t._id for t in inputs]  # pylint:disable=protected-access
590      output_tensor_device_ids = [writer.RegisterDeviceAndGetId(output.device)
591                                  for output in outputs] if outputs else []
592      writer.WriteExecution(self._dump_eager_tensors(
593          outputs, op_type, input_ids, output_tensor_device_ids,
594          graph_id=context_id))
595
596  def _lookup_tensor_name(self, tensor):
597    """Look up the name of a graph tensor.
598
599    This method maps the name of a debugger-generated Identity or
600    DebugIdentityV2 tensor to the name of the original instrumented tensor,
601    if `tensor` is such a debugger-created tensor.
602    Otherwise, it returns the name of `tensor` as is.
603
604    Args:
605      tensor: The graph tensor to look up the name for.
606
607    Returns:
608      Name of the orignal instrumented tensor as known to the debugger.
609    """
610    return self._tensor_aliases.get(tensor.name, tensor.name)
611
612  def _func_graph_id_from_func_name(self, op_type):
613    """Attempt to get the ID of a FuncGraph based on an op type name.
614
615    Also caches the ID for faster access later.
616
617    Args:
618      op_type: Op type string, which may be the name of a function.
619
620    Returns:
621      If the op_type name does not fit the pattern of a function name (e.g.,
622      one that starts with "__inference_"), `None` is returned immediately.
623      Else, if the FuncGraph is found, ID of the underlying FuncGraph is
624      returned as a string.
625      Else, `None` is returned.
626    """
627    op_type = compat.as_bytes(op_type)
628    if is_op_type_function(op_type):
629      # op_type for eagerly-executed FuncGraphs have the prefixed and suffixed
630      # form such as "__inference_my_function_13579", wherein the middle part
631      # "my_function" is the name of the Python function from which the
632      # FuncGraph is compiled. Due to the suffix, the op_type is unique for
633      # - duplicate Python function names
634      # - multiple compilation of the same Python function
635      if op_type in self._op_type_to_context_id:
636        return self._op_type_to_context_id[op_type]
637      with self._context_lock:
638        for function in self._function_to_graph_id:
639          if function.name == op_type:
640            graph_id = self._function_to_graph_id[function]
641            self._op_type_to_context_id[op_type] = graph_id
642            return graph_id
643      return None
644    else:
645      return None
646
647  def _get_symbolic_tensor_ids(self, num_tensors):
648    tensor_ids = []
649    if num_tensors:
650      with self._symbolic_tensor_counter_lock:
651        for _ in range(num_tensors):
652          self._symbolic_tensor_counter += 1
653          tensor_ids.append(self._symbolic_tensor_counter)
654    return tensor_ids
655
656  def _should_dump_tensor(self, op_type, dtype):
657    """Determine if the given tensor's value will be dumped.
658
659    The determination is made given the configurations such as `op_regex`,
660    `tensor_dtypes`.
661
662    Args:
663      op_type: Name of the op's type, as a string (e.g., "MatMul").
664      dtype: The dtype of the tensor, as a `dtypes.DType` object.
665
666    Returns:
667      A bool indicating whether the tensor's value will be dumped.
668    """
669    should_dump = True
670    if self._op_regex:
671      should_dump = (should_dump and
672                     re.match(self._op_regex, op_type))
673    if self._tensor_dtypes:
674      if isinstance(self._tensor_dtypes, (list, tuple)):
675        should_dump = (should_dump and
676                       any(dtype == dtype_item for dtype_item
677                           in self._tensor_dtypes))
678      else:  # A callable that takes a DType argument and return a boolean.
679        should_dump = should_dump and self._tensor_dtypes(dtype)
680    return should_dump
681
682
683@tf_export("debugging.experimental.enable_dump_debug_info")
684def enable_dump_debug_info(dump_root,
685                           tensor_debug_mode=DEFAULT_TENSOR_DEBUG_MODE,
686                           circular_buffer_size=1000,
687                           op_regex=None,
688                           tensor_dtypes=None):
689  """Enable dumping debugging information from a TensorFlow program.
690
691  The debugging information is dumped to a directory on the file system
692  specified as `dump_root`.
693
694  The dumped debugging information can be ingested by debugger UIs.
695
696  The files in the dump directory contain the following information:
697    - TensorFlow Function construction (e.g., compilation of Python functions
698      decorated with @tf.function), the op types, names (if available), context,
699      the input and output tensors, and the associated stack traces.
700    - Execution of TensorFlow operations (ops) and Functions and their stack
701      traces, op types, names (if available) and contexts. In addition,
702      depending on the value of the `tensor_debug_mode` argument (see Args
703      section below), the value(s) of the output tensors or more concise
704      summaries of the tensor values will be dumped.
705    - A snapshot of Python source files involved in the execution of the
706      TensorFlow program.
707
708  Once enabled, the dumping can be disabled with the corresponding
709  `disable_dump_debug_info()` method under the same Python namespace.
710  Calling this method more than once with the same `dump_root` is idempotent.
711  Calling this method more than once with different `tensor_debug_mode`s
712  leads to a `ValueError`.
713  Calling this method more than once with different `circular_buffer_size`s
714  leads to a `ValueError`.
715  Calling this method with a different `dump_root` abolishes the
716  previously-enabled `dump_root`.
717
718  Usage example:
719
720  ```py
721  tf.debugging.experimental.enable_dump_debug_info('/tmp/my-tfdbg-dumps')
722
723  # Code to build, train and run your TensorFlow model...
724  ```
725
726  NOTE: If your code is running on TPUs, be sure to call
727  `tf.config.set_soft_device_placement(True)` before calling
728  `tf.debugging.experimental.enable_dump_debug_info()` as this API uses
729  automatic outside compilation on TPUs. For example:
730
731  ```py
732  tf.config.set_soft_device_placement(True)
733  tf.debugging.experimental.enable_dump_debug_info(
734      logdir, tensor_debug_mode="FULL_HEALTH")
735
736  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
737  strategy = tf.distribute.TPUStrategy(resolver)
738  with strategy.scope():
739    # ...
740  ```
741
742  Args:
743    dump_root: The directory path where the dumping information will be written.
744    tensor_debug_mode: Debug mode for tensor values, as a string.
745      The currently supported options are:
746      - "NO_TENSOR": (Default) Only traces the output tensors of all executed
747        ops (including those executed eagerly at the Python level or as a part
748        of a TensorFlow graph) and functions, while not extracting any
749        information from the values of the tensors.
750      - "CURT_HEALTH": For each floating-dtype tensor (e.g., tensors of dtypes
751        such as `float32`, `float64` and `bfloat16`), extracts a binary bit
752        indicating whether it contains any -infinity, +infinity or NaN.
753      - "CONCISE_HEALTH": For each floating-dtype tensor, extract total
754        element count, and counts of -infinity, +infinity and NaN elements.
755      - "FULL_HEALTH": For each floating-dtype tensor, extracts the dtype,
756        rank (number of dimensions), total element count, and counts of
757        -infinity, +infinity and NaN elements.
758      - "SHAPE": For each tensor (regardless of dtype), extracts its dtype,
759        rank, total element count and shape.
760    circular_buffer_size: Size of the circular buffers for execution events.
761      These circular buffers are designed to reduce the overhead of debugging
762      dumping. They hold the most recent debug events concerning eager execution
763      of ops and `tf.function`s and traces of tensor values computed inside
764      `tf.function`s. They are written to the file system only when the proper
765      flushing method is called (see description of return values below).
766      Expected to be an integer. If <= 0, the circular-buffer behavior will be
767      disabled, i.e., the execution debug events will be written to the file
768      writers in the same way as non-execution events such as op creations and
769      source-file snapshots.
770    op_regex: Dump data from only the tensors from op types that matches to the
771      regular expression (through Python's `re.match()`).
772      "Op type" refers to the names of the TensorFlow operations (e.g.,
773      "MatMul", "LogSoftmax"), which may repeat in a TensorFlow
774      function. It does *not* refer to the names of nodes (e.g.,
775      "dense/MatMul", "dense_1/MatMul_1") which are unique within a function.
776      - Example 1: Dump tensor data from only MatMul and Relu ops
777        `op_regex="^(MatMul|Relu)$"`.
778      - Example 2: Dump tensors from all ops *except* Relu:
779        `op_regex="(?!^Relu$)"`.
780      This filter operates in a logical AND relation with `tensor_dtypes`.
781    tensor_dtypes: Dump data from only the tensors of which the specified
782      dtypes. This optional argument can be in any of the following format:
783      - a list or tuple of `DType` objects or strings that can be converted
784        to `DType` objects via `tf.as_dtype()`. Examples:
785        - `tensor_dtype=[tf.float32, tf.float64]`,
786        - `tensor_dtype=["float32", "float64"]`,
787        - `tensor_dtypes=(tf.int32, tf.bool)`,
788        - `tensor_dtypes=("int32", "bool")`
789      - a callable that takes a single `DType` argument and returns a Python
790        `boolean` indicating whether the dtype is to be included in the data
791        dumping. Examples:
792        - `tensor_dtype=lambda dtype: dtype.is_integer`.
793      This filter operates in a logical AND relation with `op_regex`.
794  Returns:
795    A DebugEventsWriter instance used by the dumping callback. The caller
796    may use its flushing methods, including `FlushNonExecutionFiles()` and
797    `FlushExecutionFiles()`.
798  """
799  # TODO(cais): Revise the "UIs (currently under construction)" part of the doc
800  # string above.
801  # TODO(cais): Add Python code example to the doc string above.
802  global _state
803
804  tensor_debug_mode_keys = debug_event_pb2.TensorDebugMode.keys()
805  if tensor_debug_mode not in tensor_debug_mode_keys:
806    raise ValueError(
807        "Invalid value in tensor_debug_mode ('%s'). Valid options are: %s" %
808        (tensor_debug_mode, tensor_debug_mode_keys))
809
810  tensor_debug_mode = debug_event_pb2.TensorDebugMode.Value(tensor_debug_mode)
811  if tensor_debug_mode not in (debug_event_pb2.TensorDebugMode.NO_TENSOR,
812                               debug_event_pb2.TensorDebugMode.CURT_HEALTH,
813                               debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
814                               debug_event_pb2.TensorDebugMode.FULL_HEALTH,
815                               debug_event_pb2.TensorDebugMode.SHAPE,
816                               debug_event_pb2.TensorDebugMode.FULL_TENSOR):
817    raise NotImplementedError(
818        "tfdbg dumping: support for tensor debug mode %s is not "
819        "implemented yet" %
820        debug_event_pb2.TensorDebugMode.Name(tensor_debug_mode))
821
822  # Validate the types of tensor_dtypes.
823  if tensor_dtypes is not None:
824    if (not isinstance(tensor_dtypes, (list, tuple)) and
825        not callable(tensor_dtypes)):
826      raise ValueError(
827          "If specified, tensor_dtypes is expected to be a list, a tuple, or "
828          "a callable that takes a DType argument and returns a boolean, "
829          "but received %s" % (tensor_dtypes,))
830    if isinstance(tensor_dtypes, (list, tuple)):
831      tensor_dtypes = [
832          dtypes.as_dtype(dtype_item) for dtype_item in tensor_dtypes]
833
834  if hasattr(_state, "dumping_callback"):
835    if _state.dumping_callback.circular_buffer_size != circular_buffer_size:
836      raise ValueError(
837          "There is already a dumping callback configured with a different "
838          "circular-buffer size (%d). Therefore the newly request "
839          "circular-buffer size (%d) will not be honored." %
840          (_state.dumping_callback.circular_buffer_size, circular_buffer_size))
841    if _state.dumping_callback.tensor_debug_mode != tensor_debug_mode:
842      raise ValueError(
843          "There is already a dumping callback configured for dump root "
844          "%s with a different "
845          "tensor-debug mode (%s). Therefore the newly request "
846          "tensor-debug mode (%s) size will not be honored." %
847          (_state.dumping_callback.dump_root,
848           tensor_debug_mode_keys[_state.dumping_callback.tensor_debug_mode],
849           tensor_debug_mode_keys[tensor_debug_mode]))
850  else:
851    _state.dumping_callback = _DumpingCallback(dump_root,
852                                               tensor_debug_mode,
853                                               circular_buffer_size,
854                                               op_regex,
855                                               tensor_dtypes)
856    op_callbacks.add_op_callback(_state.dumping_callback.callback)
857    function_lib.add_function_callback(
858        _state.dumping_callback.function_callback)
859
860  if _state.dumping_callback.dump_root != dump_root:
861    _state.dumping_callback.dump_root = dump_root
862
863  logging.info(
864      "Enabled dumping callback in thread %s "
865      "(dump root: %s, tensor debug mode: %s)",
866      threading.current_thread().name,
867      _state.dumping_callback.dump_root,
868      debug_event_pb2.TensorDebugMode.Name(tensor_debug_mode))
869
870  atexit.register(disable_dump_debug_info)
871  return _state.dumping_callback.get_writer()
872
873
874@tf_export("debugging.experimental.disable_dump_debug_info")
875def disable_dump_debug_info():
876  """Disable the currently-enabled debugging dumping.
877
878  If the `enable_dump_debug_info()` method under the same Python namespace
879  has been invoked before, calling this method disables it. If no call to
880  `enable_dump_debug_info()` has been made, calling this method is a no-op.
881  Calling this method more than once is idempotent.
882  """
883  if hasattr(_state, "dumping_callback"):
884    dump_root = _state.dumping_callback.dump_root
885    tfdbg_run_id = _state.dumping_callback.tfdbg_run_id
886    debug_events_writer.DebugEventsWriter(dump_root, tfdbg_run_id).Close()
887    op_callbacks.remove_op_callback(_state.dumping_callback.callback)
888    function_lib.remove_function_callback(
889        _state.dumping_callback.function_callback)
890    delattr(_state, "dumping_callback")
891    logging.info("Disabled dumping callback in thread %s (dump root: %s)",
892                 threading.current_thread().name, dump_root)
893