1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Debugger Wrapper Session Consisting of a Local Curses-based CLI.""" 16import argparse 17import os 18import sys 19import tempfile 20 21# Google-internal import(s). 22from tensorflow.python.debug.cli import analyzer_cli 23from tensorflow.python.debug.cli import cli_config 24from tensorflow.python.debug.cli import cli_shared 25from tensorflow.python.debug.cli import command_parser 26from tensorflow.python.debug.cli import debugger_cli_common 27from tensorflow.python.debug.cli import profile_analyzer_cli 28from tensorflow.python.debug.cli import ui_factory 29from tensorflow.python.debug.lib import common 30from tensorflow.python.debug.lib import debug_data 31from tensorflow.python.debug.wrappers import framework 32from tensorflow.python.lib.io import file_io 33 34 35_DUMP_ROOT_PREFIX = "tfdbg_" 36 37 38# TODO(donglin) Remove use_random_config_path after b/137652456 is fixed. 39class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): 40 """Concrete subclass of BaseDebugWrapperSession implementing a local CLI. 41 42 This class has all the methods that a `session.Session` object has, in order 43 to support debugging with minimal code changes. Invoking its `run()` method 44 will launch the command-line interface (CLI) of tfdbg. 45 """ 46 47 def __init__(self, 48 sess, 49 dump_root=None, 50 log_usage=True, 51 ui_type="curses", 52 thread_name_filter=None, 53 config_file_path=False): 54 """Constructor of LocalCLIDebugWrapperSession. 55 56 Args: 57 sess: The TensorFlow `Session` object being wrapped. 58 dump_root: (`str`) optional path to the dump root directory. Must be a 59 directory that does not exist or an empty directory. If the directory 60 does not exist, it will be created by the debugger core during debug 61 `run()` calls and removed afterwards. If `None`, the debug dumps will 62 be at tfdbg_<random_string> under the system temp directory. 63 log_usage: (`bool`) whether the usage of this class is to be logged. 64 ui_type: (`str`) requested UI type. Currently supported: 65 (curses | readline) 66 thread_name_filter: Regular-expression white list for thread name. See 67 the doc of `BaseDebugWrapperSession` for details. 68 config_file_path: Optional override to the default configuration file 69 path, which is at `${HOME}/.tfdbg_config`. 70 71 Raises: 72 ValueError: If dump_root is an existing and non-empty directory or if 73 dump_root is a file. 74 """ 75 76 if log_usage: 77 pass # No logging for open-source. 78 79 framework.BaseDebugWrapperSession.__init__( 80 self, sess, thread_name_filter=thread_name_filter) 81 82 if not dump_root: 83 self._dump_root = tempfile.mkdtemp(prefix=_DUMP_ROOT_PREFIX) 84 else: 85 dump_root = os.path.expanduser(dump_root) 86 if os.path.isfile(dump_root): 87 raise ValueError("dump_root path points to a file: %s" % dump_root) 88 elif os.path.isdir(dump_root) and os.listdir(dump_root): 89 raise ValueError("dump_root path points to a non-empty directory: %s" % 90 dump_root) 91 92 self._dump_root = dump_root 93 94 self._initialize_argparsers() 95 96 # Registered tensor filters. 97 self._tensor_filters = {} 98 # Register frequently-used filter(s). 99 self.add_tensor_filter("has_inf_or_nan", debug_data.has_inf_or_nan) 100 101 # Below are the state variables of this wrapper object. 102 # _active_tensor_filter: what (if any) tensor filter is in effect. If such 103 # a filter is in effect, this object will call run() method of the 104 # underlying TensorFlow Session object until the filter passes. This is 105 # activated by the "-f" flag of the "run" command. 106 # _run_through_times: keeps track of how many times the wrapper needs to 107 # run through without stopping at the run-end CLI. It is activated by the 108 # "-t" option of the "run" command. 109 # _skip_debug: keeps track of whether the current run should be executed 110 # without debugging. It is activated by the "-n" option of the "run" 111 # command. 112 # 113 # _run_start_response: keeps track what OnRunStartResponse the wrapper 114 # should return at the next run-start callback. If this information is 115 # unavailable (i.e., is None), the run-start CLI will be launched to ask 116 # the user. This is the case, e.g., right before the first run starts. 117 self._active_tensor_filter = None 118 self._active_filter_exclude_node_names = None 119 self._active_tensor_filter_run_start_response = None 120 self._run_through_times = 1 121 self._skip_debug = False 122 self._run_start_response = None 123 self._is_run_start = True 124 self._ui_type = ui_type 125 self._config = None 126 if config_file_path: 127 self._config = cli_config.CLIConfig(config_file_path=config_file_path) 128 129 def _is_disk_usage_reset_each_run(self): 130 # The dumped tensors are all cleaned up after every Session.run 131 # in a command-line wrapper. 132 return True 133 134 def _initialize_argparsers(self): 135 self._argparsers = {} 136 ap = argparse.ArgumentParser( 137 description="Run through, with or without debug tensor watching.", 138 usage=argparse.SUPPRESS) 139 ap.add_argument( 140 "-t", 141 "--times", 142 dest="times", 143 type=int, 144 default=1, 145 help="How many Session.run() calls to proceed with.") 146 ap.add_argument( 147 "-n", 148 "--no_debug", 149 dest="no_debug", 150 action="store_true", 151 help="Run through without debug tensor watching.") 152 ap.add_argument( 153 "-f", 154 "--till_filter_pass", 155 dest="till_filter_pass", 156 type=str, 157 default="", 158 help="Run until a tensor in the graph passes the specified filter.") 159 ap.add_argument( 160 "-fenn", 161 "--filter_exclude_node_names", 162 dest="filter_exclude_node_names", 163 type=str, 164 default="", 165 help="When applying the tensor filter, exclude node with names " 166 "matching the regular expression. Applicable only if --tensor_filter " 167 "or -f is used.") 168 ap.add_argument( 169 "--node_name_filter", 170 dest="node_name_filter", 171 type=str, 172 default="", 173 help="Regular-expression filter for node names to be watched in the " 174 "run, e.g., loss, reshape.*") 175 ap.add_argument( 176 "--op_type_filter", 177 dest="op_type_filter", 178 type=str, 179 default="", 180 help="Regular-expression filter for op type to be watched in the run, " 181 "e.g., (MatMul|Add), Variable.*") 182 ap.add_argument( 183 "--tensor_dtype_filter", 184 dest="tensor_dtype_filter", 185 type=str, 186 default="", 187 help="Regular-expression filter for tensor dtype to be watched in the " 188 "run, e.g., (float32|float64), int.*") 189 ap.add_argument( 190 "-p", 191 "--profile", 192 dest="profile", 193 action="store_true", 194 help="Run and profile TensorFlow graph execution.") 195 self._argparsers["run"] = ap 196 197 ap = argparse.ArgumentParser( 198 description="Display information about this Session.run() call.", 199 usage=argparse.SUPPRESS) 200 self._argparsers["run_info"] = ap 201 202 self._argparsers["print_feed"] = command_parser.get_print_tensor_argparser( 203 "Print the value of a feed in feed_dict.") 204 205 def add_tensor_filter(self, filter_name, tensor_filter): 206 """Add a tensor filter. 207 208 Args: 209 filter_name: (`str`) name of the filter. 210 tensor_filter: (`callable`) the filter callable. See the doc string of 211 `DebugDumpDir.find()` for more details about its signature. 212 """ 213 214 self._tensor_filters[filter_name] = tensor_filter 215 216 def on_session_init(self, request): 217 """Overrides on-session-init callback. 218 219 Args: 220 request: An instance of `OnSessionInitRequest`. 221 222 Returns: 223 An instance of `OnSessionInitResponse`. 224 """ 225 226 return framework.OnSessionInitResponse( 227 framework.OnSessionInitAction.PROCEED) 228 229 def on_run_start(self, request): 230 """Overrides on-run-start callback. 231 232 Args: 233 request: An instance of `OnRunStartRequest`. 234 235 Returns: 236 An instance of `OnRunStartResponse`. 237 """ 238 self._is_run_start = True 239 self._update_run_calls_state( 240 request.run_call_count, request.fetches, request.feed_dict, 241 is_callable_runner=request.is_callable_runner) 242 243 if self._active_tensor_filter: 244 # If we are running until a filter passes, we just need to keep running 245 # with the previous `OnRunStartResponse`. 246 return self._active_tensor_filter_run_start_response 247 248 self._exit_if_requested_by_user() 249 250 if self._run_call_count > 1 and not self._skip_debug: 251 if self._run_through_times > 0: 252 # Just run through without debugging. 253 return framework.OnRunStartResponse( 254 framework.OnRunStartAction.NON_DEBUG_RUN, []) 255 elif self._run_through_times == 0: 256 # It is the run at which the run-end CLI will be launched: activate 257 # debugging. 258 return (self._run_start_response or 259 framework.OnRunStartResponse( 260 framework.OnRunStartAction.DEBUG_RUN, 261 self._get_run_debug_urls())) 262 263 if self._run_start_response is None: 264 self._prep_cli_for_run_start() 265 266 self._run_start_response = self._launch_cli() 267 if self._active_tensor_filter: 268 self._active_tensor_filter_run_start_response = self._run_start_response 269 if self._run_through_times > 1: 270 self._run_through_times -= 1 271 272 self._exit_if_requested_by_user() 273 return self._run_start_response 274 275 def _exit_if_requested_by_user(self): 276 if self._run_start_response == debugger_cli_common.EXPLICIT_USER_EXIT: 277 # Explicit user "exit" command leads to sys.exit(1). 278 print( 279 "Note: user exited from debugger CLI: Calling sys.exit(1).", 280 file=sys.stderr) 281 sys.exit(1) 282 283 def _prep_cli_for_run_start(self): 284 """Prepare (but not launch) the CLI for run-start.""" 285 self._run_cli = ui_factory.get_ui(self._ui_type, config=self._config) 286 287 help_intro = debugger_cli_common.RichTextLines([]) 288 if self._run_call_count == 1: 289 # Show logo at the onset of the first run. 290 help_intro.extend(cli_shared.get_tfdbg_logo()) 291 help_intro.extend(debugger_cli_common.get_tensorflow_version_lines()) 292 help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:")) 293 help_intro.extend(self._run_info) 294 295 self._run_cli.set_help_intro(help_intro) 296 297 # Create initial screen output detailing the run. 298 self._title = "run-start: " + self._run_description 299 self._init_command = "run_info" 300 self._title_color = "blue_on_white" 301 302 def on_run_end(self, request): 303 """Overrides on-run-end callback. 304 305 Actions taken: 306 1) Load the debug dump. 307 2) Bring up the Analyzer CLI. 308 309 Args: 310 request: An instance of OnSessionInitRequest. 311 312 Returns: 313 An instance of OnSessionInitResponse. 314 """ 315 316 self._is_run_start = False 317 if request.performed_action == framework.OnRunStartAction.DEBUG_RUN: 318 partition_graphs = None 319 if request.run_metadata and request.run_metadata.partition_graphs: 320 partition_graphs = request.run_metadata.partition_graphs 321 elif request.client_graph_def: 322 partition_graphs = [request.client_graph_def] 323 324 if request.tf_error and not os.path.isdir(self._dump_root): 325 # It is possible that the dump root may not exist due to errors that 326 # have occurred prior to graph execution (e.g., invalid device 327 # assignments), in which case we will just raise the exception as the 328 # unwrapped Session does. 329 raise request.tf_error 330 331 debug_dump = debug_data.DebugDumpDir( 332 self._dump_root, partition_graphs=partition_graphs) 333 debug_dump.set_python_graph(self._sess.graph) 334 335 passed_filter = None 336 passed_filter_exclude_node_names = None 337 if self._active_tensor_filter: 338 if not debug_dump.find( 339 self._tensor_filters[self._active_tensor_filter], first_n=1, 340 exclude_node_names=self._active_filter_exclude_node_names): 341 # No dumped tensor passes the filter in this run. Clean up the dump 342 # directory and move on. 343 self._remove_dump_root() 344 return framework.OnRunEndResponse() 345 else: 346 # Some dumped tensor(s) from this run passed the filter. 347 passed_filter = self._active_tensor_filter 348 passed_filter_exclude_node_names = ( 349 self._active_filter_exclude_node_names) 350 self._active_tensor_filter = None 351 self._active_filter_exclude_node_names = None 352 353 self._prep_debug_cli_for_run_end( 354 debug_dump, request.tf_error, passed_filter, 355 passed_filter_exclude_node_names) 356 357 self._run_start_response = self._launch_cli() 358 359 # Clean up the dump generated by this run. 360 self._remove_dump_root() 361 elif request.performed_action == framework.OnRunStartAction.PROFILE_RUN: 362 self._prep_profile_cli_for_run_end(self._sess.graph, request.run_metadata) 363 self._run_start_response = self._launch_cli() 364 else: 365 # No debug information to show following a non-debug run() call. 366 self._run_start_response = None 367 368 # Return placeholder response that currently holds no additional 369 # information. 370 return framework.OnRunEndResponse() 371 372 def _remove_dump_root(self): 373 if os.path.isdir(self._dump_root): 374 file_io.delete_recursively(self._dump_root) 375 376 def _prep_debug_cli_for_run_end(self, 377 debug_dump, 378 tf_error, 379 passed_filter, 380 passed_filter_exclude_node_names): 381 """Prepare (but not launch) CLI for run-end, with debug dump from the run. 382 383 Args: 384 debug_dump: (debug_data.DebugDumpDir) The debug dump directory from this 385 run. 386 tf_error: (None or OpError) OpError that happened during the run() call 387 (if any). 388 passed_filter: (None or str) Name of the tensor filter that just passed 389 and caused the preparation of this run-end CLI (if any). 390 passed_filter_exclude_node_names: (None or str) Regular expression used 391 with the tensor filter to exclude ops with names matching the regular 392 expression. 393 """ 394 395 if tf_error: 396 help_intro = cli_shared.get_error_intro(tf_error) 397 398 self._init_command = "help" 399 self._title_color = "red_on_white" 400 else: 401 help_intro = None 402 self._init_command = "lt" 403 404 self._title_color = "black_on_white" 405 if passed_filter is not None: 406 # Some dumped tensor(s) from this run passed the filter. 407 self._init_command = "lt -f %s" % passed_filter 408 if passed_filter_exclude_node_names: 409 self._init_command += (" --filter_exclude_node_names %s" % 410 passed_filter_exclude_node_names) 411 self._title_color = "red_on_white" 412 413 self._run_cli = analyzer_cli.create_analyzer_ui( 414 debug_dump, 415 self._tensor_filters, 416 ui_type=self._ui_type, 417 on_ui_exit=self._remove_dump_root, 418 config=self._config) 419 420 # Get names of all dumped tensors. 421 dumped_tensor_names = [] 422 for datum in debug_dump.dumped_tensor_data: 423 dumped_tensor_names.append("%s:%d" % 424 (datum.node_name, datum.output_slot)) 425 426 # Tab completions for command "print_tensors". 427 self._run_cli.register_tab_comp_context(["print_tensor", "pt"], 428 dumped_tensor_names) 429 430 # Tab completion for commands "node_info", "list_inputs" and 431 # "list_outputs". The list comprehension is used below because nodes() 432 # output can be unicodes and they need to be converted to strs. 433 self._run_cli.register_tab_comp_context( 434 ["node_info", "ni", "list_inputs", "li", "list_outputs", "lo"], 435 [str(node_name) for node_name in debug_dump.nodes()]) 436 # TODO(cais): Reduce API surface area for aliases vis-a-vis tab 437 # completion contexts and registered command handlers. 438 439 self._title = "run-end: " + self._run_description 440 441 if help_intro: 442 self._run_cli.set_help_intro(help_intro) 443 444 def _prep_profile_cli_for_run_end(self, py_graph, run_metadata): 445 self._init_command = "lp" 446 self._run_cli = profile_analyzer_cli.create_profiler_ui( 447 py_graph, run_metadata, ui_type=self._ui_type, 448 config=self._run_cli.config) 449 self._title = "run-end (profiler mode): " + self._run_description 450 451 def _launch_cli(self): 452 """Launch the interactive command-line interface. 453 454 Returns: 455 The OnRunStartResponse specified by the user using the "run" command. 456 """ 457 458 self._register_this_run_info(self._run_cli) 459 response = self._run_cli.run_ui( 460 init_command=self._init_command, 461 title=self._title, 462 title_color=self._title_color) 463 464 return response 465 466 def _run_info_handler(self, args, screen_info=None): 467 output = debugger_cli_common.RichTextLines([]) 468 469 if self._run_call_count == 1: 470 output.extend(cli_shared.get_tfdbg_logo()) 471 output.extend(debugger_cli_common.get_tensorflow_version_lines()) 472 output.extend(self._run_info) 473 474 if (not self._is_run_start and 475 debugger_cli_common.MAIN_MENU_KEY in output.annotations): 476 menu = output.annotations[debugger_cli_common.MAIN_MENU_KEY] 477 if "list_tensors" not in menu.captions(): 478 menu.insert( 479 0, debugger_cli_common.MenuItem("list_tensors", "list_tensors")) 480 481 return output 482 483 def _print_feed_handler(self, args, screen_info=None): 484 np_printoptions = cli_shared.numpy_printoptions_from_screen_info( 485 screen_info) 486 487 if not self._feed_dict: 488 return cli_shared.error( 489 "The feed_dict of the current run is None or empty.") 490 491 parsed = self._argparsers["print_feed"].parse_args(args) 492 tensor_name, tensor_slicing = ( 493 command_parser.parse_tensor_name_with_slicing(parsed.tensor_name)) 494 495 feed_key = None 496 feed_value = None 497 for key in self._feed_dict: 498 key_name = common.get_graph_element_name(key) 499 if key_name == tensor_name: 500 feed_key = key_name 501 feed_value = self._feed_dict[key] 502 break 503 504 if feed_key is None: 505 return cli_shared.error( 506 "The feed_dict of the current run does not contain the key %s" % 507 tensor_name) 508 else: 509 return cli_shared.format_tensor( 510 feed_value, 511 feed_key + " (feed)", 512 np_printoptions, 513 print_all=parsed.print_all, 514 tensor_slicing=tensor_slicing, 515 highlight_options=cli_shared.parse_ranges_highlight(parsed.ranges), 516 include_numeric_summary=parsed.numeric_summary) 517 518 def _run_handler(self, args, screen_info=None): 519 """Command handler for "run" command during on-run-start.""" 520 521 del screen_info # Currently unused. 522 523 parsed = self._argparsers["run"].parse_args(args) 524 parsed.node_name_filter = parsed.node_name_filter or None 525 parsed.op_type_filter = parsed.op_type_filter or None 526 parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None 527 528 if parsed.filter_exclude_node_names and not parsed.till_filter_pass: 529 raise ValueError( 530 "The --filter_exclude_node_names (or -feon) flag is valid only if " 531 "the --till_filter_pass (or -f) flag is used.") 532 533 if parsed.profile: 534 raise debugger_cli_common.CommandLineExit( 535 exit_token=framework.OnRunStartResponse( 536 framework.OnRunStartAction.PROFILE_RUN, [])) 537 538 self._skip_debug = parsed.no_debug 539 self._run_through_times = parsed.times 540 541 if parsed.times > 1 or parsed.no_debug: 542 # If requested -t times > 1, the very next run will be a non-debug run. 543 action = framework.OnRunStartAction.NON_DEBUG_RUN 544 debug_urls = [] 545 else: 546 action = framework.OnRunStartAction.DEBUG_RUN 547 debug_urls = self._get_run_debug_urls() 548 run_start_response = framework.OnRunStartResponse( 549 action, 550 debug_urls, 551 node_name_regex_allowlist=parsed.node_name_filter, 552 op_type_regex_allowlist=parsed.op_type_filter, 553 tensor_dtype_regex_allowlist=parsed.tensor_dtype_filter) 554 555 if parsed.till_filter_pass: 556 # For the run-till-filter-pass (run -f) mode, use the DEBUG_RUN 557 # option to access the intermediate tensors, and set the corresponding 558 # state flag of the class itself to True. 559 if parsed.till_filter_pass in self._tensor_filters: 560 action = framework.OnRunStartAction.DEBUG_RUN 561 self._active_tensor_filter = parsed.till_filter_pass 562 self._active_filter_exclude_node_names = ( 563 parsed.filter_exclude_node_names) 564 self._active_tensor_filter_run_start_response = run_start_response 565 else: 566 # Handle invalid filter name. 567 return debugger_cli_common.RichTextLines( 568 ["ERROR: tensor filter \"%s\" does not exist." % 569 parsed.till_filter_pass]) 570 571 # Raise CommandLineExit exception to cause the CLI to exit. 572 raise debugger_cli_common.CommandLineExit(exit_token=run_start_response) 573 574 def _register_this_run_info(self, curses_cli): 575 curses_cli.register_command_handler( 576 "run", 577 self._run_handler, 578 self._argparsers["run"].format_help(), 579 prefix_aliases=["r"]) 580 curses_cli.register_command_handler( 581 "run_info", 582 self._run_info_handler, 583 self._argparsers["run_info"].format_help(), 584 prefix_aliases=["ri"]) 585 curses_cli.register_command_handler( 586 "print_feed", 587 self._print_feed_handler, 588 self._argparsers["print_feed"].format_help(), 589 prefix_aliases=["pf"]) 590 591 if self._tensor_filters: 592 # Register tab completion for the filter names. 593 curses_cli.register_tab_comp_context(["run", "r"], 594 list(self._tensor_filters.keys())) 595 if self._feed_dict and hasattr(self._feed_dict, "keys"): 596 # Register tab completion for feed_dict keys. 597 feed_keys = [common.get_graph_element_name(key) 598 for key in self._feed_dict.keys()] 599 curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys) 600 601 def _get_run_debug_urls(self): 602 """Get the debug_urls value for the current run() call. 603 604 Returns: 605 debug_urls: (list of str) Debug URLs for the current run() call. 606 Currently, the list consists of only one URL that is a file:// URL. 607 """ 608 609 return ["file://" + self._dump_root] 610 611 def _update_run_calls_state(self, 612 run_call_count, 613 fetches, 614 feed_dict, 615 is_callable_runner=False): 616 """Update the internal state with regard to run() call history. 617 618 Args: 619 run_call_count: (int) Number of run() calls that have occurred. 620 fetches: a node/tensor or a list of node/tensor that are the fetches of 621 the run() call. This is the same as the fetches argument to the run() 622 call. 623 feed_dict: None of a dict. This is the feed_dict argument to the run() 624 call. 625 is_callable_runner: (bool) whether a runner returned by 626 Session.make_callable is being run. 627 """ 628 629 self._run_call_count = run_call_count 630 self._feed_dict = feed_dict 631 self._run_description = cli_shared.get_run_short_description( 632 run_call_count, 633 fetches, 634 feed_dict, 635 is_callable_runner=is_callable_runner) 636 self._run_through_times -= 1 637 638 self._run_info = cli_shared.get_run_start_intro( 639 run_call_count, 640 fetches, 641 feed_dict, 642 self._tensor_filters, 643 is_callable_runner=is_callable_runner) 644