1# mypy: ignore-errors 2 3""" 4``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with 5our symbolic shapes reasoning system that is used heavily in torch.compile. Although 6this is not generally considered public API, when writing framework code in PyTorch 7as well as extensions to PyTorch (e.g., in custom operator implementations), you may 8need to make use of these APIs to setup dynamic shapes support appropriately. 9""" 10 11import builtins 12import collections 13import functools 14import inspect 15import itertools 16import logging 17import math 18import operator 19import os 20import re 21import sys 22import threading 23import traceback 24from collections import defaultdict 25from contextlib import contextmanager 26from dataclasses import dataclass, field 27from enum import Enum 28import atexit 29from typing import ( 30 Any, 31 cast, 32 Callable, 33 Dict, 34 Iterable, 35 List, 36 Optional, 37 Sequence, 38 Set, 39 Tuple, 40 Type, 41 Union, 42 TYPE_CHECKING 43) 44from typing_extensions import TypeAlias 45 46import torch 47import torch.fx 48import torch.fx.traceback as fx_traceback 49from torch.fx.experimental import _config as config 50 51from torch.fx.experimental.recording import ( 52 FakeTensorMeta, 53 ShapeEnvEvent, 54 record_shapeenv_event, 55 replay_shape_env_events, 56 shape_env_check_state_equal 57) 58from torch.fx.experimental.sym_node import SymNode, SymTypes 59from torch._logging import trace_structured, structured 60 61# NB: The sym_* functions are used via getattr() and must be imported here. 62from torch import SymBool, SymFloat, SymInt 63from torch._guards import ShapeGuard, Source, TracingContext 64from torch.utils._python_dispatch import is_traceable_wrapper_subclass 65from torch.utils._sympy.functions import ( 66 Application, FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt 67) 68from torch.utils._sympy.solve import try_solve 69from torch.utils._sympy.numbers import int_oo 70from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError 71from torch.utils._sympy.singleton_int import SingletonInt 72from torch.utils._traceback import format_frame, CapturedTraceback 73from torch._utils_internal import signpost_event 74from torch._subclasses.meta_utils import is_sparse_any 75import torch.utils._pytree as pytree 76from torch.utils._sympy.symbol import SymT, make_symbol, symbol_is_type 77 78from torch._logging import LazyString 79 80if TYPE_CHECKING: 81 from torch._dynamo.source import TensorPropertySource 82 83InputList = List 84DimList = List 85 86log = logging.getLogger(__name__) 87 88import sympy 89from sympy.printing.str import StrPrinter 90from sympy.printing.precedence import precedence, PRECEDENCE 91 92class GuardOnDataDependentSymNode(RuntimeError): 93 cond: sympy.Expr 94 95 def __init__(self, cond, *args): 96 super().__init__(*args) 97 self.cond = cond 98 99class PendingUnbackedSymbolNotFound(RuntimeError): 100 pass 101 102aten = torch._ops.ops.aten # type: ignore[has-type] 103 104__all__ = [ 105 "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int", 106 "guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr", 107 "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", 108 "is_concrete_bool", "is_nested_int", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY", 109 "has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext", 110 "StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true", 111 "guard_size_oblivious", "check_consistent", 112 "compute_unbacked_bindings", "ConvertIntKey", 113 "rebind_unbacked", "resolve_unbacked_bindings", "is_accessor_node", 114] 115 116# FX node metadata keys for symbolic shape FX graph. 117SHAPEENV_EVENT_KEY = "shapeenv_event" 118CURRENT_NODE_KEY = "current_node" 119 120 121def log_lru_cache_stats(wrapped_f): 122 log.debug("lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info()) 123 124 125# Wrapper on lru_cache that reports statistics at process end 126def lru_cache(maxsize): 127 def inner(f): 128 wrapped_f = functools.lru_cache(maxsize)(f) 129 old_cache_clear = wrapped_f.cache_clear 130 prev_hits = 0 131 prev_misses = 0 132 133 # TODO: There's a ref-cycle here (wrapped_f -> cumulative_cache_info 134 # -> wrapped_f) but cannot be solved with weakref as wrapped_f is not 135 # weakref'able on some versions of Python 136 137 def cumulative_cache_info(): 138 cur = wrapped_f.cache_info() 139 return functools._CacheInfo( 140 prev_hits + cur.hits, 141 prev_misses + cur.misses, 142 cur.maxsize, 143 cur.currsize, 144 ) 145 146 def new_cache_clear(): 147 nonlocal prev_hits, prev_misses 148 cur = wrapped_f.cache_info() 149 prev_hits += cur.hits 150 prev_misses += cur.misses 151 old_cache_clear() 152 153 wrapped_f.cache_clear = new_cache_clear 154 wrapped_f.cumulative_cache_info = cumulative_cache_info 155 if log.isEnabledFor(logging.DEBUG): 156 atexit.register(log_lru_cache_stats, wrapped_f) 157 return wrapped_f 158 159 return inner 160 161# These are modules that contain generic code for interacting with ShapeEnv 162# which are unlikely to identify a particular interesting guard statement 163@lru_cache(None) 164def uninteresting_files() -> Set[str]: 165 import torch._inductor.sizevars 166 import torch._library.fake_impl 167 import torch._subclasses.meta_utils 168 import torch._subclasses.fake_tensor 169 mods = [ 170 sys.modules[__name__], 171 torch.fx.experimental.recording, 172 torch.fx.experimental.sym_node, 173 torch.fx.interpreter, 174 torch, 175 torch._inductor.sizevars, 176 torch._library.fake_impl, 177 torch._subclasses.meta_utils, 178 torch._subclasses.fake_tensor, 179 ] 180 return {inspect.getfile(m) for m in mods} 181 182# We don't bother with the metaclass as all of the dispatching logic happens 183# entirely from Python 184# 185# Didn't bother with ancestors for now, unlikely to have multiple modes for 186# symints right now 187 188class ConstraintViolationError(RuntimeError): 189 pass 190 191def has_symbolic_sizes_strides(elem) -> bool: 192 return elem._has_symbolic_sizes_strides 193 194Int = Union[torch.SymInt, int] 195 196def create_contiguous(shape: Sequence[Int]) -> List[Int]: 197 strides: List[Int] = [1] 198 for dim in reversed(shape[:-1]): 199 strides.append(dim * strides[-1]) 200 return list(reversed(strides)) 201 202def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int: 203 """ 204 Retrieve the hint for an int (based on the underlying real values as observed 205 at runtime). If no hint is available (e.g., because data dependent shapes), 206 if fallback is not None, use that instead (otherwise raise an error). 207 """ 208 if isinstance(a, torch.SymInt): 209 return a.node.require_hint(fallback) 210 assert type(a) is int, a 211 return a 212 213Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool] 214 215def has_hint(a: Scalar) -> bool: 216 if isinstance(a, SymTypes): 217 return a.node.has_hint() 218 return True 219 220def is_concrete_int(a: Union[int, SymInt]) -> bool: 221 r""" Utility to check if underlying object 222 in SymInt is concrete value. Also returns 223 true if integer is passed in. 224 225 Args: 226 a (SymInt or int): Object to test if it int 227 """ 228 assert isinstance(a, (SymInt, int)) 229 230 if isinstance(a, int): 231 return True 232 233 if isinstance(a.node.expr, sympy.core.numbers.Integer): 234 return True 235 236 return False 237 238# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime. 239# So make sure only type checker evaluates this alias. 240# Xref: https://www.internalfb.com/diff/D53324783 241SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean" 242 243def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool: 244 """ 245 Perform a guard on a symbolic boolean expression in a size oblivious way. 246 This is typically used when a non-oblivious test would result in a guard 247 on a data dependent value of which we don't know the value of at compile time. 248 When a guard is tested this way, we may diverge in behavior from how regular 249 PyTorch semantics would treat it. For more information, see 250 https://github.com/pytorch/pytorch/pull/118579 251 """ 252 if isinstance(expr, torch.SymBool): 253 return expr.node.guard_size_oblivious("", 0) 254 else: 255 assert isinstance(expr, bool), expr 256 return expr 257 258def check_consistent(new, old) -> None: 259 """ 260 Test that two "meta" values (typically either Tensor or SymInt) have 261 the same values, e.g., after retracing. If we don't understand the 262 quantities in question, we'll just skip the consistency check. 263 """ 264 # TODO: do boolean equality test too, see 265 # https://github.com/pytorch/pytorch/issues/124110 266 scalar_types = (torch.SymInt, torch.SymFloat, int, float) 267 268 if isinstance(new, torch.Tensor): 269 assert isinstance(old, torch.Tensor) 270 torch._check(old.dim() == new.dim(), lambda: f"{old.shape} != {new.shape} (old != new)") 271 # Do this manually so that each individual test is irrefutable 272 # (TODO: should be a helper for this, maybe sym_eq? That 273 # gives us a compound expression and I'm not sure it 274 # simplifies right now) 275 for i, j in zip(old.shape, new.shape): 276 torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)") 277 # NB: bool is subclass of int 278 elif isinstance(new, scalar_types) and not isinstance(new, bool): 279 assert isinstance(old, scalar_types) and not isinstance(old, bool), f"{old} != {new}" 280 torch._check(old == new, lambda: f"{old} != {new} (old != new)") 281 282def resolve_unbacked_bindings(shape_env, bindings): 283 if bindings is None: 284 return None 285 return { 286 shape_env.unbacked_renamings.get(k, k): v 287 for k, v in bindings.items() 288 } 289 290def rebind_unbacked(shape_env, n: torch.fx.Node, result): 291 """ 292 Suppose we are retracing a pre-existing FX graph that previously had 293 fake tensor propagation (and therefore unbacked SymInts). When we retrace, 294 we re-propagate fake tensors, which results in new unbacked SymInts. 295 When this happens, we need to tell the shape environment about the equivalence 296 of the old and new unbacked SymInts. Pass us the old torch.fx.Node (which 297 has the old binding information) and the new result (which we can extract the 298 new unbacked SymInts out from). 299 """ 300 from torch._dynamo.tensor_version_op import _tensor_version 301 302 # Inputs never need rebinding 303 if n.op == "placeholder": 304 return 305 306 if bindings := resolve_unbacked_bindings(shape_env, n.meta.get("unbacked_bindings")): 307 for raw_u0, path in bindings.items(): 308 u1 = pytree.key_get(result, path) 309 # tensor_version ops get specialized after AOTAutograd, it's OK, 310 # we don't actually want to do asserts on them. This is all a bit 311 # questionable though 312 if isinstance(u1, int) and n.target is _tensor_version: 313 log.info("rebind_unbacked: discard _tensor_version %s %s -> %s", raw_u0, path, u1) 314 continue 315 raw_u1 = u1.node.expr 316 # Simplify SymBool binding 317 if ( 318 isinstance(raw_u1, sympy.Piecewise) and 319 len(raw_u1.args) == 2 and 320 raw_u1.args[0][0] == 1 and 321 isinstance(eq := raw_u1.args[0][1], sympy.Eq) and 322 isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) and 323 shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) and 324 eq.rhs == 1 and 325 raw_u1.args[1] == (0, True) 326 ): 327 # This is what the pattern match above is testing 328 repacked = _sympy_cast_symbool_to_symint_guardless(sympy.Eq(new_raw_u1, 1)) 329 assert repacked == raw_u1, f"{repacked} != {raw_u1}" 330 # Cancel the to_int(to_bool(x)). This is sound because x in 331 # [0, 1] 332 raw_u1 = new_raw_u1 333 assert isinstance(raw_u1, sympy.Symbol) 334 # The old and new could be the same if you improperly hit the memo 335 # while retracing. Make sure you updated FakeTensorMode.epoch 336 assert raw_u0 != raw_u1, f"{raw_u0} possible memo disaster" 337 # Reuse the OLD symbol name 338 shape_env._rename_unbacked_to(raw_u1, raw_u0) 339 340# NB: You could try to expand this to cover more cases by simply 341# detecting whenever you have an int output, but this is a bit 342# dangerous in case someone adds a function that returns an int but is 343# mutating. So manually whitelist for now. 344def is_accessor_node(node: torch.fx.Node) -> bool: 345 # Dynamo only exercised condition 346 if ( 347 node.op == "call_method" 348 and isinstance(node.args[0].meta.get("example_value"), torch.Tensor) 349 and node.target in ["size", "stride", "storage_offset", "item"] 350 ): 351 return True 352 if node.op == "call_function" and node.target in [ 353 torch.ops.aten.sym_size, 354 torch.ops.aten.sym_size.default, 355 torch.ops.aten.sym_size.int, 356 torch.ops.aten.sym_stride, 357 torch.ops.aten.sym_stride.default, 358 torch.ops.aten.sym_stride.int, 359 torch.ops.aten.sym_storage_offset, 360 torch.ops.aten.sym_storage_offset.default, 361 torch.ops.aten.sym_numel.default, 362 ]: 363 return True 364 return False 365 366def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean: 367 r""" Canonicalize a boolean expression by transforming it into a lt / le 368 inequality and moving all the non-constant terms to the rhs. 369 We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr 370 recursively 371 nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924 372 373 Args: 374 expr (sympy.Expr): Expression to canonicalize 375 """ 376 # Canonicalise an inequality by transforming it into a lt / le 377 # inequality and moving all the non-constant terms to the rhs 378 # We canonicalise And / Ors / Not via cnf 379 # nb. Relational.canonical in sympy is broken 380 # https://github.com/sympy/sympy/issues/25924 381 382 if not isinstance(expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne)): 383 return expr 384 385 if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)): 386 expr = sympy.logic.boolalg.to_cnf(expr) 387 return _canonicalize_bool_expr_impl(expr) 388 389def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: 390 """ 391 After canonicalization, we are guaranteed to have eliminated Ge/Gt relations 392 (rewriting them to Le/Lt, respectively). 393 """ 394 if isinstance(expr, (sympy.And, sympy.Or)): 395 return type(expr)(*map(canonicalize_bool_expr, expr.args)) 396 397 opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le} 398 if isinstance(expr, tuple(opposite.keys())): 399 rhs = expr.lhs - expr.rhs 400 t = opposite[type(expr)] 401 else: 402 assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne)) 403 rhs = expr.rhs - expr.lhs 404 t = type(expr) 405 406 def is_neg(t): 407 return t.is_negative or (isinstance(t, sympy.Mul) and t.args[0].is_negative) 408 409 lhs = 0 410 rhs = _reduce_to_lowest_terms(rhs) 411 if isinstance(rhs, sympy.Add): 412 pos = [] 413 neg = [] 414 for term in rhs.args: 415 if is_neg(term): 416 neg.append(-term) 417 else: 418 pos.append(term) 419 lhs = sympy.Add(*neg) 420 rhs = sympy.Add(*pos) 421 elif is_neg(rhs): 422 # lhs == 0 423 lhs, rhs = -rhs, 0 424 return t(lhs, rhs) 425 426 427def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: 428 """ 429 Eliminates any integer factor from a given expression. 430 E.g., 6x + 4y reduces to 3x + 2y. 431 432 Useful when an expression is == or != to 0. 433 """ 434 def integer_coefficient(x): 435 if isinstance(x, sympy.Integer): 436 return abs(int(x)) 437 elif isinstance(x, sympy.Mul): 438 return math.prod([abs(int(arg)) for arg in x.args if isinstance(arg, sympy.Integer)]) 439 else: 440 return 1 441 442 if isinstance(expr, sympy.Add): 443 atoms = expr.args 444 factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) 445 atoms = [x / factor for x in atoms] 446 return sympy.Add(*atoms) 447 else: 448 return expr / integer_coefficient(expr) 449 450 451def is_concrete_bool(a: Union[bool, SymBool]) -> bool: 452 r""" Utility to check if underlying object 453 in SymBool is concrete value. Also returns 454 true if integer is passed in. 455 Args: 456 a (SymBool or bool): Object to test if it bool 457 """ 458 assert isinstance(a, (SymBool, bool)) 459 460 if isinstance(a, bool): 461 return True 462 463 if isinstance(a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse)): 464 return True 465 466 return False 467 468def is_nested_int(s): 469 return isinstance(s, torch.SymInt) and s.node.is_nested_int() 470 471def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]: 472 if isinstance(val, SymTypes): 473 # This allow applies to the jagged layout NestedTensor case as 474 # nested ints are not symbolic 475 if is_symbolic(val): 476 yield val.node.expr 477 elif isinstance(val, sympy.Basic): 478 yield val 479 elif isinstance(val, (int, float, bool)): 480 pass 481 elif isinstance(val, (tuple, list)): 482 for s in val: 483 yield from _iterate_exprs(s) 484 elif is_sparse_any(val): 485 yield from _iterate_exprs(val.size()) 486 elif isinstance(val, torch.Tensor): 487 yield from _iterate_exprs(val.size()) 488 yield from _iterate_exprs(val.stride()) 489 yield from _iterate_exprs(val.storage_offset()) 490 elif val is None: 491 pass 492 else: 493 raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}") 494 495def free_symbols(val: Union[SymInt, sympy.Expr, torch.Tensor]) -> Set[sympy.Symbol]: 496 if val is None: 497 return set() 498 itr = _iterate_exprs(val) 499 # we need at least 1 to call union, so we hand code the identity 500 try: 501 first_expr = next(itr) 502 except StopIteration: 503 return set() 504 505 return first_expr.free_symbols.union(*(e.free_symbols for e in itr)) 506 507def has_free_symbols(val: Union[SymInt, torch.Tensor]) -> bool: 508 """Faster version of bool(free_symbols(val))""" 509 return not all(e.is_number for e in _iterate_exprs(val)) 510 511# Like free_symbols, but filtered to only report unbacked symbols 512def free_unbacked_symbols(x): 513 # NB: keep synced with is_unbacked_symint 514 return {s for s in free_symbols(x) if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))} 515 516# WARNING: Don't use this on Dynamo produced graphs, they don't have meta 517# setup! 518def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]: 519 if ( 520 "val" in node.meta and 521 isinstance(node.meta["val"], torch.SymInt) and 522 isinstance(node.meta["val"].node.expr, sympy.Symbol) and 523 (node.op == "placeholder" or free_unbacked_symbols(node.meta["val"].node.expr)) 524 ): 525 return node.meta["val"].node.expr 526 return None 527 528def find_symbol_binding_fx_nodes(graph): 529 r = {} 530 # NB: Prefer first occurrence of symbol 531 for node in graph.nodes: 532 if is_symbol_binding_fx_node(node) and node.meta["val"].node.expr not in r: 533 r[node.meta["val"].node.expr] = node 534 return r 535 536 537# Analogous to ConvertIntSource 538@dataclass(frozen=True) 539class ConvertIntKey: 540 def __str__(self) -> str: 541 return ".cast_symbool_to_symint_guardless()" 542 543 def get(self, b: bool) -> int: 544 """Get the int value from bool""" 545 return cast_symbool_to_symint_guardless(b) 546 547 548@dataclass(frozen=True) 549class CallMethodKey: 550 name: str 551 552 def __str__(self) -> str: 553 return f".{self.name}()" 554 555 def get(self, o: Any) -> Any: 556 """Call the method on object""" 557 return getattr(o, self.name)() 558 559 560@dataclass(frozen=True) 561class InnerTensorKey: 562 inner_name: str 563 564 def __str__(self) -> str: 565 return f".{self.inner_name}" 566 567 def get(self, o: Any) -> Any: 568 """Get the inner tensor attribute""" 569 return getattr(o, self.inner_name) 570 571 572@dataclass(frozen=True) 573class DivideByKey: 574 divisor: int 575 576 def __str__(self) -> str: 577 return f".__floordiv__({self.divisor})" 578 579 def get(self, o: int) -> int: 580 """Divide object by divisor""" 581 return o // self.divisor 582 583 584def compute_unbacked_bindings(shape_env, example_value, old_example_value=None, peek=False): 585 """ 586 After having run fake tensor propagation and producing example_value 587 result, traverse example_value looking for freshly bound unbacked 588 symbols and record their paths for later. It is an error if 589 we have allocated an unbacked SymInt but it cannot be found in 590 example_value. (NB: this means if you have a multi-output 591 function, you must call this on the tuple of tensor output, you 592 cannot wait!) 593 594 The peek parameter lets you check out what the bindings are without 595 changing the affected list. This is primarily useful for ensuring 596 unbacked_var_to_val is promptly populated when propagate_real_tensors is on. 597 """ 598 if shape_env is None: 599 return 600 fs = shape_env.pending_fresh_unbacked_symbols 601 pending = set(fs) 602 if pending: 603 if not peek: 604 log.info("compute_unbacked_bindings %s", fs) 605 fs.clear() 606 607 def free_unbacked_symbols_with_path( 608 a, path, real=None 609 ) -> Dict[sympy.Symbol, pytree.KeyPath]: 610 r = {} 611 if isinstance(a, (tuple, list)): 612 for i in range(len(a)): 613 r.update( 614 free_unbacked_symbols_with_path( 615 a[i], path + (pytree.SequenceKey(i),), 616 real=real[i] if real is not None else None 617 ) 618 ) 619 elif is_traceable_wrapper_subclass(a): 620 # TODO: Determine if this is correct 621 attrs, _ = a.__tensor_flatten__() 622 for attr in attrs: 623 sub = getattr(a, attr) 624 r.update( 625 free_unbacked_symbols_with_path(sub, path + (InnerTensorKey(attr),)) 626 ) 627 elif isinstance(a, torch.Tensor): 628 r.update( 629 free_unbacked_symbols_with_path( 630 a.size(), path + (CallMethodKey("size"),), 631 real=a.real_tensor.size() if a.real_tensor is not None else None 632 ) 633 ) 634 r.update( 635 free_unbacked_symbols_with_path( 636 a.stride(), path + (CallMethodKey("stride"),), 637 real=a.real_tensor.stride() if a.real_tensor is not None else None 638 ) 639 ) 640 r.update( 641 free_unbacked_symbols_with_path( 642 a.storage_offset(), path + (CallMethodKey("storage_offset"),), 643 real=a.real_tensor.storage_offset() if a.real_tensor is not None else None 644 ) 645 ) 646 647 # NB: Intentionally access _expr, not expr, do not want 648 # simplification! 649 elif ( 650 isinstance(a, (torch.SymInt, torch.SymFloat)) 651 and isinstance(s := a.node._expr, sympy.Symbol) 652 and s in pending 653 ): 654 r[s] = path 655 if real is not None: 656 shape_env.set_unbacked_var_to_val(s, real) 657 pending.remove(s) 658 # When an unbacked SymInt is perfectly divisible by an integer 659 # constant, we replace it with the integer constant to improve 660 # reasoning capabilities. However, in synthetic examples, it is 661 # then possible that the factor never is explicitly allocated. 662 # Fortunately, we can compute it by division. 663 elif ( 664 isinstance(a, torch.SymInt) 665 and isinstance(s := a.node._expr, sympy.Mul) 666 and len(s.args) == 2 667 and isinstance(lhs := s.args[0], sympy.Integer) 668 and isinstance(rhs := s.args[1], sympy.Symbol) 669 and rhs in pending 670 ): 671 # TODO: DivideByKey needs to test divisibility at runtime! 672 r[s] = path + (DivideByKey(int(lhs)),) 673 if real is not None: 674 shape_env.set_unbacked_var_to_val(s, real // int(lhs)) 675 pending.remove(rhs) 676 # The annoyance here arises from the fact that SymBool is 677 # allocated by allocating a SymInt and then testing if it's equal 678 # to one. So you have a complicated binding site logic for this. 679 elif ( 680 isinstance(a, torch.SymBool) 681 and isinstance(s := a.node._expr, sympy.Eq) 682 # This must match create_unbacked_symbool EXACTLY 683 and isinstance(s.lhs, sympy.Symbol) 684 and s.rhs == 1 685 and s.lhs in pending 686 ): 687 r[s.lhs] = path + (ConvertIntKey(),) 688 if real is not None: 689 shape_env.set_unbacked_var_to_val(s, int(real)) 690 pending.remove(s.lhs) 691 692 return r 693 694 symbol_to_path = free_unbacked_symbols_with_path(example_value, ()) 695 if not peek and pending: 696 extra = ( 697 repr((example_value.stride(), example_value.storage_offset())) 698 if isinstance(example_value, torch.Tensor) 699 else "" 700 ) 701 raise PendingUnbackedSymbolNotFound( 702 f"Pending unbacked symbols {pending} not in returned outputs {example_value} {extra}.\n" 703 "Did you accidentally call new_dynamic_size() or item() more times " 704 "than you needed to in your fake implementation?\n" 705 "For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit" 706 ) 707 708 # Why do we have to do some rebinding here? If the original FX node 709 # wasn't a binding site because you had a memo hit, but post 710 # translation you aren't a memo hit anymore, there's now a new binding 711 # site... but we know (because it's the same FX node) that the value 712 # is actually the same, they're just not obviously equal anymore. 713 # 714 # The logic here is written carefully, because unlike the 715 # bind_unbacked case, we are not guaranteed to have a symbol for 716 # old_sym. If we have a symbol, do regular rename unbacked to; but if 717 # we don't, we need to specially eliminate the fresh unbacked symbol 718 # (NB: we are /trusting/ that the memoization is correct, and that we 719 # don't need to generate a new runtime assert. This is load bearing, 720 # as repropagation can happen after we've frozen runtime asserts.) 721 if old_example_value is not None: 722 for keypath in symbol_to_path.values(): 723 old_sym = pytree.key_get(old_example_value, keypath) 724 new_sym = pytree.key_get(example_value, keypath) 725 if ( 726 isinstance(new_sym, SymTypes) and 727 isinstance(new_s := new_sym.node.expr, sympy.Symbol) 728 ): 729 if isinstance(old_sym, SymTypes) and (old_s := old_sym.node.expr) != new_s: 730 if isinstance(old_s, sympy.Symbol): 731 shape_env._rename_unbacked_to(new_s, old_s) 732 else: 733 shape_env._eliminate_unbacked(new_s, old_s) 734 elif not isinstance(old_sym, SymTypes): 735 shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym)) 736 737 return symbol_to_path 738 739def definitely_true(a): 740 """ 741 Returns True only if we can tell that a is True, possibly introducing 742 a guard in the process. If a depends on some unbacked SymInt, we may 743 return False even though there may exist a possible value of the SymInt 744 that would cause the expression to return True. 745 746 When is it appropriate to use definitely_true? First, if you can use 747 a higher level combinator like parallel_or/parallel_and, prefer using 748 those instead, they are definitely safe (modulo short-circuiting). 749 Second, it can be used if the program would behave equivalently if 750 definitely_true always returned False (parallel_or/parallel_and are 751 examples of this pattern, modulo short-circuiting). Finally, it even 752 be OK if the program wouldn't behave equivalently, so long as the 753 change is semantics preserving. It can be semantics preserving if 754 the program errors in more cases than it did previously (but otherwise 755 behaves identically), or if it changes some quantity in a way that 756 doesn't matter (e.g., strides often fall in this bucket.) 757 """ 758 if isinstance(a, SymBool): 759 if a.node.has_hint(): 760 return guard_bool(a) 761 else: 762 return False 763 return bool(a) 764 765def definitely_false(a): 766 """ 767 Returns True only if we can tell that a is False, possibly introducing 768 a guard in the process. If a depends on some unbacked SymInt, we may 769 return False even though there may exist a possible value of the SymInt 770 that would cause the expression a to be False. See definitely_true 771 for more usage guidance. 772 """ 773 if isinstance(a, SymBool): 774 if a.node.has_hint(): 775 return not guard_bool(a) 776 else: 777 return False 778 return not bool(a) 779 780def statically_known_true(x: Union[bool, SymBool]) -> bool: 781 """Returns True if x can be simplified to a constant and is true. 782 783 .. note:: 784 This function doesn't introduce new guards, so the expression may end 785 up evaluating to true at runtime even if this function returns False. 786 787 Args: 788 x (bool, SymBool): The expression to try statically evaluating 789 790 """ 791 if isinstance(x, SymBool): 792 expr = x.node.expr 793 shape_env = x.node.shape_env 794 try: 795 simplified = shape_env._maybe_evaluate_static(expr) 796 if simplified is not None: 797 return bool(simplified) 798 except Exception: 799 log.debug("Could not simplify %s", expr) 800 return False 801 assert isinstance(x, bool) 802 return x 803 804 805def parallel_or(*args): 806 """ 807 Evaluate the logical OR of several arguments, avoiding guarding on 808 unbacked SymInts if another argument is definitely True. 809 """ 810 if any(statically_known_true(a) for a in args): 811 return True 812 if any(definitely_true(a) for a in args): 813 return True 814 return any(args) 815 816def parallel_and(*args): 817 """ 818 Evaluate the logical FALSE of several arguments, avoiding guarding on 819 unbacked SymInts if another argument is definitely False. 820 """ 821 if any(statically_known_true(torch.sym_not(a)) for a in args): 822 return False 823 if any(definitely_false(a) for a in args): 824 return False 825 return all(args) 826 827def sym_eq(x, y): 828 """ 829 Like ==, but when run on list/tuple, it will recursively test equality 830 and use sym_and to join the results together, without guarding. 831 """ 832 if (isinstance(x, tuple) and isinstance(y, tuple)) or (isinstance(x, list) and isinstance(y, list)): 833 if len(x) != len(y): 834 return False 835 return functools.reduce(operator.and_, map(sym_eq, x, y), True) 836 elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)): 837 return x == y 838 else: 839 raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}") 840 841def guard_scalar(a): 842 if isinstance(a, (SymBool, bool)): 843 return guard_bool(a) 844 elif isinstance(a, (SymInt, int)): 845 return guard_int(a) 846 elif isinstance(a, (SymFloat, float)): 847 return guard_float(a) 848 else: 849 raise AssertionError(f"unrecognized scalar {a}") 850 851 852def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int): 853 shape_env.constrain_symbol_range(s, compiler_min, compiler_max) 854 855 856def _advise_is_size(a): 857 """ 858 Don't use this directly; use torch._check_is_size instead. 859 860 This is a softer version of _constrain_range_for_size (with min=0, 861 max=Inf). Instead of forcibly constraining a variable (and erroring if we 862 failed to constrain it), it will simply advise us that a size is 863 constrained in some way. We will always defer a runtime assert for this 864 constraint if we cannot prove it at compile-time, but we we only 865 *sometimes* learn useful extra information at compile-time with this 866 information. This is in contrast to constrain_range_for_size, where if 867 you don't call that on a fresh unbacked symint, chances are we will choke. 868 869 TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed 870 code. Right now this is only really used in code with AOTAutograd trace 871 through, so it is not a big problem that this isn't supported, but in 872 principle all of this code should be Dynamo'able too. 873 874 TODO: I didn't support min/max because I didn't have a use case where this 875 actually helped. In principle we can support it, it just makes the 876 implementation below more complicated. 877 """ 878 879 # This must always succeed, because the sole allowed caller _check_is_size 880 # was responsible for expect_true'ing this 881 # This assert triggers expensive sym compute, do not do it until its cheap. 882 # assert a >= 0 883 884 # NB: it's important not to constrain range for size for *hinted* SymInts, 885 # because it is not only unsound, it will immediately trip our asserts 886 # that hints have to be consistent with static analysis! If you somehow 887 # have an unbounded SymInt that later constrains to 1, this will be 888 # inconsistent with the range 889 if ( 890 isinstance(a, SymInt) 891 and isinstance(a.node, SymNode) 892 and isinstance(a.node.expr, sympy.Symbol) 893 and a.node.shape_env.is_unbacked_symint(a.node.expr) 894 ): 895 _constrain_range_for_size(a) 896 897def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None): 898 """ 899 This function is NOT INTENDED to be used by itself. 900 """ 901 902 if isinstance(a, (SymFloat, SymBool)): 903 raise ValueError("Constraining SymFloat/SymBool is nyi") 904 905 assert isinstance(a, SymInt), "can only constrain range for SymInt" 906 assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" 907 908 a.node.shape_env._constrain_range_for_size(a.node.expr, min, max) 909 910 911# inclusive both ways 912def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): 913 """ 914 Applies a constraint that the passed in SymInt must lie between min-max 915 inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning 916 that it can be used on unbacked SymInts). If min/max are None, we assume 917 that the dimension is unbounded in that direction. Repeated application 918 of constrain_range intersects the ranges. This is a fairly low level API 919 that doesn't have a lot of safety guarantees (TODO: provide higher level 920 APIs). 921 922 Currently, we use this API in the following circumstance: when we allocate 923 an unbacked SymInt, denoting an integer quantity which is data dependent, 924 we ordinarily do not know anything about what values it may take. This 925 means that any sort of guard on it will immediately fail. However, in 926 many cases, we know something about the unbacked SymInt: for example, we 927 know that nonzero(x).size(0) must be >= 0. We use constrain_range to 928 narrow the possible range, declaring that negative symbols are impossible. 929 This permits to definitely answer True to queries like 'nnz >= 0', even if 930 we don't know what the actual (hinted) value of 'nnz' is. In fact, we 931 actually use constrain_range to unsoundly discharge common guards: for an 932 unbacked SymInt produced by nonzero, we will also assume that it is not 933 equal to 0/1 (even though these are perfectly possible values at runtime), 934 because we generally expect graphs that are valid for N=2 to also be valid 935 for N=1. 936 """ 937 if min is None: 938 min = -int_oo 939 if max is None: 940 max = int_oo 941 942 if max < min: 943 raise ValueError( 944 "Maximum value to constrain_as_size can't be less than the specified min value, " 945 "received min={min} and max={max}" 946 ) 947 948 if isinstance(a, int): 949 if not (min <= a <= max): 950 raise ValueError(f"Invalid value {a} for range [{min}:{max}]") 951 return 952 953 a.node.shape_env._constrain_range(a.node.expr, min, max) 954 955def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None: 956 """ 957 Given two SymInts, constrain them so that they must be equal. NB: 958 this will not work with SymInts that represent nontrivial expressions 959 (yet!) 960 """ 961 if not isinstance(a, SymInt): 962 if not isinstance(b, SymInt): 963 assert a == b 964 return 965 else: 966 shape_env = b.node.shape_env 967 else: 968 shape_env = a.node.shape_env 969 970 shape_env._constrain_unify(a, b) 971 972# Assume that a boolean is true for the purposes of subsequent symbolic 973# reasoning. This will keep track of corresponding runtime checks to verify 974# that the result is upheld: either as a regular guard, or as a special set 975# of asserts which are triggered when an unbacked SymInt is allocated. 976# 977# DO NOT use this function for these cases: 978# 979# - This is inappropriate for "branching" conditions (where both 980# true and false result in valid programs). We will always assume 981# the condition evaluates true, and so it will never be possible 982# to trace the false condition when you use it. For true branching 983# on unbacked SymInts, you must use torch.cond; if you incorrectly 984# use expect_true in this case, you will make the false branch 985# unreachable (as we will simply assume that only the true branch 986# is ever exercised). 987# 988# - This is inappropriate for situations where you know some other system 989# invariant guarantees that this property holds, since you don't 990# really need to insert a runtime check in that case. Use something 991# like constrain_range in that case. 992# 993# This API has a hitch. To avoid having to reimplement error reporting 994# capabilities, this function CAN return False. The invariant is that 995# the surrounding code must raise an error when this function returns 996# False. This is quite low level, so we recommend using other functions 997# like check() which enforce this in a more intuitive way. 998# 999# By the way, this name is a nod to the __builtin_expect macro, 1000# which is used similarly (but unlike __builtin_expect, you MUST fail 1001# in the unlikely branch.) (I think expect is a good name; in recent 1002# versions of C++, this is replaced with [[likely]], which is weaker 1003# and not accurate for this function!) 1004def expect_true(a, skip: int = 0): 1005 if isinstance(a, SymBool): 1006 # TODO: check perf implications of this 1007 frame = inspect.currentframe() 1008 for _ in range(skip + 1): # always run this loop at least once 1009 frame = frame.f_back 1010 return a.node.expect_true(frame.f_code.co_filename, frame.f_lineno) 1011 assert type(a) is bool, a 1012 return a 1013 1014def guard_bool(a): 1015 if isinstance(a, SymBool): 1016 return a.node.guard_bool("", 0) # NB: uses Python backtrace 1017 assert type(a) is bool, a 1018 return a 1019 1020def guard_int(a): 1021 if isinstance(a, SymInt): 1022 return a.node.guard_int("", 0) # NB: uses Python backtrace 1023 assert type(a) is int, a 1024 return a 1025 1026def guard_float(a): 1027 if isinstance(a, SymFloat): 1028 return a.node.guard_float("", 0) # NB: uses Python backtrace 1029 assert isinstance(a, float), a 1030 return a 1031 1032# Given a GraphModule, return all the FakeTensors for all the placeholders 1033def fx_placeholder_vals(gm): 1034 return [n.meta['val'] for n in gm.graph.nodes if n.op == "placeholder"] 1035 1036def fx_placeholder_targets(gm): 1037 return [n.target for n in gm.graph.nodes if n.op == "placeholder"] 1038 1039# Given a GraphModule and arguments to run it with, evaluate that the guards 1040# for its associated ShapeEnv are satisfied by the passed arguments. This 1041# WILL check for duck sizing. 1042def eval_guards(gm, *args, ignore_static=True): 1043 return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args, ignore_static=ignore_static) 1044 1045def bind_symbols(gm, *args): 1046 return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) 1047 1048class DimDynamic(Enum): 1049 """ 1050 Controls how to perform symbol allocation for a dimension. It is always 1051 sound to default this to DYNAMIC, but the policies DUCK and STATIC can 1052 result in better trace-time and compile-time performance, as they reduce 1053 the number of allocated symbols and generally make your graph more static. 1054 1055 NB: If we notice you've applied a constraint to the dimension, we will 1056 force it to DYNAMIC for simplicity. 1057 1058 DimDynamic is controlled by a variety of higher level UX features. 1059 Currently: 1060 1061 - In eager mode, the default policy is DUCK. 1062 - The default is changed to STATIC with assume_static_by_default. 1063 - An individual dim is marked DYNAMIC if you mark_dynamic_dim. 1064 - In export mode, the default policy is STATIC. 1065 - An individual dim is marked DYNAMIC if you specify it in 1066 dynamic_shapes passed to export. 1067 """ 1068 # Treat the dimension symbolically 1069 DYNAMIC = 0 1070 # Treat the dimension symbolically, but if its hint matches another 1071 # dynamic dimension, unify the two symbols ("duck sizing") 1072 DUCK = 1 1073 # Treat the dimension statically based on its hint 1074 STATIC = 2 1075 # Treat the dimension as a size-like unbacked 1076 SIZE_LIKE_UNBACKED = 3 1077 # Infer the strides from stride. If size is static, strides will be static as well. 1078 INFER_STRIDE = 4 1079 1080 1081# NB: These constraints affect both clients and backends: given some 1082# constraint C, the client must pass inputs that satisfy the constraint, 1083# while a backend must not introduce guards BEYOND this constraint. 1084# For clarity, we document the implications on both sides for both the client 1085# and the backend. 1086# 1087# NB: These constraints are on a *single* dimension. In principle, we could 1088# also have multi-dimension constraints, but our guess is that this is not 1089# actually useful and so we are not supporting it right now. 1090# 1091# NB: Strict constraints are typically only suitable for export, as in eager 1092# a backend like inductor may validly introduce extra, discretionary guards 1093# to improve performance of code. A StrictMinMaxConstraint would be brittle 1094# under future optimizations performed by inductor; we don't guarantee 1095# eager code with StrictMinMaxConstraint will keep working in the future! 1096 1097@dataclass(frozen=True) 1098class Constraint: 1099 warn_only: bool 1100 1101@dataclass(frozen=True) 1102class StrictMinMaxConstraint(Constraint): 1103 """ 1104 For clients: the size at this dimension must be within 'vr' (which 1105 specifies a lower and upper bound, inclusive-inclusive) AND it 1106 must be non-negative and should not be 0 or 1 (but see NB below). 1107 1108 For backends: there must not be any guards on this dimension which 1109 are not implied by the given lower and upper bound. Regardless of 1110 the lower bound, the backend can assume the size is non-negative 1111 and that it is not 0 or 1. 1112 1113 An unbounded StrictMinMaxConstraint can be thought of as a strict version 1114 of "RelaxedUnspecConstraint". 1115 1116 NB: Export will often unsoundly assume that a graph works for 0/1, even 1117 though at trace time we assumed size is not 0 or 1. The idea is that 1118 if we produce a graph that works for a range of values, it will be OK 1119 for N=0/1 too. 1120 """ 1121 vr: ValueRanges 1122 1123 def render(self, source: Source): 1124 """Format the constrain equation""" 1125 # TODO: better printing for -oo and oo 1126 return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}" 1127 1128@dataclass(frozen=True) 1129class RelaxedUnspecConstraint(Constraint): 1130 """ 1131 For clients: no explicit constraint; constraint is whatever is implicitly 1132 inferred by guards from tracing. 1133 1134 For backends: there must exist at least TWO possible values for the 1135 size at this dimension which satisfy the guards for this dimension. 1136 1137 In other words, this constraint helps us distinguish between "we don't 1138 care if this dimension specializes or not" versus "this dimension must be 1139 unspecialized." However, this constraint doesn't say very much about what 1140 specialization is permitted; for example, if we guard on a size being 1141 even, this would still be acceptable under an unspec constraint. This 1142 makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler 1143 may add constraints to otherwise dynamic dimensions; we can't assert that 1144 there are NO guards as this is brittle because compilers should be able to 1145 add extra constraints. If you want to assert that there are no guards, 1146 use StrictMinMaxConstraint with an unbounded ValueRanges. 1147 """ 1148 def render(self, source: Source): 1149 return f"RelaxedUnspecConstraint({source.name()})" 1150 1151# NB: None here indicates the client constraint is whatever is implicitly 1152# inferred by guards from tracing, and that a backend can add whatever guards 1153# it wants (including fully specializing the value). 1154DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None] 1155 1156@dataclass(frozen=True) 1157class EqualityConstraint(Constraint): 1158 """ 1159 Represent and decide various kinds of equality constraints between input sources. 1160 1161 A "source pair" is a pair of input sources for dynamic dimensions that 1162 are specified equal. We represent `source_pairs` in a union-find forest 1163 so that we can efficiently check whether two such sources are transitively equal. 1164 1165 A "derived equality" relates an input source to an expression over a root. 1166 The root can be another input source, corresponding to some dynamic dimension, 1167 or a phantom symbol that does not directly represent any dynamic dimension. We 1168 represent `derived_equalities` involving input sources in a transitively-closed map 1169 so that we can efficiently check whether an input source is transitively equal to 1170 a given expression over another input source. 1171 (NOTE: In contrast, it is easy to decide whether an input source is transitively equal 1172 to a given expression over a phantom symbol; such expressions are already in canonical 1173 form and so the problem reduces to symbolic expression equality.) 1174 """ 1175 source_pairs: List[Tuple[Source, Source]] 1176 derived_equalities: List[Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]] 1177 phantom_symbols: List[sympy.Symbol] 1178 1179 def __post_init__(self): 1180 """Pre-processing to answer queries `is_equal` and `is_derived` below. 1181 1182 Example: Suppose we are given: 1183 source_pairs [a = b, b = c] 1184 derived_equalities [d = c + 1, e = d - 1] 1185 We first construct a union find with source_pairs: 1186 _parents = {a: a, b: a, c: a} 1187 Then we compute canonical symbolic expressions, recursively applying derived_equalities 1188 until we bottom out: 1189 _defs = {d: c + 1, e: (c + 1) - 1 aka c} 1190 """ 1191 1192 # self._parents is a map from input sources to input sources where, conceptually, 1193 # these are directed edges in a union-find forest 1194 _parents: Dict[Source, Source] = {} 1195 object.__setattr__(self, "_parents", _parents) 1196 # self._defs is a map from input sources to "canonical" symbolic expressions, 1197 # i.e., unary expressions with symbols that corresponds to regular Dims (i.e., 1198 # not derived Dims) 1199 _defs: Dict[Source, sympy.Expr] = {} 1200 object.__setattr__(self, "_defs", _defs) 1201 1202 for source1, source2 in self.source_pairs: 1203 # preprocess into a union-find forest 1204 self._union(self._find(source1), self._find(source2)) 1205 for source, root, fn in self.derived_equalities: 1206 # preprocess into a transitively-closed map 1207 # NOTE(avik): we reuse the union-find forest for canonicalizing input sources 1208 if isinstance(root, sympy.Symbol): 1209 self._defs[self._find(source)] = fn(root) 1210 else: 1211 self._defs[self._find(source)] = fn(self._rewrite(root)) 1212 1213 def _find(self, source): 1214 # chase edges to find the root of this equivalence class 1215 if source in self._parents: 1216 return self._find(self._parents[source]) 1217 else: 1218 return source 1219 1220 def _union(self, root1, root2): 1221 # merge two equivalence classes by adding an edge from one root to the other 1222 if root1 != root2: 1223 self._parents[root1] = root2 1224 1225 def _rewrite(self, src): 1226 # always represent the given source by the root of its equivalence class 1227 src = self._find(src) 1228 if src in self._defs: 1229 # simply look up the definition if it exists 1230 # NOTE(avik): This works because definitions are always transitively-closed; 1231 # otherwise we would have to do recursive rewriting. 1232 return self._defs[src] 1233 else: 1234 # otherwise, create a symbol representing the source 1235 return sympy.Symbol(src.name()) 1236 1237 def is_equal(self, source1, source2): 1238 return ( 1239 # check whether source1 and source2 have the same root 1240 self._find(source1) == self._find(source2) or 1241 # check whether source1 is derived equal to source2 1242 self.is_derived(source1, source2, lambda x: x) 1243 ) 1244 1245 def is_derived(self, src, symbol_src, fn): 1246 # check whether both src and symbol_src have the same definition 1247 return self._rewrite(src) == fn(self._rewrite(symbol_src)) 1248 1249 1250def _assert_symbol_context(symbolic_context): 1251 assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object" 1252 assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC" 1253 1254def _is_supported_equivalence(expr): 1255 # Currently supported Dim ops are linear expressions with integer coefficients. 1256 # So check that expr only contains +, *, ints, and a single occurrence of a symbol. 1257 # (See also documentation of dynamic_shapes._DerivedDim.) 1258 if isinstance(expr, (sympy.Add, sympy.Mul)): 1259 if len(expr.args) > 2: 1260 return False 1261 lhs, rhs = expr.args 1262 return ( 1263 (_is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or 1264 (isinstance(lhs, sympy.Integer) and _is_supported_equivalence(rhs)) 1265 ) 1266 return isinstance(expr, sympy.Symbol) 1267 1268def _has_uninterpretable_sympy_function(expr) -> bool: 1269 """ 1270 Add functions that our sympy interpreter can't reify into FX nodes 1271 """ 1272 return expr.has( 1273 torch.utils._sympy.functions.ToFloat, 1274 torch.utils._sympy.functions.TruncToInt, 1275 torch.utils._sympy.functions.CeilToInt, 1276 ) 1277 1278@dataclass(frozen=True) 1279class SymbolicContext: 1280 """ 1281 Data structure specifying how we should create symbols in 1282 ``create_symbolic_sizes_strides_storage_offset``; e.g., should 1283 they be static or dynamic. 1284 1285 This is an abstract base class because we are probably going to add 1286 another version of this that says "use exactly these SymInts, don't 1287 allocate fresh symbols." 1288 """ 1289 1290 1291@dataclass(frozen=True) 1292class StatelessSymbolicContext(SymbolicContext): 1293 """ 1294 Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via 1295 a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. 1296 This will cause fresh symbols to be allocated 1297 """ 1298 dynamic_sizes: DimList[DimDynamic] 1299 dynamic_strides: DimList[DimDynamic] = None 1300 constraint_sizes: DimList[DimConstraint] = None 1301 constraint_strides: DimList[DimConstraint] = None 1302 # If the tensor is a view, this should be populated for the base. It contains 1303 # information on how to allocate symbols when recursively fakeifying the base 1304 # during view fake-ification. 1305 view_base_context: Optional[SymbolicContext] = None 1306 # TODO: add storage offset and stride symbolic_context 1307 1308 def __post_init__(self): 1309 if self.dynamic_strides is None: 1310 object.__setattr__(self, 'dynamic_strides', [DimDynamic.INFER_STRIDE] * len(self.dynamic_sizes)) 1311 if self.constraint_sizes is None: 1312 object.__setattr__(self, 'constraint_sizes', [None] * len(self.dynamic_sizes)) 1313 if self.constraint_strides is None: 1314 object.__setattr__(self, 'constraint_strides', [None] * len(self.dynamic_sizes)) 1315 assert all(stride in (DimDynamic.INFER_STRIDE, DimDynamic.DYNAMIC, DimDynamic.DUCK) for stride in self.dynamic_strides) 1316 1317 1318# note [Tensor Fakification and Symbol Caching] 1319# 1320# As of the time of this note, dynamo creates a fresh fake tensor mode for backends. 1321# The reason we do this is because there are certain classes of operations, namely, 1322# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor 1323# state at the end of a dynamo trace is different than the fake tensor state at the beginning 1324# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation, 1325# view relationships, etc. 1326# 1327# As we create a new fake mode, we also lose the memoization that comes with it. Rather than 1328# transfer the memoization cache, we instead transfer the shape env. However, with this 1329# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in 1330# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across 1331# recompilations. 1332# 1333# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass 1334# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext. 1335# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is 1336# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors 1337# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env 1338# is used. 1339# TODO(voz): Shape env validation 1340@dataclass(frozen=True) 1341class StatefulSymbolicContext(StatelessSymbolicContext): 1342 """ 1343 Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via 1344 a symbolic_context determination as given by a cache of Source:Symbol. A cache hit 1345 will reuse a stored symbol, and a cache miss will write to this cache. 1346 1347 This behaves like StatelessSymbolicContext, except the cache supersedes the 1348 other values - dynamic_sizes and constraint_sizes will not be read if we cache 1349 hit. 1350 1351 It is the cache owners responsibility to maintain the lifecycle of the cache 1352 w/r/t different shape_envs, clearing, etc. 1353 """ 1354 tensor_source: Source = None 1355 # Why is this keyd on int first? 1356 # That integer is actually the id of the shape_env. This cache short-circuits symbol 1357 # creation, and we must store it per shape env. Now, while tracing invariants are a single 1358 # shape env per tracing context, and every new frame gets a new shape_env. So where would we have 1359 # multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events 1360 # is invoked, and creates a new shape_env. Replaying events against this new shape_env will 1361 # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never 1362 # get recorded in var_to_val, etc. 1363 # TODO(voz): consider a weakref to the shape_env here 1364 shape_env_to_source_to_symbol_cache : Dict[int, Dict["TensorPropertySource", "sympy.Expr"]] = None 1365 1366 def __post_init__(self): 1367 super().__post_init__() 1368 # The None default is annoying, but required because of dataclass limitations 1369 assert self.tensor_source is not None 1370 if not self.shape_env_to_source_to_symbol_cache: 1371 object.__setattr__(self, 'shape_env_to_source_to_symbol_cache', {}) 1372 1373 1374@dataclass(frozen=True) 1375class SubclassSymbolicContext(StatefulSymbolicContext): 1376 """ 1377 The correct symbolic context for a given inner tensor of a traceable tensor subclass 1378 may differ from that of the outer symbolic context. This structure allows for this 1379 flexibility, with inner symbolic contexts mapped via attr -> symbolic context. 1380 """ 1381 inner_contexts: Dict[str, SymbolicContext] = None 1382 1383 def __post_init__(self): 1384 super().__post_init__() 1385 if self.inner_contexts is None: 1386 self.inner_contexts = {} 1387 1388 1389def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool: 1390 if isinstance(val, (int, float, bool)): 1391 return False 1392 return val.node.is_symbolic() 1393 1394IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) 1395 1396@lru_cache(256) 1397def safe_expand(r): 1398 if hasattr(r, 'expand'): 1399 try: 1400 return sympy.expand(r) 1401 except RecursionError: 1402 log.warning("RecursionError in sympy.expand(%s)", r) 1403 return r 1404 else: 1405 return r 1406 1407def error(): 1408 raise AssertionError("shouldn't be hit") 1409 1410 1411# TODO: Deduplicate this with torch/_prims_common/__init__.py 1412def eval_is_non_overlapping_and_dense(sizes, strides): 1413 return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides))) 1414 1415def _eval_is_non_overlapping_and_dense(sizes, strides): 1416 dim = len(sizes) 1417 1418 # Short-circuits for tensors of rank one, which are 1419 # non-overlapping and "dense" if their stride is one 1420 # or it is a 0/1 element tensor 1421 if dim == 1: 1422 return strides[0] == 1 or sizes[0] < 2 1423 1424 # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous 1425 # Sorts (length, stride) pairs by stride 1426 lengths_and_strides = sorted( 1427 zip(sizes, strides), key=operator.itemgetter(1) 1428 ) 1429 1430 # Unlike the C++ code, we don't move the 0/1 size dimensions to the 1431 # end. So we have to keep going for this code. 1432 expected_stride = 1 1433 for length, stride in lengths_and_strides: 1434 1435 if length == 1: 1436 continue 1437 1438 if stride != expected_stride: 1439 return False 1440 1441 expected_stride *= length 1442 1443 return True 1444 1445 1446def _sympy_cast_symbool_to_symint_guardless(x: sympy.Expr) -> sympy.Expr: 1447 return sympy.Piecewise((1, x), (0, True)) 1448 1449 1450def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 1451 if isinstance(symbool, bool): 1452 return 1 if symbool else 0 1453 int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr) 1454 return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint()) if has_hint(symbool) else None) 1455 1456SYMPY_INTERP = { 1457 'Abs': operator.abs, 1458 'Eq': operator.eq, 1459 'Ne': operator.ne, 1460 'Gt': operator.gt, 1461 'Lt': operator.lt, 1462 'Le': operator.le, 1463 'Ge': operator.ge, 1464 'Min': min, 1465 'Max': max, 1466 'Mod': operator.mod, 1467 'PythonMod': operator.mod, 1468 'FloorDiv': operator.floordiv, 1469 'TrueDiv': operator.truediv, 1470 'PowByNatural': operator.pow, 1471 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 1472 'floor': math.floor, 1473 'ceiling': math.ceil, 1474 'FloorToInt': math.floor, 1475 'FloatPow': math.pow, 1476 'CeilToInt': math.ceil, 1477 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, 1478 'RoundToInt': builtins.round, 1479 'RoundDecimal': builtins.round, 1480 'TruncToInt': math.trunc, 1481 'IntTrueDiv': operator.truediv, 1482 'FloatTrueDiv': operator.truediv, 1483 'ToFloat': builtins.float, 1484} 1485 1486 1487def _lru_cache(fn, maxsize=None): 1488 """ 1489 Wrapper around lru_cache that clears when new info about shapes has been 1490 updated. 1491 1492 Use lru_cache if the output is always the same, regardless of the 1493 constraints we know now (i.e. evaluate_expr) 1494 1495 Use _lru_cache otherwise. 1496 1497 Also note that this depends on _update_version_counter being called on the 1498 shape environment whenever the constraints are updated, otherwise the cache 1499 will not be cleared. 1500 """ 1501 fn_cache = lru_cache(maxsize)(fn) 1502 prior_version = 0 1503 1504 if config.validate_shape_env_version_key: 1505 prior_key = None 1506 1507 @functools.wraps(fn) 1508 def wrapper(self, *args, **kwargs): 1509 nonlocal prior_version, prior_key 1510 if prior_key is None: 1511 prior_key = self._get_key() 1512 1513 if prior_version != self._version_counter: 1514 fn_cache.cache_clear() 1515 prior_version = self._version_counter 1516 prior_key = self._get_key() 1517 else: 1518 assert prior_key == self._get_key(), \ 1519 "ShapeEnv cache key changed without version being updated!" 1520 1521 return fn_cache(self, *args, **kwargs) 1522 1523 else: 1524 1525 @functools.wraps(fn) 1526 def wrapper(self, *args, **kwargs): 1527 nonlocal prior_version 1528 if prior_version != self._version_counter: 1529 fn_cache.cache_clear() 1530 prior_version = self._version_counter 1531 1532 return fn_cache(self, *args, **kwargs) 1533 1534 wrapper.cache_clear = fn_cache.cache_clear 1535 wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined] 1536 return wrapper 1537 1538 1539# This is pretty similar to ShapeGuard but it also comes with a message, 1540# and is exclusively used for things that MUST be true (unlike guards, 1541# which can evaluate False, in which case you just choose not to use 1542# a particular specialization) 1543@dataclass(frozen=True) 1544class RuntimeAssert: 1545 expr: sympy.Expr 1546 msg: str = field(repr=False) 1547 stack: str = field(repr=False) 1548 1549 1550# Used for printing SymExprs in compile_fx 1551class SymExprPrinter(StrPrinter): 1552 def _print_Float(self, expr): 1553 return str(float(expr)) 1554 1555 1556class ShapeGuardPrinter(SymExprPrinter): 1557 def __init__( 1558 self, 1559 symbol_to_source, 1560 source_ref, 1561 var_to_sources, 1562 ): 1563 super().__init__() 1564 self.symbol_to_source = symbol_to_source 1565 self.source_ref = source_ref 1566 self.var_to_sources = var_to_sources 1567 1568 def _print_Not(self, expr): 1569 return 'not {}'.format(self.parenthesize(expr.args[0], PRECEDENCE["Not"])) 1570 1571 def _print_And(self, expr): 1572 return self.stringify(expr.args, " and ", PRECEDENCE["And"]) 1573 1574 def _print_Or(self, expr): 1575 return self.stringify(expr.args, " or ", PRECEDENCE["Or"]) 1576 1577 def _print_Symbol(self, expr) -> str: 1578 assert isinstance(expr, sympy.Symbol), str(type(expr)) 1579 1580 def repr_symbol_to_source(): 1581 return repr({ 1582 symbol: [s.name() for s in sources] 1583 for symbol, sources in self.symbol_to_source.items() 1584 }) 1585 1586 assert self.symbol_to_source.get(expr), ( 1587 f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) " 1588 f"not in {repr_symbol_to_source()}. If this assert is failing, it could be " 1589 "due to the issue described in https://github.com/pytorch/pytorch/pull/90665" 1590 ) 1591 return self.source_ref(self.symbol_to_source[expr][0]) 1592 1593 1594class LoggingShapeGuardPrinter(ShapeGuardPrinter): 1595 def __init__(self, var_to_sources): 1596 super().__init__(var_to_sources, lambda n: n.name(), var_to_sources) 1597 1598 1599class DynamicDimConstraintPrinter(StrPrinter): 1600 """ 1601 Printer for dynamic dim constraints. 1602 - Instead of symbol s_k it prints its source t.size()[i] 1603 - Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc. 1604 1605 We use this to suggest code for specifying dynamic dim constraints. 1606 """ 1607 def __init__(self, symbol_to_source, source_name_to_debug_name): 1608 super().__init__() 1609 self.symbol_to_source = symbol_to_source 1610 self.source_name_to_debug_name = source_name_to_debug_name 1611 1612 def _print_Symbol(self, expr) -> str: 1613 assert isinstance(expr, sympy.Symbol), str(type(expr)) 1614 assert self.symbol_to_source.get(expr), ( 1615 f"Unknown symbol {expr} created by constraints solver" 1616 ) 1617 return self.symbol_to_source[expr][0].name() 1618 1619 def _print_Relational(self, expr): 1620 return f'{self.parenthesize(expr.lhs, precedence(expr))} {expr.rel_op} {self.parenthesize(expr.rhs, precedence(expr))}' 1621 1622 1623class DimConstraints: 1624 """ 1625 Custom solver for a system of constraints on symbolic dimensions. 1626 Solutions are "static" values or simplified "dynamic" constraints. 1627 """ 1628 1629 def __init__( 1630 self, 1631 symbol_to_source, 1632 var_to_val, 1633 marked_dynamic, 1634 source_name_to_debug_name, 1635 ): 1636 # We try to solve systems of inequalities with 1 free variable. 1637 self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set) 1638 # Among them, we prioritize solving for a free variable that has equalities. 1639 # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() 1640 # and removing a symbol from the former => removing it from the latter. 1641 self._symbols_with_equalities: Set[sympy.Symbol] = set() 1642 # A solution of a free variable with equalities becomes a substitution. 1643 # We use these substitutions to simplify other constraints. 1644 # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions. 1645 self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {} 1646 1647 # In general, constraints may have // and % operations. 1648 # Of course, // can be expressed in terms of / and %. 1649 # Our inequality solver can handle / but not %. So we need to transform them away. 1650 # We do so by using the values of variables as hints to evaluate %. 1651 # For soundness we record additional congruence guards and solve them separately. 1652 self._var_to_val: Dict[sympy.Symbol, sympy.Integer] = var_to_val 1653 self._congruences: Set[sympy.Expr] = defaultdict(set) 1654 1655 # We do not try to (directly) solve inequalities with > 1 free variables. 1656 # NOTE: free variables in these inequalities cannot also be in _substitutions. 1657 self._multivariate_inequalities: Set[sympy.Expr] = set() 1658 1659 # We park external equalities between free variables here. 1660 self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = [] 1661 1662 # Solutions come in two forms: 1663 # - (static) specializations 1664 # - (dynamic) inequalities / congruences 1665 self._static_results: Set[str] = set() 1666 self._dynamic_results: Set[str] = set() 1667 1668 # printer for solutions 1669 self._dcp = DynamicDimConstraintPrinter(symbol_to_source, source_name_to_debug_name) 1670 1671 # inconsistencies found on substituting with concrete values / static solutions 1672 self._inconsistencies: List[str] = [] 1673 1674 # symbols that are marked dynamic 1675 self._marked_dynamic = marked_dynamic 1676 1677 # track supported sympy functions and subtract from list of all sympy functions 1678 self._supported_sympy_functions: Set[sympy.Function] = { 1679 Application, 1680 Mod, 1681 PythonMod, 1682 FloorDiv, 1683 } 1684 self._enumerate_sympy_functions() 1685 1686 def rewrite_with_congruences(self, s, expr): 1687 """ 1688 Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. 1689 This leaves rational operators (in particular of the form b / d) that our inequality solver can handle. 1690 We solve the added congruences separately (using our congruence solver, see below). 1691 """ 1692 def mod_handler(*args): 1693 # Suppose that we have an expression of the form b % d with free variable s. 1694 # Using the value of s as a "hint," we can evaluate b % d to a value k. 1695 # Then we can rewrite b % d to k while adding the guard b % d == k. 1696 1697 # NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF 1698 # the original expression always evaluates to a constant value (i.e., it does not vary with s). 1699 # In other words, 1700 # - solutions of s with the rewritten expression are guaranteed to also be solutions of s with 1701 # the original expression; 1702 # - while it may be possible to find solutions of s with the original expression that are not 1703 # solutions with the rewritten expression, in that case the original expression cannot evaluate 1704 # to the same value for all solutions of s. 1705 # 1706 # Should we be worried about this incompleteness? No, because of the following reasons: 1707 # 1. It unblocks dramatic simplification that would not be otherwise possible with current tech 1708 # (i.e., "don't let perfect be the enemy of the good"). 1709 # 2. We already have a tradition of using hints to add guards in the compiler for making progress. 1710 # 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards 1711 # we generate (or simplify to) seem to be of the form b % d == k where k is a constant. 1712 # 1713 # Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2. 1714 # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we 1715 # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution! 1716 base, divisor = args 1717 base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor) 1718 mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(self._var_to_val) 1719 congruence = (base - mod_reduced) % divisor 1720 if congruence != 0: 1721 self._congruences[s].add(congruence) 1722 return mod_reduced 1723 1724 def floor_div_handler(*args): 1725 # Suppose that we have an expression of the form b // d with free variable s. 1726 # Using the value of s, we can evaluate b % d to a value k. 1727 # Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k. 1728 1729 # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d 1730 # and eliminating b % d as above. 1731 base, divisor = args 1732 base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor) 1733 mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(self._var_to_val) 1734 congruence = (base - mod_reduced) % divisor 1735 if congruence != 0: 1736 self._congruences[s].add(congruence) 1737 # NB: Must not be CleanDiv, it needs to be regular sympy division 1738 # so inequality solver works. This is sort of problematic for 1739 # is_integer tests though haha 1740 return (base - mod_reduced) / divisor 1741 1742 if expr.has(Mod): 1743 expr = expr.replace(Mod, mod_handler) 1744 # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative 1745 # arguments should be OK. 1746 if expr.has(PythonMod): 1747 expr = expr.replace(PythonMod, mod_handler) 1748 if expr.has(FloorDiv): 1749 expr = expr.replace(FloorDiv, floor_div_handler) 1750 return expr 1751 1752 def _enumerate_sympy_functions(self): 1753 module = torch.utils._sympy.functions 1754 all_functions = set() 1755 for attr in dir(module): 1756 if isinstance(func := getattr(module, attr), sympy.FunctionClass): 1757 all_functions.add(func) 1758 self._unsupported_sympy_functions = all_functions.difference(self._supported_sympy_functions) 1759 1760 def _has_unsupported_sympy_function(self, expr) -> bool: 1761 """ 1762 Tracks list of sympy.Functions the export solver doesn't know how to handle. 1763 """ 1764 return expr.has(*self._unsupported_sympy_functions) 1765 1766 def add(self, expr) -> bool: 1767 """Add an expression to the set of constraints. 1768 1769 Return whether the expression is a trivial constraint (i.e., an obvious tautology). 1770 """ 1771 if expr == sympy.true: 1772 return True 1773 orig_expr = expr 1774 orig_reduced = orig_expr.xreplace(self._var_to_val) 1775 # TODO(avik): https://github.com/pytorch/pytorch/issues/101093 1776 # It is possible that `expr` will fail the consistency check because of 1777 # precision errors. Specifically, on substituting its free symbols with 1778 # their concrete values, we might end up comparing floats. Until we have 1779 # a fix for this issue, we delay raising such failures. See solve(). 1780 if orig_reduced == sympy.false: 1781 self._inconsistencies.append(f"{orig_expr} is inconsistent!") 1782 if isinstance(expr, sympy.Ne) or self._has_unsupported_sympy_function(expr): 1783 # we're not going to do anything useful with these, so drop them 1784 return False 1785 free_symbols = expr.free_symbols 1786 assert free_symbols, f"Did not expect constraint with no free variables: {expr}" 1787 if len(free_symbols) > 1: 1788 # multivariate: record and move on 1789 self._multivariate_inequalities.add(expr) 1790 else: 1791 # univariate: can solve these immediately 1792 s = next(iter(free_symbols)) 1793 # eliminate // and % (see documentation of `rewrite_with_congruences` above) 1794 old_n_congruences = len(self._congruences[s]) 1795 expr = self.rewrite_with_congruences(s, expr) 1796 new_n_congruences = len(self._congruences[s]) 1797 if expr == sympy.true: 1798 return old_n_congruences == new_n_congruences 1799 reduced = expr.xreplace(self._var_to_val) 1800 if reduced == sympy.false: 1801 self._inconsistencies.append( 1802 f"{expr}, obtained by rewriting {orig_expr} with congruences, " 1803 "is inconsistent!" 1804 ) 1805 if isinstance(expr, sympy.Eq): 1806 # special status for symbols that have equalities (see `solve` below) 1807 self._symbols_with_equalities.add(s) 1808 self._univariate_inequalities[s].add(expr) 1809 return False 1810 1811 def add_equality(self, source, expr): 1812 """Add an equality constraint""" 1813 if expr.is_number: 1814 # specialization, right here 1815 self._static_results.add(f"{source.name()} == {expr}") 1816 else: 1817 # these will resolve to either specializations or dynamic equality constraints 1818 self._symbolic_equivalences.append((source, expr)) 1819 1820 def _reduce_congruences(self): 1821 reduced_congruences = {} 1822 for s, congruences in self._congruences.items(): 1823 remainder_modulus_pairs = [] 1824 congruences_to_check = set() 1825 for congruence in congruences: 1826 base, divisor = congruence.args 1827 # We are given a congruence of the form base % divisor == 0 with a free variable s. So: 1828 # - we transform this into an equation of the form base = divisor * tmp; 1829 # - we solve this equation for s to get a linear solution with free variable tmp. 1830 tmp = sympy.Symbol("reduce_congruences_tmp", integer=True) 1831 symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s]) 1832 # See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear 1833 # for how to interpret the results. 1834 if s == symbol: 1835 # This means the solution is of the form s = modulus*tmp + remainder. 1836 modulus, remainder = sympy.polys.polytools.div(solution, tmp) 1837 if isinstance(modulus, sympy.Integer) and isinstance(remainder, sympy.Integer): 1838 # Make sure 0 <= remainder <= modulus. 1839 remainder = remainder % modulus 1840 remainder_modulus_pairs.append((remainder, modulus)) 1841 continue 1842 # This means that we did not get a unique solution to the equation. 1843 # No problem, we will check it. 1844 congruences_to_check.add(congruence) 1845 # Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i). 1846 # The solution will be a congruence of the form s = r mod m. 1847 # NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT. 1848 if remainder_modulus_pairs: 1849 remainder, modulus = sympy.ntheory.modular.solve_congruence(*remainder_modulus_pairs) 1850 reduced_congruences[s] = {(s - remainder) % modulus} 1851 substitution = {s: modulus * sympy.Symbol("tmp", integer=True) + remainder} 1852 reduced_congruences[s].update( 1853 congruence for congruence in congruences_to_check 1854 if not sympy.checksol(congruence, substitution) 1855 ) 1856 else: 1857 reduced_congruences[s] = congruences_to_check 1858 1859 return reduced_congruences 1860 1861 def _raise_inconsistencies(self): 1862 if self._inconsistencies: 1863 msg = "\n".join(self._inconsistencies) 1864 self._inconsistencies.clear() 1865 raise ValueError(f"The following inconsistencies were found:\n{msg}") 1866 1867 def solve(self): 1868 """Solve the system of constraint equations to find simplified constraints 1869 """ 1870 self._raise_inconsistencies() 1871 # as long as there are symbols with equalities, solve for them 1872 # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols) 1873 while self._symbols_with_equalities: 1874 s = self._symbols_with_equalities.pop() 1875 exprs = self._univariate_inequalities.pop(s) 1876 solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) 1877 if isinstance(solution, sympy.And): 1878 solution = next((arg for arg in solution.args if isinstance(arg, sympy.Eq)), solution) 1879 assert isinstance(solution, sympy.Eq), f"Expected an equality constraint for {s}, got {solution}" 1880 symbol, val = solution.args 1881 assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}" 1882 # because this is univariate, the solution is a specialization 1883 self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}") 1884 # add this as a substitution to simplify other constraints 1885 self._substitutions[s] = val 1886 1887 # simplify multivariate inequalities: some of them will now become univariate! 1888 multivariate_inequalities = self._multivariate_inequalities 1889 self._multivariate_inequalities = set() 1890 for expr in multivariate_inequalities: 1891 self.add(expr.xreplace({s: self._substitutions[s]})) 1892 self._raise_inconsistencies() 1893 1894 # solve linear congruences 1895 # NOTE(avik): We do not need to solve them for symbols that have already been specialized. 1896 reduced_congruences = self._reduce_congruences() 1897 for s, congruences in reduced_congruences.items(): 1898 for congruence in congruences: 1899 # any congruence that cannot be checked becomes a dynamic constraint as well 1900 if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}): 1901 if self._is_supported_congruence(congruence): 1902 base, divisor = congruence.args 1903 tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}" 1904 tmp = sympy.Symbol(tmp_name, integer=True) 1905 from torch._dynamo.source import ConstantSource 1906 self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)] 1907 r = try_solve(sympy.Eq(base, divisor * tmp), s) 1908 self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1]))) 1909 1910 # remaining symbols have only pure inequalities (no equalities) 1911 for s, exprs in self._univariate_inequalities.items(): 1912 try: 1913 solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) 1914 # because this is univariate, the solution is a dynamic (range) constraint 1915 if isinstance(solution, sympy.Or): 1916 solution = next(iter(arg for arg in solution.args if arg.xreplace(self._var_to_val))) 1917 if isinstance(solution, sympy.And): 1918 for arg in solution.args: 1919 self._dynamic_results.add(self._dcp.doprint(arg)) 1920 else: 1921 self._dynamic_results.add(self._dcp.doprint(solution)) 1922 except (NotImplementedError, AssertionError) as e: 1923 log.warning("Failed to reduce inequalities: %s", e) 1924 for expr in exprs: 1925 self._dynamic_results.add(self._dcp.doprint(expr)) 1926 1927 # simplify symbolic equivalences: some of them will now become specializations! 1928 symbolic_equivalences = self._symbolic_equivalences 1929 self._symbolic_equivalences = [] 1930 for source, expr in symbolic_equivalences: 1931 self.add_equality(source, expr.xreplace(self._substitutions)) 1932 1933 # remaining symbolic equivalences become dynamic equality constraints 1934 for source, expr in self._symbolic_equivalences: 1935 self._dynamic_results.add(f"{source.name()} == {self._dcp.doprint(expr)}") 1936 1937 @classmethod 1938 def _is_supported_congruence(cls, congruence): 1939 base, divisor = congruence.args 1940 # Congruences that can be currently expressed with supported Dim ops are 1941 # of the form (x + a) % b == 0, where x is a Dim and a and b are constants. 1942 # This allows us to derive x as b*y - a for some Dim y. 1943 # (See also documentation of dynamic_shapes._DerivedDim.) 1944 if isinstance(base, sympy.Add): 1945 lhs, rhs = base.args 1946 cond = ( 1947 (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer)) or 1948 (isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol)) 1949 ) 1950 else: 1951 cond = isinstance(base, sympy.Symbol) 1952 cond = cond and isinstance(divisor, sympy.Integer) 1953 return cond 1954 1955 def forced_specializations(self): 1956 """Returns a dictionary of the names of symbols to their specialized value 1957 """ 1958 def debug_name(src): 1959 name = src.name() 1960 if self._dcp.source_name_to_debug_name: 1961 return f"{self._dcp.source_name_to_debug_name[name]} = {name}" 1962 else: 1963 return name 1964 1965 return { 1966 debug_name(self._dcp.symbol_to_source[s][0]): val 1967 for s, val in self._substitutions.items() 1968 if s in self._marked_dynamic 1969 } 1970 1971 def _is_derived_dim(self, dim): 1972 return isinstance(dim, torch.export.dynamic_shapes._DerivedDim) 1973 1974 def _is_dim(self, dim): 1975 return ( 1976 isinstance(dim, torch.export.dynamic_shapes._Dim) 1977 and not isinstance(dim, torch.export.dynamic_shapes._DerivedDim) 1978 ) 1979 1980 def _process_derived_dim_roots( 1981 self, 1982 results: Dict[str, Dict[str, Any]], 1983 name_to_dim: Dict[str, Any], 1984 ) -> None: 1985 ''' 1986 Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots, 1987 and 2) root swapping. 1988 1989 1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests 1990 dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final 1991 suggested fixes handle this correctly, but we can get intermediate results that look like 1992 {"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}} 1993 and this routine prettifies this by unifying to a single root, and making each suggestion 1994 either a derived dim or min/max range, not both. 1995 1996 2) With suggested fixes for derived dims, roots can be swapped, 1997 e.g. dx, dx - 1 -> dy + 1, dy. Here we don't want to print out the attached name, 1998 since this leads to messages like "dx - 1 = Dim("dx - 1", ...)". 1999 Instead we evaluate the new root value, and remove results for its derivations. 2000 2001 First we find all the original roots (specified in dynamic_shapes), that are found in the 2002 values of results (i.e. used for computing suggesting fix values). These original roots 2003 (suppose `dx`) are either specialized, unchanged, refined, or swapped 2004 (expressed as a derived dim). If any of the first 3 cases happen, we suggest `dx`'s value 2005 in results, and remove suggestions for derivations of `dx`, assuming the derived relation 2006 is valid. If swapped, we find the new root, and use the fix to evaluate `dx`'s new value, 2007 and then do the same with `dx`'s derivations. 2008 2009 Assuming the originally specified derived relations are correct is valid, because: 2010 1) if the relations are plain wrong (e.g. input shape = (6, 4) with spec (dx, dx - 1)) 2011 produce_guards() will catch this and crash before hand. 2012 2) if the relations are numerically correct but do not match the emitted guard, 2013 for example: 2014 2015 def forward(self, x, y): 2016 return x.reshape([-1]) + y # guard: s0 * 2 = s1 2017 inputs = (torch.randn(6, 2), torch.randn(12)) 2018 dx = Dim("dx", min=2, max=32) 2019 dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} # this matches values but not op 2020 2021 then this leads to 2 linear equations, and a) produce_guards() is able to solve for 2022 the unique solution of dx = 6 and specialize, and b) the export constraint solver will 2023 raise an issue due to range constraints (a unique solution means not all values in a 2024 range satisfy a guard) and also force specializations. 2025 ''' 2026 from torch.export.dynamic_shapes import Dim 2027 2028 def _check_same_range(c, dim): 2029 # returns True if c & dim are both min/max ranges with same values 2030 return ( 2031 self._is_dim(dim) 2032 and ("min" in c or "max" in c) 2033 and ( 2034 (dim.min < 2 and c.get("min", 2) == 2) 2035 or dim.min == c.get("min", 2) 2036 ) # let pass if analysis min = 2 and specified min = 0/1 2037 and dim.max == c.get("max", int_oo) 2038 ) 2039 2040 # 1) newly introduced roots 2041 # this part we handle adding newly introduced roots 2042 # these arise from guards like "x.shape[0] % 3 == 0" 2043 # leading to suggested fixes like "dx = 3*_dx" 2044 # extract _dx, and find appropriate min/max values 2045 # 2046 # before, we have something like: 2047 # {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2} 2048 # we want instead: 2049 # {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3} 2050 introduced_roots: Dict[str, str] = {} # map new root -> old root 2051 for k, c in list(results.items()): 2052 if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim 2053 root = next(iter(c["eq"].free_symbols)) 2054 if str(root) not in name_to_dim: 2055 introduced_roots[str(root)] = k 2056 # calculate necessary min & max 2057 modulus, remainder = sympy.polys.polytools.div(c["eq"], root) 2058 c_min = c.get("min", 2) 2059 min_ = math.ceil((c_min - remainder) / modulus) 2060 c_max = c.get("max", int_oo) 2061 max_ = math.floor((c_max - remainder) / modulus) 2062 # create result & dim 2063 results[str(root)] = {"min": min_, "max": max_} 2064 name_to_dim[str(root)] = Dim(str(root), min=min_, max=max_) 2065 # remove old root min/max bounds 2066 c.pop("min", None) 2067 c.pop("max", None) 2068 2069 # alter derivations that depend on old root, to unify to new root 2070 # e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2 2071 for old_root in introduced_roots.values(): 2072 for k, c in list(results.items()): 2073 if ( 2074 "eq" in c 2075 and isinstance(c["eq"], sympy.Expr) 2076 and str(symbol := next(iter(c["eq"].free_symbols))) == old_root 2077 ): # derived dim with root = old_root 2078 new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1 2079 new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1 2080 c["eq"] = new_expr 2081 2082 # 2) root swapping 2083 # collect all the original roots that are used for calculating values of suggested fixes 2084 # this consists of: 2085 # 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim 2086 # 2) {"dy": "dx + 1"} -> dx: root for suggested fix 2087 modified_roots: Set[str] = set() 2088 for k, c in results.items(): 2089 if k not in name_to_dim: # _dynamo.export() may handle source directly 2090 continue 2091 if self._is_dim(name_to_dim[k]) and ("min" in c or "max" in c): # case 1) 2092 modified_roots.add(k) 2093 elif "eq" in c and isinstance(c["eq"], sympy.Expr): # case 2) 2094 root = next(iter(c["eq"].free_symbols)) 2095 assert root is not None 2096 modified_roots.add(str(root)) 2097 2098 # exclude newly introduced roots, we've already processed these 2099 modified_roots = modified_roots.difference(introduced_roots) 2100 2101 # evaluate the new value for each root 2102 # this is now either 1) unchanged, 2) refined with a new range, 2103 # or 3) specialized to a concrete value 2104 modified_root_values: Dict[str, Dict[str, Any]] = {} 2105 for root in modified_roots: 2106 swapped_root = True 2107 if root in results: 2108 c = results[root] 2109 if ( 2110 ("min" in c or "max" in c) # range 2111 or isinstance(c["eq"], int) # specialized 2112 ): 2113 # here, the original root is a root Dim or concrete value in results. 2114 # if it is a derived dim, it is swapped, and we handle that below. 2115 if not _check_same_range(c, name_to_dim[root]): # ignore if unchanged 2116 modified_root_values[root] = c 2117 swapped_root = False 2118 2119 if swapped_root: 2120 # if the original root has been swapped in results, that means the new root 2121 # is a range (if it had specialized, the original root would have too). 2122 # find this new root, and solve for the original root's range. 2123 for k, c in results.items(): 2124 if k not in name_to_dim: 2125 continue 2126 dim = name_to_dim[k] 2127 if dim.__class__.__name__ == "_DerivedDim" and dim.root.__name__ == root: 2128 # only look for min/max root, otherwise root would have specialized 2129 if "min" in c or "max" in c: 2130 expr = sympy.sympify(k) 2131 s = next(iter(expr.free_symbols)) 2132 result = { 2133 "min": try_solve(sympy.Eq(expr, c["min"]), s)[1], # type: ignore[arg-type] 2134 "max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type] 2135 } 2136 if not _check_same_range(result, name_to_dim[root]): # ignore if unchanged 2137 modified_root_values[root] = result 2138 break 2139 2140 # filter out results where the key is a derived dim (e.g. {"dx - 1" : 4}) 2141 # we only want to suggest fixes for the root, to avoid derived names. 2142 # also, remove anything in modified_roots, since we either add new modified values after this, 2143 # or have decided they are unchanged. 2144 for k in list(results.keys()): 2145 if k not in name_to_dim: 2146 continue 2147 if self._is_derived_dim(name_to_dim[k]) or k in modified_roots: 2148 del results[k] 2149 2150 # update results with modified root values 2151 # now results has the following properties: 2152 # - only contains original roots as keys 2153 # - each root is now either specialized, refined, or derived from another original root 2154 results.update(modified_root_values) 2155 2156 def prettify_results( 2157 self, 2158 original_signature: inspect.Signature, 2159 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, 2160 constraint_violation_error=None, 2161 forced_specializations=None, 2162 ): 2163 """Format a message for constraint violation erros""" 2164 from torch.export.dynamic_shapes import _get_dim_name_mapping 2165 if not self._dcp.source_name_to_debug_name: 2166 # nothing to do 2167 return "" 2168 2169 def transform(s, inverse=False): 2170 for k, v in self._dcp.source_name_to_debug_name.items(): 2171 s = s.replace(k, v) if not inverse else s.replace(v, k) 2172 return s 2173 2174 results = defaultdict(dict) 2175 if dynamic_shapes is None: 2176 dynamic_shapes = {} 2177 2178 def flip(op): 2179 if op == "<=": 2180 return ">=" 2181 if op == ">=": 2182 return "<=" 2183 if op == "<": 2184 return ">" 2185 if op == ">": 2186 return "<" 2187 assert op == "==" 2188 return op 2189 2190 def relation_with_digit(expr, op, digit): 2191 if op == "<=": 2192 results[expr]["max"] = digit 2193 elif op == "<": 2194 results[expr]["max"] = digit - 1 2195 elif op == ">=": 2196 results[expr]["min"] = digit 2197 elif op == ">": 2198 results[expr]["min"] = digit + 1 2199 else: 2200 assert op == "==" 2201 results[expr]["eq"] = digit 2202 2203 # retrieve dynamic shapes 2204 name_to_dim = _get_dim_name_mapping(dynamic_shapes) 2205 2206 for s in self._static_results.union(self._dynamic_results): 2207 t = transform(s) 2208 if t == s: 2209 continue 2210 left, op, right = re.split(r"( == | <= | >= | < | > )", t) 2211 op = op.strip() 2212 if op == "==" and left == right: 2213 continue 2214 if right.isdigit(): 2215 relation_with_digit(left, op, int(right)) 2216 elif left.isdigit(): 2217 relation_with_digit(right, flip(op), int(left)) 2218 else: 2219 assert op == "==", t 2220 results[left]["eq"] = sympy.sympify(right) 2221 2222 # order forced specializations based on name 2223 forced_specializations = { 2224 k: forced_specializations[k] 2225 for k in sorted( 2226 forced_specializations.keys(), 2227 key=lambda x: x.split(" = ")[1], 2228 ) 2229 } 2230 2231 buf = "" 2232 if forced_specializations: 2233 debug_names = set() 2234 for k in forced_specializations: 2235 dim = name_to_dim[k.split(" = ")[0]] 2236 if self._is_derived_dim(dim): 2237 debug_names.add(dim.root.__name__) 2238 else: 2239 debug_names.add(dim.__name__) 2240 2241 buf += ( 2242 f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! " 2243 'For more information, run with TORCH_LOGS="+dynamic".\n' 2244 ) 2245 for s, val in forced_specializations.items(): 2246 buf += f" - solving the guards generated for {s} resulted in a specialized value of {val}.\n" 2247 2248 self._process_derived_dim_roots(results, name_to_dim) 2249 2250 dims = [] 2251 others = [] 2252 2253 # order results by source name 2254 results = { 2255 k: results[k] for k in sorted( 2256 results.keys(), 2257 key=lambda x: transform(x, inverse=True), 2258 ) 2259 } 2260 for k, c in results.items(): 2261 if "eq" in c: 2262 other = c["eq"] 2263 if isinstance(other, int): 2264 others.append(f"{k} = {other}") 2265 elif _is_supported_equivalence(other): 2266 others.append(f"{k} = {other}") 2267 else: 2268 min_ = c.get("min", None) 2269 if min_ == 2: 2270 min_ = None 2271 max_ = c.get("max", None) 2272 if min_ is not None and max_ is not None: 2273 dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})") 2274 elif min_ is not None: 2275 dims.append(f"{k} = Dim('{k}', min={min_})") 2276 elif max_ is not None: 2277 dims.append(f"{k} = Dim('{k}', max={max_})") 2278 else: 2279 dims.append(f"{k} = Dim('{k}')") 2280 2281 # results will get filtered out if no new suggestions, 2282 # this can happen if guards are too complex. 2283 # in that case don't suggest fix 2284 if dims or others: 2285 buf += "\nSuggested fixes:\n " 2286 buf += "\n ".join(dims + others) 2287 2288 return buf 2289 2290 2291TLS = threading.local() 2292 2293 2294@dataclass(frozen=True) 2295class ShapeEnvSettings: 2296 """ 2297 Encapsulates all shape env settings that could potentially affect 2298 FakeTensor dispatch. Used when creating dispatch cache keys. 2299 """ 2300 2301 allow_scalar_outputs: bool 2302 allow_dynamic_output_shape_ops: bool 2303 assume_static_by_default: bool 2304 specialize_zero_one: bool 2305 duck_shape: bool 2306 prefer_deferred_runtime_asserts_over_guards: bool 2307 allow_complex_guards_as_runtime_asserts: bool 2308 2309 2310class ShapeEnv: 2311 # This is a wrapper over the actual __init__ function. 2312 # 2313 # Where to add a new constructor parameter to ShapeEnv? 2314 # ===================================================== 2315 # This __init__ function should be used only for parameters related to event recording. 2316 # These are parameters that we don't wish to pass down the road to new ShapeEnv instances 2317 # created from replaying events. 2318 # 2319 # If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event 2320 # recording, do so in the _init function. 2321 def __init__( 2322 self, *, 2323 should_record_events: Optional[bool] = None, 2324 tracked_fakes: Optional[List[Any]] = None, 2325 **kwargs 2326 ) -> None: 2327 self._init(**kwargs) 2328 2329 # Disable event recording when replaying. 2330 kwargs["should_record_events"] = False 2331 2332 from torch.fx.experimental.validator import translation_validation_enabled 2333 self._translation_validation_enabled = translation_validation_enabled() 2334 2335 # If not specified, enable event recording if both: 2336 # - Translation validation is on 2337 # - Translation validation bisection is not disabled 2338 self.should_record_events = ( 2339 should_record_events 2340 if should_record_events is not None 2341 else ( 2342 self._translation_validation_enabled 2343 and not config.translation_validation_no_bisect 2344 ) 2345 ) 2346 2347 # Enable event recording check if both: 2348 # - It should record events 2349 # - The recording check is enabled 2350 self.check_recorded_events = ( 2351 self.should_record_events and config.check_shape_env_recorded_events 2352 ) 2353 2354 # This will make sure we only record the top-level function call. 2355 self.is_recording = not self.should_record_events 2356 # Keep track of the list of tracked fakes. 2357 self.tracked_fakes = tracked_fakes 2358 # List of events for reconstructing ShapeEnv at arbitrary points in time. 2359 self.events: List[ShapeEnvEvent] = ( 2360 [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else [] 2361 ) 2362 2363 # FakeTensor per-ShapeEnv operation cache. This is used for caching 2364 # operations that contain symbolic shapes which have guards on the 2365 # ShapeEnv (so are ShapeEnv-dependent). 2366 # 2367 # NOTE: It's important that SymNodes in this cache have their ShapeEnv 2368 # stripped otherwise you end up with cycles which can only be cleaned 2369 # with the GC. 2370 self.fake_tensor_cache: Dict[torch._subclasses.fake_tensor._DispatchCacheKey, 2371 torch._subclasses.fake_tensor._DispatchCacheEntry] = {} 2372 2373 # Pro-tip: if you add new field to ShapeEnv, this affects some accept 2374 # tests. Accept their output with: 2375 # 2376 # EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal 2377 # 2378 def _init( 2379 self, *, 2380 allow_scalar_outputs=True, 2381 allow_dynamic_output_shape_ops=True, 2382 # NB: These are legacy configuration that help us make good choices 2383 # when the constraint/dynamic dims are not explicitly passed to us. 2384 # Ideally we will fix all call sites to be explicit and not have 2385 # implicit choices, but this apparently was pretty involved. 2386 assume_static_by_default=False, 2387 # Note - On 0/1 specialization 2388 # 2389 # The following options affect decisions we make about eager 2390 # specialization. Disabling them will increase trace time (as we do 2391 # more symbolic reasoning) and can also harm the quality of generated 2392 # code (because inductor may not be able to specialize for bounds 2393 # being equal--although if we later respecialize because of a guard, 2394 # your code may be just as good as it was before.) 2395 # 2396 # When True, eagerly specialize input sizes which have 0/1. 2397 specialize_zero_one=True, 2398 # When True, assume input sizes which have the same size are 2399 # symbolically equal. 2400 duck_shape: Optional[bool] = None, 2401 # For debugging 2402 co_fields=None, 2403 # When True, whenever safe, we will generate a deferred runtime assert 2404 # instead of a guard whenever we know that an expression must be True, 2405 # otherwise it would be an error, even for backed SymInts (where we 2406 # could ostensibly unconditionally generate guards). This is useful 2407 # for export, where preventing "error checking" sizes from showing up 2408 # in guards is helpful, since these guards in some sense are overly 2409 # pedantic. See also https://github.com/pytorch/pytorch/issues/121749 2410 prefer_deferred_runtime_asserts_over_guards=False, 2411 # When True, does not emit or raise constraint violation errors on 2412 # implicit guards generated by ops, and defers to runtime assertions 2413 # in the graph instead. For export. 2414 allow_complex_guards_as_runtime_asserts=False, 2415 # XXX Add any new settings that could affect FakeTensor evaluation 2416 # to: torch._subclasses.fake_tensor._ShapeEnvSettings 2417 ): 2418 if duck_shape is None: 2419 duck_shape = config.use_duck_shape 2420 2421 self.settings = ShapeEnvSettings( 2422 # Not directly used by ShapeEnv; indirectly used by FakeTensor 2423 allow_scalar_outputs=allow_scalar_outputs, 2424 allow_dynamic_output_shape_ops=allow_dynamic_output_shape_ops, 2425 # End 2426 assume_static_by_default=assume_static_by_default, 2427 specialize_zero_one=specialize_zero_one, 2428 duck_shape=duck_shape, 2429 prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, 2430 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, 2431 ) 2432 2433 self.guards: List[ShapeGuard] = [] 2434 # Maps symbolic ints to their original concrete values 2435 # Currently populated from tensors 2436 self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} 2437 # Like var_to_val, but only set when propagate_real_tensors is on. 2438 # Used as last resort to avoid GuardOnDataDependent error 2439 self.unbacked_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} 2440 # Maps symbolic ints to their min/max range. These ranges 2441 # are conservative: the int MUST fall in the range, but the 2442 # range may contain ints which may not actually appear in 2443 # practice 2444 self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {} 2445 self.source_name_to_debug_name: Dict[str, str] = {} 2446 self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {} 2447 self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {} 2448 # Maps from sympy ints to expressions representing them 2449 # Populated from equality guards (i.e. a.shape[0] == b.shape[0]) 2450 self.replacements: Dict[sympy.Symbol, sympy.Expr] = {} 2451 self.unbacked_renamings: Dict[sympy.Symbol, sympy.Symbol] = {} 2452 # Set holds a % b expressions that evaluate to 0. 2453 self.divisible: Set[sympy.Expr] = set() 2454 # Set that holds "size-like" symbols. When we perform 2455 # "size-oblivious" tests, these can be assumed to be >= 2. 2456 self.size_like: Set[sympy.Symbol] = set() 2457 # Duck-shaping says that if two input tensors have the same size, 2458 # they get assigned the same symbolic variable 2459 self.val_to_var: Dict[int, sympy.Expr] = {} 2460 if specialize_zero_one: 2461 self.val_to_var = {0: sympy.Integer(0), 1: sympy.Integer(1)} 2462 self.unbacked_symfloat_counter = itertools.count() 2463 self.unbacked_symint_counter = itertools.count() 2464 # Similar to guards, but these MUST evaluate to true and can 2465 # only be evaluated at runtime midway through (i.e., they always 2466 # involve unbacked symints) 2467 # 2468 # For efficiency reasons, we index in the following way. Suppose you have 2469 # a runtime assert i0 + i1 <= s1. We pick the most recently allocated 2470 # symbol in the source expression and add the assert to the list for 2471 # that symbol e.g., {i1: [i0 + i1 <= s1]}. 2472 # 2473 # We access the runtime asserts in two situations: 2474 # 2475 # - When we are guarding on an expression, we will attempt to 2476 # statically evaluate it, in case the unbacked SymInts can 2477 # simplify away. If we have a runtime assert, we may be able 2478 # to discharge the guard entirely. We only need to attempt 2479 # runtime asserts that mention freevars of the expression in 2480 # question. 2481 # 2482 # - When we are performing codegen (in Inductor for eager, or 2483 # when finalizing the export FX graph), we need to know what 2484 # extra runtime asserts to insert. Whenever an unbacked 2485 # SymInt comes into scope, all runtime asserts involving it 2486 # become eligible for insertion (so long as all of their other 2487 # free unbacked symbols are also in scope). We technically 2488 # can handle any choice of key by kicking inexpressible asserts 2489 # to the next unbacked symbol to wait on, but if we choose the 2490 # latest key, an assert will only show up at the moment when 2491 # we can actually codegen it. 2492 self.deferred_runtime_asserts: Dict[sympy.Symbol, List[RuntimeAssert]] = {} 2493 # This exists so we can efficiently invalidate the cache (it's used as 2494 # part of the cache key); otherwise we'd have to iterate through 2495 # deferred_runtime_asserts to compute its length 2496 self.num_deferred_runtime_asserts = 0 2497 self.log = log 2498 self.log.debug("create_env") 2499 self.frozen = False 2500 self.runtime_asserts_frozen = False 2501 self.dim_constraints: Optional[DimConstraints] = None 2502 self.counter = collections.Counter() 2503 # Mapping from sympy.Symbol to the number of guards which mention this 2504 # symbol 2505 self.symbol_guard_counter = collections.Counter() 2506 # A selection of important fields on co_field; solely used for 2507 # signpost_event 2508 self.co_fields = co_fields if co_fields else {} 2509 2510 # Whenever we allocate a fresh unbacked Symbol, we add it to this 2511 # pending list. Unbacked symbol allocation can occur at unpredictable 2512 # points during meta tensor propagation, but at some point, the we 2513 # have to know what the binding site for an unbacked symbol is, and 2514 # this is computed when we actually place the node in the graph. The 2515 # important thing is that we always actually handle every unaccounted 2516 # for unbacked symbol, so this list helps us keep track of them and 2517 # then make sure they are all accounted for. 2518 # 2519 # We could potentially give rise to errors earlier by lexically 2520 # scoping when we do propagation, and only allowing unbacked symbols 2521 # to be allocated at this point in time. However this is inconvenient 2522 # to do in Dynamo, because fake tensor propagation is far from when we 2523 # analyze binding sites (set_example_value), so we do it in a more 2524 # mutatey way. 2525 # 2526 # NB: fresh unbacked symbols NEVER get substitutions applied to them, 2527 # they are binding sites! 2528 self.pending_fresh_unbacked_symbols: List[sympy.Symbol] = [] 2529 2530 # Version counter used to invalidate cached values 2531 self._prev_cache_key = self._get_key() 2532 self._version_counter = 0 2533 2534 # Cache for FX nodes. 2535 # Maps an already built node a tuple of: 2536 # 1. node's target 2537 # 2. list of arguments 2538 # This drastically reduces the size of the FX graph, avoiding 2539 # duplicated nodes. 2540 self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {} 2541 self.source_to_symbol: Dict[str, sympy.Symbol] = {} 2542 2543 # Suppose you want to replace an unbacked symbol with another 2544 # unbacked symbol. This is error prone because you can cause 2545 # references to unbacked symbols to time travel backwards. E.g., 2546 # 2547 # u1 = x.item() 2548 # ... use of u1 ... 2549 # u2 = y.item() 2550 # u3 = z.item() 2551 # torch._check(u1 == u2 + u3) 2552 # 2553 # If you replace u1 with u2 + u3, then the use of u1 now 2554 # references u2 and u3 prior to them actually being bound at 2555 # runtime. 2556 # 2557 # To control for this, we track the order unbacked symbols 2558 # were allocated, and only allow substitutions if they respect 2559 # the dependency from this order; an unbacked symbol can only 2560 # be substituted with unbacked symbols that come before it in the 2561 # order. 2562 # 2563 # This also imposes an ordering on the unbacked symbol binding 2564 # sites themselves: you are not allowed to reorder unbacked symbol 2565 # bindings. At the moment, this is not tracked, but we potentially 2566 # could track this at the IR level using a higher order operator 2567 # with something like effect token tracking. 2568 self.unbacked_alloc_order: Dict[sympy.Symbol, int] = {} 2569 2570 from torch.fx.experimental.validator import translation_validation_enabled 2571 self._translation_validation_enabled = translation_validation_enabled() 2572 2573 if self._translation_validation_enabled: 2574 from torch.fx.experimental.validator import TranslationValidator 2575 2576 self.validator = TranslationValidator() 2577 self.graph = torch.fx.Graph() 2578 # Create an output graph and start inserting before that. 2579 # This is needed when 'deepcopy'-ing this object. 2580 self.graph.inserting_before(self.graph.output(None)) 2581 2582 # Mapping of each node name to the node itself. 2583 # 2584 # This is useful for matching an FX node from a recorded ShapeEnv.graph 2585 # to the FX node of the ShapeEnv we are running the event on. 2586 # 2587 # Whenever you add a node to self.graph, you must add a mapping to this 2588 # variable. Otherwise, the built FX graph on the replayed ShapeEnv will 2589 # not be valid. 2590 self.name_to_node: Dict[str, torch.fx.Node] = {} 2591 2592 @property 2593 def allow_scalar_outputs(self): 2594 return self.settings.allow_scalar_outputs 2595 2596 @property 2597 def allow_dynamic_output_shape_ops(self): 2598 return self.settings.allow_dynamic_output_shape_ops 2599 2600 @property 2601 def assume_static_by_default(self): 2602 return self.settings.assume_static_by_default 2603 2604 @property 2605 def specialize_zero_one(self): 2606 return self.settings.specialize_zero_one 2607 2608 @property 2609 def duck_shape(self): 2610 return self.settings.duck_shape 2611 2612 @property 2613 def prefer_deferred_runtime_asserts_over_guards(self): 2614 return self.settings.prefer_deferred_runtime_asserts_over_guards 2615 2616 @property 2617 def allow_complex_guards_as_runtime_asserts(self): 2618 return self.settings.allow_complex_guards_as_runtime_asserts 2619 2620 def check_equal(self, other: "ShapeEnv") -> None: 2621 """Compare another ShapeEnv for equivalence 2622 """ 2623 # ShapeEnv fields that are not relevant for the outcome of 2624 # ShapeEnv.produce_guards call: 2625 # - Debugging variables 2626 # - Translation validation related variables 2627 # - Events recording related variables 2628 non_state_variable_names = ( 2629 "counter", 2630 "log", 2631 "var_to_stack", 2632 "fx_node_cache", 2633 "graph", 2634 "validator", 2635 "check_recorded_events", 2636 "should_record_events", 2637 "is_recording", 2638 "tracked_fakes", 2639 "events", 2640 "source_name_to_debug_name", 2641 "_prev_cache_key", 2642 "_version_counter", 2643 "dim_constraints", 2644 ) 2645 2646 # Mapping of the value of each to-be-compared field into the values that 2647 # should actually be compared. 2648 # 2649 # You should modify this if, for example, the field that holds state and 2650 # debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr) 2651 # and the stack when it was added to the set of guards. In order to compare 2652 # it, we throw away the stack information. 2653 def map_value(key: str, value: Any) -> Any: 2654 if key in ("unbacked_symfloat_counter", "unbacked_symint_counter"): 2655 from copy import copy 2656 2657 # For itertools.count(), we compare the next integer returned 2658 # by the count iterators. Not that we need to copy the iterator 2659 # first. Otherwise we are mutating the object. 2660 return next(copy(value)) 2661 elif key == "guards": 2662 # Transform the list of ShapeGuard into a list of expressions. 2663 return [g.expr for g in value] 2664 elif key == "deferred_runtime_asserts": 2665 # Transform the list of RuntimeAsserts into a list of expressions. 2666 return {s: [ra.expr for ra in ras] for s, ras in value.items()} 2667 elif key == "name_to_node": 2668 # Compare just the set of keys is the same. 2669 return set(value.keys()) 2670 elif key in ("symbol_guard_counter", "pending_fresh_unbacked_symbols", "fake_tensor_cache"): 2671 # Skip this for comparisons 2672 return None 2673 return value 2674 2675 shape_env_check_state_equal(self, other, non_state_variable_names, map_value) 2676 2677 def _snapshot_tracked_fakes(self) -> Optional[List[Any]]: 2678 if self.tracked_fakes is None: 2679 return None 2680 2681 from torch._dynamo.variables.builder import TrackedFake 2682 2683 def maybe_transform_fake(fake: TrackedFake): 2684 inner_fake = fake.fake \ 2685 if isinstance(fake.fake, (torch.SymInt, torch.SymFloat)) \ 2686 else FakeTensorMeta.from_fake(fake.fake) 2687 # Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a 2688 # FakeTensorMeta for two reasons: 2689 # 1. this is all the information we need when recording ShapeEnvEvents. 2690 # 2. it works even if each TrackedFake changes its metadata. 2691 return TrackedFake(inner_fake, fake.source, fake.symbolic_context) # type: ignore[arg-type] 2692 2693 return [maybe_transform_fake(fake) for fake in self.tracked_fakes] 2694 2695 def _last_event_index(self) -> int: 2696 return len(self.events) - 1 2697 2698 @contextmanager 2699 def _recording(self): 2700 self.is_recording = True 2701 try: 2702 yield 2703 finally: 2704 self.is_recording = False 2705 2706 @record_shapeenv_event() 2707 def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr): 2708 self._set_replacement(orig_s, new_s, "eliminate_unbacked") 2709 2710 @record_shapeenv_event() 2711 def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None: 2712 """Used only when propagate_real_tensors; registers a value for an 2713 unbacked symbol, which can be used last resort to resolve hints.""" 2714 self.unbacked_var_to_val[k] = sympy.sympify(v) 2715 2716 # Unlike set_replacement, this records a shapeenv event 2717 @record_shapeenv_event() 2718 def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol): 2719 assert isinstance(orig_s, sympy.Symbol), orig_s 2720 assert isinstance(new_s, sympy.Symbol), new_s 2721 assert free_unbacked_symbols(new_s), new_s 2722 assert free_unbacked_symbols(orig_s), orig_s 2723 dest = self.replacements.get(orig_s) 2724 assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}" 2725 self._set_replacement(orig_s, new_s, "rename_unbacked_to") 2726 self.unbacked_renamings[orig_s] = new_s 2727 if dest is not None: 2728 self._set_replacement(new_s, dest, "rename_unbacked_to_dest") 2729 2730 @record_shapeenv_event() 2731 def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None): 2732 if min is None: 2733 min = 0 2734 if max is None: 2735 max = int_oo 2736 2737 if max < min: 2738 raise ValueError( 2739 "Maximum value to constrain_as_size can't be less than the specified min value, " 2740 "received min={min} and max={max}" 2741 ) 2742 2743 self.constrain_symbol_range( 2744 a, 2745 compiler_min=min, 2746 compiler_max=max, 2747 ) 2748 self.size_like.add(a) 2749 2750 @record_shapeenv_event() 2751 def _constrain_range(self, a: sympy.Expr, min: int, max: int): 2752 if isinstance(a, sympy.Integer): 2753 if not (min <= int(a) <= max): 2754 raise ValueRangeError(f"Invalid value {int(a)} for range [{min}:{max}]") 2755 return 2756 2757 # TODO: Shouldn't we install a guard if the symbol is backed? Or is the 2758 # semantics that this is an "unchecked" assert (but it this actually 2759 # something useful? Might be better to restrict only for unbacked 2760 # SymInt). 2761 if isinstance(a, sympy.Symbol): 2762 self.constrain_symbol_range( 2763 a, 2764 compiler_min=min, 2765 compiler_max=max, 2766 ) 2767 2768 @record_shapeenv_event() 2769 def _constrain_unify(self, a, b): 2770 """ 2771 Given two SymInts, constrain them so that they must be equal. NB: 2772 this will not work with SymInts that represent nontrivial expressions 2773 (yet!) 2774 """ 2775 # TODO: this does not install a deferred runtime assert yet 2776 2777 # TODO: Maybe dedupe this with _maybe_guard_rel? 2778 # Update Feb 2024: this is extra important to do, this doesn't handle 2779 # unbacked replacements properly nor does it generate deferred runtime 2780 # asserts 2781 if not isinstance(a, SymInt): 2782 if not isinstance(b, SymInt): 2783 assert a == b 2784 else: 2785 assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI" 2786 assert b.node.shape_env is self 2787 self.replacements[b.node.expr] = sympy.Integer(a) 2788 else: 2789 # TODO: Actually, we can support this as long as one of them is a symbol. 2790 # NB: We can't actually do "unification" as our operators are not 2791 # injective 2792 assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" 2793 assert a.node.shape_env is self 2794 if not isinstance(b, SymInt): 2795 self.replacements[a.node.expr] = sympy.Integer(b) 2796 else: 2797 assert a.node.shape_env is b.node.shape_env 2798 assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI" 2799 new_var = self._find(a.node.expr) 2800 self.replacements[b.node.expr] = new_var 2801 2802 def _ignore_fresh_unbacked_symbols_tls(self): 2803 return getattr(TLS, "ignore_fresh_unbacked_symbols", False) 2804 2805 @record_shapeenv_event() 2806 def _ignore_fresh_unbacked_symbols_enter(self): 2807 TLS.ignore_fresh_unbacked_symbols = True 2808 2809 @record_shapeenv_event() 2810 def _ignore_fresh_unbacked_symbols_exit(self): 2811 TLS.ignore_fresh_unbacked_symbols = False 2812 2813 @contextmanager 2814 def ignore_fresh_unbacked_symbols(self): 2815 """ 2816 Indicates that the newly allocated unbacked SymInts are being 2817 discarded 2818 """ 2819 self._ignore_fresh_unbacked_symbols_enter() 2820 try: 2821 yield 2822 finally: 2823 self._ignore_fresh_unbacked_symbols_exit() 2824 2825 @record_shapeenv_event() 2826 def freeze(self): 2827 """Freeze this ShapeEnv to stop accumulating guards 2828 2829 A frozen ShapeEnv will ignore any further guards generated on it and 2830 only emit a warning which may lead to accuracy problems. 2831 """ 2832 self.frozen = True 2833 2834 @record_shapeenv_event() 2835 def freeze_runtime_asserts(self): 2836 """Freeze this ShapeEnv to stop adding deferred runtime asserts. 2837 2838 We will error if you try to install a new runtime assert when it is 2839 frozen. This would indicate a lowering violation, or perhaps something 2840 we know statically is already True but we are checking it again in a way 2841 that is not clearly dischargeable. 2842 """ 2843 # self.prefer_deferred_runtime_asserts_over_guards = False 2844 self.runtime_asserts_frozen = True 2845 2846 def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]: 2847 if not self._translation_validation_enabled: 2848 return None 2849 srcname = source.name() 2850 if source not in self.source_to_symbol: 2851 self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True) 2852 return self.source_to_symbol[srcname] 2853 2854 def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None: 2855 if self._translation_validation_enabled: 2856 self.validator.add_var(symbol, type) 2857 2858 def _add_target_expr(self, expr) -> None: 2859 if self._translation_validation_enabled: 2860 self.validator.add_target_expr(expr) 2861 2862 def _add_assertion(self, expr) -> None: 2863 if self._translation_validation_enabled: 2864 self.validator.add_assertion(expr) 2865 2866 def _check_translation_validate(self) -> None: 2867 if self._translation_validation_enabled: 2868 self.validator.validate() 2869 2870 @record_shapeenv_event() 2871 def _create_fx_call_function( 2872 self, 2873 op: Callable, 2874 args: Tuple, 2875 ) -> Tuple[Optional[torch.fx.Node], bool]: 2876 # Cache this tuple in order to avoid duplicated nodes. 2877 node_key = (op, args) 2878 # Flags whether the returned node was cached or not. 2879 fresh = False 2880 2881 if self._translation_validation_enabled and node_key not in self.fx_node_cache: 2882 2883 # Presence of None in the arguments implies that we should ignore this operation. 2884 if any(a is None for a in args): 2885 # We check if we are not mixing SymNode that should not be ignored 2886 # (fx_node is not None) with those that should (fx_node is None). 2887 assert all(not isinstance(a, torch.fx.Node) for a in args) 2888 return None, fresh 2889 2890 fresh = True 2891 2892 # If translation validation is enabled, all arguments must have its 2893 # own FX node. 2894 assert all(a is not None for a in args), f"missing arg in FX graph ({op.__name__}): {args}" 2895 node = self.fx_node_cache[node_key] = self.graph.call_function(op, args) 2896 self.name_to_node[node.name] = node 2897 2898 return self.fx_node_cache.get(node_key, None), fresh 2899 2900 def _create_fx_placeholder_and_z3var( 2901 self, 2902 symbol: sympy.Symbol, 2903 type: Type, 2904 ) -> Optional[torch.fx.Node]: 2905 if not self._translation_validation_enabled: 2906 return None 2907 2908 node_key = (self.graph.placeholder, (symbol,)) 2909 2910 # Check if we haven't added this symbol already. 2911 # If so, skip the placeholder creation, as it 2912 # generates invalid Python code. 2913 if node_key not in self.fx_node_cache: 2914 # Add a Z3 variable according to 'type'. 2915 self._add_z3var(symbol, type) 2916 # Create the FX placeholder out of a mangled name. 2917 mangled_name = re.sub(r'[^a-zA-Z0-9]', '_', re.sub(r'[()]', '', symbol.name)) 2918 node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name) 2919 self.name_to_node[node.name] = node 2920 # Attach the 'symbol' to the placeholder so that we can retrieve 2921 # the Z3 variable later. 2922 node.meta["symbol"] = symbol 2923 2924 return self.fx_node_cache[node_key] 2925 2926 def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None: 2927 if self._translation_validation_enabled and node is not None: 2928 self.name_to_node.pop(node.name) 2929 self.graph.erase_node(node) 2930 2931 def _add_fx_node_metadata(self, node: torch.fx.Node) -> None: 2932 from torch._dynamo.utils import get_current_node 2933 2934 if self.should_record_events: 2935 node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index() 2936 node.meta[CURRENT_NODE_KEY] = get_current_node() 2937 2938 def _suppress_guards_tls(self): 2939 return getattr(TLS, "suppress_guards", False) 2940 2941 @record_shapeenv_event() 2942 def _suppress_guards_enter(self): 2943 TLS.suppress_guards = True 2944 2945 @record_shapeenv_event() 2946 def _suppress_guards_exit(self): 2947 TLS.suppress_guards = False 2948 2949 @contextmanager 2950 def suppress_guards(self): 2951 """Context manager to ignore all guards generated inside""" 2952 self._suppress_guards_enter() 2953 try: 2954 yield 2955 finally: 2956 self._suppress_guards_exit() 2957 2958 def _get_key(self): 2959 """ 2960 Defines the current "state" of the guards we've accumulated in this ShapeEnv. 2961 Determines when we need to invalidate our cache 2962 """ 2963 return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts, len(self.unbacked_var_to_val)) 2964 2965 def _update_version_counter(self): 2966 # The shape environment is queried orders of magnitude more often than 2967 # it is changed, so we summarise the cache key into a linearly 2968 # increasing version counter which is cheaper to check in _lru_cache 2969 2970 # Only update version counter if the state actually changed 2971 cur_key = self._get_key() 2972 if self._prev_cache_key != cur_key: 2973 self._prev_cache_key = cur_key 2974 self._version_counter += 1 2975 2976 def _produce_dyn_sizes(self, 2977 ex_size: Sequence[int], 2978 source: Source, 2979 symbolic_context: SymbolicContext 2980 ) -> List[sympy.Expr]: 2981 return self._produce_dyn_sizes_from_int_tuple(tuple(ex_size), source, symbolic_context) 2982 2983 def _produce_dyn_sizes_from_int_tuple(self, 2984 tensor_size: Tuple[int], 2985 source: Source, 2986 symbolic_context: SymbolicContext, 2987 ) -> List[sympy.Expr]: 2988 assert all(not is_symbolic(val) for val in tensor_size), f"Expect size to be a plain tuple of ints but got {tensor_size}" 2989 from torch._dynamo.source import TensorPropertySource, TensorProperty 2990 _assert_symbol_context(symbolic_context) 2991 dynamic_dims = symbolic_context.dynamic_sizes 2992 constraint_dims = symbolic_context.constraint_sizes 2993 size = [] 2994 for i, val in enumerate(tensor_size): 2995 size.append(self.create_symbol( 2996 val, 2997 TensorPropertySource(source, TensorProperty.SIZE, i), 2998 dynamic_dims[i], 2999 constraint_dims[i], 3000 symbolic_context=symbolic_context 3001 )) 3002 return size 3003 3004 def create_symbolic_sizes_strides_storage_offset( 3005 self, 3006 ex: torch.Tensor, 3007 source: Source, 3008 *, 3009 symbolic_context: Optional[SymbolicContext] = None, 3010 ): 3011 """ 3012 Returns a list of symbolic sizes and strides for the given tensor. 3013 We try our best to express stride in terms of the sizes, so as to not 3014 introduce new symbolic variables. 3015 """ 3016 3017 ex_size = tuple(self._maybe_specialize_sym_int_with_hint(sz) for sz in ex.size()) 3018 ex_stride = tuple(self._maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride()) 3019 ex_storage_offset = self._maybe_specialize_sym_int_with_hint(ex.storage_offset()) 3020 3021 return self._create_symbolic_sizes_strides_storage_offset( 3022 ex_size, 3023 ex_stride, 3024 ex_storage_offset, 3025 [_is_dim_dynamic(ex, i) for i in range(ex.dim())], 3026 source, 3027 symbolic_context=symbolic_context, 3028 ) 3029 3030 # Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic"). 3031 # We create symbols in shape_env using the backed hints behind SymInt. 3032 3033 # Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape. 3034 # produce_guards will trigger specializations on the outer stuff 3035 3036 # Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint(). 3037 # 3038 # It's probably good for now but it's important to note that this approach has implications for 3039 # the original shape_env when checking guards in different order. 3040 3041 # Example: 3042 # --------- 3043 # Consider a function "opt_f" as shown below: 3044 3045 # @torch.compile() 3046 # def opt_f(x: bool, y: Tensor): 3047 # if x == True: 3048 # return y + torch.randn([4]) 3049 # else: 3050 # return y 3051 # Depending on the sequence of calls, we might install two different sets of guards: 3052 3053 # 1. opt_f(False, y): 3054 # - "x == False" (always works for any size y) 3055 3056 # 2. opt_f(True, y): 3057 # - Triggers recompilation and results in guards like: 3058 # - "x == True and y.size(0) == 4" 3059 # - (or "y.size(0) == 4 and x == True") 3060 3061 # The order of checking the guards matters. In this specific example: 3062 # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, 3063 # we may have an unnessary shape speciliazation for y. 3064 def _maybe_specialize_sym_int_with_hint(self, maybe_sym) -> int: 3065 assert isinstance(maybe_sym, (int, torch.SymInt)) 3066 if is_symbolic(maybe_sym): 3067 assert maybe_sym.node.shape_env is not self, \ 3068 "expect the symbol is created from an shape env other than current one." 3069 return maybe_sym.node.require_hint() 3070 return maybe_sym 3071 3072 @record_shapeenv_event() 3073 def _create_symbolic_sizes_strides_storage_offset( 3074 self, 3075 ex_size: Sequence[int], 3076 ex_stride: Sequence[int], 3077 ex_storage_offset: int, 3078 is_dim_dynamic: Sequence[bool], 3079 source: Source, 3080 *, 3081 symbolic_context: Optional[SymbolicContext] = None, 3082 ): 3083 dim = len(ex_size) 3084 3085 # Reimplement the legacy behavior 3086 if symbolic_context is None: 3087 constraint_sizes = [None] * dim 3088 constraint_strides = [None] * dim 3089 dynamic_dims = [] 3090 dynamic_strides = [] 3091 for i in range(dim): 3092 # NB: This is encapsulation breaking! Legacy behavior was 3093 # bad. 3094 if is_dim_dynamic[i]: 3095 r = DimDynamic.DYNAMIC 3096 elif self.assume_static_by_default: 3097 r = DimDynamic.STATIC 3098 else: 3099 r = DimDynamic.DUCK 3100 dynamic_dims.append(r) 3101 dynamic_strides.append(r) 3102 dynamic_dims = [DimDynamic.DUCK] * dim 3103 dynamic_strides = [DimDynamic.INFER_STRIDE] * dim 3104 # symbolic_context is None - set one 3105 symbolic_context = StatelessSymbolicContext( 3106 dynamic_sizes=dynamic_dims, 3107 dynamic_strides=dynamic_strides, 3108 constraint_sizes=constraint_sizes, 3109 constraint_strides=constraint_strides, 3110 ) 3111 # We got a StatelessSymbolicContext 3112 _assert_symbol_context(symbolic_context) 3113 constraint_sizes = symbolic_context.constraint_sizes 3114 constraint_strides = symbolic_context.constraint_strides 3115 dynamic_sizes = symbolic_context.dynamic_sizes 3116 dynamic_strides = symbolic_context.dynamic_strides 3117 3118 # TODO: make this configurable from outside symbolic_context; we made a symbolic_context 3119 # decision here where if all sizes are static, we are going to 3120 # specialize all of the inner strides/offset too. We don't have to 3121 # do this, and arguably we should ALWAYS allow for dynamic offset, 3122 # this is cheap. 3123 # TODO: This should be DYNAMIC, using DUCK for BC 3124 dynamic_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_sizes) else DimDynamic.DUCK 3125 are_sizes_static = all(r == DimDynamic.STATIC for r in dynamic_sizes) 3126 3127 assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}" 3128 assert len(dynamic_strides) == dim, f"{len(dynamic_sizes)} != {dim}" 3129 assert len(constraint_sizes) == dim 3130 assert len(constraint_strides) == dim 3131 3132 from torch._dynamo.source import TensorPropertySource, TensorProperty 3133 size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context) 3134 stride: List[Optional[sympy.Expr]] = [None] * len(size) 3135 for i, val in enumerate(ex_stride): 3136 if val in (0, 1): 3137 stride[i] = sympy.Integer(val) 3138 while any(x is None for x in stride): 3139 candidates = { 3140 ex_size[i] * ex_stride[i]: size[i] * stride[i] 3141 for i in range(len(size)) 3142 if stride[i] is not None and ex_stride[i] >= 0 3143 } 3144 3145 # iterate over unbound strides in sorted order 3146 def _nested_int_aware_sort(tup): 3147 return ( 3148 # Order nested ints by their coefficients. 3149 # 1 here to order nested ints after non-nested-ints. 3150 (1, tup[0].node.nested_int_coeff(), tup[1]) if is_nested_int(tup[0]) 3151 else (0, *tup) 3152 ) 3153 val_list = sorted( 3154 [(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None], 3155 key=_nested_int_aware_sort, 3156 ) 3157 for _, i in val_list: 3158 # Set stride to a candidate only for DimDynamic.INFER_STRIDE 3159 if stride[i] is None and dynamic_strides[i] == DimDynamic.INFER_STRIDE and ex_stride[i] in candidates: 3160 stride[i] = candidates[ex_stride[i]] 3161 candidates[ex_size[i] * ex_stride[i]] = size[i] * stride[i] 3162 3163 if any(x is None for x in stride): 3164 # bind the smallest unbound stride to a new variable 3165 val, i = min( 3166 [ 3167 (ex_stride[i], i) 3168 for i in range(len(stride)) 3169 if stride[i] is None 3170 ], key=_nested_int_aware_sort 3171 ) 3172 # Set INFER_STRIDE to STATIC or DUCK depending on sizes 3173 dyn_stride = dynamic_strides[i] 3174 if dynamic_strides[i] == DimDynamic.INFER_STRIDE: 3175 dyn_stride = DimDynamic.STATIC if are_sizes_static else DimDynamic.DUCK 3176 stride[i] = self.create_symbol( 3177 val, 3178 TensorPropertySource(source, TensorProperty.STRIDE, i), 3179 dynamic_dim=dyn_stride, 3180 constraint_dim=constraint_strides[i], 3181 symbolic_context=symbolic_context, 3182 ) 3183 assert all(x is not None for x in stride) 3184 3185 sym_sizes = [ 3186 self.create_symintnode( 3187 sym, 3188 hint=hint, 3189 source=TensorPropertySource(source, TensorProperty.SIZE, i), 3190 ) 3191 for i, (sym, hint) in enumerate(zip(size, ex_size)) 3192 ] 3193 sym_stride = [] 3194 for i, stride_expr in enumerate(stride): 3195 # NB: Don't duck size the stride; instead use the expression 3196 # we computed 3197 assert stride_expr is not None 3198 sym_stride.append(self.create_symintnode( 3199 stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i))) 3200 sym_storage_offset = self.create_symintnode( 3201 self.create_symbol( 3202 ex_storage_offset, 3203 TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), 3204 dynamic_dim=dynamic_offset, 3205 constraint_dim=None, 3206 symbolic_context=symbolic_context 3207 ), 3208 hint=ex_storage_offset, 3209 source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)) 3210 return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset 3211 3212 @record_shapeenv_event() 3213 def create_symintnode( 3214 self, 3215 sym: "sympy.Expr", 3216 *, 3217 hint: Optional[int], 3218 source: Optional[Source] = None, 3219 ): 3220 """Create a SymInt value from a symbolic expression 3221 3222 If you know what the current hint value of the SymInt to be created 3223 is, pass it into hint. Otherwise, pass None and we will make our best 3224 guess 3225 3226 """ 3227 source_name = source.name() if source else None 3228 3229 if self._translation_validation_enabled and source is not None: 3230 # Create a new symbol for this source. 3231 symbol = self._create_symbol_for_source(source) 3232 assert symbol is not None 3233 3234 # Create a new FX placeholder and Z3 variable for 'symbol'. 3235 fx_node = self._create_fx_placeholder_and_z3var(symbol, int) 3236 3237 # Add an equality assertion for the newly created symbol and 'sym'. 3238 self._add_assertion(sympy.Eq(symbol, sym)) 3239 else: 3240 fx_node = None 3241 3242 if isinstance(sym, sympy.Integer): 3243 if hint is not None: 3244 assert int(sym) == hint 3245 out = int(sym) 3246 else: 3247 # How can this occur? When we mark_unbacked, we end up with a real 3248 # tensor that has hints for all sizes, but we MUST NOT create a 3249 # SymNode with a hint, because we're hiding the hint from our eyes 3250 # with the unbacked Symbol. And in fact, the hint compute may be 3251 # inconsistent with size oblivious tests. 3252 if free_unbacked_symbols(sym): 3253 hint = None 3254 out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) 3255 return out 3256 3257 @record_shapeenv_event() 3258 def create_symfloatnode( 3259 self, 3260 sym: "sympy.Expr", 3261 *, 3262 hint: Optional[int], 3263 source: Optional[Source] = None, 3264 ): 3265 """Create a SymFloat value from a symbolic expression""" 3266 source_name = source.name() if source else None 3267 3268 if self._translation_validation_enabled and source is not None: 3269 # Create a new symbol for this source. 3270 symbol = self._create_symbol_for_source(source) 3271 assert symbol is not None 3272 3273 # Create a new FX placeholder and Z3 variable for 'symbol'. 3274 fx_node = self._create_fx_placeholder_and_z3var(symbol, float) 3275 3276 # Add an equality assertion for the newly created symbol and 'sym'. 3277 self._add_assertion(sympy.Eq(symbol, sym)) 3278 else: 3279 fx_node = None 3280 3281 if isinstance(sym, sympy.Float): 3282 if hint is not None: 3283 assert float(sym) == hint 3284 out = float(sym) 3285 else: 3286 # You could give this the same treatment as SymInt above if 3287 # you supported mark_unbacked on a float, but it's a kind of 3288 # strange thing to do though because floats don't get 0/1 3289 # specialization anyway 3290 if free_unbacked_symbols(sym): 3291 assert hint is None, sym 3292 out = SymFloat(SymNode(sym, self, float, hint, fx_node=fx_node)) 3293 return out 3294 3295 @record_shapeenv_event() 3296 def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim): 3297 """Create a SymInt wrapping a new unspecified symbol""" 3298 return self.create_symintnode( 3299 self.create_unspecified_symbol( 3300 value, 3301 source=source, 3302 dynamic_dim=dynamic_dim, 3303 ), 3304 hint=value, 3305 source=source, 3306 ) 3307 3308 def create_symboolnode(self, sym: "sympy.Expr"): 3309 """Create a SymBool object from a sympy boolean expression""" 3310 # This function is only being used in serialization, so we do not track it 3311 # for validation. 3312 return SymBool(SymNode(sym, self, bool, None)) 3313 3314 def _log_create_unbacked_symbol(self, prefix: str, symbol, vr: ValueRanges): 3315 is_debug = config.extended_debug_create_symbol is not None and str(symbol) in config.extended_debug_create_symbol.split(',') 3316 fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) 3317 log.info( 3318 "%s %s [%s, %s]%s (%s)%s", 3319 prefix, symbol, vr.lower, vr.upper, maybe_user_loc, format_frame(fsummary), maybe_extra_debug, stack_info=is_debug 3320 ) 3321 3322 @record_shapeenv_event() 3323 def create_unbacked_symfloat(self): 3324 """Create a symbolic float without a hint value 3325 """ 3326 symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_FLOAT, next(self.unbacked_symfloat_counter)) 3327 self.counter["create_unbacked_symbol"] += 1 3328 if not self._ignore_fresh_unbacked_symbols_tls(): 3329 self.pending_fresh_unbacked_symbols.append(symbol) 3330 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) 3331 vr = self.var_to_range[symbol] = ValueRanges.unknown() 3332 assert vr.is_float 3333 3334 # Create a new FX placeholder and Z3 variable for 'symbol'. 3335 fx_node = self._create_fx_placeholder_and_z3var(symbol, float) 3336 3337 self._log_create_unbacked_symbol("create_unbacked_symfloat", symbol, vr) 3338 3339 return SymFloat(SymNode(symbol, self, float, None, fx_node=fx_node)) 3340 3341 @record_shapeenv_event() 3342 def create_unbacked_symint(self): 3343 """Create a symbolic integer without a hint value 3344 """ 3345 symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True) 3346 if not self._ignore_fresh_unbacked_symbols_tls(): 3347 self.pending_fresh_unbacked_symbols.append(symbol) 3348 self.counter["create_unbacked_symbol"] += 1 3349 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) 3350 vr = self.var_to_range[symbol] = self._default_unspecified_value_range() 3351 assert vr.is_int 3352 3353 # Create a new FX placeholder and Z3 variable for 'symbol'. 3354 fx_node = self._create_fx_placeholder_and_z3var(symbol, int) 3355 3356 self._log_create_unbacked_symbol("create_unbacked_symint", symbol, vr) 3357 3358 return SymInt(SymNode(symbol, self, int, None, fx_node=fx_node)) 3359 3360 def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool: 3361 """Check if a sympy symbol matches the naming convention for unbacked symbols 3362 """ 3363 return symbol_is_type(symbol, SymT.UNBACKED_INT) 3364 3365 @record_shapeenv_event() 3366 def create_unbacked_symbool(self): 3367 """Create a symbolic boolean without a hint value 3368 """ 3369 symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True) 3370 if not self._ignore_fresh_unbacked_symbols_tls(): 3371 self.pending_fresh_unbacked_symbols.append(symbol) 3372 self.counter["create_unbacked_symbol"] += 1 3373 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) 3374 vr = self.var_to_range[symbol] = ValueRanges(0, 1) 3375 assert vr.is_int 3376 3377 # Create a new FX placeholder and Z3 variable for 'symbol'. 3378 fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) 3379 3380 self._log_create_unbacked_symbol("create_unbacked_symbool", symbol, vr) 3381 3382 return SymBool(SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node)) 3383 3384 @record_shapeenv_event() 3385 def create_unspecified_symbol( 3386 self, 3387 val: Union[int, SymInt, float, SymFloat], 3388 source: Source, 3389 dynamic_dim: DimDynamic = DimDynamic.DUCK, 3390 constraint_dim: DimConstraint = None, # NB: includes None 3391 ) -> "sympy.Expr": 3392 """Create a symbol with an unspecified value 3393 3394 Compared to standard symbols we do not assume the value is positive, 3395 nor do we specialze on zero or one values. 3396 """ 3397 # 'positive' is None for unspecified symbols, since we can't 3398 # assume that it will be neither positive nor negative. 3399 3400 # We don't want to specialize zero one val for unspecified symbol 3401 # so that we can always get a new symbol despite val. 3402 return self.create_symbol( 3403 val, 3404 source, 3405 dynamic_dim, 3406 constraint_dim, 3407 positive=None, 3408 do_not_specialize_zero_one=True, 3409 symbolic_context=None) 3410 3411 @record_shapeenv_event() 3412 def create_symbol( 3413 self, 3414 val: int, 3415 source: Source, 3416 dynamic_dim: DimDynamic = DimDynamic.DUCK, 3417 constraint_dim: DimConstraint = None, # NB: includes None 3418 positive: Optional[bool] = True, 3419 do_not_specialize_zero_one: bool = False, 3420 symbolic_context=None, 3421 ) -> "sympy.Expr": 3422 """Create a new symbol which is tracked by this ShapeEnv 3423 """ 3424 # check if constraint_dim is actually static integer 3425 if isinstance(constraint_dim, StrictMinMaxConstraint) and constraint_dim.vr.lower == constraint_dim.vr.upper: 3426 dynamic_dim = DimDynamic.STATIC 3427 if constraint_dim.vr.lower != val: 3428 raise ConstraintViolationError( 3429 f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, " 3430 f"for {source.name()}" 3431 ) 3432 if symbolic_context: 3433 symbolic_context.dynamic_sizes[source.idx] = dynamic_dim 3434 symbolic_context.constraint_sizes[source.idx] = None 3435 constraint_dim = None 3436 3437 # see note [Tensor Fakification and Symbol Caching] 3438 source_name = source.name() 3439 if (isinstance(symbolic_context, StatefulSymbolicContext) 3440 and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache): 3441 symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {} 3442 3443 if (isinstance(symbolic_context, StatefulSymbolicContext) 3444 and source_name 3445 and (source_name in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)])): 3446 return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] 3447 3448 if dynamic_dim is DimDynamic.SIZE_LIKE_UNBACKED: 3449 out = self.create_unbacked_symint().node.expr 3450 self._constrain_range_for_size(out) 3451 # TODO: maybe put the hint somewhere 3452 if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: 3453 symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = out 3454 return out 3455 3456 if do_not_specialize_zero_one: 3457 specialize_zero_one = False 3458 else: 3459 specialize_zero_one = self.specialize_zero_one 3460 3461 assert isinstance(source, Source), f"{type(source)} {source}" 3462 assert not (positive and val < 0), f"positive set for negative value: {val}" 3463 # It's always sound to allocate a symbol as DYNAMIC. If the user 3464 # constrained the symbol, force the symbolic_context to DYNAMIC, because our 3465 # constraint code will do weird stuff if, e.g., it's duck shaped 3466 if constraint_dim is not None: 3467 dynamic_dim = DimDynamic.DYNAMIC 3468 3469 if dynamic_dim is DimDynamic.STATIC: 3470 out = sympy.Integer(val) 3471 if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: 3472 symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = out 3473 return out 3474 3475 elif dynamic_dim is DimDynamic.DUCK: 3476 # duck_shape can be used to globally turn off duck shaping, even 3477 # if it was requested 3478 duck = self.duck_shape 3479 elif dynamic_dim is DimDynamic.DYNAMIC: 3480 duck = False 3481 else: 3482 raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}") 3483 3484 if val in (0, 1) and specialize_zero_one: 3485 r = self.val_to_var[val] 3486 elif not duck or val not in self.val_to_var: 3487 # If we're not duck shaping, we always create a new symbol 3488 # Even if we're duck shaping, if we haven't seen this particular 3489 # value before, we also create a new symbol 3490 if type(val) is int or is_nested_int(val): 3491 sympy_expr = make_symbol(SymT.SIZE, len(self.var_to_val), positive=positive, integer=True) 3492 else: 3493 sympy_expr = make_symbol(SymT.FLOAT, len(self.var_to_val), positive=positive, real=True) 3494 # We always associate vars to vals 3495 if isinstance(val, int): 3496 self.var_to_val[sympy_expr] = sympy.Integer(val) 3497 elif isinstance(val, float): 3498 self.var_to_val[sympy_expr] = sympy.Float(val) 3499 else: 3500 # Only used for jagged layout nested tensors 3501 self.var_to_val[sympy_expr] = SingletonInt(val.node.nested_int(), coeff=val.node.nested_int_coeff()) 3502 3503 # Do the appending later, because we always want to populate this 3504 self.var_to_sources[sympy_expr] = [] 3505 # Create a Z3 variable for the new symbol. 3506 self._add_z3var(sympy_expr, int) 3507 3508 if duck: 3509 # Make sure to reuse this symbol for subsequent duck shaping 3510 self.val_to_var[val] = sympy_expr 3511 3512 if isinstance(val, int): 3513 if positive: 3514 # Add assertions for the newly created symbols 3515 self._add_assertion(sympy_expr > 1) 3516 3517 # Apply default range, which assumes not zero-one 3518 self.var_to_range[sympy_expr] = self._default_value_range() 3519 else: 3520 self.var_to_range[sympy_expr] = self._default_unspecified_value_range() 3521 3522 # Small performance optimization: if we have a min-max constraint, 3523 # we can proactively narrow to that range 3524 if isinstance(constraint_dim, StrictMinMaxConstraint): 3525 assert not duck 3526 self.var_to_range[sympy_expr] &= constraint_dim.vr 3527 3528 vr = self.var_to_range[sympy_expr] 3529 assert vr.is_int 3530 3531 if val not in vr: 3532 raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") 3533 3534 range_str = f"[{vr.lower}, {vr.upper}]" 3535 elif isinstance(val, float): 3536 self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) 3537 range_str = f"[{vr.lower}, {vr.upper}]" 3538 assert vr.is_float 3539 else: 3540 # Skip var_range logic for SingletonInt 3541 # Only used for jagged layout nested tensors 3542 range_str = "" 3543 3544 r = sympy_expr 3545 3546 is_debug = ( 3547 config.extended_debug_create_symbol is not None and 3548 str(sympy_expr) in config.extended_debug_create_symbol.split(',') 3549 ) 3550 maybe_more_info = "" 3551 if not is_debug: 3552 maybe_more_info = ( 3553 ", for more info run with " 3554 f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{sympy_expr}"' 3555 ) 3556 fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) 3557 self.log.info( 3558 "create_symbol %s = %s for %s %s%s (%s)%s%s", 3559 sympy_expr, val, source.name(), range_str, 3560 maybe_user_loc, format_frame(fsummary), maybe_more_info, maybe_extra_debug, stack_info=is_debug 3561 ) 3562 3563 self.counter["create_symbol"] += 1 3564 else: 3565 # This implements duck-shaping: input sizes that match are assigned 3566 # the same symint 3567 r = self.val_to_var[val] 3568 self.log.debug("create_symbol %s duck sized %s", r, source.name()) 3569 3570 if isinstance(r, sympy.Symbol): 3571 r_sources = self.var_to_sources[r] 3572 r_sources.append(source) 3573 if not source.is_ephemeral() and r_sources[0].is_ephemeral(): 3574 # prefer non-ephemeral source first since it may be guarded on later 3575 r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0] 3576 3577 # This ensures we get zeros in symbol_guard_counts, which makes 3578 # some queries simpler (since we will accumulate mass on 0 this 3579 # way) 3580 self.symbol_guard_counter[r] = 0 3581 3582 if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: 3583 symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = r 3584 return r 3585 3586 def add_var_to_val(self, expr: sympy.Symbol, val: int): 3587 """ Adds a new symbol to the symbolic environment. """ 3588 log.debug("add_var_to_val %s %s", expr, val, stack_info=True) 3589 assert expr not in self.var_to_val, f"{expr} already exists" 3590 self.var_to_val[expr] = sympy.Integer(val) 3591 3592 def _debug_name(self, source): 3593 src_name = source.name() 3594 return self.source_name_to_debug_name.get(src_name, src_name) 3595 3596 def _render_range_for_constraint_violation(self, source, c): 3597 if isinstance(c, StrictMinMaxConstraint): 3598 lower, upper = c.vr.lower, c.vr.upper 3599 default = self._default_value_range() 3600 if lower <= default.lower: 3601 lower = None 3602 if upper >= default.upper: 3603 upper = None 3604 c_render = f"{self._debug_name(source)} = {source.name()} in the specified range" 3605 if lower is not None and upper is not None: 3606 c_render += f" {lower} <= {self._debug_name(source)} <= {upper}" 3607 elif lower is None and upper is not None: 3608 c_render += f" {self._debug_name(source)} <= {upper}" 3609 elif lower is not None and upper is None: 3610 c_render += f" {lower} <= {self._debug_name(source)}" 3611 return c_render 3612 return c.render(source) 3613 3614 def produce_guards( 3615 self, 3616 placeholders, 3617 sources, 3618 source_ref=lambda n: n.name(), 3619 *, 3620 guards: List[ShapeGuard] = None, 3621 input_contexts: Optional[DimList[SymbolicContext]] = None, 3622 # Encodes user-specified input shape equations of the form s = s' and s = fn(s'). 3623 # (See docs on EqualityConstraint for details of the encoding.) 3624 equalities_inputs: Optional[EqualityConstraint] = None, 3625 _simplified=False, 3626 # Indicates if we should produce guards for known static values. 3627 ignore_static=True, 3628 ) -> List[str]: 3629 """ 3630 Generates a list of guards strings which, when evaluated in a context that 3631 defines tensors for all the sources, returns True or False depending 3632 on if the guards in the list evaluated to True or not. Primarily used by Dynamo, 3633 but this is also helpful for manual testing of guards (see 3634 evaluate_guards_for_args) 3635 3636 For convenience in testing, a source is allowed to be a str, 3637 in which case we will assume it is a LocalSource 3638 3639 simplified lets you omit duck sizing, equality and 0/1 guards. 3640 This is useful for testing when you don't care about the boilerplate 3641 guards, and it may be helpful for user output too (be careful though; 3642 some equality guards are nontrivial! It would be nice to get simplified 3643 output to print them too). It's private because it's not 3644 intended for normal use 3645 """ 3646 self.log.info("produce_guards") 3647 3648 # Check if we get to the same ShapeEnv state by replaying the recorded events. 3649 # This will create a new ShapeEnv instance, and call all recorded function 3650 # calls on this new instance. Finally, it will check whether this new instance 3651 # has equal state. 3652 # 3653 # It's important that we do it in the begining of this function, since it modifies 3654 # self.dim_constraints through its execution. Changes that happen in this method 3655 # aren't interesting, since this is the function call we wish to reproduce at the 3656 # end. If we wish to simply reproduce ShapeEnv instances even after this call, 3657 # this method should also be recorded. 3658 if self.check_recorded_events: 3659 shape_env = replay_shape_env_events(self.events) 3660 self.check_equal(shape_env) 3661 3662 assert len(placeholders) == len(sources), f"len({placeholders}) != len({sources})" 3663 Tensorlike = (torch.Tensor, FakeTensorMeta) 3664 3665 def _create_no_constraints_context(t): 3666 return StatelessSymbolicContext( 3667 # Ignored; only the constraints part is relevant below. 3668 dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(), 3669 dynamic_strides=[DimDynamic.INFER_STRIDE] * t.dim(), 3670 constraint_sizes=[None] * t.dim(), 3671 constraint_strides=[None] * t.dim() 3672 ) 3673 3674 # Expand optional inputs, or verify invariants are upheld 3675 if input_contexts is None: 3676 input_contexts = [ 3677 _create_no_constraints_context(t) if isinstance(t, Tensorlike) 3678 else None for t in placeholders 3679 ] 3680 else: 3681 assert len(input_contexts) == len(placeholders) 3682 for i, (t, context) in enumerate(zip(placeholders, input_contexts)): 3683 if isinstance(t, Tensorlike): 3684 if context is None: 3685 input_contexts[i] = _create_no_constraints_context(t) 3686 else: 3687 assert isinstance(t, (SymInt, int, SymFloat, float)) 3688 assert not isinstance(context, list) 3689 3690 # It took a lot of sweat to figure out the algorithm here. Let's 3691 # explain how it works. 3692 # 3693 # The ShapeEnv lifecycle looks something like this: 3694 # 3695 # - For each input, you either generate a fresh Sympy symbol (s0) to 3696 # represent its value (a binding site), or you reuse some 3697 # preexisting symbol or expression, skipping the symbol allocation 3698 # (e.g., duck sizing to a preexisting symbol, or expressing a 3699 # stride as a multiplication of a separate stride and size.) 3700 # Naively, you might expect to bind a fresh Sympy symbol for 3701 # every input, but this is fairly wasteful as most of these 3702 # symbols immediately simplify away, and if you don't eagerly 3703 # specialize, e.g., 0/1 symbols, you end up with very complicated 3704 # expressions that are not optimizable in practice. 3705 # 3706 # - You perform some compute on these symbols, occasionally 3707 # introducing guards on boolean expressions on these symbols. 3708 # In particular, whenever we guard on equality (_maybe_guard_rel), 3709 # we can simplify shapes; e.g., when s0 == s1 * 2, we can now 3710 # replace all occurrences of s0 with s1 * 2. Sometimes, a 3711 # boolean expression evaluation doesn't introduce a guard, as 3712 # the guard is already entailed by the simplifications we have 3713 # applied. 3714 # 3715 # - In the end, you have a bunch of replacements (saying how to 3716 # simplify shapes) and a bunch of guards (all the equality guards 3717 # are trivial, because they're covered by the replacements). 3718 # 3719 # From the ShapeEnv, we must generate a Python expression that, when 3720 # evaluated on a set of inputs, tells us whether or not these boolean 3721 # expressions would have evaluated in the same way. However, 3722 # we cannot easily compute this, as we elide recording boolean 3723 # expressions when we think they are vacuously true. Thus, we seek 3724 # an approximation: we must generate an expression, if true, would have 3725 # produced an "equivalent" ShapeEnv, which would answer guard 3726 # expressions in the same way. 3727 # 3728 # Our notion of equivalence is a bit subtle. For example, consider 3729 # the ShapeEnv created from an input of size (5, 4) versus (4, 4) 3730 # (no other guards.) Duck sizing would generate (s0, s1) in the first 3731 # case but (s0, s0) in the second. We do NOT assume that size 3732 # variables are disjoint; so in fact a graph that assumes the input 3733 # could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not 3734 # vice versa. However, consider an analogous case (1,) versus (2,). 3735 # Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT 3736 # subsume the (1,) graph because we assume that any size variables 3737 # is NOT 0/1 (and make simplifications according to this; e.g., if 3738 # we queried s0 == 0, we would immediately return False without 3739 # returning a guard.) 3740 # 3741 # So, it is perhaps easier to flip things on their head: the guard 3742 # expressions we generate here say what simplifications are valid, 3743 # and what are not. Below, we explain each of the guard expressions 3744 # we generate 3745 3746 # TODO: Make this more efficient by binding all the size/stride/offsets 3747 # to locals before performing tests on them. 3748 3749 from torch._dynamo.source import TensorPropertySource, TensorProperty 3750 3751 # Actual codegen must be delayed as we don't necessarily know what 3752 # the symbol mapping is 3753 input_guards = [] 3754 3755 symbol_to_source = collections.defaultdict(list) 3756 symbol_to_constraints = collections.defaultdict(set) 3757 constraint_violations : List[Tuple[bool, str, Callable[[], str]]] = [] 3758 3759 def record_constraint_violation(warn_only, debug_name, msg, hint=None): 3760 constraint_violations.append( 3761 (warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg) 3762 ) 3763 3764 def is_dim(src): 3765 return isinstance(src, TensorPropertySource) and src.prop is TensorProperty.SIZE 3766 3767 if equalities_inputs: 3768 source_index = {} 3769 for i, src in enumerate(sources): 3770 source_index[src.name()] = i 3771 3772 def get_expression(tensor_dim_src): 3773 fake = placeholders[source_index[tensor_dim_src.base.name()]] 3774 symint = fake.shape[tensor_dim_src.idx] 3775 if isinstance(symint, torch.SymInt): 3776 return symint.node.expr 3777 else: 3778 assert type(symint) is int, f"Expected int, got {type(symint)}" 3779 return symint 3780 3781 for src1, src2 in equalities_inputs.source_pairs: 3782 expr1, expr2 = get_expression(src1), get_expression(src2) 3783 # Check whether given input shape values satisfy a specified equation s = s'. 3784 # - Raise when the equation was violated by the given input shape values. 3785 # - Otherwise issue a guard to constrain them. 3786 concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2)) 3787 if not concrete_val: 3788 raise ConstraintViolationError( 3789 f"{src1.name()} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}" 3790 " is not equal to " 3791 f"{src2.name()} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}" 3792 ) 3793 3794 for src, root, fn in equalities_inputs.derived_equalities: 3795 expr1 = get_expression(src) 3796 # recall that root is either a phantom symbol or an input source 3797 expr2, debug_name = ( 3798 (root, self.var_to_sources[root][0].name()) if isinstance(root, sympy.Symbol) 3799 else (get_expression(root), self._debug_name(root)) 3800 ) 3801 expr2_ = fn(expr2) 3802 # Check whether given input shape values satisfy a specified equation s = fn(s'). 3803 # - Raise when the equation was violated by the given input shape values. 3804 # - Otherwise issue a guard to constrain them. 3805 concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_)) 3806 if not concrete_val: 3807 raise ConstraintViolationError( 3808 f"Expected input {src.name()} to be equal to " 3809 f"{fn(sympy.Symbol(debug_name))}, " 3810 f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, " 3811 f"but got {expr1.xreplace(self.var_to_val)}" 3812 ) 3813 3814 for phantom_symbol in equalities_inputs.phantom_symbols: 3815 # we created additional phantom symbols that are not input shape dimensions 3816 symbol_to_source[phantom_symbol].extend(self.var_to_sources[phantom_symbol]) 3817 3818 # How do we know what the value of s0 is? Fresh variables can only be 3819 # bound by inputs, so there MUST be some other input which binds the 3820 # variable. If there is no such input, this is an error in our 3821 # system. We record where all symbols come from, to help you diagnose 3822 # why those symbols didn't occur. 3823 # 3824 # In fact, generally speaking it is only possible for the "outermost" 3825 # user of a ShapeEnv to evaluate the guards, because some inputs may 3826 # not be available to inner levels. For example, Dynamo can guard on 3827 # tensors that never actually become graph arguments (they are 3828 # pruned). In this case, only Dynamo knows about these arguments. 3829 def track_symint(source, val, constraint=None): 3830 log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint) 3831 assert not isinstance(val, SymInt) or is_symbolic(val) 3832 3833 if isinstance(val, SymInt) and val.node.maybe_as_int() is not None: 3834 val = val.node.maybe_as_int() 3835 3836 if isinstance(val, SymInt): 3837 s = val.node.expr 3838 if isinstance(s, sympy.Symbol): 3839 symbol_to_source[s].append(source) 3840 if ( 3841 constraint is not None 3842 and not isinstance(constraint, RelaxedUnspecConstraint) 3843 ): 3844 symbol_to_constraints[s].add(constraint) 3845 else: 3846 constraint_violated = False 3847 if isinstance(constraint, StrictMinMaxConstraint): 3848 # try inferring the ranges of the expr s 3849 sym_vrs = {x: self.var_to_range.get(x, None) for x in s.free_symbols} 3850 if any(vr is None for vr in sym_vrs.values()): 3851 # some of the free symbols in s don't have ranges 3852 constraint_violated = True 3853 elif isinstance(constraint, RelaxedUnspecConstraint): 3854 if s.is_number: 3855 i = int(s) 3856 # Don't complain about 0/1 specialization, we 3857 # expect to have to compile in this case anyway 3858 if i not in (0, 1): 3859 constraint_violated = True 3860 if constraint_violated: 3861 def hint(s): 3862 sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s) 3863 return f"{sexpr}." 3864 3865 var_with_range = self._render_range_for_constraint_violation(source, constraint) 3866 msg = ( 3867 f"Not all values of {var_with_range} are valid because " 3868 f"{self._debug_name(source)} was inferred to be equal to " 3869 ) 3870 record_constraint_violation( 3871 constraint.warn_only, 3872 self._debug_name(source), 3873 msg, 3874 hint=functools.partial(hint, s), 3875 ) 3876 3877 input_guards.append((source, s)) 3878 else: 3879 s = sympy.Integer(val) 3880 input_guards.append((source, s)) 3881 constraint_violated = False 3882 if isinstance(constraint, StrictMinMaxConstraint): 3883 if not (s == constraint.vr.lower == constraint.vr.upper): # allow static constraints 3884 constraint_violated = True 3885 elif isinstance(constraint, RelaxedUnspecConstraint): 3886 # Don't complain about 0/1 specialization, we 3887 # expect to have to compile in this case anyway 3888 if val not in (0, 1): 3889 constraint_violated = True 3890 if constraint_violated: 3891 var_with_range = self._render_range_for_constraint_violation(source, constraint) 3892 msg = ( 3893 f"Not all values of {var_with_range} are valid because " 3894 f"{self._debug_name(source)} was inferred to be a constant ({val})." 3895 ) 3896 record_constraint_violation(constraint.warn_only, self._debug_name(source), msg) 3897 3898 def track_symfloat(source, val): 3899 log.debug("track_symfloat %s %s", LazyString(source.name), val) 3900 assert not isinstance(val, SymFloat) or is_symbolic(val) 3901 3902 if isinstance(val, SymFloat) and val.node.maybe_as_float() is not None: 3903 val = val.node.maybe_as_float() 3904 3905 if isinstance(val, SymFloat): 3906 s = val.node.expr 3907 if isinstance(s, sympy.Symbol): 3908 symbol_to_source[s].append(source) 3909 input_guards.append((source, s)) 3910 else: 3911 s = sympy.Float(val) 3912 input_guards.append((source, s)) 3913 3914 for t, source, context in zip(placeholders, sources, input_contexts): 3915 if isinstance(source, str): 3916 from torch._dynamo.source import LocalSource 3917 source = LocalSource(source) 3918 assert isinstance(source, Source) 3919 if t is None: 3920 continue 3921 if isinstance(t, (SymInt, int)): 3922 track_symint(source, t) 3923 continue 3924 elif isinstance(t, (SymFloat, float)): 3925 track_symfloat(source, t) 3926 continue 3927 assert isinstance(t, Tensorlike) 3928 if is_traceable_wrapper_subclass(t): 3929 from torch._dynamo.source import AttrSource 3930 3931 assert isinstance(context, SubclassSymbolicContext) 3932 3933 # For subclasses, we need to track symints on BOTH the outer 3934 # and inner tensors. 3935 sources_tensors_constraints = [ 3936 (source, t, context.constraint_sizes, context.constraint_strides) 3937 ] 3938 attrs, _ = t.__tensor_flatten__() 3939 for attr in attrs: 3940 inner_t = getattr(t, attr) 3941 inner_context = context.inner_contexts[attr] 3942 sources_tensors_constraints.append(( 3943 AttrSource(source, attr), 3944 inner_t, 3945 inner_context.constraint_sizes, 3946 inner_context.constraint_strides 3947 )) 3948 else: 3949 sources_tensors_constraints = [(source, t, context.constraint_sizes, context.constraint_strides)] 3950 3951 for src, curr_t, constraint_size, constraint_stride in sources_tensors_constraints: 3952 if is_sparse_any(curr_t): 3953 for i, ss in enumerate(curr_t.size()): 3954 property_source = TensorPropertySource(src, TensorProperty.SIZE, i) 3955 track_symint(property_source, ss, constraint_size[i]) 3956 else: 3957 for i, ss in enumerate(curr_t.size()): 3958 property_source = TensorPropertySource(src, TensorProperty.SIZE, i) 3959 track_symint(property_source, ss, constraint_size[i]) 3960 for i, ss in enumerate(curr_t.stride()): 3961 property_source = TensorPropertySource(src, TensorProperty.STRIDE, i) 3962 track_symint(property_source, ss, constraint_stride[i]) 3963 track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset()) 3964 3965 # 1. Every input must equal the final simplified symbolic expression 3966 # stored on the placeholder. Given a placeholder (s0*2, s1), 3967 # if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3. 3968 # This does a lot of work: it covers duck sizing and equality guards. 3969 exprs = [] 3970 self.dim_constraints = DimConstraints( 3971 symbol_to_source, 3972 self.var_to_val, 3973 set(symbol_to_constraints.keys()), 3974 self.source_name_to_debug_name, 3975 ) 3976 3977 if not _simplified: 3978 for source, expr in input_guards: 3979 if self._translation_validation_enabled: 3980 # Ignore sources that were not turned into SymInts. 3981 srcname = source.name() 3982 if srcname in self.source_to_symbol: 3983 self._add_target_expr(sympy.Eq(self.source_to_symbol[srcname], expr)) 3984 3985 # Small optimization 3986 if ( 3987 isinstance(expr, sympy.Symbol) and 3988 symbol_to_source.get(expr) and 3989 source == symbol_to_source[expr][0] 3990 ): 3991 continue 3992 3993 # This logic excludes static values found on tensors from guarding, because 3994 # dynamo's check_tensor_fn does that (see guards.cpp). 3995 # However, for non tensor sources, we still need to guard here. 3996 if ignore_static and isinstance(source, TensorPropertySource): 3997 if expr.is_number: 3998 self.log.debug("Skipping guard %s", f"{source_ref(source)} == {expr}") 3999 continue 4000 4001 if is_dim(source): 4002 self.dim_constraints.add_equality(source, expr) 4003 4004 sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr) 4005 exprs.append(f"{source_ref(source)} == {sexpr}") 4006 if ( 4007 isinstance(source, TensorPropertySource) 4008 and source.prop is TensorProperty.SIZE 4009 and equalities_inputs 4010 and len(expr.free_symbols) == 1 4011 ): 4012 symbol = next(iter(expr.free_symbols)) 4013 if ( 4014 isinstance(expr, sympy.Symbol) and 4015 expr in symbol_to_constraints and 4016 not equalities_inputs.is_equal(source, symbol_to_source[expr][0]) 4017 ): 4018 msg = ( 4019 f"The values of {self._debug_name(source)} = {source.name()} and " 4020 f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} " 4021 "must always be equal." 4022 ) 4023 record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg) 4024 4025 if ( 4026 not isinstance(expr, sympy.Symbol) and 4027 symbol in symbol_to_constraints and 4028 not equalities_inputs.is_derived(source, symbol_to_source[symbol][0], lambda x: expr.xreplace({symbol: x})) 4029 ): 4030 src = symbol_to_source[symbol][0] 4031 msg = ( 4032 f"The values of {self._debug_name(source)} = {source.name()} must always be related to " 4033 f"the values of {self._debug_name(src)} = {src.name()} by " 4034 f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}." 4035 ) 4036 record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg) 4037 4038 # NB: Not necessary to report constraint violations here: 4039 # constraints are guaranteed to be on symbols (we've already 4040 # caught constants and non-atomic expressions), so we only 4041 # have relational constraints, but we don't support those 4042 # at the moment 4043 4044 # 2. Every guard must evaluate to True (but remember many guards 4045 # like s0 == s1*2 because trivial due to simplification) 4046 issued = set() 4047 4048 def issue_guard(guard: ShapeGuard) -> None: 4049 expr = self.simplify(guard.expr) 4050 4051 # Avoid re-issueing the same guard. 4052 if expr in issued: 4053 return 4054 4055 issued.add(expr) 4056 4057 try: 4058 is_trivial = False 4059 if any(is_dim(source) for s in expr.free_symbols for source in symbol_to_source[s]): 4060 is_trivial = self.dim_constraints.add(expr) 4061 guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr) 4062 exprs.append(guard_expr) 4063 self._add_target_expr(expr) 4064 # A non-relational constraint on a single sizevar can violate 4065 # a constraint 4066 if not is_trivial and len(expr.free_symbols) == 1: 4067 symbol = next(iter(expr.free_symbols)) 4068 source = symbol_to_source[symbol][0] 4069 constraints = symbol_to_constraints[symbol] 4070 for c in constraints: 4071 if isinstance(c, StrictMinMaxConstraint): 4072 var_with_range = self._render_range_for_constraint_violation(source, c) 4073 msg = ( 4074 f"Not all values of {var_with_range} " 4075 f"satisfy the generated guard {guard_expr}." 4076 ) 4077 record_constraint_violation(c.warn_only, self._debug_name(source), msg) 4078 elif isinstance(c, RelaxedUnspecConstraint): 4079 # This is fine, we allow guards here as long as it 4080 # didn't constrain it to one value (we don't 4081 # actually know this; this depends on our 4082 # ValueRanges reasoning capability) 4083 pass 4084 else: 4085 raise AssertionError(f"unrecognized constraint {c}") 4086 except Exception: 4087 self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format())) 4088 raise 4089 4090 # First, issue all guards. 4091 # This removes all the checks that follow from bounds 4092 # We could simply emit those and also the bounds 2 <= size when necessary 4093 for guard in (guards if guards is not None else self.guards): 4094 if self._maybe_evaluate_static(guard.expr, axioms=()) is not None: 4095 continue 4096 issue_guard(guard) 4097 4098 # Because there are guards that export's constraint solver can suggest good fixes for, that we may have 4099 # deferred as runtime asserts, and that produce_guards() alone won't do anything with (e.g. divisiblity guards), 4100 # we want to send runtime asserts to export's constraint solver too. These will still stay in the graph as asserts, 4101 # but export's constraint solver can decide whether to do anything with them (i.e. raise an error and provide 4102 # suggested fixes, or decide it's out of scope and leave as a runtime assert in the graph). 4103 for ra in self.deferred_runtime_asserts.get(None, []): 4104 if self._maybe_evaluate_static(ra.expr, axioms=()) is not None: 4105 continue 4106 expr = self.simplify(ra.expr) 4107 self.dim_constraints.add(expr) 4108 4109 # 3. Every symbol must be within its value range (this handles 0/1 4110 # specialization too). 4111 for symbol, sources in symbol_to_source.items(): 4112 r = self.var_to_range.get(symbol) 4113 if r is None: 4114 if symbol not in self.var_to_range: 4115 continue 4116 r = self.var_to_range[symbol] 4117 4118 assert sources 4119 bounds = [] 4120 if r.lower not in (-sympy.oo, -int_oo): 4121 if any(is_dim(source) for source in sources): 4122 self.dim_constraints.add(sympy.Ge(symbol, r.lower)) 4123 # Only print lower bound in simplified mode if it is not the 4124 # default 4125 if not _simplified or r.lower != self._default_value_range().lower: 4126 bounds.append(str(r.lower)) 4127 bounds.append(source_ref(sources[0])) 4128 if r.upper not in (sympy.oo, int_oo): 4129 if any(is_dim(source) for source in sources): 4130 self.dim_constraints.add(sympy.Le(symbol, r.upper)) 4131 # nontrivial upper bound is always interesting 4132 bounds.append(str(r.upper)) 4133 if len(bounds) > 1: 4134 exprs.append(" <= ".join(bounds)) 4135 4136 # Check constraints 4137 constraints = symbol_to_constraints[symbol] 4138 for c in constraints: 4139 if isinstance(c, StrictMinMaxConstraint): 4140 # TODO: With int_oo, I think this condition is a noop 4141 # now 4142 if not (c.vr & self._default_value_range()).issubset(r): 4143 source = sources[0] 4144 4145 expr = sympy.And(sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper)) 4146 guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr) 4147 var_with_range = self._render_range_for_constraint_violation(source, c) 4148 msg = ( 4149 f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}" 4150 ) 4151 record_constraint_violation( 4152 c.warn_only, 4153 self._debug_name(source), 4154 msg, 4155 ) 4156 # We NaN specialize, which means similar to 0/1 specialization we 4157 # should assume that the float is NOT nan. This is load bearing 4158 # if you have something like an equality guard, nan will play 4159 # merry hell with the reasoning. 4160 if symbol_is_type(symbol, SymT.FLOAT): 4161 exprs.append(f"not __math_isnan({source_ref(sources[0])})") 4162 4163 if constraint_violations: 4164 warn_msgs = [] 4165 error_msgs = [] 4166 debug_names = set() 4167 for warn_only, debug_name, msg in constraint_violations: 4168 if warn_only: 4169 msg = f" {len(warn_msgs) + 1}. {msg()}" 4170 warn_msgs.append(msg) 4171 else: 4172 msg = f" - {msg()}" 4173 error_msgs.append(msg) 4174 debug_names.add(debug_name) 4175 if len(error_msgs) > 0: 4176 debug_names = ', '.join(sorted(debug_names)) 4177 err = '\n'.join(error_msgs) 4178 raise ConstraintViolationError( 4179 f"Constraints violated ({debug_names})! " 4180 'For more information, run with TORCH_LOGS="+dynamic".\n' 4181 f"{err}" 4182 ) 4183 elif len(warn_msgs) > 0: 4184 log.debug("%s Warning only constraints violated", len(warn_msgs)) 4185 4186 signpost_event( 4187 "dynamic", 4188 "produce_guards", 4189 { 4190 **self.co_fields, 4191 **self.counter, 4192 "num_guards": len(exprs), 4193 "free_symbols": sum(1 for v in symbol_to_source.values() if v), 4194 # The keys are meaningless from an aggregate perspective, so 4195 # don't include them. Biggest first. 4196 "symbol_guard_counts": sorted(self.symbol_guard_counter.values(), reverse=True), 4197 }, 4198 ) 4199 4200 if self._translation_validation_enabled: 4201 from torch.fx.experimental.validator import PopulateValidator 4202 4203 # Add all deferred runtime assertions; these are not technically 4204 # handled by produce_guards but we need to put them in the target 4205 # set 4206 for ras in self.deferred_runtime_asserts.values(): 4207 for ra in ras: 4208 self._add_target_expr(ra.expr) 4209 4210 # Add value range bound guards for all symbols with no trivial bounds. 4211 # Reason: '_maybe_evaluate_static' may eliminate guards based on the 4212 # refined value ranges. 4213 for sym, vr in self.var_to_range.items(): 4214 if vr.lower not in (-sympy.oo, -int_oo): 4215 self._add_target_expr(sympy.Le(vr.lower, sym)) 4216 if vr.upper not in (sympy.oo, int_oo): 4217 self._add_target_expr(sympy.Le(sym, vr.upper)) 4218 4219 # Before validating, populate the input of the validator with the 4220 # built FX graph. 4221 with fx_traceback.preserve_node_meta(): 4222 PopulateValidator(self.graph, self.validator).run() 4223 4224 # Only run translation validation when we are not passing custom guards 4225 if guards is None: 4226 self._check_translation_validate() 4227 return exprs 4228 4229 def produce_guards_expression( 4230 self, 4231 placeholders, 4232 *, 4233 guards: Optional[List[ShapeGuard]] = None, 4234 ignore_static=True 4235 ): 4236 """ 4237 Expected to be used with evaluate_guards_expression(). Produces the guards 4238 for the given placeholders and returns a string expression to be evaluated 4239 by evaluate_guards_expression given concrete values for the placeholders. 4240 """ 4241 from torch._dynamo.source import LocalSource 4242 arg_names = [f"t{i}" for i in range(len(placeholders))] 4243 produced_guards = self.produce_guards( 4244 placeholders, 4245 [LocalSource(a) for a in arg_names], 4246 guards=guards, 4247 ignore_static=ignore_static, 4248 ) 4249 if produced_guards: 4250 return " and ".join(produced_guards) 4251 return None 4252 4253 def evaluate_symexpr(self, code): 4254 """ 4255 To be used by compile_fx to evaluate symexprs 4256 """ 4257 args = {str(e): val for e, val in self.var_to_val.items()} 4258 return eval(code, SYMPY_INTERP, args) 4259 4260 def evaluate_guards_expression(self, code, args): 4261 """ 4262 Expected to be used with produce_guards_expression(). Evaluates an expression 4263 generated by produce_guards_expression for the given concrete args. 4264 """ 4265 arg_names = [f"t{i}" for i in range(len(args))] 4266 return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))}) 4267 4268 def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True): 4269 """Generate guards for a graph's placeholder values and evaluate the guards with args 4270 """ 4271 code = self.produce_guards_expression(placeholders, ignore_static=ignore_static) 4272 if code: 4273 return self.evaluate_guards_expression(code, args) 4274 return True 4275 4276 def get_pruned_guards(self, symints): 4277 """ 4278 Get a list of guards, but pruned so it only provides guards that 4279 reference symints from the passed in input 4280 """ 4281 symints = {s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)} 4282 guards = [] 4283 for g in self.guards: 4284 if all(s in symints for s in g.expr.free_symbols): 4285 guards.append(g) 4286 return guards 4287 4288 def bind_symbols(self, placeholders, args): 4289 """ 4290 Given a paired list of placeholders (fake tensors with 4291 symbolic sizes) and concrete arguments (regular tensors 4292 with real sizes), returns a dictionary mapping each 4293 symbol to its real value. So for example, if you 4294 have a placeholder with size (s0, s1), binding 4295 (2, 4) to it will give you {s0: 2, s1: 4}. This is 4296 not guaranteed to bind ALL symbols in the ShapeEnv; 4297 we can't bind a symbol if it doesn't occur in any placeholder, 4298 and symbols that already have replacements won't get bindings. 4299 4300 This is a little duplicative with evaluate_guards but 4301 it's different enough that it seemed cleanest to make 4302 another copy. This assumes the guards are already checked, 4303 though if it's cheap we'll check for shenanigans 4304 """ 4305 bindings: Dict[sympy.Symbol, int] = {} 4306 4307 def bind_symint(arg, val): 4308 if isinstance(val, SymInt): 4309 s = val.node.expr 4310 4311 if isinstance(s, sympy.Symbol): 4312 if s in bindings: 4313 assert bindings[s] == arg, f"{bindings[s]} != {arg}" 4314 else: 4315 bindings[s] = arg 4316 elif isinstance(-s, sympy.Symbol): 4317 if -s in bindings: 4318 assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}" 4319 else: 4320 bindings[-s] = -arg 4321 4322 for t, arg in zip(placeholders, args): 4323 if t is None: 4324 continue 4325 if isinstance(t, SymInt): 4326 bind_symint(arg, t) 4327 continue 4328 assert isinstance(t, torch.Tensor) 4329 for i, s in enumerate(t.size()): 4330 bind_symint(arg.size(i), s) 4331 for i, s in enumerate(t.stride()): 4332 bind_symint(arg.stride(i), s) 4333 bind_symint(arg.storage_offset(), t.storage_offset()) 4334 4335 return bindings 4336 4337 def get_nontrivial_guards(self): 4338 """Returns a list of guard expressions that aren't statically known (i.e. not trivial)""" 4339 return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr, axioms=()) is None] 4340 4341 def format_guards(self, verbose=False): 4342 """Format this shape env's guard expressions with optional traceback info if verbose""" 4343 def format_tb(tb): 4344 if not verbose: 4345 return "" 4346 return f"\n Guarded at:\n{''.join(' ' + l for l in tb.format())}" 4347 4348 return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards) 4349 4350 def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRanges: 4351 """Given a sympy expression, computes a ValueRanges bound for what values it can be""" 4352 var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} 4353 if size_oblivious: 4354 # Clamp values of size-like variables 4355 # NB: discarding the old upper bound in intentional, per 4356 # https://github.com/pytorch/pytorch/pull/123675 4357 for x in self.size_like & var_to_range.keys(): 4358 if var_to_range[x] is not None: 4359 # NB: do NOT set upper to 2 ** 48, we're using this solely 4360 # to determine if we can do size-like replacement, the 4361 # upper bound is irrelevant here 4362 var_to_range[x] = ValueRanges(2, int_oo) 4363 assert var_to_range[x].is_int 4364 return bound_sympy(expr, var_to_range) 4365 4366 @_lru_cache 4367 def get_axioms(self, symbols: Optional[Tuple["sympy.Symbol"]] = None, compute_hint: bool = False) -> Tuple["sympy.Expr"]: 4368 """ 4369 Given the symbols in an expression, it returns all the runtime asserts that have those symbols 4370 concatenated with all the guards. 4371 If symbols is None, it returns all the runtime asserts (and all the guards) 4372 """ 4373 if symbols is None: 4374 runtime_asserts = (r.expr 4375 for rs in self.deferred_runtime_asserts.values() 4376 for r in rs) 4377 else: 4378 runtime_asserts = (r.expr 4379 for s in symbols if s not in self.var_to_val 4380 for r in self.deferred_runtime_asserts.get(s, ())) 4381 guards = (g.expr for g in self.guards) 4382 axioms = itertools.chain(guards, runtime_asserts) 4383 if compute_hint: 4384 axioms = (canonicalize_bool_expr(a.xreplace(self.var_to_val)) for a in axioms) 4385 return tuple(dict.fromkeys(axioms).keys()) 4386 4387 @lru_cache(None) 4388 def get_implications(self, 4389 e: "sympy.Expr") -> Tuple[Tuple["sympy.Expr", 'sympy.logic.boolalg.BooleanAtom']]: 4390 """ Given a expression, it returns a list of predicates that follow from it """ 4391 equiv = {} 4392 4393 def add_expr(expr): 4394 expr = canonicalize_bool_expr(expr) 4395 if isinstance(expr, (sympy.Eq, sympy.Ne)): 4396 # No need to canonicalize 4397 # TODO We could further canonicalize Eq ordering the lhs and rhs somehow 4398 # With this, we could remove the need for the commutativity part 4399 opposite = sympy.Eq if isinstance(expr, sympy.Ne) else sympy.Ne 4400 # Commutativity of == and != 4401 equiv[type(expr)(expr.lhs, expr.rhs)] = sympy.true 4402 equiv[type(expr)(expr.rhs, expr.lhs)] = sympy.true 4403 equiv[opposite(expr.lhs, expr.rhs)] = sympy.false 4404 equiv[opposite(expr.rhs, expr.lhs)] = sympy.false 4405 else: 4406 # Expr and negation 4407 equiv[expr] = sympy.true 4408 equiv[canonicalize_bool_expr(sympy.Not(expr))] = sympy.false 4409 4410 add_expr(e) 4411 # Other relational expressions this expression implies 4412 if isinstance(e, sympy.Eq): 4413 add_expr(sympy.Le(e.lhs, e.rhs)) 4414 add_expr(sympy.Ge(e.lhs, e.rhs)) 4415 elif isinstance(e, sympy.Lt): 4416 add_expr(sympy.Le(e.lhs, e.rhs)) 4417 add_expr(sympy.Ne(e.lhs, e.rhs)) 4418 if e.lhs.is_integer and e.rhs.is_integer: 4419 add_expr(sympy.Le(e.lhs, e.rhs - 1)) 4420 elif isinstance(e, sympy.Le): 4421 add_expr(sympy.Lt(e.lhs, e.rhs + 1)) 4422 return tuple(equiv.items()) 4423 4424 @_lru_cache 4425 def _maybe_evaluate_static( 4426 self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False, 4427 size_oblivious: bool = False, axioms: Optional[Tuple[sympy.Expr]] = None, 4428 var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None 4429 ) -> "Optional[sympy.Expr]": 4430 """ 4431 Tries to evaluate expr without introducing guards 4432 4433 If unbacked_only == True, then we only do substitutions on 4434 unbacked SymInts (leaving regular hinted integers alone). This could 4435 result in an expression that still contains backed SymInts, which you 4436 could then potentially guard on. 4437 4438 Use compute_hint == True if you are trying to compute a non-binding 4439 hint for the particular hint values of backed SymInts, e.g., if 4440 s0 happens to be 3 this run, compute_hint will subsitute s0 with 3. 4441 """ 4442 4443 # axioms with compute hint NYE 4444 assert not compute_hint or not axioms 4445 4446 if var_to_range is None: 4447 var_ranges = self.var_to_range 4448 else: 4449 var_ranges = dict(var_to_range) 4450 4451 expr = self.simplify(expr) 4452 4453 if compute_hint: 4454 expr = expr.xreplace(self.var_to_val) 4455 4456 expr = canonicalize_bool_expr(expr) 4457 4458 # Pattern matching 4459 symbols = tuple(expr.free_symbols) 4460 if axioms is None: 4461 axioms = self.get_axioms(symbols, compute_hint=compute_hint) 4462 subst = {} 4463 for e in axioms: 4464 if e.free_symbols.issubset(expr.free_symbols): 4465 subst.update(dict(self.get_implications(e))) 4466 4467 expr = expr.xreplace(subst) 4468 4469 symbols = tuple(expr.free_symbols) 4470 4471 # Simplify making use of value range lower bound 4472 new_shape_env = {} 4473 new_range_env = {} 4474 for idx, k in enumerate(symbols): 4475 if isinstance(self.var_to_val.get(k, None), SingletonInt): 4476 # Skip var_ranges logic for SingletonInt which is only used 4477 # for jagged layout NestedTensors today 4478 continue 4479 vr = var_ranges[k] 4480 if size_oblivious and k in self.size_like: 4481 lower = max(2, vr.lower) 4482 # Clamping size-oblivious to some quantity below sys.maxsize 4483 # helps us determine that f(u0) != sys.maxsize, which is a 4484 # test that is looking for sys.maxsize as a sentinel, but you 4485 # don't really want to worry about it for unbacked SymInts. 4486 # This is similar to the flavor where size oblivious omits 4487 # 0/1, it changes semantics but in a benign way. 4488 upper = min(2 ** 48, vr.upper) 4489 # This is a bit dodgy: what this means is that there was a 4490 # size-like unbacked symbol whose upper bound < 2. This 4491 # causes... problems. 4492 if lower <= upper: 4493 vr = ValueRanges(lower, upper) 4494 else: 4495 lower = vr.lower 4496 # Don't do anything if we don't have a nontrivial lower bound 4497 # Also don't do anything if we asked only to simplify unbacked 4498 # SymInt 4499 if ( 4500 lower is -int_oo or 4501 (unbacked_only and k in self.var_to_val) or 4502 not vr.is_int 4503 ): 4504 new_range_env[k] = vr 4505 continue 4506 # The goal is to take our symbols which have various lower bounds 4507 # and reallocate them into new symbols which are exactly positive; 4508 # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in 4509 # [1, inf], where s0 = ess0 + 1. This gives the most information 4510 # to sympy for subsequent simplifications. 4511 # 4512 # Positive means >= 1 4513 # Positive - 1 means >= 0 4514 # Positive + lower - 1 means >= lower 4515 # The new symbol 's' is "too low", so when we substitute it in 4516 # we have to increase it by offset (and conversely, the new 4517 # variables have to have their value range bounds adjusted as 4518 # well) 4519 s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True) 4520 4521 # Note: 4522 # Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers. 4523 # Sympy might give unexepected results when comparing an integer with a non-integer 4524 # Therefore, we cast offset to int here. 4525 # For example: 4526 # shape_0 = sympy.Symbol("shape_0", positive=True, integer=True) 4527 # expr = sympy.Eq(shape_0 - 1/3, 4) 4528 # expr.xreplace({}) # False 4529 offset = int(lower - 1) 4530 new_shape_env[k] = s + offset 4531 new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset) 4532 4533 try: 4534 new_expr = expr.xreplace(new_shape_env) 4535 except RecursionError: 4536 log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env) 4537 self.counter["sympy_recursion_error"] += 1 4538 return None 4539 4540 # We need to canonicalize, as after expand we may have something like `a + b = a` and 4541 # sympy will not simplify the a. The two appeareances of the a will then make value ranges 4542 # analysis give lose bounds 4543 new_expr = canonicalize_bool_expr(safe_expand(new_expr)) 4544 if new_expr.is_number: 4545 return new_expr 4546 4547 # This is bad to do, the replacement with division leaves us with 4548 # rationals when atom.args[0] is addition, e.g., sympy will happily 4549 # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication! 4550 """ 4551 floor_div_replace = {} 4552 for atom in new_expr.atoms(FloorDiv): 4553 floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) 4554 new_expr = safe_expand(new_expr.xreplace(floor_div_replace)) 4555 # TODO: when unbacked_only, can sometimes early return even when there 4556 # are still free symbols 4557 if new_expr.is_number: 4558 return new_expr 4559 """ 4560 4561 # Check if the range can solve it statically 4562 out = bound_sympy(new_expr, new_range_env) 4563 if out.is_singleton(): 4564 return out.lower 4565 4566 return new_expr if unbacked_only else None 4567 4568 @_lru_cache 4569 def replace(self, expr: "sympy.Expr") -> "sympy.Expr": 4570 """Apply symbol replacements to any symbols in the given expression 4571 """ 4572 replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols} 4573 return safe_expand(expr.xreplace(replacements)) 4574 4575 @_lru_cache 4576 def _update_divisible(self): 4577 new_divisible = set() 4578 for k in self.divisible: 4579 res = self.replace(k) 4580 if not res.is_number: 4581 new_divisible.add(k) 4582 4583 self.divisible = new_divisible 4584 self._update_version_counter() 4585 4586 @_lru_cache 4587 def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": 4588 """Use known constraints and replacements to simplify the given expr 4589 """ 4590 expr = self.replace(expr) 4591 # TODO it would seem that this pass is not necessary given the 4592 # below replacement of // with /, but for nested FloorDivs 4593 # the non-recursive replacement doesn't work, and 4594 # recursive makes it hard to look up divisibility, 4595 # because existing divisibility info has FloorDiv in it, not / 4596 # for now just do a separate pass to catch common nested case 4597 if expr.has(FloorDiv): 4598 self._update_divisible() 4599 div_replacements = {} 4600 for atom in expr.atoms(FloorDiv): 4601 base, divisor = atom.args 4602 if isinstance(divisor, FloorDiv): 4603 base1, divisor1 = divisor.args 4604 if self.replace(Mod(base, divisor)) in self.divisible and \ 4605 base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible: 4606 div_replacements[atom] = divisor1 4607 expr = expr.xreplace(div_replacements) 4608 expr = safe_expand(expr) 4609 if expr.has(FloorDiv): 4610 div_replacements = {} 4611 pows = expr.atoms(sympy.Pow) 4612 rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer)) 4613 for fd in expr.atoms(FloorDiv): 4614 base, divisor = fd.args 4615 if self.replace(Mod(base, divisor)) in self.divisible: 4616 div_replacements[fd] = CleanDiv(base, divisor) 4617 new_expr = expr.xreplace(div_replacements) 4618 new_expr = safe_expand(new_expr) 4619 new_pows = new_expr.atoms(sympy.Pow) 4620 new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer)) 4621 # divisions simplified away 4622 if new_pows.issubset(pows) and new_rationals.issubset(rationals): 4623 expr = new_expr 4624 return expr 4625 4626 @lru_cache(256) 4627 def size_hint(self, expr: "sympy.Expr", *, allow_none=False): 4628 """ 4629 Gets a size hint for a given expression from the underlying shapes we had. 4630 Does not introduce a guard, so only use this when you can guarantee that 4631 your code is still valid for arbitrary shapes (such as optimization decisions) 4632 """ 4633 result_expr = safe_expand(expr).xreplace(self.var_to_val) 4634 if not result_expr.is_number: 4635 4636 from torch.utils._sympy.singleton_int import SingletonInt 4637 4638 if isinstance(result_expr, SingletonInt): 4639 return None 4640 r = self._maybe_evaluate_static(result_expr, compute_hint=True) 4641 if r is not None: 4642 return r 4643 if allow_none: 4644 return None 4645 4646 if self.unbacked_var_to_val: 4647 unsound_expr = result_expr.xreplace(self.unbacked_var_to_val) 4648 if not unsound_expr.free_symbols: 4649 log.warning("propagate_real_tensors size_hint(%s) -> %s", expr, unsound_expr) 4650 trace_structured( 4651 "propagate_real_tensors", 4652 metadata_fn=lambda: { 4653 "expr": repr(expr), 4654 "result": repr(unsound_expr), 4655 "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), 4656 }, 4657 ) 4658 self.defer_runtime_assert( 4659 sympy.Eq(result_expr, unsound_expr), 4660 f"propagate_real_tensors: {result_expr} == {unsound_expr}" 4661 ) 4662 return unsound_expr 4663 4664 raise self._make_data_dependent_error(result_expr, expr) 4665 return result_expr 4666 4667 # NB: keep in sync with size_hint 4668 @lru_cache(256) 4669 def has_hint(self, expr: "sympy.Expr"): 4670 result_expr = safe_expand(expr).xreplace(self.var_to_val) 4671 return result_expr.is_number or self._maybe_evaluate_static(result_expr) is not None 4672 4673 def _make_data_dependent_error(self, expr, unhinted_expr, *, size_oblivious_result: Optional[bool] = None): 4674 # TODO: in a Dynamo context, having user code, and having the 4675 # name of the local, will be much better 4676 size_like_symbols = [] 4677 for s in expr.free_symbols: 4678 stacktrace = ''.join(self.var_to_stack[s].format()) 4679 self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace) 4680 if s in self.size_like: 4681 size_like_symbols.append(s) 4682 size_oblivious_result_msg = "" 4683 if size_oblivious_result is not None: 4684 size_oblivious_result_msg = ( 4685 f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n" 4686 "Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n" 4687 ) 4688 fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(True) 4689 if expr.is_integer: 4690 desc = "Could not extract specialized integer from data-dependent expression" 4691 else: 4692 desc = "Could not guard on data-dependent expression" 4693 msg = ( 4694 f"{desc} {expr} (unhinted: {unhinted_expr}). " 4695 f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n" 4696 f"{size_oblivious_result_msg}" 4697 "Potential framework code culprit (scroll up for full backtrace):\n" 4698 f"{''.join(traceback.StackSummary.from_list([fsummary]).format())}\n" 4699 'For more information, run with TORCH_LOGS="dynamic"\n' 4700 "For extended logs when we create symbols, also add " 4701 f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n" 4702 "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" 4703 "For more debugging help, see " 4704 "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" + 4705 maybe_extra_debug 4706 # TODO: Help text about how to use our runtime tests to fix this 4707 # problem 4708 ) 4709 return GuardOnDataDependentSymNode(expr, msg) 4710 4711 def _update_var_to_range(self, symbol, vr): 4712 lower, upper = vr.lower, vr.upper 4713 4714 # If we have a size-like unbacked SymInt, refuse to refine the range to be 4715 # less than two. This is because when we intersect this range 4716 # with [2, inf] for size oblivious tests, the range would be 4717 # unsatisfiable. In other words, once you have a size-like 4718 # unbacked SymInt, we can never learn that it is exactly zero or one, 4719 # because we would now give inconsistent results for all size 4720 # oblivous tests! 4721 if upper < 2 and symbol in self.size_like: 4722 upper = 2 4723 4724 # Updates the range and the guards corresponding to each bound of the symbol. 4725 if symbol not in self.var_to_range: 4726 r = ValueRanges(lower, upper) 4727 self.log.debug("_update_var_to_range %s = %s (new)", symbol, r) 4728 self.var_to_range[symbol] = r 4729 else: 4730 old = self.var_to_range[symbol] 4731 new = old & ValueRanges(lower, upper) 4732 if new != old: 4733 self.var_to_range[symbol] = new 4734 self.log.debug("_update_var_to_range %s = %s (update)", symbol, new) 4735 4736 if (v := self.var_to_val.get(symbol)) is not None: 4737 r = self.var_to_range[symbol] 4738 assert v in r, f"{v} not in {r}" 4739 4740 def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None: 4741 """ 4742 Adds or updates a replacement for a symbol. 4743 Use this instead of `self.replacements[a] = tgt`. 4744 """ 4745 4746 if tgt == self.replacements.get(a, None): 4747 return 4748 4749 # Precondition: a == tgt 4750 assert isinstance(a, sympy.Symbol) 4751 4752 if self.allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt): 4753 return # continuing leads to placeholder shapes having complex expressions that we can't resolve 4754 4755 # Handles nested tensor symbolic variables which don't have 4756 # var_to_range bounds 4757 tgt_bound = None 4758 if a in self.var_to_range: 4759 src_bound = self.var_to_range[a] 4760 4761 # First, refine the value range of a based on the computed value range 4762 # of tgt. This is always OK to do, even if we decide not to do the 4763 # substitution in the end. This might be a no-op, if a already has 4764 # a tighter bound 4765 tgt_bound = self.bound_sympy(tgt) 4766 self._update_var_to_range(a, tgt_bound) 4767 4768 # Next, check if we can update the range of free symbols in tgt 4769 # based on the range in a. But only do it if: 4770 # - the source bound non-trivially improves over what we get out of 4771 # the existing bounds. 4772 # - the replacement is univariate and we can invert the tgt expression 4773 if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1: 4774 b = next(iter(tgt.free_symbols)) 4775 # Try to invert the equality 4776 r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) 4777 if r is not None: 4778 self.log.debug("set_replacement: solve for %s in %s == %s gives %s", b, a, tgt, r) 4779 # The solution here can be non-integral, for example, if 4780 # we have s0 = 2*s1, then s1 = s0/2. What we would like 4781 # to do is calculated the bounds in arbitrary precision, 4782 # and then requantize the bound to integers when we are 4783 # done. 4784 rat_b_bound = self.bound_sympy(r[1]) 4785 b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)) 4786 self._update_var_to_range(b, b_bound) 4787 tgt_bound = self.bound_sympy(tgt) 4788 assert tgt_bound.issubset(src_bound) 4789 4790 # TODO: Should we propagate size-like-ness? 4791 # 4792 # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1 4793 # to become size-like. 4794 # 4795 # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T 4796 # propagate in this case, because what if u0 == 0, then u1 is negative 4797 # and clearly isn't a size. So, at minimum, any f(x) whose value 4798 # range isn't [0, inf] given x in [0, inf] cannot propagate 4799 # size-like-ness. But there are many situations where you could 4800 # imagine u1 is going to be size-like and actually you just didn't 4801 # have a refined enough value range on u0. Since even innocuous 4802 # looking arithmetic operations can destroy size-like-ness, it's 4803 # best to not propagate it at all and force the user to annotate it 4804 # as necessary. 4805 # 4806 # Compromise: we preserve size-like-ness only for exact equality 4807 # and nothing else. 4808 if a in self.size_like and isinstance(tgt, sympy.Symbol): 4809 self.size_like.add(tgt) 4810 elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like: 4811 self.size_like.add(a) 4812 4813 # Now, decide if we will do the substitution. 4814 # 4815 # - If the source has a non-trivial range, only substitute if 4816 # we preserve this range. Note that we may have propagated 4817 # the src_range to free variables in tgt when tgt is univariate 4818 # and we could find an inverse, which helps us achieve this. 4819 # This ensures we never "forget" about user defined ranges, 4820 # even if they end up being defined on composite formulas 4821 # like s0 + s1. 4822 # 4823 # - If the variable is unbacked, only substitute if the substitution 4824 # would preserve the bounds also under size-like-ness conditions. 4825 4826 if not tgt_bound.issubset(src_bound): 4827 self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound) 4828 return 4829 elif a in self.size_like: 4830 tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) 4831 src_bound_so = self.bound_sympy(a, size_oblivious=True) 4832 if not tgt_bound_so.issubset(src_bound_so): 4833 self.log.debug("skipped set_replacement %s = %s (%s) " 4834 "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) 4835 return 4836 4837 if isinstance(tgt, (sympy.Integer, sympy.Float)): 4838 # specializing to a constant, which is likely unexpected (unless 4839 # you specified dynamic=True) 4840 4841 user_tb = TracingContext.extract_stack() 4842 trace_structured( 4843 "symbolic_shape_specialization", 4844 metadata_fn=lambda: { 4845 "symbol": repr(a), 4846 "sources": [s.name() for s in self.var_to_sources.get(a, [])], 4847 "value": repr(tgt), 4848 "reason": msg, 4849 "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), 4850 "user_stack": structured.from_traceback(user_tb) if user_tb else None, 4851 } 4852 ) 4853 4854 if config.print_specializations: 4855 self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt) 4856 self.log.debug("SPECIALIZATION", stack_info=True) 4857 log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) 4858 self.replacements[a] = tgt 4859 self._update_version_counter() 4860 4861 # When specializing 'a == tgt', the equality should be also conveyed to 4862 # Z3, in case an expression uses 'a'. 4863 self._add_target_expr(sympy.Eq(a, tgt)) 4864 4865 def _add_divisible(self, expr: "sympy.Expr"): 4866 self.divisible.add(expr) 4867 self._update_version_counter() 4868 4869 @_lru_cache 4870 @record_shapeenv_event() 4871 def _find(self, a: "sympy.Symbol") -> "sympy.Expr": 4872 """ 4873 Implements a DSU-like algorithm to find the variable that represents a 4874 Also handles transitive non-identity replacements. 4875 4876 a: b + c 4877 c: d 4878 """ 4879 if a not in self.replacements: 4880 return a 4881 res = self.replacements[a] 4882 cur_replace = {s: self._find(s) for s in res.free_symbols} 4883 replaced, changed = self.replacements[a]._xreplace(cur_replace) 4884 if changed: 4885 self._set_replacement(a, replaced, "find") 4886 return self.replacements[a] 4887 4888 @lru_cache(256) 4889 def _maybe_guard_rel(self, expr: "sympy.Rel") -> None: 4890 """ 4891 The relational guard is guarded to be true. Use this information to 4892 simplify shapes (i.e. a == b or a % 5 == 0) 4893 """ 4894 assert isinstance(expr, sympy.Rel) 4895 4896 # A good example of what goes wrong if you don't do this is 4897 # python test/functorch/test_aotdispatch.py -k 4898 # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32 4899 if isinstance(expr, sympy.Ne): 4900 return 4901 4902 free = list(expr.free_symbols) 4903 4904 assert len(free) > 0, f"The expression should not be static by this point: {expr}" 4905 # In case of really gnarly expression, we don't blow up 4906 if len(free) > 5: 4907 return 4908 4909 # Prioritize unbacked symints for solving by ordering them last. 4910 # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3). 4911 # (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols) 4912 # Prefer to simplify out symbols with ephemeral sources. 4913 def _smart_symbol_sort(x): 4914 has_only_ephemeral_sources = ( 4915 x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x]) 4916 ) 4917 # NB: size_hint is int, not sympy.Expr, do not use int_oo here 4918 size = self.size_hint(x, allow_none=True) or sys.maxsize 4919 name = x.name 4920 # 1 puts ephemeral sourced symbols first when sorting in reverse 4921 return (1 if has_only_ephemeral_sources else 0, size, name) 4922 4923 free = sorted(free, key=_smart_symbol_sort, reverse=True) # type: ignore[attr-defined] 4924 lhs = expr.lhs 4925 rhs = expr.rhs 4926 4927 self._refine_ranges(expr) 4928 4929 # The rest of this stuff is for equality only 4930 if not isinstance(expr, sympy.Eq): 4931 return 4932 4933 if not expr.has(Mod): 4934 try: 4935 floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv)) 4936 if len(floor_div_atoms) > 0 and any(a.divisor != 1 for a in floor_div_atoms): 4937 raise NotImplementedError 4938 4939 # Never replace unbacked symbols with other unbacked symbols. 4940 # This is error prone because you can cause references to 4941 # unbacked symbols to time travel backwards. E.g., 4942 # 4943 # u1 = x.item() 4944 # ... use of u1 ... 4945 # u2 = y.item() 4946 # u3 = z.item() 4947 # torch._check(u1 == u2 + u3) 4948 # 4949 # If you replace u1 with u2 + u3, then the use of u1 now 4950 # references u2 and u3 prior to them actually being bound at 4951 # runtime. It's pretty inconvenient to setup control 4952 # dependencies for substitutions, so ban it entirely. 4953 def trivial_solve(lhs, rhs): 4954 if isinstance(lhs, sympy.Symbol): 4955 if free_unbacked_symbols(lhs) and not free_unbacked_symbols(rhs): 4956 return True 4957 if symbol_is_type(lhs, SymT.FLOAT): 4958 return True 4959 # TODO: Maybe trivial solutions for int should also be 4960 # done? 4961 return False 4962 4963 # short-circuit when no solving is needed 4964 if trivial_solve(lhs, rhs): 4965 self._set_replacement(lhs, self._find(rhs), "trivial_lhs") 4966 elif trivial_solve(rhs, lhs): 4967 self._set_replacement(rhs, self._find(lhs), "trivial_rhs") 4968 else: 4969 r = try_solve(expr, free[0], floordiv_inequality=False) 4970 if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])): 4971 new_var = self._find(r[1]) 4972 ok = len(free_unbacked_symbols(new_var)) == 0 4973 if ok: 4974 self._set_replacement(cast(sympy.Symbol, free[0]), new_var, "solve") 4975 except NotImplementedError: 4976 pass 4977 if expr.has(Mod): 4978 mod_expr = next(iter(expr.atoms(Mod))) 4979 try: 4980 r = try_solve(expr, mod_expr, floordiv_inequality=False) 4981 if r is not None and r[1] == 0: 4982 self._add_divisible(mod_expr) 4983 # This is a little bit of extra logic to make things like 4984 # torch.empty(i0, q).view(c, -1, q) work out 4985 p, q = mod_expr.args 4986 if isinstance(q, sympy.Number) and isinstance(p, sympy.Mul) and len(p.args) == 2: 4987 c, i0 = p.args 4988 # Given Mod(c * i0, q) == 0 4989 if ( 4990 isinstance(c, sympy.Number) and 4991 isinstance(i0, sympy.Symbol) and 4992 self.is_unbacked_symint(i0) 4993 ): 4994 # We have Mod(i0, q / c) == 0, which means we can 4995 # rewrite i0 as (q / gcd(q, c)) * i1 4996 d = q / sympy.gcd(q, c) # TODO: CleanDiv? 4997 i1 = self.create_unbacked_symint().node.expr 4998 # Propagate the value ranges. It doesn't really 4999 # matter if we use truediv or floordiv, because we 5000 # have established divisibility. 5001 self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv( 5002 self.var_to_range[i0], ValueRanges.wrap(d) 5003 )) 5004 # Propagate size-like-ness 5005 if i0 in self.size_like: 5006 self.size_like.add(i1) 5007 self._set_replacement(i0, d * i1, "divisibility") 5008 5009 except NotImplementedError: 5010 pass 5011 return 5012 5013 # See: Note - On 0/1 specialization 5014 def _default_value_range(self) -> ValueRanges: 5015 lower = 2 if self.specialize_zero_one else 0 5016 return ValueRanges(lower, int_oo) 5017 5018 def _default_unspecified_value_range(self) -> ValueRanges: 5019 return ValueRanges(-int_oo, int_oo) 5020 5021 @_lru_cache 5022 def _simplify_floor_div(self, expr): 5023 floor_divs = tuple(expr.atoms(FloorDiv)) 5024 # we expect floor_divs to be exact, 5025 # and thus add the guards for the exact floordivs, 5026 # even if tracing doesn't require them otherwise 5027 for fd in reversed(floor_divs): 5028 base, divisor = fd.args 5029 mod_expr = Mod(base, divisor) 5030 eq_expr = sympy.Eq(mod_expr, 0) 5031 # add necessary mod guards 5032 self.evaluate_expr(eq_expr) 5033 return self.simplify(expr) 5034 5035 # We're about to add a guard/runtime assert, check if the ShapeEnv is frozen 5036 # and if so issue a warning 5037 def _check_frozen(self, expr, concrete_val): 5038 if self.frozen: 5039 self.counter["ignored_backward_guard"] += 1 5040 signpost_event( 5041 "dynamic", 5042 "evaluate_expr_frozen", 5043 { 5044 **self.co_fields, 5045 "ignored_guard": f"{expr} == {concrete_val}", 5046 # no version = original state (this signpost is expected) 5047 # version 2 = dynamic backwards is eagerly compiled 5048 "version": 2, 5049 }, 5050 ) 5051 log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val, stack_info=True) 5052 5053 5054 def _get_stack_summary(self, is_debug: bool = False): 5055 fsummary = None 5056 frame = inspect.currentframe() 5057 try: 5058 while frame is not None: 5059 if frame.f_code.co_filename not in uninteresting_files(): 5060 fsummary = traceback.FrameSummary( 5061 frame.f_code.co_filename, 5062 frame.f_lineno, 5063 frame.f_code.co_name, 5064 ) 5065 break 5066 frame = frame.f_back 5067 finally: 5068 del frame 5069 5070 # NB: this stack is truncated, but it's fine because the main 5071 # stack_info will give you the rest of the info you need 5072 maybe_user_loc = "" 5073 user_tb = TracingContext.extract_stack() 5074 if user_tb: 5075 maybe_user_loc = " at " + format_frame(user_tb[-1]) 5076 5077 maybe_extra_debug = "" 5078 if is_debug and user_tb: 5079 maybe_extra_debug = ( 5080 '\nUser Stack (most recent call last):\n' + 5081 ' (snipped, see stack below for prefix)\n' + 5082 ''.join(traceback.format_list(user_tb)) 5083 ) 5084 if is_debug and config.extended_debug_cpp: 5085 cpp_stack = CapturedTraceback.extract(cpp=True) 5086 maybe_extra_debug += "\nC++ stack trace:\n" + ''.join(cpp_stack.format()) 5087 elif is_debug: 5088 maybe_extra_debug += ( 5089 "\nFor C++ stack trace, run with " 5090 "TORCHDYNAMO_EXTENDED_DEBUG_CPP=1" 5091 ) 5092 5093 return fsummary, maybe_user_loc, maybe_extra_debug 5094 5095 def _log_guard(self, prefix: str, g, forcing_spec: bool): 5096 if self.log.isEnabledFor(logging.INFO): 5097 str_g = str(g) 5098 is_debug = config.extended_debug_guard_added is not None and str_g == config.extended_debug_guard_added 5099 fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) 5100 maybe_more_info = "" 5101 if not is_debug: 5102 maybe_more_info = ( 5103 ", for more info run with " 5104 f'TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="{str_g}"' 5105 ) 5106 self.log.info( 5107 "%s %s [guard added]%s (%s)%s%s", 5108 prefix if not forcing_spec else f"{prefix} (forcing_spec)", 5109 str_g, 5110 maybe_user_loc, 5111 format_frame(fsummary), 5112 maybe_more_info, 5113 maybe_extra_debug, 5114 stack_info=is_debug, 5115 ) 5116 5117 @lru_cache(256) 5118 @record_shapeenv_event(save_tracked_fakes=True) 5119 def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None, 5120 size_oblivious: bool = False, *, forcing_spec: bool = False): 5121 try: 5122 return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec) 5123 except Exception: 5124 self.log.warning( 5125 "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s", 5126 orig_expr, hint, size_oblivious, forcing_spec 5127 ) 5128 raise 5129 5130 def _evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None, 5131 size_oblivious: bool = False, *, forcing_spec: bool = False): 5132 """ 5133 Given an expression, evaluates it, adding guards if necessary 5134 """ 5135 5136 # TODO: split conjunctions and evaluate them separately 5137 5138 # Don't track this one 5139 @functools.lru_cache(None) 5140 def compute_concrete_val(): 5141 if hint is None: 5142 return self.size_hint(orig_expr) 5143 else: 5144 return sympy.sympify(hint) 5145 5146 # Check if: 5147 # 1. 'translation_validation' is set 5148 # 2. the corresponding 'fx_node' is not 'None' 5149 # 3. the guard should not be suppressed 5150 # 5151 # If all of the above check, we create an FX node representing the 5152 # actual expression to be guarded. 5153 node = None 5154 fresh = False 5155 if ( 5156 self._translation_validation_enabled 5157 and fx_node is not None 5158 and not self._suppress_guards_tls() 5159 and not size_oblivious 5160 ): 5161 concrete_val = compute_concrete_val() 5162 if concrete_val is sympy.true: 5163 node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) 5164 elif concrete_val is sympy.false: 5165 neg, _ = self._create_fx_call_function(operator.not_, (fx_node,)) 5166 node, fresh = self._create_fx_call_function(torch._assert, (neg,)) 5167 else: 5168 eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val)) 5169 node, fresh = self._create_fx_call_function(torch._assert, (eql,)) 5170 5171 assert node is not None 5172 # If this is a fresh node, we have to remember the event index that 5173 # corresponds to this assertion node. 5174 # Reason: so that, given an assertion node, we can replay the ShapeEnv 5175 # events until the point where this assertion node was freshly created. 5176 if fresh: 5177 self._add_fx_node_metadata(node) 5178 5179 # After creating the FX node corresponding to orig_expr, we must make sure that 5180 # no error will be raised until the end of this function. 5181 # 5182 # Reason: the translation validation may become invalid otherwise. 5183 # 5184 # If an error is raised before the end of this function, we remove the FX node 5185 # inserted, and re-raise the error. 5186 guard = None 5187 tb = None 5188 5189 try: 5190 if orig_expr.is_number: 5191 self.log.debug("eval %s [trivial]", orig_expr) 5192 if hint is not None: 5193 assert orig_expr == hint, f"{orig_expr} != {hint}" 5194 return orig_expr 5195 5196 expr = orig_expr 5197 5198 static_expr = self._maybe_evaluate_static(expr, 5199 size_oblivious=size_oblivious) 5200 if static_expr is not None: 5201 self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr) 5202 if hint is not None: 5203 assert static_expr == hint, f"{static_expr} != {hint}" 5204 return static_expr 5205 5206 transmute_into_runtime_assert = False 5207 5208 concrete_val = None 5209 if not (expr.free_symbols <= self.var_to_val.keys()): 5210 # TODO: dedupe this with _maybe_evaluate_static 5211 # Attempt to eliminate the unbacked SymInt 5212 new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) 5213 if not (new_expr.free_symbols <= self.var_to_val.keys()): 5214 size_oblivious_result = None 5215 if not size_oblivious: 5216 size_oblivious_result = self._maybe_evaluate_static( 5217 expr, 5218 size_oblivious=True 5219 ) 5220 5221 # Last ditch 5222 if ( 5223 self.unbacked_var_to_val and 5224 not (unsound_result := orig_expr.xreplace(self.unbacked_var_to_val)).free_symbols 5225 ): 5226 log.warning("propagate_real_tensors evaluate_expr(%s) -> %s", orig_expr, unsound_result) 5227 trace_structured( 5228 "propagate_real_tensors", 5229 metadata_fn=lambda: { 5230 "expr": repr(orig_expr), 5231 "result": repr(unsound_result), 5232 "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), 5233 }, 5234 ) 5235 transmute_into_runtime_assert = True 5236 concrete_val = unsound_result 5237 else: 5238 raise self._make_data_dependent_error( 5239 expr.xreplace(self.var_to_val), 5240 expr, 5241 size_oblivious_result=size_oblivious_result 5242 ) 5243 else: 5244 expr = new_expr 5245 5246 if concrete_val is None: 5247 concrete_val = compute_concrete_val() 5248 self._check_frozen(expr, concrete_val) 5249 5250 if ( 5251 config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY 5252 and isinstance(hint, bool) 5253 and isinstance(expr, (sympy.Eq, sympy.Ne)) 5254 ): 5255 expr = sympy.Not(expr) 5256 5257 # Turn this into a boolean expression, no longer need to consult 5258 # concrete_val 5259 if concrete_val is sympy.true: 5260 g = expr 5261 elif concrete_val is sympy.false: 5262 g = sympy.Not(expr) 5263 else: 5264 g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type] 5265 5266 if transmute_into_runtime_assert: 5267 self.defer_runtime_assert( 5268 g, 5269 f"propagate_real_tensors: {orig_expr} == {unsound_result}" 5270 ) 5271 return concrete_val 5272 5273 if not self._suppress_guards_tls(): 5274 if isinstance(g, sympy.Rel): 5275 # TODO: If we successfully eliminate a symbol via equality, it 5276 # is not actually necessary to save a guard for the equality, 5277 # as we will implicitly generate a guard when we match that 5278 # input against the symbol. Probably the easiest way to 5279 # implement this is to have maybe_guard_rel return a bool 5280 # saying if it "subsumed" the guard (and therefore the guard 5281 # is no longer necessary) 5282 self._maybe_guard_rel(g) 5283 5284 if not self.allow_complex_guards_as_runtime_asserts: 5285 # at this point, we've evaluated the concrete expr value, and have 5286 # flipped/negated the guard if necessary. Now we know what to guard 5287 # or defer to runtime assert on. 5288 stack = CapturedTraceback.extract(skip=1) 5289 guard = ShapeGuard(g, stack) 5290 self.guards.append(guard) 5291 else: 5292 # it's fine to defer simple guards here without checking, 5293 # the _maybe_guard_rel() call above will set replacements if possible, 5294 # and so the result here will be statically known 5295 self.defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") 5296 5297 except Exception: 5298 if fresh: 5299 self._remove_fx_node(node) 5300 raise 5301 else: 5302 if not self._suppress_guards_tls(): 5303 if guard is not None: # we might have deferred this to runtime assert 5304 self._log_guard("eval", g, forcing_spec=forcing_spec) 5305 5306 for s in g.free_symbols: 5307 self.symbol_guard_counter[s] += 1 5308 # Forcing_spec to avoid infinite recursion 5309 if ( 5310 not forcing_spec and 5311 config.symbol_guard_limit_before_specialize is not None and 5312 self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize 5313 ): 5314 # Force specialization 5315 self.log.info( 5316 "symbol_guard_limit_before_specialize=%s exceeded on %s", 5317 config.symbol_guard_limit_before_specialize, 5318 s 5319 ) 5320 self.evaluate_expr(s, forcing_spec=True) 5321 else: 5322 self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) 5323 5324 return concrete_val 5325 5326 def cleanup(self): 5327 """ 5328 Break reference cycles. 5329 5330 This destroys the stacks. If you really want to keep them, we 5331 just need some way to break references on code objects. 5332 """ 5333 for g in self.guards: 5334 g.stack.cleanup() 5335 for s in self.var_to_stack.values(): 5336 s.cleanup() 5337 for ras in self.deferred_runtime_asserts.values(): 5338 for ra in ras: 5339 ra.stack.cleanup() 5340 5341 @record_shapeenv_event(save_tracked_fakes=True) 5342 def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None): 5343 """Create an assert that is checked at runtime 5344 5345 Args: 5346 orig_expr (sympy.Expr): Boolean expression to assert is true 5347 msg (str): Message to display on assertion failure 5348 fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding 5349 to the expression, if applicable 5350 5351 """ 5352 expr = orig_expr 5353 5354 # TODO: split conjunctions and evaluate them separately 5355 5356 static_expr = self._maybe_evaluate_static(expr) 5357 if static_expr is not None: 5358 self.log.debug("runtime_assert %s == %s [statically known]", orig_expr, static_expr) 5359 return static_expr 5360 5361 # Attempt to eliminate the unbacked SymInt 5362 new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) 5363 if not self.prefer_deferred_runtime_asserts_over_guards and new_expr.free_symbols <= self.var_to_val.keys(): 5364 # Do a normal guard 5365 return self.evaluate_expr(new_expr, fx_node=fx_node) 5366 # NB: Don't use new_expr as expr; it could contain gunk like shape0 5367 # which we don't want to guard on 5368 5369 # OK, we're definitely doing a runtime assert now 5370 if ( 5371 self._translation_validation_enabled 5372 and fx_node is not None 5373 and not self._suppress_guards_tls() 5374 ): 5375 node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) 5376 assert node is not None 5377 if fresh: 5378 self._add_fx_node_metadata(node) 5379 5380 if not self._suppress_guards_tls(): 5381 # If you're here because of this assert, read Note [Backwards runtime asserts] 5382 # in torch/_inductor/graph.py 5383 assert not self.runtime_asserts_frozen, expr 5384 5385 self._check_frozen(expr, sympy.true) 5386 5387 # eliminate symbols on equality tests / refine ranges 5388 if isinstance(expr, sympy.Rel): 5389 self._maybe_guard_rel(expr) 5390 5391 # canonicalise to remove equations that are trivially equal 5392 orig_expr = expr 5393 expr = canonicalize_bool_expr(expr) 5394 stack = CapturedTraceback.extract(skip=1) 5395 ra = RuntimeAssert(expr, msg, stack) 5396 # TODO: Do this in a way that is less janky than int(s.name[1:]) 5397 cands = sorted((s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)), key=lambda s: int(s.name[1:])) 5398 # Is None when prefer_deferred_runtime_asserts_over_guards=True 5399 # and the guard in question has no unbacked SymInts in front 5400 ix = cands[-1] if cands else None 5401 self.deferred_runtime_asserts.setdefault(ix, []).append(ra) 5402 self.num_deferred_runtime_asserts += 1 5403 self._update_version_counter() 5404 self._log_guard("runtime_assert", orig_expr, forcing_spec=False) 5405 else: 5406 self._log_guard("runtime_assert [guard suppressed]", orig_expr, forcing_spec=False) 5407 5408 return True 5409 5410 # Refines the ranges of the variables present in 'guard'. 5411 # 5412 # This function tries to refine the range of the variables inside 5413 # 'guard' by reasoning about it. Specifically, when 'guard' is a 5414 # 'sympy.Relational' operation. 5415 # 5416 # It does mainly 3 things: 5417 # 1. Tries to isolate a variable in the left-hand side 5418 # 2. Compute the value range of the right-hand side 5419 # 3. Update the value range of the variable, if better 5420 def _refine_ranges(self, expr: sympy.Expr) -> None: 5421 expr = self.simplify(expr) 5422 5423 for symbol in expr.free_symbols: 5424 assert isinstance(symbol, sympy.Symbol) 5425 5426 if isinstance(self.var_to_val.get(symbol, None), SingletonInt): 5427 # Skip var_to_range logic for SingletonInt which is only used 5428 # for jagged layout NestedTensors today 5429 continue 5430 5431 r = try_solve(expr, symbol) 5432 5433 if r is None or not (symbol.is_integer and r[1].is_integer): 5434 # Range refinement only supports integer symbols for now. 5435 # There are lots of SymPy bugs when it comes to comparing 5436 # reals and integers, so we skip that for now. 5437 continue 5438 5439 r_expr, rhs = r 5440 vr = self.var_to_range[symbol] 5441 lower, upper = vr.lower, vr.upper 5442 5443 rhs_vr = bound_sympy(rhs, self.var_to_range) 5444 5445 # Let's suppose that we have a preexisting range for x [0, 100]. 5446 # Now, we issue a guard x > y, where the range for y is [50, 150]. 5447 # Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen, 5448 # refining x to [51, 100], since x must be greater than y, but the lowest 5449 # y could be is 50. 5450 # 5451 # sympy.Eq may update both lower and upper bounds. 5452 # sympy.G{t,e} may update the lower bound, only. 5453 # sympy.L{t,e} may update the upper bound, only. 5454 if lower < rhs_vr.lower and isinstance(r_expr, (sympy.Eq, sympy.Ge, sympy.Gt)): 5455 # Strictly greater relations allow us to refine a bit more, since 5456 # x < y implies that the lower bound for x is: y + 1. 5457 lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt)) 5458 if upper > rhs_vr.upper and isinstance(r_expr, (sympy.Eq, sympy.Le, sympy.Lt)): 5459 upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt)) 5460 5461 # Do nothing if the new value range is no better than what we already have. 5462 if vr == ValueRanges(lower, upper): 5463 continue 5464 5465 # Updates the range and the guards corresponding to each bound of the symbol. 5466 self._update_var_to_range(symbol, ValueRanges(lower, upper)) 5467 # If the range is refined to singleton, set replacement 5468 if self.var_to_range[symbol].is_singleton(): 5469 self._set_replacement(symbol, self.var_to_range[symbol].lower, "range_refined_to_singleton") 5470 5471 # Clears the cache, since this update can change the result. 5472 self._maybe_evaluate_static.cache_clear() 5473 5474 @lru_cache(maxsize=None) 5475 @record_shapeenv_event() 5476 def constrain_symbol_range(self, s: sympy.Symbol, compiler_min: int, compiler_max: int): 5477 upd_vr = ValueRanges(compiler_min, compiler_max) 5478 old_vr = self.var_to_range.get(s, ValueRanges.unknown()) 5479 self._update_var_to_range(s, upd_vr) 5480 if (new_vr := self.var_to_range[s]) != old_vr: 5481 log.info("constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper) 5482 5483 5484def _is_int(expr): 5485 return isinstance(expr, SymInt) and expr.node.expr.is_number 5486 5487# WARNING: This is legacy, DO NOT USE 5488def _is_dim_dynamic(t, d): 5489 return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices 5490 5491class PropagateUnbackedSymInts(torch.fx.Interpreter): 5492 def run_node(self, n: torch.fx.Node): 5493 """ 5494 Run an FX node, propagating unbacked Symbol bindings to the new fake tensor 5495 """ 5496 from torch._guards import detect_fake_mode 5497 5498 result = super().run_node(n) 5499 rebind_unbacked(detect_fake_mode().shape_env, n, result) 5500 return result 5501 5502 5503def _find_user_code_frame(): 5504 frame = inspect.currentframe() 5505 while frame is not None: 5506 if not frame.f_code.co_filename.startswith( 5507 os.path.dirname(inspect.getfile(torch)) + os.path.sep 5508 ): 5509 break 5510 frame = frame.f_back 5511 return frame 5512 5513 5514def _blame_user_code(e, frame): 5515 frame_summary = traceback.FrameSummary( 5516 frame.f_code.co_filename, 5517 frame.f_lineno, 5518 frame.f_code.co_name, 5519 ) 5520 msg = e.args[0] 5521 msg += ( 5522 '\n\nThe following call raised this error:\n' + 5523 ''.join(traceback.StackSummary.from_list([frame_summary]).format()) 5524 ) 5525 e.args = (msg,) 5526 5527 5528class _PythonPrinter(sympy.printing.str.StrPrinter): 5529 """ 5530 Util printer that replaces sympy symbols with their source-level names 5531 and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline 5532 (i.e., as ==, !=, >, <). 5533 """ 5534 5535 def __init__(self, src_map): 5536 super().__init__() 5537 self.src_map = src_map 5538 5539 def _print_Symbol(self, sym): 5540 return self.src_map[sym.name][0] 5541 5542 def _print_Relational(self, expr): 5543 lhs = self.parenthesize(expr.lhs, sympy.printing.precedence.precedence(expr)) 5544 rel_op = expr.rel_op 5545 rhs = self.parenthesize(expr.rhs, sympy.printing.precedence.precedence(expr)) 5546 return f"{lhs} {rel_op} {rhs}" 5547 5548 5549def _suggest_torch_checks(e, src_map): 5550 # extract the unresolved condition on unbacked symints in the error 5551 cond = e.cond 5552 diff = ", ".join(s.name for s in cond.free_symbols if s.name not in src_map) 5553 if diff: 5554 log.warning("Unable to find user code corresponding to {%s}", diff) 5555 return 5556 printer = _PythonPrinter(src_map) 5557 msg = e.args[0] 5558 msg += "\nTo fix the error, insert one of the following checks before this call:" 5559 # suggested fixes to resolve `cond`` are to tell the compiler to assume 5560 # either `cond` or its negation (the user will need to select which) 5561 suggested_fixes = [ 5562 f"torch._check({printer.doprint(cond)})", 5563 f"torch._check({printer.doprint(sympy.Not(cond))})", 5564 ] 5565 for i, fix in enumerate(suggested_fixes): 5566 msg += f"\n {i+1}. {fix}" 5567 src_mapped = ', '.join( 5568 f"`{s}` with {' or '.join(src_map[s])}" 5569 for s in sorted(s.name for s in cond.free_symbols) 5570 ) 5571 msg += f"\n\n(These suggested fixes were derived by replacing {src_mapped} in {cond} and its negation.)" 5572 e.args = (msg,) 5573 5574 5575def _suggest_fixes_for_data_dependent_error_non_strict(e): 5576 """ 5577 Given a raised data-dependent error, add the following to the error message: 5578 1. the closest user code location that raised the error; 5579 2. suggested fixes for the error in terms of live variables at that location. 5580 """ 5581 5582 # walk the stack up from the data-dependent error until a non-torch frame is found 5583 frame = _find_user_code_frame() 5584 if frame is not None: 5585 # add frame info to error message 5586 _blame_user_code(e, frame) 5587 5588 # map symbol names reachable via frame locals to their source-level names 5589 src_map = defaultdict(list) 5590 for var, val in frame.f_locals.items(): 5591 # figure out how to access any symbol inside `val` through `var` 5592 for path, leaf in pytree.tree_leaves_with_path(val): 5593 name = var + pytree.keystr(path) 5594 if isinstance(leaf, torch.SymInt): 5595 src_map[str(leaf.node.expr)].append(name) 5596 elif isinstance(leaf, torch.Tensor): 5597 for i, dim in enumerate(leaf.shape): 5598 if isinstance(dim, torch.SymInt): 5599 src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]") 5600 5601 # add suggested torch.check()s based on `src_map` to the error message 5602 # replacing unbacked symints in the unresolved condition in the error 5603 _suggest_torch_checks(e, src_map) 5604