xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/sym_node.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This file does three things:
4- Contains the definition of SymNode
5- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
6- Does not depend on sympy at import time
7
8As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
9to avoid having to load SymPy at import time, as doing so is *very* slow.
10"""
11
12import builtins
13import itertools
14import logging
15import math
16import operator
17import sys
18from functools import lru_cache, update_wrapper
19from typing import Optional, Type, TYPE_CHECKING, Union
20
21import torch
22
23# NB: The sym_* functions are used via getattr() and must be imported here.
24from torch import (  # noqa: F401
25    sym_float,
26    sym_ite,
27    sym_max,
28    sym_min,
29    sym_not,
30    SymBool,
31    SymFloat,
32    SymInt,
33)
34
35
36if TYPE_CHECKING:
37    from torch.fx.experimental.symbolic_shapes import ShapeEnv
38
39log = logging.getLogger(__name__)
40sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
41
42
43__all__ = ["SymNode", "method_to_operator", "magic_methods"]
44
45
46from torch.types import py_sym_types as SymTypes
47
48
49def _to_symtype(t):
50    if t is bool:
51        return SymBool
52    if t is int:
53        return SymInt
54    if t is float:
55        return SymFloat
56    return t
57
58
59# TODO: An incomplete list
60# 1. Set variables to be equal when we do equality
61# 2. Specialize on 0/1 when we do subtraction
62class SymNode:
63    """
64    This is a type erased SymInt/SymFloat which we use to do actual operations.
65    End users don't touch this.  Magic methods are NOT defined on this object.
66    """
67
68    def __init__(
69        self,
70        expr,
71        shape_env,
72        pytype,
73        hint: Optional[Union[int, float, bool]],
74        constant=None,
75        fx_node=None,
76    ):
77        self._expr = expr
78        self.shape_env = shape_env
79        self.pytype = pytype
80
81        # What's the difference between hint and constant?
82        #
83        # - A constant is known to be invariant across invocations of the model;
84        #   it will always be this value.  We only really know this when we
85        #   encounter an honest-to-goodness literal (when wrapping it into
86        #   a SymNode, we set constant.)  Most of the time, constant is None
87        #
88        # - A hint is a *particular* value from the particular run we are
89        #   tracing, but it may vary the next time around.  It's useful to
90        #   keep this around, as if we need a concrete value from a SymNode,
91        #   we will return the hint and guard on the expression that produced
92        #   it giving the same hint next time around.  The hint is not
93        #   guaranteed to be set either: if you have an unbacked SymNode,
94        #   there won't be any hint; it was the result of some tensor-dependent
95        #   computation, but we don't know what it actually is because we
96        #   haven't actually run the tensor computation.
97        #
98        # If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
99        # in hopes that we've learned enough about the unbacked symints to
100        # discharge the hint; otherwise, you're likely to just error out.
101        #
102        # (A previous version of this system had some optimizations to only
103        # recompute when it was possible we had learned enough about the
104        # unbacked symint that a hint was now possible, but as we added more
105        # potential refinements to unbacked symints this got harder to keep
106        # in sync, so we've deleted it for now.)
107
108        def compute_hint():
109            from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
110
111            # This occasionally gets exercised by, e.g.,
112            # convert_shape_to_symint.  It's just a nicety so you don't HAVE
113            # to have a correct hint on hand when making a SymNode.
114            # Don't attempt to compute for unbacked, this can be quite
115            # expensive.
116            if free_unbacked_symbols(self.expr):
117                return None
118            hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
119            if hint is not None:
120                hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint
121            return hint
122
123        if hint is not None:
124            assert type(hint) is pytype or type(hint) is _to_symtype(pytype), (
125                "Cannot create SymNode of type "
126                f"{pytype} with incompatible hint of type {type(hint)}"
127            )
128            if self.shape_env and self.shape_env._translation_validation_enabled:
129                # This is technically not TV, but this assert is expensive so
130                # let's only do it when we're already doing expensive things
131                computed_hint = compute_hint()
132                assert (
133                    hint == computed_hint
134                ), f"{hint} != {computed_hint} (for {self.expr})"
135        else:
136            hint = compute_hint()
137        self._hint = hint
138        self.constant: Optional[Union[int, float, bool]] = constant
139
140        # Record the FX node of the current node if we are doing translation
141        # validation. They will be used for building the input assertions for
142        # the translation validation problem.
143        tx_validation_en = (
144            self.shape_env and self.shape_env._translation_validation_enabled
145        )
146        self.fx_node = tx_validation_en and fx_node
147
148    def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode":
149        return SymNode(
150            self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
151        )
152
153    def _value_eq(self, other: "SymNode") -> bool:
154        # Purposely don't include the shape_env in the eq.
155        return (
156            self._expr == other._expr
157            and self.pytype == other.pytype
158            and self._hint == other._hint
159            and self.constant == other.constant
160            and self.fx_node == other.fx_node
161        )
162
163    def _value_hash(self) -> int:
164        # Purposely don't include the shape_env in the hash.
165        return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
166
167    @property
168    def expr(self):
169        return self.shape_env.replace(self._expr)
170
171    @property
172    def hint(self):
173        return self._hint
174
175    def has_hint(self):
176        return self._hint is not None
177
178    def require_hint(self, fallback=None):
179        if self._hint is None:
180            if fallback is not None:
181                return fallback
182            # NB: we expect this to raise
183            return self.shape_env.size_hint(self.expr)
184        return self._hint
185
186    def maybe_as_int(self):
187        if self.expr.is_number:
188            return int(self.expr)
189        else:
190            return None
191
192    # NB: This does conversions, not sure if this is good or not
193    def maybe_as_float(self):
194        import sympy
195
196        if isinstance(self.expr, sympy.Float):
197            return float(self.expr)
198        else:
199            return None
200
201    def maybe_as_bool(self):
202        import sympy
203
204        if self.expr is sympy.true:
205            return True
206        elif self.expr is sympy.false:
207            return False
208        else:
209            return None
210
211    def is_int(self):
212        return self.pytype is int
213
214    def is_float(self):
215        return self.pytype is float
216
217    def is_bool(self):
218        return self.pytype is bool
219
220    def is_nested_int(self):
221        # Unbacked SymInts cannot be nested int today
222        return (
223            self._hint is not None
224            and isinstance(self._hint, SymInt)
225            and self._hint.node.is_nested_int()
226        )
227
228    def wrap_int(self, num):
229        assert type(num) is int
230        import sympy
231
232        return SymNode(
233            sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
234        )
235
236    def wrap_float(self, num):
237        assert type(num) is float
238        import sympy
239
240        return SymNode(
241            sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
242        )
243
244    def wrap_bool(self, num):
245        assert type(num) is bool
246        import sympy
247
248        return SymNode(
249            sympy.true if num else sympy.false,
250            self.shape_env,
251            bool,
252            num,
253            constant=num,
254            fx_node=num,
255        )
256
257    def clone(self):
258        return self
259
260    def str(self):
261        return f"{self.expr}"
262
263    def __str__(self):
264        return self.str()
265
266    def __repr__(self):
267        rep = [
268            f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}",
269        ]
270        if self._hint is not None:
271            rep.append(f"hint={self._hint}")
272        if self.constant is not None:
273            rep.append(f"constant={self.constant}")
274        if self.fx_node is not None:
275            rep.append(f"fx_node={self.fx_node}")
276        return ", ".join(rep) + ")"
277
278    def _graph_repr(self) -> builtins.str:
279        # Representation used by GraphModule to create a pythonic version of a graph
280        return self.str()
281
282    # These methods call the metaprogrammed methods, they're hand written
283    # here so we get good stack traces
284    def abs(self) -> "SymNode":
285        return self._abs()  # type: ignore[attr-defined]
286
287    def pos(self) -> "SymNode":
288        return self._pos()  # type: ignore[attr-defined]
289
290    def round(self, ndigits=None) -> "SymNode":
291        return self._round(ndigits)  # type: ignore[attr-defined]
292
293    def trunc(self) -> "SymNode":
294        return self._trunc()  # type: ignore[attr-defined]
295
296    def add(self, other) -> "SymNode":
297        return self._add(other)  # type: ignore[attr-defined]
298
299    def sub(self, other) -> "SymNode":
300        return self._sub(other)  # type: ignore[attr-defined]
301
302    def mul(self, other) -> "SymNode":
303        return self._mul(other)  # type: ignore[attr-defined]
304
305    def mod(self, other) -> "SymNode":
306        return self._mod(other)  # type: ignore[attr-defined]
307
308    def float_pow(self, other) -> "SymNode":
309        return self._float_pow(other)  # type: ignore[attr-defined]
310
311    def pow_by_natural(self, other) -> "SymNode":
312        return self._pow_by_natural(other)  # type: ignore[attr-defined]
313
314    def and_(self, other) -> "SymNode":
315        return self._and_(other)  # type: ignore[attr-defined]
316
317    def or_(self, other) -> "SymNode":
318        return self._or_(other)  # type: ignore[attr-defined]
319
320    def float_truediv(self, other) -> "SymNode":
321        return self._float_truediv(other)  # type: ignore[attr-defined]
322
323    def int_truediv(self, other) -> "SymNode":
324        return self._int_truediv(other)  # type: ignore[attr-defined]
325
326    def int_floordiv(self, other) -> "SymNode":
327        return self._int_floordiv(other)  # type: ignore[attr-defined]
328
329    def lshift(self, other) -> "SymNode":
330        return self._lshift(other)  # type: ignore[attr-defined]
331
332    def rshift(self, other) -> "SymNode":
333        return self._rshift(other)  # type: ignore[attr-defined]
334
335    def sym_not(self) -> "SymNode":  # noqa: F811
336        return self._sym_not()  # type: ignore[attr-defined]
337
338    def eq(self, other) -> "SymNode":
339        return self._eq(other)  # type: ignore[attr-defined]
340
341    def ne(self, other) -> "SymNode":
342        return self._ne(other)  # type: ignore[attr-defined]
343
344    def gt(self, other) -> "SymNode":
345        return self._gt(other)  # type: ignore[attr-defined]
346
347    def lt(self, other) -> "SymNode":
348        return self._lt(other)  # type: ignore[attr-defined]
349
350    def le(self, other) -> "SymNode":
351        return self._le(other)  # type: ignore[attr-defined]
352
353    def ge(self, other) -> "SymNode":
354        return self._ge(other)  # type: ignore[attr-defined]
355
356    def floor(self) -> "SymNode":
357        return self._floor()  # type: ignore[attr-defined]
358
359    def is_integer(self) -> "SymNode":
360        return self._is_integer()  # type: ignore[attr-defined]
361
362    def sym_float(self) -> "SymNode":  # noqa: F811
363        return self._sym_float()  # type: ignore[attr-defined]
364
365    def sym_int(self) -> "SymNode":
366        return self._sym_int()  # type: ignore[attr-defined]
367
368    def ceil(self) -> "SymNode":
369        return self._ceil()  # type: ignore[attr-defined]
370
371    def neg(self) -> "SymNode":
372        return self._neg()  # type: ignore[attr-defined]
373
374    def sym_min(self, other) -> "SymNode":  # noqa: F811
375        return self._sym_min(other)  # type: ignore[attr-defined]
376
377    def sym_max(self, other) -> "SymNode":  # noqa: F811
378        return self._sym_max(other)  # type: ignore[attr-defined]
379
380    def sym_ite(self, then_val, else_val) -> "SymNode":
381        return self._sym_ite(then_val, else_val)  # type: ignore[attr-defined]
382
383    def is_contiguous(self, sizes, strides) -> "SymNode":
384        return self._is_contiguous(sizes, strides)  # type: ignore[attr-defined]
385
386    def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode":
387        return self._is_channels_last_contiguous_2d(sizes, strides)  # type: ignore[attr-defined]
388
389    def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode":
390        return self._is_channels_last_contiguous_3d(sizes, strides)  # type: ignore[attr-defined]
391
392    def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode":
393        return self._is_channels_last_strides_2d(sizes, strides)  # type: ignore[attr-defined]
394
395    def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode":
396        return self._is_channels_last_strides_3d(sizes, strides)  # type: ignore[attr-defined]
397
398    def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode":
399        return self._is_non_overlapping_and_dense_indicator(sizes, strides)  # type: ignore[attr-defined]
400
401    # Make C++ happy
402    def sym_or(self, other):
403        return self.or_(other)
404
405    def sym_and(self, other):
406        return self.and_(other)
407
408    # There is no int_truediv available from C++
409    def truediv(self, other):
410        return self.float_truediv(other)
411
412    def floordiv(self, other) -> "SymNode":
413        return self.int_floordiv(other)
414
415    # We didn't bind integer pow in C++
416    def pow(self, other):
417        return self.float_pow(other)
418
419    def is_non_overlapping_and_dense(self, sizes, strides):
420        return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1))  # type: ignore[attr-defined]
421
422    def int_(self):
423        return self.guard_int("", 0)  # NB: uses Python backtrace
424
425    # You can manually trigger a guard with this function
426    def guard_int(self, file, line):
427        # TODO: use the file/line for some useful diagnostic on why a
428        # guard occurred
429        r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
430        try:
431            return int(r)
432        except Exception:
433            log.warning("Failed to convert to int: %s", r)
434            raise
435
436    def guard_float(self, file, line):
437        # TODO: use the file/line for some useful diagnostic on why a
438        # guard occurred
439        r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
440        try:
441            return float(r)
442        except Exception:
443            log.warning("Failed to convert to float: %s", r)
444            raise
445
446    def guard_bool(self, file, line):
447        # TODO: use the file/line for some useful diagnostic on why a
448        # guard occurred
449        r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
450        try:
451            return bool(r)
452        except Exception:
453            log.warning("Failed to convert to bool: %s", r)
454            raise
455
456    def expect_true(self, file, line):
457        from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
458
459        if (
460            self.has_hint()
461            and not free_unbacked_symbols(self.expr)
462            and not self.shape_env.prefer_deferred_runtime_asserts_over_guards
463        ):
464            # OK to generate guards
465            return self.guard_bool(file, line)
466        # Generate a deferred runtime assert (this might actually end up doing
467        # a regular guard if we can!)
468        # TODO: file/line here is very important, because the assert has been
469        # deferred so you can't backtrace easily
470        return self.shape_env.defer_runtime_assert(
471            self.expr, f"{file}:{line}", fx_node=self.fx_node
472        )
473
474    def expect_size(self, file, line):
475        from torch.fx.experimental.symbolic_shapes import _advise_is_size
476
477        b = self.ge(self.wrap_int(0))
478        # Generate a deferred runtime assert
479        r = b.expect_true(file, line)
480        # Refine compile time range, but only if it's unbacked.
481        # If you refine range for hinted variables, you can end up making
482        # improper deductions since compile time reasoning may be
483        # incompatible with runtime reasoning.
484        if r and not self.has_hint():
485            _advise_is_size(SymInt(self))
486        return r
487
488    def guard_size_oblivious(self, file, line):
489        """
490        Like guard_bool, but if we encounter unbacked symbols, if those symbols
491        are size-like, we will treat them as >= 2 for the purposes of the analysis.
492
493        This CHANGES the runtime semantics, but all size-oblivious sites have been
494        audited to ensure that the runtime semantics don't change in a material way.
495        Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
496        an unbacked one size, or a tensor reporting as non-contiguous even if it's
497        contiguous if it would have been reported contiguous due to being empty.
498        """
499        # TODO: use the file/line for some useful diagnostic on why a
500        # guard occurred
501        r = self.shape_env.evaluate_expr(
502            self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True
503        )
504        try:
505            return bool(r)
506        except Exception:
507            log.warning("Failed to convert to bool: %s", r)
508            raise
509
510    def bool_(self):
511        return self.guard_bool("", 0)
512
513    def is_symbolic(self):
514        return True
515
516    def nested_int(self):
517        return None
518
519    def is_constant(self):
520        return False
521
522
523# TODO: this probably needs the sizes-strides eval functions
524METHOD_TO_OPERATOR = {
525    "pos": operator.pos,
526    "abs": operator.abs,
527    "add": operator.add,
528    "and": operator.and_,
529    "ceil": math.ceil,
530    "eq": operator.eq,
531    "floor": math.floor,
532    "trunc": math.trunc,
533    "int_floordiv": operator.floordiv,
534    "ge": operator.ge,
535    "gt": operator.gt,
536    "is_integer": lambda x: x.is_integer(),
537    "le": operator.le,
538    "lshift": operator.lshift,
539    "lt": operator.lt,
540    "mod": operator.mod,
541    "mul": operator.mul,
542    "ne": operator.ne,
543    "neg": operator.neg,
544    "or": operator.or_,
545    "float_pow": operator.pow,
546    "pow_by_natural": operator.pow,
547    "round": builtins.round,
548    "rshift": operator.rshift,
549    "sub": operator.sub,
550    "sym_float": sym_float,
551    "sym_ite": sym_ite,
552    "sym_max": sym_max,
553    "sym_min": sym_min,
554    "sym_not": sym_not,
555    "float_truediv": operator.truediv,
556    "int_truediv": operator.truediv,
557}
558
559unary_magic_methods = {
560    "abs",
561    "sym_float",
562    "sym_int",
563    "ceil",
564    "floor",
565    "neg",
566    "sym_not",
567    "pos",
568    "trunc",
569}
570
571
572# Adding math ops: sqrt, cos, sin, ...
573def _get_sym_node_fn(name):
574    def fn(self):
575        return getattr(self, f"_sym_{name}")()
576
577    return fn
578
579
580math_op_names = (
581    "sqrt",
582    "cos",
583    "cosh",
584    "sin",
585    "sinh",
586    "tan",
587    "tanh",
588    "asin",
589    "acos",
590    "atan",
591)
592for name in math_op_names:
593    sym_name = f"sym_{name}"
594    priv_sym_name = f"_{sym_name}"
595    setattr(SymNode, sym_name, _get_sym_node_fn(name))
596    METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
597    unary_magic_methods.add(sym_name)
598    __all__.append(sym_name)
599
600
601# Unary methods that are not magic methods
602unary_nonmagic_methods = {
603    "is_integer",
604}
605
606unary_methods = unary_magic_methods | unary_nonmagic_methods
607
608# Most methods are only registered on SymInt and SymFloat
609# Some methods are only be registered on SymBool
610only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
611# Methods that implicitly convert SymBool into SymInt
612bool_becomes_int_magic_methods = {"add", "sub", "mul"}
613# Methods that are also on SymBool, in addition to on SymInt and SymFloat
614also_bool_magic_methods = {"eq"}
615bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
616
617# Methods that are only for float
618only_float_magic_methods = {"is_integer", "round", "sym_int"}
619
620
621magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
622
623
624always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
625
626for name in math_op_names:
627    sym_name = f"sym_{name}"
628    always_float_magic_methods.add(sym_name)
629
630
631always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"}
632always_bool_magic_methods = {
633    "eq",
634    "ne",
635    "gt",
636    "lt",
637    "le",
638    "ge",
639    "and",
640    "or",
641    "sym_not",
642    "is_non_overlapping_and_dense",
643    "is_integer",
644}
645
646# Methods that have a `__foo__` as well as `__rfoo__`
647
648
649def _sympy_float_truediv(a, b):
650    from torch.utils._sympy.functions import FloatTrueDiv
651
652    return FloatTrueDiv(a, b)
653
654
655def _sympy_int_truediv(a, b):
656    from torch.utils._sympy.functions import IntTrueDiv
657
658    return IntTrueDiv(a, b)
659
660
661def _sympy_floordiv(a, b):
662    from torch.utils._sympy.functions import FloorDiv
663
664    return FloorDiv(a, b)
665
666
667def _sympy_mod(a, b):
668    from torch.utils._sympy.functions import Mod, PythonMod
669
670    if a.is_nonnegative and b.is_nonnegative:
671        return Mod(a, b)
672    else:
673        return PythonMod(a, b)
674
675
676def _sympy_pow_by_natural(a, b):
677    from torch.utils._sympy.functions import PowByNatural
678
679    return PowByNatural(a, b)
680
681
682def _sympy_float_pow(a, b):
683    from torch.utils._sympy.functions import FloatPow
684
685    return FloatPow(a, b)
686
687
688def _sympy_and(a, b):
689    import sympy
690
691    return sympy.And(a, b)
692
693
694def _sympy_or(a, b):
695    import sympy
696
697    return sympy.Or(a, b)
698
699
700def _sympy_lshift(a, b):
701    from torch.utils._sympy.functions import LShift
702
703    return LShift(a, b)
704
705
706def _sympy_rshift(a, b):
707    from torch.utils._sympy.functions import RShift
708
709    return RShift(a, b)
710
711
712reflectable_magic_methods = {
713    "add": operator.add,
714    "sub": operator.sub,
715    "mul": operator.mul,
716    "mod": _sympy_mod,
717    "pow_by_natural": _sympy_pow_by_natural,
718    "float_pow": _sympy_float_pow,
719    "and": _sympy_and,
720    "or": _sympy_or,
721    "float_truediv": _sympy_float_truediv,
722    "int_truediv": _sympy_int_truediv,
723    "int_floordiv": _sympy_floordiv,
724    "lshift": _sympy_lshift,
725    "rshift": _sympy_rshift,
726}
727
728
729def _floor_ceil_helper(a, fn):
730    import sympy
731
732    if isinstance(a, sympy.Mul):
733        aa = a.args
734        if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
735            coef = sympy.Integer(aa[0])
736            if aa[0] == coef:  # structural equality test
737                return coef * aa[1]
738    if (
739        isinstance(a, sympy.Float)
740        and a == sympy.Integer(a)
741        or isinstance(a, sympy.Integer)
742    ):
743        return sympy.Integer(a)
744    return fn(a)
745
746
747def _sympy_floor(a):
748    from torch.utils._sympy.functions import FloorToInt
749
750    return FloorToInt(a)
751
752
753# NB: this is Python trunc semantics which returns an int.  Do NOT use this to
754# represent torch.trunc (which is float to float)
755def _sympy_trunc(a):
756    from torch.utils._sympy.functions import TruncToInt
757
758    return TruncToInt(a)
759
760
761def _sympy_ceil(a):
762    from torch.utils._sympy.functions import CeilToInt
763
764    return CeilToInt(a)
765
766
767def _sympy_eq(a, b):
768    import sympy
769
770    return sympy.Eq(a, b)
771
772
773def _sympy_ne(a, b):
774    import sympy
775
776    return sympy.Ne(a, b)
777
778
779def _sympy_gt(a, b):
780    import sympy
781
782    return sympy.Gt(a, b)
783
784
785def _sympy_lt(a, b):
786    import sympy
787
788    return sympy.Lt(a, b)
789
790
791def _sympy_le(a, b):
792    import sympy
793
794    return sympy.Le(a, b)
795
796
797def _sympy_ge(a, b):
798    import sympy
799
800    return sympy.Ge(a, b)
801
802
803def _sympy_min(a, b):
804    from torch.utils._sympy.functions import Min
805
806    return Min(a, b)
807
808
809def _sympy_max(a, b):
810    from torch.utils._sympy.functions import Max
811
812    return Max(a, b)
813
814
815def _sympy_ite(a, t, f):
816    import sympy
817
818    return sympy.Piecewise((t, a), (f, True))
819
820
821current_module = sys.modules[__name__]
822
823
824def _get_sym_math_fn(name):
825    def fn(a):
826        import torch.utils._sympy.functions
827
828        return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a)
829
830    return fn
831
832
833for name in math_op_names:
834    priv_sympy_name = f"_sympy_{name}"
835    fn = _get_sym_math_fn(name)
836    fn.__qualname__ = fn.__name__ = priv_sympy_name
837    setattr(current_module, priv_sympy_name, fn)
838
839del fn, name, priv_sympy_name  # type: ignore[possibly-undefined]
840
841
842def _sympy_abs(a):
843    import sympy
844
845    return sympy.Abs(a)
846
847
848def _sympy_round(number, ndigits=None):
849    from torch.utils._sympy.functions import RoundDecimal, RoundToInt
850
851    if ndigits is None:
852        return RoundToInt(number)
853    else:
854        return RoundDecimal(number, ndigits)
855
856
857def _sympy_sym_float(a):
858    from torch.utils._sympy.functions import ToFloat
859
860    # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly
861    # reports that it is an integer
862    return ToFloat(a)
863
864
865def _sympy_is_integer(a):
866    import sympy
867
868    from torch.utils._sympy.functions import ToFloat
869
870    return sympy.Eq(ToFloat(sympy.floor(a)), a)
871
872
873magic_methods = {
874    **reflectable_magic_methods,
875    "sym_not": operator.invert,
876    "pos": operator.pos,
877    "eq": _sympy_eq,
878    "ne": _sympy_ne,
879    "gt": _sympy_gt,
880    "lt": _sympy_lt,
881    "le": _sympy_le,
882    "ge": _sympy_ge,
883    "floor": _sympy_floor,
884    "trunc": _sympy_trunc,
885    "sym_float": _sympy_sym_float,
886    "ceil": _sympy_ceil,
887    "neg": operator.neg,
888    "sym_min": _sympy_min,
889    "sym_max": _sympy_max,
890    "sym_ite": _sympy_ite,
891    "abs": _sympy_abs,
892    "round": _sympy_round,
893    "is_integer": _sympy_is_integer,
894}
895
896
897for name in math_op_names:
898    sym_name = f"sym_{name}"
899    magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
900
901del name, sym_name, math_op_names, current_module  # type: ignore[possibly-undefined]
902
903
904def sympy_is_contiguous(sizes, strides):
905    dim = len(sizes)
906    return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
907
908
909def sympy_is_contiguous_generic(sizes, strides, dim_order):
910    import sympy
911
912    dim = len(sizes)
913
914    if len(dim_order) != dim:
915        return sympy.false
916
917    is_contiguous = sympy.true
918    z = sympy.Integer(1)
919    # Contiguous if the strides make sense (or the dim is size 1)
920    for d in dim_order:
921        is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z)
922        z *= sizes[d]
923    # OR if any size is zero
924    for d in range(dim):
925        is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0))
926    return is_contiguous
927
928
929# NB: There is a TODO in C++ to allow omitting the batch dim.  If that
930# happens you will need to refactor this
931
932
933def sympy_is_channels_last_contiguous_2d(sizes, strides):
934    return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
935
936
937def sympy_is_channels_last_contiguous_3d(sizes, strides):
938    return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
939
940
941def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
942    import sympy
943
944    from torch.utils._sympy.functions import Max
945
946    dim = len(sizes)
947
948    if dim != len(dim_order):
949        return sympy.false
950
951    m = sympy.Integer(0)
952    r = sympy.true
953
954    # special case for trivial C dimension. default to NCHW
955    r &= sympy.Ne(strides[1], 0)
956
957    for d in dim_order:
958        r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
959        # Fallback to NCHW as default layout for ambiguous cases
960        # This is the flaw of implicit memory_format from strides.
961        # N111 tensor with identical strides for size 1 dimension;
962        # Two cases could lead us here:
963        # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
964        # b. N11W contiguous Tensor sliced on the W-dimension.
965        # ([N,1,1,1]@[W,W,W,W])
966        if d == 0:
967            r &= sympy.Ne(m, strides[1])
968        # This is necessary to:
969        # 1. distinguish the memory_format of N1H1;
970        #     [H, 1, 1, 1] channels_last stride
971        #     [H, H, 1, 1] contiguous stride
972        # 2. permutation of 1C1W:
973        #     [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
974        #     [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
975        #     channels_last
976        m = strides[d] * Max(sizes[d], 1)
977
978    return r
979
980
981def sympy_is_channels_last_strides_2d(sizes, strides):
982    return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
983
984
985def sympy_is_channels_last_strides_3d(sizes, strides):
986    return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
987
988
989def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
990    from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
991
992    return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
993
994
995sizes_strides_methods = {
996    # TODO: These could also be done with indicators, maybe it is better
997    # for reasoning to do it that way
998    "is_contiguous": sympy_is_contiguous,
999    "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
1000    "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
1001    "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
1002    "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
1003    "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
1004}
1005
1006alternate_impl_if_hinted_methods = {
1007    "sym_min": builtins.min,
1008    "sym_max": builtins.max,
1009}
1010
1011
1012def to_node(self, num):
1013    if isinstance(num, SymTypes):
1014        return num.node
1015    elif type(num) is bool:
1016        return self.wrap_bool(num)
1017    elif type(num) is int:
1018        return self.wrap_int(num)
1019    elif type(num) is float:
1020        return self.wrap_float(num)
1021    else:
1022        # NotImplemented is important so that Python tries the
1023        # other magic method
1024        return NotImplemented
1025
1026
1027def wrap_node(x):
1028    # TODO: let C++ also take advantage of this
1029    if isinstance(x, SymNode) and x.constant is not None:
1030        return x.constant
1031    if x.is_int():
1032        return SymInt(x)
1033    elif x.is_float():
1034        return SymFloat(x)
1035    elif x.is_bool():
1036        return SymBool(x)
1037    else:
1038        raise AssertionError(f"unrecognized return type {x}")
1039
1040
1041def method_to_operator(method):
1042    return METHOD_TO_OPERATOR[method]
1043
1044
1045def _make_node_magic(method, func):
1046    func = lru_cache(256)(func)
1047
1048    if method in magic_methods_on_operator_with_trailing_underscore:
1049        method_attr = f"{method}_"
1050    else:
1051        method_attr = method
1052
1053    def binary_magic_impl(self, other):
1054        from torch.fx.experimental.proxy_tensor import (
1055            get_proxy_mode,
1056            handle_sym_dispatch,
1057        )
1058        from torch.fx.experimental.symbolic_shapes import safe_expand
1059
1060        op = method_to_operator(method)
1061
1062        out_hint = None
1063        if self.hint is not None and other.hint is not None:
1064            out_hint = op(self.hint, other.hint)
1065
1066        alternate_impl = alternate_impl_if_hinted_methods.get(method)
1067        if alternate_impl and out_hint is not None:
1068            return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
1069
1070        if get_proxy_mode():
1071            return to_node(
1072                self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
1073            )
1074        assert isinstance(other, SymNode)
1075        try:
1076            if method == "mod":
1077                from torch.utils._sympy.functions import Mod, PythonMod
1078
1079                # Special handling for mod that requires access to the value
1080                # ranges
1081                shape_env = self.shape_env
1082                if (
1083                    self.expr.is_nonnegative
1084                    or shape_env.bound_sympy(self.expr).lower >= 0
1085                ) and (
1086                    other.expr.is_nonnegative
1087                    or shape_env.bound_sympy(other.expr).lower >= 0
1088                ):
1089                    out = Mod(self.expr, other.expr)
1090                else:
1091                    out = PythonMod(self.expr, other.expr)
1092            else:
1093                # TODO: consider constant prop here
1094                out = func(self.expr, other.expr)
1095        except Exception:
1096            log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
1097            raise
1098        out = safe_expand(out)
1099        sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
1100        pytype: Type
1101        # This is not strictly correct. In Python, a**b may return complex when
1102        # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
1103        # returns a float while both arguments are ints: 2**(-1). Also, max and
1104        # min do not type promote. To avoid having data-dependent control flow
1105        # here, we just set the type to float if one of the args is a float. In
1106        # case of a type mismatch, we assume that it will be detected during
1107        # evaluation.
1108        if method in always_float_magic_methods:
1109            pytype = float
1110        elif method in always_bool_magic_methods:
1111            pytype = bool
1112        elif self.pytype is float or other.pytype is float:
1113            pytype = float
1114        else:
1115            pytype = self.pytype
1116
1117        if (
1118            pytype is not None
1119            and out_hint is not None
1120            and not isinstance(out_hint, SymTypes)
1121        ):
1122            out_hint = pytype(out_hint)
1123
1124        # Create a FX node that corresponds to the operation being applied to
1125        # this node.
1126        fx_node, _ = self.shape_env._create_fx_call_function(
1127            op, (self.fx_node, other.fx_node)
1128        )
1129        return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
1130
1131    def unary_magic_impl(self):
1132        from torch.fx.experimental.proxy_tensor import (
1133            get_proxy_mode,
1134            handle_sym_dispatch,
1135        )
1136        from torch.fx.experimental.symbolic_shapes import safe_expand
1137
1138        op = method_to_operator(method)
1139        if get_proxy_mode():
1140            return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
1141        # TODO: consider constant prop here
1142        expr = self.expr
1143        if method == "floor" or method == "ceiling":
1144            expr = self.shape_env._simplify_floor_div(expr)
1145
1146        try:
1147            out = func(expr)
1148        except Exception:
1149            log.warning("failed to eval %s(%s)", method, expr)
1150            raise
1151        sym_node_log.debug("%s %s -> %s", func, expr, out)
1152        out_hint = None
1153        if self.hint is not None:
1154            out_hint = op(self.hint)
1155        out = safe_expand(out)
1156        pytype: Type
1157        if method in always_int_magic_methods:
1158            pytype = int
1159        elif method in always_bool_magic_methods:
1160            pytype = bool
1161        elif method in always_float_magic_methods:
1162            pytype = float
1163        else:
1164            pytype = self.pytype
1165
1166        fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
1167        return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
1168
1169    if method in unary_methods:
1170        setattr(SymNode, f"_{method_attr}", unary_magic_impl)
1171    elif method == "sym_ite":
1172
1173        def sym_ite_impl(pred_node, then_node, else_node):
1174            from torch.fx.experimental.proxy_tensor import (
1175                get_proxy_mode,
1176                handle_sym_dispatch,
1177            )
1178            from torch.fx.experimental.symbolic_shapes import safe_expand
1179
1180            out_hint = then_node.hint if pred_node.hint else else_node.hint
1181            if get_proxy_mode():
1182                return to_node(
1183                    pred_node,
1184                    handle_sym_dispatch(
1185                        sym_ite,
1186                        (
1187                            wrap_node(pred_node),
1188                            wrap_node(then_node),
1189                            wrap_node(else_node),
1190                        ),
1191                        {},
1192                    ),
1193                )
1194
1195            try:
1196                out = func(pred_node.expr, then_node.expr, else_node.expr)
1197            except Exception:
1198                log.warning(
1199                    "failed to eval %s(%s, %s, %s)",
1200                    method,
1201                    pred_node.expr,
1202                    then_node.expr,
1203                    else_node.expr,
1204                )
1205                raise
1206
1207            out = safe_expand(out)
1208            fx_node, _ = pred_node.shape_env._create_fx_call_function(
1209                sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
1210            )
1211            return SymNode(
1212                out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
1213            )
1214
1215        setattr(SymNode, f"_{method_attr}", sym_ite_impl)
1216    elif method == "round":
1217
1218        def round_impl(self, ndigits=None):
1219            from torch.fx.experimental.proxy_tensor import (
1220                get_proxy_mode,
1221                handle_sym_dispatch,
1222            )
1223            from torch.fx.experimental.symbolic_shapes import safe_expand
1224
1225            op = builtins.round
1226            if get_proxy_mode():
1227                return to_node(
1228                    self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
1229                )
1230
1231            expr = self.expr
1232            try:
1233                out = func(expr, ndigits)
1234            except Exception:
1235                log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
1236                raise
1237
1238            out = safe_expand(out)
1239
1240            if ndigits is None:
1241                pytype = int
1242            else:
1243                pytype = self.pytype
1244
1245            out_hint = None
1246            if self.hint is not None:
1247                out_hint = op(self.hint, ndigits)
1248
1249            # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the
1250            # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here
1251            # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The
1252            # hack down below works, because all round function down the line all take ndigits=None as default in their
1253            # signature.
1254            # TODO: Remove the args construction below if a different sentinel is used by FX.
1255            # ezyang(May 2024): LOL
1256            args = [self.fx_node]
1257            if ndigits is not None:
1258                args.append(ndigits)
1259            fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
1260            return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
1261
1262        setattr(SymNode, f"_{method_attr}", round_impl)
1263    else:
1264        setattr(SymNode, f"_{method_attr}", binary_magic_impl)
1265
1266
1267def _make_node_sizes_strides(method, func):
1268    # NB: don't LRU cache, lots of arguments
1269
1270    def sizes_strides_impl(self, sizes, strides):
1271        from torch.fx.experimental.proxy_tensor import (
1272            get_proxy_mode,
1273            handle_sym_dispatch,
1274        )
1275
1276        op = getattr(sys.modules[__name__], method)
1277        if get_proxy_mode():
1278            return to_node(
1279                self,
1280                handle_sym_dispatch(
1281                    op,
1282                    ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
1283                    {},
1284                ),
1285            )
1286        size_exprs = [s.expr for s in sizes]
1287        stride_exprs = [s.expr for s in strides]
1288        try:
1289            out = func(size_exprs, stride_exprs)
1290        except Exception:
1291            log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
1292            raise
1293        # bool is never expandable
1294
1295        size_hints = []
1296        out_hint = None
1297        for s in sizes:
1298            if s.hint is None:
1299                break
1300            size_hints.append(s.hint)
1301        else:
1302            stride_hints = []
1303            for s in strides:
1304                if s.hint is None:
1305                    break
1306                stride_hints.append(s.hint)
1307            else:
1308                out_hint = op(size_hints, stride_hints)
1309
1310        # NB: This is the indicator function, not the actual bool!
1311        pytype: Type
1312        if method.endswith("_indicator"):
1313            pytype = int
1314        else:
1315            pytype = bool
1316        return SymNode(out, self.shape_env, pytype, out_hint)
1317
1318    setattr(SymNode, f"_{method}", sizes_strides_impl)
1319
1320    # TODO: This is technically hotpath, but in the ideal end state
1321    # guards on this will resolve at a higher level so you never
1322    # spend time in this code
1323    def sizes_strides_user(sizes, strides):
1324        import sympy
1325
1326        from torch.fx.experimental.symbolic_shapes import (
1327            eval_is_non_overlapping_and_dense,
1328        )
1329
1330        for a in itertools.chain(sizes, strides):
1331            if isinstance(a, SymInt):
1332                return wrap_node(
1333                    getattr(a.node, method)(
1334                        [to_node(a.node, b) for b in sizes],
1335                        [to_node(a.node, b) for b in strides],
1336                    )
1337                )
1338        if method == "is_non_overlapping_and_dense_indicator":
1339            return eval_is_non_overlapping_and_dense(sizes, strides)
1340        else:
1341            # TODO: this is an awful implementation
1342            return bool(
1343                func(
1344                    [sympy.sympify(a) for a in sizes],
1345                    [sympy.sympify(a) for a in strides],
1346                )
1347            )
1348
1349    # Skip for is_non_overlapping_and_dense_indicator
1350    if not hasattr(sys.modules[__name__], method):
1351        setattr(sys.modules[__name__], method, sizes_strides_user)
1352
1353
1354for method, func in magic_methods.items():
1355    _make_node_magic(method, func)
1356
1357for method, func in sizes_strides_methods.items():
1358    _make_node_sizes_strides(method, func)
1359
1360
1361def _make_user_magic(method, user_type):
1362    # User magic takes care of wrapping the other operand into a node,
1363    # so that our internal logic can assume everything is nodes
1364
1365    if method in magic_methods_on_operator_with_trailing_underscore:
1366        method_attr = f"sym_{method}"
1367    else:
1368        method_attr = method
1369
1370    def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
1371        if isinstance(x, (int, float, bool)):
1372            return x
1373        if isinstance(x, SymBool):
1374            return x.node.guard_bool("", 0)
1375        raise AssertionError("expect to be called with constant SymBools")
1376
1377    def is_constant(x):
1378        if isinstance(x, (int, float, bool)):
1379            return True
1380        if isinstance(x, (SymInt, SymFloat, SymBool)):
1381            return x.node.is_constant()
1382        return False
1383
1384    # Promotion rules for binary operations.  NB: we preserve PYTHON semantics
1385    #   - if args are same type, do nothing
1386    #   - if one arg is float, promote other arg to float
1387    #       - nb: this applies to floordiv, even though output is integral
1388    #       (it's still float)
1389    #   - pow is funny business
1390    #       - if both ints
1391    #       - trigger a guard on exponent >= 0
1392    #           - if non-negative, output is int
1393    #           - otherwise, output is float
1394    #   - otherwise, promote other arg to float
1395    #       - nb: complex is impossible to handle correctly lol, with
1396    #       negative base and integral float need to diverge semantics and
1397    #       just always return complex.  Neener neener pretend this problem
1398    #       doesn't exist
1399    #   - equality is pain: Python does the fancy thing where it unpacks the
1400    #     mantissa from the float and then compares that against the int.
1401    #     Which means it is able to tell that
1402    #     9007199254740993 != 9007199254740992. (rather than if the LHS was
1403    #     promoted to float, in which case it would have truncated to the RHS
1404    #     and subsequently been equal).  We'll model this exactly by having
1405    #     special mixed type equality operations.  Unfortunately, we need to
1406    #     do this for all comparison operations (maybe I'll only implement
1407    #     compare)
1408    #   - sym_ite mumble mumble really shouldn't allow mixed but whatever
1409
1410    if method in bool_becomes_int_magic_methods:
1411
1412        def promote(x):
1413            """Implements True+True=2, which works in python but not sympy"""
1414            if isinstance(x, SymBool):
1415                return SymInt(x.node.wrap_int(int(x)))
1416            return x
1417
1418    else:
1419
1420        def promote(x):
1421            return x
1422
1423    def promote2(self, other):
1424        # TODO: Remove eq and other relations from this list.
1425        # CPython has fancy implementations for these to get as much precision
1426        # as possible instead of just promoting to float64 and praying, so we
1427        # need to handle them specially too.
1428        # Also, note that int_truediv doesn't go through this path: both
1429        # arguments are "int" so there isn't any promotion
1430        if method not in [
1431            "add",
1432            "sub",
1433            "mul",
1434            "mod",
1435            "float_pow",
1436            "float_truediv",
1437            "int_floordiv",
1438            "sym_min",
1439            "sym_max",
1440            # TODO: remove these
1441            "eq",
1442            "ne",
1443            "gt",
1444            "lt",
1445            "le",
1446            "ge",
1447        ]:
1448            return self, other
1449        f_self = isinstance(self, (float, torch.SymFloat))
1450        f_other = isinstance(other, (float, torch.SymFloat))
1451        if f_self or f_other:
1452            if not f_self:
1453                self = torch.sym_float(self)
1454            if not f_other:
1455                other = torch.sym_float(other)
1456        return self, other
1457
1458    # Before and after performing the operation, check if any operands are constant.
1459    # If so, extract out the constant values first. If `self` itself is a
1460    # constant, then "redispatch" by calling back into the operator. Sometimes
1461    # this means that operations involving SymBool return plain bools.
1462    # Alternatively, we could also rewrap into constant Symbool (i.e. by
1463    # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
1464    # today for no particular reason.
1465    def unary_magic_impl(self):
1466        self = promote(self)
1467        if is_constant(self):
1468            return (method_to_operator(method))(get_constant(self))
1469        return wrap_node(getattr(self.node, method_attr)())
1470
1471    def binary_magic_impl(self, other):
1472        if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
1473            return NotImplemented
1474        sym_node_log.debug("MAGIC %s %s %s", method, self, other)
1475        self = promote(self)
1476        other = promote(other)
1477        self, other = promote2(self, other)
1478        if is_constant(self):
1479            return (method_to_operator(method))(get_constant(self), other)
1480        if is_constant(other):
1481            other = get_constant(other)
1482        other_node = to_node(self.node, other)
1483        if other_node is NotImplemented:
1484            return NotImplemented
1485        ret = wrap_node(getattr(self.node, method_attr)(other_node))
1486        return get_constant(ret) if is_constant(ret) else ret
1487
1488    def rbinary_magic_impl(self, other):
1489        if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
1490            return NotImplemented
1491        self = promote(self)
1492        other = promote(other)
1493        self, other = promote2(self, other)
1494        if is_constant(self):
1495            return (method_to_operator(method))(get_constant(self), other)
1496        if is_constant(other):
1497            other = get_constant(other)
1498        other_node = to_node(self.node, other)
1499        if other_node is NotImplemented:
1500            return NotImplemented
1501        ret = wrap_node(getattr(other_node, method_attr)(self.node))
1502        return get_constant(ret) if is_constant(ret) else ret
1503
1504    if method in unary_magic_methods:
1505        setattr(user_type, f"__{method}__", unary_magic_impl)
1506    elif method in unary_nonmagic_methods:
1507        orig = getattr(user_type, method)
1508        setattr(user_type, method, update_wrapper(unary_magic_impl, orig))
1509    elif method == "sym_ite":
1510
1511        def sym_ite_magic_impl(pred, then_val, else_val):
1512            pred_node = pred.node
1513            then_node = to_node(pred_node, then_val)
1514            else_node = to_node(pred_node, else_val)
1515            if then_node is NotImplemented or else_node is NotImplemented:
1516                return NotImplemented
1517            assert (
1518                isinstance(then_node, SymNode)
1519                and isinstance(else_node, SymNode)
1520                and then_node.pytype == else_node.pytype
1521            )
1522            ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
1523            return get_constant(ret) if ret.node.is_constant() else ret
1524
1525        setattr(user_type, f"__{method}__", sym_ite_magic_impl)
1526    elif method == "round":
1527
1528        def round_magic_impl(self, ndigits=None):
1529            if is_constant(self):
1530                return builtins.round(get_constant(self), ndigits)
1531
1532            return wrap_node(getattr(self.node, method)(ndigits))
1533
1534        setattr(user_type, f"__{method}__", round_magic_impl)
1535    else:
1536        setattr(user_type, f"__{method}__", binary_magic_impl)
1537        if method in reflectable_magic_methods:
1538            setattr(user_type, f"__r{method}__", rbinary_magic_impl)
1539
1540
1541for method, func in magic_methods.items():  # type: ignore[assignment]
1542    if method in only_bool_magic_methods:
1543        _make_user_magic(method, SymBool)
1544        continue
1545    if method in only_float_magic_methods:
1546        _make_user_magic(method, SymFloat)
1547        continue
1548    if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
1549        _make_user_magic(method, SymBool)
1550    _make_user_magic(method, SymInt)
1551    _make_user_magic(method, SymFloat)
1552
1553del method
1554del func
1555