xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/origin_info.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Container for origin source code information before AutoGraph compilation."""
16import collections
17import difflib
18import io
19import os
20import tokenize
21
22import gast
23
24from tensorflow.python.autograph.pyct import anno
25from tensorflow.python.autograph.pyct import ast_util
26from tensorflow.python.autograph.pyct import parser
27from tensorflow.python.autograph.pyct import pretty_printer
28from tensorflow.python.util import tf_inspect
29
30
31class LineLocation(
32    collections.namedtuple('LineLocation', ('filename', 'lineno'))):
33  """Similar to Location, but without column information.
34
35  Attributes:
36    filename: Text
37    lineno: int, 1-based
38  """
39  pass
40
41
42class Location(
43    collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))):
44  """Encodes code location information.
45
46  Attributes:
47    filename: Text
48    lineno: int, 1-based
49    col_offset: int
50    line_loc: LineLocation
51  """
52
53  @property
54  def line_loc(self):
55    return LineLocation(self.filename, self.lineno)
56
57
58class OriginInfo(
59    collections.namedtuple(
60        'OriginInfo',
61        ('loc', 'function_name', 'source_code_line', 'comment'))):
62  """Container for information about the source code before conversion.
63
64  Attributes:
65    loc: Location
66    function_name: Optional[Text]
67    source_code_line: Text
68    comment: Optional[Text]
69  """
70
71  def as_frame(self):
72    """Returns a 4-tuple consistent with the return of traceback.extract_tb."""
73    return (self.loc.filename, self.loc.lineno, self.function_name,
74            self.source_code_line)
75
76  def __repr__(self):
77    if self.loc.filename:
78      return '{}:{}:{}'.format(
79          os.path.split(self.loc.filename)[1], self.loc.lineno,
80          self.loc.col_offset)
81    return '<no file>:{}:{}'.format(self.loc.lineno, self.loc.col_offset)
82
83
84# TODO(mdan): This source map should be a class - easier to refer to.
85def create_source_map(nodes, code, filepath):
86  """Creates a source map between an annotated AST and the code it compiles to.
87
88  Note: this function assumes nodes nodes, code and filepath correspond to the
89  same code.
90
91  Args:
92    nodes: Iterable[ast.AST, ...], one or more AST modes.
93    code: Text, the source code in which nodes are found.
94    filepath: Text
95
96  Returns:
97    Dict[LineLocation, OriginInfo], mapping locations in code to locations
98    indicated by origin annotations in node.
99  """
100  reparsed_nodes = parser.parse(code, preamble_len=0, single_node=False)
101  for node in reparsed_nodes:
102    resolve(node, code, filepath, node.lineno, node.col_offset)
103
104  source_map = {}
105
106  try:
107    for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
108      # Note: generated code might not be mapped back to its origin.
109      # TODO(mdan): Generated code should always be mapped to something.
110      origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
111      final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
112      if origin_info is None or final_info is None:
113        continue
114
115      # Note: the keys are by line only, excluding the column offset.
116      line_loc = LineLocation(final_info.loc.filename, final_info.loc.lineno)
117
118      existing_origin = source_map.get(line_loc)
119      if existing_origin is not None:
120        # Overlaps may exist because of child nodes, but almost never to
121        # different line locations. Exception make decorated functions, where
122        # both lines are mapped to the same line in the AST.
123
124        # Line overlaps: keep bottom node.
125        if existing_origin.loc.line_loc == origin_info.loc.line_loc:
126          if existing_origin.loc.lineno >= origin_info.loc.lineno:
127            continue
128
129        # In case of column overlaps, keep the leftmost node.
130        if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
131          continue
132
133      source_map[line_loc] = origin_info
134
135  except ValueError as err:
136    new_msg = 'Inconsistent ASTs detected. This is a bug. Cause: \n'
137    new_msg += str(err)
138    new_msg += 'Diff:\n'
139
140    for n, rn in zip(nodes, reparsed_nodes):
141      nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
142      reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True)
143      diff = difflib.context_diff(
144          nodes_str.split('\n'),
145          reparsed_nodes_str.split('\n'),
146          fromfile='Original nodes',
147          tofile='Reparsed nodes',
148          n=7)
149      diff = '\n'.join(diff)
150      new_msg += diff + '\n'
151    raise ValueError(new_msg)
152
153  return source_map
154
155
156class _Function:
157
158  def __init__(self, name):
159    self.name = name
160
161
162class OriginResolver(gast.NodeVisitor):
163  """Annotates an AST with additional source information like file name."""
164
165  def __init__(self, root_node, source_lines, comments_map,
166               context_lineno, context_col_offset,
167               filepath):
168    self._source_lines = source_lines
169    self._comments_map = comments_map
170
171    if (hasattr(root_node, 'decorator_list') and root_node.decorator_list and
172        hasattr(root_node.decorator_list[0], 'lineno')):
173      # Typical case: functions. The line number of the first decorator
174      # is more accurate than the line number of the function itself in
175      # 3.8+. In earier versions they coincide.
176      self._lineno_offset = context_lineno - root_node.decorator_list[0].lineno
177    else:
178      # Fall back to the line number of the root node.
179      self._lineno_offset = context_lineno - root_node.lineno
180
181    self._col_offset = context_col_offset - root_node.col_offset
182
183    self._filepath = filepath
184
185    self._function_stack = []
186
187  def _absolute_lineno(self, lineno):
188    return lineno + self._lineno_offset
189
190  def _absolute_col_offset(self, col_offset):
191    if col_offset is None:
192      return 0
193    return col_offset + self._col_offset
194
195  def _attach_origin_info(self, node):
196    lineno = getattr(node, 'lineno', None)
197    col_offset = getattr(node, 'col_offset', None)
198
199    if lineno is None:
200      return
201
202    if self._function_stack:
203      function_name = self._function_stack[-1].name
204    else:
205      function_name = None
206
207    source_code_line = self._source_lines[lineno - 1]
208    comment = self._comments_map.get(lineno)
209
210    loc = Location(self._filepath, self._absolute_lineno(lineno),
211                   self._absolute_col_offset(col_offset))
212    origin = OriginInfo(loc, function_name, source_code_line, comment)
213    anno.setanno(node, 'lineno', lineno)
214    anno.setanno(node, anno.Basic.ORIGIN, origin)
215
216  def visit(self, node):
217    entered_function = False
218    if isinstance(node, gast.FunctionDef):
219      entered_function = True
220      self._function_stack.append(_Function(node.name))
221
222    self._attach_origin_info(node)
223    self.generic_visit(node)
224
225    if entered_function:
226      self._function_stack.pop()
227
228
229def resolve(node, source, context_filepath, context_lineno, context_col_offset):
230  """Adds origin information to an AST, based on the source it was loaded from.
231
232  This allows us to map the original source code line numbers to generated
233  source code.
234
235  Note: the AST may be a part of a larger context (e.g. a function is part of
236  a module that may contain other things). However, this function does not
237  assume the source argument contains the entire context, nor that it contains
238  only code corresponding to node itself. However, it assumes that node was
239  parsed from the given source code.
240  For this reason, two extra arguments are required, and they indicate the
241  location of the node in the original context.
242
243  Args:
244    node: gast.AST, the AST to annotate.
245    source: Text, the source code representing node.
246    context_filepath: Text
247    context_lineno: int
248    context_col_offset: int
249  """
250  # TODO(mdan): Pull this to a separate utility.
251  code_reader = io.StringIO(source)
252  comments_map = {}
253  try:
254    for token in tokenize.generate_tokens(code_reader.readline):
255      tok_type, tok_string, loc, _, _ = token
256      srow, _ = loc
257      if tok_type == tokenize.COMMENT:
258        comments_map[srow] = tok_string.strip()[1:].strip()
259  except tokenize.TokenError:
260    if isinstance(node, gast.Lambda):
261      # Source code resolution in older Python versions is brittle for
262      # lambda functions, and may contain garbage.
263      pass
264    else:
265      raise
266
267  source_lines = source.split('\n')
268  visitor = OriginResolver(node, source_lines, comments_map,
269                           context_lineno, context_col_offset,
270                           context_filepath)
271  visitor.visit(node)
272
273
274def resolve_entity(node, source, entity):
275  """Like resolve, but extracts the context information from an entity."""
276  lines, lineno = tf_inspect.getsourcelines(entity)
277  filepath = tf_inspect.getsourcefile(entity)
278
279  # Poor man's attempt at guessing the column offset: count the leading
280  # whitespace. This might not work well with tabs.
281  definition_line = lines[0]
282  col_offset = len(definition_line) - len(definition_line.lstrip())
283
284  resolve(node, source, filepath, lineno, col_offset)
285
286
287def copy_origin(from_node, to_node):
288  """Copies the origin info from a node to another, recursively."""
289  origin = anno.Basic.ORIGIN.of(from_node, default=None)
290  if origin is None:
291    return
292  if not isinstance(to_node, (list, tuple)):
293    to_node = (to_node,)
294  for node in to_node:
295    for n in gast.walk(node):
296      anno.setanno(n, anno.Basic.ORIGIN, origin)
297