1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AST conversion templates. 16 17Adapted from Tangent. 18""" 19 20import ast 21import textwrap 22 23import gast 24 25from tensorflow.python.autograph.pyct import anno 26from tensorflow.python.autograph.pyct import ast_util 27from tensorflow.python.autograph.pyct import parser 28from tensorflow.python.autograph.pyct import qual_names 29 30 31class ContextAdjuster(gast.NodeTransformer): 32 """Adjusts the ctx field of nodes to ensure consistency. 33 34 This transformer can change the ctx fields of a variable, tuple and other 35 AST elements that allow one, based on whether the element is being read or 36 written. 37 """ 38 39 def __init__(self, override_value): 40 self._ctx_override = override_value 41 42 def visit(self, node): 43 original_override = self._ctx_override 44 node = super(ContextAdjuster, self).visit(node) 45 if hasattr(node, 'ctx'): 46 assert node.ctx is not None, 'node {} has ctx unset'.format(node) 47 self._ctx_override = original_override 48 return node 49 50 def _apply_override(self, node): 51 if self._ctx_override is not None: 52 node.ctx = self._ctx_override() 53 54 def visit_Attribute(self, node): 55 self._apply_override(node) 56 self._ctx_override = gast.Load 57 node = self.generic_visit(node) 58 return node 59 60 def visit_Tuple(self, node): 61 self._apply_override(node) 62 return self.generic_visit(node) 63 64 def visit_List(self, node): 65 self._apply_override(node) 66 return self.generic_visit(node) 67 68 def visit_Name(self, node): 69 self._apply_override(node) 70 return self.generic_visit(node) 71 72 def visit_Call(self, node): 73 self._apply_override(node) 74 # We may be able to override these to Load(), but for now it's simpler 75 # to just assert that they're set. 76 self._ctx_override = None 77 return self.generic_visit(node) 78 79 def visit_Dict(self, node): 80 # We may be able to override these to Load(), but for now it's simpler 81 # to just assert that they're set. 82 self._ctx_override = None 83 return self.generic_visit(node) 84 85 def visit_Subscript(self, node): 86 self._apply_override(node) 87 self._ctx_override = gast.Load 88 node.value = self.visit(node.value) 89 return self.generic_visit(node) 90 91 def visit_comprehension(self, node): 92 # We may be able to override some of these, but for now it's simpler 93 # to just assert that they're set. 94 self._ctx_override = None 95 return self.generic_visit(node) 96 97 def visit_Lambda(self, node): 98 # We may be able to override some of these, but for now it's simpler 99 # to just assert that they're set. 100 self._ctx_override = None 101 return self.generic_visit(node) 102 103 104class ReplaceTransformer(gast.NodeTransformer): 105 """Replace AST nodes.""" 106 107 def __init__(self, replacements): 108 """Create a new ReplaceTransformer. 109 110 Args: 111 replacements: A mapping from placeholder names to (lists of) AST nodes 112 that these placeholders will be replaced by. 113 """ 114 self.replacements = replacements 115 self.in_replacements = False 116 self.preserved_annos = { 117 anno.Basic.DIRECTIVES, 118 anno.Basic.EXTRA_LOOP_TEST, 119 anno.Basic.ORIGIN, 120 anno.Basic.SKIP_PROCESSING, 121 anno.Static.ORIG_DEFINITIONS, 122 'function_context_name', 123 } 124 125 def _prepare_replacement(self, replaced, key): 126 """Prepares a replacement AST that's safe to swap in for a node. 127 128 Args: 129 replaced: ast.AST, the node being replaced 130 key: Hashable, the key of the replacement AST 131 Returns: 132 ast.AST, the replacement AST 133 """ 134 repl = self.replacements[key] 135 136 new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos) 137 if isinstance(new_nodes, gast.AST): 138 new_nodes = [new_nodes] 139 140 return new_nodes 141 142 def visit_Expr(self, node): 143 # When replacing a placeholder with an entire statement, the replacement 144 # must stand on its own and not be wrapped in an Expr. 145 new_value = self.visit(node.value) 146 if new_value is node.value: 147 return node 148 return new_value 149 150 def visit_keyword(self, node): 151 if node.arg not in self.replacements: 152 return self.generic_visit(node) 153 154 repl = self._prepare_replacement(node, node.arg) 155 if isinstance(repl, gast.keyword): 156 return repl 157 elif (repl and isinstance(repl, (list, tuple)) and 158 all(isinstance(r, gast.keyword) for r in repl)): 159 return repl 160 # TODO(mdan): We may allow replacing with a string as well. 161 # For example, if one wanted to replace foo with bar in foo=baz, then 162 # we could allow changing just node arg, so that we end up with bar=baz. 163 raise ValueError( 164 'a keyword argument may only be replaced by another keyword or a ' 165 'non-empty list of keywords. Found: {} for keyword {}'.format( 166 repl, node.arg)) 167 168 def visit_FunctionDef(self, node): 169 node = self.generic_visit(node) 170 if node.name not in self.replacements: 171 return node 172 173 repl = self.replacements[node.name] 174 if not isinstance(repl, (gast.Name, ast.Name)): 175 raise ValueError( 176 'a function name can only be replaced by a Name node. Found: %s' % 177 repl) 178 node.name = repl.id 179 return node 180 181 def visit_Attribute(self, node): 182 node = self.generic_visit(node) 183 if node.attr not in self.replacements: 184 return node 185 186 repl = self.replacements[node.attr] 187 if not isinstance(repl, gast.Name): 188 raise ValueError( 189 'An attribute can only be replaced by a Name node. Found: %s' % repl) 190 node.attr = repl.id 191 return node 192 193 def visit_Name(self, node): 194 if node.id not in self.replacements: 195 return node 196 197 new_nodes = self._prepare_replacement(node, node.id) 198 199 if not new_nodes: 200 return new_nodes 201 202 # Preserve the target context. 203 adjuster = ContextAdjuster(type(node.ctx)) 204 for n in new_nodes: 205 if hasattr(n, 'ctx'): 206 adjuster.visit(n) 207 208 if len(new_nodes) == 1: 209 new_nodes, = new_nodes 210 211 return new_nodes 212 213 214def _convert_to_ast(n): 215 """Converts from a known data type to AST.""" 216 # Note: When generating AST nodes from strings/QNs in isolation, ctx is 217 # unknown. ctx must be filled in according to the template being used. 218 # See ReplaceTransformer.visit_Name. 219 if isinstance(n, str): 220 return gast.Name(id=n, ctx=None, annotation=None, type_comment=None) 221 if isinstance(n, qual_names.QN): 222 return n.ast() 223 if isinstance(n, list): 224 return [_convert_to_ast(e) for e in n] 225 if isinstance(n, tuple): 226 return tuple(_convert_to_ast(e) for e in n) 227 return n 228 229 230def replace(template, **replacements): 231 """Replaces placeholders in a Python template. 232 233 AST Name and Tuple nodes always receive the context that inferred from 234 the template. However, when replacing more complex nodes (that can potentially 235 contain Name children), then the caller is responsible for setting the 236 appropriate context. 237 238 Args: 239 template: A string representing Python code. Any symbol name can be used 240 that appears in the template code can be used as placeholder. 241 **replacements: A mapping from placeholder names to (lists of) AST nodes 242 that these placeholders will be replaced by. String values are also 243 supported as a shorthand for AST Name nodes with the respective ID. 244 245 Returns: 246 An AST node or list of AST nodes with the replacements made. If the 247 template was a function, a list will be returned. If the template was a 248 node, the same node will be returned. If the template was a string, an 249 AST node will be returned (a `Module` node in the case of a multi-line 250 string, an `Expr` node otherwise). 251 252 Raises: 253 ValueError: if the arguments are incorrect. 254 """ 255 if not isinstance(template, str): 256 raise ValueError('Expected string template, got %s' % type(template)) 257 for k in replacements: 258 replacements[k] = _convert_to_ast(replacements[k]) 259 template_str = parser.STANDARD_PREAMBLE + textwrap.dedent(template) 260 nodes = parser.parse( 261 template_str, 262 preamble_len=parser.STANDARD_PREAMBLE_LEN, 263 single_node=False) 264 results = [] 265 for node in nodes: 266 node = ReplaceTransformer(replacements).visit(node) 267 if isinstance(node, (list, tuple)): 268 results.extend(node) 269 else: 270 results.append(node) 271 results = [qual_names.resolve(r) for r in results] 272 return results 273 274 275def replace_as_expression(template, **replacements): 276 """Variant of replace that generates expressions, instead of code blocks.""" 277 replacement = replace(template, **replacements) 278 if len(replacement) != 1: 279 raise ValueError( 280 'single expression expected; for more general templates use replace') 281 node, = replacement 282 283 if isinstance(node, gast.Expr): 284 return node.value 285 elif isinstance(node, gast.Name): 286 return node 287 288 raise ValueError( 289 'the template is expected to generate an expression or a name node;' 290 ' instead found %s' % node) 291