1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""An analysis that determines the reach of a function definition. 16 17A function definition is said to reach a statement if that function may exist 18(and therefore may be called) when that statement executes. 19""" 20 21import gast 22 23from tensorflow.python.autograph.pyct import anno 24from tensorflow.python.autograph.pyct import cfg 25from tensorflow.python.autograph.pyct import transformer 26 27 28class Definition(object): 29 """Definition objects describe a unique definition of a function.""" 30 31 def __init__(self, def_node): 32 self.def_node = def_node 33 34 35class _NodeState(object): 36 """Abstraction for the state of the CFG walk for reaching definition analysis. 37 38 This is a value type. Only implements the strictly necessary operators. 39 40 Attributes: 41 value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and 42 their possible definitions 43 """ 44 45 def __init__(self, init_from=None): 46 if init_from: 47 self.value = set(init_from) 48 else: 49 self.value = set() 50 51 def __eq__(self, other): 52 return self.value == other.value 53 54 def __ne__(self, other): 55 return self.value != other.value 56 57 def __or__(self, other): 58 assert isinstance(other, _NodeState) 59 result = _NodeState(self.value) 60 result.value.update(other.value) 61 return result 62 63 def __add__(self, value): 64 result = _NodeState(self.value) 65 result.value.add(value) 66 return result 67 68 def __repr__(self): 69 return 'NodeState[%s]=%s' % (id(self), repr(self.value)) 70 71 72class Analyzer(cfg.GraphVisitor): 73 """CFG visitor that determines reaching definitions at statement level.""" 74 75 def __init__(self, graph, external_defs): 76 super(Analyzer, self).__init__(graph) 77 # This allows communicating that nodes have extra reaching definitions, 78 # e.g. those that a function closes over. 79 self.external_defs = external_defs 80 81 def init_state(self, _): 82 return _NodeState() 83 84 def visit_node(self, node): 85 prev_defs_out = self.out[node] 86 87 if node is self.graph.entry: 88 defs_in = _NodeState(self.external_defs) 89 else: 90 defs_in = prev_defs_out 91 92 for n in node.prev: 93 defs_in |= self.out[n] 94 95 defs_out = defs_in 96 if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)): 97 defs_out += node.ast_node 98 99 self.in_[node] = defs_in 100 self.out[node] = defs_out 101 102 return prev_defs_out != defs_out 103 104 105class TreeAnnotator(transformer.Base): 106 """AST visitor that annotates each symbol name with its reaching definitions. 107 108 Simultaneously, the visitor runs the dataflow analysis on each function node, 109 accounting for the effect of closures. For example: 110 111 def foo(): 112 def f(): 113 pass 114 def g(): 115 # `def f` reaches here 116 """ 117 118 def __init__(self, source_info, graphs): 119 super(TreeAnnotator, self).__init__(source_info) 120 self.graphs = graphs 121 self.allow_skips = False 122 self.current_analyzer = None 123 124 def _proces_function(self, node): 125 parent_analyzer = self.current_analyzer 126 subgraph = self.graphs[node] 127 128 if (self.current_analyzer is not None 129 and node in self.current_analyzer.graph.index): 130 cfg_node = self.current_analyzer.graph.index[node] 131 defined_in = self.current_analyzer.in_[cfg_node].value 132 else: 133 defined_in = () 134 135 analyzer = Analyzer(subgraph, defined_in) 136 analyzer.visit_forward() 137 138 self.current_analyzer = analyzer 139 node = self.generic_visit(node) 140 self.current_analyzer = parent_analyzer 141 return node 142 143 def visit_FunctionDef(self, node): 144 return self._proces_function(node) 145 146 def visit_Lambda(self, node): 147 return self._proces_function(node) 148 149 def visit(self, node): 150 # This can happen before entering the top level function 151 if (self.current_analyzer is not None 152 and node in self.current_analyzer.graph.index): 153 cfg_node = self.current_analyzer.graph.index[node] 154 anno.setanno(node, anno.Static.DEFINED_FNS_IN, 155 self.current_analyzer.in_[cfg_node].value) 156 157 extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None) 158 if extra_node is not None: 159 cfg_node = self.current_analyzer.graph.index[extra_node] 160 anno.setanno(extra_node, anno.Static.DEFINED_FNS_IN, 161 self.current_analyzer.in_[cfg_node].value) 162 163 return super(TreeAnnotator, self).visit(node) 164 165 166def resolve(node, source_info, graphs): 167 """Resolves reaching definitions for each symbol. 168 169 Args: 170 node: ast.AST 171 source_info: transformer.SourceInfo 172 graphs: Dict[ast.FunctionDef, cfg.Graph] 173 Returns: 174 ast.AST 175 """ 176 visitor = TreeAnnotator(source_info, graphs) 177 node = visitor.visit(node) 178 return node 179