1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions to handle debug-dump data of TensorFlow Debugger.""" 16 17import collections 18import glob 19import json 20import os 21import platform 22import re 23 24import numpy as np 25 26from tensorflow.core.framework import graph_pb2 27from tensorflow.core.framework import types_pb2 28from tensorflow.core.util import event_pb2 29from tensorflow.python.debug.lib import debug_graphs 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.platform import gfile 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.util import compat 34 35 36# TODO(cais): Tie these string constants in with C++? 37METADATA_FILE_PREFIX = "_tfdbg_" 38CORE_METADATA_TAG = "core_metadata_" 39GRAPH_FILE_TAG = "graph_" 40DEVICE_TAG = "device_" 41HASH_TAG = "hash" 42 43FETCHES_INFO_FILE_TAG = "fetches_info_" 44FEED_KEYS_INFO_FILE_TAG = "feed_keys_info_" 45 46 47def _glob(glob_pattern): 48 if platform.system() == "Windows": 49 return glob.glob(glob_pattern) 50 else: 51 return gfile.Glob(glob_pattern) 52 53 54class InconvertibleTensorProto: 55 """Represents a TensorProto that cannot be converted to np.ndarray.""" 56 57 def __init__(self, tensor_proto, initialized=True): 58 """Constructor. 59 60 Args: 61 tensor_proto: the `TensorProto` object that cannot be represented as a 62 `np.ndarray` object. 63 initialized: (`bool`) whether the Tensor is initialized. 64 """ 65 self._tensor_proto = tensor_proto 66 self._initialized = initialized 67 68 def __str__(self): 69 output = "" if self._initialized else "Uninitialized tensor:\n" 70 output += str(self._tensor_proto) 71 return output 72 73 @property 74 def initialized(self): 75 return self._initialized 76 77 78def load_tensor_from_event_file(event_file_path): 79 """Load a tensor from an event file. 80 81 Assumes that the event file contains a `Event` protobuf and the `Event` 82 protobuf contains a `Tensor` value. 83 84 Args: 85 event_file_path: (`str`) path to the event file. 86 87 Returns: 88 The tensor value loaded from the event file, as a `numpy.ndarray`. For 89 uninitialized Tensors, returns `None`. For Tensors of data types that 90 cannot be converted to `numpy.ndarray` (e.g., `tf.resource`), return 91 `None`. 92 """ 93 94 event = event_pb2.Event() 95 with gfile.Open(event_file_path, "rb") as f: 96 event.ParseFromString(f.read()) 97 return load_tensor_from_event(event) 98 99 100def load_tensor_from_event(event): 101 """Load a tensor from an Event proto. 102 103 Args: 104 event: The Event proto, assumed to hold a tensor value in its 105 summary.value[0] field. 106 107 Returns: 108 The tensor value loaded from the event file, as a `numpy.ndarray`, if 109 representation of the tensor value by a `numpy.ndarray` is possible. 110 For uninitialized Tensors, returns `None`. For Tensors of data types that 111 cannot be represented as `numpy.ndarray` (e.g., `tf.resource`), return 112 the `TensorProto` protobuf object without converting it to a 113 `numpy.ndarray`. 114 """ 115 116 tensor_proto = event.summary.value[0].tensor 117 shape = tensor_util.TensorShapeProtoToList(tensor_proto.tensor_shape) 118 num_elements = 1 119 for shape_dim in shape: 120 num_elements *= shape_dim 121 122 if tensor_proto.tensor_content or tensor_proto.string_val or not num_elements: 123 # Initialized tensor or empty tensor. 124 if tensor_proto.dtype == types_pb2.DT_RESOURCE: 125 tensor_value = InconvertibleTensorProto(tensor_proto) 126 else: 127 try: 128 tensor_value = tensor_util.MakeNdarray(tensor_proto) 129 except KeyError: 130 tensor_value = InconvertibleTensorProto(tensor_proto) 131 else: 132 # Uninitialized tensor or tensor of unconvertible data type. 133 tensor_value = InconvertibleTensorProto(tensor_proto, False) 134 135 return tensor_value 136 137 138def _load_graph_def_from_event_file(event_file_path): 139 event = event_pb2.Event() 140 with gfile.Open(event_file_path, "rb") as f: 141 event.ParseFromString(f.read()) 142 143 return graph_pb2.GraphDef.FromString(event.graph_def) 144 145 146def _load_log_message_from_event_file(event_file_path): 147 event = event_pb2.Event() 148 with gfile.Open(event_file_path, "rb") as f: 149 event.ParseFromString(f.read()) 150 151 return event.log_message.message 152 153 154def _is_graph_file(file_name): 155 return file_name.startswith(METADATA_FILE_PREFIX + GRAPH_FILE_TAG) 156 157 158def _is_run_fetches_info_file(file_name): 159 return file_name == METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG 160 161 162def _is_run_feed_keys_info_file(file_name): 163 return file_name == METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG 164 165 166def _get_tensor_name(node_name, output_slot): 167 """Get tensor name given node name and output slot index. 168 169 Args: 170 node_name: Name of the node that outputs the tensor, as a string. 171 output_slot: Output slot index of the tensor, as an integer. 172 173 Returns: 174 Name of the tensor, as a string. 175 """ 176 177 return "%s:%d" % (node_name, output_slot) 178 179 180def _get_tensor_watch_key(node_name, output_slot, debug_op): 181 """Get the string representation of a debug watch on a tensor. 182 183 Args: 184 node_name: Name of the node by which the watched tensor is produced, as a 185 string. 186 output_slot: Output slot index of the tensor, as an integer. 187 debug_op: Name of the debug op that is used to watch the tensor, as a 188 string. 189 190 Returns: 191 A string representing the debug watch on the tensor (i.e., the "watch 192 key"). 193 """ 194 return "%s:%s" % (_get_tensor_name(node_name, output_slot), debug_op) 195 196 197def has_inf_or_nan(datum, tensor): 198 """A predicate for whether a tensor consists of any bad numerical values. 199 200 This predicate is common enough to merit definition in this module. 201 Bad numerical values include `nan`s and `inf`s. 202 The signature of this function follows the requirement of the method 203 `DebugDumpDir.find()`. 204 205 Args: 206 datum: (`DebugTensorDatum`) Datum metadata. 207 tensor: (`numpy.ndarray` or None) Value of the tensor. None represents 208 an uninitialized tensor. 209 210 Returns: 211 (`bool`) True if and only if tensor consists of any nan or inf values. 212 """ 213 214 _ = datum # Datum metadata is unused in this predicate. 215 216 if isinstance(tensor, InconvertibleTensorProto): 217 # Uninitialized tensor doesn't have bad numerical values. 218 # Also return False for data types that cannot be represented as numpy 219 # arrays. 220 return False 221 elif (np.issubdtype(tensor.dtype, np.floating) or 222 np.issubdtype(tensor.dtype, np.complexfloating) or 223 np.issubdtype(tensor.dtype, np.integer)): 224 return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor)) 225 else: 226 return False 227 228 229_CoreMetadata = collections.namedtuple("CoreMetadata", [ 230 "global_step", "session_run_index", "executor_step_index", "input_names", 231 "output_names", "target_nodes" 232]) 233 234 235def extract_core_metadata_from_event_proto(event): 236 json_metadata = json.loads(event.log_message.message) 237 return _CoreMetadata(json_metadata["global_step"], 238 json_metadata["session_run_index"], 239 json_metadata["executor_step_index"], 240 json_metadata["input_names"], 241 json_metadata["output_names"], 242 json_metadata["target_nodes"]) 243 244 245def device_name_to_device_path(device_name): 246 """Convert device name to device path.""" 247 device_name_items = compat.as_text(device_name).split("/") 248 device_name_items = [item.replace(":", "_") for item in device_name_items] 249 return METADATA_FILE_PREFIX + DEVICE_TAG + ",".join(device_name_items) 250 251 252def device_path_to_device_name(device_dir): 253 """Parse device name from device path. 254 255 Args: 256 device_dir: (str) a directory name for the device. 257 258 Returns: 259 (str) parsed device name. 260 """ 261 path_items = os.path.basename(device_dir)[ 262 len(METADATA_FILE_PREFIX) + len(DEVICE_TAG):].split(",") 263 return "/".join([ 264 path_item.replace("device_", "device:").replace("_", ":", 1) 265 for path_item in path_items]) 266 267 268class DebugTensorDatum: 269 """A single tensor dumped by TensorFlow Debugger (tfdbg). 270 271 Contains metadata about the dumped tensor, including `timestamp`, 272 `node_name`, `output_slot`, `debug_op`, and path to the dump file 273 (`file_path`). 274 275 This type does not hold the generally space-expensive tensor value (numpy 276 array). Instead, it points to the file from which the tensor value can be 277 loaded (with the `get_tensor` method) if needed. 278 """ 279 280 def __init__(self, dump_root, debug_dump_rel_path): 281 """`DebugTensorDatum` constructor. 282 283 Args: 284 dump_root: (`str`) Debug dump root directory. This path should not include 285 the path component that represents the device name (see also below). 286 debug_dump_rel_path: (`str`) Path to a debug dump file, relative to the 287 `dump_root`. The first item of this relative path is assumed to be 288 a path representing the name of the device that the Tensor belongs to. 289 See `device_path_to_device_name` for more details on the device path. 290 For example, suppose the debug dump root 291 directory is `/tmp/tfdbg_1` and the dump file is at 292 `/tmp/tfdbg_1/<device_path>/>ns_1/node_a_0_DebugIdentity_123456789`, 293 then the value of the debug_dump_rel_path should be 294 `<device_path>/ns_1/node_a_0_DebugIdentity_1234456789`. 295 296 Raises: 297 ValueError: If the base file name of the dump file does not conform to 298 the dump file naming pattern: 299 `node_name`_`output_slot`_`debug_op`_`timestamp` 300 """ 301 302 path_components = os.path.normpath(debug_dump_rel_path).split(os.sep) 303 self._device_name = device_path_to_device_name(path_components[0]) 304 base = path_components[-1] 305 if base.count("_") < 3: 306 raise ValueError( 307 "Dump file path does not conform to the naming pattern: %s" % base) 308 309 self._extended_timestamp = base.split("_")[-1] 310 # It may include an index suffix at the end if file path collision happened 311 # due to identical timestamps. 312 if "-" in self._extended_timestamp: 313 self._timestamp = int( 314 self._extended_timestamp[:self._extended_timestamp.find("-")]) 315 else: 316 self._timestamp = int(self._extended_timestamp) 317 318 self._debug_op = base.split("_")[-2] 319 self._output_slot = int(base.split("_")[-3]) 320 321 node_base_name = "_".join(base.split("_")[:-3]) 322 self._node_name = "/".join(path_components[1:-1] + [node_base_name]) 323 324 self._file_path = os.path.join(dump_root, debug_dump_rel_path) 325 self._dump_size_bytes = (gfile.Stat(self._file_path).length if 326 gfile.Exists(self._file_path) else None) 327 328 def __str__(self): 329 return "{DebugTensorDatum (%s) %s:%d @ %s @ %d}" % (self.device_name, 330 self.node_name, 331 self.output_slot, 332 self.debug_op, 333 self.timestamp) 334 335 def __repr__(self): 336 return self.__str__() 337 338 def get_tensor(self): 339 """Get tensor from the dump (`Event`) file. 340 341 Returns: 342 The tensor loaded from the dump (`Event`) file. 343 """ 344 345 return load_tensor_from_event_file(self.file_path) 346 347 # TODO(cais): Add time unit suffix to timestamp and t0 (us). 348 @property 349 def timestamp(self): 350 """Timestamp of when this tensor value was dumped. 351 352 Returns: 353 (`int`) The timestamp in microseconds. 354 """ 355 356 return self._timestamp 357 358 @property 359 def extended_timestamp(self): 360 """Extended timestamp, possibly with an index suffix. 361 362 The index suffix, e.g., "-1", is for disambiguating multiple dumps of the 363 same tensor with the same timestamp, which can occur if the dumping events 364 are spaced by shorter than the temporal resolution of the timestamps. 365 366 Returns: 367 (`str`) The extended timestamp. 368 """ 369 370 return self._extended_timestamp 371 372 @property 373 def debug_op(self): 374 """Name of the debug op. 375 376 Returns: 377 (`str`) debug op name (e.g., `DebugIdentity`). 378 """ 379 380 return self._debug_op 381 382 @property 383 def device_name(self): 384 """Name of the device that the tensor belongs to. 385 386 Returns: 387 (`str`) device name. 388 """ 389 390 return self._device_name 391 392 @property 393 def node_name(self): 394 """Name of the node from which the tensor value was dumped. 395 396 Returns: 397 (`str`) name of the node watched by the debug op. 398 """ 399 400 return self._node_name 401 402 @property 403 def output_slot(self): 404 """Output slot index from which the tensor value was dumped. 405 406 Returns: 407 (`int`) output slot index watched by the debug op. 408 """ 409 410 return self._output_slot 411 412 @property 413 def tensor_name(self): 414 """Name of the tensor watched by the debug op. 415 416 Returns: 417 (`str`) `Tensor` name, in the form of `node_name`:`output_slot` 418 """ 419 420 return _get_tensor_name(self.node_name, self.output_slot) 421 422 @property 423 def watch_key(self): 424 """Watch key identities a debug watch on a tensor. 425 426 Returns: 427 (`str`) A watch key, in the form of `tensor_name`:`debug_op`. 428 """ 429 430 return _get_tensor_watch_key(self.node_name, self.output_slot, 431 self.debug_op) 432 433 @property 434 def file_path(self): 435 """Path to the file which stores the value of the dumped tensor.""" 436 437 return self._file_path 438 439 @property 440 def dump_size_bytes(self): 441 """Size of the dump file. 442 443 Unit: byte. 444 445 Returns: 446 If the dump file exists, size of the dump file, in bytes. 447 If the dump file does not exist, None. 448 """ 449 450 return self._dump_size_bytes 451 452 453class WatchKeyDoesNotExistInDebugDumpDirError(ValueError): 454 pass 455 456 457class DebugDumpDir: 458 """Data set from a debug-dump directory on filesystem. 459 460 An instance of `DebugDumpDir` contains all `DebugTensorDatum` instances 461 in a tfdbg dump root directory. 462 """ 463 464 def __init__(self, dump_root, partition_graphs=None, validate=True): 465 """`DebugDumpDir` constructor. 466 467 Args: 468 dump_root: (`str`) path to the dump root directory. 469 partition_graphs: A repeated field of GraphDefs representing the 470 partition graphs executed by the TensorFlow runtime. 471 validate: (`bool`) whether the dump files are to be validated against the 472 partition graphs. 473 474 Raises: 475 IOError: If dump_root does not exist as a directory. 476 ValueError: If more than one core metadata file is found under the dump 477 root directory. 478 """ 479 480 if not gfile.IsDirectory(dump_root): 481 raise IOError("Dump root directory %s does not exist" % dump_root) 482 483 self._core_metadata = [] 484 485 # Find the list of devices. 486 self._dump_root = dump_root 487 488 self._load_core_metadata() 489 self._load_fetches_info() 490 self._load_feeds_info() 491 self._load_all_device_dumps(partition_graphs, validate) 492 493 self._python_graph = None 494 495 def _load_all_device_dumps(self, partition_graphs, validate): 496 """Load the dump data for all devices.""" 497 device_dirs = _glob(os.path.join( 498 self._dump_root, METADATA_FILE_PREFIX + DEVICE_TAG + "*")) 499 500 self._device_names = [] 501 self._t0s = {} 502 self._dump_tensor_data = {} 503 self._dump_graph_file_paths = {} 504 self._debug_watches = {} 505 self._watch_key_to_devices = {} 506 self._watch_key_to_datum = {} 507 self._watch_key_to_rel_time = {} 508 self._watch_key_to_dump_size_bytes = {} 509 for device_dir in device_dirs: 510 device_name = device_path_to_device_name(device_dir) 511 self._device_names.append(device_name) 512 self._load_device_dumps(device_name, device_dir) 513 self._load_partition_graphs(partition_graphs, validate) 514 self._calculate_t0() 515 516 for device_name in self._device_names: 517 self._create_tensor_watch_maps(device_name) 518 519 def _load_device_dumps(self, device_name, device_root): 520 """Load `DebugTensorDatum` instances from the dump root of a given device. 521 522 Populates a map {device_name: a list of `DebugTensorDatum`}, where the list 523 is sorted by ascending timestamp. 524 525 This sorting order reflects the order in which the TensorFlow executor 526 processed the nodes of the graph. It is (one of many possible) topological 527 sort of the nodes. This is useful for displaying tensors in the debugger 528 frontend as well as for the use case in which the user wants to find a 529 "culprit tensor", i.e., the first tensor in the graph that exhibits certain 530 problematic properties, i.e., all zero values, or bad numerical values such 531 as nan and inf. 532 533 In addition, creates a map from node name to debug watches. In this Map, 534 the key is the watched node name; the value is a dictionary. 535 Of this dictionary, the key is the watched_output_slot. 536 537 This method attempts to load the debug watches from the tensor dump files 538 first, before loading the full set of debug watches from the partition 539 graphs as done later. This is necessary because sometimes the partition 540 graphs may not be available, e.g., when the run errors out. 541 542 Args: 543 device_name: (`str`) name of the device. 544 device_root: (`str`) dump root directory of the given device. 545 546 Raises: 547 ValueError: If GraphDef for the device is not available. 548 """ 549 550 self._dump_tensor_data[device_name] = [] 551 self._debug_watches[device_name] = collections.defaultdict( 552 lambda: collections.defaultdict(set)) 553 554 for root, _, files in gfile.Walk(device_root): 555 for f in files: 556 if _is_graph_file(f): 557 self._dump_graph_file_paths[device_name] = os.path.join(root, f) 558 else: 559 datum = self._dump_file_name_to_datum(root, f) 560 self._dump_tensor_data[device_name].append(datum) 561 self._debug_watches[device_name][datum.node_name][ 562 datum.output_slot].add(datum.debug_op) 563 564 self._dump_tensor_data[device_name] = sorted( 565 self._dump_tensor_data[device_name], 566 key=lambda x: x.extended_timestamp) 567 568 if self._dump_tensor_data[device_name]: 569 self._t0s[device_name] = self._dump_tensor_data[device_name][0].timestamp 570 else: 571 self._t0s[device_name] = None 572 573 def _calculate_t0(self): 574 """Calculate the first timestamp across all devices.""" 575 t0s = [t0 for t0 in self._t0s.values() if t0 is not None] 576 self._t0 = min(t0s) if t0s else None 577 578 def _load_core_metadata(self): 579 core_metadata_files = _glob(os.path.join( 580 self._dump_root, METADATA_FILE_PREFIX + CORE_METADATA_TAG + "*")) 581 for core_metadata_file in core_metadata_files: 582 with gfile.Open(core_metadata_file, "rb") as f: 583 event = event_pb2.Event() 584 event.ParseFromString(f.read()) 585 self._core_metadata.append( 586 extract_core_metadata_from_event_proto(event)) 587 588 def _load_fetches_info(self): 589 fetches_info_files = _glob(os.path.join( 590 self._dump_root, METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG + "*")) 591 self._run_fetches_info = [] 592 for fetches_info_file in fetches_info_files: 593 self._run_fetches_info.append( 594 _load_log_message_from_event_file(fetches_info_file)) 595 596 def _load_feeds_info(self): 597 feeds_info_files = _glob(os.path.join( 598 self._dump_root, METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG + "*")) 599 self._run_feed_keys_info = [] 600 for feeds_info_file in feeds_info_files: 601 self._run_feed_keys_info.append( 602 _load_log_message_from_event_file(feeds_info_file)) 603 604 def _dump_file_name_to_datum(self, dir_name, file_name): 605 """Obtain a DebugTensorDatum from the directory and file name. 606 607 Args: 608 dir_name: (`str`) Name of the directory in which the dump file resides. 609 file_name: (`str`) Base name of the dump file. 610 611 Returns: 612 (`DebugTensorDatum`) The `DebugTensorDatum` loaded from the dump file. 613 """ 614 615 # Calculate the relative path of the dump file with respect to the root. 616 debug_dump_rel_path = os.path.join( 617 os.path.relpath(dir_name, self._dump_root), file_name) 618 return DebugTensorDatum(self._dump_root, debug_dump_rel_path) 619 620 def _create_tensor_watch_maps(self, device_name): 621 """Create maps from tensor watch keys to datum and to timestamps. 622 623 Create a map from watch key (tensor name + debug op) to `DebugTensorDatum` 624 item. Also make a map from watch key to relative timestamp. 625 "relative" means (absolute timestamp - t0). 626 627 Args: 628 device_name: (str) name of the device. 629 """ 630 631 self._watch_key_to_datum[device_name] = {} 632 self._watch_key_to_rel_time[device_name] = {} 633 self._watch_key_to_dump_size_bytes[device_name] = {} 634 for datum in self._dump_tensor_data[device_name]: 635 if datum.watch_key not in self._watch_key_to_devices: 636 self._watch_key_to_devices[datum.watch_key] = {device_name} 637 else: 638 self._watch_key_to_devices[datum.watch_key].add(device_name) 639 640 if datum.watch_key not in self._watch_key_to_datum[device_name]: 641 self._watch_key_to_datum[device_name][datum.watch_key] = [datum] 642 self._watch_key_to_rel_time[device_name][datum.watch_key] = [ 643 datum.timestamp - self._t0] 644 self._watch_key_to_dump_size_bytes[device_name][datum.watch_key] = [ 645 datum.dump_size_bytes] 646 else: 647 self._watch_key_to_datum[device_name][datum.watch_key].append(datum) 648 self._watch_key_to_rel_time[device_name][datum.watch_key].append( 649 datum.timestamp - self._t0) 650 self._watch_key_to_dump_size_bytes[device_name][datum.watch_key].append( 651 datum.dump_size_bytes) 652 653 def set_python_graph(self, python_graph): 654 """Provide Python `Graph` object to the wrapper. 655 656 Unlike the partition graphs, which are protobuf `GraphDef` objects, `Graph` 657 is a Python object and carries additional information such as the traceback 658 of the construction of the nodes in the graph. 659 660 Args: 661 python_graph: (ops.Graph) The Python Graph object. 662 """ 663 664 self._python_graph = python_graph 665 self._node_traceback = {} 666 if self._python_graph: 667 for op in self._python_graph.get_operations(): 668 self._node_traceback[op.name] = tuple(map(tuple, op.traceback)) 669 670 @property 671 def python_graph(self): 672 """Get the Python graph. 673 674 Returns: 675 If the Python graph has been set, returns a `tf.Graph` object. Otherwise, 676 returns None. 677 """ 678 679 return self._python_graph 680 681 @property 682 def core_metadata(self): 683 """Metadata about the `Session.run()` call from the core runtime. 684 685 Of the three counters available in the return value, `global_step` is 686 supplied by the caller of the debugged `Session.run()`, while 687 `session_run_index` and `executor_step_index` are determined by the state 688 of the core runtime, automatically. For the same fetch list, feed keys and 689 debug tensor watch options, the same executor will be used and 690 `executor_step_index` should increase by one at a time. However, runs with 691 different fetch lists, feed keys and debug_tensor watch options that all 692 share the same `Session` object can lead to gaps in `session_run_index`. 693 694 Returns: 695 If core metadata are loaded, a `namedtuple` with the fields: 696 `global_step`: A global step count supplied by the caller of 697 `Session.run()`. It is optional to the caller. If the caller did not 698 supply this parameter, its value will be -1. 699 `session_run_index`: A sorted index for Run() calls to the underlying 700 TensorFlow `Session` object. 701 `executor_step_index`: A counter for invocations of a given runtime 702 executor. The same executor is re-used for the same fetched tensors, 703 target nodes, input feed keys and debug tensor watch options. 704 `input_names`: Names of the input (feed) Tensors. 705 `output_names`: Names of the output (fetched) Tensors. 706 `target_nodes`: Names of the target nodes. 707 If the core metadata have not been loaded, `None`. 708 If more than one core metadata files exist, return a list of the 709 `nametuple` described above. 710 """ 711 712 output = self._core_metadata 713 return output[0] if len(output) == 1 else output 714 715 @property 716 def dumped_tensor_data(self): 717 """Retrieve dumped tensor data.""" 718 if len(self.devices()) == 1: 719 return self._dump_tensor_data[self.devices()[0]] 720 else: 721 all_devices_data = self._dump_tensor_data.values() 722 data = [] 723 for device_data in all_devices_data: 724 data.extend(device_data) 725 return sorted(data, key=lambda x: x.extended_timestamp) 726 727 @property 728 def t0(self): 729 """Absolute timestamp of the first dumped tensor across all devices. 730 731 Returns: 732 (`int`) absolute timestamp of the first dumped tensor, in microseconds. 733 """ 734 return self._t0 735 736 @property 737 def size(self): 738 """Total number of dumped tensors in the dump root directory. 739 740 Returns: 741 (`int`) The total number of dumped tensors in the dump root directory. 742 """ 743 return sum(len(self._dump_tensor_data[device_name]) 744 for device_name in self._dump_tensor_data) 745 746 def _load_partition_graphs(self, client_partition_graphs, validate): 747 """Load and process partition graphs. 748 749 Load the graphs; parse the input and control input structure; obtain the 750 device and op type of each node; remove the Copy and debug ops inserted 751 by the debugger. The gathered information can be used to validate the 752 tensor dumps. 753 754 Args: 755 client_partition_graphs: A repeated field of GraphDefs representing the 756 partition graphs executed by the TensorFlow runtime, from the Python 757 client. These partition graphs are used only if partition graphs 758 cannot be loaded from the dump directory on the file system. 759 validate: (`bool`) Whether the dump files are to be validated against the 760 partition graphs. 761 762 Raises: 763 ValueError: If the partition GraphDef of one or more devices fail to be 764 loaded. 765 """ 766 self._debug_graphs = {} 767 self._node_devices = {} 768 769 partition_graphs_and_device_names = [] 770 for device_name in self._device_names: 771 partition_graph = None 772 if device_name in self._dump_graph_file_paths: 773 partition_graph = _load_graph_def_from_event_file( 774 self._dump_graph_file_paths[device_name]) 775 else: 776 logging.warn( 777 "Failed to load partition graphs for device %s from disk. " 778 "As a fallback, the client graphs will be used. This " 779 "may cause mismatches in device names." % device_name) 780 partition_graph = self._find_partition_graph(client_partition_graphs, 781 device_name) 782 783 if partition_graph: 784 partition_graphs_and_device_names.append((partition_graph, 785 device_name)) 786 787 for partition_graph, maybe_device_name in partition_graphs_and_device_names: 788 debug_graph = debug_graphs.DebugGraph(partition_graph, 789 device_name=maybe_device_name) 790 self._debug_graphs[debug_graph.device_name] = debug_graph 791 self._collect_node_devices(debug_graph) 792 793 if validate and debug_graph.device_name in self._dump_tensor_data: 794 self._validate_dump_with_graphs(debug_graph.device_name) 795 796 def _find_partition_graph(self, partition_graphs, device_name): 797 if partition_graphs is None: 798 return None 799 else: 800 for graph_def in partition_graphs: 801 for node_def in graph_def.node: 802 if node_def.device == device_name: 803 return graph_def 804 return None 805 806 def _collect_node_devices(self, debug_graph): 807 for node_name in debug_graph.node_devices: 808 if node_name in self._node_devices: 809 self._node_devices[node_name] = self._node_devices[node_name].union( 810 debug_graph.node_devices[node_name]) 811 else: 812 self._node_devices[node_name] = debug_graph.node_devices[node_name] 813 814 def _validate_dump_with_graphs(self, device_name): 815 """Validate the dumped tensor data against the partition graphs. 816 817 Only the watched nodes are validated by this method, because tfdbg allows 818 clients to watch only a subset of the nodes. 819 820 Args: 821 device_name: (`str`) device name. 822 823 Raises: 824 LookupError: If the partition graphs have not been loaded yet. 825 ValueError: If dumps contain node names not found in partition graph. 826 Or if the temporal order of the dump's timestamps violate the 827 input relations on the partition graphs. 828 """ 829 if not self._debug_graphs: 830 raise LookupError( 831 "No partition graphs loaded for device %s" % device_name) 832 debug_graph = self._debug_graphs[device_name] 833 834 # Verify that the node names in the dump data are all present in the 835 # partition graphs. 836 for datum in self._dump_tensor_data[device_name]: 837 if datum.node_name not in debug_graph.node_inputs: 838 raise ValueError("Node name '%s' is not found in partition graphs of " 839 "device %s." % (datum.node_name, device_name)) 840 841 pending_inputs = {} 842 for node in debug_graph.node_inputs: 843 pending_inputs[node] = [] 844 inputs = debug_graph.node_inputs[node] 845 for inp in inputs: 846 inp_node = debug_graphs.get_node_name(inp) 847 inp_output_slot = debug_graphs.get_output_slot(inp) 848 # Inputs from Enter and NextIteration nodes are not validated because 849 # DebugNodeInserter::InsertNodes() in the debugger core skips creating 850 # control edges from debug ops watching these types of nodes. 851 if (inp_node in self._debug_watches[device_name] and 852 inp_output_slot in self._debug_watches[device_name][inp_node] and 853 debug_graph.node_op_types.get(inp) not in ( 854 "Enter", "NextIteration") and 855 (inp_node, inp_output_slot) not in pending_inputs[node]): 856 pending_inputs[node].append((inp_node, inp_output_slot)) 857 858 for i, datum in enumerate(self._dump_tensor_data[device_name]): 859 node = datum.node_name 860 slot = datum.output_slot 861 # In some cases (e.g., system clocks with insufficient precision), 862 # the upstream and downstream tensors may have identical timestamps, the 863 # following check examines this possibility and avoids raising an error if 864 # that is the case. 865 if not self._satisfied_at_timestamp( 866 device_name, pending_inputs[node], datum.timestamp, start_i=i + 1): 867 raise ValueError("Causality violated in timing relations of debug " 868 "dumps: %s (%d): " 869 "these input(s) are not satisfied: %s" % 870 (node, datum.timestamp, repr(pending_inputs[node]))) 871 872 recipients = debug_graph.node_recipients[node] 873 for recipient in recipients: 874 recipient_pending_inputs = pending_inputs[recipient] 875 if (node, slot) in recipient_pending_inputs: 876 if self.node_op_type(recipient) == "Merge": 877 # If this is a Merge op, we automatically clear the list because 878 # a Merge node only requires one of its two inputs. 879 del recipient_pending_inputs[:] 880 else: 881 del recipient_pending_inputs[ 882 recipient_pending_inputs.index((node, slot))] 883 884 def _satisfied_at_timestamp(self, device_name, pending, timestamp, start_i=0): 885 """Determine whether pending inputs are satisfied at given timestamp. 886 887 Note: This method mutates the input argument "pending". 888 889 Args: 890 device_name: (str) device name. 891 pending: A list of 2-tuple (node_name, output_slot): the dependencies to 892 check. 893 timestamp: (int) the timestamp in question. 894 start_i: (int) the index in self._dump_tensor_data to start searching for 895 the timestamp. 896 897 Returns: 898 (bool) Whether all the dependencies in pending are satisfied at the 899 timestamp. If pending is empty to begin with, return True. 900 """ 901 if not pending: 902 return True 903 904 for datum in self._dump_tensor_data[device_name][start_i:]: 905 if datum.timestamp > timestamp: 906 break 907 if (datum.timestamp == timestamp and 908 (datum.node_name, datum.output_slot) in pending): 909 pending.remove((datum.node_name, datum.output_slot)) 910 if not pending: 911 return True 912 913 return not pending 914 915 def loaded_partition_graphs(self): 916 """Test whether partition graphs have been loaded.""" 917 return bool(self._debug_graphs) 918 919 def partition_graphs(self): 920 """Get the partition graphs. 921 922 Returns: 923 Partition graphs as a list of GraphDef. 924 925 Raises: 926 LookupError: If no partition graphs have been loaded. 927 """ 928 if not self._debug_graphs: 929 raise LookupError("No partition graphs have been loaded.") 930 return [self._debug_graphs[key].debug_graph_def 931 for key in self._debug_graphs] 932 933 def reconstructed_non_debug_partition_graphs(self): 934 """Reconstruct partition graphs with the debugger-inserted ops stripped. 935 936 The reconstructed partition graphs are identical to the original (i.e., 937 non-debugger-decorated) partition graphs except in the following respects: 938 1) The exact names of the runtime-inserted internal nodes may differ. 939 These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops. 940 2) As a consequence of 1, the nodes that receive input directly from such 941 send- and recv-type ops will have different input names. 942 3) The parallel_iteration attribute of while-loop Enter ops are set to 1. 943 944 Returns: 945 A dict mapping device names (`str`s) to reconstructed 946 `tf.compat.v1.GraphDef`s. 947 """ 948 non_debug_graphs = {} 949 for key in self._debug_graphs: 950 non_debug_graphs[key] = self._debug_graphs[key].non_debug_graph_def 951 return non_debug_graphs 952 953 @property 954 def run_fetches_info(self): 955 """Get a str representation of the fetches used in the Session.run() call. 956 957 Returns: 958 If the information is available from one `Session.run` call, a `str` 959 obtained from `repr(fetches)`. 960 If the information is available from multiple `Session.run` calls, a 961 `list` of `str` from `repr(fetches)`. 962 If the information is not available, `None`. 963 """ 964 965 output = self._run_fetches_info 966 return output[0] if len(output) == 1 else output 967 968 @property 969 def run_feed_keys_info(self): 970 """Get a str representation of the feed_dict used in the Session.run() call. 971 972 Returns: 973 If the information is available from one `Session.run` call, a `str` 974 obtained from `repr(feed_dict)`. 975 If the information is available from multiple `Session.run` calls, a 976 `list` of `str` obtained from `repr(feed_dict)`. 977 If the information is not available, `None`. 978 """ 979 980 output = self._run_feed_keys_info 981 return output[0] if len(output) == 1 else output 982 983 def _infer_device_name(self, device_name, node_name): 984 """Infer the device name given node name. 985 986 If device_name is provided (i.e., not None), it'll be simply returned right 987 away. 988 989 Args: 990 device_name: (str or None) name of the device. If None, will try to infer 991 the device name by looking at the available nodes. 992 node_name: (str) name of the node. 993 994 Returns: 995 (str) Inferred name of the device, if available. 996 997 Raises: 998 ValueError: If the node name does not exist on any of the available 999 devices or if there are multiple devices that contain the node with 1000 the given name. 1001 """ 1002 if device_name is None: 1003 if node_name in self._node_devices: 1004 if len(self._node_devices[node_name]) == 1: 1005 return list(self._node_devices[node_name])[0] 1006 else: 1007 raise ValueError( 1008 "There are multiple (%d) devices with nodes named '%s' but " 1009 "device_name is not specified." % 1010 (len(self._node_devices[node_name]), node_name)) 1011 else: 1012 raise ValueError("None of the %d device(s) has a node named '%s'." % 1013 (len(self._device_names), node_name)) 1014 else: 1015 return device_name 1016 1017 def nodes(self, device_name=None): 1018 """Get a list of all nodes from the partition graphs. 1019 1020 Args: 1021 device_name: (`str`) name of device. If None, all nodes from all available 1022 devices will be included. 1023 1024 Returns: 1025 All nodes' names, as a list of str. 1026 1027 Raises: 1028 LookupError: If no partition graphs have been loaded. 1029 ValueError: If specified node name does not exist. 1030 """ 1031 if not self._debug_graphs: 1032 raise LookupError("No partition graphs have been loaded.") 1033 if device_name is None: 1034 nodes = [] 1035 for device_name in self._debug_graphs: 1036 nodes.extend(self._debug_graphs[device_name].node_inputs.keys()) 1037 return nodes 1038 else: 1039 if device_name not in self._debug_graphs: 1040 raise ValueError("Invalid device name: %s" % device_name) 1041 return self._debug_graphs[device_name].node_inputs.keys() 1042 1043 def node_attributes(self, node_name, device_name=None): 1044 """Get the attributes of a node. 1045 1046 Args: 1047 node_name: Name of the node in question. 1048 device_name: (`str`) name of the device. If there is only one device or if 1049 node_name exists on only one device, this argument is optional. 1050 1051 Returns: 1052 Attributes of the node. 1053 1054 Raises: 1055 LookupError: If no partition graphs have been loaded. 1056 """ 1057 if not self._debug_graphs: 1058 raise LookupError("No partition graphs have been loaded.") 1059 1060 device_name = self._infer_device_name(device_name, node_name) 1061 return self._debug_graphs[device_name].node_attributes[node_name] 1062 1063 def node_inputs(self, node_name, is_control=False, device_name=None): 1064 """Get the inputs of given node according to partition graphs. 1065 1066 Args: 1067 node_name: Name of the node. 1068 is_control: (`bool`) Whether control inputs, rather than non-control 1069 inputs, are to be returned. 1070 device_name: (`str`) name of the device. If there is only one device or if 1071 node_name exists on only one device, this argument is optional. 1072 1073 Returns: 1074 (`list` of `str`) inputs to the node, as a list of node names. 1075 1076 Raises: 1077 LookupError: If node inputs and control inputs have not been loaded 1078 from partition graphs yet. 1079 """ 1080 if not self._debug_graphs: 1081 raise LookupError( 1082 "Node inputs are not loaded from partition graphs yet.") 1083 1084 device_name = self._infer_device_name(device_name, node_name) 1085 if is_control: 1086 return self._debug_graphs[device_name].node_ctrl_inputs[node_name] 1087 else: 1088 return self._debug_graphs[device_name].node_inputs[node_name] 1089 1090 def transitive_inputs(self, 1091 node_name, 1092 include_control=True, 1093 include_reversed_ref=False, 1094 device_name=None,): 1095 """Get the transitive inputs of given node according to partition graphs. 1096 1097 Args: 1098 node_name: Name of the node. 1099 include_control: Include control inputs (True by default). 1100 include_reversed_ref: Whether a ref input, say from A to B, is to be also 1101 considered as an input from B to A. The rationale is that ref inputs 1102 generally let the recipient (e.g., B in this case) mutate the value of 1103 the source (e.g., A in this case). So the reverse direction of the ref 1104 edge reflects the direction of information flow. 1105 device_name: (`str`) name of the device. If there is only one device or if 1106 node_name exists on only one device, this argument is optional. 1107 1108 Returns: 1109 (`list` of `str`) all transitive inputs to the node, as a list of node 1110 names. 1111 1112 Raises: 1113 LookupError: If node inputs and control inputs have not been loaded 1114 from partition graphs yet. 1115 """ 1116 if not self._debug_graphs: 1117 raise LookupError( 1118 "Node inputs are not loaded from partition graphs yet.") 1119 1120 device_name = self._infer_device_name(device_name, node_name) 1121 1122 input_lists = [self._debug_graphs[device_name].node_inputs] 1123 if include_control: 1124 input_lists.append(self._debug_graphs[device_name].node_ctrl_inputs) 1125 if include_reversed_ref: 1126 input_lists.append( 1127 self._debug_graphs[device_name].node_reversed_ref_inputs) 1128 tracer = debug_graphs.DFSGraphTracer( 1129 input_lists, 1130 skip_node_names=self._get_merge_node_names(device_name)) 1131 tracer.trace(node_name) 1132 return tracer.inputs() 1133 1134 def _get_merge_node_names(self, device_name): 1135 """Lazily get a list of Merge nodes on a given device.""" 1136 if device_name not in self._device_names: 1137 raise ValueError("Invalid device name: %s" % device_name) 1138 1139 if not hasattr(self, "_merge_node_names"): 1140 self._merge_node_names = {} 1141 if device_name not in self._merge_node_names: 1142 debug_graph = self._debug_graphs[device_name] 1143 self._merge_node_names[device_name] = [ 1144 node for node in debug_graph.node_op_types 1145 if debug_graph.node_op_types[node] == "Merge"] 1146 return self._merge_node_names[device_name] 1147 1148 def find_some_path(self, 1149 src_node_name, 1150 dst_node_name, 1151 include_control=True, 1152 include_reversed_ref=False, 1153 device_name=None): 1154 """Find a path between a source node and a destination node. 1155 1156 Limitation: the source and destination are required to be on the same 1157 device, i.e., this method does not yet take into account Send/Recv nodes 1158 across devices. 1159 1160 TODO(cais): Make this method work across device edges by tracing Send/Recv 1161 nodes. 1162 1163 Args: 1164 src_node_name: (`str`) name of the source node or name of an output tensor 1165 of the node. 1166 dst_node_name: (`str`) name of the destination node or name of an output 1167 tensor of the node. 1168 include_control: (`bool`) whrther control edges are considered in the 1169 graph tracing. 1170 include_reversed_ref: Whether a ref input, say from A to B, is to be also 1171 considered as an input from B to A. The rationale is that ref inputs 1172 generally let the recipient (e.g., B in this case) mutate the value of 1173 the source (e.g., A in this case). So the reverse direction of the ref 1174 edge reflects the direction of information flow. 1175 device_name: (`str`) name of the device. If there is only one device or if 1176 node_name exists on only one device, this argument is optional. 1177 1178 Returns: 1179 A path from the src_node_name to dst_node_name, as a `list` of `str`, if 1180 it exists. The list includes src_node_name as the first item and 1181 dst_node_name as the last. 1182 If such a path does not exist, `None`. 1183 1184 Raises: 1185 ValueError: If the source and destination nodes are not on the same 1186 device. 1187 """ 1188 src_device_name = self._infer_device_name(device_name, src_node_name) 1189 dst_device_name = self._infer_device_name(device_name, dst_node_name) 1190 1191 if src_device_name != dst_device_name: 1192 raise ValueError( 1193 "Source (%s) and destination (%s) are not on the same device: " 1194 "%s vs. %s" % (src_node_name, dst_node_name, src_device_name, 1195 dst_device_name)) 1196 1197 input_lists = [self._debug_graphs[dst_device_name].node_inputs] 1198 debug_graph = self._debug_graphs[dst_device_name] 1199 if include_control: 1200 input_lists.append(debug_graph.node_ctrl_inputs) 1201 if include_reversed_ref: 1202 input_lists.append(debug_graph.node_reversed_ref_inputs) 1203 tracer = debug_graphs.DFSGraphTracer( 1204 input_lists, 1205 skip_node_names=self._get_merge_node_names(dst_device_name), 1206 destination_node_name=src_node_name) 1207 # Here the value of destination_node_name is src_node_name, because we 1208 # are tracing the graph from output to its inputs (i.e., going backwards 1209 # on the graph). 1210 1211 try: 1212 tracer.trace(dst_node_name) 1213 except debug_graphs.GraphTracingReachedDestination: 1214 # Prune nodes not on the path. 1215 inputs = [dst_node_name] + tracer.inputs() 1216 depth_list = [0] + tracer.depth_list() 1217 1218 path = [] 1219 curr_depth = depth_list[-1] 1220 for inp, depth in zip(reversed(inputs), reversed(depth_list)): 1221 if depth == curr_depth: 1222 path.append(inp) 1223 curr_depth -= 1 1224 return path 1225 1226 def node_recipients(self, node_name, is_control=False, device_name=None): 1227 """Get recipient of the given node's output according to partition graphs. 1228 1229 Args: 1230 node_name: (`str`) name of the node. 1231 is_control: (`bool`) whether control outputs, rather than non-control 1232 outputs, are to be returned. 1233 device_name: (`str`) name of the device. If there is only one device or if 1234 node_name exists on only one device, this argument is optional. 1235 1236 Returns: 1237 (`list` of `str`) all inputs to the node, as a list of node names. 1238 1239 Raises: 1240 LookupError: If node inputs and control inputs have not been loaded 1241 from partition graphs yet. 1242 """ 1243 1244 if not self._debug_graphs: 1245 raise LookupError( 1246 "Node recipients are not loaded from partition graphs yet.") 1247 1248 device_name = self._infer_device_name(device_name, node_name) 1249 debug_graph = self._debug_graphs[device_name] 1250 if is_control: 1251 return debug_graph.node_ctrl_recipients[node_name] 1252 else: 1253 return debug_graph.node_recipients[node_name] 1254 1255 def devices(self): 1256 """Get the list of device names. 1257 1258 Returns: 1259 (`list` of `str`) names of the devices. 1260 """ 1261 return self._device_names 1262 1263 def node_exists(self, node_name, device_name=None): 1264 """Test if a node exists in the partition graphs. 1265 1266 Args: 1267 node_name: (`str`) name of the node to be checked. 1268 device_name: optional device name. If None, will search for the node 1269 on all available devices. Otherwise, search for the node only on 1270 the given device. 1271 1272 Returns: 1273 A boolean indicating whether the node exists. 1274 1275 Raises: 1276 LookupError: If no partition graphs have been loaded yet. 1277 ValueError: If device_name is specified but cannot be found. 1278 """ 1279 if not self._debug_graphs: 1280 raise LookupError( 1281 "Nodes have not been loaded from partition graphs yet.") 1282 1283 if (device_name is not None) and device_name not in self._debug_graphs: 1284 raise ValueError( 1285 "The specified device_name '%s' cannot be found." % device_name) 1286 1287 for _, debug_graph in self._debug_graphs.items(): 1288 if node_name in debug_graph.node_inputs: 1289 return True 1290 return False 1291 1292 def node_device(self, node_name): 1293 """Get the names of the devices that has nodes of the specified name. 1294 1295 Args: 1296 node_name: (`str`) name of the node. 1297 1298 Returns: 1299 (`str` or `list` of `str`) name of the device(s) on which the node of the 1300 given name is found. Returns a `str` if there is only one such device, 1301 otherwise return a `list` of `str`. 1302 1303 Raises: 1304 LookupError: If node inputs and control inputs have not been loaded 1305 from partition graphs yet. 1306 ValueError: If the node does not exist in partition graphs. 1307 """ 1308 if not self._debug_graphs: 1309 raise LookupError( 1310 "Node devices are not loaded from partition graphs yet.") 1311 1312 if node_name not in self._node_devices: 1313 raise ValueError("Node '%s' does not exist in partition graphs." % 1314 node_name) 1315 1316 output = list(self._node_devices[node_name]) 1317 return output[0] if len(output) == 1 else output 1318 1319 def node_op_type(self, node_name, device_name=None): 1320 """Get the op type of given node. 1321 1322 Args: 1323 node_name: (`str`) name of the node. 1324 device_name: (`str`) name of the device. If there is only one device or if 1325 node_name exists on only one device, this argument is optional. 1326 1327 Returns: 1328 (`str`) op type of the node. 1329 1330 Raises: 1331 LookupError: If node op types have not been loaded 1332 from partition graphs yet. 1333 """ 1334 if not self._debug_graphs: 1335 raise LookupError( 1336 "Node op types are not loaded from partition graphs yet.") 1337 1338 device_name = self._infer_device_name(device_name, node_name) 1339 return self._debug_graphs[device_name].node_op_types[node_name] 1340 1341 def debug_watch_keys(self, node_name, device_name=None): 1342 """Get all tensor watch keys of given node according to partition graphs. 1343 1344 Args: 1345 node_name: (`str`) name of the node. 1346 device_name: (`str`) name of the device. If there is only one device or if 1347 node_name exists on only one device, this argument is optional. 1348 1349 Returns: 1350 (`list` of `str`) all debug tensor watch keys. Returns an empty list if 1351 the node name does not correspond to any debug watch keys. 1352 1353 Raises: 1354 `LookupError`: If debug watch information has not been loaded from 1355 partition graphs yet. 1356 """ 1357 1358 try: 1359 device_name = self._infer_device_name(device_name, node_name) 1360 except ValueError: 1361 return [] 1362 1363 if node_name not in self._debug_watches[device_name]: 1364 return [] 1365 1366 watch_keys = [] 1367 for watched_slot in self._debug_watches[device_name][node_name]: 1368 debug_ops = self._debug_watches[device_name][node_name][watched_slot] 1369 for debug_op in debug_ops: 1370 watch_keys.append( 1371 _get_tensor_watch_key(node_name, watched_slot, debug_op)) 1372 1373 return watch_keys 1374 1375 def watch_key_to_data(self, debug_watch_key, device_name=None): 1376 """Get all `DebugTensorDatum` instances corresponding to a debug watch key. 1377 1378 Args: 1379 debug_watch_key: (`str`) debug watch key. 1380 device_name: (`str`) name of the device. If there is only one device or if 1381 the specified debug_watch_key exists on only one device, this argument 1382 is optional. 1383 1384 Returns: 1385 A list of `DebugTensorDatum` instances that correspond to the debug watch 1386 key. If the watch key does not exist, returns an empty list. 1387 1388 Raises: 1389 ValueError: If there are multiple devices that have the debug_watch_key, 1390 but device_name is not specified. 1391 """ 1392 if device_name is None: 1393 matching_device_names = [ 1394 name for name in self._watch_key_to_datum 1395 if debug_watch_key in self._watch_key_to_datum[name]] 1396 if not matching_device_names: 1397 return [] 1398 elif len(matching_device_names) == 1: 1399 device_name = matching_device_names[0] 1400 else: 1401 raise ValueError( 1402 "The debug watch key '%s' exists on multiple (%d) devices, but " 1403 "device name is not specified." % 1404 (debug_watch_key, len(matching_device_names))) 1405 elif device_name not in self._debug_key_to_datum: 1406 raise ValueError( 1407 "There is no device named '%s' consisting of debug watch keys." % 1408 device_name) 1409 1410 return self._watch_key_to_datum[device_name].get(debug_watch_key, []) 1411 1412 def find(self, 1413 predicate, 1414 first_n=0, 1415 device_name=None, 1416 exclude_node_names=None): 1417 """Find dumped tensor data by a certain predicate. 1418 1419 Args: 1420 predicate: A callable that takes two input arguments: 1421 1422 ```python 1423 def predicate(debug_tensor_datum, tensor): 1424 # returns a bool 1425 ``` 1426 1427 where `debug_tensor_datum` is an instance of `DebugTensorDatum`, which 1428 carries the metadata, such as the `Tensor`'s node name, output slot 1429 timestamp, debug op name, etc.; and `tensor` is the dumped tensor value 1430 as a `numpy.ndarray`. 1431 first_n: (`int`) return only the first n `DebugTensotDatum` instances (in 1432 time order) for which the predicate returns True. To return all the 1433 `DebugTensotDatum` instances, let first_n be <= 0. 1434 device_name: optional device name. 1435 exclude_node_names: Optional regular expression to exclude nodes with 1436 names matching the regular expression. 1437 1438 Returns: 1439 A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object 1440 for which predicate returns True, sorted in ascending order of the 1441 timestamp. 1442 """ 1443 if exclude_node_names: 1444 exclude_node_names = re.compile(exclude_node_names) 1445 1446 matched_data = [] 1447 for device in (self._dump_tensor_data if device_name is None 1448 else (self._dump_tensor_data[device_name],)): 1449 for datum in self._dump_tensor_data[device]: 1450 if exclude_node_names and exclude_node_names.match(datum.node_name): 1451 continue 1452 1453 if predicate(datum, datum.get_tensor()): 1454 matched_data.append(datum) 1455 1456 if first_n > 0 and len(matched_data) >= first_n: 1457 return matched_data 1458 1459 return matched_data 1460 1461 def get_tensor_file_paths(self, 1462 node_name, 1463 output_slot, 1464 debug_op, 1465 device_name=None): 1466 """Get the file paths from a debug-dumped tensor. 1467 1468 Args: 1469 node_name: (`str`) name of the node that the tensor is produced by. 1470 output_slot: (`int`) output slot index of tensor. 1471 debug_op: (`str`) name of the debug op. 1472 device_name: (`str`) name of the device. If there is only one device or if 1473 the specified debug_watch_key exists on only one device, this argument 1474 is optional. 1475 1476 Returns: 1477 List of file path(s) loaded. This is a list because each debugged tensor 1478 may be dumped multiple times. 1479 1480 Raises: 1481 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in 1482 the debug-dump data. 1483 """ 1484 1485 device_name = self._infer_device_name(device_name, node_name) 1486 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1487 if watch_key not in self._watch_key_to_datum[device_name]: 1488 raise WatchKeyDoesNotExistInDebugDumpDirError( 1489 "Watch key \"%s\" does not exist in the debug dump of device %s" % 1490 (watch_key, device_name)) 1491 1492 return [datum.file_path for datum in 1493 self._watch_key_to_datum[device_name][watch_key]] 1494 1495 def get_tensors(self, node_name, output_slot, debug_op, device_name=None): 1496 """Get the tensor value from for a debug-dumped tensor. 1497 1498 The tensor may be dumped multiple times in the dump root directory, so a 1499 list of tensors (`numpy.ndarray`) is returned. 1500 1501 Args: 1502 node_name: (`str`) name of the node that the tensor is produced by. 1503 output_slot: (`int`) output slot index of tensor. 1504 debug_op: (`str`) name of the debug op. 1505 device_name: (`str`) name of the device. If there is only one device or if 1506 the specified debug_watch_key exists on only one device, this argument 1507 is optional. 1508 1509 Returns: 1510 List of tensors (`numpy.ndarray`) loaded from the debug-dump file(s). 1511 1512 Raises: 1513 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in 1514 the debug-dump data. 1515 """ 1516 1517 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1518 try: 1519 device_name = self._infer_device_name(device_name, node_name) 1520 return [datum.get_tensor() for datum in 1521 self._watch_key_to_datum[device_name][watch_key]] 1522 except (ValueError, KeyError): 1523 raise WatchKeyDoesNotExistInDebugDumpDirError( 1524 "Watch key \"%s\" does not exist in the debug dump of device %s" % 1525 (watch_key, device_name)) 1526 1527 def get_rel_timestamps(self, 1528 node_name, 1529 output_slot, 1530 debug_op, 1531 device_name=None): 1532 """Get the relative timestamp from for a debug-dumped tensor. 1533 1534 Relative timestamp means (absolute timestamp - `t0`), where `t0` is the 1535 absolute timestamp of the first dumped tensor in the dump root. The tensor 1536 may be dumped multiple times in the dump root directory, so a list of 1537 relative timestamps (`numpy.ndarray`) is returned. 1538 1539 Args: 1540 node_name: (`str`) name of the node that the tensor is produced by. 1541 output_slot: (`int`) output slot index of tensor. 1542 debug_op: (`str`) name of the debug op. 1543 device_name: (`str`) name of the device. If there is only one device or if 1544 the specified debug_watch_key exists on only one device, this argument 1545 is optional. 1546 1547 Returns: 1548 (`list` of `int`) list of relative timestamps. 1549 1550 Raises: 1551 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not 1552 exist in the debug dump data. 1553 """ 1554 1555 device_name = self._infer_device_name(device_name, node_name) 1556 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1557 if watch_key not in self._watch_key_to_datum[device_name]: 1558 raise WatchKeyDoesNotExistInDebugDumpDirError( 1559 "Watch key \"%s\" does not exist in the debug dump" % watch_key) 1560 1561 # TODO(cais): Figure out whether this should be relative to the global t0. 1562 return self._watch_key_to_rel_time[device_name][watch_key] 1563 1564 def get_dump_sizes_bytes(self, 1565 node_name, 1566 output_slot, 1567 debug_op, 1568 device_name=None): 1569 """Get the sizes of the dump files for a debug-dumped tensor. 1570 1571 Unit of the file size: byte. 1572 1573 Args: 1574 node_name: (`str`) name of the node that the tensor is produced by. 1575 output_slot: (`int`) output slot index of tensor. 1576 debug_op: (`str`) name of the debug op. 1577 device_name: (`str`) name of the device. If there is only one device or if 1578 the specified debug_watch_key exists on only one device, this argument 1579 is optional. 1580 1581 Returns: 1582 (`list` of `int`): list of dump file sizes in bytes. 1583 1584 Raises: 1585 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not 1586 exist in the debug dump data. 1587 """ 1588 1589 device_name = self._infer_device_name(device_name, node_name) 1590 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1591 if watch_key not in self._watch_key_to_datum[device_name]: 1592 raise WatchKeyDoesNotExistInDebugDumpDirError( 1593 "Watch key \"%s\" does not exist in the debug dump of device %s" % 1594 (watch_key, device_name)) 1595 1596 return self._watch_key_to_dump_size_bytes[device_name][watch_key] 1597 1598 def node_traceback(self, element_name): 1599 """Try to retrieve the Python traceback of node's construction. 1600 1601 Args: 1602 element_name: (`str`) Name of a graph element (node or tensor). 1603 1604 Returns: 1605 (list) The traceback list object as returned by the `extract_trace` 1606 method of Python's traceback module. 1607 1608 Raises: 1609 LookupError: If Python graph is not available for traceback lookup. 1610 KeyError: If the node cannot be found in the Python graph loaded. 1611 """ 1612 1613 if self._python_graph is None: 1614 raise LookupError("Python graph is not available for traceback lookup") 1615 1616 node_name = debug_graphs.get_node_name(element_name) 1617 if node_name not in self._node_traceback: 1618 raise KeyError("Cannot find node \"%s\" in Python graph" % node_name) 1619 1620 return self._node_traceback[node_name] 1621