1# mypy: allow-untyped-defs 2import contextlib 3import inspect 4import logging 5from collections import defaultdict 6from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union 7 8import torch 9import torch.utils._pytree as pytree 10from torch._dynamo.source import ( 11 AttrSource, 12 GetItemSource, 13 LocalSource, 14 TensorProperty, 15 TensorPropertySource, 16) 17from torch._dynamo.variables.builder import TrackedFake 18from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim 19from torch._export.passes.lift_constants_pass import ConstantAttrMap 20from torch._guards import Source 21from torch._library.fake_class_registry import FakeScriptObject 22from torch._subclasses.fake_tensor import FakeTensorMode 23from torch.export import Constraint 24from torch.export.dynamic_shapes import ( 25 _check_dynamic_shapes, 26 _combine_args, 27 _DimHint, 28 _process_dynamic_shapes, 29 _transform_shapes_for_default_dynamic, 30 _tree_map_with_path, 31) 32from torch.export.graph_signature import CustomObjArgument 33from torch.fx.experimental import _config as config 34from torch.fx.experimental.symbolic_shapes import ( 35 _find_user_code_frame, 36 _suggest_fixes_for_data_dependent_error_non_strict, 37 ConstraintViolationError, 38 DimDynamic, 39 EqualityConstraint, 40 GuardOnDataDependentSymNode, 41 ShapeEnv, 42 StatelessSymbolicContext, 43 ValueRanges, 44) 45from torch.utils._pytree import ( 46 GetAttrKey, 47 KeyPath, 48 MappingKey, 49 SequenceKey, 50 tree_map_with_path, 51) 52 53 54if TYPE_CHECKING: 55 from sympy import Symbol 56 57 58log = logging.getLogger(__name__) 59 60 61def key_path_to_source(kp: KeyPath) -> Source: 62 """ 63 Given a key path, return the source for the key path. 64 """ 65 source: Source = LocalSource("args") 66 for k in kp: 67 if isinstance(k, SequenceKey): 68 source = GetItemSource(source, k.idx) 69 elif isinstance(k, MappingKey): 70 source = GetItemSource(source, k.key) 71 elif isinstance(k, GetAttrKey): 72 source = AttrSource(source, k.name) 73 else: 74 raise ValueError(f"Unknown KeyEntry {k}") 75 76 return source 77 78 79def _is_constant_argument(t): 80 return t is None or isinstance(t, (int, float, bool, str)) 81 82 83def fakify( 84 mode: FakeTensorMode, 85 kp: KeyPath, 86 t: Any, 87 t_constraints: Dict[int, Dict[int, Constraint]], 88 sources: Dict[Tuple[int, int], List[Source]], 89): 90 source = key_path_to_source(kp) 91 if _is_constant_argument(t) or isinstance(t, torch.ScriptObject): 92 return t 93 94 if not isinstance(t, torch.Tensor): 95 raise ValueError(f"Unsupported input type {type(t)}") 96 n_dims = len(t.shape) 97 symbolic_context = StatelessSymbolicContext( 98 dynamic_sizes=[DimDynamic.DYNAMIC] * n_dims, 99 constraint_sizes=[None] * n_dims, 100 ) 101 t_id = id(t) 102 assert mode.shape_env is not None 103 if t_id in t_constraints: 104 for i, constraint in t_constraints[t_id].items(): 105 symbolic_context.constraint_sizes[i] = constraint.constraint_range 106 src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i) 107 sources[(t_id, i)].append(src) 108 mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment] 109 fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context) 110 mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr] 111 return fake 112 113 114def make_fake_inputs( 115 nn_module, 116 args, 117 kwargs, 118 dynamic_shapes, 119 _is_torch_jit_trace=False, 120 allow_complex_guards_as_runtime_asserts=False, 121): 122 """ 123 Given an nn module, example inputs, and constraints, return a new fake mode, 124 fake inputs created in that mode whose dynamic shape dimensions are constrained 125 by the given ranges, and sources for pairs of dynamic shape dimensions that are 126 constrained to be equal. 127 """ 128 # TODO(avik): refactor Dynamo to avoid duplication of the following code 129 # between non-strict and strict. 130 # Specifically, here (non-strict) we do the following pre-tracing steps: 131 # - Fakify inputs. 132 # - Process input shape equalities. 133 # In strict, these steps are spread across multiple files: 134 # - output_graph.py fakifies inputs. 135 # - [post-tracing] guards.py processes input shape equalities. 136 137 combined_args = _combine_args(nn_module, args, kwargs) 138 _check_dynamic_shapes(combined_args, dynamic_shapes) 139 transformed_dynamic_shapes = _transform_shapes_for_default_dynamic( 140 combined_args, dynamic_shapes 141 ) 142 constraints = _process_dynamic_shapes(combined_args, transformed_dynamic_shapes) 143 t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) 144 for constraint in constraints: 145 t_constraints[constraint.t_id][constraint.dim] = constraint 146 147 context = torch._guards.TracingContext.try_get() 148 if context is not None: 149 # This occurs when we are exporting within dynamo. There already exists 150 # a toplevel TracingContext with a fake mode, so we do not want to 151 # create another fake mode. 152 fake_mode = context.fake_mode 153 elif not _is_torch_jit_trace: 154 code = nn_module.forward.__code__ 155 co_fields = { 156 "co_name": code.co_name, 157 "co_filename": code.co_filename, 158 "co_firstlineno": code.co_firstlineno, 159 } 160 fake_mode = FakeTensorMode( 161 shape_env=ShapeEnv( 162 tracked_fakes=[], 163 co_fields=co_fields, 164 prefer_deferred_runtime_asserts_over_guards=True, 165 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, 166 ), 167 allow_non_fake_inputs=True, 168 export=True, 169 ) 170 else: 171 fake_mode = FakeTensorMode( 172 shape_env=ShapeEnv( 173 tracked_fakes=[], 174 prefer_deferred_runtime_asserts_over_guards=True, 175 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, 176 ), 177 allow_non_fake_inputs=True, 178 ) 179 if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: 180 raise ValueError( 181 "Detected fake_mode does not have a shape_env with tracked fakes. " 182 "If you constructed the module under a FakeTensorMode, " 183 "please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))" 184 ) 185 186 with fake_mode: 187 # FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock 188 if not _is_torch_jit_trace: 189 original_signature = inspect.signature(nn_module.forward) 190 else: 191 original_signature = None 192 sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list) 193 fake_args, fake_kwargs = tree_map_with_path( 194 lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), 195 (args, kwargs), 196 ) 197 198 names: Dict[str, Tuple[int, int]] = {} 199 source_pairs: List[Tuple[Source, Source]] = [] 200 derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = [] 201 phantom_symbols: Dict[str, Symbol] = {} 202 for constraint in constraints: 203 torch.export.dynamic_shapes._process_equalities( 204 constraint, 205 lambda t_id, dim: sources[(t_id, dim)], 206 fake_mode.shape_env, 207 names, 208 source_pairs, 209 derived_equalities, 210 phantom_symbols, 211 ) 212 213 equalities_inputs = EqualityConstraint( 214 source_pairs=source_pairs, 215 derived_equalities=derived_equalities, 216 phantom_symbols=list(phantom_symbols.values()), 217 warn_only=False, 218 ) 219 return ( 220 fake_mode, 221 fake_args, 222 fake_kwargs, 223 equalities_inputs, 224 original_signature, 225 transformed_dynamic_shapes, 226 ) 227 228 229def _flatten_dynamic_shapes( 230 combined_args: Dict[str, Any], 231 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], 232) -> List[Any]: 233 flat_shapes = [] 234 235 def _tree_map_helper(path, t, shape): 236 nonlocal flat_shapes 237 flat_shapes.append(shape) 238 239 _tree_map_with_path(_tree_map_helper, combined_args, dynamic_shapes) 240 return flat_shapes 241 242 243def produce_guards_and_solve_constraints( 244 fake_mode: FakeTensorMode, 245 gm: torch.fx.GraphModule, 246 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], 247 equalities_inputs: EqualityConstraint, 248 original_signature: inspect.Signature, 249 _is_torch_jit_trace=False, 250): 251 """ 252 Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, 253 and a graph module, produce guards on the fake mode's shape env (raising constraint 254 violations if any), solve (to suggest simplifications or fixes). 255 Dynamo already performs this, so this is for non-strict mode. 256 257 Additional inputs: 258 equalities_inputs: the equality constraints to use for guards 259 original_signature: the signature of the forward method 260 """ 261 shape_env = fake_mode.shape_env 262 assert shape_env is not None 263 assert shape_env.tracked_fakes is not None 264 265 placeholders = [tf.fake for tf in shape_env.tracked_fakes] 266 sources = [tf.source for tf in shape_env.tracked_fakes] 267 input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes] 268 constraint_violation_error = None 269 try: 270 shape_env.produce_guards( 271 placeholders, 272 sources, 273 input_contexts=input_contexts, 274 equalities_inputs=equalities_inputs, 275 ignore_static=False, 276 ) 277 except ConstraintViolationError as e: 278 constraint_violation_error = e 279 280 shape_env.frozen = True 281 dim_constraints = shape_env.dim_constraints 282 if dim_constraints is None: 283 # Expected when shape_env.produce_guards throws an early constraint violation error. 284 # There is nothing to solve for in this case. 285 # TODO(avik): Maybe record the constraint violation error instead and replay later? 286 assert constraint_violation_error 287 raise constraint_violation_error 288 dim_constraints.solve() 289 forced_specializations = dim_constraints.forced_specializations() 290 if not _is_torch_jit_trace: 291 msg = dim_constraints.prettify_results( 292 original_signature, 293 dynamic_shapes, 294 constraint_violation_error, 295 forced_specializations, 296 ) 297 else: 298 # FIXME(ycao): This is a hack to get around missing signature from ScriptMethod 299 msg = "dummy constraint violation message" 300 if constraint_violation_error: 301 constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) 302 elif forced_specializations: 303 constraint_violation_error = ConstraintViolationError(msg) 304 if constraint_violation_error: 305 raise constraint_violation_error 306 307 308def make_constraints( 309 fake_mode: FakeTensorMode, 310 gm: torch.fx.GraphModule, 311 combined_args: Dict[str, Any], 312 dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], 313 num_lifted_inputs: int, 314): 315 """ 316 Given a fake mode's shape env and user-specified dynamic shapes, 317 return the resulting range constraints and equality constraints. 318 319 Additional args: 320 num_lifted_inputs: the number of non-user-input placeholder nodes in the graph 321 (used only to enumerate the user-input nodes) 322 """ 323 324 shape_env = fake_mode.shape_env 325 assert shape_env is not None 326 inline_constraints = gm.meta.get("inline_constraints", []) 327 range_constraints = { 328 symbol: inline_constraints[symbol] for symbol in inline_constraints 329 } 330 if not dynamic_shapes: 331 return range_constraints 332 333 # get individual dynamic shapes spec for each input 334 if not isinstance(dynamic_shapes, dict): 335 assert isinstance(dynamic_shapes, (tuple, list)) 336 combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] 337 flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) 338 339 # check number of shapes vs. number of inputs 340 num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True) 341 assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs 342 343 input_dims = defaultdict(list) 344 free_symbols = set() 345 for input_index, node in enumerate(gm.graph.nodes): 346 if input_index < num_lifted_inputs or node.op != "placeholder": 347 continue 348 if _is_constant_argument(node.meta["val"]) or isinstance( 349 node.meta["val"], CustomObjArgument 350 ): 351 continue 352 shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs] 353 for i, d in enumerate(node.meta["val"].shape): 354 if isinstance(d, torch.SymInt) and not d.node.expr.is_number: 355 # Look up the range constraint for the symbol corresponding to this shape dimension 356 # and store it indexed by the symbolic expression corresponding to it. 357 # NOTE(avik): Use node._expr instead of node.expr for the lookup here because 358 # we want the symbol, not its replacement, which could be an expression. Maybe 359 # there's a better way to do this, e.g., by (re)computing value ranges for expressions? 360 dim = shape_spec[i] if shape_spec else None 361 if dim is None or isinstance(dim, _DimHint): 362 range_constraints[d.node.expr] = shape_env.var_to_range[ 363 d.node._expr 364 ] 365 else: 366 range_constraints[d.node.expr] = ValueRanges( 367 lower=dim.min, upper=dim.max 368 ) 369 input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i)) 370 free_symbols.update(d.node.expr.free_symbols) 371 372 for symbol in free_symbols: 373 if symbol not in range_constraints: 374 # Placeholders can have symbolic shapes that are derived expressions. 375 # The above code will record direct range constraints for them 376 # so that we can do runtime assertions. In addition, for serde checks 377 # we want to record range constraints for their root symbols. 378 range_constraints[symbol] = shape_env.var_to_range[symbol] 379 380 return range_constraints 381 382 383def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap: 384 """Search the module hierarchy, gathering up all tensor and ScriptObject constants. 385 386 Returns a dictionary mapping hash(value) to the name of the constant. We 387 have to abuse `hash` here unfortunately, see: [ScriptObject hash]. 388 """ 389 constants = ConstantAttrMap() 390 buffers_parameters = set(m.buffers()) 391 buffers_parameters.update(m.parameters()) 392 393 def inner(m: torch.nn.Module, prefix_atoms: List[str], constants): 394 for k, v in m.__dict__.items(): 395 if isinstance( 396 v, 397 ( 398 torch.Tensor, 399 torch.ScriptObject, 400 FakeScriptObject, 401 ), 402 ): 403 if v in buffers_parameters: 404 # filter out buffers and parameters, leaving only constants 405 continue 406 407 fqn = ".".join(prefix_atoms + [k]) 408 constants.add(v, fqn) 409 for k, v in m.named_children(): 410 inner(v, prefix_atoms + [k], constants) 411 412 inner(m, [], constants) 413 return constants 414 415 416@contextlib.contextmanager 417def _fakify_script_objects( 418 mod: torch.nn.Module, 419 args: Tuple[Any], 420 kwargs: Dict[Any, Any], 421 fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, 422): 423 # This context manager is used to fakify script objects into FakeScriptObject. 424 # Inputs: 425 # mod: the module to be exported, it (and its recursive submodules)'s script object attrs haven't been fakified. 426 # args, kwargs: the args and kwargs inputs for mod, script object inputs haven't been fakified. 427 # fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors. 428 # 429 # Returns: 430 # mod: the patched module, its (and its recursive submodules) script object attrs have been fakified. 431 # fake_args, fake_kwargs: new fakified args and kwargs. 432 # Script object inputs have been fakified. Don't touch the tensors. 433 # fake_constant_attrs: a new map from FakeScriptObject to the fqn of the original script object. 434 # fake_to_real: a mapping between FakeScriptObject and the original script object in order to un-do the patching. 435 436 constant_attrs: ConstantAttrMap = _gather_constant_attrs(mod) 437 assert not any( 438 isinstance(obj, FakeScriptObject) for obj in constant_attrs.values() 439 ), "Mod shouldn't contain any FakeScriptObject." 440 assert not pytree.tree_any( 441 lambda obj: isinstance(obj, FakeScriptObject), (args, kwargs) 442 ), "args and kwargs shouldn't contain any FakeScriptObject." 443 444 patched_attr = {} 445 fake_constant_attrs = ConstantAttrMap() 446 fake_to_real = {} 447 448 def _maybe_fakify_obj(obj): 449 fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj) 450 fake_to_real[fake_obj] = obj 451 return fake_obj 452 453 def _leaf_mod_and_attr( 454 mod: torch.nn.Module, attr_fqn: str 455 ) -> Tuple[torch.nn.Module, str]: 456 *prefix_attr, last_attr = attr_fqn.split(".") 457 cur_mod = mod 458 for attr in prefix_attr: 459 cur_mod = getattr(cur_mod, attr) 460 return cur_mod, last_attr 461 462 try: 463 for obj, fqns in constant_attrs.items(): 464 if isinstance(obj, torch.ScriptObject): 465 fake_script_obj = _maybe_fakify_obj(obj) 466 for fqn in fqns: 467 cur_mod, attr = _leaf_mod_and_attr(mod, fqn) 468 assert obj is getattr(cur_mod, attr) 469 setattr(cur_mod, attr, fake_script_obj) 470 fake_constant_attrs.add(fake_script_obj, fqn) 471 patched_attr[fqn] = obj 472 else: 473 for fqn in fqns: 474 fake_constant_attrs.add(obj, fqn) 475 476 fake_args, fake_kwargs = pytree.tree_map_only( 477 torch.ScriptObject, _maybe_fakify_obj, (args, kwargs) 478 ) 479 yield (mod, fake_args, fake_kwargs, fake_constant_attrs, fake_to_real) 480 finally: 481 for fqn, orig_obj in patched_attr.items(): 482 cur_mod, attr = _leaf_mod_and_attr(mod, fqn) 483 setattr(cur_mod, attr, orig_obj) 484 485 486class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): 487 """ 488 1. Handles data-dependent errors raised by torch function calls in non-strict. 489 490 Any data-dependent error is due to some condition on unbacked symints 491 that cannot be resolved. A mechanical way of fixing the error is to use 492 a torch._check() call to assert either that condition or its negation. 493 The handler suggests these options as code and points to the location 494 of the torch function call that raised the error as part of the error 495 message shown to the user, who can then simply select and copy-paste 496 a suggested fix at that location. 497 498 NOTE: Not all data-dependent errors are raised by torch function calls. 499 In particular, conditions on unbacked symints can appear outside such 500 calls, and as such are not handled here. 501 502 2. Handles line-of-code logging for each torch function call in non-strict. 503 504 Usage: TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ... 505 """ 506 507 def __torch_function__(self, func, types, args=(), kwargs=None): 508 kwargs = kwargs or {} 509 if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: 510 frame = _find_user_code_frame() 511 if frame is not None: 512 log.debug( 513 "%s called at %s:%s in %s", 514 func.__qualname__, 515 frame.f_code.co_filename, 516 frame.f_lineno, 517 frame.f_code.co_name, 518 ) 519 try: 520 return func(*args, **kwargs) 521 except GuardOnDataDependentSymNode as e: 522 _suggest_fixes_for_data_dependent_error_non_strict(e) 523 raise 524