1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import dataclasses 5import itertools 6import logging 7import math 8import operator 9from typing import ( 10 Callable, 11 Dict, 12 Generic, 13 Optional, 14 overload, 15 SupportsFloat, 16 TYPE_CHECKING, 17 TypeVar, 18 Union, 19) 20from typing_extensions import TypeGuard 21 22import sympy 23from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom 24 25import torch 26from torch._logging import LazyString 27from torch._prims_common import dtype_to_type 28 29from .functions import ( 30 _keep_float, 31 FloatTrueDiv, 32 FloorDiv, 33 IntTrueDiv, 34 OpaqueUnaryFn_exp, 35 OpaqueUnaryFn_log, 36 OpaqueUnaryFn_sqrt, 37 PowByNatural, 38 RoundDecimal, 39 RoundToInt, 40 safe_pow, 41 ToFloat, 42 TruncToFloat, 43 TruncToInt, 44) 45from .interp import sympy_interp 46from .numbers import int_oo, IntInfinity, NegativeIntInfinity 47 48 49log = logging.getLogger(__name__) 50 51__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"] 52 53_T = TypeVar("_T", sympy.Expr, SympyBoolean) 54 55 56class ValueRangeError(RuntimeError): 57 pass 58 59 60# Like sympify, but supports less stuff, and also ensures that direct 61# sympy expressions don't have free variables 62def simple_sympify(e): 63 if isinstance(e, bool): 64 return sympy.true if e else sympy.false 65 elif isinstance(e, int): 66 return sympy.Integer(e) 67 elif isinstance(e, float): 68 # infinity is special; we use it to bracket integers as well 69 if math.isinf(e): 70 return sympy.oo if e > 0 else -sympy.oo 71 return sympy.Float(e) 72 elif isinstance(e, sympy.Expr): 73 assert e.is_number, e 74 # NaNs can occur when doing things like 0 * sympy.oo, but it is better 75 # if the operator notices this and takes care of it, because sometimes 76 # the NaN is inappropriate (for example, for ints, the [-oo, oo] range 77 # should go to zero when multiplied with [0, 0]) 78 assert e != sympy.nan 79 return e 80 elif isinstance(e, BooleanAtom): 81 return e 82 else: 83 raise AssertionError(f"not simple sympy type {type(e)}: {e}") 84 85 86# Sympy atomics only. Unlike <=, it also works on Sympy bools. 87def sympy_generic_le(lower, upper): 88 if isinstance(lower, sympy.Expr): 89 assert isinstance(upper, sympy.Expr) 90 return lower <= upper 91 else: 92 # only negative condition is True > False 93 assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), ( 94 lower, 95 upper, 96 ) 97 return not (lower and not upper) 98 99 100def vr_is_bool(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[SympyBoolean]]: 101 return vr.is_bool 102 103 104def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]: 105 return not vr.is_bool 106 107 108ExprIn = Union[int, float, sympy.Expr] 109BoolIn = Union[bool, SympyBoolean] 110AllIn = Union[ExprIn, BoolIn] 111ExprFn = Callable[[sympy.Expr], sympy.Expr] 112ExprFn2 = Callable[[sympy.Expr, sympy.Expr], sympy.Expr] 113BoolFn = Callable[[SympyBoolean], SympyBoolean] 114BoolFn2 = Callable[[SympyBoolean, SympyBoolean], SympyBoolean] 115AllFn = Union[ExprFn, BoolFn] 116AllFn2 = Union[ExprFn2, BoolFn2] 117 118 119@dataclasses.dataclass(frozen=True) 120class ValueRanges(Generic[_T]): 121 if TYPE_CHECKING: 122 # ruff doesn't understand circular references but mypy does 123 ExprVR = ValueRanges[sympy.Expr] # noqa: F821 124 BoolVR = ValueRanges[SympyBoolean] # noqa: F821 125 AllVR = Union[ExprVR, BoolVR] 126 127 # Although the type signature here suggests you can pass any 128 # sympy expression, in practice the analysis here only works 129 # with constant sympy expressions 130 lower: _T 131 upper: _T 132 is_bool: bool 133 is_int: bool 134 is_float: bool 135 136 def __repr__(self) -> str: 137 return f"VR[{self.lower}, {self.upper}]" 138 139 @overload 140 def __init__( 141 self: ValueRanges[sympy.Expr], 142 lower: ExprIn, 143 upper: ExprIn, 144 ) -> None: 145 ... 146 147 @overload 148 def __init__( # type: ignore[misc] 149 self: ValueRanges[SympyBoolean], 150 lower: BoolIn, 151 upper: BoolIn, 152 ) -> None: 153 ... 154 155 def __init__(self, lower: AllIn, upper: AllIn) -> None: 156 lower = simple_sympify(lower) 157 upper = simple_sympify(upper) 158 # TODO: when the bounds have free variables, this may be 159 # nontrivial to actually verify 160 try: 161 if not sympy_generic_le(lower, upper): 162 raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]") 163 except TypeError as e: 164 raise TypeError(f"Could not compare {lower} <= {upper}") from e 165 166 is_bool_lower = isinstance(lower, SympyBoolean) 167 is_bool_upper = isinstance(upper, SympyBoolean) 168 assert is_bool_lower == is_bool_upper, (lower, upper) 169 170 # Warning: is_int/is_float is best effort. We do pretty well in 171 # Dynamo, but in Inductor these attributes are often wrong because we 172 # are not very rigorous in dtype analysis. This is also why we need 173 # the flexible analysis for is_int: sometimes a sympy.oo pops in for 174 # an integer bound. I would /like/ for us not to do this, but it's 175 # too hard to push the invariant through right now. 176 if isinstance(lower, sympy.Integer) and upper == sympy.oo: 177 upper = int_oo 178 if isinstance(upper, sympy.Integer) and lower == -sympy.oo: 179 lower = -int_oo 180 # NB: [-int_oo, -int_oo] and [int_oo, int_oo] are allowed 181 integer_types = (sympy.Integer, NegativeIntInfinity, IntInfinity) 182 is_int_lower = isinstance(lower, integer_types) 183 is_int_upper = isinstance(upper, integer_types) 184 185 # Because this is a frozen class 186 object.__setattr__(self, "lower", lower) 187 object.__setattr__(self, "upper", upper) 188 # Unlike bool/int in Python, we don't report bools are ints 189 # 190 # NB: is_bool_lower == is_bool_upper, so we only need to check one 191 object.__setattr__(self, "is_bool", is_bool_lower) 192 object.__setattr__( 193 self, 194 "is_int", 195 not self.is_bool and is_int_lower and is_int_upper, 196 ) 197 """ 198 # This assert is just impossible right now, too many sympy bugs 199 if self.is_int: 200 # NB: sympy will sometimes randomly lose the float-ness of zero, 201 # so we also need to account for that in the assertion here. 202 # See also https://github.com/sympy/sympy/issues/26620 203 assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], ( 204 lower, 205 upper, 206 ) 207 assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper) 208 """ 209 # NB: [-oo, oo] always advertises as float! 210 object.__setattr__(self, "is_float", not self.is_bool and not self.is_int) 211 assert self.is_bool or self.is_int or self.is_float, (lower, upper) 212 213 def boolify(self) -> ValueRanges[SympyBoolean]: 214 if vr_is_bool(self): 215 return self 216 elif self == ValueRanges.unknown(): 217 return ValueRanges.unknown_bool() 218 else: 219 raise AssertionError(f"not bool like {self}") 220 221 def __contains__(self, x: AllIn) -> bool: 222 return ValueRanges.wrap(x).issubset(self) 223 224 def issubset(self, other): 225 return sympy_generic_le(other.lower, self.lower) and sympy_generic_le( 226 self.upper, other.upper 227 ) 228 229 def tighten(self, other) -> ValueRanges: 230 """Given two ValueRanges, returns their intersection""" 231 return self & other 232 233 # Intersection 234 @overload 235 def __and__( 236 self: ValueRanges[sympy.Expr], 237 other: ValueRanges[sympy.Expr], 238 ) -> ValueRanges[sympy.Expr]: 239 ... 240 241 @overload 242 def __and__( # type: ignore[misc] 243 self: ValueRanges[SympyBoolean], 244 other: ValueRanges[SympyBoolean], 245 ) -> ValueRanges[SympyBoolean]: 246 ... 247 248 def __and__(self: AllVR, other: AllVR) -> AllVR: 249 if other == ValueRanges.unknown(): 250 return self 251 if self == ValueRanges.unknown(): 252 return other 253 assert self.is_bool == other.is_bool, (self, other) 254 assert self.is_int == other.is_int, (self, other) 255 assert self.is_float == other.is_float, (self, other) 256 if self.is_bool: 257 return ValueRanges( 258 sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper) 259 ) 260 else: 261 return ValueRanges( 262 sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper) 263 ) 264 265 # Union 266 @overload 267 def __or__( 268 self: ValueRanges[sympy.Expr], 269 other: ValueRanges[sympy.Expr], 270 ) -> ValueRanges[sympy.Expr]: 271 ... 272 273 @overload 274 def __or__( # type: ignore[misc] 275 self: ValueRanges[SympyBoolean], 276 other: ValueRanges[SympyBoolean], 277 ) -> ValueRanges[SympyBoolean]: 278 ... 279 280 def __or__(self: AllVR, other: AllVR) -> AllVR: 281 if ValueRanges.unknown() in (self, other): 282 return ValueRanges.unknown() 283 assert self.is_bool == other.is_bool, (self, other) 284 assert self.is_int == other.is_int, (self, other) 285 assert self.is_float == other.is_float, (self, other) 286 if self.is_bool: 287 return ValueRanges( 288 sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper) 289 ) 290 else: 291 return ValueRanges( 292 sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper) 293 ) 294 295 def is_singleton(self) -> bool: 296 return self.lower == self.upper 297 298 @staticmethod 299 def unknown() -> ValueRanges[sympy.Expr]: 300 return ValueRanges(-sympy.oo, sympy.oo) 301 302 @staticmethod 303 def unknown_int() -> ValueRanges[sympy.Expr]: 304 return ValueRanges(-int_oo, int_oo) 305 306 @staticmethod 307 def unknown_bool() -> ValueRanges[SympyBoolean]: 308 return ValueRanges(sympy.false, sympy.true) 309 310 @overload 311 @staticmethod 312 # work around the fact that bool and int overlap 313 def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR: # type: ignore[overload-overlap] 314 ... 315 316 @overload 317 @staticmethod 318 def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR: # type: ignore[misc] 319 ... 320 321 @staticmethod 322 def wrap(arg: Union[AllIn, AllVR]) -> AllVR: 323 if isinstance(arg, ValueRanges): 324 return arg 325 if isinstance(arg, float) and math.isnan(arg): 326 return ValueRanges.unknown() 327 # arg is either ExprIn or BoolIn, but we don't know it here 328 return ValueRanges(arg, arg) # type: ignore[arg-type] 329 330 @staticmethod 331 def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: 332 """Increasing: x <= y => f(x) <= f(y).""" 333 x = ValueRanges.wrap(x) 334 return ValueRanges(fn(x.lower), fn(x.upper)) 335 336 @overload 337 @staticmethod 338 def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: 339 ... 340 341 @overload 342 @staticmethod 343 def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR: # type: ignore[misc] 344 ... 345 346 @staticmethod 347 def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR: 348 """Decreasing: x <= y => f(x) >= f(y).""" 349 x = ValueRanges.wrap(x) 350 # consistently either Expr or Bool, but we don't know it here 351 return ValueRanges(fn(x.upper), fn(x.lower)) # type: ignore[arg-type] 352 353 @staticmethod 354 def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: 355 """It's increasing or decreasing.""" 356 x = ValueRanges.wrap(x) 357 l = fn(x.lower) 358 u = fn(x.upper) 359 return ValueRanges(min(l, u), max(l, u)) 360 361 @staticmethod 362 def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: 363 """Fn is convex and has a minimum at 0.""" 364 x = ValueRanges.wrap(x) 365 if 0 in x: 366 upper = max(fn(x.lower), fn(x.upper)) 367 upper = simple_sympify(upper) 368 if isinstance(upper, sympy.Float) or upper == sympy.oo: 369 return ValueRanges(0.0, upper) 370 return ValueRanges(0, upper) 371 return ValueRanges.monotone_map(x, fn) 372 373 @overload 374 @staticmethod 375 def coordinatewise_increasing_map( 376 x: Union[ExprIn, ExprVR], 377 y: Union[ExprIn, ExprVR], 378 fn: ExprFn2, 379 ) -> ExprVR: 380 ... 381 382 @overload 383 @staticmethod 384 def coordinatewise_increasing_map( # type: ignore[misc] 385 x: Union[BoolIn, BoolVR], 386 y: Union[BoolIn, BoolVR], 387 fn: BoolFn2, 388 ) -> BoolVR: 389 ... 390 391 @staticmethod 392 def coordinatewise_increasing_map( 393 x: Union[AllIn, AllVR], 394 y: Union[AllIn, AllVR], 395 fn: AllFn2, 396 ) -> AllVR: 397 """ 398 It's increasing on each coordinate. 399 400 Mathematically: 401 For every 1 <= i <= n and x_i <= y_i we have that 402 f(x1, .., xn) <= f(x1, , yi, ..., xn) 403 """ 404 x, y = ValueRanges.wrap(x), ValueRanges.wrap(y) 405 return ValueRanges( 406 fn(x.lower, y.lower), # type: ignore[arg-type] 407 fn(x.upper, y.upper), # type: ignore[arg-type] 408 ) 409 410 @classmethod 411 def coordinatewise_monotone_map(cls, x, y, fn): 412 """It's increasing or decreasing on each coordinate.""" 413 x, y = cls.wrap(x), cls.wrap(y) 414 products = [ 415 fn(a, b) 416 for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper]) 417 ] 418 return ValueRanges(min(products), max(products)) 419 420 421class SymPyValueRangeAnalysis: 422 """ 423 It gives bounds on a SymPy operator given bounds on its arguments 424 See the function `bound_sympy` for a function that applies this logic to a full SymPy expression 425 """ 426 427 @staticmethod 428 def constant(value, dtype): 429 if isinstance(value, ValueRanges): 430 assert value.is_singleton() 431 value = value.lower 432 # NB: value is NOT a sympy expression, it's a constant! 433 is_python = isinstance(value, (int, float, bool)) 434 assert is_python or isinstance( 435 value, (BooleanAtom, sympy.Integer, sympy.Number) 436 ) 437 438 # using nan makes subsequent computation throw, and for the purposes of optimization 439 # returning -math.inf - math.inf is equivalent to giving up 440 if isinstance(value, SupportsFloat) and math.isnan(value): 441 if dtype == torch.bool: 442 return ValueRanges.unknown_bool() 443 elif dtype.is_floating_point: 444 return ValueRanges.unknown() 445 else: 446 return ValueRanges(-int_oo, int_oo) 447 448 if is_python: 449 type_ = dtype_to_type(dtype) 450 value = type_(value) 451 else: 452 # We do a type check on a best-effort basis 453 # We don't want to force a cast to sympy.Float if the value is Rational to avoid losing precision 454 if dtype == torch.bool: 455 assert isinstance(value, BooleanAtom) 456 elif dtype.is_floating_point: 457 assert not value.is_finite or value.is_real 458 else: 459 # dtype is intXX 460 assert value.is_integer 461 462 r = ValueRanges.wrap(value) 463 return r 464 465 @staticmethod 466 def to_dtype(a, dtype, src_dtype=None): 467 if dtype == torch.float64: 468 return ValueRanges.increasing_map(a, ToFloat) 469 elif dtype == torch.bool: 470 return ValueRanges.unknown_bool() 471 elif not dtype.is_floating_point: 472 return ValueRanges.unknown_int() 473 return ValueRanges.unknown() 474 475 @staticmethod 476 def trunc_to_int(a, dtype): 477 return ValueRanges.increasing_map(a, TruncToInt) 478 479 @staticmethod 480 def not_(a): 481 a = ValueRanges.wrap(a) 482 a = a.boolify() 483 assert a.is_bool 484 return ValueRanges.decreasing_map(a, sympy.Not) 485 486 @staticmethod 487 def or_(a, b): 488 return ValueRanges.coordinatewise_increasing_map(a, b, sympy.Or) 489 490 @staticmethod 491 def and_(a, b): 492 return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And) 493 494 @staticmethod 495 def eq(a, b): 496 a = ValueRanges.wrap(a) 497 b = ValueRanges.wrap(b) 498 if a.is_singleton() and b.is_singleton() and a.lower == b.lower: 499 return ValueRanges.wrap(sympy.true) 500 elif a.lower > b.upper or b.lower > a.upper: # ranges disjoint 501 return ValueRanges.wrap(sympy.false) 502 return ValueRanges(sympy.false, sympy.true) 503 504 @classmethod 505 def ne(cls, a, b): 506 return cls.not_(cls.eq(a, b)) 507 508 @classmethod 509 def identity(cls, a): 510 return ValueRanges.wrap(a) 511 512 @classmethod 513 def lt(cls, a, b): 514 a = ValueRanges.wrap(a) 515 b = ValueRanges.wrap(b) 516 assert a.is_bool == b.is_bool 517 if a.is_bool: 518 return cls.and_(cls.not_(a), b) 519 else: 520 if a.upper < b.lower: 521 return ValueRanges.wrap(sympy.true) 522 elif a.lower >= b.upper: 523 return ValueRanges.wrap(sympy.false) 524 return ValueRanges(sympy.false, sympy.true) 525 526 @classmethod 527 def gt(cls, a, b): 528 return cls.lt(b, a) 529 530 @classmethod 531 def le(cls, a, b): 532 return cls.not_(cls.gt(a, b)) 533 534 @classmethod 535 def ge(cls, a, b): 536 return cls.not_(cls.lt(a, b)) 537 538 @staticmethod 539 def add(a, b): 540 return ValueRanges.coordinatewise_increasing_map( 541 a, b, _keep_float(operator.add) 542 ) 543 544 @classmethod 545 def mul(cls, a, b): 546 a = ValueRanges.wrap(a) 547 b = ValueRanges.wrap(b) 548 549 assert a.is_bool == b.is_bool 550 if a.is_bool: 551 return cls.and_(a, b) 552 553 def safe_mul(a, b): 554 # Make unknown() * wrap(0.0) == wrap(0.0) 555 if a == 0.0 or a == 0: 556 return a 557 elif b == 0.0 or b == 0: 558 return b 559 else: 560 return a * b 561 562 return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul)) 563 564 @staticmethod 565 def int_truediv(a, b): 566 a = ValueRanges.wrap(a) 567 b = ValueRanges.wrap(b) 568 if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)): 569 return ValueRanges.unknown() 570 else: 571 return ValueRanges.coordinatewise_monotone_map( 572 a, b, _keep_float(IntTrueDiv) 573 ) 574 575 @staticmethod 576 def truediv(a, b): 577 a = ValueRanges.wrap(a) 578 b = ValueRanges.wrap(b) 579 if 0 in b or ( 580 (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) 581 ): 582 return ValueRanges.unknown() 583 else: 584 return ValueRanges.coordinatewise_monotone_map( 585 a, b, _keep_float(FloatTrueDiv) 586 ) 587 588 @staticmethod 589 def floordiv(a, b): 590 a = ValueRanges.wrap(a) 591 b = ValueRanges.wrap(b) 592 if 0 in b: 593 return ValueRanges.unknown_int() 594 products = [] 595 for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]): 596 r = FloorDiv(x, y) 597 if r is sympy.nan: 598 products.append((sympy.sign(x) * sympy.sign(y)) * int_oo) 599 else: 600 products.append(r) 601 602 return ValueRanges(min(products), max(products)) 603 604 @classmethod 605 def mod(cls, x, y): 606 x = ValueRanges.wrap(x) 607 y = ValueRanges.wrap(y) 608 # nb. We implement C semantics 609 610 def c_mod(a, b): 611 ret = abs(a) % abs(b) 612 if a < 0: 613 ret *= -1 614 return ret 615 616 def c_div(a, b): 617 x = a / b 618 return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x 619 620 if 0 in y: 621 return ValueRanges.unknown_int() 622 elif y.is_singleton(): 623 y_val = abs(y.lower) 624 # If it wraps, we need to take the whole interval 625 626 # The function is locally linear if they are in the same class 627 if c_div(x.lower, y_val) == c_div(x.upper, y_val): 628 return ValueRanges.increasing_map(x, lambda u: c_mod(u, y_val)) 629 if x.upper < 0: 630 # Negative case 631 return ValueRanges(-y_val + 1, 0) 632 elif x.lower > 0: 633 # Positive case 634 return ValueRanges(0, y_val - 1) 635 else: 636 # Mixed case 637 lower = max(-y_val + 1, x.lower) 638 upper = min(y_val - 1, x.upper) 639 return ValueRanges(lower, upper) 640 else: 641 # Too difficult, we bail out 642 upper = cls.abs(y).upper - 1 643 return ValueRanges(-upper, upper) 644 645 @classmethod 646 def modular_indexing(cls, a, b, c): 647 return cls.mod(cls.floordiv(a, b), c) 648 649 @classmethod 650 def is_non_overlapping_and_dense_indicator(cls, *args): 651 return ValueRanges.unknown_int() 652 653 @classmethod 654 def pow_by_natural(cls, a, b): 655 a = ValueRanges.wrap(a) 656 b = ValueRanges.wrap(b) 657 if a.is_singleton() and b.is_singleton(): 658 return ValueRanges.wrap(safe_pow(a.lower, b.lower)) 659 # NB: Exclude zero, because zero is special 660 elif a.lower >= 1: 661 # We should know that b >= 0 but we may have forgotten this fact due 662 # to replacements, so don't assert it, but DO clamp it to prevent 663 # degenerate problems 664 return ValueRanges.coordinatewise_increasing_map( 665 a, b & ValueRanges(0, int_oo), PowByNatural 666 ) 667 elif b.is_singleton(): 668 if b.lower % 2 == 0: 669 # x^n where n is even 670 return ValueRanges.convex_min_zero_map( 671 a, lambda x: safe_pow(x, b.lower) 672 ) 673 else: 674 # x^n where n is odd 675 return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower)) 676 else: 677 # a is potentially negative, and we don't know if the exponent is 678 # even or odd. So just conservatively set the upper and lower 679 # bound based on what the maximum absolute value could be, in both 680 # directions 681 max_base = max(a.upper, -a.lower) 682 return ValueRanges( 683 -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper) 684 ) 685 686 @classmethod 687 def pow(cls, a, b): 688 return ValueRanges.unknown() 689 690 # We could implement all this, but for floating point pow, is there 691 # really a point? 692 """ 693 a = ValueRanges.wrap(a) 694 b = ValueRanges.wrap(b) 695 696 # Not implemented yet. It's a bit tricky 697 # If you want to implement it, compute the partial derivatives of a ** b 698 # and check the ranges where the function is increasing / decreasing 699 # Another non-tight way of doing this is defaulting to doing noting that for a > 0, a ** b == exp(b * log(a)) 700 # If this second option is implemented, by carefult about the types and possible infinities here and there. 701 if not b.is_singleton(): 702 return ValueRanges.unknown() 703 704 b = b.lower 705 if a.is_singleton(): 706 a = a.lower 707 r = a**b 708 if not r.is_finite: 709 return ValueRanges.unknown() 710 return ValueRanges.wrap(r) 711 712 if b == 0: 713 if not a.lower.is_finite: 714 return ValueRanges.unknown() 715 return ValueRanges.wrap(1.0) 716 717 if b < 0: 718 a = cls.reciprocal(a) 719 b = -b 720 721 if a == ValueRanges.unknown(): 722 return ValueRanges.unknown() 723 724 # If the base is positive, then we're good, otherwise nothing's defined 725 if a.lower >= 0: 726 return ValueRanges.increasing_map(a, lambda x: x**b) 727 else: 728 return ValueRanges.unknown() 729 """ 730 731 @staticmethod 732 def reciprocal(x): 733 """Needed as it's used in pow, but it won't appear on a SymPy expression""" 734 x = ValueRanges.wrap(x) 735 if 0 in x: 736 return ValueRanges.unknown() 737 else: 738 return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y)) # type: ignore[operator] 739 740 @staticmethod 741 def abs(x): 742 return ValueRanges.convex_min_zero_map(x, abs) 743 744 @staticmethod 745 def exp(x): 746 return ValueRanges.increasing_map(x, OpaqueUnaryFn_exp) 747 748 @staticmethod 749 def log(x): 750 x = ValueRanges.wrap(x) 751 if x.lower <= 0: 752 return ValueRanges.unknown() 753 return ValueRanges.increasing_map(x, OpaqueUnaryFn_log) 754 755 @classmethod 756 def minimum(cls, a, b): 757 return cls.min_or_max(a, b, sympy.Min) 758 759 @classmethod 760 def maximum(cls, a, b): 761 return cls.min_or_max(a, b, sympy.Max) 762 763 @staticmethod 764 def min_or_max(a, b, fn): 765 a = ValueRanges.wrap(a) 766 b = ValueRanges.wrap(b) 767 return ValueRanges.coordinatewise_increasing_map(a, b, fn) 768 769 @classmethod 770 def floor_to_int(cls, x, dtype): 771 return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) 772 773 @classmethod 774 def ceil_to_int(cls, x, dtype): 775 return ValueRanges.increasing_map( 776 x, sympy.functions.elementary.integers.ceiling 777 ) 778 779 # I think these implementations are sound. The hazard here is that sympy 780 # will carry out the floor/ceil at too high precision and then something 781 # bad will happen when we convert it to float. 782 # 783 # For truncation, the implementation is clearly sound, because the desired 784 # target float is always exactly representable, since you're just chopping 785 # off bits the mantissa. But what about ceil/floor? 786 # 787 # The important constraint here is that we're not defining floor on 788 # arbitrary real numbers, only representable float numbers. So we can 789 # take advantage of the fact that before we reach the first 790 # unrepresentable integer in floating point space, we have the range of 791 # numbers corresponding to exponent zero: all integers, with no fractional 792 # amounts. floor/ceil is an identity operation in this case. In the 793 # range below here, representable floating point numbers are spaced 794 # exactly 1/2 apart, and notably, both the floor/ceil are defined floating 795 # point numbers. There is no "gap" as you step up to the next exponent. 796 797 @classmethod 798 def floor(cls, x): 799 return ValueRanges.increasing_map( 800 x, _keep_float(sympy.functions.elementary.integers.floor) 801 ) 802 803 @classmethod 804 def ceil(cls, x): 805 return ValueRanges.increasing_map( 806 x, _keep_float(sympy.functions.elementary.integers.ceiling) 807 ) 808 809 @classmethod 810 def round_decimal(cls, number, ndigits): 811 if not ndigits.is_singleton(): 812 return ValueRanges.unknown() 813 814 ndigits = ndigits.lower 815 # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind 816 # the second parameter. 817 fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 818 819 return ValueRanges.increasing_map(number, fn) 820 821 @classmethod 822 def round_to_int(cls, number, dtype): 823 return ValueRanges.increasing_map(number, RoundToInt) 824 825 # It's used in some models on symints 826 @staticmethod 827 def sqrt(x): 828 x = ValueRanges.wrap(x) 829 if x.lower < 0: 830 return ValueRanges.unknown() 831 return ValueRanges.increasing_map(x, OpaqueUnaryFn_sqrt) 832 833 @staticmethod 834 def where(a, b, c): 835 b = ValueRanges.wrap(b) 836 c = ValueRanges.wrap(c) 837 a = a.boolify() 838 # We sometimes write unknown without specifying the type correctly 839 # In particular, we do that when initialising the bounds for loads in bounds.py 840 assert b.is_bool == c.is_bool or ValueRanges.unknown() in (b, c) 841 if b.is_bool: 842 return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper)) 843 else: 844 return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper)) 845 846 # expr_cond_pair is used to represent a single (expr, condition) pair in piecewise. 847 # We just return the value range of the expression and its corresponding condition as a tuple 848 # and defer the analysis to piecewise 849 @staticmethod 850 def expr_cond_pair(a, b): 851 b = b.boolify() 852 return (a, b) 853 854 # piecewise function can be used to convert a SymBool to SymInt: 855 # int_expr = Piecewise((1, bool_expr), (0, True)), it evalutes to 1 when sym_bool is True and 0 otherwise. 856 # 857 # ranges is a sequence of (expr_range, condition_range) pairs. The range pair is constructed in expr_cond_pair. 858 # The ValueRange of Piecewise is just the union of all expr ranges whose condition expr can be True. 859 @staticmethod 860 def piecewise(*ranges): 861 init_range = None 862 for expr_range, cond_range in ranges: 863 if sympy.true in cond_range: 864 if init_range is None: 865 init_range = expr_range 866 else: 867 init_range = init_range | expr_range 868 return init_range 869 870 @staticmethod 871 def cos(x): 872 # TODO: We should tighten value ranges 873 # If input range span is pi + 2*pi*k, then output range is (-1, 1) 874 # otherwise the minimum of the value of the function on the extremes 875 return ValueRanges(-1.0, 1.0) 876 877 @staticmethod 878 def cosh(x): 879 return ValueRanges(0.0, sympy.oo) 880 """ 881 x = ValueRanges.wrap(x) 882 if x.lower > 0: 883 return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh) 884 elif x.upper < 0: 885 return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh) 886 return ValueRanges(0.0, sympy.oo) 887 """ 888 889 @staticmethod 890 def sin(x): 891 # TODO: We should tighten value ranges 892 # See details on cos 893 return ValueRanges(-1.0, 1.0) 894 895 @staticmethod 896 def sinh(x): 897 # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) 898 return ValueRanges(-sympy.oo, sympy.oo) 899 900 @staticmethod 901 def tan(x): 902 return ValueRanges(-sympy.oo, sympy.oo) 903 904 @staticmethod 905 def tanh(x): 906 # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) 907 return ValueRanges(-sympy.oo, sympy.oo) 908 909 @staticmethod 910 def asin(x): 911 return ValueRanges(-sympy.oo, sympy.oo) 912 """ 913 x = ValueRanges.wrap(x) 914 if -1 <= x.lower and x.upper <= 1: 915 return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh) 916 return ValueRanges.unknown() 917 """ 918 919 @staticmethod 920 def acos(x): 921 return ValueRanges(-sympy.oo, sympy.oo) 922 """ 923 x = ValueRanges.wrap(x) 924 if -1 <= x.lower and x.upper <= 1: 925 return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos) 926 return ValueRanges.unknown() 927 """ 928 929 @staticmethod 930 def atan(x): 931 return ValueRanges(-sympy.oo, sympy.oo) 932 # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) 933 934 @staticmethod 935 def trunc(x): 936 return ValueRanges.increasing_map(x, TruncToFloat) 937 938 939class ValueRangeAnalysis(SymPyValueRangeAnalysis): 940 def __init__(self) -> None: 941 self.name = "ValueRangeAnalysis" 942 boolean_operators = ( 943 "xor", 944 "logical_and", 945 "logical_or", 946 "logical_not", 947 ) 948 for op in boolean_operators: 949 setattr(self, op, self.bool_handler) 950 951 @staticmethod 952 def bool_handler(*args, **kwargs): 953 # just assuming bools can have both values 954 return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type] 955 956 @staticmethod 957 def default_handler(*args, **kwargs): 958 # many ops are unlikely to show up in optimizable indexing compute, 959 # so we dont have full coverage 960 return ValueRanges.unknown() 961 962 def load(self, name: str, index: sympy.Expr): 963 return ValueRanges.unknown() 964 965 def store(self, name, index, value, mode=None): 966 return 967 968 def reduction(self, name, dtype, src_dtype, reduction_type, index, value): 969 return ValueRanges.unknown() 970 971 @classmethod 972 def index_expr(cls, index, dtype): 973 assert isinstance(index, ValueRanges) 974 return cls.to_dtype(index, dtype) 975 976 @staticmethod 977 def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): 978 x = ValueRanges.wrap(x) 979 980 if dtype == torch.bool: 981 if x.is_singleton(): 982 return ValueRanges.wrap(x.lower != 0) 983 elif x.is_bool: 984 return x 985 elif 0 not in x: 986 return ValueRanges.wrap(sympy.true) 987 else: 988 return ValueRanges(sympy.false, sympy.true) 989 990 def cast(x, dtype): 991 # dtype is int or float 992 if dtype.is_floating_point: 993 return sympy.Float(x) 994 else: 995 if x in (int_oo, -int_oo): 996 return x 997 try: 998 return sympy.Integer(x) 999 except TypeError: 1000 # inf cannot be cast to Integer 1001 return x 1002 1003 if x.is_bool: 1004 if x.is_singleton(): 1005 val = 1 if x.lower else 0 1006 return ValueRanges.wrap(cast(val, dtype)) 1007 else: 1008 return ValueRanges(cast(0, dtype), cast(1, dtype)) 1009 else: 1010 # int to float or float to int 1011 return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype)) 1012 1013 @staticmethod 1014 def square(x): 1015 return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) 1016 1017 @staticmethod 1018 def neg(x): 1019 return ValueRanges.decreasing_map(x, operator.neg) 1020 1021 # TODO: this is slightly inaccurate because truncdiv operates at integer 1022 # precision, but we're going through float truediv which means we can 1023 # potentially lose precision on the bounds 1024 @classmethod 1025 def truncdiv(cls, a, b): 1026 x = cls.truediv(a, b) 1027 if x == ValueRanges.unknown(): 1028 return x 1029 1030 return cls.trunc(x) 1031 1032 @classmethod 1033 def sub(cls, a, b): 1034 return cls.add(a, cls.neg(b)) 1035 1036 def __getattr__(self, name): 1037 log.debug("unhandled ValueRange op %s", name) 1038 return self.default_handler 1039 1040 1041def bound_sympy( 1042 expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None 1043) -> ValueRanges: 1044 log.debug( 1045 "bound_sympy(%s)%s", 1046 expr, 1047 LazyString( 1048 lambda: "\n" 1049 + "\n".join( 1050 f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols 1051 ) 1052 if ranges 1053 else "" 1054 ), 1055 ) 1056 if isinstance(expr, sympy.Number): 1057 return ValueRanges.wrap(expr) 1058 1059 ranges = ranges or {} 1060 1061 # If there's a tracing context, augment available constrained ranges. 1062 context = torch._guards.TracingContext.try_get() 1063 if context and context.fake_mode.shape_env: 1064 ranges = {**context.fake_mode.shape_env.var_to_range, **ranges} 1065 1066 unbounded_vars = expr.free_symbols - ranges.keys() 1067 if unbounded_vars: 1068 # Give some bounds to the free variables via their SymPy assumptions 1069 # TODO A better way of doing this would be to assign them a range upon creation, as 1070 # size variables can come with a lower bound of 2, as we specialize on 0 and 1 1071 unbounded_ranges: Dict[sympy.Symbol, ValueRanges] = {} 1072 for s in unbounded_vars: 1073 if s.is_integer: # type: ignore[attr-defined] 1074 if s.is_positive: # type: ignore[attr-defined] 1075 vr = ValueRanges(1, int_oo) 1076 elif s.is_nonnegative: # type: ignore[attr-defined] 1077 vr = ValueRanges(0, int_oo) 1078 else: 1079 vr = ValueRanges.unknown_int() 1080 else: 1081 # Don't bother trying very hard here 1082 vr = ValueRanges.unknown() 1083 unbounded_ranges[s] = vr # type: ignore[index] 1084 ranges = {**ranges, **unbounded_ranges} 1085 1086 return sympy_interp(SymPyValueRangeAnalysis, ranges, expr) 1087