xref: /aosp_15_r20/external/executorch/exir/verification/test/test_verifier.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import unittest
10from contextlib import contextmanager
11from typing import Any
12
13import torch
14from executorch.exir import EdgeCompileConfig, to_edge
15
16from executorch.exir.dialects._ops import ops
17from torch._export.verifier import SpecViolationError
18from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
19from torch.export import export
20
21from ..verifier import EXIREdgeDialectVerifier
22
23
24class TestEdgeDialectVerifier(unittest.TestCase):
25    @contextmanager
26    def assertNotRaises(self, exc_type: Any) -> Any:
27        try:
28            yield None
29        except exc_type:
30            raise self.failureException("{} raised".format(exc_type.__name__))
31
32    def test_edge_verifier_check_valid_op_succeed_given_custom_op(self) -> None:
33        edge_op = ops.edge.quantized_decomposed.quantize_per_tensor.default
34        verifier = EXIREdgeDialectVerifier()
35        with self.assertNotRaises(SpecViolationError):
36            verifier.check_valid_edge_op(edge_op)
37            verifier.check_valid_op(edge_op)
38
39    def test_edge_verifier_enablement(self) -> None:
40        class M(torch.nn.Module):
41            def forward(self, x, y):
42                z = y.item()
43                torch._check(z > 0)
44                torch._check(z < 4)
45                return x[z : z + y.shape[0]]
46
47        ep = torch.export.export(M(), (torch.randn(10), torch.tensor([3])))
48
49        compile_config_with_disable_ir_validity = EdgeCompileConfig(
50            _check_ir_validity=False
51        )
52        edge_manager = to_edge(
53            ep, compile_config=compile_config_with_disable_ir_validity
54        )
55
56        normal_verifier = EXIREdgeDialectVerifier()
57        disable_ir_validity_verifier = EXIREdgeDialectVerifier(
58            compile_config_with_disable_ir_validity
59        )
60
61        # exported model can not pass normal verifier due to
62        # aten.sym_constrain_range.default is illegal to be edge op
63        with self.assertRaises(SpecViolationError):
64            normal_verifier(edge_manager.exported_program())
65
66        # exported model can pass disable_ir_validity_verifier due to verifier
67        # is disabled by compile_config_with_disable_ir_validity
68        # (_check_ir_validity=False). Noted that this verifation has been done
69        # when calling `to_edge`. Explicitly calling verifier here just for better
70        # demonstration and is unnecessary in real world for ir verification.
71        disable_ir_validity_verifier(edge_manager.exported_program())
72
73    def test_edge_verifier_check_edge_op(self) -> None:
74        class Model(torch.nn.Module):
75            def __init__(self):
76                super().__init__()
77
78            def forward(self, x: torch.Tensor) -> torch.Tensor:
79                return x.transpose(0, 1)
80
81        m = Model().eval()
82
83        example_input = (torch.zeros([2, 2]),)
84
85        export_model = export(m, example_input)
86
87        compile_config_without_edge_op = EdgeCompileConfig(
88            _use_edge_ops=False, _skip_dim_order=False
89        )
90
91        edge_manager = to_edge(
92            export_model, compile_config=compile_config_without_edge_op
93        )
94
95        normal_verifier = EXIREdgeDialectVerifier()
96        disable_edge_op_check_verifier = EXIREdgeDialectVerifier(
97            compile_config_without_edge_op
98        )
99
100        # exported model can not pass normal verifier due to
101        # incontiguous memory layout tensor is not supported in ET
102        with self.assertRaises(SpecViolationError):
103            normal_verifier(edge_manager.exported_program())
104
105        # exported model can pass disable_edge_op_check_verifier due to the
106        # incontiguous memory layout tensor verification is disabled by
107        # compile_config_without_edge_op (_use_edge_ops=False). Noted that this
108        # verifation has been done when calling `to_edge`. Explicitly calling
109        # verifier here just for better demonstration and is unnecessary
110        # in real world for ir verification.
111        disable_edge_op_check_verifier(edge_manager.exported_program())
112
113    def test_edge_verifier_check_valid_dim_order_graph(self) -> None:
114        class Model(torch.nn.Module):
115            def __init__(self):
116                super().__init__()
117
118            def forward(self, x: torch.Tensor) -> torch.Tensor:
119                t1 = x.to(dtype=torch.double, memory_format=torch.channels_last)
120                t2 = t1 + t1
121                return t1 * t2
122
123        m = Model().eval()
124
125        example_input = (
126            torch.rand_like(
127                torch.zeros([2, 2, 2, 2]),
128                dtype=torch.float32,
129                memory_format=torch.contiguous_format,
130            ),
131        )
132
133        export_model = export(m, example_input)
134
135        compile_config_with_dim_order = EdgeCompileConfig(_skip_dim_order=False)
136        compile_config_with_stride = EdgeCompileConfig(_skip_dim_order=True)
137
138        dim_order_edge_model = to_edge(
139            export_model, compile_config=compile_config_with_dim_order
140        )
141        stride_edge_model = to_edge(
142            export_model, compile_config=compile_config_with_stride
143        )
144
145        dim_order_verifier = EXIREdgeDialectVerifier(
146            edge_compile_config=compile_config_with_dim_order
147        )
148        stride_verifier = EXIREdgeDialectVerifier(
149            edge_compile_config=compile_config_with_stride
150        )
151
152        dim_order_verifier(dim_order_edge_model.exported_program())
153        stride_verifier(stride_edge_model.exported_program())
154
155        with self.assertRaises(SpecViolationError):
156            dim_order_verifier(stride_edge_model.exported_program())
157        with self.assertRaises(SpecViolationError):
158            stride_verifier(dim_order_edge_model.exported_program())
159