xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/cli/cli_shared.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Shared functions and classes for tfdbg command-line interface."""
16import math
17
18import numpy as np
19
20from tensorflow.python.debug.cli import command_parser
21from tensorflow.python.debug.cli import debugger_cli_common
22from tensorflow.python.debug.cli import tensor_format
23from tensorflow.python.debug.lib import common
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import variables
26from tensorflow.python.platform import gfile
27
28RL = debugger_cli_common.RichLine
29
30# Default threshold number of elements above which ellipses will be used
31# when printing the value of the tensor.
32DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000
33
34COLOR_BLACK = "black"
35COLOR_BLUE = "blue"
36COLOR_CYAN = "cyan"
37COLOR_GRAY = "gray"
38COLOR_GREEN = "green"
39COLOR_MAGENTA = "magenta"
40COLOR_RED = "red"
41COLOR_WHITE = "white"
42COLOR_YELLOW = "yellow"
43
44TIME_UNIT_US = "us"
45TIME_UNIT_MS = "ms"
46TIME_UNIT_S = "s"
47TIME_UNITS = [TIME_UNIT_US, TIME_UNIT_MS, TIME_UNIT_S]
48
49
50def bytes_to_readable_str(num_bytes, include_b=False):
51  """Generate a human-readable string representing number of bytes.
52
53  The units B, kB, MB and GB are used.
54
55  Args:
56    num_bytes: (`int` or None) Number of bytes.
57    include_b: (`bool`) Include the letter B at the end of the unit.
58
59  Returns:
60    (`str`) A string representing the number of bytes in a human-readable way,
61      including a unit at the end.
62  """
63
64  if num_bytes is None:
65    return str(num_bytes)
66  if num_bytes < 1024:
67    result = "%d" % num_bytes
68  elif num_bytes < 1048576:
69    result = "%.2fk" % (num_bytes / 1024.0)
70  elif num_bytes < 1073741824:
71    result = "%.2fM" % (num_bytes / 1048576.0)
72  else:
73    result = "%.2fG" % (num_bytes / 1073741824.0)
74
75  if include_b:
76    result += "B"
77  return result
78
79
80def time_to_readable_str(value_us, force_time_unit=None):
81  """Convert time value to human-readable string.
82
83  Args:
84    value_us: time value in microseconds.
85    force_time_unit: force the output to use the specified time unit. Must be
86      in TIME_UNITS.
87
88  Returns:
89    Human-readable string representation of the time value.
90
91  Raises:
92    ValueError: if force_time_unit value is not in TIME_UNITS.
93  """
94  if not value_us:
95    return "0"
96  if force_time_unit:
97    if force_time_unit not in TIME_UNITS:
98      raise ValueError("Invalid time unit: %s" % force_time_unit)
99    order = TIME_UNITS.index(force_time_unit)
100    time_unit = force_time_unit
101    return "{:.10g}{}".format(value_us / math.pow(10.0, 3*order), time_unit)
102  else:
103    order = min(len(TIME_UNITS) - 1, int(math.log(value_us, 10) / 3))
104    time_unit = TIME_UNITS[order]
105    return "{:.3g}{}".format(value_us / math.pow(10.0, 3*order), time_unit)
106
107
108def parse_ranges_highlight(ranges_string):
109  """Process ranges highlight string.
110
111  Args:
112    ranges_string: (str) A string representing a numerical range of a list of
113      numerical ranges. See the help info of the -r flag of the print_tensor
114      command for more details.
115
116  Returns:
117    An instance of tensor_format.HighlightOptions, if range_string is a valid
118      representation of a range or a list of ranges.
119  """
120
121  ranges = None
122
123  def ranges_filter(x):
124    r = np.zeros(x.shape, dtype=bool)
125    for range_start, range_end in ranges:
126      r = np.logical_or(r, np.logical_and(x >= range_start, x <= range_end))
127
128    return r
129
130  if ranges_string:
131    ranges = command_parser.parse_ranges(ranges_string)
132    return tensor_format.HighlightOptions(
133        ranges_filter, description=ranges_string)
134  else:
135    return None
136
137
138def numpy_printoptions_from_screen_info(screen_info):
139  if screen_info and "cols" in screen_info:
140    return {"linewidth": screen_info["cols"]}
141  else:
142    return {}
143
144
145def format_tensor(tensor,
146                  tensor_name,
147                  np_printoptions,
148                  print_all=False,
149                  tensor_slicing=None,
150                  highlight_options=None,
151                  include_numeric_summary=False,
152                  write_path=None):
153  """Generate formatted str to represent a tensor or its slices.
154
155  Args:
156    tensor: (numpy ndarray) The tensor value.
157    tensor_name: (str) Name of the tensor, e.g., the tensor's debug watch key.
158    np_printoptions: (dict) Numpy tensor formatting options.
159    print_all: (bool) Whether the tensor is to be displayed in its entirety,
160      instead of printing ellipses, even if its number of elements exceeds
161      the default numpy display threshold.
162      (Note: Even if this is set to true, the screen output can still be cut
163       off by the UI frontend if it consist of more lines than the frontend
164       can handle.)
165    tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
166      None, no slicing will be performed on the tensor.
167    highlight_options: (tensor_format.HighlightOptions) options to highlight
168      elements of the tensor. See the doc of tensor_format.format_tensor()
169      for more details.
170    include_numeric_summary: Whether a text summary of the numeric values (if
171      applicable) will be included.
172    write_path: A path to save the tensor value (after any slicing) to
173      (optional). `numpy.save()` is used to save the value.
174
175  Returns:
176    An instance of `debugger_cli_common.RichTextLines` representing the
177    (potentially sliced) tensor.
178  """
179
180  if tensor_slicing:
181    # Validate the indexing.
182    value = command_parser.evaluate_tensor_slice(tensor, tensor_slicing)
183    sliced_name = tensor_name + tensor_slicing
184  else:
185    value = tensor
186    sliced_name = tensor_name
187
188  auxiliary_message = None
189  if write_path:
190    with gfile.Open(write_path, "wb") as output_file:
191      np.save(output_file, value)
192    line = debugger_cli_common.RichLine("Saved value to: ")
193    line += debugger_cli_common.RichLine(write_path, font_attr="bold")
194    line += " (%sB)" % bytes_to_readable_str(gfile.Stat(write_path).length)
195    auxiliary_message = debugger_cli_common.rich_text_lines_from_rich_line_list(
196        [line, debugger_cli_common.RichLine("")])
197
198  if print_all:
199    np_printoptions["threshold"] = value.size
200  else:
201    np_printoptions["threshold"] = DEFAULT_NDARRAY_DISPLAY_THRESHOLD
202
203  return tensor_format.format_tensor(
204      value,
205      sliced_name,
206      include_metadata=True,
207      include_numeric_summary=include_numeric_summary,
208      auxiliary_message=auxiliary_message,
209      np_printoptions=np_printoptions,
210      highlight_options=highlight_options)
211
212
213def error(msg):
214  """Generate a RichTextLines output for error.
215
216  Args:
217    msg: (str) The error message.
218
219  Returns:
220    (debugger_cli_common.RichTextLines) A representation of the error message
221      for screen output.
222  """
223
224  return debugger_cli_common.rich_text_lines_from_rich_line_list([
225      RL("ERROR: " + msg, COLOR_RED)])
226
227
228def _recommend_command(command, description, indent=2, create_link=False):
229  """Generate a RichTextLines object that describes a recommended command.
230
231  Args:
232    command: (str) The command to recommend.
233    description: (str) A description of what the command does.
234    indent: (int) How many spaces to indent in the beginning.
235    create_link: (bool) Whether a command link is to be applied to the command
236      string.
237
238  Returns:
239    (RichTextLines) Formatted text (with font attributes) for recommending the
240      command.
241  """
242
243  indent_str = " " * indent
244
245  if create_link:
246    font_attr = [debugger_cli_common.MenuItem("", command), "bold"]
247  else:
248    font_attr = "bold"
249
250  lines = [RL(indent_str) + RL(command, font_attr) + ":",
251           indent_str + "  " + description]
252
253  return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
254
255
256def get_tfdbg_logo():
257  """Make an ASCII representation of the tfdbg logo."""
258
259  lines = [
260      "",
261      "TTTTTT FFFF DDD  BBBB   GGG ",
262      "  TT   F    D  D B   B G    ",
263      "  TT   FFF  D  D BBBB  G  GG",
264      "  TT   F    D  D B   B G   G",
265      "  TT   F    DDD  BBBB   GGG ",
266      "",
267  ]
268  return debugger_cli_common.RichTextLines(lines)
269
270
271_HORIZONTAL_BAR = "======================================"
272
273
274def get_run_start_intro(run_call_count,
275                        fetches,
276                        feed_dict,
277                        tensor_filters,
278                        is_callable_runner=False):
279  """Generate formatted intro for run-start UI.
280
281  Args:
282    run_call_count: (int) Run call counter.
283    fetches: Fetches of the `Session.run()` call. See doc of `Session.run()`
284      for more details.
285    feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()`
286      for more details.
287    tensor_filters: (dict) A dict from tensor-filter name to tensor-filter
288      callable.
289    is_callable_runner: (bool) whether a runner returned by
290        Session.make_callable is being run.
291
292  Returns:
293    (RichTextLines) Formatted intro message about the `Session.run()` call.
294  """
295
296  fetch_lines = common.get_flattened_names(fetches)
297
298  if not feed_dict:
299    feed_dict_lines = [debugger_cli_common.RichLine("  (Empty)")]
300  else:
301    feed_dict_lines = []
302    for feed_key in feed_dict:
303      feed_key_name = common.get_graph_element_name(feed_key)
304      feed_dict_line = debugger_cli_common.RichLine("  ")
305      feed_dict_line += debugger_cli_common.RichLine(
306          feed_key_name,
307          debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name))
308      # Surround the name string with quotes, because feed_key_name may contain
309      # spaces in some cases, e.g., SparseTensors.
310      feed_dict_lines.append(feed_dict_line)
311  feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list(
312      feed_dict_lines)
313
314  out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR)
315  if is_callable_runner:
316    out.append("Running a runner returned by Session.make_callable()")
317  else:
318    out.append("Session.run() call #%d:" % run_call_count)
319    out.append("")
320    out.append("Fetch(es):")
321    out.extend(debugger_cli_common.RichTextLines(
322        ["  " + line for line in fetch_lines]))
323    out.append("")
324    out.append("Feed dict:")
325    out.extend(feed_dict_lines)
326  out.append(_HORIZONTAL_BAR)
327  out.append("")
328  out.append("Select one of the following commands to proceed ---->")
329
330  out.extend(
331      _recommend_command(
332          "run",
333          "Execute the run() call with debug tensor-watching",
334          create_link=True))
335  out.extend(
336      _recommend_command(
337          "run -n",
338          "Execute the run() call without debug tensor-watching",
339          create_link=True))
340  out.extend(
341      _recommend_command(
342          "run -t <T>",
343          "Execute run() calls (T - 1) times without debugging, then "
344          "execute run() once more with debugging and drop back to the CLI"))
345  out.extend(
346      _recommend_command(
347          "run -f <filter_name>",
348          "Keep executing run() calls until a dumped tensor passes a given, "
349          "registered filter (conditional breakpoint mode)"))
350
351  more_lines = ["    Registered filter(s):"]
352  if tensor_filters:
353    filter_names = []
354    for filter_name in tensor_filters:
355      filter_names.append(filter_name)
356      command_menu_node = debugger_cli_common.MenuItem(
357          "", "run -f %s" % filter_name)
358      more_lines.append(RL("        * ") + RL(filter_name, command_menu_node))
359  else:
360    more_lines.append("        (None)")
361
362  out.extend(
363      debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines))
364
365  out.append("")
366
367  out.append_rich_line(RL("For more details, see ") +
368                       RL("help.", debugger_cli_common.MenuItem("", "help")) +
369                       ".")
370  out.append("")
371
372  # Make main menu for the run-start intro.
373  menu = debugger_cli_common.Menu()
374  menu.append(debugger_cli_common.MenuItem("run", "run"))
375  menu.append(debugger_cli_common.MenuItem("exit", "exit"))
376  out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu
377
378  return out
379
380
381def get_run_short_description(run_call_count,
382                              fetches,
383                              feed_dict,
384                              is_callable_runner=False):
385  """Get a short description of the run() call.
386
387  Args:
388    run_call_count: (int) Run call counter.
389    fetches: Fetches of the `Session.run()` call. See doc of `Session.run()`
390      for more details.
391    feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()`
392      for more details.
393    is_callable_runner: (bool) whether a runner returned by
394        Session.make_callable is being run.
395
396  Returns:
397    (str) A short description of the run() call, including information about
398      the fetche(s) and feed(s).
399  """
400  if is_callable_runner:
401    return "runner from make_callable()"
402
403  description = "run #%d: " % run_call_count
404
405  if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
406    description += "1 fetch (%s); " % common.get_graph_element_name(fetches)
407  else:
408    # Could be (nested) list, tuple, dict or namedtuple.
409    num_fetches = len(common.get_flattened_names(fetches))
410    if num_fetches > 1:
411      description += "%d fetches; " % num_fetches
412    else:
413      description += "%d fetch; " % num_fetches
414
415  if not feed_dict:
416    description += "0 feeds"
417  else:
418    if len(feed_dict) == 1:
419      for key in feed_dict:
420        description += "1 feed (%s)" % (
421            key
422            if isinstance(key, str) or not hasattr(key, "name") else key.name)
423    else:
424      description += "%d feeds" % len(feed_dict)
425
426  return description
427
428
429def get_error_intro(tf_error):
430  """Generate formatted intro for TensorFlow run-time error.
431
432  Args:
433    tf_error: (errors.OpError) TensorFlow run-time error object.
434
435  Returns:
436    (RichTextLines) Formatted intro message about the run-time OpError, with
437      sample commands for debugging.
438  """
439
440  if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"):
441    op_name = tf_error.op.name
442  else:
443    op_name = None
444
445  intro_lines = [
446      "--------------------------------------",
447      RL("!!! An error occurred during the run !!!", "blink"),
448      "",
449  ]
450
451  out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines)
452
453  if op_name is not None:
454    out.extend(debugger_cli_common.RichTextLines(
455        ["You may use the following commands to debug:"]))
456    out.extend(
457        _recommend_command("ni -a -d -t %s" % op_name,
458                           "Inspect information about the failing op.",
459                           create_link=True))
460    out.extend(
461        _recommend_command("li -r %s" % op_name,
462                           "List inputs to the failing op, recursively.",
463                           create_link=True))
464
465    out.extend(
466        _recommend_command(
467            "lt",
468            "List all tensors dumped during the failing run() call.",
469            create_link=True))
470  else:
471    out.extend(debugger_cli_common.RichTextLines([
472        "WARNING: Cannot determine the name of the op that caused the error."]))
473
474  more_lines = [
475      "",
476      "Op name:    %s" % op_name,
477      "Error type: " + str(type(tf_error)),
478      "",
479      "Details:",
480      str(tf_error),
481      "",
482      "--------------------------------------",
483      "",
484  ]
485
486  out.extend(debugger_cli_common.RichTextLines(more_lines))
487
488  return out
489