xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/error_interpolation_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.errors."""
16
17import collections
18import os
19import re
20
21from tensorflow.python.eager import def_function
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import error_interpolation
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.framework import traceable_stack
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import script_ops
32from tensorflow.python.platform import test
33
34# A mock for ``tf_stack.FrameSummary``.
35FrameSummary = collections.namedtuple(
36    "StackFrame", ["filename", "lineno", "name", "line"])
37
38# TODO(feyu): convert tests to tf function from graph when appropriate.
39
40
41def _make_frame_with_filename(tb, idx, filename):
42  """Return a copy of an existing stack frame with a new filename."""
43  frame = tb[idx]
44  return FrameSummary(
45      filename,
46      frame.lineno,
47      frame.name,
48      frame.line)
49
50
51def _modify_op_stack_with_filenames(tb, num_user_frames, user_filename,
52                                    num_inner_tf_frames):
53  """Replace traceback with a new traceback using special filenames."""
54  tf_filename = error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "%d.py"
55  user_filename = os.path.join("%d", "my_favorite_file.py")
56
57  num_requested_frames = num_user_frames + num_inner_tf_frames
58  num_actual_frames = len(tb)
59  num_outer_frames = num_actual_frames - num_requested_frames
60  assert num_requested_frames <= num_actual_frames, "Too few real frames."
61
62  # The op's traceback has outermost frame at index 0.
63  stack = []
64  for idx in range(0, num_outer_frames):
65    stack.append(tb[idx])
66  for idx in range(len(stack), len(stack) + num_user_frames):
67    stack.append(_make_frame_with_filename(tb, idx, user_filename % idx))
68  for idx in range(len(stack), len(stack) + num_inner_tf_frames):
69    stack.append(_make_frame_with_filename(tb, idx, tf_filename % idx))
70  return stack
71
72
73class ComputeDeviceSummaryFromOpTest(test.TestCase):
74
75  def testCorrectFormatWithActiveDeviceAssignments(self):
76    assignments = []
77    assignments.append(
78        traceable_stack.TraceableObject(
79            "/cpu:0", filename="hope.py", lineno=24))
80    assignments.append(
81        traceable_stack.TraceableObject(
82            "/gpu:2", filename="please.py", lineno=42))
83
84    summary = error_interpolation._compute_device_summary_from_list(
85        "nodename", assignments, prefix="  ")
86
87    self.assertIn("nodename", summary)
88    self.assertIn("tf.device(/cpu:0)", summary)
89    self.assertIn("<hope.py:24>", summary)
90    self.assertIn("tf.device(/gpu:2)", summary)
91    self.assertIn("<please.py:42>", summary)
92
93  def testCorrectFormatWhenNoColocationsWereActive(self):
94    device_assignment_list = []
95    summary = error_interpolation._compute_device_summary_from_list(
96        "nodename", device_assignment_list, prefix="  ")
97    self.assertIn("nodename", summary)
98    self.assertIn("No device assignments", summary)
99
100
101class ComputeColocationSummaryFromOpTest(test.TestCase):
102
103  def testCorrectFormatWithActiveColocations(self):
104    t_obj_1 = traceable_stack.TraceableObject(
105        None, filename="test_1.py", lineno=27)
106    t_obj_2 = traceable_stack.TraceableObject(
107        None, filename="test_2.py", lineno=38)
108    colocation_dict = {
109        "test_node_1": t_obj_1,
110        "test_node_2": t_obj_2,
111    }
112    summary = error_interpolation._compute_colocation_summary_from_dict(
113        "node_name", colocation_dict, prefix="  ")
114    self.assertIn("node_name", summary)
115    self.assertIn("colocate_with(test_node_1)", summary)
116    self.assertIn("<test_1.py:27>", summary)
117    self.assertIn("colocate_with(test_node_2)", summary)
118    self.assertIn("<test_2.py:38>", summary)
119
120  def testCorrectFormatWhenNoColocationsWereActive(self):
121    colocation_dict = {}
122    summary = error_interpolation._compute_colocation_summary_from_dict(
123        "node_name", colocation_dict, prefix="  ")
124    self.assertIn("node_name", summary)
125    self.assertIn("No node-device colocations", summary)
126
127
128# Note that the create_graph_debug_info_def needs to run on graph mode ops,
129# so it is excluded from eager tests. Even when used in eager mode, it is
130# via FunctionGraphs, and directly verifying in graph mode is the narrowest
131# way to unit test the functionality.
132class CreateGraphDebugInfoDefTest(test.TestCase):
133
134  def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index):
135    self.assertIn(key, graph_debug_info.traces)
136    stack_trace = graph_debug_info.traces[key]
137    found_flc = None
138    for flc in stack_trace.file_line_cols:
139      if flc.file_index == file_index:
140        found_flc = flc
141        break
142    self.assertIsNotNone(found_flc,
143                         "Could not find a stack trace entry for file")
144    return found_flc
145
146  def testStackTraceExtraction(self):
147    # This test is verifying stack trace information added in graph mode, so
148    # only makes sense in graph mode.
149    with ops.Graph().as_default():
150      # Since the create_graph_debug_info_def() function does not actually
151      # do anything special with functions except name mangling, just verify
152      # it with a loose op and manually provided function name.
153      # The following ops *must* be on consecutive lines (it will be verified
154      # in the resulting trace).
155      # pyformat: disable
156      global_op = constant_op.constant(0, name="Global").op
157      op1 = constant_op.constant(1, name="One").op
158      op2 = constant_op.constant(2, name="Two").op
159      # pyformat: enable
160
161      # Ensure op without traceback does not fail
162      node_def_copy = type(op1.node_def)()
163      node_def_copy.CopyFrom(op1.node_def)
164      node_def_copy.name = "NonTraceback"
165      c_op = ops._create_c_op(
166          ops.get_default_graph(),
167          node_def=node_def_copy,
168          inputs=[],
169          control_inputs=[],
170          extract_traceback=False)
171
172      non_traceback_op = ops.Operation._from_c_op(c_op, ops.get_default_graph())
173      self.assertIsNone(non_traceback_op.traceback)
174
175      export_ops = [("", global_op), ("func1", op1), ("func2", op2),
176                    ("func2", non_traceback_op)]
177      graph_debug_info = error_interpolation.create_graph_debug_info_def(
178          export_ops)
179      this_file_index = -1
180      for file_index, file_name in enumerate(graph_debug_info.files):
181        if "{}error_interpolation_test.py".format(os.sep) in file_name:
182          this_file_index = file_index
183      self.assertGreaterEqual(
184          this_file_index, 0,
185          "Could not find this file in trace:" + repr(graph_debug_info))
186
187      # Verify the traces exist for each op.
188      global_flc = self._getFirstStackTraceForFile(graph_debug_info, "Global@",
189                                                   this_file_index)
190      op1_flc = self._getFirstStackTraceForFile(graph_debug_info, "One@func1",
191                                                this_file_index)
192      op2_flc = self._getFirstStackTraceForFile(graph_debug_info, "Two@func2",
193                                                this_file_index)
194
195      self.assertNotIn("NonTraceback@func2", graph_debug_info.traces)
196
197      global_line = global_flc.line
198      self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line")
199      self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line")
200
201
202class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
203
204  def testFindIndexOfDefiningFrameForOp(self):
205    with ops.Graph().as_default():
206      local_op = constant_op.constant(42).op
207      user_filename = "hope.py"
208      modified_tb = _modify_op_stack_with_filenames(
209          local_op.traceback,
210          num_user_frames=3,
211          user_filename=user_filename,
212          num_inner_tf_frames=5)
213      idx = error_interpolation._find_index_of_defining_frame(modified_tb)
214      # Expected frame is 6th from the end because there are 5 inner frames with
215      # TF filenames.
216      expected_frame = len(modified_tb) - 6
217      self.assertEqual(expected_frame, idx)
218
219  def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self):
220    with ops.Graph().as_default():
221      local_op = constant_op.constant(43).op
222      # Ensure all frames look like TF frames.
223      modified_tb = _modify_op_stack_with_filenames(
224          local_op.traceback[:7],  # Truncate stack to known length.
225          num_user_frames=0,
226          user_filename="user_file.py",
227          num_inner_tf_frames=7)
228      idx = error_interpolation._find_index_of_defining_frame(modified_tb)
229      self.assertEqual(0, idx)
230
231  def testNothingToDo(self):
232    with ops.Graph().as_default():
233      constant_op.constant(1, name="One")
234      normal_string = "This is just a normal string"
235      interpolated_string = error_interpolation.interpolate(
236          normal_string, ops.get_default_graph())
237      self.assertIn(normal_string, interpolated_string)
238
239  def testOneTagWithAFakeNameResultsInPlaceholders(self):
240    with ops.Graph().as_default():
241      one_tag_string = "{{node MinusOne}}"
242      interpolated_string = error_interpolation.interpolate(
243          one_tag_string, ops.get_default_graph())
244      self.assertIn(one_tag_string, interpolated_string)
245
246  def testOneTagWithAFakeFunctionTag(self):
247    defined_at = r"defined at.*error_interpolation_test\.py"
248    with ops.Graph().as_default():
249      constant_op.constant(1, name="One")
250      constant_op.constant(2, name="Two")
251      one_tag_with_a_fake_function_tag = "{{function_node fake}}{{node One}}"
252      interpolated_string = error_interpolation.interpolate(
253          one_tag_with_a_fake_function_tag, ops.get_default_graph())
254      # Fragments the expression to avoid matching the pattern itself.
255      expected_regex = re.compile(rf"node 'One'.*{defined_at}", re.DOTALL)
256      self.assertRegex(interpolated_string, expected_regex)
257      self.assertNotIn("function_node", interpolated_string)
258      self.assertNotIn("node 'Two'", interpolated_string)
259
260  def testTwoTagsNoSeps(self):
261    defined_at = r"defined at.*error_interpolation_test\.py"
262    with ops.Graph().as_default():
263      constant_op.constant(1, name="One")
264      constant_op.constant(2, name="Two")
265      constant_op.constant(3, name="Three")
266      two_tags_no_seps = "{{node One}}{{node Three}}"
267      interpolated_string = error_interpolation.interpolate(
268          two_tags_no_seps, ops.get_default_graph())
269      # Fragments the expression to avoid matching the pattern itself.
270      expected_regex = re.compile(
271          rf"node 'One'.*{defined_at}.*node 'Three'.*{defined_at}", re.DOTALL)
272      self.assertRegex(interpolated_string, expected_regex)
273
274  def testTwoTagsWithSeps(self):
275    defined_at = r"defined at.*error_interpolation_test\.py"
276    with ops.Graph().as_default():
277      constant_op.constant(1, name="One")
278      constant_op.constant(2, name="Two")
279      constant_op.constant(3, name="Three")
280      two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
281      interpolated_string = error_interpolation.interpolate(
282          two_tags_with_seps, ops.get_default_graph())
283      # Fragments the expression to avoid matching the pattern itself.
284      expected_regex = re.compile(
285          rf"node 'Two'.*{defined_at}.*node 'Three'.*{defined_at}", re.DOTALL)
286      self.assertRegex(interpolated_string, expected_regex)
287
288  def testNewLine(self):
289    defined_at = r"defined at.*error_interpolation_test\.py"
290    with ops.Graph().as_default():
291      constant_op.constant(1, name="One")
292      constant_op.constant(2, name="Two")
293      newline = "\n\n;;;{{node One}};;;"
294      interpolated_string = error_interpolation.interpolate(
295          newline, ops.get_default_graph())
296      expected_regex = re.compile(rf"node 'One'.*{defined_at}", re.DOTALL)
297      self.assertRegex(interpolated_string, expected_regex)
298
299
300class OperationDefinedAtTraceTest(test.TestCase):
301
302  @test_util.run_v2_only
303  def testSimpleCall(self):
304
305    @def_function.function
306    def func():
307      x = constant_op.constant([[1, 2, 3]])
308      y = script_ops.eager_py_func(lambda: [[1, 2, 3]], (), dtypes.int32)
309      return math_ops.matmul(x, y)
310
311    with self.assertRaisesRegex(
312        errors_impl.InvalidArgumentError,
313        re.compile(r"defined at.*"
314                   r"in testSimpleCall.*"
315                   r"in func", re.DOTALL)):
316      func()
317
318  @test_util.run_v2_only
319  def testNestedCall(self):
320
321    def inner():
322      x = constant_op.constant([[1, 2, 3]])
323      y = script_ops.eager_py_func(lambda: [[1, 2, 3]], (), dtypes.int32)
324      return math_ops.matmul(x, y)
325
326    @def_function.function
327    def func():
328      return inner()
329
330    with self.assertRaisesRegex(
331        errors_impl.InvalidArgumentError,
332        re.compile(r"defined at.*"
333                   r"in testNestedCall.*"
334                   r"in func.*"
335                   r"in inner", re.DOTALL)):
336      func()
337
338  @test_util.run_v2_only
339  def testAssert(self):
340    @def_function.function
341    def func():
342      control_flow_ops.Assert(False, [False])
343      return
344
345    with self.assertRaisesRegex(
346        errors_impl.InvalidArgumentError,
347        re.compile(r"defined at.*"
348                   r"in testAssert.*"
349                   r"in func", re.DOTALL)):
350      func()
351
352  @test_util.run_v2_only
353  def testControlFlow(self):
354    @def_function.function
355    def func():
356      if constant_op.constant(False):
357        return constant_op.constant(1)
358
359      else:
360        x = constant_op.constant([[1, 2, 3]])
361        y = script_ops.eager_py_func(lambda: [[1, 2, 3]], (), dtypes.int32)
362        return math_ops.matmul(x, y)
363
364    with self.assertRaisesRegex(
365        errors_impl.InvalidArgumentError,
366        re.compile(r"defined at.*"
367                   r"in testControlFlow.*"
368                   r"in func", re.DOTALL)):
369      func()
370
371
372class IsFrameworkFilenameTest(test.TestCase):
373
374  def testAllowsUnitTests(self):
375    self.assertFalse(
376        error_interpolation._is_framework_filename(
377            error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "foobar_test.py"))
378
379  def testFrameworkPythonFile(self):
380    self.assertTrue(
381        error_interpolation._is_framework_filename(
382            error_interpolation.__file__))
383
384  def testEmbedded(self):
385    self.assertTrue(
386        error_interpolation._is_framework_filename(
387            "<embedded stdlib>/context_lib.py"))
388
389
390if __name__ == "__main__":
391  test.main()
392