xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/cli/analyzer_cli.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""CLI Backend for the Analyzer Part of the Debugger.
16
17The analyzer performs post hoc analysis of dumped intermediate tensors and
18graph structure information from debugged Session.run() calls.
19"""
20import argparse
21import copy
22import re
23
24
25from tensorflow.python.debug.cli import cli_config
26from tensorflow.python.debug.cli import cli_shared
27from tensorflow.python.debug.cli import command_parser
28from tensorflow.python.debug.cli import debugger_cli_common
29from tensorflow.python.debug.cli import evaluator
30from tensorflow.python.debug.cli import ui_factory
31from tensorflow.python.debug.lib import debug_graphs
32from tensorflow.python.debug.lib import source_utils
33
34RL = debugger_cli_common.RichLine
35
36# String constants for the depth-dependent hanging indent at the beginning
37# of each line.
38HANG_UNFINISHED = "|  "  # Used for unfinished recursion depths.
39HANG_FINISHED = "   "
40HANG_SUFFIX = "|- "
41
42# String constant for displaying depth and op type.
43DEPTH_TEMPLATE = "(%d) "
44OP_TYPE_TEMPLATE = "[%s] "
45
46# String constants for control inputs/outputs, etc.
47CTRL_LABEL = "(Ctrl) "
48ELLIPSIS = "..."
49
50SORT_TENSORS_BY_TIMESTAMP = "timestamp"
51SORT_TENSORS_BY_DUMP_SIZE = "dump_size"
52SORT_TENSORS_BY_OP_TYPE = "op_type"
53SORT_TENSORS_BY_TENSOR_NAME = "tensor_name"
54
55
56def _add_main_menu(output,
57                   node_name=None,
58                   enable_list_tensors=True,
59                   enable_node_info=True,
60                   enable_print_tensor=True,
61                   enable_list_inputs=True,
62                   enable_list_outputs=True):
63  """Generate main menu for the screen output from a command.
64
65  Args:
66    output: (debugger_cli_common.RichTextLines) the output object to modify.
67    node_name: (str or None) name of the node involved (if any). If None,
68      the menu items node_info, list_inputs and list_outputs will be
69      automatically disabled, overriding the values of arguments
70      enable_node_info, enable_list_inputs and enable_list_outputs.
71    enable_list_tensors: (bool) whether the list_tensor menu item will be
72      enabled.
73    enable_node_info: (bool) whether the node_info item will be enabled.
74    enable_print_tensor: (bool) whether the print_tensor item will be enabled.
75    enable_list_inputs: (bool) whether the item list_inputs will be enabled.
76    enable_list_outputs: (bool) whether the item list_outputs will be enabled.
77  """
78
79  menu = debugger_cli_common.Menu()
80
81  menu.append(
82      debugger_cli_common.MenuItem(
83          "list_tensors", "list_tensors", enabled=enable_list_tensors))
84
85  if node_name:
86    menu.append(
87        debugger_cli_common.MenuItem(
88            "node_info",
89            "node_info -a -d -t %s" % node_name,
90            enabled=enable_node_info))
91    menu.append(
92        debugger_cli_common.MenuItem(
93            "print_tensor",
94            "print_tensor %s" % node_name,
95            enabled=enable_print_tensor))
96    menu.append(
97        debugger_cli_common.MenuItem(
98            "list_inputs",
99            "list_inputs -c -r %s" % node_name,
100            enabled=enable_list_inputs))
101    menu.append(
102        debugger_cli_common.MenuItem(
103            "list_outputs",
104            "list_outputs -c -r %s" % node_name,
105            enabled=enable_list_outputs))
106  else:
107    menu.append(
108        debugger_cli_common.MenuItem(
109            "node_info", None, enabled=False))
110    menu.append(
111        debugger_cli_common.MenuItem("print_tensor", None, enabled=False))
112    menu.append(
113        debugger_cli_common.MenuItem("list_inputs", None, enabled=False))
114    menu.append(
115        debugger_cli_common.MenuItem("list_outputs", None, enabled=False))
116
117  menu.append(
118      debugger_cli_common.MenuItem("run_info", "run_info"))
119  menu.append(
120      debugger_cli_common.MenuItem("help", "help"))
121
122  output.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu
123
124
125class DebugAnalyzer(object):
126  """Analyzer for debug data from dump directories."""
127
128  _TIMESTAMP_COLUMN_HEAD = "t (ms)"
129  _DUMP_SIZE_COLUMN_HEAD = "Size (B)"
130  _OP_TYPE_COLUMN_HEAD = "Op type"
131  _TENSOR_NAME_COLUMN_HEAD = "Tensor name"
132
133  # Op types to be omitted when generating descriptions of graph structure.
134  _GRAPH_STRUCT_OP_TYPE_DENYLIST = ("_Send", "_Recv", "_HostSend", "_HostRecv",
135                                    "_Retval")
136
137  def __init__(self, debug_dump, config):
138    """DebugAnalyzer constructor.
139
140    Args:
141      debug_dump: A DebugDumpDir object.
142      config: A `cli_config.CLIConfig` object that carries user-facing
143        configurations.
144    """
145
146    self._debug_dump = debug_dump
147    self._evaluator = evaluator.ExpressionEvaluator(self._debug_dump)
148
149    # Initialize tensor filters state.
150    self._tensor_filters = {}
151
152    self._build_argument_parsers(config)
153    config.set_callback("graph_recursion_depth",
154                        self._build_argument_parsers)
155
156    # TODO(cais): Implement list_nodes.
157
158  def _build_argument_parsers(self, config):
159    """Build argument parsers for DebugAnalayzer.
160
161    Args:
162      config: A `cli_config.CLIConfig` object.
163
164    Returns:
165      A dict mapping command handler name to `ArgumentParser` instance.
166    """
167    # Argument parsers for command handlers.
168    self._arg_parsers = {}
169
170    # Parser for list_tensors.
171    ap = argparse.ArgumentParser(
172        description="List dumped intermediate tensors.",
173        usage=argparse.SUPPRESS)
174    ap.add_argument(
175        "-f",
176        "--tensor_filter",
177        dest="tensor_filter",
178        type=str,
179        default="",
180        help="List only Tensors passing the filter of the specified name")
181    ap.add_argument(
182        "-fenn",
183        "--filter_exclude_node_names",
184        dest="filter_exclude_node_names",
185        type=str,
186        default="",
187        help="When applying the tensor filter, exclude node with names "
188        "matching the regular expression. Applicable only if --tensor_filter "
189        "or -f is used.")
190    ap.add_argument(
191        "-n",
192        "--node_name_filter",
193        dest="node_name_filter",
194        type=str,
195        default="",
196        help="filter node name by regex.")
197    ap.add_argument(
198        "-t",
199        "--op_type_filter",
200        dest="op_type_filter",
201        type=str,
202        default="",
203        help="filter op type by regex.")
204    ap.add_argument(
205        "-s",
206        "--sort_by",
207        dest="sort_by",
208        type=str,
209        default=SORT_TENSORS_BY_TIMESTAMP,
210        help=("the field to sort the data by: (%s | %s | %s | %s)" %
211              (SORT_TENSORS_BY_TIMESTAMP, SORT_TENSORS_BY_DUMP_SIZE,
212               SORT_TENSORS_BY_OP_TYPE, SORT_TENSORS_BY_TENSOR_NAME)))
213    ap.add_argument(
214        "-r",
215        "--reverse",
216        dest="reverse",
217        action="store_true",
218        help="sort the data in reverse (descending) order")
219    self._arg_parsers["list_tensors"] = ap
220
221    # Parser for node_info.
222    ap = argparse.ArgumentParser(
223        description="Show information about a node.", usage=argparse.SUPPRESS)
224    ap.add_argument(
225        "node_name",
226        type=str,
227        help="Name of the node or an associated tensor, e.g., "
228        "hidden1/Wx_plus_b/MatMul, hidden1/Wx_plus_b/MatMul:0")
229    ap.add_argument(
230        "-a",
231        "--attributes",
232        dest="attributes",
233        action="store_true",
234        help="Also list attributes of the node.")
235    ap.add_argument(
236        "-d",
237        "--dumps",
238        dest="dumps",
239        action="store_true",
240        help="Also list dumps available from the node.")
241    ap.add_argument(
242        "-t",
243        "--traceback",
244        dest="traceback",
245        action="store_true",
246        help="Also include the traceback of the node's creation "
247        "(if available in Python).")
248    self._arg_parsers["node_info"] = ap
249
250    # Parser for list_inputs.
251    ap = argparse.ArgumentParser(
252        description="Show inputs to a node.", usage=argparse.SUPPRESS)
253    ap.add_argument(
254        "node_name",
255        type=str,
256        help="Name of the node or an output tensor from the node, e.g., "
257        "hidden1/Wx_plus_b/MatMul, hidden1/Wx_plus_b/MatMul:0")
258    ap.add_argument(
259        "-c", "--control", action="store_true", help="Include control inputs.")
260    ap.add_argument(
261        "-d",
262        "--depth",
263        dest="depth",
264        type=int,
265        default=config.get("graph_recursion_depth"),
266        help="Maximum depth of recursion used when showing the input tree.")
267    ap.add_argument(
268        "-r",
269        "--recursive",
270        dest="recursive",
271        action="store_true",
272        help="Show inputs to the node recursively, i.e., the input tree.")
273    ap.add_argument(
274        "-t",
275        "--op_type",
276        action="store_true",
277        help="Show op types of input nodes.")
278    self._arg_parsers["list_inputs"] = ap
279
280    # Parser for list_outputs.
281    ap = argparse.ArgumentParser(
282        description="Show the nodes that receive the outputs of given node.",
283        usage=argparse.SUPPRESS)
284    ap.add_argument(
285        "node_name",
286        type=str,
287        help="Name of the node or an output tensor from the node, e.g., "
288        "hidden1/Wx_plus_b/MatMul, hidden1/Wx_plus_b/MatMul:0")
289    ap.add_argument(
290        "-c", "--control", action="store_true", help="Include control inputs.")
291    ap.add_argument(
292        "-d",
293        "--depth",
294        dest="depth",
295        type=int,
296        default=config.get("graph_recursion_depth"),
297        help="Maximum depth of recursion used when showing the output tree.")
298    ap.add_argument(
299        "-r",
300        "--recursive",
301        dest="recursive",
302        action="store_true",
303        help="Show recipients of the node recursively, i.e., the output "
304        "tree.")
305    ap.add_argument(
306        "-t",
307        "--op_type",
308        action="store_true",
309        help="Show op types of recipient nodes.")
310    self._arg_parsers["list_outputs"] = ap
311
312    # Parser for print_tensor.
313    self._arg_parsers["print_tensor"] = (
314        command_parser.get_print_tensor_argparser(
315            "Print the value of a dumped tensor."))
316
317    # Parser for print_source.
318    ap = argparse.ArgumentParser(
319        description="Print a Python source file with overlaid debug "
320        "information, including the nodes (ops) or Tensors created at the "
321        "source lines.",
322        usage=argparse.SUPPRESS)
323    ap.add_argument(
324        "source_file_path",
325        type=str,
326        help="Path to the source file.")
327    ap.add_argument(
328        "-t",
329        "--tensors",
330        dest="tensors",
331        action="store_true",
332        help="Label lines with dumped Tensors, instead of ops.")
333    ap.add_argument(
334        "-m",
335        "--max_elements_per_line",
336        type=int,
337        default=10,
338        help="Maximum number of elements (ops or Tensors) to show per source "
339             "line.")
340    ap.add_argument(
341        "-b",
342        "--line_begin",
343        type=int,
344        default=1,
345        help="Print source beginning at line number (1-based.)")
346    self._arg_parsers["print_source"] = ap
347
348    # Parser for list_source.
349    ap = argparse.ArgumentParser(
350        description="List source files responsible for constructing nodes and "
351        "tensors present in the run().",
352        usage=argparse.SUPPRESS)
353    ap.add_argument(
354        "-p",
355        "--path_filter",
356        type=str,
357        default="",
358        help="Regular expression filter for file path.")
359    ap.add_argument(
360        "-n",
361        "--node_name_filter",
362        type=str,
363        default="",
364        help="Regular expression filter for node name.")
365    self._arg_parsers["list_source"] = ap
366
367    # Parser for eval.
368    ap = argparse.ArgumentParser(
369        description="""Evaluate an arbitrary expression. Can use tensor values
370        from the current debug dump. The debug tensor names should be enclosed
371        in pairs of backticks. Expressions with spaces should be enclosed in
372        a pair of double quotes or a pair of single quotes. By default, numpy
373        is imported as np and can be used in the expressions. E.g.,
374          1) eval np.argmax(`Softmax:0`),
375          2) eval 'np.sum(`Softmax:0`, axis=1)',
376          3) eval "np.matmul((`output/Identity:0`/`Softmax:0`).T, `Softmax:0`)".
377        """,
378        usage=argparse.SUPPRESS)
379    ap.add_argument(
380        "expression",
381        type=str,
382        help="""Expression to be evaluated.
383        1) in the simplest case, use <node_name>:<output_slot>, e.g.,
384          hidden_0/MatMul:0.
385
386        2) if the default debug op "DebugIdentity" is to be overridden, use
387          <node_name>:<output_slot>:<debug_op>, e.g.,
388          hidden_0/MatMul:0:DebugNumericSummary.
389
390        3) if the tensor of the same name exists on more than one device, use
391          <device_name>:<node_name>:<output_slot>[:<debug_op>], e.g.,
392          /job:worker/replica:0/task:0/gpu:0:hidden_0/MatMul:0
393          /job:worker/replica:0/task:2/cpu:0:hidden_0/MatMul:0:DebugNanCount.
394
395        4) if the tensor is executed multiple times in a given `Session.run`
396        call, specify the execution index with a 0-based integer enclose in a
397        pair of brackets at the end, e.g.,
398          RNN/tanh:0[0]
399          /job:worker/replica:0/task:0/gpu:0:RNN/tanh:0[0].""")
400    ap.add_argument(
401        "-a",
402        "--all",
403        dest="print_all",
404        action="store_true",
405        help="Print the tensor in its entirety, i.e., do not use ellipses "
406        "(may be slow for large results).")
407    ap.add_argument(
408        "-w",
409        "--write_path",
410        default="",
411        help="Path of the numpy file to write the evaluation result to, "
412        "using numpy.save()")
413    self._arg_parsers["eval"] = ap
414
415  def add_tensor_filter(self, filter_name, filter_callable):
416    """Add a tensor filter.
417
418    A tensor filter is a named callable of the signature:
419      filter_callable(dump_datum, tensor),
420
421    wherein dump_datum is an instance of debug_data.DebugTensorDatum carrying
422    metadata about the dumped tensor, including tensor name, timestamps, etc.
423    tensor is the value of the dumped tensor as an numpy.ndarray object.
424    The return value of the function is a bool.
425    This is the same signature as the input argument to
426    debug_data.DebugDumpDir.find().
427
428    Args:
429      filter_name: (str) name of the filter. Cannot be empty.
430      filter_callable: (callable) a filter function of the signature described
431        as above.
432
433    Raises:
434      ValueError: If filter_name is an empty str.
435      TypeError: If filter_name is not a str.
436                 Or if filter_callable is not callable.
437    """
438
439    if not isinstance(filter_name, str):
440      raise TypeError("Input argument filter_name is expected to be str, "
441                      "but is not.")
442
443    # Check that filter_name is not an empty str.
444    if not filter_name:
445      raise ValueError("Input argument filter_name cannot be empty.")
446
447    # Check that filter_callable is callable.
448    if not callable(filter_callable):
449      raise TypeError(
450          "Input argument filter_callable is expected to be callable, "
451          "but is not.")
452
453    self._tensor_filters[filter_name] = filter_callable
454
455  def get_tensor_filter(self, filter_name):
456    """Retrieve filter function by name.
457
458    Args:
459      filter_name: Name of the filter set during add_tensor_filter() call.
460
461    Returns:
462      The callable associated with the filter name.
463
464    Raises:
465      ValueError: If there is no tensor filter of the specified filter name.
466    """
467
468    if filter_name not in self._tensor_filters:
469      raise ValueError("There is no tensor filter named \"%s\"" % filter_name)
470
471    return self._tensor_filters[filter_name]
472
473  def get_help(self, handler_name):
474    return self._arg_parsers[handler_name].format_help()
475
476  def list_tensors(self, args, screen_info=None):
477    """Command handler for list_tensors.
478
479    List tensors dumped during debugged Session.run() call.
480
481    Args:
482      args: Command-line arguments, excluding the command prefix, as a list of
483        str.
484      screen_info: Optional dict input containing screen information such as
485        cols.
486
487    Returns:
488      Output text lines as a RichTextLines object.
489
490    Raises:
491      ValueError: If `--filter_exclude_node_names` is used without `-f` or
492        `--tensor_filter` being used.
493    """
494
495    # TODO(cais): Add annotations of substrings for dumped tensor names, to
496    # facilitate on-screen highlighting/selection of node names.
497    _ = screen_info
498
499    parsed = self._arg_parsers["list_tensors"].parse_args(args)
500
501    output = []
502
503    filter_strs = []
504    if parsed.op_type_filter:
505      op_type_regex = re.compile(parsed.op_type_filter)
506      filter_strs.append("Op type regex filter: \"%s\"" % parsed.op_type_filter)
507    else:
508      op_type_regex = None
509
510    if parsed.node_name_filter:
511      node_name_regex = re.compile(parsed.node_name_filter)
512      filter_strs.append("Node name regex filter: \"%s\"" %
513                         parsed.node_name_filter)
514    else:
515      node_name_regex = None
516
517    output = debugger_cli_common.RichTextLines(filter_strs)
518    output.append("")
519
520    if parsed.tensor_filter:
521      try:
522        filter_callable = self.get_tensor_filter(parsed.tensor_filter)
523      except ValueError:
524        output = cli_shared.error("There is no tensor filter named \"%s\"." %
525                                  parsed.tensor_filter)
526        _add_main_menu(output, node_name=None, enable_list_tensors=False)
527        return output
528
529      data_to_show = self._debug_dump.find(
530          filter_callable,
531          exclude_node_names=parsed.filter_exclude_node_names)
532    else:
533      if parsed.filter_exclude_node_names:
534        raise ValueError(
535            "The flag --filter_exclude_node_names is valid only when "
536            "the flag -f or --tensor_filter is used.")
537
538      data_to_show = self._debug_dump.dumped_tensor_data
539
540    # TODO(cais): Implement filter by lambda on tensor value.
541
542    max_timestamp_width, max_dump_size_width, max_op_type_width = (
543        self._measure_tensor_list_column_widths(data_to_show))
544
545    # Sort the data.
546    data_to_show = self._sort_dump_data_by(
547        data_to_show, parsed.sort_by, parsed.reverse)
548
549    output.extend(
550        self._tensor_list_column_heads(parsed, max_timestamp_width,
551                                       max_dump_size_width, max_op_type_width))
552
553    dump_count = 0
554    for dump in data_to_show:
555      if node_name_regex and not node_name_regex.match(dump.node_name):
556        continue
557
558      if op_type_regex:
559        op_type = self._debug_dump.node_op_type(dump.node_name)
560        if not op_type_regex.match(op_type):
561          continue
562
563      rel_time = (dump.timestamp - self._debug_dump.t0) / 1000.0
564      dump_size_str = cli_shared.bytes_to_readable_str(dump.dump_size_bytes)
565      dumped_tensor_name = "%s:%d" % (dump.node_name, dump.output_slot)
566      op_type = self._debug_dump.node_op_type(dump.node_name)
567
568      line = "[%.3f]" % rel_time
569      line += " " * (max_timestamp_width - len(line))
570      line += dump_size_str
571      line += " " * (max_timestamp_width + max_dump_size_width - len(line))
572      line += op_type
573      line += " " * (max_timestamp_width + max_dump_size_width +
574                     max_op_type_width - len(line))
575      line += dumped_tensor_name
576
577      output.append(
578          line,
579          font_attr_segs=[(
580              len(line) - len(dumped_tensor_name), len(line),
581              debugger_cli_common.MenuItem("", "pt %s" % dumped_tensor_name))])
582      dump_count += 1
583
584    if parsed.tensor_filter:
585      output.prepend([
586          "%d dumped tensor(s) passing filter \"%s\":" %
587          (dump_count, parsed.tensor_filter)
588      ])
589    else:
590      output.prepend(["%d dumped tensor(s):" % dump_count])
591
592    _add_main_menu(output, node_name=None, enable_list_tensors=False)
593    return output
594
595  def _measure_tensor_list_column_widths(self, data):
596    """Determine the maximum widths of the timestamp and op-type column.
597
598    This method assumes that data is sorted in the default order, i.e.,
599    by ascending timestamps.
600
601    Args:
602      data: (list of DebugTensorDaum) the data based on which the maximum
603        column widths will be determined.
604
605    Returns:
606      (int) maximum width of the timestamp column. 0 if data is empty.
607      (int) maximum width of the dump size column. 0 if data is empty.
608      (int) maximum width of the op type column. 0 if data is empty.
609    """
610
611    max_timestamp_width = 0
612    if data:
613      max_rel_time_ms = (data[-1].timestamp - self._debug_dump.t0) / 1000.0
614      max_timestamp_width = len("[%.3f] " % max_rel_time_ms) + 1
615    max_timestamp_width = max(max_timestamp_width,
616                              len(self._TIMESTAMP_COLUMN_HEAD) + 1)
617
618    max_dump_size_width = 0
619    for dump in data:
620      dump_size_str = cli_shared.bytes_to_readable_str(dump.dump_size_bytes)
621      if len(dump_size_str) + 1 > max_dump_size_width:
622        max_dump_size_width = len(dump_size_str) + 1
623    max_dump_size_width = max(max_dump_size_width,
624                              len(self._DUMP_SIZE_COLUMN_HEAD) + 1)
625
626    max_op_type_width = 0
627    for dump in data:
628      op_type = self._debug_dump.node_op_type(dump.node_name)
629      if len(op_type) + 1 > max_op_type_width:
630        max_op_type_width = len(op_type) + 1
631    max_op_type_width = max(max_op_type_width,
632                            len(self._OP_TYPE_COLUMN_HEAD) + 1)
633
634    return max_timestamp_width, max_dump_size_width, max_op_type_width
635
636  def _sort_dump_data_by(self, data, sort_by, reverse):
637    """Sort a list of DebugTensorDatum in specified order.
638
639    Args:
640      data: (list of DebugTensorDatum) the data to be sorted.
641      sort_by: The field to sort data by.
642      reverse: (bool) Whether to use reversed (descending) order.
643
644    Returns:
645      (list of DebugTensorDatum) in sorted order.
646
647    Raises:
648      ValueError: given an invalid value of sort_by.
649    """
650
651    if sort_by == SORT_TENSORS_BY_TIMESTAMP:
652      return sorted(
653          data,
654          reverse=reverse,
655          key=lambda x: x.timestamp)
656    elif sort_by == SORT_TENSORS_BY_DUMP_SIZE:
657      return sorted(data, reverse=reverse, key=lambda x: x.dump_size_bytes)
658    elif sort_by == SORT_TENSORS_BY_OP_TYPE:
659      return sorted(
660          data,
661          reverse=reverse,
662          key=lambda x: self._debug_dump.node_op_type(x.node_name))
663    elif sort_by == SORT_TENSORS_BY_TENSOR_NAME:
664      return sorted(
665          data,
666          reverse=reverse,
667          key=lambda x: "%s:%d" % (x.node_name, x.output_slot))
668    else:
669      raise ValueError("Unsupported key to sort tensors by: %s" % sort_by)
670
671  def _tensor_list_column_heads(self, parsed, max_timestamp_width,
672                                max_dump_size_width, max_op_type_width):
673    """Generate a line containing the column heads of the tensor list.
674
675    Args:
676      parsed: Parsed arguments (by argparse) of the list_tensors command.
677      max_timestamp_width: (int) maximum width of the timestamp column.
678      max_dump_size_width: (int) maximum width of the dump size column.
679      max_op_type_width: (int) maximum width of the op type column.
680
681    Returns:
682      A RichTextLines object.
683    """
684
685    base_command = "list_tensors"
686    if parsed.tensor_filter:
687      base_command += " -f %s" % parsed.tensor_filter
688    if parsed.op_type_filter:
689      base_command += " -t %s" % parsed.op_type_filter
690    if parsed.node_name_filter:
691      base_command += " -n %s" % parsed.node_name_filter
692
693    attr_segs = {0: []}
694    row = self._TIMESTAMP_COLUMN_HEAD
695    command = "%s -s %s" % (base_command, SORT_TENSORS_BY_TIMESTAMP)
696    if parsed.sort_by == SORT_TENSORS_BY_TIMESTAMP and not parsed.reverse:
697      command += " -r"
698    attr_segs[0].append(
699        (0, len(row), [debugger_cli_common.MenuItem(None, command), "bold"]))
700    row += " " * (max_timestamp_width - len(row))
701
702    prev_len = len(row)
703    row += self._DUMP_SIZE_COLUMN_HEAD
704    command = "%s -s %s" % (base_command, SORT_TENSORS_BY_DUMP_SIZE)
705    if parsed.sort_by == SORT_TENSORS_BY_DUMP_SIZE and not parsed.reverse:
706      command += " -r"
707    attr_segs[0].append((prev_len, len(row),
708                         [debugger_cli_common.MenuItem(None, command), "bold"]))
709    row += " " * (max_dump_size_width + max_timestamp_width - len(row))
710
711    prev_len = len(row)
712    row += self._OP_TYPE_COLUMN_HEAD
713    command = "%s -s %s" % (base_command, SORT_TENSORS_BY_OP_TYPE)
714    if parsed.sort_by == SORT_TENSORS_BY_OP_TYPE and not parsed.reverse:
715      command += " -r"
716    attr_segs[0].append((prev_len, len(row),
717                         [debugger_cli_common.MenuItem(None, command), "bold"]))
718    row += " " * (
719        max_op_type_width + max_dump_size_width + max_timestamp_width - len(row)
720    )
721
722    prev_len = len(row)
723    row += self._TENSOR_NAME_COLUMN_HEAD
724    command = "%s -s %s" % (base_command, SORT_TENSORS_BY_TENSOR_NAME)
725    if parsed.sort_by == SORT_TENSORS_BY_TENSOR_NAME and not parsed.reverse:
726      command += " -r"
727    attr_segs[0].append((prev_len, len(row),
728                         [debugger_cli_common.MenuItem("", command), "bold"]))
729    row += " " * (
730        max_op_type_width + max_dump_size_width + max_timestamp_width - len(row)
731    )
732
733    return debugger_cli_common.RichTextLines([row], font_attr_segs=attr_segs)
734
735  def node_info(self, args, screen_info=None):
736    """Command handler for node_info.
737
738    Query information about a given node.
739
740    Args:
741      args: Command-line arguments, excluding the command prefix, as a list of
742        str.
743      screen_info: Optional dict input containing screen information such as
744        cols.
745
746    Returns:
747      Output text lines as a RichTextLines object.
748    """
749
750    # TODO(cais): Add annotation of substrings for node names, to facilitate
751    # on-screen highlighting/selection of node names.
752    _ = screen_info
753
754    parsed = self._arg_parsers["node_info"].parse_args(args)
755
756    # Get a node name, regardless of whether the input is a node name (without
757    # output slot attached) or a tensor name (with output slot attached).
758    node_name, unused_slot = debug_graphs.parse_node_or_tensor_name(
759        parsed.node_name)
760
761    if not self._debug_dump.node_exists(node_name):
762      output = cli_shared.error(
763          "There is no node named \"%s\" in the partition graphs" % node_name)
764      _add_main_menu(
765          output,
766          node_name=None,
767          enable_list_tensors=True,
768          enable_node_info=False,
769          enable_list_inputs=False,
770          enable_list_outputs=False)
771      return output
772
773    # TODO(cais): Provide UI glossary feature to explain to users what the
774    # term "partition graph" means and how it is related to TF graph objects
775    # in Python. The information can be along the line of:
776    # "A tensorflow graph defined in Python is stripped of unused ops
777    # according to the feeds and fetches and divided into a number of
778    # partition graphs that may be distributed among multiple devices and
779    # hosts. The partition graphs are what's actually executed by the C++
780    # runtime during a run() call."
781
782    lines = ["Node %s" % node_name]
783    font_attr_segs = {
784        0: [(len(lines[-1]) - len(node_name), len(lines[-1]), "bold")]
785    }
786    lines.append("")
787    lines.append("  Op: %s" % self._debug_dump.node_op_type(node_name))
788    lines.append("  Device: %s" % self._debug_dump.node_device(node_name))
789    output = debugger_cli_common.RichTextLines(
790        lines, font_attr_segs=font_attr_segs)
791
792    # List node inputs (non-control and control).
793    inputs = self._exclude_denylisted_ops(
794        self._debug_dump.node_inputs(node_name))
795    ctrl_inputs = self._exclude_denylisted_ops(
796        self._debug_dump.node_inputs(node_name, is_control=True))
797    output.extend(self._format_neighbors("input", inputs, ctrl_inputs))
798
799    # List node output recipients (non-control and control).
800    recs = self._exclude_denylisted_ops(
801        self._debug_dump.node_recipients(node_name))
802    ctrl_recs = self._exclude_denylisted_ops(
803        self._debug_dump.node_recipients(node_name, is_control=True))
804    output.extend(self._format_neighbors("recipient", recs, ctrl_recs))
805
806    # Optional: List attributes of the node.
807    if parsed.attributes:
808      output.extend(self._list_node_attributes(node_name))
809
810    # Optional: List dumps available from the node.
811    if parsed.dumps:
812      output.extend(self._list_node_dumps(node_name))
813
814    if parsed.traceback:
815      output.extend(self._render_node_traceback(node_name))
816
817    _add_main_menu(output, node_name=node_name, enable_node_info=False)
818    return output
819
820  def _exclude_denylisted_ops(self, node_names):
821    """Exclude all nodes whose op types are in _GRAPH_STRUCT_OP_TYPE_DENYLIST.
822
823    Args:
824      node_names: An iterable of node or graph element names.
825
826    Returns:
827      A list of node names that are not denylisted.
828    """
829    return [
830        node_name for node_name in node_names
831        if self._debug_dump.node_op_type(debug_graphs.get_node_name(node_name))
832        not in self._GRAPH_STRUCT_OP_TYPE_DENYLIST
833    ]
834
835  def _render_node_traceback(self, node_name):
836    """Render traceback of a node's creation in Python, if available.
837
838    Args:
839      node_name: (str) name of the node.
840
841    Returns:
842      A RichTextLines object containing the stack trace of the node's
843      construction.
844    """
845
846    lines = [RL(""), RL(""), RL("Traceback of node construction:", "bold")]
847
848    try:
849      node_stack = self._debug_dump.node_traceback(node_name)
850      for depth, (file_path, line, function_name, text) in enumerate(
851          node_stack):
852        lines.append("%d: %s" % (depth, file_path))
853
854        attribute = debugger_cli_common.MenuItem(
855            "", "ps %s -b %d" % (file_path, line)) if text else None
856        line_number_line = RL("  ")
857        line_number_line += RL("Line:     %d" % line, attribute)
858        lines.append(line_number_line)
859
860        lines.append("  Function: %s" % function_name)
861        lines.append("  Text:     " + (("\"%s\"" % text) if text else "None"))
862        lines.append("")
863    except KeyError:
864      lines.append("(Node unavailable in the loaded Python graph)")
865    except LookupError:
866      lines.append("(Unavailable because no Python graph has been loaded)")
867
868    return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
869
870  def list_inputs(self, args, screen_info=None):
871    """Command handler for inputs.
872
873    Show inputs to a given node.
874
875    Args:
876      args: Command-line arguments, excluding the command prefix, as a list of
877        str.
878      screen_info: Optional dict input containing screen information such as
879        cols.
880
881    Returns:
882      Output text lines as a RichTextLines object.
883    """
884
885    # Screen info not currently used by this handler. Include this line to
886    # mute pylint.
887    _ = screen_info
888    # TODO(cais): Use screen info to format the output lines more prettily,
889    # e.g., hanging indent of long node names.
890
891    parsed = self._arg_parsers["list_inputs"].parse_args(args)
892
893    output = self._list_inputs_or_outputs(
894        parsed.recursive,
895        parsed.node_name,
896        parsed.depth,
897        parsed.control,
898        parsed.op_type,
899        do_outputs=False)
900
901    node_name = debug_graphs.get_node_name(parsed.node_name)
902    _add_main_menu(output, node_name=node_name, enable_list_inputs=False)
903
904    return output
905
906  def print_tensor(self, args, screen_info=None):
907    """Command handler for print_tensor.
908
909    Print value of a given dumped tensor.
910
911    Args:
912      args: Command-line arguments, excluding the command prefix, as a list of
913        str.
914      screen_info: Optional dict input containing screen information such as
915        cols.
916
917    Returns:
918      Output text lines as a RichTextLines object.
919    """
920
921    parsed = self._arg_parsers["print_tensor"].parse_args(args)
922
923    np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
924        screen_info)
925
926    # Determine if any range-highlighting is required.
927    highlight_options = cli_shared.parse_ranges_highlight(parsed.ranges)
928
929    tensor_name, tensor_slicing = (
930        command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
931
932    node_name, output_slot = debug_graphs.parse_node_or_tensor_name(tensor_name)
933    if (self._debug_dump.loaded_partition_graphs() and
934        not self._debug_dump.node_exists(node_name)):
935      output = cli_shared.error(
936          "Node \"%s\" does not exist in partition graphs" % node_name)
937      _add_main_menu(
938          output,
939          node_name=None,
940          enable_list_tensors=True,
941          enable_print_tensor=False)
942      return output
943
944    watch_keys = self._debug_dump.debug_watch_keys(node_name)
945    if output_slot is None:
946      output_slots = set()
947      for watch_key in watch_keys:
948        output_slots.add(int(watch_key.split(":")[1]))
949
950      if len(output_slots) == 1:
951        # There is only one dumped tensor from this node, so there is no
952        # ambiguity. Proceed to show the only dumped tensor.
953        output_slot = list(output_slots)[0]
954      else:
955        # There are more than one dumped tensors from this node. Indicate as
956        # such.
957        # TODO(cais): Provide an output screen with command links for
958        # convenience.
959        lines = [
960            "Node \"%s\" generated debug dumps from %s output slots:" %
961            (node_name, len(output_slots)),
962            "Please specify the output slot: %s:x." % node_name
963        ]
964        output = debugger_cli_common.RichTextLines(lines)
965        _add_main_menu(
966            output,
967            node_name=node_name,
968            enable_list_tensors=True,
969            enable_print_tensor=False)
970        return output
971
972    # Find debug dump data that match the tensor name (node name + output
973    # slot).
974    matching_data = []
975    for watch_key in watch_keys:
976      debug_tensor_data = self._debug_dump.watch_key_to_data(watch_key)
977      for datum in debug_tensor_data:
978        if datum.output_slot == output_slot:
979          matching_data.append(datum)
980
981    if not matching_data:
982      # No dump for this tensor.
983      output = cli_shared.error("Tensor \"%s\" did not generate any dumps." %
984                                parsed.tensor_name)
985    elif len(matching_data) == 1:
986      # There is only one dump for this tensor.
987      if parsed.number <= 0:
988        output = cli_shared.format_tensor(
989            matching_data[0].get_tensor(),
990            matching_data[0].watch_key,
991            np_printoptions,
992            print_all=parsed.print_all,
993            tensor_slicing=tensor_slicing,
994            highlight_options=highlight_options,
995            include_numeric_summary=parsed.numeric_summary,
996            write_path=parsed.write_path)
997      else:
998        output = cli_shared.error(
999            "Invalid number (%d) for tensor %s, which generated one dump." %
1000            (parsed.number, parsed.tensor_name))
1001
1002      _add_main_menu(output, node_name=node_name, enable_print_tensor=False)
1003    else:
1004      # There are more than one dumps for this tensor.
1005      if parsed.number < 0:
1006        lines = [
1007            "Tensor \"%s\" generated %d dumps:" % (parsed.tensor_name,
1008                                                   len(matching_data))
1009        ]
1010        font_attr_segs = {}
1011
1012        for i, datum in enumerate(matching_data):
1013          rel_time = (datum.timestamp - self._debug_dump.t0) / 1000.0
1014          lines.append("#%d [%.3f ms] %s" % (i, rel_time, datum.watch_key))
1015          command = "print_tensor %s -n %d" % (parsed.tensor_name, i)
1016          font_attr_segs[len(lines) - 1] = [(
1017              len(lines[-1]) - len(datum.watch_key), len(lines[-1]),
1018              debugger_cli_common.MenuItem(None, command))]
1019
1020        lines.append("")
1021        lines.append(
1022            "You can use the -n (--number) flag to specify which dump to "
1023            "print.")
1024        lines.append("For example:")
1025        lines.append("  print_tensor %s -n 0" % parsed.tensor_name)
1026
1027        output = debugger_cli_common.RichTextLines(
1028            lines, font_attr_segs=font_attr_segs)
1029      elif parsed.number >= len(matching_data):
1030        output = cli_shared.error(
1031            "Specified number (%d) exceeds the number of available dumps "
1032            "(%d) for tensor %s" %
1033            (parsed.number, len(matching_data), parsed.tensor_name))
1034      else:
1035        output = cli_shared.format_tensor(
1036            matching_data[parsed.number].get_tensor(),
1037            matching_data[parsed.number].watch_key + " (dump #%d)" %
1038            parsed.number,
1039            np_printoptions,
1040            print_all=parsed.print_all,
1041            tensor_slicing=tensor_slicing,
1042            highlight_options=highlight_options,
1043            write_path=parsed.write_path)
1044      _add_main_menu(output, node_name=node_name, enable_print_tensor=False)
1045
1046    return output
1047
1048  def list_outputs(self, args, screen_info=None):
1049    """Command handler for inputs.
1050
1051    Show inputs to a given node.
1052
1053    Args:
1054      args: Command-line arguments, excluding the command prefix, as a list of
1055        str.
1056      screen_info: Optional dict input containing screen information such as
1057        cols.
1058
1059    Returns:
1060      Output text lines as a RichTextLines object.
1061    """
1062
1063    # Screen info not currently used by this handler. Include this line to
1064    # mute pylint.
1065    _ = screen_info
1066    # TODO(cais): Use screen info to format the output lines more prettily,
1067    # e.g., hanging indent of long node names.
1068
1069    parsed = self._arg_parsers["list_outputs"].parse_args(args)
1070
1071    output = self._list_inputs_or_outputs(
1072        parsed.recursive,
1073        parsed.node_name,
1074        parsed.depth,
1075        parsed.control,
1076        parsed.op_type,
1077        do_outputs=True)
1078
1079    node_name = debug_graphs.get_node_name(parsed.node_name)
1080    _add_main_menu(output, node_name=node_name, enable_list_outputs=False)
1081
1082    return output
1083
1084  def evaluate_expression(self, args, screen_info=None):
1085    parsed = self._arg_parsers["eval"].parse_args(args)
1086
1087    eval_res = self._evaluator.evaluate(parsed.expression)
1088
1089    np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
1090        screen_info)
1091    return cli_shared.format_tensor(
1092        eval_res,
1093        "from eval of expression '%s'" % parsed.expression,
1094        np_printoptions,
1095        print_all=parsed.print_all,
1096        include_numeric_summary=True,
1097        write_path=parsed.write_path)
1098
1099  def _reconstruct_print_source_command(self,
1100                                        parsed,
1101                                        line_begin,
1102                                        max_elements_per_line_increase=0):
1103    return "ps %s %s -b %d -m %d" % (
1104        parsed.source_file_path, "-t" if parsed.tensors else "", line_begin,
1105        parsed.max_elements_per_line + max_elements_per_line_increase)
1106
1107  def print_source(self, args, screen_info=None):
1108    """Print the content of a source file."""
1109    del screen_info  # Unused.
1110
1111    parsed = self._arg_parsers["print_source"].parse_args(args)
1112
1113    source_annotation = source_utils.annotate_source(
1114        self._debug_dump,
1115        parsed.source_file_path,
1116        do_dumped_tensors=parsed.tensors)
1117
1118    source_lines, line_num_width = source_utils.load_source(
1119        parsed.source_file_path)
1120
1121    labeled_source_lines = []
1122    actual_initial_scroll_target = 0
1123    for i, line in enumerate(source_lines):
1124      annotated_line = RL("L%d" % (i + 1), cli_shared.COLOR_YELLOW)
1125      annotated_line += " " * (line_num_width - len(annotated_line))
1126      annotated_line += line
1127      labeled_source_lines.append(annotated_line)
1128
1129      if i + 1 == parsed.line_begin:
1130        actual_initial_scroll_target = len(labeled_source_lines) - 1
1131
1132      if i + 1 in source_annotation:
1133        sorted_elements = sorted(source_annotation[i + 1])
1134        for k, element in enumerate(sorted_elements):
1135          if k >= parsed.max_elements_per_line:
1136            omitted_info_line = RL("    (... Omitted %d of %d %s ...) " % (
1137                len(sorted_elements) - parsed.max_elements_per_line,
1138                len(sorted_elements),
1139                "tensor(s)" if parsed.tensors else "op(s)"))
1140            omitted_info_line += RL(
1141                "+5",
1142                debugger_cli_common.MenuItem(
1143                    None,
1144                    self._reconstruct_print_source_command(
1145                        parsed, i + 1, max_elements_per_line_increase=5)))
1146            labeled_source_lines.append(omitted_info_line)
1147            break
1148
1149          label = RL(" " * 4)
1150          if self._debug_dump.debug_watch_keys(
1151              debug_graphs.get_node_name(element)):
1152            attribute = debugger_cli_common.MenuItem("", "pt %s" % element)
1153          else:
1154            attribute = cli_shared.COLOR_BLUE
1155
1156          label += RL(element, attribute)
1157          labeled_source_lines.append(label)
1158
1159    output = debugger_cli_common.rich_text_lines_from_rich_line_list(
1160        labeled_source_lines,
1161        annotations={debugger_cli_common.INIT_SCROLL_POS_KEY:
1162                     actual_initial_scroll_target})
1163    _add_main_menu(output, node_name=None)
1164    return output
1165
1166  def _make_source_table(self, source_list, is_tf_py_library):
1167    """Make a table summarizing the source files that create nodes and tensors.
1168
1169    Args:
1170      source_list: List of source files and related information as a list of
1171        tuples (file_path, is_tf_library, num_nodes, num_tensors, num_dumps,
1172        first_line).
1173      is_tf_py_library: (`bool`) whether this table is for files that belong
1174        to the TensorFlow Python library.
1175
1176    Returns:
1177      The table as a `debugger_cli_common.RichTextLines` object.
1178    """
1179    path_head = "Source file path"
1180    num_nodes_head = "#(nodes)"
1181    num_tensors_head = "#(tensors)"
1182    num_dumps_head = "#(tensor dumps)"
1183
1184    if is_tf_py_library:
1185      # Use color to mark files that are guessed to belong to TensorFlow Python
1186      # library.
1187      color = cli_shared.COLOR_GRAY
1188      lines = [RL("TensorFlow Python library file(s):", color)]
1189    else:
1190      color = cli_shared.COLOR_WHITE
1191      lines = [RL("File(s) outside TensorFlow Python library:", color)]
1192
1193    if not source_list:
1194      lines.append(RL("[No files.]"))
1195      lines.append(RL())
1196      return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
1197
1198    path_column_width = max(
1199        max(len(item[0]) for item in source_list), len(path_head)) + 1
1200    num_nodes_column_width = max(
1201        max(len(str(item[2])) for item in source_list),
1202        len(num_nodes_head)) + 1
1203    num_tensors_column_width = max(
1204        max(len(str(item[3])) for item in source_list),
1205        len(num_tensors_head)) + 1
1206
1207    head = RL(path_head + " " * (path_column_width - len(path_head)), color)
1208    head += RL(num_nodes_head + " " * (
1209        num_nodes_column_width - len(num_nodes_head)), color)
1210    head += RL(num_tensors_head + " " * (
1211        num_tensors_column_width - len(num_tensors_head)), color)
1212    head += RL(num_dumps_head, color)
1213
1214    lines.append(head)
1215
1216    for (file_path, _, num_nodes, num_tensors, num_dumps,
1217         first_line_num) in source_list:
1218      path_attributes = [color]
1219      if source_utils.is_extension_uncompiled_python_source(file_path):
1220        path_attributes.append(
1221            debugger_cli_common.MenuItem(None, "ps %s -b %d" %
1222                                         (file_path, first_line_num)))
1223
1224      line = RL(file_path, path_attributes)
1225      line += " " * (path_column_width - len(line))
1226      line += RL(
1227          str(num_nodes) + " " * (num_nodes_column_width - len(str(num_nodes))),
1228          color)
1229      line += RL(
1230          str(num_tensors) + " " *
1231          (num_tensors_column_width - len(str(num_tensors))), color)
1232      line += RL(str(num_dumps), color)
1233      lines.append(line)
1234    lines.append(RL())
1235
1236    return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
1237
1238  def list_source(self, args, screen_info=None):
1239    """List Python source files that constructed nodes and tensors."""
1240    del screen_info  # Unused.
1241
1242    parsed = self._arg_parsers["list_source"].parse_args(args)
1243    source_list = source_utils.list_source_files_against_dump(
1244        self._debug_dump,
1245        path_regex_allowlist=parsed.path_filter,
1246        node_name_regex_allowlist=parsed.node_name_filter)
1247
1248    top_lines = [
1249        RL("List of source files that created nodes in this run", "bold")]
1250    if parsed.path_filter:
1251      top_lines.append(
1252          RL("File path regex filter: \"%s\"" % parsed.path_filter))
1253    if parsed.node_name_filter:
1254      top_lines.append(
1255          RL("Node name regex filter: \"%s\"" % parsed.node_name_filter))
1256    top_lines.append(RL())
1257    output = debugger_cli_common.rich_text_lines_from_rich_line_list(top_lines)
1258    if not source_list:
1259      output.append("[No source file information.]")
1260      return output
1261
1262    output.extend(self._make_source_table(
1263        [item for item in source_list if not item[1]], False))
1264    output.extend(self._make_source_table(
1265        [item for item in source_list if item[1]], True))
1266    _add_main_menu(output, node_name=None)
1267    return output
1268
1269  def _list_inputs_or_outputs(self,
1270                              recursive,
1271                              node_name,
1272                              depth,
1273                              control,
1274                              op_type,
1275                              do_outputs=False):
1276    """Helper function used by list_inputs and list_outputs.
1277
1278    Format a list of lines to display the inputs or output recipients of a
1279    given node.
1280
1281    Args:
1282      recursive: Whether the listing is to be done recursively, as a boolean.
1283      node_name: The name of the node in question, as a str.
1284      depth: Maximum recursion depth, applies only if recursive == True, as an
1285        int.
1286      control: Whether control inputs or control recipients are included, as a
1287        boolean.
1288      op_type: Whether the op types of the nodes are to be included, as a
1289        boolean.
1290      do_outputs: Whether recipients, instead of input nodes are to be
1291        listed, as a boolean.
1292
1293    Returns:
1294      Input or recipient tree formatted as a RichTextLines object.
1295    """
1296
1297    if do_outputs:
1298      tracker = self._debug_dump.node_recipients
1299      type_str = "Recipients of"
1300      short_type_str = "recipients"
1301    else:
1302      tracker = self._debug_dump.node_inputs
1303      type_str = "Inputs to"
1304      short_type_str = "inputs"
1305
1306    lines = []
1307    font_attr_segs = {}
1308
1309    # Check if this is a tensor name, instead of a node name.
1310    node_name, _ = debug_graphs.parse_node_or_tensor_name(node_name)
1311
1312    # Check if node exists.
1313    if not self._debug_dump.node_exists(node_name):
1314      return cli_shared.error(
1315          "There is no node named \"%s\" in the partition graphs" % node_name)
1316
1317    if recursive:
1318      max_depth = depth
1319    else:
1320      max_depth = 1
1321
1322    if control:
1323      include_ctrls_str = ", control %s included" % short_type_str
1324    else:
1325      include_ctrls_str = ""
1326
1327    line = "%s node \"%s\"" % (type_str, node_name)
1328    font_attr_segs[0] = [(len(line) - 1 - len(node_name), len(line) - 1, "bold")
1329                        ]
1330    lines.append(line + " (Depth limit = %d%s):" % (max_depth, include_ctrls_str
1331                                                   ))
1332
1333    command_template = "lo -c -r %s" if do_outputs else "li -c -r %s"
1334    self._dfs_from_node(
1335        lines,
1336        font_attr_segs,
1337        node_name,
1338        tracker,
1339        max_depth,
1340        1, [],
1341        control,
1342        op_type,
1343        command_template=command_template)
1344
1345    # Include legend.
1346    lines.append("")
1347    lines.append("Legend:")
1348    lines.append("  (d): recursion depth = d.")
1349
1350    if control:
1351      lines.append("  (Ctrl): Control input.")
1352    if op_type:
1353      lines.append("  [Op]: Input node has op type Op.")
1354
1355    # TODO(cais): Consider appending ":0" at the end of 1st outputs of nodes.
1356
1357    return debugger_cli_common.RichTextLines(
1358        lines, font_attr_segs=font_attr_segs)
1359
1360  def _dfs_from_node(self,
1361                     lines,
1362                     attr_segs,
1363                     node_name,
1364                     tracker,
1365                     max_depth,
1366                     depth,
1367                     unfinished,
1368                     include_control=False,
1369                     show_op_type=False,
1370                     command_template=None):
1371    """Perform depth-first search (DFS) traversal of a node's input tree.
1372
1373    It recursively tracks the inputs (or output recipients) of the node called
1374    node_name, and append these inputs (or output recipients) to a list of text
1375    lines (lines) with proper indentation that reflects the recursion depth,
1376    together with some formatting attributes (to attr_segs). The formatting
1377    attributes can include command shortcuts, for example.
1378
1379    Args:
1380      lines: Text lines to append to, as a list of str.
1381      attr_segs: (dict) Attribute segments dictionary to append to.
1382      node_name: Name of the node, as a str. This arg is updated during the
1383        recursion.
1384      tracker: A callable that takes one str as the node name input and
1385        returns a list of str as the inputs/outputs.
1386        This makes it this function general enough to be used with both
1387        node-input and node-output tracking.
1388      max_depth: Maximum recursion depth, as an int.
1389      depth: Current recursion depth. This arg is updated during the
1390        recursion.
1391      unfinished: A stack of unfinished recursion depths, as a list of int.
1392      include_control: Whether control dependencies are to be included as
1393        inputs (and marked as such).
1394      show_op_type: Whether op type of the input nodes are to be displayed
1395        alongside the nodes' names.
1396      command_template: (str) Template for command shortcut of the node names.
1397    """
1398
1399    # Make a shallow copy of the list because it may be extended later.
1400    all_inputs = self._exclude_denylisted_ops(
1401        copy.copy(tracker(node_name, is_control=False)))
1402    is_ctrl = [False] * len(all_inputs)
1403    if include_control:
1404      # Sort control inputs or recipients in alphabetical order of the node
1405      # names.
1406      ctrl_inputs = self._exclude_denylisted_ops(
1407          sorted(tracker(node_name, is_control=True)))
1408      all_inputs.extend(ctrl_inputs)
1409      is_ctrl.extend([True] * len(ctrl_inputs))
1410
1411    if not all_inputs:
1412      if depth == 1:
1413        lines.append("  [None]")
1414
1415      return
1416
1417    unfinished.append(depth)
1418
1419    # Create depth-dependent hanging indent for the line.
1420    hang = ""
1421    for k in range(depth):
1422      if k < depth - 1:
1423        if k + 1 in unfinished:
1424          hang += HANG_UNFINISHED
1425        else:
1426          hang += HANG_FINISHED
1427      else:
1428        hang += HANG_SUFFIX
1429
1430    if all_inputs and depth > max_depth:
1431      lines.append(hang + ELLIPSIS)
1432      unfinished.pop()
1433      return
1434
1435    hang += DEPTH_TEMPLATE % depth
1436
1437    for i, inp in enumerate(all_inputs):
1438      op_type = self._debug_dump.node_op_type(debug_graphs.get_node_name(inp))
1439      if op_type in self._GRAPH_STRUCT_OP_TYPE_DENYLIST:
1440        continue
1441
1442      if is_ctrl[i]:
1443        ctrl_str = CTRL_LABEL
1444      else:
1445        ctrl_str = ""
1446
1447      op_type_str = ""
1448      if show_op_type:
1449        op_type_str = OP_TYPE_TEMPLATE % op_type
1450
1451      if i == len(all_inputs) - 1:
1452        unfinished.pop()
1453
1454      line = hang + ctrl_str + op_type_str + inp
1455      lines.append(line)
1456      if command_template:
1457        attr_segs[len(lines) - 1] = [(
1458            len(line) - len(inp), len(line),
1459            debugger_cli_common.MenuItem(None, command_template % inp))]
1460
1461      # Recursive call.
1462      # The input's/output's name can be a tensor name, in the case of node
1463      # with >1 output slots.
1464      inp_node_name, _ = debug_graphs.parse_node_or_tensor_name(inp)
1465      self._dfs_from_node(
1466          lines,
1467          attr_segs,
1468          inp_node_name,
1469          tracker,
1470          max_depth,
1471          depth + 1,
1472          unfinished,
1473          include_control=include_control,
1474          show_op_type=show_op_type,
1475          command_template=command_template)
1476
1477  def _format_neighbors(self, neighbor_type, non_ctrls, ctrls):
1478    """List neighbors (inputs or recipients) of a node.
1479
1480    Args:
1481      neighbor_type: ("input" | "recipient")
1482      non_ctrls: Non-control neighbor node names, as a list of str.
1483      ctrls: Control neighbor node names, as a list of str.
1484
1485    Returns:
1486      A RichTextLines object.
1487    """
1488
1489    # TODO(cais): Return RichTextLines instead, to allow annotation of node
1490    # names.
1491    lines = []
1492    font_attr_segs = {}
1493
1494    lines.append("")
1495    lines.append("  %d %s(s) + %d control %s(s):" %
1496                 (len(non_ctrls), neighbor_type, len(ctrls), neighbor_type))
1497    lines.append("    %d %s(s):" % (len(non_ctrls), neighbor_type))
1498    for non_ctrl in non_ctrls:
1499      line = "      [%s] %s" % (self._debug_dump.node_op_type(non_ctrl),
1500                                non_ctrl)
1501      lines.append(line)
1502      font_attr_segs[len(lines) - 1] = [(
1503          len(line) - len(non_ctrl), len(line),
1504          debugger_cli_common.MenuItem(None, "ni -a -d -t %s" % non_ctrl))]
1505
1506    if ctrls:
1507      lines.append("")
1508      lines.append("    %d control %s(s):" % (len(ctrls), neighbor_type))
1509      for ctrl in ctrls:
1510        line = "      [%s] %s" % (self._debug_dump.node_op_type(ctrl), ctrl)
1511        lines.append(line)
1512        font_attr_segs[len(lines) - 1] = [(
1513            len(line) - len(ctrl), len(line),
1514            debugger_cli_common.MenuItem(None, "ni -a -d -t %s" % ctrl))]
1515
1516    return debugger_cli_common.RichTextLines(
1517        lines, font_attr_segs=font_attr_segs)
1518
1519  def _list_node_attributes(self, node_name):
1520    """List neighbors (inputs or recipients) of a node.
1521
1522    Args:
1523      node_name: Name of the node of which the attributes are to be listed.
1524
1525    Returns:
1526      A RichTextLines object.
1527    """
1528
1529    lines = []
1530    lines.append("")
1531    lines.append("Node attributes:")
1532
1533    attrs = self._debug_dump.node_attributes(node_name)
1534    for attr_key in attrs:
1535      lines.append("  %s:" % attr_key)
1536      attr_val_str = repr(attrs[attr_key]).strip().replace("\n", " ")
1537      lines.append("    %s" % attr_val_str)
1538      lines.append("")
1539
1540    return debugger_cli_common.RichTextLines(lines)
1541
1542  def _list_node_dumps(self, node_name):
1543    """List dumped tensor data from a node.
1544
1545    Args:
1546      node_name: Name of the node of which the attributes are to be listed.
1547
1548    Returns:
1549      A RichTextLines object.
1550    """
1551
1552    lines = []
1553    font_attr_segs = {}
1554
1555    watch_keys = self._debug_dump.debug_watch_keys(node_name)
1556
1557    dump_count = 0
1558    for watch_key in watch_keys:
1559      debug_tensor_data = self._debug_dump.watch_key_to_data(watch_key)
1560      for datum in debug_tensor_data:
1561        line = "  Slot %d @ %s @ %.3f ms" % (
1562            datum.output_slot, datum.debug_op,
1563            (datum.timestamp - self._debug_dump.t0) / 1000.0)
1564        lines.append(line)
1565        command = "pt %s:%d -n %d" % (node_name, datum.output_slot, dump_count)
1566        font_attr_segs[len(lines) - 1] = [(
1567            2, len(line), debugger_cli_common.MenuItem(None, command))]
1568        dump_count += 1
1569
1570    output = debugger_cli_common.RichTextLines(
1571        lines, font_attr_segs=font_attr_segs)
1572    output_with_header = debugger_cli_common.RichTextLines(
1573        ["%d dumped tensor(s):" % dump_count, ""])
1574    output_with_header.extend(output)
1575    return output_with_header
1576
1577
1578def create_analyzer_ui(debug_dump,
1579                       tensor_filters=None,
1580                       ui_type="curses",
1581                       on_ui_exit=None,
1582                       config=None):
1583  """Create an instance of CursesUI based on a DebugDumpDir object.
1584
1585  Args:
1586    debug_dump: (debug_data.DebugDumpDir) The debug dump to use.
1587    tensor_filters: (dict) A dict mapping tensor filter name (str) to tensor
1588      filter (Callable).
1589    ui_type: (str) requested UI type, e.g., "curses", "readline".
1590    on_ui_exit: (`Callable`) the callback to be called when the UI exits.
1591    config: A `cli_config.CLIConfig` object.
1592
1593  Returns:
1594    (base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
1595      commands and tab-completions registered.
1596  """
1597  if config is None:
1598    config = cli_config.CLIConfig()
1599
1600  analyzer = DebugAnalyzer(debug_dump, config=config)
1601  if tensor_filters:
1602    for tensor_filter_name in tensor_filters:
1603      analyzer.add_tensor_filter(
1604          tensor_filter_name, tensor_filters[tensor_filter_name])
1605
1606  cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit, config=config)
1607  cli.register_command_handler(
1608      "list_tensors",
1609      analyzer.list_tensors,
1610      analyzer.get_help("list_tensors"),
1611      prefix_aliases=["lt"])
1612  cli.register_command_handler(
1613      "node_info",
1614      analyzer.node_info,
1615      analyzer.get_help("node_info"),
1616      prefix_aliases=["ni"])
1617  cli.register_command_handler(
1618      "list_inputs",
1619      analyzer.list_inputs,
1620      analyzer.get_help("list_inputs"),
1621      prefix_aliases=["li"])
1622  cli.register_command_handler(
1623      "list_outputs",
1624      analyzer.list_outputs,
1625      analyzer.get_help("list_outputs"),
1626      prefix_aliases=["lo"])
1627  cli.register_command_handler(
1628      "print_tensor",
1629      analyzer.print_tensor,
1630      analyzer.get_help("print_tensor"),
1631      prefix_aliases=["pt"])
1632  cli.register_command_handler(
1633      "print_source",
1634      analyzer.print_source,
1635      analyzer.get_help("print_source"),
1636      prefix_aliases=["ps"])
1637  cli.register_command_handler(
1638      "list_source",
1639      analyzer.list_source,
1640      analyzer.get_help("list_source"),
1641      prefix_aliases=["ls"])
1642  cli.register_command_handler(
1643      "eval",
1644      analyzer.evaluate_expression,
1645      analyzer.get_help("eval"),
1646      prefix_aliases=["ev"])
1647
1648  dumped_tensor_names = []
1649  for datum in debug_dump.dumped_tensor_data:
1650    dumped_tensor_names.append("%s:%d" % (datum.node_name, datum.output_slot))
1651
1652  # Tab completions for command "print_tensors".
1653  cli.register_tab_comp_context(["print_tensor", "pt"], dumped_tensor_names)
1654
1655  return cli
1656