xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/converters/directives.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles directives.
16
17This converter removes the directive functions from the code and moves the
18information they specify into AST annotations. It is a specialized form of
19static analysis, one that is specific to AutoGraph.
20
21Note that this requires that the actual directive functions are static - that
22is, they do not change at runtime. So if you do something like this:
23
24  tf.autograph.set_loop_options = <new function>
25
26Then the directive will may no longer be recognized. Furthermore, if the
27converted function is cached, such an action may be irreversible.
28"""
29
30import inspect
31
32import gast
33
34from tensorflow.python.autograph.core import converter
35from tensorflow.python.autograph.lang import directives
36from tensorflow.python.autograph.pyct import anno
37from tensorflow.python.util import tf_inspect
38
39
40STATIC_VALUE = 'static_value'
41"""Used for AST annotations, see visit_Name."""
42
43
44class _LoopScope(object):
45
46  def __init__(self):
47    self.ast_node = None
48    self.statements_visited = 0
49
50
51def _map_args(call_node, function):
52  """Maps AST call nodes to the actual function's arguments.
53
54  Args:
55    call_node: ast.Call
56    function: Callable[..., Any], the actual function matching call_node
57  Returns:
58    Dict[Text, ast.AST], mapping each of the function's argument names to
59    the respective AST node.
60  Raises:
61      ValueError: if the default arguments are not correctly set
62  """
63  args = call_node.args
64  kwds = {kwd.arg: kwd.value for kwd in call_node.keywords}
65  call_args = tf_inspect.getcallargs(function, *args, **kwds)
66
67  # Keyword arguments not specified in kwds will be mapped to their defaults,
68  # which are Python values. Since we don't currently have a way to transform
69  # those into AST references, we simply remove them. By convention, directives
70  # use UNSPECIFIED as default value for optional arguments. No other
71  # defaults should be present.
72  unexpected_defaults = []
73  for k in call_args:
74    if (k not in kwds
75        and call_args[k] not in args
76        and call_args[k] is not directives.UNSPECIFIED):
77      unexpected_defaults.append(k)
78  if unexpected_defaults:
79    raise ValueError('Unexpected keyword argument values, %s, for function %s'
80                     % (zip(unexpected_defaults,
81                            [call_args[k] for k in unexpected_defaults]),
82                        function))
83  return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED}
84
85
86class DirectivesTransformer(converter.Base):
87  """Parses compiler directives and converts them into AST annotations."""
88
89  def _process_symbol_directive(self, call_node, directive):
90    if len(call_node.args) < 1:
91      raise ValueError('"%s" requires a positional first argument'
92                       ' as the target' % directive.__name__)
93    target = call_node.args[0]
94    defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS)
95    for def_ in defs:
96      def_.directives[directive] = _map_args(call_node, directive)
97    return call_node
98
99  def _process_statement_directive(self, call_node, directive):
100    if self.state[_LoopScope].statements_visited > 1:
101      raise ValueError(
102          '"%s" must be the first statement in the loop block' % (
103              directive.__name__))
104    if self.state[_LoopScope].level < 2:
105      raise ValueError(
106          '"%s" must be used inside a statement' % directive.__name__)
107    target = self.state[_LoopScope].ast_node
108    node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {})
109    node_anno[directive] = _map_args(call_node, directive)
110    anno.setanno(target, anno.Basic.DIRECTIVES, node_anno)
111    return call_node
112
113  def visit_Name(self, node):
114    node = self.generic_visit(node)
115    if isinstance(node.ctx, gast.Load):
116      defs = anno.getanno(node, anno.Static.DEFINITIONS, ())
117      is_defined = bool(defs)
118      if not is_defined and node.id in self.ctx.info.namespace:
119        anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id])
120    return node
121
122  def visit_Attribute(self, node):
123    node = self.generic_visit(node)
124    parent_val = anno.getanno(node.value, STATIC_VALUE, default=None)
125    if parent_val is not None and inspect.ismodule(parent_val):
126      if hasattr(parent_val, node.attr):
127        anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr))
128    return node
129
130  def visit_Assign(self, node):
131    self.state[_LoopScope].statements_visited += 1
132    return self.generic_visit(node)
133
134  def visit_AugAssign(self, node):
135    self.state[_LoopScope].statements_visited += 1
136    return self.generic_visit(node)
137
138  def visit_Expr(self, node):
139    self.state[_LoopScope].statements_visited += 1
140    node = self.generic_visit(node)
141    if isinstance(node.value, gast.Call):
142      call_node = node.value
143      static_val = anno.getanno(call_node.func, STATIC_VALUE, default=None)
144      if static_val is not None:
145        # Note: directive calls are not output in the generated code, hence
146        # the removal from the code by returning None.
147
148        if static_val is directives.set_element_type:
149          self._process_symbol_directive(call_node, static_val)
150          return None
151        elif static_val is directives.set_loop_options:
152          self._process_statement_directive(call_node, static_val)
153          return None
154    return node
155
156  # TODO(mdan): This will be insufficient for other control flow.
157  # That means that if we ever have a directive that affects things other than
158  # loops, we'll need support for parallel scopes, or have multiple converters.
159  def _track_and_visit_loop(self, node):
160    self.state[_LoopScope].enter()
161    self.state[_LoopScope].ast_node = node
162    node = self.generic_visit(node)
163    # Edge case: a loop with just one directive statement would become empty.
164    if not node.body:
165      node.body = [gast.Pass()]
166    self.state[_LoopScope].exit()
167    return node
168
169  def visit_While(self, node):
170    return self._track_and_visit_loop(node)
171
172  def visit_For(self, node):
173    return self._track_and_visit_loop(node)
174
175
176def transform(node, ctx):
177  return DirectivesTransformer(ctx).visit(node)
178