xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/symbolic_shapes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3"""
4``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with
5our symbolic shapes reasoning system that is used heavily in torch.compile.  Although
6this is not generally considered public API, when writing framework code in PyTorch
7as well as extensions to PyTorch (e.g., in custom operator implementations), you may
8need to make use of these APIs to setup dynamic shapes support appropriately.
9"""
10
11import builtins
12import collections
13import functools
14import inspect
15import itertools
16import logging
17import math
18import operator
19import os
20import re
21import sys
22import threading
23import traceback
24from collections import defaultdict
25from contextlib import contextmanager
26from dataclasses import dataclass, field
27from enum import Enum
28import atexit
29from typing import (
30    Any,
31    cast,
32    Callable,
33    Dict,
34    Iterable,
35    List,
36    Optional,
37    Sequence,
38    Set,
39    Tuple,
40    Type,
41    Union,
42    TYPE_CHECKING
43)
44from typing_extensions import TypeAlias
45
46import torch
47import torch.fx
48import torch.fx.traceback as fx_traceback
49from torch.fx.experimental import _config as config
50
51from torch.fx.experimental.recording import (
52    FakeTensorMeta,
53    ShapeEnvEvent,
54    record_shapeenv_event,
55    replay_shape_env_events,
56    shape_env_check_state_equal
57)
58from torch.fx.experimental.sym_node import SymNode, SymTypes
59from torch._logging import trace_structured, structured
60
61# NB: The sym_* functions are used via getattr() and must be imported here.
62from torch import SymBool, SymFloat, SymInt
63from torch._guards import ShapeGuard, Source, TracingContext
64from torch.utils._python_dispatch import is_traceable_wrapper_subclass
65from torch.utils._sympy.functions import (
66    Application, FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt
67)
68from torch.utils._sympy.solve import try_solve
69from torch.utils._sympy.numbers import int_oo
70from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
71from torch.utils._sympy.singleton_int import SingletonInt
72from torch.utils._traceback import format_frame, CapturedTraceback
73from torch._utils_internal import signpost_event
74from torch._subclasses.meta_utils import is_sparse_any
75import torch.utils._pytree as pytree
76from torch.utils._sympy.symbol import SymT, make_symbol, symbol_is_type
77
78from torch._logging import LazyString
79
80if TYPE_CHECKING:
81    from torch._dynamo.source import TensorPropertySource
82
83InputList = List
84DimList = List
85
86log = logging.getLogger(__name__)
87
88import sympy
89from sympy.printing.str import StrPrinter
90from sympy.printing.precedence import precedence, PRECEDENCE
91
92class GuardOnDataDependentSymNode(RuntimeError):
93    cond: sympy.Expr
94
95    def __init__(self, cond, *args):
96        super().__init__(*args)
97        self.cond = cond
98
99class PendingUnbackedSymbolNotFound(RuntimeError):
100    pass
101
102aten = torch._ops.ops.aten  # type: ignore[has-type]
103
104__all__ = [
105    "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
106    "guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr",
107    "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
108    "is_concrete_bool", "is_nested_int", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
109    "has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
110    "StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
111    "guard_size_oblivious", "check_consistent",
112    "compute_unbacked_bindings", "ConvertIntKey",
113    "rebind_unbacked", "resolve_unbacked_bindings", "is_accessor_node",
114]
115
116# FX node metadata keys for symbolic shape FX graph.
117SHAPEENV_EVENT_KEY = "shapeenv_event"
118CURRENT_NODE_KEY = "current_node"
119
120
121def log_lru_cache_stats(wrapped_f):
122    log.debug("lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info())
123
124
125# Wrapper on lru_cache that reports statistics at process end
126def lru_cache(maxsize):
127    def inner(f):
128        wrapped_f = functools.lru_cache(maxsize)(f)
129        old_cache_clear = wrapped_f.cache_clear
130        prev_hits = 0
131        prev_misses = 0
132
133        # TODO: There's a ref-cycle here (wrapped_f -> cumulative_cache_info
134        # -> wrapped_f) but cannot be solved with weakref as wrapped_f is not
135        # weakref'able on some versions of Python
136
137        def cumulative_cache_info():
138            cur = wrapped_f.cache_info()
139            return functools._CacheInfo(
140                prev_hits + cur.hits,
141                prev_misses + cur.misses,
142                cur.maxsize,
143                cur.currsize,
144            )
145
146        def new_cache_clear():
147            nonlocal prev_hits, prev_misses
148            cur = wrapped_f.cache_info()
149            prev_hits += cur.hits
150            prev_misses += cur.misses
151            old_cache_clear()
152
153        wrapped_f.cache_clear = new_cache_clear
154        wrapped_f.cumulative_cache_info = cumulative_cache_info
155        if log.isEnabledFor(logging.DEBUG):
156            atexit.register(log_lru_cache_stats, wrapped_f)
157        return wrapped_f
158
159    return inner
160
161# These are modules that contain generic code for interacting with ShapeEnv
162# which are unlikely to identify a particular interesting guard statement
163@lru_cache(None)
164def uninteresting_files() -> Set[str]:
165    import torch._inductor.sizevars
166    import torch._library.fake_impl
167    import torch._subclasses.meta_utils
168    import torch._subclasses.fake_tensor
169    mods = [
170        sys.modules[__name__],
171        torch.fx.experimental.recording,
172        torch.fx.experimental.sym_node,
173        torch.fx.interpreter,
174        torch,
175        torch._inductor.sizevars,
176        torch._library.fake_impl,
177        torch._subclasses.meta_utils,
178        torch._subclasses.fake_tensor,
179    ]
180    return {inspect.getfile(m) for m in mods}
181
182# We don't bother with the metaclass as all of the dispatching logic happens
183# entirely from Python
184#
185# Didn't bother with ancestors for now, unlikely to have multiple modes for
186# symints right now
187
188class ConstraintViolationError(RuntimeError):
189    pass
190
191def has_symbolic_sizes_strides(elem) -> bool:
192    return elem._has_symbolic_sizes_strides
193
194Int = Union[torch.SymInt, int]
195
196def create_contiguous(shape: Sequence[Int]) -> List[Int]:
197    strides: List[Int] = [1]
198    for dim in reversed(shape[:-1]):
199        strides.append(dim * strides[-1])
200    return list(reversed(strides))
201
202def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int:
203    """
204    Retrieve the hint for an int (based on the underlying real values as observed
205    at runtime).  If no hint is available (e.g., because data dependent shapes),
206    if fallback is not None, use that instead (otherwise raise an error).
207    """
208    if isinstance(a, torch.SymInt):
209        return a.node.require_hint(fallback)
210    assert type(a) is int, a
211    return a
212
213Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool]
214
215def has_hint(a: Scalar) -> bool:
216    if isinstance(a, SymTypes):
217        return a.node.has_hint()
218    return True
219
220def is_concrete_int(a: Union[int, SymInt]) -> bool:
221    r""" Utility to check if underlying object
222    in SymInt is concrete value. Also returns
223    true if integer is passed in.
224
225    Args:
226        a (SymInt or int): Object to test if it int
227    """
228    assert isinstance(a, (SymInt, int))
229
230    if isinstance(a, int):
231        return True
232
233    if isinstance(a.node.expr, sympy.core.numbers.Integer):
234        return True
235
236    return False
237
238# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime.
239# So make sure only type checker evaluates this alias.
240# Xref: https://www.internalfb.com/diff/D53324783
241SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean"
242
243def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
244    """
245    Perform a guard on a symbolic boolean expression in a size oblivious way.
246    This is typically used when a non-oblivious test would result in a guard
247    on a data dependent value of which we don't know the value of at compile time.
248    When a guard is tested this way, we may diverge in behavior from how regular
249    PyTorch semantics would treat it.  For more information, see
250    https://github.com/pytorch/pytorch/pull/118579
251    """
252    if isinstance(expr, torch.SymBool):
253        return expr.node.guard_size_oblivious("", 0)
254    else:
255        assert isinstance(expr, bool), expr
256        return expr
257
258def check_consistent(new, old) -> None:
259    """
260    Test that two "meta" values (typically either Tensor or SymInt) have
261    the same values, e.g., after retracing.  If we don't understand the
262    quantities in question, we'll just skip the consistency check.
263    """
264    # TODO: do boolean equality test too, see
265    # https://github.com/pytorch/pytorch/issues/124110
266    scalar_types = (torch.SymInt, torch.SymFloat, int, float)
267
268    if isinstance(new, torch.Tensor):
269        assert isinstance(old, torch.Tensor)
270        torch._check(old.dim() == new.dim(), lambda: f"{old.shape} != {new.shape} (old != new)")
271        # Do this manually so that each individual test is irrefutable
272        # (TODO: should be a helper for this, maybe sym_eq?  That
273        # gives us a compound expression and I'm not sure it
274        # simplifies right now)
275        for i, j in zip(old.shape, new.shape):
276            torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)")
277    # NB: bool is subclass of int
278    elif isinstance(new, scalar_types) and not isinstance(new, bool):
279        assert isinstance(old, scalar_types) and not isinstance(old, bool), f"{old} != {new}"
280        torch._check(old == new, lambda: f"{old} != {new} (old != new)")
281
282def resolve_unbacked_bindings(shape_env, bindings):
283    if bindings is None:
284        return None
285    return {
286        shape_env.unbacked_renamings.get(k, k): v
287        for k, v in bindings.items()
288    }
289
290def rebind_unbacked(shape_env, n: torch.fx.Node, result):
291    """
292    Suppose we are retracing a pre-existing FX graph that previously had
293    fake tensor propagation (and therefore unbacked SymInts).  When we retrace,
294    we re-propagate fake tensors, which results in new unbacked SymInts.
295    When this happens, we need to tell the shape environment about the equivalence
296    of the old and new unbacked SymInts.  Pass us the old torch.fx.Node (which
297    has the old binding information) and the new result (which we can extract the
298    new unbacked SymInts out from).
299    """
300    from torch._dynamo.tensor_version_op import _tensor_version
301
302    # Inputs never need rebinding
303    if n.op == "placeholder":
304        return
305
306    if bindings := resolve_unbacked_bindings(shape_env, n.meta.get("unbacked_bindings")):
307        for raw_u0, path in bindings.items():
308            u1 = pytree.key_get(result, path)
309            # tensor_version ops get specialized after AOTAutograd, it's OK,
310            # we don't actually want to do asserts on them.  This is all a bit
311            # questionable though
312            if isinstance(u1, int) and n.target is _tensor_version:
313                log.info("rebind_unbacked: discard _tensor_version %s %s -> %s", raw_u0, path, u1)
314                continue
315            raw_u1 = u1.node.expr
316            # Simplify SymBool binding
317            if (
318                isinstance(raw_u1, sympy.Piecewise) and
319                len(raw_u1.args) == 2 and
320                raw_u1.args[0][0] == 1 and
321                isinstance(eq := raw_u1.args[0][1], sympy.Eq) and
322                isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) and
323                shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) and
324                eq.rhs == 1 and
325                raw_u1.args[1] == (0, True)
326            ):
327                # This is what the pattern match above is testing
328                repacked = _sympy_cast_symbool_to_symint_guardless(sympy.Eq(new_raw_u1, 1))
329                assert repacked == raw_u1, f"{repacked} != {raw_u1}"
330                # Cancel the to_int(to_bool(x)). This is sound because x in
331                # [0, 1]
332                raw_u1 = new_raw_u1
333            assert isinstance(raw_u1, sympy.Symbol)
334            # The old and new could be the same if you improperly hit the memo
335            # while retracing.  Make sure you updated FakeTensorMode.epoch
336            assert raw_u0 != raw_u1, f"{raw_u0} possible memo disaster"
337            # Reuse the OLD symbol name
338            shape_env._rename_unbacked_to(raw_u1, raw_u0)
339
340# NB: You could try to expand this to cover more cases by simply
341# detecting whenever you have an int output, but this is a bit
342# dangerous in case someone adds a function that returns an int but is
343# mutating.  So manually whitelist for now.
344def is_accessor_node(node: torch.fx.Node) -> bool:
345    # Dynamo only exercised condition
346    if (
347        node.op == "call_method"
348        and isinstance(node.args[0].meta.get("example_value"), torch.Tensor)
349        and node.target in ["size", "stride", "storage_offset", "item"]
350    ):
351        return True
352    if node.op == "call_function" and node.target in [
353        torch.ops.aten.sym_size,
354        torch.ops.aten.sym_size.default,
355        torch.ops.aten.sym_size.int,
356        torch.ops.aten.sym_stride,
357        torch.ops.aten.sym_stride.default,
358        torch.ops.aten.sym_stride.int,
359        torch.ops.aten.sym_storage_offset,
360        torch.ops.aten.sym_storage_offset.default,
361        torch.ops.aten.sym_numel.default,
362    ]:
363        return True
364    return False
365
366def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
367    r""" Canonicalize a boolean expression by transforming it into a lt / le
368    inequality and moving all the non-constant terms to the rhs.
369    We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr
370    recursively
371    nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924
372
373    Args:
374        expr (sympy.Expr): Expression to canonicalize
375    """
376    # Canonicalise an inequality by transforming it into a lt / le
377    # inequality and moving all the non-constant terms to the rhs
378    # We canonicalise And / Ors / Not via cnf
379    # nb. Relational.canonical in sympy is broken
380    # https://github.com/sympy/sympy/issues/25924
381
382    if not isinstance(expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne)):
383        return expr
384
385    if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)):
386        expr = sympy.logic.boolalg.to_cnf(expr)
387    return _canonicalize_bool_expr_impl(expr)
388
389def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
390    """
391    After canonicalization, we are guaranteed to have eliminated Ge/Gt relations
392    (rewriting them to Le/Lt, respectively).
393    """
394    if isinstance(expr, (sympy.And, sympy.Or)):
395        return type(expr)(*map(canonicalize_bool_expr, expr.args))
396
397    opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le}
398    if isinstance(expr, tuple(opposite.keys())):
399        rhs = expr.lhs - expr.rhs
400        t = opposite[type(expr)]
401    else:
402        assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne))
403        rhs = expr.rhs - expr.lhs
404        t = type(expr)
405
406    def is_neg(t):
407        return t.is_negative or (isinstance(t, sympy.Mul) and t.args[0].is_negative)
408
409    lhs = 0
410    rhs = _reduce_to_lowest_terms(rhs)
411    if isinstance(rhs, sympy.Add):
412        pos = []
413        neg = []
414        for term in rhs.args:
415            if is_neg(term):
416                neg.append(-term)
417            else:
418                pos.append(term)
419        lhs = sympy.Add(*neg)
420        rhs = sympy.Add(*pos)
421    elif is_neg(rhs):
422        # lhs == 0
423        lhs, rhs = -rhs, 0
424    return t(lhs, rhs)
425
426
427def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr:
428    """
429    Eliminates any integer factor from a given expression.
430    E.g., 6x + 4y reduces to 3x + 2y.
431
432    Useful when an expression is == or != to 0.
433    """
434    def integer_coefficient(x):
435        if isinstance(x, sympy.Integer):
436            return abs(int(x))
437        elif isinstance(x, sympy.Mul):
438            return math.prod([abs(int(arg)) for arg in x.args if isinstance(arg, sympy.Integer)])
439        else:
440            return 1
441
442    if isinstance(expr, sympy.Add):
443        atoms = expr.args
444        factor = functools.reduce(math.gcd, map(integer_coefficient, atoms))
445        atoms = [x / factor for x in atoms]
446        return sympy.Add(*atoms)
447    else:
448        return expr / integer_coefficient(expr)
449
450
451def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
452    r""" Utility to check if underlying object
453    in SymBool is concrete value. Also returns
454    true if integer is passed in.
455    Args:
456        a (SymBool or bool): Object to test if it bool
457    """
458    assert isinstance(a, (SymBool, bool))
459
460    if isinstance(a, bool):
461        return True
462
463    if isinstance(a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse)):
464        return True
465
466    return False
467
468def is_nested_int(s):
469    return isinstance(s, torch.SymInt) and s.node.is_nested_int()
470
471def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
472    if isinstance(val, SymTypes):
473        # This allow applies to the jagged layout NestedTensor case as
474        # nested ints are not symbolic
475        if is_symbolic(val):
476            yield val.node.expr
477    elif isinstance(val, sympy.Basic):
478        yield val
479    elif isinstance(val, (int, float, bool)):
480        pass
481    elif isinstance(val, (tuple, list)):
482        for s in val:
483            yield from _iterate_exprs(s)
484    elif is_sparse_any(val):
485        yield from _iterate_exprs(val.size())
486    elif isinstance(val, torch.Tensor):
487        yield from _iterate_exprs(val.size())
488        yield from _iterate_exprs(val.stride())
489        yield from _iterate_exprs(val.storage_offset())
490    elif val is None:
491        pass
492    else:
493        raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
494
495def free_symbols(val: Union[SymInt, sympy.Expr, torch.Tensor]) -> Set[sympy.Symbol]:
496    if val is None:
497        return set()
498    itr = _iterate_exprs(val)
499    # we need at least 1 to call union, so we hand code the identity
500    try:
501        first_expr = next(itr)
502    except StopIteration:
503        return set()
504
505    return first_expr.free_symbols.union(*(e.free_symbols for e in itr))
506
507def has_free_symbols(val: Union[SymInt, torch.Tensor]) -> bool:
508    """Faster version of bool(free_symbols(val))"""
509    return not all(e.is_number for e in _iterate_exprs(val))
510
511# Like free_symbols, but filtered to only report unbacked symbols
512def free_unbacked_symbols(x):
513    # NB: keep synced with is_unbacked_symint
514    return {s for s in free_symbols(x) if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))}
515
516# WARNING: Don't use this on Dynamo produced graphs, they don't have meta
517# setup!
518def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]:
519    if (
520        "val" in node.meta and
521        isinstance(node.meta["val"], torch.SymInt) and
522        isinstance(node.meta["val"].node.expr, sympy.Symbol) and
523        (node.op == "placeholder" or free_unbacked_symbols(node.meta["val"].node.expr))
524    ):
525        return node.meta["val"].node.expr
526    return None
527
528def find_symbol_binding_fx_nodes(graph):
529    r = {}
530    # NB: Prefer first occurrence of symbol
531    for node in graph.nodes:
532        if is_symbol_binding_fx_node(node) and node.meta["val"].node.expr not in r:
533            r[node.meta["val"].node.expr] = node
534    return r
535
536
537# Analogous to ConvertIntSource
538@dataclass(frozen=True)
539class ConvertIntKey:
540    def __str__(self) -> str:
541        return ".cast_symbool_to_symint_guardless()"
542
543    def get(self, b: bool) -> int:
544        """Get the int value from bool"""
545        return cast_symbool_to_symint_guardless(b)
546
547
548@dataclass(frozen=True)
549class CallMethodKey:
550    name: str
551
552    def __str__(self) -> str:
553        return f".{self.name}()"
554
555    def get(self, o: Any) -> Any:
556        """Call the method on object"""
557        return getattr(o, self.name)()
558
559
560@dataclass(frozen=True)
561class InnerTensorKey:
562    inner_name: str
563
564    def __str__(self) -> str:
565        return f".{self.inner_name}"
566
567    def get(self, o: Any) -> Any:
568        """Get the inner tensor attribute"""
569        return getattr(o, self.inner_name)
570
571
572@dataclass(frozen=True)
573class DivideByKey:
574    divisor: int
575
576    def __str__(self) -> str:
577        return f".__floordiv__({self.divisor})"
578
579    def get(self, o: int) -> int:
580        """Divide object by divisor"""
581        return o // self.divisor
582
583
584def compute_unbacked_bindings(shape_env, example_value, old_example_value=None, peek=False):
585    """
586    After having run fake tensor propagation and producing example_value
587    result, traverse example_value looking for freshly bound unbacked
588    symbols and record their paths for later.  It is an error if
589    we have allocated an unbacked SymInt but it cannot be found in
590    example_value.  (NB: this means if you have a multi-output
591    function, you must call this on the tuple of tensor output, you
592    cannot wait!)
593
594    The peek parameter lets you check out what the bindings are without
595    changing the affected list.  This is primarily useful for ensuring
596    unbacked_var_to_val is promptly populated when propagate_real_tensors is on.
597    """
598    if shape_env is None:
599        return
600    fs = shape_env.pending_fresh_unbacked_symbols
601    pending = set(fs)
602    if pending:
603        if not peek:
604            log.info("compute_unbacked_bindings %s", fs)
605            fs.clear()
606
607        def free_unbacked_symbols_with_path(
608            a, path, real=None
609        ) -> Dict[sympy.Symbol, pytree.KeyPath]:
610            r = {}
611            if isinstance(a, (tuple, list)):
612                for i in range(len(a)):
613                    r.update(
614                        free_unbacked_symbols_with_path(
615                            a[i], path + (pytree.SequenceKey(i),),
616                            real=real[i] if real is not None else None
617                        )
618                    )
619            elif is_traceable_wrapper_subclass(a):
620                # TODO: Determine if this is correct
621                attrs, _ = a.__tensor_flatten__()
622                for attr in attrs:
623                    sub = getattr(a, attr)
624                    r.update(
625                        free_unbacked_symbols_with_path(sub, path + (InnerTensorKey(attr),))
626                    )
627            elif isinstance(a, torch.Tensor):
628                r.update(
629                    free_unbacked_symbols_with_path(
630                        a.size(), path + (CallMethodKey("size"),),
631                        real=a.real_tensor.size() if a.real_tensor is not None else None
632                    )
633                )
634                r.update(
635                    free_unbacked_symbols_with_path(
636                        a.stride(), path + (CallMethodKey("stride"),),
637                        real=a.real_tensor.stride() if a.real_tensor is not None else None
638                    )
639                )
640                r.update(
641                    free_unbacked_symbols_with_path(
642                        a.storage_offset(), path + (CallMethodKey("storage_offset"),),
643                        real=a.real_tensor.storage_offset() if a.real_tensor is not None else None
644                    )
645                )
646
647            # NB: Intentionally access _expr, not expr, do not want
648            # simplification!
649            elif (
650                isinstance(a, (torch.SymInt, torch.SymFloat))
651                and isinstance(s := a.node._expr, sympy.Symbol)
652                and s in pending
653            ):
654                r[s] = path
655                if real is not None:
656                    shape_env.set_unbacked_var_to_val(s, real)
657                pending.remove(s)
658            # When an unbacked SymInt is perfectly divisible by an integer
659            # constant, we replace it with the integer constant to improve
660            # reasoning capabilities.  However, in synthetic examples, it is
661            # then possible that the factor never is explicitly allocated.
662            # Fortunately, we can compute it by division.
663            elif (
664                isinstance(a, torch.SymInt)
665                and isinstance(s := a.node._expr, sympy.Mul)
666                and len(s.args) == 2
667                and isinstance(lhs := s.args[0], sympy.Integer)
668                and isinstance(rhs := s.args[1], sympy.Symbol)
669                and rhs in pending
670            ):
671                # TODO: DivideByKey needs to test divisibility at runtime!
672                r[s] = path + (DivideByKey(int(lhs)),)
673                if real is not None:
674                    shape_env.set_unbacked_var_to_val(s, real // int(lhs))
675                pending.remove(rhs)
676            # The annoyance here arises from the fact that SymBool is
677            # allocated by allocating a SymInt and then testing if it's equal
678            # to one.  So you have a complicated binding site logic for this.
679            elif (
680                isinstance(a, torch.SymBool)
681                and isinstance(s := a.node._expr, sympy.Eq)
682                # This must match create_unbacked_symbool EXACTLY
683                and isinstance(s.lhs, sympy.Symbol)
684                and s.rhs == 1
685                and s.lhs in pending
686            ):
687                r[s.lhs] = path + (ConvertIntKey(),)
688                if real is not None:
689                    shape_env.set_unbacked_var_to_val(s, int(real))
690                pending.remove(s.lhs)
691
692            return r
693
694        symbol_to_path = free_unbacked_symbols_with_path(example_value, ())
695        if not peek and pending:
696            extra = (
697                repr((example_value.stride(), example_value.storage_offset()))
698                if isinstance(example_value, torch.Tensor)
699                else ""
700            )
701            raise PendingUnbackedSymbolNotFound(
702                f"Pending unbacked symbols {pending} not in returned outputs {example_value} {extra}.\n"
703                "Did you accidentally call new_dynamic_size() or item() more times "
704                "than you needed to in your fake implementation?\n"
705                "For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit"
706            )
707
708        # Why do we have to do some rebinding here?  If the original FX node
709        # wasn't a binding site because you had a memo hit, but post
710        # translation you aren't a memo hit anymore, there's now a new binding
711        # site... but we know (because it's the same FX node) that the value
712        # is actually the same, they're just not obviously equal anymore.
713        #
714        # The logic here is written carefully, because unlike the
715        # bind_unbacked case, we are not guaranteed to have a symbol for
716        # old_sym.  If we have a symbol, do regular rename unbacked to; but if
717        # we don't, we need to specially eliminate the fresh unbacked symbol
718        # (NB: we are /trusting/ that the memoization is correct, and that we
719        # don't need to generate a new runtime assert.  This is load bearing,
720        # as repropagation can happen after we've frozen runtime asserts.)
721        if old_example_value is not None:
722            for keypath in symbol_to_path.values():
723                old_sym = pytree.key_get(old_example_value, keypath)
724                new_sym = pytree.key_get(example_value, keypath)
725                if (
726                    isinstance(new_sym, SymTypes) and
727                    isinstance(new_s := new_sym.node.expr, sympy.Symbol)
728                ):
729                    if isinstance(old_sym, SymTypes) and (old_s := old_sym.node.expr) != new_s:
730                        if isinstance(old_s, sympy.Symbol):
731                            shape_env._rename_unbacked_to(new_s, old_s)
732                        else:
733                            shape_env._eliminate_unbacked(new_s, old_s)
734                    elif not isinstance(old_sym, SymTypes):
735                        shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym))
736
737        return symbol_to_path
738
739def definitely_true(a):
740    """
741    Returns True only if we can tell that a is True, possibly introducing
742    a guard in the process.  If a depends on some unbacked SymInt, we may
743    return False even though there may exist a possible value of the SymInt
744    that would cause the expression to return True.
745
746    When is it appropriate to use definitely_true?  First, if you can use
747    a higher level combinator like parallel_or/parallel_and, prefer using
748    those instead, they are definitely safe (modulo short-circuiting).
749    Second, it can be used if the program would behave equivalently if
750    definitely_true always returned False (parallel_or/parallel_and are
751    examples of this pattern, modulo short-circuiting).  Finally, it even
752    be OK if the program wouldn't behave equivalently, so long as the
753    change is semantics preserving.  It can be semantics preserving if
754    the program errors in more cases than it did previously (but otherwise
755    behaves identically), or if it changes some quantity in a way that
756    doesn't matter (e.g., strides often fall in this bucket.)
757    """
758    if isinstance(a, SymBool):
759        if a.node.has_hint():
760            return guard_bool(a)
761        else:
762            return False
763    return bool(a)
764
765def definitely_false(a):
766    """
767    Returns True only if we can tell that a is False, possibly introducing
768    a guard in the process.  If a depends on some unbacked SymInt, we may
769    return False even though there may exist a possible value of the SymInt
770    that would cause the expression a to be False.  See definitely_true
771    for more usage guidance.
772    """
773    if isinstance(a, SymBool):
774        if a.node.has_hint():
775            return not guard_bool(a)
776        else:
777            return False
778    return not bool(a)
779
780def statically_known_true(x: Union[bool, SymBool]) -> bool:
781    """Returns True if x can be simplified to a constant and is true.
782
783    .. note::
784        This function doesn't introduce new guards, so the expression may end
785        up evaluating to true at runtime even if this function returns False.
786
787    Args:
788        x (bool, SymBool): The expression to try statically evaluating
789
790    """
791    if isinstance(x, SymBool):
792        expr = x.node.expr
793        shape_env = x.node.shape_env
794        try:
795            simplified = shape_env._maybe_evaluate_static(expr)
796            if simplified is not None:
797                return bool(simplified)
798        except Exception:
799            log.debug("Could not simplify %s", expr)
800        return False
801    assert isinstance(x, bool)
802    return x
803
804
805def parallel_or(*args):
806    """
807    Evaluate the logical OR of several arguments, avoiding guarding on
808    unbacked SymInts if another argument is definitely True.
809    """
810    if any(statically_known_true(a) for a in args):
811        return True
812    if any(definitely_true(a) for a in args):
813        return True
814    return any(args)
815
816def parallel_and(*args):
817    """
818    Evaluate the logical FALSE of several arguments, avoiding guarding on
819    unbacked SymInts if another argument is definitely False.
820    """
821    if any(statically_known_true(torch.sym_not(a)) for a in args):
822        return False
823    if any(definitely_false(a) for a in args):
824        return False
825    return all(args)
826
827def sym_eq(x, y):
828    """
829    Like ==, but when run on list/tuple, it will recursively test equality
830    and use sym_and to join the results together, without guarding.
831    """
832    if (isinstance(x, tuple) and isinstance(y, tuple)) or (isinstance(x, list) and isinstance(y, list)):
833        if len(x) != len(y):
834            return False
835        return functools.reduce(operator.and_, map(sym_eq, x, y), True)
836    elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)):
837        return x == y
838    else:
839        raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}")
840
841def guard_scalar(a):
842    if isinstance(a, (SymBool, bool)):
843        return guard_bool(a)
844    elif isinstance(a, (SymInt, int)):
845        return guard_int(a)
846    elif isinstance(a, (SymFloat, float)):
847        return guard_float(a)
848    else:
849        raise AssertionError(f"unrecognized scalar {a}")
850
851
852def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int):
853    shape_env.constrain_symbol_range(s, compiler_min, compiler_max)
854
855
856def _advise_is_size(a):
857    """
858    Don't use this directly; use torch._check_is_size instead.
859
860    This is a softer version of _constrain_range_for_size (with min=0,
861    max=Inf).  Instead of forcibly constraining a variable (and erroring if we
862    failed to constrain it), it will simply advise us that a size is
863    constrained in some way.  We will always defer a runtime assert for this
864    constraint if we cannot prove it at compile-time, but we we only
865    *sometimes* learn useful extra information at compile-time with this
866    information.  This is in contrast to constrain_range_for_size, where if
867    you don't call that on a fresh unbacked symint, chances are we will choke.
868
869    TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed
870    code.  Right now this is only really used in code with AOTAutograd trace
871    through, so it is not a big problem that this isn't supported, but in
872    principle all of this code should be Dynamo'able too.
873
874    TODO: I didn't support min/max because I didn't have a use case where this
875    actually helped.  In principle we can support it, it just makes the
876    implementation below more complicated.
877    """
878
879    # This must always succeed, because the sole allowed caller _check_is_size
880    # was responsible for expect_true'ing this
881    # This assert triggers expensive sym compute, do not do it until its cheap.
882    # assert a >= 0
883
884    # NB: it's important not to constrain range for size for *hinted* SymInts,
885    # because it is not only unsound, it will immediately trip our asserts
886    # that hints have to be consistent with static analysis!  If you somehow
887    # have an unbounded SymInt that later constrains to 1, this will be
888    # inconsistent with the range
889    if (
890        isinstance(a, SymInt)
891        and isinstance(a.node, SymNode)
892        and isinstance(a.node.expr, sympy.Symbol)
893        and a.node.shape_env.is_unbacked_symint(a.node.expr)
894    ):
895        _constrain_range_for_size(a)
896
897def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None):
898    """
899    This function is NOT INTENDED to be used by itself.
900    """
901
902    if isinstance(a, (SymFloat, SymBool)):
903        raise ValueError("Constraining SymFloat/SymBool is nyi")
904
905    assert isinstance(a, SymInt), "can only constrain range for SymInt"
906    assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
907
908    a.node.shape_env._constrain_range_for_size(a.node.expr, min, max)
909
910
911# inclusive both ways
912def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
913    """
914    Applies a constraint that the passed in SymInt must lie between min-max
915    inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning
916    that it can be used on unbacked SymInts).  If min/max are None, we assume
917    that the dimension is unbounded in that direction.  Repeated application
918    of constrain_range intersects the ranges.  This is a fairly low level API
919    that doesn't have a lot of safety guarantees (TODO: provide higher level
920    APIs).
921
922    Currently, we use this API in the following circumstance: when we allocate
923    an unbacked SymInt, denoting an integer quantity which is data dependent,
924    we ordinarily do not know anything about what values it may take.  This
925    means that any sort of guard on it will immediately fail.  However, in
926    many cases, we know something about the unbacked SymInt: for example, we
927    know that nonzero(x).size(0) must be >= 0.  We use constrain_range to
928    narrow the possible range, declaring that negative symbols are impossible.
929    This permits to definitely answer True to queries like 'nnz >= 0', even if
930    we don't know what the actual (hinted) value of 'nnz' is.  In fact, we
931    actually use constrain_range to unsoundly discharge common guards: for an
932    unbacked SymInt produced by nonzero, we will also assume that it is not
933    equal to 0/1 (even though these are perfectly possible values at runtime),
934    because we generally expect graphs that are valid for N=2 to also be valid
935    for N=1.
936    """
937    if min is None:
938        min = -int_oo
939    if max is None:
940        max = int_oo
941
942    if max < min:
943        raise ValueError(
944            "Maximum value to constrain_as_size can't be less than the specified min value, "
945            "received min={min} and max={max}"
946        )
947
948    if isinstance(a, int):
949        if not (min <= a <= max):
950            raise ValueError(f"Invalid value {a} for range [{min}:{max}]")
951        return
952
953    a.node.shape_env._constrain_range(a.node.expr, min, max)
954
955def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None:
956    """
957    Given two SymInts, constrain them so that they must be equal.  NB:
958    this will not work with SymInts that represent nontrivial expressions
959    (yet!)
960    """
961    if not isinstance(a, SymInt):
962        if not isinstance(b, SymInt):
963            assert a == b
964            return
965        else:
966            shape_env = b.node.shape_env
967    else:
968        shape_env = a.node.shape_env
969
970    shape_env._constrain_unify(a, b)
971
972# Assume that a boolean is true for the purposes of subsequent symbolic
973# reasoning.  This will keep track of corresponding runtime checks to verify
974# that the result is upheld: either as a regular guard, or as a special set
975# of asserts which are triggered when an unbacked SymInt is allocated.
976#
977# DO NOT use this function for these cases:
978#
979#  - This is inappropriate for "branching" conditions (where both
980#    true and false result in valid programs).  We will always assume
981#    the condition evaluates true, and so it will never be possible
982#    to trace the false condition when you use it.  For true branching
983#    on unbacked SymInts, you must use torch.cond; if you incorrectly
984#    use expect_true in this case, you will make the false branch
985#    unreachable (as we will simply assume that only the true branch
986#    is ever exercised).
987#
988#  - This is inappropriate for situations where you know some other system
989#    invariant guarantees that this property holds, since you don't
990#    really need to insert a runtime check in that case.  Use something
991#    like constrain_range in that case.
992#
993# This API has a hitch.  To avoid having to reimplement error reporting
994# capabilities, this function CAN return False.  The invariant is that
995# the surrounding code must raise an error when this function returns
996# False.  This is quite low level, so we recommend using other functions
997# like check() which enforce this in a more intuitive way.
998#
999# By the way, this name is a nod to the __builtin_expect macro,
1000# which is used similarly (but unlike __builtin_expect, you MUST fail
1001# in the unlikely branch.)  (I think expect is a good name; in recent
1002# versions of C++, this is replaced with [[likely]], which is weaker
1003# and not accurate for this function!)
1004def expect_true(a, skip: int = 0):
1005    if isinstance(a, SymBool):
1006        # TODO: check perf implications of this
1007        frame = inspect.currentframe()
1008        for _ in range(skip + 1):  # always run this loop at least once
1009            frame = frame.f_back
1010        return a.node.expect_true(frame.f_code.co_filename, frame.f_lineno)
1011    assert type(a) is bool, a
1012    return a
1013
1014def guard_bool(a):
1015    if isinstance(a, SymBool):
1016        return a.node.guard_bool("", 0)  # NB: uses Python backtrace
1017    assert type(a) is bool, a
1018    return a
1019
1020def guard_int(a):
1021    if isinstance(a, SymInt):
1022        return a.node.guard_int("", 0)  # NB: uses Python backtrace
1023    assert type(a) is int, a
1024    return a
1025
1026def guard_float(a):
1027    if isinstance(a, SymFloat):
1028        return a.node.guard_float("", 0)  # NB: uses Python backtrace
1029    assert isinstance(a, float), a
1030    return a
1031
1032# Given a GraphModule, return all the FakeTensors for all the placeholders
1033def fx_placeholder_vals(gm):
1034    return [n.meta['val'] for n in gm.graph.nodes if n.op == "placeholder"]
1035
1036def fx_placeholder_targets(gm):
1037    return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
1038
1039# Given a GraphModule and arguments to run it with, evaluate that the guards
1040# for its associated ShapeEnv are satisfied by the passed arguments.  This
1041# WILL check for duck sizing.
1042def eval_guards(gm, *args, ignore_static=True):
1043    return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args, ignore_static=ignore_static)
1044
1045def bind_symbols(gm, *args):
1046    return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
1047
1048class DimDynamic(Enum):
1049    """
1050    Controls how to perform symbol allocation for a dimension.  It is always
1051    sound to default this to DYNAMIC, but the policies DUCK and STATIC can
1052    result in better trace-time and compile-time performance, as they reduce
1053    the number of allocated symbols and generally make your graph more static.
1054
1055    NB: If we notice you've applied a constraint to the dimension, we will
1056    force it to DYNAMIC for simplicity.
1057
1058    DimDynamic is controlled by a variety of higher level UX features.
1059    Currently:
1060
1061    - In eager mode, the default policy is DUCK.
1062        - The default is changed to STATIC with assume_static_by_default.
1063        - An individual dim is marked DYNAMIC if you mark_dynamic_dim.
1064    - In export mode, the default policy is STATIC.
1065        - An individual dim is marked DYNAMIC if you specify it in
1066          dynamic_shapes passed to export.
1067    """
1068    # Treat the dimension symbolically
1069    DYNAMIC = 0
1070    # Treat the dimension symbolically, but if its hint matches another
1071    # dynamic dimension, unify the two symbols ("duck sizing")
1072    DUCK = 1
1073    # Treat the dimension statically based on its hint
1074    STATIC = 2
1075    # Treat the dimension as a size-like unbacked
1076    SIZE_LIKE_UNBACKED = 3
1077    # Infer the strides from stride. If size is static, strides will be static as well.
1078    INFER_STRIDE = 4
1079
1080
1081# NB: These constraints affect both clients and backends: given some
1082# constraint C, the client must pass inputs that satisfy the constraint,
1083# while a backend must not introduce guards BEYOND this constraint.
1084# For clarity, we document the implications on both sides for both the client
1085# and the backend.
1086#
1087# NB: These constraints are on a *single* dimension.  In principle, we could
1088# also have multi-dimension constraints, but our guess is that this is not
1089# actually useful and so we are not supporting it right now.
1090#
1091# NB: Strict constraints are typically only suitable for export, as in eager
1092# a backend like inductor may validly introduce extra, discretionary guards
1093# to improve performance of code.  A StrictMinMaxConstraint would be brittle
1094# under future optimizations performed by inductor; we don't guarantee
1095# eager code with StrictMinMaxConstraint will keep working in the future!
1096
1097@dataclass(frozen=True)
1098class Constraint:
1099    warn_only: bool
1100
1101@dataclass(frozen=True)
1102class StrictMinMaxConstraint(Constraint):
1103    """
1104    For clients: the size at this dimension must be within 'vr' (which
1105    specifies a lower and upper bound, inclusive-inclusive) AND it
1106    must be non-negative and should not be 0 or 1 (but see NB below).
1107
1108    For backends: there must not be any guards on this dimension which
1109    are not implied by the given lower and upper bound.  Regardless of
1110    the lower bound, the backend can assume the size is non-negative
1111    and that it is not 0 or 1.
1112
1113    An unbounded StrictMinMaxConstraint can be thought of as a strict version
1114    of "RelaxedUnspecConstraint".
1115
1116    NB: Export will often unsoundly assume that a graph works for 0/1, even
1117    though at trace time we assumed size is not 0 or 1.  The idea is that
1118    if we produce a graph that works for a range of values, it will be OK
1119    for N=0/1 too.
1120    """
1121    vr: ValueRanges
1122
1123    def render(self, source: Source):
1124        """Format the constrain equation"""
1125        # TODO: better printing for -oo and oo
1126        return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}"
1127
1128@dataclass(frozen=True)
1129class RelaxedUnspecConstraint(Constraint):
1130    """
1131    For clients: no explicit constraint; constraint is whatever is implicitly
1132    inferred by guards from tracing.
1133
1134    For backends: there must exist at least TWO possible values for the
1135    size at this dimension which satisfy the guards for this dimension.
1136
1137    In other words, this constraint helps us distinguish between "we don't
1138    care if this dimension specializes or not" versus "this dimension must be
1139    unspecialized."  However, this constraint doesn't say very much about what
1140    specialization is permitted; for example, if we guard on a size being
1141    even, this would still be acceptable under an unspec constraint.  This
1142    makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler
1143    may add constraints to otherwise dynamic dimensions; we can't assert that
1144    there are NO guards as this is brittle because compilers should be able to
1145    add extra constraints.  If you want to assert that there are no guards,
1146    use StrictMinMaxConstraint with an unbounded ValueRanges.
1147    """
1148    def render(self, source: Source):
1149        return f"RelaxedUnspecConstraint({source.name()})"
1150
1151# NB: None here indicates the client constraint is whatever is implicitly
1152# inferred by guards from tracing, and that a backend can add whatever guards
1153# it wants (including fully specializing the value).
1154DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None]
1155
1156@dataclass(frozen=True)
1157class EqualityConstraint(Constraint):
1158    """
1159    Represent and decide various kinds of equality constraints between input sources.
1160
1161    A "source pair" is a pair of input sources for dynamic dimensions that
1162    are specified equal. We represent `source_pairs` in a union-find forest
1163    so that we can efficiently check whether two such sources are transitively equal.
1164
1165    A "derived equality" relates an input source to an expression over a root.
1166    The root can be another input source, corresponding to some dynamic dimension,
1167    or a phantom symbol that does not directly represent any dynamic dimension. We
1168    represent `derived_equalities` involving input sources in a transitively-closed map
1169    so that we can efficiently check whether an input source is transitively equal to
1170    a given expression over another input source.
1171    (NOTE: In contrast, it is easy to decide whether an input source is transitively equal
1172    to a given expression over a phantom symbol; such expressions are already in canonical
1173    form and so the problem reduces to symbolic expression equality.)
1174    """
1175    source_pairs: List[Tuple[Source, Source]]
1176    derived_equalities: List[Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]]
1177    phantom_symbols: List[sympy.Symbol]
1178
1179    def __post_init__(self):
1180        """Pre-processing to answer queries `is_equal` and `is_derived` below.
1181
1182        Example: Suppose we are given:
1183          source_pairs [a = b, b = c]
1184          derived_equalities [d = c + 1, e = d - 1]
1185        We first construct a union find with source_pairs:
1186          _parents = {a: a, b: a, c: a}
1187        Then we compute canonical symbolic expressions, recursively applying derived_equalities
1188        until we bottom out:
1189          _defs = {d: c + 1, e: (c + 1) - 1 aka c}
1190        """
1191
1192        # self._parents is a map from input sources to input sources where, conceptually,
1193        # these are directed edges in a union-find forest
1194        _parents: Dict[Source, Source] = {}
1195        object.__setattr__(self, "_parents", _parents)
1196        # self._defs is a map from input sources to "canonical" symbolic expressions,
1197        # i.e., unary expressions with symbols that corresponds to regular Dims (i.e.,
1198        # not derived Dims)
1199        _defs: Dict[Source, sympy.Expr] = {}
1200        object.__setattr__(self, "_defs", _defs)
1201
1202        for source1, source2 in self.source_pairs:
1203            # preprocess into a union-find forest
1204            self._union(self._find(source1), self._find(source2))
1205        for source, root, fn in self.derived_equalities:
1206            # preprocess into a transitively-closed map
1207            # NOTE(avik): we reuse the union-find forest for canonicalizing input sources
1208            if isinstance(root, sympy.Symbol):
1209                self._defs[self._find(source)] = fn(root)
1210            else:
1211                self._defs[self._find(source)] = fn(self._rewrite(root))
1212
1213    def _find(self, source):
1214        # chase edges to find the root of this equivalence class
1215        if source in self._parents:
1216            return self._find(self._parents[source])
1217        else:
1218            return source
1219
1220    def _union(self, root1, root2):
1221        # merge two equivalence classes by adding an edge from one root to the other
1222        if root1 != root2:
1223            self._parents[root1] = root2
1224
1225    def _rewrite(self, src):
1226        # always represent the given source by the root of its equivalence class
1227        src = self._find(src)
1228        if src in self._defs:
1229            # simply look up the definition if it exists
1230            # NOTE(avik): This works because definitions are always transitively-closed;
1231            # otherwise we would have to do recursive rewriting.
1232            return self._defs[src]
1233        else:
1234            # otherwise, create a symbol representing the source
1235            return sympy.Symbol(src.name())
1236
1237    def is_equal(self, source1, source2):
1238        return (
1239            # check whether source1 and source2 have the same root
1240            self._find(source1) == self._find(source2) or
1241            # check whether source1 is derived equal to source2
1242            self.is_derived(source1, source2, lambda x: x)
1243        )
1244
1245    def is_derived(self, src, symbol_src, fn):
1246        # check whether both src and symbol_src have the same definition
1247        return self._rewrite(src) == fn(self._rewrite(symbol_src))
1248
1249
1250def _assert_symbol_context(symbolic_context):
1251    assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object"
1252    assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC"
1253
1254def _is_supported_equivalence(expr):
1255    # Currently supported Dim ops are linear expressions with integer coefficients.
1256    # So check that expr only contains +, *, ints, and a single occurrence of a symbol.
1257    # (See also documentation of dynamic_shapes._DerivedDim.)
1258    if isinstance(expr, (sympy.Add, sympy.Mul)):
1259        if len(expr.args) > 2:
1260            return False
1261        lhs, rhs = expr.args
1262        return (
1263            (_is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or
1264            (isinstance(lhs, sympy.Integer) and _is_supported_equivalence(rhs))
1265        )
1266    return isinstance(expr, sympy.Symbol)
1267
1268def _has_uninterpretable_sympy_function(expr) -> bool:
1269    """
1270    Add functions that our sympy interpreter can't reify into FX nodes
1271    """
1272    return expr.has(
1273        torch.utils._sympy.functions.ToFloat,
1274        torch.utils._sympy.functions.TruncToInt,
1275        torch.utils._sympy.functions.CeilToInt,
1276    )
1277
1278@dataclass(frozen=True)
1279class SymbolicContext:
1280    """
1281    Data structure specifying how we should create symbols in
1282    ``create_symbolic_sizes_strides_storage_offset``; e.g., should
1283    they be static or dynamic.
1284
1285    This is an abstract base class because we are probably going to add
1286    another version of this that says "use exactly these SymInts, don't
1287    allocate fresh symbols."
1288    """
1289
1290
1291@dataclass(frozen=True)
1292class StatelessSymbolicContext(SymbolicContext):
1293    """
1294    Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
1295    a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``.
1296    This will cause fresh symbols to be allocated
1297    """
1298    dynamic_sizes: DimList[DimDynamic]
1299    dynamic_strides: DimList[DimDynamic] = None
1300    constraint_sizes: DimList[DimConstraint] = None
1301    constraint_strides: DimList[DimConstraint] = None
1302    # If the tensor is a view, this should be populated for the base. It contains
1303    # information on how to allocate symbols when recursively fakeifying the base
1304    # during view fake-ification.
1305    view_base_context: Optional[SymbolicContext] = None
1306    # TODO: add storage offset and stride symbolic_context
1307
1308    def __post_init__(self):
1309        if self.dynamic_strides is None:
1310            object.__setattr__(self, 'dynamic_strides', [DimDynamic.INFER_STRIDE] * len(self.dynamic_sizes))
1311        if self.constraint_sizes is None:
1312            object.__setattr__(self, 'constraint_sizes', [None] * len(self.dynamic_sizes))
1313        if self.constraint_strides is None:
1314            object.__setattr__(self, 'constraint_strides', [None] * len(self.dynamic_sizes))
1315        assert all(stride in (DimDynamic.INFER_STRIDE, DimDynamic.DYNAMIC, DimDynamic.DUCK) for stride in self.dynamic_strides)
1316
1317
1318# note [Tensor Fakification and Symbol Caching]
1319#
1320# As of the time of this note, dynamo creates a fresh fake tensor mode for backends.
1321# The reason we do this is because there are certain classes of operations, namely,
1322# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor
1323# state at the end of a dynamo trace is different than the fake tensor state at the beginning
1324# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation,
1325# view relationships, etc.
1326#
1327# As we create a new fake mode, we also lose the memoization that comes with it. Rather than
1328# transfer the memoization cache, we instead transfer the shape env. However, with this
1329# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in
1330# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across
1331# recompilations.
1332#
1333# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass
1334# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext.
1335# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is
1336# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors
1337# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env
1338# is used.
1339# TODO(voz): Shape env validation
1340@dataclass(frozen=True)
1341class StatefulSymbolicContext(StatelessSymbolicContext):
1342    """
1343    Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
1344    a symbolic_context determination as given by a cache of Source:Symbol. A cache hit
1345    will reuse a stored symbol, and a cache miss will write to this cache.
1346
1347    This behaves like StatelessSymbolicContext, except the cache supersedes the
1348    other values - dynamic_sizes and constraint_sizes will not be read if we cache
1349    hit.
1350
1351    It is the cache owners responsibility to maintain the lifecycle of the cache
1352    w/r/t different shape_envs, clearing, etc.
1353    """
1354    tensor_source: Source = None
1355    # Why is this keyd on int first?
1356    # That integer is actually the id of the shape_env. This cache short-circuits symbol
1357    # creation, and we must store it per shape env. Now, while tracing invariants are a single
1358    # shape env per tracing context, and every new frame gets a new shape_env. So where would we have
1359    # multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events
1360    # is invoked, and creates a new shape_env. Replaying events against this new shape_env will
1361    # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never
1362    # get recorded in var_to_val, etc.
1363    # TODO(voz): consider a weakref to the shape_env here
1364    shape_env_to_source_to_symbol_cache : Dict[int, Dict["TensorPropertySource", "sympy.Expr"]] = None
1365
1366    def __post_init__(self):
1367        super().__post_init__()
1368        # The None default is annoying, but required because of dataclass limitations
1369        assert self.tensor_source is not None
1370        if not self.shape_env_to_source_to_symbol_cache:
1371            object.__setattr__(self, 'shape_env_to_source_to_symbol_cache', {})
1372
1373
1374@dataclass(frozen=True)
1375class SubclassSymbolicContext(StatefulSymbolicContext):
1376    """
1377    The correct symbolic context for a given inner tensor of a traceable tensor subclass
1378    may differ from that of the outer symbolic context. This structure allows for this
1379    flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
1380    """
1381    inner_contexts: Dict[str, SymbolicContext] = None
1382
1383    def __post_init__(self):
1384        super().__post_init__()
1385        if self.inner_contexts is None:
1386            self.inner_contexts = {}
1387
1388
1389def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool:
1390    if isinstance(val, (int, float, bool)):
1391        return False
1392    return val.node.is_symbolic()
1393
1394IndicatorTypes = (IsNonOverlappingAndDenseIndicator,)
1395
1396@lru_cache(256)
1397def safe_expand(r):
1398    if hasattr(r, 'expand'):
1399        try:
1400            return sympy.expand(r)
1401        except RecursionError:
1402            log.warning("RecursionError in sympy.expand(%s)", r)
1403            return r
1404    else:
1405        return r
1406
1407def error():
1408    raise AssertionError("shouldn't be hit")
1409
1410
1411# TODO: Deduplicate this with torch/_prims_common/__init__.py
1412def eval_is_non_overlapping_and_dense(sizes, strides):
1413    return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides)))
1414
1415def _eval_is_non_overlapping_and_dense(sizes, strides):
1416    dim = len(sizes)
1417
1418    # Short-circuits for tensors of rank one, which are
1419    # non-overlapping and "dense" if their stride is one
1420    # or it is a 0/1 element tensor
1421    if dim == 1:
1422        return strides[0] == 1 or sizes[0] < 2
1423
1424    # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
1425    # Sorts (length, stride) pairs by stride
1426    lengths_and_strides = sorted(
1427        zip(sizes, strides), key=operator.itemgetter(1)
1428    )
1429
1430    # Unlike the C++ code, we don't move the 0/1 size dimensions to the
1431    # end.  So we have to keep going for this code.
1432    expected_stride = 1
1433    for length, stride in lengths_and_strides:
1434
1435        if length == 1:
1436            continue
1437
1438        if stride != expected_stride:
1439            return False
1440
1441        expected_stride *= length
1442
1443    return True
1444
1445
1446def _sympy_cast_symbool_to_symint_guardless(x: sympy.Expr) -> sympy.Expr:
1447    return sympy.Piecewise((1, x), (0, True))
1448
1449
1450def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt:
1451    if isinstance(symbool, bool):
1452        return 1 if symbool else 0
1453    int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr)
1454    return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint()) if has_hint(symbool) else None)
1455
1456SYMPY_INTERP = {
1457    'Abs': operator.abs,
1458    'Eq': operator.eq,
1459    'Ne': operator.ne,
1460    'Gt': operator.gt,
1461    'Lt': operator.lt,
1462    'Le': operator.le,
1463    'Ge': operator.ge,
1464    'Min': min,
1465    'Max': max,
1466    'Mod': operator.mod,
1467    'PythonMod': operator.mod,
1468    'FloorDiv': operator.floordiv,
1469    'TrueDiv': operator.truediv,
1470    'PowByNatural': operator.pow,
1471    'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
1472    'floor': math.floor,
1473    'ceiling': math.ceil,
1474    'FloorToInt': math.floor,
1475    'FloatPow': math.pow,
1476    'CeilToInt': math.ceil,
1477    'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless,
1478    'RoundToInt': builtins.round,
1479    'RoundDecimal': builtins.round,
1480    'TruncToInt': math.trunc,
1481    'IntTrueDiv': operator.truediv,
1482    'FloatTrueDiv': operator.truediv,
1483    'ToFloat': builtins.float,
1484}
1485
1486
1487def _lru_cache(fn, maxsize=None):
1488    """
1489    Wrapper around lru_cache that clears when new info about shapes has been
1490    updated.
1491
1492    Use lru_cache if the output is always the same, regardless of the
1493    constraints we know now (i.e. evaluate_expr)
1494
1495    Use _lru_cache otherwise.
1496
1497    Also note that this depends on _update_version_counter being called on the
1498    shape environment whenever the constraints are updated, otherwise the cache
1499    will not be cleared.
1500    """
1501    fn_cache = lru_cache(maxsize)(fn)
1502    prior_version = 0
1503
1504    if config.validate_shape_env_version_key:
1505        prior_key = None
1506
1507        @functools.wraps(fn)
1508        def wrapper(self, *args, **kwargs):
1509            nonlocal prior_version, prior_key
1510            if prior_key is None:
1511                prior_key = self._get_key()
1512
1513            if prior_version != self._version_counter:
1514                fn_cache.cache_clear()
1515                prior_version = self._version_counter
1516                prior_key = self._get_key()
1517            else:
1518                assert prior_key == self._get_key(), \
1519                    "ShapeEnv cache key changed without version being updated!"
1520
1521            return fn_cache(self, *args, **kwargs)
1522
1523    else:
1524
1525        @functools.wraps(fn)
1526        def wrapper(self, *args, **kwargs):
1527            nonlocal prior_version
1528            if prior_version != self._version_counter:
1529                fn_cache.cache_clear()
1530                prior_version = self._version_counter
1531
1532            return fn_cache(self, *args, **kwargs)
1533
1534    wrapper.cache_clear = fn_cache.cache_clear
1535    wrapper.cache_info = fn_cache.cache_info  # type: ignore[attr-defined]
1536    return wrapper
1537
1538
1539# This is pretty similar to ShapeGuard but it also comes with a message,
1540# and is exclusively used for things that MUST be true (unlike guards,
1541# which can evaluate False, in which case you just choose not to use
1542# a particular specialization)
1543@dataclass(frozen=True)
1544class RuntimeAssert:
1545    expr: sympy.Expr
1546    msg: str = field(repr=False)
1547    stack: str = field(repr=False)
1548
1549
1550# Used for printing SymExprs in compile_fx
1551class SymExprPrinter(StrPrinter):
1552    def _print_Float(self, expr):
1553        return str(float(expr))
1554
1555
1556class ShapeGuardPrinter(SymExprPrinter):
1557    def __init__(
1558        self,
1559        symbol_to_source,
1560        source_ref,
1561        var_to_sources,
1562    ):
1563        super().__init__()
1564        self.symbol_to_source = symbol_to_source
1565        self.source_ref = source_ref
1566        self.var_to_sources = var_to_sources
1567
1568    def _print_Not(self, expr):
1569        return 'not {}'.format(self.parenthesize(expr.args[0], PRECEDENCE["Not"]))
1570
1571    def _print_And(self, expr):
1572        return self.stringify(expr.args, " and ", PRECEDENCE["And"])
1573
1574    def _print_Or(self, expr):
1575        return self.stringify(expr.args, " or ", PRECEDENCE["Or"])
1576
1577    def _print_Symbol(self, expr) -> str:
1578        assert isinstance(expr, sympy.Symbol), str(type(expr))
1579
1580        def repr_symbol_to_source():
1581            return repr({
1582                symbol: [s.name() for s in sources]
1583                for symbol, sources in self.symbol_to_source.items()
1584            })
1585
1586        assert self.symbol_to_source.get(expr), (
1587            f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) "
1588            f"not in {repr_symbol_to_source()}.  If this assert is failing, it could be "
1589            "due to the issue described in https://github.com/pytorch/pytorch/pull/90665"
1590        )
1591        return self.source_ref(self.symbol_to_source[expr][0])
1592
1593
1594class LoggingShapeGuardPrinter(ShapeGuardPrinter):
1595    def __init__(self, var_to_sources):
1596        super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
1597
1598
1599class DynamicDimConstraintPrinter(StrPrinter):
1600    """
1601    Printer for dynamic dim constraints.
1602    - Instead of symbol s_k it prints its source t.size()[i]
1603    - Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc.
1604
1605    We use this to suggest code for specifying dynamic dim constraints.
1606    """
1607    def __init__(self, symbol_to_source, source_name_to_debug_name):
1608        super().__init__()
1609        self.symbol_to_source = symbol_to_source
1610        self.source_name_to_debug_name = source_name_to_debug_name
1611
1612    def _print_Symbol(self, expr) -> str:
1613        assert isinstance(expr, sympy.Symbol), str(type(expr))
1614        assert self.symbol_to_source.get(expr), (
1615            f"Unknown symbol {expr} created by constraints solver"
1616        )
1617        return self.symbol_to_source[expr][0].name()
1618
1619    def _print_Relational(self, expr):
1620        return f'{self.parenthesize(expr.lhs, precedence(expr))} {expr.rel_op} {self.parenthesize(expr.rhs, precedence(expr))}'
1621
1622
1623class DimConstraints:
1624    """
1625    Custom solver for a system of constraints on symbolic dimensions.
1626    Solutions are "static" values or simplified "dynamic" constraints.
1627    """
1628
1629    def __init__(
1630        self,
1631        symbol_to_source,
1632        var_to_val,
1633        marked_dynamic,
1634        source_name_to_debug_name,
1635    ):
1636        # We try to solve systems of inequalities with 1 free variable.
1637        self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set)
1638        # Among them, we prioritize solving for a free variable that has equalities.
1639        # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
1640        # and removing a symbol from the former => removing it from the latter.
1641        self._symbols_with_equalities: Set[sympy.Symbol] = set()
1642        # A solution of a free variable with equalities becomes a substitution.
1643        # We use these substitutions to simplify other constraints.
1644        # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions.
1645        self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {}
1646
1647        # In general, constraints may have // and % operations.
1648        # Of course, // can be expressed in terms of / and %.
1649        # Our inequality solver can handle / but not %. So we need to transform them away.
1650        # We do so by using the values of variables as hints to evaluate %.
1651        # For soundness we record additional congruence guards and solve them separately.
1652        self._var_to_val: Dict[sympy.Symbol, sympy.Integer] = var_to_val
1653        self._congruences: Set[sympy.Expr] = defaultdict(set)
1654
1655        # We do not try to (directly) solve inequalities with > 1 free variables.
1656        # NOTE: free variables in these inequalities cannot also be in _substitutions.
1657        self._multivariate_inequalities: Set[sympy.Expr] = set()
1658
1659        # We park external equalities between free variables here.
1660        self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = []
1661
1662        # Solutions come in two forms:
1663        # - (static) specializations
1664        # - (dynamic) inequalities / congruences
1665        self._static_results: Set[str] = set()
1666        self._dynamic_results: Set[str] = set()
1667
1668        # printer for solutions
1669        self._dcp = DynamicDimConstraintPrinter(symbol_to_source, source_name_to_debug_name)
1670
1671        # inconsistencies found on substituting with concrete values / static solutions
1672        self._inconsistencies: List[str] = []
1673
1674        # symbols that are marked dynamic
1675        self._marked_dynamic = marked_dynamic
1676
1677        # track supported sympy functions and subtract from list of all sympy functions
1678        self._supported_sympy_functions: Set[sympy.Function] = {
1679            Application,
1680            Mod,
1681            PythonMod,
1682            FloorDiv,
1683        }
1684        self._enumerate_sympy_functions()
1685
1686    def rewrite_with_congruences(self, s, expr):
1687        """
1688        Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
1689        This leaves rational operators (in particular of the form b / d) that our inequality solver can handle.
1690        We solve the added congruences separately (using our congruence solver, see below).
1691        """
1692        def mod_handler(*args):
1693            # Suppose that we have an expression of the form b % d with free variable s.
1694            # Using the value of s as a "hint," we can evaluate b % d to a value k.
1695            # Then we can rewrite b % d to k while adding the guard b % d == k.
1696
1697            # NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF
1698            # the original expression always evaluates to a constant value (i.e., it does not vary with s).
1699            # In other words,
1700            # - solutions of s with the rewritten expression are guaranteed to also be solutions of s with
1701            #   the original expression;
1702            # - while it may be possible to find solutions of s with the original expression that are not
1703            #   solutions with the rewritten expression, in that case the original expression cannot evaluate
1704            #   to the same value for all solutions of s.
1705            #
1706            # Should we be worried about this incompleteness? No, because of the following reasons:
1707            # 1. It unblocks dramatic simplification that would not be otherwise possible with current tech
1708            #    (i.e., "don't let perfect be the enemy of the good").
1709            # 2. We already have a tradition of using hints to add guards in the compiler for making progress.
1710            # 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards
1711            #    we generate (or simplify to) seem to be of the form b % d == k where k is a constant.
1712            #
1713            # Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2.
1714            # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we
1715            # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution!
1716            base, divisor = args
1717            base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor)
1718            mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(self._var_to_val)
1719            congruence = (base - mod_reduced) % divisor
1720            if congruence != 0:
1721                self._congruences[s].add(congruence)
1722            return mod_reduced
1723
1724        def floor_div_handler(*args):
1725            # Suppose that we have an expression of the form b // d with free variable s.
1726            # Using the value of s, we can evaluate b % d to a value k.
1727            # Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k.
1728
1729            # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d
1730            # and eliminating b % d as above.
1731            base, divisor = args
1732            base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor)
1733            mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(self._var_to_val)
1734            congruence = (base - mod_reduced) % divisor
1735            if congruence != 0:
1736                self._congruences[s].add(congruence)
1737            # NB: Must not be CleanDiv, it needs to be regular sympy division
1738            # so inequality solver works.  This is sort of problematic for
1739            # is_integer tests though haha
1740            return (base - mod_reduced) / divisor
1741
1742        if expr.has(Mod):
1743            expr = expr.replace(Mod, mod_handler)
1744        # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
1745        # arguments should be OK.
1746        if expr.has(PythonMod):
1747            expr = expr.replace(PythonMod, mod_handler)
1748        if expr.has(FloorDiv):
1749            expr = expr.replace(FloorDiv, floor_div_handler)
1750        return expr
1751
1752    def _enumerate_sympy_functions(self):
1753        module = torch.utils._sympy.functions
1754        all_functions = set()
1755        for attr in dir(module):
1756            if isinstance(func := getattr(module, attr), sympy.FunctionClass):
1757                all_functions.add(func)
1758        self._unsupported_sympy_functions = all_functions.difference(self._supported_sympy_functions)
1759
1760    def _has_unsupported_sympy_function(self, expr) -> bool:
1761        """
1762        Tracks list of sympy.Functions the export solver doesn't know how to handle.
1763        """
1764        return expr.has(*self._unsupported_sympy_functions)
1765
1766    def add(self, expr) -> bool:
1767        """Add an expression to the set of constraints.
1768
1769        Return whether the expression is a trivial constraint (i.e., an obvious tautology).
1770        """
1771        if expr == sympy.true:
1772            return True
1773        orig_expr = expr
1774        orig_reduced = orig_expr.xreplace(self._var_to_val)
1775        # TODO(avik): https://github.com/pytorch/pytorch/issues/101093
1776        # It is possible that `expr` will fail the consistency check because of
1777        # precision errors. Specifically, on substituting its free symbols with
1778        # their concrete values, we might end up comparing floats. Until we have
1779        # a fix for this issue, we delay raising such failures. See solve().
1780        if orig_reduced == sympy.false:
1781            self._inconsistencies.append(f"{orig_expr} is inconsistent!")
1782        if isinstance(expr, sympy.Ne) or self._has_unsupported_sympy_function(expr):
1783            # we're not going to do anything useful with these, so drop them
1784            return False
1785        free_symbols = expr.free_symbols
1786        assert free_symbols, f"Did not expect constraint with no free variables: {expr}"
1787        if len(free_symbols) > 1:
1788            # multivariate: record and move on
1789            self._multivariate_inequalities.add(expr)
1790        else:
1791            # univariate: can solve these immediately
1792            s = next(iter(free_symbols))
1793            # eliminate // and % (see documentation of `rewrite_with_congruences` above)
1794            old_n_congruences = len(self._congruences[s])
1795            expr = self.rewrite_with_congruences(s, expr)
1796            new_n_congruences = len(self._congruences[s])
1797            if expr == sympy.true:
1798                return old_n_congruences == new_n_congruences
1799            reduced = expr.xreplace(self._var_to_val)
1800            if reduced == sympy.false:
1801                self._inconsistencies.append(
1802                    f"{expr}, obtained by rewriting {orig_expr} with congruences, "
1803                    "is inconsistent!"
1804                )
1805            if isinstance(expr, sympy.Eq):
1806                # special status for symbols that have equalities (see `solve` below)
1807                self._symbols_with_equalities.add(s)
1808            self._univariate_inequalities[s].add(expr)
1809        return False
1810
1811    def add_equality(self, source, expr):
1812        """Add an equality constraint"""
1813        if expr.is_number:
1814            # specialization, right here
1815            self._static_results.add(f"{source.name()} == {expr}")
1816        else:
1817            # these will resolve to either specializations or dynamic equality constraints
1818            self._symbolic_equivalences.append((source, expr))
1819
1820    def _reduce_congruences(self):
1821        reduced_congruences = {}
1822        for s, congruences in self._congruences.items():
1823            remainder_modulus_pairs = []
1824            congruences_to_check = set()
1825            for congruence in congruences:
1826                base, divisor = congruence.args
1827                # We are given a congruence of the form base % divisor == 0 with a free variable s. So:
1828                # - we transform this into an equation of the form base = divisor * tmp;
1829                # - we solve this equation for s to get a linear solution with free variable tmp.
1830                tmp = sympy.Symbol("reduce_congruences_tmp", integer=True)
1831                symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s])
1832                # See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear
1833                # for how to interpret the results.
1834                if s == symbol:
1835                    # This means the solution is of the form s = modulus*tmp + remainder.
1836                    modulus, remainder = sympy.polys.polytools.div(solution, tmp)
1837                    if isinstance(modulus, sympy.Integer) and isinstance(remainder, sympy.Integer):
1838                        # Make sure 0 <= remainder <= modulus.
1839                        remainder = remainder % modulus
1840                        remainder_modulus_pairs.append((remainder, modulus))
1841                        continue
1842                # This means that we did not get a unique solution to the equation.
1843                # No problem, we will check it.
1844                congruences_to_check.add(congruence)
1845            # Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i).
1846            # The solution will be a congruence of the form s = r mod m.
1847            # NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT.
1848            if remainder_modulus_pairs:
1849                remainder, modulus = sympy.ntheory.modular.solve_congruence(*remainder_modulus_pairs)
1850                reduced_congruences[s] = {(s - remainder) % modulus}
1851                substitution = {s: modulus * sympy.Symbol("tmp", integer=True) + remainder}
1852                reduced_congruences[s].update(
1853                    congruence for congruence in congruences_to_check
1854                    if not sympy.checksol(congruence, substitution)
1855                )
1856            else:
1857                reduced_congruences[s] = congruences_to_check
1858
1859        return reduced_congruences
1860
1861    def _raise_inconsistencies(self):
1862        if self._inconsistencies:
1863            msg = "\n".join(self._inconsistencies)
1864            self._inconsistencies.clear()
1865            raise ValueError(f"The following inconsistencies were found:\n{msg}")
1866
1867    def solve(self):
1868        """Solve the system of constraint equations to find simplified constraints
1869        """
1870        self._raise_inconsistencies()
1871        # as long as there are symbols with equalities, solve for them
1872        # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
1873        while self._symbols_with_equalities:
1874            s = self._symbols_with_equalities.pop()
1875            exprs = self._univariate_inequalities.pop(s)
1876            solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
1877            if isinstance(solution, sympy.And):
1878                solution = next((arg for arg in solution.args if isinstance(arg, sympy.Eq)), solution)
1879            assert isinstance(solution, sympy.Eq), f"Expected an equality constraint for {s}, got {solution}"
1880            symbol, val = solution.args
1881            assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
1882            # because this is univariate, the solution is a specialization
1883            self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
1884            # add this as a substitution to simplify other constraints
1885            self._substitutions[s] = val
1886
1887            # simplify multivariate inequalities: some of them will now become univariate!
1888            multivariate_inequalities = self._multivariate_inequalities
1889            self._multivariate_inequalities = set()
1890            for expr in multivariate_inequalities:
1891                self.add(expr.xreplace({s: self._substitutions[s]}))
1892            self._raise_inconsistencies()
1893
1894        # solve linear congruences
1895        # NOTE(avik): We do not need to solve them for symbols that have already been specialized.
1896        reduced_congruences = self._reduce_congruences()
1897        for s, congruences in reduced_congruences.items():
1898            for congruence in congruences:
1899                # any congruence that cannot be checked becomes a dynamic constraint as well
1900                if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
1901                    if self._is_supported_congruence(congruence):
1902                        base, divisor = congruence.args
1903                        tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"
1904                        tmp = sympy.Symbol(tmp_name, integer=True)
1905                        from torch._dynamo.source import ConstantSource
1906                        self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)]
1907                        r = try_solve(sympy.Eq(base, divisor * tmp), s)
1908                        self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1])))
1909
1910        # remaining symbols have only pure inequalities (no equalities)
1911        for s, exprs in self._univariate_inequalities.items():
1912            try:
1913                solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
1914                # because this is univariate, the solution is a dynamic (range) constraint
1915                if isinstance(solution, sympy.Or):
1916                    solution = next(iter(arg for arg in solution.args if arg.xreplace(self._var_to_val)))
1917                if isinstance(solution, sympy.And):
1918                    for arg in solution.args:
1919                        self._dynamic_results.add(self._dcp.doprint(arg))
1920                else:
1921                    self._dynamic_results.add(self._dcp.doprint(solution))
1922            except (NotImplementedError, AssertionError) as e:
1923                log.warning("Failed to reduce inequalities: %s", e)
1924                for expr in exprs:
1925                    self._dynamic_results.add(self._dcp.doprint(expr))
1926
1927        # simplify symbolic equivalences: some of them will now become specializations!
1928        symbolic_equivalences = self._symbolic_equivalences
1929        self._symbolic_equivalences = []
1930        for source, expr in symbolic_equivalences:
1931            self.add_equality(source, expr.xreplace(self._substitutions))
1932
1933        # remaining symbolic equivalences become dynamic equality constraints
1934        for source, expr in self._symbolic_equivalences:
1935            self._dynamic_results.add(f"{source.name()} == {self._dcp.doprint(expr)}")
1936
1937    @classmethod
1938    def _is_supported_congruence(cls, congruence):
1939        base, divisor = congruence.args
1940        # Congruences that can be currently expressed with supported Dim ops are
1941        # of the form (x + a) % b == 0, where x is a Dim and a and b are constants.
1942        # This allows us to derive x as b*y - a for some Dim y.
1943        # (See also documentation of dynamic_shapes._DerivedDim.)
1944        if isinstance(base, sympy.Add):
1945            lhs, rhs = base.args
1946            cond = (
1947                (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer)) or
1948                (isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol))
1949            )
1950        else:
1951            cond = isinstance(base, sympy.Symbol)
1952        cond = cond and isinstance(divisor, sympy.Integer)
1953        return cond
1954
1955    def forced_specializations(self):
1956        """Returns a dictionary of the names of symbols to their specialized value
1957        """
1958        def debug_name(src):
1959            name = src.name()
1960            if self._dcp.source_name_to_debug_name:
1961                return f"{self._dcp.source_name_to_debug_name[name]} = {name}"
1962            else:
1963                return name
1964
1965        return {
1966            debug_name(self._dcp.symbol_to_source[s][0]): val
1967            for s, val in self._substitutions.items()
1968            if s in self._marked_dynamic
1969        }
1970
1971    def _is_derived_dim(self, dim):
1972        return isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
1973
1974    def _is_dim(self, dim):
1975        return (
1976            isinstance(dim, torch.export.dynamic_shapes._Dim)
1977            and not isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
1978        )
1979
1980    def _process_derived_dim_roots(
1981        self,
1982        results: Dict[str, Dict[str, Any]],
1983        name_to_dim: Dict[str, Any],
1984    ) -> None:
1985        '''
1986        Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots,
1987        and 2) root swapping.
1988
1989        1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests
1990        dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final
1991        suggested fixes handle this correctly, but we can get intermediate results that look like
1992        {"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}}
1993        and this routine prettifies this by unifying to a single root, and making each suggestion
1994        either a derived dim or min/max range, not both.
1995
1996        2) With suggested fixes for derived dims, roots can be swapped,
1997        e.g. dx, dx - 1 -> dy + 1, dy. Here we don't want to print out the attached name,
1998        since this leads to messages like "dx - 1 = Dim("dx - 1", ...)".
1999        Instead we evaluate the new root value, and remove results for its derivations.
2000
2001        First we find all the original roots (specified in dynamic_shapes), that are found in the
2002        values of results (i.e. used for computing suggesting fix values). These original roots
2003        (suppose `dx`) are either specialized, unchanged, refined, or swapped
2004        (expressed as a derived dim). If any of the first 3 cases happen, we suggest `dx`'s value
2005        in results, and remove suggestions for derivations of `dx`, assuming the derived relation
2006        is valid. If swapped, we find the new root, and use the fix to evaluate `dx`'s new value,
2007        and then do the same with `dx`'s derivations.
2008
2009        Assuming the originally specified derived relations are correct is valid, because:
2010            1) if the relations are plain wrong (e.g. input shape = (6, 4) with spec (dx, dx - 1))
2011               produce_guards() will catch this and crash before hand.
2012            2) if the relations are numerically correct but do not match the emitted guard,
2013               for example:
2014
2015                    def forward(self, x, y):
2016                        return x.reshape([-1]) + y  # guard: s0 * 2 = s1
2017                    inputs = (torch.randn(6, 2), torch.randn(12))
2018                    dx = Dim("dx", min=2, max=32)
2019                    dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )}  # this matches values but not op
2020
2021               then this leads to 2 linear equations, and a) produce_guards() is able to solve for
2022               the unique solution of dx = 6 and specialize, and b) the export constraint solver will
2023               raise an issue due to range constraints (a unique solution means not all values in a
2024               range satisfy a guard) and also force specializations.
2025        '''
2026        from torch.export.dynamic_shapes import Dim
2027
2028        def _check_same_range(c, dim):
2029            # returns True if c & dim are both min/max ranges with same values
2030            return (
2031                self._is_dim(dim)
2032                and ("min" in c or "max" in c)
2033                and (
2034                    (dim.min < 2 and c.get("min", 2) == 2)
2035                    or dim.min == c.get("min", 2)
2036                )  # let pass if analysis min = 2 and specified min = 0/1
2037                and dim.max == c.get("max", int_oo)
2038            )
2039
2040        # 1) newly introduced roots
2041        # this part we handle adding newly introduced roots
2042        # these arise from guards like "x.shape[0] % 3 == 0"
2043        # leading to suggested fixes like "dx = 3*_dx"
2044        # extract _dx, and find appropriate min/max values
2045        #
2046        # before, we have something like:
2047        # {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2}
2048        # we want instead:
2049        # {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3}
2050        introduced_roots: Dict[str, str] = {}  # map new root -> old root
2051        for k, c in list(results.items()):
2052            if "eq" in c and isinstance(c["eq"], sympy.Expr):  # derived dim
2053                root = next(iter(c["eq"].free_symbols))
2054                if str(root) not in name_to_dim:
2055                    introduced_roots[str(root)] = k
2056                    # calculate necessary min & max
2057                    modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
2058                    c_min = c.get("min", 2)
2059                    min_ = math.ceil((c_min - remainder) / modulus)
2060                    c_max = c.get("max", int_oo)
2061                    max_ = math.floor((c_max - remainder) / modulus)
2062                    # create result & dim
2063                    results[str(root)] = {"min": min_, "max": max_}
2064                    name_to_dim[str(root)] = Dim(str(root), min=min_, max=max_)
2065                    # remove old root min/max bounds
2066                    c.pop("min", None)
2067                    c.pop("max", None)
2068
2069        # alter derivations that depend on old root, to unify to new root
2070        # e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2
2071        for old_root in introduced_roots.values():
2072            for k, c in list(results.items()):
2073                if (
2074                    "eq" in c
2075                    and isinstance(c["eq"], sympy.Expr)
2076                    and str(symbol := next(iter(c["eq"].free_symbols))) == old_root
2077                ):  # derived dim with root = old_root
2078                    new_root_expr = results[str(old_root)]["eq"]  # dx=3*_dx+1
2079                    new_expr = c["eq"].subs({symbol: new_root_expr})  # dy=(3*_dx+1)+1
2080                    c["eq"] = new_expr
2081
2082        # 2) root swapping
2083        # collect all the original roots that are used for calculating values of suggested fixes
2084        # this consists of:
2085        # 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim
2086        # 2) {"dy": "dx + 1"} -> dx: root for suggested fix
2087        modified_roots: Set[str] = set()
2088        for k, c in results.items():
2089            if k not in name_to_dim:  # _dynamo.export() may handle source directly
2090                continue
2091            if self._is_dim(name_to_dim[k]) and ("min" in c or "max" in c):  # case 1)
2092                modified_roots.add(k)
2093            elif "eq" in c and isinstance(c["eq"], sympy.Expr):  # case 2)
2094                root = next(iter(c["eq"].free_symbols))
2095                assert root is not None
2096                modified_roots.add(str(root))
2097
2098        # exclude newly introduced roots, we've already processed these
2099        modified_roots = modified_roots.difference(introduced_roots)
2100
2101        # evaluate the new value for each root
2102        # this is now either 1) unchanged, 2) refined with a new range,
2103        # or 3) specialized to a concrete value
2104        modified_root_values: Dict[str, Dict[str, Any]] = {}
2105        for root in modified_roots:
2106            swapped_root = True
2107            if root in results:
2108                c = results[root]
2109                if (
2110                    ("min" in c or "max" in c)  # range
2111                    or isinstance(c["eq"], int)  # specialized
2112                ):
2113                    # here, the original root is a root Dim or concrete value in results.
2114                    # if it is a derived dim, it is swapped, and we handle that below.
2115                    if not _check_same_range(c, name_to_dim[root]):  # ignore if unchanged
2116                        modified_root_values[root] = c
2117                    swapped_root = False
2118
2119            if swapped_root:
2120                # if the original root has been swapped in results, that means the new root
2121                # is a range (if it had specialized, the original root would have too).
2122                # find this new root, and solve for the original root's range.
2123                for k, c in results.items():
2124                    if k not in name_to_dim:
2125                        continue
2126                    dim = name_to_dim[k]
2127                    if dim.__class__.__name__ == "_DerivedDim" and dim.root.__name__ == root:
2128                        # only look for min/max root, otherwise root would have specialized
2129                        if "min" in c or "max" in c:
2130                            expr = sympy.sympify(k)
2131                            s = next(iter(expr.free_symbols))
2132                            result = {
2133                                "min": try_solve(sympy.Eq(expr, c["min"]), s)[1],  # type: ignore[arg-type]
2134                                "max": try_solve(sympy.Eq(expr, c["max"]), s)[1],  # type: ignore[arg-type]
2135                            }
2136                            if not _check_same_range(result, name_to_dim[root]):  # ignore if unchanged
2137                                modified_root_values[root] = result
2138                                break
2139
2140        # filter out results where the key is a derived dim (e.g. {"dx - 1" : 4})
2141        # we only want to suggest fixes for the root, to avoid derived names.
2142        # also, remove anything in modified_roots, since we either add new modified values after this,
2143        # or have decided they are unchanged.
2144        for k in list(results.keys()):
2145            if k not in name_to_dim:
2146                continue
2147            if self._is_derived_dim(name_to_dim[k]) or k in modified_roots:
2148                del results[k]
2149
2150        # update results with modified root values
2151        # now results has the following properties:
2152        # - only contains original roots as keys
2153        # - each root is now either specialized, refined, or derived from another original root
2154        results.update(modified_root_values)
2155
2156    def prettify_results(
2157        self,
2158        original_signature: inspect.Signature,
2159        dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
2160        constraint_violation_error=None,
2161        forced_specializations=None,
2162    ):
2163        """Format a message for constraint violation erros"""
2164        from torch.export.dynamic_shapes import _get_dim_name_mapping
2165        if not self._dcp.source_name_to_debug_name:
2166            # nothing to do
2167            return ""
2168
2169        def transform(s, inverse=False):
2170            for k, v in self._dcp.source_name_to_debug_name.items():
2171                s = s.replace(k, v) if not inverse else s.replace(v, k)
2172            return s
2173
2174        results = defaultdict(dict)
2175        if dynamic_shapes is None:
2176            dynamic_shapes = {}
2177
2178        def flip(op):
2179            if op == "<=":
2180                return ">="
2181            if op == ">=":
2182                return "<="
2183            if op == "<":
2184                return ">"
2185            if op == ">":
2186                return "<"
2187            assert op == "=="
2188            return op
2189
2190        def relation_with_digit(expr, op, digit):
2191            if op == "<=":
2192                results[expr]["max"] = digit
2193            elif op == "<":
2194                results[expr]["max"] = digit - 1
2195            elif op == ">=":
2196                results[expr]["min"] = digit
2197            elif op == ">":
2198                results[expr]["min"] = digit + 1
2199            else:
2200                assert op == "=="
2201                results[expr]["eq"] = digit
2202
2203        # retrieve dynamic shapes
2204        name_to_dim = _get_dim_name_mapping(dynamic_shapes)
2205
2206        for s in self._static_results.union(self._dynamic_results):
2207            t = transform(s)
2208            if t == s:
2209                continue
2210            left, op, right = re.split(r"( == | <= | >= | < | > )", t)
2211            op = op.strip()
2212            if op == "==" and left == right:
2213                continue
2214            if right.isdigit():
2215                relation_with_digit(left, op, int(right))
2216            elif left.isdigit():
2217                relation_with_digit(right, flip(op), int(left))
2218            else:
2219                assert op == "==", t
2220                results[left]["eq"] = sympy.sympify(right)
2221
2222        # order forced specializations based on name
2223        forced_specializations = {
2224            k: forced_specializations[k]
2225            for k in sorted(
2226                forced_specializations.keys(),
2227                key=lambda x: x.split(" = ")[1],
2228            )
2229        }
2230
2231        buf = ""
2232        if forced_specializations:
2233            debug_names = set()
2234            for k in forced_specializations:
2235                dim = name_to_dim[k.split(" = ")[0]]
2236                if self._is_derived_dim(dim):
2237                    debug_names.add(dim.root.__name__)
2238                else:
2239                    debug_names.add(dim.__name__)
2240
2241            buf += (
2242                f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! "
2243                'For more information, run with TORCH_LOGS="+dynamic".\n'
2244            )
2245            for s, val in forced_specializations.items():
2246                buf += f"  - solving the guards generated for {s} resulted in a specialized value of {val}.\n"
2247
2248        self._process_derived_dim_roots(results, name_to_dim)
2249
2250        dims = []
2251        others = []
2252
2253        # order results by source name
2254        results = {
2255            k: results[k] for k in sorted(
2256                results.keys(),
2257                key=lambda x: transform(x, inverse=True),
2258            )
2259        }
2260        for k, c in results.items():
2261            if "eq" in c:
2262                other = c["eq"]
2263                if isinstance(other, int):
2264                    others.append(f"{k} = {other}")
2265                elif _is_supported_equivalence(other):
2266                    others.append(f"{k} = {other}")
2267            else:
2268                min_ = c.get("min", None)
2269                if min_ == 2:
2270                    min_ = None
2271                max_ = c.get("max", None)
2272                if min_ is not None and max_ is not None:
2273                    dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})")
2274                elif min_ is not None:
2275                    dims.append(f"{k} = Dim('{k}', min={min_})")
2276                elif max_ is not None:
2277                    dims.append(f"{k} = Dim('{k}', max={max_})")
2278                else:
2279                    dims.append(f"{k} = Dim('{k}')")
2280
2281        # results will get filtered out if no new suggestions,
2282        # this can happen if guards are too complex.
2283        # in that case don't suggest fix
2284        if dims or others:
2285            buf += "\nSuggested fixes:\n  "
2286            buf += "\n  ".join(dims + others)
2287
2288        return buf
2289
2290
2291TLS = threading.local()
2292
2293
2294@dataclass(frozen=True)
2295class ShapeEnvSettings:
2296    """
2297    Encapsulates all shape env settings that could potentially affect
2298    FakeTensor dispatch. Used when creating dispatch cache keys.
2299    """
2300
2301    allow_scalar_outputs: bool
2302    allow_dynamic_output_shape_ops: bool
2303    assume_static_by_default: bool
2304    specialize_zero_one: bool
2305    duck_shape: bool
2306    prefer_deferred_runtime_asserts_over_guards: bool
2307    allow_complex_guards_as_runtime_asserts: bool
2308
2309
2310class ShapeEnv:
2311    # This is a wrapper over the actual __init__ function.
2312    #
2313    # Where to add a new constructor parameter to ShapeEnv?
2314    # =====================================================
2315    # This __init__ function should be used only for parameters related to event recording.
2316    # These are parameters that we don't wish to pass down the road to new ShapeEnv instances
2317    # created from replaying events.
2318    #
2319    # If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event
2320    # recording, do so in the _init function.
2321    def __init__(
2322        self, *,
2323        should_record_events: Optional[bool] = None,
2324        tracked_fakes: Optional[List[Any]] = None,
2325        **kwargs
2326    ) -> None:
2327        self._init(**kwargs)
2328
2329        # Disable event recording when replaying.
2330        kwargs["should_record_events"] = False
2331
2332        from torch.fx.experimental.validator import translation_validation_enabled
2333        self._translation_validation_enabled = translation_validation_enabled()
2334
2335        # If not specified, enable event recording if both:
2336        #   - Translation validation is on
2337        #   - Translation validation bisection is not disabled
2338        self.should_record_events = (
2339            should_record_events
2340            if should_record_events is not None
2341            else (
2342                self._translation_validation_enabled
2343                and not config.translation_validation_no_bisect
2344            )
2345        )
2346
2347        # Enable event recording check if both:
2348        #   - It should record events
2349        #   - The recording check is enabled
2350        self.check_recorded_events = (
2351            self.should_record_events and config.check_shape_env_recorded_events
2352        )
2353
2354        # This will make sure we only record the top-level function call.
2355        self.is_recording = not self.should_record_events
2356        # Keep track of the list of tracked fakes.
2357        self.tracked_fakes = tracked_fakes
2358        # List of events for reconstructing ShapeEnv at arbitrary points in time.
2359        self.events: List[ShapeEnvEvent] = (
2360            [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else []
2361        )
2362
2363        # FakeTensor per-ShapeEnv operation cache. This is used for caching
2364        # operations that contain symbolic shapes which have guards on the
2365        # ShapeEnv (so are ShapeEnv-dependent).
2366        #
2367        # NOTE: It's important that SymNodes in this cache have their ShapeEnv
2368        # stripped otherwise you end up with cycles which can only be cleaned
2369        # with the GC.
2370        self.fake_tensor_cache: Dict[torch._subclasses.fake_tensor._DispatchCacheKey,
2371                                     torch._subclasses.fake_tensor._DispatchCacheEntry] = {}
2372
2373    # Pro-tip: if you add new field to ShapeEnv, this affects some accept
2374    # tests.  Accept their output with:
2375    #
2376    #   EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal
2377    #
2378    def _init(
2379        self, *,
2380        allow_scalar_outputs=True,
2381        allow_dynamic_output_shape_ops=True,
2382        # NB: These are legacy configuration that help us make good choices
2383        # when the constraint/dynamic dims are not explicitly passed to us.
2384        # Ideally we will fix all call sites to be explicit and not have
2385        # implicit choices, but this apparently was pretty involved.
2386        assume_static_by_default=False,
2387        # Note - On 0/1 specialization
2388        #
2389        # The following options affect decisions we make about eager
2390        # specialization.  Disabling them will increase trace time (as we do
2391        # more symbolic reasoning) and can also harm the quality of generated
2392        # code (because inductor may not be able to specialize for bounds
2393        # being equal--although if we later respecialize because of a guard,
2394        # your code may be just as good as it was before.)
2395        #
2396        # When True, eagerly specialize input sizes which have 0/1.
2397        specialize_zero_one=True,
2398        # When True, assume input sizes which have the same size are
2399        # symbolically equal.
2400        duck_shape: Optional[bool] = None,
2401        # For debugging
2402        co_fields=None,
2403        # When True, whenever safe, we will generate a deferred runtime assert
2404        # instead of a guard whenever we know that an expression must be True,
2405        # otherwise it would be an error, even for backed SymInts (where we
2406        # could ostensibly unconditionally generate guards).  This is useful
2407        # for export, where preventing "error checking" sizes from showing up
2408        # in guards is helpful, since these guards in some sense are overly
2409        # pedantic.  See also https://github.com/pytorch/pytorch/issues/121749
2410        prefer_deferred_runtime_asserts_over_guards=False,
2411        # When True, does not emit or raise constraint violation errors on
2412        # implicit guards generated by ops, and defers to runtime assertions
2413        # in the graph instead. For export.
2414        allow_complex_guards_as_runtime_asserts=False,
2415        # XXX Add any new settings that could affect FakeTensor evaluation
2416        # to: torch._subclasses.fake_tensor._ShapeEnvSettings
2417    ):
2418        if duck_shape is None:
2419            duck_shape = config.use_duck_shape
2420
2421        self.settings = ShapeEnvSettings(
2422            # Not directly used by ShapeEnv; indirectly used by FakeTensor
2423            allow_scalar_outputs=allow_scalar_outputs,
2424            allow_dynamic_output_shape_ops=allow_dynamic_output_shape_ops,
2425            # End
2426            assume_static_by_default=assume_static_by_default,
2427            specialize_zero_one=specialize_zero_one,
2428            duck_shape=duck_shape,
2429            prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
2430            allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
2431        )
2432
2433        self.guards: List[ShapeGuard] = []
2434        # Maps symbolic ints to their original concrete values
2435        # Currently populated from tensors
2436        self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
2437        # Like var_to_val, but only set when propagate_real_tensors is on.
2438        # Used as last resort to avoid GuardOnDataDependent error
2439        self.unbacked_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
2440        # Maps symbolic ints to their min/max range.  These ranges
2441        # are conservative: the int MUST fall in the range, but the
2442        # range may contain ints which may not actually appear in
2443        # practice
2444        self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
2445        self.source_name_to_debug_name: Dict[str, str] = {}
2446        self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
2447        self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
2448        # Maps from sympy ints to expressions representing them
2449        # Populated from equality guards (i.e. a.shape[0] == b.shape[0])
2450        self.replacements: Dict[sympy.Symbol, sympy.Expr] = {}
2451        self.unbacked_renamings: Dict[sympy.Symbol, sympy.Symbol] = {}
2452        # Set holds a % b expressions that evaluate to 0.
2453        self.divisible: Set[sympy.Expr] = set()
2454        # Set that holds "size-like" symbols.  When we perform
2455        # "size-oblivious" tests, these can be assumed to be >= 2.
2456        self.size_like: Set[sympy.Symbol] = set()
2457        # Duck-shaping says that if two input tensors have the same size,
2458        # they get assigned the same symbolic variable
2459        self.val_to_var: Dict[int, sympy.Expr] = {}
2460        if specialize_zero_one:
2461            self.val_to_var = {0: sympy.Integer(0), 1: sympy.Integer(1)}
2462        self.unbacked_symfloat_counter = itertools.count()
2463        self.unbacked_symint_counter = itertools.count()
2464        # Similar to guards, but these MUST evaluate to true and can
2465        # only be evaluated at runtime midway through (i.e., they always
2466        # involve unbacked symints)
2467        #
2468        # For efficiency reasons, we index in the following way.  Suppose you have
2469        # a runtime assert i0 + i1 <= s1.  We pick the most recently allocated
2470        # symbol in the source expression and add the assert to the list for
2471        # that symbol e.g., {i1: [i0 + i1 <= s1]}.
2472        #
2473        # We access the runtime asserts in two situations:
2474        #
2475        #   - When we are guarding on an expression, we will attempt to
2476        #     statically evaluate it, in case the unbacked SymInts can
2477        #     simplify away.  If we have a runtime assert, we may be able
2478        #     to discharge the guard entirely.  We only need to attempt
2479        #     runtime asserts that mention freevars of the expression in
2480        #     question.
2481        #
2482        #   - When we are performing codegen (in Inductor for eager, or
2483        #     when finalizing the export FX graph), we need to know what
2484        #     extra runtime asserts to insert.  Whenever an unbacked
2485        #     SymInt comes into scope, all runtime asserts involving it
2486        #     become eligible for insertion (so long as all of their other
2487        #     free unbacked symbols are also in scope).  We technically
2488        #     can handle any choice of key by kicking inexpressible asserts
2489        #     to the next unbacked symbol to wait on, but if we choose the
2490        #     latest key, an assert will only show up at the moment when
2491        #     we can actually codegen it.
2492        self.deferred_runtime_asserts: Dict[sympy.Symbol, List[RuntimeAssert]] = {}
2493        # This exists so we can efficiently invalidate the cache (it's used as
2494        # part of the cache key); otherwise we'd have to iterate through
2495        # deferred_runtime_asserts to compute its length
2496        self.num_deferred_runtime_asserts = 0
2497        self.log = log
2498        self.log.debug("create_env")
2499        self.frozen = False
2500        self.runtime_asserts_frozen = False
2501        self.dim_constraints: Optional[DimConstraints] = None
2502        self.counter = collections.Counter()
2503        # Mapping from sympy.Symbol to the number of guards which mention this
2504        # symbol
2505        self.symbol_guard_counter = collections.Counter()
2506        # A selection of important fields on co_field; solely used for
2507        # signpost_event
2508        self.co_fields = co_fields if co_fields else {}
2509
2510        # Whenever we allocate a fresh unbacked Symbol, we add it to this
2511        # pending list.  Unbacked symbol allocation can occur at unpredictable
2512        # points during meta tensor propagation, but at some point, the we
2513        # have to know what the binding site for an unbacked symbol is, and
2514        # this is computed when we actually place the node in the graph.  The
2515        # important thing is that we always actually handle every unaccounted
2516        # for unbacked symbol, so this list helps us keep track of them and
2517        # then make sure they are all accounted for.
2518        #
2519        # We could potentially give rise to errors earlier by lexically
2520        # scoping when we do propagation, and only allowing unbacked symbols
2521        # to be allocated at this point in time.  However this is inconvenient
2522        # to do in Dynamo, because fake tensor propagation is far from when we
2523        # analyze binding sites (set_example_value), so we do it in a more
2524        # mutatey way.
2525        #
2526        # NB: fresh unbacked symbols NEVER get substitutions applied to them,
2527        # they are binding sites!
2528        self.pending_fresh_unbacked_symbols: List[sympy.Symbol] = []
2529
2530        # Version counter used to invalidate cached values
2531        self._prev_cache_key = self._get_key()
2532        self._version_counter = 0
2533
2534        # Cache for FX nodes.
2535        # Maps an already built node a tuple of:
2536        #   1. node's target
2537        #   2. list of arguments
2538        # This drastically reduces the size of the FX graph, avoiding
2539        # duplicated nodes.
2540        self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {}
2541        self.source_to_symbol: Dict[str, sympy.Symbol] = {}
2542
2543        # Suppose you want to replace an unbacked symbol with another
2544        # unbacked symbol.  This is error prone because you can cause
2545        # references to unbacked symbols to time travel backwards.  E.g.,
2546        #
2547        # u1 = x.item()
2548        # ... use of u1 ...
2549        # u2 = y.item()
2550        # u3 = z.item()
2551        # torch._check(u1 == u2 + u3)
2552        #
2553        # If you replace u1 with u2 + u3, then the use of u1 now
2554        # references u2 and u3 prior to them actually being bound at
2555        # runtime.
2556        #
2557        # To control for this, we track the order unbacked symbols
2558        # were allocated, and only allow substitutions if they respect
2559        # the dependency from this order; an unbacked symbol can only
2560        # be substituted with unbacked symbols that come before it in the
2561        # order.
2562        #
2563        # This also imposes an ordering on the unbacked symbol binding
2564        # sites themselves: you are not allowed to reorder unbacked symbol
2565        # bindings.  At the moment, this is not tracked, but we potentially
2566        # could track this at the IR level using a higher order operator
2567        # with something like effect token tracking.
2568        self.unbacked_alloc_order: Dict[sympy.Symbol, int] = {}
2569
2570        from torch.fx.experimental.validator import translation_validation_enabled
2571        self._translation_validation_enabled = translation_validation_enabled()
2572
2573        if self._translation_validation_enabled:
2574            from torch.fx.experimental.validator import TranslationValidator
2575
2576            self.validator = TranslationValidator()
2577            self.graph = torch.fx.Graph()
2578            # Create an output graph and start inserting before that.
2579            # This is needed when 'deepcopy'-ing this object.
2580            self.graph.inserting_before(self.graph.output(None))
2581
2582            # Mapping of each node name to the node itself.
2583            #
2584            # This is useful for matching an FX node from a recorded ShapeEnv.graph
2585            # to the FX node of the ShapeEnv we are running the event on.
2586            #
2587            # Whenever you add a node to self.graph, you must add a mapping to this
2588            # variable. Otherwise, the built FX graph on the replayed ShapeEnv will
2589            # not be valid.
2590            self.name_to_node: Dict[str, torch.fx.Node] = {}
2591
2592    @property
2593    def allow_scalar_outputs(self):
2594        return self.settings.allow_scalar_outputs
2595
2596    @property
2597    def allow_dynamic_output_shape_ops(self):
2598        return self.settings.allow_dynamic_output_shape_ops
2599
2600    @property
2601    def assume_static_by_default(self):
2602        return self.settings.assume_static_by_default
2603
2604    @property
2605    def specialize_zero_one(self):
2606        return self.settings.specialize_zero_one
2607
2608    @property
2609    def duck_shape(self):
2610        return self.settings.duck_shape
2611
2612    @property
2613    def prefer_deferred_runtime_asserts_over_guards(self):
2614        return self.settings.prefer_deferred_runtime_asserts_over_guards
2615
2616    @property
2617    def allow_complex_guards_as_runtime_asserts(self):
2618        return self.settings.allow_complex_guards_as_runtime_asserts
2619
2620    def check_equal(self, other: "ShapeEnv") -> None:
2621        """Compare another ShapeEnv for equivalence
2622        """
2623        # ShapeEnv fields that are not relevant for the outcome of
2624        # ShapeEnv.produce_guards call:
2625        #   - Debugging variables
2626        #   - Translation validation related variables
2627        #   - Events recording related variables
2628        non_state_variable_names = (
2629            "counter",
2630            "log",
2631            "var_to_stack",
2632            "fx_node_cache",
2633            "graph",
2634            "validator",
2635            "check_recorded_events",
2636            "should_record_events",
2637            "is_recording",
2638            "tracked_fakes",
2639            "events",
2640            "source_name_to_debug_name",
2641            "_prev_cache_key",
2642            "_version_counter",
2643            "dim_constraints",
2644        )
2645
2646        # Mapping of the value of each to-be-compared field into the values that
2647        # should actually be compared.
2648        #
2649        # You should modify this if, for example, the field that holds state and
2650        # debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr)
2651        # and the stack when it was added to the set of guards. In order to compare
2652        # it, we throw away the stack information.
2653        def map_value(key: str, value: Any) -> Any:
2654            if key in ("unbacked_symfloat_counter", "unbacked_symint_counter"):
2655                from copy import copy
2656
2657                # For itertools.count(), we compare the next integer returned
2658                # by the count iterators. Not that we need to copy the iterator
2659                # first. Otherwise we are mutating the object.
2660                return next(copy(value))
2661            elif key == "guards":
2662                # Transform the list of ShapeGuard into a list of expressions.
2663                return [g.expr for g in value]
2664            elif key == "deferred_runtime_asserts":
2665                # Transform the list of RuntimeAsserts into a list of expressions.
2666                return {s: [ra.expr for ra in ras] for s, ras in value.items()}
2667            elif key == "name_to_node":
2668                # Compare just the set of keys is the same.
2669                return set(value.keys())
2670            elif key in ("symbol_guard_counter", "pending_fresh_unbacked_symbols", "fake_tensor_cache"):
2671                # Skip this for comparisons
2672                return None
2673            return value
2674
2675        shape_env_check_state_equal(self, other, non_state_variable_names, map_value)
2676
2677    def _snapshot_tracked_fakes(self) -> Optional[List[Any]]:
2678        if self.tracked_fakes is None:
2679            return None
2680
2681        from torch._dynamo.variables.builder import TrackedFake
2682
2683        def maybe_transform_fake(fake: TrackedFake):
2684            inner_fake = fake.fake \
2685                if isinstance(fake.fake, (torch.SymInt, torch.SymFloat)) \
2686                else FakeTensorMeta.from_fake(fake.fake)
2687            # Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a
2688            # FakeTensorMeta for two reasons:
2689            #   1. this is all the information we need when recording ShapeEnvEvents.
2690            #   2. it works even if each TrackedFake changes its metadata.
2691            return TrackedFake(inner_fake, fake.source, fake.symbolic_context)  # type: ignore[arg-type]
2692
2693        return [maybe_transform_fake(fake) for fake in self.tracked_fakes]
2694
2695    def _last_event_index(self) -> int:
2696        return len(self.events) - 1
2697
2698    @contextmanager
2699    def _recording(self):
2700        self.is_recording = True
2701        try:
2702            yield
2703        finally:
2704            self.is_recording = False
2705
2706    @record_shapeenv_event()
2707    def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr):
2708        self._set_replacement(orig_s, new_s, "eliminate_unbacked")
2709
2710    @record_shapeenv_event()
2711    def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None:
2712        """Used only when propagate_real_tensors; registers a value for an
2713        unbacked symbol, which can be used last resort to resolve hints."""
2714        self.unbacked_var_to_val[k] = sympy.sympify(v)
2715
2716    # Unlike set_replacement, this records a shapeenv event
2717    @record_shapeenv_event()
2718    def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol):
2719        assert isinstance(orig_s, sympy.Symbol), orig_s
2720        assert isinstance(new_s, sympy.Symbol), new_s
2721        assert free_unbacked_symbols(new_s), new_s
2722        assert free_unbacked_symbols(orig_s), orig_s
2723        dest = self.replacements.get(orig_s)
2724        assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}"
2725        self._set_replacement(orig_s, new_s, "rename_unbacked_to")
2726        self.unbacked_renamings[orig_s] = new_s
2727        if dest is not None:
2728            self._set_replacement(new_s, dest, "rename_unbacked_to_dest")
2729
2730    @record_shapeenv_event()
2731    def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None):
2732        if min is None:
2733            min = 0
2734        if max is None:
2735            max = int_oo
2736
2737        if max < min:
2738            raise ValueError(
2739                "Maximum value to constrain_as_size can't be less than the specified min value, "
2740                "received min={min} and max={max}"
2741            )
2742
2743        self.constrain_symbol_range(
2744            a,
2745            compiler_min=min,
2746            compiler_max=max,
2747        )
2748        self.size_like.add(a)
2749
2750    @record_shapeenv_event()
2751    def _constrain_range(self, a: sympy.Expr, min: int, max: int):
2752        if isinstance(a, sympy.Integer):
2753            if not (min <= int(a) <= max):
2754                raise ValueRangeError(f"Invalid value {int(a)} for range [{min}:{max}]")
2755            return
2756
2757        # TODO: Shouldn't we install a guard if the symbol is backed?  Or is the
2758        # semantics that this is an "unchecked" assert (but it this actually
2759        # something useful?  Might be better to restrict only for unbacked
2760        # SymInt).
2761        if isinstance(a, sympy.Symbol):
2762            self.constrain_symbol_range(
2763                a,
2764                compiler_min=min,
2765                compiler_max=max,
2766            )
2767
2768    @record_shapeenv_event()
2769    def _constrain_unify(self, a, b):
2770        """
2771        Given two SymInts, constrain them so that they must be equal.  NB:
2772        this will not work with SymInts that represent nontrivial expressions
2773        (yet!)
2774        """
2775        # TODO: this does not install a deferred runtime assert yet
2776
2777        # TODO: Maybe dedupe this with _maybe_guard_rel?
2778        # Update Feb 2024: this is extra important to do, this doesn't handle
2779        # unbacked replacements properly nor does it generate deferred runtime
2780        # asserts
2781        if not isinstance(a, SymInt):
2782            if not isinstance(b, SymInt):
2783                assert a == b
2784            else:
2785                assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
2786                assert b.node.shape_env is self
2787                self.replacements[b.node.expr] = sympy.Integer(a)
2788        else:
2789            # TODO: Actually, we can support this as long as one of them is a symbol.
2790            # NB: We can't actually do "unification" as our operators are not
2791            # injective
2792            assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
2793            assert a.node.shape_env is self
2794            if not isinstance(b, SymInt):
2795                self.replacements[a.node.expr] = sympy.Integer(b)
2796            else:
2797                assert a.node.shape_env is b.node.shape_env
2798                assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
2799                new_var = self._find(a.node.expr)
2800                self.replacements[b.node.expr] = new_var
2801
2802    def _ignore_fresh_unbacked_symbols_tls(self):
2803        return getattr(TLS, "ignore_fresh_unbacked_symbols", False)
2804
2805    @record_shapeenv_event()
2806    def _ignore_fresh_unbacked_symbols_enter(self):
2807        TLS.ignore_fresh_unbacked_symbols = True
2808
2809    @record_shapeenv_event()
2810    def _ignore_fresh_unbacked_symbols_exit(self):
2811        TLS.ignore_fresh_unbacked_symbols = False
2812
2813    @contextmanager
2814    def ignore_fresh_unbacked_symbols(self):
2815        """
2816        Indicates that the newly allocated unbacked SymInts are being
2817        discarded
2818        """
2819        self._ignore_fresh_unbacked_symbols_enter()
2820        try:
2821            yield
2822        finally:
2823            self._ignore_fresh_unbacked_symbols_exit()
2824
2825    @record_shapeenv_event()
2826    def freeze(self):
2827        """Freeze this ShapeEnv to stop accumulating guards
2828
2829        A frozen ShapeEnv will ignore any further guards generated on it and
2830        only emit a warning which may lead to accuracy problems.
2831        """
2832        self.frozen = True
2833
2834    @record_shapeenv_event()
2835    def freeze_runtime_asserts(self):
2836        """Freeze this ShapeEnv to stop adding deferred runtime asserts.
2837
2838        We will error if you try to install a new runtime assert when it is
2839        frozen.  This would indicate a lowering violation, or perhaps something
2840        we know statically is already True but we are checking it again in a way
2841        that is not clearly dischargeable.
2842        """
2843        # self.prefer_deferred_runtime_asserts_over_guards = False
2844        self.runtime_asserts_frozen = True
2845
2846    def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]:
2847        if not self._translation_validation_enabled:
2848            return None
2849        srcname = source.name()
2850        if source not in self.source_to_symbol:
2851            self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
2852        return self.source_to_symbol[srcname]
2853
2854    def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None:
2855        if self._translation_validation_enabled:
2856            self.validator.add_var(symbol, type)
2857
2858    def _add_target_expr(self, expr) -> None:
2859        if self._translation_validation_enabled:
2860            self.validator.add_target_expr(expr)
2861
2862    def _add_assertion(self, expr) -> None:
2863        if self._translation_validation_enabled:
2864            self.validator.add_assertion(expr)
2865
2866    def _check_translation_validate(self) -> None:
2867        if self._translation_validation_enabled:
2868            self.validator.validate()
2869
2870    @record_shapeenv_event()
2871    def _create_fx_call_function(
2872            self,
2873            op: Callable,
2874            args: Tuple,
2875    ) -> Tuple[Optional[torch.fx.Node], bool]:
2876        # Cache this tuple in order to avoid duplicated nodes.
2877        node_key = (op, args)
2878        # Flags whether the returned node was cached or not.
2879        fresh = False
2880
2881        if self._translation_validation_enabled and node_key not in self.fx_node_cache:
2882
2883            # Presence of None in the arguments implies that we should ignore this operation.
2884            if any(a is None for a in args):
2885                # We check if we are not mixing SymNode that should not be ignored
2886                # (fx_node is not None) with those that should (fx_node is None).
2887                assert all(not isinstance(a, torch.fx.Node) for a in args)
2888                return None, fresh
2889
2890            fresh = True
2891
2892            # If translation validation is enabled, all arguments must have its
2893            # own FX node.
2894            assert all(a is not None for a in args), f"missing arg in FX graph ({op.__name__}): {args}"
2895            node = self.fx_node_cache[node_key] = self.graph.call_function(op, args)
2896            self.name_to_node[node.name] = node
2897
2898        return self.fx_node_cache.get(node_key, None), fresh
2899
2900    def _create_fx_placeholder_and_z3var(
2901            self,
2902            symbol: sympy.Symbol,
2903            type: Type,
2904    ) -> Optional[torch.fx.Node]:
2905        if not self._translation_validation_enabled:
2906            return None
2907
2908        node_key = (self.graph.placeholder, (symbol,))
2909
2910        # Check if we haven't added this symbol already.
2911        # If so, skip the placeholder creation, as it
2912        # generates invalid Python code.
2913        if node_key not in self.fx_node_cache:
2914            # Add a Z3 variable according to 'type'.
2915            self._add_z3var(symbol, type)
2916            # Create the FX placeholder out of a mangled name.
2917            mangled_name = re.sub(r'[^a-zA-Z0-9]', '_', re.sub(r'[()]', '', symbol.name))
2918            node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name)
2919            self.name_to_node[node.name] = node
2920            # Attach the 'symbol' to the placeholder so that we can retrieve
2921            # the Z3 variable later.
2922            node.meta["symbol"] = symbol
2923
2924        return self.fx_node_cache[node_key]
2925
2926    def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None:
2927        if self._translation_validation_enabled and node is not None:
2928            self.name_to_node.pop(node.name)
2929            self.graph.erase_node(node)
2930
2931    def _add_fx_node_metadata(self, node: torch.fx.Node) -> None:
2932        from torch._dynamo.utils import get_current_node
2933
2934        if self.should_record_events:
2935            node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index()
2936            node.meta[CURRENT_NODE_KEY] = get_current_node()
2937
2938    def _suppress_guards_tls(self):
2939        return getattr(TLS, "suppress_guards", False)
2940
2941    @record_shapeenv_event()
2942    def _suppress_guards_enter(self):
2943        TLS.suppress_guards = True
2944
2945    @record_shapeenv_event()
2946    def _suppress_guards_exit(self):
2947        TLS.suppress_guards = False
2948
2949    @contextmanager
2950    def suppress_guards(self):
2951        """Context manager to ignore all guards generated inside"""
2952        self._suppress_guards_enter()
2953        try:
2954            yield
2955        finally:
2956            self._suppress_guards_exit()
2957
2958    def _get_key(self):
2959        """
2960        Defines the current "state" of the guards we've accumulated in this ShapeEnv.
2961        Determines when we need to invalidate our cache
2962        """
2963        return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts, len(self.unbacked_var_to_val))
2964
2965    def _update_version_counter(self):
2966        # The shape environment is queried orders of magnitude more often than
2967        # it is changed, so we summarise the cache key into a linearly
2968        # increasing version counter which is cheaper to check in _lru_cache
2969
2970        # Only update version counter if the state actually changed
2971        cur_key = self._get_key()
2972        if self._prev_cache_key != cur_key:
2973            self._prev_cache_key = cur_key
2974            self._version_counter += 1
2975
2976    def _produce_dyn_sizes(self,
2977                           ex_size: Sequence[int],
2978                           source: Source,
2979                           symbolic_context: SymbolicContext
2980                           ) -> List[sympy.Expr]:
2981        return self._produce_dyn_sizes_from_int_tuple(tuple(ex_size), source, symbolic_context)
2982
2983    def _produce_dyn_sizes_from_int_tuple(self,
2984                                          tensor_size: Tuple[int],
2985                                          source: Source,
2986                                          symbolic_context: SymbolicContext,
2987                                          ) -> List[sympy.Expr]:
2988        assert all(not is_symbolic(val) for val in tensor_size), f"Expect size to be a plain tuple of ints but got {tensor_size}"
2989        from torch._dynamo.source import TensorPropertySource, TensorProperty
2990        _assert_symbol_context(symbolic_context)
2991        dynamic_dims = symbolic_context.dynamic_sizes
2992        constraint_dims = symbolic_context.constraint_sizes
2993        size = []
2994        for i, val in enumerate(tensor_size):
2995            size.append(self.create_symbol(
2996                val,
2997                TensorPropertySource(source, TensorProperty.SIZE, i),
2998                dynamic_dims[i],
2999                constraint_dims[i],
3000                symbolic_context=symbolic_context
3001            ))
3002        return size
3003
3004    def create_symbolic_sizes_strides_storage_offset(
3005        self,
3006        ex: torch.Tensor,
3007        source: Source,
3008        *,
3009        symbolic_context: Optional[SymbolicContext] = None,
3010    ):
3011        """
3012        Returns a list of symbolic sizes and strides for the given tensor.
3013        We try our best to express stride in terms of the sizes, so as to not
3014        introduce new symbolic variables.
3015        """
3016
3017        ex_size = tuple(self._maybe_specialize_sym_int_with_hint(sz) for sz in ex.size())
3018        ex_stride = tuple(self._maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride())
3019        ex_storage_offset = self._maybe_specialize_sym_int_with_hint(ex.storage_offset())
3020
3021        return self._create_symbolic_sizes_strides_storage_offset(
3022            ex_size,
3023            ex_stride,
3024            ex_storage_offset,
3025            [_is_dim_dynamic(ex, i) for i in range(ex.dim())],
3026            source,
3027            symbolic_context=symbolic_context,
3028        )
3029
3030    # Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").
3031    # We create symbols in shape_env using the backed hints behind SymInt.
3032
3033    # Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape.
3034    # produce_guards will trigger specializations on the outer stuff
3035
3036    # Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint().
3037    #
3038    # It's probably good for now but it's important to note that this approach has implications for
3039    # the original shape_env when checking guards in different order.
3040
3041    # Example:
3042    # ---------
3043    # Consider a function "opt_f" as shown below:
3044
3045    # @torch.compile()
3046    # def opt_f(x: bool, y: Tensor):
3047    #   if x == True:
3048    #     return y + torch.randn([4])
3049    #   else:
3050    #     return y
3051    # Depending on the sequence of calls, we might install two different sets of guards:
3052
3053    # 1. opt_f(False, y):
3054    #    - "x == False" (always works for any size y)
3055
3056    # 2. opt_f(True, y):
3057    #    - Triggers recompilation and results in guards like:
3058    #      - "x == True and y.size(0) == 4"
3059    #      - (or "y.size(0) == 4 and x == True")
3060
3061    # The order of checking the guards matters. In this specific example:
3062    # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
3063    # we may have an unnessary shape speciliazation for y.
3064    def _maybe_specialize_sym_int_with_hint(self, maybe_sym) -> int:
3065        assert isinstance(maybe_sym, (int, torch.SymInt))
3066        if is_symbolic(maybe_sym):
3067            assert maybe_sym.node.shape_env is not self, \
3068                "expect the symbol is created from an shape env other than current one."
3069            return maybe_sym.node.require_hint()
3070        return maybe_sym
3071
3072    @record_shapeenv_event()
3073    def _create_symbolic_sizes_strides_storage_offset(
3074        self,
3075        ex_size: Sequence[int],
3076        ex_stride: Sequence[int],
3077        ex_storage_offset: int,
3078        is_dim_dynamic: Sequence[bool],
3079        source: Source,
3080        *,
3081        symbolic_context: Optional[SymbolicContext] = None,
3082    ):
3083        dim = len(ex_size)
3084
3085        # Reimplement the legacy behavior
3086        if symbolic_context is None:
3087            constraint_sizes = [None] * dim
3088            constraint_strides = [None] * dim
3089            dynamic_dims = []
3090            dynamic_strides = []
3091            for i in range(dim):
3092                # NB: This is encapsulation breaking!  Legacy behavior was
3093                # bad.
3094                if is_dim_dynamic[i]:
3095                    r = DimDynamic.DYNAMIC
3096                elif self.assume_static_by_default:
3097                    r = DimDynamic.STATIC
3098                else:
3099                    r = DimDynamic.DUCK
3100                dynamic_dims.append(r)
3101                dynamic_strides.append(r)
3102            dynamic_dims = [DimDynamic.DUCK] * dim
3103            dynamic_strides = [DimDynamic.INFER_STRIDE] * dim
3104            # symbolic_context is None - set one
3105            symbolic_context = StatelessSymbolicContext(
3106                dynamic_sizes=dynamic_dims,
3107                dynamic_strides=dynamic_strides,
3108                constraint_sizes=constraint_sizes,
3109                constraint_strides=constraint_strides,
3110            )
3111        # We got a StatelessSymbolicContext
3112        _assert_symbol_context(symbolic_context)
3113        constraint_sizes = symbolic_context.constraint_sizes
3114        constraint_strides = symbolic_context.constraint_strides
3115        dynamic_sizes = symbolic_context.dynamic_sizes
3116        dynamic_strides = symbolic_context.dynamic_strides
3117
3118        # TODO: make this configurable from outside symbolic_context; we made a symbolic_context
3119        # decision here where if all sizes are static, we are going to
3120        # specialize all of the inner strides/offset too. We don't have to
3121        # do this, and arguably we should ALWAYS allow for dynamic offset,
3122        # this is cheap.
3123        # TODO: This should be DYNAMIC, using DUCK for BC
3124        dynamic_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_sizes) else DimDynamic.DUCK
3125        are_sizes_static = all(r == DimDynamic.STATIC for r in dynamic_sizes)
3126
3127        assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}"
3128        assert len(dynamic_strides) == dim, f"{len(dynamic_sizes)} != {dim}"
3129        assert len(constraint_sizes) == dim
3130        assert len(constraint_strides) == dim
3131
3132        from torch._dynamo.source import TensorPropertySource, TensorProperty
3133        size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context)
3134        stride: List[Optional[sympy.Expr]] = [None] * len(size)
3135        for i, val in enumerate(ex_stride):
3136            if val in (0, 1):
3137                stride[i] = sympy.Integer(val)
3138        while any(x is None for x in stride):
3139            candidates = {
3140                ex_size[i] * ex_stride[i]: size[i] * stride[i]
3141                for i in range(len(size))
3142                if stride[i] is not None and ex_stride[i] >= 0
3143            }
3144
3145            # iterate over unbound strides in sorted order
3146            def _nested_int_aware_sort(tup):
3147                return (
3148                    # Order nested ints by their coefficients.
3149                    # 1 here to order nested ints after non-nested-ints.
3150                    (1, tup[0].node.nested_int_coeff(), tup[1]) if is_nested_int(tup[0])
3151                    else (0, *tup)
3152                )
3153            val_list = sorted(
3154                [(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None],
3155                key=_nested_int_aware_sort,
3156            )
3157            for _, i in val_list:
3158                # Set stride to a candidate only for DimDynamic.INFER_STRIDE
3159                if stride[i] is None and dynamic_strides[i] == DimDynamic.INFER_STRIDE and ex_stride[i] in candidates:
3160                    stride[i] = candidates[ex_stride[i]]
3161                    candidates[ex_size[i] * ex_stride[i]] = size[i] * stride[i]
3162
3163            if any(x is None for x in stride):
3164                # bind the smallest unbound stride to a new variable
3165                val, i = min(
3166                    [
3167                        (ex_stride[i], i)
3168                        for i in range(len(stride))
3169                        if stride[i] is None
3170                    ], key=_nested_int_aware_sort
3171                )
3172                # Set INFER_STRIDE to STATIC or DUCK depending on sizes
3173                dyn_stride = dynamic_strides[i]
3174                if dynamic_strides[i] == DimDynamic.INFER_STRIDE:
3175                    dyn_stride = DimDynamic.STATIC if are_sizes_static else DimDynamic.DUCK
3176                stride[i] = self.create_symbol(
3177                    val,
3178                    TensorPropertySource(source, TensorProperty.STRIDE, i),
3179                    dynamic_dim=dyn_stride,
3180                    constraint_dim=constraint_strides[i],
3181                    symbolic_context=symbolic_context,
3182                )
3183        assert all(x is not None for x in stride)
3184
3185        sym_sizes = [
3186            self.create_symintnode(
3187                sym,
3188                hint=hint,
3189                source=TensorPropertySource(source, TensorProperty.SIZE, i),
3190            )
3191            for i, (sym, hint) in enumerate(zip(size, ex_size))
3192        ]
3193        sym_stride = []
3194        for i, stride_expr in enumerate(stride):
3195            # NB: Don't duck size the stride; instead use the expression
3196            # we computed
3197            assert stride_expr is not None
3198            sym_stride.append(self.create_symintnode(
3199                stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i)))
3200        sym_storage_offset = self.create_symintnode(
3201            self.create_symbol(
3202                ex_storage_offset,
3203                TensorPropertySource(source, TensorProperty.STORAGE_OFFSET),
3204                dynamic_dim=dynamic_offset,
3205                constraint_dim=None,
3206                symbolic_context=symbolic_context
3207            ),
3208            hint=ex_storage_offset,
3209            source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET))
3210        return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset
3211
3212    @record_shapeenv_event()
3213    def create_symintnode(
3214            self,
3215            sym: "sympy.Expr",
3216            *,
3217            hint: Optional[int],
3218            source: Optional[Source] = None,
3219    ):
3220        """Create a SymInt value from a symbolic expression
3221
3222        If you know what the current hint value of the SymInt to be created
3223        is, pass it into hint.  Otherwise, pass None and we will make our best
3224        guess
3225
3226        """
3227        source_name = source.name() if source else None
3228
3229        if self._translation_validation_enabled and source is not None:
3230            # Create a new symbol for this source.
3231            symbol = self._create_symbol_for_source(source)
3232            assert symbol is not None
3233
3234            # Create a new FX placeholder and Z3 variable for 'symbol'.
3235            fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
3236
3237            # Add an equality assertion for the newly created symbol and 'sym'.
3238            self._add_assertion(sympy.Eq(symbol, sym))
3239        else:
3240            fx_node = None
3241
3242        if isinstance(sym, sympy.Integer):
3243            if hint is not None:
3244                assert int(sym) == hint
3245            out = int(sym)
3246        else:
3247            # How can this occur? When we mark_unbacked, we end up with a real
3248            # tensor that has hints for all sizes, but we MUST NOT create a
3249            # SymNode with a hint, because we're hiding the hint from our eyes
3250            # with the unbacked Symbol.  And in fact, the hint compute may be
3251            # inconsistent with size oblivious tests.
3252            if free_unbacked_symbols(sym):
3253                hint = None
3254            out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
3255        return out
3256
3257    @record_shapeenv_event()
3258    def create_symfloatnode(
3259            self,
3260            sym: "sympy.Expr",
3261            *,
3262            hint: Optional[int],
3263            source: Optional[Source] = None,
3264    ):
3265        """Create a SymFloat value from a symbolic expression"""
3266        source_name = source.name() if source else None
3267
3268        if self._translation_validation_enabled and source is not None:
3269            # Create a new symbol for this source.
3270            symbol = self._create_symbol_for_source(source)
3271            assert symbol is not None
3272
3273            # Create a new FX placeholder and Z3 variable for 'symbol'.
3274            fx_node = self._create_fx_placeholder_and_z3var(symbol, float)
3275
3276            # Add an equality assertion for the newly created symbol and 'sym'.
3277            self._add_assertion(sympy.Eq(symbol, sym))
3278        else:
3279            fx_node = None
3280
3281        if isinstance(sym, sympy.Float):
3282            if hint is not None:
3283                assert float(sym) == hint
3284            out = float(sym)
3285        else:
3286            # You could give this the same treatment as SymInt above if
3287            # you supported mark_unbacked on a float, but it's a kind of
3288            # strange thing to do though because floats don't get 0/1
3289            # specialization anyway
3290            if free_unbacked_symbols(sym):
3291                assert hint is None, sym
3292            out = SymFloat(SymNode(sym, self, float, hint, fx_node=fx_node))
3293        return out
3294
3295    @record_shapeenv_event()
3296    def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim):
3297        """Create a SymInt wrapping a new unspecified symbol"""
3298        return self.create_symintnode(
3299            self.create_unspecified_symbol(
3300                value,
3301                source=source,
3302                dynamic_dim=dynamic_dim,
3303            ),
3304            hint=value,
3305            source=source,
3306        )
3307
3308    def create_symboolnode(self, sym: "sympy.Expr"):
3309        """Create a SymBool object from a sympy boolean expression"""
3310        # This function is only being used in serialization, so we do not track it
3311        # for validation.
3312        return SymBool(SymNode(sym, self, bool, None))
3313
3314    def _log_create_unbacked_symbol(self, prefix: str, symbol, vr: ValueRanges):
3315        is_debug = config.extended_debug_create_symbol is not None and str(symbol) in config.extended_debug_create_symbol.split(',')
3316        fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
3317        log.info(
3318            "%s %s [%s, %s]%s (%s)%s",
3319            prefix, symbol, vr.lower, vr.upper, maybe_user_loc, format_frame(fsummary), maybe_extra_debug, stack_info=is_debug
3320        )
3321
3322    @record_shapeenv_event()
3323    def create_unbacked_symfloat(self):
3324        """Create a symbolic float without a hint value
3325        """
3326        symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_FLOAT, next(self.unbacked_symfloat_counter))
3327        self.counter["create_unbacked_symbol"] += 1
3328        if not self._ignore_fresh_unbacked_symbols_tls():
3329            self.pending_fresh_unbacked_symbols.append(symbol)
3330        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
3331        vr = self.var_to_range[symbol] = ValueRanges.unknown()
3332        assert vr.is_float
3333
3334        # Create a new FX placeholder and Z3 variable for 'symbol'.
3335        fx_node = self._create_fx_placeholder_and_z3var(symbol, float)
3336
3337        self._log_create_unbacked_symbol("create_unbacked_symfloat", symbol, vr)
3338
3339        return SymFloat(SymNode(symbol, self, float, None, fx_node=fx_node))
3340
3341    @record_shapeenv_event()
3342    def create_unbacked_symint(self):
3343        """Create a symbolic integer without a hint value
3344        """
3345        symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True)
3346        if not self._ignore_fresh_unbacked_symbols_tls():
3347            self.pending_fresh_unbacked_symbols.append(symbol)
3348        self.counter["create_unbacked_symbol"] += 1
3349        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
3350        vr = self.var_to_range[symbol] = self._default_unspecified_value_range()
3351        assert vr.is_int
3352
3353        # Create a new FX placeholder and Z3 variable for 'symbol'.
3354        fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
3355
3356        self._log_create_unbacked_symbol("create_unbacked_symint", symbol, vr)
3357
3358        return SymInt(SymNode(symbol, self, int, None, fx_node=fx_node))
3359
3360    def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
3361        """Check if a sympy symbol matches the naming convention for unbacked symbols
3362        """
3363        return symbol_is_type(symbol, SymT.UNBACKED_INT)
3364
3365    @record_shapeenv_event()
3366    def create_unbacked_symbool(self):
3367        """Create a symbolic boolean without a hint value
3368        """
3369        symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True)
3370        if not self._ignore_fresh_unbacked_symbols_tls():
3371            self.pending_fresh_unbacked_symbols.append(symbol)
3372        self.counter["create_unbacked_symbol"] += 1
3373        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
3374        vr = self.var_to_range[symbol] = ValueRanges(0, 1)
3375        assert vr.is_int
3376
3377        # Create a new FX placeholder and Z3 variable for 'symbol'.
3378        fx_node = self._create_fx_placeholder_and_z3var(symbol, bool)
3379
3380        self._log_create_unbacked_symbol("create_unbacked_symbool", symbol, vr)
3381
3382        return SymBool(SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node))
3383
3384    @record_shapeenv_event()
3385    def create_unspecified_symbol(
3386        self,
3387        val: Union[int, SymInt, float, SymFloat],
3388        source: Source,
3389        dynamic_dim: DimDynamic = DimDynamic.DUCK,
3390        constraint_dim: DimConstraint = None,  # NB: includes None
3391    ) -> "sympy.Expr":
3392        """Create a symbol with an unspecified value
3393
3394        Compared to standard symbols we do not assume the value is positive,
3395        nor do we specialze on zero or one values.
3396        """
3397        # 'positive' is None for unspecified symbols, since we can't
3398        # assume that it will be neither positive nor negative.
3399
3400        # We don't want to specialize zero one val for unspecified symbol
3401        # so that we can always get a new symbol despite val.
3402        return self.create_symbol(
3403            val,
3404            source,
3405            dynamic_dim,
3406            constraint_dim,
3407            positive=None,
3408            do_not_specialize_zero_one=True,
3409            symbolic_context=None)
3410
3411    @record_shapeenv_event()
3412    def create_symbol(
3413        self,
3414        val: int,
3415        source: Source,
3416        dynamic_dim: DimDynamic = DimDynamic.DUCK,
3417        constraint_dim: DimConstraint = None,  # NB: includes None
3418        positive: Optional[bool] = True,
3419        do_not_specialize_zero_one: bool = False,
3420        symbolic_context=None,
3421    ) -> "sympy.Expr":
3422        """Create a new symbol which is tracked by this ShapeEnv
3423        """
3424        # check if constraint_dim is actually static integer
3425        if isinstance(constraint_dim, StrictMinMaxConstraint) and constraint_dim.vr.lower == constraint_dim.vr.upper:
3426            dynamic_dim = DimDynamic.STATIC
3427            if constraint_dim.vr.lower != val:
3428                raise ConstraintViolationError(
3429                    f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, "
3430                    f"for {source.name()}"
3431                )
3432            if symbolic_context:
3433                symbolic_context.dynamic_sizes[source.idx] = dynamic_dim
3434                symbolic_context.constraint_sizes[source.idx] = None
3435            constraint_dim = None
3436
3437        # see note [Tensor Fakification and Symbol Caching]
3438        source_name = source.name()
3439        if (isinstance(symbolic_context, StatefulSymbolicContext)
3440                and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache):
3441            symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {}
3442
3443        if (isinstance(symbolic_context, StatefulSymbolicContext)
3444                and source_name
3445                and (source_name in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)])):
3446            return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name]
3447
3448        if dynamic_dim is DimDynamic.SIZE_LIKE_UNBACKED:
3449            out = self.create_unbacked_symint().node.expr
3450            self._constrain_range_for_size(out)
3451            # TODO: maybe put the hint somewhere
3452            if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
3453                symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = out
3454            return out
3455
3456        if do_not_specialize_zero_one:
3457            specialize_zero_one = False
3458        else:
3459            specialize_zero_one = self.specialize_zero_one
3460
3461        assert isinstance(source, Source), f"{type(source)} {source}"
3462        assert not (positive and val < 0), f"positive set for negative value: {val}"
3463        # It's always sound to allocate a symbol as DYNAMIC.  If the user
3464        # constrained the symbol, force the symbolic_context to DYNAMIC, because our
3465        # constraint code will do weird stuff if, e.g., it's duck shaped
3466        if constraint_dim is not None:
3467            dynamic_dim = DimDynamic.DYNAMIC
3468
3469        if dynamic_dim is DimDynamic.STATIC:
3470            out = sympy.Integer(val)
3471            if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
3472                symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = out
3473            return out
3474
3475        elif dynamic_dim is DimDynamic.DUCK:
3476            # duck_shape can be used to globally turn off duck shaping, even
3477            # if it was requested
3478            duck = self.duck_shape
3479        elif dynamic_dim is DimDynamic.DYNAMIC:
3480            duck = False
3481        else:
3482            raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}")
3483
3484        if val in (0, 1) and specialize_zero_one:
3485            r = self.val_to_var[val]
3486        elif not duck or val not in self.val_to_var:
3487            # If we're not duck shaping, we always create a new symbol
3488            # Even if we're duck shaping, if we haven't seen this particular
3489            # value before, we also create a new symbol
3490            if type(val) is int or is_nested_int(val):
3491                sympy_expr = make_symbol(SymT.SIZE, len(self.var_to_val), positive=positive, integer=True)
3492            else:
3493                sympy_expr = make_symbol(SymT.FLOAT, len(self.var_to_val), positive=positive, real=True)
3494            # We always associate vars to vals
3495            if isinstance(val, int):
3496                self.var_to_val[sympy_expr] = sympy.Integer(val)
3497            elif isinstance(val, float):
3498                self.var_to_val[sympy_expr] = sympy.Float(val)
3499            else:
3500                # Only used for jagged layout nested tensors
3501                self.var_to_val[sympy_expr] = SingletonInt(val.node.nested_int(), coeff=val.node.nested_int_coeff())
3502
3503            # Do the appending later, because we always want to populate this
3504            self.var_to_sources[sympy_expr] = []
3505            # Create a Z3 variable for the new symbol.
3506            self._add_z3var(sympy_expr, int)
3507
3508            if duck:
3509                # Make sure to reuse this symbol for subsequent duck shaping
3510                self.val_to_var[val] = sympy_expr
3511
3512            if isinstance(val, int):
3513                if positive:
3514                    # Add assertions for the newly created symbols
3515                    self._add_assertion(sympy_expr > 1)
3516
3517                    # Apply default range, which assumes not zero-one
3518                    self.var_to_range[sympy_expr] = self._default_value_range()
3519                else:
3520                    self.var_to_range[sympy_expr] = self._default_unspecified_value_range()
3521
3522                # Small performance optimization: if we have a min-max constraint,
3523                # we can proactively narrow to that range
3524                if isinstance(constraint_dim, StrictMinMaxConstraint):
3525                    assert not duck
3526                    self.var_to_range[sympy_expr] &= constraint_dim.vr
3527
3528                vr = self.var_to_range[sympy_expr]
3529                assert vr.is_int
3530
3531                if val not in vr:
3532                    raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")
3533
3534                range_str = f"[{vr.lower}, {vr.upper}]"
3535            elif isinstance(val, float):
3536                self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo)
3537                range_str = f"[{vr.lower}, {vr.upper}]"
3538                assert vr.is_float
3539            else:
3540                # Skip var_range logic for SingletonInt
3541                # Only used for jagged layout nested tensors
3542                range_str = ""
3543
3544            r = sympy_expr
3545
3546            is_debug = (
3547                config.extended_debug_create_symbol is not None and
3548                str(sympy_expr) in config.extended_debug_create_symbol.split(',')
3549            )
3550            maybe_more_info = ""
3551            if not is_debug:
3552                maybe_more_info = (
3553                    ", for more info run with "
3554                    f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{sympy_expr}"'
3555                )
3556            fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
3557            self.log.info(
3558                "create_symbol %s = %s for %s %s%s (%s)%s%s",
3559                sympy_expr, val, source.name(), range_str,
3560                maybe_user_loc, format_frame(fsummary), maybe_more_info, maybe_extra_debug, stack_info=is_debug
3561            )
3562
3563            self.counter["create_symbol"] += 1
3564        else:
3565            # This implements duck-shaping: input sizes that match are assigned
3566            # the same symint
3567            r = self.val_to_var[val]
3568            self.log.debug("create_symbol %s duck sized %s", r, source.name())
3569
3570        if isinstance(r, sympy.Symbol):
3571            r_sources = self.var_to_sources[r]
3572            r_sources.append(source)
3573            if not source.is_ephemeral() and r_sources[0].is_ephemeral():
3574                # prefer non-ephemeral source first since it may be guarded on later
3575                r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0]
3576
3577            # This ensures we get zeros in symbol_guard_counts, which makes
3578            # some queries simpler (since we will accumulate mass on 0 this
3579            # way)
3580            self.symbol_guard_counter[r] = 0
3581
3582        if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
3583            symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = r
3584        return r
3585
3586    def add_var_to_val(self, expr: sympy.Symbol, val: int):
3587        """ Adds a new symbol to the symbolic environment. """
3588        log.debug("add_var_to_val %s %s", expr, val, stack_info=True)
3589        assert expr not in self.var_to_val, f"{expr} already exists"
3590        self.var_to_val[expr] = sympy.Integer(val)
3591
3592    def _debug_name(self, source):
3593        src_name = source.name()
3594        return self.source_name_to_debug_name.get(src_name, src_name)
3595
3596    def _render_range_for_constraint_violation(self, source, c):
3597        if isinstance(c, StrictMinMaxConstraint):
3598            lower, upper = c.vr.lower, c.vr.upper
3599            default = self._default_value_range()
3600            if lower <= default.lower:
3601                lower = None
3602            if upper >= default.upper:
3603                upper = None
3604            c_render = f"{self._debug_name(source)} = {source.name()} in the specified range"
3605            if lower is not None and upper is not None:
3606                c_render += f" {lower} <= {self._debug_name(source)} <= {upper}"
3607            elif lower is None and upper is not None:
3608                c_render += f" {self._debug_name(source)} <= {upper}"
3609            elif lower is not None and upper is None:
3610                c_render += f" {lower} <= {self._debug_name(source)}"
3611            return c_render
3612        return c.render(source)
3613
3614    def produce_guards(
3615        self,
3616        placeholders,
3617        sources,
3618        source_ref=lambda n: n.name(),
3619        *,
3620        guards: List[ShapeGuard] = None,
3621        input_contexts: Optional[DimList[SymbolicContext]] = None,
3622        # Encodes user-specified input shape equations of the form s = s' and s = fn(s').
3623        # (See docs on EqualityConstraint for details of the encoding.)
3624        equalities_inputs: Optional[EqualityConstraint] = None,
3625        _simplified=False,
3626        # Indicates if we should produce guards for known static values.
3627        ignore_static=True,
3628    ) -> List[str]:
3629        """
3630        Generates a list of guards strings which, when evaluated in a context that
3631        defines tensors for all the sources, returns True or False depending
3632        on if the guards in the list evaluated to True or not.  Primarily used by Dynamo,
3633        but this is also helpful for manual testing of guards (see
3634        evaluate_guards_for_args)
3635
3636        For convenience in testing, a source is allowed to be a str,
3637        in which case we will assume it is a LocalSource
3638
3639        simplified lets you omit duck sizing, equality and 0/1 guards.
3640        This is useful for testing when you don't care about the boilerplate
3641        guards, and it may be helpful for user output too (be careful though;
3642        some equality guards are nontrivial!  It would be nice to get simplified
3643        output to print them too).  It's private because it's not
3644        intended for normal use
3645        """
3646        self.log.info("produce_guards")
3647
3648        # Check if we get to the same ShapeEnv state by replaying the recorded events.
3649        # This will create a new ShapeEnv instance, and call all recorded function
3650        # calls on this new instance. Finally, it will check whether this new instance
3651        # has equal state.
3652        #
3653        # It's important that we do it in the begining of this function, since it modifies
3654        # self.dim_constraints through its execution. Changes that happen in this method
3655        # aren't interesting, since this is the function call we wish to reproduce at the
3656        # end. If we wish to simply reproduce ShapeEnv instances even after this call,
3657        # this method should also be recorded.
3658        if self.check_recorded_events:
3659            shape_env = replay_shape_env_events(self.events)
3660            self.check_equal(shape_env)
3661
3662        assert len(placeholders) == len(sources), f"len({placeholders}) != len({sources})"
3663        Tensorlike = (torch.Tensor, FakeTensorMeta)
3664
3665        def _create_no_constraints_context(t):
3666            return StatelessSymbolicContext(
3667                # Ignored; only the constraints part is relevant below.
3668                dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(),
3669                dynamic_strides=[DimDynamic.INFER_STRIDE] * t.dim(),
3670                constraint_sizes=[None] * t.dim(),
3671                constraint_strides=[None] * t.dim()
3672            )
3673
3674        # Expand optional inputs, or verify invariants are upheld
3675        if input_contexts is None:
3676            input_contexts = [
3677                _create_no_constraints_context(t) if isinstance(t, Tensorlike)
3678                else None for t in placeholders
3679            ]
3680        else:
3681            assert len(input_contexts) == len(placeholders)
3682            for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
3683                if isinstance(t, Tensorlike):
3684                    if context is None:
3685                        input_contexts[i] = _create_no_constraints_context(t)
3686                else:
3687                    assert isinstance(t, (SymInt, int, SymFloat, float))
3688                    assert not isinstance(context, list)
3689
3690        # It took a lot of sweat to figure out the algorithm here.  Let's
3691        # explain how it works.
3692        #
3693        # The ShapeEnv lifecycle looks something like this:
3694        #
3695        # - For each input, you either generate a fresh Sympy symbol (s0) to
3696        #   represent its value (a binding site), or you reuse some
3697        #   preexisting symbol or expression, skipping the symbol allocation
3698        #   (e.g., duck sizing to a preexisting symbol, or expressing a
3699        #   stride as a multiplication of a separate stride and size.)
3700        #   Naively, you might expect to bind a fresh Sympy symbol for
3701        #   every input, but this is fairly wasteful as most of these
3702        #   symbols immediately simplify away, and if you don't eagerly
3703        #   specialize, e.g., 0/1 symbols, you end up with very complicated
3704        #   expressions that are not optimizable in practice.
3705        #
3706        # - You perform some compute on these symbols, occasionally
3707        #   introducing guards on boolean expressions on these symbols.
3708        #   In particular, whenever we guard on equality (_maybe_guard_rel),
3709        #   we can simplify shapes; e.g., when s0 == s1 * 2, we can now
3710        #   replace all occurrences of s0 with s1 * 2.  Sometimes, a
3711        #   boolean expression evaluation doesn't introduce a guard, as
3712        #   the guard is already entailed by the simplifications we have
3713        #   applied.
3714        #
3715        # - In the end, you have a bunch of replacements (saying how to
3716        #   simplify shapes) and a bunch of guards (all the equality guards
3717        #   are trivial, because they're covered by the replacements).
3718        #
3719        # From the ShapeEnv, we must generate a Python expression that, when
3720        # evaluated on a set of inputs, tells us whether or not these boolean
3721        # expressions would have evaluated in the same way.  However,
3722        # we cannot easily compute this, as we elide recording boolean
3723        # expressions when we think they are vacuously true.  Thus, we seek
3724        # an approximation: we must generate an expression, if true, would have
3725        # produced an "equivalent" ShapeEnv, which would answer guard
3726        # expressions in the same way.
3727        #
3728        # Our notion of equivalence is a bit subtle.  For example, consider
3729        # the ShapeEnv created from an input of size (5, 4) versus (4, 4)
3730        # (no other guards.)  Duck sizing would generate (s0, s1) in the first
3731        # case but (s0, s0) in the second.  We do NOT assume that size
3732        # variables are disjoint; so in fact a graph that assumes the input
3733        # could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not
3734        # vice versa.  However, consider an analogous case (1,) versus (2,).
3735        # Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT
3736        # subsume the (1,) graph because we assume that any size variables
3737        # is NOT 0/1 (and make simplifications according to this; e.g., if
3738        # we queried s0 == 0, we would immediately return False without
3739        # returning a guard.)
3740        #
3741        # So, it is perhaps easier to flip things on their head: the guard
3742        # expressions we generate here say what simplifications are valid,
3743        # and what are not.  Below, we explain each of the guard expressions
3744        # we generate
3745
3746        # TODO: Make this more efficient by binding all the size/stride/offsets
3747        # to locals before performing tests on them.
3748
3749        from torch._dynamo.source import TensorPropertySource, TensorProperty
3750
3751        # Actual codegen must be delayed as we don't necessarily know what
3752        # the symbol mapping is
3753        input_guards = []
3754
3755        symbol_to_source = collections.defaultdict(list)
3756        symbol_to_constraints = collections.defaultdict(set)
3757        constraint_violations : List[Tuple[bool, str, Callable[[], str]]] = []
3758
3759        def record_constraint_violation(warn_only, debug_name, msg, hint=None):
3760            constraint_violations.append(
3761                (warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg)
3762            )
3763
3764        def is_dim(src):
3765            return isinstance(src, TensorPropertySource) and src.prop is TensorProperty.SIZE
3766
3767        if equalities_inputs:
3768            source_index = {}
3769            for i, src in enumerate(sources):
3770                source_index[src.name()] = i
3771
3772            def get_expression(tensor_dim_src):
3773                fake = placeholders[source_index[tensor_dim_src.base.name()]]
3774                symint = fake.shape[tensor_dim_src.idx]
3775                if isinstance(symint, torch.SymInt):
3776                    return symint.node.expr
3777                else:
3778                    assert type(symint) is int, f"Expected int, got {type(symint)}"
3779                    return symint
3780
3781            for src1, src2 in equalities_inputs.source_pairs:
3782                expr1, expr2 = get_expression(src1), get_expression(src2)
3783                # Check whether given input shape values satisfy a specified equation s = s'.
3784                # - Raise when the equation was violated by the given input shape values.
3785                # - Otherwise issue a guard to constrain them.
3786                concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2))
3787                if not concrete_val:
3788                    raise ConstraintViolationError(
3789                        f"{src1.name()} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}"
3790                        " is not equal to "
3791                        f"{src2.name()} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}"
3792                    )
3793
3794            for src, root, fn in equalities_inputs.derived_equalities:
3795                expr1 = get_expression(src)
3796                # recall that root is either a phantom symbol or an input source
3797                expr2, debug_name = (
3798                    (root, self.var_to_sources[root][0].name()) if isinstance(root, sympy.Symbol)
3799                    else (get_expression(root), self._debug_name(root))
3800                )
3801                expr2_ = fn(expr2)
3802                # Check whether given input shape values satisfy a specified equation s = fn(s').
3803                # - Raise when the equation was violated by the given input shape values.
3804                # - Otherwise issue a guard to constrain them.
3805                concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_))
3806                if not concrete_val:
3807                    raise ConstraintViolationError(
3808                        f"Expected input {src.name()} to be equal to "
3809                        f"{fn(sympy.Symbol(debug_name))}, "
3810                        f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, "
3811                        f"but got {expr1.xreplace(self.var_to_val)}"
3812                    )
3813
3814            for phantom_symbol in equalities_inputs.phantom_symbols:
3815                # we created additional phantom symbols that are not input shape dimensions
3816                symbol_to_source[phantom_symbol].extend(self.var_to_sources[phantom_symbol])
3817
3818        # How do we know what the value of s0 is?  Fresh variables can only be
3819        # bound by inputs, so there MUST be some other input which binds the
3820        # variable.  If there is no such input, this is an error in our
3821        # system.  We record where all symbols come from, to help you diagnose
3822        # why those symbols didn't occur.
3823        #
3824        # In fact, generally speaking it is only possible for the "outermost"
3825        # user of a ShapeEnv to evaluate the guards, because some inputs may
3826        # not be available to inner levels.  For example, Dynamo can guard on
3827        # tensors that never actually become graph arguments (they are
3828        # pruned).  In this case, only Dynamo knows about these arguments.
3829        def track_symint(source, val, constraint=None):
3830            log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint)
3831            assert not isinstance(val, SymInt) or is_symbolic(val)
3832
3833            if isinstance(val, SymInt) and val.node.maybe_as_int() is not None:
3834                val = val.node.maybe_as_int()
3835
3836            if isinstance(val, SymInt):
3837                s = val.node.expr
3838                if isinstance(s, sympy.Symbol):
3839                    symbol_to_source[s].append(source)
3840                    if (
3841                        constraint is not None
3842                        and not isinstance(constraint, RelaxedUnspecConstraint)
3843                    ):
3844                        symbol_to_constraints[s].add(constraint)
3845                else:
3846                    constraint_violated = False
3847                    if isinstance(constraint, StrictMinMaxConstraint):
3848                        # try inferring the ranges of the expr s
3849                        sym_vrs = {x: self.var_to_range.get(x, None) for x in s.free_symbols}
3850                        if any(vr is None for vr in sym_vrs.values()):
3851                            # some of the free symbols in s don't have ranges
3852                            constraint_violated = True
3853                    elif isinstance(constraint, RelaxedUnspecConstraint):
3854                        if s.is_number:
3855                            i = int(s)
3856                            # Don't complain about 0/1 specialization, we
3857                            # expect to have to compile in this case anyway
3858                            if i not in (0, 1):
3859                                constraint_violated = True
3860                    if constraint_violated:
3861                        def hint(s):
3862                            sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s)
3863                            return f"{sexpr}."
3864
3865                        var_with_range = self._render_range_for_constraint_violation(source, constraint)
3866                        msg = (
3867                            f"Not all values of {var_with_range} are valid because "
3868                            f"{self._debug_name(source)} was inferred to be equal to "
3869                        )
3870                        record_constraint_violation(
3871                            constraint.warn_only,
3872                            self._debug_name(source),
3873                            msg,
3874                            hint=functools.partial(hint, s),
3875                        )
3876
3877                input_guards.append((source, s))
3878            else:
3879                s = sympy.Integer(val)
3880                input_guards.append((source, s))
3881                constraint_violated = False
3882                if isinstance(constraint, StrictMinMaxConstraint):
3883                    if not (s == constraint.vr.lower == constraint.vr.upper):  # allow static constraints
3884                        constraint_violated = True
3885                elif isinstance(constraint, RelaxedUnspecConstraint):
3886                    # Don't complain about 0/1 specialization, we
3887                    # expect to have to compile in this case anyway
3888                    if val not in (0, 1):
3889                        constraint_violated = True
3890                if constraint_violated:
3891                    var_with_range = self._render_range_for_constraint_violation(source, constraint)
3892                    msg = (
3893                        f"Not all values of {var_with_range} are valid because "
3894                        f"{self._debug_name(source)} was inferred to be a constant ({val})."
3895                    )
3896                    record_constraint_violation(constraint.warn_only, self._debug_name(source), msg)
3897
3898        def track_symfloat(source, val):
3899            log.debug("track_symfloat %s %s", LazyString(source.name), val)
3900            assert not isinstance(val, SymFloat) or is_symbolic(val)
3901
3902            if isinstance(val, SymFloat) and val.node.maybe_as_float() is not None:
3903                val = val.node.maybe_as_float()
3904
3905            if isinstance(val, SymFloat):
3906                s = val.node.expr
3907                if isinstance(s, sympy.Symbol):
3908                    symbol_to_source[s].append(source)
3909                input_guards.append((source, s))
3910            else:
3911                s = sympy.Float(val)
3912                input_guards.append((source, s))
3913
3914        for t, source, context in zip(placeholders, sources, input_contexts):
3915            if isinstance(source, str):
3916                from torch._dynamo.source import LocalSource
3917                source = LocalSource(source)
3918            assert isinstance(source, Source)
3919            if t is None:
3920                continue
3921            if isinstance(t, (SymInt, int)):
3922                track_symint(source, t)
3923                continue
3924            elif isinstance(t, (SymFloat, float)):
3925                track_symfloat(source, t)
3926                continue
3927            assert isinstance(t, Tensorlike)
3928            if is_traceable_wrapper_subclass(t):
3929                from torch._dynamo.source import AttrSource
3930
3931                assert isinstance(context, SubclassSymbolicContext)
3932
3933                # For subclasses, we need to track symints on BOTH the outer
3934                # and inner tensors.
3935                sources_tensors_constraints = [
3936                    (source, t, context.constraint_sizes, context.constraint_strides)
3937                ]
3938                attrs, _ = t.__tensor_flatten__()
3939                for attr in attrs:
3940                    inner_t = getattr(t, attr)
3941                    inner_context = context.inner_contexts[attr]
3942                    sources_tensors_constraints.append((
3943                        AttrSource(source, attr),
3944                        inner_t,
3945                        inner_context.constraint_sizes,
3946                        inner_context.constraint_strides
3947                    ))
3948            else:
3949                sources_tensors_constraints = [(source, t, context.constraint_sizes, context.constraint_strides)]
3950
3951            for src, curr_t, constraint_size, constraint_stride in sources_tensors_constraints:
3952                if is_sparse_any(curr_t):
3953                    for i, ss in enumerate(curr_t.size()):
3954                        property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
3955                        track_symint(property_source, ss, constraint_size[i])
3956                else:
3957                    for i, ss in enumerate(curr_t.size()):
3958                        property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
3959                        track_symint(property_source, ss, constraint_size[i])
3960                    for i, ss in enumerate(curr_t.stride()):
3961                        property_source = TensorPropertySource(src, TensorProperty.STRIDE, i)
3962                        track_symint(property_source, ss, constraint_stride[i])
3963                    track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset())
3964
3965        # 1. Every input must equal the final simplified symbolic expression
3966        #    stored on the placeholder.  Given a placeholder (s0*2, s1),
3967        #    if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
3968        #    This does a lot of work: it covers duck sizing and equality guards.
3969        exprs = []
3970        self.dim_constraints = DimConstraints(
3971            symbol_to_source,
3972            self.var_to_val,
3973            set(symbol_to_constraints.keys()),
3974            self.source_name_to_debug_name,
3975        )
3976
3977        if not _simplified:
3978            for source, expr in input_guards:
3979                if self._translation_validation_enabled:
3980                    # Ignore sources that were not turned into SymInts.
3981                    srcname = source.name()
3982                    if srcname in self.source_to_symbol:
3983                        self._add_target_expr(sympy.Eq(self.source_to_symbol[srcname], expr))
3984
3985                # Small optimization
3986                if (
3987                    isinstance(expr, sympy.Symbol) and
3988                    symbol_to_source.get(expr) and
3989                    source == symbol_to_source[expr][0]
3990                ):
3991                    continue
3992
3993                # This logic excludes static values found on tensors from guarding, because
3994                # dynamo's check_tensor_fn does that (see guards.cpp).
3995                # However, for non tensor sources, we still need to guard here.
3996                if ignore_static and isinstance(source, TensorPropertySource):
3997                    if expr.is_number:
3998                        self.log.debug("Skipping guard %s", f"{source_ref(source)} == {expr}")
3999                        continue
4000
4001                if is_dim(source):
4002                    self.dim_constraints.add_equality(source, expr)
4003
4004                sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
4005                exprs.append(f"{source_ref(source)} == {sexpr}")
4006                if (
4007                    isinstance(source, TensorPropertySource)
4008                    and source.prop is TensorProperty.SIZE
4009                    and equalities_inputs
4010                    and len(expr.free_symbols) == 1
4011                ):
4012                    symbol = next(iter(expr.free_symbols))
4013                    if (
4014                        isinstance(expr, sympy.Symbol) and
4015                        expr in symbol_to_constraints and
4016                        not equalities_inputs.is_equal(source, symbol_to_source[expr][0])
4017                    ):
4018                        msg = (
4019                            f"The values of {self._debug_name(source)} = {source.name()} and "
4020                            f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} "
4021                            "must always be equal."
4022                        )
4023                        record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg)
4024
4025                    if (
4026                        not isinstance(expr, sympy.Symbol) and
4027                        symbol in symbol_to_constraints and
4028                        not equalities_inputs.is_derived(source, symbol_to_source[symbol][0], lambda x: expr.xreplace({symbol: x}))
4029                    ):
4030                        src = symbol_to_source[symbol][0]
4031                        msg = (
4032                            f"The values of {self._debug_name(source)} = {source.name()} must always be related to "
4033                            f"the values of {self._debug_name(src)} = {src.name()} by "
4034                            f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}."
4035                        )
4036                        record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg)
4037
4038                # NB: Not necessary to report constraint violations here:
4039                # constraints are guaranteed to be on symbols (we've already
4040                # caught constants and non-atomic expressions), so we only
4041                # have relational constraints, but we don't support those
4042                # at the moment
4043
4044        # 2. Every guard must evaluate to True (but remember many guards
4045        #    like s0 == s1*2 because trivial due to simplification)
4046        issued = set()
4047
4048        def issue_guard(guard: ShapeGuard) -> None:
4049            expr = self.simplify(guard.expr)
4050
4051            # Avoid re-issueing the same guard.
4052            if expr in issued:
4053                return
4054
4055            issued.add(expr)
4056
4057            try:
4058                is_trivial = False
4059                if any(is_dim(source) for s in expr.free_symbols for source in symbol_to_source[s]):
4060                    is_trivial = self.dim_constraints.add(expr)
4061                guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
4062                exprs.append(guard_expr)
4063                self._add_target_expr(expr)
4064                # A non-relational constraint on a single sizevar can violate
4065                # a constraint
4066                if not is_trivial and len(expr.free_symbols) == 1:
4067                    symbol = next(iter(expr.free_symbols))
4068                    source = symbol_to_source[symbol][0]
4069                    constraints = symbol_to_constraints[symbol]
4070                    for c in constraints:
4071                        if isinstance(c, StrictMinMaxConstraint):
4072                            var_with_range = self._render_range_for_constraint_violation(source, c)
4073                            msg = (
4074                                f"Not all values of {var_with_range} "
4075                                f"satisfy the generated guard {guard_expr}."
4076                            )
4077                            record_constraint_violation(c.warn_only, self._debug_name(source), msg)
4078                        elif isinstance(c, RelaxedUnspecConstraint):
4079                            # This is fine, we allow guards here as long as it
4080                            # didn't constrain it to one value  (we don't
4081                            # actually know this; this depends on our
4082                            # ValueRanges reasoning capability)
4083                            pass
4084                        else:
4085                            raise AssertionError(f"unrecognized constraint {c}")
4086            except Exception:
4087                self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format()))
4088                raise
4089
4090        # First, issue all guards.
4091        # This removes all the checks that follow from bounds
4092        # We could simply emit those and also the bounds 2 <= size when necessary
4093        for guard in (guards if guards is not None else self.guards):
4094            if self._maybe_evaluate_static(guard.expr, axioms=()) is not None:
4095                continue
4096            issue_guard(guard)
4097
4098        # Because there are guards that export's constraint solver can suggest good fixes for, that we may have
4099        # deferred as runtime asserts, and that produce_guards() alone won't do anything with (e.g. divisiblity guards),
4100        # we want to send runtime asserts to export's constraint solver too. These will still stay in the graph as asserts,
4101        # but export's constraint solver can decide whether to do anything with them (i.e. raise an error and provide
4102        # suggested fixes, or decide it's out of scope and leave as a runtime assert in the graph).
4103        for ra in self.deferred_runtime_asserts.get(None, []):
4104            if self._maybe_evaluate_static(ra.expr, axioms=()) is not None:
4105                continue
4106            expr = self.simplify(ra.expr)
4107            self.dim_constraints.add(expr)
4108
4109        # 3. Every symbol must be within its value range (this handles 0/1
4110        # specialization too).
4111        for symbol, sources in symbol_to_source.items():
4112            r = self.var_to_range.get(symbol)
4113            if r is None:
4114                if symbol not in self.var_to_range:
4115                    continue
4116                r = self.var_to_range[symbol]
4117
4118            assert sources
4119            bounds = []
4120            if r.lower not in (-sympy.oo, -int_oo):
4121                if any(is_dim(source) for source in sources):
4122                    self.dim_constraints.add(sympy.Ge(symbol, r.lower))
4123                # Only print lower bound in simplified mode if it is not the
4124                # default
4125                if not _simplified or r.lower != self._default_value_range().lower:
4126                    bounds.append(str(r.lower))
4127            bounds.append(source_ref(sources[0]))
4128            if r.upper not in (sympy.oo, int_oo):
4129                if any(is_dim(source) for source in sources):
4130                    self.dim_constraints.add(sympy.Le(symbol, r.upper))
4131                # nontrivial upper bound is always interesting
4132                bounds.append(str(r.upper))
4133            if len(bounds) > 1:
4134                exprs.append(" <= ".join(bounds))
4135
4136                # Check constraints
4137                constraints = symbol_to_constraints[symbol]
4138                for c in constraints:
4139                    if isinstance(c, StrictMinMaxConstraint):
4140                        # TODO: With int_oo, I think this condition is a noop
4141                        # now
4142                        if not (c.vr & self._default_value_range()).issubset(r):
4143                            source = sources[0]
4144
4145                            expr = sympy.And(sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper))
4146                            guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
4147                            var_with_range = self._render_range_for_constraint_violation(source, c)
4148                            msg = (
4149                                f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}"
4150                            )
4151                            record_constraint_violation(
4152                                c.warn_only,
4153                                self._debug_name(source),
4154                                msg,
4155                            )
4156            # We NaN specialize, which means similar to 0/1 specialization we
4157            # should assume that the float is NOT nan.  This is load bearing
4158            # if you have something like an equality guard, nan will play
4159            # merry hell with the reasoning.
4160            if symbol_is_type(symbol, SymT.FLOAT):
4161                exprs.append(f"not __math_isnan({source_ref(sources[0])})")
4162
4163        if constraint_violations:
4164            warn_msgs = []
4165            error_msgs = []
4166            debug_names = set()
4167            for warn_only, debug_name, msg in constraint_violations:
4168                if warn_only:
4169                    msg = f"  {len(warn_msgs) + 1}. {msg()}"
4170                    warn_msgs.append(msg)
4171                else:
4172                    msg = f"  - {msg()}"
4173                    error_msgs.append(msg)
4174                    debug_names.add(debug_name)
4175            if len(error_msgs) > 0:
4176                debug_names = ', '.join(sorted(debug_names))
4177                err = '\n'.join(error_msgs)
4178                raise ConstraintViolationError(
4179                    f"Constraints violated ({debug_names})! "
4180                    'For more information, run with TORCH_LOGS="+dynamic".\n'
4181                    f"{err}"
4182                )
4183            elif len(warn_msgs) > 0:
4184                log.debug("%s Warning only constraints violated", len(warn_msgs))
4185
4186        signpost_event(
4187            "dynamic",
4188            "produce_guards",
4189            {
4190                **self.co_fields,
4191                **self.counter,
4192                "num_guards": len(exprs),
4193                "free_symbols": sum(1 for v in symbol_to_source.values() if v),
4194                # The keys are meaningless from an aggregate perspective, so
4195                # don't include them.  Biggest first.
4196                "symbol_guard_counts": sorted(self.symbol_guard_counter.values(), reverse=True),
4197            },
4198        )
4199
4200        if self._translation_validation_enabled:
4201            from torch.fx.experimental.validator import PopulateValidator
4202
4203            # Add all deferred runtime assertions; these are not technically
4204            # handled by produce_guards but we need to put them in the target
4205            # set
4206            for ras in self.deferred_runtime_asserts.values():
4207                for ra in ras:
4208                    self._add_target_expr(ra.expr)
4209
4210            # Add value range bound guards for all symbols with no trivial bounds.
4211            # Reason: '_maybe_evaluate_static' may eliminate guards based on the
4212            # refined value ranges.
4213            for sym, vr in self.var_to_range.items():
4214                if vr.lower not in (-sympy.oo, -int_oo):
4215                    self._add_target_expr(sympy.Le(vr.lower, sym))
4216                if vr.upper not in (sympy.oo, int_oo):
4217                    self._add_target_expr(sympy.Le(sym, vr.upper))
4218
4219            # Before validating, populate the input of the validator with the
4220            # built FX graph.
4221            with fx_traceback.preserve_node_meta():
4222                PopulateValidator(self.graph, self.validator).run()
4223
4224        # Only run translation validation when we are not passing custom guards
4225        if guards is None:
4226            self._check_translation_validate()
4227        return exprs
4228
4229    def produce_guards_expression(
4230        self,
4231        placeholders,
4232        *,
4233        guards: Optional[List[ShapeGuard]] = None,
4234        ignore_static=True
4235    ):
4236        """
4237        Expected to be used with evaluate_guards_expression(). Produces the guards
4238        for the given placeholders and returns a string expression to be evaluated
4239        by evaluate_guards_expression given concrete values for the placeholders.
4240        """
4241        from torch._dynamo.source import LocalSource
4242        arg_names = [f"t{i}" for i in range(len(placeholders))]
4243        produced_guards = self.produce_guards(
4244            placeholders,
4245            [LocalSource(a) for a in arg_names],
4246            guards=guards,
4247            ignore_static=ignore_static,
4248        )
4249        if produced_guards:
4250            return " and ".join(produced_guards)
4251        return None
4252
4253    def evaluate_symexpr(self, code):
4254        """
4255        To be used by compile_fx to evaluate symexprs
4256        """
4257        args = {str(e): val for e, val in self.var_to_val.items()}
4258        return eval(code, SYMPY_INTERP, args)
4259
4260    def evaluate_guards_expression(self, code, args):
4261        """
4262        Expected to be used with produce_guards_expression(). Evaluates an expression
4263        generated by produce_guards_expression for the given concrete args.
4264        """
4265        arg_names = [f"t{i}" for i in range(len(args))]
4266        return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})
4267
4268    def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
4269        """Generate guards for a graph's placeholder values and evaluate the guards with args
4270        """
4271        code = self.produce_guards_expression(placeholders, ignore_static=ignore_static)
4272        if code:
4273            return self.evaluate_guards_expression(code, args)
4274        return True
4275
4276    def get_pruned_guards(self, symints):
4277        """
4278        Get a list of guards, but pruned so it only provides guards that
4279        reference symints from the passed in input
4280        """
4281        symints = {s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)}
4282        guards = []
4283        for g in self.guards:
4284            if all(s in symints for s in g.expr.free_symbols):
4285                guards.append(g)
4286        return guards
4287
4288    def bind_symbols(self, placeholders, args):
4289        """
4290        Given a paired list of placeholders (fake tensors with
4291        symbolic sizes) and concrete arguments (regular tensors
4292        with real sizes), returns a dictionary mapping each
4293        symbol to its real value.  So for example, if you
4294        have a placeholder with size (s0, s1), binding
4295        (2, 4) to it will give you {s0: 2, s1: 4}.  This is
4296        not guaranteed to bind ALL symbols in the ShapeEnv;
4297        we can't bind a symbol if it doesn't occur in any placeholder,
4298        and symbols that already have replacements won't get bindings.
4299
4300        This is a little duplicative with evaluate_guards but
4301        it's different enough that it seemed cleanest to make
4302        another copy.  This assumes the guards are already checked,
4303        though if it's cheap we'll check for shenanigans
4304        """
4305        bindings: Dict[sympy.Symbol, int] = {}
4306
4307        def bind_symint(arg, val):
4308            if isinstance(val, SymInt):
4309                s = val.node.expr
4310
4311                if isinstance(s, sympy.Symbol):
4312                    if s in bindings:
4313                        assert bindings[s] == arg, f"{bindings[s]} != {arg}"
4314                    else:
4315                        bindings[s] = arg
4316                elif isinstance(-s, sympy.Symbol):
4317                    if -s in bindings:
4318                        assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}"
4319                    else:
4320                        bindings[-s] = -arg
4321
4322        for t, arg in zip(placeholders, args):
4323            if t is None:
4324                continue
4325            if isinstance(t, SymInt):
4326                bind_symint(arg, t)
4327                continue
4328            assert isinstance(t, torch.Tensor)
4329            for i, s in enumerate(t.size()):
4330                bind_symint(arg.size(i), s)
4331            for i, s in enumerate(t.stride()):
4332                bind_symint(arg.stride(i), s)
4333            bind_symint(arg.storage_offset(), t.storage_offset())
4334
4335        return bindings
4336
4337    def get_nontrivial_guards(self):
4338        """Returns a list of guard expressions that aren't statically known (i.e. not trivial)"""
4339        return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr, axioms=()) is None]
4340
4341    def format_guards(self, verbose=False):
4342        """Format this shape env's guard expressions with optional traceback info if verbose"""
4343        def format_tb(tb):
4344            if not verbose:
4345                return ""
4346            return f"\n   Guarded at:\n{''.join('   ' + l for l in tb.format())}"
4347
4348        return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
4349
4350    def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRanges:
4351        """Given a sympy expression, computes a ValueRanges bound for what values it can be"""
4352        var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols}
4353        if size_oblivious:
4354            # Clamp values of size-like variables
4355            # NB: discarding the old upper bound in intentional, per
4356            # https://github.com/pytorch/pytorch/pull/123675
4357            for x in self.size_like & var_to_range.keys():
4358                if var_to_range[x] is not None:
4359                    # NB: do NOT set upper to 2 ** 48, we're using this solely
4360                    # to determine if we can do size-like replacement, the
4361                    # upper bound is irrelevant here
4362                    var_to_range[x] = ValueRanges(2, int_oo)
4363                    assert var_to_range[x].is_int
4364        return bound_sympy(expr, var_to_range)
4365
4366    @_lru_cache
4367    def get_axioms(self, symbols: Optional[Tuple["sympy.Symbol"]] = None, compute_hint: bool = False) -> Tuple["sympy.Expr"]:
4368        """
4369        Given the symbols in an expression, it returns all the runtime asserts that have those symbols
4370        concatenated with all the guards.
4371        If symbols is None, it returns all the runtime asserts (and all the guards)
4372        """
4373        if symbols is None:
4374            runtime_asserts = (r.expr
4375                               for rs in self.deferred_runtime_asserts.values()
4376                               for r in rs)
4377        else:
4378            runtime_asserts = (r.expr
4379                               for s in symbols if s not in self.var_to_val
4380                               for r in self.deferred_runtime_asserts.get(s, ()))
4381        guards = (g.expr for g in self.guards)
4382        axioms = itertools.chain(guards, runtime_asserts)
4383        if compute_hint:
4384            axioms = (canonicalize_bool_expr(a.xreplace(self.var_to_val)) for a in axioms)
4385        return tuple(dict.fromkeys(axioms).keys())
4386
4387    @lru_cache(None)
4388    def get_implications(self,
4389                         e: "sympy.Expr") -> Tuple[Tuple["sympy.Expr", 'sympy.logic.boolalg.BooleanAtom']]:
4390        """ Given a expression, it returns a list of predicates that follow from it """
4391        equiv = {}
4392
4393        def add_expr(expr):
4394            expr = canonicalize_bool_expr(expr)
4395            if isinstance(expr, (sympy.Eq, sympy.Ne)):
4396                # No need to canonicalize
4397                # TODO We could further canonicalize Eq ordering the lhs and rhs somehow
4398                # With this, we could remove the need for the commutativity part
4399                opposite = sympy.Eq if isinstance(expr, sympy.Ne) else sympy.Ne
4400                # Commutativity of == and !=
4401                equiv[type(expr)(expr.lhs, expr.rhs)] = sympy.true
4402                equiv[type(expr)(expr.rhs, expr.lhs)] = sympy.true
4403                equiv[opposite(expr.lhs, expr.rhs)] = sympy.false
4404                equiv[opposite(expr.rhs, expr.lhs)] = sympy.false
4405            else:
4406                # Expr and negation
4407                equiv[expr] = sympy.true
4408                equiv[canonicalize_bool_expr(sympy.Not(expr))] = sympy.false
4409
4410        add_expr(e)
4411        # Other relational expressions this expression implies
4412        if isinstance(e, sympy.Eq):
4413            add_expr(sympy.Le(e.lhs, e.rhs))
4414            add_expr(sympy.Ge(e.lhs, e.rhs))
4415        elif isinstance(e, sympy.Lt):
4416            add_expr(sympy.Le(e.lhs, e.rhs))
4417            add_expr(sympy.Ne(e.lhs, e.rhs))
4418            if e.lhs.is_integer and e.rhs.is_integer:
4419                add_expr(sympy.Le(e.lhs, e.rhs - 1))
4420        elif isinstance(e, sympy.Le):
4421            add_expr(sympy.Lt(e.lhs, e.rhs + 1))
4422        return tuple(equiv.items())
4423
4424    @_lru_cache
4425    def _maybe_evaluate_static(
4426        self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False,
4427        size_oblivious: bool = False, axioms: Optional[Tuple[sympy.Expr]] = None,
4428        var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None
4429    ) -> "Optional[sympy.Expr]":
4430        """
4431        Tries to evaluate expr without introducing guards
4432
4433        If unbacked_only == True, then we only do substitutions on
4434        unbacked SymInts (leaving regular hinted integers alone).  This could
4435        result in an expression that still contains backed SymInts, which you
4436        could then potentially guard on.
4437
4438        Use compute_hint == True if you are trying to compute a non-binding
4439        hint for the particular hint values of backed SymInts, e.g., if
4440        s0 happens to be 3 this run, compute_hint will subsitute s0 with 3.
4441        """
4442
4443        # axioms with compute hint NYE
4444        assert not compute_hint or not axioms
4445
4446        if var_to_range is None:
4447            var_ranges = self.var_to_range
4448        else:
4449            var_ranges = dict(var_to_range)
4450
4451        expr = self.simplify(expr)
4452
4453        if compute_hint:
4454            expr = expr.xreplace(self.var_to_val)
4455
4456        expr = canonicalize_bool_expr(expr)
4457
4458        # Pattern matching
4459        symbols = tuple(expr.free_symbols)
4460        if axioms is None:
4461            axioms = self.get_axioms(symbols, compute_hint=compute_hint)
4462        subst = {}
4463        for e in axioms:
4464            if e.free_symbols.issubset(expr.free_symbols):
4465                subst.update(dict(self.get_implications(e)))
4466
4467        expr = expr.xreplace(subst)
4468
4469        symbols = tuple(expr.free_symbols)
4470
4471        # Simplify making use of value range lower bound
4472        new_shape_env = {}
4473        new_range_env = {}
4474        for idx, k in enumerate(symbols):
4475            if isinstance(self.var_to_val.get(k, None), SingletonInt):
4476                # Skip var_ranges logic for SingletonInt which is only used
4477                # for jagged layout NestedTensors today
4478                continue
4479            vr = var_ranges[k]
4480            if size_oblivious and k in self.size_like:
4481                lower = max(2, vr.lower)
4482                # Clamping size-oblivious to some quantity below sys.maxsize
4483                # helps us determine that f(u0) != sys.maxsize, which is a
4484                # test that is looking for sys.maxsize as a sentinel, but you
4485                # don't really want to worry about it for unbacked SymInts.
4486                # This is similar to the flavor where size oblivious omits
4487                # 0/1, it changes semantics but in a benign way.
4488                upper = min(2 ** 48, vr.upper)
4489                # This is a bit dodgy: what this means is that there was a
4490                # size-like unbacked symbol whose upper bound < 2.  This
4491                # causes... problems.
4492                if lower <= upper:
4493                    vr = ValueRanges(lower, upper)
4494            else:
4495                lower = vr.lower
4496            # Don't do anything if we don't have a nontrivial lower bound
4497            # Also don't do anything if we asked only to simplify unbacked
4498            # SymInt
4499            if (
4500                lower is -int_oo or
4501                (unbacked_only and k in self.var_to_val) or
4502                not vr.is_int
4503            ):
4504                new_range_env[k] = vr
4505                continue
4506            # The goal is to take our symbols which have various lower bounds
4507            # and reallocate them into new symbols which are exactly positive;
4508            # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
4509            # [1, inf], where s0 = ess0 + 1.  This gives the most information
4510            # to sympy for subsequent simplifications.
4511            #
4512            # Positive means >= 1
4513            # Positive - 1 means >= 0
4514            # Positive + lower - 1 means >= lower
4515            # The new symbol 's' is "too low", so when we substitute it in
4516            # we have to increase it by offset (and conversely, the new
4517            # variables have to have their value range bounds adjusted as
4518            # well)
4519            s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True)
4520
4521            # Note:
4522            #   Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
4523            #   Sympy might give unexepected results when comparing an integer with a non-integer
4524            #   Therefore, we cast offset to int here.
4525            #   For example:
4526            #       shape_0 = sympy.Symbol("shape_0", positive=True, integer=True)
4527            #       expr = sympy.Eq(shape_0 - 1/3, 4)
4528            #       expr.xreplace({}) # False
4529            offset = int(lower - 1)
4530            new_shape_env[k] = s + offset
4531            new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset)
4532
4533        try:
4534            new_expr = expr.xreplace(new_shape_env)
4535        except RecursionError:
4536            log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
4537            self.counter["sympy_recursion_error"] += 1
4538            return None
4539
4540        # We need to canonicalize, as after expand we may have something like `a + b = a` and
4541        # sympy will not simplify the a. The two appeareances of the a will then make value ranges
4542        # analysis give lose bounds
4543        new_expr = canonicalize_bool_expr(safe_expand(new_expr))
4544        if new_expr.is_number:
4545            return new_expr
4546
4547        # This is bad to do, the replacement with division leaves us with
4548        # rationals when atom.args[0] is addition, e.g., sympy will happily
4549        # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2.  Needless complication!
4550        """
4551        floor_div_replace = {}
4552        for atom in new_expr.atoms(FloorDiv):
4553            floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
4554        new_expr = safe_expand(new_expr.xreplace(floor_div_replace))
4555        # TODO: when unbacked_only, can sometimes early return even when there
4556        # are still free symbols
4557        if new_expr.is_number:
4558            return new_expr
4559        """
4560
4561        # Check if the range can solve it statically
4562        out = bound_sympy(new_expr, new_range_env)
4563        if out.is_singleton():
4564            return out.lower
4565
4566        return new_expr if unbacked_only else None
4567
4568    @_lru_cache
4569    def replace(self, expr: "sympy.Expr") -> "sympy.Expr":
4570        """Apply symbol replacements to any symbols in the given expression
4571        """
4572        replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols}
4573        return safe_expand(expr.xreplace(replacements))
4574
4575    @_lru_cache
4576    def _update_divisible(self):
4577        new_divisible = set()
4578        for k in self.divisible:
4579            res = self.replace(k)
4580            if not res.is_number:
4581                new_divisible.add(k)
4582
4583        self.divisible = new_divisible
4584        self._update_version_counter()
4585
4586    @_lru_cache
4587    def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
4588        """Use known constraints and replacements to simplify the given expr
4589        """
4590        expr = self.replace(expr)
4591        # TODO it would seem that this pass is not necessary given the
4592        # below replacement of // with /, but for nested FloorDivs
4593        # the non-recursive replacement doesn't work, and
4594        # recursive makes it hard to look up divisibility,
4595        # because existing divisibility info has FloorDiv in it, not /
4596        # for now just do a separate pass to catch common nested case
4597        if expr.has(FloorDiv):
4598            self._update_divisible()
4599            div_replacements = {}
4600            for atom in expr.atoms(FloorDiv):
4601                base, divisor = atom.args
4602                if isinstance(divisor, FloorDiv):
4603                    base1, divisor1 = divisor.args
4604                    if self.replace(Mod(base, divisor)) in self.divisible and \
4605                            base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible:
4606                        div_replacements[atom] = divisor1
4607            expr = expr.xreplace(div_replacements)
4608            expr = safe_expand(expr)
4609        if expr.has(FloorDiv):
4610            div_replacements = {}
4611            pows = expr.atoms(sympy.Pow)
4612            rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer))
4613            for fd in expr.atoms(FloorDiv):
4614                base, divisor = fd.args
4615                if self.replace(Mod(base, divisor)) in self.divisible:
4616                    div_replacements[fd] = CleanDiv(base, divisor)
4617            new_expr = expr.xreplace(div_replacements)
4618            new_expr = safe_expand(new_expr)
4619            new_pows = new_expr.atoms(sympy.Pow)
4620            new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer))
4621            # divisions simplified away
4622            if new_pows.issubset(pows) and new_rationals.issubset(rationals):
4623                expr = new_expr
4624        return expr
4625
4626    @lru_cache(256)
4627    def size_hint(self, expr: "sympy.Expr", *, allow_none=False):
4628        """
4629        Gets a size hint for a given expression from the underlying shapes we had.
4630        Does not introduce a guard, so only use this when you can guarantee that
4631        your code is still valid for arbitrary shapes (such as optimization decisions)
4632        """
4633        result_expr = safe_expand(expr).xreplace(self.var_to_val)
4634        if not result_expr.is_number:
4635
4636            from torch.utils._sympy.singleton_int import SingletonInt
4637
4638            if isinstance(result_expr, SingletonInt):
4639                return None
4640            r = self._maybe_evaluate_static(result_expr, compute_hint=True)
4641            if r is not None:
4642                return r
4643            if allow_none:
4644                return None
4645
4646            if self.unbacked_var_to_val:
4647                unsound_expr = result_expr.xreplace(self.unbacked_var_to_val)
4648                if not unsound_expr.free_symbols:
4649                    log.warning("propagate_real_tensors size_hint(%s) -> %s", expr, unsound_expr)
4650                    trace_structured(
4651                        "propagate_real_tensors",
4652                        metadata_fn=lambda: {
4653                            "expr": repr(expr),
4654                            "result": repr(unsound_expr),
4655                            "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()),
4656                        },
4657                    )
4658                    self.defer_runtime_assert(
4659                        sympy.Eq(result_expr, unsound_expr),
4660                        f"propagate_real_tensors: {result_expr} == {unsound_expr}"
4661                    )
4662                    return unsound_expr
4663
4664            raise self._make_data_dependent_error(result_expr, expr)
4665        return result_expr
4666
4667    # NB: keep in sync with size_hint
4668    @lru_cache(256)
4669    def has_hint(self, expr: "sympy.Expr"):
4670        result_expr = safe_expand(expr).xreplace(self.var_to_val)
4671        return result_expr.is_number or self._maybe_evaluate_static(result_expr) is not None
4672
4673    def _make_data_dependent_error(self, expr, unhinted_expr, *, size_oblivious_result: Optional[bool] = None):
4674        # TODO: in a Dynamo context, having user code, and having the
4675        # name of the local, will be much better
4676        size_like_symbols = []
4677        for s in expr.free_symbols:
4678            stacktrace = ''.join(self.var_to_stack[s].format())
4679            self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace)
4680            if s in self.size_like:
4681                size_like_symbols.append(s)
4682        size_oblivious_result_msg = ""
4683        if size_oblivious_result is not None:
4684            size_oblivious_result_msg = (
4685                f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n"
4686                "Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n"
4687            )
4688        fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(True)
4689        if expr.is_integer:
4690            desc = "Could not extract specialized integer from data-dependent expression"
4691        else:
4692            desc = "Could not guard on data-dependent expression"
4693        msg = (
4694            f"{desc} {expr} (unhinted: {unhinted_expr}).  "
4695            f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n"
4696            f"{size_oblivious_result_msg}"
4697            "Potential framework code culprit (scroll up for full backtrace):\n"
4698            f"{''.join(traceback.StackSummary.from_list([fsummary]).format())}\n"
4699            'For more information, run with TORCH_LOGS="dynamic"\n'
4700            "For extended logs when we create symbols, also add "
4701            f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n"
4702            "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n"
4703            "For more debugging help, see "
4704            "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" +
4705            maybe_extra_debug
4706            # TODO: Help text about how to use our runtime tests to fix this
4707            # problem
4708        )
4709        return GuardOnDataDependentSymNode(expr, msg)
4710
4711    def _update_var_to_range(self, symbol, vr):
4712        lower, upper = vr.lower, vr.upper
4713
4714        # If we have a size-like unbacked SymInt, refuse to refine the range to be
4715        # less than two.  This is because when we intersect this range
4716        # with [2, inf] for size oblivious tests, the range would be
4717        # unsatisfiable.  In other words, once you have a size-like
4718        # unbacked SymInt, we can never learn that it is exactly zero or one,
4719        # because we would now give inconsistent results for all size
4720        # oblivous tests!
4721        if upper < 2 and symbol in self.size_like:
4722            upper = 2
4723
4724        # Updates the range and the guards corresponding to each bound of the symbol.
4725        if symbol not in self.var_to_range:
4726            r = ValueRanges(lower, upper)
4727            self.log.debug("_update_var_to_range %s = %s (new)", symbol, r)
4728            self.var_to_range[symbol] = r
4729        else:
4730            old = self.var_to_range[symbol]
4731            new = old & ValueRanges(lower, upper)
4732            if new != old:
4733                self.var_to_range[symbol] = new
4734                self.log.debug("_update_var_to_range %s = %s (update)", symbol, new)
4735
4736        if (v := self.var_to_val.get(symbol)) is not None:
4737            r = self.var_to_range[symbol]
4738            assert v in r, f"{v} not in {r}"
4739
4740    def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None:
4741        """
4742        Adds or updates a replacement for a symbol.
4743        Use this instead of `self.replacements[a] = tgt`.
4744        """
4745
4746        if tgt == self.replacements.get(a, None):
4747            return
4748
4749        # Precondition: a == tgt
4750        assert isinstance(a, sympy.Symbol)
4751
4752        if self.allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt):
4753            return  # continuing leads to placeholder shapes having complex expressions that we can't resolve
4754
4755        # Handles nested tensor symbolic variables which don't have
4756        # var_to_range bounds
4757        tgt_bound = None
4758        if a in self.var_to_range:
4759            src_bound = self.var_to_range[a]
4760
4761            # First, refine the value range of a based on the computed value range
4762            # of tgt.  This is always OK to do, even if we decide not to do the
4763            # substitution in the end.  This might be a no-op, if a already has
4764            # a tighter bound
4765            tgt_bound = self.bound_sympy(tgt)
4766            self._update_var_to_range(a, tgt_bound)
4767
4768            # Next, check if we can update the range of free symbols in tgt
4769            # based on the range in a. But only do it if:
4770            #  - the source bound non-trivially improves over what we get out of
4771            #    the existing bounds.
4772            #  - the replacement is univariate and we can invert the tgt expression
4773            if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1:
4774                b = next(iter(tgt.free_symbols))
4775                # Try to invert the equality
4776                r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
4777                if r is not None:
4778                    self.log.debug("set_replacement: solve for %s in %s == %s gives %s", b, a, tgt, r)
4779                    # The solution here can be non-integral, for example, if
4780                    # we have s0 = 2*s1, then s1 = s0/2.  What we would like
4781                    # to do is calculated the bounds in arbitrary precision,
4782                    # and then requantize the bound to integers when we are
4783                    # done.
4784                    rat_b_bound = self.bound_sympy(r[1])
4785                    b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper))
4786                    self._update_var_to_range(b, b_bound)
4787                    tgt_bound = self.bound_sympy(tgt)
4788                    assert tgt_bound.issubset(src_bound)
4789
4790            # TODO: Should we propagate size-like-ness?
4791            #
4792            # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1
4793            # to become size-like.
4794            #
4795            # Cons: if u0 is size-like, what about u0 - 1 == u1?  You CAN'T
4796            # propagate in this case, because what if u0 == 0, then u1 is negative
4797            # and clearly isn't a size.  So, at minimum, any f(x) whose value
4798            # range isn't [0, inf] given x in [0, inf] cannot propagate
4799            # size-like-ness.  But there are many situations where you could
4800            # imagine u1 is going to be size-like and actually you just didn't
4801            # have a refined enough value range on u0.  Since even innocuous
4802            # looking arithmetic operations can destroy size-like-ness, it's
4803            # best to not propagate it at all and force the user to annotate it
4804            # as necessary.
4805            #
4806            # Compromise: we preserve size-like-ness only for exact equality
4807            # and nothing else.
4808            if a in self.size_like and isinstance(tgt, sympy.Symbol):
4809                self.size_like.add(tgt)
4810            elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like:
4811                self.size_like.add(a)
4812
4813            # Now, decide if we will do the substitution.
4814            #
4815            #  - If the source has a non-trivial range, only substitute if
4816            #    we preserve this range.  Note that we may have propagated
4817            #    the src_range to free variables in tgt when tgt is univariate
4818            #    and we could find an inverse, which helps us achieve this.
4819            #    This ensures we never "forget" about user defined ranges,
4820            #    even if they end up being defined on composite formulas
4821            #    like s0 + s1.
4822            #
4823            #  - If the variable is unbacked, only substitute if the substitution
4824            #    would preserve the bounds also under size-like-ness conditions.
4825
4826            if not tgt_bound.issubset(src_bound):
4827                self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound)
4828                return
4829            elif a in self.size_like:
4830                tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
4831                src_bound_so = self.bound_sympy(a, size_oblivious=True)
4832                if not tgt_bound_so.issubset(src_bound_so):
4833                    self.log.debug("skipped set_replacement %s = %s (%s) "
4834                                   "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so)
4835                    return
4836
4837        if isinstance(tgt, (sympy.Integer, sympy.Float)):
4838            # specializing to a constant, which is likely unexpected (unless
4839            # you specified dynamic=True)
4840
4841            user_tb = TracingContext.extract_stack()
4842            trace_structured(
4843                "symbolic_shape_specialization",
4844                metadata_fn=lambda: {
4845                    "symbol": repr(a),
4846                    "sources": [s.name() for s in self.var_to_sources.get(a, [])],
4847                    "value": repr(tgt),
4848                    "reason": msg,
4849                    "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()),
4850                    "user_stack": structured.from_traceback(user_tb) if user_tb else None,
4851                }
4852            )
4853
4854            if config.print_specializations:
4855                self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt)
4856                self.log.debug("SPECIALIZATION", stack_info=True)
4857        log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
4858        self.replacements[a] = tgt
4859        self._update_version_counter()
4860
4861        # When specializing 'a == tgt', the equality should be also conveyed to
4862        # Z3, in case an expression uses 'a'.
4863        self._add_target_expr(sympy.Eq(a, tgt))
4864
4865    def _add_divisible(self, expr: "sympy.Expr"):
4866        self.divisible.add(expr)
4867        self._update_version_counter()
4868
4869    @_lru_cache
4870    @record_shapeenv_event()
4871    def _find(self, a: "sympy.Symbol") -> "sympy.Expr":
4872        """
4873        Implements a DSU-like algorithm to find the variable that represents a
4874        Also handles transitive non-identity replacements.
4875
4876        a: b + c
4877        c: d
4878        """
4879        if a not in self.replacements:
4880            return a
4881        res = self.replacements[a]
4882        cur_replace = {s: self._find(s) for s in res.free_symbols}
4883        replaced, changed = self.replacements[a]._xreplace(cur_replace)
4884        if changed:
4885            self._set_replacement(a, replaced, "find")
4886        return self.replacements[a]
4887
4888    @lru_cache(256)
4889    def _maybe_guard_rel(self, expr: "sympy.Rel") -> None:
4890        """
4891        The relational guard is guarded to be true.  Use this information to
4892        simplify shapes (i.e. a == b or a % 5 == 0)
4893        """
4894        assert isinstance(expr, sympy.Rel)
4895
4896        # A good example of what goes wrong if you don't do this is
4897        # python test/functorch/test_aotdispatch.py -k
4898        # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32
4899        if isinstance(expr, sympy.Ne):
4900            return
4901
4902        free = list(expr.free_symbols)
4903
4904        assert len(free) > 0, f"The expression should not be static by this point: {expr}"
4905        # In case of really gnarly expression, we don't blow up
4906        if len(free) > 5:
4907            return
4908
4909        # Prioritize unbacked symints for solving by ordering them last.
4910        # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
4911        #   (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
4912        # Prefer to simplify out symbols with ephemeral sources.
4913        def _smart_symbol_sort(x):
4914            has_only_ephemeral_sources = (
4915                x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
4916            )
4917            # NB: size_hint is int, not sympy.Expr, do not use int_oo here
4918            size = self.size_hint(x, allow_none=True) or sys.maxsize
4919            name = x.name
4920            # 1 puts ephemeral sourced symbols first when sorting in reverse
4921            return (1 if has_only_ephemeral_sources else 0, size, name)
4922
4923        free = sorted(free, key=_smart_symbol_sort, reverse=True)  # type: ignore[attr-defined]
4924        lhs = expr.lhs
4925        rhs = expr.rhs
4926
4927        self._refine_ranges(expr)
4928
4929        # The rest of this stuff is for equality only
4930        if not isinstance(expr, sympy.Eq):
4931            return
4932
4933        if not expr.has(Mod):
4934            try:
4935                floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv))
4936                if len(floor_div_atoms) > 0 and any(a.divisor != 1 for a in floor_div_atoms):
4937                    raise NotImplementedError
4938
4939                # Never replace unbacked symbols with other unbacked symbols.
4940                # This is error prone because you can cause references to
4941                # unbacked symbols to time travel backwards.  E.g.,
4942                #
4943                # u1 = x.item()
4944                # ... use of u1 ...
4945                # u2 = y.item()
4946                # u3 = z.item()
4947                # torch._check(u1 == u2 + u3)
4948                #
4949                # If you replace u1 with u2 + u3, then the use of u1 now
4950                # references u2 and u3 prior to them actually being bound at
4951                # runtime.  It's pretty inconvenient to setup control
4952                # dependencies for substitutions, so ban it entirely.
4953                def trivial_solve(lhs, rhs):
4954                    if isinstance(lhs, sympy.Symbol):
4955                        if free_unbacked_symbols(lhs) and not free_unbacked_symbols(rhs):
4956                            return True
4957                        if symbol_is_type(lhs, SymT.FLOAT):
4958                            return True
4959                        # TODO: Maybe trivial solutions for int should also be
4960                        # done?
4961                    return False
4962
4963                # short-circuit when no solving is needed
4964                if trivial_solve(lhs, rhs):
4965                    self._set_replacement(lhs, self._find(rhs), "trivial_lhs")
4966                elif trivial_solve(rhs, lhs):
4967                    self._set_replacement(rhs, self._find(lhs), "trivial_rhs")
4968                else:
4969                    r = try_solve(expr, free[0], floordiv_inequality=False)
4970                    if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])):
4971                        new_var = self._find(r[1])
4972                        ok = len(free_unbacked_symbols(new_var)) == 0
4973                        if ok:
4974                            self._set_replacement(cast(sympy.Symbol, free[0]), new_var, "solve")
4975            except NotImplementedError:
4976                pass
4977        if expr.has(Mod):
4978            mod_expr = next(iter(expr.atoms(Mod)))
4979            try:
4980                r = try_solve(expr, mod_expr, floordiv_inequality=False)
4981                if r is not None and r[1] == 0:
4982                    self._add_divisible(mod_expr)
4983                    # This is a little bit of extra logic to make things like
4984                    # torch.empty(i0, q).view(c, -1, q) work out
4985                    p, q = mod_expr.args
4986                    if isinstance(q, sympy.Number) and isinstance(p, sympy.Mul) and len(p.args) == 2:
4987                        c, i0 = p.args
4988                        # Given Mod(c * i0, q) == 0
4989                        if (
4990                            isinstance(c, sympy.Number) and
4991                            isinstance(i0, sympy.Symbol) and
4992                            self.is_unbacked_symint(i0)
4993                        ):
4994                            # We have Mod(i0, q / c) == 0, which means we can
4995                            # rewrite i0 as (q / gcd(q, c)) * i1
4996                            d = q / sympy.gcd(q, c)  # TODO: CleanDiv?
4997                            i1 = self.create_unbacked_symint().node.expr
4998                            # Propagate the value ranges.  It doesn't really
4999                            # matter if we use truediv or floordiv, because we
5000                            # have established divisibility.
5001                            self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv(
5002                                self.var_to_range[i0], ValueRanges.wrap(d)
5003                            ))
5004                            # Propagate size-like-ness
5005                            if i0 in self.size_like:
5006                                self.size_like.add(i1)
5007                            self._set_replacement(i0, d * i1, "divisibility")
5008
5009            except NotImplementedError:
5010                pass
5011        return
5012
5013    # See: Note - On 0/1 specialization
5014    def _default_value_range(self) -> ValueRanges:
5015        lower = 2 if self.specialize_zero_one else 0
5016        return ValueRanges(lower, int_oo)
5017
5018    def _default_unspecified_value_range(self) -> ValueRanges:
5019        return ValueRanges(-int_oo, int_oo)
5020
5021    @_lru_cache
5022    def _simplify_floor_div(self, expr):
5023        floor_divs = tuple(expr.atoms(FloorDiv))
5024        # we expect floor_divs to be exact,
5025        # and thus add the guards for the exact floordivs,
5026        # even if tracing doesn't require them otherwise
5027        for fd in reversed(floor_divs):
5028            base, divisor = fd.args
5029            mod_expr = Mod(base, divisor)
5030            eq_expr = sympy.Eq(mod_expr, 0)
5031            # add necessary mod guards
5032            self.evaluate_expr(eq_expr)
5033        return self.simplify(expr)
5034
5035    # We're about to add a guard/runtime assert, check if the ShapeEnv is frozen
5036    # and if so issue a warning
5037    def _check_frozen(self, expr, concrete_val):
5038        if self.frozen:
5039            self.counter["ignored_backward_guard"] += 1
5040            signpost_event(
5041                "dynamic",
5042                "evaluate_expr_frozen",
5043                {
5044                    **self.co_fields,
5045                    "ignored_guard": f"{expr} == {concrete_val}",
5046                    # no version = original state (this signpost is expected)
5047                    # version 2 = dynamic backwards is eagerly compiled
5048                    "version": 2,
5049                },
5050            )
5051            log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val, stack_info=True)
5052
5053
5054    def _get_stack_summary(self, is_debug: bool = False):
5055        fsummary = None
5056        frame = inspect.currentframe()
5057        try:
5058            while frame is not None:
5059                if frame.f_code.co_filename not in uninteresting_files():
5060                    fsummary = traceback.FrameSummary(
5061                        frame.f_code.co_filename,
5062                        frame.f_lineno,
5063                        frame.f_code.co_name,
5064                    )
5065                    break
5066                frame = frame.f_back
5067        finally:
5068            del frame
5069
5070        # NB: this stack is truncated, but it's fine because the main
5071        # stack_info will give you the rest of the info you need
5072        maybe_user_loc = ""
5073        user_tb = TracingContext.extract_stack()
5074        if user_tb:
5075            maybe_user_loc = " at " + format_frame(user_tb[-1])
5076
5077        maybe_extra_debug = ""
5078        if is_debug and user_tb:
5079            maybe_extra_debug = (
5080                '\nUser Stack (most recent call last):\n' +
5081                '  (snipped, see stack below for prefix)\n' +
5082                ''.join(traceback.format_list(user_tb))
5083            )
5084        if is_debug and config.extended_debug_cpp:
5085            cpp_stack = CapturedTraceback.extract(cpp=True)
5086            maybe_extra_debug += "\nC++ stack trace:\n" + ''.join(cpp_stack.format())
5087        elif is_debug:
5088            maybe_extra_debug += (
5089                "\nFor C++ stack trace, run with "
5090                "TORCHDYNAMO_EXTENDED_DEBUG_CPP=1"
5091            )
5092
5093        return fsummary, maybe_user_loc, maybe_extra_debug
5094
5095    def _log_guard(self, prefix: str, g, forcing_spec: bool):
5096        if self.log.isEnabledFor(logging.INFO):
5097            str_g = str(g)
5098            is_debug = config.extended_debug_guard_added is not None and str_g == config.extended_debug_guard_added
5099            fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
5100            maybe_more_info = ""
5101            if not is_debug:
5102                maybe_more_info = (
5103                    ", for more info run with "
5104                    f'TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="{str_g}"'
5105                )
5106            self.log.info(
5107                "%s %s [guard added]%s (%s)%s%s",
5108                prefix if not forcing_spec else f"{prefix} (forcing_spec)",
5109                str_g,
5110                maybe_user_loc,
5111                format_frame(fsummary),
5112                maybe_more_info,
5113                maybe_extra_debug,
5114                stack_info=is_debug,
5115            )
5116
5117    @lru_cache(256)
5118    @record_shapeenv_event(save_tracked_fakes=True)
5119    def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None,
5120                      size_oblivious: bool = False, *, forcing_spec: bool = False):
5121        try:
5122            return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
5123        except Exception:
5124            self.log.warning(
5125                "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s",
5126                orig_expr, hint, size_oblivious, forcing_spec
5127            )
5128            raise
5129
5130    def _evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None,
5131                       size_oblivious: bool = False, *, forcing_spec: bool = False):
5132        """
5133        Given an expression, evaluates it, adding guards if necessary
5134        """
5135
5136        # TODO: split conjunctions and evaluate them separately
5137
5138        # Don't track this one
5139        @functools.lru_cache(None)
5140        def compute_concrete_val():
5141            if hint is None:
5142                return self.size_hint(orig_expr)
5143            else:
5144                return sympy.sympify(hint)
5145
5146        # Check if:
5147        #   1. 'translation_validation' is set
5148        #   2. the corresponding 'fx_node' is not 'None'
5149        #   3. the guard should not be suppressed
5150        #
5151        # If all of the above check, we create an FX node representing the
5152        # actual expression to be guarded.
5153        node = None
5154        fresh = False
5155        if (
5156                self._translation_validation_enabled
5157                and fx_node is not None
5158                and not self._suppress_guards_tls()
5159                and not size_oblivious
5160        ):
5161            concrete_val = compute_concrete_val()
5162            if concrete_val is sympy.true:
5163                node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
5164            elif concrete_val is sympy.false:
5165                neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
5166                node, fresh = self._create_fx_call_function(torch._assert, (neg,))
5167            else:
5168                eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
5169                node, fresh = self._create_fx_call_function(torch._assert, (eql,))
5170
5171            assert node is not None
5172            # If this is a fresh node, we have to remember the event index that
5173            # corresponds to this assertion node.
5174            # Reason: so that, given an assertion node, we can replay the ShapeEnv
5175            # events until the point where this assertion node was freshly created.
5176            if fresh:
5177                self._add_fx_node_metadata(node)
5178
5179        # After creating the FX node corresponding to orig_expr, we must make sure that
5180        # no error will be raised until the end of this function.
5181        #
5182        # Reason: the translation validation may become invalid otherwise.
5183        #
5184        # If an error is raised before the end of this function, we remove the FX node
5185        # inserted, and re-raise the error.
5186        guard = None
5187        tb = None
5188
5189        try:
5190            if orig_expr.is_number:
5191                self.log.debug("eval %s [trivial]", orig_expr)
5192                if hint is not None:
5193                    assert orig_expr == hint, f"{orig_expr} != {hint}"
5194                return orig_expr
5195
5196            expr = orig_expr
5197
5198            static_expr = self._maybe_evaluate_static(expr,
5199                                                      size_oblivious=size_oblivious)
5200            if static_expr is not None:
5201                self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr)
5202                if hint is not None:
5203                    assert static_expr == hint, f"{static_expr} != {hint}"
5204                return static_expr
5205
5206            transmute_into_runtime_assert = False
5207
5208            concrete_val = None
5209            if not (expr.free_symbols <= self.var_to_val.keys()):
5210                # TODO: dedupe this with _maybe_evaluate_static
5211                # Attempt to eliminate the unbacked SymInt
5212                new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
5213                if not (new_expr.free_symbols <= self.var_to_val.keys()):
5214                    size_oblivious_result = None
5215                    if not size_oblivious:
5216                        size_oblivious_result = self._maybe_evaluate_static(
5217                            expr,
5218                            size_oblivious=True
5219                        )
5220
5221                    # Last ditch
5222                    if (
5223                        self.unbacked_var_to_val and
5224                        not (unsound_result := orig_expr.xreplace(self.unbacked_var_to_val)).free_symbols
5225                    ):
5226                        log.warning("propagate_real_tensors evaluate_expr(%s) -> %s", orig_expr, unsound_result)
5227                        trace_structured(
5228                            "propagate_real_tensors",
5229                            metadata_fn=lambda: {
5230                                "expr": repr(orig_expr),
5231                                "result": repr(unsound_result),
5232                                "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()),
5233                            },
5234                        )
5235                        transmute_into_runtime_assert = True
5236                        concrete_val = unsound_result
5237                    else:
5238                        raise self._make_data_dependent_error(
5239                            expr.xreplace(self.var_to_val),
5240                            expr,
5241                            size_oblivious_result=size_oblivious_result
5242                        )
5243                else:
5244                    expr = new_expr
5245
5246            if concrete_val is None:
5247                concrete_val = compute_concrete_val()
5248            self._check_frozen(expr, concrete_val)
5249
5250            if (
5251                    config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
5252                    and isinstance(hint, bool)
5253                    and isinstance(expr, (sympy.Eq, sympy.Ne))
5254            ):
5255                expr = sympy.Not(expr)
5256
5257            # Turn this into a boolean expression, no longer need to consult
5258            # concrete_val
5259            if concrete_val is sympy.true:
5260                g = expr
5261            elif concrete_val is sympy.false:
5262                g = sympy.Not(expr)
5263            else:
5264                g = sympy.Eq(expr, concrete_val)  # type: ignore[arg-type]
5265
5266            if transmute_into_runtime_assert:
5267                self.defer_runtime_assert(
5268                    g,
5269                    f"propagate_real_tensors: {orig_expr} == {unsound_result}"
5270                )
5271                return concrete_val
5272
5273            if not self._suppress_guards_tls():
5274                if isinstance(g, sympy.Rel):
5275                    # TODO: If we successfully eliminate a symbol via equality, it
5276                    # is not actually necessary to save a guard for the equality,
5277                    # as we will implicitly generate a guard when we match that
5278                    # input against the symbol.  Probably the easiest way to
5279                    # implement this is to have maybe_guard_rel return a bool
5280                    # saying if it "subsumed" the guard (and therefore the guard
5281                    # is no longer necessary)
5282                    self._maybe_guard_rel(g)
5283
5284                if not self.allow_complex_guards_as_runtime_asserts:
5285                    # at this point, we've evaluated the concrete expr value, and have
5286                    # flipped/negated the guard if necessary. Now we know what to guard
5287                    # or defer to runtime assert on.
5288                    stack = CapturedTraceback.extract(skip=1)
5289                    guard = ShapeGuard(g, stack)
5290                    self.guards.append(guard)
5291                else:
5292                    # it's fine to defer simple guards here without checking,
5293                    # the _maybe_guard_rel() call above will set replacements if possible,
5294                    # and so the result here will be statically known
5295                    self.defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
5296
5297        except Exception:
5298            if fresh:
5299                self._remove_fx_node(node)
5300            raise
5301        else:
5302            if not self._suppress_guards_tls():
5303                if guard is not None:  # we might have deferred this to runtime assert
5304                    self._log_guard("eval", g, forcing_spec=forcing_spec)
5305
5306                    for s in g.free_symbols:
5307                        self.symbol_guard_counter[s] += 1
5308                        # Forcing_spec to avoid infinite recursion
5309                        if (
5310                            not forcing_spec and
5311                            config.symbol_guard_limit_before_specialize is not None and
5312                            self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize
5313                        ):
5314                            # Force specialization
5315                            self.log.info(
5316                                "symbol_guard_limit_before_specialize=%s exceeded on %s",
5317                                config.symbol_guard_limit_before_specialize,
5318                                s
5319                            )
5320                            self.evaluate_expr(s, forcing_spec=True)
5321            else:
5322                self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
5323
5324        return concrete_val
5325
5326    def cleanup(self):
5327        """
5328        Break reference cycles.
5329
5330        This destroys the stacks. If you really want to keep them, we
5331        just need some way to break references on code objects.
5332        """
5333        for g in self.guards:
5334            g.stack.cleanup()
5335        for s in self.var_to_stack.values():
5336            s.cleanup()
5337        for ras in self.deferred_runtime_asserts.values():
5338            for ra in ras:
5339                ra.stack.cleanup()
5340
5341    @record_shapeenv_event(save_tracked_fakes=True)
5342    def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
5343        """Create an assert that is checked at runtime
5344
5345        Args:
5346            orig_expr (sympy.Expr): Boolean expression to assert is true
5347            msg (str): Message to display on assertion failure
5348            fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding
5349                to the expression, if applicable
5350
5351        """
5352        expr = orig_expr
5353
5354        # TODO: split conjunctions and evaluate them separately
5355
5356        static_expr = self._maybe_evaluate_static(expr)
5357        if static_expr is not None:
5358            self.log.debug("runtime_assert %s == %s [statically known]", orig_expr, static_expr)
5359            return static_expr
5360
5361        # Attempt to eliminate the unbacked SymInt
5362        new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
5363        if not self.prefer_deferred_runtime_asserts_over_guards and new_expr.free_symbols <= self.var_to_val.keys():
5364            # Do a normal guard
5365            return self.evaluate_expr(new_expr, fx_node=fx_node)
5366        # NB: Don't use new_expr as expr; it could contain gunk like shape0
5367        # which we don't want to guard on
5368
5369        # OK, we're definitely doing a runtime assert now
5370        if (
5371            self._translation_validation_enabled
5372            and fx_node is not None
5373            and not self._suppress_guards_tls()
5374        ):
5375            node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
5376            assert node is not None
5377            if fresh:
5378                self._add_fx_node_metadata(node)
5379
5380        if not self._suppress_guards_tls():
5381            # If you're here because of this assert, read Note [Backwards runtime asserts]
5382            # in torch/_inductor/graph.py
5383            assert not self.runtime_asserts_frozen, expr
5384
5385            self._check_frozen(expr, sympy.true)
5386
5387            # eliminate symbols on equality tests / refine ranges
5388            if isinstance(expr, sympy.Rel):
5389                self._maybe_guard_rel(expr)
5390
5391            # canonicalise to remove equations that are trivially equal
5392            orig_expr = expr
5393            expr = canonicalize_bool_expr(expr)
5394            stack = CapturedTraceback.extract(skip=1)
5395            ra = RuntimeAssert(expr, msg, stack)
5396            # TODO: Do this in a way that is less janky than int(s.name[1:])
5397            cands = sorted((s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)), key=lambda s: int(s.name[1:]))
5398            # Is None when prefer_deferred_runtime_asserts_over_guards=True
5399            # and the guard in question has no unbacked SymInts in front
5400            ix = cands[-1] if cands else None
5401            self.deferred_runtime_asserts.setdefault(ix, []).append(ra)
5402            self.num_deferred_runtime_asserts += 1
5403            self._update_version_counter()
5404            self._log_guard("runtime_assert", orig_expr, forcing_spec=False)
5405        else:
5406            self._log_guard("runtime_assert [guard suppressed]", orig_expr, forcing_spec=False)
5407
5408        return True
5409
5410    # Refines the ranges of the variables present in 'guard'.
5411    #
5412    # This function tries to refine the range of the variables inside
5413    # 'guard' by reasoning about it. Specifically, when 'guard' is a
5414    # 'sympy.Relational' operation.
5415    #
5416    # It does mainly 3 things:
5417    #   1. Tries to isolate a variable in the left-hand side
5418    #   2. Compute the value range of the right-hand side
5419    #   3. Update the value range of the variable, if better
5420    def _refine_ranges(self, expr: sympy.Expr) -> None:
5421        expr = self.simplify(expr)
5422
5423        for symbol in expr.free_symbols:
5424            assert isinstance(symbol, sympy.Symbol)
5425
5426            if isinstance(self.var_to_val.get(symbol, None), SingletonInt):
5427                # Skip var_to_range logic for SingletonInt which is only used
5428                # for jagged layout NestedTensors today
5429                continue
5430
5431            r = try_solve(expr, symbol)
5432
5433            if r is None or not (symbol.is_integer and r[1].is_integer):
5434                # Range refinement only supports integer symbols for now.
5435                # There are lots of SymPy bugs when it comes to comparing
5436                # reals and integers, so we skip that for now.
5437                continue
5438
5439            r_expr, rhs = r
5440            vr = self.var_to_range[symbol]
5441            lower, upper = vr.lower, vr.upper
5442
5443            rhs_vr = bound_sympy(rhs, self.var_to_range)
5444
5445            # Let's suppose that we have a preexisting range for x [0, 100].
5446            # Now, we issue a guard x > y, where the range for y is [50, 150].
5447            # Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen,
5448            # refining x to [51, 100], since x must be greater than y, but the lowest
5449            # y could be is 50.
5450            #
5451            # sympy.Eq may update both lower and upper bounds.
5452            # sympy.G{t,e} may update the lower bound, only.
5453            # sympy.L{t,e} may update the upper bound, only.
5454            if lower < rhs_vr.lower and isinstance(r_expr, (sympy.Eq, sympy.Ge, sympy.Gt)):
5455                # Strictly greater relations allow us to refine a bit more, since
5456                # x < y implies that the lower bound for x is: y + 1.
5457                lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt))
5458            if upper > rhs_vr.upper and isinstance(r_expr, (sympy.Eq, sympy.Le, sympy.Lt)):
5459                upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt))
5460
5461            # Do nothing if the new value range is no better than what we already have.
5462            if vr == ValueRanges(lower, upper):
5463                continue
5464
5465            # Updates the range and the guards corresponding to each bound of the symbol.
5466            self._update_var_to_range(symbol, ValueRanges(lower, upper))
5467            # If the range is refined to singleton, set replacement
5468            if self.var_to_range[symbol].is_singleton():
5469                self._set_replacement(symbol, self.var_to_range[symbol].lower, "range_refined_to_singleton")
5470
5471            # Clears the cache, since this update can change the result.
5472            self._maybe_evaluate_static.cache_clear()
5473
5474    @lru_cache(maxsize=None)
5475    @record_shapeenv_event()
5476    def constrain_symbol_range(self, s: sympy.Symbol, compiler_min: int, compiler_max: int):
5477        upd_vr = ValueRanges(compiler_min, compiler_max)
5478        old_vr = self.var_to_range.get(s, ValueRanges.unknown())
5479        self._update_var_to_range(s, upd_vr)
5480        if (new_vr := self.var_to_range[s]) != old_vr:
5481            log.info("constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper)
5482
5483
5484def _is_int(expr):
5485    return isinstance(expr, SymInt) and expr.node.expr.is_number
5486
5487# WARNING: This is legacy, DO NOT USE
5488def _is_dim_dynamic(t, d):
5489    return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices
5490
5491class PropagateUnbackedSymInts(torch.fx.Interpreter):
5492    def run_node(self, n: torch.fx.Node):
5493        """
5494        Run an FX node, propagating unbacked Symbol bindings to the new fake tensor
5495        """
5496        from torch._guards import detect_fake_mode
5497
5498        result = super().run_node(n)
5499        rebind_unbacked(detect_fake_mode().shape_env, n, result)
5500        return result
5501
5502
5503def _find_user_code_frame():
5504    frame = inspect.currentframe()
5505    while frame is not None:
5506        if not frame.f_code.co_filename.startswith(
5507            os.path.dirname(inspect.getfile(torch)) + os.path.sep
5508        ):
5509            break
5510        frame = frame.f_back
5511    return frame
5512
5513
5514def _blame_user_code(e, frame):
5515    frame_summary = traceback.FrameSummary(
5516        frame.f_code.co_filename,
5517        frame.f_lineno,
5518        frame.f_code.co_name,
5519    )
5520    msg = e.args[0]
5521    msg += (
5522        '\n\nThe following call raised this error:\n' +
5523        ''.join(traceback.StackSummary.from_list([frame_summary]).format())
5524    )
5525    e.args = (msg,)
5526
5527
5528class _PythonPrinter(sympy.printing.str.StrPrinter):
5529    """
5530    Util printer that replaces sympy symbols with their source-level names
5531    and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline
5532    (i.e., as ==, !=, >, <).
5533    """
5534
5535    def __init__(self, src_map):
5536        super().__init__()
5537        self.src_map = src_map
5538
5539    def _print_Symbol(self, sym):
5540        return self.src_map[sym.name][0]
5541
5542    def _print_Relational(self, expr):
5543        lhs = self.parenthesize(expr.lhs, sympy.printing.precedence.precedence(expr))
5544        rel_op = expr.rel_op
5545        rhs = self.parenthesize(expr.rhs, sympy.printing.precedence.precedence(expr))
5546        return f"{lhs} {rel_op} {rhs}"
5547
5548
5549def _suggest_torch_checks(e, src_map):
5550    # extract the unresolved condition on unbacked symints in the error
5551    cond = e.cond
5552    diff = ", ".join(s.name for s in cond.free_symbols if s.name not in src_map)
5553    if diff:
5554        log.warning("Unable to find user code corresponding to {%s}", diff)
5555        return
5556    printer = _PythonPrinter(src_map)
5557    msg = e.args[0]
5558    msg += "\nTo fix the error, insert one of the following checks before this call:"
5559    # suggested fixes to resolve `cond`` are to tell the compiler to assume
5560    # either `cond` or its negation (the user will need to select which)
5561    suggested_fixes = [
5562        f"torch._check({printer.doprint(cond)})",
5563        f"torch._check({printer.doprint(sympy.Not(cond))})",
5564    ]
5565    for i, fix in enumerate(suggested_fixes):
5566        msg += f"\n  {i+1}. {fix}"
5567    src_mapped = ', '.join(
5568        f"`{s}` with {' or '.join(src_map[s])}"
5569        for s in sorted(s.name for s in cond.free_symbols)
5570    )
5571    msg += f"\n\n(These suggested fixes were derived by replacing {src_mapped} in {cond} and its negation.)"
5572    e.args = (msg,)
5573
5574
5575def _suggest_fixes_for_data_dependent_error_non_strict(e):
5576    """
5577    Given a raised data-dependent error, add the following to the error message:
5578    1. the closest user code location that raised the error;
5579    2. suggested fixes for the error in terms of live variables at that location.
5580    """
5581
5582    # walk the stack up from the data-dependent error until a non-torch frame is found
5583    frame = _find_user_code_frame()
5584    if frame is not None:
5585        # add frame info to error message
5586        _blame_user_code(e, frame)
5587
5588        # map symbol names reachable via frame locals to their source-level names
5589        src_map = defaultdict(list)
5590        for var, val in frame.f_locals.items():
5591            # figure out how to access any symbol inside `val` through `var`
5592            for path, leaf in pytree.tree_leaves_with_path(val):
5593                name = var + pytree.keystr(path)
5594                if isinstance(leaf, torch.SymInt):
5595                    src_map[str(leaf.node.expr)].append(name)
5596                elif isinstance(leaf, torch.Tensor):
5597                    for i, dim in enumerate(leaf.shape):
5598                        if isinstance(dim, torch.SymInt):
5599                            src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]")
5600
5601        # add suggested torch.check()s based on `src_map` to the error message
5602        # replacing unbacked symints in the unresolved condition in the error
5603        _suggest_torch_checks(e, src_map)
5604