1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import itertools 8from typing import Any, Dict, List, Optional, Tuple 9 10import torch 11import torch.testing._internal.common_dtype as common_dtype 12from executorch.exir.dialects.edge.arg.model import ArgMode, BaseArg, BaseKwarg 13from executorch.exir.dialects.edge.arg.type import ArgType 14from executorch.exir.dialects.edge.dtype.utils import extract_return_dtype 15from executorch.exir.dialects.edge.op.api import get_callable 16 17 18class DtypeRunner: 19 def __init__(self): 20 self.tensor_dtypes = list(common_dtype.all_types_and(torch.bool, torch.half)) 21 self.scalar_dtypes = [torch.bool, torch.int, torch.float] 22 23 @staticmethod 24 def _get_types(inputs: Dict[str, List[BaseArg]]) -> List[ArgType]: 25 """Given inputs, return a list of argument types.""" 26 return [arg.type for arg in inputs["args"] if arg.type.has_dtype()] 27 28 @staticmethod 29 def _get_args_kwargs( 30 inputs: Dict[str, List[BaseArg]], 31 dtypes: Tuple[Optional[torch.dtype]], 32 mode: ArgMode, 33 ) -> Tuple[List[BaseArg], Dict[str, BaseKwarg]]: 34 """Construct args and kwargs for op given dtypes.""" 35 args = [] 36 kwargs = {} 37 counter = 0 38 for arg in inputs["args"]: 39 arg.mode = mode 40 val = arg.get_val() 41 if arg.type.has_dtype(): 42 val = arg.get_val_with_dtype(dtypes[counter]) 43 counter += 1 44 if arg.kw and isinstance(arg, BaseKwarg): 45 kwargs[arg.argname] = val 46 else: 47 args.append(val) 48 return args, kwargs 49 50 def _get_type_tuples( 51 self, inputs: Dict[str, List[BaseArg]] 52 ) -> List[List[Optional[torch.dtype]]]: 53 types = DtypeRunner._get_types(inputs) 54 55 def mapping(t): 56 type_dtypes = [] 57 if t.is_optional(): 58 type_dtypes = [None] 59 if t.is_scalar(): 60 return type_dtypes + self.scalar_dtypes 61 elif t.is_scalar_type() or t.is_tensor() or t.is_tensor_list(): 62 return type_dtypes + self.tensor_dtypes 63 else: 64 raise ValueError("Type {t.name} does not have dtype") 65 66 return list(map(mapping, types)) 67 68 def run_dtypes( 69 self, 70 name: str, 71 inputs: Dict[str, List[BaseArg]], 72 dtypes: Tuple[Optional[torch.dtype]], 73 argmode: ArgMode = ArgMode.RANDOM, 74 ) -> Tuple[ 75 bool, str, Tuple[Optional[torch.dtype]], List[BaseArg], Dict[str, BaseKwarg] 76 ]: 77 args, kwargs = DtypeRunner._get_args_kwargs(inputs, dtypes, argmode) 78 op = get_callable(name) 79 try: 80 res = op(*args, **kwargs) 81 ret_dtypes = () 82 if "returns" in inputs: 83 ret_dtypes = tuple(extract_return_dtype(res, inputs["returns"])) 84 return (True, name, dtypes + ret_dtypes, args, kwargs) 85 except AssertionError as e: 86 raise RuntimeError( 87 f"opname: {name}, inputs: {inputs}, dtypes: {dtypes}, argmode {argmode}" 88 ) from e 89 except Exception as e: 90 if argmode == ArgMode.ONES: 91 return (False, name, dtypes, args, kwargs) 92 ones_args, ones_kwargs = DtypeRunner._get_args_kwargs( 93 inputs, dtypes, ArgMode.ONES 94 ) 95 try: 96 res = op(*args, **kwargs) 97 ret_dtypes = () 98 if "returns" in inputs: 99 ret_dtypes = tuple(extract_return_dtype(res, inputs["returns"])) 100 print(e) 101 print(name, dtypes, args, kwargs) 102 return (True, name, dtypes + ret_dtypes, ones_args, ones_kwargs) 103 except Exception: 104 return (False, name, dtypes, ones_args, ones_kwargs) 105 106 def run( 107 self, 108 name: str, 109 inputs: Dict[str, Any], 110 argmode: ArgMode = ArgMode.ONES, 111 ) -> List[ 112 Tuple[ 113 bool, str, Tuple[Optional[torch.dtype]], List[BaseArg], Dict[str, BaseKwarg] 114 ] 115 ]: 116 results = [] 117 type_tuples = self._get_type_tuples(inputs) 118 for element in itertools.product(*type_tuples): 119 results.append(self.run_dtypes(name, inputs, element, argmode)) 120 return results 121