1# mypy: allow-untyped-defs 2import functools 3 4import torch 5from torch._inductor.compile_fx import fake_tensor_prop 6 7from ..._dynamo.utils import counters 8from .. import config 9from ..pattern_matcher import ( 10 _return_true, 11 CallFunction, 12 fwd_only, 13 Ignored, 14 init_once_fakemode, 15 KeywordArg, 16 Match, 17 PatternMatcherPass, 18 register_graph_pattern, 19 register_replacement, 20 stable_topological_sort, 21) 22 23 24aten = torch.ops.aten 25 26# First pass_patterns[0] are applied, then [1], then [2] 27pass_patterns = [ 28 PatternMatcherPass(), 29 PatternMatcherPass(), 30 PatternMatcherPass(), 31] 32 33binary_folding_pass = PatternMatcherPass() 34 35 36def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs): 37 """ 38 Passes that are applied to the graph to freeze pass. 39 """ 40 41 from ..freezing import constant_fold 42 43 lazy_init() 44 # We need a few rounds of binary folding to get rid of all the 45 # unnecessary nodes, but may need a good method to chose the rounds number. 46 # works like: conv+binary+binary. 47 binary_folding = counters["inductor"]["binary_folding"] 48 fake_tensor_prop(gm, aot_example_inputs, True) 49 50 torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_convs(gm) 51 for _ in range(4): 52 constant_fold(gm) 53 # Make sure meta['val'] is properly set for all nodes 54 fake_tensor_prop(gm, aot_example_inputs, True) 55 binary_folding_pass.apply(gm.graph) # type: ignore[arg-type] 56 # If we don't have binary folding, we don't need to run the pass again. 57 # TODO: remove the need to run fake_tensor_prop on the whole model. 58 if counters["inductor"]["binary_folding"] == binary_folding: 59 break 60 binary_folding = counters["inductor"]["binary_folding"] 61 62 torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_convs(gm) 63 64 constant_fold(gm) 65 fake_tensor_prop(gm, aot_example_inputs, True) 66 67 for pattern in pass_patterns: 68 pattern.apply(gm.graph) # type: ignore[arg-type] 69 70 # The CPU weight packing always assume the conv's weight is channels last, 71 # So make sure the layout_optimization is on when doing it. 72 if ( 73 torch._C._has_mkldnn 74 and config.cpp.weight_prepack 75 and config.layout_optimization 76 ): 77 from .mkldnn_fusion import _eliminate_duplicate_packed_nodes 78 79 _eliminate_duplicate_packed_nodes(gm) 80 81 stable_topological_sort(gm.graph) 82 gm.recompile() 83 gm.graph.lint() 84 85 86@init_once_fakemode 87def lazy_init(): 88 if torch._C._has_mkldnn and config.cpp.weight_prepack: 89 from .mkldnn_fusion import _mkldnn_weight_pack_init 90 91 _mkldnn_weight_pack_init() 92 93 from .binary_folding import binary_folding_init 94 95 addmm_patterns_init() 96 binary_folding_init() 97 98 99def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0): 100 return register_graph_pattern( 101 pattern, 102 extra_check=extra_check, 103 pass_dict=pass_patterns[pass_number], 104 ) 105 106 107def register_binary_folding_pattern(pattern, extra_check=_return_true): 108 return register_graph_pattern( 109 pattern, 110 extra_check=extra_check, 111 pass_dict=binary_folding_pass, 112 ) 113 114 115@functools.lru_cache(None) 116def addmm_patterns_init(): 117 if torch.cuda.is_available(): 118 # workaround https://github.com/pytorch/pytorch/issues/97894 119 device = "cuda" 120 else: 121 device = "cpu" 122 val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False) 123 124 def check_concat_weights(match): 125 weight_inputs = ["w1", "w2"] 126 if "w3" in match.kwargs: 127 weight_inputs.append("w3") 128 129 equal_shape_inputs = [weight_inputs] 130 131 if "b1" in match.kwargs: 132 bias_inputs = ["b1", "b2"] 133 if "b3" in match.kwargs: 134 bias_inputs.append("b3") 135 136 equal_shape_inputs.append(bias_inputs) 137 138 for equal_shape_group in equal_shape_inputs: 139 inps = [match.kwargs[name] for name in equal_shape_group] 140 141 if not all( 142 inp.op == "get_attr" 143 and inp.meta["val"].shape == inps[0].meta["val"].shape 144 for inp in inps 145 ): 146 return False 147 148 return True 149 150 def matmul_fuse_pattern(inp, w1, w2, w3): 151 return (inp @ w1, inp @ w2, inp @ w3) 152 153 def matmul_replacement(inp, w1, w2, w3): 154 cat_t = torch.cat((w1, w2, w3), dim=1) 155 mm = inp @ cat_t 156 return mm.chunk(3, dim=1) 157 158 register_replacement( 159 matmul_fuse_pattern, 160 matmul_replacement, 161 [val(), val(), val(), val()], 162 fwd_only, 163 pass_patterns[0], 164 extra_check=check_concat_weights, 165 exclusive_arg_names=("w1", "w2", "w3"), 166 ) 167 168 def matmul_fuse_pattern_two(inp, w1, w2): 169 return (inp @ w1, inp @ w2) 170 171 def matmul_replacement_two(inp, w1, w2): 172 cat_t = torch.cat((w1, w2), dim=1) 173 mm = inp @ cat_t 174 return mm.chunk(2, dim=1) 175 176 register_replacement( 177 matmul_fuse_pattern_two, 178 matmul_replacement_two, 179 [val(), val(), val()], 180 fwd_only, 181 pass_patterns[0], 182 extra_check=check_concat_weights, 183 exclusive_arg_names=("w1", "w2"), 184 ) 185 186 def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3): 187 return ( 188 aten.addmm(b1, inp, w1), 189 aten.addmm(b2, inp, w2), 190 aten.addmm(b3, inp, w3), 191 ) 192 193 def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3): 194 cat_w = torch.cat((w1, w2, w3), dim=1) 195 cat_b = torch.cat((b1, b2, b3)) 196 return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1) 197 198 register_replacement( 199 addmm_fuse_pattern_second, 200 addmm_fuse_replacement_second, 201 [val() for _ in range(7)], 202 fwd_only, 203 pass_patterns[0], 204 extra_check=check_concat_weights, 205 exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"), 206 ) 207 208 209def same_dtype(match): 210 return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"] 211 212 213@register_graph_pattern( 214 CallFunction( 215 torch.ops.prims.convert_element_type.default, 216 Ignored(), 217 KeywordArg("dtype"), 218 ), 219 pass_dict=pass_patterns[0], 220 extra_check=same_dtype, 221) 222def unnecessary_dtype_convert(match: Match, **kwargs): 223 """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding""" 224 graph = match.graph 225 node = match.output_node() 226 node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] 227 graph.erase_node(node) 228