1# mypy: ignore-errors 2 3 4from typing import Callable 5 6import torch 7import torch.fx as fx 8from torch.multiprocessing.reductions import StorageWeakRef 9from torch.utils import _pytree as pytree 10from torch.utils._pytree import tree_flatten 11 12 13aten = torch.ops.aten 14 15 16def get_aten_target(node: fx.Node) -> Callable: 17 if hasattr(node.target, "overloadpacket"): 18 return node.target.overloadpacket 19 return node.target 20 21 22rand_ops = [ 23 aten.dropout, 24 aten._fused_dropout, 25 aten._standard_gamma, 26 aten.bernoulli, 27 aten.multinomial, 28 aten.native_dropout, 29 aten.normal, 30 aten.poisson, 31 aten.binomial, 32 aten.rrelu, 33 aten.rand_like, 34 aten.rand, 35 aten.randint, 36 aten.randn, 37 aten.randperm, 38] 39 40 41# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph 42def fx_graph_cse(fx_g: torch.fx.graph.Graph): 43 new_graph = fx.Graph() 44 env = {} # map from node in the old graph to node in the new graph 45 hash_env = {} # map from hash to a node in the new graph 46 token_map = {} # map from hash to token 47 48 from torch._inductor.pattern_matcher import ( 49 compute_mutation_region_ids, 50 same_mutation_regions, 51 ) 52 53 compute_mutation_region_ids(fx_g) # type: ignore[arg-type] 54 55 # Make a set of separate storages returned from the output, which will be preserved 56 # when pruning. This prevents us from deduplicating returned tensors which have 57 # experienced identical operations, but are separate data structures in eager mode. 58 output_node: fx.Node = list(fx_g.nodes)[-1] 59 assert output_node.op == "output" 60 61 def checkable_node(node: fx.Node) -> bool: 62 """We can evaluate only nodes that represent tensors with defined storage.""" 63 if "val" not in node.meta or not isinstance(node.meta["val"], torch.Tensor): 64 return False 65 66 try: 67 node.meta["val"].untyped_storage() 68 except NotImplementedError: 69 return False 70 71 return True 72 73 output_storages = { 74 StorageWeakRef(n.meta["val"].untyped_storage()) 75 for n in output_node.all_input_nodes 76 if checkable_node(n) 77 } 78 nodes_that_alias_outputs = { 79 n 80 for n in fx_g.nodes 81 if checkable_node(n) 82 and StorageWeakRef(n.meta["val"].untyped_storage()) in output_storages 83 } 84 85 for n in fx_g.nodes: 86 # The placeholder, output, and get_attr nodes are copied to the new graph without change 87 # do not CSE away random operations 88 if ( 89 n.op == "placeholder" 90 or n.op == "output" 91 or n.op == "get_attr" 92 or get_aten_target(n) in rand_ops 93 # aten.empty is non-deterministic, so don't CSE it. 94 # Also, aten.empty is almost always fusible into its consumer, 95 # so it's not worth CSEing. 96 or get_aten_target(n) is aten.empty 97 or n in nodes_that_alias_outputs 98 ): 99 new_node = new_graph.node_copy(n, lambda x: env[x]) 100 env[n] = new_node 101 else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' 102 # substitute args and kwargs members to their mapping in env if exists 103 # specs can be used to reconstruct nested list/dictionaries 104 def substitute(arg_list): 105 arg_list, spec = tree_flatten(arg_list) 106 for i in range(len(arg_list)): 107 v = arg_list[i] 108 if isinstance(v, torch.fx.node.Node) and v in env: 109 arg_list[i] = env[v] 110 if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)): 111 arg_list[i] = v.node 112 return tuple(arg_list), spec 113 114 args, args_spec = substitute(n.args) 115 kwargs, kwargs_spec = substitute(n.kwargs) 116 117 # each token corresponds to a unique node 118 # nodes with the same token can be substituted 119 token = { 120 "target": n.target, 121 "args": args, 122 "args_spec": args_spec, 123 "kwargs": kwargs, 124 "kwargs_spec": kwargs_spec, 125 } 126 127 # hash substituted args to a number, do not hash specs because specs are not hashable 128 # We need to add type into hash to avoid situations like: 129 # hash((primals_2, 1.0)) == hash((primals_2, 1)) 130 hash_arg = hash( 131 (tuple((a, type(a)) for a in args), tuple((a, type(a)) for a in kwargs)) 132 ) 133 hash_val = (n.target, hash_arg) 134 135 # check if a node has a substitute and can be eliminated 136 hash_val_in_hash_env = hash_val in hash_env 137 overwrite_due_to_mutation = False 138 if hash_val_in_hash_env and token_map[hash_val] == token: 139 duplicate_n_prev = hash_env[hash_val] 140 if same_mutation_regions(n, duplicate_n_prev): 141 env[n] = duplicate_n_prev 142 continue 143 else: 144 # any futures duplicates should replace with n, not duplicate_n_prev 145 overwrite_due_to_mutation = True 146 147 new_node = new_graph.node_copy(n, lambda x: env[x]) 148 env[n] = new_node 149 if overwrite_due_to_mutation or not hash_val_in_hash_env: 150 hash_env[hash_val] = new_node 151 token_map[hash_val] = token 152 153 return new_graph 154 155 156def strip_overloads(gm): 157 """ 158 Modifies the target of graph nodes in :attr:`gm` to strip overloads. 159 160 Args: 161 gm(fx.GraphModule): The input Fx graph module to be modified 162 """ 163 for node in gm.graph.nodes: 164 if isinstance(node.target, torch._ops.OpOverload): 165 node.target = node.target.overloadpacket 166 gm.recompile() 167 168 169def get_placeholders(graph): 170 return graph.find_nodes(op="placeholder") 171 172 173def get_outputs(graph): 174 for node in graph.find_nodes(op="output"): 175 return pytree.tree_leaves(node.args[0]) 176 raise AssertionError("No output node found") 177