xref: /aosp_15_r20/external/pytorch/test/jit/test_dce.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import torch
4from torch.testing import FileCheck
5from torch.testing._internal.jit_utils import JitTestCase, make_global
6
7
8class TestDCE(JitTestCase):
9    def test_setattr_no_aliasdb(self):
10        class Net(torch.nn.Module):
11            def __init__(self) -> None:
12                super().__init__()
13                self.x = torch.empty([2, 2])
14
15            def forward(self):
16                x = torch.rand([3, 3])
17                self.x = x
18
19        net = torch.jit.script(Net())
20
21        FileCheck().check("prim::SetAttr").run(net.graph)
22
23    def test_setattr_removed(self):
24        @torch.jit.script
25        class Thing1:
26            def __init__(self) -> None:
27                self.x = torch.zeros([2, 2])
28
29        make_global(Thing1)
30
31        class Thing2(torch.nn.Module):
32            def forward(self):
33                x = torch.rand([2, 2])
34                y = torch.rand([2, 2])
35                t1 = Thing1()
36                t1.x = x
37                return y
38
39        unscripted = Thing2()
40
41        t2 = torch.jit.script(unscripted)
42        t2.eval()
43
44        # freezing inlines t1.__init__(), after which DCE can occur.
45        t2 = torch.jit.freeze(t2)
46        FileCheck().check_not("prim::SetAttr").run(t2.graph)
47