xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/compatibility/ast_edits.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Upgrader for Python scripts according to an API change specification."""
16
17import ast
18import collections
19import os
20import re
21import shutil
22import sys
23import tempfile
24import traceback
25
26import pasta
27
28
29# Some regular expressions we will need for parsing
30FIND_OPEN = re.compile(r"^\s*(\[).*$")
31FIND_STRING_CHARS = re.compile(r"['\"]")
32
33
34INFO = "INFO"
35WARNING = "WARNING"
36ERROR = "ERROR"
37
38
39ImportRename = collections.namedtuple(
40    "ImportRename", ["new_name", "excluded_prefixes"])
41
42
43def full_name_node(name, ctx=ast.Load()):
44  """Make an Attribute or Name node for name.
45
46  Translate a qualified name into nested Attribute nodes (and a Name node).
47
48  Args:
49    name: The name to translate to a node.
50    ctx: What context this name is used in. Defaults to Load()
51
52  Returns:
53    A Name or Attribute node.
54  """
55  names = name.split(".")
56  names.reverse()
57  node = ast.Name(id=names.pop(), ctx=ast.Load())
58  while names:
59    node = ast.Attribute(value=node, attr=names.pop(), ctx=ast.Load())
60
61  # Change outermost ctx to the one given to us (inner ones should be Load).
62  node.ctx = ctx
63  return node
64
65
66def get_arg_value(node, arg_name, arg_pos=None):
67  """Get the value of an argument from a ast.Call node.
68
69  This function goes through the positional and keyword arguments to check
70  whether a given argument was used, and if so, returns its value (the node
71  representing its value).
72
73  This cannot introspect *args or **args, but it safely handles *args in
74  Python3.5+.
75
76  Args:
77    node: The ast.Call node to extract arg values from.
78    arg_name: The name of the argument to extract.
79    arg_pos: The position of the argument (in case it's passed as a positional
80      argument).
81
82  Returns:
83    A tuple (arg_present, arg_value) containing a boolean indicating whether
84    the argument is present, and its value in case it is.
85  """
86  # Check keyword args
87  if arg_name is not None:
88    for kw in node.keywords:
89      if kw.arg == arg_name:
90        return (True, kw.value)
91
92  # Check positional args
93  if arg_pos is not None:
94    idx = 0
95    for arg in node.args:
96      if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred):
97        continue  # Can't parse Starred
98      if idx == arg_pos:
99        return (True, arg)
100      idx += 1
101
102  return (False, None)
103
104
105def uses_star_args_in_call(node):
106  """Check if an ast.Call node uses arbitrary-length positional *args.
107
108  This function works with the AST call node format of Python3.5+
109  as well as the different AST format of earlier versions of Python.
110
111  Args:
112    node: The ast.Call node to check arg values for.
113
114  Returns:
115    True if the node uses starred variadic positional args or keyword args.
116    False if it does not.
117  """
118  if sys.version_info[:2] >= (3, 5):
119    # Check for an *args usage in python 3.5+
120    for arg in node.args:
121      if isinstance(arg, ast.Starred):
122        return True
123  else:
124    if node.starargs:
125      return True
126  return False
127
128
129def uses_star_kwargs_in_call(node):
130  """Check if an ast.Call node uses arbitrary-length **kwargs.
131
132  This function works with the AST call node format of Python3.5+
133  as well as the different AST format of earlier versions of Python.
134
135  Args:
136    node: The ast.Call node to check arg values for.
137
138  Returns:
139    True if the node uses starred variadic positional args or keyword args.
140    False if it does not.
141  """
142  if sys.version_info[:2] >= (3, 5):
143    # Check for a **kwarg usage in python 3.5+
144    for keyword in node.keywords:
145      if keyword.arg is None:
146        return True
147  else:
148    if node.kwargs:
149      return True
150  return False
151
152
153def uses_star_args_or_kwargs_in_call(node):
154  """Check if an ast.Call node uses arbitrary-length *args or **kwargs.
155
156  This function works with the AST call node format of Python3.5+
157  as well as the different AST format of earlier versions of Python.
158
159  Args:
160    node: The ast.Call node to check arg values for.
161
162  Returns:
163    True if the node uses starred variadic positional args or keyword args.
164    False if it does not.
165  """
166  return uses_star_args_in_call(node) or uses_star_kwargs_in_call(node)
167
168
169def excluded_from_module_rename(module, import_rename_spec):
170  """Check if this module import should not be renamed.
171
172  Args:
173    module: (string) module name.
174    import_rename_spec: ImportRename instance.
175
176  Returns:
177    True if this import should not be renamed according to the
178    import_rename_spec.
179  """
180  for excluded_prefix in import_rename_spec.excluded_prefixes:
181    if module.startswith(excluded_prefix):
182      return True
183  return False
184
185
186class APIChangeSpec:
187  """This class defines the transformations that need to happen.
188
189  This class must provide the following fields:
190
191  * `function_keyword_renames`: maps function names to a map of old -> new
192    argument names
193  * `symbol_renames`: maps function names to new function names
194  * `change_to_function`: a set of function names that have changed (for
195    notifications)
196  * `function_reorders`: maps functions whose argument order has changed to the
197    list of arguments in the new order
198  * `function_warnings`: maps full names of functions to warnings that will be
199    printed out if the function is used. (e.g. tf.nn.convolution())
200  * `function_transformers`: maps function names to custom handlers
201  * `module_deprecations`: maps module names to warnings that will be printed
202    if the module is still used after all other transformations have run
203  * `import_renames`: maps import name (must be a short name without '.')
204    to ImportRename instance.
205
206  For an example, see `TFAPIChangeSpec`.
207  """
208
209  def preprocess(self, root_node):  # pylint: disable=unused-argument
210    """Preprocess a parse tree. Return a preprocessed node, logs and errors."""
211    return root_node, [], []
212
213  def clear_preprocessing(self):
214    """Restore this APIChangeSpec to before it preprocessed a file.
215
216    This is needed if preprocessing a file changed any rewriting rules.
217    """
218    pass
219
220
221class NoUpdateSpec(APIChangeSpec):
222  """A specification of an API change which doesn't change anything."""
223
224  def __init__(self):
225    self.function_handle = {}
226    self.function_reorders = {}
227    self.function_keyword_renames = {}
228    self.symbol_renames = {}
229    self.function_warnings = {}
230    self.change_to_function = {}
231    self.module_deprecations = {}
232    self.function_transformers = {}
233    self.import_renames = {}
234
235
236class _PastaEditVisitor(ast.NodeVisitor):
237  """AST Visitor that processes function calls.
238
239  Updates function calls from old API version to new API version using a given
240  change spec.
241  """
242
243  def __init__(self, api_change_spec):
244    self._api_change_spec = api_change_spec
245    self._log = []   # Holds 4-tuples: severity, line, col, msg.
246    self._stack = []  # Allow easy access to parents.
247
248  # Overridden to maintain a stack of nodes to allow for parent access
249  def visit(self, node):
250    self._stack.append(node)
251    super(_PastaEditVisitor, self).visit(node)
252    self._stack.pop()
253
254  @property
255  def errors(self):
256    return [log for log in self._log if log[0] == ERROR]
257
258  @property
259  def warnings(self):
260    return [log for log in self._log if log[0] == WARNING]
261
262  @property
263  def warnings_and_errors(self):
264    return [log for log in self._log if log[0] in (WARNING, ERROR)]
265
266  @property
267  def info(self):
268    return [log for log in self._log if log[0] == INFO]
269
270  @property
271  def log(self):
272    return self._log
273
274  def add_log(self, severity, lineno, col, msg):
275    self._log.append((severity, lineno, col, msg))
276    print("%s line %d:%d: %s" % (severity, lineno, col, msg))
277
278  def add_logs(self, logs):
279    """Record a log and print it.
280
281    The log should be a tuple `(severity, lineno, col_offset, msg)`, which will
282    be printed and recorded. It is part of the log available in the `self.log`
283    property.
284
285    Args:
286      logs: The logs to add. Must be a list of tuples
287        `(severity, lineno, col_offset, msg)`.
288    """
289    self._log.extend(logs)
290    for log in logs:
291      print("%s line %d:%d: %s" % log)
292
293  def _get_applicable_entries(self, transformer_field, full_name, name):
294    """Get all list entries indexed by name that apply to full_name or name."""
295    # Transformers are indexed to full name, name, or no name
296    # as a performance optimization.
297    function_transformers = getattr(self._api_change_spec,
298                                    transformer_field, {})
299
300    glob_name = "*." + name if name else None
301    transformers = []
302    if full_name in function_transformers:
303      transformers.append(function_transformers[full_name])
304    if glob_name in function_transformers:
305      transformers.append(function_transformers[glob_name])
306    if "*" in function_transformers:
307      transformers.append(function_transformers["*"])
308    return transformers
309
310  def _get_applicable_dict(self, transformer_field, full_name, name):
311    """Get all dict entries indexed by name that apply to full_name or name."""
312    # Transformers are indexed to full name, name, or no name
313    # as a performance optimization.
314    function_transformers = getattr(self._api_change_spec,
315                                    transformer_field, {})
316
317    glob_name = "*." + name if name else None
318    transformers = function_transformers.get("*", {}).copy()
319    transformers.update(function_transformers.get(glob_name, {}))
320    transformers.update(function_transformers.get(full_name, {}))
321    return transformers
322
323  def _get_full_name(self, node):
324    """Traverse an Attribute node to generate a full name, e.g., "tf.foo.bar".
325
326    This is the inverse of `full_name_node`.
327
328    Args:
329      node: A Node of type Attribute.
330
331    Returns:
332      a '.'-delimited full-name or None if node was not Attribute or Name.
333      i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
334    """
335    curr = node
336    items = []
337    while not isinstance(curr, ast.Name):
338      if not isinstance(curr, ast.Attribute):
339        return None
340      items.append(curr.attr)
341      curr = curr.value
342    items.append(curr.id)
343    return ".".join(reversed(items))
344
345  def _maybe_add_warning(self, node, full_name):
346    """Adds an error to be printed about full_name at node."""
347    function_warnings = self._api_change_spec.function_warnings
348    if full_name in function_warnings:
349      level, message = function_warnings[full_name]
350      message = message.replace("<function name>", full_name)
351      self.add_log(level, node.lineno, node.col_offset,
352                   "%s requires manual check. %s" % (full_name, message))
353      return True
354    else:
355      return False
356
357  def _maybe_add_module_deprecation_warning(self, node, full_name, whole_name):
358    """Adds a warning if full_name is a deprecated module."""
359    warnings = self._api_change_spec.module_deprecations
360    if full_name in warnings:
361      level, message = warnings[full_name]
362      message = message.replace("<function name>", whole_name)
363      self.add_log(level, node.lineno, node.col_offset,
364                   "Using member %s in deprecated module %s. %s" % (whole_name,
365                                                                    full_name,
366                                                                    message))
367      return True
368    else:
369      return False
370
371  def _maybe_add_call_warning(self, node, full_name, name):
372    """Print a warning when specific functions are called with selected args.
373
374    The function _print_warning_for_function matches the full name of the called
375    function, e.g., tf.foo.bar(). This function matches the function name that
376    is called, as long as the function is an attribute. For example,
377    `tf.foo.bar()` and `foo.bar()` are matched, but not `bar()`.
378
379    Args:
380      node: ast.Call object
381      full_name: The precomputed full name of the callable, if one exists, None
382        otherwise.
383      name: The precomputed name of the callable, if one exists, None otherwise.
384
385    Returns:
386      Whether an error was recorded.
387    """
388    # Only look for *.-warnings here, the other will be handled by the Attribute
389    # visitor. Also, do not warn for bare functions, only if the call func is
390    # an attribute.
391    warned = False
392    if isinstance(node.func, ast.Attribute):
393      warned = self._maybe_add_warning(node, "*." + name)
394
395    # All arg warnings are handled here, since only we have the args
396    arg_warnings = self._get_applicable_dict("function_arg_warnings",
397                                             full_name, name)
398
399    variadic_args = uses_star_args_or_kwargs_in_call(node)
400
401    for (kwarg, arg), (level, warning) in sorted(arg_warnings.items()):
402      present, _ = get_arg_value(node, kwarg, arg) or variadic_args
403      if present:
404        warned = True
405        warning_message = warning.replace("<function name>", full_name or name)
406        template = "%s called with %s argument, requires manual check: %s"
407        if variadic_args:
408          template = ("%s called with *args or **kwargs that may include %s, "
409                      "requires manual check: %s")
410        self.add_log(level, node.lineno, node.col_offset,
411                     template % (full_name or name, kwarg, warning_message))
412
413    return warned
414
415  def _maybe_rename(self, parent, node, full_name):
416    """Replace node (Attribute or Name) with a node representing full_name."""
417    new_name = self._api_change_spec.symbol_renames.get(full_name, None)
418    if new_name:
419      self.add_log(INFO, node.lineno, node.col_offset,
420                   "Renamed %r to %r" % (full_name, new_name))
421      new_node = full_name_node(new_name, node.ctx)
422      ast.copy_location(new_node, node)
423      pasta.ast_utils.replace_child(parent, node, new_node)
424      return True
425    else:
426      return False
427
428  def _maybe_change_to_function_call(self, parent, node, full_name):
429    """Wraps node (typically, an Attribute or Expr) in a Call."""
430    if full_name in self._api_change_spec.change_to_function:
431      if not isinstance(parent, ast.Call):
432        # ast.Call's constructor is really picky about how many arguments it
433        # wants, and also, it changed between Py2 and Py3.
434        new_node = ast.Call(node, [], [])
435        pasta.ast_utils.replace_child(parent, node, new_node)
436        ast.copy_location(new_node, node)
437        self.add_log(INFO, node.lineno, node.col_offset,
438                     "Changed %r to a function call" % full_name)
439        return True
440    return False
441
442  def _maybe_add_arg_names(self, node, full_name):
443    """Make args into keyword args if function called full_name requires it."""
444    function_reorders = self._api_change_spec.function_reorders
445
446    if full_name in function_reorders:
447      if uses_star_args_in_call(node):
448        self.add_log(WARNING, node.lineno, node.col_offset,
449                     "(Manual check required) upgrading %s may require "
450                     "re-ordering the call arguments, but it was passed "
451                     "variable-length positional *args. The upgrade "
452                     "script cannot handle these automatically." % full_name)
453
454      reordered = function_reorders[full_name]
455      new_keywords = []
456      idx = 0
457      for arg in node.args:
458        if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred):
459          continue  # Can't move Starred to keywords
460        keyword_arg = reordered[idx]
461        keyword = ast.keyword(arg=keyword_arg, value=arg)
462        new_keywords.append(keyword)
463        idx += 1
464
465      if new_keywords:
466        self.add_log(INFO, node.lineno, node.col_offset,
467                     "Added keywords to args of function %r" % full_name)
468        node.args = []
469        node.keywords = new_keywords + (node.keywords or [])
470        return True
471    return False
472
473  def _maybe_modify_args(self, node, full_name, name):
474    """Rename keyword args if the function called full_name requires it."""
475    renamed_keywords = self._get_applicable_dict("function_keyword_renames",
476                                                 full_name, name)
477
478    if not renamed_keywords:
479      return False
480
481    if uses_star_kwargs_in_call(node):
482      self.add_log(WARNING, node.lineno, node.col_offset,
483                   "(Manual check required) upgrading %s may require "
484                   "renaming or removing call arguments, but it was passed "
485                   "variable-length *args or **kwargs. The upgrade "
486                   "script cannot handle these automatically." %
487                   (full_name or name))
488    modified = False
489    new_keywords = []
490    for keyword in node.keywords:
491      argkey = keyword.arg
492      if argkey in renamed_keywords:
493        modified = True
494        if renamed_keywords[argkey] is None:
495          lineno = getattr(keyword, "lineno", node.lineno)
496          col_offset = getattr(keyword, "col_offset", node.col_offset)
497          self.add_log(INFO, lineno, col_offset,
498                       "Removed argument %s for function %s" % (
499                           argkey, full_name or name))
500        else:
501          keyword.arg = renamed_keywords[argkey]
502          lineno = getattr(keyword, "lineno", node.lineno)
503          col_offset = getattr(keyword, "col_offset", node.col_offset)
504          self.add_log(INFO, lineno, col_offset,
505                       "Renamed keyword argument for %s from %s to %s" % (
506                           full_name, argkey, renamed_keywords[argkey]))
507          new_keywords.append(keyword)
508      else:
509        new_keywords.append(keyword)
510
511    if modified:
512      node.keywords = new_keywords
513    return modified
514
515  def visit_Call(self, node):  # pylint: disable=invalid-name
516    """Handle visiting a call node in the AST.
517
518    Args:
519      node: Current Node
520    """
521    assert self._stack[-1] is node
522
523    # Get the name for this call, so we can index stuff with it.
524    full_name = self._get_full_name(node.func)
525    if full_name:
526      name = full_name.split(".")[-1]
527    elif isinstance(node.func, ast.Name):
528      name = node.func.id
529    elif isinstance(node.func, ast.Attribute):
530      name = node.func.attr
531    else:
532      name = None
533
534    # Call standard transformers for this node.
535    # Make sure warnings come first, since args or names triggering warnings
536    # may be removed by the other transformations.
537    self._maybe_add_call_warning(node, full_name, name)
538    # Make all args into kwargs
539    self._maybe_add_arg_names(node, full_name)
540    # Argument name changes or deletions
541    self._maybe_modify_args(node, full_name, name)
542
543    # Call transformers. These have the ability to modify the node, and if they
544    # do, will return the new node they created (or the same node if they just
545    # changed it). The are given the parent, but we will take care of
546    # integrating their changes into the parent if they return a new node.
547    #
548    # These are matched on the old name, since renaming is performed by the
549    # Attribute visitor, which happens later.
550    transformers = self._get_applicable_entries("function_transformers",
551                                                full_name, name)
552
553    parent = self._stack[-2]
554
555    if transformers:
556      if uses_star_args_or_kwargs_in_call(node):
557        self.add_log(WARNING, node.lineno, node.col_offset,
558                     "(Manual check required) upgrading %s may require "
559                     "modifying call arguments, but it was passed "
560                     "variable-length *args or **kwargs. The upgrade "
561                     "script cannot handle these automatically." %
562                     (full_name or name))
563
564    for transformer in transformers:
565      logs = []
566      new_node = transformer(parent, node, full_name, name, logs)
567      self.add_logs(logs)
568      if new_node and new_node is not node:
569        pasta.ast_utils.replace_child(parent, node, new_node)
570        node = new_node
571        self._stack[-1] = node
572
573    self.generic_visit(node)
574
575  def visit_Attribute(self, node):  # pylint: disable=invalid-name
576    """Handle bare Attributes i.e. [tf.foo, tf.bar]."""
577    assert self._stack[-1] is node
578
579    full_name = self._get_full_name(node)
580    if full_name:
581      parent = self._stack[-2]
582
583      # Make sure the warning comes first, otherwise the name may have changed
584      self._maybe_add_warning(node, full_name)
585
586      # Once we did a modification, node is invalid and not worth inspecting
587      # further. Also, we only perform modifications for simple nodes, so
588      # There'd be no point in descending further.
589      if self._maybe_rename(parent, node, full_name):
590        return
591      if self._maybe_change_to_function_call(parent, node, full_name):
592        return
593
594      # The isinstance check is enough -- a bare Attribute is never root.
595      i = 2
596      while isinstance(self._stack[-i], ast.Attribute):
597        i += 1
598      whole_name = pasta.dump(self._stack[-(i-1)])
599
600      self._maybe_add_module_deprecation_warning(node, full_name, whole_name)
601
602    self.generic_visit(node)
603
604  def visit_Import(self, node):  # pylint: disable=invalid-name
605    """Handle visiting an import node in the AST.
606
607    Args:
608      node: Current Node
609    """
610    new_aliases = []
611    import_updated = False
612    import_renames = getattr(self._api_change_spec, "import_renames", {})
613    max_submodule_depth = getattr(self._api_change_spec, "max_submodule_depth",
614                                  1)
615    inserts_after_imports = getattr(self._api_change_spec,
616                                    "inserts_after_imports", {})
617
618    # This loop processes imports in the format
619    # import foo as f, bar as b
620    for import_alias in node.names:
621      all_import_components = import_alias.name.split(".")
622      # Look for rename, starting with longest import levels.
623      found_update = False
624      for i in reversed(list(range(1, max_submodule_depth + 1))):
625        import_component = all_import_components[0]
626        for j in range(1, min(i, len(all_import_components))):
627          import_component += "." + all_import_components[j]
628        import_rename_spec = import_renames.get(import_component, None)
629
630        if not import_rename_spec or excluded_from_module_rename(
631            import_alias.name, import_rename_spec):
632          continue
633
634        new_name = (
635            import_rename_spec.new_name +
636            import_alias.name[len(import_component):])
637
638        # If current import is
639        #   import foo
640        # then new import should preserve imported name:
641        #   import new_foo as foo
642        # This happens when module has just one component.
643        new_asname = import_alias.asname
644        if not new_asname and "." not in import_alias.name:
645          new_asname = import_alias.name
646
647        new_alias = ast.alias(name=new_name, asname=new_asname)
648        new_aliases.append(new_alias)
649        import_updated = True
650        found_update = True
651
652        # Insert any followup lines that should happen after this import.
653        full_import = (import_alias.name, import_alias.asname)
654        insert_offset = 1
655        for line_to_insert in inserts_after_imports.get(full_import, []):
656          assert self._stack[-1] is node
657          parent = self._stack[-2]
658
659          new_line_node = pasta.parse(line_to_insert)
660          ast.copy_location(new_line_node, node)
661          parent.body.insert(
662              parent.body.index(node) + insert_offset, new_line_node)
663          insert_offset += 1
664
665          # Insert a newline after the import if necessary
666          old_suffix = pasta.base.formatting.get(node, "suffix")
667          if old_suffix is None:
668            old_suffix = os.linesep
669          if os.linesep not in old_suffix:
670            pasta.base.formatting.set(node, "suffix", old_suffix + os.linesep)
671
672          # Apply indentation to new node.
673          pasta.base.formatting.set(new_line_node, "prefix",
674                                    pasta.base.formatting.get(node, "prefix"))
675          pasta.base.formatting.set(new_line_node, "suffix", os.linesep)
676          self.add_log(
677              INFO, node.lineno, node.col_offset,
678              "Adding `%s` after import of %s" %
679              (new_line_node, import_alias.name))
680        # Find one match, break
681        if found_update:
682          break
683      # No rename is found for all levels
684      if not found_update:
685        new_aliases.append(import_alias)  # no change needed
686
687    # Replace the node if at least one import needs to be updated.
688    if import_updated:
689      assert self._stack[-1] is node
690      parent = self._stack[-2]
691
692      new_node = ast.Import(new_aliases)
693      ast.copy_location(new_node, node)
694      pasta.ast_utils.replace_child(parent, node, new_node)
695      self.add_log(
696          INFO, node.lineno, node.col_offset,
697          "Changed import from %r to %r." %
698          (pasta.dump(node), pasta.dump(new_node)))
699
700    self.generic_visit(node)
701
702  def visit_ImportFrom(self, node):  # pylint: disable=invalid-name
703    """Handle visiting an import-from node in the AST.
704
705    Args:
706      node: Current Node
707    """
708    if not node.module:
709      self.generic_visit(node)
710      return
711
712    from_import = node.module
713
714    # Look for rename based on first component of from-import.
715    # i.e. based on foo in foo.bar.
716    from_import_first_component = from_import.split(".")[0]
717    import_renames = getattr(self._api_change_spec, "import_renames", {})
718    import_rename_spec = import_renames.get(from_import_first_component, None)
719    if not import_rename_spec:
720      self.generic_visit(node)
721      return
722
723    # Split module aliases into the ones that require import update
724    # and those that don't. For e.g. if we want to rename "a" to "b"
725    # unless we import "a.c" in the following:
726    # from a import c, d
727    # we want to update import for "d" but not for "c".
728    updated_aliases = []
729    same_aliases = []
730    for import_alias in node.names:
731      full_module_name = "%s.%s" % (from_import, import_alias.name)
732      if excluded_from_module_rename(full_module_name, import_rename_spec):
733        same_aliases.append(import_alias)
734      else:
735        updated_aliases.append(import_alias)
736
737    if not updated_aliases:
738      self.generic_visit(node)
739      return
740
741    assert self._stack[-1] is node
742    parent = self._stack[-2]
743
744    # Replace first component of from-import with new name.
745    new_from_import = (
746        import_rename_spec.new_name +
747        from_import[len(from_import_first_component):])
748    updated_node = ast.ImportFrom(new_from_import, updated_aliases, node.level)
749    ast.copy_location(updated_node, node)
750    pasta.ast_utils.replace_child(parent, node, updated_node)
751
752    # If some imports had to stay the same, add another import for them.
753    additional_import_log = ""
754    if same_aliases:
755      same_node = ast.ImportFrom(from_import, same_aliases, node.level,
756                                 col_offset=node.col_offset, lineno=node.lineno)
757      ast.copy_location(same_node, node)
758      parent.body.insert(parent.body.index(updated_node), same_node)
759      # Apply indentation to new node.
760      pasta.base.formatting.set(
761          same_node, "prefix",
762          pasta.base.formatting.get(updated_node, "prefix"))
763      additional_import_log = " and %r" % pasta.dump(same_node)
764
765    self.add_log(
766        INFO, node.lineno, node.col_offset,
767        "Changed import from %r to %r%s." %
768        (pasta.dump(node),
769         pasta.dump(updated_node),
770         additional_import_log))
771
772    self.generic_visit(node)
773
774
775class AnalysisResult:
776  """This class represents an analysis result and how it should be logged.
777
778  This class must provide the following fields:
779
780  * `log_level`: The log level to which this detection should be logged
781  * `log_message`: The message that should be logged for this detection
782
783  For an example, see `VersionedTFImport`.
784  """
785
786
787class APIAnalysisSpec:
788  """This class defines how `AnalysisResult`s should be generated.
789
790  It specifies how to map imports and symbols to `AnalysisResult`s.
791
792  This class must provide the following fields:
793
794  * `symbols_to_detect`: maps function names to `AnalysisResult`s
795  * `imports_to_detect`: maps imports represented as (full module name, alias)
796    tuples to `AnalysisResult`s
797    notifications)
798
799  For an example, see `TFAPIImportAnalysisSpec`.
800  """
801
802
803class PastaAnalyzeVisitor(_PastaEditVisitor):
804  """AST Visitor that looks for specific API usage without editing anything.
805
806  This is used before any rewriting is done to detect if any symbols are used
807  that require changing imports or disabling rewriting altogether.
808  """
809
810  def __init__(self, api_analysis_spec):
811    super(PastaAnalyzeVisitor, self).__init__(NoUpdateSpec())
812    self._api_analysis_spec = api_analysis_spec
813    self._results = []   # Holds AnalysisResult objects
814
815  @property
816  def results(self):
817    return self._results
818
819  def add_result(self, analysis_result):
820    self._results.append(analysis_result)
821
822  def visit_Attribute(self, node):  # pylint: disable=invalid-name
823    """Handle bare Attributes i.e. [tf.foo, tf.bar]."""
824    full_name = self._get_full_name(node)
825    if full_name:
826      detection = self._api_analysis_spec.symbols_to_detect.get(full_name, None)
827      if detection:
828        self.add_result(detection)
829        self.add_log(
830            detection.log_level, node.lineno, node.col_offset,
831            detection.log_message)
832
833    self.generic_visit(node)
834
835  def visit_Import(self, node):  # pylint: disable=invalid-name
836    """Handle visiting an import node in the AST.
837
838    Args:
839      node: Current Node
840    """
841    for import_alias in node.names:
842      # Detect based on full import name and alias)
843      full_import = (import_alias.name, import_alias.asname)
844      detection = (self._api_analysis_spec
845                   .imports_to_detect.get(full_import, None))
846      if detection:
847        self.add_result(detection)
848        self.add_log(
849            detection.log_level, node.lineno, node.col_offset,
850            detection.log_message)
851
852    self.generic_visit(node)
853
854  def visit_ImportFrom(self, node):  # pylint: disable=invalid-name
855    """Handle visiting an import-from node in the AST.
856
857    Args:
858      node: Current Node
859    """
860    if not node.module:
861      self.generic_visit(node)
862      return
863
864    from_import = node.module
865
866    for import_alias in node.names:
867      # Detect based on full import name(to & as)
868      full_module_name = "%s.%s" % (from_import, import_alias.name)
869      full_import = (full_module_name, import_alias.asname)
870      detection = (self._api_analysis_spec
871                   .imports_to_detect.get(full_import, None))
872      if detection:
873        self.add_result(detection)
874        self.add_log(
875            detection.log_level, node.lineno, node.col_offset,
876            detection.log_message)
877
878    self.generic_visit(node)
879
880
881class ASTCodeUpgrader:
882  """Handles upgrading a set of Python files using a given API change spec."""
883
884  def __init__(self, api_change_spec):
885    if not isinstance(api_change_spec, APIChangeSpec):
886      raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" %
887                      type(api_change_spec))
888    self._api_change_spec = api_change_spec
889
890  def process_file(self,
891                   in_filename,
892                   out_filename,
893                   no_change_to_outfile_on_error=False):
894    """Process the given python file for incompatible changes.
895
896    Args:
897      in_filename: filename to parse
898      out_filename: output file to write to
899      no_change_to_outfile_on_error: not modify the output file on errors
900    Returns:
901      A tuple representing number of files processed, log of actions, errors
902    """
903
904    # Write to a temporary file, just in case we are doing an implace modify.
905    # pylint: disable=g-backslash-continuation
906    with open(in_filename, "r") as in_file, \
907        tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
908      ret = self.process_opened_file(in_filename, in_file, out_filename,
909                                     temp_file)
910    # pylint: enable=g-backslash-continuation
911
912    if no_change_to_outfile_on_error and ret[0] == 0:
913      os.remove(temp_file.name)
914    else:
915      shutil.move(temp_file.name, out_filename)
916    return ret
917
918  def format_log(self, log, in_filename):
919    log_string = "%d:%d: %s: %s" % (log[1], log[2], log[0], log[3])
920    if in_filename:
921      return in_filename + ":" + log_string
922    else:
923      return log_string
924
925  def update_string_pasta(self, text, in_filename):
926    """Updates a file using pasta."""
927    try:
928      t = pasta.parse(text)
929    except (SyntaxError, ValueError, TypeError):
930      log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
931      return 0, "", log, []
932
933    t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
934
935    visitor = _PastaEditVisitor(self._api_change_spec)
936    visitor.visit(t)
937
938    self._api_change_spec.clear_preprocessing()
939
940    logs = [self.format_log(log, None) for log in (preprocess_logs +
941                                                   visitor.log)]
942    errors = [self.format_log(error, in_filename)
943              for error in (preprocess_errors +
944                            visitor.warnings_and_errors)]
945    return 1, pasta.dump(t), logs, errors
946
947  def _format_log(self, log, in_filename, out_filename):
948    text = "-" * 80 + "\n"
949    text += "Processing file %r\n outputting to %r\n" % (in_filename,
950                                                         out_filename)
951    text += "-" * 80 + "\n\n"
952    text += "\n".join(log) + "\n"
953    text += "-" * 80 + "\n\n"
954    return text
955
956  def process_opened_file(self, in_filename, in_file, out_filename, out_file):
957    """Process the given python file for incompatible changes.
958
959    This function is split out to facilitate StringIO testing from
960    tf_upgrade_test.py.
961
962    Args:
963      in_filename: filename to parse
964      in_file: opened file (or StringIO)
965      out_filename: output file to write to
966      out_file: opened file (or StringIO)
967    Returns:
968      A tuple representing number of files processed, log of actions, errors
969    """
970    lines = in_file.readlines()
971    processed_file, new_file_content, log, process_errors = (
972        self.update_string_pasta("".join(lines), in_filename))
973
974    if out_file and processed_file:
975      out_file.write(new_file_content)
976
977    return (processed_file,
978            self._format_log(log, in_filename, out_filename),
979            process_errors)
980
981  def process_tree(self, root_directory, output_root_directory,
982                   copy_other_files):
983    """Processes upgrades on an entire tree of python files in place.
984
985    Note that only Python files. If you have custom code in other languages,
986    you will need to manually upgrade those.
987
988    Args:
989      root_directory: Directory to walk and process.
990      output_root_directory: Directory to use as base.
991      copy_other_files: Copy files that are not touched by this converter.
992
993    Returns:
994      A tuple of files processed, the report string for all files, and a dict
995        mapping filenames to errors encountered in that file.
996    """
997
998    if output_root_directory == root_directory:
999      return self.process_tree_inplace(root_directory)
1000
1001    # make sure output directory doesn't exist
1002    if output_root_directory and os.path.exists(output_root_directory):
1003      print("Output directory %r must not already exist." %
1004            (output_root_directory))
1005      sys.exit(1)
1006
1007    # make sure output directory does not overlap with root_directory
1008    norm_root = os.path.split(os.path.normpath(root_directory))
1009    norm_output = os.path.split(os.path.normpath(output_root_directory))
1010    if norm_root == norm_output:
1011      print("Output directory %r same as input directory %r" %
1012            (root_directory, output_root_directory))
1013      sys.exit(1)
1014
1015    # Collect list of files to process (we do this to correctly handle if the
1016    # user puts the output directory in some sub directory of the input dir)
1017    files_to_process = []
1018    files_to_copy = []
1019    for dir_name, _, file_list in os.walk(root_directory):
1020      py_files = [f for f in file_list if f.endswith(".py")]
1021      copy_files = [f for f in file_list if not f.endswith(".py")]
1022      for filename in py_files:
1023        fullpath = os.path.join(dir_name, filename)
1024        fullpath_output = os.path.join(output_root_directory,
1025                                       os.path.relpath(fullpath,
1026                                                       root_directory))
1027        files_to_process.append((fullpath, fullpath_output))
1028      if copy_other_files:
1029        for filename in copy_files:
1030          fullpath = os.path.join(dir_name, filename)
1031          fullpath_output = os.path.join(output_root_directory,
1032                                         os.path.relpath(
1033                                             fullpath, root_directory))
1034          files_to_copy.append((fullpath, fullpath_output))
1035
1036    file_count = 0
1037    tree_errors = {}
1038    report = ""
1039    report += ("=" * 80) + "\n"
1040    report += "Input tree: %r\n" % root_directory
1041    report += ("=" * 80) + "\n"
1042
1043    for input_path, output_path in files_to_process:
1044      output_directory = os.path.dirname(output_path)
1045      if not os.path.isdir(output_directory):
1046        os.makedirs(output_directory)
1047
1048      if os.path.islink(input_path):
1049        link_target = os.readlink(input_path)
1050        link_target_output = os.path.join(
1051            output_root_directory, os.path.relpath(link_target, root_directory))
1052        if (link_target, link_target_output) in files_to_process:
1053          # Create a link to the new location of the target file
1054          os.symlink(link_target_output, output_path)
1055        else:
1056          report += "Copying symlink %s without modifying its target %s" % (
1057              input_path, link_target)
1058          os.symlink(link_target, output_path)
1059        continue
1060
1061      file_count += 1
1062      _, l_report, l_errors = self.process_file(input_path, output_path)
1063      tree_errors[input_path] = l_errors
1064      report += l_report
1065
1066    for input_path, output_path in files_to_copy:
1067      output_directory = os.path.dirname(output_path)
1068      if not os.path.isdir(output_directory):
1069        os.makedirs(output_directory)
1070      shutil.copy(input_path, output_path)
1071    return file_count, report, tree_errors
1072
1073  def process_tree_inplace(self, root_directory):
1074    """Process a directory of python files in place."""
1075    files_to_process = []
1076    for dir_name, _, file_list in os.walk(root_directory):
1077      py_files = [
1078          os.path.join(dir_name, f) for f in file_list if f.endswith(".py")
1079      ]
1080      files_to_process += py_files
1081
1082    file_count = 0
1083    tree_errors = {}
1084    report = ""
1085    report += ("=" * 80) + "\n"
1086    report += "Input tree: %r\n" % root_directory
1087    report += ("=" * 80) + "\n"
1088
1089    for path in files_to_process:
1090      if os.path.islink(path):
1091        report += "Skipping symlink %s.\n" % path
1092        continue
1093      file_count += 1
1094      _, l_report, l_errors = self.process_file(path, path)
1095      tree_errors[path] = l_errors
1096      report += l_report
1097
1098    return file_count, report, tree_errors
1099