xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/static_analysis/liveness.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Live variable analysis.
16
17See https://en.wikipedia.org/wiki/Live_variable_analysis for a definition of
18the following idioms: live variable, live in, live out, which are used
19throughout this file.
20
21This analysis attaches the following:
22 * symbols that are live at the exit of control flow statements
23 * symbols that are live at the entry of control flow statements
24
25Requires activity analysis.
26"""
27
28import gast
29
30from tensorflow.python.autograph.pyct import anno
31from tensorflow.python.autograph.pyct import cfg
32from tensorflow.python.autograph.pyct import transformer
33from tensorflow.python.autograph.pyct.static_analysis import annos
34
35
36class Analyzer(cfg.GraphVisitor):
37  """CFG visitor that performs liveness analysis at statement level."""
38
39  def __init__(self, graph, include_annotations):
40    super(Analyzer, self).__init__(graph)
41    self.include_annotations = include_annotations
42
43  def init_state(self, _):
44    return set()
45
46  def visit_node(self, node):
47    prev_live_in = self.in_[node]
48
49    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
50      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
51
52      gen = node_scope.read
53      if not self.include_annotations:
54        gen -= node_scope.annotations
55      # TODO(mdan): verify whether composites' parents need to be added.
56      # E.g. whether x needs to be added if x.y is live. Theoretically the
57      # activity analysis should have both so that wouldn't be needed.
58      kill = node_scope.modified | node_scope.deleted
59
60      live_out = set()
61      for n in node.next:
62        live_out |= self.in_[n]
63      live_in = gen | (live_out - kill)
64
65      reaching_functions = anno.getanno(
66          node.ast_node, anno.Static.DEFINED_FNS_IN)
67      for fn_ast_node in reaching_functions:
68        if isinstance(fn_ast_node, gast.Lambda):
69          # Exception: lambda functions are assumed to be used only in the
70          # place where they are defined, and not later.
71          continue
72        fn_scope = anno.getanno(fn_ast_node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
73        # Any closure of a reaching function definition is conservatively
74        # considered live.
75        live_in |= (fn_scope.read - fn_scope.bound)
76
77    else:
78      assert self.can_ignore(node), (node.ast_node, node)
79
80      live_out = set()
81      for n in node.next:
82        live_out |= self.in_[n]
83      live_in = live_out
84
85    self.in_[node] = live_in
86    self.out[node] = live_out
87
88    # TODO(mdan): Move this to the superclass?
89    return prev_live_in != live_in
90
91
92class TreeAnnotator(transformer.Base):
93  """Runs liveness analysis on each of the functions defined in the AST.
94
95  If a function defined other local functions, those will have separate CFGs.
96  However, dataflow analysis needs to tie up these CFGs to properly emulate the
97  effect of closures. In the case of liveness, the parent function's live
98  variables must account for the variables that are live at the entry of each
99  subfunction. For example:
100
101    def foo():
102      # baz is live from here on
103      def bar():
104        print(baz)
105
106  This analyzer runs liveness analysis on each individual function, accounting
107  for the effect above.
108  """
109
110  def __init__(self, source_info, graphs, include_annotations):
111    super(TreeAnnotator, self).__init__(source_info)
112    self.include_annotations = include_annotations
113    self.allow_skips = False
114    self.graphs = graphs
115    self.current_analyzer = None
116
117  def visit(self, node):
118    node = super(TreeAnnotator, self).visit(node)
119    if (self.current_analyzer is not None and
120        isinstance(node, gast.stmt) and
121        node in self.current_analyzer.graph.index):
122      cfg_node = self.current_analyzer.graph.index[node]
123      anno.setanno(node, anno.Static.LIVE_VARS_IN,
124                   frozenset(self.current_analyzer.in_[cfg_node]))
125    return node
126
127  def _analyze_function(self, node, is_lambda):
128    parent_analyzer = self.current_analyzer
129
130    analyzer = Analyzer(self.graphs[node], self.include_annotations)
131    analyzer.visit_reverse()
132    self.current_analyzer = analyzer
133    node = self.generic_visit(node)
134
135    self.current_analyzer = parent_analyzer
136    return node
137
138  def visit_Lambda(self, node):
139    return self._analyze_function(node, is_lambda=True)
140
141  def visit_FunctionDef(self, node):
142    return self._analyze_function(node, is_lambda=False)
143
144  def _block_statement_live_out(self, node):
145    successors = self.current_analyzer.graph.stmt_next[node]
146    stmt_live_out = set()
147    for s in successors:
148      stmt_live_out.update(self.current_analyzer.in_[s])
149    anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out))
150    return node
151
152  def _block_statement_live_in(self, node, entry_node):
153    if entry_node in self.current_analyzer.graph.index:
154      cfg_node = self.current_analyzer.graph.index[entry_node]
155      stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node])
156    else:
157      assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), (
158          'If not matching a CFG node, must be a block statement:'
159          ' {}'.format(entry_node))
160      stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN)
161    anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in)
162    return node
163
164  def visit_If(self, node):
165    node = self.generic_visit(node)
166    node = self._block_statement_live_out(node)
167    return self._block_statement_live_in(node, node.test)
168
169  def visit_For(self, node):
170    node = self.generic_visit(node)
171    node = self._block_statement_live_out(node)
172    return self._block_statement_live_in(node, node.iter)
173
174  def visit_While(self, node):
175    node = self.generic_visit(node)
176    node = self._block_statement_live_out(node)
177    return self._block_statement_live_in(node, node.test)
178
179  def visit_Try(self, node):
180    node = self.generic_visit(node)
181    node = self._block_statement_live_out(node)
182    return self._block_statement_live_in(node, node.body[0])
183
184  def visit_ExceptHandler(self, node):
185    node = self.generic_visit(node)
186    node = self._block_statement_live_out(node)
187    return self._block_statement_live_in(node, node.body[0])
188
189  def visit_With(self, node):
190    node = self.generic_visit(node)
191    return self._block_statement_live_in(node, node.items[0])
192
193  def visit_Expr(self, node):
194    node = self.generic_visit(node)
195    cfg_node = self.current_analyzer.graph.index[node]
196    anno.setanno(node, anno.Static.LIVE_VARS_OUT,
197                 frozenset(self.current_analyzer.out[cfg_node]))
198    return node
199
200
201# TODO(mdan): Investigate the possibility of removing include_annotations.
202def resolve(node, source_info, graphs, include_annotations=True):
203  """Resolves the live symbols at the exit of control flow statements.
204
205  Args:
206    node: ast.AST
207    source_info: transformer.SourceInfo
208    graphs: Dict[ast.FunctionDef, cfg.Graph]
209    include_annotations: Bool, whether type annotations should be included in
210      the analysis.
211  Returns:
212    ast.AST
213  """
214  node = TreeAnnotator(source_info, graphs, include_annotations).visit(node)
215  return node
216