xref: /aosp_15_r20/external/pytorch/test/dynamo/test_compile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport inspect
4*da0073e9SAndroid Build Coastguard Workerimport io
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport tempfile
7*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.test_case import run_tests, TestCase
11*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import CompileCounter
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerclass ToyModel(torch.nn.Module):
15*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
16*da0073e9SAndroid Build Coastguard Worker        super().__init__()
17*da0073e9SAndroid Build Coastguard Worker        self.linear = torch.nn.Linear(10, 10)
18*da0073e9SAndroid Build Coastguard Worker        self.relu = torch.nn.ReLU()
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
21*da0073e9SAndroid Build Coastguard Worker        return self.relu(self.linear(x))
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerclass InPlaceCompilationTests(TestCase):
25*da0073e9SAndroid Build Coastguard Worker    def test_compilation(self):
26*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
27*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
28*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
29*da0073e9SAndroid Build Coastguard Worker        model.compile(backend=cnt)
30*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
31*da0073e9SAndroid Build Coastguard Worker        model(x)
32*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker    def test_overwrite_call_impl(self):
35*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
36*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
37*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(model._compiled_call_impl is None)
38*da0073e9SAndroid Build Coastguard Worker        model.compile()
39*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(model._compiled_call_impl is not None)
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def test_save(self):
42*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
43*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
44*da0073e9SAndroid Build Coastguard Worker        model.compile()
45*da0073e9SAndroid Build Coastguard Worker        model(torch.randn(1, 10))
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tmpdirname:
48*da0073e9SAndroid Build Coastguard Worker            torch.save(model, os.path.join(tmpdirname, "model.pt"))
49*da0073e9SAndroid Build Coastguard Worker            loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
50*da0073e9SAndroid Build Coastguard Worker            loaded_model(torch.randn(1, 10))
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def test_state_dict_save(self):
53*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
54*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
55*da0073e9SAndroid Build Coastguard Worker        model.compile()
56*da0073e9SAndroid Build Coastguard Worker        model(torch.randn(1, 10))
57*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tmpdirname:
58*da0073e9SAndroid Build Coastguard Worker            torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt"))
59*da0073e9SAndroid Build Coastguard Worker            loaded_model = ToyModel()
60*da0073e9SAndroid Build Coastguard Worker            loaded_model.load_state_dict(
61*da0073e9SAndroid Build Coastguard Worker                torch.load(os.path.join(tmpdirname, "model.pt"))
62*da0073e9SAndroid Build Coastguard Worker            )
63*da0073e9SAndroid Build Coastguard Worker            loaded_model(torch.randn(1, 10))
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    def test_jit_save(self):
66*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
67*da0073e9SAndroid Build Coastguard Worker        model = ToyModel()
68*da0073e9SAndroid Build Coastguard Worker        model.compile()
69*da0073e9SAndroid Build Coastguard Worker        model(torch.randn(1, 10))
70*da0073e9SAndroid Build Coastguard Worker        scripted_model = torch.jit.script(model)
71*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tmpdirname:
72*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
73*da0073e9SAndroid Build Coastguard Worker            loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
74*da0073e9SAndroid Build Coastguard Worker            loaded_model(torch.randn(1, 10))
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    def test_compilation_callback(self):
77*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.on_compile_start
80*da0073e9SAndroid Build Coastguard Worker        def start_callback():
81*da0073e9SAndroid Build Coastguard Worker            print("Compilation started.")
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.on_compile_end
84*da0073e9SAndroid Build Coastguard Worker        def end_callback():
85*da0073e9SAndroid Build Coastguard Worker            print("Compilation ended.")
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker        mod = ToyModel()
88*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
91*da0073e9SAndroid Build Coastguard Worker            opt_mod = torch.compile(backend="eager", fullgraph=True)(mod)
92*da0073e9SAndroid Build Coastguard Worker            opt_mod(x)
93*da0073e9SAndroid Build Coastguard Worker            printed_output = mock_stdout.getvalue().strip()
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(printed_output, "Compilation started.\nCompilation ended.")
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    def test_compile_eager_options(self):
98*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", options={"foo": 2})
99*da0073e9SAndroid Build Coastguard Worker        def f(x):
100*da0073e9SAndroid Build Coastguard Worker            return x + x
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(3))
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="aot_eager", options={"foo": 2})
105*da0073e9SAndroid Build Coastguard Worker        def g(x):
106*da0073e9SAndroid Build Coastguard Worker            return x + x
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        g(torch.randn(3))
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    def test_compilation_callback_with_graph_break(self):
111*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
112*da0073e9SAndroid Build Coastguard Worker        counter = 0
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.on_compile_start
115*da0073e9SAndroid Build Coastguard Worker        def start_callback():
116*da0073e9SAndroid Build Coastguard Worker            nonlocal counter
117*da0073e9SAndroid Build Coastguard Worker            counter += 1
118*da0073e9SAndroid Build Coastguard Worker            print(f"Counter = {counter}")
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.on_compile_end
121*da0073e9SAndroid Build Coastguard Worker        def end_callback():
122*da0073e9SAndroid Build Coastguard Worker            nonlocal counter
123*da0073e9SAndroid Build Coastguard Worker            counter += 1
124*da0073e9SAndroid Build Coastguard Worker            print(f"Counter = {counter}")
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
127*da0073e9SAndroid Build Coastguard Worker        def fn(x):
128*da0073e9SAndroid Build Coastguard Worker            x = x + 1
129*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
130*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker        with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
135*da0073e9SAndroid Build Coastguard Worker            fn(x)
136*da0073e9SAndroid Build Coastguard Worker            printed_output = mock_stdout.getvalue().strip()
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
139*da0073e9SAndroid Build Coastguard Worker            printed_output, "Counter = 1\nCounter = 2\nCounter = 3\nCounter = 4"
140*da0073e9SAndroid Build Coastguard Worker        )
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker# The private variants of the below functions are extensively tested
144*da0073e9SAndroid Build Coastguard Worker# So as long as the signatures match we're good
145*da0073e9SAndroid Build Coastguard Workerclass PublicTorchCompilerTests(TestCase):
146*da0073e9SAndroid Build Coastguard Worker    def check_signature(self, public_fn_name, private_fn_name, private_namespace):
147*da0073e9SAndroid Build Coastguard Worker        public_fn = getattr(torch.compiler, public_fn_name)
148*da0073e9SAndroid Build Coastguard Worker        private_fn = getattr(private_namespace, private_fn_name)
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        public_sig = inspect.signature(public_fn)
151*da0073e9SAndroid Build Coastguard Worker        private_sig = inspect.signature(private_fn)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
154*da0073e9SAndroid Build Coastguard Worker            public_sig,
155*da0073e9SAndroid Build Coastguard Worker            private_sig,
156*da0073e9SAndroid Build Coastguard Worker            f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}",
157*da0073e9SAndroid Build Coastguard Worker        )
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_signatures(self):
160*da0073e9SAndroid Build Coastguard Worker        function_names = [
161*da0073e9SAndroid Build Coastguard Worker            "reset",
162*da0073e9SAndroid Build Coastguard Worker            "allow_in_graph",
163*da0073e9SAndroid Build Coastguard Worker            "list_backends",
164*da0073e9SAndroid Build Coastguard Worker            "assume_constant_result",
165*da0073e9SAndroid Build Coastguard Worker            "disable",
166*da0073e9SAndroid Build Coastguard Worker        ]
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        for fn_name in function_names:
169*da0073e9SAndroid Build Coastguard Worker            self.check_signature(fn_name, fn_name, torch._dynamo)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
173*da0073e9SAndroid Build Coastguard Worker    run_tests()
174