xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/anno.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST node annotation support.
16
17Adapted from Tangent.
18"""
19
20import enum
21
22# pylint:disable=g-bad-import-order
23
24import gast
25# pylint:enable=g-bad-import-order
26
27
28# TODO(mdan): Shorten the names.
29# These names are heavily used, and anno.blaa
30# TODO(mdan): Replace the attr-dict mechanism with a more typed solution.
31
32
33class NoValue(enum.Enum):
34  """Base class for different types of AST annotations."""
35
36  def of(self, node, default=None):
37    return getanno(node, self, default=default)
38
39  def add_to(self, node, value):
40    setanno(node, self, value)
41
42  def exists(self, node):
43    return hasanno(node, self)
44
45  def __repr__(self):
46    return str(self.name)
47
48
49class Basic(NoValue):
50  """Container for basic annotation keys.
51
52  The enum values are used strictly for documentation purposes.
53  """
54
55  QN = 'Qualified name, as it appeared in the code. See qual_names.py.'
56  SKIP_PROCESSING = (
57      'This node should be preserved as is and not processed any further.')
58  INDENT_BLOCK_REMAINDER = (
59      'When a node is annotated with this, the remainder of the block should'
60      ' be indented below it. The annotation contains a tuple'
61      ' (new_body, name_map), where `new_body` is the new indented block and'
62      ' `name_map` allows renaming symbols.')
63  ORIGIN = ('Information about the source code that converted code originated'
64            ' from. See origin_information.py.')
65  DIRECTIVES = ('User directives associated with a statement or a variable.'
66                ' Typically, they affect the immediately-enclosing statement.')
67
68  EXTRA_LOOP_TEST = (
69      'A special annotation containing additional test code to be executed in'
70      ' for loops.')
71
72
73class Static(NoValue):
74  """Container for static analysis annotation keys.
75
76  The enum values are used strictly for documentation purposes.
77  """
78
79  # Symbols
80  # These flags are boolean.
81  IS_PARAM = 'Symbol is a parameter to the function being analyzed.'
82
83  # Scopes
84  # Scopes are represented by objects of type activity.Scope.
85  SCOPE = 'The scope for the annotated node. See activity.py.'
86  # TODO(mdan): Drop these in favor of accessing the child's SCOPE.
87  ARGS_SCOPE = 'The scope for the argument list of a function call.'
88  COND_SCOPE = 'The scope for the test node of a conditional statement.'
89  BODY_SCOPE = (
90      'The scope for the main body of a statement (True branch for if '
91      'statements, main body for loops).')
92  ORELSE_SCOPE = (
93      'The scope for the orelse body of a statement (False branch for if '
94      'statements, orelse body for loops).')
95
96  # Static analysis annotations.
97  DEFINITIONS = (
98      'Reaching definition information. See reaching_definitions.py.')
99  ORIG_DEFINITIONS = (
100      'The value of DEFINITIONS that applied to the original code before any'
101      ' conversion.')
102  DEFINED_FNS_IN = (
103      'Local function definitions that may exist when exiting the node. See'
104      ' reaching_fndefs.py')
105  DEFINED_VARS_IN = (
106      'Symbols defined when entering the node. See reaching_definitions.py.')
107  LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.')
108  LIVE_VARS_IN = ('Symbols live when entering the node. See liveness.py.')
109  TYPES = 'Static type information. See type_inference.py.'
110  CLOSURE_TYPES = 'Types of closure symbols at each detected call site.'
111  VALUE = 'Static value information. See type_inference.py.'
112
113
114FAIL = object()
115
116
117def keys(node, field_name='___pyct_anno'):
118  if not hasattr(node, field_name):
119    return frozenset()
120  return frozenset(getattr(node, field_name).keys())
121
122
123def getanno(node, key, default=FAIL, field_name='___pyct_anno'):
124  if (default is FAIL or (hasattr(node, field_name) and
125                          (key in getattr(node, field_name)))):
126    return getattr(node, field_name)[key]
127  return default
128
129
130def hasanno(node, key, field_name='___pyct_anno'):
131  return hasattr(node, field_name) and key in getattr(node, field_name)
132
133
134def setanno(node, key, value, field_name='___pyct_anno'):
135  annotations = getattr(node, field_name, {})
136  setattr(node, field_name, annotations)
137  annotations[key] = value
138
139  # So that the annotations survive gast_to_ast() and ast_to_gast()
140  if field_name not in node._fields:
141    node._fields += (field_name,)
142
143
144def delanno(node, key, field_name='___pyct_anno'):
145  annotations = getattr(node, field_name)
146  del annotations[key]
147  if not annotations:
148    delattr(node, field_name)
149    node._fields = tuple(f for f in node._fields if f != field_name)
150
151
152def copyanno(from_node, to_node, key, field_name='___pyct_anno'):
153  if hasanno(from_node, key, field_name=field_name):
154    setanno(
155        to_node,
156        key,
157        getanno(from_node, key, field_name=field_name),
158        field_name=field_name)
159
160
161def dup(node, copy_map, field_name='___pyct_anno'):
162  """Recursively copies annotations in an AST tree.
163
164  Args:
165    node: ast.AST
166    copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination
167        key. All annotations with the source key will be copied to identical
168        annotations with the destination key.
169    field_name: str
170  """
171  for n in gast.walk(node):
172    for k in copy_map:
173      if hasanno(n, k, field_name):
174        setanno(n, copy_map[k], getanno(n, k, field_name), field_name)
175