1# mypy: allow-untyped-defs 2import functools 3import math 4import operator 5import sys 6 7import sympy 8from sympy import S 9from sympy.core import sympify 10from sympy.core.expr import Expr 11from sympy.core.function import Application 12from sympy.core.logic import _torf, fuzzy_and, fuzzy_or 13from sympy.core.numbers import equal_valued 14from sympy.core.operations import LatticeOp, ShortCircuit 15from sympy.core.sorting import ordered 16from sympy.core.traversal import walk 17from sympy.utilities.iterables import sift 18 19from .numbers import int_oo 20 21 22# Portions of this file are adapted from the Sympy codebase, which was 23# licensed as follows: 24# 25# Copyright (c) 2006-2023 SymPy Development Team 26# 27# All rights reserved. 28# 29# Redistribution and use in source and binary forms, with or without 30# modification, are permitted provided that the following conditions are met: 31# 32# a. Redistributions of source code must retain the above copyright notice, 33# this list of conditions and the following disclaimer. 34# b. Redistributions in binary form must reproduce the above copyright 35# notice, this list of conditions and the following disclaimer in the 36# documentation and/or other materials provided with the distribution. 37# c. Neither the name of SymPy nor the names of its contributors 38# may be used to endorse or promote products derived from this software 39# without specific prior written permission. 40# 41# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 42# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 43# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 44# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR 45# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 46# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 47# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 48# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 49# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 50# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 51# DAMAGE. 52 53__all__ = [ 54 "FloorDiv", 55 "ModularIndexing", 56 "Where", 57 "PythonMod", 58 "Mod", 59 "CleanDiv", 60 "CeilToInt", 61 "FloorToInt", 62 "CeilDiv", 63 "IntTrueDiv", 64 "FloatTrueDiv", 65 "LShift", 66 "RShift", 67 "IsNonOverlappingAndDenseIndicator", 68 "TruncToFloat", 69 "TruncToInt", 70 "RoundToInt", 71 "RoundDecimal", 72 "ToFloat", 73 "FloatPow", 74 "PowByNatural", 75 "Identity", 76] 77 78 79def _keep_float(f): 80 @functools.wraps(f) 81 def inner(*args): 82 r = f(*args) 83 if any(isinstance(a, sympy.Float) for a in args) and not isinstance( 84 r, sympy.Float 85 ): 86 r = sympy.Float(float(r)) 87 return r 88 89 return inner 90 91 92def fuzzy_eq(x, y): 93 if None in (x, y): 94 return None 95 return x == y 96 97 98def simple_floordiv_gcd(p, q): 99 """ 100 Fast path for sympy.gcd, using a simple factoring strategy. 101 102 We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0, 103 where n is the greatest common integer factor and e is the largest 104 syntactic common factor (i.e., common sub-expression) in p and q. 105 Then the gcd returned is n*e, cancelling which we would be left with 106 p1 + p2 and q0. 107 108 Note that further factoring of p1 + p2 and q0 might be possible with 109 sympy.factor (which uses domain-specific theories). E.g., we are unable 110 to find that x*y + x + y + 1 is divisible by x + 1. More generally, 111 when q is of the form q1 + q2 (instead of being already factored) it 112 might be necessary to fall back on sympy.gcd. 113 """ 114 115 def integer_coefficient(x): 116 integer_coefficients = [ 117 abs(int(arg)) 118 for arg in sympy.Mul.make_args(x) 119 if isinstance(arg, (int, sympy.Integer)) 120 ] 121 return math.prod(integer_coefficients) 122 123 def integer_factor(expr): 124 integer_factors = map(integer_coefficient, sympy.Add.make_args(expr)) 125 return functools.reduce(math.gcd, integer_factors) 126 127 gcd = math.gcd(integer_factor(p), integer_factor(q)) 128 p, q = p / gcd, q / gcd 129 130 base_splits = list(map(sympy.Mul.make_args, sympy.Add.make_args(p))) 131 divisor_split = sympy.Mul.make_args(q) 132 for x in divisor_split: 133 if all(x in base_split for base_split in base_splits): 134 gcd = gcd * x 135 return gcd 136 137 138# It would be nice to have assertions on whether or not inputs is_integer 139# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy 140# sometimes inconsistently reports floats an integers. 141# 142# What we can assume from sympy is that if something is an int, it 143# definitely is is_integer, but if it is a float it may or may not 144# be is_integer. So we are unable to do strong asserts that things 145# are NOT integers. 146 147 148# TODO: In Triton, // rounds to zero, but in Python, it is floor division. 149# When we can prove both arguments are non-negative, we should just have a 150# GenericFloorDiv (name pending) which can codegen efficiently in Python/C, 151# and then PythonFloorDiv and CIntDiv which have the appropriate rounding 152# semantics. 153# 154# Right now, FloorDiv de facto changes behavior if arguments are negative or 155# not, this can potentially cause correctness issues. 156class FloorDiv(sympy.Function): 157 """ 158 We maintain this so that: 159 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 160 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) 161 162 NB: This is Python-style floor division, round to -Inf 163 """ 164 165 nargs = (2,) 166 precedence = 50 # precedence of mul # noqa: F811 167 168 is_integer = True 169 170 @property 171 def base(self): 172 return self.args[0] 173 174 @property 175 def divisor(self): 176 return self.args[1] 177 178 def _sympystr(self, printer): 179 base = printer.parenthesize(self.base, self.precedence) 180 divisor = printer.parenthesize(self.divisor, self.precedence) 181 return f"({base}//{divisor})" 182 183 # Automatic evaluation. 184 # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval 185 @classmethod 186 def eval(cls, base, divisor): 187 # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full 188 # Assert triggered by inequality solver 189 # assert base.is_integer, base 190 # assert divisor.is_integer, divisor 191 192 # We don't provide the same error message as in Python because SymPy 193 # makes it difficult to check the types. 194 if divisor.is_zero: 195 raise ZeroDivisionError("division by zero") 196 if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in ( 197 int_oo, 198 -int_oo, 199 sympy.oo, 200 -sympy.oo, 201 ): 202 return sympy.nan 203 if base is sympy.nan or divisor is sympy.nan: 204 return sympy.nan 205 206 if base.is_zero: 207 return sympy.S.Zero 208 if base.is_integer and equal_valued(divisor, 1): 209 return base 210 if base.is_integer and equal_valued(divisor, -1): 211 return sympy.Mul(base, -1) 212 if ( 213 isinstance(base, sympy.Number) 214 and isinstance(divisor, sympy.Number) 215 and ( 216 base in (int_oo, -int_oo, sympy.oo, -sympy.oo) 217 or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo) 218 ) 219 ): 220 r = float(base) / float(divisor) 221 if r == math.inf: 222 return int_oo 223 elif r == -math.inf: 224 return -int_oo 225 elif math.isnan(r): 226 return sympy.nan 227 else: 228 return sympy.Integer(math.floor(r)) 229 if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): 230 return sympy.Integer(int(base) // int(divisor)) 231 if isinstance(base, FloorDiv): 232 return FloorDiv(base.args[0], base.args[1] * divisor) 233 234 # Expands (x + y) // b into x // b + y // b. 235 # This only works if floor is an identity, i.e. x / b is an integer. 236 for term in sympy.Add.make_args(base): 237 quotient = term / divisor 238 if quotient.is_integer and isinstance(divisor, sympy.Integer): 239 # NB: this is correct even if the divisor is not an integer, but it 240 # creates rational expressions that cause problems with dynamic 241 # shapes. 242 return FloorDiv(base - term, divisor) + quotient 243 244 try: 245 gcd = simple_floordiv_gcd(base, divisor) 246 if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add): 247 gcd = sympy.gcd(base, divisor) 248 if not equal_valued(gcd, 1): 249 return FloorDiv( 250 sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) 251 ) 252 except sympy.PolynomialError: 253 pass # https://github.com/pytorch/pytorch/issues/108276 254 255 256class ModularIndexing(sympy.Function): 257 """ 258 ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus 259 """ 260 261 nargs = (3,) 262 is_integer = True 263 264 @classmethod 265 def eval(cls, base, divisor, modulus): 266 if base == 0 or modulus == 1: 267 return sympy.Integer(0) 268 269 if ( 270 isinstance(base, sympy.Integer) 271 and isinstance(divisor, sympy.Integer) 272 and isinstance(modulus, sympy.Integer) 273 ): 274 return (base // divisor) % modulus 275 276 try: 277 if divisor != 1: 278 gcd = sympy.gcd(base, divisor) 279 if gcd != 1: 280 return ModularIndexing( 281 sympy.simplify(base / gcd), 282 sympy.simplify(divisor / gcd), 283 modulus, 284 ) 285 except sympy.PolynomialError: 286 pass # https://github.com/pytorch/pytorch/issues/108276 287 288 if isinstance(base, sympy.Add): 289 new_terms = [] 290 all_positive = True 291 for term in base.args: 292 if sympy.gcd(term, modulus * divisor) != modulus * divisor: 293 if (isinstance(term, sympy.Integer) and term < 0) or ( 294 isinstance(term, sympy.Mul) 295 and isinstance(term.args[0], sympy.Integer) 296 and term.args[0] < 0 297 ): 298 # workaround for https://github.com/openai/triton/issues/619, 299 # if there are negative terms, // produces wrong result 300 # TODO if https://github.com/openai/triton/issues/619 is fixed 301 # this optimization would become valid 302 all_positive = False 303 break 304 else: 305 new_terms.append(term) 306 307 if len(new_terms) != len(base.args) and all_positive: 308 return ModularIndexing(sum(new_terms), divisor, modulus) 309 310 if isinstance(base, FloorDiv): 311 return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) 312 313 def _eval_is_nonnegative(self): 314 p, q = self.args[:2] 315 return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] 316 317 def _eval_is_positive(self): 318 p, q = self.args[:2] 319 return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined] 320 321 322class Where(sympy.Function): 323 """ 324 Good ol' ternary operator 325 """ 326 327 nargs = (3,) 328 329 def _eval_is_integer(self): 330 return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] 331 332 def _eval_is_nonnegative(self): 333 return ( 334 True 335 if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] 336 else None 337 ) 338 339 def _eval_is_positive(self): 340 return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] 341 342 @classmethod 343 def eval(cls, c, p, q): 344 if c == sympy.true: 345 return p 346 elif c == sympy.false: 347 return q 348 349 350# Python-style modulus: take sign from RHS 351class PythonMod(sympy.Function): 352 nargs = (2,) 353 354 is_integer = True 355 356 @classmethod 357 def eval(cls, p, q): 358 # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint 359 # Triggered by sympy.solvers.inequalities.reduce_inequalities 360 # assert p.is_integer, p 361 # assert q.is_integer, q 362 363 if q.is_zero: 364 raise ZeroDivisionError("Modulo by zero") 365 366 # Three cases: 367 # 1. p == 0 368 # 2. p is either q or -q 369 # 3. p is integer and q == 1 370 if p is S.Zero or p in (q, -q) or q == 1: 371 return S.Zero 372 373 # Evaluate if they are both literals. 374 if q.is_Number and p.is_Number: 375 return p % q 376 377 # If q == 2, it's a matter of whether p is odd or even. 378 if q.is_Number and q == 2: 379 if p.is_even: 380 return S.Zero 381 if p.is_odd: 382 return S.One 383 384 # If p is a multiple of q. 385 r = p / q 386 if r.is_integer: 387 return S.Zero 388 389 # If p < q and its ratio is positive, then: 390 # - floor(p / q) = 0 391 # - p % q = p - floor(p / q) * q = p 392 less = p < q 393 if less.is_Boolean and bool(less) and r.is_positive: 394 return p 395 396 if sympy.Mod(p, q) == 0: 397 return S.Zero 398 399 # NB: args[1] for PythonMod 400 def _eval_is_nonnegative(self): 401 return True if self.args[1].is_positive else None # type: ignore[attr-defined] 402 403 def _eval_is_nonpositive(self): 404 return True if self.args[1].is_negative else None # type: ignore[attr-defined] 405 406 407# Generic modulus: only defined on non-negative arguments 408class Mod(sympy.Function): 409 nargs = (2,) 410 411 is_integer = True 412 is_nonnegative = True 413 414 @classmethod 415 def eval(cls, p, q): 416 # This was adapted from: sympy/core/mod.py 417 418 # Triggered by 419 # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full 420 # assert p.is_integer, p 421 # assert q.is_integer, q 422 423 if q.is_zero: 424 raise ZeroDivisionError("Modulo by zero") 425 426 # Three cases: 427 # 1. p == 0 428 # 2. p is either q or -q 429 # 3. p is integer and q == 1 430 if p is S.Zero or p in (q, -q) or q == 1: 431 return S.Zero 432 433 # Evaluate if they are both literals. 434 if q.is_Number and p.is_Number: 435 assert p >= 0, p 436 assert q >= 1, q 437 return p % q 438 439 # If q == 2, it's a matter of whether p is odd or even. 440 if q.is_Number and q == 2: 441 if p.is_even: 442 return S.Zero 443 if p.is_odd: 444 return S.One 445 446 # If p is a multiple of q. 447 r = p / q 448 if r.is_integer: 449 return S.Zero 450 451 # If p < q and its ratio is positive, then: 452 # - floor(p / q) = 0 453 # - p % q = p - floor(p / q) * q = p 454 less = p < q 455 if less.is_Boolean and bool(less) and r.is_positive: 456 return p 457 458 459class CleanDiv(FloorDiv): 460 """ 461 Div where we can assume no rounding. 462 This is to enable future optimizations. 463 """ 464 465 466# Don't use sympy ceiling/floor as they will attempt simplifications involving 467# frac 468class CeilToInt(sympy.Function): 469 is_integer = True 470 471 @classmethod 472 def eval(cls, number): 473 # assert number.is_integer is not True, number 474 if number in (sympy.oo, int_oo): 475 return int_oo 476 if number in (-sympy.oo, -int_oo): 477 return -int_oo 478 if isinstance(number, sympy.Number): 479 return sympy.Integer(math.ceil(float(number))) 480 481 482class FloorToInt(sympy.Function): 483 is_integer = True 484 485 @classmethod 486 def eval(cls, number): 487 # assert number.is_integer is not True, number 488 if number in (sympy.oo, int_oo): 489 return int_oo 490 if number in (-sympy.oo, int_oo): 491 return -int_oo 492 if isinstance(number, sympy.Number): 493 return sympy.Integer(math.floor(float(number))) 494 495 496class CeilDiv(sympy.Function): 497 """ 498 Div used in indexing that rounds up. 499 """ 500 501 is_integer = True 502 503 def __new__(cls, base, divisor): 504 base = sympy.sympify(base) 505 divisor = sympy.sympify(divisor) 506 if sympy.gcd(base, divisor) == divisor: 507 return CleanDiv(base, divisor) 508 else: 509 return FloorDiv(base + (divisor - 1), divisor) 510 511 512class LShift(sympy.Function): 513 is_integer = True 514 515 @classmethod 516 def eval(cls, base, shift): 517 if shift < 0: 518 raise ValueError("negative shift count") 519 return base * 2**shift 520 521 522class RShift(sympy.Function): 523 is_integer = True 524 525 @classmethod 526 def eval(cls, base, shift): 527 if shift < 0: 528 raise ValueError("negative shift count") 529 return base // 2**shift 530 531 532class MinMaxBase(Expr, LatticeOp): # type: ignore[misc] 533 def __new__(cls, *args, **assumptions): 534 from sympy.core.parameters import global_parameters 535 536 evaluate = assumptions.pop("evaluate", global_parameters.evaluate) 537 args = (sympify(arg) for arg in args) 538 539 # first standard filter, for cls.zero and cls.identity 540 # also reshape Max(a, Max(b, c)) to Max(a, b, c) 541 542 if evaluate: 543 try: 544 args = frozenset(cls._new_args_filter(args)) # type: ignore[assignment] 545 except ShortCircuit: 546 return cls.zero # type: ignore[attr-defined] 547 # remove redundant args that are easily identified 548 args = cls._collapse_arguments(args, **assumptions) 549 # find local zeros 550 args = cls._find_localzeros(args, **assumptions) 551 args = frozenset(args) 552 553 if not args: 554 return cls.identity # type: ignore[attr-defined] 555 556 if len(args) == 1: 557 return list(args).pop() 558 559 # base creation 560 obj = Expr.__new__(cls, *ordered(args), **assumptions) 561 obj._argset = args 562 return obj 563 564 @classmethod 565 def _collapse_arguments(cls, args, **assumptions): 566 """Remove redundant args. 567 568 Examples 569 ======== 570 571 >>> from sympy import Min, Max 572 >>> from sympy.abc import a, b, c, d, e 573 574 Any arg in parent that appears in any 575 parent-like function in any of the flat args 576 of parent can be removed from that sub-arg: 577 578 >>> Min(a, Max(b, Min(a, c, d))) 579 Min(a, Max(b, Min(c, d))) 580 581 If the arg of parent appears in an opposite-than parent 582 function in any of the flat args of parent that function 583 can be replaced with the arg: 584 585 >>> Min(a, Max(b, Min(c, d, Max(a, e)))) 586 Min(a, Max(b, Min(a, c, d))) 587 """ 588 if not args: 589 return args 590 args = list(ordered(args)) 591 if cls is Min: 592 other = Max 593 else: 594 other = Min # type: ignore[assignment] 595 596 # find global comparable max of Max and min of Min if a new 597 # value is being introduced in these args at position 0 of 598 # the ordered args 599 if args[0].is_number: 600 sifted = mins, maxs = [], [] # type: ignore[var-annotated] 601 for i in args: 602 for v in walk(i, Min, Max): 603 if v.args[0].is_comparable: 604 sifted[isinstance(v, Max)].append(v) 605 small = Min.identity 606 for i in mins: 607 v = i.args[0] 608 if v.is_number and (v < small) == True: # noqa: E712 609 small = v 610 big = Max.identity 611 for i in maxs: 612 v = i.args[0] 613 if v.is_number and (v > big) == True: # noqa: E712 614 big = v 615 # at the point when this function is called from __new__, 616 # there may be more than one numeric arg present since 617 # local zeros have not been handled yet, so look through 618 # more than the first arg 619 if cls is Min: 620 for arg in args: 621 if not arg.is_number: 622 break 623 if (arg < small) == True: # noqa: E712 624 small = arg 625 elif cls == Max: 626 for arg in args: 627 if not arg.is_number: 628 break 629 if (arg > big) == True: # noqa: E712 630 big = arg 631 T = None 632 if cls is Min: 633 if small != Min.identity: 634 other = Max 635 T = small 636 elif big != Max.identity: 637 other = Min # type: ignore[assignment] 638 T = big 639 if T is not None: 640 # remove numerical redundancy 641 for i in range(len(args)): 642 a = args[i] 643 if isinstance(a, other): 644 a0 = a.args[0] 645 if ( # noqa: E712 646 (a0 > T) if other == Max else (a0 < T) # noqa: E712 647 ) == True: # noqa: E712 648 args[i] = cls.identity # type: ignore[attr-defined] 649 650 # remove redundant symbolic args 651 def do(ai, a): 652 if not isinstance(ai, (Min, Max)): 653 return ai 654 cond = a in ai.args 655 if not cond: 656 return ai.func(*[do(i, a) for i in ai.args], evaluate=False) 657 if isinstance(ai, cls): 658 return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False) 659 return a 660 661 for i, a in enumerate(args): 662 args[i + 1 :] = [do(ai, a) for ai in args[i + 1 :]] 663 664 # factor out common elements as for 665 # Min(Max(x, y), Max(x, z)) -> Max(x, Min(y, z)) 666 # and vice versa when swapping Min/Max -- do this only for the 667 # easy case where all functions contain something in common; 668 # trying to find some optimal subset of args to modify takes 669 # too long 670 671 def factor_minmax(args): 672 is_other = lambda arg: isinstance(arg, other) # noqa: E731 673 other_args, remaining_args = sift(args, is_other, binary=True) 674 if not other_args: 675 return args 676 677 # Min(Max(x, y, z), Max(x, y, u, v)) -> {x,y}, ({z}, {u,v}) 678 arg_sets = [set(arg.args) for arg in other_args] 679 common = set.intersection(*arg_sets) 680 if not common: 681 return args 682 683 new_other_args = list(common) 684 arg_sets_diff = [arg_set - common for arg_set in arg_sets] 685 686 # If any set is empty after removing common then all can be 687 # discarded e.g. Min(Max(a, b, c), Max(a, b)) -> Max(a, b) 688 if all(arg_sets_diff): 689 other_args_diff = [other(*s, evaluate=False) for s in arg_sets_diff] 690 new_other_args.append(cls(*other_args_diff, evaluate=False)) 691 692 other_args_factored = other(*new_other_args, evaluate=False) 693 return remaining_args + [other_args_factored] 694 695 if len(args) > 1: 696 args = factor_minmax(args) 697 698 return args 699 700 @classmethod 701 def _new_args_filter(cls, arg_sequence): 702 """ 703 Generator filtering args. 704 705 first standard filter, for cls.zero and cls.identity. 706 Also reshape ``Max(a, Max(b, c))`` to ``Max(a, b, c)``, 707 and check arguments for comparability 708 """ 709 for arg in arg_sequence: 710 # pre-filter, checking comparability of arguments 711 if ( 712 not isinstance(arg, Expr) 713 or arg.is_extended_real is False 714 or (arg.is_number and not arg.is_comparable) 715 ): 716 raise ValueError(f"The argument '{arg}' is not comparable.") 717 718 if arg == cls.zero: # type: ignore[attr-defined] 719 raise ShortCircuit(arg) 720 elif arg == cls.identity: # type: ignore[attr-defined] 721 continue 722 elif arg.func == cls: 723 yield from arg.args 724 else: 725 yield arg 726 727 @classmethod 728 def _find_localzeros(cls, values, **options): 729 """ 730 Sequentially allocate values to localzeros. 731 732 When a value is identified as being more extreme than another member it 733 replaces that member; if this is never true, then the value is simply 734 appended to the localzeros. 735 """ 736 localzeros = set() # type: ignore[var-annotated] 737 for v in values: 738 is_newzero = True 739 localzeros_ = list(localzeros) 740 for z in localzeros_: 741 if id(v) == id(z): 742 is_newzero = False 743 else: 744 con = cls._is_connected(v, z) 745 if con: 746 is_newzero = False 747 if con is True or con == cls: 748 localzeros.remove(z) 749 localzeros.update([v]) 750 if is_newzero: 751 localzeros.update([v]) 752 return localzeros 753 754 @classmethod 755 def _is_connected(cls, x, y): 756 """ 757 Check if x and y are connected somehow. 758 """ 759 if x == y: 760 return True 761 t, f = Max, Min 762 for op in "><": 763 for j in range(2): 764 try: 765 if op == ">": 766 v = x >= y 767 else: 768 v = x <= y 769 except TypeError: 770 return False # non-real arg 771 if not v.is_Relational: 772 return t if v else f 773 t, f = f, t # type: ignore[assignment] 774 x, y = y, x 775 x, y = y, x # run next pass with reversed order relative to start 776 777 return False 778 779 _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731 780 _eval_is_antihermitian = lambda s: _torf( # noqa: E731 781 i.is_antihermitian for i in s.args # noqa: E731 782 ) # noqa: E731 783 _eval_is_commutative = lambda s: _torf( # noqa: E731 784 i.is_commutative for i in s.args # noqa: E731 785 ) # noqa: E731 786 _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731 787 _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731 788 _eval_is_even = lambda s: _torf(i.is_even for i in s.args) # noqa: E731 789 _eval_is_finite = lambda s: _torf(i.is_finite for i in s.args) # noqa: E731 790 _eval_is_hermitian = lambda s: _torf(i.is_hermitian for i in s.args) # noqa: E731 791 _eval_is_imaginary = lambda s: _torf(i.is_imaginary for i in s.args) # noqa: E731 792 _eval_is_infinite = lambda s: _torf(i.is_infinite for i in s.args) # noqa: E731 793 _eval_is_integer = lambda s: _torf(i.is_integer for i in s.args) # noqa: E731 794 _eval_is_irrational = lambda s: _torf(i.is_irrational for i in s.args) # noqa: E731 795 _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731 796 _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731 797 _eval_is_nonnegative = lambda s: _torf( # noqa: E731 798 i.is_nonnegative for i in s.args # noqa: E731 799 ) # noqa: E731 800 _eval_is_nonpositive = lambda s: _torf( # noqa: E731 801 i.is_nonpositive for i in s.args # noqa: E731 802 ) # noqa: E731 803 _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731 804 _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731 805 _eval_is_polar = lambda s: _torf(i.is_polar for i in s.args) # noqa: E731 806 _eval_is_positive = lambda s: _torf(i.is_positive for i in s.args) # noqa: E731 807 _eval_is_prime = lambda s: _torf(i.is_prime for i in s.args) # noqa: E731 808 _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731 809 _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731 810 _eval_is_extended_real = lambda s: _torf( # noqa: E731 811 i.is_extended_real for i in s.args # noqa: E731 812 ) # noqa: E731 813 _eval_is_transcendental = lambda s: _torf( # noqa: E731 814 i.is_transcendental for i in s.args # noqa: E731 815 ) # noqa: E731 816 _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731 817 818 819class Max(MinMaxBase, Application): # type: ignore[misc] 820 r""" 821 Return, if possible, the maximum value of the list. 822 """ 823 zero = S.Infinity 824 identity = S.NegativeInfinity 825 826 def _eval_is_positive(self): 827 return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined] 828 829 def _eval_is_nonnegative(self): 830 return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] 831 832 def _eval_is_negative(self): 833 return fuzzy_and(a.is_negative for a in self.args) 834 835 836class Min(MinMaxBase, Application): # type: ignore[misc] 837 """ 838 Return, if possible, the minimum value of the list. 839 """ 840 841 zero = S.NegativeInfinity 842 identity = S.Infinity 843 844 def _eval_is_positive(self): 845 return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined] 846 847 def _eval_is_nonnegative(self): 848 return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] 849 850 def _eval_is_negative(self): 851 return fuzzy_or(a.is_negative for a in self.args) 852 853 854def safe_pow(base, exp): 855 sign = 1 856 if base < 0: 857 base = -base 858 sign = 1 if exp % 2 == 0 else -1 859 return sign * _safe_pow(base, exp) 860 861 862# Prevent people from overflowing pow 863def _safe_pow(base, exponent): 864 if exponent < 0: 865 raise ValueError("Exponent must be non-negative.") 866 867 if exponent == 0: 868 return 1 869 870 half_exp = safe_pow(base, exponent // 2) 871 if half_exp is int_oo: 872 return int_oo 873 874 # TODO: microoptimization is to avoid overflowing into arbitrary precision 875 # and detect overflow prior to doing operations 876 877 result = half_exp * half_exp 878 if result > sys.maxsize: 879 return int_oo 880 881 if exponent % 2 == 1: 882 result *= base 883 if result > sys.maxsize: 884 return int_oo 885 886 return result 887 888 889class PowByNatural(sympy.Function): 890 is_integer = True 891 892 @classmethod 893 def eval(cls, base, exp): 894 if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer): 895 r = safe_pow(base, exp) 896 if r in (-int_oo, int_oo): 897 return r 898 return sympy.Integer(r) 899 if isinstance(exp, sympy.Integer): 900 # Rely on regular sympy Pow for this (note that iterated 901 # multiplication turns into a Pow anyway, you can't escape!!) 902 return sympy.Pow(base, exp) 903 if exp in (int_oo, sympy.oo): 904 if base.is_nonnegative: 905 return int_oo 906 elif base.is_negative: 907 return sympy.zoo # this is apparently what (-2)**sympy.oo does 908 # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp 909 # is a natural number if we do 910 911 912# base is assumed to be nonnegative, thereby prevent complex numbers from 913# occuring 914class FloatPow(sympy.Function): 915 is_real = True 916 917 @classmethod 918 def eval(cls, base, exp): 919 # NB: These test sympy.Number, not sympy.Float, because: 920 # - Sometimes we may have sympy.oo or int_oo, and that's not a Float 921 # (but coerces to math.Inf) 922 # - Sometimes Float(0.0) will unpredictably decay to Integer(0), 923 # but we should still accept it in floatey contexts 924 if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): 925 return sympy.Float(float(base) ** float(exp)) 926 # NB: do not do any nontrivial reasoning 927 928 929# Overloaded to be compatible with regular Python. 930# https://github.com/pytorch/pytorch/issues/90900 931# 932# In particular, sympy division is willing to simplify x/x == 1 933# where 1 is an integer, but this must be a float if x was float. 934class FloatTrueDiv(sympy.Function): 935 is_real = True 936 937 @classmethod 938 def eval(cls, base, divisor): 939 # assert base.is_integer is not True, base 940 # assert divisor.is_integer is not True, divisor 941 942 if divisor.is_zero: 943 raise ZeroDivisionError("division by zero") 944 945 if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): 946 return sympy.Float(float(base) / float(divisor)) 947 948 949# Overloaded to be compatible with regular Python. We distinguish this from 950# FloatTrueDiv, because the code generation has to be different for this case: 951# Python has a fancy algorithm for integer true division that isn't just 952# "promote both arguments to float and use float division", so you need to 953# codegen it differently. While technically you can work it out from the 954# types of the input, this is often inconvenient to do in Inductor codegen, 955# so just have a different operator 956# NB: Right now, Inductor codegen doesn't implement this correctly lol 957class IntTrueDiv(sympy.Function): 958 is_real = True 959 960 @classmethod 961 def eval(cls, base, divisor): 962 if divisor.is_zero: 963 raise ZeroDivisionError("division by zero") 964 965 if ( 966 isinstance(base, sympy.Number) 967 and isinstance(divisor, sympy.Number) 968 and ( 969 base in (int_oo, -int_oo, sympy.oo, -sympy.oo) 970 or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo) 971 ) 972 ): 973 # Don't have to worry about precision here, you're getting zero or 974 # inf from the division 975 return sympy.Float(float(base) / float(divisor)) 976 if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): 977 return sympy.Float(int(base) / int(divisor)) 978 979 980# TODO: As an indicator, this != 0 implies == 1 (and vice versa). 981# Because we do not have the ability to guard on the stride permutation 982# at the moment, it is hard to make further inferences when this is true, 983# as although we know the tensor is contiguous in *some* layout, we don't 984# know which one (however, you could, for example, make the inference that 985# reshaping this to a 1D tensor can be guard-free.) 986class IsNonOverlappingAndDenseIndicator(sympy.Function): 987 is_integer = True 988 989 @classmethod 990 def eval(cls, *args): 991 assert len(args) % 2 == 0 992 dim = len(args) // 2 993 sizes = args[0:dim] 994 strides = args[dim:] 995 996 # sym_node imported in torch.__init__. Local import to avoid an import cycle 997 from torch.fx.experimental.symbolic_shapes import ( 998 eval_is_non_overlapping_and_dense, 999 ) 1000 1001 if all(isinstance(a, sympy.Integer) for a in args): 1002 return eval_is_non_overlapping_and_dense( 1003 [int(a) for a in sizes], [int(a) for a in strides] 1004 ) 1005 1006 if dim == 1: 1007 # Manually implement the rank one short circuit 1008 if strides[0].is_Number and strides[0] == 1: 1009 return 1 1010 1011 if sizes[0].is_Number and sizes[0] < 2: 1012 return 1 1013 1014 # return 0 case covered by case above 1015 1016 # TODO: Inability to access size-obliviousness sucks: if we have a 1017 # size oblivious test on a size-like unbacked SymInt, we could 1018 # confidently return zero when we have a size-like u0 stride 1019 # and a size-like u1 size. Maybe a fancy ValueRanges analysis for 1020 # this function could help figure this out. 1021 1022 if all(isinstance(a, sympy.Integer) for a in strides): 1023 assert dim != 0 1024 # When all strides are integral, we can sort, and the size for the 1025 # largest stride doesn't matter and can be arbitrarily symbolic 1026 s_sizes, s_strides = zip( 1027 *sorted(zip(sizes, strides), key=operator.itemgetter(1)) 1028 ) 1029 # Put something arbitrary in the max size spot, it'll be ignored 1030 if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]): 1031 s_sizes = s_sizes[:-1] + (42,) 1032 # We can reuse the regular eval, because it is invariant to 1033 # permutation of dimensions 1034 return eval_is_non_overlapping_and_dense( 1035 [int(a) for a in s_sizes], [int(a) for a in s_strides] 1036 ) 1037 1038 return None 1039 1040 1041# NB: this is inconsistent with math.trunc in Python 1042class TruncToFloat(sympy.Function): 1043 is_real = True 1044 1045 @classmethod 1046 def eval(cls, number): 1047 # assert number.is_integer is not True, number 1048 if isinstance(number, sympy.Number): 1049 # NB: It is safe to use truncation to integer, which is what 1050 # math.trunc does, as Python integers are arbitrary precision and 1051 # so we are guaranteed not to lose precision when we do this 1052 return sympy.Float(math.trunc(float(number))) 1053 1054 1055class TruncToInt(sympy.Function): 1056 is_integer = True 1057 1058 @classmethod 1059 def eval(cls, number): 1060 # assert number.is_integer is not True, number 1061 if number in (sympy.oo, int_oo): 1062 return int_oo 1063 if number in (-sympy.oo, -int_oo): 1064 return -int_oo 1065 if isinstance(number, sympy.Number): 1066 return sympy.Integer(math.trunc(float(number))) 1067 1068 1069# This is float -> int 1070class RoundToInt(sympy.Function): 1071 is_integer = True 1072 1073 @classmethod 1074 def eval(cls, number): 1075 # assert number.is_integer is not True, number 1076 1077 if number is sympy.oo: 1078 return int_oo 1079 if number is -sympy.oo: 1080 return -int_oo 1081 if isinstance(number, sympy.Number): 1082 return sympy.Integer(round(float(number), 0)) 1083 1084 1085# To get float -> int, Python style round semantics. 1086# 1087# x = PyFloat_AsDouble(self); 1088# if (o_ndigits == Py_None) { 1089# /* single-argument round or with None ndigits: 1090# * round to nearest integer */ 1091# rounded = round(x); 1092# if (fabs(x-rounded) == 0.5) 1093# /* halfway case: round to even */ 1094# rounded = 2.0*round(x/2.0); 1095# return PyLong_FromDouble(rounded); 1096# } 1097 1098 1099# NB: Like Round, this only ever returns floats. ndigits cannot be None 1100class RoundDecimal(sympy.Function): 1101 is_real = True 1102 1103 @classmethod 1104 def eval(cls, number, ndigits): 1105 # assert number.is_integer is not True, number 1106 1107 if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): 1108 return sympy.Float(round(float(number), int(ndigits))) 1109 1110 1111class ToFloat(sympy.Function): 1112 is_real = True 1113 1114 @classmethod 1115 def eval(cls, number): 1116 if number in [sympy.oo, -sympy.oo]: 1117 return number 1118 1119 if isinstance(number, sympy.Integer): 1120 return sympy.Float(int(number)) 1121 if number is int_oo: 1122 return sympy.oo 1123 if number is -int_oo: 1124 return -sympy.oo 1125 1126 1127class Identity(sympy.Function): 1128 """ 1129 Prevents expansion and other optimizations 1130 """ 1131 1132 def __repr__(self): 1133 return f"Identity({self.args[0]})" 1134 1135 def _eval_is_real(self): 1136 return self.args[0].is_real 1137 1138 def _eval_is_integer(self): 1139 return self.args[0].is_integer # type: ignore[attr-defined] 1140 1141 1142def make_opaque_unary_fn(name): 1143 class OpaqueUnaryFn(sympy.Function): 1144 """ 1145 Unlike the builtin sympy functions on real numbers like sympy.sqrt, 1146 these equivalents do not do any nontrivial reasoning besides 1147 constant propagation. This helps avoid performing transformations 1148 that are valid for real numbers but are invalid for floating point; 1149 in particular, while we are willing to make optimizations that change 1150 numerics for Tensor compute, we are NOT willing to make optimziations 1151 that change numerics for size compute. 1152 """ 1153 1154 _torch_handler_name = name 1155 1156 @classmethod 1157 def eval(cls, a): 1158 if isinstance(a, (sympy.Integer, sympy.Float)): 1159 # Python converts to float64 before computing, c.f. 1160 # >>> math.sin(2**53+1) 1161 # -0.848925964814655 1162 # >>> math.sin(float(2**53+1)) 1163 # -0.848925964814655 1164 try: 1165 return sympy.Float(getattr(math, name)(float(a))) 1166 # Just use sympy semantics for infinity/overflow, you might get some 1167 # weird objects but ask silly questions, get silly answers 1168 except OverflowError: 1169 return getattr(sympy, name)(a) 1170 elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]: 1171 if a is int_oo: 1172 a = sympy.oo 1173 if a is -int_oo: 1174 a = -sympy.oo 1175 return getattr(sympy, name)(a) 1176 return None 1177 1178 OpaqueUnaryFn.__name__ = "OpaqueUnaryFn_" + name 1179 1180 return OpaqueUnaryFn 1181 1182 1183# Keep in sync with math_op_names in torch/fx/experimental/sym_node.py 1184OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt") 1185OpaqueUnaryFn_cos = make_opaque_unary_fn("cos") 1186OpaqueUnaryFn_cosh = make_opaque_unary_fn("cosh") 1187OpaqueUnaryFn_sin = make_opaque_unary_fn("sin") 1188OpaqueUnaryFn_sinh = make_opaque_unary_fn("sinh") 1189OpaqueUnaryFn_tan = make_opaque_unary_fn("tan") 1190OpaqueUnaryFn_tanh = make_opaque_unary_fn("tanh") 1191OpaqueUnaryFn_asin = make_opaque_unary_fn("asin") 1192OpaqueUnaryFn_acos = make_opaque_unary_fn("acos") 1193OpaqueUnaryFn_atan = make_opaque_unary_fn("atan") 1194OpaqueUnaryFn_exp = make_opaque_unary_fn("exp") 1195OpaqueUnaryFn_log = make_opaque_unary_fn("log") 1196OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh") 1197