xref: /aosp_15_r20/external/executorch/exir/dialects/edge/arg/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from enum import Enum
8from typing import Any, Optional
9
10import torch
11import torch.testing._internal.common_dtype as common_dtype
12
13from executorch.exir.dialects.edge.arg.type import ArgType
14
15
16class ArgMode(Enum):
17    ONES = 0
18    RANDOM = 1
19
20
21class BaseArg:
22    def __init__(
23        self,
24        argtype,
25        *,
26        value=None,
27        size=None,
28        fill=None,
29        dtype=None,
30        nonzero=False,
31        nonneg=False,
32        bounded=False,
33    ):
34        self.type: ArgType = argtype
35
36        self.value_given = value is not None
37        self.size_given = size is not None
38        self.fill_given = fill is not None
39        self.dtype_given = dtype is not None
40
41        self.value = value
42        self.size = (2, 2) if size is None else tuple(size)
43        self.fill = 1 if fill is None else fill
44        self.dtype = torch.float if dtype is None else dtype
45
46        self.nonzero = nonzero
47        self.nonneg = nonneg
48        self.bounded = bounded
49
50        self._mode: ArgMode = ArgMode.ONES
51        self._kw: bool = False
52        self._out: bool = False
53
54    @property
55    def mode(self):
56        return self._mode
57
58    @mode.setter
59    def mode(self, v):
60        if not isinstance(v, ArgMode):
61            raise ValueError("mode property should be type ArgMode")
62        self._mode = v
63
64    @property
65    def kw(self):
66        return self._kw
67
68    @kw.setter
69    def kw(self, v):
70        if not isinstance(v, bool):
71            raise ValueError("kw property should be boolean")
72        self._kw = v
73
74    @property
75    def out(self):
76        return self._out
77
78    @out.setter
79    def out(self, v):
80        if not isinstance(v, bool):
81            raise ValueError("out property should be boolean")
82        self._out = v
83
84    def get_random_tensor(self, size, dtype):
85        size = tuple(size)
86        if dtype == torch.bool:
87            if self.nonzero:
88                return torch.full(size, True, dtype=dtype)
89            else:
90                return torch.randint(low=0, high=2, size=size, dtype=dtype)
91
92        if dtype in common_dtype.integral_types():
93            high = 100
94        elif dtype in common_dtype.floating_types():
95            high = 800
96        else:
97            raise ValueError(f"Unsupported Dtype: {dtype}")
98
99        if dtype == torch.uint8:
100            if self.nonzero:
101                return torch.randint(low=1, high=high, size=size, dtype=dtype)
102            else:
103                return torch.randint(low=0, high=high, size=size, dtype=dtype)
104
105        t = torch.randint(low=-high, high=high, size=size, dtype=dtype)
106        if self.nonzero:
107            pos = torch.randint(low=1, high=high, size=size, dtype=dtype)
108            t = torch.where(t == 0, pos, t)
109        if self.nonneg or self.bounded:
110            t = torch.abs(t)
111
112        if dtype in common_dtype.integral_types():
113            return t
114        if dtype in common_dtype.floating_types():
115            return t / 8
116
117    def get_random_scalar(self, dtype):
118        return self.get_random_tensor([], dtype).item()
119
120    def get_converted_scalar(self, value, dtype):
121        if dtype == torch.bool:
122            return bool(value)
123        elif dtype in common_dtype.integral_types():
124            return int(value)
125        elif dtype in common_dtype.floating_types():
126            return float(value)
127        else:
128            raise ValueError(f"Unsupported Dtype: {dtype}")
129
130    def get_scalar_val_with_dtype(self, dtype):
131        if self.value_given:
132            return self.get_converted_scalar(self.value, dtype)
133        elif self._mode == ArgMode.RANDOM:
134            return self.get_random_scalar(dtype)
135        elif self._mode == ArgMode.ONES:
136            return self.get_converted_scalar(1, dtype)
137        else:
138            raise ValueError(f"Unsupported Mode: {self._mode}")
139
140    def get_tensor_val_with_dtype(self, dtype):
141        if self.value_given:
142            return torch.tensor(self.value, dtype=dtype)
143        elif self.fill_given:
144            return torch.full(self.size, self.fill, dtype=dtype)
145        elif self._mode == ArgMode.RANDOM:
146            return self.get_random_tensor(self.size, dtype=dtype)
147        elif self._mode == ArgMode.ONES:
148            return torch.full(self.size, 1, dtype=dtype)
149        elif self.size_given:
150            return torch.full(self.size, self.fill, dtype=dtype)
151        else:
152            raise ValueError(f"Unsupported Mode: {self._mode}")
153
154    def get_val_with_dtype(self, dtype):
155        if dtype is None:
156            return None
157        if self.type.is_scalar_type():
158            return dtype
159        elif self.type.is_scalar():
160            return self.get_scalar_val_with_dtype(dtype)
161        elif self.type.is_tensor():
162            return self.get_tensor_val_with_dtype(dtype)
163        elif self.type.is_tensor_list():
164            if not self.value_given:
165                return []
166            return [x.get_val_with_dtype(dtype) for x in self.value]
167        else:
168            raise ValueError(f"Unsupported Type: {self.type}")
169
170    def get_val(self):
171        if self.type.has_dtype():
172            return self.get_val_with_dtype(self.dtype)
173        else:
174            return self.value
175
176
177class BaseKwarg(BaseArg):
178    def __init__(self, argtype, argname, **kwargs):
179        BaseArg.__init__(self, argtype, **kwargs)
180        self.argname = argname
181        self._kw = True
182
183    @property
184    def kw(self):
185        return super().kw
186
187
188class InArg(BaseArg):
189    def __init__(self, *args, **kwargs):
190        BaseArg.__init__(self, *args, **kwargs)
191        self._out = False
192
193    @property
194    def out(self):
195        return self._out
196
197
198class InKwarg(BaseKwarg, InArg):
199    def __init__(self, *args, **kwargs):
200        BaseKwarg.__init__(self, *args, **kwargs)
201
202
203class OutArg(BaseKwarg):
204    def __init__(self, argtype, *, argname="out", fill=0, **kwargs):
205        BaseKwarg.__init__(self, argtype, argname, fill=fill, **kwargs)
206        self._out = True
207
208    @property
209    def out(self):
210        return self._out
211
212
213class Return(BaseKwarg):
214    """Model for returns of operators"""
215
216    RETURN_NAME_PREFIX = "__ret"
217
218    def __init__(self, argtype, *, argname=RETURN_NAME_PREFIX, fill=0, **kwargs):
219        BaseKwarg.__init__(self, argtype, argname=argname, fill=fill, **kwargs)
220
221    def is_expected(self, result: Any) -> bool:
222        """Check whether return value matches expectation.
223        For Tensor, we only focus on whether the return Tensor has the same dtype as expected.
224        """
225        if isinstance(result, torch.Tensor):
226            return result.dtype == self.dtype
227        else:
228            raise NotImplementedError(f"Not implemented for {type(result)}")
229
230    def to_out(self, *, name: Optional[str] = None) -> OutArg:
231        return OutArg(
232            self.type,
233            argname=name if name else self.argname,
234            fill=self.fill,
235            size=self.size,
236            dtype=self.dtype,
237            value=self.value,
238        )
239