xref: /aosp_15_r20/external/pytorch/test/jit/test_dce.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerclass TestDCE(JitTestCase):
9*da0073e9SAndroid Build Coastguard Worker    def test_setattr_no_aliasdb(self):
10*da0073e9SAndroid Build Coastguard Worker        class Net(torch.nn.Module):
11*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12*da0073e9SAndroid Build Coastguard Worker                super().__init__()
13*da0073e9SAndroid Build Coastguard Worker                self.x = torch.empty([2, 2])
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker            def forward(self):
16*da0073e9SAndroid Build Coastguard Worker                x = torch.rand([3, 3])
17*da0073e9SAndroid Build Coastguard Worker                self.x = x
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker        net = torch.jit.script(Net())
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::SetAttr").run(net.graph)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def test_setattr_removed(self):
24*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
25*da0073e9SAndroid Build Coastguard Worker        class Thing1:
26*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
27*da0073e9SAndroid Build Coastguard Worker                self.x = torch.zeros([2, 2])
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker        make_global(Thing1)
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker        class Thing2(torch.nn.Module):
32*da0073e9SAndroid Build Coastguard Worker            def forward(self):
33*da0073e9SAndroid Build Coastguard Worker                x = torch.rand([2, 2])
34*da0073e9SAndroid Build Coastguard Worker                y = torch.rand([2, 2])
35*da0073e9SAndroid Build Coastguard Worker                t1 = Thing1()
36*da0073e9SAndroid Build Coastguard Worker                t1.x = x
37*da0073e9SAndroid Build Coastguard Worker                return y
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        unscripted = Thing2()
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker        t2 = torch.jit.script(unscripted)
42*da0073e9SAndroid Build Coastguard Worker        t2.eval()
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        # freezing inlines t1.__init__(), after which DCE can occur.
45*da0073e9SAndroid Build Coastguard Worker        t2 = torch.jit.freeze(t2)
46*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("prim::SetAttr").run(t2.graph)
47