1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Format tensors (ndarrays) for screen display and navigation.""" 16import copy 17import re 18 19import numpy as np 20 21from tensorflow.python.debug.cli import debugger_cli_common 22from tensorflow.python.debug.lib import debug_data 23 24_NUMPY_OMISSION = "...," 25_NUMPY_DEFAULT_EDGE_ITEMS = 3 26 27_NUMBER_REGEX = re.compile(r"[-+]?([0-9][-+0-9eE\.]+|nan|inf)(\s|,|\])") 28 29BEGIN_INDICES_KEY = "i0" 30OMITTED_INDICES_KEY = "omitted" 31 32DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR = "bold" 33 34 35class HighlightOptions(object): 36 """Options for highlighting elements of a tensor.""" 37 38 def __init__(self, 39 criterion, 40 description=None, 41 font_attr=DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR): 42 """Constructor of HighlightOptions. 43 44 Args: 45 criterion: (callable) A callable of the following signature: 46 def to_highlight(X): 47 # Args: 48 # X: The tensor to highlight elements in. 49 # 50 # Returns: 51 # (boolean ndarray) A boolean ndarray of the same shape as X 52 # indicating which elements are to be highlighted (iff True). 53 This callable will be used as the argument of np.argwhere() to 54 determine which elements of the tensor are to be highlighted. 55 description: (str) Description of the highlight criterion embodied by 56 criterion. 57 font_attr: (str) Font attribute to be applied to the 58 highlighted elements. 59 60 """ 61 62 self.criterion = criterion 63 self.description = description 64 self.font_attr = font_attr 65 66 67def format_tensor(tensor, 68 tensor_label, 69 include_metadata=False, 70 auxiliary_message=None, 71 include_numeric_summary=False, 72 np_printoptions=None, 73 highlight_options=None): 74 """Generate a RichTextLines object showing a tensor in formatted style. 75 76 Args: 77 tensor: The tensor to be displayed, as a numpy ndarray or other 78 appropriate format (e.g., None representing uninitialized tensors). 79 tensor_label: A label for the tensor, as a string. If set to None, will 80 suppress the tensor name line in the return value. 81 include_metadata: Whether metadata such as dtype and shape are to be 82 included in the formatted text. 83 auxiliary_message: An auxiliary message to display under the tensor label, 84 dtype and shape information lines. 85 include_numeric_summary: Whether a text summary of the numeric values (if 86 applicable) will be included. 87 np_printoptions: A dictionary of keyword arguments that are passed to a 88 call of np.set_printoptions() to set the text format for display numpy 89 ndarrays. 90 highlight_options: (HighlightOptions) options for highlighting elements 91 of the tensor. 92 93 Returns: 94 A RichTextLines object. Its annotation field has line-by-line markups to 95 indicate which indices in the array the first element of each line 96 corresponds to. 97 """ 98 lines = [] 99 font_attr_segs = {} 100 101 if tensor_label is not None: 102 lines.append("Tensor \"%s\":" % tensor_label) 103 suffix = tensor_label.split(":")[-1] 104 if suffix.isdigit(): 105 # Suffix is a number. Assume it is the output slot index. 106 font_attr_segs[0] = [(8, 8 + len(tensor_label), "bold")] 107 else: 108 # Suffix is not a number. It is auxiliary information such as the debug 109 # op type. In this case, highlight the suffix with a different color. 110 debug_op_len = len(suffix) 111 proper_len = len(tensor_label) - debug_op_len - 1 112 font_attr_segs[0] = [ 113 (8, 8 + proper_len, "bold"), 114 (8 + proper_len + 1, 8 + proper_len + 1 + debug_op_len, "yellow") 115 ] 116 117 if isinstance(tensor, debug_data.InconvertibleTensorProto): 118 if lines: 119 lines.append("") 120 lines.extend(str(tensor).split("\n")) 121 return debugger_cli_common.RichTextLines(lines) 122 elif not isinstance(tensor, np.ndarray): 123 # If tensor is not a np.ndarray, return simple text-line representation of 124 # the object without annotations. 125 if lines: 126 lines.append("") 127 lines.extend(repr(tensor).split("\n")) 128 return debugger_cli_common.RichTextLines(lines) 129 130 if include_metadata: 131 lines.append(" dtype: %s" % str(tensor.dtype)) 132 lines.append(" shape: %s" % str(tensor.shape).replace("L", "")) 133 134 if lines: 135 lines.append("") 136 formatted = debugger_cli_common.RichTextLines( 137 lines, font_attr_segs=font_attr_segs) 138 139 if auxiliary_message: 140 formatted.extend(auxiliary_message) 141 142 if include_numeric_summary: 143 formatted.append("Numeric summary:") 144 formatted.extend(numeric_summary(tensor)) 145 formatted.append("") 146 147 # Apply custom string formatting options for numpy ndarray. 148 if np_printoptions is not None: 149 np.set_printoptions(**np_printoptions) 150 151 array_lines = repr(tensor).split("\n") 152 if tensor.dtype.type is not np.string_: 153 # Parse array lines to get beginning indices for each line. 154 155 # TODO(cais): Currently, we do not annotate string-type tensors due to 156 # difficulty in escaping sequences. Address this issue. 157 annotations = _annotate_ndarray_lines( 158 array_lines, tensor, np_printoptions=np_printoptions) 159 else: 160 annotations = None 161 formatted_array = debugger_cli_common.RichTextLines( 162 array_lines, annotations=annotations) 163 formatted.extend(formatted_array) 164 165 # Perform optional highlighting. 166 if highlight_options is not None: 167 indices_list = list(np.argwhere(highlight_options.criterion(tensor))) 168 169 total_elements = np.size(tensor) 170 highlight_summary = "Highlighted%s: %d of %d element(s) (%.2f%%)" % ( 171 "(%s)" % highlight_options.description if highlight_options.description 172 else "", len(indices_list), total_elements, 173 len(indices_list) / float(total_elements) * 100.0) 174 175 formatted.lines[0] += " " + highlight_summary 176 177 if indices_list: 178 indices_list = [list(indices) for indices in indices_list] 179 180 are_omitted, rows, start_cols, end_cols = locate_tensor_element( 181 formatted, indices_list) 182 for is_omitted, row, start_col, end_col in zip(are_omitted, rows, 183 start_cols, end_cols): 184 if is_omitted or start_col is None or end_col is None: 185 continue 186 187 if row in formatted.font_attr_segs: 188 formatted.font_attr_segs[row].append( 189 (start_col, end_col, highlight_options.font_attr)) 190 else: 191 formatted.font_attr_segs[row] = [(start_col, end_col, 192 highlight_options.font_attr)] 193 194 return formatted 195 196 197def _annotate_ndarray_lines( 198 array_lines, tensor, np_printoptions=None, offset=0): 199 """Generate annotations for line-by-line begin indices of tensor text. 200 201 Parse the numpy-generated text representation of a numpy ndarray to 202 determine the indices of the first element of each text line (if any 203 element is present in the line). 204 205 For example, given the following multi-line ndarray text representation: 206 ["array([[ 0. , 0.0625, 0.125 , 0.1875],", 207 " [ 0.25 , 0.3125, 0.375 , 0.4375],", 208 " [ 0.5 , 0.5625, 0.625 , 0.6875],", 209 " [ 0.75 , 0.8125, 0.875 , 0.9375]])"] 210 the generate annotation will be: 211 {0: {BEGIN_INDICES_KEY: [0, 0]}, 212 1: {BEGIN_INDICES_KEY: [1, 0]}, 213 2: {BEGIN_INDICES_KEY: [2, 0]}, 214 3: {BEGIN_INDICES_KEY: [3, 0]}} 215 216 Args: 217 array_lines: Text lines representing the tensor, as a list of str. 218 tensor: The tensor being formatted as string. 219 np_printoptions: A dictionary of keyword arguments that are passed to a 220 call of np.set_printoptions(). 221 offset: Line number offset applied to the line indices in the returned 222 annotation. 223 224 Returns: 225 An annotation as a dict. 226 """ 227 228 if np_printoptions and "edgeitems" in np_printoptions: 229 edge_items = np_printoptions["edgeitems"] 230 else: 231 edge_items = _NUMPY_DEFAULT_EDGE_ITEMS 232 233 annotations = {} 234 235 # Put metadata about the tensor in the annotations["tensor_metadata"]. 236 annotations["tensor_metadata"] = { 237 "dtype": tensor.dtype, "shape": tensor.shape} 238 239 dims = np.shape(tensor) 240 ndims = len(dims) 241 if ndims == 0: 242 # No indices for a 0D tensor. 243 return annotations 244 245 curr_indices = [0] * len(dims) 246 curr_dim = 0 247 for i, raw_line in enumerate(array_lines): 248 line = raw_line.strip() 249 250 if not line: 251 # Skip empty lines, which can appear for >= 3D arrays. 252 continue 253 254 if line == _NUMPY_OMISSION: 255 annotations[offset + i] = {OMITTED_INDICES_KEY: copy.copy(curr_indices)} 256 curr_indices[curr_dim - 1] = dims[curr_dim - 1] - edge_items 257 else: 258 num_lbrackets = line.count("[") # TODO(cais): String array escaping. 259 num_rbrackets = line.count("]") 260 261 curr_dim += num_lbrackets - num_rbrackets 262 263 annotations[offset + i] = {BEGIN_INDICES_KEY: copy.copy(curr_indices)} 264 if num_rbrackets == 0: 265 line_content = line[line.rfind("[") + 1:] 266 num_elements = line_content.count(",") 267 curr_indices[curr_dim - 1] += num_elements 268 else: 269 if curr_dim > 0: 270 curr_indices[curr_dim - 1] += 1 271 for k in range(curr_dim, ndims): 272 curr_indices[k] = 0 273 274 return annotations 275 276 277def locate_tensor_element(formatted, indices): 278 """Locate a tensor element in formatted text lines, given element indices. 279 280 Given a RichTextLines object representing a tensor and indices of the sought 281 element, return the row number at which the element is located (if exists). 282 283 Args: 284 formatted: A RichTextLines object containing formatted text lines 285 representing the tensor. 286 indices: Indices of the sought element, as a list of int or a list of list 287 of int. The former case is for a single set of indices to look up, 288 whereas the latter case is for looking up a batch of indices sets at once. 289 In the latter case, the indices must be in ascending order, or a 290 ValueError will be raised. 291 292 Returns: 293 1) A boolean indicating whether the element falls into an omitted line. 294 2) Row index. 295 3) Column start index, i.e., the first column in which the representation 296 of the specified tensor starts, if it can be determined. If it cannot 297 be determined (e.g., due to ellipsis), None. 298 4) Column end index, i.e., the column right after the last column that 299 represents the specified tensor. Iff it cannot be determined, None. 300 301 For return values described above are based on a single set of indices to 302 look up. In the case of batch mode (multiple sets of indices), the return 303 values will be lists of the types described above. 304 305 Raises: 306 AttributeError: If: 307 Input argument "formatted" does not have the required annotations. 308 ValueError: If: 309 1) Indices do not match the dimensions of the tensor, or 310 2) Indices exceed sizes of the tensor, or 311 3) Indices contain negative value(s). 312 4) If in batch mode, and if not all sets of indices are in ascending 313 order. 314 """ 315 316 if isinstance(indices[0], list): 317 indices_list = indices 318 input_batch = True 319 else: 320 indices_list = [indices] 321 input_batch = False 322 323 # Check that tensor_metadata is available. 324 if "tensor_metadata" not in formatted.annotations: 325 raise AttributeError("tensor_metadata is not available in annotations.") 326 327 # Sanity check on input argument. 328 _validate_indices_list(indices_list, formatted) 329 330 dims = formatted.annotations["tensor_metadata"]["shape"] 331 batch_size = len(indices_list) 332 lines = formatted.lines 333 annot = formatted.annotations 334 prev_r = 0 335 prev_line = "" 336 prev_indices = [0] * len(dims) 337 338 # Initialize return values 339 are_omitted = [None] * batch_size 340 row_indices = [None] * batch_size 341 start_columns = [None] * batch_size 342 end_columns = [None] * batch_size 343 344 batch_pos = 0 # Current position in the batch. 345 346 for r in range(len(lines)): 347 if r not in annot: 348 continue 349 350 if BEGIN_INDICES_KEY in annot[r]: 351 indices_key = BEGIN_INDICES_KEY 352 elif OMITTED_INDICES_KEY in annot[r]: 353 indices_key = OMITTED_INDICES_KEY 354 355 matching_indices_list = [ 356 ind for ind in indices_list[batch_pos:] 357 if prev_indices <= ind < annot[r][indices_key] 358 ] 359 360 if matching_indices_list: 361 num_matches = len(matching_indices_list) 362 363 match_start_columns, match_end_columns = _locate_elements_in_line( 364 prev_line, matching_indices_list, prev_indices) 365 366 start_columns[batch_pos:batch_pos + num_matches] = match_start_columns 367 end_columns[batch_pos:batch_pos + num_matches] = match_end_columns 368 are_omitted[batch_pos:batch_pos + num_matches] = [ 369 OMITTED_INDICES_KEY in annot[prev_r] 370 ] * num_matches 371 row_indices[batch_pos:batch_pos + num_matches] = [prev_r] * num_matches 372 373 batch_pos += num_matches 374 if batch_pos >= batch_size: 375 break 376 377 prev_r = r 378 prev_line = lines[r] 379 prev_indices = annot[r][indices_key] 380 381 if batch_pos < batch_size: 382 matching_indices_list = indices_list[batch_pos:] 383 num_matches = len(matching_indices_list) 384 385 match_start_columns, match_end_columns = _locate_elements_in_line( 386 prev_line, matching_indices_list, prev_indices) 387 388 start_columns[batch_pos:batch_pos + num_matches] = match_start_columns 389 end_columns[batch_pos:batch_pos + num_matches] = match_end_columns 390 are_omitted[batch_pos:batch_pos + num_matches] = [ 391 OMITTED_INDICES_KEY in annot[prev_r] 392 ] * num_matches 393 row_indices[batch_pos:batch_pos + num_matches] = [prev_r] * num_matches 394 395 if input_batch: 396 return are_omitted, row_indices, start_columns, end_columns 397 else: 398 return are_omitted[0], row_indices[0], start_columns[0], end_columns[0] 399 400 401def _validate_indices_list(indices_list, formatted): 402 prev_ind = None 403 for ind in indices_list: 404 # Check indices match tensor dimensions. 405 dims = formatted.annotations["tensor_metadata"]["shape"] 406 if len(ind) != len(dims): 407 raise ValueError("Dimensions mismatch: requested: %d; actual: %d" % 408 (len(ind), len(dims))) 409 410 # Check indices is within size limits. 411 for req_idx, siz in zip(ind, dims): 412 if req_idx >= siz: 413 raise ValueError("Indices exceed tensor dimensions.") 414 if req_idx < 0: 415 raise ValueError("Indices contain negative value(s).") 416 417 # Check indices are in ascending order. 418 if prev_ind and ind < prev_ind: 419 raise ValueError("Input indices sets are not in ascending order.") 420 421 prev_ind = ind 422 423 424def _locate_elements_in_line(line, indices_list, ref_indices): 425 """Determine the start and end indices of an element in a line. 426 427 Args: 428 line: (str) the line in which the element is to be sought. 429 indices_list: (list of list of int) list of indices of the element to 430 search for. Assumes that the indices in the batch are unique and sorted 431 in ascending order. 432 ref_indices: (list of int) reference indices, i.e., the indices of the 433 first element represented in the line. 434 435 Returns: 436 start_columns: (list of int) start column indices, if found. If not found, 437 None. 438 end_columns: (list of int) end column indices, if found. If not found, 439 None. 440 If found, the element is represented in the left-closed-right-open interval 441 [start_column, end_column]. 442 """ 443 444 batch_size = len(indices_list) 445 offsets = [indices[-1] - ref_indices[-1] for indices in indices_list] 446 447 start_columns = [None] * batch_size 448 end_columns = [None] * batch_size 449 450 if _NUMPY_OMISSION in line: 451 ellipsis_index = line.find(_NUMPY_OMISSION) 452 else: 453 ellipsis_index = len(line) 454 455 matches_iter = re.finditer(_NUMBER_REGEX, line) 456 457 batch_pos = 0 458 459 offset_counter = 0 460 for match in matches_iter: 461 if match.start() > ellipsis_index: 462 # Do not attempt to search beyond ellipsis. 463 break 464 465 if offset_counter == offsets[batch_pos]: 466 start_columns[batch_pos] = match.start() 467 # Remove the final comma, right bracket, or whitespace. 468 end_columns[batch_pos] = match.end() - 1 469 470 batch_pos += 1 471 if batch_pos >= batch_size: 472 break 473 474 offset_counter += 1 475 476 return start_columns, end_columns 477 478 479def _pad_string_to_length(string, length): 480 return " " * (length - len(string)) + string 481 482 483def numeric_summary(tensor): 484 """Get a text summary of a numeric tensor. 485 486 This summary is only available for numeric (int*, float*, complex*) and 487 Boolean tensors. 488 489 Args: 490 tensor: (`numpy.ndarray`) the tensor value object to be summarized. 491 492 Returns: 493 The summary text as a `RichTextLines` object. If the type of `tensor` is not 494 numeric or Boolean, a single-line `RichTextLines` object containing a 495 warning message will reflect that. 496 """ 497 498 def _counts_summary(counts, skip_zeros=True, total_count=None): 499 """Format values as a two-row table.""" 500 if skip_zeros: 501 counts = [(count_key, count_val) for count_key, count_val in counts 502 if count_val] 503 max_common_len = 0 504 for count_key, count_val in counts: 505 count_val_str = str(count_val) 506 common_len = max(len(count_key) + 1, len(count_val_str) + 1) 507 max_common_len = max(common_len, max_common_len) 508 509 key_line = debugger_cli_common.RichLine("|") 510 val_line = debugger_cli_common.RichLine("|") 511 for count_key, count_val in counts: 512 count_val_str = str(count_val) 513 key_line += _pad_string_to_length(count_key, max_common_len) 514 val_line += _pad_string_to_length(count_val_str, max_common_len) 515 key_line += " |" 516 val_line += " |" 517 518 if total_count is not None: 519 total_key_str = "total" 520 total_val_str = str(total_count) 521 max_common_len = max(len(total_key_str) + 1, len(total_val_str)) 522 total_key_str = _pad_string_to_length(total_key_str, max_common_len) 523 total_val_str = _pad_string_to_length(total_val_str, max_common_len) 524 key_line += total_key_str + " |" 525 val_line += total_val_str + " |" 526 527 return debugger_cli_common.rich_text_lines_from_rich_line_list( 528 [key_line, val_line]) 529 530 if not isinstance(tensor, np.ndarray) or not np.size(tensor): 531 return debugger_cli_common.RichTextLines([ 532 "No numeric summary available due to empty tensor."]) 533 elif (np.issubdtype(tensor.dtype, np.floating) or 534 np.issubdtype(tensor.dtype, np.complexfloating) or 535 np.issubdtype(tensor.dtype, np.integer)): 536 counts = [ 537 ("nan", np.sum(np.isnan(tensor))), 538 ("-inf", np.sum(np.isneginf(tensor))), 539 ("-", np.sum(np.logical_and( 540 tensor < 0.0, np.logical_not(np.isneginf(tensor))))), 541 ("0", np.sum(tensor == 0.0)), 542 ("+", np.sum(np.logical_and( 543 tensor > 0.0, np.logical_not(np.isposinf(tensor))))), 544 ("+inf", np.sum(np.isposinf(tensor)))] 545 output = _counts_summary(counts, total_count=np.size(tensor)) 546 547 valid_array = tensor[ 548 np.logical_not(np.logical_or(np.isinf(tensor), np.isnan(tensor)))] 549 if np.size(valid_array): 550 stats = [ 551 ("min", np.min(valid_array)), 552 ("max", np.max(valid_array)), 553 ("mean", np.mean(valid_array)), 554 ("std", np.std(valid_array))] 555 output.extend(_counts_summary(stats, skip_zeros=False)) 556 return output 557 elif tensor.dtype == np.bool_: 558 counts = [ 559 ("False", np.sum(tensor == 0)), 560 ("True", np.sum(tensor > 0)),] 561 return _counts_summary(counts, total_count=np.size(tensor)) 562 else: 563 return debugger_cli_common.RichTextLines([ 564 "No numeric summary available due to tensor dtype: %s." % tensor.dtype]) 565