xref: /aosp_15_r20/external/pytorch/torch/export/dynamic_shapes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import inspect
4import logging
5import sys
6from collections import defaultdict
7from enum import auto, Enum
8from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
9
10import torch
11from torch.utils._pytree import (
12    _get_node_type,
13    BUILTIN_TYPES,
14    keystr,
15    LeafSpec,
16    MappingKey,
17    SequenceKey,
18    SUPPORTED_NODES,
19    tree_flatten,
20    tree_map_with_path,
21)
22
23from .exported_program import ExportedProgram
24
25
26if TYPE_CHECKING:
27    from sympy import Symbol
28
29    from torch._guards import Source
30    from torch.fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint
31
32__all__ = [
33    "Constraint",
34    "Dim",
35    "dims",
36    "refine_dynamic_shapes_from_suggested_fixes",
37]
38
39
40log = logging.getLogger(__name__)
41
42
43class _DimHint(Enum):
44    """
45    Enum for dynamic shape hints.
46    - AUTO means automatic inference of shape (static or dynamic).
47    - STATIC means static shape (always specialized).
48    """
49
50    AUTO = auto()
51    STATIC = auto()
52
53
54class _Dim(type):
55    """
56    Metaclass for :func:`Dim` types.
57    """
58
59    @staticmethod
60    def readable(name, min_, max_):
61        from torch.utils._sympy.numbers import int_oo
62
63        if min_ == 2:
64            min_ = None
65        if max_ == int_oo:
66            max_ = None
67        if min_ is None and max_ is None:
68            return f"Dim('{name}')"
69        if min_ is None:
70            return f"Dim('{name}', max={max_})"
71        if max_ is None:
72            return f"Dim('{name}', min={min_})"
73        return f"Dim('{name}', min={min_}, max={max_})"
74
75    def __add__(cls, other):
76        # e.g., dim + 1
77        if type(other) is not int:
78            raise NotImplementedError(
79                f"Attempted to add {other} to {cls.__name__}, where an integer was expected. "
80                "(Only increasing linear operations with integer coefficients are supported.)"
81            )
82        return cls._derive(lambda x: x + other)
83
84    def __radd__(cls, other):
85        return cls + other
86
87    def __sub__(cls, other):
88        # e.g., dim - 1
89        if type(other) is not int:
90            raise NotImplementedError(
91                f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. "
92                "(Only increasing linear operations with integer coefficients are supported.)"
93            )
94        return cls._derive(lambda x: x - other)
95
96    def __rsub__(cls, other):
97        raise NotImplementedError(
98            f"Attempted to negate {cls.__name__}. "
99            "(Only increasing linear operations with integer coefficients are supported.)"
100        )
101
102    def __mul__(cls, other):
103        # e.g., dim * 2
104        if type(other) is not int or other <= 0:
105            raise NotImplementedError(
106                f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. "
107                "(Only increasing linear operations with integer coefficients are supported.)"
108            )
109        return cls._derive(lambda x: x * other)
110
111    def __rmul__(cls, other):
112        return cls * other
113
114    def _derived_name(cls, fn):
115        from sympy import sympify
116
117        return str(fn(sympify(cls.__name__)))
118
119    def _derive(cls, fn):
120        return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn})
121
122
123class _StaticDim(_Dim):
124    """
125    Meta class for static :func:`Dim` types.
126
127    This class is only for setting and checking static dim constraints,
128    and the user should never interact with it.
129    """
130
131    @property
132    def min(self):
133        return self.value  # type: ignore[attr-defined]
134
135    @property
136    def max(self):
137        return self.value  # type: ignore[attr-defined]
138
139
140class _DerivedDim(_Dim):
141    """
142    Metaclass for derived :func:`Dim` types.
143
144    Currently we only support increasing linear expressions with integer coefficients.
145    In other words, a derived Dim can always be written in the form Ax + B, where
146    x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive.
147    (In particular, the latter ensures that x < y => Ax + B < Ay + B.)
148    These restrictions on the form of derived Dims makes the metatheory simpler: e.g.,
149    it simplifies computing ranges for derived Dims, solving for underlying regular Dims,
150    deciding equalities between derived Dims, and so on.
151
152    The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`.
153    The range of a derived Dim is computed by mapping `fn` over the range of its `root`.
154    """
155
156    @property
157    def min(self):
158        # assume that self.fn is an increasing function
159        # TODO(avik): use sympy value range analysis instead?
160        from sympy import Integer
161
162        from torch.utils._sympy.numbers import int_oo
163
164        if self.root.min is -int_oo:  # type: ignore[attr-defined]
165            return -int_oo  # fn not needed cuz increasing
166
167        _min_symint = self.fn(Integer(self.root.min))  # type: ignore[attr-defined]
168        root = self.root  # type: ignore[attr-defined]
169        assert _min_symint >= 0, (
170            f"Expected derived min value of {self.__name__} to be >= 0. "
171            f"Please specify an appropriate min value for {root.__name__} "
172            f"(currently {root.min})."
173        )
174        return int(_min_symint)
175
176    @property
177    def max(self):
178        # assume that self.fn is an increasing function
179        # TODO(avik): use sympy value range analysis instead?
180        from sympy import Integer
181
182        from torch.utils._sympy.numbers import int_oo
183
184        if self.root.max is int_oo:  # type: ignore[attr-defined]
185            return int_oo  # fn not needed cuz increasing
186
187        _max_symint = self.fn(Integer(self.root.max))  # type: ignore[attr-defined]
188        root = self.root  # type: ignore[attr-defined]
189        assert _max_symint <= sys.maxsize - 1, (
190            f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. "
191            f"Please specify an appropriate max value for {root.__name__} "
192            f"(currently {root.max})."
193        )
194        return int(_max_symint)
195
196    def _derive(self, fn):
197        # We support nesting, e.g., 2*dim + 1.
198        # This is implemented by composing operations on the same root.
199        # As a consequence, roots are always regular Dims (i.e., not derived Dims).
200        return _DerivedDim(
201            self._derived_name(fn),
202            (int,),
203            {"root": self.root, "fn": lambda x: fn(self.fn(x))},  # type: ignore[attr-defined]
204        )
205
206
207def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
208    """
209    :func:`Dim` constructs a type analogous to a named symbolic integer with a range.
210    It can be used to describe multiple possible values of a dynamic tensor dimension.
211    Note that different dynamic dimensions of the same tensor, or of different tensors,
212    can be described by the same type.
213
214    Args:
215        name (str): Human-readable name for debugging.
216        min (Optional[int]): Minimum possible value of given symbol (inclusive)
217        max (Optional[int]): Maximum possible value of given symbol (inclusive)
218
219    Returns:
220        A type that can be used in dynamic shape specifications for tensors.
221    """
222
223    from torch.utils._sympy.numbers import int_oo
224
225    _min = 0 if min is None else min
226    _max = int_oo if max is None else max
227    assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
228    assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
229    dim = _Dim(name, (int,), {"min": _min, "max": _max})
230    dim.__module__ = getattr(
231        inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__"
232    )
233    return dim
234
235
236Dim.AUTO = _DimHint.AUTO  # type: ignore[attr-defined]
237Dim.STATIC = _DimHint.STATIC  # type: ignore[attr-defined]
238
239
240def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None):
241    """
242    Util to create multiple :func:`Dim` types.
243    """
244    return tuple(Dim(name, min=min, max=max) for name in names)
245
246
247@dataclasses.dataclass
248class _ConstraintTarget:
249    """
250    This represents input tensor dimensions.
251    """
252
253    t_id: int
254    dim: int
255
256
257@dataclasses.dataclass
258class _Constraint(_ConstraintTarget):
259    """
260    This represents a Dim describing a constraint target.
261
262    `name` is the name of the Dim.
263    `constraint_range` contains the min/max bounds of the Dim.
264    """
265
266    name: str
267    constraint_range: "StrictMinMaxConstraint"
268
269    def _clone_with_range(self, lower=0, upper=None):
270        # Import sympy locally
271        from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
272        from torch.utils._sympy.numbers import int_oo
273        from torch.utils._sympy.value_ranges import ValueRanges
274
275        if upper is None:
276            upper = int_oo
277
278        constraint_range = StrictMinMaxConstraint(
279            vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
280            warn_only=False,
281        )
282        return _Constraint(
283            self.t_id,
284            self.dim,
285            self.name,
286            constraint_range,
287        )
288
289    def __ge__(self, lower):
290        return self._clone_with_range(lower=lower)
291
292    def __gt__(self, lower):
293        return self._clone_with_range(lower=lower + 1)
294
295    def __le__(self, upper):
296        return self._clone_with_range(upper=upper)
297
298    def __lt__(self, upper):
299        return self._clone_with_range(upper=upper - 1)
300
301    def __bool__(self):
302        # NOTE(avik): We do not support compound expressions like a <= x <= b.
303        # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b),
304        # and moreover, enforces that any overload of __bool__ must return True or False.
305        # FWIW, sympy also raises TypeError in this case.
306        raise TypeError(
307            "Cannot determine truth value of _Constraint. "
308            "If you are trying to combine _Constraint's with logical connectives, "
309            "you can specify them separately instead."
310        )
311
312    @property
313    def serializable_spec(self):
314        # We need a serialization compatible format of the constraint so that it
315        # can be savedin the graph module w/o breaking the module serialization.
316        # The saved constraints will be used directly for the post-exporting pass
317        # that converts constraints to runtime assertion. The saved constraints
318        # will not be saved in the serialized module.
319        # TODO: A better way is needed. Currently we use 't_id' to map the constraint,
320        # which is not reliable
321        return {
322            "t_id": self.t_id,
323            "dim": self.dim,
324            "min": self.constraint_range.vr.lower,
325            "max": self.constraint_range.vr.upper,
326        }
327
328
329@dataclasses.dataclass
330class _PhantomRoot:
331    """
332    This represents the root of a derived Dim where the root does not directly
333    specify the shape of any input dimension, but the derived Dim does.
334
335    e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim.
336
337    The fields `name`, `constraint_range`, and `val` carried by a phantom root
338    help create a symbol for it. Any derived dims with this phantom root are
339    backed by expressions over this symbol.
340    """
341
342    name: str
343    constraint_range: "StrictMinMaxConstraint"
344    val: int
345
346
347@dataclasses.dataclass
348class _DerivedConstraint(_ConstraintTarget):
349    """
350    This represents a derived Dim, whose root is either a regular constraint target
351    (which directly specifies the shape of some input dimension) or a phantom root
352    (which does so indirectly).
353
354    It can be thought of as a subclass of `_Constraint`, except that it does not
355    support <, <=, >, >= operations.
356    """
357
358    name: str
359    constraint_range: "StrictMinMaxConstraint"
360    root: Union[_ConstraintTarget, _PhantomRoot]
361    fn: Callable
362
363    @property
364    def serializable_spec(self):
365        # same as _Constraint.serializable_spec
366        return {
367            "t_id": self.t_id,
368            "dim": self.dim,
369            "min": self.constraint_range.vr.lower,
370            "max": self.constraint_range.vr.upper,
371        }
372
373
374Constraint = Union[_Constraint, _DerivedConstraint]
375
376
377def _process_equalities(
378    constraint: Constraint,
379    get_sources: Callable[[int, int], List["Source"]],
380    shape_env: "ShapeEnv",
381    names: Dict[str, Tuple[int, int]],
382    source_pairs: List[Tuple["Source", "Source"]],
383    derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]],
384    phantom_symbols: Dict[str, "Symbol"],
385):
386    """
387    Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become
388    fields of `EqualityConstraint`) based on a given input `constraint`.
389    """
390
391    sources = get_sources(constraint.t_id, constraint.dim)
392    if not sources:  # empty sources due to unused shapes
393        return
394
395    source, *other_sources = sources
396    # When t.size()[dim] maps to src0, src1, ..., srcN, we add
397    # constraints that make src0 "equal" to src1, ..., srcN.
398    source_pairs.extend((source, other_source) for other_source in other_sources)
399    if not isinstance(constraint, _DerivedConstraint):
400        if constraint.name in names:
401            shared_t_id, shared_dim = names[constraint.name]
402            other_sources = get_sources(shared_t_id, shared_dim)
403            source_pairs.extend(
404                (source, other_source) for other_source in other_sources
405            )
406        else:
407            names[constraint.name] = (constraint.t_id, constraint.dim)
408    else:
409        # branch based on the root of the _DerivedConstraint
410        if not isinstance(constraint.root, _PhantomRoot):
411            # either root points to an input source
412            root = get_sources(constraint.root.t_id, constraint.root.dim)[0]  # type: ignore[assignment]
413        else:
414            # or root points to a phantom symbol
415            if constraint.root.name in phantom_symbols:
416                root = phantom_symbols[constraint.root.name]  # type: ignore[assignment]
417            else:
418                # create a phantom symbol in the shape env based on the _PhantomRoot
419                root = shape_env.create_symbol(
420                    val=constraint.root.val,
421                    source=torch._dynamo.source.ConstantSource(constraint.root.name),
422                    dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC,
423                    constraint_dim=constraint.root.constraint_range,
424                )
425                phantom_symbols[constraint.root.name] = root  # type: ignore[assignment]
426
427        fn = constraint.fn
428        # A derived equality (source, root, fn) informally corresponds to source = fn(root).
429        # Here source describes an input and root might describe another input or a phantom symbol.
430        derived_equalities.append((source, root, fn))
431
432
433def _tree_map_with_path(
434    func: Callable[..., Any],
435    tree: Any,
436    *dynamic_shapes: Any,
437    tree_name: Optional[str] = None,
438) -> Any:
439    """
440    Customized tree_map for mapping pytrees to dynamic_shapes.
441
442    For built-in types (e.g., standard collections) this behaves exactly like tree_map.
443
444    OTOH for a user-defined class C registered with pytree, we cannot assume that a C
445    containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not
446    be a polymorphic container). In that case we use the flattened form of C instead.
447    Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes).
448
449    Args:
450        func: function to apply to each (int, float, str, bool, None, torch.Tensor)
451        tree: input pytree
452        dynamic_shapes: zero or more (typically one) dynamic_shapes to match
453
454    Returns:
455        output pytree mapping func to each (int, float, str, bool, None, torch.Tensor)
456    """
457
458    def is_leaf(t):
459        # BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types
460        # registered with pytree. Types *not* in BUILTIN_TYPES include primitive types
461        # (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES,
462        # as well as user-defined classes registered with pytree, which are.
463        return _get_node_type(t) not in BUILTIN_TYPES
464
465    def f(path, t, *dynamic_shapes):
466        typ = _get_node_type(t)
467        # typ is not in BUILTIN_TYPES
468        if typ in SUPPORTED_NODES:
469            # thus typ is a user-defined class registered with pytree,
470            # in which case flatten and recurse
471            return tree_map_with_path(
472                f,
473                SUPPORTED_NODES[typ].flatten_fn(t)[0],
474                *dynamic_shapes,
475                is_leaf=is_leaf,
476            )
477        else:
478            return func(path, t, *dynamic_shapes)
479
480    try:
481        return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf)
482    except ValueError as e:
483        if "mismatch" in e.args[0]:
484            # When PyTree finds a structural mismatch between tree and dynamic_shapes,
485            # the error message is unfortunately quite horrible. Let's fix that.
486            assert dynamic_shapes, "Cannot be a mismatch if there is no dynamic_shapes"
487            assert tree_name, "Must provide a tree_name when there might be a mismatch"
488
489            def _key(type_, context, i):
490                # derive a PyTree key given the type, context, and child # of a TreeSpec
491                if type_ is dict:
492                    return MappingKey(context[i])
493                if type_ in (list, tuple):
494                    assert context is None
495                    return SequenceKey(i)
496                raise AssertionError(f"Did not expect type {type_}")
497
498            def raise_mismatch_error(msg):
499                from torch._dynamo.exc import UserError, UserErrorType
500
501                raise UserError(
502                    UserErrorType.INVALID_INPUT,
503                    f"Detected mismatch between the structure of `{tree_name}` and `dynamic_shapes`: {msg}",
504                    case_name="dynamic_shapes_validation",
505                )
506
507            def _compare(tree, dynamic_shapes, path):
508                # raise an error at the point where tree and dynamic_shapes differ,
509                # including the path to that point and the reason for the difference
510                rendered_path = keystr(path)
511                if isinstance(tree, LeafSpec):
512                    return
513                if isinstance(dynamic_shapes, LeafSpec):
514                    raise_mismatch_error(
515                        f"`{tree_name}{rendered_path}` is a {tree.type}, "
516                        f"but `dynamic_shapes{rendered_path}` is not"
517                    )
518                if tree.type != dynamic_shapes.type:
519                    raise_mismatch_error(
520                        f"`{tree_name}{rendered_path}` is a {tree.type}, "
521                        f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}"
522                    )
523                if len(tree.children_specs) != len(dynamic_shapes.children_specs):
524                    raise_mismatch_error(
525                        f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, "
526                        f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements"
527                    )
528                if tree.type is dict:
529                    # context, children could be out of order
530                    if sorted(tree.context) != sorted(dynamic_shapes.context):
531                        raise_mismatch_error(
532                            f"`{tree_name}{rendered_path}` has keys {tree.context}, "
533                            f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}"
534                        )
535                    _remap = dict(
536                        zip(dynamic_shapes.context, dynamic_shapes.children_specs)
537                    )
538                    dynamic_shapes_children_specs = [_remap[k] for k in tree.context]
539                else:
540                    dynamic_shapes_children_specs = dynamic_shapes.children_specs
541                for i, (tree_, dynamic_shapes_) in enumerate(
542                    zip(tree.children_specs, dynamic_shapes_children_specs)
543                ):
544                    _compare(
545                        tree_,
546                        dynamic_shapes_,
547                        path + [_key(tree.type, tree.context, i)],
548                    )
549
550            _, tree_spec = tree_flatten(tree, is_leaf=is_leaf)
551            for other_tree in dynamic_shapes:
552                _, other_tree_spec = tree_flatten(other_tree, is_leaf)
553                _compare(tree_spec, other_tree_spec, [])
554        raise
555
556
557def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> Dict[str, Any]:
558    # combine args and kwargs following the signature of f, as it happens
559    # in the body of f when called with *args, **kwargs
560    if isinstance(f, ExportedProgram):
561        f = f.module()
562    if not _is_torch_jit_trace:
563        signature = (
564            inspect.signature(f.forward)
565            if isinstance(f, torch.nn.Module)
566            else inspect.signature(f)
567        )
568        kwargs = kwargs if kwargs is not None else {}
569        return signature.bind(*args, **kwargs).arguments
570    return args
571
572
573class ShapesCollection:
574    """
575    Builder for dynamic_shapes.
576    Used to assign dynamic shape specifications to tensors that appear in inputs.
577
578    Example::
579        args = ({"x": tensor_x, "others": [tensor_y, tensor_z]})
580
581        dim = torch.export.Dim(...)
582        dynamic_shapes = torch.export.ShapesCollection()
583        dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
584        dynamic_shapes[tensor_y] = {0: dim * 2}
585        # This is equivalent to the following (now auto-generated):
586        # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}
587
588        torch.export(..., args, dynamic_shapes=dynamic_shapes)
589    """
590
591    def __init__(self):
592        self._shapes = {}
593
594    def __setitem__(self, t, shape):
595        assert isinstance(
596            t, torch.Tensor
597        ), f"Cannot assign shape to non-tensor type {type(t)}"
598        # TODO(avik): check that shape is indeed a Shape
599        t_id = id(t)
600        if t_id in self._shapes:
601            _shape = self._shapes[t_id]
602            assert (
603                shape == _shape
604            ), f"Shapes assigned to tensor do not match: expected {_shape}, got {shape}"
605        else:
606            self._shapes[id(t)] = shape
607
608    def __getitem__(self, t):
609        t_id = id(t)
610        if t_id in self._shapes:
611            return self._shapes[t_id]
612        else:
613            return None
614
615    def __len__(self):
616        return len(self._shapes)
617
618    def dynamic_shapes(self, m, args, kwargs=None):
619        """
620        Generate dynamic_shapes.
621        """
622
623        t_ids = set()
624
625        def find_shape(path, t):
626            t_id = id(t)
627            if t_id in self._shapes:
628                t_ids.add(t_id)
629                return self._shapes[t_id]
630            else:
631                return None
632
633        combined_args = _combine_args(m, args, kwargs)
634        dynamic_shapes = _tree_map_with_path(find_shape, combined_args)
635        if any(t_id not in t_ids for t_id in self._shapes):
636            raise ValueError(
637                "Some tensors that were assigned shapes were not found in args. "
638                "Maybe such tensors were copied when passing them as args? "
639                "Maybe such tensors are contained in classes that were not registered with pytree?"
640            )
641        return dynamic_shapes
642
643
644def _warn_on_None_dynamic_shape_dimension():
645    msg = (
646        "Using None as a dynamic shape dimension is deprecated. "
647        "Please use Dim.STATIC instead"
648    )
649    # TODO(avik): raise an error in the future
650    log.warning(msg)
651
652
653def _check_dynamic_shapes(
654    combined_args: Dict[str, Any],
655    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
656):
657    """
658    Checks the dynamic_shapes specification for correctness,
659    using combined args + kwargs as reference for inputs structure.
660    """
661    from torch._dynamo.exc import UserError, UserErrorType
662    from torch._export.non_strict_utils import _flatten_dynamic_shapes
663
664    if dynamic_shapes is None or len(dynamic_shapes) == 0:
665        return
666    if isinstance(dynamic_shapes, (tuple, list)):
667        combined_args = type(dynamic_shapes)(combined_args.values())  # type: ignore[assignment, misc]
668
669    bounds: Dict[str, Tuple[int, int]] = {}
670
671    def check_same_bounds(dim):
672        if dim.__name__ in bounds:
673            min_, max_ = bounds[dim.__name__]
674            if dim.min != min_ or dim.max != max_:
675                this_ = _Dim.readable(dim.__name__, min_, max_)
676                that_ = _Dim.readable(dim.__name__, dim.min, dim.max)
677                raise UserError(
678                    UserErrorType.INVALID_INPUT,
679                    f"Found different definitions {this_} and {that_} "
680                    f"for the same symbolic dimension {dim}!",
681                )
682        else:
683            bounds[dim.__name__] = (dim.min, dim.max)
684
685    def check_symbols(path, tensor, shape):
686        if isinstance(shape, dict):
687            for i, dim in shape.items():
688                if isinstance(dim, _Dim):
689                    check_same_bounds(dim)
690                elif dim is None:
691                    _warn_on_None_dynamic_shape_dimension()
692                elif not (isinstance(dim, (int, _DimHint))):
693                    raise UserError(
694                        UserErrorType.INVALID_INPUT,
695                        f"Unexpected dimension mapped to index {i} in input tensor shape {shape} "
696                        f"specified at `dynamic_shapes{keystr(path)}` "
697                        f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)",
698                        case_name="dynamic_shapes_validation",
699                    )
700        elif isinstance(shape, (tuple, list)):
701            for i, dim in enumerate(shape):
702                if isinstance(dim, _Dim):
703                    check_same_bounds(dim)
704                elif dim is None:
705                    _warn_on_None_dynamic_shape_dimension()
706                elif not (isinstance(dim, (int, _DimHint))):
707                    raise UserError(
708                        UserErrorType.INVALID_INPUT,
709                        f"Unexpected dimension #{i} in input tensor shape {shape} "
710                        f"specified at `dynamic_shapes{keystr(path)}` "
711                        f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)",
712                        case_name="dynamic_shapes_validation",
713                    )
714        elif shape is not None:
715            raise UserError(
716                UserErrorType.INVALID_INPUT,
717                f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` "
718                f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions,"
719                f" where each dimension is an int, a Dim, Dim.AUTO, or Dim.STATIC)",
720                case_name="dynamic_shapes_validation",
721            )
722
723    assert isinstance(dynamic_shapes, (dict, tuple, list))
724    if isinstance(dynamic_shapes, dict):
725        got_keys = list(dynamic_shapes.keys())
726        expected_arg_names = list(combined_args.keys())
727        if sorted(got_keys) != sorted(expected_arg_names):
728            msg = (
729                f"When `dynamic_shapes` is specified as a dict, its top-level keys "
730                f"must be the arg names {expected_arg_names} of `inputs`, but "
731                f"here they are {got_keys}. "
732            )
733            if (
734                len(combined_args) == 1
735                and expected_arg_names[0] not in got_keys
736                and isinstance(combined_args[expected_arg_names[0]], dict)
737            ):
738                msg += (
739                    "Since here `inputs` is a list/tuple enclosing a single dict, "
740                    "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?"
741                )
742            else:
743                msg += (
744                    "Alternatively, you could also ignore arg names entirely "
745                    "and specify `dynamic_shapes` as a list/tuple matching `inputs`."
746                )
747            raise UserError(
748                UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation"
749            )
750
751    def check_shape(path, t, dynamic_shape):
752        if isinstance(t, torch.Tensor):
753            check_symbols(path, t, dynamic_shape)
754        else:
755            if dynamic_shape is not None:
756                rendered_path = keystr(path)
757                raise UserError(
758                    UserErrorType.INVALID_INPUT,
759                    f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` "
760                    f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)",
761                    case_name="dynamic_shapes_validation",
762                )
763
764    _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs")
765
766    # raise user warning if both Dim.AUTO & Dims are specified in dynamic_shapes
767    flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes)
768    flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes)
769    if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any(
770        s == _DimHint.AUTO for s in flatter_dynamic_shapes
771    ):
772        raise UserError(
773            UserErrorType.INVALID_INPUT,
774            "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, "
775            "and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims "
776            "expect all equal or related dimensions to be specified, and does not yet compose well with `Dim.AUTO`. "
777            "We suggest using `Dim.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), "
778            "torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` "
779            "if you want to assert on the exact specification of your program's dynamic shapes behavior.",
780            case_name="dynamic_shapes_validation",
781        )
782
783
784def _transform_shapes_for_default_dynamic(
785    combined_args: Dict[str, Any],
786    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
787) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]:
788    """
789    In the long run this might not be needed, but this exists because export.export() and _dynamo.export()
790    historically have different semantics for how dynamic_shapes are specified, but go through the same
791    process of producing constraints, and now both use assume_static_by_default=False.
792
793    For _dynamo.export(), the semantics for dynamic_shapes are:
794    - None: dynamic, allocated a symbol
795    - Dim/DerivedDim: a strict assertion on the min/max range for this symbol, and require a specification
796      for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.)
797
798    For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are:
799    - Dim.AUTO: dynamic, allocated a symbol
800    - None/unspecified/Dim.STATIC: static
801    - Dim/DerivedDims: also a strict assertion
802
803    To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes
804    for export.export() to be compatible with _process_dynamic_shapes() and assume_static_by_default=False, turning them
805    into essentially what they'd look like for _dynamo.export().
806
807    An example conversion might look like, for a 3-d input tensor:
808
809        input spec: {
810            0: Dim.AUTO,
811            1: None,  # or Dim.STATIC
812            2: Dim("dx"),
813        }
814        output spec: {
815            0: None,  # None: dynamic by default
816            1: 32,  # explicitly provide static shape
817            2: Dim("dx"),  # remains the same
818        }
819    """
820
821    def _tree_map_helper(tree, val):
822        """
823        If the user generally specifies dynamic_shapes=None for a pytree input,
824        we'd like to convert this into a tree of Nones following the input spec,
825        so we can explicitly specify static dims for all tensor dimensions.
826        Non-builtin types for pytree (e.g. custom dataclasses) creates some difficulty,
827        in which case the correct format is a list containing specs for each child attribute.
828        """
829        if (node_type := _get_node_type(tree)) not in SUPPORTED_NODES:  # is_leaf
830            return val
831        flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
832        child_pytrees, context = flatten_fn(tree)  # flatten from whatever original type
833        unflatten_fn = SUPPORTED_NODES[
834            node_type if node_type in BUILTIN_TYPES else list
835        ].unflatten_fn
836        children = [_tree_map_helper(child, val) for child in child_pytrees]
837        return unflatten_fn(
838            children, context
839        )  # unflatten into original type, or list if not built-in type
840
841    if (
842        dynamic_shapes is None or len(dynamic_shapes) == 0
843    ):  # create pytree structure of static dim
844        dynamic_shapes = _tree_map_helper(combined_args, None)
845    if isinstance(dynamic_shapes, (tuple, list)):
846        combined_args = type(dynamic_shapes)(combined_args.values())  # type: ignore[assignment, misc]
847
848    def transform_shapes(path, tensor, shape):
849        def _marked_dynamic(tensor, i):
850            # TODO(pianpwk): deprecate mark_dynamic() usage for export
851            return i in getattr(tensor, "_dynamo_dynamic_indices", set())
852
853        out: Union[None, List[Any], Dict[int, Any]] = None
854        if isinstance(shape, dict):
855            out = {}
856            for i, val in enumerate(tensor.shape):
857                dim = shape.get(i, _DimHint.STATIC)
858                if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO:
859                    # don't have to specify anything if dynamic
860                    # None also works, since assume_static_by_default=False
861                    if dim == _DimHint.AUTO:
862                        torch._dynamo.maybe_mark_dynamic(tensor, i)  # avoid duck sizing
863                    continue
864                elif isinstance(dim, _Dim):
865                    out[i] = dim
866                elif isinstance(dim, int):
867                    # important that this is dim and not val,
868                    # so we can raise error if user-specified dim != val
869                    out[i] = dim
870                elif dim is None:
871                    _warn_on_None_dynamic_shape_dimension()
872                    out[i] = val
873                else:
874                    # make explicitly static
875                    assert dim == _DimHint.STATIC
876                    out[i] = val
877        elif isinstance(shape, (tuple, list)):
878            out = []
879            for i, val in enumerate(tensor.shape):
880                dim = shape[i]
881                if _marked_dynamic(tensor, i) or dim == _DimHint.AUTO:
882                    if dim == _DimHint.AUTO:
883                        torch._dynamo.maybe_mark_dynamic(tensor, i)  # avoid duck sizing
884                    out.append(None)
885                elif isinstance(dim, _Dim):
886                    out.append(dim)
887                elif isinstance(dim, int):
888                    out.append(dim)
889                elif dim is None:
890                    _warn_on_None_dynamic_shape_dimension()
891                    out.append(val)
892                else:
893                    assert dim == _DimHint.STATIC
894                    out.append(val)
895            out = type(shape)(out)  # type: ignore[assignment]
896        else:
897            assert shape is None
898            if isinstance(tensor, torch.Tensor):
899                out = []
900                for i, val in enumerate(tensor.shape):
901                    out.append(None if _marked_dynamic(tensor, i) else val)
902                out = out or None
903            else:
904                out = None
905        return out
906
907    def transform_shape(path, t, dynamic_shape):
908        if isinstance(t, torch.Tensor):
909            return transform_shapes(path, t, dynamic_shape)
910
911    result = _tree_map_with_path(
912        transform_shape, combined_args, dynamic_shapes, tree_name="inputs"
913    )
914    return result
915
916
917def _process_dynamic_shapes(
918    combined_args: Dict[str, Any],
919    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
920) -> List[Constraint]:
921    """
922    Reads the dynamic_shapes specification and produces a list of constraints.
923    """
924    from torch._dynamo.exc import UserError, UserErrorType
925
926    if dynamic_shapes is None or len(dynamic_shapes) == 0:
927        # we run with dynamic by default, so no need to produce constraints
928        return []
929    if isinstance(dynamic_shapes, (tuple, list)):
930        combined_args = type(dynamic_shapes)(combined_args.values())  # type: ignore[assignment, misc]
931
932    # map of Dim names representing input shape dimensions to constraints on them
933    symbols: Dict[str, List[Constraint]] = defaultdict(list)
934    # track roots that do not directly represent input shape dimensions
935    phantom_roots: Dict[str, _PhantomRoot] = {}
936    derived_constraints_with_phantom_root: List[_DerivedConstraint] = []
937
938    def to_constraint(dim, tensor, i):
939        import sympy
940
941        from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
942        from torch.utils._sympy.solve import try_solve
943        from torch.utils._sympy.value_ranges import ValueRanges
944
945        def root_value():
946            # given tensor.shape[i] is the value of dim = fn(root),
947            # find the value of root
948            symbol = sympy.Symbol(dim.root.__name__, integer=True)
949            expr = dim.fn(symbol)
950            solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol)
951            if solution is not None:
952                return int(solution[1])  # type: ignore[call-overload]
953            else:
954                raise UserError(  # noqa: B904
955                    UserErrorType.CONSTRAINT_VIOLATION,
956                    f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be "
957                    f"of the form {expr}, where {symbol} is an integer",
958                )
959
960        if isinstance(dim, _DerivedDim):
961            # generate a _DerivedConstraint where the root is:
962            # - either a _ConstraintTarget (if dim.root directly describes an input shape)
963            # - or a _PhantomRoot (otherwise)
964            dim_root = dim.root  # type: ignore[attr-defined]
965            if dim_root.__name__ in symbols:
966                # root represents an input shape dimension
967                root_constraint = symbols[dim_root.__name__][0]
968                root = _ConstraintTarget(
969                    root_constraint.t_id,
970                    root_constraint.dim,
971                )
972            elif dim_root.__name__ not in phantom_roots:
973                # create a phantom root
974                root = _PhantomRoot(  # type: ignore[assignment]
975                    name=dim_root.__name__,
976                    constraint_range=StrictMinMaxConstraint(
977                        vr=ValueRanges(lower=dim_root.min, upper=dim_root.max),
978                        warn_only=False,
979                    ),
980                    val=root_value(),
981                )
982                phantom_roots[dim_root.__name__] = root  # type: ignore[assignment]
983            else:
984                root = phantom_roots[dim_root.__name__]  # type: ignore[assignment]
985            constraint = _DerivedConstraint(
986                id(tensor),
987                i,
988                dim.__name__,
989                StrictMinMaxConstraint(
990                    vr=ValueRanges(lower=dim.min, upper=dim.max),
991                    warn_only=False,
992                ),
993                root,
994                dim.fn,  # type: ignore[attr-defined]
995            )
996            if isinstance(root, _PhantomRoot):
997                # NOTE(avik): since we have not processed all inputs yet, we may replace this
998                # with a root that does represent an input shape dimension later (see below)
999                derived_constraints_with_phantom_root.append(constraint)
1000        elif isinstance(dim, _StaticDim):
1001            constraint = _Constraint(  # type: ignore[assignment]
1002                id(tensor),
1003                i,
1004                dim.__name__,
1005                StrictMinMaxConstraint(
1006                    vr=ValueRanges(lower=dim.value, upper=dim.value), warn_only=False  # type: ignore[attr-defined]
1007                ),
1008            )
1009        else:
1010            constraint = _Constraint(  # type: ignore[assignment]
1011                id(tensor),
1012                i,
1013                dim.__name__,
1014                StrictMinMaxConstraint(
1015                    vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False  # type: ignore[attr-defined]
1016                ),
1017            )
1018        return constraint
1019
1020    def update_symbols(path, tensor, shape):
1021        def _create_static_dim(tensor, i, value):
1022            return _StaticDim(str(value), (int,), {"value": value})
1023
1024        if isinstance(shape, dict):
1025            for i, dim in shape.items():
1026                if isinstance(dim, (int, _Dim)):
1027                    if isinstance(dim, int):
1028                        dim = _create_static_dim(tensor, i, dim)
1029                    constraint = to_constraint(dim, tensor, i)
1030                    symbols[dim.__name__].append(constraint)
1031        elif isinstance(shape, (tuple, list)):
1032            for i, dim in enumerate(shape):
1033                if isinstance(dim, (int, _Dim)):
1034                    if isinstance(dim, int):
1035                        dim = _create_static_dim(tensor, i, dim)
1036                    constraint = to_constraint(dim, tensor, i)
1037                    symbols[dim.__name__].append(constraint)
1038
1039    def assoc_shape(path, t, dynamic_shape):
1040        if isinstance(t, torch.Tensor):
1041            update_symbols(path, t, dynamic_shape)
1042
1043    _tree_map_with_path(assoc_shape, combined_args, dynamic_shapes, tree_name="inputs")
1044
1045    constraints = []
1046    for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root:
1047        phantom_root_name = derived_constraint_with_phantom_root.root.name  # type: ignore[union-attr]
1048        if phantom_root_name in symbols:
1049            # We found an input shape dimension corresponding to this name, so we
1050            # do not need a phantom symbol for it after all.
1051            # NOTE(avik): Overall we want to maintain the invariant that roots that
1052            # are phantom symbols are really "phantom," i.e., they cannot be represented
1053            # by any input source. This is important when we are deciding derived equalities,
1054            # since we can focus our attention exclusively on input sources: deciding
1055            # derived equalities involving phantom symbols are, in comparison, trivial.
1056            derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0]
1057
1058    for dynamic_dims in symbols.values():
1059        constraints.extend(dynamic_dims)
1060
1061    return constraints  # type: ignore[return-value]
1062
1063
1064def _get_dim_name_mapping(
1065    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None]
1066):
1067    name_to_dim = {}
1068    for dim in tree_flatten(
1069        dynamic_shapes,
1070        is_leaf=lambda x: isinstance(x, _Dim),
1071    )[0]:
1072        if dim is None:
1073            # NOTE: this must denote a non-Tensor or automatic at this point.
1074            continue
1075        if isinstance(dim, int):
1076            continue
1077        assert isinstance(dim, _Dim)  # dim hints should have boiled away
1078        name_to_dim[dim.__name__] = dim
1079        if isinstance(dim, _DerivedDim):
1080            name_to_dim[dim.root.__name__] = dim.root  # type: ignore[attr-defined]
1081    return name_to_dim
1082
1083
1084def refine_dynamic_shapes_from_suggested_fixes(
1085    msg: str,
1086    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
1087) -> Union[Dict[str, Any], Tuple[Any], List[Any]]:
1088    """
1089    For working with export's dynamic shapes suggested fixes, and/or automatic dynamic shapes.
1090    Refines the given dynamic shapes spec, given a ConstraintViolation error message and the original dynamic shapes.
1091
1092    For most cases behavior is straightforward - i.e. for suggested fixes that specialize or refine a Dim's range,
1093    or fixes that suggest a derived relation, the new dynamic shapes spec will be updated as such.
1094
1095    e.g.
1096    Suggested fixes:
1097
1098        dim = Dim('dim', min=3, max=6) -> this just refines the dim's range
1099        dim = 4 -> this specializes to a constant
1100        dy = dx + 1 -> dy was specified as an independent dim, but is actually tied to dx with this relation
1101
1102    However, suggested fixes associated with derived dims can be more complicated.
1103    For example, if a suggested fix is provided for a root dim, the new derived dim value is evaluated based on the root.
1104
1105    e.g.
1106    dx = Dim('dx')
1107    dy = dx + 2
1108    dynamic_shapes = {"x": (dx,), "y": (dy,)}
1109
1110    Suggested fixes:
1111
1112        dx = 4  # specialization will lead to dy also specializing = 6
1113        dx = Dim('dx', max=6)  # dy now has max = 8
1114
1115    Derived dims suggested fixes can also be used to express divisibility constraints.
1116    This involves creating new root dims that aren't tied to a particular input shape.
1117    In this case the root dims won't appear directly in the new spec, but as a root of
1118    one of the dims.
1119
1120    e.g.
1121    Suggested fixes:
1122
1123        _dx = Dim('_dx', max=1024)  # this won't appear in the return result, but dx will
1124        dx = 4*_dx  # dx is now divisible by 4, with a max value of 4096
1125    """
1126
1127    import re
1128
1129    import sympy
1130
1131    from torch._dynamo.exc import UserError, UserErrorType
1132    from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence
1133
1134    try:
1135        shape_fixes_msg = msg.split("Suggested fixes:")[1].strip()
1136    except Exception as exc:
1137        raise UserError(
1138            UserErrorType.INVALID_INPUT,
1139            "Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()",
1140        ) from exc
1141
1142    # build shape_fixes dictionary
1143    shape_fixes = {}
1144    for fix in shape_fixes_msg.split("\n"):
1145        fix = fix.strip()
1146        if match := re.match(r"(.*) = Dim\('(.*)'.*\)", fix):
1147            name = match.group(1)
1148            _min, _max = None, None
1149            if match_min := re.match(r".* = Dim\('.*', min\=([0-9]+).*\)", fix):
1150                _min = int(match_min.group(1))
1151            if match_max := re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)", fix):
1152                _max = int(match_max.group(1))
1153            shape_fixes[name] = Dim(name, min=_min, max=_max)
1154        else:
1155            name, expr = fix.split(" = ")
1156            expr = sympy.sympify(expr)
1157            if isinstance(expr, sympy.Number):
1158                # static, integer
1159                shape_fixes[name] = int(expr)  # type: ignore[assignment]
1160            else:
1161                # relation or derived dim
1162                shape_fixes[name] = expr
1163
1164    name_to_dim = _get_dim_name_mapping(dynamic_shapes)
1165
1166    # track derived dim roots
1167    roots: Set[str] = set()
1168    for k, c in shape_fixes.items():
1169        assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr))
1170        if isinstance(c, sympy.Expr):  # check dim/derived dim expression
1171            assert _is_supported_equivalence(c)
1172            shape_fixes[k] = c
1173            roots.add(str(next(iter(c.free_symbols))))
1174        if isinstance(c, _DerivedDim):
1175            roots.add(c.root.__name__)  # type: ignore[attr-defined]
1176
1177    # check keys are existing dims or new roots
1178    for k, c in shape_fixes.items():
1179        assert k in name_to_dim or k in roots
1180
1181    # cache so we don't produce multiple derived dim objects
1182    derived_dim_cache: Dict[str, _DerivedDim] = {}
1183
1184    def apply_fixes(path, dim, dummy):
1185        if dim is None or isinstance(dim, int):  # not dynamic
1186            return dim
1187        elif dim.__name__ in shape_fixes:  # directly fix
1188            fix = shape_fixes[dim.__name__]
1189            if isinstance(fix, sympy.Expr):  # now derived or related
1190                if str(fix) in derived_dim_cache:
1191                    return derived_dim_cache[str(fix)]
1192                else:
1193                    symbol = next(iter(fix.free_symbols))
1194                    # try to locate symbol
1195                    if symbol.name in shape_fixes:  # type: ignore[attr-defined]
1196                        root = shape_fixes[symbol.name]  # type: ignore[attr-defined]
1197                    else:
1198                        assert symbol.name in name_to_dim  # type: ignore[attr-defined]
1199                        root = name_to_dim[symbol.name]  # type: ignore[attr-defined]
1200                    # figure out value of fix
1201                    modulus, remainder = sympy.polys.polytools.div(fix, symbol)
1202                    dim = root
1203                    if modulus != 1:
1204                        dim = int(modulus) * dim
1205                    if remainder != 0:
1206                        dim = dim + int(remainder)
1207                    derived_dim_cache[str(fix)] = dim
1208                    return dim
1209            else:
1210                return fix
1211        elif isinstance(dim, _DerivedDim) and dim.root.__name__ in shape_fixes:  # type: ignore[attr-defined]
1212            if dim.__name__ in derived_dim_cache:
1213                return derived_dim_cache[dim.__name__]
1214            else:  # evaluate new derived value based on root
1215                _dim = dim.fn(shape_fixes[dim.root.__name__])  # type: ignore[attr-defined]
1216                derived_dim_cache[dim.__name__] = _dim
1217                return _dim
1218        return dim  # unchanged dim
1219
1220    return _tree_map_with_path(apply_fixes, dynamic_shapes, dynamic_shapes)
1221