xref: /aosp_15_r20/external/pytorch/torch/_export/serde/dynamic_shapes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import dataclasses
2from typing import Any, Dict, List, Optional, Tuple, Union
3
4import torch
5from torch._dynamo.exc import UserError, UserErrorType
6from torch.export.dynamic_shapes import (
7    _check_dynamic_shapes,
8    _DerivedDim,
9    _Dim,
10    _DimHint,
11    _tree_map_with_path,
12    Dim,
13)
14from torch.utils._pytree import tree_map
15
16from .serialize import _dataclass_to_dict
17
18
19@dataclasses.dataclass
20class RootDim:
21    """
22    This represents a _Dim object.
23    """
24
25    min: int
26    max: Union[int, None]
27    derived: List[str]
28
29
30@dataclasses.dataclass
31class DynamicShapesSpec:
32    """
33    This stores a dynamic_shapes spec for de/serialization.
34    """
35
36    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None]
37    dims: Dict[str, RootDim]
38
39
40def _postprocess_serialized_shapes(
41    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
42    dims: Dict[str, Dict[str, Union[int, List[str], None]]],
43    to_dict: Optional[bool] = False,
44) -> Union[DynamicShapesSpec, Dict[str, Any]]:
45    """
46    Sorts dims and dumps to dictionary format.
47    """
48    from torch.utils._sympy.numbers import int_oo
49
50    dims = {
51        k: RootDim(
52            min=v["min"],  # type: ignore[arg-type]
53            max=None if v["max"] is int_oo else v["max"],  # type: ignore[arg-type]
54            derived=sorted(v["derived"]),  # type: ignore[arg-type]
55        )
56        for k, v in sorted(dims.items())
57    }
58    spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims)
59    if to_dict:
60        return _dataclass_to_dict(spec)
61    else:
62        return spec
63
64
65def _dump_dynamic_shapes(
66    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
67    args: Tuple[Any],
68    kwargs: Optional[Dict[str, Any]] = None,
69    to_dict: Optional[bool] = False,
70) -> Union[DynamicShapesSpec, Dict[str, Any]]:
71    """
72    Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec.
73    Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims".
74    Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones).
75
76    dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export():
77        - Each tensor input is represented with a list of values, non-tensor inputs with None.
78        - dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings.
79        - static dimensions are represented with ints.
80
81    dims: A dictionary mapping each symbol name to the min/max range and derived dim names.
82
83    For example:
84    ```
85    dx = Dim("dx", min=4, max=16)
86    dy = dx + 1
87
88    inputs = (
89        [
90            torch.randn(4, 4),
91            torch.randn(5, 4),
92        ],
93        torch.randn(4),
94        torch.randn(4, 4),
95        "hello",
96    )
97    dynamic_shapes = {
98        "a": [
99            (dx, 4),
100            (dy, 4),
101        ],
102        "b": (Dim.STATIC,),
103        "c": None,
104        "d": None,
105    }
106    out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True)
107    ```
108    would generate the following output:
109    ```
110    {
111        'dynamic_shapes': (
112            [
113                ['dx', 4],
114                ['dx + 1', 4],
115            ],
116            ['_DimHint.STATIC'],
117            ['_DimHint.STATIC', '_DimHint.STATIC'],
118            None,
119        ),
120        'dims': {
121            'dx': {
122                'min': 4,
123                'max': 16,
124                'derived': ['dx + 1'],
125            },
126        },
127    }
128    ```
129    """
130    dims: Dict[str, Dict[str, Any]] = {}
131
132    def _standardize_shapes(path, tensor, shape):  # type: ignore[no-untyped-def]
133        """
134        Helps standardize the dynamic_shapes tree structure we serialize,
135        returning lists for each tensor shape, handling tensor-level Nones.
136        """
137        if not isinstance(tensor, torch.Tensor):
138            return None
139        if shape is None:
140            return [Dim.STATIC] * len(tensor.shape)  # type: ignore[attr-defined]
141
142        out = []
143        if isinstance(shape, dict):
144            for i, s in enumerate(tensor.shape):
145                out.append(s if shape.get(i) is None else shape.get(i))
146        else:
147            assert isinstance(shape, (tuple, list))
148            for i, s in enumerate(tensor.shape):
149                out.append(s if shape[i] is None else shape[i])
150        return out
151
152    def _track_dim_from_dims(
153        val: Union[None, int, _DimHint, _Dim]
154    ) -> Union[None, int, str]:
155        """
156        Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
157        """
158        if val is None or isinstance(val, int):  # non-tensor input or static
159            return val
160        if isinstance(val, _DimHint):  # store enum as string
161            return val.__class__.__name__ + "." + val.name
162
163        assert isinstance(val, _Dim)
164
165        # track root dim
166        root = val.root if isinstance(val, _DerivedDim) else val  # type: ignore[attr-defined]
167        if root.__name__ not in dims:
168            dims[root.__name__] = {
169                "min": root.min,
170                "max": root.max,
171                "derived": set(),
172            }
173
174        # track derived dims
175        if isinstance(val, _DerivedDim):
176            dims[root.__name__]["derived"].add(val.__name__)
177
178        return val.__name__
179
180    if dynamic_shapes is None:
181        return {"dynamic_shapes": None, "dims": {}}
182
183    # convert to tuple of specs, for each arg/kwarg
184    kwargs = kwargs or {}
185    if isinstance(dynamic_shapes, dict):
186        dynamic_shapes = dynamic_shapes.values()  # type: ignore[assignment]
187    dynamic_shapes = tuple(dynamic_shapes)
188    combined_args = tuple(args) + tuple(kwargs.values())
189
190    # run same check when we're processing shapes for export - is this too lazy?
191    _check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes)  # type: ignore[arg-type]
192
193    tree_shapes = _tree_map_with_path(
194        _standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs"
195    )
196    serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes)
197    return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict)
198
199
200def _load_dynamic_shapes(
201    spec: Union[DynamicShapesSpec, Dict[str, Any]],
202    from_dict: Optional[bool] = False,
203) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]:
204    """
205    Utility function for dynamic shapes serialization.
206    Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export().
207    """
208    import sympy
209
210    from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence
211
212    if from_dict:
213        if not isinstance(spec, dict):
214            raise UserError(
215                UserErrorType.INVALID_INPUT,
216                f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}",
217            )
218        if sorted(spec.keys()) != ["dims", "dynamic_shapes"]:
219            raise UserError(
220                UserErrorType.INVALID_INPUT,
221                "With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, "
222                f"instead found {spec.keys()}",
223            )
224        dims = {}
225        for k, v in spec["dims"].items():
226            if not isinstance(k, str):
227                raise UserError(
228                    UserErrorType.INVALID_INPUT,
229                    f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}",
230                )
231            if sorted(v.keys()) != ["derived", "max", "min"]:
232                raise UserError(
233                    UserErrorType.INVALID_INPUT,
234                    f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, "
235                    f"instead found {v.keys()}",
236                )
237            if not isinstance(v["min"], int):
238                raise UserError(
239                    UserErrorType.INVALID_INPUT,
240                    f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}",
241                )
242            if not isinstance(v["max"], int) or v["max"] is None:
243                raise UserError(
244                    UserErrorType.INVALID_INPUT,
245                    f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}",
246                )
247            if not isinstance(v["derived"], list) or any(
248                not isinstance(d, str) for d in v["derived"]
249            ):
250                raise UserError(
251                    UserErrorType.INVALID_INPUT,
252                    "Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, "
253                    f"got {k}: {v['derived']}",
254                )
255            dims[k] = RootDim(**v)
256        dynamic_shapes = spec["dynamic_shapes"]
257    else:
258        if not isinstance(spec, DynamicShapesSpec):
259            raise UserError(
260                UserErrorType.INVALID_INPUT,
261                f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}",
262            )
263        dims = spec.dims
264        dynamic_shapes = spec.dynamic_shapes
265
266    if dynamic_shapes is None:
267        return None
268
269    dim_cache = {}
270    for name, info in dims.items():
271        symbol = sympy.sympify(name)
272        if not isinstance(symbol, sympy.Symbol):
273            raise UserError(
274                UserErrorType.INVALID_INPUT,
275                f"Expected `spec['dims']` keys to be symbols, got {name}",
276            )
277        dim_cache[name] = Dim(name, min=info.min, max=info.max)  # cache root dim
278        for _expr in info.derived:
279            expr = sympy.sympify(_expr)
280            if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols:
281                raise UserError(
282                    UserErrorType.INVALID_INPUT,
283                    f"Expected derived expressions in to have {name} as the only free symbol, got {expr}",
284                )
285            if not _is_supported_equivalence(expr):
286                raise UserError(
287                    UserErrorType.INVALID_INPUT,
288                    f"Expected derived expressions to be linear expressions, got {expr}",
289                )
290            modulus, remainder = sympy.polys.polytools.div(expr, symbol)
291            ddim = dim_cache[name]
292            if modulus != 1:
293                ddim = int(modulus) * ddim
294            if remainder != 0:
295                ddim = ddim + int(remainder)
296            dim_cache[_expr] = ddim  # cache derived dims
297
298    def deserialize_shape(
299        val: Union[None, int, str]
300    ) -> Union[None, int, _Dim, _DimHint]:
301        if val is None or isinstance(val, int):
302            return val
303        elif val == "_DimHint.AUTO":
304            return _DimHint.AUTO
305        elif val == "_DimHint.STATIC":
306            return _DimHint.STATIC
307        if not isinstance(val, str):
308            raise UserError(
309                UserErrorType.INVALID_INPUT,
310                "Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, "
311                f" or derived expressions, got {val}",
312            )
313        if val not in dim_cache:
314            raise UserError(
315                UserErrorType.INVALID_INPUT,
316                "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, "
317                f"got {val} which is not in {dims.keys()}",
318            )
319        return dim_cache[val]
320
321    return tree_map(deserialize_shape, dynamic_shapes)
322