1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerclass TestDCE(JitTestCase): 9*da0073e9SAndroid Build Coastguard Worker def test_setattr_no_aliasdb(self): 10*da0073e9SAndroid Build Coastguard Worker class Net(torch.nn.Module): 11*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12*da0073e9SAndroid Build Coastguard Worker super().__init__() 13*da0073e9SAndroid Build Coastguard Worker self.x = torch.empty([2, 2]) 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker def forward(self): 16*da0073e9SAndroid Build Coastguard Worker x = torch.rand([3, 3]) 17*da0073e9SAndroid Build Coastguard Worker self.x = x 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker net = torch.jit.script(Net()) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::SetAttr").run(net.graph) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker def test_setattr_removed(self): 24*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 25*da0073e9SAndroid Build Coastguard Worker class Thing1: 26*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 27*da0073e9SAndroid Build Coastguard Worker self.x = torch.zeros([2, 2]) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker make_global(Thing1) 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker class Thing2(torch.nn.Module): 32*da0073e9SAndroid Build Coastguard Worker def forward(self): 33*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 2]) 34*da0073e9SAndroid Build Coastguard Worker y = torch.rand([2, 2]) 35*da0073e9SAndroid Build Coastguard Worker t1 = Thing1() 36*da0073e9SAndroid Build Coastguard Worker t1.x = x 37*da0073e9SAndroid Build Coastguard Worker return y 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker unscripted = Thing2() 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker t2 = torch.jit.script(unscripted) 42*da0073e9SAndroid Build Coastguard Worker t2.eval() 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker # freezing inlines t1.__init__(), after which DCE can occur. 45*da0073e9SAndroid Build Coastguard Worker t2 = torch.jit.freeze(t2) 46*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::SetAttr").run(t2.graph) 47