xref: /aosp_15_r20/external/executorch/exir/verification/verifier.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import itertools
8import operator
9import types
10from contextlib import nullcontext
11from typing import Any, List, Optional, Tuple, Type
12
13import torch
14from executorch.exir.capture._config import EdgeCompileConfig
15from executorch.exir.dialects.edge._ops import EdgeOpOverload
16from executorch.exir.error import ExportError, ExportErrorType
17from executorch.exir.lowered_backend_module import LoweredBackendModule
18from executorch.exir.verification.arg_validator import (
19    EdgeOpArgValidator,
20    RunHigherOrderOperatorError,
21)
22from torch._dispatch.python import enable_python_dispatcher
23from torch._export.utils import _detect_fake_mode_from_gm
24
25from torch._export.verifier import SpecViolationError, Verifier
26from torch._ops import OpOverload
27from torch._subclasses import FakeTensor
28from torch.export.exported_program import ExportedProgram
29from torch.fx import GraphModule
30
31
32ALLOWED_META_KEYS = {"spec", "stack_trace"}
33
34
35def _check_tensors_are_contiguous(gm: GraphModule) -> None:
36    # Tensors be of contiguous format
37    for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()):
38        if isinstance(param, torch.Tensor):
39            if not param.is_contiguous():
40                raise SpecViolationError(
41                    f"Tensors in Aten dialect must be contiguous, {name} is not contiguous"
42                )
43
44
45def _check_valid_dim_order_ops(op, use_dim_order) -> None:
46    if use_dim_order:
47        if op in (torch.ops.aten._to_copy.default,):
48            raise SpecViolationError(f"{op} should not be used in dim_order mode")
49    else:  # not using dim_order
50        if op.namespace in ("dim_order_ops",):
51            raise SpecViolationError(f"{op} should not be used in non-dim_order mode")
52
53
54class EXIRATenDialectVerifierBase(Verifier):
55    dialect = "OLD_EXIR_ATEN_DISABLED"
56
57    def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
58        return (
59            torch.fx.GraphModule,
60            LoweredBackendModule,
61            torch.Tensor,
62            torch.ScriptObject,
63        )
64
65    def allowed_op_types(self):
66        return super().allowed_op_types() + (torch._ops.OpOverloadPacket,)
67
68    def __call__(self, *args, **kwargs):
69        if hasattr(self, "_check_graph_module"):
70            return self._check_graph_module(*args, **kwargs)
71        elif hasattr(self, "check_valid"):
72            return self.check_valid(*args, **kwargs)
73        else:
74            raise RuntimeError("")
75
76
77def EXIRATenDialectVerifier(  # noqa: C901
78    edge_compile_config: Optional[EdgeCompileConfig] = None,
79    class_only: bool = False,
80    exception_list: Optional[List[torch._ops.OpOverload]] = None,
81):
82    """
83    Returns a verifier class that runs ATen dialect specific checks on the graph module.
84    """
85    # merge the exception list from edge_compile_config and exception_list
86    if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
87        exception_list = edge_compile_config._core_aten_ops_exception_list + (
88            exception_list or []
89        )
90
91    class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
92        dialect = "OLD_EXIR_ATEN"
93
94        def __init__(self) -> None:
95            super().__init__()
96            # Note: here we are using the exception list passed from EXIRATenDialectVerifier function!
97            self._exception_list = exception_list if exception_list else []
98
99        def _get_exception_list(self) -> List[torch._ops.OpOverload]:
100            exception_list = [
101                torch.ops.aten.mkldnn_rnn_layer.default,
102                torch.ops.aten._upsample_bilinear2d_aa.default,
103                torch.ops.aten.quantize_per_tensor.default,
104                torch.ops.aten.dequantize.self,
105                torch.ops.aten.max.default,  # TODO(T188268054)
106                torch.ops.aten.min.default,  # TODO(T188268054)
107                torch.ops.aten.full_like.default,  # TODO(T183507359)
108            ]
109            exception_list += self._exception_list
110
111            return exception_list
112
113        def check_valid_op(self, op):
114            if isinstance(op, OpOverload):
115                # TODO These special ops should be removable easily.
116                if op.namespace != "aten" or op in self._get_exception_list():
117                    return
118                if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
119                    # NOTE(qihan): whether view_copy operators are marked as canonical is still under
120                    #            discussion.
121                    raise SpecViolationError(
122                        f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
123                    )
124
125    ret = _EXIRATenDialectVerifier
126    if not class_only:
127        ret = ret()
128    return ret
129
130
131def get_aten_verifier(config: EdgeCompileConfig):
132    return (
133        EXIRATenDialectVerifier(
134            class_only=True, exception_list=config._core_aten_ops_exception_list
135        )
136        if config._check_ir_validity
137        else EXIRATenDialectVerifierBase
138    )
139
140
141def _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]:
142    def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
143        if "val" in node.meta:
144            return node.meta["val"]
145
146        if len(node.users) == 0:
147            return None
148
149        # TODO(ycao): `val` should always exist after we enable shape environment
150        # serialization and deserialization.
151        raise ExportError(
152            ExportErrorType.VIOLATION_OF_SPEC,
153            f"Cannot construct an input for graph module: {graph_module}.",
154        )
155
156    return [
157        extract_input(node)
158        for node in graph_module.graph.nodes
159        if node.op == "placeholder"
160    ]
161
162
163def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
164    validator = EdgeOpArgValidator(gm)
165    inputs = _get_inputs(gm)
166    fake_mode = _detect_fake_mode_from_gm(gm) or nullcontext()
167    try:
168        with enable_python_dispatcher(), fake_mode:
169            validator.run(*inputs)
170    except RunHigherOrderOperatorError:
171        # NB: ignore higher order operator in the graph.
172        # If we lower a graph module to delegate and then compose it with some other graph module, retrace it,
173        # if we also turn on edge ops and validator (_check_ir_validity=True), we will run
174        # into RunHigherOrderOperatorError. The only thing we can do right now is to ignore this error, since
175        # by definition it's still a valid Edge dialect. This is not ideal because it ignores possible invalidity
176        # later in the graph.
177        return
178
179    if validator.violating_ops:
180        raise SpecViolationError(
181            f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}"
182        )
183
184
185def EXIREdgeDialectVerifier(  # noqa: C901
186    edge_compile_config: Optional[EdgeCompileConfig] = None,
187    class_only: bool = False,
188    exception_list: Optional[List[torch._ops.OpOverload]] = None,
189):
190    # merge the exception list from edge_compile_config and exception_list
191    if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
192        exception_list = edge_compile_config._core_aten_ops_exception_list + (
193            exception_list or []
194        )
195
196    class _EXIREdgeDialectVerifier(Verifier):
197        dialect = "EDGE"
198
199        def __init__(self) -> None:
200            _edge_compile_config = edge_compile_config or EdgeCompileConfig()
201
202            self.enable = _edge_compile_config._check_ir_validity
203            self.check_edge_ops = _edge_compile_config._use_edge_ops
204            self.use_dim_order = not _edge_compile_config._skip_dim_order
205
206            self.aten_op_verifier = EXIRATenDialectVerifier(
207                exception_list=exception_list
208            )
209            self.check_valid_aten_op = self.aten_op_verifier.check_valid_op
210
211            if self.check_edge_ops:
212                self.check_valid_op = self.check_valid_edge_op
213            else:
214                self.check_valid_op = self.check_valid_aten_op
215            self._exception_list = exception_list if exception_list else []
216
217        def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
218            return (
219                torch.fx.GraphModule,
220                LoweredBackendModule,
221                torch.Tensor,
222                torch.ScriptObject,
223            )
224
225        def allowed_op_types(self):
226            return super().allowed_op_types() + (EdgeOpOverload, types.FunctionType)
227
228        def check_valid_edge_op(self, op):
229            if not self.enable:
230                return
231            if (
232                op
233                in [
234                    operator.getitem,
235                    torch.ops.aten.sym_size.int,
236                    torch.ops.aten.scalar_tensor.default,
237                    torch.ops.aten._assert_async.msg,
238                    torch.ops.aten._assert_scalar.default,
239                ]
240                + self._exception_list
241            ):
242                return
243
244            if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload):
245                raise SpecViolationError(
246                    "Operator {}.{} is not an Edge operator.".format(
247                        op.__module__, op.__name__
248                    )
249                )
250            if isinstance(op, EdgeOpOverload):
251                _check_valid_dim_order_ops(op._op, self.use_dim_order)
252                self.check_valid_aten_op(op._op)
253
254            if isinstance(op, types.FunctionType):
255                assert op.__name__ in ("alloc",)
256
257        def check_additional(self, gm: GraphModule) -> None:
258            if not self.enable:
259                return
260            if self.check_edge_ops:
261                _check_tensors_are_contiguous(gm)
262                _check_tensor_args_matching_op_allowed_dtype(gm)
263
264        def check_valid_op(self, op):
265            if isinstance(op, OpOverload):
266                # TODO These special ops should be removable easily.
267                if op.namespace in (
268                    "quantized_decomposed",
269                    "boltnn_nimble",
270                    "nimble",
271                    "quantized",
272                    "dim_order_ops",
273                ) or op in (
274                    torch.ops.aten.mkldnn_rnn_layer.default,
275                    torch.ops.aten._upsample_bilinear2d_aa.default,
276                    torch.ops.aten.quantize_per_tensor.default,
277                    torch.ops.aten.dequantize.self,
278                    torch.ops.aten.max.default,
279                    torch.ops.aten.full_like.default,  # TODO(T183507359)
280                ):
281                    return
282                if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
283                    # NOTE(qihan): whether view_copy operators are marked as canonical is still under
284                    #            discussion.
285                    raise SpecViolationError(
286                        f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
287                    )
288
289        def is_valid(self, gm: GraphModule) -> bool:
290            try:
291                self(gm)
292                return True
293            except SpecViolationError:
294                return False
295
296        def __call__(self, ep_or_gm):
297            if not self.enable:
298                return
299            gm = ep_or_gm
300            if isinstance(gm, ExportedProgram):
301                gm = ep_or_gm.graph_module
302            if hasattr(self, "_check_graph_module"):
303                return self._check_graph_module(gm)
304            elif hasattr(self, "check_valid"):
305                return self.check_valid(gm)
306            else:
307                raise RuntimeError("")
308
309    ret = _EXIREdgeDialectVerifier
310    if not class_only:
311        ret = ret()
312    return ret
313
314
315EXIREdgeDialectVerifier()
316