1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for source_utils.""" 16 17import ast 18import os 19import sys 20import tempfile 21import zipfile 22 23import numpy as np 24 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.python.client import session 27from tensorflow.python.debug.lib import debug_data 28from tensorflow.python.debug.lib import debug_utils 29from tensorflow.python.debug.lib import source_utils 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import test_util 33from tensorflow.python.lib.io import file_io 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import math_ops 36# Import resource_variable_ops for the variables-to-tensor implicit conversion. 37from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import googletest 40from tensorflow.python.util import tf_inspect 41 42 43def line_number_above(): 44 """Get lineno of the AST node immediately above this function's call site. 45 46 It is assumed that there is no empty line(s) between the call site and the 47 preceding AST node. 48 49 Returns: 50 The lineno of the preceding AST node, at the same level of the AST. 51 If the preceding AST spans multiple lines: 52 - In Python 3.8+, the lineno of the first line is returned. 53 - In older Python versions, the lineno of the last line is returned. 54 """ 55 # https://bugs.python.org/issue12458: In Python 3.8, traceback started 56 # to return the lineno of the first line of a multi-line continuation block, 57 # instead of that of the last line. Therefore, in Python 3.8+, we use `ast` to 58 # get the lineno of the first line. 59 call_site_lineno = tf_inspect.stack()[1][2] 60 if sys.version_info < (3, 8): 61 return call_site_lineno - 1 62 else: 63 with open(__file__, "rb") as f: 64 source_text = f.read().decode("utf-8") 65 source_tree = ast.parse(source_text) 66 prev_node = _find_preceding_ast_node(source_tree, call_site_lineno) 67 return prev_node.lineno 68 69 70def _find_preceding_ast_node(node, lineno): 71 """Find the ast node immediately before and not including lineno.""" 72 for i, child_node in enumerate(node.body): 73 if child_node.lineno == lineno: 74 return node.body[i - 1] 75 if hasattr(child_node, "body"): 76 found_node = _find_preceding_ast_node(child_node, lineno) 77 if found_node: 78 return found_node 79 80 81class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase): 82 83 def setUp(self): 84 self.curr_file_path = os.path.normpath(os.path.abspath(__file__)) 85 86 def tearDown(self): 87 ops.reset_default_graph() 88 89 def testGuessedBaseDirIsProbablyCorrect(self): 90 # In the non-pip world, code resides in "tensorflow/" 91 # In the pip world, after virtual pip, code resides in "tensorflow_core/" 92 # So, we have to check both of them 93 self.assertIn( 94 os.path.basename(source_utils._TENSORFLOW_BASEDIR), 95 ["tensorflow", "tensorflow_core"]) 96 97 def testUnitTestFileReturnsFalse(self): 98 self.assertFalse( 99 source_utils.guess_is_tensorflow_py_library(self.curr_file_path)) 100 101 def testSourceUtilModuleReturnsTrue(self): 102 self.assertTrue( 103 source_utils.guess_is_tensorflow_py_library(source_utils.__file__)) 104 105 @test_util.run_v1_only("Tensor.op is not available in TF 2.x") 106 def testFileInPythonKernelsPathReturnsTrue(self): 107 x = constant_op.constant(42.0, name="x") 108 self.assertTrue( 109 source_utils.guess_is_tensorflow_py_library(x.op.traceback[-1][0])) 110 111 def testDebuggerExampleFilePathReturnsFalse(self): 112 self.assertFalse( 113 source_utils.guess_is_tensorflow_py_library(os.path.normpath( 114 "site-packages/tensorflow/python/debug/examples/debug_mnist.py"))) 115 self.assertFalse( 116 source_utils.guess_is_tensorflow_py_library(os.path.normpath( 117 "site-packages/tensorflow/python/debug/examples/v1/example_v1.py"))) 118 self.assertFalse( 119 source_utils.guess_is_tensorflow_py_library(os.path.normpath( 120 "site-packages/tensorflow/python/debug/examples/v2/example_v2.py"))) 121 self.assertFalse( 122 source_utils.guess_is_tensorflow_py_library(os.path.normpath( 123 "site-packages/tensorflow/python/debug/examples/v3/example_v3.py"))) 124 125 def testReturnsFalseForNonPythonFile(self): 126 self.assertFalse( 127 source_utils.guess_is_tensorflow_py_library( 128 os.path.join(os.path.dirname(self.curr_file_path), "foo.cc"))) 129 130 def testReturnsFalseForStdin(self): 131 self.assertFalse(source_utils.guess_is_tensorflow_py_library("<stdin>")) 132 133 def testReturnsFalseForEmptyFileName(self): 134 self.assertFalse(source_utils.guess_is_tensorflow_py_library("")) 135 136 137class SourceHelperTest(test_util.TensorFlowTestCase): 138 139 def createAndRunGraphHelper(self): 140 """Create and run a TensorFlow Graph to generate debug dumps. 141 142 This is intentionally done in separate method, to make it easier to test 143 the stack-top mode of source annotation. 144 """ 145 146 self.dump_root = self.get_temp_dir() 147 self.curr_file_path = os.path.abspath( 148 tf_inspect.getfile(tf_inspect.currentframe())) 149 150 # Run a simple TF graph to generate some debug dumps that can be used in 151 # source annotation. 152 with session.Session() as sess: 153 self.u_init = constant_op.constant( 154 np.array([[5.0, 3.0], [-1.0, 0.0]]), shape=[2, 2], name="u_init") 155 self.u_init_line_number = line_number_above() 156 157 self.u = variables.Variable(self.u_init, name="u") 158 self.u_line_number = line_number_above() 159 160 self.v_init = constant_op.constant( 161 np.array([[2.0], [-1.0]]), shape=[2, 1], name="v_init") 162 self.v_init_line_number = line_number_above() 163 164 self.v = variables.Variable(self.v_init, name="v") 165 self.v_line_number = line_number_above() 166 167 self.w = math_ops.matmul(self.u, self.v, name="w") 168 self.w_line_number = line_number_above() 169 170 self.evaluate(self.u.initializer) 171 self.evaluate(self.v.initializer) 172 173 run_options = config_pb2.RunOptions(output_partition_graphs=True) 174 debug_utils.watch_graph( 175 run_options, sess.graph, debug_urls=["file://%s" % self.dump_root]) 176 run_metadata = config_pb2.RunMetadata() 177 sess.run(self.w, options=run_options, run_metadata=run_metadata) 178 179 self.dump = debug_data.DebugDumpDir( 180 self.dump_root, partition_graphs=run_metadata.partition_graphs) 181 self.dump.set_python_graph(sess.graph) 182 183 def setUp(self): 184 self.createAndRunGraphHelper() 185 self.helper_line_number = line_number_above() 186 187 def tearDown(self): 188 if os.path.isdir(self.dump_root): 189 file_io.delete_recursively(self.dump_root) 190 ops.reset_default_graph() 191 192 def testAnnotateWholeValidSourceFileGivesCorrectResult(self): 193 source_annotation = source_utils.annotate_source(self.dump, 194 self.curr_file_path) 195 196 self.assertIn(self.u_init.op.name, 197 source_annotation[self.u_init_line_number]) 198 self.assertIn(self.u.op.name, source_annotation[self.u_line_number]) 199 self.assertIn(self.v_init.op.name, 200 source_annotation[self.v_init_line_number]) 201 self.assertIn(self.v.op.name, source_annotation[self.v_line_number]) 202 self.assertIn(self.w.op.name, source_annotation[self.w_line_number]) 203 204 # In the non-stack-top (default) mode, the helper line should be annotated 205 # with all the ops as well. 206 self.assertIn(self.u_init.op.name, 207 source_annotation[self.helper_line_number]) 208 self.assertIn(self.u.op.name, source_annotation[self.helper_line_number]) 209 self.assertIn(self.v_init.op.name, 210 source_annotation[self.helper_line_number]) 211 self.assertIn(self.v.op.name, source_annotation[self.helper_line_number]) 212 self.assertIn(self.w.op.name, source_annotation[self.helper_line_number]) 213 214 def testAnnotateWithStackTopGivesCorrectResult(self): 215 source_annotation = source_utils.annotate_source( 216 self.dump, self.curr_file_path, file_stack_top=True) 217 218 self.assertIn(self.u_init.op.name, 219 source_annotation[self.u_init_line_number]) 220 self.assertIn(self.u.op.name, source_annotation[self.u_line_number]) 221 self.assertIn(self.v_init.op.name, 222 source_annotation[self.v_init_line_number]) 223 self.assertIn(self.v.op.name, source_annotation[self.v_line_number]) 224 self.assertIn(self.w.op.name, source_annotation[self.w_line_number]) 225 226 # In the stack-top mode, the helper line should not have been annotated. 227 self.assertNotIn(self.helper_line_number, source_annotation) 228 229 def testAnnotateSubsetOfLinesGivesCorrectResult(self): 230 source_annotation = source_utils.annotate_source( 231 self.dump, 232 self.curr_file_path, 233 min_line=self.u_line_number, 234 max_line=self.u_line_number + 1) 235 236 self.assertIn(self.u.op.name, source_annotation[self.u_line_number]) 237 self.assertNotIn(self.v_line_number, source_annotation) 238 239 def testAnnotateDumpedTensorsGivesCorrectResult(self): 240 source_annotation = source_utils.annotate_source( 241 self.dump, self.curr_file_path, do_dumped_tensors=True) 242 243 # Note: Constant Tensors u_init and v_init may not get dumped due to 244 # constant-folding. 245 self.assertIn(self.u.name, source_annotation[self.u_line_number]) 246 self.assertIn(self.v.name, source_annotation[self.v_line_number]) 247 self.assertIn(self.w.name, source_annotation[self.w_line_number]) 248 249 self.assertNotIn(self.u.op.name, source_annotation[self.u_line_number]) 250 self.assertNotIn(self.v.op.name, source_annotation[self.v_line_number]) 251 self.assertNotIn(self.w.op.name, source_annotation[self.w_line_number]) 252 253 self.assertIn(self.u.name, source_annotation[self.helper_line_number]) 254 self.assertIn(self.v.name, source_annotation[self.helper_line_number]) 255 self.assertIn(self.w.name, source_annotation[self.helper_line_number]) 256 257 def testCallingAnnotateSourceWithoutPythonGraphRaisesException(self): 258 self.dump.set_python_graph(None) 259 with self.assertRaises(ValueError): 260 source_utils.annotate_source(self.dump, self.curr_file_path) 261 262 def testCallingAnnotateSourceOnUnrelatedSourceFileDoesNotError(self): 263 # Create an unrelated source file. 264 fd, unrelated_source_path = tempfile.mkstemp() 265 with open(fd, "wt") as source_file: 266 source_file.write("print('hello, world')\n") 267 268 self.assertEqual({}, 269 source_utils.annotate_source(self.dump, 270 unrelated_source_path)) 271 272 # Clean up unrelated source file. 273 os.remove(unrelated_source_path) 274 275 def testLoadingPythonSourceFileWithNonAsciiChars(self): 276 fd, source_path = tempfile.mkstemp() 277 with open(fd, "wb") as source_file: 278 source_file.write(u"print('\U0001f642')\n".encode("utf-8")) 279 source_lines, _ = source_utils.load_source(source_path) 280 self.assertEqual(source_lines, [u"print('\U0001f642')", u""]) 281 # Clean up unrelated source file. 282 os.remove(source_path) 283 284 def testLoadNonexistentNonParPathFailsWithIOError(self): 285 bad_path = os.path.join(self.get_temp_dir(), "nonexistent.py") 286 with self.assertRaisesRegex(IOError, 287 "neither exists nor can be loaded.*par.*"): 288 source_utils.load_source(bad_path) 289 290 def testLoadingPythonSourceFileInParFileSucceeds(self): 291 # Create the .par file first. 292 temp_file_path = os.path.join(self.get_temp_dir(), "model.py") 293 with open(temp_file_path, "wb") as f: 294 f.write(b"import tensorflow as tf\nx = tf.constant(42.0)\n") 295 par_path = os.path.join(self.get_temp_dir(), "train_model.par") 296 with zipfile.ZipFile(par_path, "w") as zf: 297 zf.write(temp_file_path, os.path.join("tensorflow_models", "model.py")) 298 299 source_path = os.path.join(par_path, "tensorflow_models", "model.py") 300 source_lines, _ = source_utils.load_source(source_path) 301 self.assertEqual( 302 source_lines, ["import tensorflow as tf", "x = tf.constant(42.0)", ""]) 303 304 def testLoadingPythonSourceFileInParFileFailsRaisingIOError(self): 305 # Create the .par file first. 306 temp_file_path = os.path.join(self.get_temp_dir(), "model.py") 307 with open(temp_file_path, "wb") as f: 308 f.write(b"import tensorflow as tf\nx = tf.constant(42.0)\n") 309 par_path = os.path.join(self.get_temp_dir(), "train_model.par") 310 with zipfile.ZipFile(par_path, "w") as zf: 311 zf.write(temp_file_path, os.path.join("tensorflow_models", "model.py")) 312 313 source_path = os.path.join(par_path, "tensorflow_models", "nonexistent.py") 314 with self.assertRaisesRegex(IOError, 315 "neither exists nor can be loaded.*par.*"): 316 source_utils.load_source(source_path) 317 318 319@test_util.run_v1_only("Sessions are not available in TF 2.x") 320class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase): 321 322 def createAndRunGraphWithWhileLoop(self): 323 """Create and run a TensorFlow Graph with a while loop to generate dumps.""" 324 325 self.dump_root = self.get_temp_dir() 326 self.curr_file_path = os.path.abspath( 327 tf_inspect.getfile(tf_inspect.currentframe())) 328 329 # Run a simple TF graph to generate some debug dumps that can be used in 330 # source annotation. 331 with session.Session() as sess: 332 loop_body = lambda i: math_ops.add(i, 2) 333 self.traceback_first_line = line_number_above() 334 335 loop_cond = lambda i: math_ops.less(i, 16) 336 337 i = constant_op.constant(10, name="i") 338 loop = control_flow_ops.while_loop(loop_cond, loop_body, [i]) 339 340 run_options = config_pb2.RunOptions(output_partition_graphs=True) 341 debug_utils.watch_graph( 342 run_options, sess.graph, debug_urls=["file://%s" % self.dump_root]) 343 run_metadata = config_pb2.RunMetadata() 344 sess.run(loop, options=run_options, run_metadata=run_metadata) 345 346 self.dump = debug_data.DebugDumpDir( 347 self.dump_root, partition_graphs=run_metadata.partition_graphs) 348 self.dump.set_python_graph(sess.graph) 349 350 def setUp(self): 351 self.createAndRunGraphWithWhileLoop() 352 353 def tearDown(self): 354 if os.path.isdir(self.dump_root): 355 file_io.delete_recursively(self.dump_root) 356 ops.reset_default_graph() 357 358 def testGenerateSourceList(self): 359 source_list = source_utils.list_source_files_against_dump(self.dump) 360 361 # Assert that the file paths are sorted and unique. 362 file_paths = [item[0] for item in source_list] 363 self.assertEqual(sorted(file_paths), file_paths) 364 self.assertEqual(len(set(file_paths)), len(file_paths)) 365 366 # Assert that each item of source_list has length 6. 367 for item in source_list: 368 self.assertTrue(isinstance(item, tuple)) 369 self.assertEqual(6, len(item)) 370 371 # The while loop body should have executed 3 times. The following table 372 # lists the tensors and how many times each of them is dumped. 373 # Tensor name # of times dumped: 374 # i:0 1 375 # while/Enter:0 1 376 # while/Merge:0 4 377 # while/Merge:1 4 378 # while/Less/y:0 4 379 # while/Less:0 4 380 # while/LoopCond:0 4 381 # while/Switch:0 1 382 # while/Switch:1 3 383 # while/Identity:0 3 384 # while/Add/y:0 3 385 # while/Add:0 3 386 # while/NextIteration:0 3 387 # while/Exit:0 1 388 # ---------------------------- 389 # (Total) 39 390 # 391 # The total number of nodes is 12. 392 # The total number of tensors is 14 (2 of the nodes have 2 outputs: 393 # while/Merge, while/Switch). 394 395 _, is_tf_py_library, num_nodes, num_tensors, num_dumps, first_line = ( 396 source_list[file_paths.index(self.curr_file_path)]) 397 self.assertFalse(is_tf_py_library) 398 self.assertEqual(12, num_nodes) 399 self.assertEqual(14, num_tensors) 400 self.assertEqual(39, num_dumps) 401 self.assertEqual(self.traceback_first_line, first_line) 402 403 def testGenerateSourceListWithNodeNameFilter(self): 404 source_list = source_utils.list_source_files_against_dump( 405 self.dump, node_name_regex_allowlist=r"while/Add.*") 406 407 # Assert that the file paths are sorted. 408 file_paths = [item[0] for item in source_list] 409 self.assertEqual(sorted(file_paths), file_paths) 410 self.assertEqual(len(set(file_paths)), len(file_paths)) 411 412 # Assert that each item of source_list has length 4. 413 for item in source_list: 414 self.assertTrue(isinstance(item, tuple)) 415 self.assertEqual(6, len(item)) 416 417 # Due to the node-name filtering the result should only contain 2 nodes 418 # and 2 tensors. The total number of dumped tensors should be 6: 419 # while/Add/y:0 3 420 # while/Add:0 3 421 _, is_tf_py_library, num_nodes, num_tensors, num_dumps, _ = ( 422 source_list[file_paths.index(self.curr_file_path)]) 423 self.assertFalse(is_tf_py_library) 424 self.assertEqual(2, num_nodes) 425 self.assertEqual(2, num_tensors) 426 self.assertEqual(6, num_dumps) 427 428 def testGenerateSourceListWithPathRegexFilter(self): 429 curr_file_basename = os.path.basename(self.curr_file_path) 430 source_list = source_utils.list_source_files_against_dump( 431 self.dump, 432 path_regex_allowlist=(".*" + curr_file_basename.replace(".", "\\.") + 433 "$")) 434 435 self.assertEqual(1, len(source_list)) 436 (file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps, 437 first_line) = source_list[0] 438 self.assertEqual(self.curr_file_path, file_path) 439 self.assertFalse(is_tf_py_library) 440 self.assertEqual(12, num_nodes) 441 self.assertEqual(14, num_tensors) 442 self.assertEqual(39, num_dumps) 443 self.assertEqual(self.traceback_first_line, first_line) 444 445 446if __name__ == "__main__": 447 googletest.main() 448