1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converting code to AST. 16 17Adapted from Tangent. 18""" 19 20import ast 21import inspect 22import io 23import linecache 24import re 25import sys 26import textwrap 27import tokenize 28 29import astunparse 30import gast 31 32from tensorflow.python.autograph.pyct import errors 33from tensorflow.python.autograph.pyct import inspect_utils 34from tensorflow.python.util import tf_inspect 35 36 37PY2_PREAMBLE = textwrap.dedent(""" 38""") 39PY3_PREAMBLE = '' 40MAX_SIZE = 0 41 42if sys.version_info >= (3, 9): 43 astunparse = ast 44 45if sys.version_info >= (3,): 46 STANDARD_PREAMBLE = PY3_PREAMBLE 47 MAX_SIZE = sys.maxsize 48else: 49 STANDARD_PREAMBLE = PY2_PREAMBLE 50 MAX_SIZE = sys.maxint 51 52STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__') 53 54 55_LEADING_WHITESPACE = re.compile(r'\s*') 56 57 58def _unfold_continuations(code_string): 59 """Removes any backslash line continuations from the code.""" 60 return code_string.replace('\\\n', '') 61 62 63def dedent_block(code_string): 64 """Dedents a code so that its first line starts at row zero.""" 65 66 code_string = _unfold_continuations(code_string) 67 68 token_gen = tokenize.generate_tokens(io.StringIO(code_string).readline) 69 70 block_indentation = None 71 tokens = [] 72 try: 73 for tok in token_gen: 74 tokens.append(tok) 75 except tokenize.TokenError: 76 # Resolution of lambda functions may yield incomplete code, which can 77 # in turn generate this error. We silently ignore this error because the 78 # parser may still be able to deal with it. 79 pass 80 81 for tok in tokens: 82 tok_type, tok_string, _, _, _ = tok 83 if tok_type == tokenize.INDENT: 84 block_indentation = tok_string 85 block_level = len(block_indentation) 86 break 87 elif tok_type not in ( 88 tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT): 89 block_indentation = '' 90 break 91 92 if not block_indentation: 93 return code_string 94 95 block_level = len(block_indentation) 96 first_indent_uses_tabs = '\t' in block_indentation 97 for i, tok in enumerate(tokens): 98 tok_type, tok_string, _, _, _ = tok 99 if tok_type == tokenize.INDENT: 100 if ((' ' in tok_string and first_indent_uses_tabs) 101 or ('\t' in tok_string and not first_indent_uses_tabs)): 102 # TODO(mdan): We could attempt to convert tabs to spaces by unix rule. 103 # See: 104 # https://docs.python.org/3/reference/lexical_analysis.html#indentation 105 raise errors.UnsupportedLanguageElementError( 106 'code mixing tabs and spaces for indentation is not allowed') 107 if len(tok_string) >= block_level: 108 tok_string = tok_string[block_level:] 109 tokens[i] = (tok_type, tok_string) 110 111 new_code = tokenize.untokenize(tokens) 112 113 # Note: untokenize respects the line structure, but not the whitespace within 114 # lines. For example, `def foo()` may be untokenized as `def foo ()` 115 # So instead of using the output of dedent, we match the leading whitespace 116 # on each line. 117 dedented_code = [] 118 for line, new_line in zip(code_string.split('\n'), new_code.split('\n')): 119 original_indent = re.match(_LEADING_WHITESPACE, line).group() 120 new_indent = re.match(_LEADING_WHITESPACE, new_line).group() 121 if len(original_indent) > len(new_indent): 122 dedented_line = line[len(original_indent) - len(new_indent):] 123 else: 124 dedented_line = line 125 dedented_code.append(dedented_line) 126 new_code = '\n'.join(dedented_code) 127 128 return new_code 129 130 131def parse_entity(entity, future_features): 132 """Returns the AST and source code of given entity. 133 134 Args: 135 entity: Any, Python function/method/class 136 future_features: Iterable[Text], future features to use (e.g. 137 'print_statement'). See 138 https://docs.python.org/2/reference/simple_stmts.html#future 139 140 Returns: 141 gast.AST, Text: the parsed AST node; the source code that was parsed to 142 generate the AST (including any prefixes that this function may have added). 143 """ 144 if inspect_utils.islambda(entity): 145 return _parse_lambda(entity) 146 147 try: 148 original_source = inspect_utils.getimmediatesource(entity) 149 except OSError as e: 150 raise errors.InaccessibleSourceCodeError( 151 f'Unable to locate the source code of {entity}. Note that functions' 152 ' defined in certain environments, like the interactive Python shell,' 153 ' do not expose their source code. If that is the case, you should' 154 ' define them in a .py source file. If you are certain the code is' 155 ' graph-compatible, wrap the call using' 156 f' @tf.autograph.experimental.do_not_convert. Original error: {e}') 157 158 source = dedent_block(original_source) 159 160 future_statements = tuple( 161 'from __future__ import {}'.format(name) for name in future_features) 162 source = '\n'.join(future_statements + (source,)) 163 164 return parse(source, preamble_len=len(future_features)), source 165 166 167def _without_context(node, lines, minl, maxl): 168 """Returns a clean node and source code without indenting and context.""" 169 for n in gast.walk(node): 170 lineno = getattr(n, 'lineno', None) 171 if lineno is not None: 172 n.lineno = lineno - minl 173 end_lineno = getattr(n, 'end_lineno', None) 174 if end_lineno is not None: 175 n.end_lineno = end_lineno - minl 176 177 code_lines = lines[minl - 1:maxl] 178 179 # Attempt to clean up surrounding context code. 180 181 end_col_offset = getattr(node, 'end_col_offset', None) 182 if end_col_offset is not None: 183 # This is only available in 3.8. 184 code_lines[-1] = code_lines[-1][:end_col_offset] 185 186 col_offset = getattr(node, 'col_offset', None) 187 if col_offset is None: 188 # Older Python: try to find the "lambda" token. This is brittle. 189 match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0]) 190 if match is not None: 191 col_offset = match.start(0) 192 193 if col_offset is not None: 194 code_lines[0] = code_lines[0][col_offset:] 195 196 code_block = '\n'.join([c.rstrip() for c in code_lines]) 197 198 return node, code_block 199 200 201def _arg_name(node): 202 if node is None: 203 return None 204 if isinstance(node, gast.Name): 205 return node.id 206 assert isinstance(node, str) 207 return node 208 209 210def _node_matches_argspec(node, func): 211 """Returns True is node fits the argspec of func.""" 212 # TODO(mdan): Use just inspect once support for Python 2 is dropped. 213 arg_spec = tf_inspect.getfullargspec(func) 214 215 node_args = tuple(_arg_name(arg) for arg in node.args.args) 216 if node_args != tuple(arg_spec.args): 217 return False 218 219 if arg_spec.varargs != _arg_name(node.args.vararg): 220 return False 221 222 if arg_spec.varkw != _arg_name(node.args.kwarg): 223 return False 224 225 node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs) 226 if node_kwonlyargs != tuple(arg_spec.kwonlyargs): 227 return False 228 229 return True 230 231 232def _parse_lambda(lam): 233 """Returns the AST and source code of given lambda function. 234 235 Args: 236 lam: types.LambdaType, Python function/method/class 237 238 Returns: 239 gast.AST, Text: the parsed AST node; the source code that was parsed to 240 generate the AST (including any prefixes that this function may have added). 241 """ 242 # TODO(mdan): Use a fast path if the definition is not multi-line. 243 # We could detect that the lambda is in a multi-line expression by looking 244 # at the surrounding code - an surrounding set of parentheses indicates a 245 # potential multi-line definition. 246 247 mod = inspect.getmodule(lam) 248 f = inspect.getsourcefile(lam) 249 def_line = lam.__code__.co_firstlineno 250 251 # This method is more robust that just calling inspect.getsource(mod), as it 252 # works in interactive shells, where getsource would fail. This is the 253 # same procedure followed by inspect for non-modules: 254 # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772 255 lines = linecache.getlines(f, mod.__dict__) 256 source = ''.join(lines) 257 258 # Narrow down to the last node starting before our definition node. 259 all_nodes = parse(source, preamble_len=0, single_node=False) 260 search_nodes = [] 261 for node in all_nodes: 262 # Also include nodes without a line number, for safety. This is defensive - 263 # we don't know whether such nodes might exist, and if they do, whether 264 # they are not safe to skip. 265 # TODO(mdan): Replace this check with an assertion or skip such nodes. 266 if getattr(node, 'lineno', def_line) <= def_line: 267 search_nodes.append(node) 268 else: 269 # Found a node starting past our lambda - can stop the search. 270 break 271 272 # Extract all lambda nodes from the shortlist. 273 lambda_nodes = [] 274 for node in search_nodes: 275 lambda_nodes.extend( 276 n for n in gast.walk(node) if isinstance(n, gast.Lambda)) 277 278 # Filter down to lambda nodes which span our actual lambda. 279 candidates = [] 280 for ln in lambda_nodes: 281 minl, maxl = MAX_SIZE, 0 282 for n in gast.walk(ln): 283 minl = min(minl, getattr(n, 'lineno', minl)) 284 lineno = getattr(n, 'lineno', maxl) 285 end_lineno = getattr(n, 'end_lineno', None) 286 if end_lineno is not None: 287 # end_lineno is more precise, but lineno should almost always work too. 288 lineno = end_lineno 289 maxl = max(maxl, lineno) 290 if minl <= def_line <= maxl: 291 candidates.append((ln, minl, maxl)) 292 293 # Happy path: exactly one node found. 294 if len(candidates) == 1: 295 (node, minl, maxl), = candidates # pylint:disable=unbalanced-tuple-unpacking 296 return _without_context(node, lines, minl, maxl) 297 298 elif not candidates: 299 lambda_codes = '\n'.join([unparse(l) for l in lambda_nodes]) 300 raise errors.UnsupportedLanguageElementError( 301 f'could not parse the source code of {lam}:' 302 f' no matching AST found among candidates:\n{lambda_codes}') 303 304 # Attempt to narrow down selection by signature is multiple nodes are found. 305 matches = [v for v in candidates if _node_matches_argspec(v[0], lam)] 306 if len(matches) == 1: 307 (node, minl, maxl), = matches 308 return _without_context(node, lines, minl, maxl) 309 310 # Give up if could not narrow down to a single node. 311 matches = '\n'.join( 312 'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False)) 313 for i, (node, _, _) in enumerate(matches)) 314 raise errors.UnsupportedLanguageElementError( 315 f'could not parse the source code of {lam}: found multiple definitions' 316 ' with identical signatures at the location. This error' 317 ' may be avoided by defining each lambda on a single line and with' 318 f' unique argument names. The matching definitions were:\n{matches}') 319 320 321# TODO(mdan): This should take futures as input instead. 322def parse(src, preamble_len=0, single_node=True): 323 """Returns the AST of given piece of code. 324 325 Args: 326 src: Text 327 preamble_len: Int, indicates leading nodes in the parsed AST which should be 328 dropped. 329 single_node: Bool, whether `src` is assumed to be represented by exactly one 330 AST node. 331 332 Returns: 333 ast.AST 334 """ 335 module_node = gast.parse(src) 336 nodes = module_node.body 337 if preamble_len: 338 nodes = nodes[preamble_len:] 339 if single_node: 340 if len(nodes) != 1: 341 raise ValueError('expected exactly one node, got {}'.format(nodes)) 342 return nodes[0] 343 return nodes 344 345 346def parse_expression(src): 347 """Returns the AST of given identifier. 348 349 Args: 350 src: A piece of code that represents a single Python expression 351 Returns: 352 A gast.AST object. 353 Raises: 354 ValueError: if src does not consist of a single Expression. 355 """ 356 src = STANDARD_PREAMBLE + src.strip() 357 node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True) 358 if __debug__: 359 if not isinstance(node, gast.Expr): 360 raise ValueError( 361 'expected exactly one node of type Expr, got {}'.format(node)) 362 return node.value 363 364 365def unparse(node, indentation=None, include_encoding_marker=True): 366 """Returns the source code of given AST. 367 368 Args: 369 node: The code to compile, as an AST object. 370 indentation: Unused, deprecated. The returning code will always be indented 371 at 4 spaces. 372 include_encoding_marker: Bool, whether to include a comment on the first 373 line to explicitly specify UTF-8 encoding. 374 375 Returns: 376 code: The source code generated from the AST object 377 source_mapping: A mapping between the user and AutoGraph generated code. 378 """ 379 del indentation # astunparse doesn't allow configuring it. 380 if not isinstance(node, (list, tuple)): 381 node = (node,) 382 383 codes = [] 384 if include_encoding_marker: 385 codes.append('# coding=utf-8') 386 for n in node: 387 if isinstance(n, gast.AST): 388 ast_n = gast.gast_to_ast(n) 389 else: 390 ast_n = n 391 392 if astunparse is ast: 393 ast.fix_missing_locations(ast_n) # Only ast needs to call this. 394 codes.append(astunparse.unparse(ast_n).strip()) 395 396 return '\n'.join(codes) 397