xref: /aosp_15_r20/external/executorch/exir/verification/arg_validator.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from collections import defaultdict
8from typing import Any, Dict, Optional, Sequence, Tuple
9
10import torch
11from executorch.exir.dialects.edge._ops import EdgeDialectFunctionSchema, EdgeOpOverload
12from executorch.exir.emit._emitter import _Argument, _Target
13from executorch.exir.error import ExportError, InternalError
14from torch._ops import HigherOrderOperator
15
16
17class RunHigherOrderOperatorError(Exception):
18    """
19    Raised when an we try to run delegate or other HigherOrderOperator in a graph module.
20    E.g., %executorch_call_delegate : [#users=1] = call_function[
21        target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %arg0_1), kwargs = {})
22    """
23
24    def __init__(self, message: str) -> None:
25        super().__init__(message)
26
27
28# pyre-ignore[13]: Attribute `node` is never initialized.
29class EdgeOpArgValidator(torch.fx.Interpreter):
30    """
31    Validate whether all the Tensor arguments passed to an operator are valid in terms of allowed dtype.
32    Expecting all the operators are EdgeOpOverload which contains the allowed dtype information.
33    Violating operators are being kept in self.violating_ops
34    """
35
36    node: torch.fx.Node
37
38    def __init__(self, graph_module: torch.fx.GraphModule) -> None:
39        super().__init__(graph_module)
40        self.violating_ops: Dict[EdgeOpOverload, Dict[str, Optional[torch.dtype]]] = (
41            defaultdict(dict)
42        )
43
44    def run_node(self, n: torch.fx.Node) -> None:
45        self.node = n
46        try:
47            ret = super().run_node(n)
48        except Exception as e:
49            if isinstance(e, (InternalError, ExportError, RunHigherOrderOperatorError)):
50                raise e
51            else:
52                raise InternalError(str(e)) from e
53        return ret
54
55    def _get_kernel_arg(self, schema_arg, schema_arg_idx, args, kwargs):
56        if schema_arg.name in kwargs:
57            kernel_arg = kwargs[schema_arg.name]
58        elif not schema_arg.kwarg_only and schema_arg_idx < len(args):
59            kernel_arg = args[schema_arg_idx]
60        else:
61            kernel_arg = schema_arg.default_value
62
63        return kernel_arg
64
65    def call_function(  # noqa: C901  # pyre-fixme[14]
66        self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
67    ) -> Any:
68        """
69        Go through all the node.target and validate their Tensor arguments are having the allowed dtypes.
70        """
71        if not isinstance(target, EdgeOpOverload) or not isinstance(
72            target._schema, EdgeDialectFunctionSchema
73        ):
74            if isinstance(target, HigherOrderOperator):
75                raise RunHigherOrderOperatorError("Can't run delegate")
76            return super().call_function(target, args, kwargs)  # pyre-fixme[6]
77
78        # TODO(gasoonjia): Update Optional[torch.dtype] to a concrete class to support mixed dtypes in tensorlist.
79        tensor_arg_types: Dict[str, Optional[torch.dtype]] = {}
80        for i, schema_arg in enumerate(target._schema.arguments):
81            if (
82                isinstance(schema_arg.type, torch.TensorType)
83                or schema_arg.type == torch.OptionalType.ofTensor()
84            ):
85                kernel_arg = self._get_kernel_arg(schema_arg, i, args, kwargs)
86                if not isinstance(kernel_arg, torch.Tensor):
87                    continue
88                tensor_arg_types[schema_arg.name] = kernel_arg.dtype
89            elif schema_arg.type == torch.ListType.ofTensors():
90                kernel_arg = self._get_kernel_arg(schema_arg, i, args, kwargs)
91                if not isinstance(kernel_arg, list) or not all(
92                    isinstance(kernel_arg[i], torch.Tensor)
93                    for i in range(len(kernel_arg))
94                ):
95                    continue
96                if len(kernel_arg):
97                    tensor_arg_types[schema_arg.name] = kernel_arg[0].dtype
98                else:
99                    # If kernel_arg is an empty list, treat its type as None.
100                    # FunctionDtypeConstraint.validate will take None as any legal dtype.
101                    tensor_arg_types[schema_arg.name] = None
102
103        ret_index = 0
104        kernel_rets = self.node.meta["val"]
105        ret_iter = iter(
106            kernel_rets if isinstance(kernel_rets, Sequence) else [kernel_rets]
107        )
108        for schema_ret in target._schema.returns:
109            name = schema_ret.name if schema_ret.name else f"__ret_{ret_index}"
110            kernel_ret = next(ret_iter)
111            # Return value should not be in OptionalTensor type, so only check torch.TensorType here.
112            if isinstance(schema_ret.type, torch.TensorType) and isinstance(
113                kernel_ret, torch.Tensor
114            ):
115                tensor_arg_types[name] = kernel_ret.dtype
116                ret_index += 1
117            elif schema_ret.type == torch.ListType.ofTensors() and all(
118                isinstance(kernel_ret[i], torch.Tensor) for i in range(len(kernel_ret))
119            ):
120                if len(kernel_ret):
121                    tensor_arg_types[name] = kernel_ret[0].dtype
122                else:
123                    tensor_arg_types[name] = None
124                ret_index += 1
125
126        valid = target._schema.dtype_constraint.validate(tensor_arg_types)
127        if not valid:
128            self.violating_ops[target] = tensor_arg_types
129        return super().call_function(target, args, kwargs)  # pyre-fixme[6]
130