xref: /aosp_15_r20/external/pytorch/torch/utils/_sympy/value_ranges.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import dataclasses
5import itertools
6import logging
7import math
8import operator
9from typing import (
10    Callable,
11    Dict,
12    Generic,
13    Optional,
14    overload,
15    SupportsFloat,
16    TYPE_CHECKING,
17    TypeVar,
18    Union,
19)
20from typing_extensions import TypeGuard
21
22import sympy
23from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
24
25import torch
26from torch._logging import LazyString
27from torch._prims_common import dtype_to_type
28
29from .functions import (
30    _keep_float,
31    FloatTrueDiv,
32    FloorDiv,
33    IntTrueDiv,
34    OpaqueUnaryFn_exp,
35    OpaqueUnaryFn_log,
36    OpaqueUnaryFn_sqrt,
37    PowByNatural,
38    RoundDecimal,
39    RoundToInt,
40    safe_pow,
41    ToFloat,
42    TruncToFloat,
43    TruncToInt,
44)
45from .interp import sympy_interp
46from .numbers import int_oo, IntInfinity, NegativeIntInfinity
47
48
49log = logging.getLogger(__name__)
50
51__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"]
52
53_T = TypeVar("_T", sympy.Expr, SympyBoolean)
54
55
56class ValueRangeError(RuntimeError):
57    pass
58
59
60# Like sympify, but supports less stuff, and also ensures that direct
61# sympy expressions don't have free variables
62def simple_sympify(e):
63    if isinstance(e, bool):
64        return sympy.true if e else sympy.false
65    elif isinstance(e, int):
66        return sympy.Integer(e)
67    elif isinstance(e, float):
68        # infinity is special; we use it to bracket integers as well
69        if math.isinf(e):
70            return sympy.oo if e > 0 else -sympy.oo
71        return sympy.Float(e)
72    elif isinstance(e, sympy.Expr):
73        assert e.is_number, e
74        # NaNs can occur when doing things like 0 * sympy.oo, but it is better
75        # if the operator notices this and takes care of it, because sometimes
76        # the NaN is inappropriate (for example, for ints, the [-oo, oo] range
77        # should go to zero when multiplied with [0, 0])
78        assert e != sympy.nan
79        return e
80    elif isinstance(e, BooleanAtom):
81        return e
82    else:
83        raise AssertionError(f"not simple sympy type {type(e)}: {e}")
84
85
86# Sympy atomics only. Unlike <=, it also works on Sympy bools.
87def sympy_generic_le(lower, upper):
88    if isinstance(lower, sympy.Expr):
89        assert isinstance(upper, sympy.Expr)
90        return lower <= upper
91    else:
92        # only negative condition is True > False
93        assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), (
94            lower,
95            upper,
96        )
97        return not (lower and not upper)
98
99
100def vr_is_bool(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[SympyBoolean]]:
101    return vr.is_bool
102
103
104def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]:
105    return not vr.is_bool
106
107
108ExprIn = Union[int, float, sympy.Expr]
109BoolIn = Union[bool, SympyBoolean]
110AllIn = Union[ExprIn, BoolIn]
111ExprFn = Callable[[sympy.Expr], sympy.Expr]
112ExprFn2 = Callable[[sympy.Expr, sympy.Expr], sympy.Expr]
113BoolFn = Callable[[SympyBoolean], SympyBoolean]
114BoolFn2 = Callable[[SympyBoolean, SympyBoolean], SympyBoolean]
115AllFn = Union[ExprFn, BoolFn]
116AllFn2 = Union[ExprFn2, BoolFn2]
117
118
119@dataclasses.dataclass(frozen=True)
120class ValueRanges(Generic[_T]):
121    if TYPE_CHECKING:
122        # ruff doesn't understand circular references but mypy does
123        ExprVR = ValueRanges[sympy.Expr]  # noqa: F821
124        BoolVR = ValueRanges[SympyBoolean]  # noqa: F821
125        AllVR = Union[ExprVR, BoolVR]
126
127    # Although the type signature here suggests you can pass any
128    # sympy expression, in practice the analysis here only works
129    # with constant sympy expressions
130    lower: _T
131    upper: _T
132    is_bool: bool
133    is_int: bool
134    is_float: bool
135
136    def __repr__(self) -> str:
137        return f"VR[{self.lower}, {self.upper}]"
138
139    @overload
140    def __init__(
141        self: ValueRanges[sympy.Expr],
142        lower: ExprIn,
143        upper: ExprIn,
144    ) -> None:
145        ...
146
147    @overload
148    def __init__(  # type: ignore[misc]
149        self: ValueRanges[SympyBoolean],
150        lower: BoolIn,
151        upper: BoolIn,
152    ) -> None:
153        ...
154
155    def __init__(self, lower: AllIn, upper: AllIn) -> None:
156        lower = simple_sympify(lower)
157        upper = simple_sympify(upper)
158        # TODO: when the bounds have free variables, this may be
159        # nontrivial to actually verify
160        try:
161            if not sympy_generic_le(lower, upper):
162                raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
163        except TypeError as e:
164            raise TypeError(f"Could not compare {lower} <= {upper}") from e
165
166        is_bool_lower = isinstance(lower, SympyBoolean)
167        is_bool_upper = isinstance(upper, SympyBoolean)
168        assert is_bool_lower == is_bool_upper, (lower, upper)
169
170        # Warning: is_int/is_float is best effort.  We do pretty well in
171        # Dynamo, but in Inductor these attributes are often wrong because we
172        # are not very rigorous in dtype analysis.  This is also why we need
173        # the flexible analysis for is_int: sometimes a sympy.oo pops in for
174        # an integer bound. I would /like/ for us not to do this, but it's
175        # too hard to push the invariant through right now.
176        if isinstance(lower, sympy.Integer) and upper == sympy.oo:
177            upper = int_oo
178        if isinstance(upper, sympy.Integer) and lower == -sympy.oo:
179            lower = -int_oo
180        # NB: [-int_oo, -int_oo] and [int_oo, int_oo] are allowed
181        integer_types = (sympy.Integer, NegativeIntInfinity, IntInfinity)
182        is_int_lower = isinstance(lower, integer_types)
183        is_int_upper = isinstance(upper, integer_types)
184
185        # Because this is a frozen class
186        object.__setattr__(self, "lower", lower)
187        object.__setattr__(self, "upper", upper)
188        # Unlike bool/int in Python, we don't report bools are ints
189        #
190        # NB: is_bool_lower == is_bool_upper, so we only need to check one
191        object.__setattr__(self, "is_bool", is_bool_lower)
192        object.__setattr__(
193            self,
194            "is_int",
195            not self.is_bool and is_int_lower and is_int_upper,
196        )
197        """
198        # This assert is just impossible right now, too many sympy bugs
199        if self.is_int:
200            # NB: sympy will sometimes randomly lose the float-ness of zero,
201            # so we also need to account for that in the assertion here.
202            # See also https://github.com/sympy/sympy/issues/26620
203            assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], (
204                lower,
205                upper,
206            )
207            assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper)
208        """
209        # NB: [-oo, oo] always advertises as float!
210        object.__setattr__(self, "is_float", not self.is_bool and not self.is_int)
211        assert self.is_bool or self.is_int or self.is_float, (lower, upper)
212
213    def boolify(self) -> ValueRanges[SympyBoolean]:
214        if vr_is_bool(self):
215            return self
216        elif self == ValueRanges.unknown():
217            return ValueRanges.unknown_bool()
218        else:
219            raise AssertionError(f"not bool like {self}")
220
221    def __contains__(self, x: AllIn) -> bool:
222        return ValueRanges.wrap(x).issubset(self)
223
224    def issubset(self, other):
225        return sympy_generic_le(other.lower, self.lower) and sympy_generic_le(
226            self.upper, other.upper
227        )
228
229    def tighten(self, other) -> ValueRanges:
230        """Given two ValueRanges, returns their intersection"""
231        return self & other
232
233    # Intersection
234    @overload
235    def __and__(
236        self: ValueRanges[sympy.Expr],
237        other: ValueRanges[sympy.Expr],
238    ) -> ValueRanges[sympy.Expr]:
239        ...
240
241    @overload
242    def __and__(  # type: ignore[misc]
243        self: ValueRanges[SympyBoolean],
244        other: ValueRanges[SympyBoolean],
245    ) -> ValueRanges[SympyBoolean]:
246        ...
247
248    def __and__(self: AllVR, other: AllVR) -> AllVR:
249        if other == ValueRanges.unknown():
250            return self
251        if self == ValueRanges.unknown():
252            return other
253        assert self.is_bool == other.is_bool, (self, other)
254        assert self.is_int == other.is_int, (self, other)
255        assert self.is_float == other.is_float, (self, other)
256        if self.is_bool:
257            return ValueRanges(
258                sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper)
259            )
260        else:
261            return ValueRanges(
262                sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper)
263            )
264
265    # Union
266    @overload
267    def __or__(
268        self: ValueRanges[sympy.Expr],
269        other: ValueRanges[sympy.Expr],
270    ) -> ValueRanges[sympy.Expr]:
271        ...
272
273    @overload
274    def __or__(  # type: ignore[misc]
275        self: ValueRanges[SympyBoolean],
276        other: ValueRanges[SympyBoolean],
277    ) -> ValueRanges[SympyBoolean]:
278        ...
279
280    def __or__(self: AllVR, other: AllVR) -> AllVR:
281        if ValueRanges.unknown() in (self, other):
282            return ValueRanges.unknown()
283        assert self.is_bool == other.is_bool, (self, other)
284        assert self.is_int == other.is_int, (self, other)
285        assert self.is_float == other.is_float, (self, other)
286        if self.is_bool:
287            return ValueRanges(
288                sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper)
289            )
290        else:
291            return ValueRanges(
292                sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper)
293            )
294
295    def is_singleton(self) -> bool:
296        return self.lower == self.upper
297
298    @staticmethod
299    def unknown() -> ValueRanges[sympy.Expr]:
300        return ValueRanges(-sympy.oo, sympy.oo)
301
302    @staticmethod
303    def unknown_int() -> ValueRanges[sympy.Expr]:
304        return ValueRanges(-int_oo, int_oo)
305
306    @staticmethod
307    def unknown_bool() -> ValueRanges[SympyBoolean]:
308        return ValueRanges(sympy.false, sympy.true)
309
310    @overload
311    @staticmethod
312    # work around the fact that bool and int overlap
313    def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR:  # type: ignore[overload-overlap]
314        ...
315
316    @overload
317    @staticmethod
318    def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR:  # type: ignore[misc]
319        ...
320
321    @staticmethod
322    def wrap(arg: Union[AllIn, AllVR]) -> AllVR:
323        if isinstance(arg, ValueRanges):
324            return arg
325        if isinstance(arg, float) and math.isnan(arg):
326            return ValueRanges.unknown()
327        # arg is either ExprIn or BoolIn, but we don't know it here
328        return ValueRanges(arg, arg)  # type: ignore[arg-type]
329
330    @staticmethod
331    def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
332        """Increasing: x <= y => f(x) <= f(y)."""
333        x = ValueRanges.wrap(x)
334        return ValueRanges(fn(x.lower), fn(x.upper))
335
336    @overload
337    @staticmethod
338    def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
339        ...
340
341    @overload
342    @staticmethod
343    def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR:  # type: ignore[misc]
344        ...
345
346    @staticmethod
347    def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR:
348        """Decreasing: x <= y => f(x) >= f(y)."""
349        x = ValueRanges.wrap(x)
350        # consistently either Expr or Bool, but we don't know it here
351        return ValueRanges(fn(x.upper), fn(x.lower))  # type: ignore[arg-type]
352
353    @staticmethod
354    def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
355        """It's increasing or decreasing."""
356        x = ValueRanges.wrap(x)
357        l = fn(x.lower)
358        u = fn(x.upper)
359        return ValueRanges(min(l, u), max(l, u))
360
361    @staticmethod
362    def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
363        """Fn is convex and has a minimum at 0."""
364        x = ValueRanges.wrap(x)
365        if 0 in x:
366            upper = max(fn(x.lower), fn(x.upper))
367            upper = simple_sympify(upper)
368            if isinstance(upper, sympy.Float) or upper == sympy.oo:
369                return ValueRanges(0.0, upper)
370            return ValueRanges(0, upper)
371        return ValueRanges.monotone_map(x, fn)
372
373    @overload
374    @staticmethod
375    def coordinatewise_increasing_map(
376        x: Union[ExprIn, ExprVR],
377        y: Union[ExprIn, ExprVR],
378        fn: ExprFn2,
379    ) -> ExprVR:
380        ...
381
382    @overload
383    @staticmethod
384    def coordinatewise_increasing_map(  # type: ignore[misc]
385        x: Union[BoolIn, BoolVR],
386        y: Union[BoolIn, BoolVR],
387        fn: BoolFn2,
388    ) -> BoolVR:
389        ...
390
391    @staticmethod
392    def coordinatewise_increasing_map(
393        x: Union[AllIn, AllVR],
394        y: Union[AllIn, AllVR],
395        fn: AllFn2,
396    ) -> AllVR:
397        """
398        It's increasing on each coordinate.
399
400        Mathematically:
401        For every 1 <= i <= n and x_i <= y_i we have that
402        f(x1, .., xn) <= f(x1, , yi, ..., xn)
403        """
404        x, y = ValueRanges.wrap(x), ValueRanges.wrap(y)
405        return ValueRanges(
406            fn(x.lower, y.lower),  # type: ignore[arg-type]
407            fn(x.upper, y.upper),  # type: ignore[arg-type]
408        )
409
410    @classmethod
411    def coordinatewise_monotone_map(cls, x, y, fn):
412        """It's increasing or decreasing on each coordinate."""
413        x, y = cls.wrap(x), cls.wrap(y)
414        products = [
415            fn(a, b)
416            for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper])
417        ]
418        return ValueRanges(min(products), max(products))
419
420
421class SymPyValueRangeAnalysis:
422    """
423    It gives bounds on a SymPy operator given bounds on its arguments
424    See the function `bound_sympy` for a function that applies this logic to a full SymPy expression
425    """
426
427    @staticmethod
428    def constant(value, dtype):
429        if isinstance(value, ValueRanges):
430            assert value.is_singleton()
431            value = value.lower
432        # NB: value is NOT a sympy expression, it's a constant!
433        is_python = isinstance(value, (int, float, bool))
434        assert is_python or isinstance(
435            value, (BooleanAtom, sympy.Integer, sympy.Number)
436        )
437
438        # using nan makes subsequent computation throw, and for the purposes of optimization
439        # returning -math.inf - math.inf is equivalent to giving up
440        if isinstance(value, SupportsFloat) and math.isnan(value):
441            if dtype == torch.bool:
442                return ValueRanges.unknown_bool()
443            elif dtype.is_floating_point:
444                return ValueRanges.unknown()
445            else:
446                return ValueRanges(-int_oo, int_oo)
447
448        if is_python:
449            type_ = dtype_to_type(dtype)
450            value = type_(value)
451        else:
452            # We do a type check on a best-effort basis
453            # We don't want to force a cast to sympy.Float if the value is Rational to avoid losing precision
454            if dtype == torch.bool:
455                assert isinstance(value, BooleanAtom)
456            elif dtype.is_floating_point:
457                assert not value.is_finite or value.is_real
458            else:
459                # dtype is intXX
460                assert value.is_integer
461
462        r = ValueRanges.wrap(value)
463        return r
464
465    @staticmethod
466    def to_dtype(a, dtype, src_dtype=None):
467        if dtype == torch.float64:
468            return ValueRanges.increasing_map(a, ToFloat)
469        elif dtype == torch.bool:
470            return ValueRanges.unknown_bool()
471        elif not dtype.is_floating_point:
472            return ValueRanges.unknown_int()
473        return ValueRanges.unknown()
474
475    @staticmethod
476    def trunc_to_int(a, dtype):
477        return ValueRanges.increasing_map(a, TruncToInt)
478
479    @staticmethod
480    def not_(a):
481        a = ValueRanges.wrap(a)
482        a = a.boolify()
483        assert a.is_bool
484        return ValueRanges.decreasing_map(a, sympy.Not)
485
486    @staticmethod
487    def or_(a, b):
488        return ValueRanges.coordinatewise_increasing_map(a, b, sympy.Or)
489
490    @staticmethod
491    def and_(a, b):
492        return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And)
493
494    @staticmethod
495    def eq(a, b):
496        a = ValueRanges.wrap(a)
497        b = ValueRanges.wrap(b)
498        if a.is_singleton() and b.is_singleton() and a.lower == b.lower:
499            return ValueRanges.wrap(sympy.true)
500        elif a.lower > b.upper or b.lower > a.upper:  # ranges disjoint
501            return ValueRanges.wrap(sympy.false)
502        return ValueRanges(sympy.false, sympy.true)
503
504    @classmethod
505    def ne(cls, a, b):
506        return cls.not_(cls.eq(a, b))
507
508    @classmethod
509    def identity(cls, a):
510        return ValueRanges.wrap(a)
511
512    @classmethod
513    def lt(cls, a, b):
514        a = ValueRanges.wrap(a)
515        b = ValueRanges.wrap(b)
516        assert a.is_bool == b.is_bool
517        if a.is_bool:
518            return cls.and_(cls.not_(a), b)
519        else:
520            if a.upper < b.lower:
521                return ValueRanges.wrap(sympy.true)
522            elif a.lower >= b.upper:
523                return ValueRanges.wrap(sympy.false)
524            return ValueRanges(sympy.false, sympy.true)
525
526    @classmethod
527    def gt(cls, a, b):
528        return cls.lt(b, a)
529
530    @classmethod
531    def le(cls, a, b):
532        return cls.not_(cls.gt(a, b))
533
534    @classmethod
535    def ge(cls, a, b):
536        return cls.not_(cls.lt(a, b))
537
538    @staticmethod
539    def add(a, b):
540        return ValueRanges.coordinatewise_increasing_map(
541            a, b, _keep_float(operator.add)
542        )
543
544    @classmethod
545    def mul(cls, a, b):
546        a = ValueRanges.wrap(a)
547        b = ValueRanges.wrap(b)
548
549        assert a.is_bool == b.is_bool
550        if a.is_bool:
551            return cls.and_(a, b)
552
553        def safe_mul(a, b):
554            # Make unknown() * wrap(0.0) == wrap(0.0)
555            if a == 0.0 or a == 0:
556                return a
557            elif b == 0.0 or b == 0:
558                return b
559            else:
560                return a * b
561
562        return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul))
563
564    @staticmethod
565    def int_truediv(a, b):
566        a = ValueRanges.wrap(a)
567        b = ValueRanges.wrap(b)
568        if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)):
569            return ValueRanges.unknown()
570        else:
571            return ValueRanges.coordinatewise_monotone_map(
572                a, b, _keep_float(IntTrueDiv)
573            )
574
575    @staticmethod
576    def truediv(a, b):
577        a = ValueRanges.wrap(a)
578        b = ValueRanges.wrap(b)
579        if 0 in b or (
580            (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)
581        ):
582            return ValueRanges.unknown()
583        else:
584            return ValueRanges.coordinatewise_monotone_map(
585                a, b, _keep_float(FloatTrueDiv)
586            )
587
588    @staticmethod
589    def floordiv(a, b):
590        a = ValueRanges.wrap(a)
591        b = ValueRanges.wrap(b)
592        if 0 in b:
593            return ValueRanges.unknown_int()
594        products = []
595        for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]):
596            r = FloorDiv(x, y)
597            if r is sympy.nan:
598                products.append((sympy.sign(x) * sympy.sign(y)) * int_oo)
599            else:
600                products.append(r)
601
602        return ValueRanges(min(products), max(products))
603
604    @classmethod
605    def mod(cls, x, y):
606        x = ValueRanges.wrap(x)
607        y = ValueRanges.wrap(y)
608        # nb. We implement C semantics
609
610        def c_mod(a, b):
611            ret = abs(a) % abs(b)
612            if a < 0:
613                ret *= -1
614            return ret
615
616        def c_div(a, b):
617            x = a / b
618            return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x
619
620        if 0 in y:
621            return ValueRanges.unknown_int()
622        elif y.is_singleton():
623            y_val = abs(y.lower)
624            # If it wraps, we need to take the whole interval
625
626            # The function is locally linear if they are in the same class
627            if c_div(x.lower, y_val) == c_div(x.upper, y_val):
628                return ValueRanges.increasing_map(x, lambda u: c_mod(u, y_val))
629            if x.upper < 0:
630                # Negative case
631                return ValueRanges(-y_val + 1, 0)
632            elif x.lower > 0:
633                # Positive case
634                return ValueRanges(0, y_val - 1)
635            else:
636                # Mixed case
637                lower = max(-y_val + 1, x.lower)
638                upper = min(y_val - 1, x.upper)
639                return ValueRanges(lower, upper)
640        else:
641            # Too difficult, we bail out
642            upper = cls.abs(y).upper - 1
643            return ValueRanges(-upper, upper)
644
645    @classmethod
646    def modular_indexing(cls, a, b, c):
647        return cls.mod(cls.floordiv(a, b), c)
648
649    @classmethod
650    def is_non_overlapping_and_dense_indicator(cls, *args):
651        return ValueRanges.unknown_int()
652
653    @classmethod
654    def pow_by_natural(cls, a, b):
655        a = ValueRanges.wrap(a)
656        b = ValueRanges.wrap(b)
657        if a.is_singleton() and b.is_singleton():
658            return ValueRanges.wrap(safe_pow(a.lower, b.lower))
659        # NB: Exclude zero, because zero is special
660        elif a.lower >= 1:
661            # We should know that b >= 0 but we may have forgotten this fact due
662            # to replacements, so don't assert it, but DO clamp it to prevent
663            # degenerate problems
664            return ValueRanges.coordinatewise_increasing_map(
665                a, b & ValueRanges(0, int_oo), PowByNatural
666            )
667        elif b.is_singleton():
668            if b.lower % 2 == 0:
669                # x^n where n is even
670                return ValueRanges.convex_min_zero_map(
671                    a, lambda x: safe_pow(x, b.lower)
672                )
673            else:
674                # x^n where n is odd
675                return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower))
676        else:
677            # a is potentially negative, and we don't know if the exponent is
678            # even or odd.  So just conservatively set the upper and lower
679            # bound based on what the maximum absolute value could be, in both
680            # directions
681            max_base = max(a.upper, -a.lower)
682            return ValueRanges(
683                -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper)
684            )
685
686    @classmethod
687    def pow(cls, a, b):
688        return ValueRanges.unknown()
689
690        # We could implement all this, but for floating point pow, is there
691        # really a point?
692        """
693        a = ValueRanges.wrap(a)
694        b = ValueRanges.wrap(b)
695
696        # Not implemented yet. It's a bit tricky
697        # If you want to implement it, compute the partial derivatives of a ** b
698        # and check the ranges where the function is increasing / decreasing
699        # Another non-tight way of doing this is defaulting to doing noting that for a > 0,  a ** b == exp(b * log(a))
700        # If this second option is implemented, by carefult about the types and possible infinities here and there.
701        if not b.is_singleton():
702            return ValueRanges.unknown()
703
704        b = b.lower
705        if a.is_singleton():
706            a = a.lower
707            r = a**b
708            if not r.is_finite:
709                return ValueRanges.unknown()
710            return ValueRanges.wrap(r)
711
712        if b == 0:
713            if not a.lower.is_finite:
714                return ValueRanges.unknown()
715            return ValueRanges.wrap(1.0)
716
717        if b < 0:
718            a = cls.reciprocal(a)
719            b = -b
720
721        if a == ValueRanges.unknown():
722            return ValueRanges.unknown()
723
724        # If the base is positive, then we're good, otherwise nothing's defined
725        if a.lower >= 0:
726            return ValueRanges.increasing_map(a, lambda x: x**b)
727        else:
728            return ValueRanges.unknown()
729        """
730
731    @staticmethod
732    def reciprocal(x):
733        """Needed as it's used in pow, but it won't appear on a SymPy expression"""
734        x = ValueRanges.wrap(x)
735        if 0 in x:
736            return ValueRanges.unknown()
737        else:
738            return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y))  # type: ignore[operator]
739
740    @staticmethod
741    def abs(x):
742        return ValueRanges.convex_min_zero_map(x, abs)
743
744    @staticmethod
745    def exp(x):
746        return ValueRanges.increasing_map(x, OpaqueUnaryFn_exp)
747
748    @staticmethod
749    def log(x):
750        x = ValueRanges.wrap(x)
751        if x.lower <= 0:
752            return ValueRanges.unknown()
753        return ValueRanges.increasing_map(x, OpaqueUnaryFn_log)
754
755    @classmethod
756    def minimum(cls, a, b):
757        return cls.min_or_max(a, b, sympy.Min)
758
759    @classmethod
760    def maximum(cls, a, b):
761        return cls.min_or_max(a, b, sympy.Max)
762
763    @staticmethod
764    def min_or_max(a, b, fn):
765        a = ValueRanges.wrap(a)
766        b = ValueRanges.wrap(b)
767        return ValueRanges.coordinatewise_increasing_map(a, b, fn)
768
769    @classmethod
770    def floor_to_int(cls, x, dtype):
771        return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor)
772
773    @classmethod
774    def ceil_to_int(cls, x, dtype):
775        return ValueRanges.increasing_map(
776            x, sympy.functions.elementary.integers.ceiling
777        )
778
779    # I think these implementations are sound.  The hazard here is that sympy
780    # will carry out the floor/ceil at too high precision and then something
781    # bad will happen when we convert it to float.
782    #
783    # For truncation, the implementation is clearly sound, because the desired
784    # target float is always exactly representable, since you're just chopping
785    # off bits the mantissa.  But what about ceil/floor?
786    #
787    # The important constraint here is that we're not defining floor on
788    # arbitrary real numbers, only representable float numbers.  So we can
789    # take advantage of the fact that before we reach the first
790    # unrepresentable integer in floating point space, we have the range of
791    # numbers corresponding to exponent zero: all integers, with no fractional
792    # amounts.  floor/ceil is an identity operation in this case.  In the
793    # range below here, representable floating point numbers are spaced
794    # exactly 1/2 apart, and notably, both the floor/ceil are defined floating
795    # point numbers.  There is no "gap" as you step up to the next exponent.
796
797    @classmethod
798    def floor(cls, x):
799        return ValueRanges.increasing_map(
800            x, _keep_float(sympy.functions.elementary.integers.floor)
801        )
802
803    @classmethod
804    def ceil(cls, x):
805        return ValueRanges.increasing_map(
806            x, _keep_float(sympy.functions.elementary.integers.ceiling)
807        )
808
809    @classmethod
810    def round_decimal(cls, number, ndigits):
811        if not ndigits.is_singleton():
812            return ValueRanges.unknown()
813
814        ndigits = ndigits.lower
815        # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind
816        # the second parameter.
817        fn = lambda number: RoundDecimal(number, ndigits)  # type: ignore[misc, assignment]  # noqa: E731
818
819        return ValueRanges.increasing_map(number, fn)
820
821    @classmethod
822    def round_to_int(cls, number, dtype):
823        return ValueRanges.increasing_map(number, RoundToInt)
824
825    # It's used in some models on symints
826    @staticmethod
827    def sqrt(x):
828        x = ValueRanges.wrap(x)
829        if x.lower < 0:
830            return ValueRanges.unknown()
831        return ValueRanges.increasing_map(x, OpaqueUnaryFn_sqrt)
832
833    @staticmethod
834    def where(a, b, c):
835        b = ValueRanges.wrap(b)
836        c = ValueRanges.wrap(c)
837        a = a.boolify()
838        # We sometimes write unknown without specifying the type correctly
839        # In particular, we do that when initialising the bounds for loads in bounds.py
840        assert b.is_bool == c.is_bool or ValueRanges.unknown() in (b, c)
841        if b.is_bool:
842            return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
843        else:
844            return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))
845
846    # expr_cond_pair is used to represent a single (expr, condition) pair in piecewise.
847    # We just return the value range of the expression and its corresponding condition as a tuple
848    # and defer the analysis to piecewise
849    @staticmethod
850    def expr_cond_pair(a, b):
851        b = b.boolify()
852        return (a, b)
853
854    # piecewise function can be used to convert a SymBool to SymInt:
855    # int_expr = Piecewise((1, bool_expr), (0, True)), it evalutes to 1 when sym_bool is True and 0 otherwise.
856    #
857    # ranges is a sequence of (expr_range, condition_range) pairs. The range pair is constructed in expr_cond_pair.
858    # The ValueRange of Piecewise is just the union of all expr ranges whose condition expr can be True.
859    @staticmethod
860    def piecewise(*ranges):
861        init_range = None
862        for expr_range, cond_range in ranges:
863            if sympy.true in cond_range:
864                if init_range is None:
865                    init_range = expr_range
866                else:
867                    init_range = init_range | expr_range
868        return init_range
869
870    @staticmethod
871    def cos(x):
872        # TODO: We should tighten value ranges
873        # If input range span is pi + 2*pi*k, then output range is (-1, 1)
874        # otherwise the minimum of the value of the function on the extremes
875        return ValueRanges(-1.0, 1.0)
876
877    @staticmethod
878    def cosh(x):
879        return ValueRanges(0.0, sympy.oo)
880        """
881        x = ValueRanges.wrap(x)
882        if x.lower > 0:
883            return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh)
884        elif x.upper < 0:
885            return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh)
886        return ValueRanges(0.0, sympy.oo)
887        """
888
889    @staticmethod
890    def sin(x):
891        # TODO: We should tighten value ranges
892        # See details on cos
893        return ValueRanges(-1.0, 1.0)
894
895    @staticmethod
896    def sinh(x):
897        # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh)
898        return ValueRanges(-sympy.oo, sympy.oo)
899
900    @staticmethod
901    def tan(x):
902        return ValueRanges(-sympy.oo, sympy.oo)
903
904    @staticmethod
905    def tanh(x):
906        # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh)
907        return ValueRanges(-sympy.oo, sympy.oo)
908
909    @staticmethod
910    def asin(x):
911        return ValueRanges(-sympy.oo, sympy.oo)
912        """
913        x = ValueRanges.wrap(x)
914        if -1 <= x.lower and x.upper <= 1:
915            return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh)
916        return ValueRanges.unknown()
917        """
918
919    @staticmethod
920    def acos(x):
921        return ValueRanges(-sympy.oo, sympy.oo)
922        """
923        x = ValueRanges.wrap(x)
924        if -1 <= x.lower and x.upper <= 1:
925            return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos)
926        return ValueRanges.unknown()
927        """
928
929    @staticmethod
930    def atan(x):
931        return ValueRanges(-sympy.oo, sympy.oo)
932        # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan)
933
934    @staticmethod
935    def trunc(x):
936        return ValueRanges.increasing_map(x, TruncToFloat)
937
938
939class ValueRangeAnalysis(SymPyValueRangeAnalysis):
940    def __init__(self) -> None:
941        self.name = "ValueRangeAnalysis"
942        boolean_operators = (
943            "xor",
944            "logical_and",
945            "logical_or",
946            "logical_not",
947        )
948        for op in boolean_operators:
949            setattr(self, op, self.bool_handler)
950
951    @staticmethod
952    def bool_handler(*args, **kwargs):
953        # just assuming bools can have both values
954        return ValueRanges(sympy.false, sympy.true)  # type: ignore[arg-type]
955
956    @staticmethod
957    def default_handler(*args, **kwargs):
958        # many ops are unlikely to show up in optimizable indexing compute,
959        # so we dont have full coverage
960        return ValueRanges.unknown()
961
962    def load(self, name: str, index: sympy.Expr):
963        return ValueRanges.unknown()
964
965    def store(self, name, index, value, mode=None):
966        return
967
968    def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
969        return ValueRanges.unknown()
970
971    @classmethod
972    def index_expr(cls, index, dtype):
973        assert isinstance(index, ValueRanges)
974        return cls.to_dtype(index, dtype)
975
976    @staticmethod
977    def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None):
978        x = ValueRanges.wrap(x)
979
980        if dtype == torch.bool:
981            if x.is_singleton():
982                return ValueRanges.wrap(x.lower != 0)
983            elif x.is_bool:
984                return x
985            elif 0 not in x:
986                return ValueRanges.wrap(sympy.true)
987            else:
988                return ValueRanges(sympy.false, sympy.true)
989
990        def cast(x, dtype):
991            # dtype is int or float
992            if dtype.is_floating_point:
993                return sympy.Float(x)
994            else:
995                if x in (int_oo, -int_oo):
996                    return x
997                try:
998                    return sympy.Integer(x)
999                except TypeError:
1000                    # inf cannot be cast to Integer
1001                    return x
1002
1003        if x.is_bool:
1004            if x.is_singleton():
1005                val = 1 if x.lower else 0
1006                return ValueRanges.wrap(cast(val, dtype))
1007            else:
1008                return ValueRanges(cast(0, dtype), cast(1, dtype))
1009        else:
1010            # int to float or float to int
1011            return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
1012
1013    @staticmethod
1014    def square(x):
1015        return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
1016
1017    @staticmethod
1018    def neg(x):
1019        return ValueRanges.decreasing_map(x, operator.neg)
1020
1021    # TODO: this is slightly inaccurate because truncdiv operates at integer
1022    # precision, but we're going through float truediv which means we can
1023    # potentially lose precision on the bounds
1024    @classmethod
1025    def truncdiv(cls, a, b):
1026        x = cls.truediv(a, b)
1027        if x == ValueRanges.unknown():
1028            return x
1029
1030        return cls.trunc(x)
1031
1032    @classmethod
1033    def sub(cls, a, b):
1034        return cls.add(a, cls.neg(b))
1035
1036    def __getattr__(self, name):
1037        log.debug("unhandled ValueRange op %s", name)
1038        return self.default_handler
1039
1040
1041def bound_sympy(
1042    expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None
1043) -> ValueRanges:
1044    log.debug(
1045        "bound_sympy(%s)%s",
1046        expr,
1047        LazyString(
1048            lambda: "\n"
1049            + "\n".join(
1050                f"  {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols
1051            )
1052            if ranges
1053            else ""
1054        ),
1055    )
1056    if isinstance(expr, sympy.Number):
1057        return ValueRanges.wrap(expr)
1058
1059    ranges = ranges or {}
1060
1061    # If there's a tracing context, augment available constrained ranges.
1062    context = torch._guards.TracingContext.try_get()
1063    if context and context.fake_mode.shape_env:
1064        ranges = {**context.fake_mode.shape_env.var_to_range, **ranges}
1065
1066    unbounded_vars = expr.free_symbols - ranges.keys()
1067    if unbounded_vars:
1068        # Give some bounds to the free variables via their SymPy assumptions
1069        # TODO A better way of doing this would be to assign them a range upon creation, as
1070        #      size variables can come with a lower bound of 2, as we specialize on 0 and 1
1071        unbounded_ranges: Dict[sympy.Symbol, ValueRanges] = {}
1072        for s in unbounded_vars:
1073            if s.is_integer:  # type: ignore[attr-defined]
1074                if s.is_positive:  # type: ignore[attr-defined]
1075                    vr = ValueRanges(1, int_oo)
1076                elif s.is_nonnegative:  # type: ignore[attr-defined]
1077                    vr = ValueRanges(0, int_oo)
1078                else:
1079                    vr = ValueRanges.unknown_int()
1080            else:
1081                # Don't bother trying very hard here
1082                vr = ValueRanges.unknown()
1083            unbounded_ranges[s] = vr  # type: ignore[index]
1084        ranges = {**ranges, **unbounded_ranges}
1085
1086    return sympy_interp(SymPyValueRangeAnalysis, ranges, expr)
1087