xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/cli/tensor_format.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Format tensors (ndarrays) for screen display and navigation."""
16import copy
17import re
18
19import numpy as np
20
21from tensorflow.python.debug.cli import debugger_cli_common
22from tensorflow.python.debug.lib import debug_data
23
24_NUMPY_OMISSION = "...,"
25_NUMPY_DEFAULT_EDGE_ITEMS = 3
26
27_NUMBER_REGEX = re.compile(r"[-+]?([0-9][-+0-9eE\.]+|nan|inf)(\s|,|\])")
28
29BEGIN_INDICES_KEY = "i0"
30OMITTED_INDICES_KEY = "omitted"
31
32DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR = "bold"
33
34
35class HighlightOptions(object):
36  """Options for highlighting elements of a tensor."""
37
38  def __init__(self,
39               criterion,
40               description=None,
41               font_attr=DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR):
42    """Constructor of HighlightOptions.
43
44    Args:
45      criterion: (callable) A callable of the following signature:
46        def to_highlight(X):
47          # Args:
48          #   X: The tensor to highlight elements in.
49          #
50          # Returns:
51          #   (boolean ndarray) A boolean ndarray of the same shape as X
52          #   indicating which elements are to be highlighted (iff True).
53        This callable will be used as the argument of np.argwhere() to
54        determine which elements of the tensor are to be highlighted.
55      description: (str) Description of the highlight criterion embodied by
56        criterion.
57      font_attr: (str) Font attribute to be applied to the
58        highlighted elements.
59
60    """
61
62    self.criterion = criterion
63    self.description = description
64    self.font_attr = font_attr
65
66
67def format_tensor(tensor,
68                  tensor_label,
69                  include_metadata=False,
70                  auxiliary_message=None,
71                  include_numeric_summary=False,
72                  np_printoptions=None,
73                  highlight_options=None):
74  """Generate a RichTextLines object showing a tensor in formatted style.
75
76  Args:
77    tensor: The tensor to be displayed, as a numpy ndarray or other
78      appropriate format (e.g., None representing uninitialized tensors).
79    tensor_label: A label for the tensor, as a string. If set to None, will
80      suppress the tensor name line in the return value.
81    include_metadata: Whether metadata such as dtype and shape are to be
82      included in the formatted text.
83    auxiliary_message: An auxiliary message to display under the tensor label,
84      dtype and shape information lines.
85    include_numeric_summary: Whether a text summary of the numeric values (if
86      applicable) will be included.
87    np_printoptions: A dictionary of keyword arguments that are passed to a
88      call of np.set_printoptions() to set the text format for display numpy
89      ndarrays.
90    highlight_options: (HighlightOptions) options for highlighting elements
91      of the tensor.
92
93  Returns:
94    A RichTextLines object. Its annotation field has line-by-line markups to
95    indicate which indices in the array the first element of each line
96    corresponds to.
97  """
98  lines = []
99  font_attr_segs = {}
100
101  if tensor_label is not None:
102    lines.append("Tensor \"%s\":" % tensor_label)
103    suffix = tensor_label.split(":")[-1]
104    if suffix.isdigit():
105      # Suffix is a number. Assume it is the output slot index.
106      font_attr_segs[0] = [(8, 8 + len(tensor_label), "bold")]
107    else:
108      # Suffix is not a number. It is auxiliary information such as the debug
109      # op type. In this case, highlight the suffix with a different color.
110      debug_op_len = len(suffix)
111      proper_len = len(tensor_label) - debug_op_len - 1
112      font_attr_segs[0] = [
113          (8, 8 + proper_len, "bold"),
114          (8 + proper_len + 1, 8 + proper_len + 1 + debug_op_len, "yellow")
115      ]
116
117  if isinstance(tensor, debug_data.InconvertibleTensorProto):
118    if lines:
119      lines.append("")
120    lines.extend(str(tensor).split("\n"))
121    return debugger_cli_common.RichTextLines(lines)
122  elif not isinstance(tensor, np.ndarray):
123    # If tensor is not a np.ndarray, return simple text-line representation of
124    # the object without annotations.
125    if lines:
126      lines.append("")
127    lines.extend(repr(tensor).split("\n"))
128    return debugger_cli_common.RichTextLines(lines)
129
130  if include_metadata:
131    lines.append("  dtype: %s" % str(tensor.dtype))
132    lines.append("  shape: %s" % str(tensor.shape).replace("L", ""))
133
134  if lines:
135    lines.append("")
136  formatted = debugger_cli_common.RichTextLines(
137      lines, font_attr_segs=font_attr_segs)
138
139  if auxiliary_message:
140    formatted.extend(auxiliary_message)
141
142  if include_numeric_summary:
143    formatted.append("Numeric summary:")
144    formatted.extend(numeric_summary(tensor))
145    formatted.append("")
146
147  # Apply custom string formatting options for numpy ndarray.
148  if np_printoptions is not None:
149    np.set_printoptions(**np_printoptions)
150
151  array_lines = repr(tensor).split("\n")
152  if tensor.dtype.type is not np.string_:
153    # Parse array lines to get beginning indices for each line.
154
155    # TODO(cais): Currently, we do not annotate string-type tensors due to
156    #   difficulty in escaping sequences. Address this issue.
157    annotations = _annotate_ndarray_lines(
158        array_lines, tensor, np_printoptions=np_printoptions)
159  else:
160    annotations = None
161  formatted_array = debugger_cli_common.RichTextLines(
162      array_lines, annotations=annotations)
163  formatted.extend(formatted_array)
164
165  # Perform optional highlighting.
166  if highlight_options is not None:
167    indices_list = list(np.argwhere(highlight_options.criterion(tensor)))
168
169    total_elements = np.size(tensor)
170    highlight_summary = "Highlighted%s: %d of %d element(s) (%.2f%%)" % (
171        "(%s)" % highlight_options.description if highlight_options.description
172        else "", len(indices_list), total_elements,
173        len(indices_list) / float(total_elements) * 100.0)
174
175    formatted.lines[0] += " " + highlight_summary
176
177    if indices_list:
178      indices_list = [list(indices) for indices in indices_list]
179
180      are_omitted, rows, start_cols, end_cols = locate_tensor_element(
181          formatted, indices_list)
182      for is_omitted, row, start_col, end_col in zip(are_omitted, rows,
183                                                     start_cols, end_cols):
184        if is_omitted or start_col is None or end_col is None:
185          continue
186
187        if row in formatted.font_attr_segs:
188          formatted.font_attr_segs[row].append(
189              (start_col, end_col, highlight_options.font_attr))
190        else:
191          formatted.font_attr_segs[row] = [(start_col, end_col,
192                                            highlight_options.font_attr)]
193
194  return formatted
195
196
197def _annotate_ndarray_lines(
198    array_lines, tensor, np_printoptions=None, offset=0):
199  """Generate annotations for line-by-line begin indices of tensor text.
200
201  Parse the numpy-generated text representation of a numpy ndarray to
202  determine the indices of the first element of each text line (if any
203  element is present in the line).
204
205  For example, given the following multi-line ndarray text representation:
206      ["array([[ 0.    ,  0.0625,  0.125 ,  0.1875],",
207       "       [ 0.25  ,  0.3125,  0.375 ,  0.4375],",
208       "       [ 0.5   ,  0.5625,  0.625 ,  0.6875],",
209       "       [ 0.75  ,  0.8125,  0.875 ,  0.9375]])"]
210  the generate annotation will be:
211      {0: {BEGIN_INDICES_KEY: [0, 0]},
212       1: {BEGIN_INDICES_KEY: [1, 0]},
213       2: {BEGIN_INDICES_KEY: [2, 0]},
214       3: {BEGIN_INDICES_KEY: [3, 0]}}
215
216  Args:
217    array_lines: Text lines representing the tensor, as a list of str.
218    tensor: The tensor being formatted as string.
219    np_printoptions: A dictionary of keyword arguments that are passed to a
220      call of np.set_printoptions().
221    offset: Line number offset applied to the line indices in the returned
222      annotation.
223
224  Returns:
225    An annotation as a dict.
226  """
227
228  if np_printoptions and "edgeitems" in np_printoptions:
229    edge_items = np_printoptions["edgeitems"]
230  else:
231    edge_items = _NUMPY_DEFAULT_EDGE_ITEMS
232
233  annotations = {}
234
235  # Put metadata about the tensor in the annotations["tensor_metadata"].
236  annotations["tensor_metadata"] = {
237      "dtype": tensor.dtype, "shape": tensor.shape}
238
239  dims = np.shape(tensor)
240  ndims = len(dims)
241  if ndims == 0:
242    # No indices for a 0D tensor.
243    return annotations
244
245  curr_indices = [0] * len(dims)
246  curr_dim = 0
247  for i, raw_line in enumerate(array_lines):
248    line = raw_line.strip()
249
250    if not line:
251      # Skip empty lines, which can appear for >= 3D arrays.
252      continue
253
254    if line == _NUMPY_OMISSION:
255      annotations[offset + i] = {OMITTED_INDICES_KEY: copy.copy(curr_indices)}
256      curr_indices[curr_dim - 1] = dims[curr_dim - 1] - edge_items
257    else:
258      num_lbrackets = line.count("[")  # TODO(cais): String array escaping.
259      num_rbrackets = line.count("]")
260
261      curr_dim += num_lbrackets - num_rbrackets
262
263      annotations[offset + i] = {BEGIN_INDICES_KEY: copy.copy(curr_indices)}
264      if num_rbrackets == 0:
265        line_content = line[line.rfind("[") + 1:]
266        num_elements = line_content.count(",")
267        curr_indices[curr_dim - 1] += num_elements
268      else:
269        if curr_dim > 0:
270          curr_indices[curr_dim - 1] += 1
271          for k in range(curr_dim, ndims):
272            curr_indices[k] = 0
273
274  return annotations
275
276
277def locate_tensor_element(formatted, indices):
278  """Locate a tensor element in formatted text lines, given element indices.
279
280  Given a RichTextLines object representing a tensor and indices of the sought
281  element, return the row number at which the element is located (if exists).
282
283  Args:
284    formatted: A RichTextLines object containing formatted text lines
285      representing the tensor.
286    indices: Indices of the sought element, as a list of int or a list of list
287      of int. The former case is for a single set of indices to look up,
288      whereas the latter case is for looking up a batch of indices sets at once.
289      In the latter case, the indices must be in ascending order, or a
290      ValueError will be raised.
291
292  Returns:
293    1) A boolean indicating whether the element falls into an omitted line.
294    2) Row index.
295    3) Column start index, i.e., the first column in which the representation
296       of the specified tensor starts, if it can be determined. If it cannot
297       be determined (e.g., due to ellipsis), None.
298    4) Column end index, i.e., the column right after the last column that
299       represents the specified tensor. Iff it cannot be determined, None.
300
301  For return values described above are based on a single set of indices to
302    look up. In the case of batch mode (multiple sets of indices), the return
303    values will be lists of the types described above.
304
305  Raises:
306    AttributeError: If:
307      Input argument "formatted" does not have the required annotations.
308    ValueError: If:
309      1) Indices do not match the dimensions of the tensor, or
310      2) Indices exceed sizes of the tensor, or
311      3) Indices contain negative value(s).
312      4) If in batch mode, and if not all sets of indices are in ascending
313         order.
314  """
315
316  if isinstance(indices[0], list):
317    indices_list = indices
318    input_batch = True
319  else:
320    indices_list = [indices]
321    input_batch = False
322
323  # Check that tensor_metadata is available.
324  if "tensor_metadata" not in formatted.annotations:
325    raise AttributeError("tensor_metadata is not available in annotations.")
326
327  # Sanity check on input argument.
328  _validate_indices_list(indices_list, formatted)
329
330  dims = formatted.annotations["tensor_metadata"]["shape"]
331  batch_size = len(indices_list)
332  lines = formatted.lines
333  annot = formatted.annotations
334  prev_r = 0
335  prev_line = ""
336  prev_indices = [0] * len(dims)
337
338  # Initialize return values
339  are_omitted = [None] * batch_size
340  row_indices = [None] * batch_size
341  start_columns = [None] * batch_size
342  end_columns = [None] * batch_size
343
344  batch_pos = 0  # Current position in the batch.
345
346  for r in range(len(lines)):
347    if r not in annot:
348      continue
349
350    if BEGIN_INDICES_KEY in annot[r]:
351      indices_key = BEGIN_INDICES_KEY
352    elif OMITTED_INDICES_KEY in annot[r]:
353      indices_key = OMITTED_INDICES_KEY
354
355    matching_indices_list = [
356        ind for ind in indices_list[batch_pos:]
357        if prev_indices <= ind < annot[r][indices_key]
358    ]
359
360    if matching_indices_list:
361      num_matches = len(matching_indices_list)
362
363      match_start_columns, match_end_columns = _locate_elements_in_line(
364          prev_line, matching_indices_list, prev_indices)
365
366      start_columns[batch_pos:batch_pos + num_matches] = match_start_columns
367      end_columns[batch_pos:batch_pos + num_matches] = match_end_columns
368      are_omitted[batch_pos:batch_pos + num_matches] = [
369          OMITTED_INDICES_KEY in annot[prev_r]
370      ] * num_matches
371      row_indices[batch_pos:batch_pos + num_matches] = [prev_r] * num_matches
372
373      batch_pos += num_matches
374      if batch_pos >= batch_size:
375        break
376
377    prev_r = r
378    prev_line = lines[r]
379    prev_indices = annot[r][indices_key]
380
381  if batch_pos < batch_size:
382    matching_indices_list = indices_list[batch_pos:]
383    num_matches = len(matching_indices_list)
384
385    match_start_columns, match_end_columns = _locate_elements_in_line(
386        prev_line, matching_indices_list, prev_indices)
387
388    start_columns[batch_pos:batch_pos + num_matches] = match_start_columns
389    end_columns[batch_pos:batch_pos + num_matches] = match_end_columns
390    are_omitted[batch_pos:batch_pos + num_matches] = [
391        OMITTED_INDICES_KEY in annot[prev_r]
392    ] * num_matches
393    row_indices[batch_pos:batch_pos + num_matches] = [prev_r] * num_matches
394
395  if input_batch:
396    return are_omitted, row_indices, start_columns, end_columns
397  else:
398    return are_omitted[0], row_indices[0], start_columns[0], end_columns[0]
399
400
401def _validate_indices_list(indices_list, formatted):
402  prev_ind = None
403  for ind in indices_list:
404    # Check indices match tensor dimensions.
405    dims = formatted.annotations["tensor_metadata"]["shape"]
406    if len(ind) != len(dims):
407      raise ValueError("Dimensions mismatch: requested: %d; actual: %d" %
408                       (len(ind), len(dims)))
409
410    # Check indices is within size limits.
411    for req_idx, siz in zip(ind, dims):
412      if req_idx >= siz:
413        raise ValueError("Indices exceed tensor dimensions.")
414      if req_idx < 0:
415        raise ValueError("Indices contain negative value(s).")
416
417    # Check indices are in ascending order.
418    if prev_ind and ind < prev_ind:
419      raise ValueError("Input indices sets are not in ascending order.")
420
421    prev_ind = ind
422
423
424def _locate_elements_in_line(line, indices_list, ref_indices):
425  """Determine the start and end indices of an element in a line.
426
427  Args:
428    line: (str) the line in which the element is to be sought.
429    indices_list: (list of list of int) list of indices of the element to
430       search for. Assumes that the indices in the batch are unique and sorted
431       in ascending order.
432    ref_indices: (list of int) reference indices, i.e., the indices of the
433      first element represented in the line.
434
435  Returns:
436    start_columns: (list of int) start column indices, if found. If not found,
437      None.
438    end_columns: (list of int) end column indices, if found. If not found,
439      None.
440    If found, the element is represented in the left-closed-right-open interval
441      [start_column, end_column].
442  """
443
444  batch_size = len(indices_list)
445  offsets = [indices[-1] - ref_indices[-1] for indices in indices_list]
446
447  start_columns = [None] * batch_size
448  end_columns = [None] * batch_size
449
450  if _NUMPY_OMISSION in line:
451    ellipsis_index = line.find(_NUMPY_OMISSION)
452  else:
453    ellipsis_index = len(line)
454
455  matches_iter = re.finditer(_NUMBER_REGEX, line)
456
457  batch_pos = 0
458
459  offset_counter = 0
460  for match in matches_iter:
461    if match.start() > ellipsis_index:
462      # Do not attempt to search beyond ellipsis.
463      break
464
465    if offset_counter == offsets[batch_pos]:
466      start_columns[batch_pos] = match.start()
467      # Remove the final comma, right bracket, or whitespace.
468      end_columns[batch_pos] = match.end() - 1
469
470      batch_pos += 1
471      if batch_pos >= batch_size:
472        break
473
474    offset_counter += 1
475
476  return start_columns, end_columns
477
478
479def _pad_string_to_length(string, length):
480  return " " * (length - len(string)) + string
481
482
483def numeric_summary(tensor):
484  """Get a text summary of a numeric tensor.
485
486  This summary is only available for numeric (int*, float*, complex*) and
487  Boolean tensors.
488
489  Args:
490    tensor: (`numpy.ndarray`) the tensor value object to be summarized.
491
492  Returns:
493    The summary text as a `RichTextLines` object. If the type of `tensor` is not
494    numeric or Boolean, a single-line `RichTextLines` object containing a
495    warning message will reflect that.
496  """
497
498  def _counts_summary(counts, skip_zeros=True, total_count=None):
499    """Format values as a two-row table."""
500    if skip_zeros:
501      counts = [(count_key, count_val) for count_key, count_val in counts
502                if count_val]
503    max_common_len = 0
504    for count_key, count_val in counts:
505      count_val_str = str(count_val)
506      common_len = max(len(count_key) + 1, len(count_val_str) + 1)
507      max_common_len = max(common_len, max_common_len)
508
509    key_line = debugger_cli_common.RichLine("|")
510    val_line = debugger_cli_common.RichLine("|")
511    for count_key, count_val in counts:
512      count_val_str = str(count_val)
513      key_line += _pad_string_to_length(count_key, max_common_len)
514      val_line += _pad_string_to_length(count_val_str, max_common_len)
515    key_line += " |"
516    val_line += " |"
517
518    if total_count is not None:
519      total_key_str = "total"
520      total_val_str = str(total_count)
521      max_common_len = max(len(total_key_str) + 1, len(total_val_str))
522      total_key_str = _pad_string_to_length(total_key_str, max_common_len)
523      total_val_str = _pad_string_to_length(total_val_str, max_common_len)
524      key_line += total_key_str + " |"
525      val_line += total_val_str + " |"
526
527    return debugger_cli_common.rich_text_lines_from_rich_line_list(
528        [key_line, val_line])
529
530  if not isinstance(tensor, np.ndarray) or not np.size(tensor):
531    return debugger_cli_common.RichTextLines([
532        "No numeric summary available due to empty tensor."])
533  elif (np.issubdtype(tensor.dtype, np.floating) or
534        np.issubdtype(tensor.dtype, np.complexfloating) or
535        np.issubdtype(tensor.dtype, np.integer)):
536    counts = [
537        ("nan", np.sum(np.isnan(tensor))),
538        ("-inf", np.sum(np.isneginf(tensor))),
539        ("-", np.sum(np.logical_and(
540            tensor < 0.0, np.logical_not(np.isneginf(tensor))))),
541        ("0", np.sum(tensor == 0.0)),
542        ("+", np.sum(np.logical_and(
543            tensor > 0.0, np.logical_not(np.isposinf(tensor))))),
544        ("+inf", np.sum(np.isposinf(tensor)))]
545    output = _counts_summary(counts, total_count=np.size(tensor))
546
547    valid_array = tensor[
548        np.logical_not(np.logical_or(np.isinf(tensor), np.isnan(tensor)))]
549    if np.size(valid_array):
550      stats = [
551          ("min", np.min(valid_array)),
552          ("max", np.max(valid_array)),
553          ("mean", np.mean(valid_array)),
554          ("std", np.std(valid_array))]
555      output.extend(_counts_summary(stats, skip_zeros=False))
556    return output
557  elif tensor.dtype == np.bool_:
558    counts = [
559        ("False", np.sum(tensor == 0)),
560        ("True", np.sum(tensor > 0)),]
561    return _counts_summary(counts, total_count=np.size(tensor))
562  else:
563    return debugger_cli_common.RichTextLines([
564        "No numeric summary available due to tensor dtype: %s." % tensor.dtype])
565