1# mypy: allow-untyped-defs 2""" 3This file does three things: 4- Contains the definition of SymNode 5- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time 6- Does not depend on sympy at import time 7 8As this file is imported from within torch/__init__.py we do not want it to depend on SymPy 9to avoid having to load SymPy at import time, as doing so is *very* slow. 10""" 11 12import builtins 13import itertools 14import logging 15import math 16import operator 17import sys 18from functools import lru_cache, update_wrapper 19from typing import Optional, Type, TYPE_CHECKING, Union 20 21import torch 22 23# NB: The sym_* functions are used via getattr() and must be imported here. 24from torch import ( # noqa: F401 25 sym_float, 26 sym_ite, 27 sym_max, 28 sym_min, 29 sym_not, 30 SymBool, 31 SymFloat, 32 SymInt, 33) 34 35 36if TYPE_CHECKING: 37 from torch.fx.experimental.symbolic_shapes import ShapeEnv 38 39log = logging.getLogger(__name__) 40sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node") 41 42 43__all__ = ["SymNode", "method_to_operator", "magic_methods"] 44 45 46from torch.types import py_sym_types as SymTypes 47 48 49def _to_symtype(t): 50 if t is bool: 51 return SymBool 52 if t is int: 53 return SymInt 54 if t is float: 55 return SymFloat 56 return t 57 58 59# TODO: An incomplete list 60# 1. Set variables to be equal when we do equality 61# 2. Specialize on 0/1 when we do subtraction 62class SymNode: 63 """ 64 This is a type erased SymInt/SymFloat which we use to do actual operations. 65 End users don't touch this. Magic methods are NOT defined on this object. 66 """ 67 68 def __init__( 69 self, 70 expr, 71 shape_env, 72 pytype, 73 hint: Optional[Union[int, float, bool]], 74 constant=None, 75 fx_node=None, 76 ): 77 self._expr = expr 78 self.shape_env = shape_env 79 self.pytype = pytype 80 81 # What's the difference between hint and constant? 82 # 83 # - A constant is known to be invariant across invocations of the model; 84 # it will always be this value. We only really know this when we 85 # encounter an honest-to-goodness literal (when wrapping it into 86 # a SymNode, we set constant.) Most of the time, constant is None 87 # 88 # - A hint is a *particular* value from the particular run we are 89 # tracing, but it may vary the next time around. It's useful to 90 # keep this around, as if we need a concrete value from a SymNode, 91 # we will return the hint and guard on the expression that produced 92 # it giving the same hint next time around. The hint is not 93 # guaranteed to be set either: if you have an unbacked SymNode, 94 # there won't be any hint; it was the result of some tensor-dependent 95 # computation, but we don't know what it actually is because we 96 # haven't actually run the tensor computation. 97 # 98 # If _hint is None, we will query maybe_evaluate_static(compute_hint=True) 99 # in hopes that we've learned enough about the unbacked symints to 100 # discharge the hint; otherwise, you're likely to just error out. 101 # 102 # (A previous version of this system had some optimizations to only 103 # recompute when it was possible we had learned enough about the 104 # unbacked symint that a hint was now possible, but as we added more 105 # potential refinements to unbacked symints this got harder to keep 106 # in sync, so we've deleted it for now.) 107 108 def compute_hint(): 109 from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols 110 111 # This occasionally gets exercised by, e.g., 112 # convert_shape_to_symint. It's just a nicety so you don't HAVE 113 # to have a correct hint on hand when making a SymNode. 114 # Don't attempt to compute for unbacked, this can be quite 115 # expensive. 116 if free_unbacked_symbols(self.expr): 117 return None 118 hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True) 119 if hint is not None: 120 hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint 121 return hint 122 123 if hint is not None: 124 assert type(hint) is pytype or type(hint) is _to_symtype(pytype), ( 125 "Cannot create SymNode of type " 126 f"{pytype} with incompatible hint of type {type(hint)}" 127 ) 128 if self.shape_env and self.shape_env._translation_validation_enabled: 129 # This is technically not TV, but this assert is expensive so 130 # let's only do it when we're already doing expensive things 131 computed_hint = compute_hint() 132 assert ( 133 hint == computed_hint 134 ), f"{hint} != {computed_hint} (for {self.expr})" 135 else: 136 hint = compute_hint() 137 self._hint = hint 138 self.constant: Optional[Union[int, float, bool]] = constant 139 140 # Record the FX node of the current node if we are doing translation 141 # validation. They will be used for building the input assertions for 142 # the translation validation problem. 143 tx_validation_en = ( 144 self.shape_env and self.shape_env._translation_validation_enabled 145 ) 146 self.fx_node = tx_validation_en and fx_node 147 148 def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode": 149 return SymNode( 150 self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node 151 ) 152 153 def _value_eq(self, other: "SymNode") -> bool: 154 # Purposely don't include the shape_env in the eq. 155 return ( 156 self._expr == other._expr 157 and self.pytype == other.pytype 158 and self._hint == other._hint 159 and self.constant == other.constant 160 and self.fx_node == other.fx_node 161 ) 162 163 def _value_hash(self) -> int: 164 # Purposely don't include the shape_env in the hash. 165 return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node)) 166 167 @property 168 def expr(self): 169 return self.shape_env.replace(self._expr) 170 171 @property 172 def hint(self): 173 return self._hint 174 175 def has_hint(self): 176 return self._hint is not None 177 178 def require_hint(self, fallback=None): 179 if self._hint is None: 180 if fallback is not None: 181 return fallback 182 # NB: we expect this to raise 183 return self.shape_env.size_hint(self.expr) 184 return self._hint 185 186 def maybe_as_int(self): 187 if self.expr.is_number: 188 return int(self.expr) 189 else: 190 return None 191 192 # NB: This does conversions, not sure if this is good or not 193 def maybe_as_float(self): 194 import sympy 195 196 if isinstance(self.expr, sympy.Float): 197 return float(self.expr) 198 else: 199 return None 200 201 def maybe_as_bool(self): 202 import sympy 203 204 if self.expr is sympy.true: 205 return True 206 elif self.expr is sympy.false: 207 return False 208 else: 209 return None 210 211 def is_int(self): 212 return self.pytype is int 213 214 def is_float(self): 215 return self.pytype is float 216 217 def is_bool(self): 218 return self.pytype is bool 219 220 def is_nested_int(self): 221 # Unbacked SymInts cannot be nested int today 222 return ( 223 self._hint is not None 224 and isinstance(self._hint, SymInt) 225 and self._hint.node.is_nested_int() 226 ) 227 228 def wrap_int(self, num): 229 assert type(num) is int 230 import sympy 231 232 return SymNode( 233 sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num 234 ) 235 236 def wrap_float(self, num): 237 assert type(num) is float 238 import sympy 239 240 return SymNode( 241 sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num 242 ) 243 244 def wrap_bool(self, num): 245 assert type(num) is bool 246 import sympy 247 248 return SymNode( 249 sympy.true if num else sympy.false, 250 self.shape_env, 251 bool, 252 num, 253 constant=num, 254 fx_node=num, 255 ) 256 257 def clone(self): 258 return self 259 260 def str(self): 261 return f"{self.expr}" 262 263 def __str__(self): 264 return self.str() 265 266 def __repr__(self): 267 rep = [ 268 f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}", 269 ] 270 if self._hint is not None: 271 rep.append(f"hint={self._hint}") 272 if self.constant is not None: 273 rep.append(f"constant={self.constant}") 274 if self.fx_node is not None: 275 rep.append(f"fx_node={self.fx_node}") 276 return ", ".join(rep) + ")" 277 278 def _graph_repr(self) -> builtins.str: 279 # Representation used by GraphModule to create a pythonic version of a graph 280 return self.str() 281 282 # These methods call the metaprogrammed methods, they're hand written 283 # here so we get good stack traces 284 def abs(self) -> "SymNode": 285 return self._abs() # type: ignore[attr-defined] 286 287 def pos(self) -> "SymNode": 288 return self._pos() # type: ignore[attr-defined] 289 290 def round(self, ndigits=None) -> "SymNode": 291 return self._round(ndigits) # type: ignore[attr-defined] 292 293 def trunc(self) -> "SymNode": 294 return self._trunc() # type: ignore[attr-defined] 295 296 def add(self, other) -> "SymNode": 297 return self._add(other) # type: ignore[attr-defined] 298 299 def sub(self, other) -> "SymNode": 300 return self._sub(other) # type: ignore[attr-defined] 301 302 def mul(self, other) -> "SymNode": 303 return self._mul(other) # type: ignore[attr-defined] 304 305 def mod(self, other) -> "SymNode": 306 return self._mod(other) # type: ignore[attr-defined] 307 308 def float_pow(self, other) -> "SymNode": 309 return self._float_pow(other) # type: ignore[attr-defined] 310 311 def pow_by_natural(self, other) -> "SymNode": 312 return self._pow_by_natural(other) # type: ignore[attr-defined] 313 314 def and_(self, other) -> "SymNode": 315 return self._and_(other) # type: ignore[attr-defined] 316 317 def or_(self, other) -> "SymNode": 318 return self._or_(other) # type: ignore[attr-defined] 319 320 def float_truediv(self, other) -> "SymNode": 321 return self._float_truediv(other) # type: ignore[attr-defined] 322 323 def int_truediv(self, other) -> "SymNode": 324 return self._int_truediv(other) # type: ignore[attr-defined] 325 326 def int_floordiv(self, other) -> "SymNode": 327 return self._int_floordiv(other) # type: ignore[attr-defined] 328 329 def lshift(self, other) -> "SymNode": 330 return self._lshift(other) # type: ignore[attr-defined] 331 332 def rshift(self, other) -> "SymNode": 333 return self._rshift(other) # type: ignore[attr-defined] 334 335 def sym_not(self) -> "SymNode": # noqa: F811 336 return self._sym_not() # type: ignore[attr-defined] 337 338 def eq(self, other) -> "SymNode": 339 return self._eq(other) # type: ignore[attr-defined] 340 341 def ne(self, other) -> "SymNode": 342 return self._ne(other) # type: ignore[attr-defined] 343 344 def gt(self, other) -> "SymNode": 345 return self._gt(other) # type: ignore[attr-defined] 346 347 def lt(self, other) -> "SymNode": 348 return self._lt(other) # type: ignore[attr-defined] 349 350 def le(self, other) -> "SymNode": 351 return self._le(other) # type: ignore[attr-defined] 352 353 def ge(self, other) -> "SymNode": 354 return self._ge(other) # type: ignore[attr-defined] 355 356 def floor(self) -> "SymNode": 357 return self._floor() # type: ignore[attr-defined] 358 359 def is_integer(self) -> "SymNode": 360 return self._is_integer() # type: ignore[attr-defined] 361 362 def sym_float(self) -> "SymNode": # noqa: F811 363 return self._sym_float() # type: ignore[attr-defined] 364 365 def sym_int(self) -> "SymNode": 366 return self._sym_int() # type: ignore[attr-defined] 367 368 def ceil(self) -> "SymNode": 369 return self._ceil() # type: ignore[attr-defined] 370 371 def neg(self) -> "SymNode": 372 return self._neg() # type: ignore[attr-defined] 373 374 def sym_min(self, other) -> "SymNode": # noqa: F811 375 return self._sym_min(other) # type: ignore[attr-defined] 376 377 def sym_max(self, other) -> "SymNode": # noqa: F811 378 return self._sym_max(other) # type: ignore[attr-defined] 379 380 def sym_ite(self, then_val, else_val) -> "SymNode": 381 return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] 382 383 def is_contiguous(self, sizes, strides) -> "SymNode": 384 return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] 385 386 def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode": 387 return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] 388 389 def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode": 390 return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] 391 392 def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode": 393 return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] 394 395 def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode": 396 return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] 397 398 def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode": 399 return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] 400 401 # Make C++ happy 402 def sym_or(self, other): 403 return self.or_(other) 404 405 def sym_and(self, other): 406 return self.and_(other) 407 408 # There is no int_truediv available from C++ 409 def truediv(self, other): 410 return self.float_truediv(other) 411 412 def floordiv(self, other) -> "SymNode": 413 return self.int_floordiv(other) 414 415 # We didn't bind integer pow in C++ 416 def pow(self, other): 417 return self.float_pow(other) 418 419 def is_non_overlapping_and_dense(self, sizes, strides): 420 return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] 421 422 def int_(self): 423 return self.guard_int("", 0) # NB: uses Python backtrace 424 425 # You can manually trigger a guard with this function 426 def guard_int(self, file, line): 427 # TODO: use the file/line for some useful diagnostic on why a 428 # guard occurred 429 r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) 430 try: 431 return int(r) 432 except Exception: 433 log.warning("Failed to convert to int: %s", r) 434 raise 435 436 def guard_float(self, file, line): 437 # TODO: use the file/line for some useful diagnostic on why a 438 # guard occurred 439 r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) 440 try: 441 return float(r) 442 except Exception: 443 log.warning("Failed to convert to float: %s", r) 444 raise 445 446 def guard_bool(self, file, line): 447 # TODO: use the file/line for some useful diagnostic on why a 448 # guard occurred 449 r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) 450 try: 451 return bool(r) 452 except Exception: 453 log.warning("Failed to convert to bool: %s", r) 454 raise 455 456 def expect_true(self, file, line): 457 from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols 458 459 if ( 460 self.has_hint() 461 and not free_unbacked_symbols(self.expr) 462 and not self.shape_env.prefer_deferred_runtime_asserts_over_guards 463 ): 464 # OK to generate guards 465 return self.guard_bool(file, line) 466 # Generate a deferred runtime assert (this might actually end up doing 467 # a regular guard if we can!) 468 # TODO: file/line here is very important, because the assert has been 469 # deferred so you can't backtrace easily 470 return self.shape_env.defer_runtime_assert( 471 self.expr, f"{file}:{line}", fx_node=self.fx_node 472 ) 473 474 def expect_size(self, file, line): 475 from torch.fx.experimental.symbolic_shapes import _advise_is_size 476 477 b = self.ge(self.wrap_int(0)) 478 # Generate a deferred runtime assert 479 r = b.expect_true(file, line) 480 # Refine compile time range, but only if it's unbacked. 481 # If you refine range for hinted variables, you can end up making 482 # improper deductions since compile time reasoning may be 483 # incompatible with runtime reasoning. 484 if r and not self.has_hint(): 485 _advise_is_size(SymInt(self)) 486 return r 487 488 def guard_size_oblivious(self, file, line): 489 """ 490 Like guard_bool, but if we encounter unbacked symbols, if those symbols 491 are size-like, we will treat them as >= 2 for the purposes of the analysis. 492 493 This CHANGES the runtime semantics, but all size-oblivious sites have been 494 audited to ensure that the runtime semantics don't change in a material way. 495 Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping 496 an unbacked one size, or a tensor reporting as non-contiguous even if it's 497 contiguous if it would have been reported contiguous due to being empty. 498 """ 499 # TODO: use the file/line for some useful diagnostic on why a 500 # guard occurred 501 r = self.shape_env.evaluate_expr( 502 self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True 503 ) 504 try: 505 return bool(r) 506 except Exception: 507 log.warning("Failed to convert to bool: %s", r) 508 raise 509 510 def bool_(self): 511 return self.guard_bool("", 0) 512 513 def is_symbolic(self): 514 return True 515 516 def nested_int(self): 517 return None 518 519 def is_constant(self): 520 return False 521 522 523# TODO: this probably needs the sizes-strides eval functions 524METHOD_TO_OPERATOR = { 525 "pos": operator.pos, 526 "abs": operator.abs, 527 "add": operator.add, 528 "and": operator.and_, 529 "ceil": math.ceil, 530 "eq": operator.eq, 531 "floor": math.floor, 532 "trunc": math.trunc, 533 "int_floordiv": operator.floordiv, 534 "ge": operator.ge, 535 "gt": operator.gt, 536 "is_integer": lambda x: x.is_integer(), 537 "le": operator.le, 538 "lshift": operator.lshift, 539 "lt": operator.lt, 540 "mod": operator.mod, 541 "mul": operator.mul, 542 "ne": operator.ne, 543 "neg": operator.neg, 544 "or": operator.or_, 545 "float_pow": operator.pow, 546 "pow_by_natural": operator.pow, 547 "round": builtins.round, 548 "rshift": operator.rshift, 549 "sub": operator.sub, 550 "sym_float": sym_float, 551 "sym_ite": sym_ite, 552 "sym_max": sym_max, 553 "sym_min": sym_min, 554 "sym_not": sym_not, 555 "float_truediv": operator.truediv, 556 "int_truediv": operator.truediv, 557} 558 559unary_magic_methods = { 560 "abs", 561 "sym_float", 562 "sym_int", 563 "ceil", 564 "floor", 565 "neg", 566 "sym_not", 567 "pos", 568 "trunc", 569} 570 571 572# Adding math ops: sqrt, cos, sin, ... 573def _get_sym_node_fn(name): 574 def fn(self): 575 return getattr(self, f"_sym_{name}")() 576 577 return fn 578 579 580math_op_names = ( 581 "sqrt", 582 "cos", 583 "cosh", 584 "sin", 585 "sinh", 586 "tan", 587 "tanh", 588 "asin", 589 "acos", 590 "atan", 591) 592for name in math_op_names: 593 sym_name = f"sym_{name}" 594 priv_sym_name = f"_{sym_name}" 595 setattr(SymNode, sym_name, _get_sym_node_fn(name)) 596 METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) 597 unary_magic_methods.add(sym_name) 598 __all__.append(sym_name) 599 600 601# Unary methods that are not magic methods 602unary_nonmagic_methods = { 603 "is_integer", 604} 605 606unary_methods = unary_magic_methods | unary_nonmagic_methods 607 608# Most methods are only registered on SymInt and SymFloat 609# Some methods are only be registered on SymBool 610only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} 611# Methods that implicitly convert SymBool into SymInt 612bool_becomes_int_magic_methods = {"add", "sub", "mul"} 613# Methods that are also on SymBool, in addition to on SymInt and SymFloat 614also_bool_magic_methods = {"eq"} 615bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods 616 617# Methods that are only for float 618only_float_magic_methods = {"is_integer", "round", "sym_int"} 619 620 621magic_methods_on_operator_with_trailing_underscore = {"and", "or"} 622 623 624always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} 625 626for name in math_op_names: 627 sym_name = f"sym_{name}" 628 always_float_magic_methods.add(sym_name) 629 630 631always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} 632always_bool_magic_methods = { 633 "eq", 634 "ne", 635 "gt", 636 "lt", 637 "le", 638 "ge", 639 "and", 640 "or", 641 "sym_not", 642 "is_non_overlapping_and_dense", 643 "is_integer", 644} 645 646# Methods that have a `__foo__` as well as `__rfoo__` 647 648 649def _sympy_float_truediv(a, b): 650 from torch.utils._sympy.functions import FloatTrueDiv 651 652 return FloatTrueDiv(a, b) 653 654 655def _sympy_int_truediv(a, b): 656 from torch.utils._sympy.functions import IntTrueDiv 657 658 return IntTrueDiv(a, b) 659 660 661def _sympy_floordiv(a, b): 662 from torch.utils._sympy.functions import FloorDiv 663 664 return FloorDiv(a, b) 665 666 667def _sympy_mod(a, b): 668 from torch.utils._sympy.functions import Mod, PythonMod 669 670 if a.is_nonnegative and b.is_nonnegative: 671 return Mod(a, b) 672 else: 673 return PythonMod(a, b) 674 675 676def _sympy_pow_by_natural(a, b): 677 from torch.utils._sympy.functions import PowByNatural 678 679 return PowByNatural(a, b) 680 681 682def _sympy_float_pow(a, b): 683 from torch.utils._sympy.functions import FloatPow 684 685 return FloatPow(a, b) 686 687 688def _sympy_and(a, b): 689 import sympy 690 691 return sympy.And(a, b) 692 693 694def _sympy_or(a, b): 695 import sympy 696 697 return sympy.Or(a, b) 698 699 700def _sympy_lshift(a, b): 701 from torch.utils._sympy.functions import LShift 702 703 return LShift(a, b) 704 705 706def _sympy_rshift(a, b): 707 from torch.utils._sympy.functions import RShift 708 709 return RShift(a, b) 710 711 712reflectable_magic_methods = { 713 "add": operator.add, 714 "sub": operator.sub, 715 "mul": operator.mul, 716 "mod": _sympy_mod, 717 "pow_by_natural": _sympy_pow_by_natural, 718 "float_pow": _sympy_float_pow, 719 "and": _sympy_and, 720 "or": _sympy_or, 721 "float_truediv": _sympy_float_truediv, 722 "int_truediv": _sympy_int_truediv, 723 "int_floordiv": _sympy_floordiv, 724 "lshift": _sympy_lshift, 725 "rshift": _sympy_rshift, 726} 727 728 729def _floor_ceil_helper(a, fn): 730 import sympy 731 732 if isinstance(a, sympy.Mul): 733 aa = a.args 734 if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: 735 coef = sympy.Integer(aa[0]) 736 if aa[0] == coef: # structural equality test 737 return coef * aa[1] 738 if ( 739 isinstance(a, sympy.Float) 740 and a == sympy.Integer(a) 741 or isinstance(a, sympy.Integer) 742 ): 743 return sympy.Integer(a) 744 return fn(a) 745 746 747def _sympy_floor(a): 748 from torch.utils._sympy.functions import FloorToInt 749 750 return FloorToInt(a) 751 752 753# NB: this is Python trunc semantics which returns an int. Do NOT use this to 754# represent torch.trunc (which is float to float) 755def _sympy_trunc(a): 756 from torch.utils._sympy.functions import TruncToInt 757 758 return TruncToInt(a) 759 760 761def _sympy_ceil(a): 762 from torch.utils._sympy.functions import CeilToInt 763 764 return CeilToInt(a) 765 766 767def _sympy_eq(a, b): 768 import sympy 769 770 return sympy.Eq(a, b) 771 772 773def _sympy_ne(a, b): 774 import sympy 775 776 return sympy.Ne(a, b) 777 778 779def _sympy_gt(a, b): 780 import sympy 781 782 return sympy.Gt(a, b) 783 784 785def _sympy_lt(a, b): 786 import sympy 787 788 return sympy.Lt(a, b) 789 790 791def _sympy_le(a, b): 792 import sympy 793 794 return sympy.Le(a, b) 795 796 797def _sympy_ge(a, b): 798 import sympy 799 800 return sympy.Ge(a, b) 801 802 803def _sympy_min(a, b): 804 from torch.utils._sympy.functions import Min 805 806 return Min(a, b) 807 808 809def _sympy_max(a, b): 810 from torch.utils._sympy.functions import Max 811 812 return Max(a, b) 813 814 815def _sympy_ite(a, t, f): 816 import sympy 817 818 return sympy.Piecewise((t, a), (f, True)) 819 820 821current_module = sys.modules[__name__] 822 823 824def _get_sym_math_fn(name): 825 def fn(a): 826 import torch.utils._sympy.functions 827 828 return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a) 829 830 return fn 831 832 833for name in math_op_names: 834 priv_sympy_name = f"_sympy_{name}" 835 fn = _get_sym_math_fn(name) 836 fn.__qualname__ = fn.__name__ = priv_sympy_name 837 setattr(current_module, priv_sympy_name, fn) 838 839del fn, name, priv_sympy_name # type: ignore[possibly-undefined] 840 841 842def _sympy_abs(a): 843 import sympy 844 845 return sympy.Abs(a) 846 847 848def _sympy_round(number, ndigits=None): 849 from torch.utils._sympy.functions import RoundDecimal, RoundToInt 850 851 if ndigits is None: 852 return RoundToInt(number) 853 else: 854 return RoundDecimal(number, ndigits) 855 856 857def _sympy_sym_float(a): 858 from torch.utils._sympy.functions import ToFloat 859 860 # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly 861 # reports that it is an integer 862 return ToFloat(a) 863 864 865def _sympy_is_integer(a): 866 import sympy 867 868 from torch.utils._sympy.functions import ToFloat 869 870 return sympy.Eq(ToFloat(sympy.floor(a)), a) 871 872 873magic_methods = { 874 **reflectable_magic_methods, 875 "sym_not": operator.invert, 876 "pos": operator.pos, 877 "eq": _sympy_eq, 878 "ne": _sympy_ne, 879 "gt": _sympy_gt, 880 "lt": _sympy_lt, 881 "le": _sympy_le, 882 "ge": _sympy_ge, 883 "floor": _sympy_floor, 884 "trunc": _sympy_trunc, 885 "sym_float": _sympy_sym_float, 886 "ceil": _sympy_ceil, 887 "neg": operator.neg, 888 "sym_min": _sympy_min, 889 "sym_max": _sympy_max, 890 "sym_ite": _sympy_ite, 891 "abs": _sympy_abs, 892 "round": _sympy_round, 893 "is_integer": _sympy_is_integer, 894} 895 896 897for name in math_op_names: 898 sym_name = f"sym_{name}" 899 magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") 900 901del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined] 902 903 904def sympy_is_contiguous(sizes, strides): 905 dim = len(sizes) 906 return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) 907 908 909def sympy_is_contiguous_generic(sizes, strides, dim_order): 910 import sympy 911 912 dim = len(sizes) 913 914 if len(dim_order) != dim: 915 return sympy.false 916 917 is_contiguous = sympy.true 918 z = sympy.Integer(1) 919 # Contiguous if the strides make sense (or the dim is size 1) 920 for d in dim_order: 921 is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z) 922 z *= sizes[d] 923 # OR if any size is zero 924 for d in range(dim): 925 is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0)) 926 return is_contiguous 927 928 929# NB: There is a TODO in C++ to allow omitting the batch dim. If that 930# happens you will need to refactor this 931 932 933def sympy_is_channels_last_contiguous_2d(sizes, strides): 934 return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) 935 936 937def sympy_is_channels_last_contiguous_3d(sizes, strides): 938 return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) 939 940 941def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): 942 import sympy 943 944 from torch.utils._sympy.functions import Max 945 946 dim = len(sizes) 947 948 if dim != len(dim_order): 949 return sympy.false 950 951 m = sympy.Integer(0) 952 r = sympy.true 953 954 # special case for trivial C dimension. default to NCHW 955 r &= sympy.Ne(strides[1], 0) 956 957 for d in dim_order: 958 r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) 959 # Fallback to NCHW as default layout for ambiguous cases 960 # This is the flaw of implicit memory_format from strides. 961 # N111 tensor with identical strides for size 1 dimension; 962 # Two cases could lead us here: 963 # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) 964 # b. N11W contiguous Tensor sliced on the W-dimension. 965 # ([N,1,1,1]@[W,W,W,W]) 966 if d == 0: 967 r &= sympy.Ne(m, strides[1]) 968 # This is necessary to: 969 # 1. distinguish the memory_format of N1H1; 970 # [H, 1, 1, 1] channels_last stride 971 # [H, H, 1, 1] contiguous stride 972 # 2. permutation of 1C1W: 973 # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) 974 # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as 975 # channels_last 976 m = strides[d] * Max(sizes[d], 1) 977 978 return r 979 980 981def sympy_is_channels_last_strides_2d(sizes, strides): 982 return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) 983 984 985def sympy_is_channels_last_strides_3d(sizes, strides): 986 return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) 987 988 989def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): 990 from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator 991 992 return IsNonOverlappingAndDenseIndicator(*sizes, *strides) 993 994 995sizes_strides_methods = { 996 # TODO: These could also be done with indicators, maybe it is better 997 # for reasoning to do it that way 998 "is_contiguous": sympy_is_contiguous, 999 "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d, 1000 "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d, 1001 "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d, 1002 "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d, 1003 "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator, 1004} 1005 1006alternate_impl_if_hinted_methods = { 1007 "sym_min": builtins.min, 1008 "sym_max": builtins.max, 1009} 1010 1011 1012def to_node(self, num): 1013 if isinstance(num, SymTypes): 1014 return num.node 1015 elif type(num) is bool: 1016 return self.wrap_bool(num) 1017 elif type(num) is int: 1018 return self.wrap_int(num) 1019 elif type(num) is float: 1020 return self.wrap_float(num) 1021 else: 1022 # NotImplemented is important so that Python tries the 1023 # other magic method 1024 return NotImplemented 1025 1026 1027def wrap_node(x): 1028 # TODO: let C++ also take advantage of this 1029 if isinstance(x, SymNode) and x.constant is not None: 1030 return x.constant 1031 if x.is_int(): 1032 return SymInt(x) 1033 elif x.is_float(): 1034 return SymFloat(x) 1035 elif x.is_bool(): 1036 return SymBool(x) 1037 else: 1038 raise AssertionError(f"unrecognized return type {x}") 1039 1040 1041def method_to_operator(method): 1042 return METHOD_TO_OPERATOR[method] 1043 1044 1045def _make_node_magic(method, func): 1046 func = lru_cache(256)(func) 1047 1048 if method in magic_methods_on_operator_with_trailing_underscore: 1049 method_attr = f"{method}_" 1050 else: 1051 method_attr = method 1052 1053 def binary_magic_impl(self, other): 1054 from torch.fx.experimental.proxy_tensor import ( 1055 get_proxy_mode, 1056 handle_sym_dispatch, 1057 ) 1058 from torch.fx.experimental.symbolic_shapes import safe_expand 1059 1060 op = method_to_operator(method) 1061 1062 out_hint = None 1063 if self.hint is not None and other.hint is not None: 1064 out_hint = op(self.hint, other.hint) 1065 1066 alternate_impl = alternate_impl_if_hinted_methods.get(method) 1067 if alternate_impl and out_hint is not None: 1068 return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) 1069 1070 if get_proxy_mode(): 1071 return to_node( 1072 self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) 1073 ) 1074 assert isinstance(other, SymNode) 1075 try: 1076 if method == "mod": 1077 from torch.utils._sympy.functions import Mod, PythonMod 1078 1079 # Special handling for mod that requires access to the value 1080 # ranges 1081 shape_env = self.shape_env 1082 if ( 1083 self.expr.is_nonnegative 1084 or shape_env.bound_sympy(self.expr).lower >= 0 1085 ) and ( 1086 other.expr.is_nonnegative 1087 or shape_env.bound_sympy(other.expr).lower >= 0 1088 ): 1089 out = Mod(self.expr, other.expr) 1090 else: 1091 out = PythonMod(self.expr, other.expr) 1092 else: 1093 # TODO: consider constant prop here 1094 out = func(self.expr, other.expr) 1095 except Exception: 1096 log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) 1097 raise 1098 out = safe_expand(out) 1099 sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out) 1100 pytype: Type 1101 # This is not strictly correct. In Python, a**b may return complex when 1102 # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This 1103 # returns a float while both arguments are ints: 2**(-1). Also, max and 1104 # min do not type promote. To avoid having data-dependent control flow 1105 # here, we just set the type to float if one of the args is a float. In 1106 # case of a type mismatch, we assume that it will be detected during 1107 # evaluation. 1108 if method in always_float_magic_methods: 1109 pytype = float 1110 elif method in always_bool_magic_methods: 1111 pytype = bool 1112 elif self.pytype is float or other.pytype is float: 1113 pytype = float 1114 else: 1115 pytype = self.pytype 1116 1117 if ( 1118 pytype is not None 1119 and out_hint is not None 1120 and not isinstance(out_hint, SymTypes) 1121 ): 1122 out_hint = pytype(out_hint) 1123 1124 # Create a FX node that corresponds to the operation being applied to 1125 # this node. 1126 fx_node, _ = self.shape_env._create_fx_call_function( 1127 op, (self.fx_node, other.fx_node) 1128 ) 1129 return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) 1130 1131 def unary_magic_impl(self): 1132 from torch.fx.experimental.proxy_tensor import ( 1133 get_proxy_mode, 1134 handle_sym_dispatch, 1135 ) 1136 from torch.fx.experimental.symbolic_shapes import safe_expand 1137 1138 op = method_to_operator(method) 1139 if get_proxy_mode(): 1140 return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) 1141 # TODO: consider constant prop here 1142 expr = self.expr 1143 if method == "floor" or method == "ceiling": 1144 expr = self.shape_env._simplify_floor_div(expr) 1145 1146 try: 1147 out = func(expr) 1148 except Exception: 1149 log.warning("failed to eval %s(%s)", method, expr) 1150 raise 1151 sym_node_log.debug("%s %s -> %s", func, expr, out) 1152 out_hint = None 1153 if self.hint is not None: 1154 out_hint = op(self.hint) 1155 out = safe_expand(out) 1156 pytype: Type 1157 if method in always_int_magic_methods: 1158 pytype = int 1159 elif method in always_bool_magic_methods: 1160 pytype = bool 1161 elif method in always_float_magic_methods: 1162 pytype = float 1163 else: 1164 pytype = self.pytype 1165 1166 fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,)) 1167 return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) 1168 1169 if method in unary_methods: 1170 setattr(SymNode, f"_{method_attr}", unary_magic_impl) 1171 elif method == "sym_ite": 1172 1173 def sym_ite_impl(pred_node, then_node, else_node): 1174 from torch.fx.experimental.proxy_tensor import ( 1175 get_proxy_mode, 1176 handle_sym_dispatch, 1177 ) 1178 from torch.fx.experimental.symbolic_shapes import safe_expand 1179 1180 out_hint = then_node.hint if pred_node.hint else else_node.hint 1181 if get_proxy_mode(): 1182 return to_node( 1183 pred_node, 1184 handle_sym_dispatch( 1185 sym_ite, 1186 ( 1187 wrap_node(pred_node), 1188 wrap_node(then_node), 1189 wrap_node(else_node), 1190 ), 1191 {}, 1192 ), 1193 ) 1194 1195 try: 1196 out = func(pred_node.expr, then_node.expr, else_node.expr) 1197 except Exception: 1198 log.warning( 1199 "failed to eval %s(%s, %s, %s)", 1200 method, 1201 pred_node.expr, 1202 then_node.expr, 1203 else_node.expr, 1204 ) 1205 raise 1206 1207 out = safe_expand(out) 1208 fx_node, _ = pred_node.shape_env._create_fx_call_function( 1209 sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) 1210 ) 1211 return SymNode( 1212 out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node 1213 ) 1214 1215 setattr(SymNode, f"_{method_attr}", sym_ite_impl) 1216 elif method == "round": 1217 1218 def round_impl(self, ndigits=None): 1219 from torch.fx.experimental.proxy_tensor import ( 1220 get_proxy_mode, 1221 handle_sym_dispatch, 1222 ) 1223 from torch.fx.experimental.symbolic_shapes import safe_expand 1224 1225 op = builtins.round 1226 if get_proxy_mode(): 1227 return to_node( 1228 self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) 1229 ) 1230 1231 expr = self.expr 1232 try: 1233 out = func(expr, ndigits) 1234 except Exception: 1235 log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) 1236 raise 1237 1238 out = safe_expand(out) 1239 1240 if ndigits is None: 1241 pytype = int 1242 else: 1243 pytype = self.pytype 1244 1245 out_hint = None 1246 if self.hint is not None: 1247 out_hint = op(self.hint, ndigits) 1248 1249 # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the 1250 # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here 1251 # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The 1252 # hack down below works, because all round function down the line all take ndigits=None as default in their 1253 # signature. 1254 # TODO: Remove the args construction below if a different sentinel is used by FX. 1255 # ezyang(May 2024): LOL 1256 args = [self.fx_node] 1257 if ndigits is not None: 1258 args.append(ndigits) 1259 fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args)) 1260 return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) 1261 1262 setattr(SymNode, f"_{method_attr}", round_impl) 1263 else: 1264 setattr(SymNode, f"_{method_attr}", binary_magic_impl) 1265 1266 1267def _make_node_sizes_strides(method, func): 1268 # NB: don't LRU cache, lots of arguments 1269 1270 def sizes_strides_impl(self, sizes, strides): 1271 from torch.fx.experimental.proxy_tensor import ( 1272 get_proxy_mode, 1273 handle_sym_dispatch, 1274 ) 1275 1276 op = getattr(sys.modules[__name__], method) 1277 if get_proxy_mode(): 1278 return to_node( 1279 self, 1280 handle_sym_dispatch( 1281 op, 1282 ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), 1283 {}, 1284 ), 1285 ) 1286 size_exprs = [s.expr for s in sizes] 1287 stride_exprs = [s.expr for s in strides] 1288 try: 1289 out = func(size_exprs, stride_exprs) 1290 except Exception: 1291 log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) 1292 raise 1293 # bool is never expandable 1294 1295 size_hints = [] 1296 out_hint = None 1297 for s in sizes: 1298 if s.hint is None: 1299 break 1300 size_hints.append(s.hint) 1301 else: 1302 stride_hints = [] 1303 for s in strides: 1304 if s.hint is None: 1305 break 1306 stride_hints.append(s.hint) 1307 else: 1308 out_hint = op(size_hints, stride_hints) 1309 1310 # NB: This is the indicator function, not the actual bool! 1311 pytype: Type 1312 if method.endswith("_indicator"): 1313 pytype = int 1314 else: 1315 pytype = bool 1316 return SymNode(out, self.shape_env, pytype, out_hint) 1317 1318 setattr(SymNode, f"_{method}", sizes_strides_impl) 1319 1320 # TODO: This is technically hotpath, but in the ideal end state 1321 # guards on this will resolve at a higher level so you never 1322 # spend time in this code 1323 def sizes_strides_user(sizes, strides): 1324 import sympy 1325 1326 from torch.fx.experimental.symbolic_shapes import ( 1327 eval_is_non_overlapping_and_dense, 1328 ) 1329 1330 for a in itertools.chain(sizes, strides): 1331 if isinstance(a, SymInt): 1332 return wrap_node( 1333 getattr(a.node, method)( 1334 [to_node(a.node, b) for b in sizes], 1335 [to_node(a.node, b) for b in strides], 1336 ) 1337 ) 1338 if method == "is_non_overlapping_and_dense_indicator": 1339 return eval_is_non_overlapping_and_dense(sizes, strides) 1340 else: 1341 # TODO: this is an awful implementation 1342 return bool( 1343 func( 1344 [sympy.sympify(a) for a in sizes], 1345 [sympy.sympify(a) for a in strides], 1346 ) 1347 ) 1348 1349 # Skip for is_non_overlapping_and_dense_indicator 1350 if not hasattr(sys.modules[__name__], method): 1351 setattr(sys.modules[__name__], method, sizes_strides_user) 1352 1353 1354for method, func in magic_methods.items(): 1355 _make_node_magic(method, func) 1356 1357for method, func in sizes_strides_methods.items(): 1358 _make_node_sizes_strides(method, func) 1359 1360 1361def _make_user_magic(method, user_type): 1362 # User magic takes care of wrapping the other operand into a node, 1363 # so that our internal logic can assume everything is nodes 1364 1365 if method in magic_methods_on_operator_with_trailing_underscore: 1366 method_attr = f"sym_{method}" 1367 else: 1368 method_attr = method 1369 1370 def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): 1371 if isinstance(x, (int, float, bool)): 1372 return x 1373 if isinstance(x, SymBool): 1374 return x.node.guard_bool("", 0) 1375 raise AssertionError("expect to be called with constant SymBools") 1376 1377 def is_constant(x): 1378 if isinstance(x, (int, float, bool)): 1379 return True 1380 if isinstance(x, (SymInt, SymFloat, SymBool)): 1381 return x.node.is_constant() 1382 return False 1383 1384 # Promotion rules for binary operations. NB: we preserve PYTHON semantics 1385 # - if args are same type, do nothing 1386 # - if one arg is float, promote other arg to float 1387 # - nb: this applies to floordiv, even though output is integral 1388 # (it's still float) 1389 # - pow is funny business 1390 # - if both ints 1391 # - trigger a guard on exponent >= 0 1392 # - if non-negative, output is int 1393 # - otherwise, output is float 1394 # - otherwise, promote other arg to float 1395 # - nb: complex is impossible to handle correctly lol, with 1396 # negative base and integral float need to diverge semantics and 1397 # just always return complex. Neener neener pretend this problem 1398 # doesn't exist 1399 # - equality is pain: Python does the fancy thing where it unpacks the 1400 # mantissa from the float and then compares that against the int. 1401 # Which means it is able to tell that 1402 # 9007199254740993 != 9007199254740992. (rather than if the LHS was 1403 # promoted to float, in which case it would have truncated to the RHS 1404 # and subsequently been equal). We'll model this exactly by having 1405 # special mixed type equality operations. Unfortunately, we need to 1406 # do this for all comparison operations (maybe I'll only implement 1407 # compare) 1408 # - sym_ite mumble mumble really shouldn't allow mixed but whatever 1409 1410 if method in bool_becomes_int_magic_methods: 1411 1412 def promote(x): 1413 """Implements True+True=2, which works in python but not sympy""" 1414 if isinstance(x, SymBool): 1415 return SymInt(x.node.wrap_int(int(x))) 1416 return x 1417 1418 else: 1419 1420 def promote(x): 1421 return x 1422 1423 def promote2(self, other): 1424 # TODO: Remove eq and other relations from this list. 1425 # CPython has fancy implementations for these to get as much precision 1426 # as possible instead of just promoting to float64 and praying, so we 1427 # need to handle them specially too. 1428 # Also, note that int_truediv doesn't go through this path: both 1429 # arguments are "int" so there isn't any promotion 1430 if method not in [ 1431 "add", 1432 "sub", 1433 "mul", 1434 "mod", 1435 "float_pow", 1436 "float_truediv", 1437 "int_floordiv", 1438 "sym_min", 1439 "sym_max", 1440 # TODO: remove these 1441 "eq", 1442 "ne", 1443 "gt", 1444 "lt", 1445 "le", 1446 "ge", 1447 ]: 1448 return self, other 1449 f_self = isinstance(self, (float, torch.SymFloat)) 1450 f_other = isinstance(other, (float, torch.SymFloat)) 1451 if f_self or f_other: 1452 if not f_self: 1453 self = torch.sym_float(self) 1454 if not f_other: 1455 other = torch.sym_float(other) 1456 return self, other 1457 1458 # Before and after performing the operation, check if any operands are constant. 1459 # If so, extract out the constant values first. If `self` itself is a 1460 # constant, then "redispatch" by calling back into the operator. Sometimes 1461 # this means that operations involving SymBool return plain bools. 1462 # Alternatively, we could also rewrap into constant Symbool (i.e. by 1463 # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that 1464 # today for no particular reason. 1465 def unary_magic_impl(self): 1466 self = promote(self) 1467 if is_constant(self): 1468 return (method_to_operator(method))(get_constant(self)) 1469 return wrap_node(getattr(self.node, method_attr)()) 1470 1471 def binary_magic_impl(self, other): 1472 if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): 1473 return NotImplemented 1474 sym_node_log.debug("MAGIC %s %s %s", method, self, other) 1475 self = promote(self) 1476 other = promote(other) 1477 self, other = promote2(self, other) 1478 if is_constant(self): 1479 return (method_to_operator(method))(get_constant(self), other) 1480 if is_constant(other): 1481 other = get_constant(other) 1482 other_node = to_node(self.node, other) 1483 if other_node is NotImplemented: 1484 return NotImplemented 1485 ret = wrap_node(getattr(self.node, method_attr)(other_node)) 1486 return get_constant(ret) if is_constant(ret) else ret 1487 1488 def rbinary_magic_impl(self, other): 1489 if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): 1490 return NotImplemented 1491 self = promote(self) 1492 other = promote(other) 1493 self, other = promote2(self, other) 1494 if is_constant(self): 1495 return (method_to_operator(method))(get_constant(self), other) 1496 if is_constant(other): 1497 other = get_constant(other) 1498 other_node = to_node(self.node, other) 1499 if other_node is NotImplemented: 1500 return NotImplemented 1501 ret = wrap_node(getattr(other_node, method_attr)(self.node)) 1502 return get_constant(ret) if is_constant(ret) else ret 1503 1504 if method in unary_magic_methods: 1505 setattr(user_type, f"__{method}__", unary_magic_impl) 1506 elif method in unary_nonmagic_methods: 1507 orig = getattr(user_type, method) 1508 setattr(user_type, method, update_wrapper(unary_magic_impl, orig)) 1509 elif method == "sym_ite": 1510 1511 def sym_ite_magic_impl(pred, then_val, else_val): 1512 pred_node = pred.node 1513 then_node = to_node(pred_node, then_val) 1514 else_node = to_node(pred_node, else_val) 1515 if then_node is NotImplemented or else_node is NotImplemented: 1516 return NotImplemented 1517 assert ( 1518 isinstance(then_node, SymNode) 1519 and isinstance(else_node, SymNode) 1520 and then_node.pytype == else_node.pytype 1521 ) 1522 ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) 1523 return get_constant(ret) if ret.node.is_constant() else ret 1524 1525 setattr(user_type, f"__{method}__", sym_ite_magic_impl) 1526 elif method == "round": 1527 1528 def round_magic_impl(self, ndigits=None): 1529 if is_constant(self): 1530 return builtins.round(get_constant(self), ndigits) 1531 1532 return wrap_node(getattr(self.node, method)(ndigits)) 1533 1534 setattr(user_type, f"__{method}__", round_magic_impl) 1535 else: 1536 setattr(user_type, f"__{method}__", binary_magic_impl) 1537 if method in reflectable_magic_methods: 1538 setattr(user_type, f"__r{method}__", rbinary_magic_impl) 1539 1540 1541for method, func in magic_methods.items(): # type: ignore[assignment] 1542 if method in only_bool_magic_methods: 1543 _make_user_magic(method, SymBool) 1544 continue 1545 if method in only_float_magic_methods: 1546 _make_user_magic(method, SymFloat) 1547 continue 1548 if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods: 1549 _make_user_magic(method, SymBool) 1550 _make_user_magic(method, SymInt) 1551 _make_user_magic(method, SymFloat) 1552 1553del method 1554del func 1555