xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/dumping_callback_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for tfdbg v2 dumping callback."""
16
17import collections
18import os
19import shutil
20import socket
21import tempfile
22import threading
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.core.protobuf import debug_event_pb2
28from tensorflow.python.debug.lib import debug_events_reader
29from tensorflow.python.debug.lib import dumping_callback
30from tensorflow.python.debug.lib import dumping_callback_test_lib
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import test_util
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import variables
40from tensorflow.python.platform import googletest
41from tensorflow.python.platform import test
42from tensorflow.python.platform import tf_logging
43
44
45_host_name = socket.gethostname()
46_current_file_full_path = os.path.abspath(__file__)
47
48
49class DumpingCallbackTest(
50    dumping_callback_test_lib.DumpingCallbackTestBase, parameterized.TestCase):
51
52  def setUp(self):
53    super(DumpingCallbackTest, self).setUp()
54    self.dump_root = tempfile.mkdtemp()
55
56  def tearDown(self):
57    if os.path.isdir(self.dump_root):
58      shutil.rmtree(self.dump_root, ignore_errors=True)
59    dumping_callback.disable_dump_debug_info()
60    super(DumpingCallbackTest, self).tearDown()
61
62  def _verifyStackFrames(self, stack_frames):
63    """Verify the correctness of the stack frames.
64
65    Currently, it simply asserts that the current file is found in the stack
66    frames.
67    TODO(cais): Perhaps implement a stricter check later.
68
69    Args:
70      stack_frames: The stack frames to verify.
71    """
72    self.assertTrue([
73        frame for frame in stack_frames if frame[0] == _current_file_full_path])
74
75  def _expectedDefaultDeviceName(self):
76    gpu_name = test_util.gpu_device_name()
77    if gpu_name:
78      return "/job:localhost/replica:0/task:0" + gpu_name
79    else:
80      return "/job:localhost/replica:0/task:0/device:CPU:0"
81
82  def testInvalidTensorDebugModeCausesError(self):
83    with self.assertRaisesRegex(
84        ValueError, r"Invalid value in tensor_debug_mode \(\'NONSENSICAL\'\).*"
85        r"Valid options.*NO_TENSOR.*"):
86      dumping_callback.enable_dump_debug_info(
87          self.dump_root, tensor_debug_mode="NONSENSICAL")
88
89  @parameterized.named_parameters(
90      ("NoTensor", "NO_TENSOR"),
91      ("CurtHealth", "CURT_HEALTH"),
92      ("ConciseHealth", "CONCISE_HEALTH"),
93      ("Shape", "SHAPE"),
94      ("FulHealth", "FULL_HEALTH"),
95      ("FullTensor", "FULL_TENSOR"),
96  )
97  def testEnableDumpDebugInfoLogsTensorDebugModeAsStringName(self,
98                                                             tensor_debug_mode):
99    log_messages = []
100    def fake_logging_info(*args):
101      log_messages.append(args)
102    with test.mock.patch.object(
103        tf_logging, "info", side_effect=fake_logging_info):
104      dumping_callback.enable_dump_debug_info(
105          self.dump_root, tensor_debug_mode=tensor_debug_mode)
106      self.assertLen(log_messages, 1)
107      self.assertIn(self.dump_root, log_messages[0])
108      self.assertIn(tensor_debug_mode, log_messages[0])
109
110  def testDisablingTracingCallbackWithoutEnablingFirstIsTolerated(self):
111    dumping_callback.disable_dump_debug_info()
112
113  @parameterized.named_parameters(
114      ("NoTensor", "NO_TENSOR"),
115      ("CurtHealth", "CURT_HEALTH"),
116      ("ConciseHealth", "CONCISE_HEALTH"),
117      ("Shape", "SHAPE"),
118      ("FullHealth", "FULL_HEALTH"),
119      ("FullTensor", "FULL_TENSOR"),
120  )
121  def testPureEagerOpExecution(self, tensor_debug_mode):
122    """Test dumping data from eager op execution: float32."""
123
124    x = constant_op.constant(10.0)
125    zero = constant_op.constant(0.0)
126    one = constant_op.constant(1.0)
127    two = constant_op.constant(2.0)
128    three = constant_op.constant(3.0)
129    writer = dumping_callback.enable_dump_debug_info(
130        self.dump_root, tensor_debug_mode=tensor_debug_mode)
131    # Use Collatz conjecture as a test case.
132    while x > one:
133      if math_ops.equal(x % two, zero):
134        x = x / two
135      else:
136        x = x * three + one
137
138    writer.FlushNonExecutionFiles()
139    self._readAndCheckMetadataFile()
140
141    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
142      reader.update()
143      # Before FlushExecutionFiles() is called, the .execution file should be
144      # empty.
145      self.assertFalse(reader.executions())
146
147      # After the flushing, the .execution file should hold the appropriate
148      # contents.
149      writer.FlushExecutionFiles()
150      reader.update()
151      executions = reader.executions()
152      prev_wall_time = 1
153      executed_op_types = []
154      tensor_values = collections.defaultdict(lambda: [])
155      for execution in executions:
156        self.assertGreaterEqual(execution.wall_time, prev_wall_time)
157        prev_wall_time = execution.wall_time
158        executed_op_types.append(execution.op_type)
159        # Check the device name.
160        if execution.op_type in ("AddV2", "Mul", "RealDiv"):
161          self.assertLen(execution.output_tensor_device_ids, 1)
162          self.assertEqual(
163              reader.device_name_by_id(execution.output_tensor_device_ids[0]),
164              self._expectedDefaultDeviceName(),
165              "Unexpected device name from eager op %s" % execution.op_type)
166
167        # No graph IDs should have been logged for eager op executions.
168        self.assertFalse(execution.graph_id)
169        self.assertTrue(execution.input_tensor_ids)
170        self.assertTrue(execution.output_tensor_ids)
171        self.assertEqual(
172            debug_event_pb2.TensorDebugMode.keys()[execution.tensor_debug_mode],
173            tensor_debug_mode)
174        if tensor_debug_mode == "NO_TENSOR":
175          # Due to the NO_TENSOR tensor debug mode, tensor_protos ought to
176          # be empty.
177          self.assertFalse(execution.debug_tensor_values)
178        elif tensor_debug_mode == "CURT_HEALTH":
179          self.assertLen(execution.debug_tensor_values, 1)
180          if execution.op_type in ("AddV2", "Mul", "RealDiv"):
181            # 1st element: -1 is the unset tensor_id for eager op execution.
182            # 2nd element: 0 means there is no inf or nan.
183            self.assertAllClose(execution.debug_tensor_values, [[-1.0, 0.0]])
184        elif tensor_debug_mode == "CONCISE_HEALTH":
185          if execution.op_type in ("AddV2", "Mul", "RealDiv"):
186            # 1st element: -1 is the unset tensor_id for eager op execution.
187            # 2nd element: each scalar tensor has 1 element.
188            # Remaining elements: no -inf, inf or nan in these
189            self.assertAllClose(
190                execution.debug_tensor_values, [[-1, 1, 0, 0, 0]])
191        elif tensor_debug_mode == "FULL_HEALTH":
192          if execution.op_type in ("AddV2", "Mul", "RealDiv"):
193            # Elements: [
194            #   -1 is the unset tensor_id for eager op execution,
195            #   device ID (set to -1 for now),
196            #   dtype, rank, element_count,
197            #   neg_inf_count, pos_inf_count, nan_count
198            #   neg_finite_count, zero_count, pos_finite_count]
199            self.assertAllClose(
200                execution.debug_tensor_values,
201                [[-1, -1, 1, 0, 1, 0, 0, 0, 0, 0, 1]])
202        elif tensor_debug_mode == "SHAPE":
203          if execution.op_type in ("AddV2", "Mul", "RealDiv"):
204            # 1st element: -1 is the unset tensor_id for eager op execution.
205            # 2nd element: dtype enum value (float32).
206            # 3rd element: rank (scalar).
207            # 4th element: element count (4).
208            # Remaining elements: shape at fixed length (6).
209            self.assertAllClose(execution.debug_tensor_values,
210                                [[-1, 1, 0, 1, 0, 0, 0, 0, 0, 0]])
211        elif tensor_debug_mode == "FULL_TENSOR":
212          tensor_values[execution.op_type].append(
213              reader.execution_to_tensor_values(execution)[0])
214
215        host_name, stack_frames = reader.read_execution_stack_trace(execution)
216        self.assertEqual(host_name, _host_name)
217        self._verifyStackFrames(stack_frames)
218
219      if tensor_debug_mode == "FULL_TENSOR":
220        self.assertAllClose(tensor_values["Greater"], [1, 1, 1, 1, 1, 1, 0])
221        self.assertAllClose(tensor_values["RealDiv"], [5, 8, 4, 2, 1])
222        self.assertAllClose(tensor_values["Mul"], [15])
223        self.assertAllClose(tensor_values["AddV2"], [16])
224
225      self.assertEqual(
226          executed_op_types,
227          [
228              "Greater",
229              "FloorMod",
230              "Equal",
231              "RealDiv",  # 10 --> 5
232              "Greater",
233              "FloorMod",
234              "Equal",
235              "Mul",
236              "AddV2",  # 5 --> 16
237              "Greater",
238              "FloorMod",
239              "Equal",
240              "RealDiv",  # 16 --> 8
241              "Greater",
242              "FloorMod",
243              "Equal",
244              "RealDiv",  # 8 --> 4
245              "Greater",
246              "FloorMod",
247              "Equal",
248              "RealDiv",  # 4 --> 2
249              "Greater",
250              "FloorMod",
251              "Equal",
252              "RealDiv",  # 2 --> 1
253              "Greater"
254          ])
255
256      # Due to the pure eager op execution, the .graph file and the
257      # .graph_execution_traces file ought to be empty.
258      self.assertFalse(reader.outermost_graphs())
259      self.assertEqual(reader.num_graph_execution_traces(), 0)
260
261  @parameterized.named_parameters(
262      ("CurtHealth", "CURT_HEALTH"),
263      ("ConciseHealth", "CONCISE_HEALTH"),
264      ("FullHealth", "FULL_HEALTH"),
265      ("Shape", "SHAPE"),
266  )
267  @test_util.run_in_graph_and_eager_modes
268  def testModesSummarizingBadNumericalValue(self, tensor_debug_mode):
269    writer = dumping_callback.enable_dump_debug_info(
270        self.dump_root, tensor_debug_mode=tensor_debug_mode)
271
272    @def_function.function
273    def func(x, y):
274      return (x + y) / (x - y)
275
276    x = np.array([-3, -1, 0, 0, 1, 1, 1, 2], dtype=np.float16)
277    y = np.array([2, -1, 0, 0, 1, 1, 1, 3], dtype=np.float16)
278    # x - y = [-5, 0, 0, 0, 0, 0, 0, -1]
279    # (x + y) / (x - y) = [0.2, -inf, nan, nan, inf, inf, inf, -5].
280    self.evaluate(func(x, y))
281    writer.FlushNonExecutionFiles()
282    writer.FlushExecutionFiles()
283
284    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
285      reader.update()
286      graph_exec_traces = reader.graph_execution_traces()
287      executed_op_types = [trace.op_type for trace in graph_exec_traces
288                           if trace.op_type != "Const"]
289      self.assertCountEqual(
290          executed_op_types,
291          ["Placeholder", "Placeholder", "AddV2", "Sub", "RealDiv"])
292      if tensor_debug_mode == "CURT_HEALTH":
293        for trace in graph_exec_traces:
294          # 1st element: tensor_id, should be >= 0.
295          # 2nd element: indicates if there is any inf or nan.
296          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
297          self.assertGreaterEqual(tensor_id, 0)
298          if trace.op_type == "RealDiv":
299            self.assertAllClose(trace.debug_tensor_value, [tensor_id, 1])
300          else:
301            self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0])
302      elif tensor_debug_mode == "CONCISE_HEALTH":
303        for trace in graph_exec_traces:
304          # 1st element: tensor_id, should be >= 0.
305          # 2nd element: element count (8).
306          # Remaining 3 elements: The counts of -inf, inf and nan.
307          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
308          self.assertGreaterEqual(tensor_id, 0)
309          if trace.op_type == "RealDiv":
310            self.assertAllClose(trace.debug_tensor_value,
311                                [tensor_id, 8, 1, 3, 2])
312          else:
313            self.assertAllClose(trace.debug_tensor_value,
314                                [tensor_id, 8, 0, 0, 0])
315      elif tensor_debug_mode == "FULL_HEALTH":
316        for trace in graph_exec_traces:
317          # Elements: [
318          #   -1 is the unset tensor_id for eager op execution,
319          #   device ID (set to -1 for now),
320          #   dtype, rank, element_count,
321          #   neg_inf_count, pos_inf_count, nan_count
322          #   neg_finite_count, zero_count, pos_finite_count]
323          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
324          self.assertGreaterEqual(tensor_id, 0)
325          if trace.op_type == "RealDiv":
326            self.assertAllClose(trace.debug_tensor_value,
327                                [tensor_id, -1, 19, 1, 8, 1, 3, 2, 1, 0, 1])
328          elif trace.op_type == "Sub":
329            self.assertAllClose(trace.debug_tensor_value,
330                                [tensor_id, -1, 19, 1, 8, 0, 0, 0, 2, 6, 0])
331      else:  # SHAPE.
332        for trace in graph_exec_traces:
333          # 1st element: tensor_id, should be >= 0.
334          # 2nd element: dtype enum value (float16 = 19).
335          # 3rd element: rank (1)
336          # 4th element: element count (8).
337          # Remaining elements: shape at fixed length (6).
338          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
339          self.assertGreaterEqual(tensor_id, 0)
340          self.assertAllClose(trace.debug_tensor_value,
341                              [tensor_id, 19, 1, 8, 8, 0, 0, 0, 0, 0])
342
343  @parameterized.named_parameters(
344      ("CurtHealth", "CURT_HEALTH"),
345      ("FullTensor", "FULL_TENSOR"),
346  )
347  @test_util.run_in_graph_and_eager_modes
348  def testConstTensorsAreCaptured(self, tensor_debug_mode):
349    writer = dumping_callback.enable_dump_debug_info(
350        self.dump_root, tensor_debug_mode=tensor_debug_mode)
351    @def_function.function
352    def times_two_plus_three(x):
353      return x * constant_op.constant(2.0) + constant_op.constant(3.0)
354    self.assertAllEqual(
355        self.evaluate(times_two_plus_three(10.0)), 23.0)
356    writer.FlushNonExecutionFiles()
357    writer.FlushExecutionFiles()
358
359    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
360      reader.update()
361      const_traces = [trace for trace in reader.graph_execution_traces()
362                      if trace.op_type == "Const"]
363      self.assertGreaterEqual(len(const_traces), 3)
364      if tensor_debug_mode == "CURT_HEALTH":
365        # Under CURT_HEALTH, each debug tensor value has the form
366        # [tensor_id, has_inf_or_nan].
367        self.assertLen(const_traces[0].debug_tensor_value, 2)
368        self.assertEqual(const_traces[0].debug_tensor_value[1], 0)
369        self.assertLen(const_traces[1].debug_tensor_value, 2)
370        self.assertEqual(const_traces[1].debug_tensor_value[1], 0)
371        self.assertLen(const_traces[2].debug_tensor_value, 2)
372        self.assertEqual(const_traces[2].debug_tensor_value[1], 0)
373      else:  # FULL_TENSOR.
374        const_tensor_values = [
375            reader.graph_execution_trace_to_tensor_value(const_trace)
376            for const_trace in const_traces]
377        # Avoid making assertion on the particular order of the debug tensors
378        # for the three Consts because it may be indeterminate.
379        self.assertIn(10.0, const_tensor_values)
380        self.assertIn(2.0, const_tensor_values)
381        self.assertIn(3.0, const_tensor_values)
382
383  @parameterized.named_parameters(
384      ("Shape", "SHAPE"),
385  )
386  @test_util.run_in_graph_and_eager_modes
387  def testBooleanTensors(self, tensor_debug_mode):
388    writer = dumping_callback.enable_dump_debug_info(
389        self.dump_root, tensor_debug_mode=tensor_debug_mode)
390
391    @def_function.function
392    def func(x, y):
393      return math_ops.logical_not(math_ops.logical_and(x, y))
394
395    x = np.array([[False, False], [True, True]], dtype=np.bool_)
396    y = np.array([[False, True], [False, True]], dtype=np.bool_)
397    self.assertAllEqual(
398        self.evaluate(func(x, y)), [[True, True], [True, False]])
399
400    writer.FlushNonExecutionFiles()
401    writer.FlushExecutionFiles()
402
403    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
404      reader.update()
405      graph_exec_traces = reader.graph_execution_traces()
406      executed_op_types = [trace.op_type for trace in graph_exec_traces
407                           if trace.op_type != "Const"]
408      self.assertEqual(
409          executed_op_types,
410          ["Placeholder", "Placeholder", "LogicalAnd", "LogicalNot"])
411      for trace in graph_exec_traces:
412        tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
413        self.assertGreaterEqual(tensor_id, 0)
414        # 1st element: tensor_id, should be >= 0.
415        # 2nd element: dtype enum value (bool).
416        # 3rd element: rank (2).
417        # 4th element: element count (4).
418        # Remaining elements: shape at fixed length.
419        self.assertAllClose(
420            trace.debug_tensor_value, [tensor_id, 10, 2, 4, 2, 2, 0, 0, 0, 0])
421
422  def testListingSourceFiles(self):
423    writer = dumping_callback.enable_dump_debug_info(self.dump_root)
424    # Run a simple eager execution event, so that the source files are dumped.
425    self.assertAllClose(math_ops.truediv(7.0, 1.0 / 6.0), 42.0)
426    writer.FlushNonExecutionFiles()
427    writer.FlushExecutionFiles()
428    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
429      reader.update()
430      source_file_list = reader.source_file_list()
431      self.assertIsInstance(source_file_list, tuple)
432      for item in source_file_list:
433        self.assertIsInstance(item, tuple)
434        self.assertLen(item, 2)
435      self.assertIn((_host_name, _current_file_full_path), source_file_list)
436
437  def testReadingSourceLines(self):
438    writer = dumping_callback.enable_dump_debug_info(self.dump_root)
439    # Run a simple eager execution event, so that the source-file contents are
440    # dumped.
441    self.assertAllClose(math_ops.truediv(7.0, 1.0 / 6.0), 42.0)
442    writer.FlushNonExecutionFiles()
443    writer.FlushExecutionFiles()
444    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
445      reader.update()
446      with open(_current_file_full_path, "rt") as f:
447        file_lines = f.read().split("\n")
448      self.assertEqual(
449          reader.source_lines(_host_name, _current_file_full_path), file_lines)
450
451  @parameterized.named_parameters(
452      ("NoTensor", "NO_TENSOR"),
453      ("CurtHealth", "CURT_HEALTH"),
454      ("ConciseHealth", "CONCISE_HEALTH"),
455      ("FullHealth", "FULL_HEALTH"),
456      ("Shape", "SHAPE"),
457      ("FullTensor", "FULL_TENSOR"),
458  )
459  @test_util.run_in_graph_and_eager_modes
460  def testNestedFunctionExecutionWithoutControlFlow(self, tensor_debug_mode):
461    x = constant_op.constant(2.0)
462    y = constant_op.constant(3.0)
463    writer = dumping_callback.enable_dump_debug_info(
464        self.dump_root, tensor_debug_mode=tensor_debug_mode)
465
466    @def_function.function
467    def log_sum(x, y):
468      return math_ops.log(x + y)
469
470    @def_function.function
471    def sin1p_log_sum(x, y):
472      return math_ops.sin(1.0 + log_sum(x, y))
473
474    self.assertAllClose(sin1p_log_sum(x, y), np.sin(1.0 + np.log(5.0)))
475    writer.FlushNonExecutionFiles()
476    writer.FlushExecutionFiles()
477
478    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
479      reader.update()
480      outermost_graphs = reader.outermost_graphs()
481      self.assertLen(outermost_graphs, 1)
482
483      if context.executing_eagerly():
484        # NOTE(b/142486213): Execution of the TF function happens with
485        # Session.run() in v1 graph mode, so doesn't get logged to the
486        # .execution file.
487        executions = reader.executions()
488        self.assertLen(executions, 1)
489        self.assertIn("sin1p_log_sum", executions[0].op_type)
490        # Get the executed graph and verify its identity and inner graph.
491        graph = reader.graph_by_id(executions[0].graph_id)
492        self.assertEqual(graph.name, "sin1p_log_sum")
493        self.assertLen(graph.inner_graph_ids, 1)
494        inner_graph = reader.graph_by_id(graph.inner_graph_ids[0])
495        self.assertEqual(inner_graph.name, "log_sum")
496        # Check device names.
497        self.assertLen(executions[0].output_tensor_device_ids, 1)
498        self.assertEqual(
499            reader.device_name_by_id(executions[0].output_tensor_device_ids[0]),
500            self._expectedDefaultDeviceName())
501        self.assertIn(self._expectedDefaultDeviceName(),
502                      set(reader.device_name_map().values()))
503
504      # Verify the recorded graph-building history.
505      placeholder_op_digests = reader.graph_op_digests(op_type="Placeholder")
506      add_op_digests = reader.graph_op_digests(op_type="AddV2")
507      self.assertLen(add_op_digests, 2)
508      self.assertEqual(
509          reader.graph_by_id(add_op_digests[0].graph_id).name, "log_sum")
510      self.assertEqual(
511          reader.graph_by_id(add_op_digests[1].graph_id).name, "sin1p_log_sum")
512      log_op_digests = reader.graph_op_digests(op_type="Log")
513      self.assertLen(log_op_digests, 1)
514      self.assertEqual(
515          reader.graph_by_id(log_op_digests[0].graph_id).name, "log_sum")
516      sin_op_digests = reader.graph_op_digests(op_type="Sin")
517      self.assertLen(sin_op_digests, 1)
518      self.assertEqual(
519          reader.graph_by_id(sin_op_digests[0].graph_id).name, "sin1p_log_sum")
520
521      # Verify the output tensor IDs and the stack traces.
522      for op_digest in add_op_digests + log_op_digests + sin_op_digests:
523        # These are all single-output ops.
524        self.assertLen(op_digest.output_tensor_ids, 1)
525        self.assertGreaterEqual(op_digest.output_tensor_ids[0], 0)
526        _, stack_frames = reader.read_graph_op_creation_stack_trace(op_digest)
527        self._verifyStackFrames(stack_frames)
528
529      graph_exec_traces = [trace for trace in reader.graph_execution_traces()
530                           if trace.op_type != "Const"]
531      executed_op_types = [digest.op_type for digest in graph_exec_traces]
532      self.assertEqual(
533          executed_op_types,
534          ["Placeholder", "Placeholder", "Placeholder", "Placeholder",
535           "AddV2", "Log", "AddV2", "Sin"])
536      placeholder_traces = graph_exec_traces[:4]
537      non_placeholder_traces = graph_exec_traces[4:]
538
539      # Verify the graph ID stack of each op.
540      # The outer function's 1st Placeholder.
541      self.assertEqual(
542          reader.graph_by_id(placeholder_traces[0].graph_ids[-1]).name,
543          "sin1p_log_sum")
544      # The outer function's 2nd Placeholder.
545      self.assertEqual(
546          reader.graph_by_id(placeholder_traces[1].graph_ids[-1]).name,
547          "sin1p_log_sum")
548      # The inner function's 1st Placeholder.
549      self.assertEqual(
550          reader.graph_by_id(placeholder_traces[2].graph_ids[-1]).name,
551          "log_sum")
552      self.assertEqual(
553          reader.graph_by_id(placeholder_traces[2].graph_ids[-2]).name,
554          "sin1p_log_sum")
555      # The inner function's 2nd Placeholder.
556      self.assertEqual(
557          reader.graph_by_id(placeholder_traces[3].graph_ids[-1]).name,
558          "log_sum")
559      self.assertEqual(
560          reader.graph_by_id(placeholder_traces[3].graph_ids[-2]).name,
561          "sin1p_log_sum")
562      # 1st AddV2 op.
563      self.assertEqual(
564          reader.graph_by_id(non_placeholder_traces[0].graph_ids[-1]).name,
565          "log_sum")
566      self.assertEqual(
567          reader.graph_by_id(non_placeholder_traces[0].graph_ids[-2]).name,
568          "sin1p_log_sum")
569      # Log op.
570      self.assertEqual(
571          reader.graph_by_id(non_placeholder_traces[1].graph_ids[-1]).name,
572          "log_sum")
573      self.assertEqual(
574          reader.graph_by_id(non_placeholder_traces[1].graph_ids[-2]).name,
575          "sin1p_log_sum")
576      # 2nd AddV2 op.
577      self.assertEqual(
578          reader.graph_by_id(non_placeholder_traces[2].graph_ids[-1]).name,
579          "sin1p_log_sum")
580      # Sin op.
581      self.assertEqual(
582          reader.graph_by_id(non_placeholder_traces[3].graph_ids[-1]).name,
583          "sin1p_log_sum")
584
585      if tensor_debug_mode == "NO_TENSOR":
586        # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
587        # to be an empty float32 tensor.
588        for trace in graph_exec_traces:
589          self.assertIsNone(trace.debug_tensor_value)
590      elif tensor_debug_mode == "CURT_HEALTH":
591        # Test the association between graph exec and prior graph building.
592        # In each case, the 1st element of debug_tensor_value is the ID of the
593        # symbolic tenosr and the 2nd element is a zero indicating there is no
594        # inf or nan.
595        self.assertAllClose(  # 1st outer placeholder.
596            placeholder_traces[0].debug_tensor_value,
597            [placeholder_op_digests[0].output_tensor_ids[0], 0.0])
598        self.assertAllClose(  # 2nd outer placeholder.
599            placeholder_traces[1].debug_tensor_value,
600            [placeholder_op_digests[1].output_tensor_ids[0], 0.0])
601        self.assertAllClose(  # 1st inner placeholder.
602            placeholder_traces[2].debug_tensor_value,
603            [placeholder_op_digests[2].output_tensor_ids[0], 0.0])
604        self.assertAllClose(  # 2nd outer placeholder.
605            placeholder_traces[3].debug_tensor_value,
606            [placeholder_op_digests[3].output_tensor_ids[0], 0.0])
607        self.assertAllClose(  # 1st AddV2 op.
608            non_placeholder_traces[0].debug_tensor_value,
609            [add_op_digests[0].output_tensor_ids[0], 0.0])
610        self.assertAllClose(  # Log op.
611            non_placeholder_traces[1].debug_tensor_value,
612            [log_op_digests[0].output_tensor_ids[0], 0.0])
613        self.assertAllClose(  # 2nd AddV2 op.
614            non_placeholder_traces[2].debug_tensor_value,
615            [add_op_digests[1].output_tensor_ids[0], 0.0])
616        self.assertAllClose(  # Sin op.
617            non_placeholder_traces[3].debug_tensor_value,
618            [sin_op_digests[0].output_tensor_ids[0], 0.0])
619      elif tensor_debug_mode == "CONCISE_HEALTH":
620        # 1st element: tensor_id.
621        # 2nd element: element count. Remaining elements: all zero because there
622        # is no -inf, inf or nan.
623        self.assertAllClose(  # 1st outer placeholder.
624            placeholder_traces[0].debug_tensor_value,
625            [placeholder_op_digests[0].output_tensor_ids[0], 1., 0., 0., 0.])
626        self.assertAllClose(  # 2nd outer placeholder.
627            placeholder_traces[1].debug_tensor_value,
628            [placeholder_op_digests[1].output_tensor_ids[0], 1., 0., 0., 0.])
629        self.assertAllClose(  # 1st inner placeholder.
630            placeholder_traces[2].debug_tensor_value,
631            [placeholder_op_digests[2].output_tensor_ids[0], 1., 0., 0., 0.])
632        self.assertAllClose(  # 2nd outer placeholder.
633            placeholder_traces[3].debug_tensor_value,
634            [placeholder_op_digests[3].output_tensor_ids[0], 1., 0., 0., 0.])
635        # 1st AddV2 op.
636        self.assertAllClose(
637            non_placeholder_traces[0].debug_tensor_value,
638            [add_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
639        # Log op.
640        self.assertAllClose(
641            non_placeholder_traces[1].debug_tensor_value,
642            [log_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
643        # 2nd AddV2 op.
644        self.assertAllClose(
645            non_placeholder_traces[2].debug_tensor_value,
646            [add_op_digests[1].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
647        # Sin op.
648        self.assertAllClose(
649            non_placeholder_traces[3].debug_tensor_value,
650            [sin_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
651      elif tensor_debug_mode == "FULL_HEALTH":
652        # Elements: [
653        #   -1 is the unset tensor_id for eager op execution,
654        #   device ID (set to -1 for now),
655        #   dtype, rank, element_count,
656        #   neg_inf_count, pos_inf_count, nan_count
657        #   neg_finite_count, zero_count, pos_finite_count]
658        self.assertAllClose(  # 1st outer placeholder.
659            placeholder_traces[0].debug_tensor_value,
660            [placeholder_op_digests[0].output_tensor_ids[0],
661             -1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
662        self.assertAllClose(  # 2nd outer placeholder.
663            placeholder_traces[1].debug_tensor_value,
664            [placeholder_op_digests[1].output_tensor_ids[0],
665             -1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
666        self.assertAllClose(  # 1st inner placeholder.
667            placeholder_traces[2].debug_tensor_value,
668            [placeholder_op_digests[2].output_tensor_ids[0],
669             -1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
670        self.assertAllClose(  # 2nd outer placeholder.
671            placeholder_traces[3].debug_tensor_value,
672            [placeholder_op_digests[3].output_tensor_ids[0],
673             -1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
674        # 1st AddV2 op.
675        self.assertAllClose(
676            non_placeholder_traces[0].debug_tensor_value,
677            [add_op_digests[0].output_tensor_ids[0],
678             -1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
679        # Log op.
680        self.assertAllClose(
681            non_placeholder_traces[1].debug_tensor_value,
682            [log_op_digests[0].output_tensor_ids[0],
683             -1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
684        # 2nd AddV2 op.
685        self.assertAllClose(
686            non_placeholder_traces[2].debug_tensor_value,
687            [add_op_digests[1].output_tensor_ids[0],
688             -1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
689        # Sin op.
690        self.assertAllClose(
691            non_placeholder_traces[3].debug_tensor_value,
692            [sin_op_digests[0].output_tensor_ids[0],
693             -1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
694      elif tensor_debug_mode == "SHAPE":
695        # 1st element: tensor_id.
696        # 2nd element: dtype (float32).
697        # 3rd element: rank (scalar).
698        # 4th element: element count (1).
699        # Remaining elements: shape padded to fixed length (6).
700        self.assertAllClose(  # 1st outer placeholder.
701            placeholder_traces[0].debug_tensor_value,
702            [placeholder_op_digests[0].output_tensor_ids[0],
703             1, 0, 1, 0, 0, 0, 0, 0, 0])
704        self.assertAllClose(  # 2nd outer placeholder.
705            placeholder_traces[1].debug_tensor_value,
706            [placeholder_op_digests[1].output_tensor_ids[0],
707             1, 0, 1, 0, 0, 0, 0, 0, 0])
708        self.assertAllClose(  # 1st inner placeholder.
709            placeholder_traces[2].debug_tensor_value,
710            [placeholder_op_digests[2].output_tensor_ids[0],
711             1, 0, 1, 0, 0, 0, 0, 0, 0])
712        self.assertAllClose(  # 2nd outer placeholder.
713            placeholder_traces[3].debug_tensor_value,
714            [placeholder_op_digests[3].output_tensor_ids[0],
715             1, 0, 1, 0, 0, 0, 0, 0, 0])
716        # 1st AddV2 op.
717        self.assertAllClose(
718            non_placeholder_traces[0].debug_tensor_value,
719            [add_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
720        # Log op.
721        self.assertAllClose(
722            non_placeholder_traces[1].debug_tensor_value,
723            [log_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
724        # 2nd AddV2 op.
725        self.assertAllClose(
726            non_placeholder_traces[2].debug_tensor_value,
727            [add_op_digests[1].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
728        # Sin op.
729        self.assertAllClose(
730            non_placeholder_traces[3].debug_tensor_value,
731            [sin_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
732      else:  # FULL_TENSOR.
733        placeholder_full_tensor_values = [
734            reader.graph_execution_trace_to_tensor_value(trace)
735            for trace in placeholder_traces]
736        self.assertAllClose(placeholder_full_tensor_values[0], x)  # Input x.
737        self.assertAllClose(placeholder_full_tensor_values[1], y)  # Input y.
738        self.assertAllClose(placeholder_full_tensor_values[2], x)  # Input x.
739        self.assertAllClose(placeholder_full_tensor_values[3], y)  # Input y.
740        non_placeholder_full_tensor_values = [
741            reader.graph_execution_trace_to_tensor_value(trace)
742            for trace in non_placeholder_traces]
743        self.assertAllClose(
744            non_placeholder_full_tensor_values[0], 5.0)  # 1st AddV2 op.
745        self.assertAllClose(
746            non_placeholder_full_tensor_values[1], np.log(5.0))  # Log op.
747        self.assertAllClose(
748            non_placeholder_full_tensor_values[2],
749            np.log(5.0) + 1.0)  # 2nd AddV2 op.
750        self.assertAllClose(
751            non_placeholder_full_tensor_values[3],
752            np.sin(np.log(5.0) + 1.0))  # Sin op.
753
754  @parameterized.named_parameters(
755      ("NoTensor", "NO_TENSOR"),
756      ("FullTensor", "FULL_TENSOR"),
757  )
758  @test_util.run_in_graph_and_eager_modes
759  def testGraphOpConsumingRelationIsCaptured(self, tensor_debug_mode):
760    x = constant_op.constant([2.0, 2.0])
761    y = constant_op.constant([3.0, 3.0])
762    writer = dumping_callback.enable_dump_debug_info(
763        self.dump_root, tensor_debug_mode=tensor_debug_mode)
764
765    @def_function.function
766    def log_sum(x, y):
767      return math_ops.log(x + y)
768
769    @def_function.function
770    def maxindex_sin1p_log_sum(x, y):
771      _, indices = array_ops.unique(math_ops.sin(1.0 + log_sum(x, y)))
772      return math_ops.reduce_max(indices)
773
774    maxindex = maxindex_sin1p_log_sum(x, y)
775    self.assertAllEqual(maxindex, 0)
776    writer.FlushNonExecutionFiles()
777    writer.FlushExecutionFiles()
778
779    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
780      reader.update()
781      traces = reader.graph_execution_traces()
782      add_traces = [trace for trace in traces if trace.op_type == "AddV2"]
783      log_traces = [trace for trace in traces if trace.op_type == "Log"]
784      sin_traces = [trace for trace in traces if trace.op_type == "Sin"]
785      unique_traces = [trace for trace in traces if trace.op_type == "Unique"]
786      max_traces = [trace for trace in traces if trace.op_type == "Max"]
787      self.assertLen(add_traces, 2)
788      self.assertLen(log_traces, 1)
789      self.assertLen(sin_traces, 1)
790      self.assertLen(unique_traces, 2)  # The Unique op outputs two tensors.
791      self.assertLen(max_traces, 1)
792      graph = reader.graph_by_id(add_traces[0].graph_id)
793      # The first AddV2 op is consumed by the Log op.
794      self.assertEqual(
795          graph.get_op_consumers(add_traces[0].op_name),
796          [(0, log_traces[0].op_name, 0)])
797      graph = reader.graph_by_id(add_traces[1].graph_id)
798      # The second AddV2 op is consumed by the Sin op.
799      self.assertEqual(
800          graph.get_op_consumers(add_traces[1].op_name),
801          [(0, sin_traces[0].op_name, 0)])
802      # The last Sin op is consumed by the Unique op.
803      self.assertEqual(
804          graph.get_op_consumers(sin_traces[0].op_name),
805          [(0, unique_traces[0].op_name, 0)])
806      # The Unique op's 2nd output tensor is consumed by the Max op.
807      self.assertEqual(
808          graph.get_op_consumers(unique_traces[0].op_name),
809          [(1, max_traces[0].op_name, 0)])
810
811  def testCapturingExecutedGraphIdsOfTwoCompilationsOfSameFunction(self):
812    """Test correct executed IDs of two FuncGraphs from the same Py function."""
813    x_float32 = constant_op.constant(np.array(3.5, dtype=np.float32))
814    x_float64 = constant_op.constant(np.array(4.5, dtype=np.float64))
815    writer = dumping_callback.enable_dump_debug_info(
816        self.dump_root, tensor_debug_mode="NO_TENSOR")
817
818    @def_function.function
819    def ceil_times_two(x):
820      return math_ops.ceil(x) * 2.0
821
822    # Four executions, with two different FuncGraphs, which should lead
823    # to two unique executed graph IDs (see assertion below).
824    self.assertAllClose(ceil_times_two(x_float32), 8.0)
825    self.assertAllClose(ceil_times_two(x_float64), 10.0)
826    self.assertAllClose(ceil_times_two(x_float32), 8.0)
827    self.assertAllClose(ceil_times_two(x_float64), 10.0)
828    writer.FlushNonExecutionFiles()
829    writer.FlushExecutionFiles()
830
831    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
832      reader.update()
833
834      executions = reader.executions()
835      self.assertLen(executions, 4)
836      for execution in executions:
837        self.assertStartsWith(execution.op_type, "__inference_ceil_times_two_")
838      executed_graph_ids = [execution.graph_id for execution in executions]
839      self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
840      self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
841      self.assertNotEqual(executed_graph_ids[0], executed_graph_ids[1])
842      self.assertNotEqual(executed_graph_ids[2], executed_graph_ids[3])
843      for executed_graph_id in executed_graph_ids:
844        self.assertEqual(
845            reader.graph_by_id(executed_graph_id).name, "ceil_times_two")
846
847  def testCapturingExecutedGraphIdsOfDuplicateFunctionNames(self):
848    """Two FuncGraphs compiled from Python functions with identical names."""
849    x = constant_op.constant(np.array(3.5, dtype=np.float32))
850    writer = dumping_callback.enable_dump_debug_info(
851        self.dump_root, tensor_debug_mode="NO_TENSOR")
852
853    class TestClass(object):
854
855      @def_function.function
856      def ceil_times_two(self, x):
857        return math_ops.ceil(x) * 2.0
858
859    # The `ceil_times_two` method of the two objects will be compiled
860    # into separate FuncGraphs.
861    test_object_1 = TestClass()
862    test_object_2 = TestClass()
863
864    # Four executions, with two different FuncGraphs, which should lead
865    # to two unique executed graph IDs (see assertion below).
866    self.assertAllClose(test_object_1.ceil_times_two(x), 8.0)
867    self.assertAllClose(test_object_2.ceil_times_two(x), 8.0)
868    self.assertAllClose(test_object_1.ceil_times_two(x), 8.0)
869    self.assertAllClose(test_object_2.ceil_times_two(x), 8.0)
870    writer.FlushNonExecutionFiles()
871    writer.FlushExecutionFiles()
872
873    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
874      reader.update()
875      executions = reader.executions()
876      self.assertLen(executions, 4)
877      for execution in executions:
878        self.assertStartsWith(execution.op_type, "__inference_ceil_times_two_")
879      executed_graph_ids = [execution.graph_id for execution in executions]
880      self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
881      self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
882      self.assertNotEqual(executed_graph_ids[0], executed_graph_ids[1])
883      self.assertNotEqual(executed_graph_ids[2], executed_graph_ids[3])
884      for executed_graph_id in executed_graph_ids:
885        self.assertEqual(
886            reader.graph_by_id(executed_graph_id).name, "ceil_times_two")
887
888  @parameterized.named_parameters(
889      ("AddV2", "AddV2"),
890      ("Log", "Log"),
891      ("AddV2AndLog", "(AddV2|Log)"),
892  )
893  @test_util.run_in_graph_and_eager_modes
894  def testOpRegex(self, op_regex):
895    x = constant_op.constant(2.0)
896    y = constant_op.constant(3.0)
897    writer = dumping_callback.enable_dump_debug_info(
898        self.dump_root, tensor_debug_mode="FULL_TENSOR",
899        op_regex=op_regex)
900
901    @def_function.function
902    def log_sum(x, y):
903      return math_ops.log(x + y)
904
905    @def_function.function
906    def sin1p_log_sum(x, y):
907      return math_ops.sin(1.0 + log_sum(x, y))
908
909    self.assertAllClose(
910        self.evaluate(sin1p_log_sum(x, y)), np.sin(1.0 + np.log(5.0)))
911    writer.FlushNonExecutionFiles()
912    writer.FlushExecutionFiles()
913
914    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
915      reader.update()
916      graph_op_digests = reader.graph_op_digests()
917      op_types = [digest.op_type for digest in graph_op_digests]
918      self.assertIn("AddV2", op_types)
919      self.assertIn("Log", op_types)
920      self.assertIn("Sin", op_types)
921
922      graph_exec_digests = reader.graph_execution_traces(digest=True)
923      executed_op_types = [digest.op_type for digest in graph_exec_digests]
924      tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
925                       for digest in graph_exec_digests]
926      if op_regex == "AddV2":
927        self.assertEqual(executed_op_types, ["AddV2", "AddV2"])
928        self.assertLen(tensor_values, 2)
929        self.assertAllClose(tensor_values[0], 5.0)  # 1st AddV2 op.
930        self.assertAllClose(
931            tensor_values[1], np.log(5.0) + 1.0)  # 2nd AddV2 op.
932      elif op_regex == "Log":
933        self.assertEqual(executed_op_types, ["Log"])
934        self.assertLen(tensor_values, 1)
935        self.assertAllClose(tensor_values[0], np.log(5.0))  # Log op.
936      else:  # "(AddV2|Log)"
937        self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2"])
938        self.assertLen(tensor_values, 3)
939        self.assertAllClose(tensor_values[0], 5.0)  # 1st AddV2 op.
940        self.assertAllClose(tensor_values[1], np.log(5.0))  # Log op.
941        self.assertAllClose(
942            tensor_values[2], np.log(5.0) + 1.0)  # 2nd AddV2 op.
943
944  def testIncorrectTensorDTypeArgFormatLeadsToError(self):
945    with self.assertRaisesRegex(
946        ValueError, r".*expected.*list.*tuple.*callable.*but received.*\{\}"):
947      dumping_callback.enable_dump_debug_info(self.dump_root,
948                                              tensor_dtypes=dict())
949    with self.assertRaisesRegex(
950        ValueError, r".*expected.*list.*tuple.*callable.*but received.*"):
951      dumping_callback.enable_dump_debug_info(self.dump_root,
952                                              tensor_dtypes="float32")
953    with self.assertRaisesRegex(
954        ValueError, r".*expected.*list.*tuple.*callable.*but received.*"):
955      dumping_callback.enable_dump_debug_info(
956          self.dump_root, tensor_dtypes=dtypes.float32)
957    with self.assertRaises(TypeError):
958      dumping_callback.enable_dump_debug_info(self.dump_root, tensor_dtypes=[
959          lambda dtype: dtype.is_floating, lambda dtype: dtype.is_integer])
960
961  @parameterized.named_parameters(
962      ("float", [dtypes.float32], None),
963      ("float_only_sum", ["float32"], "Sum"),
964      ("float_no_sum", (dtypes.float32,), "(?!Sum)"),
965      ("int", [dtypes.int32], None),
966      ("int_via_lambda", lambda dtype: dtype.is_integer, None),
967      ("exclude_Sum", None, "(?!Sum)"),
968      ("All", None, None),
969  )
970  @test_util.run_in_graph_and_eager_modes
971  def testTensorDTypesAndOpRegexFilters(self,
972                                        tensor_dtypes,
973                                        op_regex):
974    xs = constant_op.constant([2., 6., 8., 1., 2.], dtype=dtypes.float32)
975    writer = dumping_callback.enable_dump_debug_info(
976        self.dump_root, tensor_debug_mode="FULL_TENSOR",
977        tensor_dtypes=tensor_dtypes,
978        op_regex=op_regex)
979
980    @def_function.function
981    def unique_sum(xs):
982      """Sum over the unique values, for testing."""
983      unique_xs, indices = array_ops.unique(xs)
984      return math_ops.reduce_sum(unique_xs), indices
985
986    y, indices = self.evaluate(unique_sum(xs))
987    self.assertAllClose(y, 17.)
988    self.assertAllEqual(indices, [0, 1, 2, 3, 0])
989
990    writer.FlushNonExecutionFiles()
991    writer.FlushExecutionFiles()
992
993    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
994      reader.update()
995      graph_exec_digests = reader.graph_execution_traces(digest=True)
996      executed_op_types = [digest.op_type for digest in graph_exec_digests
997                           if digest.op_type not in ("Const", "Placeholder")]
998      tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
999                       for digest in graph_exec_digests
1000                       if digest.op_type not in ("Const", "Placeholder")]
1001
1002      if tensor_dtypes == [dtypes.float32] and not op_regex:
1003        self.assertEqual(executed_op_types, ["Unique", "Sum"])
1004        self.assertLen(tensor_values, 2)
1005        self.assertAllClose(tensor_values[0], [2, 6, 8, 1])  # Unique values.
1006        self.assertAllClose(tensor_values[1], 17.)  # Sum.
1007      elif tensor_dtypes == ["float32"] and op_regex == "Sum":
1008        self.assertEqual(executed_op_types, ["Sum"])
1009        self.assertLen(tensor_values, 1)
1010        self.assertAllClose(tensor_values[0], 17.)  # Sum.
1011      elif tensor_dtypes == (dtypes.float32,) and op_regex == "(?!Sum)":
1012        self.assertEqual(executed_op_types, ["Unique"])
1013        self.assertLen(tensor_values, 1)
1014        self.assertAllClose(tensor_values[0], [2, 6, 8, 1])  # Unique values.
1015      elif tensor_dtypes == [dtypes.int32] and not op_regex:
1016        self.assertEqual(executed_op_types, ["Unique"])
1017        self.assertLen(tensor_values, 1)
1018        self.assertAllEqual(
1019            tensor_values[0], [0, 1, 2, 3, 0])  # Unique indices.
1020      elif callable(tensor_dtypes) and not op_regex:
1021        self.assertEqual(executed_op_types, ["Unique"])
1022        self.assertLen(tensor_values, 1)
1023        self.assertAllEqual(
1024            tensor_values[0], [0, 1, 2, 3, 0])  # Unique indices.
1025      elif not tensor_dtypes and op_regex == "(?!Sum)":
1026        self.assertEqual(executed_op_types, ["Unique", "Unique"])
1027        self.assertLen(tensor_values, 2)
1028        self.assertAllClose(tensor_values[0], [2, 6, 8, 1])  # Unique values.
1029        self.assertAllEqual(
1030            tensor_values[1], [0, 1, 2, 3, 0])  # Unique indices.
1031      else:  # "All".
1032        self.assertEqual(executed_op_types, ["Unique", "Unique", "Sum"])
1033        self.assertLen(tensor_values, 3)
1034        self.assertAllClose(tensor_values[0], [2, 6, 8, 1])  # Unique values.
1035        self.assertAllEqual(
1036            tensor_values[1], [0, 1, 2, 3, 0])  # Unique indices.
1037        self.assertAllClose(tensor_values[2], 17)  # Sum.
1038
1039  @parameterized.named_parameters(
1040      ("NoTensor", "NO_TENSOR"),
1041      ("CurtHealth", "CURT_HEALTH"),
1042      ("FullTensor", "FULL_TENSOR"),
1043  )
1044  @test_util.run_in_graph_and_eager_modes
1045  def testFunctionExecutionWithControlFlow(self, tensor_debug_mode):
1046    x = constant_op.constant(0.5, dtype=dtypes.float32)
1047    times = constant_op.constant(4, dtype=dtypes.int32)
1048    writer = dumping_callback.enable_dump_debug_info(
1049        self.dump_root, tensor_debug_mode=tensor_debug_mode)
1050
1051    @def_function.function
1052    def iterative_doubling(x, times):
1053      i = constant_op.constant(0, dtype=dtypes.int32)
1054      while i < times:
1055        x = x * 2.0
1056        i += 1
1057      return x
1058
1059    self.assertAllClose(self.evaluate(iterative_doubling(x, times)), 8.0)
1060
1061    writer.FlushNonExecutionFiles()
1062    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
1063      reader.update()
1064      graph_op_digests = reader.graph_op_digests()
1065      op_types = [digest.op_type for digest in graph_op_digests]
1066      self.assertIn("Less", op_types)
1067      self.assertIn("Mul", op_types)
1068      self.assertIn("AddV2", op_types)
1069
1070      # Before FlushExecutionFiles() is called, the .execution and
1071      # .graph_execution_traces files should be both empty.
1072      self.assertEqual(reader.num_executions(), 0)
1073      self.assertEqual(reader.num_graph_execution_traces(), 0)
1074
1075      # TODO(cais): Backport execution instrumentation to tf.Session.
1076      writer.FlushExecutionFiles()
1077      # After the flushing, the .execution file should hold the appropriate
1078      # contents.
1079      reader.update()
1080      if context.executing_eagerly():
1081        # NOTE(b/142486213): Execution of the TF function happens with
1082        # Session.run() in v1 graph mode, hence it doesn't get logged to the
1083        executions = reader.executions()
1084        self.assertLen(executions, 1)
1085        executed_op_types = [execution.op_type for execution in executions]
1086        self.assertIn("iterative_doubling", executions[0].op_type)
1087        execution = executions[0]
1088        self.assertLen(execution.input_tensor_ids, 2)
1089        self.assertLen(execution.output_tensor_ids, 1)
1090        self.assertEqual(
1091            debug_event_pb2.TensorDebugMode.keys()[execution.tensor_debug_mode],
1092            tensor_debug_mode)
1093        if tensor_debug_mode == "FULL_TENSOR":
1094          tensor_values = reader.execution_to_tensor_values(execution)
1095          self.assertAllClose(tensor_values, [8.0])
1096
1097      graph_exec_traces = reader.graph_execution_traces()
1098      executed_op_types = [trace.op_type for trace in graph_exec_traces
1099                           if trace.op_type != "Const"]
1100      if tensor_debug_mode != "CURT_HEALTH":
1101        # Less outputs a boolean tensor, which is not tracked under CURT_HEALTH.
1102        # The Less op should have been executed 5 times.
1103        self.assertEqual(executed_op_types.count("Less"), 5)
1104        # The last executed op should be Less.
1105        self.assertEqual(executed_op_types[-1], "Less")
1106        # AddV2 produces an int tensor, which is not tracked under CURT_HEALTH.
1107        # The AddV2 op should have been run, but we refrain from asserting on
1108        # how many times it's executed.
1109        self.assertIn("AddV2", executed_op_types)
1110        for trace in graph_exec_traces:
1111          self.assertEqual(trace.output_slot, 0)
1112      # The Mul op should have been executed 4 times.
1113      self.assertEqual(executed_op_types.count("Mul"), 4)
1114
1115      tensor_values = [reader.graph_execution_trace_to_tensor_value(trace)
1116                       for trace in graph_exec_traces]
1117      if tensor_debug_mode == "NO_TENSOR":
1118        # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
1119        # to be an empty float32 tensor.
1120        for tensor_value in tensor_values:
1121          self.assertAllEqual(tensor_value, [])
1122      elif tensor_debug_mode == "CURT_HEALTH":
1123        for trace in graph_exec_traces:
1124          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
1125          # 1st element: tensor_id; 2nd element: 0 indicating no inf or nan.
1126          self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0.0])
1127      elif tensor_debug_mode == "FULL_TENSOR":
1128        less_values = [
1129            reader.graph_execution_trace_to_tensor_value(trace)
1130            for trace in graph_exec_traces if trace.op_type == "Less"]
1131        self.assertAllEqual(less_values, [True, True, True, True, False])
1132        mul_values = [
1133            reader.graph_execution_trace_to_tensor_value(trace)
1134            for trace in graph_exec_traces if trace.op_type == "Mul"]
1135        self.assertAllClose(mul_values, [1.0, 2.0, 4.0, 8.0])
1136
1137  def testCallingEnableTracingTwiceWithTheSameDumpRootIsIdempotent(self):
1138    x = constant_op.constant([10.0, 12.0, 10.0])
1139    dumping_callback.enable_dump_debug_info(self.dump_root)
1140    writer = dumping_callback.enable_dump_debug_info(self.dump_root)
1141
1142    for _ in range(2):
1143      array_ops.unique(x)
1144
1145    writer.FlushNonExecutionFiles()
1146    writer.FlushExecutionFiles()
1147
1148    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
1149      reader.update()
1150      executions = reader.executions()
1151      self.assertLen(executions, 2)
1152      for execution in executions:
1153        self.assertGreater(execution.wall_time, 0)
1154        self.assertEqual(execution.op_type, "Unique")
1155        self.assertEqual(execution.num_outputs, 2)
1156        _, stack_frames = reader.read_execution_stack_trace(execution)
1157        self._verifyStackFrames(stack_frames)
1158
1159  def testCallingEnableTracingTwiceWithDifferentDumpRootsOverwrites(self):
1160    x = constant_op.constant([10.0, 12.0, 10.0])
1161    dumping_callback.enable_dump_debug_info(self.dump_root)
1162    new_dump_root = self.dump_root + "_new_dump_root"
1163    writer = dumping_callback.enable_dump_debug_info(new_dump_root)
1164
1165    for _ in range(2):
1166      array_ops.unique(x)
1167
1168    writer.FlushNonExecutionFiles()
1169    writer.FlushExecutionFiles()
1170
1171    with debug_events_reader.DebugDataReader(new_dump_root) as reader:
1172      reader.update()
1173      executions = reader.executions()
1174      self.assertLen(executions, 2)
1175      for execution in executions:
1176        self.assertGreater(execution.wall_time, 0)
1177        self.assertEqual(execution.op_type, "Unique")
1178        self.assertEqual(execution.num_outputs, 2)
1179        _, stack_frames = reader.read_execution_stack_trace(execution)
1180        self._verifyStackFrames(stack_frames)
1181
1182    with debug_events_reader.DebugDataReader(
1183        self.dump_root) as old_dump_root_reader:
1184      old_dump_root_reader.update()
1185      # The old dump root shouldn't have been written to.
1186      self.assertEqual(old_dump_root_reader.num_executions(), 0)
1187      self.assertFalse(old_dump_root_reader.outermost_graphs())
1188
1189  def testCallingEnableRepeatedlyWithDifferentTensorDebugMode(self):
1190    """Assert calling enable_dump_debug_info() with two tensor-debug modes.
1191
1192    It should lead to overwriting of the previously-configured mode.
1193    """
1194    writer = dumping_callback.enable_dump_debug_info(
1195        self.dump_root, tensor_debug_mode="NO_TENSOR")
1196
1197    @def_function.function
1198    def add_1_divide_by_2(x):
1199      return (x + 1.0) / 2.0
1200
1201    self.assertAllClose(add_1_divide_by_2(constant_op.constant(4.0)), 2.5)
1202    writer.FlushNonExecutionFiles()
1203    writer.FlushExecutionFiles()
1204
1205    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
1206      reader.update()
1207      graph_exec_digests = reader.graph_execution_traces(digest=True)
1208      tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
1209                       for digest in graph_exec_digests]
1210      for tensor_value in tensor_values:
1211        # Under NO_TENSOR mode, each tensor is summarized as an empty float32
1212        # array.
1213        self.assertAllEqual(tensor_value, [])
1214
1215    with self.assertRaisesRegex(
1216        ValueError, r"already.*NO_TENSOR.*FULL_TENSOR.*not be honored"):
1217      dumping_callback.enable_dump_debug_info(
1218          self.dump_root, tensor_debug_mode="FULL_TENSOR")
1219
1220  @parameterized.named_parameters(
1221      ("NoTensor", "NO_TENSOR"),
1222      ("FullTensor", "FULL_TENSOR"),
1223  )
1224  def testDisableTracingWorks(self, tensor_debug_mode):
1225    x = constant_op.constant([10.0, 12.0, 10.0])
1226    writer = dumping_callback.enable_dump_debug_info(
1227        self.dump_root, tensor_debug_mode=tensor_debug_mode)
1228    dumping_callback.disable_dump_debug_info()
1229
1230    for _ in range(2):
1231      array_ops.unique(x)
1232
1233    writer.FlushNonExecutionFiles()
1234    writer.FlushExecutionFiles()
1235
1236    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
1237      reader.update()
1238      self.assertEqual(reader.num_executions(), 0)
1239      self.assertEqual(reader.num_graph_execution_traces(), 0)
1240      self.assertFalse(reader.outermost_graphs())
1241
1242  @parameterized.named_parameters(
1243      ("NoTensor", "NO_TENSOR"),
1244      ("CurtHealth", "CURT_HEALTH"),
1245      ("ConciseHealth", "CONCISE_HEALTH"),
1246      ("FullHealth", "FULL_HEALTH"),
1247      ("Shape", "SHAPE"),
1248      ("FullTensor", "FULL_TENSOR"),
1249  )
1250  def testMultiThreadedExecutionWithSameSetting(self, tensor_debug_mode):
1251    """Dumping from multiple threads using the same setting."""
1252    writer = dumping_callback.enable_dump_debug_info(
1253        self.dump_root, tensor_debug_mode=tensor_debug_mode)
1254    x = variables.Variable(10.0, dtype=dtypes.float32)
1255    y = variables.Variable(3.0, dtype=dtypes.float32)
1256
1257    @def_function.function
1258    def increase_x():
1259      return x.assign_add(y * 2.0)
1260
1261    increase_x()
1262
1263    num_threads = 3
1264    threads = []
1265    for _ in range(num_threads):
1266      threads.append(threading.Thread(target=increase_x))
1267    for thread in threads:
1268      thread.start()
1269    for thread in threads:
1270      thread.join()
1271    # 10 --> 16 --> 22 --> 28 --> 34.
1272    self.assertAllClose(x.read_value(), 34.0)
1273
1274    writer.FlushNonExecutionFiles()
1275    writer.FlushExecutionFiles()
1276
1277    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
1278      reader.update()
1279      exec_digests = reader.executions(digest=True)
1280      prev_wall_time = 1
1281      for exec_digest in exec_digests:
1282        self.assertGreaterEqual(exec_digest.wall_time, prev_wall_time)
1283        prev_wall_time = exec_digest.wall_time
1284
1285      graph_exec_traces = reader.graph_execution_traces()
1286      executed_op_types = [trace.op_type for trace in graph_exec_traces]
1287      self.assertEqual(executed_op_types.count("Mul"), 1 + num_threads)
1288      self.assertEqual(
1289          executed_op_types.count("ReadVariableOp"), 2 * (1 + num_threads))
1290      for trace in graph_exec_traces:
1291        # These are all single-output tensors.
1292        self.assertEqual(trace.output_slot, 0)
1293
1294    tensor_values = [reader.graph_execution_trace_to_tensor_value(trace)
1295                     for trace in graph_exec_traces]
1296    if tensor_debug_mode == "NO_TENSOR":
1297      for tensor_value in tensor_values:
1298        self.assertAllEqual(tensor_value, [])
1299    elif tensor_debug_mode == "CURT_HEALTH":
1300      for trace in graph_exec_traces:
1301        tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
1302        # 1st element: tensor ID; 2nd element: 0 indicating no inf or nan.
1303        self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0])
1304    elif tensor_debug_mode == "CONCISE_HEALTH":
1305      for trace in graph_exec_traces:
1306        tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
1307        # 1st element: tensor ID.
1308        # 2nd element: element count. Remaining elements: all zero because there
1309        # is no -inf, inf or nan.
1310        self.assertAllClose(trace.debug_tensor_value, [tensor_id, 1, 0, 0, 0])
1311    elif tensor_debug_mode == "FULL_HEALTH":
1312      for trace in graph_exec_traces:
1313        tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
1314        # Elements: [
1315        #   -1 is the unset tensor_id for eager op execution,
1316        #   device ID (set to -1 for now),
1317        #   dtype, rank, element_count,
1318        #   neg_inf_count, pos_inf_count, nan_count
1319        #   neg_finite_count, zero_count, pos_finite_count]
1320        self.assertAllClose(
1321            trace.debug_tensor_value,
1322            [tensor_id, -1, 1, 0, 1, 0, 0, 0, 0, 0, 1])
1323    elif tensor_debug_mode == "SHAPE":
1324      for trace in graph_exec_traces:
1325        if trace.op_type == "Mul":
1326          tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
1327          mul_value = reader.graph_execution_trace_to_tensor_value(trace)
1328          # 1st element: tensor_id, should be >= 0.
1329          # 2nd element: dtype enum value (float32).
1330          # 3rd element: rank.
1331          # 4th element: element count.
1332          self.assertAllClose(mul_value, [tensor_id, 1, 0, 1, 0, 0, 0, 0, 0, 0])
1333    elif tensor_debug_mode == "FULL_TENSOR":
1334      mul_values = [
1335          reader.graph_execution_trace_to_tensor_value(trace)
1336          for trace in graph_exec_traces if trace.op_type == "Mul"]
1337      self.assertAllClose(mul_values, [6.0, 6.0, 6.0, 6.0])
1338
1339  def testMultiThreadedDumpingWithDifferentSettings(self):
1340    gpu_name = test_util.gpu_device_name()
1341    if gpu_name:
1342      self.skipTest("b/153671240: test is flaky on GPUs")
1343    dump_root_1 = os.path.join(self.dump_root, "dump_root_1")
1344    dump_root_2 = os.path.join(self.dump_root, "dump_root_2")
1345    v1 = variables.Variable(10.0, dtype=dtypes.float32)
1346    v2 = variables.Variable(3.0, dtype=dtypes.float32)
1347
1348    def add_negative_v1_squared_to_itself():
1349      writer = dumping_callback.enable_dump_debug_info(
1350          dump_root_1, tensor_debug_mode="FULL_TENSOR")
1351      # Run in a loop to facilitate interleaving between threads.
1352      for _ in range(3):
1353        v1.assign_add(-(v1 ** 2.0))
1354      writer.FlushNonExecutionFiles()
1355      writer.FlushExecutionFiles()
1356
1357    def add_negative_v2_squared_to_itself():
1358      writer = dumping_callback.enable_dump_debug_info(
1359          dump_root_2, tensor_debug_mode="FULL_TENSOR")
1360      v2_squared = v2 ** 2.0
1361      # Since dumping is disabled before the Neg op is called, no tensor data
1362      # should be dumped from the op, but this shouldn't affect the dumping of
1363      # the tensor data from the Neg op in `add_negative_v1_squared_to_itself`.
1364      # Both behavior is checked below.
1365      dumping_callback.disable_dump_debug_info()
1366      negative_v2_squared = -v2_squared
1367      v2.assign_add(negative_v2_squared)
1368      writer.FlushNonExecutionFiles()
1369      writer.FlushExecutionFiles()
1370
1371    # v2 is mutated on a sub-thread.
1372    sub_thread = threading.Thread(target=add_negative_v2_squared_to_itself)
1373    sub_thread.start()
1374    add_negative_v1_squared_to_itself()  # v1 is mutated on the main thread.
1375    sub_thread.join()
1376    # 10 - 10 * 10 = -90.
1377    # -90 - (-90 * -90) = -8190.
1378    # -8190 - (-8190 * -8190) = -67084290.
1379    self.assertAllClose(v1.read_value(), -67084290.0)
1380    self.assertAllClose(v2.read_value(), -6.0)
1381
1382    with debug_events_reader.DebugDataReader(dump_root_1) as reader:
1383      reader.update()
1384      exec_digests = reader.executions(digest=True)
1385      v1_squared_values = [
1386          reader.execution_to_tensor_values(digest)
1387          for digest in exec_digests if digest.op_type == "Pow"]
1388      negative_v1_squared_values = [
1389          reader.execution_to_tensor_values(digest)
1390          for digest in exec_digests if digest.op_type == "Neg"]
1391      self.assertAllClose(v1_squared_values, [[100.0], [8100.0], [67076100.0]])
1392      self.assertAllClose(
1393          negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]])
1394
1395    with debug_events_reader.DebugDataReader(dump_root_2) as reader:
1396      reader.update()
1397      exec_digests = reader.executions(digest=True)
1398      executed_op_types = [digest.op_type for digest in exec_digests]
1399      self.assertNotIn("Neg", executed_op_types)
1400      v2_squared_values = [
1401          reader.execution_to_tensor_values(digest)
1402          for digest in exec_digests if digest.op_type == "Pow"]
1403      self.assertAllClose(v2_squared_values, [[9.0]])
1404
1405  @test_util.run_in_graph_and_eager_modes
1406  def testNestedContextIsCapturedByGraphOpCreationHistory(self):
1407    x = constant_op.constant(2.0, dtype=dtypes.float32)
1408    times = constant_op.constant(4, dtype=dtypes.int32)
1409    writer = dumping_callback.enable_dump_debug_info(
1410        self.dump_root, tensor_debug_mode="NO_TENSOR")
1411
1412    @def_function.function
1413    def iterative_doubling(x, times):
1414      i = constant_op.constant(0, dtype=dtypes.int32)
1415      while i < times:
1416        x = x * 2.0 - 1.0
1417        i += 1
1418      return x
1419
1420    # 2 * 2 - 1 = 3; 3 * 2 - 1 = 5; 5 * 2 - 1 = 9; 9 * 2 - 1 = 17.
1421    self.assertAllClose(self.evaluate(iterative_doubling(x, times)), 17.0)
1422
1423    writer.FlushNonExecutionFiles()
1424    writer.FlushExecutionFiles()
1425    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
1426      reader.update()
1427      less_op_digest = reader.graph_op_digests(op_type="Less")[-1]
1428      mul_op_digest = reader.graph_op_digests(op_type="Mul")[-1]
1429      sub_op_digest = reader.graph_op_digests(op_type="Sub")[-1]
1430      # The Less op is from the while-loop cond context and hence should have
1431      # a different innermost context ID from the mul and sub ops, which are
1432      # both from the while-loop body context.
1433      self.assertNotEqual(less_op_digest.graph_id, mul_op_digest.graph_id)
1434      self.assertNotEqual(less_op_digest.graph_id, sub_op_digest.graph_id)
1435      # The Mul and Sub ops are from the same innermost context.
1436      self.assertEqual(mul_op_digest.graph_id, sub_op_digest.graph_id)
1437
1438  @parameterized.named_parameters(
1439      ("NoTensor", "NO_TENSOR"),
1440      ("Shape", "SHAPE"),
1441      ("FullTensor", "FULL_TENSOR"),
1442  )
1443  @test_util.run_in_graph_and_eager_modes
1444  def testGraphInputTracingWorksWithConstAndPlaceholderTensors(
1445      self, tensor_debug_mode):
1446    x = constant_op.constant(2.0)
1447    writer = dumping_callback.enable_dump_debug_info(
1448        self.dump_root, tensor_debug_mode=tensor_debug_mode)
1449
1450    @def_function.function
1451    def func(x):
1452      return (x + constant_op.constant(4.0)) / x
1453
1454    self.assertAllClose(self.evaluate(func(x)), 3.0)
1455    writer.FlushNonExecutionFiles()
1456    writer.FlushExecutionFiles()
1457
1458    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
1459      reader.update()
1460      graph_op_digests = reader.graph_op_digests()
1461      placeholder_op_name = None
1462      const_op_name = None
1463      add_op_name = None
1464      div_op_name = None
1465      for op_digest in graph_op_digests:
1466        if op_digest.op_type == "Placeholder":
1467          placeholder_op_name = op_digest.op_name
1468        elif op_digest.op_type == "Const":
1469          const_op_name = op_digest.op_name
1470        elif op_digest.op_type == "AddV2":
1471          add_op_name = op_digest.op_name
1472          self.assertLen(op_digest.input_names, 2)
1473          self.assertEqual(op_digest.input_names[0], placeholder_op_name + ":0")
1474          self.assertEqual(op_digest.input_names[1], const_op_name + ":0")
1475        elif op_digest.op_type == "RealDiv":
1476          div_op_name = op_digest
1477          self.assertLen(op_digest.input_names, 2)
1478          self.assertEqual(op_digest.input_names[0], add_op_name + ":0")
1479          self.assertEqual(op_digest.input_names[1], placeholder_op_name + ":0")
1480      self.assertTrue(add_op_name)
1481      self.assertTrue(div_op_name)
1482
1483
1484if __name__ == "__main__":
1485  ops.enable_eager_execution()
1486  googletest.main()
1487