xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/source_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions that help to inspect Python source w.r.t. TF graphs."""
16
17import collections
18import os
19import re
20import zipfile
21
22import absl
23import numpy as np
24
25from tensorflow.python.debug.lib import profiling
26
27
28_TENSORFLOW_BASEDIR = os.path.dirname(
29    os.path.dirname(os.path.dirname(os.path.dirname(
30        os.path.normpath(os.path.abspath(__file__))))))
31
32_ABSL_BASEDIR = os.path.dirname(absl.__file__)
33
34
35UNCOMPILED_SOURCE_SUFFIXES = (".py")
36COMPILED_SOURCE_SUFFIXES = (".pyc", ".pyo")
37
38
39def _norm_abs_path(file_path):
40  return os.path.normpath(os.path.abspath(file_path))
41
42
43def is_extension_uncompiled_python_source(file_path):
44  _, extension = os.path.splitext(file_path)
45  return extension.lower() in UNCOMPILED_SOURCE_SUFFIXES
46
47
48def is_extension_compiled_python_source(file_path):
49  _, extension = os.path.splitext(file_path)
50  return extension.lower() in COMPILED_SOURCE_SUFFIXES
51
52
53def _convert_watch_key_to_tensor_name(watch_key):
54  return watch_key[:watch_key.rfind(":")]
55
56
57def guess_is_tensorflow_py_library(py_file_path):
58  """Guess whether a Python source file is a part of the tensorflow library.
59
60  Special cases:
61    1) Returns False for unit-test files in the library (*_test.py),
62    2) Returns False for files under python/debug/examples.
63
64  Args:
65    py_file_path: full path of the Python source file in question.
66
67  Returns:
68    (`bool`) Whether the file is inferred to be a part of the tensorflow
69      library.
70  """
71  if (not is_extension_uncompiled_python_source(py_file_path) and
72      not is_extension_compiled_python_source(py_file_path)):
73    return False
74  py_file_path = _norm_abs_path(py_file_path)
75  return ((py_file_path.startswith(_TENSORFLOW_BASEDIR) or
76           py_file_path.startswith(_ABSL_BASEDIR)) and
77          not py_file_path.endswith("_test.py") and
78          (os.path.normpath("tensorflow/python/debug/examples") not in
79           os.path.normpath(py_file_path)))
80
81
82def load_source(source_file_path):
83  """Load the content of a Python source code file.
84
85  This function covers the following case:
86    1. source_file_path points to an existing Python (.py) file on the
87       file system.
88    2. source_file_path is a path within a .par file (i.e., a zip-compressed,
89       self-contained Python executable).
90
91  Args:
92    source_file_path: Path to the Python source file to read.
93
94  Returns:
95    A length-2 tuple:
96      - Lines of the source file, as a `list` of `str`s.
97      - The width of the string needed to show the line number in the file.
98        This is calculated based on the number of lines in the source file.
99
100  Raises:
101    IOError: if loading is unsuccessful.
102  """
103  if os.path.isfile(source_file_path):
104    with open(source_file_path, "rb") as f:
105      source_text = f.read().decode("utf-8")
106    source_lines = source_text.split("\n")
107  else:
108    # One possible reason why the file doesn't exist is that it's a path
109    # inside a .par file. Try that possibility.
110    source_lines = _try_load_par_source(source_file_path)
111    if source_lines is None:
112      raise IOError(
113          "Source path neither exists nor can be loaded as a .par file: %s" %
114          source_file_path)
115  line_num_width = int(np.ceil(np.log10(len(source_lines)))) + 3
116  return source_lines, line_num_width
117
118
119def _try_load_par_source(source_file_path):
120  """Try loading the source code inside a .par file.
121
122  A .par file is a zip-compressed, self-contained Python executable.
123  It contains the content of individual Python source files that can
124  be read only through extracting from the zip file.
125
126  Args:
127    source_file_path: The full path to the file inside the .par file. This
128      path should include the path to the .par file itself, followed by the
129      intra-par path, e.g.,
130      "/tmp/my_executable.par/org-tensorflow/tensorflow/python/foo/bar.py".
131
132  Returns:
133    If successful, lines of the source file as a `list` of `str`s.
134    Else, `None`.
135  """
136  prefix_path = source_file_path
137  while True:
138    prefix_path, basename = os.path.split(prefix_path)
139    if not basename:
140      break
141    suffix_path = os.path.normpath(
142        os.path.relpath(source_file_path, start=prefix_path))
143    if prefix_path.endswith(".par") and os.path.isfile(prefix_path):
144      with zipfile.ZipFile(prefix_path) as z:
145        norm_names = [os.path.normpath(name) for name in z.namelist()]
146        if suffix_path in norm_names:
147          with z.open(z.namelist()[norm_names.index(suffix_path)]) as zf:
148            source_text = zf.read().decode("utf-8")
149            return source_text.split("\n")
150
151
152def annotate_source(dump,
153                    source_file_path,
154                    do_dumped_tensors=False,
155                    file_stack_top=False,
156                    min_line=None,
157                    max_line=None):
158  """Annotate a Python source file with a list of ops created at each line.
159
160  (The annotation doesn't change the source file itself.)
161
162  Args:
163    dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph
164      has been loaded.
165    source_file_path: (`str`) Path to the source file being annotated.
166    do_dumped_tensors: (`str`) Whether dumped Tensors, instead of ops are to be
167      used to annotate the source file.
168    file_stack_top: (`bool`) Whether only the top stack trace in the
169      specified source file is to be annotated.
170    min_line: (`None` or `int`) The 1-based line to start annotate the source
171      file from (inclusive).
172    max_line: (`None` or `int`) The 1-based line number to end the annotation
173      at (exclusive).
174
175  Returns:
176    A `dict` mapping 1-based line number to a list of op name(s) created at
177      that line, or tensor names if `do_dumped_tensors` is True.
178
179  Raises:
180    ValueError: If the dump object does not have a Python graph set.
181  """
182
183  py_graph = dump.python_graph
184  if not py_graph:
185    raise ValueError("Cannot perform source annotation due to a lack of set "
186                     "Python graph in the dump object")
187
188  source_file_path = _norm_abs_path(source_file_path)
189
190  line_to_op_names = {}
191  for op in py_graph.get_operations():
192    for file_path, line_number, _, _ in reversed(dump.node_traceback(op.name)):
193      if (min_line is not None and line_number < min_line or
194          max_line is not None and line_number >= max_line):
195        continue
196
197      if _norm_abs_path(file_path) != source_file_path:
198        continue
199
200      if do_dumped_tensors:
201        watch_keys = dump.debug_watch_keys(op.name)
202        # Convert watch keys to unique Tensor names.
203        items_to_append = list(
204            set(map(_convert_watch_key_to_tensor_name, watch_keys)))
205      else:
206        items_to_append = [op.name]
207
208      if line_number in line_to_op_names:
209        line_to_op_names[line_number].extend(items_to_append)
210      else:
211        line_to_op_names[line_number] = items_to_append
212
213      if file_stack_top:
214        break
215
216  return line_to_op_names
217
218
219def list_source_files_against_dump(dump,
220                                   path_regex_allowlist=None,
221                                   node_name_regex_allowlist=None):
222  """Generate a list of source files with information regarding ops and tensors.
223
224  Args:
225    dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph
226      has been loaded.
227    path_regex_allowlist: A regular-expression filter for source file path.
228    node_name_regex_allowlist: A regular-expression filter for node names.
229
230  Returns:
231    A list of tuples regarding the Python source files involved in constructing
232    the ops and tensors contained in `dump`. Each tuple is:
233      (source_file_path, is_tf_library, num_nodes, num_tensors, num_dumps,
234       first_line)
235
236      is_tf_library: (`bool`) A guess of whether the file belongs to the
237        TensorFlow Python library.
238      num_nodes: How many nodes were created by lines of this source file.
239        These include nodes with dumps and those without.
240      num_tensors: How many Tensors were created by lines of this source file.
241        These include Tensors with dumps and those without.
242      num_dumps: How many debug Tensor dumps were from nodes (and Tensors)
243        that were created by this source file.
244      first_line: The first line number (1-based) that created any nodes or
245        Tensors in this source file.
246
247    The list is sorted by ascending order of source_file_path.
248
249  Raises:
250    ValueError: If the dump object does not have a Python graph set.
251  """
252
253  py_graph = dump.python_graph
254  if not py_graph:
255    raise ValueError("Cannot generate source list due to a lack of set "
256                     "Python graph in the dump object")
257
258  path_to_node_names = collections.defaultdict(set)
259  path_to_tensor_names = collections.defaultdict(set)
260  path_to_first_line = {}
261  tensor_name_to_num_dumps = {}
262
263  path_regex = (
264      re.compile(path_regex_allowlist) if path_regex_allowlist else None)
265  node_name_regex = (
266      re.compile(node_name_regex_allowlist)
267      if node_name_regex_allowlist else None)
268
269  to_skip_file_paths = set()
270  for op in py_graph.get_operations():
271    if node_name_regex and not node_name_regex.match(op.name):
272      continue
273
274    for file_path, line_number, _, _ in dump.node_traceback(op.name):
275      file_path = _norm_abs_path(file_path)
276      if (file_path in to_skip_file_paths or
277          path_regex and not path_regex.match(file_path) or
278          not os.path.isfile(file_path)):
279        to_skip_file_paths.add(file_path)
280        continue
281
282      path_to_node_names[file_path].add(op.name)
283      if file_path in path_to_first_line:
284        if path_to_first_line[file_path] > line_number:
285          path_to_first_line[file_path] = line_number
286      else:
287        path_to_first_line[file_path] = line_number
288
289      for output_tensor in op.outputs:
290        tensor_name = output_tensor.name
291        path_to_tensor_names[file_path].add(tensor_name)
292
293      watch_keys = dump.debug_watch_keys(op.name)
294      for watch_key in watch_keys:
295        node_name, output_slot, debug_op = watch_key.split(":")
296        tensor_name = "%s:%s" % (node_name, output_slot)
297        if tensor_name not in tensor_name_to_num_dumps:
298          tensor_name_to_num_dumps[tensor_name] = len(
299              dump.get_tensors(node_name, int(output_slot), debug_op))
300
301  path_to_num_dumps = {}
302  for path in path_to_tensor_names:
303    path_to_num_dumps[path] = sum(
304        tensor_name_to_num_dumps.get(tensor_name, 0)
305        for tensor_name in path_to_tensor_names[path])
306
307  output = []
308  for file_path in path_to_node_names:
309    output.append((
310        file_path,
311        guess_is_tensorflow_py_library(file_path),
312        len(path_to_node_names.get(file_path, {})),
313        len(path_to_tensor_names.get(file_path, {})),
314        path_to_num_dumps.get(file_path, 0),
315        path_to_first_line[file_path]))
316
317  return sorted(output, key=lambda x: x[0])
318
319
320def annotate_source_against_profile(profile_data,
321                                    source_file_path,
322                                    node_name_filter=None,
323                                    op_type_filter=None,
324                                    min_line=None,
325                                    max_line=None):
326  """Annotate a Python source file with profiling information at each line.
327
328  (The annotation doesn't change the source file itself.)
329
330  Args:
331    profile_data: (`list` of `ProfileDatum`) A list of `ProfileDatum`.
332    source_file_path: (`str`) Path to the source file being annotated.
333    node_name_filter: Regular expression to filter by node name.
334    op_type_filter: Regular expression to filter by op type.
335    min_line: (`None` or `int`) The 1-based line to start annotate the source
336      file from (inclusive).
337    max_line: (`None` or `int`) The 1-based line number to end the annotation
338      at (exclusive).
339
340  Returns:
341    A `dict` mapping 1-based line number to a the namedtuple
342      `profiling.LineOrFuncProfileSummary`.
343  """
344
345  source_file_path = _norm_abs_path(source_file_path)
346
347  node_name_regex = re.compile(node_name_filter) if node_name_filter else None
348  op_type_regex = re.compile(op_type_filter) if op_type_filter else None
349
350  line_to_profile_summary = {}
351  for profile_datum in profile_data:
352    if not profile_datum.file_path:
353      continue
354
355    if _norm_abs_path(profile_datum.file_path) != source_file_path:
356      continue
357
358    if (min_line is not None and profile_datum.line_number < min_line or
359        max_line is not None and profile_datum.line_number >= max_line):
360      continue
361
362    if (node_name_regex and
363        not node_name_regex.match(profile_datum.node_exec_stats.node_name)):
364      continue
365
366    if op_type_regex and not op_type_regex.match(profile_datum.op_type):
367      continue
368
369    if profile_datum.line_number not in line_to_profile_summary:
370      line_to_profile_summary[profile_datum.line_number] = (
371          profiling.AggregateProfile(profile_datum))
372    else:
373      line_to_profile_summary[profile_datum.line_number].add(profile_datum)
374
375  return line_to_profile_summary
376