xref: /aosp_15_r20/external/executorch/exir/dialects/edge/dtype/runner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import itertools
8from typing import Any, Dict, List, Optional, Tuple
9
10import torch
11import torch.testing._internal.common_dtype as common_dtype
12from executorch.exir.dialects.edge.arg.model import ArgMode, BaseArg, BaseKwarg
13from executorch.exir.dialects.edge.arg.type import ArgType
14from executorch.exir.dialects.edge.dtype.utils import extract_return_dtype
15from executorch.exir.dialects.edge.op.api import get_callable
16
17
18class DtypeRunner:
19    def __init__(self):
20        self.tensor_dtypes = list(common_dtype.all_types_and(torch.bool, torch.half))
21        self.scalar_dtypes = [torch.bool, torch.int, torch.float]
22
23    @staticmethod
24    def _get_types(inputs: Dict[str, List[BaseArg]]) -> List[ArgType]:
25        """Given inputs, return a list of argument types."""
26        return [arg.type for arg in inputs["args"] if arg.type.has_dtype()]
27
28    @staticmethod
29    def _get_args_kwargs(
30        inputs: Dict[str, List[BaseArg]],
31        dtypes: Tuple[Optional[torch.dtype]],
32        mode: ArgMode,
33    ) -> Tuple[List[BaseArg], Dict[str, BaseKwarg]]:
34        """Construct args and kwargs for op given dtypes."""
35        args = []
36        kwargs = {}
37        counter = 0
38        for arg in inputs["args"]:
39            arg.mode = mode
40            val = arg.get_val()
41            if arg.type.has_dtype():
42                val = arg.get_val_with_dtype(dtypes[counter])
43                counter += 1
44            if arg.kw and isinstance(arg, BaseKwarg):
45                kwargs[arg.argname] = val
46            else:
47                args.append(val)
48        return args, kwargs
49
50    def _get_type_tuples(
51        self, inputs: Dict[str, List[BaseArg]]
52    ) -> List[List[Optional[torch.dtype]]]:
53        types = DtypeRunner._get_types(inputs)
54
55        def mapping(t):
56            type_dtypes = []
57            if t.is_optional():
58                type_dtypes = [None]
59            if t.is_scalar():
60                return type_dtypes + self.scalar_dtypes
61            elif t.is_scalar_type() or t.is_tensor() or t.is_tensor_list():
62                return type_dtypes + self.tensor_dtypes
63            else:
64                raise ValueError("Type {t.name} does not have dtype")
65
66        return list(map(mapping, types))
67
68    def run_dtypes(
69        self,
70        name: str,
71        inputs: Dict[str, List[BaseArg]],
72        dtypes: Tuple[Optional[torch.dtype]],
73        argmode: ArgMode = ArgMode.RANDOM,
74    ) -> Tuple[
75        bool, str, Tuple[Optional[torch.dtype]], List[BaseArg], Dict[str, BaseKwarg]
76    ]:
77        args, kwargs = DtypeRunner._get_args_kwargs(inputs, dtypes, argmode)
78        op = get_callable(name)
79        try:
80            res = op(*args, **kwargs)
81            ret_dtypes = ()
82            if "returns" in inputs:
83                ret_dtypes = tuple(extract_return_dtype(res, inputs["returns"]))
84            return (True, name, dtypes + ret_dtypes, args, kwargs)
85        except AssertionError as e:
86            raise RuntimeError(
87                f"opname: {name}, inputs: {inputs}, dtypes: {dtypes}, argmode {argmode}"
88            ) from e
89        except Exception as e:
90            if argmode == ArgMode.ONES:
91                return (False, name, dtypes, args, kwargs)
92            ones_args, ones_kwargs = DtypeRunner._get_args_kwargs(
93                inputs, dtypes, ArgMode.ONES
94            )
95            try:
96                res = op(*args, **kwargs)
97                ret_dtypes = ()
98                if "returns" in inputs:
99                    ret_dtypes = tuple(extract_return_dtype(res, inputs["returns"]))
100                print(e)
101                print(name, dtypes, args, kwargs)
102                return (True, name, dtypes + ret_dtypes, ones_args, ones_kwargs)
103            except Exception:
104                return (False, name, dtypes, ones_args, ones_kwargs)
105
106    def run(
107        self,
108        name: str,
109        inputs: Dict[str, Any],
110        argmode: ArgMode = ArgMode.ONES,
111    ) -> List[
112        Tuple[
113            bool, str, Tuple[Optional[torch.dtype]], List[BaseArg], Dict[str, BaseKwarg]
114        ]
115    ]:
116        results = []
117        type_tuples = self._get_type_tuples(inputs)
118        for element in itertools.product(*type_tuples):
119            results.append(self.run_dtypes(name, inputs, element, argmode))
120        return results
121