xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/parser.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converting code to AST.
16
17Adapted from Tangent.
18"""
19
20import ast
21import inspect
22import io
23import linecache
24import re
25import sys
26import textwrap
27import tokenize
28
29import astunparse
30import gast
31
32from tensorflow.python.autograph.pyct import errors
33from tensorflow.python.autograph.pyct import inspect_utils
34from tensorflow.python.util import tf_inspect
35
36
37PY2_PREAMBLE = textwrap.dedent("""
38""")
39PY3_PREAMBLE = ''
40MAX_SIZE = 0
41
42if sys.version_info >= (3, 9):
43  astunparse = ast
44
45if sys.version_info >= (3,):
46  STANDARD_PREAMBLE = PY3_PREAMBLE
47  MAX_SIZE = sys.maxsize
48else:
49  STANDARD_PREAMBLE = PY2_PREAMBLE
50  MAX_SIZE = sys.maxint
51
52STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__')
53
54
55_LEADING_WHITESPACE = re.compile(r'\s*')
56
57
58def _unfold_continuations(code_string):
59  """Removes any backslash line continuations from the code."""
60  return code_string.replace('\\\n', '')
61
62
63def dedent_block(code_string):
64  """Dedents a code so that its first line starts at row zero."""
65
66  code_string = _unfold_continuations(code_string)
67
68  token_gen = tokenize.generate_tokens(io.StringIO(code_string).readline)
69
70  block_indentation = None
71  tokens = []
72  try:
73    for tok in token_gen:
74      tokens.append(tok)
75  except tokenize.TokenError:
76    # Resolution of lambda functions may yield incomplete code, which can
77    # in turn generate this error. We silently ignore this error because the
78    # parser may still be able to deal with it.
79    pass
80
81  for tok in tokens:
82    tok_type, tok_string, _, _, _ = tok
83    if tok_type == tokenize.INDENT:
84      block_indentation = tok_string
85      block_level = len(block_indentation)
86      break
87    elif tok_type not in (
88        tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT):
89      block_indentation = ''
90      break
91
92  if not block_indentation:
93    return code_string
94
95  block_level = len(block_indentation)
96  first_indent_uses_tabs = '\t' in block_indentation
97  for i, tok in enumerate(tokens):
98    tok_type, tok_string, _, _, _ = tok
99    if tok_type == tokenize.INDENT:
100      if ((' ' in tok_string and first_indent_uses_tabs)
101          or ('\t' in tok_string and not first_indent_uses_tabs)):
102        # TODO(mdan): We could attempt to convert tabs to spaces by unix rule.
103        # See:
104        # https://docs.python.org/3/reference/lexical_analysis.html#indentation
105        raise errors.UnsupportedLanguageElementError(
106            'code mixing tabs and spaces for indentation is not allowed')
107      if len(tok_string) >= block_level:
108        tok_string = tok_string[block_level:]
109      tokens[i] = (tok_type, tok_string)
110
111  new_code = tokenize.untokenize(tokens)
112
113  # Note: untokenize respects the line structure, but not the whitespace within
114  # lines. For example, `def foo()` may be untokenized as `def foo ()`
115  # So instead of using the output of dedent, we match the leading whitespace
116  # on each line.
117  dedented_code = []
118  for line, new_line in zip(code_string.split('\n'), new_code.split('\n')):
119    original_indent = re.match(_LEADING_WHITESPACE, line).group()
120    new_indent = re.match(_LEADING_WHITESPACE, new_line).group()
121    if len(original_indent) > len(new_indent):
122      dedented_line = line[len(original_indent) - len(new_indent):]
123    else:
124      dedented_line = line
125    dedented_code.append(dedented_line)
126  new_code = '\n'.join(dedented_code)
127
128  return new_code
129
130
131def parse_entity(entity, future_features):
132  """Returns the AST and source code of given entity.
133
134  Args:
135    entity: Any, Python function/method/class
136    future_features: Iterable[Text], future features to use (e.g.
137      'print_statement'). See
138      https://docs.python.org/2/reference/simple_stmts.html#future
139
140  Returns:
141    gast.AST, Text: the parsed AST node; the source code that was parsed to
142    generate the AST (including any prefixes that this function may have added).
143  """
144  if inspect_utils.islambda(entity):
145    return _parse_lambda(entity)
146
147  try:
148    original_source = inspect_utils.getimmediatesource(entity)
149  except OSError as e:
150    raise errors.InaccessibleSourceCodeError(
151        f'Unable to locate the source code of {entity}. Note that functions'
152        ' defined in certain environments, like the interactive Python shell,'
153        ' do not expose their source code. If that is the case, you should'
154        ' define them in a .py source file. If you are certain the code is'
155        ' graph-compatible, wrap the call using'
156        f' @tf.autograph.experimental.do_not_convert. Original error: {e}')
157
158  source = dedent_block(original_source)
159
160  future_statements = tuple(
161      'from __future__ import {}'.format(name) for name in future_features)
162  source = '\n'.join(future_statements + (source,))
163
164  return parse(source, preamble_len=len(future_features)), source
165
166
167def _without_context(node, lines, minl, maxl):
168  """Returns a clean node and source code without indenting and context."""
169  for n in gast.walk(node):
170    lineno = getattr(n, 'lineno', None)
171    if lineno is not None:
172      n.lineno = lineno - minl
173    end_lineno = getattr(n, 'end_lineno', None)
174    if end_lineno is not None:
175      n.end_lineno = end_lineno - minl
176
177  code_lines = lines[minl - 1:maxl]
178
179  # Attempt to clean up surrounding context code.
180
181  end_col_offset = getattr(node, 'end_col_offset', None)
182  if end_col_offset is not None:
183    # This is only available in 3.8.
184    code_lines[-1] = code_lines[-1][:end_col_offset]
185
186  col_offset = getattr(node, 'col_offset', None)
187  if col_offset is None:
188    # Older Python: try to find the "lambda" token. This is brittle.
189    match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0])
190    if match is not None:
191      col_offset = match.start(0)
192
193  if col_offset is not None:
194    code_lines[0] = code_lines[0][col_offset:]
195
196  code_block = '\n'.join([c.rstrip() for c in code_lines])
197
198  return node, code_block
199
200
201def _arg_name(node):
202  if node is None:
203    return None
204  if isinstance(node, gast.Name):
205    return node.id
206  assert isinstance(node, str)
207  return node
208
209
210def _node_matches_argspec(node, func):
211  """Returns True is node fits the argspec of func."""
212  # TODO(mdan): Use just inspect once support for Python 2 is dropped.
213  arg_spec = tf_inspect.getfullargspec(func)
214
215  node_args = tuple(_arg_name(arg) for arg in node.args.args)
216  if node_args != tuple(arg_spec.args):
217    return False
218
219  if arg_spec.varargs != _arg_name(node.args.vararg):
220    return False
221
222  if arg_spec.varkw != _arg_name(node.args.kwarg):
223    return False
224
225  node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs)
226  if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
227    return False
228
229  return True
230
231
232def _parse_lambda(lam):
233  """Returns the AST and source code of given lambda function.
234
235  Args:
236    lam: types.LambdaType, Python function/method/class
237
238  Returns:
239    gast.AST, Text: the parsed AST node; the source code that was parsed to
240    generate the AST (including any prefixes that this function may have added).
241  """
242  # TODO(mdan): Use a fast path if the definition is not multi-line.
243  # We could detect that the lambda is in a multi-line expression by looking
244  # at the surrounding code - an surrounding set of parentheses indicates a
245  # potential multi-line definition.
246
247  mod = inspect.getmodule(lam)
248  f = inspect.getsourcefile(lam)
249  def_line = lam.__code__.co_firstlineno
250
251  # This method is more robust that just calling inspect.getsource(mod), as it
252  # works in interactive shells, where getsource would fail. This is the
253  # same procedure followed by inspect for non-modules:
254  # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772
255  lines = linecache.getlines(f, mod.__dict__)
256  source = ''.join(lines)
257
258  # Narrow down to the last node starting before our definition node.
259  all_nodes = parse(source, preamble_len=0, single_node=False)
260  search_nodes = []
261  for node in all_nodes:
262    # Also include nodes without a line number, for safety. This is defensive -
263    # we don't know whether such nodes might exist, and if they do, whether
264    # they are not safe to skip.
265    # TODO(mdan): Replace this check with an assertion or skip such nodes.
266    if getattr(node, 'lineno', def_line) <= def_line:
267      search_nodes.append(node)
268    else:
269      # Found a node starting past our lambda - can stop the search.
270      break
271
272  # Extract all lambda nodes from the shortlist.
273  lambda_nodes = []
274  for node in search_nodes:
275    lambda_nodes.extend(
276        n for n in gast.walk(node) if isinstance(n, gast.Lambda))
277
278  # Filter down to lambda nodes which span our actual lambda.
279  candidates = []
280  for ln in lambda_nodes:
281    minl, maxl = MAX_SIZE, 0
282    for n in gast.walk(ln):
283      minl = min(minl, getattr(n, 'lineno', minl))
284      lineno = getattr(n, 'lineno', maxl)
285      end_lineno = getattr(n, 'end_lineno', None)
286      if end_lineno is not None:
287        # end_lineno is more precise, but lineno should almost always work too.
288        lineno = end_lineno
289      maxl = max(maxl, lineno)
290    if minl <= def_line <= maxl:
291      candidates.append((ln, minl, maxl))
292
293  # Happy path: exactly one node found.
294  if len(candidates) == 1:
295    (node, minl, maxl), = candidates  # pylint:disable=unbalanced-tuple-unpacking
296    return _without_context(node, lines, minl, maxl)
297
298  elif not candidates:
299    lambda_codes = '\n'.join([unparse(l) for l in lambda_nodes])
300    raise errors.UnsupportedLanguageElementError(
301        f'could not parse the source code of {lam}:'
302        f' no matching AST found among candidates:\n{lambda_codes}')
303
304  # Attempt to narrow down selection by signature is multiple nodes are found.
305  matches = [v for v in candidates if _node_matches_argspec(v[0], lam)]
306  if len(matches) == 1:
307    (node, minl, maxl), = matches
308    return _without_context(node, lines, minl, maxl)
309
310  # Give up if could not narrow down to a single node.
311  matches = '\n'.join(
312      'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False))
313      for i, (node, _, _) in enumerate(matches))
314  raise errors.UnsupportedLanguageElementError(
315      f'could not parse the source code of {lam}: found multiple definitions'
316      ' with identical signatures at the location. This error'
317      ' may be avoided by defining each lambda on a single line and with'
318      f' unique argument names. The matching definitions were:\n{matches}')
319
320
321# TODO(mdan): This should take futures as input instead.
322def parse(src, preamble_len=0, single_node=True):
323  """Returns the AST of given piece of code.
324
325  Args:
326    src: Text
327    preamble_len: Int, indicates leading nodes in the parsed AST which should be
328      dropped.
329    single_node: Bool, whether `src` is assumed to be represented by exactly one
330      AST node.
331
332  Returns:
333    ast.AST
334  """
335  module_node = gast.parse(src)
336  nodes = module_node.body
337  if preamble_len:
338    nodes = nodes[preamble_len:]
339  if single_node:
340    if len(nodes) != 1:
341      raise ValueError('expected exactly one node, got {}'.format(nodes))
342    return nodes[0]
343  return nodes
344
345
346def parse_expression(src):
347  """Returns the AST of given identifier.
348
349  Args:
350    src: A piece of code that represents a single Python expression
351  Returns:
352    A gast.AST object.
353  Raises:
354    ValueError: if src does not consist of a single Expression.
355  """
356  src = STANDARD_PREAMBLE + src.strip()
357  node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True)
358  if __debug__:
359    if not isinstance(node, gast.Expr):
360      raise ValueError(
361          'expected exactly one node of type Expr, got {}'.format(node))
362  return node.value
363
364
365def unparse(node, indentation=None, include_encoding_marker=True):
366  """Returns the source code of given AST.
367
368  Args:
369    node: The code to compile, as an AST object.
370    indentation: Unused, deprecated. The returning code will always be indented
371      at 4 spaces.
372    include_encoding_marker: Bool, whether to include a comment on the first
373      line to explicitly specify UTF-8 encoding.
374
375  Returns:
376    code: The source code generated from the AST object
377    source_mapping: A mapping between the user and AutoGraph generated code.
378  """
379  del indentation  # astunparse doesn't allow configuring it.
380  if not isinstance(node, (list, tuple)):
381    node = (node,)
382
383  codes = []
384  if include_encoding_marker:
385    codes.append('# coding=utf-8')
386  for n in node:
387    if isinstance(n, gast.AST):
388      ast_n = gast.gast_to_ast(n)
389    else:
390      ast_n = n
391
392    if astunparse is ast:
393      ast.fix_missing_locations(ast_n)  # Only ast needs to call this.
394    codes.append(astunparse.unparse(ast_n).strip())
395
396  return '\n'.join(codes)
397