xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/debug_data.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions to handle debug-dump data of TensorFlow Debugger."""
16
17import collections
18import glob
19import json
20import os
21import platform
22import re
23
24import numpy as np
25
26from tensorflow.core.framework import graph_pb2
27from tensorflow.core.framework import types_pb2
28from tensorflow.core.util import event_pb2
29from tensorflow.python.debug.lib import debug_graphs
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.platform import gfile
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.util import compat
34
35
36# TODO(cais): Tie these string constants in with C++?
37METADATA_FILE_PREFIX = "_tfdbg_"
38CORE_METADATA_TAG = "core_metadata_"
39GRAPH_FILE_TAG = "graph_"
40DEVICE_TAG = "device_"
41HASH_TAG = "hash"
42
43FETCHES_INFO_FILE_TAG = "fetches_info_"
44FEED_KEYS_INFO_FILE_TAG = "feed_keys_info_"
45
46
47def _glob(glob_pattern):
48  if platform.system() == "Windows":
49    return glob.glob(glob_pattern)
50  else:
51    return gfile.Glob(glob_pattern)
52
53
54class InconvertibleTensorProto:
55  """Represents a TensorProto that cannot be converted to np.ndarray."""
56
57  def __init__(self, tensor_proto, initialized=True):
58    """Constructor.
59
60    Args:
61      tensor_proto: the `TensorProto` object that cannot be represented as a
62        `np.ndarray` object.
63      initialized: (`bool`) whether the Tensor is initialized.
64    """
65    self._tensor_proto = tensor_proto
66    self._initialized = initialized
67
68  def __str__(self):
69    output = "" if self._initialized else "Uninitialized tensor:\n"
70    output += str(self._tensor_proto)
71    return output
72
73  @property
74  def initialized(self):
75    return self._initialized
76
77
78def load_tensor_from_event_file(event_file_path):
79  """Load a tensor from an event file.
80
81  Assumes that the event file contains a `Event` protobuf and the `Event`
82  protobuf contains a `Tensor` value.
83
84  Args:
85    event_file_path: (`str`) path to the event file.
86
87  Returns:
88    The tensor value loaded from the event file, as a `numpy.ndarray`. For
89    uninitialized Tensors, returns `None`. For Tensors of data types that
90    cannot be converted to `numpy.ndarray` (e.g., `tf.resource`), return
91    `None`.
92  """
93
94  event = event_pb2.Event()
95  with gfile.Open(event_file_path, "rb") as f:
96    event.ParseFromString(f.read())
97    return load_tensor_from_event(event)
98
99
100def load_tensor_from_event(event):
101  """Load a tensor from an Event proto.
102
103  Args:
104    event: The Event proto, assumed to hold a tensor value in its
105        summary.value[0] field.
106
107  Returns:
108    The tensor value loaded from the event file, as a `numpy.ndarray`, if
109    representation of the tensor value by a `numpy.ndarray` is possible.
110    For uninitialized Tensors, returns `None`. For Tensors of data types that
111    cannot be represented as `numpy.ndarray` (e.g., `tf.resource`), return
112    the `TensorProto` protobuf object without converting it to a
113    `numpy.ndarray`.
114  """
115
116  tensor_proto = event.summary.value[0].tensor
117  shape = tensor_util.TensorShapeProtoToList(tensor_proto.tensor_shape)
118  num_elements = 1
119  for shape_dim in shape:
120    num_elements *= shape_dim
121
122  if tensor_proto.tensor_content or tensor_proto.string_val or not num_elements:
123    # Initialized tensor or empty tensor.
124    if tensor_proto.dtype == types_pb2.DT_RESOURCE:
125      tensor_value = InconvertibleTensorProto(tensor_proto)
126    else:
127      try:
128        tensor_value = tensor_util.MakeNdarray(tensor_proto)
129      except KeyError:
130        tensor_value = InconvertibleTensorProto(tensor_proto)
131  else:
132    # Uninitialized tensor or tensor of unconvertible data type.
133    tensor_value = InconvertibleTensorProto(tensor_proto, False)
134
135  return tensor_value
136
137
138def _load_graph_def_from_event_file(event_file_path):
139  event = event_pb2.Event()
140  with gfile.Open(event_file_path, "rb") as f:
141    event.ParseFromString(f.read())
142
143  return graph_pb2.GraphDef.FromString(event.graph_def)
144
145
146def _load_log_message_from_event_file(event_file_path):
147  event = event_pb2.Event()
148  with gfile.Open(event_file_path, "rb") as f:
149    event.ParseFromString(f.read())
150
151  return event.log_message.message
152
153
154def _is_graph_file(file_name):
155  return file_name.startswith(METADATA_FILE_PREFIX + GRAPH_FILE_TAG)
156
157
158def _is_run_fetches_info_file(file_name):
159  return file_name == METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG
160
161
162def _is_run_feed_keys_info_file(file_name):
163  return file_name == METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG
164
165
166def _get_tensor_name(node_name, output_slot):
167  """Get tensor name given node name and output slot index.
168
169  Args:
170    node_name: Name of the node that outputs the tensor, as a string.
171    output_slot: Output slot index of the tensor, as an integer.
172
173  Returns:
174    Name of the tensor, as a string.
175  """
176
177  return "%s:%d" % (node_name, output_slot)
178
179
180def _get_tensor_watch_key(node_name, output_slot, debug_op):
181  """Get the string representation of a debug watch on a tensor.
182
183  Args:
184    node_name: Name of the node by which the watched tensor is produced, as a
185        string.
186    output_slot: Output slot index of the tensor, as an integer.
187    debug_op: Name of the debug op that is used to watch the tensor, as a
188        string.
189
190  Returns:
191    A string representing the debug watch on the tensor (i.e., the "watch
192        key").
193  """
194  return "%s:%s" % (_get_tensor_name(node_name, output_slot), debug_op)
195
196
197def has_inf_or_nan(datum, tensor):
198  """A predicate for whether a tensor consists of any bad numerical values.
199
200  This predicate is common enough to merit definition in this module.
201  Bad numerical values include `nan`s and `inf`s.
202  The signature of this function follows the requirement of the method
203  `DebugDumpDir.find()`.
204
205  Args:
206    datum: (`DebugTensorDatum`) Datum metadata.
207    tensor: (`numpy.ndarray` or None) Value of the tensor. None represents
208      an uninitialized tensor.
209
210  Returns:
211    (`bool`) True if and only if tensor consists of any nan or inf values.
212  """
213
214  _ = datum  # Datum metadata is unused in this predicate.
215
216  if isinstance(tensor, InconvertibleTensorProto):
217    # Uninitialized tensor doesn't have bad numerical values.
218    # Also return False for data types that cannot be represented as numpy
219    # arrays.
220    return False
221  elif (np.issubdtype(tensor.dtype, np.floating) or
222        np.issubdtype(tensor.dtype, np.complexfloating) or
223        np.issubdtype(tensor.dtype, np.integer)):
224    return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor))
225  else:
226    return False
227
228
229_CoreMetadata = collections.namedtuple("CoreMetadata", [
230    "global_step", "session_run_index", "executor_step_index", "input_names",
231    "output_names", "target_nodes"
232])
233
234
235def extract_core_metadata_from_event_proto(event):
236  json_metadata = json.loads(event.log_message.message)
237  return _CoreMetadata(json_metadata["global_step"],
238                       json_metadata["session_run_index"],
239                       json_metadata["executor_step_index"],
240                       json_metadata["input_names"],
241                       json_metadata["output_names"],
242                       json_metadata["target_nodes"])
243
244
245def device_name_to_device_path(device_name):
246  """Convert device name to device path."""
247  device_name_items = compat.as_text(device_name).split("/")
248  device_name_items = [item.replace(":", "_") for item in device_name_items]
249  return METADATA_FILE_PREFIX + DEVICE_TAG + ",".join(device_name_items)
250
251
252def device_path_to_device_name(device_dir):
253  """Parse device name from device path.
254
255  Args:
256    device_dir: (str) a directory name for the device.
257
258  Returns:
259    (str) parsed device name.
260  """
261  path_items = os.path.basename(device_dir)[
262      len(METADATA_FILE_PREFIX) + len(DEVICE_TAG):].split(",")
263  return "/".join([
264      path_item.replace("device_", "device:").replace("_", ":", 1)
265      for path_item in path_items])
266
267
268class DebugTensorDatum:
269  """A single tensor dumped by TensorFlow Debugger (tfdbg).
270
271  Contains metadata about the dumped tensor, including `timestamp`,
272  `node_name`, `output_slot`, `debug_op`, and path to the dump file
273  (`file_path`).
274
275  This type does not hold the generally space-expensive tensor value (numpy
276  array). Instead, it points to the file from which the tensor value can be
277  loaded (with the `get_tensor` method) if needed.
278  """
279
280  def __init__(self, dump_root, debug_dump_rel_path):
281    """`DebugTensorDatum` constructor.
282
283    Args:
284      dump_root: (`str`) Debug dump root directory. This path should not include
285        the path component that represents the device name (see also below).
286      debug_dump_rel_path: (`str`) Path to a debug dump file, relative to the
287        `dump_root`. The first item of this relative path is assumed to be
288        a path representing the name of the device that the Tensor belongs to.
289        See `device_path_to_device_name` for more details on the device path.
290        For example, suppose the debug dump root
291        directory is `/tmp/tfdbg_1` and the dump file is at
292        `/tmp/tfdbg_1/<device_path>/>ns_1/node_a_0_DebugIdentity_123456789`,
293        then the value of the debug_dump_rel_path should be
294        `<device_path>/ns_1/node_a_0_DebugIdentity_1234456789`.
295
296    Raises:
297      ValueError: If the base file name of the dump file does not conform to
298        the dump file naming pattern:
299        `node_name`_`output_slot`_`debug_op`_`timestamp`
300    """
301
302    path_components = os.path.normpath(debug_dump_rel_path).split(os.sep)
303    self._device_name = device_path_to_device_name(path_components[0])
304    base = path_components[-1]
305    if base.count("_") < 3:
306      raise ValueError(
307          "Dump file path does not conform to the naming pattern: %s" % base)
308
309    self._extended_timestamp = base.split("_")[-1]
310    # It may include an index suffix at the end if file path collision happened
311    # due to identical timestamps.
312    if "-" in self._extended_timestamp:
313      self._timestamp = int(
314          self._extended_timestamp[:self._extended_timestamp.find("-")])
315    else:
316      self._timestamp = int(self._extended_timestamp)
317
318    self._debug_op = base.split("_")[-2]
319    self._output_slot = int(base.split("_")[-3])
320
321    node_base_name = "_".join(base.split("_")[:-3])
322    self._node_name = "/".join(path_components[1:-1] + [node_base_name])
323
324    self._file_path = os.path.join(dump_root, debug_dump_rel_path)
325    self._dump_size_bytes = (gfile.Stat(self._file_path).length if
326                             gfile.Exists(self._file_path) else None)
327
328  def __str__(self):
329    return "{DebugTensorDatum (%s) %s:%d @ %s @ %d}" % (self.device_name,
330                                                        self.node_name,
331                                                        self.output_slot,
332                                                        self.debug_op,
333                                                        self.timestamp)
334
335  def __repr__(self):
336    return self.__str__()
337
338  def get_tensor(self):
339    """Get tensor from the dump (`Event`) file.
340
341    Returns:
342      The tensor loaded from the dump (`Event`) file.
343    """
344
345    return load_tensor_from_event_file(self.file_path)
346
347  # TODO(cais): Add time unit suffix to timestamp and t0 (us).
348  @property
349  def timestamp(self):
350    """Timestamp of when this tensor value was dumped.
351
352    Returns:
353      (`int`) The timestamp in microseconds.
354    """
355
356    return self._timestamp
357
358  @property
359  def extended_timestamp(self):
360    """Extended timestamp, possibly with an index suffix.
361
362    The index suffix, e.g., "-1", is for disambiguating multiple dumps of the
363    same tensor with the same timestamp, which can occur if the dumping events
364    are spaced by shorter than the temporal resolution of the timestamps.
365
366    Returns:
367      (`str`) The extended timestamp.
368    """
369
370    return self._extended_timestamp
371
372  @property
373  def debug_op(self):
374    """Name of the debug op.
375
376    Returns:
377      (`str`) debug op name (e.g., `DebugIdentity`).
378    """
379
380    return self._debug_op
381
382  @property
383  def device_name(self):
384    """Name of the device that the tensor belongs to.
385
386    Returns:
387      (`str`) device name.
388    """
389
390    return self._device_name
391
392  @property
393  def node_name(self):
394    """Name of the node from which the tensor value was dumped.
395
396    Returns:
397      (`str`) name of the node watched by the debug op.
398    """
399
400    return self._node_name
401
402  @property
403  def output_slot(self):
404    """Output slot index from which the tensor value was dumped.
405
406    Returns:
407      (`int`) output slot index watched by the debug op.
408    """
409
410    return self._output_slot
411
412  @property
413  def tensor_name(self):
414    """Name of the tensor watched by the debug op.
415
416    Returns:
417      (`str`) `Tensor` name, in the form of `node_name`:`output_slot`
418    """
419
420    return _get_tensor_name(self.node_name, self.output_slot)
421
422  @property
423  def watch_key(self):
424    """Watch key identities a debug watch on a tensor.
425
426    Returns:
427      (`str`) A watch key, in the form of `tensor_name`:`debug_op`.
428    """
429
430    return _get_tensor_watch_key(self.node_name, self.output_slot,
431                                 self.debug_op)
432
433  @property
434  def file_path(self):
435    """Path to the file which stores the value of the dumped tensor."""
436
437    return self._file_path
438
439  @property
440  def dump_size_bytes(self):
441    """Size of the dump file.
442
443    Unit: byte.
444
445    Returns:
446      If the dump file exists, size of the dump file, in bytes.
447      If the dump file does not exist, None.
448    """
449
450    return self._dump_size_bytes
451
452
453class WatchKeyDoesNotExistInDebugDumpDirError(ValueError):
454  pass
455
456
457class DebugDumpDir:
458  """Data set from a debug-dump directory on filesystem.
459
460  An instance of `DebugDumpDir` contains all `DebugTensorDatum` instances
461  in a tfdbg dump root directory.
462  """
463
464  def __init__(self, dump_root, partition_graphs=None, validate=True):
465    """`DebugDumpDir` constructor.
466
467    Args:
468      dump_root: (`str`) path to the dump root directory.
469      partition_graphs: A repeated field of GraphDefs representing the
470          partition graphs executed by the TensorFlow runtime.
471      validate: (`bool`) whether the dump files are to be validated against the
472          partition graphs.
473
474    Raises:
475      IOError: If dump_root does not exist as a directory.
476      ValueError: If more than one core metadata file is found under the dump
477        root directory.
478    """
479
480    if not gfile.IsDirectory(dump_root):
481      raise IOError("Dump root directory %s does not exist" % dump_root)
482
483    self._core_metadata = []
484
485    # Find the list of devices.
486    self._dump_root = dump_root
487
488    self._load_core_metadata()
489    self._load_fetches_info()
490    self._load_feeds_info()
491    self._load_all_device_dumps(partition_graphs, validate)
492
493    self._python_graph = None
494
495  def _load_all_device_dumps(self, partition_graphs, validate):
496    """Load the dump data for all devices."""
497    device_dirs = _glob(os.path.join(
498        self._dump_root, METADATA_FILE_PREFIX + DEVICE_TAG + "*"))
499
500    self._device_names = []
501    self._t0s = {}
502    self._dump_tensor_data = {}
503    self._dump_graph_file_paths = {}
504    self._debug_watches = {}
505    self._watch_key_to_devices = {}
506    self._watch_key_to_datum = {}
507    self._watch_key_to_rel_time = {}
508    self._watch_key_to_dump_size_bytes = {}
509    for device_dir in device_dirs:
510      device_name = device_path_to_device_name(device_dir)
511      self._device_names.append(device_name)
512      self._load_device_dumps(device_name, device_dir)
513    self._load_partition_graphs(partition_graphs, validate)
514    self._calculate_t0()
515
516    for device_name in self._device_names:
517      self._create_tensor_watch_maps(device_name)
518
519  def _load_device_dumps(self, device_name, device_root):
520    """Load `DebugTensorDatum` instances from the dump root of a given device.
521
522    Populates a map {device_name: a list of `DebugTensorDatum`}, where the list
523    is sorted by ascending timestamp.
524
525    This sorting order reflects the order in which the TensorFlow executor
526    processed the nodes of the graph. It is (one of many possible) topological
527    sort of the nodes. This is useful for displaying tensors in the debugger
528    frontend as well as for the use case in which the user wants to find a
529    "culprit tensor", i.e., the first tensor in the graph that exhibits certain
530    problematic properties, i.e., all zero values, or bad numerical values such
531    as nan and inf.
532
533    In addition, creates a map from node name to debug watches. In this Map,
534    the key is the watched node name; the value is a dictionary.
535    Of this dictionary, the key is the watched_output_slot.
536
537    This method attempts to load the debug watches from the tensor dump files
538    first, before loading the full set of debug watches from the partition
539    graphs as done later. This is necessary because sometimes the partition
540    graphs may not be available, e.g., when the run errors out.
541
542    Args:
543      device_name: (`str`) name of the device.
544      device_root: (`str`) dump root directory of the given device.
545
546    Raises:
547      ValueError: If GraphDef for the device is not available.
548    """
549
550    self._dump_tensor_data[device_name] = []
551    self._debug_watches[device_name] = collections.defaultdict(
552        lambda: collections.defaultdict(set))
553
554    for root, _, files in gfile.Walk(device_root):
555      for f in files:
556        if _is_graph_file(f):
557          self._dump_graph_file_paths[device_name] = os.path.join(root, f)
558        else:
559          datum = self._dump_file_name_to_datum(root, f)
560          self._dump_tensor_data[device_name].append(datum)
561          self._debug_watches[device_name][datum.node_name][
562              datum.output_slot].add(datum.debug_op)
563
564    self._dump_tensor_data[device_name] = sorted(
565        self._dump_tensor_data[device_name],
566        key=lambda x: x.extended_timestamp)
567
568    if self._dump_tensor_data[device_name]:
569      self._t0s[device_name] = self._dump_tensor_data[device_name][0].timestamp
570    else:
571      self._t0s[device_name] = None
572
573  def _calculate_t0(self):
574    """Calculate the first timestamp across all devices."""
575    t0s = [t0 for t0 in self._t0s.values() if t0 is not None]
576    self._t0 = min(t0s) if t0s else None
577
578  def _load_core_metadata(self):
579    core_metadata_files = _glob(os.path.join(
580        self._dump_root, METADATA_FILE_PREFIX + CORE_METADATA_TAG + "*"))
581    for core_metadata_file in core_metadata_files:
582      with gfile.Open(core_metadata_file, "rb") as f:
583        event = event_pb2.Event()
584        event.ParseFromString(f.read())
585        self._core_metadata.append(
586            extract_core_metadata_from_event_proto(event))
587
588  def _load_fetches_info(self):
589    fetches_info_files = _glob(os.path.join(
590        self._dump_root, METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG + "*"))
591    self._run_fetches_info = []
592    for fetches_info_file in fetches_info_files:
593      self._run_fetches_info.append(
594          _load_log_message_from_event_file(fetches_info_file))
595
596  def _load_feeds_info(self):
597    feeds_info_files = _glob(os.path.join(
598        self._dump_root, METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG + "*"))
599    self._run_feed_keys_info = []
600    for feeds_info_file in feeds_info_files:
601      self._run_feed_keys_info.append(
602          _load_log_message_from_event_file(feeds_info_file))
603
604  def _dump_file_name_to_datum(self, dir_name, file_name):
605    """Obtain a DebugTensorDatum from the directory and file name.
606
607    Args:
608      dir_name: (`str`) Name of the directory in which the dump file resides.
609      file_name: (`str`) Base name of the dump file.
610
611    Returns:
612      (`DebugTensorDatum`) The `DebugTensorDatum` loaded from the dump file.
613    """
614
615    # Calculate the relative path of the dump file with respect to the root.
616    debug_dump_rel_path = os.path.join(
617        os.path.relpath(dir_name, self._dump_root), file_name)
618    return DebugTensorDatum(self._dump_root, debug_dump_rel_path)
619
620  def _create_tensor_watch_maps(self, device_name):
621    """Create maps from tensor watch keys to datum and to timestamps.
622
623    Create a map from watch key (tensor name + debug op) to `DebugTensorDatum`
624    item. Also make a map from watch key to relative timestamp.
625    "relative" means (absolute timestamp - t0).
626
627    Args:
628      device_name: (str) name of the device.
629    """
630
631    self._watch_key_to_datum[device_name] = {}
632    self._watch_key_to_rel_time[device_name] = {}
633    self._watch_key_to_dump_size_bytes[device_name] = {}
634    for datum in self._dump_tensor_data[device_name]:
635      if datum.watch_key not in self._watch_key_to_devices:
636        self._watch_key_to_devices[datum.watch_key] = {device_name}
637      else:
638        self._watch_key_to_devices[datum.watch_key].add(device_name)
639
640      if datum.watch_key not in self._watch_key_to_datum[device_name]:
641        self._watch_key_to_datum[device_name][datum.watch_key] = [datum]
642        self._watch_key_to_rel_time[device_name][datum.watch_key] = [
643            datum.timestamp - self._t0]
644        self._watch_key_to_dump_size_bytes[device_name][datum.watch_key] = [
645            datum.dump_size_bytes]
646      else:
647        self._watch_key_to_datum[device_name][datum.watch_key].append(datum)
648        self._watch_key_to_rel_time[device_name][datum.watch_key].append(
649            datum.timestamp - self._t0)
650        self._watch_key_to_dump_size_bytes[device_name][datum.watch_key].append(
651            datum.dump_size_bytes)
652
653  def set_python_graph(self, python_graph):
654    """Provide Python `Graph` object to the wrapper.
655
656    Unlike the partition graphs, which are protobuf `GraphDef` objects, `Graph`
657    is a Python object and carries additional information such as the traceback
658    of the construction of the nodes in the graph.
659
660    Args:
661      python_graph: (ops.Graph) The Python Graph object.
662    """
663
664    self._python_graph = python_graph
665    self._node_traceback = {}
666    if self._python_graph:
667      for op in self._python_graph.get_operations():
668        self._node_traceback[op.name] = tuple(map(tuple, op.traceback))
669
670  @property
671  def python_graph(self):
672    """Get the Python graph.
673
674    Returns:
675      If the Python graph has been set, returns a `tf.Graph` object. Otherwise,
676      returns None.
677    """
678
679    return self._python_graph
680
681  @property
682  def core_metadata(self):
683    """Metadata about the `Session.run()` call from the core runtime.
684
685    Of the three counters available in the return value, `global_step` is
686    supplied by the caller of the debugged `Session.run()`, while
687    `session_run_index` and `executor_step_index` are determined by the state
688    of the core runtime, automatically. For the same fetch list, feed keys and
689    debug tensor watch options, the same executor will be used and
690    `executor_step_index` should increase by one at a time. However, runs with
691    different fetch lists, feed keys and debug_tensor watch options that all
692    share the same `Session` object can lead to gaps in `session_run_index`.
693
694    Returns:
695      If core metadata are loaded, a `namedtuple` with the fields:
696        `global_step`: A global step count supplied by the caller of
697          `Session.run()`. It is optional to the caller. If the caller did not
698          supply this parameter, its value will be -1.
699        `session_run_index`: A sorted index for Run() calls to the underlying
700          TensorFlow `Session` object.
701        `executor_step_index`: A counter for invocations of a given runtime
702          executor. The same executor is re-used for the same fetched tensors,
703          target nodes, input feed keys and debug tensor watch options.
704        `input_names`: Names of the input (feed) Tensors.
705        `output_names`: Names of the output (fetched) Tensors.
706        `target_nodes`: Names of the target nodes.
707      If the core metadata have not been loaded, `None`.
708      If more than one core metadata files exist, return a list of the
709        `nametuple` described above.
710    """
711
712    output = self._core_metadata
713    return output[0] if len(output) == 1 else output
714
715  @property
716  def dumped_tensor_data(self):
717    """Retrieve dumped tensor data."""
718    if len(self.devices()) == 1:
719      return self._dump_tensor_data[self.devices()[0]]
720    else:
721      all_devices_data = self._dump_tensor_data.values()
722      data = []
723      for device_data in all_devices_data:
724        data.extend(device_data)
725      return sorted(data, key=lambda x: x.extended_timestamp)
726
727  @property
728  def t0(self):
729    """Absolute timestamp of the first dumped tensor across all devices.
730
731    Returns:
732      (`int`) absolute timestamp of the first dumped tensor, in microseconds.
733    """
734    return self._t0
735
736  @property
737  def size(self):
738    """Total number of dumped tensors in the dump root directory.
739
740    Returns:
741      (`int`) The total number of dumped tensors in the dump root directory.
742    """
743    return sum(len(self._dump_tensor_data[device_name])
744               for device_name in self._dump_tensor_data)
745
746  def _load_partition_graphs(self, client_partition_graphs, validate):
747    """Load and process partition graphs.
748
749    Load the graphs; parse the input and control input structure; obtain the
750    device and op type of each node; remove the Copy and debug ops inserted
751    by the debugger. The gathered information can be used to validate the
752    tensor dumps.
753
754    Args:
755      client_partition_graphs: A repeated field of GraphDefs representing the
756        partition graphs executed by the TensorFlow runtime, from the Python
757        client. These partition graphs are used only if partition graphs
758        cannot be loaded from the dump directory on the file system.
759      validate: (`bool`) Whether the dump files are to be validated against the
760        partition graphs.
761
762    Raises:
763      ValueError: If the partition GraphDef of one or more devices fail to be
764        loaded.
765    """
766    self._debug_graphs = {}
767    self._node_devices = {}
768
769    partition_graphs_and_device_names = []
770    for device_name in self._device_names:
771      partition_graph = None
772      if device_name in self._dump_graph_file_paths:
773        partition_graph = _load_graph_def_from_event_file(
774            self._dump_graph_file_paths[device_name])
775      else:
776        logging.warn(
777            "Failed to load partition graphs for device %s from disk. "
778            "As a fallback, the client graphs will be used. This "
779            "may cause mismatches in device names." % device_name)
780        partition_graph = self._find_partition_graph(client_partition_graphs,
781                                                     device_name)
782
783      if partition_graph:
784        partition_graphs_and_device_names.append((partition_graph,
785                                                  device_name))
786
787    for partition_graph, maybe_device_name in partition_graphs_and_device_names:
788      debug_graph = debug_graphs.DebugGraph(partition_graph,
789                                            device_name=maybe_device_name)
790      self._debug_graphs[debug_graph.device_name] = debug_graph
791      self._collect_node_devices(debug_graph)
792
793      if validate and debug_graph.device_name in self._dump_tensor_data:
794        self._validate_dump_with_graphs(debug_graph.device_name)
795
796  def _find_partition_graph(self, partition_graphs, device_name):
797    if partition_graphs is None:
798      return None
799    else:
800      for graph_def in partition_graphs:
801        for node_def in graph_def.node:
802          if node_def.device == device_name:
803            return graph_def
804      return None
805
806  def _collect_node_devices(self, debug_graph):
807    for node_name in debug_graph.node_devices:
808      if node_name in self._node_devices:
809        self._node_devices[node_name] = self._node_devices[node_name].union(
810            debug_graph.node_devices[node_name])
811      else:
812        self._node_devices[node_name] = debug_graph.node_devices[node_name]
813
814  def _validate_dump_with_graphs(self, device_name):
815    """Validate the dumped tensor data against the partition graphs.
816
817    Only the watched nodes are validated by this method, because tfdbg allows
818    clients to watch only a subset of the nodes.
819
820    Args:
821      device_name: (`str`) device name.
822
823    Raises:
824      LookupError: If the partition graphs have not been loaded yet.
825      ValueError: If dumps contain node names not found in partition graph.
826        Or if the temporal order of the dump's timestamps violate the
827        input relations on the partition graphs.
828    """
829    if not self._debug_graphs:
830      raise LookupError(
831          "No partition graphs loaded for device %s" % device_name)
832    debug_graph = self._debug_graphs[device_name]
833
834    # Verify that the node names in the dump data are all present in the
835    # partition graphs.
836    for datum in self._dump_tensor_data[device_name]:
837      if datum.node_name not in debug_graph.node_inputs:
838        raise ValueError("Node name '%s' is not found in partition graphs of "
839                         "device %s." % (datum.node_name, device_name))
840
841    pending_inputs = {}
842    for node in debug_graph.node_inputs:
843      pending_inputs[node] = []
844      inputs = debug_graph.node_inputs[node]
845      for inp in inputs:
846        inp_node = debug_graphs.get_node_name(inp)
847        inp_output_slot = debug_graphs.get_output_slot(inp)
848        # Inputs from Enter and NextIteration nodes are not validated because
849        # DebugNodeInserter::InsertNodes() in the debugger core skips creating
850        # control edges from debug ops watching these types of nodes.
851        if (inp_node in self._debug_watches[device_name] and
852            inp_output_slot in self._debug_watches[device_name][inp_node] and
853            debug_graph.node_op_types.get(inp) not in (
854                "Enter", "NextIteration") and
855            (inp_node, inp_output_slot) not in pending_inputs[node]):
856          pending_inputs[node].append((inp_node, inp_output_slot))
857
858    for i, datum in enumerate(self._dump_tensor_data[device_name]):
859      node = datum.node_name
860      slot = datum.output_slot
861      # In some cases (e.g., system clocks with insufficient precision),
862      # the upstream and downstream tensors may have identical timestamps, the
863      # following check examines this possibility and avoids raising an error if
864      # that is the case.
865      if not self._satisfied_at_timestamp(
866          device_name, pending_inputs[node], datum.timestamp, start_i=i + 1):
867        raise ValueError("Causality violated in timing relations of debug "
868                         "dumps: %s (%d): "
869                         "these input(s) are not satisfied: %s" %
870                         (node, datum.timestamp, repr(pending_inputs[node])))
871
872      recipients = debug_graph.node_recipients[node]
873      for recipient in recipients:
874        recipient_pending_inputs = pending_inputs[recipient]
875        if (node, slot) in recipient_pending_inputs:
876          if self.node_op_type(recipient) == "Merge":
877            # If this is a Merge op, we automatically clear the list because
878            # a Merge node only requires one of its two inputs.
879            del recipient_pending_inputs[:]
880          else:
881            del recipient_pending_inputs[
882                recipient_pending_inputs.index((node, slot))]
883
884  def _satisfied_at_timestamp(self, device_name, pending, timestamp, start_i=0):
885    """Determine whether pending inputs are satisfied at given timestamp.
886
887    Note: This method mutates the input argument "pending".
888
889    Args:
890      device_name: (str) device name.
891      pending: A list of 2-tuple (node_name, output_slot): the dependencies to
892        check.
893      timestamp: (int) the timestamp in question.
894      start_i: (int) the index in self._dump_tensor_data to start searching for
895        the timestamp.
896
897    Returns:
898      (bool) Whether all the dependencies in pending are satisfied at the
899        timestamp. If pending is empty to begin with, return True.
900    """
901    if not pending:
902      return True
903
904    for datum in self._dump_tensor_data[device_name][start_i:]:
905      if datum.timestamp > timestamp:
906        break
907      if (datum.timestamp == timestamp and
908          (datum.node_name, datum.output_slot) in pending):
909        pending.remove((datum.node_name, datum.output_slot))
910        if not pending:
911          return True
912
913    return not pending
914
915  def loaded_partition_graphs(self):
916    """Test whether partition graphs have been loaded."""
917    return bool(self._debug_graphs)
918
919  def partition_graphs(self):
920    """Get the partition graphs.
921
922    Returns:
923      Partition graphs as a list of GraphDef.
924
925    Raises:
926      LookupError: If no partition graphs have been loaded.
927    """
928    if not self._debug_graphs:
929      raise LookupError("No partition graphs have been loaded.")
930    return [self._debug_graphs[key].debug_graph_def
931            for key in self._debug_graphs]
932
933  def reconstructed_non_debug_partition_graphs(self):
934    """Reconstruct partition graphs with the debugger-inserted ops stripped.
935
936    The reconstructed partition graphs are identical to the original (i.e.,
937    non-debugger-decorated) partition graphs except in the following respects:
938      1) The exact names of the runtime-inserted internal nodes may differ.
939         These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops.
940      2) As a consequence of 1, the nodes that receive input directly from such
941         send- and recv-type ops will have different input names.
942      3) The parallel_iteration attribute of while-loop Enter ops are set to 1.
943
944    Returns:
945      A dict mapping device names (`str`s) to reconstructed
946      `tf.compat.v1.GraphDef`s.
947    """
948    non_debug_graphs = {}
949    for key in self._debug_graphs:
950      non_debug_graphs[key] = self._debug_graphs[key].non_debug_graph_def
951    return non_debug_graphs
952
953  @property
954  def run_fetches_info(self):
955    """Get a str representation of the fetches used in the Session.run() call.
956
957    Returns:
958      If the information is available from one `Session.run` call, a `str`
959        obtained from `repr(fetches)`.
960      If the information is available from multiple `Session.run` calls, a
961        `list` of `str` from `repr(fetches)`.
962      If the information is not available, `None`.
963    """
964
965    output = self._run_fetches_info
966    return output[0] if len(output) == 1 else output
967
968  @property
969  def run_feed_keys_info(self):
970    """Get a str representation of the feed_dict used in the Session.run() call.
971
972    Returns:
973      If the information is available from one `Session.run` call, a `str`
974        obtained from `repr(feed_dict)`.
975      If the information is available from multiple `Session.run` calls, a
976        `list` of `str` obtained from `repr(feed_dict)`.
977      If the information is not available, `None`.
978    """
979
980    output = self._run_feed_keys_info
981    return output[0] if len(output) == 1 else output
982
983  def _infer_device_name(self, device_name, node_name):
984    """Infer the device name given node name.
985
986    If device_name is provided (i.e., not None), it'll be simply returned right
987    away.
988
989    Args:
990      device_name: (str or None) name of the device. If None, will try to infer
991        the device name by looking at the available nodes.
992      node_name: (str) name of the node.
993
994    Returns:
995      (str) Inferred name of the device, if available.
996
997    Raises:
998      ValueError: If the node name does not exist on any of the available
999        devices or if there are multiple devices that contain the node with
1000        the given name.
1001    """
1002    if device_name is None:
1003      if node_name in self._node_devices:
1004        if len(self._node_devices[node_name]) == 1:
1005          return list(self._node_devices[node_name])[0]
1006        else:
1007          raise ValueError(
1008              "There are multiple (%d) devices with nodes named '%s' but "
1009              "device_name is not specified." %
1010              (len(self._node_devices[node_name]), node_name))
1011      else:
1012        raise ValueError("None of the %d device(s) has a node named '%s'." %
1013                         (len(self._device_names), node_name))
1014    else:
1015      return device_name
1016
1017  def nodes(self, device_name=None):
1018    """Get a list of all nodes from the partition graphs.
1019
1020    Args:
1021      device_name: (`str`) name of device. If None, all nodes from all available
1022        devices will be included.
1023
1024    Returns:
1025      All nodes' names, as a list of str.
1026
1027    Raises:
1028      LookupError: If no partition graphs have been loaded.
1029      ValueError: If specified node name does not exist.
1030    """
1031    if not self._debug_graphs:
1032      raise LookupError("No partition graphs have been loaded.")
1033    if device_name is None:
1034      nodes = []
1035      for device_name in self._debug_graphs:
1036        nodes.extend(self._debug_graphs[device_name].node_inputs.keys())
1037      return nodes
1038    else:
1039      if device_name not in self._debug_graphs:
1040        raise ValueError("Invalid device name: %s" % device_name)
1041      return self._debug_graphs[device_name].node_inputs.keys()
1042
1043  def node_attributes(self, node_name, device_name=None):
1044    """Get the attributes of a node.
1045
1046    Args:
1047      node_name: Name of the node in question.
1048      device_name: (`str`) name of the device. If there is only one device or if
1049        node_name exists on only one device, this argument is optional.
1050
1051    Returns:
1052      Attributes of the node.
1053
1054    Raises:
1055      LookupError: If no partition graphs have been loaded.
1056    """
1057    if not self._debug_graphs:
1058      raise LookupError("No partition graphs have been loaded.")
1059
1060    device_name = self._infer_device_name(device_name, node_name)
1061    return self._debug_graphs[device_name].node_attributes[node_name]
1062
1063  def node_inputs(self, node_name, is_control=False, device_name=None):
1064    """Get the inputs of given node according to partition graphs.
1065
1066    Args:
1067      node_name: Name of the node.
1068      is_control: (`bool`) Whether control inputs, rather than non-control
1069        inputs, are to be returned.
1070      device_name: (`str`) name of the device. If there is only one device or if
1071        node_name exists on only one device, this argument is optional.
1072
1073    Returns:
1074      (`list` of `str`) inputs to the node, as a list of node names.
1075
1076    Raises:
1077      LookupError: If node inputs and control inputs have not been loaded
1078         from partition graphs yet.
1079    """
1080    if not self._debug_graphs:
1081      raise LookupError(
1082          "Node inputs are not loaded from partition graphs yet.")
1083
1084    device_name = self._infer_device_name(device_name, node_name)
1085    if is_control:
1086      return self._debug_graphs[device_name].node_ctrl_inputs[node_name]
1087    else:
1088      return self._debug_graphs[device_name].node_inputs[node_name]
1089
1090  def transitive_inputs(self,
1091                        node_name,
1092                        include_control=True,
1093                        include_reversed_ref=False,
1094                        device_name=None,):
1095    """Get the transitive inputs of given node according to partition graphs.
1096
1097    Args:
1098      node_name: Name of the node.
1099      include_control: Include control inputs (True by default).
1100      include_reversed_ref: Whether a ref input, say from A to B, is to be also
1101        considered as an input from B to A. The rationale is that ref inputs
1102        generally let the recipient (e.g., B in this case) mutate the value of
1103        the source (e.g., A in this case). So the reverse direction of the ref
1104        edge reflects the direction of information flow.
1105      device_name: (`str`) name of the device. If there is only one device or if
1106        node_name exists on only one device, this argument is optional.
1107
1108    Returns:
1109      (`list` of `str`) all transitive inputs to the node, as a list of node
1110        names.
1111
1112    Raises:
1113      LookupError: If node inputs and control inputs have not been loaded
1114         from partition graphs yet.
1115    """
1116    if not self._debug_graphs:
1117      raise LookupError(
1118          "Node inputs are not loaded from partition graphs yet.")
1119
1120    device_name = self._infer_device_name(device_name, node_name)
1121
1122    input_lists = [self._debug_graphs[device_name].node_inputs]
1123    if include_control:
1124      input_lists.append(self._debug_graphs[device_name].node_ctrl_inputs)
1125    if include_reversed_ref:
1126      input_lists.append(
1127          self._debug_graphs[device_name].node_reversed_ref_inputs)
1128    tracer = debug_graphs.DFSGraphTracer(
1129        input_lists,
1130        skip_node_names=self._get_merge_node_names(device_name))
1131    tracer.trace(node_name)
1132    return tracer.inputs()
1133
1134  def _get_merge_node_names(self, device_name):
1135    """Lazily get a list of Merge nodes on a given device."""
1136    if device_name not in self._device_names:
1137      raise ValueError("Invalid device name: %s" % device_name)
1138
1139    if not hasattr(self, "_merge_node_names"):
1140      self._merge_node_names = {}
1141    if device_name not in self._merge_node_names:
1142      debug_graph = self._debug_graphs[device_name]
1143      self._merge_node_names[device_name] = [
1144          node for node in debug_graph.node_op_types
1145          if debug_graph.node_op_types[node] == "Merge"]
1146    return self._merge_node_names[device_name]
1147
1148  def find_some_path(self,
1149                     src_node_name,
1150                     dst_node_name,
1151                     include_control=True,
1152                     include_reversed_ref=False,
1153                     device_name=None):
1154    """Find a path between a source node and a destination node.
1155
1156    Limitation: the source and destination are required to be on the same
1157    device, i.e., this method does not yet take into account Send/Recv nodes
1158    across devices.
1159
1160    TODO(cais): Make this method work across device edges by tracing Send/Recv
1161      nodes.
1162
1163    Args:
1164      src_node_name: (`str`) name of the source node or name of an output tensor
1165        of the node.
1166      dst_node_name: (`str`) name of the destination node or name of an output
1167        tensor of the node.
1168      include_control: (`bool`) whrther control edges are considered in the
1169        graph tracing.
1170      include_reversed_ref: Whether a ref input, say from A to B, is to be also
1171        considered as an input from B to A. The rationale is that ref inputs
1172        generally let the recipient (e.g., B in this case) mutate the value of
1173        the source (e.g., A in this case). So the reverse direction of the ref
1174        edge reflects the direction of information flow.
1175      device_name: (`str`) name of the device. If there is only one device or if
1176        node_name exists on only one device, this argument is optional.
1177
1178    Returns:
1179      A path from the src_node_name to dst_node_name, as a `list` of `str`, if
1180      it exists. The list includes src_node_name as the first item and
1181      dst_node_name as the last.
1182      If such a path does not exist, `None`.
1183
1184    Raises:
1185      ValueError: If the source and destination nodes are not on the same
1186        device.
1187    """
1188    src_device_name = self._infer_device_name(device_name, src_node_name)
1189    dst_device_name = self._infer_device_name(device_name, dst_node_name)
1190
1191    if src_device_name != dst_device_name:
1192      raise ValueError(
1193          "Source (%s) and destination (%s) are not on the same device: "
1194          "%s vs. %s" % (src_node_name, dst_node_name, src_device_name,
1195                         dst_device_name))
1196
1197    input_lists = [self._debug_graphs[dst_device_name].node_inputs]
1198    debug_graph = self._debug_graphs[dst_device_name]
1199    if include_control:
1200      input_lists.append(debug_graph.node_ctrl_inputs)
1201    if include_reversed_ref:
1202      input_lists.append(debug_graph.node_reversed_ref_inputs)
1203    tracer = debug_graphs.DFSGraphTracer(
1204        input_lists,
1205        skip_node_names=self._get_merge_node_names(dst_device_name),
1206        destination_node_name=src_node_name)
1207    # Here the value of destination_node_name is src_node_name, because we
1208    # are tracing the graph from output to its inputs (i.e., going backwards
1209    # on the graph).
1210
1211    try:
1212      tracer.trace(dst_node_name)
1213    except debug_graphs.GraphTracingReachedDestination:
1214      # Prune nodes not on the path.
1215      inputs = [dst_node_name] + tracer.inputs()
1216      depth_list = [0] + tracer.depth_list()
1217
1218      path = []
1219      curr_depth = depth_list[-1]
1220      for inp, depth in zip(reversed(inputs), reversed(depth_list)):
1221        if depth == curr_depth:
1222          path.append(inp)
1223          curr_depth -= 1
1224      return path
1225
1226  def node_recipients(self, node_name, is_control=False, device_name=None):
1227    """Get recipient of the given node's output according to partition graphs.
1228
1229    Args:
1230      node_name: (`str`) name of the node.
1231      is_control: (`bool`) whether control outputs, rather than non-control
1232        outputs, are to be returned.
1233      device_name: (`str`) name of the device. If there is only one device or if
1234        node_name exists on only one device, this argument is optional.
1235
1236    Returns:
1237      (`list` of `str`) all inputs to the node, as a list of node names.
1238
1239    Raises:
1240      LookupError: If node inputs and control inputs have not been loaded
1241         from partition graphs yet.
1242    """
1243
1244    if not self._debug_graphs:
1245      raise LookupError(
1246          "Node recipients are not loaded from partition graphs yet.")
1247
1248    device_name = self._infer_device_name(device_name, node_name)
1249    debug_graph = self._debug_graphs[device_name]
1250    if is_control:
1251      return debug_graph.node_ctrl_recipients[node_name]
1252    else:
1253      return debug_graph.node_recipients[node_name]
1254
1255  def devices(self):
1256    """Get the list of device names.
1257
1258    Returns:
1259      (`list` of `str`) names of the devices.
1260    """
1261    return self._device_names
1262
1263  def node_exists(self, node_name, device_name=None):
1264    """Test if a node exists in the partition graphs.
1265
1266    Args:
1267      node_name: (`str`) name of the node to be checked.
1268      device_name: optional device name. If None, will search for the node
1269        on all available devices. Otherwise, search for the node only on
1270        the given device.
1271
1272    Returns:
1273      A boolean indicating whether the node exists.
1274
1275    Raises:
1276      LookupError: If no partition graphs have been loaded yet.
1277      ValueError: If device_name is specified but cannot be found.
1278    """
1279    if not self._debug_graphs:
1280      raise LookupError(
1281          "Nodes have not been loaded from partition graphs yet.")
1282
1283    if (device_name is not None) and device_name not in self._debug_graphs:
1284      raise ValueError(
1285          "The specified device_name '%s' cannot be found." % device_name)
1286
1287    for _, debug_graph in self._debug_graphs.items():
1288      if node_name in debug_graph.node_inputs:
1289        return True
1290    return False
1291
1292  def node_device(self, node_name):
1293    """Get the names of the devices that has nodes of the specified name.
1294
1295    Args:
1296      node_name: (`str`) name of the node.
1297
1298    Returns:
1299      (`str` or `list` of `str`) name of the device(s) on which the node of the
1300        given name is found. Returns a `str` if there is only one such device,
1301        otherwise return a `list` of `str`.
1302
1303    Raises:
1304      LookupError: If node inputs and control inputs have not been loaded
1305         from partition graphs yet.
1306      ValueError: If the node does not exist in partition graphs.
1307    """
1308    if not self._debug_graphs:
1309      raise LookupError(
1310          "Node devices are not loaded from partition graphs yet.")
1311
1312    if node_name not in self._node_devices:
1313      raise ValueError("Node '%s' does not exist in partition graphs." %
1314                       node_name)
1315
1316    output = list(self._node_devices[node_name])
1317    return output[0] if len(output) == 1 else output
1318
1319  def node_op_type(self, node_name, device_name=None):
1320    """Get the op type of given node.
1321
1322    Args:
1323      node_name: (`str`) name of the node.
1324      device_name: (`str`) name of the device. If there is only one device or if
1325        node_name exists on only one device, this argument is optional.
1326
1327    Returns:
1328      (`str`) op type of the node.
1329
1330    Raises:
1331      LookupError: If node op types have not been loaded
1332         from partition graphs yet.
1333    """
1334    if not self._debug_graphs:
1335      raise LookupError(
1336          "Node op types are not loaded from partition graphs yet.")
1337
1338    device_name = self._infer_device_name(device_name, node_name)
1339    return self._debug_graphs[device_name].node_op_types[node_name]
1340
1341  def debug_watch_keys(self, node_name, device_name=None):
1342    """Get all tensor watch keys of given node according to partition graphs.
1343
1344    Args:
1345      node_name: (`str`) name of the node.
1346      device_name: (`str`) name of the device. If there is only one device or if
1347        node_name exists on only one device, this argument is optional.
1348
1349    Returns:
1350      (`list` of `str`) all debug tensor watch keys. Returns an empty list if
1351        the node name does not correspond to any debug watch keys.
1352
1353    Raises:
1354      `LookupError`: If debug watch information has not been loaded from
1355        partition graphs yet.
1356    """
1357
1358    try:
1359      device_name = self._infer_device_name(device_name, node_name)
1360    except ValueError:
1361      return []
1362
1363    if node_name not in self._debug_watches[device_name]:
1364      return []
1365
1366    watch_keys = []
1367    for watched_slot in self._debug_watches[device_name][node_name]:
1368      debug_ops = self._debug_watches[device_name][node_name][watched_slot]
1369      for debug_op in debug_ops:
1370        watch_keys.append(
1371            _get_tensor_watch_key(node_name, watched_slot, debug_op))
1372
1373    return watch_keys
1374
1375  def watch_key_to_data(self, debug_watch_key, device_name=None):
1376    """Get all `DebugTensorDatum` instances corresponding to a debug watch key.
1377
1378    Args:
1379      debug_watch_key: (`str`) debug watch key.
1380      device_name: (`str`) name of the device. If there is only one device or if
1381        the specified debug_watch_key exists on only one device, this argument
1382        is optional.
1383
1384    Returns:
1385      A list of `DebugTensorDatum` instances that correspond to the debug watch
1386      key. If the watch key does not exist, returns an empty list.
1387
1388    Raises:
1389      ValueError: If there are multiple devices that have the debug_watch_key,
1390        but device_name is not specified.
1391    """
1392    if device_name is None:
1393      matching_device_names = [
1394          name for name in self._watch_key_to_datum
1395          if debug_watch_key in self._watch_key_to_datum[name]]
1396      if not matching_device_names:
1397        return []
1398      elif len(matching_device_names) == 1:
1399        device_name = matching_device_names[0]
1400      else:
1401        raise ValueError(
1402            "The debug watch key '%s' exists on multiple (%d) devices, but "
1403            "device name is not specified." %
1404            (debug_watch_key, len(matching_device_names)))
1405    elif device_name not in self._debug_key_to_datum:
1406      raise ValueError(
1407          "There is no device named '%s' consisting of debug watch keys." %
1408          device_name)
1409
1410    return self._watch_key_to_datum[device_name].get(debug_watch_key, [])
1411
1412  def find(self,
1413           predicate,
1414           first_n=0,
1415           device_name=None,
1416           exclude_node_names=None):
1417    """Find dumped tensor data by a certain predicate.
1418
1419    Args:
1420      predicate: A callable that takes two input arguments:
1421
1422        ```python
1423        def predicate(debug_tensor_datum, tensor):
1424          # returns a bool
1425        ```
1426
1427        where `debug_tensor_datum` is an instance of `DebugTensorDatum`, which
1428        carries the metadata, such as the `Tensor`'s node name, output slot
1429        timestamp, debug op name, etc.; and `tensor` is the dumped tensor value
1430        as a `numpy.ndarray`.
1431      first_n: (`int`) return only the first n `DebugTensotDatum` instances (in
1432        time order) for which the predicate returns True. To return all the
1433        `DebugTensotDatum` instances, let first_n be <= 0.
1434      device_name: optional device name.
1435      exclude_node_names: Optional regular expression to exclude nodes with
1436        names matching the regular expression.
1437
1438    Returns:
1439      A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object
1440       for which predicate returns True, sorted in ascending order of the
1441       timestamp.
1442    """
1443    if exclude_node_names:
1444      exclude_node_names = re.compile(exclude_node_names)
1445
1446    matched_data = []
1447    for device in (self._dump_tensor_data if device_name is None
1448                   else (self._dump_tensor_data[device_name],)):
1449      for datum in self._dump_tensor_data[device]:
1450        if exclude_node_names and exclude_node_names.match(datum.node_name):
1451          continue
1452
1453        if predicate(datum, datum.get_tensor()):
1454          matched_data.append(datum)
1455
1456          if first_n > 0 and len(matched_data) >= first_n:
1457            return matched_data
1458
1459    return matched_data
1460
1461  def get_tensor_file_paths(self,
1462                            node_name,
1463                            output_slot,
1464                            debug_op,
1465                            device_name=None):
1466    """Get the file paths from a debug-dumped tensor.
1467
1468    Args:
1469      node_name: (`str`) name of the node that the tensor is produced by.
1470      output_slot: (`int`) output slot index of tensor.
1471      debug_op: (`str`) name of the debug op.
1472      device_name: (`str`) name of the device. If there is only one device or if
1473        the specified debug_watch_key exists on only one device, this argument
1474        is optional.
1475
1476    Returns:
1477      List of file path(s) loaded. This is a list because each debugged tensor
1478        may be dumped multiple times.
1479
1480    Raises:
1481      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in
1482        the debug-dump data.
1483    """
1484
1485    device_name = self._infer_device_name(device_name, node_name)
1486    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1487    if watch_key not in self._watch_key_to_datum[device_name]:
1488      raise WatchKeyDoesNotExistInDebugDumpDirError(
1489          "Watch key \"%s\" does not exist in the debug dump of device %s" %
1490          (watch_key, device_name))
1491
1492    return [datum.file_path for datum in
1493            self._watch_key_to_datum[device_name][watch_key]]
1494
1495  def get_tensors(self, node_name, output_slot, debug_op, device_name=None):
1496    """Get the tensor value from for a debug-dumped tensor.
1497
1498    The tensor may be dumped multiple times in the dump root directory, so a
1499    list of tensors (`numpy.ndarray`) is returned.
1500
1501    Args:
1502      node_name: (`str`) name of the node that the tensor is produced by.
1503      output_slot: (`int`) output slot index of tensor.
1504      debug_op: (`str`) name of the debug op.
1505      device_name: (`str`) name of the device. If there is only one device or if
1506        the specified debug_watch_key exists on only one device, this argument
1507        is optional.
1508
1509    Returns:
1510      List of tensors (`numpy.ndarray`) loaded from the debug-dump file(s).
1511
1512    Raises:
1513      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in
1514        the debug-dump data.
1515    """
1516
1517    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1518    try:
1519      device_name = self._infer_device_name(device_name, node_name)
1520      return [datum.get_tensor() for datum in
1521              self._watch_key_to_datum[device_name][watch_key]]
1522    except (ValueError, KeyError):
1523      raise WatchKeyDoesNotExistInDebugDumpDirError(
1524          "Watch key \"%s\" does not exist in the debug dump of device %s" %
1525          (watch_key, device_name))
1526
1527  def get_rel_timestamps(self,
1528                         node_name,
1529                         output_slot,
1530                         debug_op,
1531                         device_name=None):
1532    """Get the relative timestamp from for a debug-dumped tensor.
1533
1534    Relative timestamp means (absolute timestamp - `t0`), where `t0` is the
1535    absolute timestamp of the first dumped tensor in the dump root. The tensor
1536    may be dumped multiple times in the dump root directory, so a list of
1537    relative timestamps (`numpy.ndarray`) is returned.
1538
1539    Args:
1540      node_name: (`str`) name of the node that the tensor is produced by.
1541      output_slot: (`int`) output slot index of tensor.
1542      debug_op: (`str`) name of the debug op.
1543      device_name: (`str`) name of the device. If there is only one device or if
1544        the specified debug_watch_key exists on only one device, this argument
1545        is optional.
1546
1547    Returns:
1548      (`list` of `int`) list of relative timestamps.
1549
1550    Raises:
1551      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not
1552        exist in the debug dump data.
1553    """
1554
1555    device_name = self._infer_device_name(device_name, node_name)
1556    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1557    if watch_key not in self._watch_key_to_datum[device_name]:
1558      raise WatchKeyDoesNotExistInDebugDumpDirError(
1559          "Watch key \"%s\" does not exist in the debug dump" % watch_key)
1560
1561    # TODO(cais): Figure out whether this should be relative to the global t0.
1562    return self._watch_key_to_rel_time[device_name][watch_key]
1563
1564  def get_dump_sizes_bytes(self,
1565                           node_name,
1566                           output_slot,
1567                           debug_op,
1568                           device_name=None):
1569    """Get the sizes of the dump files for a debug-dumped tensor.
1570
1571    Unit of the file size: byte.
1572
1573    Args:
1574      node_name: (`str`) name of the node that the tensor is produced by.
1575      output_slot: (`int`) output slot index of tensor.
1576      debug_op: (`str`) name of the debug op.
1577      device_name: (`str`) name of the device. If there is only one device or if
1578        the specified debug_watch_key exists on only one device, this argument
1579        is optional.
1580
1581    Returns:
1582      (`list` of `int`): list of dump file sizes in bytes.
1583
1584    Raises:
1585      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not
1586        exist in the debug dump data.
1587    """
1588
1589    device_name = self._infer_device_name(device_name, node_name)
1590    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1591    if watch_key not in self._watch_key_to_datum[device_name]:
1592      raise WatchKeyDoesNotExistInDebugDumpDirError(
1593          "Watch key \"%s\" does not exist in the debug dump of device %s" %
1594          (watch_key, device_name))
1595
1596    return self._watch_key_to_dump_size_bytes[device_name][watch_key]
1597
1598  def node_traceback(self, element_name):
1599    """Try to retrieve the Python traceback of node's construction.
1600
1601    Args:
1602      element_name: (`str`) Name of a graph element (node or tensor).
1603
1604    Returns:
1605      (list) The traceback list object as returned by the `extract_trace`
1606        method of Python's traceback module.
1607
1608    Raises:
1609      LookupError: If Python graph is not available for traceback lookup.
1610      KeyError: If the node cannot be found in the Python graph loaded.
1611    """
1612
1613    if self._python_graph is None:
1614      raise LookupError("Python graph is not available for traceback lookup")
1615
1616    node_name = debug_graphs.get_node_name(element_name)
1617    if node_name not in self._node_traceback:
1618      raise KeyError("Cannot find node \"%s\" in Python graph" % node_name)
1619
1620    return self._node_traceback[node_name]
1621