xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/debug_graphs.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and methods for processing debugger-decorated graphs."""
16from tensorflow.core.framework import graph_pb2
17from tensorflow.python.framework import op_def_registry
18from tensorflow.python.platform import tf_logging as logging
19
20
21def parse_node_or_tensor_name(name):
22  """Get the node name from a string that can be node or tensor name.
23
24  Args:
25    name: An input node name (e.g., "node_a") or tensor name (e.g.,
26      "node_a:0"), as a str.
27
28  Returns:
29    1) The node name, as a str. If the input name is a tensor name, i.e.,
30      consists of a colon, the final colon and the following output slot
31      will be stripped.
32    2) If the input name is a tensor name, the output slot, as an int. If
33      the input name is not a tensor name, None.
34  """
35
36  if ":" in name and not name.endswith(":"):
37    node_name = name[:name.rfind(":")]
38    output_slot = int(name[name.rfind(":") + 1:])
39
40    return node_name, output_slot
41  else:
42    return name, None
43
44
45def get_node_name(element_name):
46  node_name, _ = parse_node_or_tensor_name(element_name)
47  return node_name
48
49
50def get_output_slot(element_name):
51  """Get the output slot number from the name of a graph element.
52
53  If element_name is a node name without output slot at the end, 0 will be
54  assumed.
55
56  Args:
57    element_name: (`str`) name of the graph element in question.
58
59  Returns:
60    (`int`) output slot number.
61  """
62  _, output_slot = parse_node_or_tensor_name(element_name)
63  return output_slot if output_slot is not None else 0
64
65
66def is_copy_node(node_name):
67  """Determine whether a node name is that of a debug Copy node.
68
69  Such nodes are inserted by TensorFlow core upon request in
70  RunOptions.debug_options.debug_tensor_watch_opts.
71
72  Args:
73    node_name: Name of the node.
74
75  Returns:
76    A bool indicating whether the input argument is the name of a debug Copy
77    node.
78  """
79  return node_name.startswith("__copy_")
80
81
82def is_debug_node(node_name):
83  """Determine whether a node name is that of a debug node.
84
85  Such nodes are inserted by TensorFlow core upon request in
86  RunOptions.debug_options.debug_tensor_watch_opts.
87
88  Args:
89    node_name: Name of the node.
90
91  Returns:
92    A bool indicating whether the input argument is the name of a debug node.
93  """
94  return node_name.startswith("__dbg_")
95
96
97def parse_debug_node_name(node_name):
98  """Parse the name of a debug node.
99
100  Args:
101    node_name: Name of the debug node.
102
103  Returns:
104    1. Name of the watched node, as a str.
105    2. Output slot index of the watched tensor, as an int.
106    3. Index of the debug node, as an int.
107    4. Name of the debug op, as a str, e.g, "DebugIdentity".
108
109  Raises:
110    ValueError: If the input node name is not a valid debug node name.
111  """
112  prefix = "__dbg_"
113
114  name = node_name
115  if not name.startswith(prefix):
116    raise ValueError("Invalid prefix in debug node name: '%s'" % node_name)
117
118  name = name[len(prefix):]
119
120  if name.count("_") < 2:
121    raise ValueError("Invalid debug node name: '%s'" % node_name)
122
123  debug_op = name[name.rindex("_") + 1:]
124  name = name[:name.rindex("_")]
125
126  debug_op_index = int(name[name.rindex("_") + 1:])
127  name = name[:name.rindex("_")]
128
129  if name.count(":") != 1:
130    raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name)
131
132  watched_node_name = name[:name.index(":")]
133  watched_output_slot = int(name[name.index(":") + 1:])
134
135  return watched_node_name, watched_output_slot, debug_op_index, debug_op
136
137
138class GraphTracingReachedDestination(Exception):
139  pass
140
141
142class DFSGraphTracer(object):
143  """Graph input tracer using depth-first search."""
144
145  def __init__(self,
146               input_lists,
147               skip_node_names=None,
148               destination_node_name=None):
149    """Constructor of _DFSGraphTracer.
150
151    Args:
152      input_lists: A list of dicts. Each dict is an adjacency (input) map from
153        the recipient node name as the key and the list of input node names
154        as the value.
155      skip_node_names: Optional: a list of node names to skip tracing.
156      destination_node_name: Optional: destination node name. If not `None`, it
157        should be the name of a destination not as a str and the graph tracing
158        will raise GraphTracingReachedDestination as soon as the node has been
159        reached.
160
161    Raises:
162      GraphTracingReachedDestination: if stop_at_node_name is not None and
163        the specified node is reached.
164    """
165
166    self._input_lists = input_lists
167    self._skip_node_names = skip_node_names
168
169    self._inputs = []
170    self._visited_nodes = []
171    self._depth_count = 0
172    self._depth_list = []
173
174    self._destination_node_name = destination_node_name
175
176  def trace(self, graph_element_name):
177    """Trace inputs.
178
179    Args:
180      graph_element_name: Name of the node or an output tensor of the node, as a
181        str.
182
183    Raises:
184      GraphTracingReachedDestination: if destination_node_name of this tracer
185        object is not None and the specified node is reached.
186    """
187    self._depth_count += 1
188
189    node_name = get_node_name(graph_element_name)
190    if node_name == self._destination_node_name:
191      raise GraphTracingReachedDestination()
192
193    if node_name in self._skip_node_names:
194      return
195    if node_name in self._visited_nodes:
196      return
197
198    self._visited_nodes.append(node_name)
199
200    for input_list in self._input_lists:
201      if node_name not in input_list:
202        continue
203      for inp in input_list[node_name]:
204        if get_node_name(inp) in self._visited_nodes:
205          continue
206        self._inputs.append(inp)
207        self._depth_list.append(self._depth_count)
208        self.trace(inp)
209
210    self._depth_count -= 1
211
212  def inputs(self):
213    return self._inputs
214
215  def depth_list(self):
216    return self._depth_list
217
218
219def _infer_device_name(graph_def):
220  """Infer device name from a partition GraphDef."""
221  device_name = None
222  for node in graph_def.node:
223    if node.device:
224      device_name = node.device
225      break
226  if device_name is None:
227    logging.warn(
228        "Failed to infer device name from partition GraphDef: none of the "
229        "nodes of the GraphDef has a non-empty device name.")
230  return device_name
231
232
233class DebugGraph(object):
234  """Represents a debugger-decorated graph."""
235
236  def __init__(self, debug_graph_def, device_name=None):
237    self._debug_graph_def = debug_graph_def
238    self._non_debug_graph_def = None
239
240    self._node_attributes = {}
241    self._node_inputs = {}
242    self._node_reversed_ref_inputs = {}
243    self._node_ctrl_inputs = {}
244    self._node_recipients = {}
245    self._node_ctrl_recipients = {}
246    self._node_devices = {}
247    self._node_op_types = {}
248    self._copy_send_nodes = []
249    self._ref_args = {}
250
251    self._device_name = device_name
252    if not self._device_name:
253      self._device_name = _infer_device_name(debug_graph_def)
254
255    for node in debug_graph_def.node:
256      self._process_debug_graph_node(node)
257
258    self._prune_non_control_edges_of_debug_ops()
259    self._prune_control_edges_of_debug_ops()
260    self._prune_nodes_from_input_and_recipient_maps(self._get_copy_nodes())
261
262    self._populate_recipient_maps()
263
264  def _process_debug_graph_node(self, node):
265    """Process a node from the debug GraphDef.
266
267    Args:
268      node: (NodeDef) A partition-graph node to be processed.
269
270    Raises:
271      ValueError: If duplicate node names are encountered.
272    """
273    if is_debug_node(node.name):
274      # This is a debug node. Parse the node name and retrieve the
275      # information about debug watches on tensors. But do not include
276      # the node in the graph.
277      return
278
279    if node.name in self._node_inputs:
280      raise ValueError("Duplicate node name on device %s: '%s'" %
281                       (self._device_name, node.name))
282
283    self._node_attributes[node.name] = node.attr
284
285    self._node_inputs[node.name] = []
286    self._node_ctrl_inputs[node.name] = []
287    self._node_recipients[node.name] = []
288    self._node_ctrl_recipients[node.name] = []
289
290    if node.name not in self._node_devices:
291      self._node_devices[node.name] = set()
292    self._node_devices[node.name].add(
293        node.device if node.device else self._device_name)
294    self._node_op_types[node.name] = node.op
295    self._ref_args[node.name] = self._get_ref_args(node)
296
297    for inp in node.input:
298      if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"):
299        self._copy_send_nodes.append(node.name)
300
301      if inp.startswith("^"):
302        cinp = inp[1:]
303        self._node_ctrl_inputs[node.name].append(cinp)
304      else:
305        self._node_inputs[node.name].append(inp)
306
307  def _get_ref_args(self, node):
308    """Determine whether an input of an op is ref-type.
309
310    Args:
311      node: A `NodeDef`.
312
313    Returns:
314      A list of the arg names (as strs) that are ref-type.
315    """
316    op_def = op_def_registry.get(node.op)
317    if op_def is None:
318      return []
319
320    ref_args = []
321    for i, output_arg in enumerate(op_def.output_arg):
322      if output_arg.is_ref:
323        arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
324        ref_args.append(arg_name)
325    return ref_args
326
327  def _get_copy_nodes(self):
328    """Find all Copy nodes in the loaded graph."""
329    copy_nodes = []
330    for node in self._node_inputs:
331      if is_copy_node(node):
332        copy_nodes.append(node)
333    return copy_nodes
334
335  def _prune_non_control_edges_of_debug_ops(self):
336    """Prune (non-control) edges related to debug ops.
337
338    Prune the Copy ops and associated _Send ops inserted by the debugger out
339    from the non-control inputs and output recipients map. Replace the inputs
340    and recipients with original ones.
341    """
342    for node in self._node_inputs:
343      inputs = self._node_inputs[node]
344
345      for i, inp in enumerate(inputs):
346        if is_copy_node(inp):
347          # Find the input to the Copy node, which should be the original
348          # input to the node.
349          orig_inp = self._node_inputs[inp][0]
350          inputs[i] = orig_inp
351
352  def _prune_control_edges_of_debug_ops(self):
353    """Prune control edges related to the debug ops."""
354    for node in self._node_ctrl_inputs:
355      ctrl_inputs = self._node_ctrl_inputs[node]
356      debug_op_inputs = []
357      for ctrl_inp in ctrl_inputs:
358        if is_debug_node(ctrl_inp):
359          debug_op_inputs.append(ctrl_inp)
360      for debug_op_inp in debug_op_inputs:
361        ctrl_inputs.remove(debug_op_inp)
362
363  def _populate_recipient_maps(self):
364    """Populate the map from node name to recipient(s) of its output(s).
365
366    This method also populates the input map based on reversed ref edges.
367    """
368    for node in self._node_inputs:
369      inputs = self._node_inputs[node]
370      for inp in inputs:
371        inp = get_node_name(inp)
372        if inp not in self._node_recipients:
373          self._node_recipients[inp] = []
374        self._node_recipients[inp].append(node)
375
376        if inp in self._ref_args:
377          if inp not in self._node_reversed_ref_inputs:
378            self._node_reversed_ref_inputs[inp] = []
379          self._node_reversed_ref_inputs[inp].append(node)
380
381    for node in self._node_ctrl_inputs:
382      ctrl_inputs = self._node_ctrl_inputs[node]
383      for ctrl_inp in ctrl_inputs:
384        if ctrl_inp in self._copy_send_nodes:
385          continue
386
387        if ctrl_inp not in self._node_ctrl_recipients:
388          self._node_ctrl_recipients[ctrl_inp] = []
389        self._node_ctrl_recipients[ctrl_inp].append(node)
390
391  def _prune_nodes_from_input_and_recipient_maps(self, nodes_to_prune):
392    """Prune nodes out of input and recipient maps.
393
394    Args:
395      nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned.
396    """
397    for node in nodes_to_prune:
398      del self._node_inputs[node]
399      del self._node_ctrl_inputs[node]
400      del self._node_recipients[node]
401      del self._node_ctrl_recipients[node]
402
403  def _reconstruct_non_debug_graph_def(self):
404    """Reconstruct non-debug GraphDef.
405
406    Non-debug GraphDef means the original GraphDef without the Copy* and Debug
407    nodes inserted by the debugger.
408    """
409    if self._non_debug_graph_def:
410      return
411
412    self._non_debug_graph_def = graph_pb2.GraphDef()
413    for node in self._debug_graph_def.node:
414      if is_copy_node(node.name) or is_debug_node(node.name):
415        continue
416
417      new_node = self._non_debug_graph_def.node.add()
418      new_node.CopyFrom(node)
419
420      # Redo the list of inputs, because in _debug_graph_def, the list can
421      # consist of Copy* and Debug* nodes inserted by the debugger. Those will
422      # be replaced with the original inputs here.
423      del new_node.input[:]
424      for inp in self._node_inputs[node.name]:
425        new_node.input.append(inp)
426      for ctrl_inp in self._node_ctrl_inputs[node.name]:
427        new_node.input.append("^" + ctrl_inp)
428
429  @property
430  def device_name(self):
431    return self._device_name
432
433  @property
434  def debug_graph_def(self):
435    """The debugger-decorated GraphDef."""
436    return self._debug_graph_def
437
438  @property
439  def non_debug_graph_def(self):
440    """The GraphDef without the Copy* and Debug* nodes added by the debugger."""
441    self._reconstruct_non_debug_graph_def()
442    return self._non_debug_graph_def
443
444  @property
445  def node_devices(self):
446    return self._node_devices
447
448  @property
449  def node_op_types(self):
450    return self._node_op_types
451
452  @property
453  def node_attributes(self):
454    return self._node_attributes
455
456  @property
457  def node_inputs(self):
458    return self._node_inputs
459
460  @property
461  def node_ctrl_inputs(self):
462    return self._node_ctrl_inputs
463
464  @property
465  def node_reversed_ref_inputs(self):
466    return self._node_reversed_ref_inputs
467
468  @property
469  def node_recipients(self):
470    return self._node_recipients
471
472  @property
473  def node_ctrl_recipients(self):
474    return self._node_ctrl_recipients
475
476
477def reconstruct_non_debug_graph_def(debug_graph_def):
478  """Reconstruct original (non-debugger-decorated) partition GraphDef.
479
480  This method strips the input `tf.compat.v1.GraphDef` of the Copy* and
481  Debug*-type nodes inserted by the debugger.
482
483  The reconstructed partition graph is identical to the original (i.e.,
484    non-debugger-decorated) partition graph except in the following respects:
485      1) The exact names of the runtime-inserted internal nodes may differ.
486         These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops.
487      2) As a consequence of 1, the nodes that receive input directly from such
488         send- and recv-type ops will have different input names.
489      3) The parallel_iteration attribute of while-loop Enter ops are set to 1.
490
491  Args:
492    debug_graph_def: The debugger-decorated `tf.compat.v1.GraphDef`, with the
493      debugger-inserted Copy* and Debug* nodes.
494
495  Returns:
496    The reconstructed `tf.compat.v1.GraphDef` stripped of the debugger-inserted
497    nodes.
498  """
499  return DebugGraph(debug_graph_def).non_debug_graph_def
500