1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Debugger wrapper session that sends debug data to file:// URLs.""" 16import signal 17import sys 18import traceback 19 20# Google-internal import(s). 21from tensorflow.python.debug.lib import common 22from tensorflow.python.debug.wrappers import framework 23 24 25def publish_traceback(debug_server_urls, 26 graph, 27 feed_dict, 28 fetches, 29 old_graph_version): 30 """Publish traceback and source code if graph version is new. 31 32 `graph.version` is compared with `old_graph_version`. If the former is higher 33 (i.e., newer), the graph traceback and the associated source code is sent to 34 the debug server at the specified gRPC URLs. 35 36 Args: 37 debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of 38 debug server URLs. 39 graph: A Python `tf.Graph` object. 40 feed_dict: Feed dictionary given to the `Session.run()` call. 41 fetches: Fetches from the `Session.run()` call. 42 old_graph_version: Old graph version to compare to. 43 44 Returns: 45 If `graph.version > old_graph_version`, the new graph version as an `int`. 46 Else, the `old_graph_version` is returned. 47 """ 48 # TODO(cais): Consider moving this back to the top, after grpc becomes a 49 # pip dependency of tensorflow or tf_debug. 50 # pylint:disable=g-import-not-at-top 51 from tensorflow.python.debug.lib import source_remote 52 # pylint:enable=g-import-not-at-top 53 if graph.version > old_graph_version: 54 run_key = common.get_run_key(feed_dict, fetches) 55 source_remote.send_graph_tracebacks( 56 debug_server_urls, run_key, traceback.extract_stack(), graph, 57 send_source=True) 58 return graph.version 59 else: 60 return old_graph_version 61 62 63class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): 64 """Debug Session wrapper that send debug data to gRPC stream(s).""" 65 66 def __init__(self, 67 sess, 68 grpc_debug_server_addresses, 69 watch_fn=None, 70 thread_name_filter=None, 71 log_usage=True): 72 """Constructor of DumpingDebugWrapperSession. 73 74 Args: 75 sess: The TensorFlow `Session` object being wrapped. 76 grpc_debug_server_addresses: (`str` or `list` of `str`) Single or a list 77 of the gRPC debug server addresses, in the format of 78 <host:port>, with or without the "grpc://" prefix. For example: 79 "localhost:7000", 80 ["localhost:7000", "192.168.0.2:8000"] 81 watch_fn: (`Callable`) A Callable that can be used to define per-run 82 debug ops and watched tensors. See the doc of 83 `NonInteractiveDebugWrapperSession.__init__()` for details. 84 thread_name_filter: Regular-expression white list for threads on which the 85 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 86 more details. 87 log_usage: (`bool`) whether the usage of this class is to be logged. 88 89 Raises: 90 TypeError: If `grpc_debug_server_addresses` is not a `str` or a `list` 91 of `str`. 92 """ 93 94 if log_usage: 95 pass # No logging for open-source. 96 97 framework.NonInteractiveDebugWrapperSession.__init__( 98 self, sess, watch_fn=watch_fn, thread_name_filter=thread_name_filter) 99 100 if isinstance(grpc_debug_server_addresses, str): 101 self._grpc_debug_server_urls = [ 102 self._normalize_grpc_url(grpc_debug_server_addresses)] 103 elif isinstance(grpc_debug_server_addresses, list): 104 self._grpc_debug_server_urls = [] 105 for address in grpc_debug_server_addresses: 106 if not isinstance(address, str): 107 raise TypeError( 108 "Expected type str in list grpc_debug_server_addresses, " 109 "received type %s" % type(address)) 110 self._grpc_debug_server_urls.append(self._normalize_grpc_url(address)) 111 else: 112 raise TypeError( 113 "Expected type str or list in grpc_debug_server_addresses, " 114 "received type %s" % type(grpc_debug_server_addresses)) 115 116 def prepare_run_debug_urls(self, fetches, feed_dict): 117 """Implementation of abstract method in superclass. 118 119 See doc of `NonInteractiveDebugWrapperSession.prepare_run_debug_urls()` 120 for details. 121 122 Args: 123 fetches: Same as the `fetches` argument to `Session.run()` 124 feed_dict: Same as the `feed_dict` argument to `Session.run()` 125 126 Returns: 127 debug_urls: (`str` or `list` of `str`) file:// debug URLs to be used in 128 this `Session.run()` call. 129 """ 130 131 return self._grpc_debug_server_urls 132 133 def _normalize_grpc_url(self, address): 134 return (common.GRPC_URL_PREFIX + address 135 if not address.startswith(common.GRPC_URL_PREFIX) else address) 136 137 138def _signal_handler(unused_signal, unused_frame): 139 while True: 140 response = input("\nSIGINT received. Quit program? (Y/n): ").strip() 141 if response in ("", "Y", "y"): 142 sys.exit(0) 143 elif response in ("N", "n"): 144 break 145 146 147def register_signal_handler(): 148 try: 149 signal.signal(signal.SIGINT, _signal_handler) 150 except ValueError: 151 # This can happen if we are not in the MainThread. 152 pass 153 154 155class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): 156 """A tfdbg Session wrapper that can be used with TensorBoard Debugger Plugin. 157 158 This wrapper is the same as `GrpcDebugWrapperSession`, except that it uses a 159 predefined `watch_fn` that 160 1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to 161 `True` to allow the interactive enabling and disabling of tensor 162 breakpoints. 163 2) watches all tensors in the graph. 164 This saves the need for the user to define a `watch_fn`. 165 """ 166 167 def __init__(self, 168 sess, 169 grpc_debug_server_addresses, 170 thread_name_filter=None, 171 send_traceback_and_source_code=True, 172 log_usage=True): 173 """Constructor of TensorBoardDebugWrapperSession. 174 175 Args: 176 sess: The `tf.compat.v1.Session` instance to be wrapped. 177 grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a 178 `str` or a `list` of `str`s. E.g., "localhost:2333", 179 "grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"]. 180 thread_name_filter: Optional filter for thread names. 181 send_traceback_and_source_code: Whether traceback of graph elements and 182 the source code are to be sent to the debug server(s). 183 log_usage: Whether the usage of this class is to be logged (if 184 applicable). 185 """ 186 def _gated_grpc_watch_fn(fetches, feeds): 187 del fetches, feeds # Unused. 188 return framework.WatchOptions( 189 debug_ops=["DebugIdentity(gated_grpc=true)"]) 190 191 super().__init__( 192 sess, 193 grpc_debug_server_addresses, 194 watch_fn=_gated_grpc_watch_fn, 195 thread_name_filter=thread_name_filter, 196 log_usage=log_usage) 197 198 self._send_traceback_and_source_code = send_traceback_and_source_code 199 # Keeps track of the latest version of Python graph object that has been 200 # sent to the debug servers. 201 self._sent_graph_version = -1 202 203 register_signal_handler() 204 205 def run(self, 206 fetches, 207 feed_dict=None, 208 options=None, 209 run_metadata=None, 210 callable_runner=None, 211 callable_runner_args=None, 212 callable_options=None): 213 if self._send_traceback_and_source_code: 214 self._sent_graph_version = publish_traceback( 215 self._grpc_debug_server_urls, self.graph, feed_dict, fetches, 216 self._sent_graph_version) 217 return super().run( 218 fetches, 219 feed_dict=feed_dict, 220 options=options, 221 run_metadata=run_metadata, 222 callable_runner=callable_runner, 223 callable_runner_args=callable_runner_args, 224 callable_options=callable_options) 225