xref: /aosp_15_r20/external/pytorch/torch/utils/_sympy/reference.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3import operator
4
5import sympy
6
7import torch
8from torch.utils._sympy.functions import (
9    _keep_float,
10    FloatPow,
11    FloatTrueDiv,
12    FloorDiv,
13    IntTrueDiv,
14    Max,
15    Min,
16    Mod,
17    OpaqueUnaryFn_exp,
18    OpaqueUnaryFn_log,
19    OpaqueUnaryFn_sqrt,
20    PowByNatural,
21    RoundDecimal,
22    RoundToInt,
23    ToFloat,
24    TruncToInt,
25)
26
27
28# The sympy interpretation of operators.  It will also sometimes work with
29# plain int/float, but if you do certain operations you will get out a
30# sympy.Basic in the end.  If you want the Python/FX traceable interpretation,
31# check PythonReferenceAnalysis.
32# NB: For magic methods this needs to use normal magic methods
33# so that test_magic_methods works
34class ReferenceAnalysis:
35    @staticmethod
36    def constant(c, dtype):
37        return sympy.sympify(c)
38
39    @staticmethod
40    def or_(a, b):
41        return a | b
42
43    @staticmethod
44    def and_(a, b):
45        return a & b
46
47    @staticmethod
48    def eq(a, b):
49        if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
50            return sympy.Eq(a, b)
51        return a == b
52
53    @classmethod
54    def ne(cls, a, b):
55        return cls.not_(cls.eq(a, b))
56
57    @staticmethod
58    def lt(a, b):
59        return a < b
60
61    @staticmethod
62    def gt(a, b):
63        return a > b
64
65    @staticmethod
66    def le(a, b):
67        return a <= b
68
69    @staticmethod
70    def ge(a, b):
71        return a >= b
72
73    @staticmethod
74    def not_(a):
75        assert not isinstance(a, bool)
76        return ~a
77
78    @staticmethod
79    def reciprocal(x):
80        return FloatTrueDiv(1.0, x)
81
82    @staticmethod
83    def square(x):
84        return PowByNatural(x, 2)
85
86    @staticmethod
87    def trunc_to_int(x, dtype):
88        return TruncToInt(x)
89
90    @staticmethod
91    def ceil_to_int(x, dtype):
92        return sympy.ceiling(x)
93
94    @staticmethod
95    def floor_to_int(x, dtype):
96        return sympy.floor(x)
97
98    @staticmethod
99    def floor(x):
100        return _keep_float(sympy.floor)(x)
101
102    @staticmethod
103    def ceil(x):
104        return _keep_float(sympy.ceiling)(x)
105
106    @staticmethod
107    def to_dtype(x, dtype):
108        if dtype == torch.float64:
109            return ToFloat(x)
110        raise NotImplementedError(f"to_dtype {dtype} NYI")
111
112    @staticmethod
113    def mod(x, y):
114        return Mod(x, y)
115
116    @staticmethod
117    def abs(x):
118        return abs(x)
119
120    @staticmethod
121    def neg(x):
122        return -x
123
124    @staticmethod
125    def truediv(a, b):
126        return FloatTrueDiv(a, b)
127
128    @staticmethod
129    def int_truediv(a, b):
130        return IntTrueDiv(a, b)
131
132    @staticmethod
133    def floordiv(a, b):
134        return FloorDiv(a, b)
135
136    @staticmethod
137    def truncdiv(a, b):
138        raise NotImplementedError("TODO: truncdiv")
139
140    @staticmethod
141    def add(a, b):
142        return _keep_float(operator.add)(a, b)
143
144    @staticmethod
145    def mul(a, b):
146        return _keep_float(operator.mul)(a, b)
147
148    @staticmethod
149    def sub(a, b):
150        return _keep_float(operator.sub)(a, b)
151
152    @staticmethod
153    def exp(x):
154        return OpaqueUnaryFn_exp(x)
155
156    @staticmethod
157    def log(x):
158        return OpaqueUnaryFn_log(x)
159
160    @staticmethod
161    def sqrt(x):
162        return OpaqueUnaryFn_sqrt(x)
163
164    @staticmethod
165    def pow(a, b):
166        return _keep_float(FloatPow)(a, b)
167
168    @staticmethod
169    def pow_by_natural(a, b):
170        return PowByNatural(a, b)
171
172    @staticmethod
173    def minimum(a, b):
174        return Min(a, b)
175
176    @staticmethod
177    def maximum(a, b):
178        return Max(a, b)
179
180    @staticmethod
181    def round_to_int(a, dtype):
182        return RoundToInt(a)
183
184    @staticmethod
185    def round_decimal(a, b):
186        return RoundDecimal(a, b)
187
188
189# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
190# Python types and is FX traceable.  Inheritance here is purely for code
191# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
192class PythonReferenceAnalysis(ReferenceAnalysis):
193    @staticmethod
194    def constant(c, dtype):
195        if dtype is torch.int64:
196            return int(c)
197        elif dtype is torch.double:
198            return float(c)
199        elif dtype is torch.bool:
200            return bool(c)
201        else:
202            raise AssertionError(f"unrecognized dtype {dtype}")
203
204    @staticmethod
205    def not_(a):
206        return torch.sym_not(a)
207
208    @staticmethod
209    def floordiv(a, b):
210        return a // b
211
212    @staticmethod
213    def mod(x, y):
214        return x % y
215
216    @staticmethod
217    def truncdiv(a, b):
218        return a / b
219
220    @staticmethod
221    def to_dtype(x, dtype):
222        if dtype == torch.float64:
223            return torch.sym_float(x)
224        raise NotImplementedError(f"to_dtype {dtype} NYI")
225
226    @staticmethod
227    def exp(x):
228        raise AssertionError("exp is not valid shape sympy expr")
229
230    @staticmethod
231    def log(x):
232        raise AssertionError("log is not valid shape sympy expr")
233
234    @staticmethod
235    def sqrt(x):
236        return torch._sym_sqrt(x)  # type: ignore[attr-defined]
237
238    @staticmethod
239    def minimum(a, b):
240        return torch.sym_min(a, b)
241
242    @staticmethod
243    def maximum(a, b):
244        return torch.sym_max(a, b)
245
246    @staticmethod
247    def floor_to_int(x, dtype):
248        return math.floor(x)
249
250    @staticmethod
251    def ceil_to_int(x, dtype):
252        return math.ceil(x)
253
254    @staticmethod
255    def floor(x):
256        return float(math.floor(x))
257
258    @staticmethod
259    def ceil(x):
260        return float(math.ceil(x))
261
262    @staticmethod
263    def truediv(a, b):
264        return a / b
265
266    @staticmethod
267    def pow(a, b):
268        return a**b
269
270    @staticmethod
271    def pow_by_natural(a, b):
272        # Pray that safe_pow is not needed here lol.  In particular, this
273        # never participates in VR low/high ranges, so overflow should be
274        # unlikely
275        return a**b
276
277    @staticmethod
278    def round_to_int(a, dtype):
279        return round(a)
280
281    @staticmethod
282    def round_decimal(a, b):
283        return round(a, ndigits=b)
284