xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/core/converter.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converter construction support.
16
17This module contains a base class for all converters, as well as supporting
18structures. These structures are referred to as contexts.
19
20The class hierarchy is as follows:
21
22    <your converter>
23      [extends] converter.Base
24        [extends] transformer.Base
25            [extends] gast.nodeTransformer
26          [uses] transformer.SourceInfo
27        [uses] converter.EntityContext
28          [uses] converter.ProgramContext
29          [uses] transformer.SourceInfo
30
31converter.Base is a specialization of transformer.Base for AutoGraph. It's a
32very lightweight subclass that adds a `ctx` attribute holding the corresponding
33EntityContext object (see below). Note that converters are not reusable, and
34`visit` will raise an error if called more than once.
35
36converter.EntityContext contains mutable state associated with an entity that
37the converter processes.
38
39converter.ProgramContext contains mutable state across related entities. For
40example, when converting several functions that call one another, the
41ProgramContext should be shared across these entities.
42
43Below is the overall flow at conversion:
44
45    program_ctx = ProgramContext(<entities to convert>, <global settings>, ...)
46    while <program_ctx has more entities to convert>:
47      entity, source_info = <get next entity from program_ctx>
48      entity_ctx = EntityContext(program_ctx, source_info)
49      for <each ConverterClass>:
50        converter = ConverterClass(entity_ctx)
51
52        # May update entity_ctx and program_ctx
53        entity = converter.visit(entity)
54
55      <add entity's dependencies to program_ctx>
56
57Note that pyct contains a small number of transformers used for static analysis.
58These implement transformer.Base, rather than converter.Base, to avoid a
59dependency on AutoGraph.
60"""
61
62import enum
63
64from tensorflow.python.autograph.pyct import anno
65from tensorflow.python.autograph.pyct import ast_util
66from tensorflow.python.autograph.pyct import parser
67from tensorflow.python.autograph.pyct import templates
68from tensorflow.python.autograph.pyct import transformer
69from tensorflow.python.util.tf_export import tf_export
70
71# TODO(mdan): These contexts can be refactored into first class objects.
72# For example, we could define Program and Entity abstractions that hold on
73# to the actual entity and have conversion methods.
74
75# TODO(mdan): Add a test specific to this converter.
76
77
78@tf_export('autograph.experimental.Feature')
79class Feature(enum.Enum):
80  """This enumeration represents optional conversion options.
81
82  These conversion options are experimental. They are subject to change without
83  notice and offer no guarantees.
84
85  _Example Usage_
86
87  ```python
88  optionals= tf.autograph.experimental.Feature.EQUALITY_OPERATORS
89  @tf.function(experimental_autograph_options=optionals)
90  def f(i):
91    if i == 0:  # EQUALITY_OPERATORS allows the use of == here.
92      tf.print('i is zero')
93  ```
94
95  Attributes:
96    ALL: Enable all features.
97    AUTO_CONTROL_DEPS: Insert of control dependencies in the generated code.
98    ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert.
99    BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to
100      their TF counterparts.
101    EQUALITY_OPERATORS: Whether to convert the equality operator ('==') to
102      tf.math.equal.
103    LISTS: Convert list idioms, like initializers, slices, append, etc.
104    NAME_SCOPES: Insert name scopes that name ops according to context, like the
105      function they were defined in.
106  """
107
108  ALL = 'ALL'
109
110  AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS'
111  ASSERT_STATEMENTS = 'ASSERT_STATEMENTS'
112  BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS'
113  EQUALITY_OPERATORS = 'EQUALITY_OPERATORS'
114  LISTS = 'LISTS'
115  NAME_SCOPES = 'NAME_SCOPES'
116
117  @classmethod
118  def all(cls):
119    """Returns a tuple that enables all options."""
120    return tuple(cls.__members__.values())
121
122  @classmethod
123  def all_but(cls, exclude):
124    """Returns a tuple that enables all but the excluded options."""
125    if not isinstance(exclude, (list, tuple, set)):
126      exclude = (exclude,)
127    return tuple(set(cls.all()) - set(exclude) - {cls.ALL})
128
129
130STANDARD_OPTIONS = None  # Forward definition.
131
132
133class ConversionOptions(object):
134  """Immutable container for global conversion flags.
135
136  Attributes:
137    recursive: bool, whether to recursively convert any user functions or
138      classes that the converted function may use.
139    user_requested: bool, whether the conversion was explicitly requested by
140      the user, as opposed to being performed as a result of other logic. This
141      value always auto-resets to False in child conversions.
142    optional_features: Union[Feature, Set[Feature]], controls the use of
143      optional features in the conversion process. See Feature for available
144      options.
145  """
146
147  def __init__(self,
148               recursive=False,
149               user_requested=False,
150               internal_convert_user_code=True,
151               optional_features=Feature.ALL):
152    self.recursive = recursive
153    self.user_requested = user_requested
154    # TODO(mdan): Rename to conversion_recursion_depth?
155    self.internal_convert_user_code = internal_convert_user_code
156
157    if optional_features is None:
158      optional_features = ()
159    elif isinstance(optional_features, Feature):
160      optional_features = (optional_features,)
161    optional_features = frozenset(optional_features)
162    self.optional_features = optional_features
163
164  def as_tuple(self):
165    return (self.recursive, self.user_requested,
166            self.internal_convert_user_code, self.optional_features)
167
168  def __hash__(self):
169    return hash(self.as_tuple())
170
171  def __eq__(self, other):
172    assert isinstance(other, ConversionOptions)
173    return self.as_tuple() == other.as_tuple()
174
175  def __str__(self):
176    return 'ConversionOptions[{}]'
177
178  def uses(self, feature):
179    return (Feature.ALL in self.optional_features or
180            feature in self.optional_features)
181
182  def call_options(self):
183    """Returns the corresponding options to be used for recursive conversion."""
184    return ConversionOptions(
185        recursive=self.recursive,
186        user_requested=False,
187        internal_convert_user_code=self.recursive,
188        optional_features=self.optional_features)
189
190  def to_ast(self):
191    """Returns a representation of this object as an AST node.
192
193    The AST node encodes a constructor that would create an object with the
194    same contents.
195
196    Returns:
197      ast.Node
198    """
199    if self == STANDARD_OPTIONS:
200      return parser.parse_expression('ag__.STD')
201
202    template = """
203      ag__.ConversionOptions(
204          recursive=recursive_val,
205          user_requested=user_requested_val,
206          optional_features=optional_features_val,
207          internal_convert_user_code=internal_convert_user_code_val)
208    """
209
210    def list_of_features(values):
211      return parser.parse_expression('({})'.format(', '.join(
212          'ag__.{}'.format(str(v)) for v in values)))
213
214    expr_ast = templates.replace(
215        template,
216        recursive_val=parser.parse_expression(str(self.recursive)),
217        user_requested_val=parser.parse_expression(str(self.user_requested)),
218        internal_convert_user_code_val=parser.parse_expression(
219            str(self.internal_convert_user_code)),
220        optional_features_val=list_of_features(self.optional_features))
221    return expr_ast[0].value
222
223
224STANDARD_OPTIONS = ConversionOptions(
225    recursive=True,
226    user_requested=False,
227    internal_convert_user_code=True,
228    optional_features=None)
229
230
231class ProgramContext(object):
232  """ProgramContext keeps track of converting function hierarchies.
233
234  Attributes:
235    options: ConversionOptions
236    autograph_module: Deprecated. Do not use.
237  """
238
239  def __init__(self, options, autograph_module=None):
240    self.options = options
241    self.autograph_module = autograph_module
242
243
244class Base(transformer.Base):
245  """All converters should inherit from this class.
246
247  Attributes:
248    ctx: EntityContext
249  """
250
251  def __init__(self, ctx):
252    super(Base, self).__init__(ctx)
253
254    self._used = False
255    self._ast_depth = 0
256
257  def get_definition_directive(self, node, directive, arg, default):
258    """Returns the unique directive argument for a symbol.
259
260    See lang/directives.py for details on directives.
261
262    Example:
263       # Given a directive in the code:
264       ag.foo_directive(bar, baz=1)
265
266       # One can write for an AST node Name(id='bar'):
267       get_definition_directive(node, ag.foo_directive, 'baz')
268
269    Args:
270      node: ast.AST, the node representing the symbol for which the directive
271        argument is needed.
272      directive: Callable[..., Any], the directive to search.
273      arg: str, the directive argument to return.
274      default: Any
275
276    Raises:
277      ValueError: if conflicting annotations have been found
278    """
279    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
280    if not defs:
281      return default
282
283    arg_values_found = []
284    for def_ in defs:
285      if (directive in def_.directives and arg in def_.directives[directive]):
286        arg_values_found.append(def_.directives[directive][arg])
287
288    if not arg_values_found:
289      return default
290
291    if len(arg_values_found) == 1:
292      return arg_values_found[0]
293
294    # If multiple annotations reach the symbol, they must all match. If they do,
295    # return any of them.
296    first_value = arg_values_found[0]
297    for other_value in arg_values_found[1:]:
298      if not ast_util.matches(first_value, other_value):
299        qn = anno.getanno(node, anno.Basic.QN)
300        raise ValueError(
301            '%s has ambiguous annotations for %s(%s): %s, %s' %
302            (qn, directive.__name__, arg, parser.unparse(other_value).strip(),
303             parser.unparse(first_value).strip()))
304    return first_value
305
306  def visit(self, node):
307    if not self._ast_depth:
308      if self._used:
309        raise ValueError('converter objects cannot be reused')
310      self._used = True
311
312    self._ast_depth += 1
313    try:
314      return super(Base, self).visit(node)
315    finally:
316      self._ast_depth -= 1
317