xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/static_analysis/reaching_fndefs.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""An analysis that determines the reach of a function definition.
16
17A function definition is said to reach a statement if that function may exist
18(and therefore may be called) when that statement executes.
19"""
20
21import gast
22
23from tensorflow.python.autograph.pyct import anno
24from tensorflow.python.autograph.pyct import cfg
25from tensorflow.python.autograph.pyct import transformer
26
27
28class Definition(object):
29  """Definition objects describe a unique definition of a function."""
30
31  def __init__(self, def_node):
32    self.def_node = def_node
33
34
35class _NodeState(object):
36  """Abstraction for the state of the CFG walk for reaching definition analysis.
37
38  This is a value type. Only implements the strictly necessary operators.
39
40  Attributes:
41    value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
42        their possible definitions
43  """
44
45  def __init__(self, init_from=None):
46    if init_from:
47      self.value = set(init_from)
48    else:
49      self.value = set()
50
51  def __eq__(self, other):
52    return self.value == other.value
53
54  def __ne__(self, other):
55    return self.value != other.value
56
57  def __or__(self, other):
58    assert isinstance(other, _NodeState)
59    result = _NodeState(self.value)
60    result.value.update(other.value)
61    return result
62
63  def __add__(self, value):
64    result = _NodeState(self.value)
65    result.value.add(value)
66    return result
67
68  def __repr__(self):
69    return 'NodeState[%s]=%s' % (id(self), repr(self.value))
70
71
72class Analyzer(cfg.GraphVisitor):
73  """CFG visitor that determines reaching definitions at statement level."""
74
75  def __init__(self, graph, external_defs):
76    super(Analyzer, self).__init__(graph)
77    # This allows communicating that nodes have extra reaching definitions,
78    # e.g. those that a function closes over.
79    self.external_defs = external_defs
80
81  def init_state(self, _):
82    return _NodeState()
83
84  def visit_node(self, node):
85    prev_defs_out = self.out[node]
86
87    if node is self.graph.entry:
88      defs_in = _NodeState(self.external_defs)
89    else:
90      defs_in = prev_defs_out
91
92    for n in node.prev:
93      defs_in |= self.out[n]
94
95    defs_out = defs_in
96    if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)):
97      defs_out += node.ast_node
98
99    self.in_[node] = defs_in
100    self.out[node] = defs_out
101
102    return prev_defs_out != defs_out
103
104
105class TreeAnnotator(transformer.Base):
106  """AST visitor that annotates each symbol name with its reaching definitions.
107
108  Simultaneously, the visitor runs the dataflow analysis on each function node,
109  accounting for the effect of closures. For example:
110
111    def foo():
112      def f():
113        pass
114      def g():
115        # `def f` reaches here
116  """
117
118  def __init__(self, source_info, graphs):
119    super(TreeAnnotator, self).__init__(source_info)
120    self.graphs = graphs
121    self.allow_skips = False
122    self.current_analyzer = None
123
124  def _proces_function(self, node):
125    parent_analyzer = self.current_analyzer
126    subgraph = self.graphs[node]
127
128    if (self.current_analyzer is not None
129        and node in self.current_analyzer.graph.index):
130      cfg_node = self.current_analyzer.graph.index[node]
131      defined_in = self.current_analyzer.in_[cfg_node].value
132    else:
133      defined_in = ()
134
135    analyzer = Analyzer(subgraph, defined_in)
136    analyzer.visit_forward()
137
138    self.current_analyzer = analyzer
139    node = self.generic_visit(node)
140    self.current_analyzer = parent_analyzer
141    return node
142
143  def visit_FunctionDef(self, node):
144    return self._proces_function(node)
145
146  def visit_Lambda(self, node):
147    return self._proces_function(node)
148
149  def visit(self, node):
150    # This can happen before entering the top level function
151    if (self.current_analyzer is not None
152        and node in self.current_analyzer.graph.index):
153      cfg_node = self.current_analyzer.graph.index[node]
154      anno.setanno(node, anno.Static.DEFINED_FNS_IN,
155                   self.current_analyzer.in_[cfg_node].value)
156
157    extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
158    if extra_node is not None:
159      cfg_node = self.current_analyzer.graph.index[extra_node]
160      anno.setanno(extra_node, anno.Static.DEFINED_FNS_IN,
161                   self.current_analyzer.in_[cfg_node].value)
162
163    return super(TreeAnnotator, self).visit(node)
164
165
166def resolve(node, source_info, graphs):
167  """Resolves reaching definitions for each symbol.
168
169  Args:
170    node: ast.AST
171    source_info: transformer.SourceInfo
172    graphs: Dict[ast.FunctionDef, cfg.Graph]
173  Returns:
174    ast.AST
175  """
176  visitor = TreeAnnotator(source_info, graphs)
177  node = visitor.visit(node)
178  return node
179