xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/converters/functions.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converts function definitions and lambdas by adding necessary boilerplate."""
16
17import gast
18
19from tensorflow.python.autograph.core import converter
20from tensorflow.python.autograph.pyct import anno
21from tensorflow.python.autograph.pyct import parser
22from tensorflow.python.autograph.pyct import qual_names
23from tensorflow.python.autograph.pyct import templates
24from tensorflow.python.autograph.pyct.static_analysis import activity
25from tensorflow.python.autograph.pyct.static_analysis import annos
26
27
28class _Function(object):
29
30  def __init__(self):
31    self.context_name = None
32
33
34class FunctionTransformer(converter.Base):
35  """Wraps function bodies around autograph-specific boilerplate."""
36
37  def _function_scope_options(self, fn_scope):
38    """Returns the options with which to create function scopes."""
39    # Top-level function receive the options that were directly requested.
40    # All others receive the options corresponding to a recursive conversion.
41    # Note: this mainly controls the user_requested flag, which is important
42    # primarily because the FunctionScope context also creates a
43    # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
44    # function_wrappers.py.
45    if fn_scope.level == 2:
46      return self.ctx.user.options
47    return self.ctx.user.options.call_options()
48
49  def visit_Lambda(self, node):
50    with self.state[_Function] as fn_scope:
51      node = self.generic_visit(node)
52
53      # TODO(mdan): Fix the tests so that we can always add this decorator.
54      if fn_scope.level > 2:
55        return templates.replace_as_expression(
56            'ag__.autograph_artifact(l)', l=node)
57
58      scope = anno.getanno(node, anno.Static.SCOPE)
59      function_context_name = self.ctx.namer.new_symbol('lscope',
60                                                        scope.referenced)
61      fn_scope.context_name = function_context_name
62      anno.setanno(node, 'function_context_name', function_context_name)
63
64      template = """
65        ag__.with_function_scope(
66            lambda function_context: body, function_context_name, options)
67      """
68      node.body = templates.replace_as_expression(
69          template,
70          options=self._function_scope_options(fn_scope).to_ast(),
71          function_context=function_context_name,
72          function_context_name=gast.Constant(function_context_name, kind=None),
73          body=node.body)
74
75      return node
76
77  def visit_FunctionDef(self, node):
78    with self.state[_Function] as fn_scope:
79      scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
80
81      function_context_name = self.ctx.namer.new_symbol('fscope',
82                                                        scope.referenced)
83      fn_scope.context_name = function_context_name
84      anno.setanno(node, 'function_context_name', function_context_name)
85
86      node = self.generic_visit(node)
87
88      if fn_scope.level <= 2:
89        # Top-level functions lose their decorator because the conversion is
90        # always just-in-time and by the time it happens the decorators are
91        # already set to be applied.
92        node.decorator_list = []
93      else:
94        # TODO(mdan): Fix the tests so that we can always add this decorator.
95        # Inner functions are converted already, so we insert a decorator to
96        # prevent double conversion. Double conversion would work too, but this
97        # saves the overhead.
98        node.decorator_list.append(
99            parser.parse_expression('ag__.autograph_artifact'))
100
101      docstring_node = None
102      if node.body:
103        first_statement = node.body[0]
104        if (isinstance(first_statement, gast.Expr) and
105            isinstance(first_statement.value, gast.Constant)):
106          docstring_node = first_statement
107          node.body = node.body[1:]
108
109      template = """
110        with ag__.FunctionScope(
111            function_name, context_name, options) as function_context:
112          body
113      """
114      wrapped_body = templates.replace(
115          template,
116          function_name=gast.Constant(node.name, kind=None),
117          context_name=gast.Constant(function_context_name, kind=None),
118          options=self._function_scope_options(fn_scope).to_ast(),
119          function_context=function_context_name,
120          body=node.body)
121
122      if docstring_node is not None:
123        wrapped_body = [docstring_node] + wrapped_body
124
125      node.body = wrapped_body
126
127      return node
128
129
130def transform(node, ctx):
131  node = qual_names.resolve(node)
132  node = activity.resolve(node, ctx, None)
133
134  return FunctionTransformer(ctx).visit(node)
135