1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Handles directives. 16 17This converter removes the directive functions from the code and moves the 18information they specify into AST annotations. It is a specialized form of 19static analysis, one that is specific to AutoGraph. 20 21Note that this requires that the actual directive functions are static - that 22is, they do not change at runtime. So if you do something like this: 23 24 tf.autograph.set_loop_options = <new function> 25 26Then the directive will may no longer be recognized. Furthermore, if the 27converted function is cached, such an action may be irreversible. 28""" 29 30import inspect 31 32import gast 33 34from tensorflow.python.autograph.core import converter 35from tensorflow.python.autograph.lang import directives 36from tensorflow.python.autograph.pyct import anno 37from tensorflow.python.util import tf_inspect 38 39 40STATIC_VALUE = 'static_value' 41"""Used for AST annotations, see visit_Name.""" 42 43 44class _LoopScope(object): 45 46 def __init__(self): 47 self.ast_node = None 48 self.statements_visited = 0 49 50 51def _map_args(call_node, function): 52 """Maps AST call nodes to the actual function's arguments. 53 54 Args: 55 call_node: ast.Call 56 function: Callable[..., Any], the actual function matching call_node 57 Returns: 58 Dict[Text, ast.AST], mapping each of the function's argument names to 59 the respective AST node. 60 Raises: 61 ValueError: if the default arguments are not correctly set 62 """ 63 args = call_node.args 64 kwds = {kwd.arg: kwd.value for kwd in call_node.keywords} 65 call_args = tf_inspect.getcallargs(function, *args, **kwds) 66 67 # Keyword arguments not specified in kwds will be mapped to their defaults, 68 # which are Python values. Since we don't currently have a way to transform 69 # those into AST references, we simply remove them. By convention, directives 70 # use UNSPECIFIED as default value for optional arguments. No other 71 # defaults should be present. 72 unexpected_defaults = [] 73 for k in call_args: 74 if (k not in kwds 75 and call_args[k] not in args 76 and call_args[k] is not directives.UNSPECIFIED): 77 unexpected_defaults.append(k) 78 if unexpected_defaults: 79 raise ValueError('Unexpected keyword argument values, %s, for function %s' 80 % (zip(unexpected_defaults, 81 [call_args[k] for k in unexpected_defaults]), 82 function)) 83 return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED} 84 85 86class DirectivesTransformer(converter.Base): 87 """Parses compiler directives and converts them into AST annotations.""" 88 89 def _process_symbol_directive(self, call_node, directive): 90 if len(call_node.args) < 1: 91 raise ValueError('"%s" requires a positional first argument' 92 ' as the target' % directive.__name__) 93 target = call_node.args[0] 94 defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS) 95 for def_ in defs: 96 def_.directives[directive] = _map_args(call_node, directive) 97 return call_node 98 99 def _process_statement_directive(self, call_node, directive): 100 if self.state[_LoopScope].statements_visited > 1: 101 raise ValueError( 102 '"%s" must be the first statement in the loop block' % ( 103 directive.__name__)) 104 if self.state[_LoopScope].level < 2: 105 raise ValueError( 106 '"%s" must be used inside a statement' % directive.__name__) 107 target = self.state[_LoopScope].ast_node 108 node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {}) 109 node_anno[directive] = _map_args(call_node, directive) 110 anno.setanno(target, anno.Basic.DIRECTIVES, node_anno) 111 return call_node 112 113 def visit_Name(self, node): 114 node = self.generic_visit(node) 115 if isinstance(node.ctx, gast.Load): 116 defs = anno.getanno(node, anno.Static.DEFINITIONS, ()) 117 is_defined = bool(defs) 118 if not is_defined and node.id in self.ctx.info.namespace: 119 anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id]) 120 return node 121 122 def visit_Attribute(self, node): 123 node = self.generic_visit(node) 124 parent_val = anno.getanno(node.value, STATIC_VALUE, default=None) 125 if parent_val is not None and inspect.ismodule(parent_val): 126 if hasattr(parent_val, node.attr): 127 anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr)) 128 return node 129 130 def visit_Assign(self, node): 131 self.state[_LoopScope].statements_visited += 1 132 return self.generic_visit(node) 133 134 def visit_AugAssign(self, node): 135 self.state[_LoopScope].statements_visited += 1 136 return self.generic_visit(node) 137 138 def visit_Expr(self, node): 139 self.state[_LoopScope].statements_visited += 1 140 node = self.generic_visit(node) 141 if isinstance(node.value, gast.Call): 142 call_node = node.value 143 static_val = anno.getanno(call_node.func, STATIC_VALUE, default=None) 144 if static_val is not None: 145 # Note: directive calls are not output in the generated code, hence 146 # the removal from the code by returning None. 147 148 if static_val is directives.set_element_type: 149 self._process_symbol_directive(call_node, static_val) 150 return None 151 elif static_val is directives.set_loop_options: 152 self._process_statement_directive(call_node, static_val) 153 return None 154 return node 155 156 # TODO(mdan): This will be insufficient for other control flow. 157 # That means that if we ever have a directive that affects things other than 158 # loops, we'll need support for parallel scopes, or have multiple converters. 159 def _track_and_visit_loop(self, node): 160 self.state[_LoopScope].enter() 161 self.state[_LoopScope].ast_node = node 162 node = self.generic_visit(node) 163 # Edge case: a loop with just one directive statement would become empty. 164 if not node.body: 165 node.body = [gast.Pass()] 166 self.state[_LoopScope].exit() 167 return node 168 169 def visit_While(self, node): 170 return self._track_and_visit_loop(node) 171 172 def visit_For(self, node): 173 return self._track_and_visit_loop(node) 174 175 176def transform(node, ctx): 177 return DirectivesTransformer(ctx).visit(node) 178