1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Container for origin source code information before AutoGraph compilation.""" 16import collections 17import difflib 18import io 19import os 20import tokenize 21 22import gast 23 24from tensorflow.python.autograph.pyct import anno 25from tensorflow.python.autograph.pyct import ast_util 26from tensorflow.python.autograph.pyct import parser 27from tensorflow.python.autograph.pyct import pretty_printer 28from tensorflow.python.util import tf_inspect 29 30 31class LineLocation( 32 collections.namedtuple('LineLocation', ('filename', 'lineno'))): 33 """Similar to Location, but without column information. 34 35 Attributes: 36 filename: Text 37 lineno: int, 1-based 38 """ 39 pass 40 41 42class Location( 43 collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))): 44 """Encodes code location information. 45 46 Attributes: 47 filename: Text 48 lineno: int, 1-based 49 col_offset: int 50 line_loc: LineLocation 51 """ 52 53 @property 54 def line_loc(self): 55 return LineLocation(self.filename, self.lineno) 56 57 58class OriginInfo( 59 collections.namedtuple( 60 'OriginInfo', 61 ('loc', 'function_name', 'source_code_line', 'comment'))): 62 """Container for information about the source code before conversion. 63 64 Attributes: 65 loc: Location 66 function_name: Optional[Text] 67 source_code_line: Text 68 comment: Optional[Text] 69 """ 70 71 def as_frame(self): 72 """Returns a 4-tuple consistent with the return of traceback.extract_tb.""" 73 return (self.loc.filename, self.loc.lineno, self.function_name, 74 self.source_code_line) 75 76 def __repr__(self): 77 if self.loc.filename: 78 return '{}:{}:{}'.format( 79 os.path.split(self.loc.filename)[1], self.loc.lineno, 80 self.loc.col_offset) 81 return '<no file>:{}:{}'.format(self.loc.lineno, self.loc.col_offset) 82 83 84# TODO(mdan): This source map should be a class - easier to refer to. 85def create_source_map(nodes, code, filepath): 86 """Creates a source map between an annotated AST and the code it compiles to. 87 88 Note: this function assumes nodes nodes, code and filepath correspond to the 89 same code. 90 91 Args: 92 nodes: Iterable[ast.AST, ...], one or more AST modes. 93 code: Text, the source code in which nodes are found. 94 filepath: Text 95 96 Returns: 97 Dict[LineLocation, OriginInfo], mapping locations in code to locations 98 indicated by origin annotations in node. 99 """ 100 reparsed_nodes = parser.parse(code, preamble_len=0, single_node=False) 101 for node in reparsed_nodes: 102 resolve(node, code, filepath, node.lineno, node.col_offset) 103 104 source_map = {} 105 106 try: 107 for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): 108 # Note: generated code might not be mapped back to its origin. 109 # TODO(mdan): Generated code should always be mapped to something. 110 origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) 111 final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) 112 if origin_info is None or final_info is None: 113 continue 114 115 # Note: the keys are by line only, excluding the column offset. 116 line_loc = LineLocation(final_info.loc.filename, final_info.loc.lineno) 117 118 existing_origin = source_map.get(line_loc) 119 if existing_origin is not None: 120 # Overlaps may exist because of child nodes, but almost never to 121 # different line locations. Exception make decorated functions, where 122 # both lines are mapped to the same line in the AST. 123 124 # Line overlaps: keep bottom node. 125 if existing_origin.loc.line_loc == origin_info.loc.line_loc: 126 if existing_origin.loc.lineno >= origin_info.loc.lineno: 127 continue 128 129 # In case of column overlaps, keep the leftmost node. 130 if existing_origin.loc.col_offset <= origin_info.loc.col_offset: 131 continue 132 133 source_map[line_loc] = origin_info 134 135 except ValueError as err: 136 new_msg = 'Inconsistent ASTs detected. This is a bug. Cause: \n' 137 new_msg += str(err) 138 new_msg += 'Diff:\n' 139 140 for n, rn in zip(nodes, reparsed_nodes): 141 nodes_str = pretty_printer.fmt(n, color=False, noanno=True) 142 reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True) 143 diff = difflib.context_diff( 144 nodes_str.split('\n'), 145 reparsed_nodes_str.split('\n'), 146 fromfile='Original nodes', 147 tofile='Reparsed nodes', 148 n=7) 149 diff = '\n'.join(diff) 150 new_msg += diff + '\n' 151 raise ValueError(new_msg) 152 153 return source_map 154 155 156class _Function: 157 158 def __init__(self, name): 159 self.name = name 160 161 162class OriginResolver(gast.NodeVisitor): 163 """Annotates an AST with additional source information like file name.""" 164 165 def __init__(self, root_node, source_lines, comments_map, 166 context_lineno, context_col_offset, 167 filepath): 168 self._source_lines = source_lines 169 self._comments_map = comments_map 170 171 if (hasattr(root_node, 'decorator_list') and root_node.decorator_list and 172 hasattr(root_node.decorator_list[0], 'lineno')): 173 # Typical case: functions. The line number of the first decorator 174 # is more accurate than the line number of the function itself in 175 # 3.8+. In earier versions they coincide. 176 self._lineno_offset = context_lineno - root_node.decorator_list[0].lineno 177 else: 178 # Fall back to the line number of the root node. 179 self._lineno_offset = context_lineno - root_node.lineno 180 181 self._col_offset = context_col_offset - root_node.col_offset 182 183 self._filepath = filepath 184 185 self._function_stack = [] 186 187 def _absolute_lineno(self, lineno): 188 return lineno + self._lineno_offset 189 190 def _absolute_col_offset(self, col_offset): 191 if col_offset is None: 192 return 0 193 return col_offset + self._col_offset 194 195 def _attach_origin_info(self, node): 196 lineno = getattr(node, 'lineno', None) 197 col_offset = getattr(node, 'col_offset', None) 198 199 if lineno is None: 200 return 201 202 if self._function_stack: 203 function_name = self._function_stack[-1].name 204 else: 205 function_name = None 206 207 source_code_line = self._source_lines[lineno - 1] 208 comment = self._comments_map.get(lineno) 209 210 loc = Location(self._filepath, self._absolute_lineno(lineno), 211 self._absolute_col_offset(col_offset)) 212 origin = OriginInfo(loc, function_name, source_code_line, comment) 213 anno.setanno(node, 'lineno', lineno) 214 anno.setanno(node, anno.Basic.ORIGIN, origin) 215 216 def visit(self, node): 217 entered_function = False 218 if isinstance(node, gast.FunctionDef): 219 entered_function = True 220 self._function_stack.append(_Function(node.name)) 221 222 self._attach_origin_info(node) 223 self.generic_visit(node) 224 225 if entered_function: 226 self._function_stack.pop() 227 228 229def resolve(node, source, context_filepath, context_lineno, context_col_offset): 230 """Adds origin information to an AST, based on the source it was loaded from. 231 232 This allows us to map the original source code line numbers to generated 233 source code. 234 235 Note: the AST may be a part of a larger context (e.g. a function is part of 236 a module that may contain other things). However, this function does not 237 assume the source argument contains the entire context, nor that it contains 238 only code corresponding to node itself. However, it assumes that node was 239 parsed from the given source code. 240 For this reason, two extra arguments are required, and they indicate the 241 location of the node in the original context. 242 243 Args: 244 node: gast.AST, the AST to annotate. 245 source: Text, the source code representing node. 246 context_filepath: Text 247 context_lineno: int 248 context_col_offset: int 249 """ 250 # TODO(mdan): Pull this to a separate utility. 251 code_reader = io.StringIO(source) 252 comments_map = {} 253 try: 254 for token in tokenize.generate_tokens(code_reader.readline): 255 tok_type, tok_string, loc, _, _ = token 256 srow, _ = loc 257 if tok_type == tokenize.COMMENT: 258 comments_map[srow] = tok_string.strip()[1:].strip() 259 except tokenize.TokenError: 260 if isinstance(node, gast.Lambda): 261 # Source code resolution in older Python versions is brittle for 262 # lambda functions, and may contain garbage. 263 pass 264 else: 265 raise 266 267 source_lines = source.split('\n') 268 visitor = OriginResolver(node, source_lines, comments_map, 269 context_lineno, context_col_offset, 270 context_filepath) 271 visitor.visit(node) 272 273 274def resolve_entity(node, source, entity): 275 """Like resolve, but extracts the context information from an entity.""" 276 lines, lineno = tf_inspect.getsourcelines(entity) 277 filepath = tf_inspect.getsourcefile(entity) 278 279 # Poor man's attempt at guessing the column offset: count the leading 280 # whitespace. This might not work well with tabs. 281 definition_line = lines[0] 282 col_offset = len(definition_line) - len(definition_line.lstrip()) 283 284 resolve(node, source, filepath, lineno, col_offset) 285 286 287def copy_origin(from_node, to_node): 288 """Copies the origin info from a node to another, recursively.""" 289 origin = anno.Basic.ORIGIN.of(from_node, default=None) 290 if origin is None: 291 return 292 if not isinstance(to_node, (list, tuple)): 293 to_node = (to_node,) 294 for node in to_node: 295 for n in gast.walk(node): 296 anno.setanno(n, anno.Basic.ORIGIN, origin) 297