xref: /aosp_15_r20/external/executorch/exir/tests/test_verification.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import unittest
10
11import torch
12from executorch.exir import to_edge
13from executorch.exir.passes.const_prop_pass import ConstPropPass
14from executorch.exir.schema import Tensor, TensorList
15
16from executorch.exir.verification.interpreter import Interpreter
17from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
18from torch._export.verifier import SpecViolationError
19from torch.export import export
20
21
22class WrapperModule(torch.nn.Module):
23    def __init__(self, fn):
24        super().__init__()
25        self.fn = fn
26
27    def forward(self, *args, **kwargs):
28        return self.fn(*args, **kwargs)
29
30
31class TestVerification(unittest.TestCase):
32    def test_constant_buffer(self) -> None:
33        def f(x: torch.Tensor) -> torch.Tensor:
34            return torch.ones(2) + x + torch.ones(2)
35
36        # Generate program
37        program = (
38            to_edge(export(WrapperModule(f), (torch.randn(2),)))
39            .transform(
40                [
41                    ConstPropPass(),
42                ]
43            )
44            .to_executorch()
45            ._emitter_output.program
46        )
47
48        test = Interpreter(program)
49        for val_idx in range(len(test.execution_plan.values)):
50            val = test.execution_plan.values[val_idx].val
51            if not (
52                isinstance(val, Tensor) and val.data_buffer_idx == 0
53            ) and not isinstance(val, TensorList):
54                test.load_value(val_idx)
55        vlist = test.get_value_list()
56        for e in vlist:
57            if isinstance(e, torch.Tensor):
58                self.assertTrue(torch.allclose(e, torch.ones(2)))
59
60        # asserting only 2 constant Tensors exist in value list
61        self.assertEqual(len([e for e in vlist if isinstance(e, torch.Tensor)]), 2)
62
63    def test_operator_list(self) -> None:
64        class Op1(torch.nn.Module):
65            def __init__(self) -> None:
66                super().__init__()
67                self.a = torch.ones(2, 2)
68                self.b = 2 * torch.ones(2, 2)
69
70            def forward(self, x: torch.Tensor) -> torch.Tensor:
71                for _ in range(10):
72                    z = self.a * x  # mul
73                    y = z - self.b  # sub
74                return y
75
76        class Op2(torch.nn.Module):
77            def __init__(self) -> None:
78                super().__init__()
79                self.a = torch.ones(2, 2)
80                self.b = 2 * torch.ones(2, 2)
81
82            def forward(self, x: torch.Tensor) -> torch.Tensor:
83                for _ in range(10):
84                    z = self.a % x  # remainder
85                    y = z / self.b  # div
86                    z = z + z  # add
87                return y + z
88
89        # Generate a program with Op1's operations (mul, sub)
90        model1 = Op1()
91        inputs = (torch.ones(2, 2),)
92        program = (
93            to_edge(export(model1, inputs)).to_executorch()._emitter_output.program
94        )
95
96        # Initialize and test Interpreter -- assert that the operators are same as above
97        test = Interpreter(program)
98        self.assertEqual(
99            set(test.get_operators_list()),
100            {torch.ops.aten.mul.out, torch.ops.aten.sub.out},
101        )
102
103        # Generate a program with Op2's operations (remainder, div, add_, add)
104        model2 = Op2()
105        inputs = (torch.ones(2, 2),)
106        program = (
107            to_edge(export(model2, inputs)).to_executorch()._emitter_output.program
108        )
109
110        # Initialize and test Interpreter -- assert that the operators are same as above
111        test = Interpreter(program)
112        self.assertEqual(
113            set(test.get_operators_list()),
114            {
115                torch.ops.aten.remainder.Tensor_out,
116                torch.ops.aten.div.out,
117                torch.ops.aten.add.out,
118            },
119        )
120
121    def test_verification(self) -> None:
122        class Op2(torch.nn.Module):
123            def __init__(self) -> None:
124                super().__init__()
125                self.a = torch.ones(2, 2)
126                self.b = 2 * torch.ones(2, 2)
127
128            def forward(self, x: torch.Tensor) -> torch.Tensor:
129                for _ in range(10):
130                    z = self.a % x  # remainder
131                    y = z / self.b  # div
132                    z = z + z  # add
133                return y + z
134
135        # Generate a program with Op2's operations (remainder, div, add)
136        model2 = Op2()
137        inputs = torch.ones(2, 2)
138        exec_prog = to_edge(export(model2, (inputs,))).to_executorch()
139
140        exported_prog = exec_prog.exported_program()
141        res = exported_prog.module()(inputs)[0]  # noqa
142        # Verifiers are run internally in to_edge, export, and to_executorch.
143        # If we make it this far then no errors were thrown in verification
144
145
146class TestEdgeVerification(unittest.TestCase):
147    def test_edge_happy(self) -> None:
148        class TestModel(torch.nn.Module):
149            def __init__(self):
150                super().__init__()
151                self.register_buffer("a", torch.randn(1, 3, 100, 100))
152
153            def forward(self, x):
154                b = self.a + x
155                return torch._to_cpu([b, x])
156
157        m = TestModel()
158        egm = (
159            to_edge(
160                export(
161                    m,
162                    (torch.randn(1, 3, 100, 100).to(dtype=torch.int),),
163                )
164            )
165            .exported_program()
166            .graph_module
167        )
168        verifier = EXIREdgeDialectVerifier()
169        verifier(egm)
170        self.assertTrue(verifier.is_valid(egm))
171
172    def test_edge_happy_with_optional_tensor_input(self) -> None:
173        class TestModel(torch.nn.Module):
174            def __init__(self):
175                super().__init__()
176
177            def forward(self, x, weight, bias):
178                # weight and bias here are optional tensor inputs.
179                return torch.group_norm(x, 4, weight, bias)
180
181        m = TestModel()
182        egm = (
183            to_edge(
184                export(
185                    m,
186                    (torch.rand(16, 8, 32, 32), torch.rand(8), torch.rand(8)),
187                )
188            )
189            .exported_program()
190            .graph_module
191        )
192        verifier = EXIREdgeDialectVerifier()
193        verifier(egm)
194        self.assertTrue(verifier.is_valid(egm))
195
196    def test_edge_happy_with_empty_tensorlist_input(self) -> None:
197        class TestModel(torch.nn.Module):
198            def __init__(self):
199                super().__init__()
200
201            def forward(self, x):
202                return torch._to_cpu(x)
203
204        m = TestModel()
205        egm = (
206            to_edge(
207                export(
208                    m,
209                    ([],),
210                )
211            )
212            .exported_program()
213            .graph_module
214        )
215        verifier = EXIREdgeDialectVerifier()
216        verifier(egm)
217        self.assertTrue(verifier.is_valid(egm))
218
219    def test_edge_sad(self) -> None:
220        class TestModel(torch.nn.Module):
221            def __init__(self):
222                super().__init__()
223                self.register_buffer("a", torch.randn(1, 3, 100, 100))
224
225            def forward(self, x):
226                b = self.a + x
227                return torch._to_cpu([b, x])
228
229        m = TestModel()
230        egm = export(
231            m,
232            (torch.randn(1, 3, 100, 100).to(dtype=torch.int),),
233        ).graph_module
234        verifier = EXIREdgeDialectVerifier()
235        with self.assertRaises(SpecViolationError):
236            verifier(egm)
237
238    def test_edge_happy_with_edge_ops(self) -> None:
239        class TestModel(torch.nn.Module):
240            def __init__(self):
241                super().__init__()
242
243            def forward(self, x):
244                return x + x
245
246        m = TestModel()
247        egm = (
248            to_edge(
249                export(
250                    m,
251                    (torch.randn(1, 3, 100, 100).to(dtype=torch.int),),
252                )
253            )
254            .exported_program()
255            .graph_module
256        )
257        verifier = EXIREdgeDialectVerifier()
258        verifier(egm)
259        self.assertTrue(verifier.is_valid(egm))
260
261    def test_edge_sad_with_edge_ops(self) -> None:
262        # log_softmax only takes float or double Tensor
263        m = torch.nn.LogSoftmax(dim=1)
264        with self.assertRaises(SpecViolationError):
265            _ = (
266                to_edge(
267                    export(
268                        m,
269                        (torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),),
270                    )
271                )
272                .exported_program()
273                .graph_module
274            )
275