1# Owner(s): ["module: inductor"] 2from typing import Any, Callable 3 4import torch 5from torch._inductor.fx_passes.pre_grad import ( 6 linear_permute_fusion, 7 linear_transpose, 8 permute_linear_fusion, 9 permute_matmul_fusion, 10 sink_cat_after_pointwise, 11 transpose_linear, 12 transpose_matmul, 13) 14from torch._inductor.test_case import run_tests, TestCase 15from torch.fx.passes.shape_prop import ShapeProp 16 17 18PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule] 19 20 21def chain_passes(*passes: PassFunc) -> PassFunc: 22 def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule: 23 for pass_ in passes: 24 if isinstance(module, torch.fx.GraphModule): 25 ShapeProp(module).propagate(*input) 26 module = pass_(module) 27 return module 28 29 return parent_pass 30 31 32def count_call(module: torch.fx.GraphModule, op: str, target_op: Any) -> int: 33 return sum( 34 1 if (n.op == op and n.target == target_op) else 0 for n in module.graph.nodes 35 ) 36 37 38def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int: 39 return count_call(module, "call_function", target_op) 40 41 42def count_call_method(module: torch.fx.GraphModule, target_op: Any) -> int: 43 return count_call(module, "call_method", target_op) 44 45 46class TestFxFusion(TestCase): 47 def test_sink_cat_after_pointwise(self): 48 def test_kwarg(x, y): 49 return torch.cat([x, y], dim=-1).view(-1).view(128).tanh() 50 51 def test_arg(x, y): 52 return torch.cat([x, y], -1).view(-1).view(128).tanh() 53 54 def test_arg2(x, y): 55 return torch.cat([x, y]).view(-1).view(128).tanh() 56 57 def test_kwarg2(x, y): 58 return torch.cat(tensors=[x, y], dim=0).tanh() 59 60 def test_kwarg3(x, y): 61 return torch.cat(tensors=[x, y], dim=0).view(128).tanh() 62 63 trace_func = chain_passes(torch.fx.symbolic_trace, sink_cat_after_pointwise) 64 inputs = [ 65 torch.randn(8, 8), 66 torch.randn(8, 8), 67 ] 68 for f in [test_kwarg, test_arg, test_arg2, test_kwarg2, test_kwarg3]: 69 traced = trace_func(f, inputs) 70 torch.testing.assert_close(f(*inputs), traced(*inputs)) 71 self.assertEqual(count_call_method(traced, "tanh"), 2) 72 73 def test_linear_permute_fusion(self): 74 class TestModule(torch.nn.Module): 75 def __init__(self, k: int, n: int, has_bias: bool): 76 super().__init__() 77 self.weight = torch.nn.Parameter(torch.randn(n, k)) 78 self.has_bias = has_bias 79 if has_bias: 80 self.bias = torch.nn.Parameter(torch.randn(n)) 81 82 def forward(self, input: torch.Tensor): 83 if self.has_bias: 84 a0 = torch.nn.functional.linear(input, self.weight, self.bias) 85 else: 86 a0 = torch.nn.functional.linear(input, self.weight) 87 b0 = a0.permute(0, 2, 1) 88 return b0 89 90 m, k, n = 16, 8, 4 91 trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion) 92 for has_bias in [True, False]: 93 module = TestModule(k, n, has_bias).eval() 94 input = torch.randn(6, m, k) 95 traced = trace_func(module, [input]) 96 num_linear = count_call_function(traced, torch.nn.functional.linear) 97 num_linear_transpose = count_call_function(traced, linear_transpose) 98 self.assertEqual(num_linear, 0) 99 self.assertEqual(num_linear_transpose, 1) 100 101 torch.testing.assert_close(module(input), traced(input)) 102 103 def test_permute_linear_fusion(self): 104 class TestModule(torch.nn.Module): 105 def __init__(self, k: int, n: int, has_bias: bool): 106 super().__init__() 107 self.weight = torch.nn.Parameter(torch.randn(n, k)) 108 self.has_bias = has_bias 109 if has_bias: 110 self.bias = torch.nn.Parameter(torch.randn(n)) 111 112 def forward(self, input: torch.Tensor): 113 input1 = input.permute(0, 2, 1) 114 if self.has_bias: 115 return torch.nn.functional.linear(input1, self.weight, self.bias) 116 return torch.nn.functional.linear(input1, self.weight) 117 118 m, k, n = 16, 8, 4 119 120 trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion) 121 for has_bias in [True, False]: 122 module = TestModule(k, n, has_bias).eval() 123 input = torch.randn(6, k, m) 124 traced = trace_func(module, [input]) 125 num_linear = count_call_function(traced, torch.nn.functional.linear) 126 num_transpose_linear = count_call_function(traced, transpose_linear) 127 self.assertEqual(num_linear, 0) 128 self.assertEqual(num_transpose_linear, 1) 129 130 torch.testing.assert_close(module(input), traced(input)) 131 132 def test_permute_bmm_fusion(self): 133 class TestModule(torch.nn.Module): 134 def __init__(self, batch: int, k: int, n: int): 135 super().__init__() 136 self.other = torch.randn(batch, k, n) 137 138 def forward(self, input: torch.Tensor): 139 input1 = input.permute(0, 2, 1) 140 output = torch.bmm(input1, self.other) 141 return output 142 143 batch, m, k, n = 6, 16, 8, 4 144 145 trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion) 146 module = TestModule(batch, k, n).eval() 147 input = torch.randn(batch, k, m) 148 traced = trace_func(module, [input]) 149 num_bmm = count_call_function(traced, torch.bmm) 150 num_transpose_matmul = count_call_function(traced, transpose_matmul) 151 self.assertEqual(num_bmm, 0) 152 self.assertEqual(num_transpose_matmul, 1) 153 154 torch.testing.assert_close(module(input), traced(input)) 155 156 157if __name__ == "__main__": 158 run_tests() 159