1# Owner(s): ["oncall: jit"] 2 3import torch 4import torch._C 5from torch.testing import FileCheck 6from torch.testing._internal.jit_utils import JitTestCase 7 8 9class TestGraphRewritePasses(JitTestCase): 10 def test_fuse_linear(self): 11 class FunctionalLinear(torch.nn.Module): 12 def __init__(self, weight, bias): 13 super().__init__() 14 self.weight = weight 15 self.bias = bias 16 17 def forward(self, x): 18 res = torch.matmul(x, self.weight.t()) 19 if self.bias is not None: 20 res.add_(self.bias) 21 return res 22 23 x1 = torch.rand(3) 24 w1 = torch.rand(5, 3) 25 b1 = torch.rand(5) 26 for has_bias in [True, False]: 27 bias = b1 if has_bias else None 28 model = torch.jit.trace(FunctionalLinear(w1, bias), [x1]) 29 for node in model.graph.nodes(): 30 if node.kind() == "aten::matmul": 31 source_range_1 = node.sourceRange() 32 torch._C._jit_pass_fuse_linear(model.graph) 33 for node in model.graph.nodes(): 34 if node.kind() == "aten::linear": 35 source_range_2 = node.sourceRange() 36 FileCheck().check("aten::linear").run(model.graph) 37 check_not = ["aten::matmul", "aten::addmm", "aten::add_", "aten::t("] 38 for cn in check_not: 39 FileCheck().check_not(cn).run(model.graph) 40 self.assertTrue(source_range_1 == source_range_2) 41 # make sure it runs 42 model(x1) 43 44 # check matmuls are not fused 45 class Matmul(torch.nn.Module): 46 def __init__(self, weight): 47 super().__init__() 48 self.weight = weight 49 50 def forward(self, x): 51 return torch.matmul(x, self.weight) 52 53 x = torch.rand(5, 6, 5) 54 w = torch.rand(5, 5, 100) 55 model = torch.jit.trace(Matmul(w), [x]) 56 torch._C._jit_pass_fuse_linear(model.graph) 57 # check 3d matmul is not fused 58 FileCheck().check("aten::matmul").run(model.graph) 59 FileCheck().check_not("aten::linear").run(model.graph) 60 # make sure it runs 61 model(x) 62