1# Owner(s): ["oncall: jit"] 2 3import torch 4from torch.testing import FileCheck 5from torch.testing._internal.jit_utils import JitTestCase, make_global 6 7 8class TestDCE(JitTestCase): 9 def test_setattr_no_aliasdb(self): 10 class Net(torch.nn.Module): 11 def __init__(self) -> None: 12 super().__init__() 13 self.x = torch.empty([2, 2]) 14 15 def forward(self): 16 x = torch.rand([3, 3]) 17 self.x = x 18 19 net = torch.jit.script(Net()) 20 21 FileCheck().check("prim::SetAttr").run(net.graph) 22 23 def test_setattr_removed(self): 24 @torch.jit.script 25 class Thing1: 26 def __init__(self) -> None: 27 self.x = torch.zeros([2, 2]) 28 29 make_global(Thing1) 30 31 class Thing2(torch.nn.Module): 32 def forward(self): 33 x = torch.rand([2, 2]) 34 y = torch.rand([2, 2]) 35 t1 = Thing1() 36 t1.x = x 37 return y 38 39 unscripted = Thing2() 40 41 t2 = torch.jit.script(unscripted) 42 t2.eval() 43 44 # freezing inlines t1.__init__(), after which DCE can occur. 45 t2 = torch.jit.freeze(t2) 46 FileCheck().check_not("prim::SetAttr").run(t2.graph) 47