xref: /aosp_15_r20/external/pytorch/torch/_export/non_strict_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import inspect
4import logging
5from collections import defaultdict
6from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union
7
8import torch
9import torch.utils._pytree as pytree
10from torch._dynamo.source import (
11    AttrSource,
12    GetItemSource,
13    LocalSource,
14    TensorProperty,
15    TensorPropertySource,
16)
17from torch._dynamo.variables.builder import TrackedFake
18from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim
19from torch._export.passes.lift_constants_pass import ConstantAttrMap
20from torch._guards import Source
21from torch._library.fake_class_registry import FakeScriptObject
22from torch._subclasses.fake_tensor import FakeTensorMode
23from torch.export import Constraint
24from torch.export.dynamic_shapes import (
25    _check_dynamic_shapes,
26    _combine_args,
27    _DimHint,
28    _process_dynamic_shapes,
29    _transform_shapes_for_default_dynamic,
30    _tree_map_with_path,
31)
32from torch.export.graph_signature import CustomObjArgument
33from torch.fx.experimental import _config as config
34from torch.fx.experimental.symbolic_shapes import (
35    _find_user_code_frame,
36    _suggest_fixes_for_data_dependent_error_non_strict,
37    ConstraintViolationError,
38    DimDynamic,
39    EqualityConstraint,
40    GuardOnDataDependentSymNode,
41    ShapeEnv,
42    StatelessSymbolicContext,
43    ValueRanges,
44)
45from torch.utils._pytree import (
46    GetAttrKey,
47    KeyPath,
48    MappingKey,
49    SequenceKey,
50    tree_map_with_path,
51)
52
53
54if TYPE_CHECKING:
55    from sympy import Symbol
56
57
58log = logging.getLogger(__name__)
59
60
61def key_path_to_source(kp: KeyPath) -> Source:
62    """
63    Given a key path, return the source for the key path.
64    """
65    source: Source = LocalSource("args")
66    for k in kp:
67        if isinstance(k, SequenceKey):
68            source = GetItemSource(source, k.idx)
69        elif isinstance(k, MappingKey):
70            source = GetItemSource(source, k.key)
71        elif isinstance(k, GetAttrKey):
72            source = AttrSource(source, k.name)
73        else:
74            raise ValueError(f"Unknown KeyEntry {k}")
75
76    return source
77
78
79def _is_constant_argument(t):
80    return t is None or isinstance(t, (int, float, bool, str))
81
82
83def fakify(
84    mode: FakeTensorMode,
85    kp: KeyPath,
86    t: Any,
87    t_constraints: Dict[int, Dict[int, Constraint]],
88    sources: Dict[Tuple[int, int], List[Source]],
89):
90    source = key_path_to_source(kp)
91    if _is_constant_argument(t) or isinstance(t, torch.ScriptObject):
92        return t
93
94    if not isinstance(t, torch.Tensor):
95        raise ValueError(f"Unsupported input type {type(t)}")
96    n_dims = len(t.shape)
97    symbolic_context = StatelessSymbolicContext(
98        dynamic_sizes=[DimDynamic.DYNAMIC] * n_dims,
99        constraint_sizes=[None] * n_dims,
100    )
101    t_id = id(t)
102    assert mode.shape_env is not None
103    if t_id in t_constraints:
104        for i, constraint in t_constraints[t_id].items():
105            symbolic_context.constraint_sizes[i] = constraint.constraint_range
106            src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i)
107            sources[(t_id, i)].append(src)
108            mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name  # type: ignore[assignment]
109    fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
110    mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context))  # type: ignore[union-attr]
111    return fake
112
113
114def make_fake_inputs(
115    nn_module,
116    args,
117    kwargs,
118    dynamic_shapes,
119    _is_torch_jit_trace=False,
120    allow_complex_guards_as_runtime_asserts=False,
121):
122    """
123    Given an nn module, example inputs, and constraints, return a new fake mode,
124    fake inputs created in that mode whose dynamic shape dimensions are constrained
125    by the given ranges, and sources for pairs of dynamic shape dimensions that are
126    constrained to be equal.
127    """
128    # TODO(avik): refactor Dynamo to avoid duplication of the following code
129    # between non-strict and strict.
130    # Specifically, here (non-strict) we do the following pre-tracing steps:
131    #   - Fakify inputs.
132    #   - Process input shape equalities.
133    # In strict, these steps are spread across multiple files:
134    #   - output_graph.py fakifies inputs.
135    #   - [post-tracing] guards.py processes input shape equalities.
136
137    combined_args = _combine_args(nn_module, args, kwargs)
138    _check_dynamic_shapes(combined_args, dynamic_shapes)
139    transformed_dynamic_shapes = _transform_shapes_for_default_dynamic(
140        combined_args, dynamic_shapes
141    )
142    constraints = _process_dynamic_shapes(combined_args, transformed_dynamic_shapes)
143    t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict)
144    for constraint in constraints:
145        t_constraints[constraint.t_id][constraint.dim] = constraint
146
147    context = torch._guards.TracingContext.try_get()
148    if context is not None:
149        # This occurs when we are exporting within dynamo. There already exists
150        # a toplevel TracingContext with a fake mode, so we do not want to
151        # create another fake mode.
152        fake_mode = context.fake_mode
153    elif not _is_torch_jit_trace:
154        code = nn_module.forward.__code__
155        co_fields = {
156            "co_name": code.co_name,
157            "co_filename": code.co_filename,
158            "co_firstlineno": code.co_firstlineno,
159        }
160        fake_mode = FakeTensorMode(
161            shape_env=ShapeEnv(
162                tracked_fakes=[],
163                co_fields=co_fields,
164                prefer_deferred_runtime_asserts_over_guards=True,
165                allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
166            ),
167            allow_non_fake_inputs=True,
168            export=True,
169        )
170    else:
171        fake_mode = FakeTensorMode(
172            shape_env=ShapeEnv(
173                tracked_fakes=[],
174                prefer_deferred_runtime_asserts_over_guards=True,
175                allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
176            ),
177            allow_non_fake_inputs=True,
178        )
179    if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None:
180        raise ValueError(
181            "Detected fake_mode does not have a shape_env with tracked fakes. "
182            "If you constructed the module under a FakeTensorMode, "
183            "please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))"
184        )
185
186    with fake_mode:
187        # FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock
188        if not _is_torch_jit_trace:
189            original_signature = inspect.signature(nn_module.forward)
190        else:
191            original_signature = None
192        sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list)
193        fake_args, fake_kwargs = tree_map_with_path(
194            lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
195            (args, kwargs),
196        )
197
198        names: Dict[str, Tuple[int, int]] = {}
199        source_pairs: List[Tuple[Source, Source]] = []
200        derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = []
201        phantom_symbols: Dict[str, Symbol] = {}
202        for constraint in constraints:
203            torch.export.dynamic_shapes._process_equalities(
204                constraint,
205                lambda t_id, dim: sources[(t_id, dim)],
206                fake_mode.shape_env,
207                names,
208                source_pairs,
209                derived_equalities,
210                phantom_symbols,
211            )
212
213        equalities_inputs = EqualityConstraint(
214            source_pairs=source_pairs,
215            derived_equalities=derived_equalities,
216            phantom_symbols=list(phantom_symbols.values()),
217            warn_only=False,
218        )
219        return (
220            fake_mode,
221            fake_args,
222            fake_kwargs,
223            equalities_inputs,
224            original_signature,
225            transformed_dynamic_shapes,
226        )
227
228
229def _flatten_dynamic_shapes(
230    combined_args: Dict[str, Any],
231    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
232) -> List[Any]:
233    flat_shapes = []
234
235    def _tree_map_helper(path, t, shape):
236        nonlocal flat_shapes
237        flat_shapes.append(shape)
238
239    _tree_map_with_path(_tree_map_helper, combined_args, dynamic_shapes)
240    return flat_shapes
241
242
243def produce_guards_and_solve_constraints(
244    fake_mode: FakeTensorMode,
245    gm: torch.fx.GraphModule,
246    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
247    equalities_inputs: EqualityConstraint,
248    original_signature: inspect.Signature,
249    _is_torch_jit_trace=False,
250):
251    """
252    Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions,
253    and a graph module, produce guards on the fake mode's shape env (raising constraint
254    violations if any), solve (to suggest simplifications or fixes).
255    Dynamo already performs this, so this is for non-strict mode.
256
257    Additional inputs:
258        equalities_inputs: the equality constraints to use for guards
259        original_signature: the signature of the forward method
260    """
261    shape_env = fake_mode.shape_env
262    assert shape_env is not None
263    assert shape_env.tracked_fakes is not None
264
265    placeholders = [tf.fake for tf in shape_env.tracked_fakes]
266    sources = [tf.source for tf in shape_env.tracked_fakes]
267    input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes]
268    constraint_violation_error = None
269    try:
270        shape_env.produce_guards(
271            placeholders,
272            sources,
273            input_contexts=input_contexts,
274            equalities_inputs=equalities_inputs,
275            ignore_static=False,
276        )
277    except ConstraintViolationError as e:
278        constraint_violation_error = e
279
280    shape_env.frozen = True
281    dim_constraints = shape_env.dim_constraints
282    if dim_constraints is None:
283        # Expected when shape_env.produce_guards throws an early constraint violation error.
284        # There is nothing to solve for in this case.
285        # TODO(avik): Maybe record the constraint violation error instead and replay later?
286        assert constraint_violation_error
287        raise constraint_violation_error
288    dim_constraints.solve()
289    forced_specializations = dim_constraints.forced_specializations()
290    if not _is_torch_jit_trace:
291        msg = dim_constraints.prettify_results(
292            original_signature,
293            dynamic_shapes,
294            constraint_violation_error,
295            forced_specializations,
296        )
297    else:
298        # FIXME(ycao): This is a hack to get around missing signature from ScriptMethod
299        msg = "dummy constraint violation message"
300    if constraint_violation_error:
301        constraint_violation_error.args = (constraint_violation_error.args[0] + msg,)
302    elif forced_specializations:
303        constraint_violation_error = ConstraintViolationError(msg)
304    if constraint_violation_error:
305        raise constraint_violation_error
306
307
308def make_constraints(
309    fake_mode: FakeTensorMode,
310    gm: torch.fx.GraphModule,
311    combined_args: Dict[str, Any],
312    dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
313    num_lifted_inputs: int,
314):
315    """
316    Given a fake mode's shape env and user-specified dynamic shapes,
317    return the resulting range constraints and equality constraints.
318
319    Additional args:
320        num_lifted_inputs: the number of non-user-input placeholder nodes in the graph
321        (used only to enumerate the user-input nodes)
322    """
323
324    shape_env = fake_mode.shape_env
325    assert shape_env is not None
326    inline_constraints = gm.meta.get("inline_constraints", [])
327    range_constraints = {
328        symbol: inline_constraints[symbol] for symbol in inline_constraints
329    }
330    if not dynamic_shapes:
331        return range_constraints
332
333    # get individual dynamic shapes spec for each input
334    if not isinstance(dynamic_shapes, dict):
335        assert isinstance(dynamic_shapes, (tuple, list))
336        combined_args = type(dynamic_shapes)(combined_args.values())  # type: ignore[assignment, misc]
337    flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes)
338
339    # check number of shapes vs. number of inputs
340    num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True)
341    assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs
342
343    input_dims = defaultdict(list)
344    free_symbols = set()
345    for input_index, node in enumerate(gm.graph.nodes):
346        if input_index < num_lifted_inputs or node.op != "placeholder":
347            continue
348        if _is_constant_argument(node.meta["val"]) or isinstance(
349            node.meta["val"], CustomObjArgument
350        ):
351            continue
352        shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs]
353        for i, d in enumerate(node.meta["val"].shape):
354            if isinstance(d, torch.SymInt) and not d.node.expr.is_number:
355                # Look up the range constraint for the symbol corresponding to this shape dimension
356                # and store it indexed by the symbolic expression corresponding to it.
357                # NOTE(avik): Use node._expr instead of node.expr for the lookup here because
358                # we want the symbol, not its replacement, which could be an expression. Maybe
359                # there's a better way to do this, e.g., by (re)computing value ranges for expressions?
360                dim = shape_spec[i] if shape_spec else None
361                if dim is None or isinstance(dim, _DimHint):
362                    range_constraints[d.node.expr] = shape_env.var_to_range[
363                        d.node._expr
364                    ]
365                else:
366                    range_constraints[d.node.expr] = ValueRanges(
367                        lower=dim.min, upper=dim.max
368                    )
369                input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i))
370                free_symbols.update(d.node.expr.free_symbols)
371
372    for symbol in free_symbols:
373        if symbol not in range_constraints:
374            # Placeholders can have symbolic shapes that are derived expressions.
375            # The above code will record direct range constraints for them
376            # so that we can do runtime assertions. In addition, for serde checks
377            # we want to record range constraints for their root symbols.
378            range_constraints[symbol] = shape_env.var_to_range[symbol]
379
380    return range_constraints
381
382
383def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
384    """Search the module hierarchy, gathering up all tensor and ScriptObject constants.
385
386    Returns a dictionary mapping hash(value) to the name of the constant. We
387    have to abuse `hash` here unfortunately, see: [ScriptObject hash].
388    """
389    constants = ConstantAttrMap()
390    buffers_parameters = set(m.buffers())
391    buffers_parameters.update(m.parameters())
392
393    def inner(m: torch.nn.Module, prefix_atoms: List[str], constants):
394        for k, v in m.__dict__.items():
395            if isinstance(
396                v,
397                (
398                    torch.Tensor,
399                    torch.ScriptObject,
400                    FakeScriptObject,
401                ),
402            ):
403                if v in buffers_parameters:
404                    # filter out buffers and parameters, leaving only constants
405                    continue
406
407                fqn = ".".join(prefix_atoms + [k])
408                constants.add(v, fqn)
409        for k, v in m.named_children():
410            inner(v, prefix_atoms + [k], constants)
411
412    inner(m, [], constants)
413    return constants
414
415
416@contextlib.contextmanager
417def _fakify_script_objects(
418    mod: torch.nn.Module,
419    args: Tuple[Any],
420    kwargs: Dict[Any, Any],
421    fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
422):
423    # This context manager is used to fakify script objects into FakeScriptObject.
424    # Inputs:
425    #   mod: the module to be exported, it (and its recursive submodules)'s script object attrs haven't been fakified.
426    #   args, kwargs: the args and kwargs inputs for mod, script object inputs haven't been fakified.
427    #   fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors.
428    #
429    # Returns:
430    #   mod: the patched module, its (and its recursive submodules) script object attrs have been fakified.
431    #   fake_args, fake_kwargs: new fakified args and kwargs.
432    #        Script object inputs have been fakified. Don't touch the tensors.
433    #   fake_constant_attrs: a new map from FakeScriptObject to the fqn of the original script object.
434    #   fake_to_real: a mapping between FakeScriptObject and the original script object in order to un-do the patching.
435
436    constant_attrs: ConstantAttrMap = _gather_constant_attrs(mod)
437    assert not any(
438        isinstance(obj, FakeScriptObject) for obj in constant_attrs.values()
439    ), "Mod shouldn't contain any FakeScriptObject."
440    assert not pytree.tree_any(
441        lambda obj: isinstance(obj, FakeScriptObject), (args, kwargs)
442    ), "args and kwargs shouldn't contain any FakeScriptObject."
443
444    patched_attr = {}
445    fake_constant_attrs = ConstantAttrMap()
446    fake_to_real = {}
447
448    def _maybe_fakify_obj(obj):
449        fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj)
450        fake_to_real[fake_obj] = obj
451        return fake_obj
452
453    def _leaf_mod_and_attr(
454        mod: torch.nn.Module, attr_fqn: str
455    ) -> Tuple[torch.nn.Module, str]:
456        *prefix_attr, last_attr = attr_fqn.split(".")
457        cur_mod = mod
458        for attr in prefix_attr:
459            cur_mod = getattr(cur_mod, attr)
460        return cur_mod, last_attr
461
462    try:
463        for obj, fqns in constant_attrs.items():
464            if isinstance(obj, torch.ScriptObject):
465                fake_script_obj = _maybe_fakify_obj(obj)
466                for fqn in fqns:
467                    cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
468                    assert obj is getattr(cur_mod, attr)
469                    setattr(cur_mod, attr, fake_script_obj)
470                    fake_constant_attrs.add(fake_script_obj, fqn)
471                    patched_attr[fqn] = obj
472            else:
473                for fqn in fqns:
474                    fake_constant_attrs.add(obj, fqn)
475
476        fake_args, fake_kwargs = pytree.tree_map_only(
477            torch.ScriptObject, _maybe_fakify_obj, (args, kwargs)
478        )
479        yield (mod, fake_args, fake_kwargs, fake_constant_attrs, fake_to_real)
480    finally:
481        for fqn, orig_obj in patched_attr.items():
482            cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
483            setattr(cur_mod, attr, orig_obj)
484
485
486class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
487    """
488    1. Handles data-dependent errors raised by torch function calls in non-strict.
489
490    Any data-dependent error is due to some condition on unbacked symints
491    that cannot be resolved. A mechanical way of fixing the error is to use
492    a torch._check() call to assert either that condition or its negation.
493    The handler suggests these options as code and points to the location
494    of the torch function call that raised the error as part of the error
495    message shown to the user, who can then simply select and copy-paste
496    a suggested fix at that location.
497
498    NOTE: Not all data-dependent errors are raised by torch function calls.
499    In particular, conditions on unbacked symints can appear outside such
500    calls, and as such are not handled here.
501
502    2. Handles line-of-code logging for each torch function call in non-strict.
503
504    Usage: TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ...
505    """
506
507    def __torch_function__(self, func, types, args=(), kwargs=None):
508        kwargs = kwargs or {}
509        if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
510            frame = _find_user_code_frame()
511            if frame is not None:
512                log.debug(
513                    "%s called at %s:%s in %s",
514                    func.__qualname__,
515                    frame.f_code.co_filename,
516                    frame.f_lineno,
517                    frame.f_code.co_name,
518                )
519        try:
520            return func(*args, **kwargs)
521        except GuardOnDataDependentSymNode as e:
522            _suggest_fixes_for_data_dependent_error_non_strict(e)
523            raise
524