1# mypy: allow-untyped-defs 2import dataclasses 3import inspect 4import logging 5import sys 6from collections import defaultdict 7from enum import auto, Enum 8from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union 9 10import torch 11from torch.utils._pytree import ( 12 _get_node_type, 13 BUILTIN_TYPES, 14 keystr, 15 LeafSpec, 16 MappingKey, 17 SequenceKey, 18 SUPPORTED_NODES, 19 tree_flatten, 20 tree_map_with_path, 21) 22 23from .exported_program import ExportedProgram 24 25 26if TYPE_CHECKING: 27 from sympy import Symbol 28 29 from torch._guards import Source 30 from torch.fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint 31 32__all__ = [ 33 "Constraint", 34 "Dim", 35 "dims", 36 "refine_dynamic_shapes_from_suggested_fixes", 37] 38 39 40log = logging.getLogger(__name__) 41 42 43class _DimHint(Enum): 44 """ 45 Enum for dynamic shape hints. 46 - AUTO means automatic inference of shape (static or dynamic). 47 - STATIC means static shape (always specialized). 48 """ 49 50 AUTO = auto() 51 STATIC = auto() 52 53 54class _Dim(type): 55 """ 56 Metaclass for :func:`Dim` types. 57 """ 58 59 @staticmethod 60 def readable(name, min_, max_): 61 from torch.utils._sympy.numbers import int_oo 62 63 if min_ == 2: 64 min_ = None 65 if max_ == int_oo: 66 max_ = None 67 if min_ is None and max_ is None: 68 return f"Dim('{name}')" 69 if min_ is None: 70 return f"Dim('{name}', max={max_})" 71 if max_ is None: 72 return f"Dim('{name}', min={min_})" 73 return f"Dim('{name}', min={min_}, max={max_})" 74 75 def __add__(cls, other): 76 # e.g., dim + 1 77 if type(other) is not int: 78 raise NotImplementedError( 79 f"Attempted to add {other} to {cls.__name__}, where an integer was expected. " 80 "(Only increasing linear operations with integer coefficients are supported.)" 81 ) 82 return cls._derive(lambda x: x + other) 83 84 def __radd__(cls, other): 85 return cls + other 86 87 def __sub__(cls, other): 88 # e.g., dim - 1 89 if type(other) is not int: 90 raise NotImplementedError( 91 f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. " 92 "(Only increasing linear operations with integer coefficients are supported.)" 93 ) 94 return cls._derive(lambda x: x - other) 95 96 def __rsub__(cls, other): 97 raise NotImplementedError( 98 f"Attempted to negate {cls.__name__}. " 99 "(Only increasing linear operations with integer coefficients are supported.)" 100 ) 101 102 def __mul__(cls, other): 103 # e.g., dim * 2 104 if type(other) is not int or other <= 0: 105 raise NotImplementedError( 106 f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. " 107 "(Only increasing linear operations with integer coefficients are supported.)" 108 ) 109 return cls._derive(lambda x: x * other) 110 111 def __rmul__(cls, other): 112 return cls * other 113 114 def _derived_name(cls, fn): 115 from sympy import sympify 116 117 return str(fn(sympify(cls.__name__))) 118 119 def _derive(cls, fn): 120 return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn}) 121 122 123class _StaticDim(_Dim): 124 """ 125 Meta class for static :func:`Dim` types. 126 127 This class is only for setting and checking static dim constraints, 128 and the user should never interact with it. 129 """ 130 131 @property 132 def min(self): 133 return self.value # type: ignore[attr-defined] 134 135 @property 136 def max(self): 137 return self.value # type: ignore[attr-defined] 138 139 140class _DerivedDim(_Dim): 141 """ 142 Metaclass for derived :func:`Dim` types. 143 144 Currently we only support increasing linear expressions with integer coefficients. 145 In other words, a derived Dim can always be written in the form Ax + B, where 146 x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. 147 (In particular, the latter ensures that x < y => Ax + B < Ay + B.) 148 These restrictions on the form of derived Dims makes the metatheory simpler: e.g., 149 it simplifies computing ranges for derived Dims, solving for underlying regular Dims, 150 deciding equalities between derived Dims, and so on. 151 152 The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. 153 The range of a derived Dim is computed by mapping `fn` over the range of its `root`. 154 """ 155 156 @property 157 def min(self): 158 # assume that self.fn is an increasing function 159 # TODO(avik): use sympy value range analysis instead? 160 from sympy import Integer 161 162 from torch.utils._sympy.numbers import int_oo 163 164 if self.root.min is -int_oo: # type: ignore[attr-defined] 165 return -int_oo # fn not needed cuz increasing 166 167 _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] 168 root = self.root # type: ignore[attr-defined] 169 assert _min_symint >= 0, ( 170 f"Expected derived min value of {self.__name__} to be >= 0. " 171 f"Please specify an appropriate min value for {root.__name__} " 172 f"(currently {root.min})." 173 ) 174 return int(_min_symint) 175 176 @property 177 def max(self): 178 # assume that self.fn is an increasing function 179 # TODO(avik): use sympy value range analysis instead? 180 from sympy import Integer 181 182 from torch.utils._sympy.numbers import int_oo 183 184 if self.root.max is int_oo: # type: ignore[attr-defined] 185 return int_oo # fn not needed cuz increasing 186 187 _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] 188 root = self.root # type: ignore[attr-defined] 189 assert _max_symint <= sys.maxsize - 1, ( 190 f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. " 191 f"Please specify an appropriate max value for {root.__name__} " 192 f"(currently {root.max})." 193 ) 194 return int(_max_symint) 195 196 def _derive(self, fn): 197 # We support nesting, e.g., 2*dim + 1. 198 # This is implemented by composing operations on the same root. 199 # As a consequence, roots are always regular Dims (i.e., not derived Dims). 200 return _DerivedDim( 201 self._derived_name(fn), 202 (int,), 203 {"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined] 204 ) 205 206 207def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): 208 """ 209 :func:`Dim` constructs a type analogous to a named symbolic integer with a range. 210 It can be used to describe multiple possible values of a dynamic tensor dimension. 211 Note that different dynamic dimensions of the same tensor, or of different tensors, 212 can be described by the same type. 213 214 Args: 215 name (str): Human-readable name for debugging. 216 min (Optional[int]): Minimum possible value of given symbol (inclusive) 217 max (Optional[int]): Maximum possible value of given symbol (inclusive) 218 219 Returns: 220 A type that can be used in dynamic shape specifications for tensors. 221 """ 222 223 from torch.utils._sympy.numbers import int_oo 224 225 _min = 0 if min is None else min 226 _max = int_oo if max is None else max 227 assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" 228 assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}" 229 dim = _Dim(name, (int,), {"min": _min, "max": _max}) 230 dim.__module__ = getattr( 231 inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__" 232 ) 233 return dim 234 235 236Dim.AUTO = _DimHint.AUTO # type: ignore[attr-defined] 237Dim.STATIC = _DimHint.STATIC # type: ignore[attr-defined] 238 239 240def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): 241 """ 242 Util to create multiple :func:`Dim` types. 243 """ 244 return tuple(Dim(name, min=min, max=max) for name in names) 245 246 247@dataclasses.dataclass 248class _ConstraintTarget: 249 """ 250 This represents input tensor dimensions. 251 """ 252 253 t_id: int 254 dim: int 255 256 257@dataclasses.dataclass 258class _Constraint(_ConstraintTarget): 259 """ 260 This represents a Dim describing a constraint target. 261 262 `name` is the name of the Dim. 263 `constraint_range` contains the min/max bounds of the Dim. 264 """ 265 266 name: str 267 constraint_range: "StrictMinMaxConstraint" 268 269 def _clone_with_range(self, lower=0, upper=None): 270 # Import sympy locally 271 from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint 272 from torch.utils._sympy.numbers import int_oo 273 from torch.utils._sympy.value_ranges import ValueRanges 274 275 if upper is None: 276 upper = int_oo 277 278 constraint_range = StrictMinMaxConstraint( 279 vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), 280 warn_only=False, 281 ) 282 return _Constraint( 283 self.t_id, 284 self.dim, 285 self.name, 286 constraint_range, 287 ) 288 289 def __ge__(self, lower): 290 return self._clone_with_range(lower=lower) 291 292 def __gt__(self, lower): 293 return self._clone_with_range(lower=lower + 1) 294 295 def __le__(self, upper): 296 return self._clone_with_range(upper=upper) 297 298 def __lt__(self, upper): 299 return self._clone_with_range(upper=upper - 1) 300 301 def __bool__(self): 302 # NOTE(avik): We do not support compound expressions like a <= x <= b. 303 # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), 304 # and moreover, enforces that any overload of __bool__ must return True or False. 305 # FWIW, sympy also raises TypeError in this case. 306 raise TypeError( 307 "Cannot determine truth value of _Constraint. " 308 "If you are trying to combine _Constraint's with logical connectives, " 309 "you can specify them separately instead." 310 ) 311 312 @property 313 def serializable_spec(self): 314 # We need a serialization compatible format of the constraint so that it 315 # can be savedin the graph module w/o breaking the module serialization. 316 # The saved constraints will be used directly for the post-exporting pass 317 # that converts constraints to runtime assertion. The saved constraints 318 # will not be saved in the serialized module. 319 # TODO: A better way is needed. Currently we use 't_id' to map the constraint, 320 # which is not reliable 321 return { 322 "t_id": self.t_id, 323 "dim": self.dim, 324 "min": self.constraint_range.vr.lower, 325 "max": self.constraint_range.vr.upper, 326 } 327 328 329@dataclasses.dataclass 330class _PhantomRoot: 331 """ 332 This represents the root of a derived Dim where the root does not directly 333 specify the shape of any input dimension, but the derived Dim does. 334 335 e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. 336 337 The fields `name`, `constraint_range`, and `val` carried by a phantom root 338 help create a symbol for it. Any derived dims with this phantom root are 339 backed by expressions over this symbol. 340 """ 341 342 name: str 343 constraint_range: "StrictMinMaxConstraint" 344 val: int 345 346 347@dataclasses.dataclass 348class _DerivedConstraint(_ConstraintTarget): 349 """ 350 This represents a derived Dim, whose root is either a regular constraint target 351 (which directly specifies the shape of some input dimension) or a phantom root 352 (which does so indirectly). 353 354 It can be thought of as a subclass of `_Constraint`, except that it does not 355 support <, <=, >, >= operations. 356 """ 357 358 name: str 359 constraint_range: "StrictMinMaxConstraint" 360 root: Union[_ConstraintTarget, _PhantomRoot] 361 fn: Callable 362 363 @property 364 def serializable_spec(self): 365 # same as _Constraint.serializable_spec 366 return { 367 "t_id": self.t_id, 368 "dim": self.dim, 369 "min": self.constraint_range.vr.lower, 370 "max": self.constraint_range.vr.upper, 371 } 372 373 374Constraint = Union[_Constraint, _DerivedConstraint] 375 376 377def _process_equalities( 378 constraint: Constraint, 379 get_sources: Callable[[int, int], List["Source"]], 380 shape_env: "ShapeEnv", 381 names: Dict[str, Tuple[int, int]], 382 source_pairs: List[Tuple["Source", "Source"]], 383 derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]], 384 phantom_symbols: Dict[str, "Symbol"], 385): 386 """ 387 Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become 388 fields of `EqualityConstraint`) based on a given input `constraint`. 389 """ 390 391 sources = get_sources(constraint.t_id, constraint.dim) 392 if not sources: # empty sources due to unused shapes 393 return 394 395 source, *other_sources = sources 396 # When t.size()[dim] maps to src0, src1, ..., srcN, we add 397 # constraints that make src0 "equal" to src1, ..., srcN. 398 source_pairs.extend((source, other_source) for other_source in other_sources) 399 if not isinstance(constraint, _DerivedConstraint): 400 if constraint.name in names: 401 shared_t_id, shared_dim = names[constraint.name] 402 other_sources = get_sources(shared_t_id, shared_dim) 403 source_pairs.extend( 404 (source, other_source) for other_source in other_sources 405 ) 406 else: 407 names[constraint.name] = (constraint.t_id, constraint.dim) 408 else: 409 # branch based on the root of the _DerivedConstraint 410 if not isinstance(constraint.root, _PhantomRoot): 411 # either root points to an input source 412 root = get_sources(constraint.root.t_id, constraint.root.dim)[0] # type: ignore[assignment] 413 else: 414 # or root points to a phantom symbol 415 if constraint.root.name in phantom_symbols: 416 root = phantom_symbols[constraint.root.name] # type: ignore[assignment] 417 else: 418 # create a phantom symbol in the shape env based on the _PhantomRoot 419 root = shape_env.create_symbol( 420 val=constraint.root.val, 421 source=torch._dynamo.source.ConstantSource(constraint.root.name), 422 dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC, 423 constraint_dim=constraint.root.constraint_range, 424 ) 425 phantom_symbols[constraint.root.name] = root # type: ignore[assignment] 426 427 fn = constraint.fn 428 # A derived equality (source, root, fn) informally corresponds to source = fn(root). 429 # Here source describes an input and root might describe another input or a phantom symbol. 430 derived_equalities.append((source, root, fn)) 431 432 433def _tree_map_with_path( 434 func: Callable[..., Any], 435 tree: Any, 436 *dynamic_shapes: Any, 437 tree_name: Optional[str] = None, 438) -> Any: 439 """ 440 Customized tree_map for mapping pytrees to dynamic_shapes. 441 442 For built-in types (e.g., standard collections) this behaves exactly like tree_map. 443 444 OTOH for a user-defined class C registered with pytree, we cannot assume that a C 445 containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not 446 be a polymorphic container). In that case we use the flattened form of C instead. 447 Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes). 448 449 Args: 450 func: function to apply to each (int, float, str, bool, None, torch.Tensor) 451 tree: input pytree 452 dynamic_shapes: zero or more (typically one) dynamic_shapes to match 453 454 Returns: 455 output pytree mapping func to each (int, float, str, bool, None, torch.Tensor) 456 """ 457 458 def is_leaf(t): 459 # BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types 460 # registered with pytree. Types *not* in BUILTIN_TYPES include primitive types 461 # (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES, 462 # as well as user-defined classes registered with pytree, which are. 463 return _get_node_type(t) not in BUILTIN_TYPES 464 465 def f(path, t, *dynamic_shapes): 466 typ = _get_node_type(t) 467 # typ is not in BUILTIN_TYPES 468 if typ in SUPPORTED_NODES: 469 # thus typ is a user-defined class registered with pytree, 470 # in which case flatten and recurse 471 return tree_map_with_path( 472 f, 473 SUPPORTED_NODES[typ].flatten_fn(t)[0], 474 *dynamic_shapes, 475 is_leaf=is_leaf, 476 ) 477 else: 478 return func(path, t, *dynamic_shapes) 479 480 try: 481 return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf) 482 except ValueError as e: 483 if "mismatch" in e.args[0]: 484 # When PyTree finds a structural mismatch between tree and dynamic_shapes, 485 # the error message is unfortunately quite horrible. Let's fix that. 486 assert dynamic_shapes, "Cannot be a mismatch if there is no dynamic_shapes" 487 assert tree_name, "Must provide a tree_name when there might be a mismatch" 488 489 def _key(type_, context, i): 490 # derive a PyTree key given the type, context, and child # of a TreeSpec 491 if type_ is dict: 492 return MappingKey(context[i]) 493 if type_ in (list, tuple): 494 assert context is None 495 return SequenceKey(i) 496 raise AssertionError(f"Did not expect type {type_}") 497 498 def raise_mismatch_error(msg): 499 from torch._dynamo.exc import UserError, UserErrorType 500 501 raise UserError( 502 UserErrorType.INVALID_INPUT, 503 f"Detected mismatch between the structure of `{tree_name}` and `dynamic_shapes`: {msg}", 504 case_name="dynamic_shapes_validation", 505 ) 506 507 def _compare(tree, dynamic_shapes, path): 508 # raise an error at the point where tree and dynamic_shapes differ, 509 # including the path to that point and the reason for the difference 510 rendered_path = keystr(path) 511 if isinstance(tree, LeafSpec): 512 return 513 if isinstance(dynamic_shapes, LeafSpec): 514 raise_mismatch_error( 515 f"`{tree_name}{rendered_path}` is a {tree.type}, " 516 f"but `dynamic_shapes{rendered_path}` is not" 517 ) 518 if tree.type != dynamic_shapes.type: 519 raise_mismatch_error( 520 f"`{tree_name}{rendered_path}` is a {tree.type}, " 521 f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}" 522 ) 523 if len(tree.children_specs) != len(dynamic_shapes.children_specs): 524 raise_mismatch_error( 525 f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, " 526 f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements" 527 ) 528 if tree.type is dict: 529 # context, children could be out of order 530 if sorted(tree.context) != sorted(dynamic_shapes.context): 531 raise_mismatch_error( 532 f"`{tree_name}{rendered_path}` has keys {tree.context}, " 533 f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}" 534 ) 535 _remap = dict( 536 zip(dynamic_shapes.context, dynamic_shapes.children_specs) 537 ) 538 dynamic_shapes_children_specs = [_remap[k] for k in tree.context] 539 else: 540 dynamic_shapes_children_specs = dynamic_shapes.children_specs 541 for i, (tree_, dynamic_shapes_) in enumerate( 542 zip(tree.children_specs, dynamic_shapes_children_specs) 543 ): 544 _compare( 545 tree_, 546 dynamic_shapes_, 547 path + [_key(tree.type, tree.context, i)], 548 ) 549 550 _, tree_spec = tree_flatten(tree, is_leaf=is_leaf) 551 for other_tree in dynamic_shapes: 552 _, other_tree_spec = tree_flatten(other_tree, is_leaf) 553 _compare(tree_spec, other_tree_spec, []) 554 raise 555 556 557def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> Dict[str, Any]: 558 # combine args and kwargs following the signature of f, as it happens 559 # in the body of f when called with *args, **kwargs 560 if isinstance(f, ExportedProgram): 561 f = f.module() 562 if not _is_torch_jit_trace: 563 signature = ( 564 inspect.signature(f.forward) 565 if isinstance(f, torch.nn.Module) 566 else inspect.signature(f) 567 ) 568 kwargs = kwargs if kwargs is not None else {} 569 return signature.bind(*args, **kwargs).arguments 570 return args 571 572 573class ShapesCollection: 574 """ 575 Builder for dynamic_shapes. 576 Used to assign dynamic shape specifications to tensors that appear in inputs. 577 578 Example:: 579 args = ({"x": tensor_x, "others": [tensor_y, tensor_z]}) 580 581 dim = torch.export.Dim(...) 582 dynamic_shapes = torch.export.ShapesCollection() 583 dynamic_shapes[tensor_x] = (dim, dim + 1, 8) 584 dynamic_shapes[tensor_y] = {0: dim * 2} 585 # This is equivalent to the following (now auto-generated): 586 # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]} 587 588 torch.export(..., args, dynamic_shapes=dynamic_shapes) 589 """ 590 591 def __init__(self): 592 self._shapes = {} 593 594 def __setitem__(self, t, shape): 595 assert isinstance( 596 t, torch.Tensor 597 ), f"Cannot assign shape to non-tensor type {type(t)}" 598 # TODO(avik): check that shape is indeed a Shape 599 t_id = id(t) 600 if t_id in self._shapes: 601 _shape = self._shapes[t_id] 602 assert ( 603 shape == _shape 604 ), f"Shapes assigned to tensor do not match: expected {_shape}, got {shape}" 605 else: 606 self._shapes[id(t)] = shape 607 608 def __getitem__(self, t): 609 t_id = id(t) 610 if t_id in self._shapes: 611 return self._shapes[t_id] 612 else: 613 return None 614 615 def __len__(self): 616 return len(self._shapes) 617 618 def dynamic_shapes(self, m, args, kwargs=None): 619 """ 620 Generate dynamic_shapes. 621 """ 622 623 t_ids = set() 624 625 def find_shape(path, t): 626 t_id = id(t) 627 if t_id in self._shapes: 628 t_ids.add(t_id) 629 return self._shapes[t_id] 630 else: 631 return None 632 633 combined_args = _combine_args(m, args, kwargs) 634 dynamic_shapes = _tree_map_with_path(find_shape, combined_args) 635 if any(t_id not in t_ids for t_id in self._shapes): 636 raise ValueError( 637 "Some tensors that were assigned shapes were not found in args. " 638 "Maybe such tensors were copied when passing them as args? " 639 "Maybe such tensors are contained in classes that were not registered with pytree?" 640 ) 641 return dynamic_shapes 642 643 644def _warn_on_None_dynamic_shape_dimension(): 645 msg = ( 646 "Using None as a dynamic shape dimension is deprecated. " 647 "Please use Dim.STATIC instead" 648 ) 649 # TODO(avik): raise an error in the future 650 log.warning(msg) 651 652 653def _check_dynamic_shapes( 654 combined_args: Dict[str, Any], 655 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], 656): 657 """ 658 Checks the dynamic_shapes specification for correctness, 659 using combined args + kwargs as reference for inputs structure. 660 """ 661 from torch._dynamo.exc import UserError, UserErrorType 662 from torch._export.non_strict_utils import _flatten_dynamic_shapes 663 664 if dynamic_shapes is None or len(dynamic_shapes) == 0: 665 return 666 if isinstance(dynamic_shapes, (tuple, list)): 667 combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] 668 669 bounds: Dict[str, Tuple[int, int]] = {} 670 671 def check_same_bounds(dim): 672 if dim.__name__ in bounds: 673 min_, max_ = bounds[dim.__name__] 674 if dim.min != min_ or dim.max != max_: 675 this_ = _Dim.readable(dim.__name__, min_, max_) 676 that_ = _Dim.readable(dim.__name__, dim.min, dim.max) 677 raise UserError( 678 UserErrorType.INVALID_INPUT, 679 f"Found different definitions {this_} and {that_} " 680 f"for the same symbolic dimension {dim}!", 681 ) 682 else: 683 bounds[dim.__name__] = (dim.min, dim.max) 684 685 def check_symbols(path, tensor, shape): 686 if isinstance(shape, dict): 687 for i, dim in shape.items(): 688 if isinstance(dim, _Dim): 689 check_same_bounds(dim) 690 elif dim is None: 691 _warn_on_None_dynamic_shape_dimension() 692 elif not (isinstance(dim, (int, _DimHint))): 693 raise UserError( 694 UserErrorType.INVALID_INPUT, 695 f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " 696 f"specified at `dynamic_shapes{keystr(path)}` " 697 f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", 698 case_name="dynamic_shapes_validation", 699 ) 700 elif isinstance(shape, (tuple, list)): 701 for i, dim in enumerate(shape): 702 if isinstance(dim, _Dim): 703 check_same_bounds(dim) 704 elif dim is None: 705 _warn_on_None_dynamic_shape_dimension() 706 elif not (isinstance(dim, (int, _DimHint))): 707 raise UserError( 708 UserErrorType.INVALID_INPUT, 709 f"Unexpected dimension #{i} in input tensor shape {shape} " 710 f"specified at `dynamic_shapes{keystr(path)}` " 711 f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", 712 case_name="dynamic_shapes_validation", 713 ) 714 elif shape is not None: 715 raise UserError( 716 UserErrorType.INVALID_INPUT, 717 f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` " 718 f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," 719 f" where each dimension is an int, a Dim, Dim.AUTO, or Dim.STATIC)", 720 case_name="dynamic_shapes_validation", 721 ) 722 723 assert isinstance(dynamic_shapes, (dict, tuple, list)) 724 if isinstance(dynamic_shapes, dict): 725 got_keys = list(dynamic_shapes.keys()) 726 expected_arg_names = list(combined_args.keys()) 727 if sorted(got_keys) != sorted(expected_arg_names): 728 msg = ( 729 f"When `dynamic_shapes` is specified as a dict, its top-level keys " 730 f"must be the arg names {expected_arg_names} of `inputs`, but " 731 f"here they are {got_keys}. " 732 ) 733 if ( 734 len(combined_args) == 1 735 and expected_arg_names[0] not in got_keys 736 and isinstance(combined_args[expected_arg_names[0]], dict) 737 ): 738 msg += ( 739 "Since here `inputs` is a list/tuple enclosing a single dict, " 740 "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?" 741 ) 742 else: 743 msg += ( 744 "Alternatively, you could also ignore arg names entirely " 745 "and specify `dynamic_shapes` as a list/tuple matching `inputs`." 746 ) 747 raise UserError( 748 UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation" 749 ) 750 751 def check_shape(path, t, dynamic_shape): 752 if isinstance(t, torch.Tensor): 753 check_symbols(path, t, dynamic_shape) 754 else: 755 if dynamic_shape is not None: 756 rendered_path = keystr(path) 757 raise UserError( 758 UserErrorType.INVALID_INPUT, 759 f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` " 760 f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)", 761 case_name="dynamic_shapes_validation", 762 ) 763 764 _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs") 765 766 # raise user warning if both Dim.AUTO & Dims are specified in dynamic_shapes 767 flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) 768 flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes) 769 if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any( 770 s == _DimHint.AUTO for s in flatter_dynamic_shapes 771 ): 772 raise UserError( 773 UserErrorType.INVALID_INPUT, 774 "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " 775 "and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims " 776 "expect all equal or related dimensions to be specified, and does not yet compose well with `Dim.AUTO`. " 777 "We suggest using `Dim.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " 778 "torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` " 779 "if you want to assert on the exact specification of your program's dynamic shapes behavior.", 780 case_name="dynamic_shapes_validation", 781 ) 782 783 784def _transform_shapes_for_default_dynamic( 785 combined_args: Dict[str, Any], 786 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], 787) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]: 788 """ 789 In the long run this might not be needed, but this exists because export.export() and _dynamo.export() 790 historically have different semantics for how dynamic_shapes are specified, but go through the same 791 process of producing constraints, and now both use assume_static_by_default=False. 792 793 For _dynamo.export(), the semantics for dynamic_shapes are: 794 - None: dynamic, allocated a symbol 795 - Dim/DerivedDim: a strict assertion on the min/max range for this symbol, and require a specification 796 for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.) 797 798 For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are: 799 - Dim.AUTO: dynamic, allocated a symbol 800 - None/unspecified/Dim.STATIC: static 801 - Dim/DerivedDims: also a strict assertion 802 803 To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes 804 for export.export() to be compatible with _process_dynamic_shapes() and assume_static_by_default=False, turning them 805 into essentially what they'd look like for _dynamo.export(). 806 807 An example conversion might look like, for a 3-d input tensor: 808 809 input spec: { 810 0: Dim.AUTO, 811 1: None, # or Dim.STATIC 812 2: Dim("dx"), 813 } 814 output spec: { 815 0: None, # None: dynamic by default 816 1: 32, # explicitly provide static shape 817 2: Dim("dx"), # remains the same 818 } 819 """ 820 821 def _tree_map_helper(tree, val): 822 """ 823 If the user generally specifies dynamic_shapes=None for a pytree input, 824 we'd like to convert this into a tree of Nones following the input spec, 825 so we can explicitly specify static dims for all tensor dimensions. 826 Non-builtin types for pytree (e.g. custom dataclasses) creates some difficulty, 827 in which case the correct format is a list containing specs for each child attribute. 828 """ 829 if (node_type := _get_node_type(tree)) not in SUPPORTED_NODES: # is_leaf 830 return val 831 flatten_fn = SUPPORTED_NODES[node_type].flatten_fn 832 child_pytrees, context = flatten_fn(tree) # flatten from whatever original type 833 unflatten_fn = SUPPORTED_NODES[ 834 node_type if node_type in BUILTIN_TYPES else list 835 ].unflatten_fn 836 children = [_tree_map_helper(child, val) for child in child_pytrees] 837 return unflatten_fn( 838 children, context 839 ) # unflatten into original type, or list if not built-in type 840 841 if ( 842 dynamic_shapes is None or len(dynamic_shapes) == 0 843 ): # create pytree structure of static dim 844 dynamic_shapes = _tree_map_helper(combined_args, None) 845 if isinstance(dynamic_shapes, (tuple, list)): 846 combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] 847 848 def transform_shapes(path, tensor, shape): 849 def _marked_dynamic(tensor, i): 850 # TODO(pianpwk): deprecate mark_dynamic() usage for export 851 return i in getattr(tensor, "_dynamo_dynamic_indices", set()) 852 853 out: Union[None, List[Any], Dict[int, Any]] = None 854 if isinstance(shape, dict): 855 out = {} 856 for i, val in enumerate(tensor.shape): 857 dim = shape.get(i, _DimHint.STATIC) 858 if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: 859 # don't have to specify anything if dynamic 860 # None also works, since assume_static_by_default=False 861 if dim == _DimHint.AUTO: 862 torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing 863 continue 864 elif isinstance(dim, _Dim): 865 out[i] = dim 866 elif isinstance(dim, int): 867 # important that this is dim and not val, 868 # so we can raise error if user-specified dim != val 869 out[i] = dim 870 elif dim is None: 871 _warn_on_None_dynamic_shape_dimension() 872 out[i] = val 873 else: 874 # make explicitly static 875 assert dim == _DimHint.STATIC 876 out[i] = val 877 elif isinstance(shape, (tuple, list)): 878 out = [] 879 for i, val in enumerate(tensor.shape): 880 dim = shape[i] 881 if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO: 882 if dim == _DimHint.AUTO: 883 torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing 884 out.append(None) 885 elif isinstance(dim, _Dim): 886 out.append(dim) 887 elif isinstance(dim, int): 888 out.append(dim) 889 elif dim is None: 890 _warn_on_None_dynamic_shape_dimension() 891 out.append(val) 892 else: 893 assert dim == _DimHint.STATIC 894 out.append(val) 895 out = type(shape)(out) # type: ignore[assignment] 896 else: 897 assert shape is None 898 if isinstance(tensor, torch.Tensor): 899 out = [] 900 for i, val in enumerate(tensor.shape): 901 out.append(None if _marked_dynamic(tensor, i) else val) 902 out = out or None 903 else: 904 out = None 905 return out 906 907 def transform_shape(path, t, dynamic_shape): 908 if isinstance(t, torch.Tensor): 909 return transform_shapes(path, t, dynamic_shape) 910 911 result = _tree_map_with_path( 912 transform_shape, combined_args, dynamic_shapes, tree_name="inputs" 913 ) 914 return result 915 916 917def _process_dynamic_shapes( 918 combined_args: Dict[str, Any], 919 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], 920) -> List[Constraint]: 921 """ 922 Reads the dynamic_shapes specification and produces a list of constraints. 923 """ 924 from torch._dynamo.exc import UserError, UserErrorType 925 926 if dynamic_shapes is None or len(dynamic_shapes) == 0: 927 # we run with dynamic by default, so no need to produce constraints 928 return [] 929 if isinstance(dynamic_shapes, (tuple, list)): 930 combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] 931 932 # map of Dim names representing input shape dimensions to constraints on them 933 symbols: Dict[str, List[Constraint]] = defaultdict(list) 934 # track roots that do not directly represent input shape dimensions 935 phantom_roots: Dict[str, _PhantomRoot] = {} 936 derived_constraints_with_phantom_root: List[_DerivedConstraint] = [] 937 938 def to_constraint(dim, tensor, i): 939 import sympy 940 941 from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint 942 from torch.utils._sympy.solve import try_solve 943 from torch.utils._sympy.value_ranges import ValueRanges 944 945 def root_value(): 946 # given tensor.shape[i] is the value of dim = fn(root), 947 # find the value of root 948 symbol = sympy.Symbol(dim.root.__name__, integer=True) 949 expr = dim.fn(symbol) 950 solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol) 951 if solution is not None: 952 return int(solution[1]) # type: ignore[call-overload] 953 else: 954 raise UserError( # noqa: B904 955 UserErrorType.CONSTRAINT_VIOLATION, 956 f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " 957 f"of the form {expr}, where {symbol} is an integer", 958 ) 959 960 if isinstance(dim, _DerivedDim): 961 # generate a _DerivedConstraint where the root is: 962 # - either a _ConstraintTarget (if dim.root directly describes an input shape) 963 # - or a _PhantomRoot (otherwise) 964 dim_root = dim.root # type: ignore[attr-defined] 965 if dim_root.__name__ in symbols: 966 # root represents an input shape dimension 967 root_constraint = symbols[dim_root.__name__][0] 968 root = _ConstraintTarget( 969 root_constraint.t_id, 970 root_constraint.dim, 971 ) 972 elif dim_root.__name__ not in phantom_roots: 973 # create a phantom root 974 root = _PhantomRoot( # type: ignore[assignment] 975 name=dim_root.__name__, 976 constraint_range=StrictMinMaxConstraint( 977 vr=ValueRanges(lower=dim_root.min, upper=dim_root.max), 978 warn_only=False, 979 ), 980 val=root_value(), 981 ) 982 phantom_roots[dim_root.__name__] = root # type: ignore[assignment] 983 else: 984 root = phantom_roots[dim_root.__name__] # type: ignore[assignment] 985 constraint = _DerivedConstraint( 986 id(tensor), 987 i, 988 dim.__name__, 989 StrictMinMaxConstraint( 990 vr=ValueRanges(lower=dim.min, upper=dim.max), 991 warn_only=False, 992 ), 993 root, 994 dim.fn, # type: ignore[attr-defined] 995 ) 996 if isinstance(root, _PhantomRoot): 997 # NOTE(avik): since we have not processed all inputs yet, we may replace this 998 # with a root that does represent an input shape dimension later (see below) 999 derived_constraints_with_phantom_root.append(constraint) 1000 elif isinstance(dim, _StaticDim): 1001 constraint = _Constraint( # type: ignore[assignment] 1002 id(tensor), 1003 i, 1004 dim.__name__, 1005 StrictMinMaxConstraint( 1006 vr=ValueRanges(lower=dim.value, upper=dim.value), warn_only=False # type: ignore[attr-defined] 1007 ), 1008 ) 1009 else: 1010 constraint = _Constraint( # type: ignore[assignment] 1011 id(tensor), 1012 i, 1013 dim.__name__, 1014 StrictMinMaxConstraint( 1015 vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False # type: ignore[attr-defined] 1016 ), 1017 ) 1018 return constraint 1019 1020 def update_symbols(path, tensor, shape): 1021 def _create_static_dim(tensor, i, value): 1022 return _StaticDim(str(value), (int,), {"value": value}) 1023 1024 if isinstance(shape, dict): 1025 for i, dim in shape.items(): 1026 if isinstance(dim, (int, _Dim)): 1027 if isinstance(dim, int): 1028 dim = _create_static_dim(tensor, i, dim) 1029 constraint = to_constraint(dim, tensor, i) 1030 symbols[dim.__name__].append(constraint) 1031 elif isinstance(shape, (tuple, list)): 1032 for i, dim in enumerate(shape): 1033 if isinstance(dim, (int, _Dim)): 1034 if isinstance(dim, int): 1035 dim = _create_static_dim(tensor, i, dim) 1036 constraint = to_constraint(dim, tensor, i) 1037 symbols[dim.__name__].append(constraint) 1038 1039 def assoc_shape(path, t, dynamic_shape): 1040 if isinstance(t, torch.Tensor): 1041 update_symbols(path, t, dynamic_shape) 1042 1043 _tree_map_with_path(assoc_shape, combined_args, dynamic_shapes, tree_name="inputs") 1044 1045 constraints = [] 1046 for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: 1047 phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr] 1048 if phantom_root_name in symbols: 1049 # We found an input shape dimension corresponding to this name, so we 1050 # do not need a phantom symbol for it after all. 1051 # NOTE(avik): Overall we want to maintain the invariant that roots that 1052 # are phantom symbols are really "phantom," i.e., they cannot be represented 1053 # by any input source. This is important when we are deciding derived equalities, 1054 # since we can focus our attention exclusively on input sources: deciding 1055 # derived equalities involving phantom symbols are, in comparison, trivial. 1056 derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0] 1057 1058 for dynamic_dims in symbols.values(): 1059 constraints.extend(dynamic_dims) 1060 1061 return constraints # type: ignore[return-value] 1062 1063 1064def _get_dim_name_mapping( 1065 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None] 1066): 1067 name_to_dim = {} 1068 for dim in tree_flatten( 1069 dynamic_shapes, 1070 is_leaf=lambda x: isinstance(x, _Dim), 1071 )[0]: 1072 if dim is None: 1073 # NOTE: this must denote a non-Tensor or automatic at this point. 1074 continue 1075 if isinstance(dim, int): 1076 continue 1077 assert isinstance(dim, _Dim) # dim hints should have boiled away 1078 name_to_dim[dim.__name__] = dim 1079 if isinstance(dim, _DerivedDim): 1080 name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined] 1081 return name_to_dim 1082 1083 1084def refine_dynamic_shapes_from_suggested_fixes( 1085 msg: str, 1086 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], 1087) -> Union[Dict[str, Any], Tuple[Any], List[Any]]: 1088 """ 1089 For working with export's dynamic shapes suggested fixes, and/or automatic dynamic shapes. 1090 Refines the given dynamic shapes spec, given a ConstraintViolation error message and the original dynamic shapes. 1091 1092 For most cases behavior is straightforward - i.e. for suggested fixes that specialize or refine a Dim's range, 1093 or fixes that suggest a derived relation, the new dynamic shapes spec will be updated as such. 1094 1095 e.g. 1096 Suggested fixes: 1097 1098 dim = Dim('dim', min=3, max=6) -> this just refines the dim's range 1099 dim = 4 -> this specializes to a constant 1100 dy = dx + 1 -> dy was specified as an independent dim, but is actually tied to dx with this relation 1101 1102 However, suggested fixes associated with derived dims can be more complicated. 1103 For example, if a suggested fix is provided for a root dim, the new derived dim value is evaluated based on the root. 1104 1105 e.g. 1106 dx = Dim('dx') 1107 dy = dx + 2 1108 dynamic_shapes = {"x": (dx,), "y": (dy,)} 1109 1110 Suggested fixes: 1111 1112 dx = 4 # specialization will lead to dy also specializing = 6 1113 dx = Dim('dx', max=6) # dy now has max = 8 1114 1115 Derived dims suggested fixes can also be used to express divisibility constraints. 1116 This involves creating new root dims that aren't tied to a particular input shape. 1117 In this case the root dims won't appear directly in the new spec, but as a root of 1118 one of the dims. 1119 1120 e.g. 1121 Suggested fixes: 1122 1123 _dx = Dim('_dx', max=1024) # this won't appear in the return result, but dx will 1124 dx = 4*_dx # dx is now divisible by 4, with a max value of 4096 1125 """ 1126 1127 import re 1128 1129 import sympy 1130 1131 from torch._dynamo.exc import UserError, UserErrorType 1132 from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence 1133 1134 try: 1135 shape_fixes_msg = msg.split("Suggested fixes:")[1].strip() 1136 except Exception as exc: 1137 raise UserError( 1138 UserErrorType.INVALID_INPUT, 1139 "Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()", 1140 ) from exc 1141 1142 # build shape_fixes dictionary 1143 shape_fixes = {} 1144 for fix in shape_fixes_msg.split("\n"): 1145 fix = fix.strip() 1146 if match := re.match(r"(.*) = Dim\('(.*)'.*\)", fix): 1147 name = match.group(1) 1148 _min, _max = None, None 1149 if match_min := re.match(r".* = Dim\('.*', min\=([0-9]+).*\)", fix): 1150 _min = int(match_min.group(1)) 1151 if match_max := re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)", fix): 1152 _max = int(match_max.group(1)) 1153 shape_fixes[name] = Dim(name, min=_min, max=_max) 1154 else: 1155 name, expr = fix.split(" = ") 1156 expr = sympy.sympify(expr) 1157 if isinstance(expr, sympy.Number): 1158 # static, integer 1159 shape_fixes[name] = int(expr) # type: ignore[assignment] 1160 else: 1161 # relation or derived dim 1162 shape_fixes[name] = expr 1163 1164 name_to_dim = _get_dim_name_mapping(dynamic_shapes) 1165 1166 # track derived dim roots 1167 roots: Set[str] = set() 1168 for k, c in shape_fixes.items(): 1169 assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr)) 1170 if isinstance(c, sympy.Expr): # check dim/derived dim expression 1171 assert _is_supported_equivalence(c) 1172 shape_fixes[k] = c 1173 roots.add(str(next(iter(c.free_symbols)))) 1174 if isinstance(c, _DerivedDim): 1175 roots.add(c.root.__name__) # type: ignore[attr-defined] 1176 1177 # check keys are existing dims or new roots 1178 for k, c in shape_fixes.items(): 1179 assert k in name_to_dim or k in roots 1180 1181 # cache so we don't produce multiple derived dim objects 1182 derived_dim_cache: Dict[str, _DerivedDim] = {} 1183 1184 def apply_fixes(path, dim, dummy): 1185 if dim is None or isinstance(dim, int): # not dynamic 1186 return dim 1187 elif dim.__name__ in shape_fixes: # directly fix 1188 fix = shape_fixes[dim.__name__] 1189 if isinstance(fix, sympy.Expr): # now derived or related 1190 if str(fix) in derived_dim_cache: 1191 return derived_dim_cache[str(fix)] 1192 else: 1193 symbol = next(iter(fix.free_symbols)) 1194 # try to locate symbol 1195 if symbol.name in shape_fixes: # type: ignore[attr-defined] 1196 root = shape_fixes[symbol.name] # type: ignore[attr-defined] 1197 else: 1198 assert symbol.name in name_to_dim # type: ignore[attr-defined] 1199 root = name_to_dim[symbol.name] # type: ignore[attr-defined] 1200 # figure out value of fix 1201 modulus, remainder = sympy.polys.polytools.div(fix, symbol) 1202 dim = root 1203 if modulus != 1: 1204 dim = int(modulus) * dim 1205 if remainder != 0: 1206 dim = dim + int(remainder) 1207 derived_dim_cache[str(fix)] = dim 1208 return dim 1209 else: 1210 return fix 1211 elif isinstance(dim, _DerivedDim) and dim.root.__name__ in shape_fixes: # type: ignore[attr-defined] 1212 if dim.__name__ in derived_dim_cache: 1213 return derived_dim_cache[dim.__name__] 1214 else: # evaluate new derived value based on root 1215 _dim = dim.fn(shape_fixes[dim.root.__name__]) # type: ignore[attr-defined] 1216 derived_dim_cache[dim.__name__] = _dim 1217 return _dim 1218 return dim # unchanged dim 1219 1220 return _tree_map_with_path(apply_fixes, dynamic_shapes, dynamic_shapes) 1221