xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/wrappers/local_cli_wrapper.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Debugger Wrapper Session Consisting of a Local Curses-based CLI."""
16import argparse
17import os
18import sys
19import tempfile
20
21# Google-internal import(s).
22from tensorflow.python.debug.cli import analyzer_cli
23from tensorflow.python.debug.cli import cli_config
24from tensorflow.python.debug.cli import cli_shared
25from tensorflow.python.debug.cli import command_parser
26from tensorflow.python.debug.cli import debugger_cli_common
27from tensorflow.python.debug.cli import profile_analyzer_cli
28from tensorflow.python.debug.cli import ui_factory
29from tensorflow.python.debug.lib import common
30from tensorflow.python.debug.lib import debug_data
31from tensorflow.python.debug.wrappers import framework
32from tensorflow.python.lib.io import file_io
33
34
35_DUMP_ROOT_PREFIX = "tfdbg_"
36
37
38# TODO(donglin) Remove use_random_config_path after b/137652456 is fixed.
39class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
40  """Concrete subclass of BaseDebugWrapperSession implementing a local CLI.
41
42  This class has all the methods that a `session.Session` object has, in order
43  to support debugging with minimal code changes. Invoking its `run()` method
44  will launch the command-line interface (CLI) of tfdbg.
45  """
46
47  def __init__(self,
48               sess,
49               dump_root=None,
50               log_usage=True,
51               ui_type="curses",
52               thread_name_filter=None,
53               config_file_path=False):
54    """Constructor of LocalCLIDebugWrapperSession.
55
56    Args:
57      sess: The TensorFlow `Session` object being wrapped.
58      dump_root: (`str`) optional path to the dump root directory. Must be a
59        directory that does not exist or an empty directory. If the directory
60        does not exist, it will be created by the debugger core during debug
61        `run()` calls and removed afterwards. If `None`, the debug dumps will
62        be at tfdbg_<random_string> under the system temp directory.
63      log_usage: (`bool`) whether the usage of this class is to be logged.
64      ui_type: (`str`) requested UI type. Currently supported:
65        (curses | readline)
66      thread_name_filter: Regular-expression white list for thread name. See
67        the doc of `BaseDebugWrapperSession` for details.
68      config_file_path: Optional override to the default configuration file
69        path, which is at `${HOME}/.tfdbg_config`.
70
71    Raises:
72      ValueError: If dump_root is an existing and non-empty directory or if
73        dump_root is a file.
74    """
75
76    if log_usage:
77      pass  # No logging for open-source.
78
79    framework.BaseDebugWrapperSession.__init__(
80        self, sess, thread_name_filter=thread_name_filter)
81
82    if not dump_root:
83      self._dump_root = tempfile.mkdtemp(prefix=_DUMP_ROOT_PREFIX)
84    else:
85      dump_root = os.path.expanduser(dump_root)
86      if os.path.isfile(dump_root):
87        raise ValueError("dump_root path points to a file: %s" % dump_root)
88      elif os.path.isdir(dump_root) and os.listdir(dump_root):
89        raise ValueError("dump_root path points to a non-empty directory: %s" %
90                         dump_root)
91
92      self._dump_root = dump_root
93
94    self._initialize_argparsers()
95
96    # Registered tensor filters.
97    self._tensor_filters = {}
98    # Register frequently-used filter(s).
99    self.add_tensor_filter("has_inf_or_nan", debug_data.has_inf_or_nan)
100
101    # Below are the state variables of this wrapper object.
102    # _active_tensor_filter: what (if any) tensor filter is in effect. If such
103    #   a filter is in effect, this object will call run() method of the
104    #   underlying TensorFlow Session object until the filter passes. This is
105    #   activated by the "-f" flag of the "run" command.
106    # _run_through_times: keeps track of how many times the wrapper needs to
107    #   run through without stopping at the run-end CLI. It is activated by the
108    #   "-t" option of the "run" command.
109    # _skip_debug: keeps track of whether the current run should be executed
110    #   without debugging. It is activated by the "-n" option of the "run"
111    #   command.
112    #
113    # _run_start_response: keeps track what OnRunStartResponse the wrapper
114    #   should return at the next run-start callback. If this information is
115    #   unavailable (i.e., is None), the run-start CLI will be launched to ask
116    #   the user. This is the case, e.g., right before the first run starts.
117    self._active_tensor_filter = None
118    self._active_filter_exclude_node_names = None
119    self._active_tensor_filter_run_start_response = None
120    self._run_through_times = 1
121    self._skip_debug = False
122    self._run_start_response = None
123    self._is_run_start = True
124    self._ui_type = ui_type
125    self._config = None
126    if config_file_path:
127      self._config = cli_config.CLIConfig(config_file_path=config_file_path)
128
129  def _is_disk_usage_reset_each_run(self):
130    # The dumped tensors are all cleaned up after every Session.run
131    # in a command-line wrapper.
132    return True
133
134  def _initialize_argparsers(self):
135    self._argparsers = {}
136    ap = argparse.ArgumentParser(
137        description="Run through, with or without debug tensor watching.",
138        usage=argparse.SUPPRESS)
139    ap.add_argument(
140        "-t",
141        "--times",
142        dest="times",
143        type=int,
144        default=1,
145        help="How many Session.run() calls to proceed with.")
146    ap.add_argument(
147        "-n",
148        "--no_debug",
149        dest="no_debug",
150        action="store_true",
151        help="Run through without debug tensor watching.")
152    ap.add_argument(
153        "-f",
154        "--till_filter_pass",
155        dest="till_filter_pass",
156        type=str,
157        default="",
158        help="Run until a tensor in the graph passes the specified filter.")
159    ap.add_argument(
160        "-fenn",
161        "--filter_exclude_node_names",
162        dest="filter_exclude_node_names",
163        type=str,
164        default="",
165        help="When applying the tensor filter, exclude node with names "
166        "matching the regular expression. Applicable only if --tensor_filter "
167        "or -f is used.")
168    ap.add_argument(
169        "--node_name_filter",
170        dest="node_name_filter",
171        type=str,
172        default="",
173        help="Regular-expression filter for node names to be watched in the "
174        "run, e.g., loss, reshape.*")
175    ap.add_argument(
176        "--op_type_filter",
177        dest="op_type_filter",
178        type=str,
179        default="",
180        help="Regular-expression filter for op type to be watched in the run, "
181        "e.g., (MatMul|Add), Variable.*")
182    ap.add_argument(
183        "--tensor_dtype_filter",
184        dest="tensor_dtype_filter",
185        type=str,
186        default="",
187        help="Regular-expression filter for tensor dtype to be watched in the "
188        "run, e.g., (float32|float64), int.*")
189    ap.add_argument(
190        "-p",
191        "--profile",
192        dest="profile",
193        action="store_true",
194        help="Run and profile TensorFlow graph execution.")
195    self._argparsers["run"] = ap
196
197    ap = argparse.ArgumentParser(
198        description="Display information about this Session.run() call.",
199        usage=argparse.SUPPRESS)
200    self._argparsers["run_info"] = ap
201
202    self._argparsers["print_feed"] = command_parser.get_print_tensor_argparser(
203        "Print the value of a feed in feed_dict.")
204
205  def add_tensor_filter(self, filter_name, tensor_filter):
206    """Add a tensor filter.
207
208    Args:
209      filter_name: (`str`) name of the filter.
210      tensor_filter: (`callable`) the filter callable. See the doc string of
211        `DebugDumpDir.find()` for more details about its signature.
212    """
213
214    self._tensor_filters[filter_name] = tensor_filter
215
216  def on_session_init(self, request):
217    """Overrides on-session-init callback.
218
219    Args:
220      request: An instance of `OnSessionInitRequest`.
221
222    Returns:
223      An instance of `OnSessionInitResponse`.
224    """
225
226    return framework.OnSessionInitResponse(
227        framework.OnSessionInitAction.PROCEED)
228
229  def on_run_start(self, request):
230    """Overrides on-run-start callback.
231
232    Args:
233      request: An instance of `OnRunStartRequest`.
234
235    Returns:
236      An instance of `OnRunStartResponse`.
237    """
238    self._is_run_start = True
239    self._update_run_calls_state(
240        request.run_call_count, request.fetches, request.feed_dict,
241        is_callable_runner=request.is_callable_runner)
242
243    if self._active_tensor_filter:
244      # If we are running until a filter passes, we just need to keep running
245      # with the previous `OnRunStartResponse`.
246      return self._active_tensor_filter_run_start_response
247
248    self._exit_if_requested_by_user()
249
250    if self._run_call_count > 1 and not self._skip_debug:
251      if self._run_through_times > 0:
252        # Just run through without debugging.
253        return framework.OnRunStartResponse(
254            framework.OnRunStartAction.NON_DEBUG_RUN, [])
255      elif self._run_through_times == 0:
256        # It is the run at which the run-end CLI will be launched: activate
257        # debugging.
258        return (self._run_start_response or
259                framework.OnRunStartResponse(
260                    framework.OnRunStartAction.DEBUG_RUN,
261                    self._get_run_debug_urls()))
262
263    if self._run_start_response is None:
264      self._prep_cli_for_run_start()
265
266      self._run_start_response = self._launch_cli()
267      if self._active_tensor_filter:
268        self._active_tensor_filter_run_start_response = self._run_start_response
269      if self._run_through_times > 1:
270        self._run_through_times -= 1
271
272    self._exit_if_requested_by_user()
273    return self._run_start_response
274
275  def _exit_if_requested_by_user(self):
276    if self._run_start_response == debugger_cli_common.EXPLICIT_USER_EXIT:
277      # Explicit user "exit" command leads to sys.exit(1).
278      print(
279          "Note: user exited from debugger CLI: Calling sys.exit(1).",
280          file=sys.stderr)
281      sys.exit(1)
282
283  def _prep_cli_for_run_start(self):
284    """Prepare (but not launch) the CLI for run-start."""
285    self._run_cli = ui_factory.get_ui(self._ui_type, config=self._config)
286
287    help_intro = debugger_cli_common.RichTextLines([])
288    if self._run_call_count == 1:
289      # Show logo at the onset of the first run.
290      help_intro.extend(cli_shared.get_tfdbg_logo())
291      help_intro.extend(debugger_cli_common.get_tensorflow_version_lines())
292    help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:"))
293    help_intro.extend(self._run_info)
294
295    self._run_cli.set_help_intro(help_intro)
296
297    # Create initial screen output detailing the run.
298    self._title = "run-start: " + self._run_description
299    self._init_command = "run_info"
300    self._title_color = "blue_on_white"
301
302  def on_run_end(self, request):
303    """Overrides on-run-end callback.
304
305    Actions taken:
306      1) Load the debug dump.
307      2) Bring up the Analyzer CLI.
308
309    Args:
310      request: An instance of OnSessionInitRequest.
311
312    Returns:
313      An instance of OnSessionInitResponse.
314    """
315
316    self._is_run_start = False
317    if request.performed_action == framework.OnRunStartAction.DEBUG_RUN:
318      partition_graphs = None
319      if request.run_metadata and request.run_metadata.partition_graphs:
320        partition_graphs = request.run_metadata.partition_graphs
321      elif request.client_graph_def:
322        partition_graphs = [request.client_graph_def]
323
324      if request.tf_error and not os.path.isdir(self._dump_root):
325        # It is possible that the dump root may not exist due to errors that
326        # have occurred prior to graph execution (e.g., invalid device
327        # assignments), in which case we will just raise the exception as the
328        # unwrapped Session does.
329        raise request.tf_error
330
331      debug_dump = debug_data.DebugDumpDir(
332          self._dump_root, partition_graphs=partition_graphs)
333      debug_dump.set_python_graph(self._sess.graph)
334
335      passed_filter = None
336      passed_filter_exclude_node_names = None
337      if self._active_tensor_filter:
338        if not debug_dump.find(
339            self._tensor_filters[self._active_tensor_filter], first_n=1,
340            exclude_node_names=self._active_filter_exclude_node_names):
341          # No dumped tensor passes the filter in this run. Clean up the dump
342          # directory and move on.
343          self._remove_dump_root()
344          return framework.OnRunEndResponse()
345        else:
346          # Some dumped tensor(s) from this run passed the filter.
347          passed_filter = self._active_tensor_filter
348          passed_filter_exclude_node_names = (
349              self._active_filter_exclude_node_names)
350          self._active_tensor_filter = None
351          self._active_filter_exclude_node_names = None
352
353      self._prep_debug_cli_for_run_end(
354          debug_dump, request.tf_error, passed_filter,
355          passed_filter_exclude_node_names)
356
357      self._run_start_response = self._launch_cli()
358
359      # Clean up the dump generated by this run.
360      self._remove_dump_root()
361    elif request.performed_action == framework.OnRunStartAction.PROFILE_RUN:
362      self._prep_profile_cli_for_run_end(self._sess.graph, request.run_metadata)
363      self._run_start_response = self._launch_cli()
364    else:
365      # No debug information to show following a non-debug run() call.
366      self._run_start_response = None
367
368    # Return placeholder response that currently holds no additional
369    # information.
370    return framework.OnRunEndResponse()
371
372  def _remove_dump_root(self):
373    if os.path.isdir(self._dump_root):
374      file_io.delete_recursively(self._dump_root)
375
376  def _prep_debug_cli_for_run_end(self,
377                                  debug_dump,
378                                  tf_error,
379                                  passed_filter,
380                                  passed_filter_exclude_node_names):
381    """Prepare (but not launch) CLI for run-end, with debug dump from the run.
382
383    Args:
384      debug_dump: (debug_data.DebugDumpDir) The debug dump directory from this
385        run.
386      tf_error: (None or OpError) OpError that happened during the run() call
387        (if any).
388      passed_filter: (None or str) Name of the tensor filter that just passed
389        and caused the preparation of this run-end CLI (if any).
390      passed_filter_exclude_node_names: (None or str) Regular expression used
391        with the tensor filter to exclude ops with names matching the regular
392        expression.
393    """
394
395    if tf_error:
396      help_intro = cli_shared.get_error_intro(tf_error)
397
398      self._init_command = "help"
399      self._title_color = "red_on_white"
400    else:
401      help_intro = None
402      self._init_command = "lt"
403
404      self._title_color = "black_on_white"
405      if passed_filter is not None:
406        # Some dumped tensor(s) from this run passed the filter.
407        self._init_command = "lt -f %s" % passed_filter
408        if passed_filter_exclude_node_names:
409          self._init_command += (" --filter_exclude_node_names %s" %
410                                 passed_filter_exclude_node_names)
411        self._title_color = "red_on_white"
412
413    self._run_cli = analyzer_cli.create_analyzer_ui(
414        debug_dump,
415        self._tensor_filters,
416        ui_type=self._ui_type,
417        on_ui_exit=self._remove_dump_root,
418        config=self._config)
419
420    # Get names of all dumped tensors.
421    dumped_tensor_names = []
422    for datum in debug_dump.dumped_tensor_data:
423      dumped_tensor_names.append("%s:%d" %
424                                 (datum.node_name, datum.output_slot))
425
426    # Tab completions for command "print_tensors".
427    self._run_cli.register_tab_comp_context(["print_tensor", "pt"],
428                                            dumped_tensor_names)
429
430    # Tab completion for commands "node_info", "list_inputs" and
431    # "list_outputs". The list comprehension is used below because nodes()
432    # output can be unicodes and they need to be converted to strs.
433    self._run_cli.register_tab_comp_context(
434        ["node_info", "ni", "list_inputs", "li", "list_outputs", "lo"],
435        [str(node_name) for node_name in debug_dump.nodes()])
436    # TODO(cais): Reduce API surface area for aliases vis-a-vis tab
437    #    completion contexts and registered command handlers.
438
439    self._title = "run-end: " + self._run_description
440
441    if help_intro:
442      self._run_cli.set_help_intro(help_intro)
443
444  def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):
445    self._init_command = "lp"
446    self._run_cli = profile_analyzer_cli.create_profiler_ui(
447        py_graph, run_metadata, ui_type=self._ui_type,
448        config=self._run_cli.config)
449    self._title = "run-end (profiler mode): " + self._run_description
450
451  def _launch_cli(self):
452    """Launch the interactive command-line interface.
453
454    Returns:
455      The OnRunStartResponse specified by the user using the "run" command.
456    """
457
458    self._register_this_run_info(self._run_cli)
459    response = self._run_cli.run_ui(
460        init_command=self._init_command,
461        title=self._title,
462        title_color=self._title_color)
463
464    return response
465
466  def _run_info_handler(self, args, screen_info=None):
467    output = debugger_cli_common.RichTextLines([])
468
469    if self._run_call_count == 1:
470      output.extend(cli_shared.get_tfdbg_logo())
471      output.extend(debugger_cli_common.get_tensorflow_version_lines())
472    output.extend(self._run_info)
473
474    if (not self._is_run_start and
475        debugger_cli_common.MAIN_MENU_KEY in output.annotations):
476      menu = output.annotations[debugger_cli_common.MAIN_MENU_KEY]
477      if "list_tensors" not in menu.captions():
478        menu.insert(
479            0, debugger_cli_common.MenuItem("list_tensors", "list_tensors"))
480
481    return output
482
483  def _print_feed_handler(self, args, screen_info=None):
484    np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
485        screen_info)
486
487    if not self._feed_dict:
488      return cli_shared.error(
489          "The feed_dict of the current run is None or empty.")
490
491    parsed = self._argparsers["print_feed"].parse_args(args)
492    tensor_name, tensor_slicing = (
493        command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
494
495    feed_key = None
496    feed_value = None
497    for key in self._feed_dict:
498      key_name = common.get_graph_element_name(key)
499      if key_name == tensor_name:
500        feed_key = key_name
501        feed_value = self._feed_dict[key]
502        break
503
504    if feed_key is None:
505      return cli_shared.error(
506          "The feed_dict of the current run does not contain the key %s" %
507          tensor_name)
508    else:
509      return cli_shared.format_tensor(
510          feed_value,
511          feed_key + " (feed)",
512          np_printoptions,
513          print_all=parsed.print_all,
514          tensor_slicing=tensor_slicing,
515          highlight_options=cli_shared.parse_ranges_highlight(parsed.ranges),
516          include_numeric_summary=parsed.numeric_summary)
517
518  def _run_handler(self, args, screen_info=None):
519    """Command handler for "run" command during on-run-start."""
520
521    del screen_info  # Currently unused.
522
523    parsed = self._argparsers["run"].parse_args(args)
524    parsed.node_name_filter = parsed.node_name_filter or None
525    parsed.op_type_filter = parsed.op_type_filter or None
526    parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None
527
528    if parsed.filter_exclude_node_names and not parsed.till_filter_pass:
529      raise ValueError(
530          "The --filter_exclude_node_names (or -feon) flag is valid only if "
531          "the --till_filter_pass (or -f) flag is used.")
532
533    if parsed.profile:
534      raise debugger_cli_common.CommandLineExit(
535          exit_token=framework.OnRunStartResponse(
536              framework.OnRunStartAction.PROFILE_RUN, []))
537
538    self._skip_debug = parsed.no_debug
539    self._run_through_times = parsed.times
540
541    if parsed.times > 1 or parsed.no_debug:
542      # If requested -t times > 1, the very next run will be a non-debug run.
543      action = framework.OnRunStartAction.NON_DEBUG_RUN
544      debug_urls = []
545    else:
546      action = framework.OnRunStartAction.DEBUG_RUN
547      debug_urls = self._get_run_debug_urls()
548    run_start_response = framework.OnRunStartResponse(
549        action,
550        debug_urls,
551        node_name_regex_allowlist=parsed.node_name_filter,
552        op_type_regex_allowlist=parsed.op_type_filter,
553        tensor_dtype_regex_allowlist=parsed.tensor_dtype_filter)
554
555    if parsed.till_filter_pass:
556      # For the run-till-filter-pass (run -f) mode, use the DEBUG_RUN
557      # option to access the intermediate tensors, and set the corresponding
558      # state flag of the class itself to True.
559      if parsed.till_filter_pass in self._tensor_filters:
560        action = framework.OnRunStartAction.DEBUG_RUN
561        self._active_tensor_filter = parsed.till_filter_pass
562        self._active_filter_exclude_node_names = (
563            parsed.filter_exclude_node_names)
564        self._active_tensor_filter_run_start_response = run_start_response
565      else:
566        # Handle invalid filter name.
567        return debugger_cli_common.RichTextLines(
568            ["ERROR: tensor filter \"%s\" does not exist." %
569             parsed.till_filter_pass])
570
571    # Raise CommandLineExit exception to cause the CLI to exit.
572    raise debugger_cli_common.CommandLineExit(exit_token=run_start_response)
573
574  def _register_this_run_info(self, curses_cli):
575    curses_cli.register_command_handler(
576        "run",
577        self._run_handler,
578        self._argparsers["run"].format_help(),
579        prefix_aliases=["r"])
580    curses_cli.register_command_handler(
581        "run_info",
582        self._run_info_handler,
583        self._argparsers["run_info"].format_help(),
584        prefix_aliases=["ri"])
585    curses_cli.register_command_handler(
586        "print_feed",
587        self._print_feed_handler,
588        self._argparsers["print_feed"].format_help(),
589        prefix_aliases=["pf"])
590
591    if self._tensor_filters:
592      # Register tab completion for the filter names.
593      curses_cli.register_tab_comp_context(["run", "r"],
594                                           list(self._tensor_filters.keys()))
595    if self._feed_dict and hasattr(self._feed_dict, "keys"):
596      # Register tab completion for feed_dict keys.
597      feed_keys = [common.get_graph_element_name(key)
598                   for key in self._feed_dict.keys()]
599      curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys)
600
601  def _get_run_debug_urls(self):
602    """Get the debug_urls value for the current run() call.
603
604    Returns:
605      debug_urls: (list of str) Debug URLs for the current run() call.
606        Currently, the list consists of only one URL that is a file:// URL.
607    """
608
609    return ["file://" + self._dump_root]
610
611  def _update_run_calls_state(self,
612                              run_call_count,
613                              fetches,
614                              feed_dict,
615                              is_callable_runner=False):
616    """Update the internal state with regard to run() call history.
617
618    Args:
619      run_call_count: (int) Number of run() calls that have occurred.
620      fetches: a node/tensor or a list of node/tensor that are the fetches of
621        the run() call. This is the same as the fetches argument to the run()
622        call.
623      feed_dict: None of a dict. This is the feed_dict argument to the run()
624        call.
625      is_callable_runner: (bool) whether a runner returned by
626        Session.make_callable is being run.
627    """
628
629    self._run_call_count = run_call_count
630    self._feed_dict = feed_dict
631    self._run_description = cli_shared.get_run_short_description(
632        run_call_count,
633        fetches,
634        feed_dict,
635        is_callable_runner=is_callable_runner)
636    self._run_through_times -= 1
637
638    self._run_info = cli_shared.get_run_start_intro(
639        run_call_count,
640        fetches,
641        feed_dict,
642        self._tensor_filters,
643        is_callable_runner=is_callable_runner)
644