xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/cli/analyzer_cli_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests of the Analyzer CLI Backend."""
16import os
17import tempfile
18
19import numpy as np
20
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.core.protobuf import rewriter_config_pb2
23from tensorflow.python.client import session
24from tensorflow.python.debug.cli import analyzer_cli
25from tensorflow.python.debug.cli import cli_config
26from tensorflow.python.debug.cli import cli_shared
27from tensorflow.python.debug.cli import cli_test_utils
28from tensorflow.python.debug.cli import command_parser
29from tensorflow.python.debug.cli import debugger_cli_common
30from tensorflow.python.debug.lib import debug_data
31from tensorflow.python.debug.lib import debug_utils
32from tensorflow.python.debug.lib import source_utils
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import test_util
36from tensorflow.python.lib.io import file_io
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import variables
41from tensorflow.python.platform import googletest
42from tensorflow.python.platform import test
43from tensorflow.python.util import tf_inspect
44
45
46# Helper function to accommodate MKL-enabled TensorFlow:
47# MatMul op is supported by MKL for some data types and its name is prefixed
48# with "_Mkl" during the MKL graph rewrite pass.
49def _matmul_op_name():
50  if (test_util.IsMklEnabled() and
51      _get_graph_matmul_dtype() in _mkl_matmul_supported_types()):
52    return "_MklMatMul"
53  else:
54    return "MatMul"
55
56
57# Helper function to get MklMatMul supported types
58def _mkl_matmul_supported_types():
59  return {"float32", "bfloat16"}
60
61
62# Helper function to get dtype used in the graph of SetUpClass()
63def _get_graph_matmul_dtype():
64  # default dtype of matmul op created is float64
65  return "float64"
66
67
68def _cli_config_from_temp_file():
69  return cli_config.CLIConfig(
70      config_file_path=os.path.join(tempfile.mkdtemp(), ".tfdbg_config"))
71
72
73def no_rewrite_session_config():
74  rewriter_config = rewriter_config_pb2.RewriterConfig(
75      disable_model_pruning=True,
76      constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
77      arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
78      dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
79      pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
80
81  graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
82  return config_pb2.ConfigProto(graph_options=graph_options)
83
84
85def line_number_above():
86  return tf_inspect.stack()[1][2] - 1
87
88
89def parse_op_and_node(line):
90  """Parse a line containing an op node followed by a node name.
91
92  For example, if the line is
93    "  [Variable] hidden/weights",
94  this function will return ("Variable", "hidden/weights")
95
96  Args:
97    line: The line to be parsed, as a str.
98
99  Returns:
100    Name of the parsed op type.
101    Name of the parsed node.
102  """
103
104  op_type = line.strip().split(" ")[0].replace("[", "").replace("]", "")
105
106  # Not using [-1], to tolerate any other items that might be present behind
107  # the node name.
108  node_name = line.strip().split(" ")[1]
109
110  return op_type, node_name
111
112
113def assert_column_header_command_shortcut(tst,
114                                          command,
115                                          reverse,
116                                          node_name_regex,
117                                          op_type_regex,
118                                          tensor_filter_name):
119  tst.assertFalse(reverse and "-r" in command)
120  tst.assertFalse(not(op_type_regex) and ("-t %s" % op_type_regex) in command)
121  tst.assertFalse(
122      not(node_name_regex) and ("-t %s" % node_name_regex) in command)
123  tst.assertFalse(
124      not(tensor_filter_name) and ("-t %s" % tensor_filter_name) in command)
125
126
127def assert_listed_tensors(tst,
128                          out,
129                          expected_tensor_names,
130                          expected_op_types,
131                          node_name_regex=None,
132                          op_type_regex=None,
133                          tensor_filter_name=None,
134                          sort_by="timestamp",
135                          reverse=False):
136  """Check RichTextLines output for list_tensors commands.
137
138  Args:
139    tst: A test_util.TensorFlowTestCase instance.
140    out: The RichTextLines object to be checked.
141    expected_tensor_names: (list of str) Expected tensor names in the list.
142    expected_op_types: (list of str) Expected op types of the tensors, in the
143      same order as the expected_tensor_names.
144    node_name_regex: Optional: node name regex filter.
145    op_type_regex: Optional: op type regex filter.
146    tensor_filter_name: Optional: name of the tensor filter.
147    sort_by: (str) (timestamp | op_type | tensor_name) the field by which the
148      tensors in the list are sorted.
149    reverse: (bool) whether the sorting is in reverse (i.e., descending) order.
150  """
151
152  line_iter = iter(out.lines)
153  attr_segs = out.font_attr_segs
154  line_counter = 0
155
156  num_dumped_tensors = int(next(line_iter).split(" ")[0])
157  line_counter += 1
158  tst.assertGreaterEqual(num_dumped_tensors, len(expected_tensor_names))
159
160  if op_type_regex is not None:
161    tst.assertEqual("Op type regex filter: \"%s\"" % op_type_regex,
162                    next(line_iter))
163    line_counter += 1
164
165  if node_name_regex is not None:
166    tst.assertEqual("Node name regex filter: \"%s\"" % node_name_regex,
167                    next(line_iter))
168    line_counter += 1
169
170  tst.assertEqual("", next(line_iter))
171  line_counter += 1
172
173  # Verify the column heads "t (ms)", "Op type" and "Tensor name" are present.
174  line = next(line_iter)
175  tst.assertIn("t (ms)", line)
176  tst.assertIn("Op type", line)
177  tst.assertIn("Tensor name", line)
178
179  # Verify the command shortcuts in the top row.
180  attr_segs = out.font_attr_segs[line_counter]
181  attr_seg = attr_segs[0]
182  tst.assertEqual(0, attr_seg[0])
183  tst.assertEqual(len("t (ms)"), attr_seg[1])
184  command = attr_seg[2][0].content
185  tst.assertIn("-s timestamp", command)
186  assert_column_header_command_shortcut(
187      tst, command, reverse, node_name_regex, op_type_regex,
188      tensor_filter_name)
189  tst.assertEqual("bold", attr_seg[2][1])
190
191  idx0 = line.index("Size")
192  attr_seg = attr_segs[1]
193  tst.assertEqual(idx0, attr_seg[0])
194  tst.assertEqual(idx0 + len("Size (B)"), attr_seg[1])
195  command = attr_seg[2][0].content
196  tst.assertIn("-s dump_size", command)
197  assert_column_header_command_shortcut(tst, command, reverse, node_name_regex,
198                                        op_type_regex, tensor_filter_name)
199  tst.assertEqual("bold", attr_seg[2][1])
200
201  idx0 = line.index("Op type")
202  attr_seg = attr_segs[2]
203  tst.assertEqual(idx0, attr_seg[0])
204  tst.assertEqual(idx0 + len("Op type"), attr_seg[1])
205  command = attr_seg[2][0].content
206  tst.assertIn("-s op_type", command)
207  assert_column_header_command_shortcut(
208      tst, command, reverse, node_name_regex, op_type_regex,
209      tensor_filter_name)
210  tst.assertEqual("bold", attr_seg[2][1])
211
212  idx0 = line.index("Tensor name")
213  attr_seg = attr_segs[3]
214  tst.assertEqual(idx0, attr_seg[0])
215  tst.assertEqual(idx0 + len("Tensor name"), attr_seg[1])
216  command = attr_seg[2][0].content
217  tst.assertIn("-s tensor_name", command)
218  assert_column_header_command_shortcut(
219      tst, command, reverse, node_name_regex, op_type_regex,
220      tensor_filter_name)
221  tst.assertEqual("bold", attr_seg[2][1])
222
223  # Verify the listed tensors and their timestamps.
224  tensor_timestamps = []
225  dump_sizes_bytes = []
226  op_types = []
227  tensor_names = []
228  for line in line_iter:
229    items = line.split(" ")
230    items = [item for item in items if item]
231
232    rel_time = float(items[0][1:-1])
233    tst.assertGreaterEqual(rel_time, 0.0)
234
235    tensor_timestamps.append(rel_time)
236    dump_sizes_bytes.append(command_parser.parse_readable_size_str(items[1]))
237    op_types.append(items[2])
238    tensor_names.append(items[3])
239
240  # Verify that the tensors should be listed in ascending order of their
241  # timestamps.
242  if sort_by == "timestamp":
243    sorted_timestamps = sorted(tensor_timestamps)
244    if reverse:
245      sorted_timestamps.reverse()
246    tst.assertEqual(sorted_timestamps, tensor_timestamps)
247  elif sort_by == "dump_size":
248    sorted_dump_sizes_bytes = sorted(dump_sizes_bytes)
249    if reverse:
250      sorted_dump_sizes_bytes.reverse()
251    tst.assertEqual(sorted_dump_sizes_bytes, dump_sizes_bytes)
252  elif sort_by == "op_type":
253    sorted_op_types = sorted(op_types)
254    if reverse:
255      sorted_op_types.reverse()
256    tst.assertEqual(sorted_op_types, op_types)
257  elif sort_by == "tensor_name":
258    sorted_tensor_names = sorted(tensor_names)
259    if reverse:
260      sorted_tensor_names.reverse()
261    tst.assertEqual(sorted_tensor_names, tensor_names)
262  else:
263    tst.fail("Invalid value in sort_by: %s" % sort_by)
264
265  # Verify that the tensors are all listed.
266  for tensor_name, op_type in zip(expected_tensor_names, expected_op_types):
267    tst.assertIn(tensor_name, tensor_names)
268    index = tensor_names.index(tensor_name)
269    tst.assertEqual(op_type, op_types[index])
270
271
272def assert_node_attribute_lines(tst,
273                                out,
274                                node_name,
275                                op_type,
276                                device,
277                                input_op_type_node_name_pairs,
278                                ctrl_input_op_type_node_name_pairs,
279                                recipient_op_type_node_name_pairs,
280                                ctrl_recipient_op_type_node_name_pairs,
281                                attr_key_val_pairs=None,
282                                num_dumped_tensors=None,
283                                show_stack_trace=False,
284                                stack_trace_available=False):
285  """Check RichTextLines output for node_info commands.
286
287  Args:
288    tst: A test_util.TensorFlowTestCase instance.
289    out: The RichTextLines object to be checked.
290    node_name: Name of the node.
291    op_type: Op type of the node, as a str.
292    device: Name of the device on which the node resides.
293    input_op_type_node_name_pairs: A list of 2-tuples of op type and node name,
294      for the (non-control) inputs to the node.
295    ctrl_input_op_type_node_name_pairs: A list of 2-tuples of op type and node
296      name, for the control inputs to the node.
297    recipient_op_type_node_name_pairs: A list of 2-tuples of op type and node
298      name, for the (non-control) output recipients to the node.
299    ctrl_recipient_op_type_node_name_pairs: A list of 2-tuples of op type and
300      node name, for the control output recipients to the node.
301    attr_key_val_pairs: Optional: attribute key-value pairs of the node, as a
302      list of 2-tuples.
303    num_dumped_tensors: Optional: number of tensor dumps from the node.
304    show_stack_trace: (bool) whether the stack trace of the node's
305      construction is asserted to be present.
306    stack_trace_available: (bool) whether Python stack trace is available.
307  """
308
309  line_iter = iter(out.lines)
310
311  tst.assertEqual("Node %s" % node_name, next(line_iter))
312  tst.assertEqual("", next(line_iter))
313  tst.assertEqual("  Op: %s" % op_type, next(line_iter))
314  tst.assertEqual("  Device: %s" % device, next(line_iter))
315  tst.assertEqual("", next(line_iter))
316  tst.assertEqual("  %d input(s) + %d control input(s):" %
317                  (len(input_op_type_node_name_pairs),
318                   len(ctrl_input_op_type_node_name_pairs)), next(line_iter))
319
320  # Check inputs.
321  tst.assertEqual("    %d input(s):" % len(input_op_type_node_name_pairs),
322                  next(line_iter))
323  for op_type, node_name in input_op_type_node_name_pairs:
324    tst.assertEqual("      [%s] %s" % (op_type, node_name), next(line_iter))
325
326  tst.assertEqual("", next(line_iter))
327
328  # Check control inputs.
329  if ctrl_input_op_type_node_name_pairs:
330    tst.assertEqual("    %d control input(s):" %
331                    len(ctrl_input_op_type_node_name_pairs), next(line_iter))
332    for op_type, node_name in ctrl_input_op_type_node_name_pairs:
333      tst.assertEqual("      [%s] %s" % (op_type, node_name), next(line_iter))
334
335    tst.assertEqual("", next(line_iter))
336
337  tst.assertEqual("  %d recipient(s) + %d control recipient(s):" %
338                  (len(recipient_op_type_node_name_pairs),
339                   len(ctrl_recipient_op_type_node_name_pairs)),
340                  next(line_iter))
341
342  # Check recipients, the order of which is not deterministic.
343  tst.assertEqual("    %d recipient(s):" %
344                  len(recipient_op_type_node_name_pairs), next(line_iter))
345
346  t_recs = []
347  for _ in recipient_op_type_node_name_pairs:
348    line = next(line_iter)
349
350    op_type, node_name = parse_op_and_node(line)
351    t_recs.append((op_type, node_name))
352
353  tst.assertItemsEqual(recipient_op_type_node_name_pairs, t_recs)
354
355  # Check control recipients, the order of which is not deterministic.
356  if ctrl_recipient_op_type_node_name_pairs:
357    tst.assertEqual("", next(line_iter))
358
359    tst.assertEqual("    %d control recipient(s):" %
360                    len(ctrl_recipient_op_type_node_name_pairs),
361                    next(line_iter))
362
363    t_ctrl_recs = []
364    for _ in ctrl_recipient_op_type_node_name_pairs:
365      line = next(line_iter)
366
367      op_type, node_name = parse_op_and_node(line)
368      t_ctrl_recs.append((op_type, node_name))
369
370    tst.assertItemsEqual(ctrl_recipient_op_type_node_name_pairs, t_ctrl_recs)
371
372  # The order of multiple attributes can be non-deterministic.
373  if attr_key_val_pairs:
374    tst.assertEqual("", next(line_iter))
375
376    tst.assertEqual("Node attributes:", next(line_iter))
377
378    kv_pairs = []
379    for key, val in attr_key_val_pairs:
380      key = next(line_iter).strip().replace(":", "")
381
382      val = next(line_iter).strip()
383
384      kv_pairs.append((key, val))
385
386      tst.assertEqual("", next(line_iter))
387
388  if num_dumped_tensors is not None:
389    tst.assertEqual("%d dumped tensor(s):" % num_dumped_tensors,
390                    next(line_iter))
391    tst.assertEqual("", next(line_iter))
392
393    dump_timestamps_ms = []
394    for _ in range(num_dumped_tensors):
395      line = next(line_iter)
396
397      tst.assertStartsWith(line.strip(), "Slot 0 @ DebugIdentity @")
398      tst.assertTrue(line.strip().endswith(" ms"))
399
400      dump_timestamp_ms = float(line.strip().split(" @ ")[-1].replace("ms", ""))
401      tst.assertGreaterEqual(dump_timestamp_ms, 0.0)
402
403      dump_timestamps_ms.append(dump_timestamp_ms)
404
405    tst.assertEqual(sorted(dump_timestamps_ms), dump_timestamps_ms)
406
407  if show_stack_trace:
408    tst.assertEqual("", next(line_iter))
409    tst.assertEqual("", next(line_iter))
410    tst.assertEqual("Traceback of node construction:", next(line_iter))
411    if stack_trace_available:
412      try:
413        depth_counter = 0
414        while True:
415          for i in range(5):
416            line = next(line_iter)
417            if i == 0:
418              tst.assertEqual(depth_counter, int(line.split(":")[0]))
419            elif i == 1:
420              tst.assertStartsWith(line, "  Line:")
421            elif i == 2:
422              tst.assertStartsWith(line, "  Function:")
423            elif i == 3:
424              tst.assertStartsWith(line, "  Text:")
425            elif i == 4:
426              tst.assertEqual("", line)
427
428          depth_counter += 1
429      except StopIteration:
430        tst.assertEqual(0, i)
431    else:
432      tst.assertEqual("(Unavailable because no Python graph has been loaded)",
433                      next(line_iter))
434
435
436def check_syntax_error_output(tst, out, command_prefix):
437  """Check RichTextLines output for valid command prefix but invalid syntax."""
438
439  tst.assertEqual([
440      "Syntax error for command: %s" % command_prefix,
441      "For help, do \"help %s\"" % command_prefix
442  ], out.lines)
443
444
445def check_error_output(tst, out, command_prefix, args):
446  """Check RichTextLines output from invalid/erroneous commands.
447
448  Args:
449    tst: A test_util.TensorFlowTestCase instance.
450    out: The RichTextLines object to be checked.
451    command_prefix: The command prefix of the command that caused the error.
452    args: The arguments (excluding prefix) of the command that caused the error.
453  """
454
455  tst.assertGreater(len(out.lines), 2)
456  tst.assertStartsWith(out.lines[0],
457                       "Error occurred during handling of command: %s %s" %
458                       (command_prefix, " ".join(args)))
459
460
461def check_main_menu(tst,
462                    out,
463                    list_tensors_enabled=False,
464                    node_info_node_name=None,
465                    print_tensor_node_name=None,
466                    list_inputs_node_name=None,
467                    list_outputs_node_name=None):
468  """Check the main menu annotation of an output."""
469
470  tst.assertIn(debugger_cli_common.MAIN_MENU_KEY, out.annotations)
471
472  menu = out.annotations[debugger_cli_common.MAIN_MENU_KEY]
473  tst.assertEqual(list_tensors_enabled,
474                  menu.caption_to_item("list_tensors").is_enabled())
475
476  menu_item = menu.caption_to_item("node_info")
477  if node_info_node_name:
478    tst.assertTrue(menu_item.is_enabled())
479    tst.assertTrue(menu_item.content.endswith(node_info_node_name))
480  else:
481    tst.assertFalse(menu_item.is_enabled())
482
483  menu_item = menu.caption_to_item("print_tensor")
484  if print_tensor_node_name:
485    tst.assertTrue(menu_item.is_enabled())
486    tst.assertTrue(menu_item.content.endswith(print_tensor_node_name))
487  else:
488    tst.assertFalse(menu_item.is_enabled())
489
490  menu_item = menu.caption_to_item("list_inputs")
491  if list_inputs_node_name:
492    tst.assertTrue(menu_item.is_enabled())
493    tst.assertTrue(menu_item.content.endswith(list_inputs_node_name))
494  else:
495    tst.assertFalse(menu_item.is_enabled())
496
497  menu_item = menu.caption_to_item("list_outputs")
498  if list_outputs_node_name:
499    tst.assertTrue(menu_item.is_enabled())
500    tst.assertTrue(menu_item.content.endswith(list_outputs_node_name))
501  else:
502    tst.assertFalse(menu_item.is_enabled())
503
504  tst.assertTrue(menu.caption_to_item("run_info").is_enabled())
505  tst.assertTrue(menu.caption_to_item("help").is_enabled())
506
507
508def check_menu_item(tst, out, line_index, expected_begin, expected_end,
509                    expected_command):
510  attr_segs = out.font_attr_segs[line_index]
511  found_menu_item = False
512  for begin, end, attribute in attr_segs:
513    attributes = [attribute] if not isinstance(attribute, list) else attribute
514    menu_item = [attribute for attribute in attributes if
515                 isinstance(attribute, debugger_cli_common.MenuItem)]
516    if menu_item:
517      tst.assertEqual(expected_begin, begin)
518      tst.assertEqual(expected_end, end)
519      tst.assertEqual(expected_command, menu_item[0].content)
520      found_menu_item = True
521      break
522  tst.assertTrue(found_menu_item)
523
524
525def create_analyzer_cli(dump):
526  """Create an analyzer CLI.
527
528  Args:
529    dump: A `DebugDumpDir` object to base the analyzer CLI on.
530
531  Returns:
532    1) A `DebugAnalyzer` object created based on `dump`.
533    2) A `CommandHandlerRegistry` that is based on the `DebugAnalyzer` object
534       and has the common tfdbg commands, e.g., lt, ni, li, lo, registered.
535  """
536  # Construct the analyzer.
537  analyzer = analyzer_cli.DebugAnalyzer(dump, _cli_config_from_temp_file())
538
539  # Construct the handler registry.
540  registry = debugger_cli_common.CommandHandlerRegistry()
541
542  # Register command handlers.
543  registry.register_command_handler(
544      "list_tensors",
545      analyzer.list_tensors,
546      analyzer.get_help("list_tensors"),
547      prefix_aliases=["lt"])
548  registry.register_command_handler(
549      "node_info",
550      analyzer.node_info,
551      analyzer.get_help("node_info"),
552      prefix_aliases=["ni"])
553  registry.register_command_handler(
554      "list_inputs",
555      analyzer.list_inputs,
556      analyzer.get_help("list_inputs"),
557      prefix_aliases=["li"])
558  registry.register_command_handler(
559      "list_outputs",
560      analyzer.list_outputs,
561      analyzer.get_help("list_outputs"),
562      prefix_aliases=["lo"])
563  registry.register_command_handler(
564      "print_tensor",
565      analyzer.print_tensor,
566      analyzer.get_help("print_tensor"),
567      prefix_aliases=["pt"])
568  registry.register_command_handler(
569      "print_source",
570      analyzer.print_source,
571      analyzer.get_help("print_source"),
572      prefix_aliases=["ps"])
573  registry.register_command_handler(
574      "list_source",
575      analyzer.list_source,
576      analyzer.get_help("list_source"),
577      prefix_aliases=["ls"])
578  registry.register_command_handler(
579      "eval",
580      analyzer.evaluate_expression,
581      analyzer.get_help("eval"),
582      prefix_aliases=["ev"])
583
584  return analyzer, registry
585
586
587@test_util.run_v1_only("b/120545219")
588class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
589
590  @classmethod
591  def setUpClass(cls):
592    cls._dump_root = tempfile.mkdtemp()
593    cls._dump_root_for_unique = tempfile.mkdtemp()
594
595    cls._is_gpu_available = test.is_gpu_available()
596    if cls._is_gpu_available:
597      gpu_name = test_util.gpu_device_name()
598      cls._main_device = "/job:localhost/replica:0/task:0" + gpu_name
599    else:
600      cls._main_device = "/job:localhost/replica:0/task:0/device:CPU:0"
601
602    cls._curr_file_path = os.path.abspath(
603        tf_inspect.getfile(tf_inspect.currentframe()))
604
605    cls._sess = session.Session(config=no_rewrite_session_config())
606    with cls._sess as sess:
607      u_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
608      v_init_val = np.array([[2.0], [-1.0]])
609
610      u_name = "simple_mul_add/u"
611      v_name = "simple_mul_add/v"
612
613      u_init = constant_op.constant(u_init_val, shape=[2, 2], name="u_init")
614      u = variables.VariableV1(u_init, name=u_name)
615      cls._u_line_number = line_number_above()
616
617      v_init = constant_op.constant(v_init_val, shape=[2, 1], name="v_init")
618      v = variables.VariableV1(v_init, name=v_name)
619      cls._v_line_number = line_number_above()
620
621      w = math_ops.matmul(u, v, name="simple_mul_add/matmul")
622      cls._w_line_number = line_number_above()
623
624      x = math_ops.add(w, w, name="simple_mul_add/add")
625      cls._x_line_number = line_number_above()
626
627      a = variables.VariableV1([1, 3, 3, 7], name="a")
628
629      u.initializer.run()
630      v.initializer.run()
631      a.initializer.run()
632
633      run_options = config_pb2.RunOptions(output_partition_graphs=True)
634      debug_utils.watch_graph(
635          run_options,
636          sess.graph,
637          debug_ops=["DebugIdentity"],
638          debug_urls="file://%s" % cls._dump_root)
639
640      # Invoke Session.run().
641      run_metadata = config_pb2.RunMetadata()
642      sess.run([x], options=run_options, run_metadata=run_metadata)
643      cls._debug_dump = debug_data.DebugDumpDir(
644          cls._dump_root, partition_graphs=run_metadata.partition_graphs)
645      cls._analyzer, cls._registry = create_analyzer_cli(cls._debug_dump)
646
647  @classmethod
648  def tearDownClass(cls):
649    # Tear down temporary dump directory.
650    file_io.delete_recursively(cls._dump_root)
651    file_io.delete_recursively(cls._dump_root_for_unique)
652
653  def testMeasureTensorListColumnWidthsGivesRightAnswerForEmptyData(self):
654    timestamp_col_width, dump_size_col_width, op_type_col_width = (
655        self._analyzer._measure_tensor_list_column_widths([]))
656    self.assertEqual(len("t (ms)") + 1, timestamp_col_width)
657    self.assertEqual(len("Size (B)") + 1, dump_size_col_width)
658    self.assertEqual(len("Op type") + 1, op_type_col_width)
659
660  def testMeasureTensorListColumnWidthsGivesRightAnswerForData(self):
661    dump = self._debug_dump.dumped_tensor_data[0]
662    self.assertLess(dump.dump_size_bytes, 1000)
663    self.assertEqual(
664        "VariableV2", self._debug_dump.node_op_type(dump.node_name))
665    _, dump_size_col_width, op_type_col_width = (
666        self._analyzer._measure_tensor_list_column_widths([dump]))
667    # The length of str(dump.dump_size_bytes) is less than the length of
668    # "Size (B)" (8). So the column width should be determined by the length of
669    # "Size (B)".
670    self.assertEqual(len("Size (B)") + 1, dump_size_col_width)
671    # The length of "VariableV2" is greater than the length of "Op type". So the
672    # column should be determined by the length of "VariableV2".
673    self.assertEqual(len("VariableV2") + 1, op_type_col_width)
674
675  def testListTensors(self):
676    # Use shorthand alias for the command prefix.
677    out = self._registry.dispatch_command("lt", [])
678
679    assert_listed_tensors(self, out, [
680        "simple_mul_add/u:0", "simple_mul_add/v:0", "simple_mul_add/u/read:0",
681        "simple_mul_add/v/read:0", "simple_mul_add/matmul:0",
682        "simple_mul_add/add:0"
683    ], [
684        "VariableV2", "VariableV2", "Identity", "Identity",
685        _matmul_op_name(), "AddV2"
686    ])
687
688    # Check the main menu.
689    check_main_menu(self, out, list_tensors_enabled=False)
690
691  def testListTensorsInReverseTimeOrderWorks(self):
692    # Use shorthand alias for the command prefix.
693    out = self._registry.dispatch_command("lt", ["-s", "timestamp", "-r"])
694    assert_listed_tensors(
695        self,
696        out, [
697            "simple_mul_add/u:0", "simple_mul_add/v:0",
698            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
699            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
700        ], [
701            "VariableV2", "VariableV2", "Identity", "Identity",
702            _matmul_op_name(), "AddV2"
703        ],
704        sort_by="timestamp",
705        reverse=True)
706    check_main_menu(self, out, list_tensors_enabled=False)
707
708  def testListTensorsInDumpSizeOrderWorks(self):
709    out = self._registry.dispatch_command("lt", ["-s", "dump_size"])
710    assert_listed_tensors(
711        self,
712        out, [
713            "simple_mul_add/u:0", "simple_mul_add/v:0",
714            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
715            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
716        ], [
717            "VariableV2", "VariableV2", "Identity", "Identity",
718            _matmul_op_name(), "AddV2"
719        ],
720        sort_by="dump_size")
721    check_main_menu(self, out, list_tensors_enabled=False)
722
723  def testListTensorsInReverseDumpSizeOrderWorks(self):
724    out = self._registry.dispatch_command("lt", ["-s", "dump_size", "-r"])
725    assert_listed_tensors(
726        self,
727        out, [
728            "simple_mul_add/u:0", "simple_mul_add/v:0",
729            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
730            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
731        ], [
732            "VariableV2", "VariableV2", "Identity", "Identity",
733            _matmul_op_name(), "AddV2"
734        ],
735        sort_by="dump_size",
736        reverse=True)
737    check_main_menu(self, out, list_tensors_enabled=False)
738
739  def testListTensorsWithInvalidSortByFieldGivesError(self):
740    out = self._registry.dispatch_command("lt", ["-s", "foobar"])
741    self.assertIn("ValueError: Unsupported key to sort tensors by: foobar",
742                  out.lines)
743
744  def testListTensorsInOpTypeOrderWorks(self):
745    # Use shorthand alias for the command prefix.
746    out = self._registry.dispatch_command("lt", ["-s", "op_type"])
747    assert_listed_tensors(
748        self,
749        out, [
750            "simple_mul_add/u:0", "simple_mul_add/v:0",
751            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
752            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
753        ], [
754            "VariableV2", "VariableV2", "Identity", "Identity",
755            _matmul_op_name(), "AddV2"
756        ],
757        sort_by="op_type",
758        reverse=False)
759    check_main_menu(self, out, list_tensors_enabled=False)
760
761  def testListTensorsInReverseOpTypeOrderWorks(self):
762    # Use shorthand alias for the command prefix.
763    out = self._registry.dispatch_command("lt", ["-s", "op_type", "-r"])
764    assert_listed_tensors(
765        self,
766        out, [
767            "simple_mul_add/u:0", "simple_mul_add/v:0",
768            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
769            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
770        ], [
771            "VariableV2", "VariableV2", "Identity", "Identity",
772            _matmul_op_name(), "AddV2"
773        ],
774        sort_by="op_type",
775        reverse=True)
776    check_main_menu(self, out, list_tensors_enabled=False)
777
778  def testListTensorsInTensorNameOrderWorks(self):
779    # Use shorthand alias for the command prefix.
780    out = self._registry.dispatch_command("lt", ["-s", "tensor_name"])
781    assert_listed_tensors(
782        self,
783        out, [
784            "simple_mul_add/u:0", "simple_mul_add/v:0",
785            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
786            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
787        ], [
788            "VariableV2", "VariableV2", "Identity", "Identity",
789            _matmul_op_name(), "AddV2"
790        ],
791        sort_by="tensor_name",
792        reverse=False)
793    check_main_menu(self, out, list_tensors_enabled=False)
794
795  def testListTensorsInReverseTensorNameOrderWorks(self):
796    # Use shorthand alias for the command prefix.
797    out = self._registry.dispatch_command("lt", ["-s", "tensor_name", "-r"])
798    assert_listed_tensors(
799        self,
800        out, [
801            "simple_mul_add/u:0", "simple_mul_add/v:0",
802            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
803            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
804        ], [
805            "VariableV2", "VariableV2", "Identity", "Identity",
806            _matmul_op_name(), "AddV2"
807        ],
808        sort_by="tensor_name",
809        reverse=True)
810    check_main_menu(self, out, list_tensors_enabled=False)
811
812  def testListTensorsFilterByNodeNameRegex(self):
813    out = self._registry.dispatch_command("list_tensors",
814                                          ["--node_name_filter", ".*read.*"])
815    assert_listed_tensors(
816        self,
817        out, ["simple_mul_add/u/read:0", "simple_mul_add/v/read:0"],
818        ["Identity", "Identity"],
819        node_name_regex=".*read.*")
820
821    out = self._registry.dispatch_command("list_tensors", ["-n", "^read"])
822    assert_listed_tensors(self, out, [], [], node_name_regex="^read")
823    check_main_menu(self, out, list_tensors_enabled=False)
824
825  def testListTensorFilterByOpTypeRegex(self):
826    out = self._registry.dispatch_command("list_tensors",
827                                          ["--op_type_filter", "Identity"])
828    assert_listed_tensors(
829        self,
830        out, ["simple_mul_add/u/read:0", "simple_mul_add/v/read:0"],
831        ["Identity", "Identity"],
832        op_type_regex="Identity")
833
834    out = self._registry.dispatch_command(
835        "list_tensors", ["-t", "(Add|" + _matmul_op_name() + ")"])
836    assert_listed_tensors(
837        self,
838        out, ["simple_mul_add/add:0", "simple_mul_add/matmul:0"],
839        ["AddV2", _matmul_op_name()],
840        op_type_regex=("(Add|" + _matmul_op_name() + ")"))
841    check_main_menu(self, out, list_tensors_enabled=False)
842
843  def testListTensorFilterByNodeNameRegexAndOpTypeRegex(self):
844    out = self._registry.dispatch_command(
845        "list_tensors", ["-t", "(Add|MatMul)", "-n", ".*add$"])
846    assert_listed_tensors(
847        self,
848        out, ["simple_mul_add/add:0"], ["AddV2"],
849        node_name_regex=".*add$",
850        op_type_regex="(Add|MatMul)")
851    check_main_menu(self, out, list_tensors_enabled=False)
852
853  def testListTensorWithFilterAndNodeNameExclusionWorks(self):
854    # First, create and register the filter.
855    def is_2x1_vector(datum, tensor):
856      del datum  # Unused.
857      return list(tensor.shape) == [2, 1]
858    self._analyzer.add_tensor_filter("is_2x1_vector", is_2x1_vector)
859
860    # Use shorthand alias for the command prefix.
861    out = self._registry.dispatch_command(
862        "lt", ["-f", "is_2x1_vector", "--filter_exclude_node_names", ".*v.*"])
863
864    # If the --filter_exclude_node_names were not used, then the matching
865    # tensors would be:
866    #   - simple_mul_add/v:0
867    #   - simple_mul_add/v/read:0
868    #   - simple_mul_add/matmul:0
869    #   - simple_mul_add/add:0
870    #
871    # With the --filter_exclude_node_names option, only the last two should
872    # show up in the result.
873    assert_listed_tensors(
874        self,
875        out, ["simple_mul_add/matmul:0", "simple_mul_add/add:0"],
876        [_matmul_op_name(), "AddV2"],
877        tensor_filter_name="is_2x1_vector")
878
879    check_main_menu(self, out, list_tensors_enabled=False)
880
881  def testListTensorsFilterNanOrInf(self):
882    """Test register and invoke a tensor filter."""
883
884    # First, register the filter.
885    self._analyzer.add_tensor_filter("has_inf_or_nan",
886                                     debug_data.has_inf_or_nan)
887
888    # Use shorthand alias for the command prefix.
889    out = self._registry.dispatch_command("lt", ["-f", "has_inf_or_nan"])
890
891    # This TF graph run did not generate any bad numerical values.
892    assert_listed_tensors(
893        self, out, [], [], tensor_filter_name="has_inf_or_nan")
894    # TODO(cais): A test with some actual bad numerical values.
895
896    check_main_menu(self, out, list_tensors_enabled=False)
897
898  def testListTensorNonexistentFilter(self):
899    """Test attempt to use a nonexistent tensor filter."""
900
901    out = self._registry.dispatch_command("lt", ["-f", "foo_filter"])
902
903    self.assertEqual(["ERROR: There is no tensor filter named \"foo_filter\"."],
904                     out.lines)
905    check_main_menu(self, out, list_tensors_enabled=False)
906
907  def testListTensorsInvalidOptions(self):
908    out = self._registry.dispatch_command("list_tensors", ["--bar"])
909    check_syntax_error_output(self, out, "list_tensors")
910
911  def testNodeInfoByNodeName(self):
912    node_name = "simple_mul_add/matmul"
913    out = self._registry.dispatch_command("node_info", [node_name])
914
915    recipients = [("AddV2", "simple_mul_add/add"),
916                  ("AddV2", "simple_mul_add/add")]
917
918    assert_node_attribute_lines(self, out, node_name, _matmul_op_name(),
919                                self._main_device,
920                                [("Identity", "simple_mul_add/u/read"),
921                                 ("Identity", "simple_mul_add/v/read")], [],
922                                recipients, [])
923    check_main_menu(
924        self,
925        out,
926        list_tensors_enabled=True,
927        list_inputs_node_name=node_name,
928        print_tensor_node_name=node_name,
929        list_outputs_node_name=node_name)
930
931    # Verify that the node name is bold in the first line.
932    self.assertEqual(
933        [(len(out.lines[0]) - len(node_name), len(out.lines[0]), "bold")],
934        out.font_attr_segs[0])
935
936  def testNodeInfoShowAttributes(self):
937    node_name = "simple_mul_add/matmul"
938    out = self._registry.dispatch_command("node_info", ["-a", node_name])
939
940    test_attr_key_val_pairs = [("transpose_a", "b: false"),
941                               ("transpose_b", "b: false"),
942                               ("T", "type: DT_DOUBLE")]
943    if test_util.IsMklEnabled():
944      test_attr_key_val_pairs.append(("_kernel", 's: "MklNameChangeOp"'))
945
946    assert_node_attribute_lines(
947        self,
948        out,
949        node_name,
950        _matmul_op_name(),
951        self._main_device, [("Identity", "simple_mul_add/u/read"),
952                            ("Identity", "simple_mul_add/v/read")], [],
953        [("AddV2", "simple_mul_add/add"), ("AddV2", "simple_mul_add/add")], [],
954        attr_key_val_pairs=test_attr_key_val_pairs)
955    check_main_menu(
956        self,
957        out,
958        list_tensors_enabled=True,
959        list_inputs_node_name=node_name,
960        print_tensor_node_name=node_name,
961        list_outputs_node_name=node_name)
962
963  def testNodeInfoShowDumps(self):
964    node_name = "simple_mul_add/matmul"
965    out = self._registry.dispatch_command("node_info", ["-d", node_name])
966
967    assert_node_attribute_lines(
968        self,
969        out,
970        node_name,
971        _matmul_op_name(),
972        self._main_device, [("Identity", "simple_mul_add/u/read"),
973                            ("Identity", "simple_mul_add/v/read")], [],
974        [("AddV2", "simple_mul_add/add"), ("AddV2", "simple_mul_add/add")], [],
975        num_dumped_tensors=1)
976    check_main_menu(
977        self,
978        out,
979        list_tensors_enabled=True,
980        list_inputs_node_name=node_name,
981        print_tensor_node_name=node_name,
982        list_outputs_node_name=node_name)
983    check_menu_item(self, out, 16,
984                    len(out.lines[16]) - len(out.lines[16].strip()),
985                    len(out.lines[16]), "pt %s:0 -n 0" % node_name)
986
987  def testNodeInfoShowStackTraceUnavailableIsIndicated(self):
988    self._debug_dump.set_python_graph(None)
989
990    node_name = "simple_mul_add/matmul"
991    out = self._registry.dispatch_command("node_info", ["-t", node_name])
992
993    assert_node_attribute_lines(
994        self,
995        out,
996        node_name,
997        _matmul_op_name(),
998        self._main_device, [("Identity", "simple_mul_add/u/read"),
999                            ("Identity", "simple_mul_add/v/read")], [],
1000        [("AddV2", "simple_mul_add/add"), ("AddV2", "simple_mul_add/add")], [],
1001        show_stack_trace=True,
1002        stack_trace_available=False)
1003    check_main_menu(
1004        self,
1005        out,
1006        list_tensors_enabled=True,
1007        list_inputs_node_name=node_name,
1008        print_tensor_node_name=node_name,
1009        list_outputs_node_name=node_name)
1010
1011  def testNodeInfoShowStackTraceAvailableWorks(self):
1012    self._debug_dump.set_python_graph(self._sess.graph)
1013
1014    node_name = "simple_mul_add/matmul"
1015    out = self._registry.dispatch_command("node_info", ["-t", node_name])
1016
1017    assert_node_attribute_lines(
1018        self,
1019        out,
1020        node_name,
1021        _matmul_op_name(),
1022        self._main_device, [("Identity", "simple_mul_add/u/read"),
1023                            ("Identity", "simple_mul_add/v/read")], [],
1024        [("AddV2", "simple_mul_add/add"), ("AddV2", "simple_mul_add/add")], [],
1025        show_stack_trace=True,
1026        stack_trace_available=True)
1027    check_main_menu(
1028        self,
1029        out,
1030        list_tensors_enabled=True,
1031        list_inputs_node_name=node_name,
1032        print_tensor_node_name=node_name,
1033        list_outputs_node_name=node_name)
1034
1035  def testNodeInfoByTensorName(self):
1036    node_name = "simple_mul_add/u/read"
1037    tensor_name = node_name + ":0"
1038    out = self._registry.dispatch_command("node_info", [tensor_name])
1039
1040    assert_node_attribute_lines(self, out, node_name, "Identity",
1041                                self._main_device,
1042                                [("VariableV2", "simple_mul_add/u")], [],
1043                                [(_matmul_op_name(), "simple_mul_add/matmul")],
1044                                [])
1045    check_main_menu(
1046        self,
1047        out,
1048        list_tensors_enabled=True,
1049        list_inputs_node_name=node_name,
1050        print_tensor_node_name=node_name,
1051        list_outputs_node_name=node_name)
1052
1053  def testNodeInfoNonexistentNodeName(self):
1054    out = self._registry.dispatch_command("node_info", ["bar"])
1055    self.assertEqual(
1056        ["ERROR: There is no node named \"bar\" in the partition graphs"],
1057        out.lines)
1058    # Check color indicating error.
1059    self.assertEqual({0: [(0, 59, cli_shared.COLOR_RED)]}, out.font_attr_segs)
1060    check_main_menu(self, out, list_tensors_enabled=True)
1061
1062  def testPrintTensor(self):
1063    node_name = "simple_mul_add/matmul"
1064    tensor_name = node_name + ":0"
1065    out = self._registry.dispatch_command(
1066        "print_tensor", [tensor_name], screen_info={"cols": 80})
1067
1068    self.assertEqual([
1069        "Tensor \"%s:DebugIdentity\":" % tensor_name,
1070        "  dtype: float64",
1071        "  shape: (2, 1)",
1072        "",
1073        "array([[ 7.],",
1074        "       [-2.]])",
1075    ], out.lines)
1076
1077    self.assertIn("tensor_metadata", out.annotations)
1078    self.assertIn(4, out.annotations)
1079    self.assertIn(5, out.annotations)
1080    check_main_menu(
1081        self,
1082        out,
1083        list_tensors_enabled=True,
1084        node_info_node_name=node_name,
1085        list_inputs_node_name=node_name,
1086        list_outputs_node_name=node_name)
1087
1088  def testPrintTensorAndWriteToNpyFile(self):
1089    node_name = "simple_mul_add/matmul"
1090    tensor_name = node_name + ":0"
1091    npy_path = os.path.join(self._dump_root, "matmul.npy")
1092    out = self._registry.dispatch_command(
1093        "print_tensor", [tensor_name, "-w", npy_path],
1094        screen_info={"cols": 80})
1095
1096    self.assertEqual([
1097        "Tensor \"%s:DebugIdentity\":" % tensor_name,
1098        "  dtype: float64",
1099        "  shape: (2, 1)",
1100        "",
1101    ], out.lines[:4])
1102    self.assertTrue(out.lines[4].startswith("Saved value to: %s (" % npy_path))
1103    # Load the numpy file and verify its contents.
1104    self.assertAllClose([[7.0], [-2.0]], np.load(npy_path))
1105
1106  def testPrintTensorHighlightingRanges(self):
1107    node_name = "simple_mul_add/matmul"
1108    tensor_name = node_name + ":0"
1109    out = self._registry.dispatch_command(
1110        "print_tensor", [tensor_name, "--ranges", "[-inf, 0.0]"],
1111        screen_info={"cols": 80})
1112
1113    self.assertEqual([
1114        "Tensor \"%s:DebugIdentity\": " % tensor_name +
1115        "Highlighted([-inf, 0.0]): 1 of 2 element(s) (50.00%)",
1116        "  dtype: float64",
1117        "  shape: (2, 1)",
1118        "",
1119        "array([[ 7.],",
1120        "       [-2.]])",
1121    ], out.lines)
1122
1123    self.assertIn("tensor_metadata", out.annotations)
1124    self.assertIn(4, out.annotations)
1125    self.assertIn(5, out.annotations)
1126    self.assertEqual([(8, 11, "bold")], out.font_attr_segs[5])
1127
1128    out = self._registry.dispatch_command(
1129        "print_tensor", [tensor_name, "--ranges", "[[-inf, -5.5], [5.5, inf]]"],
1130        screen_info={"cols": 80})
1131
1132    self.assertEqual([
1133        "Tensor \"%s:DebugIdentity\": " % tensor_name +
1134        "Highlighted([[-inf, -5.5], [5.5, inf]]): "
1135        "1 of 2 element(s) (50.00%)",
1136        "  dtype: float64",
1137        "  shape: (2, 1)",
1138        "",
1139        "array([[ 7.],",
1140        "       [-2.]])",
1141    ], out.lines)
1142
1143    self.assertIn("tensor_metadata", out.annotations)
1144    self.assertIn(4, out.annotations)
1145    self.assertIn(5, out.annotations)
1146    self.assertEqual([(9, 11, "bold")], out.font_attr_segs[4])
1147    self.assertNotIn(5, out.font_attr_segs)
1148    check_main_menu(
1149        self,
1150        out,
1151        list_tensors_enabled=True,
1152        node_info_node_name=node_name,
1153        list_inputs_node_name=node_name,
1154        list_outputs_node_name=node_name)
1155
1156  def testPrintTensorHighlightingRangesAndIncludingNumericSummary(self):
1157    node_name = "simple_mul_add/matmul"
1158    tensor_name = node_name + ":0"
1159    out = self._registry.dispatch_command(
1160        "print_tensor", [tensor_name, "--ranges", "[-inf, 0.0]", "-s"],
1161        screen_info={"cols": 80})
1162
1163    self.assertEqual([
1164        "Tensor \"%s:DebugIdentity\": " % tensor_name +
1165        "Highlighted([-inf, 0.0]): 1 of 2 element(s) (50.00%)",
1166        "  dtype: float64",
1167        "  shape: (2, 1)",
1168        "",
1169        "Numeric summary:",
1170        "| - + | total |",
1171        "| 1 1 |     2 |",
1172        "|  min  max mean  std |",
1173        "| -2.0  7.0  2.5  4.5 |",
1174        "",
1175        "array([[ 7.],",
1176        "       [-2.]])",
1177    ], out.lines)
1178
1179    self.assertIn("tensor_metadata", out.annotations)
1180    self.assertIn(10, out.annotations)
1181    self.assertIn(11, out.annotations)
1182    self.assertEqual([(8, 11, "bold")], out.font_attr_segs[11])
1183
1184  def testPrintTensorWithSlicing(self):
1185    node_name = "simple_mul_add/matmul"
1186    tensor_name = node_name + ":0"
1187    out = self._registry.dispatch_command(
1188        "print_tensor", [tensor_name + "[1, :]"], screen_info={"cols": 80})
1189
1190    self.assertEqual([
1191        "Tensor \"%s:DebugIdentity[1, :]\":" % tensor_name, "  dtype: float64",
1192        "  shape: (1,)", "", "array([-2.])"
1193    ], out.lines)
1194
1195    self.assertIn("tensor_metadata", out.annotations)
1196    self.assertIn(4, out.annotations)
1197    check_main_menu(
1198        self,
1199        out,
1200        list_tensors_enabled=True,
1201        node_info_node_name=node_name,
1202        list_inputs_node_name=node_name,
1203        list_outputs_node_name=node_name)
1204
1205  def testPrintTensorInvalidSlicingString(self):
1206    node_name = "simple_mul_add/matmul"
1207    tensor_name = node_name + ":0"
1208    out = self._registry.dispatch_command(
1209        "print_tensor", [tensor_name + "[1, foo()]"], screen_info={"cols": 80})
1210
1211    self.assertEqual("Error occurred during handling of command: print_tensor "
1212                     + tensor_name + "[1, foo()]:", out.lines[0])
1213    self.assertEqual("ValueError: Invalid tensor-slicing string.",
1214                     out.lines[-2])
1215
1216  def testPrintTensorValidExplicitNumber(self):
1217    node_name = "simple_mul_add/matmul"
1218    tensor_name = node_name + ":0"
1219    out = self._registry.dispatch_command(
1220        "print_tensor", [tensor_name, "-n", "0"], screen_info={"cols": 80})
1221
1222    self.assertEqual([
1223        "Tensor \"%s:DebugIdentity\":" % tensor_name,
1224        "  dtype: float64",
1225        "  shape: (2, 1)",
1226        "",
1227        "array([[ 7.],",
1228        "       [-2.]])",
1229    ], out.lines)
1230
1231    self.assertIn("tensor_metadata", out.annotations)
1232    self.assertIn(4, out.annotations)
1233    self.assertIn(5, out.annotations)
1234    check_main_menu(
1235        self,
1236        out,
1237        list_tensors_enabled=True,
1238        node_info_node_name=node_name,
1239        list_inputs_node_name=node_name,
1240        list_outputs_node_name=node_name)
1241
1242  def testPrintTensorInvalidExplicitNumber(self):
1243    node_name = "simple_mul_add/matmul"
1244    tensor_name = node_name + ":0"
1245    out = self._registry.dispatch_command(
1246        "print_tensor", [tensor_name, "-n", "1"], screen_info={"cols": 80})
1247
1248    self.assertEqual([
1249        "ERROR: Invalid number (1) for tensor simple_mul_add/matmul:0, "
1250        "which generated one dump."
1251    ], out.lines)
1252
1253    self.assertNotIn("tensor_metadata", out.annotations)
1254
1255    check_main_menu(
1256        self,
1257        out,
1258        list_tensors_enabled=True,
1259        node_info_node_name=node_name,
1260        list_inputs_node_name=node_name,
1261        list_outputs_node_name=node_name)
1262
1263  def testPrintTensorMissingOutputSlotLeadsToOnlyDumpedTensorPrinted(self):
1264    node_name = "simple_mul_add/matmul"
1265    out = self._registry.dispatch_command("print_tensor", [node_name])
1266
1267    self.assertEqual([
1268        "Tensor \"%s:0:DebugIdentity\":" % node_name, "  dtype: float64",
1269        "  shape: (2, 1)", "", "array([[ 7.],", "       [-2.]])"
1270    ], out.lines)
1271    check_main_menu(
1272        self,
1273        out,
1274        list_tensors_enabled=True,
1275        node_info_node_name=node_name,
1276        list_inputs_node_name=node_name,
1277        list_outputs_node_name=node_name)
1278
1279  def testPrintTensorNonexistentNodeName(self):
1280    out = self._registry.dispatch_command(
1281        "print_tensor", ["simple_mul_add/matmul/foo:0"])
1282
1283    self.assertEqual([
1284        "ERROR: Node \"simple_mul_add/matmul/foo\" does not exist in partition "
1285        "graphs"
1286    ], out.lines)
1287    check_main_menu(self, out, list_tensors_enabled=True)
1288
1289  def testEvalExpression(self):
1290    node_name = "simple_mul_add/matmul"
1291    tensor_name = node_name + ":0"
1292    out = self._registry.dispatch_command(
1293        "eval", ["np.matmul(`%s`, `%s`.T)" % (tensor_name, tensor_name)],
1294        screen_info={"cols": 80})
1295
1296    cli_test_utils.assert_lines_equal_ignoring_whitespace(
1297        self,
1298        ["Tensor \"from eval of expression "
1299         "'np.matmul(`simple_mul_add/matmul:0`, "
1300         "`simple_mul_add/matmul:0`.T)'\":",
1301         "  dtype: float64",
1302         "  shape: (2, 2)",
1303         "",
1304         "Numeric summary:",
1305         "| - + | total |",
1306         "| 2 2 |     4 |",
1307         "|           min           max          mean           std |"],
1308        out.lines[:8])
1309    cli_test_utils.assert_array_lines_close(
1310        self, [-14.0, 49.0, 6.25, 25.7524270701], out.lines[8:9])
1311    cli_test_utils.assert_array_lines_close(
1312        self, [[49.0, -14.0], [-14.0, 4.0]], out.lines[10:])
1313
1314  def testEvalExpressionAndWriteToNpyFile(self):
1315    node_name = "simple_mul_add/matmul"
1316    tensor_name = node_name + ":0"
1317    npy_path = os.path.join(self._dump_root, "matmul_eval.npy")
1318    out = self._registry.dispatch_command(
1319        "eval",
1320        ["np.matmul(`%s`, `%s`.T)" % (tensor_name, tensor_name), "-w",
1321         npy_path], screen_info={"cols": 80})
1322
1323    self.assertEqual([
1324        "Tensor \"from eval of expression "
1325        "'np.matmul(`simple_mul_add/matmul:0`, "
1326        "`simple_mul_add/matmul:0`.T)'\":",
1327        "  dtype: float64",
1328        "  shape: (2, 2)",
1329        ""], out.lines[:4])
1330
1331    self.assertTrue(out.lines[4].startswith("Saved value to: %s (" % npy_path))
1332    # Load the numpy file and verify its contents.
1333    self.assertAllClose([[49.0, -14.0], [-14.0, 4.0]], np.load(npy_path))
1334
1335  def testAddGetTensorFilterLambda(self):
1336    analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
1337                                          _cli_config_from_temp_file())
1338    analyzer.add_tensor_filter("foo_filter", lambda x, y: True)
1339    self.assertTrue(analyzer.get_tensor_filter("foo_filter")(None, None))
1340
1341  def testAddGetTensorFilterNestedFunction(self):
1342    analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
1343                                          _cli_config_from_temp_file())
1344
1345    def foo_filter(unused_arg_0, unused_arg_1):
1346      return True
1347
1348    analyzer.add_tensor_filter("foo_filter", foo_filter)
1349    self.assertTrue(analyzer.get_tensor_filter("foo_filter")(None, None))
1350
1351  def testAddTensorFilterEmptyName(self):
1352    analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
1353                                          _cli_config_from_temp_file())
1354
1355    with self.assertRaisesRegex(ValueError,
1356                                "Input argument filter_name cannot be empty."):
1357      analyzer.add_tensor_filter("", lambda datum, tensor: True)
1358
1359  def testAddTensorFilterNonStrName(self):
1360    analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
1361                                          _cli_config_from_temp_file())
1362
1363    with self.assertRaisesRegex(
1364        TypeError, "Input argument filter_name is expected to be str, "
1365        "but is not"):
1366      analyzer.add_tensor_filter(1, lambda datum, tensor: True)
1367
1368  def testAddGetTensorFilterNonCallable(self):
1369    analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
1370                                          _cli_config_from_temp_file())
1371
1372    with self.assertRaisesRegex(
1373        TypeError, "Input argument filter_callable is expected to be callable, "
1374        "but is not."):
1375      analyzer.add_tensor_filter("foo_filter", "bar")
1376
1377  def testGetNonexistentTensorFilter(self):
1378    analyzer = analyzer_cli.DebugAnalyzer(self._debug_dump,
1379                                          _cli_config_from_temp_file())
1380
1381    analyzer.add_tensor_filter("foo_filter", lambda datum, tensor: True)
1382    with self.assertRaisesRegex(ValueError,
1383                                "There is no tensor filter named \"bar\""):
1384      analyzer.get_tensor_filter("bar")
1385
1386  def _findSourceLine(self, annotated_source, line_number):
1387    """Find line of given line number in annotated source.
1388
1389    Args:
1390      annotated_source: (debugger_cli_common.RichTextLines) the annotated source
1391      line_number: (int) 1-based line number
1392
1393    Returns:
1394      (int) If line_number is found, 0-based line index in
1395        annotated_source.lines. Otherwise, None.
1396    """
1397
1398    index = None
1399    for i, line in enumerate(annotated_source.lines):
1400      if line.startswith("L%d " % line_number):
1401        index = i
1402        break
1403    return index
1404
1405  def testPrintSourceForOpNamesWholeFileWorks(self):
1406    self._debug_dump.set_python_graph(self._sess.graph)
1407    out = self._registry.dispatch_command(
1408        "print_source", [self._curr_file_path], screen_info={"cols": 80})
1409
1410    # Verify the annotation of the line that creates u.
1411    index = self._findSourceLine(out, self._u_line_number)
1412    self.assertEqual(
1413        ["L%d         u = variables.VariableV1(u_init, name=u_name)" %
1414         self._u_line_number,
1415         "    simple_mul_add/u",
1416         "    simple_mul_add/u/Assign",
1417         "    simple_mul_add/u/read"],
1418        out.lines[index : index + 4])
1419    self.assertEqual("pt simple_mul_add/u",
1420                     out.font_attr_segs[index + 1][0][2].content)
1421    # simple_mul_add/u/Assign is not used in this run because the Variable has
1422    # already been initialized.
1423    self.assertEqual(cli_shared.COLOR_BLUE, out.font_attr_segs[index + 2][0][2])
1424    self.assertEqual("pt simple_mul_add/u/read",
1425                     out.font_attr_segs[index + 3][0][2].content)
1426
1427    # Verify the annotation of the line that creates v.
1428    index = self._findSourceLine(out, self._v_line_number)
1429    self.assertEqual(
1430        ["L%d         v = variables.VariableV1(v_init, name=v_name)" %
1431         self._v_line_number,
1432         "    simple_mul_add/v"],
1433        out.lines[index : index + 2])
1434    self.assertEqual("pt simple_mul_add/v",
1435                     out.font_attr_segs[index + 1][0][2].content)
1436
1437    # Verify the annotation of the line that creates w.
1438    index = self._findSourceLine(out, self._w_line_number)
1439    self.assertEqual(
1440        ["L%d         " % self._w_line_number +
1441         "w = math_ops.matmul(u, v, name=\"simple_mul_add/matmul\")",
1442         "    simple_mul_add/matmul"],
1443        out.lines[index : index + 2])
1444    self.assertEqual("pt simple_mul_add/matmul",
1445                     out.font_attr_segs[index + 1][0][2].content)
1446
1447    # Verify the annotation of the line that creates x.
1448    index = self._findSourceLine(out, self._x_line_number)
1449    self.assertEqual(
1450        ["L%d         " % self._x_line_number +
1451         "x = math_ops.add(w, w, name=\"simple_mul_add/add\")",
1452         "    simple_mul_add/add"],
1453        out.lines[index : index + 2])
1454    self.assertEqual("pt simple_mul_add/add",
1455                     out.font_attr_segs[index + 1][0][2].content)
1456
1457  def testPrintSourceForTensorNamesWholeFileWorks(self):
1458    self._debug_dump.set_python_graph(self._sess.graph)
1459    out = self._registry.dispatch_command(
1460        "print_source",
1461        [self._curr_file_path, "--tensors"],
1462        screen_info={"cols": 80})
1463
1464    # Verify the annotation of the line that creates u.
1465    index = self._findSourceLine(out, self._u_line_number)
1466    self.assertEqual(
1467        ["L%d         u = variables.VariableV1(u_init, name=u_name)" %
1468         self._u_line_number,
1469         "    simple_mul_add/u/read:0",
1470         "    simple_mul_add/u:0"],
1471        out.lines[index : index + 3])
1472    self.assertEqual("pt simple_mul_add/u/read:0",
1473                     out.font_attr_segs[index + 1][0][2].content)
1474    self.assertEqual("pt simple_mul_add/u:0",
1475                     out.font_attr_segs[index + 2][0][2].content)
1476
1477  def testPrintSourceForOpNamesStartingAtSpecifiedLineWorks(self):
1478    self._debug_dump.set_python_graph(self._sess.graph)
1479    out = self._registry.dispatch_command(
1480        "print_source",
1481        [self._curr_file_path, "-b", "3"],
1482        screen_info={"cols": 80})
1483
1484    self.assertEqual(
1485        2, out.annotations[debugger_cli_common.INIT_SCROLL_POS_KEY])
1486
1487    index = self._findSourceLine(out, self._u_line_number)
1488    self.assertEqual(
1489        ["L%d         u = variables.VariableV1(u_init, name=u_name)" %
1490         self._u_line_number,
1491         "    simple_mul_add/u",
1492         "    simple_mul_add/u/Assign",
1493         "    simple_mul_add/u/read"],
1494        out.lines[index : index + 4])
1495    self.assertEqual("pt simple_mul_add/u",
1496                     out.font_attr_segs[index + 1][0][2].content)
1497    # simple_mul_add/u/Assign is not used in this run because the Variable has
1498    # already been initialized.
1499    self.assertEqual(cli_shared.COLOR_BLUE, out.font_attr_segs[index + 2][0][2])
1500    self.assertEqual("pt simple_mul_add/u/read",
1501                     out.font_attr_segs[index + 3][0][2].content)
1502
1503  def testPrintSourceForOpNameSettingMaximumElementCountWorks(self):
1504    self._debug_dump.set_python_graph(self._sess.graph)
1505    out = self._registry.dispatch_command(
1506        "print_source",
1507        [self._curr_file_path, "-m", "1"],
1508        screen_info={"cols": 80})
1509
1510    index = self._findSourceLine(out, self._u_line_number)
1511    self.assertEqual(
1512        ["L%d         u = variables.VariableV1(u_init, name=u_name)" %
1513         self._u_line_number,
1514         "    simple_mul_add/u",
1515         "    (... Omitted 2 of 3 op(s) ...) +5"],
1516        out.lines[index : index + 3])
1517    self.assertEqual("pt simple_mul_add/u",
1518                     out.font_attr_segs[index + 1][0][2].content)
1519    more_elements_command = out.font_attr_segs[index + 2][-1][2].content
1520    self.assertStartsWith(more_elements_command,
1521                          "ps %s " % self._curr_file_path)
1522    self.assertIn(" -m 6", more_elements_command)
1523
1524  def testListSourceWorks(self):
1525    self._debug_dump.set_python_graph(self._sess.graph)
1526    out = self._registry.dispatch_command("list_source", [])
1527
1528    non_tf_lib_files_start = [
1529        i for i in range(len(out.lines))
1530        if out.lines[i].startswith("Source file path")
1531    ][0] + 1
1532    non_tf_lib_files_end = [
1533        i for i in range(len(out.lines))
1534        if out.lines[i].startswith("TensorFlow Python library file(s):")
1535    ][0] - 1
1536    non_tf_lib_files = [
1537        line.split(" ")[0] for line
1538        in out.lines[non_tf_lib_files_start : non_tf_lib_files_end]]
1539    self.assertIn(self._curr_file_path, non_tf_lib_files)
1540
1541    # Check that the TF library files are marked with special color attribute.
1542    for i in range(non_tf_lib_files_end + 1, len(out.lines)):
1543      if not out.lines[i]:
1544        continue
1545      for attr_seg in  out.font_attr_segs[i]:
1546        self.assertTrue(cli_shared.COLOR_GRAY in attr_seg[2] or
1547                        attr_seg[2] == cli_shared.COLOR_GRAY)
1548
1549  def testListSourceWithNodeNameFilterWithMatchesWorks(self):
1550    self._debug_dump.set_python_graph(self._sess.graph)
1551    out = self._registry.dispatch_command("list_source", ["-n", ".*/read"])
1552
1553    self.assertStartsWith(out.lines[1], "Node name regex filter: \".*/read\"")
1554
1555    non_tf_lib_files_start = [
1556        i for i in range(len(out.lines))
1557        if out.lines[i].startswith("Source file path")
1558    ][0] + 1
1559    non_tf_lib_files_end = [
1560        i for i in range(len(out.lines))
1561        if out.lines[i].startswith("TensorFlow Python library file(s):")
1562    ][0] - 1
1563    non_tf_lib_files = [
1564        line.split(" ")[0] for line
1565        in out.lines[non_tf_lib_files_start : non_tf_lib_files_end]]
1566    self.assertIn(self._curr_file_path, non_tf_lib_files)
1567
1568    # Check that the TF library files are marked with special color attribute.
1569    for i in range(non_tf_lib_files_end + 1, len(out.lines)):
1570      if not out.lines[i]:
1571        continue
1572      for attr_seg in  out.font_attr_segs[i]:
1573        self.assertTrue(cli_shared.COLOR_GRAY in attr_seg[2] or
1574                        attr_seg[2] == cli_shared.COLOR_GRAY)
1575
1576  def testListSourceWithNodeNameFilterWithNoMatchesWorks(self):
1577    self._debug_dump.set_python_graph(self._sess.graph)
1578    out = self._registry.dispatch_command("list_source", ["-n", "^$"])
1579
1580    self.assertEqual([
1581        "List of source files that created nodes in this run",
1582        "Node name regex filter: \"^$\"", "",
1583        "[No source file information.]"], out.lines)
1584
1585  def testListSourceWithPathAndNodeNameFiltersWorks(self):
1586    self._debug_dump.set_python_graph(self._sess.graph)
1587    out = self._registry.dispatch_command(
1588        "list_source", ["-p", self._curr_file_path, "-n", ".*read"])
1589
1590    self.assertEqual([
1591        "List of source files that created nodes in this run",
1592        "File path regex filter: \"%s\"" % self._curr_file_path,
1593        "Node name regex filter: \".*read\"", ""], out.lines[:4])
1594
1595  def testListSourceWithCompiledPythonSourceWorks(self):
1596    def fake_list_source_files_against_dump(dump,
1597                                            path_regex_allowlist=None,
1598                                            node_name_regex_allowlist=None):
1599      del dump, path_regex_allowlist, node_name_regex_allowlist
1600      return [("compiled_1.pyc", False, 10, 20, 30, 4),
1601              ("compiled_2.pyo", False, 10, 20, 30, 5),
1602              ("uncompiled.py", False, 10, 20, 30, 6)]
1603
1604    with test.mock.patch.object(
1605        source_utils, "list_source_files_against_dump",
1606        side_effect=fake_list_source_files_against_dump):
1607      out = self._registry.dispatch_command("list_source", [])
1608
1609      self.assertStartsWith(out.lines[4], "compiled_1.pyc")
1610      self.assertEqual((0, 14, [cli_shared.COLOR_WHITE]),
1611                       out.font_attr_segs[4][0])
1612      self.assertStartsWith(out.lines[5], "compiled_2.pyo")
1613      self.assertEqual((0, 14, [cli_shared.COLOR_WHITE]),
1614                       out.font_attr_segs[5][0])
1615      self.assertStartsWith(out.lines[6], "uncompiled.py")
1616      self.assertEqual(0, out.font_attr_segs[6][0][0])
1617      self.assertEqual(13, out.font_attr_segs[6][0][1])
1618      self.assertEqual(cli_shared.COLOR_WHITE, out.font_attr_segs[6][0][2][0])
1619      self.assertEqual("ps uncompiled.py -b 6",
1620                       out.font_attr_segs[6][0][2][1].content)
1621
1622  def testListInputInvolvingNodesWithMultipleOutputs(self):
1623    """List an input tree containing tensors from non-:0 output slot."""
1624
1625    with session.Session(config=no_rewrite_session_config()) as sess:
1626      with ops.device("CPU:0"):
1627        x = variables.VariableV1([1, 3, 3, 7], name="x")
1628        _, idx = array_ops.unique(x, name="x_unique")
1629        idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
1630        self.evaluate(x.initializer)
1631
1632        run_options = config_pb2.RunOptions(output_partition_graphs=True)
1633        debug_utils.watch_graph(
1634            run_options,
1635            sess.graph,
1636            debug_ops=["DebugIdentity"],
1637            debug_urls="file://%s" % self._dump_root_for_unique)
1638        run_metadata = config_pb2.RunMetadata()
1639        self.assertAllEqual([0, 2, 2, 4],
1640                            sess.run(
1641                                idx_times_two,
1642                                options=run_options,
1643                                run_metadata=run_metadata))
1644        debug_dump = debug_data.DebugDumpDir(
1645            self._dump_root_for_unique,
1646            partition_graphs=run_metadata.partition_graphs)
1647        _, registry = create_analyzer_cli(debug_dump)
1648
1649        out = registry.dispatch_command("li", ["idx_times_two"])
1650        self.assertEqual([
1651            "Inputs to node \"idx_times_two\" (Depth limit = 1):",
1652            "|- (1) x_unique:1"
1653        ], out.lines[:2])
1654
1655
1656class AnalyzerCLIPrintLargeTensorTest(test_util.TensorFlowTestCase):
1657
1658  @classmethod
1659  def setUpClass(cls):
1660    cls._dump_root = tempfile.mkdtemp()
1661
1662    with session.Session(config=no_rewrite_session_config()) as sess:
1663      # 2400 elements should exceed the default threshold (2000).
1664      x = constant_op.constant(np.zeros([300, 8]), name="large_tensors/x")
1665
1666      run_options = config_pb2.RunOptions(output_partition_graphs=True)
1667      debug_utils.watch_graph(
1668          run_options,
1669          sess.graph,
1670          debug_ops=["DebugIdentity"],
1671          debug_urls="file://%s" % cls._dump_root)
1672
1673      # Invoke Session.run().
1674      run_metadata = config_pb2.RunMetadata()
1675      sess.run(x, options=run_options, run_metadata=run_metadata)
1676
1677    cls._debug_dump = debug_data.DebugDumpDir(
1678        cls._dump_root, partition_graphs=run_metadata.partition_graphs)
1679
1680    # Construct the analyzer and command registry.
1681    cls._analyzer, cls._registry = create_analyzer_cli(cls._debug_dump)
1682
1683  @classmethod
1684  def tearDownClass(cls):
1685    # Tear down temporary dump directory.
1686    file_io.delete_recursively(cls._dump_root)
1687
1688  def testPrintLargeTensorWithoutAllOption(self):
1689    out = self._registry.dispatch_command(
1690        "print_tensor", ["large_tensors/x:0"], screen_info={"cols": 80})
1691
1692    # Assert that ellipses are present in the tensor value printout.
1693    self.assertIn("...,", out.lines[4])
1694
1695    # 2100 still exceeds 2000.
1696    out = self._registry.dispatch_command(
1697        "print_tensor", ["large_tensors/x:0[:, 0:7]"],
1698        screen_info={"cols": 80})
1699
1700    self.assertIn("...,", out.lines[4])
1701
1702  def testPrintLargeTensorWithAllOption(self):
1703    out = self._registry.dispatch_command(
1704        "print_tensor", ["large_tensors/x:0", "-a"],
1705        screen_info={"cols": 80})
1706
1707    # Assert that ellipses are not present in the tensor value printout.
1708    self.assertNotIn("...,", out.lines[4])
1709
1710    out = self._registry.dispatch_command(
1711        "print_tensor", ["large_tensors/x:0[:, 0:7]", "--all"],
1712        screen_info={"cols": 80})
1713    self.assertNotIn("...,", out.lines[4])
1714
1715
1716@test_util.run_v1_only("b/120545219")
1717class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
1718
1719  @classmethod
1720  def setUpClass(cls):
1721    cls._dump_root = tempfile.mkdtemp()
1722
1723    cls._is_gpu_available = test.is_gpu_available()
1724    if cls._is_gpu_available:
1725      gpu_name = test_util.gpu_device_name()
1726      cls._main_device = "/job:localhost/replica:0/task:0" + gpu_name
1727    else:
1728      cls._main_device = "/job:localhost/replica:0/task:0/device:CPU:0"
1729
1730    with session.Session(config=no_rewrite_session_config()) as sess:
1731      x_init_val = np.array([5.0, 3.0])
1732      x_init = constant_op.constant(x_init_val, shape=[2])
1733      x = variables.VariableV1(x_init, name="control_deps/x")
1734
1735      y = math_ops.add(x, x, name="control_deps/y")
1736      y = control_flow_ops.with_dependencies(
1737          [x], y, name="control_deps/ctrl_dep_y")
1738
1739      z = math_ops.multiply(x, y, name="control_deps/z")
1740
1741      z = control_flow_ops.with_dependencies(
1742          [x, y], z, name="control_deps/ctrl_dep_z")
1743
1744      x.initializer.run()
1745
1746      run_options = config_pb2.RunOptions(output_partition_graphs=True)
1747      debug_utils.watch_graph(
1748          run_options,
1749          sess.graph,
1750          debug_ops=["DebugIdentity"],
1751          debug_urls="file://%s" % cls._dump_root)
1752
1753      # Invoke Session.run().
1754      run_metadata = config_pb2.RunMetadata()
1755      sess.run(z, options=run_options, run_metadata=run_metadata)
1756
1757    debug_dump = debug_data.DebugDumpDir(
1758        cls._dump_root, partition_graphs=run_metadata.partition_graphs)
1759
1760    # Construct the analyzer and command handler registry.
1761    _, cls._registry = create_analyzer_cli(debug_dump)
1762
1763  @classmethod
1764  def tearDownClass(cls):
1765    # Tear down temporary dump directory.
1766    file_io.delete_recursively(cls._dump_root)
1767
1768  def testNodeInfoWithControlDependencies(self):
1769    # Call node_info on a node with control inputs.
1770    out = self._registry.dispatch_command("node_info",
1771                                          ["control_deps/ctrl_dep_y"])
1772
1773    assert_node_attribute_lines(self, out, "control_deps/ctrl_dep_y",
1774                                "Identity", self._main_device,
1775                                [("AddV2", "control_deps/y")],
1776                                [("VariableV2", "control_deps/x")],
1777                                [("Mul", "control_deps/z")],
1778                                [("Identity", "control_deps/ctrl_dep_z")])
1779
1780    # Call node info on a node with control recipients.
1781    out = self._registry.dispatch_command("ni", ["control_deps/x"])
1782
1783    assert_node_attribute_lines(self, out, "control_deps/x", "VariableV2",
1784                                self._main_device, [], [],
1785                                [("Identity", "control_deps/x/read")],
1786                                [("Identity", "control_deps/ctrl_dep_y"),
1787                                 ("Identity", "control_deps/ctrl_dep_z")])
1788
1789    # Verify the menu items (command shortcuts) in the output.
1790    check_menu_item(self, out, 10,
1791                    len(out.lines[10]) - len("control_deps/x/read"),
1792                    len(out.lines[10]), "ni -a -d -t control_deps/x/read")
1793    if out.lines[13].endswith("control_deps/ctrl_dep_y"):
1794      y_line = 13
1795      z_line = 14
1796    else:
1797      y_line = 14
1798      z_line = 13
1799    check_menu_item(self, out, y_line,
1800                    len(out.lines[y_line]) - len("control_deps/ctrl_dep_y"),
1801                    len(out.lines[y_line]),
1802                    "ni -a -d -t control_deps/ctrl_dep_y")
1803    check_menu_item(self, out, z_line,
1804                    len(out.lines[z_line]) - len("control_deps/ctrl_dep_z"),
1805                    len(out.lines[z_line]),
1806                    "ni -a -d -t control_deps/ctrl_dep_z")
1807
1808  def testListInputsNonRecursiveNoControl(self):
1809    """List inputs non-recursively, without any control inputs."""
1810
1811    # Do not include node op types.
1812    node_name = "control_deps/z"
1813    out = self._registry.dispatch_command("list_inputs", [node_name])
1814
1815    self.assertEqual([
1816        "Inputs to node \"%s\" (Depth limit = 1):" % node_name,
1817        "|- (1) control_deps/x/read", "|  |- ...",
1818        "|- (1) control_deps/ctrl_dep_y", "   |- ...", "", "Legend:",
1819        "  (d): recursion depth = d."
1820    ], out.lines)
1821
1822    # Include node op types.
1823    out = self._registry.dispatch_command("li", ["-t", node_name])
1824
1825    self.assertEqual([
1826        "Inputs to node \"%s\" (Depth limit = 1):" % node_name,
1827        "|- (1) [Identity] control_deps/x/read", "|  |- ...",
1828        "|- (1) [Identity] control_deps/ctrl_dep_y", "   |- ...", "", "Legend:",
1829        "  (d): recursion depth = d.", "  [Op]: Input node has op type Op."
1830    ], out.lines)
1831    check_main_menu(
1832        self,
1833        out,
1834        list_tensors_enabled=True,
1835        node_info_node_name=node_name,
1836        print_tensor_node_name=node_name,
1837        list_outputs_node_name=node_name)
1838
1839    # Verify that the node name has bold attribute.
1840    self.assertEqual([(16, 16 + len(node_name), "bold")], out.font_attr_segs[0])
1841
1842    # Verify the menu items (command shortcuts) in the output.
1843    check_menu_item(self, out, 1,
1844                    len(out.lines[1]) - len("control_deps/x/read"),
1845                    len(out.lines[1]), "li -c -r control_deps/x/read")
1846    check_menu_item(self, out, 3,
1847                    len(out.lines[3]) - len("control_deps/ctrl_dep_y"),
1848                    len(out.lines[3]), "li -c -r control_deps/ctrl_dep_y")
1849
1850  def testListInputsNonRecursiveNoControlUsingTensorName(self):
1851    """List inputs using the name of an output tensor of the node."""
1852
1853    # Do not include node op types.
1854    node_name = "control_deps/z"
1855    tensor_name = node_name + ":0"
1856    out = self._registry.dispatch_command("list_inputs", [tensor_name])
1857
1858    self.assertEqual([
1859        "Inputs to node \"%s\" (Depth limit = 1):" % node_name,
1860        "|- (1) control_deps/x/read", "|  |- ...",
1861        "|- (1) control_deps/ctrl_dep_y", "   |- ...", "", "Legend:",
1862        "  (d): recursion depth = d."
1863    ], out.lines)
1864    check_main_menu(
1865        self,
1866        out,
1867        list_tensors_enabled=True,
1868        node_info_node_name=node_name,
1869        print_tensor_node_name=node_name,
1870        list_outputs_node_name=node_name)
1871    check_menu_item(self, out, 1,
1872                    len(out.lines[1]) - len("control_deps/x/read"),
1873                    len(out.lines[1]), "li -c -r control_deps/x/read")
1874    check_menu_item(self, out, 3,
1875                    len(out.lines[3]) - len("control_deps/ctrl_dep_y"),
1876                    len(out.lines[3]), "li -c -r control_deps/ctrl_dep_y")
1877
1878  def testListInputsNonRecursiveWithControls(self):
1879    """List inputs non-recursively, with control inputs."""
1880    node_name = "control_deps/ctrl_dep_z"
1881    out = self._registry.dispatch_command("li", ["-t", node_name, "-c"])
1882
1883    self.assertEqual([
1884        "Inputs to node \"%s\" (Depth limit = 1, " % node_name +
1885        "control inputs included):", "|- (1) [Mul] control_deps/z", "|  |- ...",
1886        "|- (1) (Ctrl) [Identity] control_deps/ctrl_dep_y", "|  |- ...",
1887        "|- (1) (Ctrl) [VariableV2] control_deps/x", "", "Legend:",
1888        "  (d): recursion depth = d.", "  (Ctrl): Control input.",
1889        "  [Op]: Input node has op type Op."
1890    ], out.lines)
1891    check_main_menu(
1892        self,
1893        out,
1894        list_tensors_enabled=True,
1895        node_info_node_name=node_name,
1896        print_tensor_node_name=node_name,
1897        list_outputs_node_name=node_name)
1898    check_menu_item(self, out, 1,
1899                    len(out.lines[1]) - len("control_deps/z"),
1900                    len(out.lines[1]), "li -c -r control_deps/z")
1901    check_menu_item(self, out, 3,
1902                    len(out.lines[3]) - len("control_deps/ctrl_dep_y"),
1903                    len(out.lines[3]), "li -c -r control_deps/ctrl_dep_y")
1904    check_menu_item(self, out, 5,
1905                    len(out.lines[5]) - len("control_deps/x"),
1906                    len(out.lines[5]), "li -c -r control_deps/x")
1907
1908  def testListInputsRecursiveWithControls(self):
1909    """List inputs recursively, with control inputs."""
1910    node_name = "control_deps/ctrl_dep_z"
1911    out = self._registry.dispatch_command("li", ["-c", "-r", "-t", node_name])
1912
1913    self.assertEqual([
1914        "Inputs to node \"%s\" (Depth limit = 20, " % node_name +
1915        "control inputs included):", "|- (1) [Mul] control_deps/z",
1916        "|  |- (2) [Identity] control_deps/x/read",
1917        "|  |  |- (3) [VariableV2] control_deps/x",
1918        "|  |- (2) [Identity] control_deps/ctrl_dep_y",
1919        "|     |- (3) [AddV2] control_deps/y",
1920        "|     |  |- (4) [Identity] control_deps/x/read",
1921        "|     |  |  |- (5) [VariableV2] control_deps/x",
1922        "|     |  |- (4) [Identity] control_deps/x/read",
1923        "|     |     |- (5) [VariableV2] control_deps/x",
1924        "|     |- (3) (Ctrl) [VariableV2] control_deps/x",
1925        "|- (1) (Ctrl) [Identity] control_deps/ctrl_dep_y",
1926        "|  |- (2) [AddV2] control_deps/y",
1927        "|  |  |- (3) [Identity] control_deps/x/read",
1928        "|  |  |  |- (4) [VariableV2] control_deps/x",
1929        "|  |  |- (3) [Identity] control_deps/x/read",
1930        "|  |     |- (4) [VariableV2] control_deps/x",
1931        "|  |- (2) (Ctrl) [VariableV2] control_deps/x",
1932        "|- (1) (Ctrl) [VariableV2] control_deps/x", "", "Legend:",
1933        "  (d): recursion depth = d.", "  (Ctrl): Control input.",
1934        "  [Op]: Input node has op type Op."
1935    ], out.lines)
1936    check_main_menu(
1937        self,
1938        out,
1939        list_tensors_enabled=True,
1940        node_info_node_name=node_name,
1941        print_tensor_node_name=node_name,
1942        list_outputs_node_name=node_name)
1943    check_menu_item(self, out, 1,
1944                    len(out.lines[1]) - len("control_deps/z"),
1945                    len(out.lines[1]), "li -c -r control_deps/z")
1946    check_menu_item(self, out, 11,
1947                    len(out.lines[11]) - len("control_deps/ctrl_dep_y"),
1948                    len(out.lines[11]), "li -c -r control_deps/ctrl_dep_y")
1949    check_menu_item(self, out, 18,
1950                    len(out.lines[18]) - len("control_deps/x"),
1951                    len(out.lines[18]), "li -c -r control_deps/x")
1952
1953  def testListInputsRecursiveWithControlsWithDepthLimit(self):
1954    """List inputs recursively, with control inputs and a depth limit."""
1955    node_name = "control_deps/ctrl_dep_z"
1956    out = self._registry.dispatch_command(
1957        "li", ["-c", "-r", "-t", "-d", "2", node_name])
1958
1959    self.assertEqual([
1960        "Inputs to node \"%s\" (Depth limit = 2, " % node_name +
1961        "control inputs included):", "|- (1) [Mul] control_deps/z",
1962        "|  |- (2) [Identity] control_deps/x/read", "|  |  |- ...",
1963        "|  |- (2) [Identity] control_deps/ctrl_dep_y", "|     |- ...",
1964        "|- (1) (Ctrl) [Identity] control_deps/ctrl_dep_y",
1965        "|  |- (2) [AddV2] control_deps/y", "|  |  |- ...",
1966        "|  |- (2) (Ctrl) [VariableV2] control_deps/x",
1967        "|- (1) (Ctrl) [VariableV2] control_deps/x", "", "Legend:",
1968        "  (d): recursion depth = d.", "  (Ctrl): Control input.",
1969        "  [Op]: Input node has op type Op."
1970    ], out.lines)
1971    check_main_menu(
1972        self,
1973        out,
1974        list_tensors_enabled=True,
1975        node_info_node_name=node_name,
1976        print_tensor_node_name=node_name,
1977        list_outputs_node_name=node_name)
1978    check_menu_item(self, out, 1,
1979                    len(out.lines[1]) - len("control_deps/z"),
1980                    len(out.lines[1]), "li -c -r control_deps/z")
1981    check_menu_item(self, out, 10,
1982                    len(out.lines[10]) - len("control_deps/x"),
1983                    len(out.lines[10]), "li -c -r control_deps/x")
1984
1985  def testListInputsNodeWithoutInputs(self):
1986    """List the inputs to a node without any input."""
1987    node_name = "control_deps/x"
1988    out = self._registry.dispatch_command("li", ["-c", "-r", "-t", node_name])
1989
1990    self.assertEqual([
1991        "Inputs to node \"%s\" (Depth limit = 20, control " % node_name +
1992        "inputs included):", "  [None]", "", "Legend:",
1993        "  (d): recursion depth = d.", "  (Ctrl): Control input.",
1994        "  [Op]: Input node has op type Op."
1995    ], out.lines)
1996    check_main_menu(
1997        self,
1998        out,
1999        list_tensors_enabled=True,
2000        node_info_node_name=node_name,
2001        print_tensor_node_name=node_name,
2002        list_outputs_node_name=node_name)
2003
2004  def testListInputsNonexistentNode(self):
2005    out = self._registry.dispatch_command(
2006        "list_inputs", ["control_deps/z/foo"])
2007
2008    self.assertEqual([
2009        "ERROR: There is no node named \"control_deps/z/foo\" in the "
2010        "partition graphs"], out.lines)
2011
2012  def testListRecipientsRecursiveWithControlsWithDepthLimit(self):
2013    """List recipients recursively, with control inputs and a depth limit."""
2014
2015    out = self._registry.dispatch_command(
2016        "lo", ["-c", "-r", "-t", "-d", "1", "control_deps/x"])
2017
2018    self.assertEqual([
2019        "Recipients of node \"control_deps/x\" (Depth limit = 1, control "
2020        "recipients included):",
2021        "|- (1) [Identity] control_deps/x/read",
2022        "|  |- ...",
2023        "|- (1) (Ctrl) [Identity] control_deps/ctrl_dep_y",
2024        "|  |- ...",
2025        "|- (1) (Ctrl) [Identity] control_deps/ctrl_dep_z",
2026        "", "Legend:", "  (d): recursion depth = d.",
2027        "  (Ctrl): Control input.",
2028        "  [Op]: Input node has op type Op."], out.lines)
2029    check_menu_item(self, out, 1,
2030                    len(out.lines[1]) - len("control_deps/x/read"),
2031                    len(out.lines[1]), "lo -c -r control_deps/x/read")
2032    check_menu_item(self, out, 3,
2033                    len(out.lines[3]) - len("control_deps/ctrl_dep_y"),
2034                    len(out.lines[3]), "lo -c -r control_deps/ctrl_dep_y")
2035    check_menu_item(self, out, 5,
2036                    len(out.lines[5]) - len("control_deps/ctrl_dep_z"),
2037                    len(out.lines[5]), "lo -c -r control_deps/ctrl_dep_z")
2038
2039    # Verify the bold attribute of the node name.
2040    self.assertEqual([(20, 20 + len("control_deps/x"), "bold")],
2041                     out.font_attr_segs[0])
2042
2043
2044@test_util.run_v1_only("b/120545219")
2045class AnalyzerCLIWhileLoopTest(test_util.TensorFlowTestCase):
2046
2047  @classmethod
2048  def setUpClass(cls):
2049    cls._dump_root = tempfile.mkdtemp()
2050
2051    with session.Session(config=no_rewrite_session_config()) as sess:
2052      loop_var = constant_op.constant(0, name="while_loop_test/loop_var")
2053      cond = lambda loop_var: math_ops.less(loop_var, 10)
2054      body = lambda loop_var: math_ops.add(loop_var, 1)
2055      while_loop = control_flow_ops.while_loop(
2056          cond, body, [loop_var], parallel_iterations=1)
2057
2058      run_options = config_pb2.RunOptions(output_partition_graphs=True)
2059      debug_url = "file://%s" % cls._dump_root
2060
2061      watch_opts = run_options.debug_options.debug_tensor_watch_opts
2062
2063      # Add debug tensor watch for "while/Identity".
2064      watch = watch_opts.add()
2065      watch.node_name = "while/Identity"
2066      watch.output_slot = 0
2067      watch.debug_ops.append("DebugIdentity")
2068      watch.debug_urls.append(debug_url)
2069
2070      # Invoke Session.run().
2071      run_metadata = config_pb2.RunMetadata()
2072      sess.run(while_loop, options=run_options, run_metadata=run_metadata)
2073
2074    cls._debug_dump = debug_data.DebugDumpDir(
2075        cls._dump_root, partition_graphs=run_metadata.partition_graphs)
2076
2077    cls._analyzer, cls._registry = create_analyzer_cli(cls._debug_dump)
2078
2079  @classmethod
2080  def tearDownClass(cls):
2081    # Tear down temporary dump directory.
2082    file_io.delete_recursively(cls._dump_root)
2083
2084  def testMultipleDumpsPrintTensorNoNumber(self):
2085    output = self._registry.dispatch_command("pt", ["while/Identity:0"])
2086
2087    self.assertEqual("Tensor \"while/Identity:0\" generated 10 dumps:",
2088                     output.lines[0])
2089
2090    for i in range(10):
2091      self.assertTrue(output.lines[i + 1].startswith("#%d" % i))
2092      self.assertTrue(output.lines[i + 1].endswith(
2093          " ms] while/Identity:0:DebugIdentity"))
2094
2095    self.assertEqual(
2096        "You can use the -n (--number) flag to specify which dump to print.",
2097        output.lines[-3])
2098    self.assertEqual("For example:", output.lines[-2])
2099    self.assertEqual("  print_tensor while/Identity:0 -n 0", output.lines[-1])
2100
2101  def testMultipleDumpsPrintTensorWithNumber(self):
2102    for i in range(5):
2103      output = self._registry.dispatch_command(
2104          "pt", ["while/Identity:0", "-n", "%d" % i])
2105
2106      self.assertEqual("Tensor \"while/Identity:0:DebugIdentity (dump #%d)\":" %
2107                       i, output.lines[0])
2108      self.assertEqual("  dtype: int32", output.lines[1])
2109      self.assertEqual("  shape: ()", output.lines[2])
2110      self.assertEqual("", output.lines[3])
2111      self.assertTrue(output.lines[4].startswith("array(%d" % i))
2112      self.assertTrue(output.lines[4].endswith(")"))
2113
2114  def testMultipleDumpsPrintTensorInvalidNumber(self):
2115    output = self._registry.dispatch_command("pt",
2116                                             ["while/Identity:0", "-n", "10"])
2117
2118    self.assertEqual([
2119        "ERROR: Specified number (10) exceeds the number of available dumps "
2120        "(10) for tensor while/Identity:0"
2121    ], output.lines)
2122
2123
2124if __name__ == "__main__":
2125  googletest.main()
2126