xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/cli/readline_ui.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Readline-Based Command-Line Interface of TensorFlow Debugger (tfdbg)."""
16import readline
17
18from tensorflow.python.debug.cli import base_ui
19from tensorflow.python.debug.cli import debugger_cli_common
20
21
22class ReadlineUI(base_ui.BaseUI):
23  """Readline-based Command-line UI."""
24
25  def __init__(self, on_ui_exit=None, config=None):
26    base_ui.BaseUI.__init__(self, on_ui_exit=on_ui_exit, config=config)
27    self._init_input()
28
29  def _init_input(self):
30    readline.parse_and_bind("set editing-mode emacs")
31
32    # Disable default readline delimiter in order to receive the full text
33    # (not just the last word) in the completer.
34    readline.set_completer_delims("\n")
35    readline.set_completer(self._readline_complete)
36    readline.parse_and_bind("tab: complete")
37
38    self._input = input
39
40  def _readline_complete(self, text, state):
41    context, prefix, except_last_word = self._analyze_tab_complete_input(text)
42    candidates, _ = self._tab_completion_registry.get_completions(context,
43                                                                  prefix)
44    candidates = [(except_last_word + candidate) for candidate in candidates]
45    return candidates[state]
46
47  def run_ui(self,
48             init_command=None,
49             title=None,
50             title_color=None,
51             enable_mouse_on_start=True):
52    """Run the CLI: See the doc of base_ui.BaseUI.run_ui for more details."""
53
54    print(title)
55
56    if init_command is not None:
57      self._dispatch_command(init_command)
58
59    exit_token = self._ui_loop()
60
61    if self._on_ui_exit:
62      self._on_ui_exit()
63
64    return exit_token
65
66  def _ui_loop(self):
67    while True:
68      command = self._get_user_command()
69
70      exit_token = self._dispatch_command(command)
71      if exit_token is not None:
72        return exit_token
73
74  def _get_user_command(self):
75    print("")
76    return self._input(self.CLI_PROMPT).strip()
77
78  def _dispatch_command(self, command):
79    """Dispatch user command.
80
81    Args:
82      command: (str) Command to dispatch.
83
84    Returns:
85      An exit token object. None value means that the UI loop should not exit.
86      A non-None value means the UI loop should exit.
87    """
88
89    if command in self.CLI_EXIT_COMMANDS:
90      # Explicit user command-triggered exit: EXPLICIT_USER_EXIT as the exit
91      # token.
92      return debugger_cli_common.EXPLICIT_USER_EXIT
93
94    try:
95      prefix, args, output_file_path = self._parse_command(command)
96    except SyntaxError as e:
97      print(str(e))
98      return
99
100    if self._command_handler_registry.is_registered(prefix):
101      try:
102        screen_output = self._command_handler_registry.dispatch_command(
103            prefix, args, screen_info=None)
104      except debugger_cli_common.CommandLineExit as e:
105        return e.exit_token
106    else:
107      screen_output = debugger_cli_common.RichTextLines([
108          self.ERROR_MESSAGE_PREFIX + "Invalid command prefix \"%s\"" % prefix
109      ])
110
111    self._display_output(screen_output)
112    if output_file_path:
113      try:
114        screen_output.write_to_file(output_file_path)
115        print("Wrote output to %s" % output_file_path)
116      except Exception:  # pylint: disable=broad-except
117        print("Failed to write output to %s" % output_file_path)
118
119  def _display_output(self, screen_output):
120    for line in screen_output.lines:
121      print(line)
122