1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Upgrader for Python scripts according to an API change specification.""" 16 17import ast 18import collections 19import os 20import re 21import shutil 22import sys 23import tempfile 24import traceback 25 26import pasta 27 28 29# Some regular expressions we will need for parsing 30FIND_OPEN = re.compile(r"^\s*(\[).*$") 31FIND_STRING_CHARS = re.compile(r"['\"]") 32 33 34INFO = "INFO" 35WARNING = "WARNING" 36ERROR = "ERROR" 37 38 39ImportRename = collections.namedtuple( 40 "ImportRename", ["new_name", "excluded_prefixes"]) 41 42 43def full_name_node(name, ctx=ast.Load()): 44 """Make an Attribute or Name node for name. 45 46 Translate a qualified name into nested Attribute nodes (and a Name node). 47 48 Args: 49 name: The name to translate to a node. 50 ctx: What context this name is used in. Defaults to Load() 51 52 Returns: 53 A Name or Attribute node. 54 """ 55 names = name.split(".") 56 names.reverse() 57 node = ast.Name(id=names.pop(), ctx=ast.Load()) 58 while names: 59 node = ast.Attribute(value=node, attr=names.pop(), ctx=ast.Load()) 60 61 # Change outermost ctx to the one given to us (inner ones should be Load). 62 node.ctx = ctx 63 return node 64 65 66def get_arg_value(node, arg_name, arg_pos=None): 67 """Get the value of an argument from a ast.Call node. 68 69 This function goes through the positional and keyword arguments to check 70 whether a given argument was used, and if so, returns its value (the node 71 representing its value). 72 73 This cannot introspect *args or **args, but it safely handles *args in 74 Python3.5+. 75 76 Args: 77 node: The ast.Call node to extract arg values from. 78 arg_name: The name of the argument to extract. 79 arg_pos: The position of the argument (in case it's passed as a positional 80 argument). 81 82 Returns: 83 A tuple (arg_present, arg_value) containing a boolean indicating whether 84 the argument is present, and its value in case it is. 85 """ 86 # Check keyword args 87 if arg_name is not None: 88 for kw in node.keywords: 89 if kw.arg == arg_name: 90 return (True, kw.value) 91 92 # Check positional args 93 if arg_pos is not None: 94 idx = 0 95 for arg in node.args: 96 if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred): 97 continue # Can't parse Starred 98 if idx == arg_pos: 99 return (True, arg) 100 idx += 1 101 102 return (False, None) 103 104 105def uses_star_args_in_call(node): 106 """Check if an ast.Call node uses arbitrary-length positional *args. 107 108 This function works with the AST call node format of Python3.5+ 109 as well as the different AST format of earlier versions of Python. 110 111 Args: 112 node: The ast.Call node to check arg values for. 113 114 Returns: 115 True if the node uses starred variadic positional args or keyword args. 116 False if it does not. 117 """ 118 if sys.version_info[:2] >= (3, 5): 119 # Check for an *args usage in python 3.5+ 120 for arg in node.args: 121 if isinstance(arg, ast.Starred): 122 return True 123 else: 124 if node.starargs: 125 return True 126 return False 127 128 129def uses_star_kwargs_in_call(node): 130 """Check if an ast.Call node uses arbitrary-length **kwargs. 131 132 This function works with the AST call node format of Python3.5+ 133 as well as the different AST format of earlier versions of Python. 134 135 Args: 136 node: The ast.Call node to check arg values for. 137 138 Returns: 139 True if the node uses starred variadic positional args or keyword args. 140 False if it does not. 141 """ 142 if sys.version_info[:2] >= (3, 5): 143 # Check for a **kwarg usage in python 3.5+ 144 for keyword in node.keywords: 145 if keyword.arg is None: 146 return True 147 else: 148 if node.kwargs: 149 return True 150 return False 151 152 153def uses_star_args_or_kwargs_in_call(node): 154 """Check if an ast.Call node uses arbitrary-length *args or **kwargs. 155 156 This function works with the AST call node format of Python3.5+ 157 as well as the different AST format of earlier versions of Python. 158 159 Args: 160 node: The ast.Call node to check arg values for. 161 162 Returns: 163 True if the node uses starred variadic positional args or keyword args. 164 False if it does not. 165 """ 166 return uses_star_args_in_call(node) or uses_star_kwargs_in_call(node) 167 168 169def excluded_from_module_rename(module, import_rename_spec): 170 """Check if this module import should not be renamed. 171 172 Args: 173 module: (string) module name. 174 import_rename_spec: ImportRename instance. 175 176 Returns: 177 True if this import should not be renamed according to the 178 import_rename_spec. 179 """ 180 for excluded_prefix in import_rename_spec.excluded_prefixes: 181 if module.startswith(excluded_prefix): 182 return True 183 return False 184 185 186class APIChangeSpec: 187 """This class defines the transformations that need to happen. 188 189 This class must provide the following fields: 190 191 * `function_keyword_renames`: maps function names to a map of old -> new 192 argument names 193 * `symbol_renames`: maps function names to new function names 194 * `change_to_function`: a set of function names that have changed (for 195 notifications) 196 * `function_reorders`: maps functions whose argument order has changed to the 197 list of arguments in the new order 198 * `function_warnings`: maps full names of functions to warnings that will be 199 printed out if the function is used. (e.g. tf.nn.convolution()) 200 * `function_transformers`: maps function names to custom handlers 201 * `module_deprecations`: maps module names to warnings that will be printed 202 if the module is still used after all other transformations have run 203 * `import_renames`: maps import name (must be a short name without '.') 204 to ImportRename instance. 205 206 For an example, see `TFAPIChangeSpec`. 207 """ 208 209 def preprocess(self, root_node): # pylint: disable=unused-argument 210 """Preprocess a parse tree. Return a preprocessed node, logs and errors.""" 211 return root_node, [], [] 212 213 def clear_preprocessing(self): 214 """Restore this APIChangeSpec to before it preprocessed a file. 215 216 This is needed if preprocessing a file changed any rewriting rules. 217 """ 218 pass 219 220 221class NoUpdateSpec(APIChangeSpec): 222 """A specification of an API change which doesn't change anything.""" 223 224 def __init__(self): 225 self.function_handle = {} 226 self.function_reorders = {} 227 self.function_keyword_renames = {} 228 self.symbol_renames = {} 229 self.function_warnings = {} 230 self.change_to_function = {} 231 self.module_deprecations = {} 232 self.function_transformers = {} 233 self.import_renames = {} 234 235 236class _PastaEditVisitor(ast.NodeVisitor): 237 """AST Visitor that processes function calls. 238 239 Updates function calls from old API version to new API version using a given 240 change spec. 241 """ 242 243 def __init__(self, api_change_spec): 244 self._api_change_spec = api_change_spec 245 self._log = [] # Holds 4-tuples: severity, line, col, msg. 246 self._stack = [] # Allow easy access to parents. 247 248 # Overridden to maintain a stack of nodes to allow for parent access 249 def visit(self, node): 250 self._stack.append(node) 251 super(_PastaEditVisitor, self).visit(node) 252 self._stack.pop() 253 254 @property 255 def errors(self): 256 return [log for log in self._log if log[0] == ERROR] 257 258 @property 259 def warnings(self): 260 return [log for log in self._log if log[0] == WARNING] 261 262 @property 263 def warnings_and_errors(self): 264 return [log for log in self._log if log[0] in (WARNING, ERROR)] 265 266 @property 267 def info(self): 268 return [log for log in self._log if log[0] == INFO] 269 270 @property 271 def log(self): 272 return self._log 273 274 def add_log(self, severity, lineno, col, msg): 275 self._log.append((severity, lineno, col, msg)) 276 print("%s line %d:%d: %s" % (severity, lineno, col, msg)) 277 278 def add_logs(self, logs): 279 """Record a log and print it. 280 281 The log should be a tuple `(severity, lineno, col_offset, msg)`, which will 282 be printed and recorded. It is part of the log available in the `self.log` 283 property. 284 285 Args: 286 logs: The logs to add. Must be a list of tuples 287 `(severity, lineno, col_offset, msg)`. 288 """ 289 self._log.extend(logs) 290 for log in logs: 291 print("%s line %d:%d: %s" % log) 292 293 def _get_applicable_entries(self, transformer_field, full_name, name): 294 """Get all list entries indexed by name that apply to full_name or name.""" 295 # Transformers are indexed to full name, name, or no name 296 # as a performance optimization. 297 function_transformers = getattr(self._api_change_spec, 298 transformer_field, {}) 299 300 glob_name = "*." + name if name else None 301 transformers = [] 302 if full_name in function_transformers: 303 transformers.append(function_transformers[full_name]) 304 if glob_name in function_transformers: 305 transformers.append(function_transformers[glob_name]) 306 if "*" in function_transformers: 307 transformers.append(function_transformers["*"]) 308 return transformers 309 310 def _get_applicable_dict(self, transformer_field, full_name, name): 311 """Get all dict entries indexed by name that apply to full_name or name.""" 312 # Transformers are indexed to full name, name, or no name 313 # as a performance optimization. 314 function_transformers = getattr(self._api_change_spec, 315 transformer_field, {}) 316 317 glob_name = "*." + name if name else None 318 transformers = function_transformers.get("*", {}).copy() 319 transformers.update(function_transformers.get(glob_name, {})) 320 transformers.update(function_transformers.get(full_name, {})) 321 return transformers 322 323 def _get_full_name(self, node): 324 """Traverse an Attribute node to generate a full name, e.g., "tf.foo.bar". 325 326 This is the inverse of `full_name_node`. 327 328 Args: 329 node: A Node of type Attribute. 330 331 Returns: 332 a '.'-delimited full-name or None if node was not Attribute or Name. 333 i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". 334 """ 335 curr = node 336 items = [] 337 while not isinstance(curr, ast.Name): 338 if not isinstance(curr, ast.Attribute): 339 return None 340 items.append(curr.attr) 341 curr = curr.value 342 items.append(curr.id) 343 return ".".join(reversed(items)) 344 345 def _maybe_add_warning(self, node, full_name): 346 """Adds an error to be printed about full_name at node.""" 347 function_warnings = self._api_change_spec.function_warnings 348 if full_name in function_warnings: 349 level, message = function_warnings[full_name] 350 message = message.replace("<function name>", full_name) 351 self.add_log(level, node.lineno, node.col_offset, 352 "%s requires manual check. %s" % (full_name, message)) 353 return True 354 else: 355 return False 356 357 def _maybe_add_module_deprecation_warning(self, node, full_name, whole_name): 358 """Adds a warning if full_name is a deprecated module.""" 359 warnings = self._api_change_spec.module_deprecations 360 if full_name in warnings: 361 level, message = warnings[full_name] 362 message = message.replace("<function name>", whole_name) 363 self.add_log(level, node.lineno, node.col_offset, 364 "Using member %s in deprecated module %s. %s" % (whole_name, 365 full_name, 366 message)) 367 return True 368 else: 369 return False 370 371 def _maybe_add_call_warning(self, node, full_name, name): 372 """Print a warning when specific functions are called with selected args. 373 374 The function _print_warning_for_function matches the full name of the called 375 function, e.g., tf.foo.bar(). This function matches the function name that 376 is called, as long as the function is an attribute. For example, 377 `tf.foo.bar()` and `foo.bar()` are matched, but not `bar()`. 378 379 Args: 380 node: ast.Call object 381 full_name: The precomputed full name of the callable, if one exists, None 382 otherwise. 383 name: The precomputed name of the callable, if one exists, None otherwise. 384 385 Returns: 386 Whether an error was recorded. 387 """ 388 # Only look for *.-warnings here, the other will be handled by the Attribute 389 # visitor. Also, do not warn for bare functions, only if the call func is 390 # an attribute. 391 warned = False 392 if isinstance(node.func, ast.Attribute): 393 warned = self._maybe_add_warning(node, "*." + name) 394 395 # All arg warnings are handled here, since only we have the args 396 arg_warnings = self._get_applicable_dict("function_arg_warnings", 397 full_name, name) 398 399 variadic_args = uses_star_args_or_kwargs_in_call(node) 400 401 for (kwarg, arg), (level, warning) in sorted(arg_warnings.items()): 402 present, _ = get_arg_value(node, kwarg, arg) or variadic_args 403 if present: 404 warned = True 405 warning_message = warning.replace("<function name>", full_name or name) 406 template = "%s called with %s argument, requires manual check: %s" 407 if variadic_args: 408 template = ("%s called with *args or **kwargs that may include %s, " 409 "requires manual check: %s") 410 self.add_log(level, node.lineno, node.col_offset, 411 template % (full_name or name, kwarg, warning_message)) 412 413 return warned 414 415 def _maybe_rename(self, parent, node, full_name): 416 """Replace node (Attribute or Name) with a node representing full_name.""" 417 new_name = self._api_change_spec.symbol_renames.get(full_name, None) 418 if new_name: 419 self.add_log(INFO, node.lineno, node.col_offset, 420 "Renamed %r to %r" % (full_name, new_name)) 421 new_node = full_name_node(new_name, node.ctx) 422 ast.copy_location(new_node, node) 423 pasta.ast_utils.replace_child(parent, node, new_node) 424 return True 425 else: 426 return False 427 428 def _maybe_change_to_function_call(self, parent, node, full_name): 429 """Wraps node (typically, an Attribute or Expr) in a Call.""" 430 if full_name in self._api_change_spec.change_to_function: 431 if not isinstance(parent, ast.Call): 432 # ast.Call's constructor is really picky about how many arguments it 433 # wants, and also, it changed between Py2 and Py3. 434 new_node = ast.Call(node, [], []) 435 pasta.ast_utils.replace_child(parent, node, new_node) 436 ast.copy_location(new_node, node) 437 self.add_log(INFO, node.lineno, node.col_offset, 438 "Changed %r to a function call" % full_name) 439 return True 440 return False 441 442 def _maybe_add_arg_names(self, node, full_name): 443 """Make args into keyword args if function called full_name requires it.""" 444 function_reorders = self._api_change_spec.function_reorders 445 446 if full_name in function_reorders: 447 if uses_star_args_in_call(node): 448 self.add_log(WARNING, node.lineno, node.col_offset, 449 "(Manual check required) upgrading %s may require " 450 "re-ordering the call arguments, but it was passed " 451 "variable-length positional *args. The upgrade " 452 "script cannot handle these automatically." % full_name) 453 454 reordered = function_reorders[full_name] 455 new_keywords = [] 456 idx = 0 457 for arg in node.args: 458 if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred): 459 continue # Can't move Starred to keywords 460 keyword_arg = reordered[idx] 461 keyword = ast.keyword(arg=keyword_arg, value=arg) 462 new_keywords.append(keyword) 463 idx += 1 464 465 if new_keywords: 466 self.add_log(INFO, node.lineno, node.col_offset, 467 "Added keywords to args of function %r" % full_name) 468 node.args = [] 469 node.keywords = new_keywords + (node.keywords or []) 470 return True 471 return False 472 473 def _maybe_modify_args(self, node, full_name, name): 474 """Rename keyword args if the function called full_name requires it.""" 475 renamed_keywords = self._get_applicable_dict("function_keyword_renames", 476 full_name, name) 477 478 if not renamed_keywords: 479 return False 480 481 if uses_star_kwargs_in_call(node): 482 self.add_log(WARNING, node.lineno, node.col_offset, 483 "(Manual check required) upgrading %s may require " 484 "renaming or removing call arguments, but it was passed " 485 "variable-length *args or **kwargs. The upgrade " 486 "script cannot handle these automatically." % 487 (full_name or name)) 488 modified = False 489 new_keywords = [] 490 for keyword in node.keywords: 491 argkey = keyword.arg 492 if argkey in renamed_keywords: 493 modified = True 494 if renamed_keywords[argkey] is None: 495 lineno = getattr(keyword, "lineno", node.lineno) 496 col_offset = getattr(keyword, "col_offset", node.col_offset) 497 self.add_log(INFO, lineno, col_offset, 498 "Removed argument %s for function %s" % ( 499 argkey, full_name or name)) 500 else: 501 keyword.arg = renamed_keywords[argkey] 502 lineno = getattr(keyword, "lineno", node.lineno) 503 col_offset = getattr(keyword, "col_offset", node.col_offset) 504 self.add_log(INFO, lineno, col_offset, 505 "Renamed keyword argument for %s from %s to %s" % ( 506 full_name, argkey, renamed_keywords[argkey])) 507 new_keywords.append(keyword) 508 else: 509 new_keywords.append(keyword) 510 511 if modified: 512 node.keywords = new_keywords 513 return modified 514 515 def visit_Call(self, node): # pylint: disable=invalid-name 516 """Handle visiting a call node in the AST. 517 518 Args: 519 node: Current Node 520 """ 521 assert self._stack[-1] is node 522 523 # Get the name for this call, so we can index stuff with it. 524 full_name = self._get_full_name(node.func) 525 if full_name: 526 name = full_name.split(".")[-1] 527 elif isinstance(node.func, ast.Name): 528 name = node.func.id 529 elif isinstance(node.func, ast.Attribute): 530 name = node.func.attr 531 else: 532 name = None 533 534 # Call standard transformers for this node. 535 # Make sure warnings come first, since args or names triggering warnings 536 # may be removed by the other transformations. 537 self._maybe_add_call_warning(node, full_name, name) 538 # Make all args into kwargs 539 self._maybe_add_arg_names(node, full_name) 540 # Argument name changes or deletions 541 self._maybe_modify_args(node, full_name, name) 542 543 # Call transformers. These have the ability to modify the node, and if they 544 # do, will return the new node they created (or the same node if they just 545 # changed it). The are given the parent, but we will take care of 546 # integrating their changes into the parent if they return a new node. 547 # 548 # These are matched on the old name, since renaming is performed by the 549 # Attribute visitor, which happens later. 550 transformers = self._get_applicable_entries("function_transformers", 551 full_name, name) 552 553 parent = self._stack[-2] 554 555 if transformers: 556 if uses_star_args_or_kwargs_in_call(node): 557 self.add_log(WARNING, node.lineno, node.col_offset, 558 "(Manual check required) upgrading %s may require " 559 "modifying call arguments, but it was passed " 560 "variable-length *args or **kwargs. The upgrade " 561 "script cannot handle these automatically." % 562 (full_name or name)) 563 564 for transformer in transformers: 565 logs = [] 566 new_node = transformer(parent, node, full_name, name, logs) 567 self.add_logs(logs) 568 if new_node and new_node is not node: 569 pasta.ast_utils.replace_child(parent, node, new_node) 570 node = new_node 571 self._stack[-1] = node 572 573 self.generic_visit(node) 574 575 def visit_Attribute(self, node): # pylint: disable=invalid-name 576 """Handle bare Attributes i.e. [tf.foo, tf.bar].""" 577 assert self._stack[-1] is node 578 579 full_name = self._get_full_name(node) 580 if full_name: 581 parent = self._stack[-2] 582 583 # Make sure the warning comes first, otherwise the name may have changed 584 self._maybe_add_warning(node, full_name) 585 586 # Once we did a modification, node is invalid and not worth inspecting 587 # further. Also, we only perform modifications for simple nodes, so 588 # There'd be no point in descending further. 589 if self._maybe_rename(parent, node, full_name): 590 return 591 if self._maybe_change_to_function_call(parent, node, full_name): 592 return 593 594 # The isinstance check is enough -- a bare Attribute is never root. 595 i = 2 596 while isinstance(self._stack[-i], ast.Attribute): 597 i += 1 598 whole_name = pasta.dump(self._stack[-(i-1)]) 599 600 self._maybe_add_module_deprecation_warning(node, full_name, whole_name) 601 602 self.generic_visit(node) 603 604 def visit_Import(self, node): # pylint: disable=invalid-name 605 """Handle visiting an import node in the AST. 606 607 Args: 608 node: Current Node 609 """ 610 new_aliases = [] 611 import_updated = False 612 import_renames = getattr(self._api_change_spec, "import_renames", {}) 613 max_submodule_depth = getattr(self._api_change_spec, "max_submodule_depth", 614 1) 615 inserts_after_imports = getattr(self._api_change_spec, 616 "inserts_after_imports", {}) 617 618 # This loop processes imports in the format 619 # import foo as f, bar as b 620 for import_alias in node.names: 621 all_import_components = import_alias.name.split(".") 622 # Look for rename, starting with longest import levels. 623 found_update = False 624 for i in reversed(list(range(1, max_submodule_depth + 1))): 625 import_component = all_import_components[0] 626 for j in range(1, min(i, len(all_import_components))): 627 import_component += "." + all_import_components[j] 628 import_rename_spec = import_renames.get(import_component, None) 629 630 if not import_rename_spec or excluded_from_module_rename( 631 import_alias.name, import_rename_spec): 632 continue 633 634 new_name = ( 635 import_rename_spec.new_name + 636 import_alias.name[len(import_component):]) 637 638 # If current import is 639 # import foo 640 # then new import should preserve imported name: 641 # import new_foo as foo 642 # This happens when module has just one component. 643 new_asname = import_alias.asname 644 if not new_asname and "." not in import_alias.name: 645 new_asname = import_alias.name 646 647 new_alias = ast.alias(name=new_name, asname=new_asname) 648 new_aliases.append(new_alias) 649 import_updated = True 650 found_update = True 651 652 # Insert any followup lines that should happen after this import. 653 full_import = (import_alias.name, import_alias.asname) 654 insert_offset = 1 655 for line_to_insert in inserts_after_imports.get(full_import, []): 656 assert self._stack[-1] is node 657 parent = self._stack[-2] 658 659 new_line_node = pasta.parse(line_to_insert) 660 ast.copy_location(new_line_node, node) 661 parent.body.insert( 662 parent.body.index(node) + insert_offset, new_line_node) 663 insert_offset += 1 664 665 # Insert a newline after the import if necessary 666 old_suffix = pasta.base.formatting.get(node, "suffix") 667 if old_suffix is None: 668 old_suffix = os.linesep 669 if os.linesep not in old_suffix: 670 pasta.base.formatting.set(node, "suffix", old_suffix + os.linesep) 671 672 # Apply indentation to new node. 673 pasta.base.formatting.set(new_line_node, "prefix", 674 pasta.base.formatting.get(node, "prefix")) 675 pasta.base.formatting.set(new_line_node, "suffix", os.linesep) 676 self.add_log( 677 INFO, node.lineno, node.col_offset, 678 "Adding `%s` after import of %s" % 679 (new_line_node, import_alias.name)) 680 # Find one match, break 681 if found_update: 682 break 683 # No rename is found for all levels 684 if not found_update: 685 new_aliases.append(import_alias) # no change needed 686 687 # Replace the node if at least one import needs to be updated. 688 if import_updated: 689 assert self._stack[-1] is node 690 parent = self._stack[-2] 691 692 new_node = ast.Import(new_aliases) 693 ast.copy_location(new_node, node) 694 pasta.ast_utils.replace_child(parent, node, new_node) 695 self.add_log( 696 INFO, node.lineno, node.col_offset, 697 "Changed import from %r to %r." % 698 (pasta.dump(node), pasta.dump(new_node))) 699 700 self.generic_visit(node) 701 702 def visit_ImportFrom(self, node): # pylint: disable=invalid-name 703 """Handle visiting an import-from node in the AST. 704 705 Args: 706 node: Current Node 707 """ 708 if not node.module: 709 self.generic_visit(node) 710 return 711 712 from_import = node.module 713 714 # Look for rename based on first component of from-import. 715 # i.e. based on foo in foo.bar. 716 from_import_first_component = from_import.split(".")[0] 717 import_renames = getattr(self._api_change_spec, "import_renames", {}) 718 import_rename_spec = import_renames.get(from_import_first_component, None) 719 if not import_rename_spec: 720 self.generic_visit(node) 721 return 722 723 # Split module aliases into the ones that require import update 724 # and those that don't. For e.g. if we want to rename "a" to "b" 725 # unless we import "a.c" in the following: 726 # from a import c, d 727 # we want to update import for "d" but not for "c". 728 updated_aliases = [] 729 same_aliases = [] 730 for import_alias in node.names: 731 full_module_name = "%s.%s" % (from_import, import_alias.name) 732 if excluded_from_module_rename(full_module_name, import_rename_spec): 733 same_aliases.append(import_alias) 734 else: 735 updated_aliases.append(import_alias) 736 737 if not updated_aliases: 738 self.generic_visit(node) 739 return 740 741 assert self._stack[-1] is node 742 parent = self._stack[-2] 743 744 # Replace first component of from-import with new name. 745 new_from_import = ( 746 import_rename_spec.new_name + 747 from_import[len(from_import_first_component):]) 748 updated_node = ast.ImportFrom(new_from_import, updated_aliases, node.level) 749 ast.copy_location(updated_node, node) 750 pasta.ast_utils.replace_child(parent, node, updated_node) 751 752 # If some imports had to stay the same, add another import for them. 753 additional_import_log = "" 754 if same_aliases: 755 same_node = ast.ImportFrom(from_import, same_aliases, node.level, 756 col_offset=node.col_offset, lineno=node.lineno) 757 ast.copy_location(same_node, node) 758 parent.body.insert(parent.body.index(updated_node), same_node) 759 # Apply indentation to new node. 760 pasta.base.formatting.set( 761 same_node, "prefix", 762 pasta.base.formatting.get(updated_node, "prefix")) 763 additional_import_log = " and %r" % pasta.dump(same_node) 764 765 self.add_log( 766 INFO, node.lineno, node.col_offset, 767 "Changed import from %r to %r%s." % 768 (pasta.dump(node), 769 pasta.dump(updated_node), 770 additional_import_log)) 771 772 self.generic_visit(node) 773 774 775class AnalysisResult: 776 """This class represents an analysis result and how it should be logged. 777 778 This class must provide the following fields: 779 780 * `log_level`: The log level to which this detection should be logged 781 * `log_message`: The message that should be logged for this detection 782 783 For an example, see `VersionedTFImport`. 784 """ 785 786 787class APIAnalysisSpec: 788 """This class defines how `AnalysisResult`s should be generated. 789 790 It specifies how to map imports and symbols to `AnalysisResult`s. 791 792 This class must provide the following fields: 793 794 * `symbols_to_detect`: maps function names to `AnalysisResult`s 795 * `imports_to_detect`: maps imports represented as (full module name, alias) 796 tuples to `AnalysisResult`s 797 notifications) 798 799 For an example, see `TFAPIImportAnalysisSpec`. 800 """ 801 802 803class PastaAnalyzeVisitor(_PastaEditVisitor): 804 """AST Visitor that looks for specific API usage without editing anything. 805 806 This is used before any rewriting is done to detect if any symbols are used 807 that require changing imports or disabling rewriting altogether. 808 """ 809 810 def __init__(self, api_analysis_spec): 811 super(PastaAnalyzeVisitor, self).__init__(NoUpdateSpec()) 812 self._api_analysis_spec = api_analysis_spec 813 self._results = [] # Holds AnalysisResult objects 814 815 @property 816 def results(self): 817 return self._results 818 819 def add_result(self, analysis_result): 820 self._results.append(analysis_result) 821 822 def visit_Attribute(self, node): # pylint: disable=invalid-name 823 """Handle bare Attributes i.e. [tf.foo, tf.bar].""" 824 full_name = self._get_full_name(node) 825 if full_name: 826 detection = self._api_analysis_spec.symbols_to_detect.get(full_name, None) 827 if detection: 828 self.add_result(detection) 829 self.add_log( 830 detection.log_level, node.lineno, node.col_offset, 831 detection.log_message) 832 833 self.generic_visit(node) 834 835 def visit_Import(self, node): # pylint: disable=invalid-name 836 """Handle visiting an import node in the AST. 837 838 Args: 839 node: Current Node 840 """ 841 for import_alias in node.names: 842 # Detect based on full import name and alias) 843 full_import = (import_alias.name, import_alias.asname) 844 detection = (self._api_analysis_spec 845 .imports_to_detect.get(full_import, None)) 846 if detection: 847 self.add_result(detection) 848 self.add_log( 849 detection.log_level, node.lineno, node.col_offset, 850 detection.log_message) 851 852 self.generic_visit(node) 853 854 def visit_ImportFrom(self, node): # pylint: disable=invalid-name 855 """Handle visiting an import-from node in the AST. 856 857 Args: 858 node: Current Node 859 """ 860 if not node.module: 861 self.generic_visit(node) 862 return 863 864 from_import = node.module 865 866 for import_alias in node.names: 867 # Detect based on full import name(to & as) 868 full_module_name = "%s.%s" % (from_import, import_alias.name) 869 full_import = (full_module_name, import_alias.asname) 870 detection = (self._api_analysis_spec 871 .imports_to_detect.get(full_import, None)) 872 if detection: 873 self.add_result(detection) 874 self.add_log( 875 detection.log_level, node.lineno, node.col_offset, 876 detection.log_message) 877 878 self.generic_visit(node) 879 880 881class ASTCodeUpgrader: 882 """Handles upgrading a set of Python files using a given API change spec.""" 883 884 def __init__(self, api_change_spec): 885 if not isinstance(api_change_spec, APIChangeSpec): 886 raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" % 887 type(api_change_spec)) 888 self._api_change_spec = api_change_spec 889 890 def process_file(self, 891 in_filename, 892 out_filename, 893 no_change_to_outfile_on_error=False): 894 """Process the given python file for incompatible changes. 895 896 Args: 897 in_filename: filename to parse 898 out_filename: output file to write to 899 no_change_to_outfile_on_error: not modify the output file on errors 900 Returns: 901 A tuple representing number of files processed, log of actions, errors 902 """ 903 904 # Write to a temporary file, just in case we are doing an implace modify. 905 # pylint: disable=g-backslash-continuation 906 with open(in_filename, "r") as in_file, \ 907 tempfile.NamedTemporaryFile("w", delete=False) as temp_file: 908 ret = self.process_opened_file(in_filename, in_file, out_filename, 909 temp_file) 910 # pylint: enable=g-backslash-continuation 911 912 if no_change_to_outfile_on_error and ret[0] == 0: 913 os.remove(temp_file.name) 914 else: 915 shutil.move(temp_file.name, out_filename) 916 return ret 917 918 def format_log(self, log, in_filename): 919 log_string = "%d:%d: %s: %s" % (log[1], log[2], log[0], log[3]) 920 if in_filename: 921 return in_filename + ":" + log_string 922 else: 923 return log_string 924 925 def update_string_pasta(self, text, in_filename): 926 """Updates a file using pasta.""" 927 try: 928 t = pasta.parse(text) 929 except (SyntaxError, ValueError, TypeError): 930 log = ["ERROR: Failed to parse.\n" + traceback.format_exc()] 931 return 0, "", log, [] 932 933 t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t) 934 935 visitor = _PastaEditVisitor(self._api_change_spec) 936 visitor.visit(t) 937 938 self._api_change_spec.clear_preprocessing() 939 940 logs = [self.format_log(log, None) for log in (preprocess_logs + 941 visitor.log)] 942 errors = [self.format_log(error, in_filename) 943 for error in (preprocess_errors + 944 visitor.warnings_and_errors)] 945 return 1, pasta.dump(t), logs, errors 946 947 def _format_log(self, log, in_filename, out_filename): 948 text = "-" * 80 + "\n" 949 text += "Processing file %r\n outputting to %r\n" % (in_filename, 950 out_filename) 951 text += "-" * 80 + "\n\n" 952 text += "\n".join(log) + "\n" 953 text += "-" * 80 + "\n\n" 954 return text 955 956 def process_opened_file(self, in_filename, in_file, out_filename, out_file): 957 """Process the given python file for incompatible changes. 958 959 This function is split out to facilitate StringIO testing from 960 tf_upgrade_test.py. 961 962 Args: 963 in_filename: filename to parse 964 in_file: opened file (or StringIO) 965 out_filename: output file to write to 966 out_file: opened file (or StringIO) 967 Returns: 968 A tuple representing number of files processed, log of actions, errors 969 """ 970 lines = in_file.readlines() 971 processed_file, new_file_content, log, process_errors = ( 972 self.update_string_pasta("".join(lines), in_filename)) 973 974 if out_file and processed_file: 975 out_file.write(new_file_content) 976 977 return (processed_file, 978 self._format_log(log, in_filename, out_filename), 979 process_errors) 980 981 def process_tree(self, root_directory, output_root_directory, 982 copy_other_files): 983 """Processes upgrades on an entire tree of python files in place. 984 985 Note that only Python files. If you have custom code in other languages, 986 you will need to manually upgrade those. 987 988 Args: 989 root_directory: Directory to walk and process. 990 output_root_directory: Directory to use as base. 991 copy_other_files: Copy files that are not touched by this converter. 992 993 Returns: 994 A tuple of files processed, the report string for all files, and a dict 995 mapping filenames to errors encountered in that file. 996 """ 997 998 if output_root_directory == root_directory: 999 return self.process_tree_inplace(root_directory) 1000 1001 # make sure output directory doesn't exist 1002 if output_root_directory and os.path.exists(output_root_directory): 1003 print("Output directory %r must not already exist." % 1004 (output_root_directory)) 1005 sys.exit(1) 1006 1007 # make sure output directory does not overlap with root_directory 1008 norm_root = os.path.split(os.path.normpath(root_directory)) 1009 norm_output = os.path.split(os.path.normpath(output_root_directory)) 1010 if norm_root == norm_output: 1011 print("Output directory %r same as input directory %r" % 1012 (root_directory, output_root_directory)) 1013 sys.exit(1) 1014 1015 # Collect list of files to process (we do this to correctly handle if the 1016 # user puts the output directory in some sub directory of the input dir) 1017 files_to_process = [] 1018 files_to_copy = [] 1019 for dir_name, _, file_list in os.walk(root_directory): 1020 py_files = [f for f in file_list if f.endswith(".py")] 1021 copy_files = [f for f in file_list if not f.endswith(".py")] 1022 for filename in py_files: 1023 fullpath = os.path.join(dir_name, filename) 1024 fullpath_output = os.path.join(output_root_directory, 1025 os.path.relpath(fullpath, 1026 root_directory)) 1027 files_to_process.append((fullpath, fullpath_output)) 1028 if copy_other_files: 1029 for filename in copy_files: 1030 fullpath = os.path.join(dir_name, filename) 1031 fullpath_output = os.path.join(output_root_directory, 1032 os.path.relpath( 1033 fullpath, root_directory)) 1034 files_to_copy.append((fullpath, fullpath_output)) 1035 1036 file_count = 0 1037 tree_errors = {} 1038 report = "" 1039 report += ("=" * 80) + "\n" 1040 report += "Input tree: %r\n" % root_directory 1041 report += ("=" * 80) + "\n" 1042 1043 for input_path, output_path in files_to_process: 1044 output_directory = os.path.dirname(output_path) 1045 if not os.path.isdir(output_directory): 1046 os.makedirs(output_directory) 1047 1048 if os.path.islink(input_path): 1049 link_target = os.readlink(input_path) 1050 link_target_output = os.path.join( 1051 output_root_directory, os.path.relpath(link_target, root_directory)) 1052 if (link_target, link_target_output) in files_to_process: 1053 # Create a link to the new location of the target file 1054 os.symlink(link_target_output, output_path) 1055 else: 1056 report += "Copying symlink %s without modifying its target %s" % ( 1057 input_path, link_target) 1058 os.symlink(link_target, output_path) 1059 continue 1060 1061 file_count += 1 1062 _, l_report, l_errors = self.process_file(input_path, output_path) 1063 tree_errors[input_path] = l_errors 1064 report += l_report 1065 1066 for input_path, output_path in files_to_copy: 1067 output_directory = os.path.dirname(output_path) 1068 if not os.path.isdir(output_directory): 1069 os.makedirs(output_directory) 1070 shutil.copy(input_path, output_path) 1071 return file_count, report, tree_errors 1072 1073 def process_tree_inplace(self, root_directory): 1074 """Process a directory of python files in place.""" 1075 files_to_process = [] 1076 for dir_name, _, file_list in os.walk(root_directory): 1077 py_files = [ 1078 os.path.join(dir_name, f) for f in file_list if f.endswith(".py") 1079 ] 1080 files_to_process += py_files 1081 1082 file_count = 0 1083 tree_errors = {} 1084 report = "" 1085 report += ("=" * 80) + "\n" 1086 report += "Input tree: %r\n" % root_directory 1087 report += ("=" * 80) + "\n" 1088 1089 for path in files_to_process: 1090 if os.path.islink(path): 1091 report += "Skipping symlink %s.\n" % path 1092 continue 1093 file_count += 1 1094 _, l_report, l_errors = self.process_file(path, path) 1095 tree_errors[path] = l_errors 1096 report += l_report 1097 1098 return file_count, report, tree_errors 1099