xref: /aosp_15_r20/external/executorch/exir/tests/test_verification.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport unittest
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport torch
12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import to_edge
13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.const_prop_pass import ConstPropPass
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import Tensor, TensorList
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.interpreter import Interpreter
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.verifier import EXIREdgeDialectVerifier
18*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.verifier import SpecViolationError
19*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Workerclass WrapperModule(torch.nn.Module):
23*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, fn):
24*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
25*523fa7a6SAndroid Build Coastguard Worker        self.fn = fn
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker    def forward(self, *args, **kwargs):
28*523fa7a6SAndroid Build Coastguard Worker        return self.fn(*args, **kwargs)
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Workerclass TestVerification(unittest.TestCase):
32*523fa7a6SAndroid Build Coastguard Worker    def test_constant_buffer(self) -> None:
33*523fa7a6SAndroid Build Coastguard Worker        def f(x: torch.Tensor) -> torch.Tensor:
34*523fa7a6SAndroid Build Coastguard Worker            return torch.ones(2) + x + torch.ones(2)
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker        # Generate program
37*523fa7a6SAndroid Build Coastguard Worker        program = (
38*523fa7a6SAndroid Build Coastguard Worker            to_edge(export(WrapperModule(f), (torch.randn(2),)))
39*523fa7a6SAndroid Build Coastguard Worker            .transform(
40*523fa7a6SAndroid Build Coastguard Worker                [
41*523fa7a6SAndroid Build Coastguard Worker                    ConstPropPass(),
42*523fa7a6SAndroid Build Coastguard Worker                ]
43*523fa7a6SAndroid Build Coastguard Worker            )
44*523fa7a6SAndroid Build Coastguard Worker            .to_executorch()
45*523fa7a6SAndroid Build Coastguard Worker            ._emitter_output.program
46*523fa7a6SAndroid Build Coastguard Worker        )
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker        test = Interpreter(program)
49*523fa7a6SAndroid Build Coastguard Worker        for val_idx in range(len(test.execution_plan.values)):
50*523fa7a6SAndroid Build Coastguard Worker            val = test.execution_plan.values[val_idx].val
51*523fa7a6SAndroid Build Coastguard Worker            if not (
52*523fa7a6SAndroid Build Coastguard Worker                isinstance(val, Tensor) and val.data_buffer_idx == 0
53*523fa7a6SAndroid Build Coastguard Worker            ) and not isinstance(val, TensorList):
54*523fa7a6SAndroid Build Coastguard Worker                test.load_value(val_idx)
55*523fa7a6SAndroid Build Coastguard Worker        vlist = test.get_value_list()
56*523fa7a6SAndroid Build Coastguard Worker        for e in vlist:
57*523fa7a6SAndroid Build Coastguard Worker            if isinstance(e, torch.Tensor):
58*523fa7a6SAndroid Build Coastguard Worker                self.assertTrue(torch.allclose(e, torch.ones(2)))
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker        # asserting only 2 constant Tensors exist in value list
61*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(len([e for e in vlist if isinstance(e, torch.Tensor)]), 2)
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker    def test_operator_list(self) -> None:
64*523fa7a6SAndroid Build Coastguard Worker        class Op1(torch.nn.Module):
65*523fa7a6SAndroid Build Coastguard Worker            def __init__(self) -> None:
66*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
67*523fa7a6SAndroid Build Coastguard Worker                self.a = torch.ones(2, 2)
68*523fa7a6SAndroid Build Coastguard Worker                self.b = 2 * torch.ones(2, 2)
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
71*523fa7a6SAndroid Build Coastguard Worker                for _ in range(10):
72*523fa7a6SAndroid Build Coastguard Worker                    z = self.a * x  # mul
73*523fa7a6SAndroid Build Coastguard Worker                    y = z - self.b  # sub
74*523fa7a6SAndroid Build Coastguard Worker                return y
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker        class Op2(torch.nn.Module):
77*523fa7a6SAndroid Build Coastguard Worker            def __init__(self) -> None:
78*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
79*523fa7a6SAndroid Build Coastguard Worker                self.a = torch.ones(2, 2)
80*523fa7a6SAndroid Build Coastguard Worker                self.b = 2 * torch.ones(2, 2)
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
83*523fa7a6SAndroid Build Coastguard Worker                for _ in range(10):
84*523fa7a6SAndroid Build Coastguard Worker                    z = self.a % x  # remainder
85*523fa7a6SAndroid Build Coastguard Worker                    y = z / self.b  # div
86*523fa7a6SAndroid Build Coastguard Worker                    z = z + z  # add
87*523fa7a6SAndroid Build Coastguard Worker                return y + z
88*523fa7a6SAndroid Build Coastguard Worker
89*523fa7a6SAndroid Build Coastguard Worker        # Generate a program with Op1's operations (mul, sub)
90*523fa7a6SAndroid Build Coastguard Worker        model1 = Op1()
91*523fa7a6SAndroid Build Coastguard Worker        inputs = (torch.ones(2, 2),)
92*523fa7a6SAndroid Build Coastguard Worker        program = (
93*523fa7a6SAndroid Build Coastguard Worker            to_edge(export(model1, inputs)).to_executorch()._emitter_output.program
94*523fa7a6SAndroid Build Coastguard Worker        )
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker        # Initialize and test Interpreter -- assert that the operators are same as above
97*523fa7a6SAndroid Build Coastguard Worker        test = Interpreter(program)
98*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(
99*523fa7a6SAndroid Build Coastguard Worker            set(test.get_operators_list()),
100*523fa7a6SAndroid Build Coastguard Worker            {torch.ops.aten.mul.out, torch.ops.aten.sub.out},
101*523fa7a6SAndroid Build Coastguard Worker        )
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker        # Generate a program with Op2's operations (remainder, div, add_, add)
104*523fa7a6SAndroid Build Coastguard Worker        model2 = Op2()
105*523fa7a6SAndroid Build Coastguard Worker        inputs = (torch.ones(2, 2),)
106*523fa7a6SAndroid Build Coastguard Worker        program = (
107*523fa7a6SAndroid Build Coastguard Worker            to_edge(export(model2, inputs)).to_executorch()._emitter_output.program
108*523fa7a6SAndroid Build Coastguard Worker        )
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker        # Initialize and test Interpreter -- assert that the operators are same as above
111*523fa7a6SAndroid Build Coastguard Worker        test = Interpreter(program)
112*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(
113*523fa7a6SAndroid Build Coastguard Worker            set(test.get_operators_list()),
114*523fa7a6SAndroid Build Coastguard Worker            {
115*523fa7a6SAndroid Build Coastguard Worker                torch.ops.aten.remainder.Tensor_out,
116*523fa7a6SAndroid Build Coastguard Worker                torch.ops.aten.div.out,
117*523fa7a6SAndroid Build Coastguard Worker                torch.ops.aten.add.out,
118*523fa7a6SAndroid Build Coastguard Worker            },
119*523fa7a6SAndroid Build Coastguard Worker        )
120*523fa7a6SAndroid Build Coastguard Worker
121*523fa7a6SAndroid Build Coastguard Worker    def test_verification(self) -> None:
122*523fa7a6SAndroid Build Coastguard Worker        class Op2(torch.nn.Module):
123*523fa7a6SAndroid Build Coastguard Worker            def __init__(self) -> None:
124*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
125*523fa7a6SAndroid Build Coastguard Worker                self.a = torch.ones(2, 2)
126*523fa7a6SAndroid Build Coastguard Worker                self.b = 2 * torch.ones(2, 2)
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
129*523fa7a6SAndroid Build Coastguard Worker                for _ in range(10):
130*523fa7a6SAndroid Build Coastguard Worker                    z = self.a % x  # remainder
131*523fa7a6SAndroid Build Coastguard Worker                    y = z / self.b  # div
132*523fa7a6SAndroid Build Coastguard Worker                    z = z + z  # add
133*523fa7a6SAndroid Build Coastguard Worker                return y + z
134*523fa7a6SAndroid Build Coastguard Worker
135*523fa7a6SAndroid Build Coastguard Worker        # Generate a program with Op2's operations (remainder, div, add)
136*523fa7a6SAndroid Build Coastguard Worker        model2 = Op2()
137*523fa7a6SAndroid Build Coastguard Worker        inputs = torch.ones(2, 2)
138*523fa7a6SAndroid Build Coastguard Worker        exec_prog = to_edge(export(model2, (inputs,))).to_executorch()
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker        exported_prog = exec_prog.exported_program()
141*523fa7a6SAndroid Build Coastguard Worker        res = exported_prog.module()(inputs)[0]  # noqa
142*523fa7a6SAndroid Build Coastguard Worker        # Verifiers are run internally in to_edge, export, and to_executorch.
143*523fa7a6SAndroid Build Coastguard Worker        # If we make it this far then no errors were thrown in verification
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Workerclass TestEdgeVerification(unittest.TestCase):
147*523fa7a6SAndroid Build Coastguard Worker    def test_edge_happy(self) -> None:
148*523fa7a6SAndroid Build Coastguard Worker        class TestModel(torch.nn.Module):
149*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
150*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
151*523fa7a6SAndroid Build Coastguard Worker                self.register_buffer("a", torch.randn(1, 3, 100, 100))
152*523fa7a6SAndroid Build Coastguard Worker
153*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
154*523fa7a6SAndroid Build Coastguard Worker                b = self.a + x
155*523fa7a6SAndroid Build Coastguard Worker                return torch._to_cpu([b, x])
156*523fa7a6SAndroid Build Coastguard Worker
157*523fa7a6SAndroid Build Coastguard Worker        m = TestModel()
158*523fa7a6SAndroid Build Coastguard Worker        egm = (
159*523fa7a6SAndroid Build Coastguard Worker            to_edge(
160*523fa7a6SAndroid Build Coastguard Worker                export(
161*523fa7a6SAndroid Build Coastguard Worker                    m,
162*523fa7a6SAndroid Build Coastguard Worker                    (torch.randn(1, 3, 100, 100).to(dtype=torch.int),),
163*523fa7a6SAndroid Build Coastguard Worker                )
164*523fa7a6SAndroid Build Coastguard Worker            )
165*523fa7a6SAndroid Build Coastguard Worker            .exported_program()
166*523fa7a6SAndroid Build Coastguard Worker            .graph_module
167*523fa7a6SAndroid Build Coastguard Worker        )
168*523fa7a6SAndroid Build Coastguard Worker        verifier = EXIREdgeDialectVerifier()
169*523fa7a6SAndroid Build Coastguard Worker        verifier(egm)
170*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(verifier.is_valid(egm))
171*523fa7a6SAndroid Build Coastguard Worker
172*523fa7a6SAndroid Build Coastguard Worker    def test_edge_happy_with_optional_tensor_input(self) -> None:
173*523fa7a6SAndroid Build Coastguard Worker        class TestModel(torch.nn.Module):
174*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
175*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
176*523fa7a6SAndroid Build Coastguard Worker
177*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x, weight, bias):
178*523fa7a6SAndroid Build Coastguard Worker                # weight and bias here are optional tensor inputs.
179*523fa7a6SAndroid Build Coastguard Worker                return torch.group_norm(x, 4, weight, bias)
180*523fa7a6SAndroid Build Coastguard Worker
181*523fa7a6SAndroid Build Coastguard Worker        m = TestModel()
182*523fa7a6SAndroid Build Coastguard Worker        egm = (
183*523fa7a6SAndroid Build Coastguard Worker            to_edge(
184*523fa7a6SAndroid Build Coastguard Worker                export(
185*523fa7a6SAndroid Build Coastguard Worker                    m,
186*523fa7a6SAndroid Build Coastguard Worker                    (torch.rand(16, 8, 32, 32), torch.rand(8), torch.rand(8)),
187*523fa7a6SAndroid Build Coastguard Worker                )
188*523fa7a6SAndroid Build Coastguard Worker            )
189*523fa7a6SAndroid Build Coastguard Worker            .exported_program()
190*523fa7a6SAndroid Build Coastguard Worker            .graph_module
191*523fa7a6SAndroid Build Coastguard Worker        )
192*523fa7a6SAndroid Build Coastguard Worker        verifier = EXIREdgeDialectVerifier()
193*523fa7a6SAndroid Build Coastguard Worker        verifier(egm)
194*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(verifier.is_valid(egm))
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker    def test_edge_happy_with_empty_tensorlist_input(self) -> None:
197*523fa7a6SAndroid Build Coastguard Worker        class TestModel(torch.nn.Module):
198*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
199*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
202*523fa7a6SAndroid Build Coastguard Worker                return torch._to_cpu(x)
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker        m = TestModel()
205*523fa7a6SAndroid Build Coastguard Worker        egm = (
206*523fa7a6SAndroid Build Coastguard Worker            to_edge(
207*523fa7a6SAndroid Build Coastguard Worker                export(
208*523fa7a6SAndroid Build Coastguard Worker                    m,
209*523fa7a6SAndroid Build Coastguard Worker                    ([],),
210*523fa7a6SAndroid Build Coastguard Worker                )
211*523fa7a6SAndroid Build Coastguard Worker            )
212*523fa7a6SAndroid Build Coastguard Worker            .exported_program()
213*523fa7a6SAndroid Build Coastguard Worker            .graph_module
214*523fa7a6SAndroid Build Coastguard Worker        )
215*523fa7a6SAndroid Build Coastguard Worker        verifier = EXIREdgeDialectVerifier()
216*523fa7a6SAndroid Build Coastguard Worker        verifier(egm)
217*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(verifier.is_valid(egm))
218*523fa7a6SAndroid Build Coastguard Worker
219*523fa7a6SAndroid Build Coastguard Worker    def test_edge_sad(self) -> None:
220*523fa7a6SAndroid Build Coastguard Worker        class TestModel(torch.nn.Module):
221*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
222*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
223*523fa7a6SAndroid Build Coastguard Worker                self.register_buffer("a", torch.randn(1, 3, 100, 100))
224*523fa7a6SAndroid Build Coastguard Worker
225*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
226*523fa7a6SAndroid Build Coastguard Worker                b = self.a + x
227*523fa7a6SAndroid Build Coastguard Worker                return torch._to_cpu([b, x])
228*523fa7a6SAndroid Build Coastguard Worker
229*523fa7a6SAndroid Build Coastguard Worker        m = TestModel()
230*523fa7a6SAndroid Build Coastguard Worker        egm = export(
231*523fa7a6SAndroid Build Coastguard Worker            m,
232*523fa7a6SAndroid Build Coastguard Worker            (torch.randn(1, 3, 100, 100).to(dtype=torch.int),),
233*523fa7a6SAndroid Build Coastguard Worker        ).graph_module
234*523fa7a6SAndroid Build Coastguard Worker        verifier = EXIREdgeDialectVerifier()
235*523fa7a6SAndroid Build Coastguard Worker        with self.assertRaises(SpecViolationError):
236*523fa7a6SAndroid Build Coastguard Worker            verifier(egm)
237*523fa7a6SAndroid Build Coastguard Worker
238*523fa7a6SAndroid Build Coastguard Worker    def test_edge_happy_with_edge_ops(self) -> None:
239*523fa7a6SAndroid Build Coastguard Worker        class TestModel(torch.nn.Module):
240*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
241*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
244*523fa7a6SAndroid Build Coastguard Worker                return x + x
245*523fa7a6SAndroid Build Coastguard Worker
246*523fa7a6SAndroid Build Coastguard Worker        m = TestModel()
247*523fa7a6SAndroid Build Coastguard Worker        egm = (
248*523fa7a6SAndroid Build Coastguard Worker            to_edge(
249*523fa7a6SAndroid Build Coastguard Worker                export(
250*523fa7a6SAndroid Build Coastguard Worker                    m,
251*523fa7a6SAndroid Build Coastguard Worker                    (torch.randn(1, 3, 100, 100).to(dtype=torch.int),),
252*523fa7a6SAndroid Build Coastguard Worker                )
253*523fa7a6SAndroid Build Coastguard Worker            )
254*523fa7a6SAndroid Build Coastguard Worker            .exported_program()
255*523fa7a6SAndroid Build Coastguard Worker            .graph_module
256*523fa7a6SAndroid Build Coastguard Worker        )
257*523fa7a6SAndroid Build Coastguard Worker        verifier = EXIREdgeDialectVerifier()
258*523fa7a6SAndroid Build Coastguard Worker        verifier(egm)
259*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(verifier.is_valid(egm))
260*523fa7a6SAndroid Build Coastguard Worker
261*523fa7a6SAndroid Build Coastguard Worker    def test_edge_sad_with_edge_ops(self) -> None:
262*523fa7a6SAndroid Build Coastguard Worker        # log_softmax only takes float or double Tensor
263*523fa7a6SAndroid Build Coastguard Worker        m = torch.nn.LogSoftmax(dim=1)
264*523fa7a6SAndroid Build Coastguard Worker        with self.assertRaises(SpecViolationError):
265*523fa7a6SAndroid Build Coastguard Worker            _ = (
266*523fa7a6SAndroid Build Coastguard Worker                to_edge(
267*523fa7a6SAndroid Build Coastguard Worker                    export(
268*523fa7a6SAndroid Build Coastguard Worker                        m,
269*523fa7a6SAndroid Build Coastguard Worker                        (torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),),
270*523fa7a6SAndroid Build Coastguard Worker                    )
271*523fa7a6SAndroid Build Coastguard Worker                )
272*523fa7a6SAndroid Build Coastguard Worker                .exported_program()
273*523fa7a6SAndroid Build Coastguard Worker                .graph_module
274*523fa7a6SAndroid Build Coastguard Worker            )
275