1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.framework.errors.""" 16 17import collections 18import os 19import re 20 21from tensorflow.python.eager import def_function 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import error_interpolation 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.framework import traceable_stack 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import script_ops 32from tensorflow.python.platform import test 33 34# A mock for ``tf_stack.FrameSummary``. 35FrameSummary = collections.namedtuple( 36 "StackFrame", ["filename", "lineno", "name", "line"]) 37 38# TODO(feyu): convert tests to tf function from graph when appropriate. 39 40 41def _make_frame_with_filename(tb, idx, filename): 42 """Return a copy of an existing stack frame with a new filename.""" 43 frame = tb[idx] 44 return FrameSummary( 45 filename, 46 frame.lineno, 47 frame.name, 48 frame.line) 49 50 51def _modify_op_stack_with_filenames(tb, num_user_frames, user_filename, 52 num_inner_tf_frames): 53 """Replace traceback with a new traceback using special filenames.""" 54 tf_filename = error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "%d.py" 55 user_filename = os.path.join("%d", "my_favorite_file.py") 56 57 num_requested_frames = num_user_frames + num_inner_tf_frames 58 num_actual_frames = len(tb) 59 num_outer_frames = num_actual_frames - num_requested_frames 60 assert num_requested_frames <= num_actual_frames, "Too few real frames." 61 62 # The op's traceback has outermost frame at index 0. 63 stack = [] 64 for idx in range(0, num_outer_frames): 65 stack.append(tb[idx]) 66 for idx in range(len(stack), len(stack) + num_user_frames): 67 stack.append(_make_frame_with_filename(tb, idx, user_filename % idx)) 68 for idx in range(len(stack), len(stack) + num_inner_tf_frames): 69 stack.append(_make_frame_with_filename(tb, idx, tf_filename % idx)) 70 return stack 71 72 73class ComputeDeviceSummaryFromOpTest(test.TestCase): 74 75 def testCorrectFormatWithActiveDeviceAssignments(self): 76 assignments = [] 77 assignments.append( 78 traceable_stack.TraceableObject( 79 "/cpu:0", filename="hope.py", lineno=24)) 80 assignments.append( 81 traceable_stack.TraceableObject( 82 "/gpu:2", filename="please.py", lineno=42)) 83 84 summary = error_interpolation._compute_device_summary_from_list( 85 "nodename", assignments, prefix=" ") 86 87 self.assertIn("nodename", summary) 88 self.assertIn("tf.device(/cpu:0)", summary) 89 self.assertIn("<hope.py:24>", summary) 90 self.assertIn("tf.device(/gpu:2)", summary) 91 self.assertIn("<please.py:42>", summary) 92 93 def testCorrectFormatWhenNoColocationsWereActive(self): 94 device_assignment_list = [] 95 summary = error_interpolation._compute_device_summary_from_list( 96 "nodename", device_assignment_list, prefix=" ") 97 self.assertIn("nodename", summary) 98 self.assertIn("No device assignments", summary) 99 100 101class ComputeColocationSummaryFromOpTest(test.TestCase): 102 103 def testCorrectFormatWithActiveColocations(self): 104 t_obj_1 = traceable_stack.TraceableObject( 105 None, filename="test_1.py", lineno=27) 106 t_obj_2 = traceable_stack.TraceableObject( 107 None, filename="test_2.py", lineno=38) 108 colocation_dict = { 109 "test_node_1": t_obj_1, 110 "test_node_2": t_obj_2, 111 } 112 summary = error_interpolation._compute_colocation_summary_from_dict( 113 "node_name", colocation_dict, prefix=" ") 114 self.assertIn("node_name", summary) 115 self.assertIn("colocate_with(test_node_1)", summary) 116 self.assertIn("<test_1.py:27>", summary) 117 self.assertIn("colocate_with(test_node_2)", summary) 118 self.assertIn("<test_2.py:38>", summary) 119 120 def testCorrectFormatWhenNoColocationsWereActive(self): 121 colocation_dict = {} 122 summary = error_interpolation._compute_colocation_summary_from_dict( 123 "node_name", colocation_dict, prefix=" ") 124 self.assertIn("node_name", summary) 125 self.assertIn("No node-device colocations", summary) 126 127 128# Note that the create_graph_debug_info_def needs to run on graph mode ops, 129# so it is excluded from eager tests. Even when used in eager mode, it is 130# via FunctionGraphs, and directly verifying in graph mode is the narrowest 131# way to unit test the functionality. 132class CreateGraphDebugInfoDefTest(test.TestCase): 133 134 def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index): 135 self.assertIn(key, graph_debug_info.traces) 136 stack_trace = graph_debug_info.traces[key] 137 found_flc = None 138 for flc in stack_trace.file_line_cols: 139 if flc.file_index == file_index: 140 found_flc = flc 141 break 142 self.assertIsNotNone(found_flc, 143 "Could not find a stack trace entry for file") 144 return found_flc 145 146 def testStackTraceExtraction(self): 147 # This test is verifying stack trace information added in graph mode, so 148 # only makes sense in graph mode. 149 with ops.Graph().as_default(): 150 # Since the create_graph_debug_info_def() function does not actually 151 # do anything special with functions except name mangling, just verify 152 # it with a loose op and manually provided function name. 153 # The following ops *must* be on consecutive lines (it will be verified 154 # in the resulting trace). 155 # pyformat: disable 156 global_op = constant_op.constant(0, name="Global").op 157 op1 = constant_op.constant(1, name="One").op 158 op2 = constant_op.constant(2, name="Two").op 159 # pyformat: enable 160 161 # Ensure op without traceback does not fail 162 node_def_copy = type(op1.node_def)() 163 node_def_copy.CopyFrom(op1.node_def) 164 node_def_copy.name = "NonTraceback" 165 c_op = ops._create_c_op( 166 ops.get_default_graph(), 167 node_def=node_def_copy, 168 inputs=[], 169 control_inputs=[], 170 extract_traceback=False) 171 172 non_traceback_op = ops.Operation._from_c_op(c_op, ops.get_default_graph()) 173 self.assertIsNone(non_traceback_op.traceback) 174 175 export_ops = [("", global_op), ("func1", op1), ("func2", op2), 176 ("func2", non_traceback_op)] 177 graph_debug_info = error_interpolation.create_graph_debug_info_def( 178 export_ops) 179 this_file_index = -1 180 for file_index, file_name in enumerate(graph_debug_info.files): 181 if "{}error_interpolation_test.py".format(os.sep) in file_name: 182 this_file_index = file_index 183 self.assertGreaterEqual( 184 this_file_index, 0, 185 "Could not find this file in trace:" + repr(graph_debug_info)) 186 187 # Verify the traces exist for each op. 188 global_flc = self._getFirstStackTraceForFile(graph_debug_info, "Global@", 189 this_file_index) 190 op1_flc = self._getFirstStackTraceForFile(graph_debug_info, "One@func1", 191 this_file_index) 192 op2_flc = self._getFirstStackTraceForFile(graph_debug_info, "Two@func2", 193 this_file_index) 194 195 self.assertNotIn("NonTraceback@func2", graph_debug_info.traces) 196 197 global_line = global_flc.line 198 self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line") 199 self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line") 200 201 202class InterpolateFilenamesAndLineNumbersTest(test.TestCase): 203 204 def testFindIndexOfDefiningFrameForOp(self): 205 with ops.Graph().as_default(): 206 local_op = constant_op.constant(42).op 207 user_filename = "hope.py" 208 modified_tb = _modify_op_stack_with_filenames( 209 local_op.traceback, 210 num_user_frames=3, 211 user_filename=user_filename, 212 num_inner_tf_frames=5) 213 idx = error_interpolation._find_index_of_defining_frame(modified_tb) 214 # Expected frame is 6th from the end because there are 5 inner frames with 215 # TF filenames. 216 expected_frame = len(modified_tb) - 6 217 self.assertEqual(expected_frame, idx) 218 219 def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self): 220 with ops.Graph().as_default(): 221 local_op = constant_op.constant(43).op 222 # Ensure all frames look like TF frames. 223 modified_tb = _modify_op_stack_with_filenames( 224 local_op.traceback[:7], # Truncate stack to known length. 225 num_user_frames=0, 226 user_filename="user_file.py", 227 num_inner_tf_frames=7) 228 idx = error_interpolation._find_index_of_defining_frame(modified_tb) 229 self.assertEqual(0, idx) 230 231 def testNothingToDo(self): 232 with ops.Graph().as_default(): 233 constant_op.constant(1, name="One") 234 normal_string = "This is just a normal string" 235 interpolated_string = error_interpolation.interpolate( 236 normal_string, ops.get_default_graph()) 237 self.assertIn(normal_string, interpolated_string) 238 239 def testOneTagWithAFakeNameResultsInPlaceholders(self): 240 with ops.Graph().as_default(): 241 one_tag_string = "{{node MinusOne}}" 242 interpolated_string = error_interpolation.interpolate( 243 one_tag_string, ops.get_default_graph()) 244 self.assertIn(one_tag_string, interpolated_string) 245 246 def testOneTagWithAFakeFunctionTag(self): 247 defined_at = r"defined at.*error_interpolation_test\.py" 248 with ops.Graph().as_default(): 249 constant_op.constant(1, name="One") 250 constant_op.constant(2, name="Two") 251 one_tag_with_a_fake_function_tag = "{{function_node fake}}{{node One}}" 252 interpolated_string = error_interpolation.interpolate( 253 one_tag_with_a_fake_function_tag, ops.get_default_graph()) 254 # Fragments the expression to avoid matching the pattern itself. 255 expected_regex = re.compile(rf"node 'One'.*{defined_at}", re.DOTALL) 256 self.assertRegex(interpolated_string, expected_regex) 257 self.assertNotIn("function_node", interpolated_string) 258 self.assertNotIn("node 'Two'", interpolated_string) 259 260 def testTwoTagsNoSeps(self): 261 defined_at = r"defined at.*error_interpolation_test\.py" 262 with ops.Graph().as_default(): 263 constant_op.constant(1, name="One") 264 constant_op.constant(2, name="Two") 265 constant_op.constant(3, name="Three") 266 two_tags_no_seps = "{{node One}}{{node Three}}" 267 interpolated_string = error_interpolation.interpolate( 268 two_tags_no_seps, ops.get_default_graph()) 269 # Fragments the expression to avoid matching the pattern itself. 270 expected_regex = re.compile( 271 rf"node 'One'.*{defined_at}.*node 'Three'.*{defined_at}", re.DOTALL) 272 self.assertRegex(interpolated_string, expected_regex) 273 274 def testTwoTagsWithSeps(self): 275 defined_at = r"defined at.*error_interpolation_test\.py" 276 with ops.Graph().as_default(): 277 constant_op.constant(1, name="One") 278 constant_op.constant(2, name="Two") 279 constant_op.constant(3, name="Three") 280 two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;" 281 interpolated_string = error_interpolation.interpolate( 282 two_tags_with_seps, ops.get_default_graph()) 283 # Fragments the expression to avoid matching the pattern itself. 284 expected_regex = re.compile( 285 rf"node 'Two'.*{defined_at}.*node 'Three'.*{defined_at}", re.DOTALL) 286 self.assertRegex(interpolated_string, expected_regex) 287 288 def testNewLine(self): 289 defined_at = r"defined at.*error_interpolation_test\.py" 290 with ops.Graph().as_default(): 291 constant_op.constant(1, name="One") 292 constant_op.constant(2, name="Two") 293 newline = "\n\n;;;{{node One}};;;" 294 interpolated_string = error_interpolation.interpolate( 295 newline, ops.get_default_graph()) 296 expected_regex = re.compile(rf"node 'One'.*{defined_at}", re.DOTALL) 297 self.assertRegex(interpolated_string, expected_regex) 298 299 300class OperationDefinedAtTraceTest(test.TestCase): 301 302 @test_util.run_v2_only 303 def testSimpleCall(self): 304 305 @def_function.function 306 def func(): 307 x = constant_op.constant([[1, 2, 3]]) 308 y = script_ops.eager_py_func(lambda: [[1, 2, 3]], (), dtypes.int32) 309 return math_ops.matmul(x, y) 310 311 with self.assertRaisesRegex( 312 errors_impl.InvalidArgumentError, 313 re.compile(r"defined at.*" 314 r"in testSimpleCall.*" 315 r"in func", re.DOTALL)): 316 func() 317 318 @test_util.run_v2_only 319 def testNestedCall(self): 320 321 def inner(): 322 x = constant_op.constant([[1, 2, 3]]) 323 y = script_ops.eager_py_func(lambda: [[1, 2, 3]], (), dtypes.int32) 324 return math_ops.matmul(x, y) 325 326 @def_function.function 327 def func(): 328 return inner() 329 330 with self.assertRaisesRegex( 331 errors_impl.InvalidArgumentError, 332 re.compile(r"defined at.*" 333 r"in testNestedCall.*" 334 r"in func.*" 335 r"in inner", re.DOTALL)): 336 func() 337 338 @test_util.run_v2_only 339 def testAssert(self): 340 @def_function.function 341 def func(): 342 control_flow_ops.Assert(False, [False]) 343 return 344 345 with self.assertRaisesRegex( 346 errors_impl.InvalidArgumentError, 347 re.compile(r"defined at.*" 348 r"in testAssert.*" 349 r"in func", re.DOTALL)): 350 func() 351 352 @test_util.run_v2_only 353 def testControlFlow(self): 354 @def_function.function 355 def func(): 356 if constant_op.constant(False): 357 return constant_op.constant(1) 358 359 else: 360 x = constant_op.constant([[1, 2, 3]]) 361 y = script_ops.eager_py_func(lambda: [[1, 2, 3]], (), dtypes.int32) 362 return math_ops.matmul(x, y) 363 364 with self.assertRaisesRegex( 365 errors_impl.InvalidArgumentError, 366 re.compile(r"defined at.*" 367 r"in testControlFlow.*" 368 r"in func", re.DOTALL)): 369 func() 370 371 372class IsFrameworkFilenameTest(test.TestCase): 373 374 def testAllowsUnitTests(self): 375 self.assertFalse( 376 error_interpolation._is_framework_filename( 377 error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "foobar_test.py")) 378 379 def testFrameworkPythonFile(self): 380 self.assertTrue( 381 error_interpolation._is_framework_filename( 382 error_interpolation.__file__)) 383 384 def testEmbedded(self): 385 self.assertTrue( 386 error_interpolation._is_framework_filename( 387 "<embedded stdlib>/context_lib.py")) 388 389 390if __name__ == "__main__": 391 test.main() 392