1import dataclasses 2from typing import Any, Dict, List, Optional, Tuple, Union 3 4import torch 5from torch._dynamo.exc import UserError, UserErrorType 6from torch.export.dynamic_shapes import ( 7 _check_dynamic_shapes, 8 _DerivedDim, 9 _Dim, 10 _DimHint, 11 _tree_map_with_path, 12 Dim, 13) 14from torch.utils._pytree import tree_map 15 16from .serialize import _dataclass_to_dict 17 18 19@dataclasses.dataclass 20class RootDim: 21 """ 22 This represents a _Dim object. 23 """ 24 25 min: int 26 max: Union[int, None] 27 derived: List[str] 28 29 30@dataclasses.dataclass 31class DynamicShapesSpec: 32 """ 33 This stores a dynamic_shapes spec for de/serialization. 34 """ 35 36 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None] 37 dims: Dict[str, RootDim] 38 39 40def _postprocess_serialized_shapes( 41 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], 42 dims: Dict[str, Dict[str, Union[int, List[str], None]]], 43 to_dict: Optional[bool] = False, 44) -> Union[DynamicShapesSpec, Dict[str, Any]]: 45 """ 46 Sorts dims and dumps to dictionary format. 47 """ 48 from torch.utils._sympy.numbers import int_oo 49 50 dims = { 51 k: RootDim( 52 min=v["min"], # type: ignore[arg-type] 53 max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type] 54 derived=sorted(v["derived"]), # type: ignore[arg-type] 55 ) 56 for k, v in sorted(dims.items()) 57 } 58 spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims) 59 if to_dict: 60 return _dataclass_to_dict(spec) 61 else: 62 return spec 63 64 65def _dump_dynamic_shapes( 66 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], 67 args: Tuple[Any], 68 kwargs: Optional[Dict[str, Any]] = None, 69 to_dict: Optional[bool] = False, 70) -> Union[DynamicShapesSpec, Dict[str, Any]]: 71 """ 72 Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec. 73 Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims". 74 Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones). 75 76 dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export(): 77 - Each tensor input is represented with a list of values, non-tensor inputs with None. 78 - dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings. 79 - static dimensions are represented with ints. 80 81 dims: A dictionary mapping each symbol name to the min/max range and derived dim names. 82 83 For example: 84 ``` 85 dx = Dim("dx", min=4, max=16) 86 dy = dx + 1 87 88 inputs = ( 89 [ 90 torch.randn(4, 4), 91 torch.randn(5, 4), 92 ], 93 torch.randn(4), 94 torch.randn(4, 4), 95 "hello", 96 ) 97 dynamic_shapes = { 98 "a": [ 99 (dx, 4), 100 (dy, 4), 101 ], 102 "b": (Dim.STATIC,), 103 "c": None, 104 "d": None, 105 } 106 out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True) 107 ``` 108 would generate the following output: 109 ``` 110 { 111 'dynamic_shapes': ( 112 [ 113 ['dx', 4], 114 ['dx + 1', 4], 115 ], 116 ['_DimHint.STATIC'], 117 ['_DimHint.STATIC', '_DimHint.STATIC'], 118 None, 119 ), 120 'dims': { 121 'dx': { 122 'min': 4, 123 'max': 16, 124 'derived': ['dx + 1'], 125 }, 126 }, 127 } 128 ``` 129 """ 130 dims: Dict[str, Dict[str, Any]] = {} 131 132 def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] 133 """ 134 Helps standardize the dynamic_shapes tree structure we serialize, 135 returning lists for each tensor shape, handling tensor-level Nones. 136 """ 137 if not isinstance(tensor, torch.Tensor): 138 return None 139 if shape is None: 140 return [Dim.STATIC] * len(tensor.shape) # type: ignore[attr-defined] 141 142 out = [] 143 if isinstance(shape, dict): 144 for i, s in enumerate(tensor.shape): 145 out.append(s if shape.get(i) is None else shape.get(i)) 146 else: 147 assert isinstance(shape, (tuple, list)) 148 for i, s in enumerate(tensor.shape): 149 out.append(s if shape[i] is None else shape[i]) 150 return out 151 152 def _track_dim_from_dims( 153 val: Union[None, int, _DimHint, _Dim] 154 ) -> Union[None, int, str]: 155 """ 156 Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec. 157 """ 158 if val is None or isinstance(val, int): # non-tensor input or static 159 return val 160 if isinstance(val, _DimHint): # store enum as string 161 return val.__class__.__name__ + "." + val.name 162 163 assert isinstance(val, _Dim) 164 165 # track root dim 166 root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined] 167 if root.__name__ not in dims: 168 dims[root.__name__] = { 169 "min": root.min, 170 "max": root.max, 171 "derived": set(), 172 } 173 174 # track derived dims 175 if isinstance(val, _DerivedDim): 176 dims[root.__name__]["derived"].add(val.__name__) 177 178 return val.__name__ 179 180 if dynamic_shapes is None: 181 return {"dynamic_shapes": None, "dims": {}} 182 183 # convert to tuple of specs, for each arg/kwarg 184 kwargs = kwargs or {} 185 if isinstance(dynamic_shapes, dict): 186 dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment] 187 dynamic_shapes = tuple(dynamic_shapes) 188 combined_args = tuple(args) + tuple(kwargs.values()) 189 190 # run same check when we're processing shapes for export - is this too lazy? 191 _check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type] 192 193 tree_shapes = _tree_map_with_path( 194 _standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs" 195 ) 196 serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes) 197 return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict) 198 199 200def _load_dynamic_shapes( 201 spec: Union[DynamicShapesSpec, Dict[str, Any]], 202 from_dict: Optional[bool] = False, 203) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]: 204 """ 205 Utility function for dynamic shapes serialization. 206 Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export(). 207 """ 208 import sympy 209 210 from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence 211 212 if from_dict: 213 if not isinstance(spec, dict): 214 raise UserError( 215 UserErrorType.INVALID_INPUT, 216 f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}", 217 ) 218 if sorted(spec.keys()) != ["dims", "dynamic_shapes"]: 219 raise UserError( 220 UserErrorType.INVALID_INPUT, 221 "With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, " 222 f"instead found {spec.keys()}", 223 ) 224 dims = {} 225 for k, v in spec["dims"].items(): 226 if not isinstance(k, str): 227 raise UserError( 228 UserErrorType.INVALID_INPUT, 229 f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}", 230 ) 231 if sorted(v.keys()) != ["derived", "max", "min"]: 232 raise UserError( 233 UserErrorType.INVALID_INPUT, 234 f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, " 235 f"instead found {v.keys()}", 236 ) 237 if not isinstance(v["min"], int): 238 raise UserError( 239 UserErrorType.INVALID_INPUT, 240 f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}", 241 ) 242 if not isinstance(v["max"], int) or v["max"] is None: 243 raise UserError( 244 UserErrorType.INVALID_INPUT, 245 f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}", 246 ) 247 if not isinstance(v["derived"], list) or any( 248 not isinstance(d, str) for d in v["derived"] 249 ): 250 raise UserError( 251 UserErrorType.INVALID_INPUT, 252 "Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, " 253 f"got {k}: {v['derived']}", 254 ) 255 dims[k] = RootDim(**v) 256 dynamic_shapes = spec["dynamic_shapes"] 257 else: 258 if not isinstance(spec, DynamicShapesSpec): 259 raise UserError( 260 UserErrorType.INVALID_INPUT, 261 f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}", 262 ) 263 dims = spec.dims 264 dynamic_shapes = spec.dynamic_shapes 265 266 if dynamic_shapes is None: 267 return None 268 269 dim_cache = {} 270 for name, info in dims.items(): 271 symbol = sympy.sympify(name) 272 if not isinstance(symbol, sympy.Symbol): 273 raise UserError( 274 UserErrorType.INVALID_INPUT, 275 f"Expected `spec['dims']` keys to be symbols, got {name}", 276 ) 277 dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim 278 for _expr in info.derived: 279 expr = sympy.sympify(_expr) 280 if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols: 281 raise UserError( 282 UserErrorType.INVALID_INPUT, 283 f"Expected derived expressions in to have {name} as the only free symbol, got {expr}", 284 ) 285 if not _is_supported_equivalence(expr): 286 raise UserError( 287 UserErrorType.INVALID_INPUT, 288 f"Expected derived expressions to be linear expressions, got {expr}", 289 ) 290 modulus, remainder = sympy.polys.polytools.div(expr, symbol) 291 ddim = dim_cache[name] 292 if modulus != 1: 293 ddim = int(modulus) * ddim 294 if remainder != 0: 295 ddim = ddim + int(remainder) 296 dim_cache[_expr] = ddim # cache derived dims 297 298 def deserialize_shape( 299 val: Union[None, int, str] 300 ) -> Union[None, int, _Dim, _DimHint]: 301 if val is None or isinstance(val, int): 302 return val 303 elif val == "_DimHint.AUTO": 304 return _DimHint.AUTO 305 elif val == "_DimHint.STATIC": 306 return _DimHint.STATIC 307 if not isinstance(val, str): 308 raise UserError( 309 UserErrorType.INVALID_INPUT, 310 "Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, " 311 f" or derived expressions, got {val}", 312 ) 313 if val not in dim_cache: 314 raise UserError( 315 UserErrorType.INVALID_INPUT, 316 "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, " 317 f"got {val} which is not in {dims.keys()}", 318 ) 319 return dim_cache[val] 320 321 return tree_map(deserialize_shape, dynamic_shapes) 322