1# Owner(s): ["module: inductor"] 2import contextlib 3import operator 4from collections import defaultdict 5 6import torch 7import torch._inductor.pattern_matcher as pattern_matcher 8import torch.fx as fx 9from torch._dynamo.utils import counters 10from torch._inductor import config 11from torch._inductor.lowering import lowerings as L 12from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass 13from torch._inductor.test_case import run_tests, TestCase 14from torch.testing._internal.common_utils import IS_LINUX 15from torch.testing._internal.inductor_utils import HAS_CPU 16 17 18@config.patch({"freezing": True}) 19class TestCustomPassBase(TestCase): 20 def _clone_inputs(self, inputs): 21 def clone(x): 22 if not isinstance(x, torch.Tensor): 23 return x 24 return x.clone() 25 26 return tuple(clone(x) for x in inputs) 27 28 def _test_common( 29 self, 30 mod, 31 inputs, 32 matcher_count, 33 matcher_nodes, 34 atol=1e-5, 35 rtol=1.3e-6, 36 ): 37 counters.clear() 38 maybe_autocast = contextlib.nullcontext() 39 with torch.no_grad(), maybe_autocast: 40 clone_inputs = self._clone_inputs(inputs) 41 expected = mod(*inputs) 42 actual = torch.compile(mod)(*clone_inputs) 43 torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) 44 self.assertEqual( 45 counters["inductor"]["pattern_matcher_count"], matcher_count 46 ) 47 self.assertEqual( 48 counters["inductor"]["pattern_matcher_nodes"], 49 matcher_nodes, 50 ) 51 52 53aten = torch.ops.aten 54mkldnn = torch.ops.mkldnn 55 56 57def change_cos_pass(graph): 58 for node in graph.nodes: 59 if node.op == "call_function" and node.target == aten.cos.default: 60 node.target = aten.sin.default 61 62 63class TestPostGradCustomPrePostPass(TestCustomPassBase): 64 # mkldnn fusion's pattern_matcher 65 # (torch/_inductor/fx_passes/mkldnn_fusion.py), 66 # and apply it to custom post_grad_passes. 67 def _register_mkldnn_conv_relu_fusion(self, custom_pass_dict): 68 # pattern 69 def _mkldnn_conv_relu_pattern(): 70 return CallFunction( 71 aten.relu, 72 CallFunction( 73 mkldnn._convolution_pointwise.default, 74 Arg(), 75 Arg(), 76 Arg(), 77 Arg(), 78 Arg(), 79 Arg(), 80 Arg(), 81 Arg(), 82 Arg(), 83 Arg(), 84 _users=1, 85 ), 86 ) 87 88 # utils of pattern matcher registration 89 def _register_fusion_lowering(pattern, custom_pass_dict): 90 def dummy_check(m): 91 return True 92 93 def register_custom_lowering_pattern( 94 pattern, extra_check, custom_pass_dict 95 ): 96 return pattern_matcher.register_lowering_pattern( 97 pattern, extra_check, pass_dict=custom_pass_dict 98 ) 99 100 @register_custom_lowering_pattern(pattern, dummy_check, custom_pass_dict) 101 def fn(match, *args, **kwargs): 102 computation_args = list(args)[:-3] + ["relu", [], ""] 103 return L[mkldnn._convolution_pointwise.default](*computation_args) 104 105 return fn 106 107 _register_fusion_lowering(_mkldnn_conv_relu_pattern(), custom_pass_dict) 108 109 # custom post grad pass 110 class _CustomPass(PatternMatcherPass): 111 def __init__(self) -> None: 112 super().__init__() 113 114 def __call__(self, g: torch.fx.graph.Graph): 115 self.apply(g) 116 117 # case model 118 class _ConvReLU(torch.nn.Module): 119 def __init__(self, ic, oc): 120 super().__init__() 121 self.conv = torch.nn.Conv2d(ic, oc, kernel_size=3, stride=1, padding=1) 122 123 def forward(self, x): 124 x1 = self.conv(x) 125 return x1.relu() 126 127 def test_custom_joint_pass_pre(self): 128 with config.patch(joint_custom_pre_pass=change_cos_pass): 129 130 def g(x): 131 return x.sin().sin().sin() 132 133 def f(x): 134 return x.cos().cos().cos() 135 136 x = torch.randn(8, dtype=torch.float32) 137 torch.testing.assert_close(torch.compile(f)(x), g(x)) 138 139 def test_custom_joint_pass_post(self): 140 with config.patch(joint_custom_post_pass=change_cos_pass): 141 142 def g(x): 143 return x.sin().sin().sin() 144 145 def f(x): 146 return x.cos().cos().cos() 147 148 x = torch.randn(8, dtype=torch.float32) 149 torch.testing.assert_close(torch.compile(f)(x), g(x)) 150 151 def test_custom_pre_pass(self): 152 with config.patch( 153 # leave custom pass only in post_grad_passes() 154 pattern_matcher=False, 155 post_grad_custom_pre_pass=self._CustomPass(), 156 # define pattern match as custom post grad opt pass 157 post_grad_custom_post_pass=None, 158 ): 159 # init mkldnn fusion on custom_matcher 160 self._register_mkldnn_conv_relu_fusion(config.post_grad_custom_pre_pass) 161 162 mod = self._ConvReLU(16, 16).eval() 163 x = torch.randn((1, 16, 56, 56), dtype=torch.float32) 164 165 match_count = 1 166 match_nodes = 2 167 other_match_count = 1 # conv prepack weight 168 other_match_nodes = 1 # conv prepack weight 169 self._test_common( 170 mod, 171 (x,), 172 match_count + other_match_count, 173 match_nodes + other_match_nodes, 174 ) 175 176 def test_custom_post_pass(self): 177 with config.patch( 178 # leave custom pass only in post_grad_passes() 179 pattern_matcher=False, 180 # define pattern match as custom post grad opt pass 181 post_grad_custom_pre_pass=None, 182 post_grad_custom_post_pass=self._CustomPass(), 183 ): 184 # init mkldnn fusion on custom_matcher 185 self._register_mkldnn_conv_relu_fusion(config.post_grad_custom_post_pass) 186 187 mod = self._ConvReLU(16, 16).eval() 188 x = torch.randn((1, 16, 56, 56), dtype=torch.float32) 189 190 match_count = 1 191 match_nodes = 2 192 other_match_count = 1 # conv prepack weight 193 other_match_nodes = 1 # conv prepack weight 194 self._test_common( 195 mod, 196 (x,), 197 match_count + other_match_count, 198 match_nodes + other_match_nodes, 199 ) 200 201 def test_custom_pre_grad_pass(self): 202 saved_graph = [None] 203 204 def merge_mm_shared_rhs(graph: fx.Graph): 205 """ 206 Bad POC of merging mm with a shared RHS. 207 i.e. [mm(x, W), mm(x2, W)] => mm(cat(x, x2), W).split() 208 209 Isn't actually safe for a couple reasons. For example, it doesn't handle the 210 case where the LHS inputs depend on each other 211 """ 212 saved_graph[0] = graph 213 matmuls = [n for n in graph.nodes if n.target == torch.mm] 214 rhs_vals = defaultdict(set) 215 for m in matmuls: 216 rhs_vals[m.args[1]].add(m) 217 218 order = {} 219 for idx, n in enumerate(graph.nodes): 220 order[n] = idx 221 222 for rhs, matmuls in rhs_vals.items(): 223 if len(matmuls) == 1: 224 continue 225 matmuls = sorted(matmuls, key=lambda x: order[x]) 226 with graph.inserting_before(matmuls[0]): 227 lhs_vals = [m.args[0] for m in matmuls] 228 new_cat = graph.create_node( 229 "call_function", torch.cat, args=(lhs_vals, 0) 230 ) 231 new_mm = graph.create_node( 232 "call_function", torch.mm, args=(new_cat, rhs) 233 ) 234 split_vals = graph.create_node( 235 "call_function", 236 torch.split, 237 args=( 238 new_mm, 239 [l.meta["example_value"].shape[0] for l in lhs_vals], 240 ), 241 ) 242 for idx, m in enumerate(matmuls): 243 m.target = operator.getitem 244 m.args = (split_vals, idx) 245 246 @config.patch(pre_grad_custom_pass=merge_mm_shared_rhs) 247 def inner_test(): 248 @torch.compile 249 def f(W, nested_seqs): 250 outs = [torch.mm(s, W) for s in nested_seqs] 251 return outs 252 253 W = torch.randn(16, 16, dtype=torch.bfloat16) 254 nested_seqs = [ 255 torch.randn(l, 16, dtype=torch.bfloat16) for l in [4, 8, 5, 3] 256 ] 257 258 f(W, nested_seqs) 259 assert saved_graph[0] is not None 260 matmuls = [n for n in saved_graph[0].nodes if n.target == torch.mm] 261 assert len(matmuls) == 1 262 263 inner_test() 264 265 266if __name__ == "__main__": 267 if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available(): 268 run_tests() 269