xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/templates.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST conversion templates.
16
17Adapted from Tangent.
18"""
19
20import ast
21import textwrap
22
23import gast
24
25from tensorflow.python.autograph.pyct import anno
26from tensorflow.python.autograph.pyct import ast_util
27from tensorflow.python.autograph.pyct import parser
28from tensorflow.python.autograph.pyct import qual_names
29
30
31class ContextAdjuster(gast.NodeTransformer):
32  """Adjusts the ctx field of nodes to ensure consistency.
33
34  This transformer can change the ctx fields of a variable, tuple and other
35  AST elements that allow one, based on whether the element is being read or
36  written.
37  """
38
39  def __init__(self, override_value):
40    self._ctx_override = override_value
41
42  def visit(self, node):
43    original_override = self._ctx_override
44    node = super(ContextAdjuster, self).visit(node)
45    if hasattr(node, 'ctx'):
46      assert node.ctx is not None, 'node {} has ctx unset'.format(node)
47    self._ctx_override = original_override
48    return node
49
50  def _apply_override(self, node):
51    if self._ctx_override is not None:
52      node.ctx = self._ctx_override()
53
54  def visit_Attribute(self, node):
55    self._apply_override(node)
56    self._ctx_override = gast.Load
57    node = self.generic_visit(node)
58    return node
59
60  def visit_Tuple(self, node):
61    self._apply_override(node)
62    return self.generic_visit(node)
63
64  def visit_List(self, node):
65    self._apply_override(node)
66    return self.generic_visit(node)
67
68  def visit_Name(self, node):
69    self._apply_override(node)
70    return self.generic_visit(node)
71
72  def visit_Call(self, node):
73    self._apply_override(node)
74    # We may be able to override these to Load(), but for now it's simpler
75    # to just assert that they're set.
76    self._ctx_override = None
77    return self.generic_visit(node)
78
79  def visit_Dict(self, node):
80    # We may be able to override these to Load(), but for now it's simpler
81    # to just assert that they're set.
82    self._ctx_override = None
83    return self.generic_visit(node)
84
85  def visit_Subscript(self, node):
86    self._apply_override(node)
87    self._ctx_override = gast.Load
88    node.value = self.visit(node.value)
89    return self.generic_visit(node)
90
91  def visit_comprehension(self, node):
92    # We may be able to override some of these, but for now it's simpler
93    # to just assert that they're set.
94    self._ctx_override = None
95    return self.generic_visit(node)
96
97  def visit_Lambda(self, node):
98    # We may be able to override some of these, but for now it's simpler
99    # to just assert that they're set.
100    self._ctx_override = None
101    return self.generic_visit(node)
102
103
104class ReplaceTransformer(gast.NodeTransformer):
105  """Replace AST nodes."""
106
107  def __init__(self, replacements):
108    """Create a new ReplaceTransformer.
109
110    Args:
111      replacements: A mapping from placeholder names to (lists of) AST nodes
112          that these placeholders will be replaced by.
113    """
114    self.replacements = replacements
115    self.in_replacements = False
116    self.preserved_annos = {
117        anno.Basic.DIRECTIVES,
118        anno.Basic.EXTRA_LOOP_TEST,
119        anno.Basic.ORIGIN,
120        anno.Basic.SKIP_PROCESSING,
121        anno.Static.ORIG_DEFINITIONS,
122        'function_context_name',
123    }
124
125  def _prepare_replacement(self, replaced, key):
126    """Prepares a replacement AST that's safe to swap in for a node.
127
128    Args:
129      replaced: ast.AST, the node being replaced
130      key: Hashable, the key of the replacement AST
131    Returns:
132      ast.AST, the replacement AST
133    """
134    repl = self.replacements[key]
135
136    new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos)
137    if isinstance(new_nodes, gast.AST):
138      new_nodes = [new_nodes]
139
140    return new_nodes
141
142  def visit_Expr(self, node):
143    # When replacing a placeholder with an entire statement, the replacement
144    # must stand on its own and not be wrapped in an Expr.
145    new_value = self.visit(node.value)
146    if new_value is node.value:
147      return node
148    return new_value
149
150  def visit_keyword(self, node):
151    if node.arg not in self.replacements:
152      return self.generic_visit(node)
153
154    repl = self._prepare_replacement(node, node.arg)
155    if isinstance(repl, gast.keyword):
156      return repl
157    elif (repl and isinstance(repl, (list, tuple)) and
158          all(isinstance(r, gast.keyword) for r in repl)):
159      return repl
160    # TODO(mdan): We may allow replacing with a string as well.
161    # For example, if one wanted to replace foo with bar in foo=baz, then
162    # we could allow changing just node arg, so that we end up with bar=baz.
163    raise ValueError(
164        'a keyword argument may only be replaced by another keyword or a '
165        'non-empty list of keywords. Found: {} for keyword {}'.format(
166            repl, node.arg))
167
168  def visit_FunctionDef(self, node):
169    node = self.generic_visit(node)
170    if node.name not in self.replacements:
171      return node
172
173    repl = self.replacements[node.name]
174    if not isinstance(repl, (gast.Name, ast.Name)):
175      raise ValueError(
176          'a function name can only be replaced by a Name node. Found: %s' %
177          repl)
178    node.name = repl.id
179    return node
180
181  def visit_Attribute(self, node):
182    node = self.generic_visit(node)
183    if node.attr not in self.replacements:
184      return node
185
186    repl = self.replacements[node.attr]
187    if not isinstance(repl, gast.Name):
188      raise ValueError(
189          'An attribute can only be replaced by a Name node. Found: %s' % repl)
190    node.attr = repl.id
191    return node
192
193  def visit_Name(self, node):
194    if node.id not in self.replacements:
195      return node
196
197    new_nodes = self._prepare_replacement(node, node.id)
198
199    if not new_nodes:
200      return new_nodes
201
202    # Preserve the target context.
203    adjuster = ContextAdjuster(type(node.ctx))
204    for n in new_nodes:
205      if hasattr(n, 'ctx'):
206        adjuster.visit(n)
207
208    if len(new_nodes) == 1:
209      new_nodes, = new_nodes
210
211    return new_nodes
212
213
214def _convert_to_ast(n):
215  """Converts from a known data type to AST."""
216  # Note: When generating AST nodes from strings/QNs in isolation, ctx is
217  # unknown. ctx must be filled in according to the template being used.
218  # See ReplaceTransformer.visit_Name.
219  if isinstance(n, str):
220    return gast.Name(id=n, ctx=None, annotation=None, type_comment=None)
221  if isinstance(n, qual_names.QN):
222    return n.ast()
223  if isinstance(n, list):
224    return [_convert_to_ast(e) for e in n]
225  if isinstance(n, tuple):
226    return tuple(_convert_to_ast(e) for e in n)
227  return n
228
229
230def replace(template, **replacements):
231  """Replaces placeholders in a Python template.
232
233  AST Name and Tuple nodes always receive the context that inferred from
234  the template. However, when replacing more complex nodes (that can potentially
235  contain Name children), then the caller is responsible for setting the
236  appropriate context.
237
238  Args:
239    template: A string representing Python code. Any symbol name can be used
240        that appears in the template code can be used as placeholder.
241    **replacements: A mapping from placeholder names to (lists of) AST nodes
242        that these placeholders will be replaced by. String values are also
243        supported as a shorthand for AST Name nodes with the respective ID.
244
245  Returns:
246    An AST node or list of AST nodes with the replacements made. If the
247    template was a function, a list will be returned. If the template was a
248    node, the same node will be returned. If the template was a string, an
249    AST node will be returned (a `Module` node in the case of a multi-line
250    string, an `Expr` node otherwise).
251
252  Raises:
253    ValueError: if the arguments are incorrect.
254  """
255  if not isinstance(template, str):
256    raise ValueError('Expected string template, got %s' % type(template))
257  for k in replacements:
258    replacements[k] = _convert_to_ast(replacements[k])
259  template_str = parser.STANDARD_PREAMBLE + textwrap.dedent(template)
260  nodes = parser.parse(
261      template_str,
262      preamble_len=parser.STANDARD_PREAMBLE_LEN,
263      single_node=False)
264  results = []
265  for node in nodes:
266    node = ReplaceTransformer(replacements).visit(node)
267    if isinstance(node, (list, tuple)):
268      results.extend(node)
269    else:
270      results.append(node)
271  results = [qual_names.resolve(r) for r in results]
272  return results
273
274
275def replace_as_expression(template, **replacements):
276  """Variant of replace that generates expressions, instead of code blocks."""
277  replacement = replace(template, **replacements)
278  if len(replacement) != 1:
279    raise ValueError(
280        'single expression expected; for more general templates use replace')
281  node, = replacement
282
283  if isinstance(node, gast.Expr):
284    return node.value
285  elif isinstance(node, gast.Name):
286    return node
287
288  raise ValueError(
289      'the template is expected to generate an expression or a name node;'
290      ' instead found %s' % node)
291