xref: /aosp_15_r20/external/pytorch/torch/utils/_sympy/functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import math
4import operator
5import sys
6
7import sympy
8from sympy import S
9from sympy.core import sympify
10from sympy.core.expr import Expr
11from sympy.core.function import Application
12from sympy.core.logic import _torf, fuzzy_and, fuzzy_or
13from sympy.core.numbers import equal_valued
14from sympy.core.operations import LatticeOp, ShortCircuit
15from sympy.core.sorting import ordered
16from sympy.core.traversal import walk
17from sympy.utilities.iterables import sift
18
19from .numbers import int_oo
20
21
22# Portions of this file are adapted from the Sympy codebase, which was
23# licensed as follows:
24#
25#   Copyright (c) 2006-2023 SymPy Development Team
26#
27#   All rights reserved.
28#
29#   Redistribution and use in source and binary forms, with or without
30#   modification, are permitted provided that the following conditions are met:
31#
32#     a. Redistributions of source code must retain the above copyright notice,
33#        this list of conditions and the following disclaimer.
34#     b. Redistributions in binary form must reproduce the above copyright
35#        notice, this list of conditions and the following disclaimer in the
36#        documentation and/or other materials provided with the distribution.
37#     c. Neither the name of SymPy nor the names of its contributors
38#        may be used to endorse or promote products derived from this software
39#        without specific prior written permission.
40#
41#   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
42#   AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
43#   IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
44#   ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
45#   ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
46#   DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
47#   SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
48#   CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
49#   LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
50#   OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
51#   DAMAGE.
52
53__all__ = [
54    "FloorDiv",
55    "ModularIndexing",
56    "Where",
57    "PythonMod",
58    "Mod",
59    "CleanDiv",
60    "CeilToInt",
61    "FloorToInt",
62    "CeilDiv",
63    "IntTrueDiv",
64    "FloatTrueDiv",
65    "LShift",
66    "RShift",
67    "IsNonOverlappingAndDenseIndicator",
68    "TruncToFloat",
69    "TruncToInt",
70    "RoundToInt",
71    "RoundDecimal",
72    "ToFloat",
73    "FloatPow",
74    "PowByNatural",
75    "Identity",
76]
77
78
79def _keep_float(f):
80    @functools.wraps(f)
81    def inner(*args):
82        r = f(*args)
83        if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
84            r, sympy.Float
85        ):
86            r = sympy.Float(float(r))
87        return r
88
89    return inner
90
91
92def fuzzy_eq(x, y):
93    if None in (x, y):
94        return None
95    return x == y
96
97
98def simple_floordiv_gcd(p, q):
99    """
100    Fast path for sympy.gcd, using a simple factoring strategy.
101
102    We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0,
103    where n is the greatest common integer factor and e is the largest
104    syntactic common factor (i.e., common sub-expression) in p and q.
105    Then the gcd returned is n*e, cancelling which we would be left with
106    p1 + p2 and q0.
107
108    Note that further factoring of p1 + p2 and q0 might be possible with
109    sympy.factor (which uses domain-specific theories). E.g., we are unable
110    to find that x*y + x + y + 1 is divisible by x + 1. More generally,
111    when q is of the form q1 + q2 (instead of being already factored) it
112    might be necessary to fall back on sympy.gcd.
113    """
114
115    def integer_coefficient(x):
116        integer_coefficients = [
117            abs(int(arg))
118            for arg in sympy.Mul.make_args(x)
119            if isinstance(arg, (int, sympy.Integer))
120        ]
121        return math.prod(integer_coefficients)
122
123    def integer_factor(expr):
124        integer_factors = map(integer_coefficient, sympy.Add.make_args(expr))
125        return functools.reduce(math.gcd, integer_factors)
126
127    gcd = math.gcd(integer_factor(p), integer_factor(q))
128    p, q = p / gcd, q / gcd
129
130    base_splits = list(map(sympy.Mul.make_args, sympy.Add.make_args(p)))
131    divisor_split = sympy.Mul.make_args(q)
132    for x in divisor_split:
133        if all(x in base_split for base_split in base_splits):
134            gcd = gcd * x
135    return gcd
136
137
138# It would be nice to have assertions on whether or not inputs is_integer
139# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy
140# sometimes inconsistently reports floats an integers.
141#
142# What we can assume from sympy is that if something is an int, it
143# definitely is is_integer, but if it is a float it may or may not
144# be is_integer.  So we are unable to do strong asserts that things
145# are NOT integers.
146
147
148# TODO: In Triton, // rounds to zero, but in Python, it is floor division.
149# When we can prove both arguments are non-negative, we should just have a
150# GenericFloorDiv (name pending) which can codegen efficiently in Python/C,
151# and then PythonFloorDiv and CIntDiv which have the appropriate rounding
152# semantics.
153#
154# Right now, FloorDiv de facto changes behavior if arguments are negative or
155# not, this can potentially cause correctness issues.
156class FloorDiv(sympy.Function):
157    """
158    We maintain this so that:
159    1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
160    2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
161
162    NB: This is Python-style floor division, round to -Inf
163    """
164
165    nargs = (2,)
166    precedence = 50  # precedence of mul  # noqa: F811
167
168    is_integer = True
169
170    @property
171    def base(self):
172        return self.args[0]
173
174    @property
175    def divisor(self):
176        return self.args[1]
177
178    def _sympystr(self, printer):
179        base = printer.parenthesize(self.base, self.precedence)
180        divisor = printer.parenthesize(self.divisor, self.precedence)
181        return f"({base}//{divisor})"
182
183    # Automatic evaluation.
184    # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
185    @classmethod
186    def eval(cls, base, divisor):
187        # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
188        # Assert triggered by inequality solver
189        # assert base.is_integer, base
190        # assert divisor.is_integer, divisor
191
192        # We don't provide the same error message as in Python because SymPy
193        # makes it difficult to check the types.
194        if divisor.is_zero:
195            raise ZeroDivisionError("division by zero")
196        if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in (
197            int_oo,
198            -int_oo,
199            sympy.oo,
200            -sympy.oo,
201        ):
202            return sympy.nan
203        if base is sympy.nan or divisor is sympy.nan:
204            return sympy.nan
205
206        if base.is_zero:
207            return sympy.S.Zero
208        if base.is_integer and equal_valued(divisor, 1):
209            return base
210        if base.is_integer and equal_valued(divisor, -1):
211            return sympy.Mul(base, -1)
212        if (
213            isinstance(base, sympy.Number)
214            and isinstance(divisor, sympy.Number)
215            and (
216                base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
217                or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
218            )
219        ):
220            r = float(base) / float(divisor)
221            if r == math.inf:
222                return int_oo
223            elif r == -math.inf:
224                return -int_oo
225            elif math.isnan(r):
226                return sympy.nan
227            else:
228                return sympy.Integer(math.floor(r))
229        if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
230            return sympy.Integer(int(base) // int(divisor))
231        if isinstance(base, FloorDiv):
232            return FloorDiv(base.args[0], base.args[1] * divisor)
233
234        # Expands (x + y) // b into x // b + y // b.
235        # This only works if floor is an identity, i.e. x / b is an integer.
236        for term in sympy.Add.make_args(base):
237            quotient = term / divisor
238            if quotient.is_integer and isinstance(divisor, sympy.Integer):
239                # NB: this is correct even if the divisor is not an integer, but it
240                # creates rational expressions that cause problems with dynamic
241                # shapes.
242                return FloorDiv(base - term, divisor) + quotient
243
244        try:
245            gcd = simple_floordiv_gcd(base, divisor)
246            if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add):
247                gcd = sympy.gcd(base, divisor)
248            if not equal_valued(gcd, 1):
249                return FloorDiv(
250                    sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
251                )
252        except sympy.PolynomialError:
253            pass  # https://github.com/pytorch/pytorch/issues/108276
254
255
256class ModularIndexing(sympy.Function):
257    """
258    ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus
259    """
260
261    nargs = (3,)
262    is_integer = True
263
264    @classmethod
265    def eval(cls, base, divisor, modulus):
266        if base == 0 or modulus == 1:
267            return sympy.Integer(0)
268
269        if (
270            isinstance(base, sympy.Integer)
271            and isinstance(divisor, sympy.Integer)
272            and isinstance(modulus, sympy.Integer)
273        ):
274            return (base // divisor) % modulus
275
276        try:
277            if divisor != 1:
278                gcd = sympy.gcd(base, divisor)
279                if gcd != 1:
280                    return ModularIndexing(
281                        sympy.simplify(base / gcd),
282                        sympy.simplify(divisor / gcd),
283                        modulus,
284                    )
285        except sympy.PolynomialError:
286            pass  # https://github.com/pytorch/pytorch/issues/108276
287
288        if isinstance(base, sympy.Add):
289            new_terms = []
290            all_positive = True
291            for term in base.args:
292                if sympy.gcd(term, modulus * divisor) != modulus * divisor:
293                    if (isinstance(term, sympy.Integer) and term < 0) or (
294                        isinstance(term, sympy.Mul)
295                        and isinstance(term.args[0], sympy.Integer)
296                        and term.args[0] < 0
297                    ):
298                        # workaround for https://github.com/openai/triton/issues/619,
299                        # if there are negative terms, // produces wrong result
300                        # TODO if https://github.com/openai/triton/issues/619 is fixed
301                        # this optimization would become valid
302                        all_positive = False
303                        break
304                    else:
305                        new_terms.append(term)
306
307            if len(new_terms) != len(base.args) and all_positive:
308                return ModularIndexing(sum(new_terms), divisor, modulus)
309
310        if isinstance(base, FloorDiv):
311            return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
312
313    def _eval_is_nonnegative(self):
314        p, q = self.args[:2]
315        return fuzzy_eq(p.is_nonnegative, q.is_nonnegative)  # type: ignore[attr-defined]
316
317    def _eval_is_positive(self):
318        p, q = self.args[:2]
319        return fuzzy_eq(p.is_positive, q.is_positive)  # type: ignore[attr-defined]
320
321
322class Where(sympy.Function):
323    """
324    Good ol' ternary operator
325    """
326
327    nargs = (3,)
328
329    def _eval_is_integer(self):
330        return True if self.args[1].is_integer and self.args[2].is_integer else None  # type: ignore[attr-defined]
331
332    def _eval_is_nonnegative(self):
333        return (
334            True
335            if self.args[1].is_nonnegative and self.args[2].is_nonnegative  # type: ignore[attr-defined]
336            else None
337        )
338
339    def _eval_is_positive(self):
340        return True if self.args[1].is_positive and self.args[2].is_positive else None  # type: ignore[attr-defined]
341
342    @classmethod
343    def eval(cls, c, p, q):
344        if c == sympy.true:
345            return p
346        elif c == sympy.false:
347            return q
348
349
350# Python-style modulus: take sign from RHS
351class PythonMod(sympy.Function):
352    nargs = (2,)
353
354    is_integer = True
355
356    @classmethod
357    def eval(cls, p, q):
358        # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
359        # Triggered by sympy.solvers.inequalities.reduce_inequalities
360        # assert p.is_integer, p
361        # assert q.is_integer, q
362
363        if q.is_zero:
364            raise ZeroDivisionError("Modulo by zero")
365
366        # Three cases:
367        #   1. p == 0
368        #   2. p is either q or -q
369        #   3. p is integer and q == 1
370        if p is S.Zero or p in (q, -q) or q == 1:
371            return S.Zero
372
373        # Evaluate if they are both literals.
374        if q.is_Number and p.is_Number:
375            return p % q
376
377        # If q == 2, it's a matter of whether p is odd or even.
378        if q.is_Number and q == 2:
379            if p.is_even:
380                return S.Zero
381            if p.is_odd:
382                return S.One
383
384        # If p is a multiple of q.
385        r = p / q
386        if r.is_integer:
387            return S.Zero
388
389        # If p < q and its ratio is positive, then:
390        #   - floor(p / q) = 0
391        #   - p % q = p - floor(p / q) * q = p
392        less = p < q
393        if less.is_Boolean and bool(less) and r.is_positive:
394            return p
395
396        if sympy.Mod(p, q) == 0:
397            return S.Zero
398
399    # NB: args[1] for PythonMod
400    def _eval_is_nonnegative(self):
401        return True if self.args[1].is_positive else None  # type: ignore[attr-defined]
402
403    def _eval_is_nonpositive(self):
404        return True if self.args[1].is_negative else None  # type: ignore[attr-defined]
405
406
407# Generic modulus: only defined on non-negative arguments
408class Mod(sympy.Function):
409    nargs = (2,)
410
411    is_integer = True
412    is_nonnegative = True
413
414    @classmethod
415    def eval(cls, p, q):
416        # This was adapted from: sympy/core/mod.py
417
418        # Triggered by
419        # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
420        # assert p.is_integer, p
421        # assert q.is_integer, q
422
423        if q.is_zero:
424            raise ZeroDivisionError("Modulo by zero")
425
426        # Three cases:
427        #   1. p == 0
428        #   2. p is either q or -q
429        #   3. p is integer and q == 1
430        if p is S.Zero or p in (q, -q) or q == 1:
431            return S.Zero
432
433        # Evaluate if they are both literals.
434        if q.is_Number and p.is_Number:
435            assert p >= 0, p
436            assert q >= 1, q
437            return p % q
438
439        # If q == 2, it's a matter of whether p is odd or even.
440        if q.is_Number and q == 2:
441            if p.is_even:
442                return S.Zero
443            if p.is_odd:
444                return S.One
445
446        # If p is a multiple of q.
447        r = p / q
448        if r.is_integer:
449            return S.Zero
450
451        # If p < q and its ratio is positive, then:
452        #   - floor(p / q) = 0
453        #   - p % q = p - floor(p / q) * q = p
454        less = p < q
455        if less.is_Boolean and bool(less) and r.is_positive:
456            return p
457
458
459class CleanDiv(FloorDiv):
460    """
461    Div where we can assume no rounding.
462    This is to enable future optimizations.
463    """
464
465
466# Don't use sympy ceiling/floor as they will attempt simplifications involving
467# frac
468class CeilToInt(sympy.Function):
469    is_integer = True
470
471    @classmethod
472    def eval(cls, number):
473        # assert number.is_integer is not True, number
474        if number in (sympy.oo, int_oo):
475            return int_oo
476        if number in (-sympy.oo, -int_oo):
477            return -int_oo
478        if isinstance(number, sympy.Number):
479            return sympy.Integer(math.ceil(float(number)))
480
481
482class FloorToInt(sympy.Function):
483    is_integer = True
484
485    @classmethod
486    def eval(cls, number):
487        # assert number.is_integer is not True, number
488        if number in (sympy.oo, int_oo):
489            return int_oo
490        if number in (-sympy.oo, int_oo):
491            return -int_oo
492        if isinstance(number, sympy.Number):
493            return sympy.Integer(math.floor(float(number)))
494
495
496class CeilDiv(sympy.Function):
497    """
498    Div used in indexing that rounds up.
499    """
500
501    is_integer = True
502
503    def __new__(cls, base, divisor):
504        base = sympy.sympify(base)
505        divisor = sympy.sympify(divisor)
506        if sympy.gcd(base, divisor) == divisor:
507            return CleanDiv(base, divisor)
508        else:
509            return FloorDiv(base + (divisor - 1), divisor)
510
511
512class LShift(sympy.Function):
513    is_integer = True
514
515    @classmethod
516    def eval(cls, base, shift):
517        if shift < 0:
518            raise ValueError("negative shift count")
519        return base * 2**shift
520
521
522class RShift(sympy.Function):
523    is_integer = True
524
525    @classmethod
526    def eval(cls, base, shift):
527        if shift < 0:
528            raise ValueError("negative shift count")
529        return base // 2**shift
530
531
532class MinMaxBase(Expr, LatticeOp):  # type: ignore[misc]
533    def __new__(cls, *args, **assumptions):
534        from sympy.core.parameters import global_parameters
535
536        evaluate = assumptions.pop("evaluate", global_parameters.evaluate)
537        args = (sympify(arg) for arg in args)
538
539        # first standard filter, for cls.zero and cls.identity
540        # also reshape Max(a, Max(b, c)) to Max(a, b, c)
541
542        if evaluate:
543            try:
544                args = frozenset(cls._new_args_filter(args))  # type: ignore[assignment]
545            except ShortCircuit:
546                return cls.zero  # type: ignore[attr-defined]
547            # remove redundant args that are easily identified
548            args = cls._collapse_arguments(args, **assumptions)
549            # find local zeros
550            args = cls._find_localzeros(args, **assumptions)
551        args = frozenset(args)
552
553        if not args:
554            return cls.identity  # type: ignore[attr-defined]
555
556        if len(args) == 1:
557            return list(args).pop()
558
559        # base creation
560        obj = Expr.__new__(cls, *ordered(args), **assumptions)
561        obj._argset = args
562        return obj
563
564    @classmethod
565    def _collapse_arguments(cls, args, **assumptions):
566        """Remove redundant args.
567
568        Examples
569        ========
570
571        >>> from sympy import Min, Max
572        >>> from sympy.abc import a, b, c, d, e
573
574        Any arg in parent that appears in any
575        parent-like function in any of the flat args
576        of parent can be removed from that sub-arg:
577
578        >>> Min(a, Max(b, Min(a, c, d)))
579        Min(a, Max(b, Min(c, d)))
580
581        If the arg of parent appears in an opposite-than parent
582        function in any of the flat args of parent that function
583        can be replaced with the arg:
584
585        >>> Min(a, Max(b, Min(c, d, Max(a, e))))
586        Min(a, Max(b, Min(a, c, d)))
587        """
588        if not args:
589            return args
590        args = list(ordered(args))
591        if cls is Min:
592            other = Max
593        else:
594            other = Min  # type: ignore[assignment]
595
596        # find global comparable max of Max and min of Min if a new
597        # value is being introduced in these args at position 0 of
598        # the ordered args
599        if args[0].is_number:
600            sifted = mins, maxs = [], []  # type: ignore[var-annotated]
601            for i in args:
602                for v in walk(i, Min, Max):
603                    if v.args[0].is_comparable:
604                        sifted[isinstance(v, Max)].append(v)
605            small = Min.identity
606            for i in mins:
607                v = i.args[0]
608                if v.is_number and (v < small) == True:  # noqa: E712
609                    small = v
610            big = Max.identity
611            for i in maxs:
612                v = i.args[0]
613                if v.is_number and (v > big) == True:  # noqa: E712
614                    big = v
615            # at the point when this function is called from __new__,
616            # there may be more than one numeric arg present since
617            # local zeros have not been handled yet, so look through
618            # more than the first arg
619            if cls is Min:
620                for arg in args:
621                    if not arg.is_number:
622                        break
623                    if (arg < small) == True:  # noqa: E712
624                        small = arg
625            elif cls == Max:
626                for arg in args:
627                    if not arg.is_number:
628                        break
629                    if (arg > big) == True:  # noqa: E712
630                        big = arg
631            T = None
632            if cls is Min:
633                if small != Min.identity:
634                    other = Max
635                    T = small
636            elif big != Max.identity:
637                other = Min  # type: ignore[assignment]
638                T = big
639            if T is not None:
640                # remove numerical redundancy
641                for i in range(len(args)):
642                    a = args[i]
643                    if isinstance(a, other):
644                        a0 = a.args[0]
645                        if (  # noqa: E712
646                            (a0 > T) if other == Max else (a0 < T)  # noqa: E712
647                        ) == True:  # noqa: E712
648                            args[i] = cls.identity  # type: ignore[attr-defined]
649
650        # remove redundant symbolic args
651        def do(ai, a):
652            if not isinstance(ai, (Min, Max)):
653                return ai
654            cond = a in ai.args
655            if not cond:
656                return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
657            if isinstance(ai, cls):
658                return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
659            return a
660
661        for i, a in enumerate(args):
662            args[i + 1 :] = [do(ai, a) for ai in args[i + 1 :]]
663
664        # factor out common elements as for
665        # Min(Max(x, y), Max(x, z)) -> Max(x, Min(y, z))
666        # and vice versa when swapping Min/Max -- do this only for the
667        # easy case where all functions contain something in common;
668        # trying to find some optimal subset of args to modify takes
669        # too long
670
671        def factor_minmax(args):
672            is_other = lambda arg: isinstance(arg, other)  # noqa: E731
673            other_args, remaining_args = sift(args, is_other, binary=True)
674            if not other_args:
675                return args
676
677            # Min(Max(x, y, z), Max(x, y, u, v)) -> {x,y}, ({z}, {u,v})
678            arg_sets = [set(arg.args) for arg in other_args]
679            common = set.intersection(*arg_sets)
680            if not common:
681                return args
682
683            new_other_args = list(common)
684            arg_sets_diff = [arg_set - common for arg_set in arg_sets]
685
686            # If any set is empty after removing common then all can be
687            # discarded e.g. Min(Max(a, b, c), Max(a, b)) -> Max(a, b)
688            if all(arg_sets_diff):
689                other_args_diff = [other(*s, evaluate=False) for s in arg_sets_diff]
690                new_other_args.append(cls(*other_args_diff, evaluate=False))
691
692            other_args_factored = other(*new_other_args, evaluate=False)
693            return remaining_args + [other_args_factored]
694
695        if len(args) > 1:
696            args = factor_minmax(args)
697
698        return args
699
700    @classmethod
701    def _new_args_filter(cls, arg_sequence):
702        """
703        Generator filtering args.
704
705        first standard filter, for cls.zero and cls.identity.
706        Also reshape ``Max(a, Max(b, c))`` to ``Max(a, b, c)``,
707        and check arguments for comparability
708        """
709        for arg in arg_sequence:
710            # pre-filter, checking comparability of arguments
711            if (
712                not isinstance(arg, Expr)
713                or arg.is_extended_real is False
714                or (arg.is_number and not arg.is_comparable)
715            ):
716                raise ValueError(f"The argument '{arg}' is not comparable.")
717
718            if arg == cls.zero:  # type: ignore[attr-defined]
719                raise ShortCircuit(arg)
720            elif arg == cls.identity:  # type: ignore[attr-defined]
721                continue
722            elif arg.func == cls:
723                yield from arg.args
724            else:
725                yield arg
726
727    @classmethod
728    def _find_localzeros(cls, values, **options):
729        """
730        Sequentially allocate values to localzeros.
731
732        When a value is identified as being more extreme than another member it
733        replaces that member; if this is never true, then the value is simply
734        appended to the localzeros.
735        """
736        localzeros = set()  # type: ignore[var-annotated]
737        for v in values:
738            is_newzero = True
739            localzeros_ = list(localzeros)
740            for z in localzeros_:
741                if id(v) == id(z):
742                    is_newzero = False
743                else:
744                    con = cls._is_connected(v, z)
745                    if con:
746                        is_newzero = False
747                        if con is True or con == cls:
748                            localzeros.remove(z)
749                            localzeros.update([v])
750            if is_newzero:
751                localzeros.update([v])
752        return localzeros
753
754    @classmethod
755    def _is_connected(cls, x, y):
756        """
757        Check if x and y are connected somehow.
758        """
759        if x == y:
760            return True
761        t, f = Max, Min
762        for op in "><":
763            for j in range(2):
764                try:
765                    if op == ">":
766                        v = x >= y
767                    else:
768                        v = x <= y
769                except TypeError:
770                    return False  # non-real arg
771                if not v.is_Relational:
772                    return t if v else f
773                t, f = f, t  # type: ignore[assignment]
774                x, y = y, x
775            x, y = y, x  # run next pass with reversed order relative to start
776
777        return False
778
779    _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args)  # noqa: E731
780    _eval_is_antihermitian = lambda s: _torf(  # noqa: E731
781        i.is_antihermitian for i in s.args  # noqa: E731
782    )  # noqa: E731
783    _eval_is_commutative = lambda s: _torf(  # noqa: E731
784        i.is_commutative for i in s.args  # noqa: E731
785    )  # noqa: E731
786    _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args)  # noqa: E731
787    _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args)  # noqa: E731
788    _eval_is_even = lambda s: _torf(i.is_even for i in s.args)  # noqa: E731
789    _eval_is_finite = lambda s: _torf(i.is_finite for i in s.args)  # noqa: E731
790    _eval_is_hermitian = lambda s: _torf(i.is_hermitian for i in s.args)  # noqa: E731
791    _eval_is_imaginary = lambda s: _torf(i.is_imaginary for i in s.args)  # noqa: E731
792    _eval_is_infinite = lambda s: _torf(i.is_infinite for i in s.args)  # noqa: E731
793    _eval_is_integer = lambda s: _torf(i.is_integer for i in s.args)  # noqa: E731
794    _eval_is_irrational = lambda s: _torf(i.is_irrational for i in s.args)  # noqa: E731
795    _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args)  # noqa: E731
796    _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args)  # noqa: E731
797    _eval_is_nonnegative = lambda s: _torf(  # noqa: E731
798        i.is_nonnegative for i in s.args  # noqa: E731
799    )  # noqa: E731
800    _eval_is_nonpositive = lambda s: _torf(  # noqa: E731
801        i.is_nonpositive for i in s.args  # noqa: E731
802    )  # noqa: E731
803    _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args)  # noqa: E731
804    _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args)  # noqa: E731
805    _eval_is_polar = lambda s: _torf(i.is_polar for i in s.args)  # noqa: E731
806    _eval_is_positive = lambda s: _torf(i.is_positive for i in s.args)  # noqa: E731
807    _eval_is_prime = lambda s: _torf(i.is_prime for i in s.args)  # noqa: E731
808    _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args)  # noqa: E731
809    _eval_is_real = lambda s: _torf(i.is_real for i in s.args)  # noqa: E731
810    _eval_is_extended_real = lambda s: _torf(  # noqa: E731
811        i.is_extended_real for i in s.args  # noqa: E731
812    )  # noqa: E731
813    _eval_is_transcendental = lambda s: _torf(  # noqa: E731
814        i.is_transcendental for i in s.args  # noqa: E731
815    )  # noqa: E731
816    _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args)  # noqa: E731
817
818
819class Max(MinMaxBase, Application):  # type: ignore[misc]
820    r"""
821    Return, if possible, the maximum value of the list.
822    """
823    zero = S.Infinity
824    identity = S.NegativeInfinity
825
826    def _eval_is_positive(self):
827        return fuzzy_or(a.is_positive for a in self.args)  # type: ignore[attr-defined]
828
829    def _eval_is_nonnegative(self):
830        return fuzzy_or(a.is_nonnegative for a in self.args)  # type: ignore[attr-defined]
831
832    def _eval_is_negative(self):
833        return fuzzy_and(a.is_negative for a in self.args)
834
835
836class Min(MinMaxBase, Application):  # type: ignore[misc]
837    """
838    Return, if possible, the minimum value of the list.
839    """
840
841    zero = S.NegativeInfinity
842    identity = S.Infinity
843
844    def _eval_is_positive(self):
845        return fuzzy_and(a.is_positive for a in self.args)  # type: ignore[attr-defined]
846
847    def _eval_is_nonnegative(self):
848        return fuzzy_and(a.is_nonnegative for a in self.args)  # type: ignore[attr-defined]
849
850    def _eval_is_negative(self):
851        return fuzzy_or(a.is_negative for a in self.args)
852
853
854def safe_pow(base, exp):
855    sign = 1
856    if base < 0:
857        base = -base
858        sign = 1 if exp % 2 == 0 else -1
859    return sign * _safe_pow(base, exp)
860
861
862# Prevent people from overflowing pow
863def _safe_pow(base, exponent):
864    if exponent < 0:
865        raise ValueError("Exponent must be non-negative.")
866
867    if exponent == 0:
868        return 1
869
870    half_exp = safe_pow(base, exponent // 2)
871    if half_exp is int_oo:
872        return int_oo
873
874    # TODO: microoptimization is to avoid overflowing into arbitrary precision
875    # and detect overflow prior to doing operations
876
877    result = half_exp * half_exp
878    if result > sys.maxsize:
879        return int_oo
880
881    if exponent % 2 == 1:
882        result *= base
883        if result > sys.maxsize:
884            return int_oo
885
886    return result
887
888
889class PowByNatural(sympy.Function):
890    is_integer = True
891
892    @classmethod
893    def eval(cls, base, exp):
894        if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
895            r = safe_pow(base, exp)
896            if r in (-int_oo, int_oo):
897                return r
898            return sympy.Integer(r)
899        if isinstance(exp, sympy.Integer):
900            # Rely on regular sympy Pow for this (note that iterated
901            # multiplication turns into a Pow anyway, you can't escape!!)
902            return sympy.Pow(base, exp)
903        if exp in (int_oo, sympy.oo):
904            if base.is_nonnegative:
905                return int_oo
906            elif base.is_negative:
907                return sympy.zoo  # this is apparently what (-2)**sympy.oo does
908        # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
909        # is a natural number if we do
910
911
912# base is assumed to be nonnegative, thereby prevent complex numbers from
913# occuring
914class FloatPow(sympy.Function):
915    is_real = True
916
917    @classmethod
918    def eval(cls, base, exp):
919        # NB: These test sympy.Number, not sympy.Float, because:
920        #   - Sometimes we may have sympy.oo or int_oo, and that's not a Float
921        #     (but coerces to math.Inf)
922        #   - Sometimes Float(0.0) will unpredictably decay to Integer(0),
923        #     but we should still accept it in floatey contexts
924        if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
925            return sympy.Float(float(base) ** float(exp))
926        # NB: do not do any nontrivial reasoning
927
928
929# Overloaded to be compatible with regular Python.
930# https://github.com/pytorch/pytorch/issues/90900
931#
932# In particular, sympy division is willing to simplify x/x == 1
933# where 1 is an integer, but this must be a float if x was float.
934class FloatTrueDiv(sympy.Function):
935    is_real = True
936
937    @classmethod
938    def eval(cls, base, divisor):
939        # assert base.is_integer is not True, base
940        # assert divisor.is_integer is not True, divisor
941
942        if divisor.is_zero:
943            raise ZeroDivisionError("division by zero")
944
945        if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
946            return sympy.Float(float(base) / float(divisor))
947
948
949# Overloaded to be compatible with regular Python.  We distinguish this from
950# FloatTrueDiv, because the code generation has to be different for this case:
951# Python has a fancy algorithm for integer true division that isn't just
952# "promote both arguments to float and use float division", so you need to
953# codegen it differently.  While technically you can work it out from the
954# types of the input, this is often inconvenient to do in Inductor codegen,
955# so just have a different operator
956# NB: Right now, Inductor codegen doesn't implement this correctly lol
957class IntTrueDiv(sympy.Function):
958    is_real = True
959
960    @classmethod
961    def eval(cls, base, divisor):
962        if divisor.is_zero:
963            raise ZeroDivisionError("division by zero")
964
965        if (
966            isinstance(base, sympy.Number)
967            and isinstance(divisor, sympy.Number)
968            and (
969                base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
970                or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
971            )
972        ):
973            # Don't have to worry about precision here, you're getting zero or
974            # inf from the division
975            return sympy.Float(float(base) / float(divisor))
976        if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
977            return sympy.Float(int(base) / int(divisor))
978
979
980# TODO: As an indicator, this != 0 implies == 1 (and vice versa).
981# Because we do not have the ability to guard on the stride permutation
982# at the moment, it is hard to make further inferences when this is true,
983# as although we know the tensor is contiguous in *some* layout, we don't
984# know which one (however, you could, for example, make the inference that
985# reshaping this to a 1D tensor can be guard-free.)
986class IsNonOverlappingAndDenseIndicator(sympy.Function):
987    is_integer = True
988
989    @classmethod
990    def eval(cls, *args):
991        assert len(args) % 2 == 0
992        dim = len(args) // 2
993        sizes = args[0:dim]
994        strides = args[dim:]
995
996        # sym_node imported in torch.__init__. Local import to avoid an import cycle
997        from torch.fx.experimental.symbolic_shapes import (
998            eval_is_non_overlapping_and_dense,
999        )
1000
1001        if all(isinstance(a, sympy.Integer) for a in args):
1002            return eval_is_non_overlapping_and_dense(
1003                [int(a) for a in sizes], [int(a) for a in strides]
1004            )
1005
1006        if dim == 1:
1007            # Manually implement the rank one short circuit
1008            if strides[0].is_Number and strides[0] == 1:
1009                return 1
1010
1011            if sizes[0].is_Number and sizes[0] < 2:
1012                return 1
1013
1014            # return 0 case covered by case above
1015
1016            # TODO: Inability to access size-obliviousness sucks: if we have a
1017            # size oblivious test on a size-like unbacked SymInt, we could
1018            # confidently return zero when we have a size-like u0 stride
1019            # and a size-like u1 size.  Maybe a fancy ValueRanges analysis for
1020            # this function could help figure this out.
1021
1022        if all(isinstance(a, sympy.Integer) for a in strides):
1023            assert dim != 0
1024            # When all strides are integral, we can sort, and the size for the
1025            # largest stride doesn't matter and can be arbitrarily symbolic
1026            s_sizes, s_strides = zip(
1027                *sorted(zip(sizes, strides), key=operator.itemgetter(1))
1028            )
1029            # Put something arbitrary in the max size spot, it'll be ignored
1030            if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]):
1031                s_sizes = s_sizes[:-1] + (42,)
1032                # We can reuse the regular eval, because it is invariant to
1033                # permutation of dimensions
1034                return eval_is_non_overlapping_and_dense(
1035                    [int(a) for a in s_sizes], [int(a) for a in s_strides]
1036                )
1037
1038        return None
1039
1040
1041# NB: this is inconsistent with math.trunc in Python
1042class TruncToFloat(sympy.Function):
1043    is_real = True
1044
1045    @classmethod
1046    def eval(cls, number):
1047        # assert number.is_integer is not True, number
1048        if isinstance(number, sympy.Number):
1049            # NB: It is safe to use truncation to integer, which is what
1050            # math.trunc does, as Python integers are arbitrary precision and
1051            # so we are guaranteed not to lose precision when we do this
1052            return sympy.Float(math.trunc(float(number)))
1053
1054
1055class TruncToInt(sympy.Function):
1056    is_integer = True
1057
1058    @classmethod
1059    def eval(cls, number):
1060        # assert number.is_integer is not True, number
1061        if number in (sympy.oo, int_oo):
1062            return int_oo
1063        if number in (-sympy.oo, -int_oo):
1064            return -int_oo
1065        if isinstance(number, sympy.Number):
1066            return sympy.Integer(math.trunc(float(number)))
1067
1068
1069# This is float -> int
1070class RoundToInt(sympy.Function):
1071    is_integer = True
1072
1073    @classmethod
1074    def eval(cls, number):
1075        # assert number.is_integer is not True, number
1076
1077        if number is sympy.oo:
1078            return int_oo
1079        if number is -sympy.oo:
1080            return -int_oo
1081        if isinstance(number, sympy.Number):
1082            return sympy.Integer(round(float(number), 0))
1083
1084
1085# To get float -> int, Python style round semantics.
1086#
1087#   x = PyFloat_AsDouble(self);
1088#   if (o_ndigits == Py_None) {
1089#       /* single-argument round or with None ndigits:
1090#        * round to nearest integer */
1091#       rounded = round(x);
1092#       if (fabs(x-rounded) == 0.5)
1093#           /* halfway case: round to even */
1094#           rounded = 2.0*round(x/2.0);
1095#       return PyLong_FromDouble(rounded);
1096#   }
1097
1098
1099# NB: Like Round, this only ever returns floats.  ndigits cannot be None
1100class RoundDecimal(sympy.Function):
1101    is_real = True
1102
1103    @classmethod
1104    def eval(cls, number, ndigits):
1105        # assert number.is_integer is not True, number
1106
1107        if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
1108            return sympy.Float(round(float(number), int(ndigits)))
1109
1110
1111class ToFloat(sympy.Function):
1112    is_real = True
1113
1114    @classmethod
1115    def eval(cls, number):
1116        if number in [sympy.oo, -sympy.oo]:
1117            return number
1118
1119        if isinstance(number, sympy.Integer):
1120            return sympy.Float(int(number))
1121        if number is int_oo:
1122            return sympy.oo
1123        if number is -int_oo:
1124            return -sympy.oo
1125
1126
1127class Identity(sympy.Function):
1128    """
1129    Prevents expansion and other optimizations
1130    """
1131
1132    def __repr__(self):
1133        return f"Identity({self.args[0]})"
1134
1135    def _eval_is_real(self):
1136        return self.args[0].is_real
1137
1138    def _eval_is_integer(self):
1139        return self.args[0].is_integer  # type: ignore[attr-defined]
1140
1141
1142def make_opaque_unary_fn(name):
1143    class OpaqueUnaryFn(sympy.Function):
1144        """
1145        Unlike the builtin sympy functions on real numbers like sympy.sqrt,
1146        these equivalents do not do any nontrivial reasoning besides
1147        constant propagation.  This helps avoid performing transformations
1148        that are valid for real numbers but are invalid for floating point;
1149        in particular, while we are willing to make optimizations that change
1150        numerics for Tensor compute, we are NOT willing to make optimziations
1151        that change numerics for size compute.
1152        """
1153
1154        _torch_handler_name = name
1155
1156        @classmethod
1157        def eval(cls, a):
1158            if isinstance(a, (sympy.Integer, sympy.Float)):
1159                # Python converts to float64 before computing, c.f.
1160                # >>> math.sin(2**53+1)
1161                # -0.848925964814655
1162                # >>> math.sin(float(2**53+1))
1163                # -0.848925964814655
1164                try:
1165                    return sympy.Float(getattr(math, name)(float(a)))
1166                # Just use sympy semantics for infinity/overflow, you might get some
1167                # weird objects but ask silly questions, get silly answers
1168                except OverflowError:
1169                    return getattr(sympy, name)(a)
1170            elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]:
1171                if a is int_oo:
1172                    a = sympy.oo
1173                if a is -int_oo:
1174                    a = -sympy.oo
1175                return getattr(sympy, name)(a)
1176            return None
1177
1178    OpaqueUnaryFn.__name__ = "OpaqueUnaryFn_" + name
1179
1180    return OpaqueUnaryFn
1181
1182
1183# Keep in sync with math_op_names in torch/fx/experimental/sym_node.py
1184OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt")
1185OpaqueUnaryFn_cos = make_opaque_unary_fn("cos")
1186OpaqueUnaryFn_cosh = make_opaque_unary_fn("cosh")
1187OpaqueUnaryFn_sin = make_opaque_unary_fn("sin")
1188OpaqueUnaryFn_sinh = make_opaque_unary_fn("sinh")
1189OpaqueUnaryFn_tan = make_opaque_unary_fn("tan")
1190OpaqueUnaryFn_tanh = make_opaque_unary_fn("tanh")
1191OpaqueUnaryFn_asin = make_opaque_unary_fn("asin")
1192OpaqueUnaryFn_acos = make_opaque_unary_fn("acos")
1193OpaqueUnaryFn_atan = make_opaque_unary_fn("atan")
1194OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
1195OpaqueUnaryFn_log = make_opaque_unary_fn("log")
1196OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
1197