xref: /aosp_15_r20/external/pytorch/test/jit/test_alias_analysis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import torch
4from torch._C import parse_ir
5from torch.testing._internal.common_utils import TemporaryFileName
6from torch.testing._internal.jit_utils import JitTestCase
7
8
9if __name__ == "__main__":
10    raise RuntimeError(
11        "This test file is not meant to be run directly, use:\n\n"
12        "\tpython test/test_jit.py TESTNAME\n\n"
13        "instead."
14    )
15
16
17class TestAliasAnalysis(JitTestCase):
18    def test_becomes_wildcard_annotations(self):
19        graph_str = """
20        graph(%a.1 : Tensor, %b.1 : Tensor):
21            %11 : NoneType = prim::Constant()
22            %8 : int = prim::Constant[value=0]()
23            %7 : int = prim::Constant[value=1]()
24            %x.1 : Tensor = aten::add(%a.1, %b.1, %7)
25            %y.1 : Tensor[] = aten::split(%x.1, %7, %8)
26            return ()
27        """
28        graph = parse_ir(graph_str)
29        alias_db = graph.alias_db()
30        split_node = graph.findNode("aten::split")
31        # split input enters wildcard set, list initalized as containing wildcard set
32        self.assertTrue(
33            alias_db.may_contain_alias(next(split_node.inputs()), split_node.output())
34        )
35        # because %x.1 enters wildcard set, it now aliases other members of wildcard set (graph inputs)
36        self.assertTrue(
37            alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs()))
38        )
39
40    def test_nested_list_construct_not_wildcard(self):
41        @torch.jit.script
42        def foo(x):
43            y = torch.rand([2, 2])
44            return [y]
45
46        graph = foo.graph
47        graph.alias_db()
48        alias_db = graph.alias_db()
49        ten_construct = graph.findNode("aten::rand").output()
50        output = next(graph.outputs())
51        self.assertTrue(alias_db.may_contain_alias(ten_construct, output))
52        self.assertFalse(
53            alias_db.may_contain_alias(next(graph.inputs()), ten_construct)
54        )
55
56    def test_recursive_calls(self):
57        @torch.jit.script
58        def foo(x, y):
59            x.add_(1)
60            return x + y
61
62        @torch.jit.script
63        def caller():
64            a = torch.rand([2, 2])
65            b = torch.ones([2, 2])
66            out1 = foo(a, b)
67            c = torch.rand([1])
68            d = torch.ones([2])
69            out2 = foo(d, c)
70            return out1, out2
71
72        isFrozen = False
73        descend_function_calls = True
74        alias_db = caller.graph.alias_db(isFrozen, descend_function_calls)
75        func_calls = caller.graph.findAllNodes("prim::CallFunction")
76        self.assertEqual(len(func_calls), 2)
77        for node in func_calls:
78            inps = list(node.inputs())
79            self.assertTrue(alias_db.has_writers(inps[1]))
80            self.assertFalse(alias_db.has_writers(inps[2]))
81
82        class Mod(torch.nn.Module):
83            def forward(self):
84                a = torch.rand([2, 2])
85                b = torch.ones([2, 2])
86                out1 = self.foo2(a, b)
87                c = torch.rand([1])
88                d = torch.ones([2])
89                out2 = self.foo2(d, c)
90                return out1, out2
91
92            def foo2(self, x, y):
93                x.add_(1)
94                return x + y
95
96        mod = torch.jit.script(Mod())
97        alias_db = mod.graph.alias_db(isFrozen, descend_function_calls)
98        func_calls = mod.graph.findAllNodes("prim::CallMethod")
99        self.assertEqual(len(func_calls), 2)
100        for node in func_calls:
101            inps = list(node.inputs())
102            self.assertTrue(alias_db.has_writers(inps[1]))
103            self.assertFalse(alias_db.has_writers(inps[2]))
104
105    def test_multiple_compilation_units(self):
106        # This is a repro of an internal issue we saw.
107        # Here, we have a large number (40) of modules each with the same name (MyModuleCUTest).
108        # AliasDB uses some hash tables that hash on types; each of these 40 modules are not
109        # identical because they have different compilation units, but they have the same name.
110        # Therefore, if we hash only on the module name (which we previously did), we will have
111        # hash collisions for all of these module types.
112        #
113        # flat_hash_map has very bad performance (exponential) for this hash collision behavior.
114        # This OOMs prior to the fix.
115        N = 40
116
117        class MultiTmpFile:
118            def __init__(self, N):
119                self.N = N
120                self.ctxs = [
121                    TemporaryFileName(mode="w", suffix=".py") for _ in range(N)
122                ]
123
124            def __enter__(self):
125                return [x.__enter__() for x in self.ctxs]
126
127            def __exit__(self, exc_type, exc_value, traceback):
128                return [x.__exit__(exc_type, exc_value, traceback) for x in self.ctxs]
129
130        class ModuleWrapper(torch.nn.Module):
131            def __init__(self, module_list):
132                super().__init__()
133                self.module_list = module_list
134
135            def forward(self, x):
136                for mod in self.module_list:
137                    x = mod(x)
138                return x
139
140        with MultiTmpFile(N) as fnames:
141            module_list = torch.nn.ModuleList()
142            global MyModuleCUTest
143
144            class MyModuleCUTest(torch.nn.Module):
145                def forward(self, x):
146                    return x + 2
147
148            for _, fname in enumerate(fnames):
149                mod = torch.jit.script(MyModuleCUTest())
150                torch.jit.save(mod, fname)
151                loaded_mod = torch.jit.load(fname)
152                module_list.append(loaded_mod)
153
154            mod = ModuleWrapper(module_list)
155            mod = torch.jit.script(mod)
156            mod(torch.zeros((2, 2)))
157