1# mypy: allow-untyped-defs 2import math 3import operator 4 5import sympy 6 7import torch 8from torch.utils._sympy.functions import ( 9 _keep_float, 10 FloatPow, 11 FloatTrueDiv, 12 FloorDiv, 13 IntTrueDiv, 14 Max, 15 Min, 16 Mod, 17 OpaqueUnaryFn_exp, 18 OpaqueUnaryFn_log, 19 OpaqueUnaryFn_sqrt, 20 PowByNatural, 21 RoundDecimal, 22 RoundToInt, 23 ToFloat, 24 TruncToInt, 25) 26 27 28# The sympy interpretation of operators. It will also sometimes work with 29# plain int/float, but if you do certain operations you will get out a 30# sympy.Basic in the end. If you want the Python/FX traceable interpretation, 31# check PythonReferenceAnalysis. 32# NB: For magic methods this needs to use normal magic methods 33# so that test_magic_methods works 34class ReferenceAnalysis: 35 @staticmethod 36 def constant(c, dtype): 37 return sympy.sympify(c) 38 39 @staticmethod 40 def or_(a, b): 41 return a | b 42 43 @staticmethod 44 def and_(a, b): 45 return a & b 46 47 @staticmethod 48 def eq(a, b): 49 if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr): 50 return sympy.Eq(a, b) 51 return a == b 52 53 @classmethod 54 def ne(cls, a, b): 55 return cls.not_(cls.eq(a, b)) 56 57 @staticmethod 58 def lt(a, b): 59 return a < b 60 61 @staticmethod 62 def gt(a, b): 63 return a > b 64 65 @staticmethod 66 def le(a, b): 67 return a <= b 68 69 @staticmethod 70 def ge(a, b): 71 return a >= b 72 73 @staticmethod 74 def not_(a): 75 assert not isinstance(a, bool) 76 return ~a 77 78 @staticmethod 79 def reciprocal(x): 80 return FloatTrueDiv(1.0, x) 81 82 @staticmethod 83 def square(x): 84 return PowByNatural(x, 2) 85 86 @staticmethod 87 def trunc_to_int(x, dtype): 88 return TruncToInt(x) 89 90 @staticmethod 91 def ceil_to_int(x, dtype): 92 return sympy.ceiling(x) 93 94 @staticmethod 95 def floor_to_int(x, dtype): 96 return sympy.floor(x) 97 98 @staticmethod 99 def floor(x): 100 return _keep_float(sympy.floor)(x) 101 102 @staticmethod 103 def ceil(x): 104 return _keep_float(sympy.ceiling)(x) 105 106 @staticmethod 107 def to_dtype(x, dtype): 108 if dtype == torch.float64: 109 return ToFloat(x) 110 raise NotImplementedError(f"to_dtype {dtype} NYI") 111 112 @staticmethod 113 def mod(x, y): 114 return Mod(x, y) 115 116 @staticmethod 117 def abs(x): 118 return abs(x) 119 120 @staticmethod 121 def neg(x): 122 return -x 123 124 @staticmethod 125 def truediv(a, b): 126 return FloatTrueDiv(a, b) 127 128 @staticmethod 129 def int_truediv(a, b): 130 return IntTrueDiv(a, b) 131 132 @staticmethod 133 def floordiv(a, b): 134 return FloorDiv(a, b) 135 136 @staticmethod 137 def truncdiv(a, b): 138 raise NotImplementedError("TODO: truncdiv") 139 140 @staticmethod 141 def add(a, b): 142 return _keep_float(operator.add)(a, b) 143 144 @staticmethod 145 def mul(a, b): 146 return _keep_float(operator.mul)(a, b) 147 148 @staticmethod 149 def sub(a, b): 150 return _keep_float(operator.sub)(a, b) 151 152 @staticmethod 153 def exp(x): 154 return OpaqueUnaryFn_exp(x) 155 156 @staticmethod 157 def log(x): 158 return OpaqueUnaryFn_log(x) 159 160 @staticmethod 161 def sqrt(x): 162 return OpaqueUnaryFn_sqrt(x) 163 164 @staticmethod 165 def pow(a, b): 166 return _keep_float(FloatPow)(a, b) 167 168 @staticmethod 169 def pow_by_natural(a, b): 170 return PowByNatural(a, b) 171 172 @staticmethod 173 def minimum(a, b): 174 return Min(a, b) 175 176 @staticmethod 177 def maximum(a, b): 178 return Max(a, b) 179 180 @staticmethod 181 def round_to_int(a, dtype): 182 return RoundToInt(a) 183 184 @staticmethod 185 def round_decimal(a, b): 186 return RoundDecimal(a, b) 187 188 189# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain 190# Python types and is FX traceable. Inheritance here is purely for code 191# sharing (TODO: considering splitting out a BaseReferenceAnalysis). 192class PythonReferenceAnalysis(ReferenceAnalysis): 193 @staticmethod 194 def constant(c, dtype): 195 if dtype is torch.int64: 196 return int(c) 197 elif dtype is torch.double: 198 return float(c) 199 elif dtype is torch.bool: 200 return bool(c) 201 else: 202 raise AssertionError(f"unrecognized dtype {dtype}") 203 204 @staticmethod 205 def not_(a): 206 return torch.sym_not(a) 207 208 @staticmethod 209 def floordiv(a, b): 210 return a // b 211 212 @staticmethod 213 def mod(x, y): 214 return x % y 215 216 @staticmethod 217 def truncdiv(a, b): 218 return a / b 219 220 @staticmethod 221 def to_dtype(x, dtype): 222 if dtype == torch.float64: 223 return torch.sym_float(x) 224 raise NotImplementedError(f"to_dtype {dtype} NYI") 225 226 @staticmethod 227 def exp(x): 228 raise AssertionError("exp is not valid shape sympy expr") 229 230 @staticmethod 231 def log(x): 232 raise AssertionError("log is not valid shape sympy expr") 233 234 @staticmethod 235 def sqrt(x): 236 return torch._sym_sqrt(x) # type: ignore[attr-defined] 237 238 @staticmethod 239 def minimum(a, b): 240 return torch.sym_min(a, b) 241 242 @staticmethod 243 def maximum(a, b): 244 return torch.sym_max(a, b) 245 246 @staticmethod 247 def floor_to_int(x, dtype): 248 return math.floor(x) 249 250 @staticmethod 251 def ceil_to_int(x, dtype): 252 return math.ceil(x) 253 254 @staticmethod 255 def floor(x): 256 return float(math.floor(x)) 257 258 @staticmethod 259 def ceil(x): 260 return float(math.ceil(x)) 261 262 @staticmethod 263 def truediv(a, b): 264 return a / b 265 266 @staticmethod 267 def pow(a, b): 268 return a**b 269 270 @staticmethod 271 def pow_by_natural(a, b): 272 # Pray that safe_pow is not needed here lol. In particular, this 273 # never participates in VR low/high ranges, so overflow should be 274 # unlikely 275 return a**b 276 277 @staticmethod 278 def round_to_int(a, dtype): 279 return round(a) 280 281 @staticmethod 282 def round_decimal(a, b): 283 return round(a, ndigits=b) 284