1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converts function definitions and lambdas by adding necessary boilerplate.""" 16 17import gast 18 19from tensorflow.python.autograph.core import converter 20from tensorflow.python.autograph.pyct import anno 21from tensorflow.python.autograph.pyct import parser 22from tensorflow.python.autograph.pyct import qual_names 23from tensorflow.python.autograph.pyct import templates 24from tensorflow.python.autograph.pyct.static_analysis import activity 25from tensorflow.python.autograph.pyct.static_analysis import annos 26 27 28class _Function(object): 29 30 def __init__(self): 31 self.context_name = None 32 33 34class FunctionTransformer(converter.Base): 35 """Wraps function bodies around autograph-specific boilerplate.""" 36 37 def _function_scope_options(self, fn_scope): 38 """Returns the options with which to create function scopes.""" 39 # Top-level function receive the options that were directly requested. 40 # All others receive the options corresponding to a recursive conversion. 41 # Note: this mainly controls the user_requested flag, which is important 42 # primarily because the FunctionScope context also creates a 43 # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See 44 # function_wrappers.py. 45 if fn_scope.level == 2: 46 return self.ctx.user.options 47 return self.ctx.user.options.call_options() 48 49 def visit_Lambda(self, node): 50 with self.state[_Function] as fn_scope: 51 node = self.generic_visit(node) 52 53 # TODO(mdan): Fix the tests so that we can always add this decorator. 54 if fn_scope.level > 2: 55 return templates.replace_as_expression( 56 'ag__.autograph_artifact(l)', l=node) 57 58 scope = anno.getanno(node, anno.Static.SCOPE) 59 function_context_name = self.ctx.namer.new_symbol('lscope', 60 scope.referenced) 61 fn_scope.context_name = function_context_name 62 anno.setanno(node, 'function_context_name', function_context_name) 63 64 template = """ 65 ag__.with_function_scope( 66 lambda function_context: body, function_context_name, options) 67 """ 68 node.body = templates.replace_as_expression( 69 template, 70 options=self._function_scope_options(fn_scope).to_ast(), 71 function_context=function_context_name, 72 function_context_name=gast.Constant(function_context_name, kind=None), 73 body=node.body) 74 75 return node 76 77 def visit_FunctionDef(self, node): 78 with self.state[_Function] as fn_scope: 79 scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 80 81 function_context_name = self.ctx.namer.new_symbol('fscope', 82 scope.referenced) 83 fn_scope.context_name = function_context_name 84 anno.setanno(node, 'function_context_name', function_context_name) 85 86 node = self.generic_visit(node) 87 88 if fn_scope.level <= 2: 89 # Top-level functions lose their decorator because the conversion is 90 # always just-in-time and by the time it happens the decorators are 91 # already set to be applied. 92 node.decorator_list = [] 93 else: 94 # TODO(mdan): Fix the tests so that we can always add this decorator. 95 # Inner functions are converted already, so we insert a decorator to 96 # prevent double conversion. Double conversion would work too, but this 97 # saves the overhead. 98 node.decorator_list.append( 99 parser.parse_expression('ag__.autograph_artifact')) 100 101 docstring_node = None 102 if node.body: 103 first_statement = node.body[0] 104 if (isinstance(first_statement, gast.Expr) and 105 isinstance(first_statement.value, gast.Constant)): 106 docstring_node = first_statement 107 node.body = node.body[1:] 108 109 template = """ 110 with ag__.FunctionScope( 111 function_name, context_name, options) as function_context: 112 body 113 """ 114 wrapped_body = templates.replace( 115 template, 116 function_name=gast.Constant(node.name, kind=None), 117 context_name=gast.Constant(function_context_name, kind=None), 118 options=self._function_scope_options(fn_scope).to_ast(), 119 function_context=function_context_name, 120 body=node.body) 121 122 if docstring_node is not None: 123 wrapped_body = [docstring_node] + wrapped_body 124 125 node.body = wrapped_body 126 127 return node 128 129 130def transform(node, ctx): 131 node = qual_names.resolve(node) 132 node = activity.resolve(node, ctx, None) 133 134 return FunctionTransformer(ctx).visit(node) 135