1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""GRPC debug server for testing.""" 16import collections 17import errno 18import functools 19import hashlib 20import json 21import os 22import re 23import tempfile 24import threading 25import time 26 27import portpicker 28 29from tensorflow.core.debug import debug_service_pb2 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.util import event_pb2 32from tensorflow.python.client import session 33from tensorflow.python.debug.lib import debug_data 34from tensorflow.python.debug.lib import debug_utils 35from tensorflow.python.debug.lib import grpc_debug_server 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import errors 38from tensorflow.python.lib.io import file_io 39from tensorflow.python.ops import variables 40from tensorflow.python.util import compat 41 42 43def _get_dump_file_path(dump_root, device_name, debug_node_name): 44 """Get the file path of the dump file for a debug node. 45 46 Args: 47 dump_root: (str) Root dump directory. 48 device_name: (str) Name of the device that the debug node resides on. 49 debug_node_name: (str) Name of the debug node, e.g., 50 cross_entropy/Log:0:DebugIdentity. 51 52 Returns: 53 (str) Full path of the dump file. 54 """ 55 56 dump_root = os.path.join( 57 dump_root, debug_data.device_name_to_device_path(device_name)) 58 if "/" in debug_node_name: 59 dump_dir = os.path.join(dump_root, os.path.dirname(debug_node_name)) 60 dump_file_name = re.sub(":", "_", os.path.basename(debug_node_name)) 61 else: 62 dump_dir = dump_root 63 dump_file_name = re.sub(":", "_", debug_node_name) 64 65 now_microsec = int(round(time.time() * 1000 * 1000)) 66 dump_file_name += "_%d" % now_microsec 67 68 return os.path.join(dump_dir, dump_file_name) 69 70 71class EventListenerTestStreamHandler( 72 grpc_debug_server.EventListenerBaseStreamHandler): 73 """Implementation of EventListenerBaseStreamHandler that dumps to file.""" 74 75 def __init__(self, dump_dir, event_listener_servicer): 76 super(EventListenerTestStreamHandler, self).__init__() 77 self._dump_dir = dump_dir 78 self._event_listener_servicer = event_listener_servicer 79 if self._dump_dir: 80 self._try_makedirs(self._dump_dir) 81 82 self._grpc_path = None 83 self._cached_graph_defs = [] 84 self._cached_graph_def_device_names = [] 85 self._cached_graph_def_wall_times = [] 86 87 def on_core_metadata_event(self, event): 88 self._event_listener_servicer.toggle_watch() 89 90 core_metadata = json.loads(event.log_message.message) 91 92 if not self._grpc_path: 93 grpc_path = core_metadata["grpc_path"] 94 if grpc_path: 95 if grpc_path.startswith("/"): 96 grpc_path = grpc_path[1:] 97 if self._dump_dir: 98 self._dump_dir = os.path.join(self._dump_dir, grpc_path) 99 100 # Write cached graph defs to filesystem. 101 for graph_def, device_name, wall_time in zip( 102 self._cached_graph_defs, 103 self._cached_graph_def_device_names, 104 self._cached_graph_def_wall_times): 105 self._write_graph_def(graph_def, device_name, wall_time) 106 107 if self._dump_dir: 108 self._write_core_metadata_event(event) 109 else: 110 self._event_listener_servicer.core_metadata_json_strings.append( 111 event.log_message.message) 112 113 def on_graph_def(self, graph_def, device_name, wall_time): 114 """Implementation of the tensor value-carrying Event proto callback. 115 116 Args: 117 graph_def: A GraphDef object. 118 device_name: Name of the device on which the graph was created. 119 wall_time: An epoch timestamp (in microseconds) for the graph. 120 """ 121 if self._dump_dir: 122 if self._grpc_path: 123 self._write_graph_def(graph_def, device_name, wall_time) 124 else: 125 self._cached_graph_defs.append(graph_def) 126 self._cached_graph_def_device_names.append(device_name) 127 self._cached_graph_def_wall_times.append(wall_time) 128 else: 129 self._event_listener_servicer.partition_graph_defs.append(graph_def) 130 131 def on_value_event(self, event): 132 """Implementation of the tensor value-carrying Event proto callback. 133 134 Writes the Event proto to the file system for testing. The path written to 135 follows the same pattern as the file:// debug URLs of tfdbg, i.e., the 136 name scope of the op becomes the directory structure under the dump root 137 directory. 138 139 Args: 140 event: The Event proto carrying a tensor value. 141 142 Returns: 143 If the debug node belongs to the set of currently activated breakpoints, 144 a `EventReply` proto will be returned. 145 """ 146 if self._dump_dir: 147 self._write_value_event(event) 148 else: 149 value = event.summary.value[0] 150 tensor_value = debug_data.load_tensor_from_event(event) 151 self._event_listener_servicer.debug_tensor_values[value.node_name].append( 152 tensor_value) 153 154 items = event.summary.value[0].node_name.split(":") 155 node_name = items[0] 156 output_slot = int(items[1]) 157 debug_op = items[2] 158 if ((node_name, output_slot, debug_op) in 159 self._event_listener_servicer.breakpoints): 160 return debug_service_pb2.EventReply() 161 162 def _try_makedirs(self, dir_path): 163 if not os.path.isdir(dir_path): 164 try: 165 os.makedirs(dir_path) 166 except OSError as error: 167 if error.errno != errno.EEXIST: 168 raise 169 170 def _write_core_metadata_event(self, event): 171 core_metadata_path = os.path.join( 172 self._dump_dir, 173 debug_data.METADATA_FILE_PREFIX + debug_data.CORE_METADATA_TAG + 174 "_%d" % event.wall_time) 175 self._try_makedirs(self._dump_dir) 176 with open(core_metadata_path, "wb") as f: 177 f.write(event.SerializeToString()) 178 179 def _write_graph_def(self, graph_def, device_name, wall_time): 180 encoded_graph_def = graph_def.SerializeToString() 181 graph_hash = int(hashlib.sha1(encoded_graph_def).hexdigest(), 16) 182 event = event_pb2.Event(graph_def=encoded_graph_def, wall_time=wall_time) 183 graph_file_path = os.path.join( 184 self._dump_dir, 185 debug_data.device_name_to_device_path(device_name), 186 debug_data.METADATA_FILE_PREFIX + debug_data.GRAPH_FILE_TAG + 187 debug_data.HASH_TAG + "%d_%d" % (graph_hash, wall_time)) 188 self._try_makedirs(os.path.dirname(graph_file_path)) 189 with open(graph_file_path, "wb") as f: 190 f.write(event.SerializeToString()) 191 192 def _write_value_event(self, event): 193 value = event.summary.value[0] 194 195 # Obtain the device name from the metadata. 196 summary_metadata = event.summary.value[0].metadata 197 if not summary_metadata.plugin_data: 198 raise ValueError("The value lacks plugin data.") 199 try: 200 content = json.loads(compat.as_text(summary_metadata.plugin_data.content)) 201 except ValueError as err: 202 raise ValueError("Could not parse content into JSON: %r, %r" % (content, 203 err)) 204 device_name = content["device"] 205 206 dump_full_path = _get_dump_file_path( 207 self._dump_dir, device_name, value.node_name) 208 self._try_makedirs(os.path.dirname(dump_full_path)) 209 with open(dump_full_path, "wb") as f: 210 f.write(event.SerializeToString()) 211 212 213class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): 214 """An implementation of EventListenerBaseServicer for testing.""" 215 216 def __init__(self, server_port, dump_dir, toggle_watch_on_core_metadata=None): 217 """Constructor of EventListenerTestServicer. 218 219 Args: 220 server_port: (int) The server port number. 221 dump_dir: (str) The root directory to which the data files will be 222 dumped. If empty or None, the received debug data will not be dumped 223 to the file system: they will be stored in memory instead. 224 toggle_watch_on_core_metadata: A list of 225 (node_name, output_slot, debug_op) tuples to toggle the 226 watchpoint status during the on_core_metadata calls (optional). 227 """ 228 self.core_metadata_json_strings = [] 229 self.partition_graph_defs = [] 230 self.debug_tensor_values = collections.defaultdict(list) 231 self._initialize_toggle_watch_state(toggle_watch_on_core_metadata) 232 233 grpc_debug_server.EventListenerBaseServicer.__init__( 234 self, server_port, 235 functools.partial(EventListenerTestStreamHandler, dump_dir, self)) 236 237 # Members for storing the graph ops traceback and source files. 238 self._call_types = [] 239 self._call_keys = [] 240 self._origin_stacks = [] 241 self._origin_id_to_strings = [] 242 self._graph_tracebacks = [] 243 self._graph_versions = [] 244 self._source_files = [] 245 246 def _initialize_toggle_watch_state(self, toggle_watches): 247 self._toggle_watches = toggle_watches 248 self._toggle_watch_state = {} 249 if self._toggle_watches: 250 for watch_key in self._toggle_watches: 251 self._toggle_watch_state[watch_key] = False 252 253 def toggle_watch(self): 254 for watch_key in self._toggle_watch_state: 255 node_name, output_slot, debug_op = watch_key 256 if self._toggle_watch_state[watch_key]: 257 self.request_unwatch(node_name, output_slot, debug_op) 258 else: 259 self.request_watch(node_name, output_slot, debug_op) 260 self._toggle_watch_state[watch_key] = ( 261 not self._toggle_watch_state[watch_key]) 262 263 def clear_data(self): 264 self.core_metadata_json_strings = [] 265 self.partition_graph_defs = [] 266 self.debug_tensor_values = collections.defaultdict(list) 267 self._call_types = [] 268 self._call_keys = [] 269 self._origin_stacks = [] 270 self._origin_id_to_strings = [] 271 self._graph_tracebacks = [] 272 self._graph_versions = [] 273 self._source_files = [] 274 275 def SendTracebacks(self, request, context): 276 self._call_types.append(request.call_type) 277 self._call_keys.append(request.call_key) 278 self._origin_stacks.append(request.origin_stack) 279 self._origin_id_to_strings.append(request.origin_id_to_string) 280 self._graph_tracebacks.append(request.graph_traceback) 281 self._graph_versions.append(request.graph_version) 282 return debug_service_pb2.EventReply() 283 284 def SendSourceFiles(self, request, context): 285 self._source_files.append(request) 286 return debug_service_pb2.EventReply() 287 288 def query_op_traceback(self, op_name): 289 """Query the traceback of an op. 290 291 Args: 292 op_name: Name of the op to query. 293 294 Returns: 295 The traceback of the op, as a list of 3-tuples: 296 (filename, lineno, function_name) 297 298 Raises: 299 ValueError: If the op cannot be found in the tracebacks received by the 300 server so far. 301 """ 302 for op_log_proto in self._graph_tracebacks: 303 for log_entry in op_log_proto.log_entries: 304 if log_entry.name == op_name: 305 return self._code_def_to_traceback(log_entry.code_def, 306 op_log_proto.id_to_string) 307 raise ValueError( 308 "Op '%s' does not exist in the tracebacks received by the debug " 309 "server." % op_name) 310 311 def query_origin_stack(self): 312 """Query the stack of the origin of the execution call. 313 314 Returns: 315 A `list` of all tracebacks. Each item corresponds to an execution call, 316 i.e., a `SendTracebacks` request. Each item is a `list` of 3-tuples: 317 (filename, lineno, function_name). 318 """ 319 ret = [] 320 for stack, id_to_string in zip( 321 self._origin_stacks, self._origin_id_to_strings): 322 ret.append(self._code_def_to_traceback(stack, id_to_string)) 323 return ret 324 325 def query_call_types(self): 326 return self._call_types 327 328 def query_call_keys(self): 329 return self._call_keys 330 331 def query_graph_versions(self): 332 return self._graph_versions 333 334 def query_source_file_line(self, file_path, lineno): 335 """Query the content of a given line in a source file. 336 337 Args: 338 file_path: Path to the source file. 339 lineno: Line number as an `int`. 340 341 Returns: 342 Content of the line as a string. 343 344 Raises: 345 ValueError: If no source file is found at the given file_path. 346 """ 347 if not self._source_files: 348 raise ValueError( 349 "This debug server has not received any source file contents yet.") 350 for source_files in self._source_files: 351 for source_file_proto in source_files.source_files: 352 if source_file_proto.file_path == file_path: 353 return source_file_proto.lines[lineno - 1] 354 raise ValueError( 355 "Source file at path %s has not been received by the debug server", 356 file_path) 357 358 def _code_def_to_traceback(self, code_def, id_to_string): 359 return [(id_to_string[trace.file_id], 360 trace.lineno, 361 id_to_string[trace.function_id]) for trace in code_def.traces] 362 363 364def start_server_on_separate_thread(dump_to_filesystem=True, 365 server_start_delay_sec=0.0, 366 poll_server=False, 367 blocking=True, 368 toggle_watch_on_core_metadata=None): 369 """Create a test gRPC debug server and run on a separate thread. 370 371 Args: 372 dump_to_filesystem: (bool) whether the debug server will dump debug data 373 to the filesystem. 374 server_start_delay_sec: (float) amount of time (in sec) to delay the server 375 start up for. 376 poll_server: (bool) whether the server will be polled till success on 377 startup. 378 blocking: (bool) whether the server should be started in a blocking mode. 379 toggle_watch_on_core_metadata: A list of 380 (node_name, output_slot, debug_op) tuples to toggle the 381 watchpoint status during the on_core_metadata calls (optional). 382 383 Returns: 384 server_port: (int) Port on which the server runs. 385 debug_server_url: (str) grpc:// URL to the server. 386 server_dump_dir: (str) The debug server's dump directory. 387 server_thread: The server Thread object. 388 server: The `EventListenerTestServicer` object. 389 390 Raises: 391 ValueError: If polling the server process for ready state is not successful 392 within maximum polling count. 393 """ 394 server_port = portpicker.pick_unused_port() 395 debug_server_url = "grpc://localhost:%d" % server_port 396 397 server_dump_dir = tempfile.mkdtemp() if dump_to_filesystem else None 398 server = EventListenerTestServicer( 399 server_port=server_port, 400 dump_dir=server_dump_dir, 401 toggle_watch_on_core_metadata=toggle_watch_on_core_metadata) 402 403 def delay_then_run_server(): 404 time.sleep(server_start_delay_sec) 405 server.run_server(blocking=blocking) 406 407 server_thread = threading.Thread(target=delay_then_run_server) 408 server_thread.start() 409 410 if poll_server: 411 if not _poll_server_till_success( 412 50, 413 0.2, 414 debug_server_url, 415 server_dump_dir, 416 server, 417 gpu_memory_fraction=0.1): 418 raise ValueError( 419 "Failed to start test gRPC debug server at port %d" % server_port) 420 server.clear_data() 421 return server_port, debug_server_url, server_dump_dir, server_thread, server 422 423 424def _poll_server_till_success(max_attempts, 425 sleep_per_poll_sec, 426 debug_server_url, 427 dump_dir, 428 server, 429 gpu_memory_fraction=1.0): 430 """Poll server until success or exceeding max polling count. 431 432 Args: 433 max_attempts: (int) How many times to poll at maximum 434 sleep_per_poll_sec: (float) How many seconds to sleep for after each 435 unsuccessful poll. 436 debug_server_url: (str) gRPC URL to the debug server. 437 dump_dir: (str) Dump directory to look for files in. If None, will directly 438 check data from the server object. 439 server: The server object. 440 gpu_memory_fraction: (float) Fraction of GPU memory to be 441 allocated for the Session used in server polling. 442 443 Returns: 444 (bool) Whether the polling succeeded within max_polls attempts. 445 """ 446 poll_count = 0 447 448 config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions( 449 per_process_gpu_memory_fraction=gpu_memory_fraction)) 450 with session.Session(config=config) as sess: 451 for poll_count in range(max_attempts): 452 server.clear_data() 453 print("Polling: poll_count = %d" % poll_count) 454 455 x_init_name = "x_init_%d" % poll_count 456 x_init = constant_op.constant([42.0], shape=[1], name=x_init_name) 457 x = variables.Variable(x_init, name=x_init_name) 458 459 run_options = config_pb2.RunOptions() 460 debug_utils.add_debug_tensor_watch( 461 run_options, x_init_name, 0, debug_urls=[debug_server_url]) 462 try: 463 sess.run(x.initializer, options=run_options) 464 except errors.FailedPreconditionError: 465 pass 466 467 if dump_dir: 468 if os.path.isdir( 469 dump_dir) and debug_data.DebugDumpDir(dump_dir).size > 0: 470 file_io.delete_recursively(dump_dir) 471 print("Poll succeeded.") 472 return True 473 else: 474 print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec) 475 time.sleep(sleep_per_poll_sec) 476 else: 477 if server.debug_tensor_values: 478 print("Poll succeeded.") 479 return True 480 else: 481 print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec) 482 time.sleep(sleep_per_poll_sec) 483 484 return False 485