1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Shared functions and classes for tfdbg command-line interface.""" 16import math 17 18import numpy as np 19 20from tensorflow.python.debug.cli import command_parser 21from tensorflow.python.debug.cli import debugger_cli_common 22from tensorflow.python.debug.cli import tensor_format 23from tensorflow.python.debug.lib import common 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import variables 26from tensorflow.python.platform import gfile 27 28RL = debugger_cli_common.RichLine 29 30# Default threshold number of elements above which ellipses will be used 31# when printing the value of the tensor. 32DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000 33 34COLOR_BLACK = "black" 35COLOR_BLUE = "blue" 36COLOR_CYAN = "cyan" 37COLOR_GRAY = "gray" 38COLOR_GREEN = "green" 39COLOR_MAGENTA = "magenta" 40COLOR_RED = "red" 41COLOR_WHITE = "white" 42COLOR_YELLOW = "yellow" 43 44TIME_UNIT_US = "us" 45TIME_UNIT_MS = "ms" 46TIME_UNIT_S = "s" 47TIME_UNITS = [TIME_UNIT_US, TIME_UNIT_MS, TIME_UNIT_S] 48 49 50def bytes_to_readable_str(num_bytes, include_b=False): 51 """Generate a human-readable string representing number of bytes. 52 53 The units B, kB, MB and GB are used. 54 55 Args: 56 num_bytes: (`int` or None) Number of bytes. 57 include_b: (`bool`) Include the letter B at the end of the unit. 58 59 Returns: 60 (`str`) A string representing the number of bytes in a human-readable way, 61 including a unit at the end. 62 """ 63 64 if num_bytes is None: 65 return str(num_bytes) 66 if num_bytes < 1024: 67 result = "%d" % num_bytes 68 elif num_bytes < 1048576: 69 result = "%.2fk" % (num_bytes / 1024.0) 70 elif num_bytes < 1073741824: 71 result = "%.2fM" % (num_bytes / 1048576.0) 72 else: 73 result = "%.2fG" % (num_bytes / 1073741824.0) 74 75 if include_b: 76 result += "B" 77 return result 78 79 80def time_to_readable_str(value_us, force_time_unit=None): 81 """Convert time value to human-readable string. 82 83 Args: 84 value_us: time value in microseconds. 85 force_time_unit: force the output to use the specified time unit. Must be 86 in TIME_UNITS. 87 88 Returns: 89 Human-readable string representation of the time value. 90 91 Raises: 92 ValueError: if force_time_unit value is not in TIME_UNITS. 93 """ 94 if not value_us: 95 return "0" 96 if force_time_unit: 97 if force_time_unit not in TIME_UNITS: 98 raise ValueError("Invalid time unit: %s" % force_time_unit) 99 order = TIME_UNITS.index(force_time_unit) 100 time_unit = force_time_unit 101 return "{:.10g}{}".format(value_us / math.pow(10.0, 3*order), time_unit) 102 else: 103 order = min(len(TIME_UNITS) - 1, int(math.log(value_us, 10) / 3)) 104 time_unit = TIME_UNITS[order] 105 return "{:.3g}{}".format(value_us / math.pow(10.0, 3*order), time_unit) 106 107 108def parse_ranges_highlight(ranges_string): 109 """Process ranges highlight string. 110 111 Args: 112 ranges_string: (str) A string representing a numerical range of a list of 113 numerical ranges. See the help info of the -r flag of the print_tensor 114 command for more details. 115 116 Returns: 117 An instance of tensor_format.HighlightOptions, if range_string is a valid 118 representation of a range or a list of ranges. 119 """ 120 121 ranges = None 122 123 def ranges_filter(x): 124 r = np.zeros(x.shape, dtype=bool) 125 for range_start, range_end in ranges: 126 r = np.logical_or(r, np.logical_and(x >= range_start, x <= range_end)) 127 128 return r 129 130 if ranges_string: 131 ranges = command_parser.parse_ranges(ranges_string) 132 return tensor_format.HighlightOptions( 133 ranges_filter, description=ranges_string) 134 else: 135 return None 136 137 138def numpy_printoptions_from_screen_info(screen_info): 139 if screen_info and "cols" in screen_info: 140 return {"linewidth": screen_info["cols"]} 141 else: 142 return {} 143 144 145def format_tensor(tensor, 146 tensor_name, 147 np_printoptions, 148 print_all=False, 149 tensor_slicing=None, 150 highlight_options=None, 151 include_numeric_summary=False, 152 write_path=None): 153 """Generate formatted str to represent a tensor or its slices. 154 155 Args: 156 tensor: (numpy ndarray) The tensor value. 157 tensor_name: (str) Name of the tensor, e.g., the tensor's debug watch key. 158 np_printoptions: (dict) Numpy tensor formatting options. 159 print_all: (bool) Whether the tensor is to be displayed in its entirety, 160 instead of printing ellipses, even if its number of elements exceeds 161 the default numpy display threshold. 162 (Note: Even if this is set to true, the screen output can still be cut 163 off by the UI frontend if it consist of more lines than the frontend 164 can handle.) 165 tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If 166 None, no slicing will be performed on the tensor. 167 highlight_options: (tensor_format.HighlightOptions) options to highlight 168 elements of the tensor. See the doc of tensor_format.format_tensor() 169 for more details. 170 include_numeric_summary: Whether a text summary of the numeric values (if 171 applicable) will be included. 172 write_path: A path to save the tensor value (after any slicing) to 173 (optional). `numpy.save()` is used to save the value. 174 175 Returns: 176 An instance of `debugger_cli_common.RichTextLines` representing the 177 (potentially sliced) tensor. 178 """ 179 180 if tensor_slicing: 181 # Validate the indexing. 182 value = command_parser.evaluate_tensor_slice(tensor, tensor_slicing) 183 sliced_name = tensor_name + tensor_slicing 184 else: 185 value = tensor 186 sliced_name = tensor_name 187 188 auxiliary_message = None 189 if write_path: 190 with gfile.Open(write_path, "wb") as output_file: 191 np.save(output_file, value) 192 line = debugger_cli_common.RichLine("Saved value to: ") 193 line += debugger_cli_common.RichLine(write_path, font_attr="bold") 194 line += " (%sB)" % bytes_to_readable_str(gfile.Stat(write_path).length) 195 auxiliary_message = debugger_cli_common.rich_text_lines_from_rich_line_list( 196 [line, debugger_cli_common.RichLine("")]) 197 198 if print_all: 199 np_printoptions["threshold"] = value.size 200 else: 201 np_printoptions["threshold"] = DEFAULT_NDARRAY_DISPLAY_THRESHOLD 202 203 return tensor_format.format_tensor( 204 value, 205 sliced_name, 206 include_metadata=True, 207 include_numeric_summary=include_numeric_summary, 208 auxiliary_message=auxiliary_message, 209 np_printoptions=np_printoptions, 210 highlight_options=highlight_options) 211 212 213def error(msg): 214 """Generate a RichTextLines output for error. 215 216 Args: 217 msg: (str) The error message. 218 219 Returns: 220 (debugger_cli_common.RichTextLines) A representation of the error message 221 for screen output. 222 """ 223 224 return debugger_cli_common.rich_text_lines_from_rich_line_list([ 225 RL("ERROR: " + msg, COLOR_RED)]) 226 227 228def _recommend_command(command, description, indent=2, create_link=False): 229 """Generate a RichTextLines object that describes a recommended command. 230 231 Args: 232 command: (str) The command to recommend. 233 description: (str) A description of what the command does. 234 indent: (int) How many spaces to indent in the beginning. 235 create_link: (bool) Whether a command link is to be applied to the command 236 string. 237 238 Returns: 239 (RichTextLines) Formatted text (with font attributes) for recommending the 240 command. 241 """ 242 243 indent_str = " " * indent 244 245 if create_link: 246 font_attr = [debugger_cli_common.MenuItem("", command), "bold"] 247 else: 248 font_attr = "bold" 249 250 lines = [RL(indent_str) + RL(command, font_attr) + ":", 251 indent_str + " " + description] 252 253 return debugger_cli_common.rich_text_lines_from_rich_line_list(lines) 254 255 256def get_tfdbg_logo(): 257 """Make an ASCII representation of the tfdbg logo.""" 258 259 lines = [ 260 "", 261 "TTTTTT FFFF DDD BBBB GGG ", 262 " TT F D D B B G ", 263 " TT FFF D D BBBB G GG", 264 " TT F D D B B G G", 265 " TT F DDD BBBB GGG ", 266 "", 267 ] 268 return debugger_cli_common.RichTextLines(lines) 269 270 271_HORIZONTAL_BAR = "======================================" 272 273 274def get_run_start_intro(run_call_count, 275 fetches, 276 feed_dict, 277 tensor_filters, 278 is_callable_runner=False): 279 """Generate formatted intro for run-start UI. 280 281 Args: 282 run_call_count: (int) Run call counter. 283 fetches: Fetches of the `Session.run()` call. See doc of `Session.run()` 284 for more details. 285 feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()` 286 for more details. 287 tensor_filters: (dict) A dict from tensor-filter name to tensor-filter 288 callable. 289 is_callable_runner: (bool) whether a runner returned by 290 Session.make_callable is being run. 291 292 Returns: 293 (RichTextLines) Formatted intro message about the `Session.run()` call. 294 """ 295 296 fetch_lines = common.get_flattened_names(fetches) 297 298 if not feed_dict: 299 feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")] 300 else: 301 feed_dict_lines = [] 302 for feed_key in feed_dict: 303 feed_key_name = common.get_graph_element_name(feed_key) 304 feed_dict_line = debugger_cli_common.RichLine(" ") 305 feed_dict_line += debugger_cli_common.RichLine( 306 feed_key_name, 307 debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name)) 308 # Surround the name string with quotes, because feed_key_name may contain 309 # spaces in some cases, e.g., SparseTensors. 310 feed_dict_lines.append(feed_dict_line) 311 feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list( 312 feed_dict_lines) 313 314 out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR) 315 if is_callable_runner: 316 out.append("Running a runner returned by Session.make_callable()") 317 else: 318 out.append("Session.run() call #%d:" % run_call_count) 319 out.append("") 320 out.append("Fetch(es):") 321 out.extend(debugger_cli_common.RichTextLines( 322 [" " + line for line in fetch_lines])) 323 out.append("") 324 out.append("Feed dict:") 325 out.extend(feed_dict_lines) 326 out.append(_HORIZONTAL_BAR) 327 out.append("") 328 out.append("Select one of the following commands to proceed ---->") 329 330 out.extend( 331 _recommend_command( 332 "run", 333 "Execute the run() call with debug tensor-watching", 334 create_link=True)) 335 out.extend( 336 _recommend_command( 337 "run -n", 338 "Execute the run() call without debug tensor-watching", 339 create_link=True)) 340 out.extend( 341 _recommend_command( 342 "run -t <T>", 343 "Execute run() calls (T - 1) times without debugging, then " 344 "execute run() once more with debugging and drop back to the CLI")) 345 out.extend( 346 _recommend_command( 347 "run -f <filter_name>", 348 "Keep executing run() calls until a dumped tensor passes a given, " 349 "registered filter (conditional breakpoint mode)")) 350 351 more_lines = [" Registered filter(s):"] 352 if tensor_filters: 353 filter_names = [] 354 for filter_name in tensor_filters: 355 filter_names.append(filter_name) 356 command_menu_node = debugger_cli_common.MenuItem( 357 "", "run -f %s" % filter_name) 358 more_lines.append(RL(" * ") + RL(filter_name, command_menu_node)) 359 else: 360 more_lines.append(" (None)") 361 362 out.extend( 363 debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines)) 364 365 out.append("") 366 367 out.append_rich_line(RL("For more details, see ") + 368 RL("help.", debugger_cli_common.MenuItem("", "help")) + 369 ".") 370 out.append("") 371 372 # Make main menu for the run-start intro. 373 menu = debugger_cli_common.Menu() 374 menu.append(debugger_cli_common.MenuItem("run", "run")) 375 menu.append(debugger_cli_common.MenuItem("exit", "exit")) 376 out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu 377 378 return out 379 380 381def get_run_short_description(run_call_count, 382 fetches, 383 feed_dict, 384 is_callable_runner=False): 385 """Get a short description of the run() call. 386 387 Args: 388 run_call_count: (int) Run call counter. 389 fetches: Fetches of the `Session.run()` call. See doc of `Session.run()` 390 for more details. 391 feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()` 392 for more details. 393 is_callable_runner: (bool) whether a runner returned by 394 Session.make_callable is being run. 395 396 Returns: 397 (str) A short description of the run() call, including information about 398 the fetche(s) and feed(s). 399 """ 400 if is_callable_runner: 401 return "runner from make_callable()" 402 403 description = "run #%d: " % run_call_count 404 405 if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)): 406 description += "1 fetch (%s); " % common.get_graph_element_name(fetches) 407 else: 408 # Could be (nested) list, tuple, dict or namedtuple. 409 num_fetches = len(common.get_flattened_names(fetches)) 410 if num_fetches > 1: 411 description += "%d fetches; " % num_fetches 412 else: 413 description += "%d fetch; " % num_fetches 414 415 if not feed_dict: 416 description += "0 feeds" 417 else: 418 if len(feed_dict) == 1: 419 for key in feed_dict: 420 description += "1 feed (%s)" % ( 421 key 422 if isinstance(key, str) or not hasattr(key, "name") else key.name) 423 else: 424 description += "%d feeds" % len(feed_dict) 425 426 return description 427 428 429def get_error_intro(tf_error): 430 """Generate formatted intro for TensorFlow run-time error. 431 432 Args: 433 tf_error: (errors.OpError) TensorFlow run-time error object. 434 435 Returns: 436 (RichTextLines) Formatted intro message about the run-time OpError, with 437 sample commands for debugging. 438 """ 439 440 if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"): 441 op_name = tf_error.op.name 442 else: 443 op_name = None 444 445 intro_lines = [ 446 "--------------------------------------", 447 RL("!!! An error occurred during the run !!!", "blink"), 448 "", 449 ] 450 451 out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines) 452 453 if op_name is not None: 454 out.extend(debugger_cli_common.RichTextLines( 455 ["You may use the following commands to debug:"])) 456 out.extend( 457 _recommend_command("ni -a -d -t %s" % op_name, 458 "Inspect information about the failing op.", 459 create_link=True)) 460 out.extend( 461 _recommend_command("li -r %s" % op_name, 462 "List inputs to the failing op, recursively.", 463 create_link=True)) 464 465 out.extend( 466 _recommend_command( 467 "lt", 468 "List all tensors dumped during the failing run() call.", 469 create_link=True)) 470 else: 471 out.extend(debugger_cli_common.RichTextLines([ 472 "WARNING: Cannot determine the name of the op that caused the error."])) 473 474 more_lines = [ 475 "", 476 "Op name: %s" % op_name, 477 "Error type: " + str(type(tf_error)), 478 "", 479 "Details:", 480 str(tf_error), 481 "", 482 "--------------------------------------", 483 "", 484 ] 485 486 out.extend(debugger_cli_common.RichTextLines(more_lines)) 487 488 return out 489