1# mypy: allow-untyped-defs 2import operator 3from collections import defaultdict 4from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type 5 6import sympy 7 8import torch 9import torch.fx 10from torch.fx.experimental.symbolic_shapes import ( 11 compute_unbacked_bindings, 12 rebind_unbacked, 13 statically_known_true, 14 sym_eq, 15) 16from torch.utils import _pytree as pytree 17from torch.utils._pytree import tree_map 18 19from .virtualized import V 20 21 22# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched. 23# Works for length 2 patterns with 1 module and 1 function/method. 24def matches_module_function_pattern( 25 pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]], 26 node: torch.fx.node.Node, 27 modules: Dict[str, torch.nn.modules.Module], 28) -> bool: 29 if len(node.args) == 0: 30 return False 31 if not isinstance(node.args[0], torch.fx.Node) or not isinstance( 32 node, torch.fx.Node 33 ): 34 return False 35 # the first node is call_module 36 if node.args[0].op != "call_module": 37 return False 38 if not isinstance(node.args[0].target, str): 39 return False 40 if node.args[0].target not in modules: 41 return False 42 if type(modules[node.args[0].target]) is not pattern[0]: 43 return False 44 # the second node is call_function or call_method 45 if node.op != "call_function" and node.op != "call_method": 46 return False 47 if node.target != pattern[1]: 48 return False 49 # make sure node.args[0] output is only used by current node. 50 if len(node.args[0].users) > 1: 51 return False 52 return True 53 54 55class FakeTensorUpdater: 56 """ 57 The main idea here is that it's difficult to maintain accurate fake 58 tensors (our primary form of metadata) for each node in our graph as we 59 transform it. 60 61 The most reliable way to obtain this information is by rerunning 62 faketensor propagation. However, in general, faketensor propagation is 63 fairly expensive. So, instead we'd like to only rerun faketensor 64 propagation on nodes that have changed. 65 66 In order to detect which nodes have changed, we first hash its node, 67 target, and argument lists (which are immutable in FX). 68 69 Then, whenever we call incremental_update, we check which FX nodes have a 70 new hash, and recompute the faketensor metadata for that node. Then, we 71 continue to recursively compute the faketensors for all users until the 72 fake tensors stop changing. 73 """ 74 75 def __init__(self, graph: torch.fx.Graph) -> None: 76 self.processed_hashes = set() 77 self.graph = graph 78 79 for node in self.graph.nodes: 80 self.processed_hashes.add(self.hash_node(node)) 81 82 def hash_node(self, node: torch.fx.Node): 83 # todo(chilli): Not a great hash function 84 return (node, node.target, id(node.args), id(node.kwargs)) 85 86 def incremental_update(self): 87 processed = set() 88 existing_storages: DefaultDict[Optional[int], int] = defaultdict(int) 89 for node in self.graph.nodes: 90 existing_storages[get_node_storage(node)] += 1 91 92 def is_intlist_same(new, old): 93 return statically_known_true(sym_eq(new, old)) 94 95 def is_fake_tensor_same(new, old): 96 if type(new) != type(old): 97 return False 98 if isinstance(new, (list, tuple)): 99 if len(new) != len(old): 100 return False 101 return all( 102 is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old) 103 ) 104 if new is None: 105 return old is None 106 if not isinstance(new, torch.Tensor): 107 assert isinstance( 108 new, (torch.SymInt, torch.SymBool, torch.SymFloat) 109 ), f"Unknown type {type(new)} in {self.graph}" 110 return ( 111 new.node.shape_env._maybe_evaluate_static( 112 sympy.Eq(new.node.expr, old.node.expr) 113 ) 114 == sympy.true 115 ) 116 if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout: 117 return False 118 if new.layout == torch.strided and ( 119 not is_intlist_same(new.stride(), old.stride()) 120 or not statically_known_true( 121 new.storage_offset() == old.storage_offset() 122 ) 123 ): 124 return False 125 126 if new.device != old.device: 127 return False 128 129 if get_storage(new) == get_storage(old): 130 return True 131 132 # This is the case where it returns a completely fresh storage that's used nowhere else. 133 if ( 134 existing_storages[get_storage(old)] == 1 135 and get_storage(new) not in existing_storages 136 ): 137 return True 138 return False 139 140 def should_process_node(node): 141 # node.target for nodes returning true from this function 142 # are called under fake mode and does not work for inductor 143 # lowerings. We check if the node.target is an aten operator 144 # or operator.getitem which is used when returning multiple 145 # tensors from an op. 146 return node.op == "call_function" and ( 147 isinstance(node.target, torch._ops.OpOverload) 148 or node.target == operator.getitem 149 ) 150 151 to_process = set() 152 for node in self.graph.nodes: 153 if ( 154 self.hash_node(node) in self.processed_hashes 155 and id(node) not in to_process 156 ): 157 continue 158 159 if not should_process_node(node): 160 continue 161 162 is_valid, args, kwargs = get_fake_args_kwargs(node) 163 if not is_valid: 164 continue 165 with V.fake_mode: 166 new_fake_tensor = node.target(*args, **kwargs) 167 if "val" in node.meta and is_fake_tensor_same( 168 new_fake_tensor, node.meta["val"] 169 ): 170 continue 171 172 rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor) 173 174 node.meta["val"] = new_fake_tensor 175 if (shape_env := V.fake_mode.shape_env) and ( 176 symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor) 177 ): 178 # Refresh the bindings to the new symbols 179 node.meta["unbacked_bindings"] = symbol_to_path 180 181 existing_storages[get_node_storage(node)] += 1 182 183 to_process.update([id(user) for user in node.users]) 184 185 self.processed_hashes.add(self.hash_node(node)) 186 187 188def get_storage(t: torch.Tensor) -> int: 189 return t.untyped_storage()._cdata 190 191 192def get_node_storage(node: torch.fx.Node) -> Optional[int]: 193 if "val" not in node.meta: 194 return None 195 if not isinstance(node.meta["val"], torch.Tensor): 196 return None 197 if not torch._C._has_storage(node.meta["val"]): 198 return None 199 return get_storage(node.meta["val"]) 200 201 202def get_fake(x): 203 if isinstance(x, torch.fx.Node): 204 if "val" not in x.meta: 205 return x 206 return x.meta["val"] 207 return x 208 209 210def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]: 211 """ 212 First value returns a boolean if any of the input nodes don't have a faketensor. 213 """ 214 args, kwargs = tree_map(get_fake, (x.args, x.kwargs)) 215 if any( 216 isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs) 217 ): 218 return False, args, kwargs 219 return True, args, kwargs 220 221 222def is_node_realized(node: torch.fx.Node) -> bool: 223 """Returns true if a node is always realized when lowered to inductor IR. 224 225 NOTE: This may return some false negatives. e.g. it doesn't 226 handle buffers realized heuristically during lowering, or 227 buffers realized indirectly through view ops. 228 """ 229 from torch._inductor.lowering import fallbacks, needs_realized_inputs 230 231 def is_buffer(node: torch.fx.Node) -> bool: 232 if node.op == "call_function" and node.target is operator.getitem: 233 # For nodes with multiple outputs, we get the fx graph: 234 # foo = torch.ops.aten.foo(...) 235 # getitem = foo[0] 236 # getitem_1 = foo[1] 237 # where we need to check if foo is a fallback kernel 238 return is_buffer(node.args[0]) # type: ignore[arg-type] 239 return node.op in ("placeholder", "output") or node.target in fallbacks 240 241 if is_buffer(node): 242 return True 243 244 def realizes_inputs(node: torch.fx.Node) -> bool: 245 return node.op == "output" or node.target in needs_realized_inputs 246 247 if any(realizes_inputs(user) for user in node.users): 248 return True 249 250 # Otherwise, assume node isn't realized 251 return False 252