1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from enum import Enum 8from typing import Any, Optional 9 10import torch 11import torch.testing._internal.common_dtype as common_dtype 12 13from executorch.exir.dialects.edge.arg.type import ArgType 14 15 16class ArgMode(Enum): 17 ONES = 0 18 RANDOM = 1 19 20 21class BaseArg: 22 def __init__( 23 self, 24 argtype, 25 *, 26 value=None, 27 size=None, 28 fill=None, 29 dtype=None, 30 nonzero=False, 31 nonneg=False, 32 bounded=False, 33 ): 34 self.type: ArgType = argtype 35 36 self.value_given = value is not None 37 self.size_given = size is not None 38 self.fill_given = fill is not None 39 self.dtype_given = dtype is not None 40 41 self.value = value 42 self.size = (2, 2) if size is None else tuple(size) 43 self.fill = 1 if fill is None else fill 44 self.dtype = torch.float if dtype is None else dtype 45 46 self.nonzero = nonzero 47 self.nonneg = nonneg 48 self.bounded = bounded 49 50 self._mode: ArgMode = ArgMode.ONES 51 self._kw: bool = False 52 self._out: bool = False 53 54 @property 55 def mode(self): 56 return self._mode 57 58 @mode.setter 59 def mode(self, v): 60 if not isinstance(v, ArgMode): 61 raise ValueError("mode property should be type ArgMode") 62 self._mode = v 63 64 @property 65 def kw(self): 66 return self._kw 67 68 @kw.setter 69 def kw(self, v): 70 if not isinstance(v, bool): 71 raise ValueError("kw property should be boolean") 72 self._kw = v 73 74 @property 75 def out(self): 76 return self._out 77 78 @out.setter 79 def out(self, v): 80 if not isinstance(v, bool): 81 raise ValueError("out property should be boolean") 82 self._out = v 83 84 def get_random_tensor(self, size, dtype): 85 size = tuple(size) 86 if dtype == torch.bool: 87 if self.nonzero: 88 return torch.full(size, True, dtype=dtype) 89 else: 90 return torch.randint(low=0, high=2, size=size, dtype=dtype) 91 92 if dtype in common_dtype.integral_types(): 93 high = 100 94 elif dtype in common_dtype.floating_types(): 95 high = 800 96 else: 97 raise ValueError(f"Unsupported Dtype: {dtype}") 98 99 if dtype == torch.uint8: 100 if self.nonzero: 101 return torch.randint(low=1, high=high, size=size, dtype=dtype) 102 else: 103 return torch.randint(low=0, high=high, size=size, dtype=dtype) 104 105 t = torch.randint(low=-high, high=high, size=size, dtype=dtype) 106 if self.nonzero: 107 pos = torch.randint(low=1, high=high, size=size, dtype=dtype) 108 t = torch.where(t == 0, pos, t) 109 if self.nonneg or self.bounded: 110 t = torch.abs(t) 111 112 if dtype in common_dtype.integral_types(): 113 return t 114 if dtype in common_dtype.floating_types(): 115 return t / 8 116 117 def get_random_scalar(self, dtype): 118 return self.get_random_tensor([], dtype).item() 119 120 def get_converted_scalar(self, value, dtype): 121 if dtype == torch.bool: 122 return bool(value) 123 elif dtype in common_dtype.integral_types(): 124 return int(value) 125 elif dtype in common_dtype.floating_types(): 126 return float(value) 127 else: 128 raise ValueError(f"Unsupported Dtype: {dtype}") 129 130 def get_scalar_val_with_dtype(self, dtype): 131 if self.value_given: 132 return self.get_converted_scalar(self.value, dtype) 133 elif self._mode == ArgMode.RANDOM: 134 return self.get_random_scalar(dtype) 135 elif self._mode == ArgMode.ONES: 136 return self.get_converted_scalar(1, dtype) 137 else: 138 raise ValueError(f"Unsupported Mode: {self._mode}") 139 140 def get_tensor_val_with_dtype(self, dtype): 141 if self.value_given: 142 return torch.tensor(self.value, dtype=dtype) 143 elif self.fill_given: 144 return torch.full(self.size, self.fill, dtype=dtype) 145 elif self._mode == ArgMode.RANDOM: 146 return self.get_random_tensor(self.size, dtype=dtype) 147 elif self._mode == ArgMode.ONES: 148 return torch.full(self.size, 1, dtype=dtype) 149 elif self.size_given: 150 return torch.full(self.size, self.fill, dtype=dtype) 151 else: 152 raise ValueError(f"Unsupported Mode: {self._mode}") 153 154 def get_val_with_dtype(self, dtype): 155 if dtype is None: 156 return None 157 if self.type.is_scalar_type(): 158 return dtype 159 elif self.type.is_scalar(): 160 return self.get_scalar_val_with_dtype(dtype) 161 elif self.type.is_tensor(): 162 return self.get_tensor_val_with_dtype(dtype) 163 elif self.type.is_tensor_list(): 164 if not self.value_given: 165 return [] 166 return [x.get_val_with_dtype(dtype) for x in self.value] 167 else: 168 raise ValueError(f"Unsupported Type: {self.type}") 169 170 def get_val(self): 171 if self.type.has_dtype(): 172 return self.get_val_with_dtype(self.dtype) 173 else: 174 return self.value 175 176 177class BaseKwarg(BaseArg): 178 def __init__(self, argtype, argname, **kwargs): 179 BaseArg.__init__(self, argtype, **kwargs) 180 self.argname = argname 181 self._kw = True 182 183 @property 184 def kw(self): 185 return super().kw 186 187 188class InArg(BaseArg): 189 def __init__(self, *args, **kwargs): 190 BaseArg.__init__(self, *args, **kwargs) 191 self._out = False 192 193 @property 194 def out(self): 195 return self._out 196 197 198class InKwarg(BaseKwarg, InArg): 199 def __init__(self, *args, **kwargs): 200 BaseKwarg.__init__(self, *args, **kwargs) 201 202 203class OutArg(BaseKwarg): 204 def __init__(self, argtype, *, argname="out", fill=0, **kwargs): 205 BaseKwarg.__init__(self, argtype, argname, fill=fill, **kwargs) 206 self._out = True 207 208 @property 209 def out(self): 210 return self._out 211 212 213class Return(BaseKwarg): 214 """Model for returns of operators""" 215 216 RETURN_NAME_PREFIX = "__ret" 217 218 def __init__(self, argtype, *, argname=RETURN_NAME_PREFIX, fill=0, **kwargs): 219 BaseKwarg.__init__(self, argtype, argname=argname, fill=fill, **kwargs) 220 221 def is_expected(self, result: Any) -> bool: 222 """Check whether return value matches expectation. 223 For Tensor, we only focus on whether the return Tensor has the same dtype as expected. 224 """ 225 if isinstance(result, torch.Tensor): 226 return result.dtype == self.dtype 227 else: 228 raise NotImplementedError(f"Not implemented for {type(result)}") 229 230 def to_out(self, *, name: Optional[str] = None) -> OutArg: 231 return OutArg( 232 self.type, 233 argname=name if name else self.argname, 234 fill=self.fill, 235 size=self.size, 236 dtype=self.dtype, 237 value=self.value, 238 ) 239