xref: /aosp_15_r20/external/pytorch/test/export/test_verifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2import unittest
3
4import torch
5from functorch.experimental import control_flow
6from torch import Tensor
7from torch._dynamo.eval_frame import is_dynamo_supported
8from torch._export.verifier import SpecViolationError, Verifier
9from torch.export import export
10from torch.export.exported_program import InputKind, InputSpec, TensorArgument
11from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
12
13
14@unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported")
15class TestVerifier(TestCase):
16    def test_verifier_basic(self) -> None:
17        class Foo(torch.nn.Module):
18            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
19                return x + y
20
21        f = Foo()
22
23        ep = export(f, (torch.randn(100), torch.randn(100)))
24
25        verifier = Verifier()
26        verifier.check(ep)
27
28    def test_verifier_call_module(self) -> None:
29        class M(torch.nn.Module):
30            def __init__(self) -> None:
31                super().__init__()
32                self.linear = torch.nn.Linear(10, 10)
33
34            def forward(self, x: Tensor) -> Tensor:
35                return self.linear(x)
36
37        gm = torch.fx.symbolic_trace(M())
38
39        verifier = Verifier()
40        with self.assertRaises(SpecViolationError):
41            verifier._check_graph_module(gm)
42
43    def test_verifier_no_functional(self) -> None:
44        class Foo(torch.nn.Module):
45            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
46                return x + y
47
48        f = Foo()
49
50        ep = export(f, (torch.randn(100), torch.randn(100)))
51        for node in ep.graph.nodes:
52            if node.target == torch.ops.aten.add.Tensor:
53                node.target = torch.ops.aten.add_.Tensor
54
55        verifier = Verifier()
56        with self.assertRaises(SpecViolationError):
57            verifier.check(ep)
58
59    @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
60    def test_verifier_higher_order(self) -> None:
61        class Foo(torch.nn.Module):
62            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
63                def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
64                    return x + y
65
66                def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
67                    return x - y
68
69                return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
70
71        f = Foo()
72
73        ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)))
74
75        verifier = Verifier()
76        verifier.check(ep)
77
78    @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
79    def test_verifier_nested_invalid_module(self) -> None:
80        class Foo(torch.nn.Module):
81            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
82                def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
83                    return x + y
84
85                def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
86                    return x - y
87
88                return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
89
90        f = Foo()
91
92        ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)))
93        for node in ep.graph_module.true_graph_0.graph.nodes:
94            if node.target == torch.ops.aten.add.Tensor:
95                node.target = torch.ops.aten.add_.Tensor
96
97        verifier = Verifier()
98        with self.assertRaises(SpecViolationError):
99            verifier.check(ep)
100
101    def test_ep_verifier_basic(self) -> None:
102        class M(torch.nn.Module):
103            def __init__(self) -> None:
104                super().__init__()
105                self.linear = torch.nn.Linear(10, 10)
106
107            def forward(self, x: Tensor) -> Tensor:
108                return self.linear(x)
109
110        ep = export(M(), (torch.randn(10, 10),))
111        ep.validate()
112
113    def test_ep_verifier_invalid_param(self) -> None:
114        class M(torch.nn.Module):
115            def __init__(self) -> None:
116                super().__init__()
117                self.register_parameter(
118                    name="a", param=torch.nn.Parameter(torch.randn(100))
119                )
120
121            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
122                return x + y + self.a
123
124        ep = export(M(), (torch.randn(100), torch.randn(100)))
125
126        # Parameter doesn't exist in the state dict
127        ep.graph_signature.input_specs[0] = InputSpec(
128            kind=InputKind.PARAMETER, arg=TensorArgument(name="p_a"), target="bad_param"
129        )
130        with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
131            ep.validate()
132
133        # Add non-torch.nn.Parameter parameter to the state dict
134        ep.state_dict["bad_param"] = torch.randn(100)
135        with self.assertRaisesRegex(
136            SpecViolationError, "not an instance of torch.nn.Parameter"
137        ):
138            ep.validate()
139
140    def test_ep_verifier_invalid_buffer(self) -> None:
141        class M(torch.nn.Module):
142            def __init__(self) -> None:
143                super().__init__()
144                self.a = torch.tensor(3.0)
145
146            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
147                return x + y + self.a
148
149        ep = export(M(), (torch.randn(100), torch.randn(100)))
150
151        # Buffer doesn't exist in the state dict
152        ep.graph_signature.input_specs[0] = InputSpec(
153            kind=InputKind.BUFFER,
154            arg=TensorArgument(name="c_a"),
155            target="bad_buffer",
156            persistent=True,
157        )
158        with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
159            ep.validate()
160
161    def test_ep_verifier_buffer_mutate(self) -> None:
162        class M(torch.nn.Module):
163            def __init__(self) -> None:
164                super().__init__()
165
166                self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
167
168                self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
169                self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
170
171            def forward(self, x1, x2):
172                # Use the parameter, buffers, and both inputs in the forward method
173                output = (
174                    x1 + self.my_parameter
175                ) * self.my_buffer1 + x2 * self.my_buffer2
176
177                # Mutate one of the buffers (e.g., increment it by 1)
178                self.my_buffer2.add_(1.0)
179                return output
180
181        ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))
182        ep.validate()
183
184    def test_ep_verifier_invalid_output(self) -> None:
185        class M(torch.nn.Module):
186            def __init__(self) -> None:
187                super().__init__()
188
189                self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
190
191                self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
192                self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
193
194            def forward(self, x1, x2):
195                # Use the parameter, buffers, and both inputs in the forward method
196                output = (
197                    x1 + self.my_parameter
198                ) * self.my_buffer1 + x2 * self.my_buffer2
199
200                # Mutate one of the buffers (e.g., increment it by 1)
201                self.my_buffer2.add_(1.0)
202                return output
203
204        ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))
205
206        output_node = list(ep.graph.nodes)[-1]
207        output_node.args = (
208            (
209                output_node.args[0][0],
210                next(iter(ep.graph.nodes)),
211                output_node.args[0][1],
212            ),
213        )
214
215        with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"):
216            ep.validate()
217
218
219if __name__ == "__main__":
220    run_tests()
221