xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/source_utils_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for source_utils."""
16
17import ast
18import os
19import sys
20import tempfile
21import zipfile
22
23import numpy as np
24
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.python.client import session
27from tensorflow.python.debug.lib import debug_data
28from tensorflow.python.debug.lib import debug_utils
29from tensorflow.python.debug.lib import source_utils
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import test_util
33from tensorflow.python.lib.io import file_io
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import math_ops
36# Import resource_variable_ops for the variables-to-tensor implicit conversion.
37from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import googletest
40from tensorflow.python.util import tf_inspect
41
42
43def line_number_above():
44  """Get lineno of the AST node immediately above this function's call site.
45
46  It is assumed that there is no empty line(s) between the call site and the
47  preceding AST node.
48
49  Returns:
50    The lineno of the preceding AST node, at the same level of the AST.
51    If the preceding AST spans multiple lines:
52      - In Python 3.8+, the lineno of the first line is returned.
53      - In older Python versions, the lineno of the last line is returned.
54  """
55  # https://bugs.python.org/issue12458: In Python 3.8, traceback started
56  # to return the lineno of the first line of a multi-line continuation block,
57  # instead of that of the last line. Therefore, in Python 3.8+, we use `ast` to
58  # get the lineno of the first line.
59  call_site_lineno = tf_inspect.stack()[1][2]
60  if sys.version_info < (3, 8):
61    return call_site_lineno - 1
62  else:
63    with open(__file__, "rb") as f:
64      source_text = f.read().decode("utf-8")
65    source_tree = ast.parse(source_text)
66    prev_node = _find_preceding_ast_node(source_tree, call_site_lineno)
67    return prev_node.lineno
68
69
70def _find_preceding_ast_node(node, lineno):
71  """Find the ast node immediately before and not including lineno."""
72  for i, child_node in enumerate(node.body):
73    if child_node.lineno == lineno:
74      return node.body[i - 1]
75    if hasattr(child_node, "body"):
76      found_node = _find_preceding_ast_node(child_node, lineno)
77      if found_node:
78        return found_node
79
80
81class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):
82
83  def setUp(self):
84    self.curr_file_path = os.path.normpath(os.path.abspath(__file__))
85
86  def tearDown(self):
87    ops.reset_default_graph()
88
89  def testGuessedBaseDirIsProbablyCorrect(self):
90    # In the non-pip world, code resides in "tensorflow/"
91    # In the pip world, after virtual pip, code resides in "tensorflow_core/"
92    # So, we have to check both of them
93    self.assertIn(
94        os.path.basename(source_utils._TENSORFLOW_BASEDIR),
95        ["tensorflow", "tensorflow_core"])
96
97  def testUnitTestFileReturnsFalse(self):
98    self.assertFalse(
99        source_utils.guess_is_tensorflow_py_library(self.curr_file_path))
100
101  def testSourceUtilModuleReturnsTrue(self):
102    self.assertTrue(
103        source_utils.guess_is_tensorflow_py_library(source_utils.__file__))
104
105  @test_util.run_v1_only("Tensor.op is not available in TF 2.x")
106  def testFileInPythonKernelsPathReturnsTrue(self):
107    x = constant_op.constant(42.0, name="x")
108    self.assertTrue(
109        source_utils.guess_is_tensorflow_py_library(x.op.traceback[-1][0]))
110
111  def testDebuggerExampleFilePathReturnsFalse(self):
112    self.assertFalse(
113        source_utils.guess_is_tensorflow_py_library(os.path.normpath(
114            "site-packages/tensorflow/python/debug/examples/debug_mnist.py")))
115    self.assertFalse(
116        source_utils.guess_is_tensorflow_py_library(os.path.normpath(
117            "site-packages/tensorflow/python/debug/examples/v1/example_v1.py")))
118    self.assertFalse(
119        source_utils.guess_is_tensorflow_py_library(os.path.normpath(
120            "site-packages/tensorflow/python/debug/examples/v2/example_v2.py")))
121    self.assertFalse(
122        source_utils.guess_is_tensorflow_py_library(os.path.normpath(
123            "site-packages/tensorflow/python/debug/examples/v3/example_v3.py")))
124
125  def testReturnsFalseForNonPythonFile(self):
126    self.assertFalse(
127        source_utils.guess_is_tensorflow_py_library(
128            os.path.join(os.path.dirname(self.curr_file_path), "foo.cc")))
129
130  def testReturnsFalseForStdin(self):
131    self.assertFalse(source_utils.guess_is_tensorflow_py_library("<stdin>"))
132
133  def testReturnsFalseForEmptyFileName(self):
134    self.assertFalse(source_utils.guess_is_tensorflow_py_library(""))
135
136
137class SourceHelperTest(test_util.TensorFlowTestCase):
138
139  def createAndRunGraphHelper(self):
140    """Create and run a TensorFlow Graph to generate debug dumps.
141
142    This is intentionally done in separate method, to make it easier to test
143    the stack-top mode of source annotation.
144    """
145
146    self.dump_root = self.get_temp_dir()
147    self.curr_file_path = os.path.abspath(
148        tf_inspect.getfile(tf_inspect.currentframe()))
149
150    # Run a simple TF graph to generate some debug dumps that can be used in
151    # source annotation.
152    with session.Session() as sess:
153      self.u_init = constant_op.constant(
154          np.array([[5.0, 3.0], [-1.0, 0.0]]), shape=[2, 2], name="u_init")
155      self.u_init_line_number = line_number_above()
156
157      self.u = variables.Variable(self.u_init, name="u")
158      self.u_line_number = line_number_above()
159
160      self.v_init = constant_op.constant(
161          np.array([[2.0], [-1.0]]), shape=[2, 1], name="v_init")
162      self.v_init_line_number = line_number_above()
163
164      self.v = variables.Variable(self.v_init, name="v")
165      self.v_line_number = line_number_above()
166
167      self.w = math_ops.matmul(self.u, self.v, name="w")
168      self.w_line_number = line_number_above()
169
170      self.evaluate(self.u.initializer)
171      self.evaluate(self.v.initializer)
172
173      run_options = config_pb2.RunOptions(output_partition_graphs=True)
174      debug_utils.watch_graph(
175          run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
176      run_metadata = config_pb2.RunMetadata()
177      sess.run(self.w, options=run_options, run_metadata=run_metadata)
178
179      self.dump = debug_data.DebugDumpDir(
180          self.dump_root, partition_graphs=run_metadata.partition_graphs)
181      self.dump.set_python_graph(sess.graph)
182
183  def setUp(self):
184    self.createAndRunGraphHelper()
185    self.helper_line_number = line_number_above()
186
187  def tearDown(self):
188    if os.path.isdir(self.dump_root):
189      file_io.delete_recursively(self.dump_root)
190    ops.reset_default_graph()
191
192  def testAnnotateWholeValidSourceFileGivesCorrectResult(self):
193    source_annotation = source_utils.annotate_source(self.dump,
194                                                     self.curr_file_path)
195
196    self.assertIn(self.u_init.op.name,
197                  source_annotation[self.u_init_line_number])
198    self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
199    self.assertIn(self.v_init.op.name,
200                  source_annotation[self.v_init_line_number])
201    self.assertIn(self.v.op.name, source_annotation[self.v_line_number])
202    self.assertIn(self.w.op.name, source_annotation[self.w_line_number])
203
204    # In the non-stack-top (default) mode, the helper line should be annotated
205    # with all the ops as well.
206    self.assertIn(self.u_init.op.name,
207                  source_annotation[self.helper_line_number])
208    self.assertIn(self.u.op.name, source_annotation[self.helper_line_number])
209    self.assertIn(self.v_init.op.name,
210                  source_annotation[self.helper_line_number])
211    self.assertIn(self.v.op.name, source_annotation[self.helper_line_number])
212    self.assertIn(self.w.op.name, source_annotation[self.helper_line_number])
213
214  def testAnnotateWithStackTopGivesCorrectResult(self):
215    source_annotation = source_utils.annotate_source(
216        self.dump, self.curr_file_path, file_stack_top=True)
217
218    self.assertIn(self.u_init.op.name,
219                  source_annotation[self.u_init_line_number])
220    self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
221    self.assertIn(self.v_init.op.name,
222                  source_annotation[self.v_init_line_number])
223    self.assertIn(self.v.op.name, source_annotation[self.v_line_number])
224    self.assertIn(self.w.op.name, source_annotation[self.w_line_number])
225
226    # In the stack-top mode, the helper line should not have been annotated.
227    self.assertNotIn(self.helper_line_number, source_annotation)
228
229  def testAnnotateSubsetOfLinesGivesCorrectResult(self):
230    source_annotation = source_utils.annotate_source(
231        self.dump,
232        self.curr_file_path,
233        min_line=self.u_line_number,
234        max_line=self.u_line_number + 1)
235
236    self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
237    self.assertNotIn(self.v_line_number, source_annotation)
238
239  def testAnnotateDumpedTensorsGivesCorrectResult(self):
240    source_annotation = source_utils.annotate_source(
241        self.dump, self.curr_file_path, do_dumped_tensors=True)
242
243    # Note: Constant Tensors u_init and v_init may not get dumped due to
244    #   constant-folding.
245    self.assertIn(self.u.name, source_annotation[self.u_line_number])
246    self.assertIn(self.v.name, source_annotation[self.v_line_number])
247    self.assertIn(self.w.name, source_annotation[self.w_line_number])
248
249    self.assertNotIn(self.u.op.name, source_annotation[self.u_line_number])
250    self.assertNotIn(self.v.op.name, source_annotation[self.v_line_number])
251    self.assertNotIn(self.w.op.name, source_annotation[self.w_line_number])
252
253    self.assertIn(self.u.name, source_annotation[self.helper_line_number])
254    self.assertIn(self.v.name, source_annotation[self.helper_line_number])
255    self.assertIn(self.w.name, source_annotation[self.helper_line_number])
256
257  def testCallingAnnotateSourceWithoutPythonGraphRaisesException(self):
258    self.dump.set_python_graph(None)
259    with self.assertRaises(ValueError):
260      source_utils.annotate_source(self.dump, self.curr_file_path)
261
262  def testCallingAnnotateSourceOnUnrelatedSourceFileDoesNotError(self):
263    # Create an unrelated source file.
264    fd, unrelated_source_path = tempfile.mkstemp()
265    with open(fd, "wt") as source_file:
266      source_file.write("print('hello, world')\n")
267
268    self.assertEqual({},
269                     source_utils.annotate_source(self.dump,
270                                                  unrelated_source_path))
271
272    # Clean up unrelated source file.
273    os.remove(unrelated_source_path)
274
275  def testLoadingPythonSourceFileWithNonAsciiChars(self):
276    fd, source_path = tempfile.mkstemp()
277    with open(fd, "wb") as source_file:
278      source_file.write(u"print('\U0001f642')\n".encode("utf-8"))
279    source_lines, _ = source_utils.load_source(source_path)
280    self.assertEqual(source_lines, [u"print('\U0001f642')", u""])
281    # Clean up unrelated source file.
282    os.remove(source_path)
283
284  def testLoadNonexistentNonParPathFailsWithIOError(self):
285    bad_path = os.path.join(self.get_temp_dir(), "nonexistent.py")
286    with self.assertRaisesRegex(IOError,
287                                "neither exists nor can be loaded.*par.*"):
288      source_utils.load_source(bad_path)
289
290  def testLoadingPythonSourceFileInParFileSucceeds(self):
291    # Create the .par file first.
292    temp_file_path = os.path.join(self.get_temp_dir(), "model.py")
293    with open(temp_file_path, "wb") as f:
294      f.write(b"import tensorflow as tf\nx = tf.constant(42.0)\n")
295    par_path = os.path.join(self.get_temp_dir(), "train_model.par")
296    with zipfile.ZipFile(par_path, "w") as zf:
297      zf.write(temp_file_path, os.path.join("tensorflow_models", "model.py"))
298
299    source_path = os.path.join(par_path, "tensorflow_models", "model.py")
300    source_lines, _ = source_utils.load_source(source_path)
301    self.assertEqual(
302        source_lines, ["import tensorflow as tf", "x = tf.constant(42.0)", ""])
303
304  def testLoadingPythonSourceFileInParFileFailsRaisingIOError(self):
305    # Create the .par file first.
306    temp_file_path = os.path.join(self.get_temp_dir(), "model.py")
307    with open(temp_file_path, "wb") as f:
308      f.write(b"import tensorflow as tf\nx = tf.constant(42.0)\n")
309    par_path = os.path.join(self.get_temp_dir(), "train_model.par")
310    with zipfile.ZipFile(par_path, "w") as zf:
311      zf.write(temp_file_path, os.path.join("tensorflow_models", "model.py"))
312
313    source_path = os.path.join(par_path, "tensorflow_models", "nonexistent.py")
314    with self.assertRaisesRegex(IOError,
315                                "neither exists nor can be loaded.*par.*"):
316      source_utils.load_source(source_path)
317
318
319@test_util.run_v1_only("Sessions are not available in TF 2.x")
320class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):
321
322  def createAndRunGraphWithWhileLoop(self):
323    """Create and run a TensorFlow Graph with a while loop to generate dumps."""
324
325    self.dump_root = self.get_temp_dir()
326    self.curr_file_path = os.path.abspath(
327        tf_inspect.getfile(tf_inspect.currentframe()))
328
329    # Run a simple TF graph to generate some debug dumps that can be used in
330    # source annotation.
331    with session.Session() as sess:
332      loop_body = lambda i: math_ops.add(i, 2)
333      self.traceback_first_line = line_number_above()
334
335      loop_cond = lambda i: math_ops.less(i, 16)
336
337      i = constant_op.constant(10, name="i")
338      loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
339
340      run_options = config_pb2.RunOptions(output_partition_graphs=True)
341      debug_utils.watch_graph(
342          run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
343      run_metadata = config_pb2.RunMetadata()
344      sess.run(loop, options=run_options, run_metadata=run_metadata)
345
346      self.dump = debug_data.DebugDumpDir(
347          self.dump_root, partition_graphs=run_metadata.partition_graphs)
348      self.dump.set_python_graph(sess.graph)
349
350  def setUp(self):
351    self.createAndRunGraphWithWhileLoop()
352
353  def tearDown(self):
354    if os.path.isdir(self.dump_root):
355      file_io.delete_recursively(self.dump_root)
356    ops.reset_default_graph()
357
358  def testGenerateSourceList(self):
359    source_list = source_utils.list_source_files_against_dump(self.dump)
360
361    # Assert that the file paths are sorted and unique.
362    file_paths = [item[0] for item in source_list]
363    self.assertEqual(sorted(file_paths), file_paths)
364    self.assertEqual(len(set(file_paths)), len(file_paths))
365
366    # Assert that each item of source_list has length 6.
367    for item in source_list:
368      self.assertTrue(isinstance(item, tuple))
369      self.assertEqual(6, len(item))
370
371    # The while loop body should have executed 3 times. The following table
372    # lists the tensors and how many times each of them is dumped.
373    #   Tensor name            # of times dumped:
374    #   i:0                    1
375    #   while/Enter:0          1
376    #   while/Merge:0          4
377    #   while/Merge:1          4
378    #   while/Less/y:0         4
379    #   while/Less:0           4
380    #   while/LoopCond:0       4
381    #   while/Switch:0         1
382    #   while/Switch:1         3
383    #   while/Identity:0       3
384    #   while/Add/y:0          3
385    #   while/Add:0            3
386    #   while/NextIteration:0  3
387    #   while/Exit:0           1
388    # ----------------------------
389    #   (Total)                39
390    #
391    # The total number of nodes is 12.
392    # The total number of tensors is 14 (2 of the nodes have 2 outputs:
393    #   while/Merge, while/Switch).
394
395    _, is_tf_py_library, num_nodes, num_tensors, num_dumps, first_line = (
396        source_list[file_paths.index(self.curr_file_path)])
397    self.assertFalse(is_tf_py_library)
398    self.assertEqual(12, num_nodes)
399    self.assertEqual(14, num_tensors)
400    self.assertEqual(39, num_dumps)
401    self.assertEqual(self.traceback_first_line, first_line)
402
403  def testGenerateSourceListWithNodeNameFilter(self):
404    source_list = source_utils.list_source_files_against_dump(
405        self.dump, node_name_regex_allowlist=r"while/Add.*")
406
407    # Assert that the file paths are sorted.
408    file_paths = [item[0] for item in source_list]
409    self.assertEqual(sorted(file_paths), file_paths)
410    self.assertEqual(len(set(file_paths)), len(file_paths))
411
412    # Assert that each item of source_list has length 4.
413    for item in source_list:
414      self.assertTrue(isinstance(item, tuple))
415      self.assertEqual(6, len(item))
416
417    # Due to the node-name filtering the result should only contain 2 nodes
418    # and 2 tensors. The total number of dumped tensors should be 6:
419    #   while/Add/y:0          3
420    #   while/Add:0            3
421    _, is_tf_py_library, num_nodes, num_tensors, num_dumps, _ = (
422        source_list[file_paths.index(self.curr_file_path)])
423    self.assertFalse(is_tf_py_library)
424    self.assertEqual(2, num_nodes)
425    self.assertEqual(2, num_tensors)
426    self.assertEqual(6, num_dumps)
427
428  def testGenerateSourceListWithPathRegexFilter(self):
429    curr_file_basename = os.path.basename(self.curr_file_path)
430    source_list = source_utils.list_source_files_against_dump(
431        self.dump,
432        path_regex_allowlist=(".*" + curr_file_basename.replace(".", "\\.") +
433                              "$"))
434
435    self.assertEqual(1, len(source_list))
436    (file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps,
437     first_line) = source_list[0]
438    self.assertEqual(self.curr_file_path, file_path)
439    self.assertFalse(is_tf_py_library)
440    self.assertEqual(12, num_nodes)
441    self.assertEqual(14, num_tensors)
442    self.assertEqual(39, num_dumps)
443    self.assertEqual(self.traceback_first_line, first_line)
444
445
446if __name__ == "__main__":
447  googletest.main()
448