xref: /aosp_15_r20/external/pytorch/torch/_export/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import ast
3import dataclasses
4import inspect
5import math
6import operator
7import re
8from inspect import Parameter
9from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING
10
11import torch
12from torch._guards import detect_fake_mode
13from torch._subclasses.fake_tensor import FakeTensor
14
15
16if TYPE_CHECKING:
17    from torch._export.passes.lift_constants_pass import ConstantAttrMap
18    from torch.export import ExportedProgram
19    from torch.export.graph_signature import ExportGraphSignature
20
21from torch.export.graph_signature import InputKind, OutputKind
22from torch.utils._pytree import (
23    _register_pytree_node,
24    Context,
25    FlattenFunc,
26    FromDumpableContextFn,
27    GetAttrKey,
28    KeyPath,
29    keystr,
30    MappingKey,
31    SequenceKey,
32    ToDumpableContextFn,
33    tree_flatten_with_path,
34    UnflattenFunc,
35)
36
37
38placeholder_prefixes = {
39    InputKind.USER_INPUT: "",
40    InputKind.PARAMETER: "p_",
41    InputKind.BUFFER: "b_",
42    InputKind.CONSTANT_TENSOR: "c_",
43    InputKind.CUSTOM_OBJ: "obj_",
44    InputKind.TOKEN: "token",
45}
46
47
48def _collect_and_set_constant_attrs(
49    graph_signature, constants, mod
50) -> "ConstantAttrMap":
51    # the exported module will store constants & non-persistent buffers such that
52    # retracing treats them as persistent buffers, so we inform the constants lifting pass
53    # and overwrite the new graph signature using the previous program. This is intended to only be used
54    # in run_decompositions where we still have access to original EP.
55    from torch._export.passes.lift_constants_pass import ConstantAttrMap
56
57    constant_attrs = ConstantAttrMap()
58    non_persistent_buffers = {
59        spec.target
60        for spec in graph_signature.input_specs
61        if spec.kind == InputKind.BUFFER and not spec.persistent
62    }
63    for name, value in constants.items():
64        if name in non_persistent_buffers:
65            continue
66        # recursive getattr
67        _mod = mod
68        *atoms, attr = name.split(".")
69        for atom in atoms:
70            _mod = getattr(_mod, atom)
71        # remove as buffer, reassign as constant/non-persistent buffer
72        _mod._buffers.pop(attr, None)
73        setattr(_mod, attr, value)
74        constant_attrs.add(value, name)
75    return constant_attrs
76
77
78def _overwrite_signature_for_non_persistent_buffers(
79    old_sig: "ExportGraphSignature", new_sig: "ExportGraphSignature"
80):
81    # overwrite signature for non-persistent buffers
82    non_persistent_buffers = {
83        spec.target
84        for spec in old_sig.input_specs
85        if spec.kind == InputKind.BUFFER and not spec.persistent
86    }
87
88    for spec in new_sig.input_specs:
89        if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers:
90            spec.persistent = False
91    return new_sig
92
93
94def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]:
95    """
96    Param/buffer metadata needs to be saved before lowering to aten IR
97    because aten IR lifts them, as a result, automatic preservation doesn't work.
98    This is intended to be called on the strict mode tracing right before lowering to
99    aten IR OR run_decomposition pass.
100    """
101    params_buffers_to_node_meta = {}
102
103    def _getattr(model: torch.fx.GraphModule, attr_name: str):
104        *prefix, field = attr_name.split(".")
105        t = model
106        for item in prefix:
107            t = getattr(t, item, None)  # type: ignore[assignment]
108            assert t is not None
109
110        return getattr(t, field)
111
112    for node in mod.graph.nodes:
113        target = node.target
114        meta = node.meta
115        if node.op == "call_module":
116            submodule = _getattr(mod, target)
117            if isinstance(submodule, torch.nn.Module):
118                for name, _ in submodule.named_parameters(
119                    recurse=True, remove_duplicate=False
120                ):
121                    params_buffers_to_node_meta[target + "." + name] = meta
122
123                for name, _ in submodule.named_buffers(
124                    recurse=True, remove_duplicate=False
125                ):
126                    params_buffers_to_node_meta[target + "." + name] = meta
127
128        if node.op == "get_attr":
129            submodule = _getattr(mod, target)
130            if not isinstance(submodule, torch.fx.GraphModule):
131                params_buffers_to_node_meta[target] = meta
132
133        # If the call_function uses param as input, we also need to update params' meta
134        # with this call_function node's meta.
135        # This is basically the same flow as torch.fx.traceback.preserve_meta()
136        if node.op == "call_function" and not isinstance(
137            node.target, torch._ops.HigherOrderOperator
138        ):
139            for arg in node._input_nodes:
140                if arg.op == "get_attr":
141                    for entry in torch.fx.proxy._COPY_META_FIELDS:
142                        if entry in meta:
143                            params_buffers_to_node_meta[arg.target][entry] = meta[entry]
144
145    return params_buffers_to_node_meta
146
147
148def _populate_param_buffer_metadata_to_new_gm(
149    params_buffers_to_node_meta: Dict[str, Any],
150    gm: torch.fx.GraphModule,
151    new_sig: "ExportGraphSignature",
152) -> None:
153    """
154    Given that we collected param'buffer metadata before, we put them back in
155    newly traced graph module
156    """
157    # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes
158    for metadata in params_buffers_to_node_meta.values():
159        metadata.pop("nn_module_stack", None)
160        metadata.pop("stack_trace", None)
161
162    for node in gm.graph.nodes:
163        if node.op == "placeholder":
164            if node.target in new_sig.inputs_to_parameters:
165                param_name = new_sig.inputs_to_parameters[node.target]
166                if param_name in params_buffers_to_node_meta:
167                    for k, v in params_buffers_to_node_meta[param_name].items():
168                        node.meta[k] = v
169            if node.target in new_sig.inputs_to_buffers:
170                buffer_name = new_sig.inputs_to_buffers[node.target]
171                if buffer_name in params_buffers_to_node_meta:
172                    for k, v in params_buffers_to_node_meta[buffer_name].items():
173                        node.meta[k] = v
174
175
176def _get_shape_env_from_gm(gm: torch.fx.GraphModule):
177    vals = [
178        node.meta["val"]
179        for node in gm.graph.nodes
180        if node.meta.get("val", None) is not None
181    ]
182
183    fake_mode = _detect_fake_mode_from_gm(gm)
184    if fake_mode is not None:
185        return fake_mode.shape_env
186    for v in vals:
187        if isinstance(v, torch.SymInt):
188            return v.node.shape_env
189
190
191def _rename_without_collisions(
192    name_map: Dict[str, str],
193    orig_name: str,
194    name: str,
195    is_placeholder: bool = False,
196):
197    """
198    Renames nodes to avoid name collisions, with suffixing.
199    name_map: map from original name to new name
200    orig_name: mapping key
201    name: candidate name (potentially suffixed, e.g. mul_2)
202    is_placeholder: if the node is a placeholder, avoid detecting suffix
203    """
204    if name in name_map.values():
205        # non-placeholder nodes may be suffixed with the count
206        # instead of adding another suffix, we will try to increment it
207        match = re.match(r"(.*)_(\d+)", name)
208        if match and not is_placeholder:
209            name, n = match.group(1), int(match.group(2))
210        else:
211            n = 0
212        while (dup_name := f"{name}_{n + 1}") in name_map.values():
213            n += 1
214        name_map[orig_name] = dup_name
215    else:
216        name_map[orig_name] = name
217    return name_map[orig_name]
218
219
220def _check_input_constraints_for_graph(
221    input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints
222):
223    def get_keystr(key_path: KeyPath) -> str:
224        """For a given index into the flat_args, return a human readable string
225        describing how to access it, e.g. "*args["foo"][0].bar"
226        """
227        # Prefix the keypath with "*args" or "**kwargs" to make it clearer where
228        # the arguments come from. Ultimately we ought to serialize the
229        # original arg names for the best error message here.
230        args_kwargs_key_path = key_path[0]
231        assert isinstance(args_kwargs_key_path, SequenceKey)
232        if args_kwargs_key_path.idx == 0:
233            return f"*args{keystr(key_path[1:])}"
234        else:
235            kwarg_key = key_path[1]
236            assert isinstance(kwarg_key, MappingKey)
237            name = str(kwarg_key)[1:-1]  # get rid of the enclosed []
238            return f"{name}{keystr(key_path[2:])}"
239
240    import sympy
241
242    from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
243        _convert_range_to_int,
244    )
245    from torch.utils._sympy.solve import try_solve
246
247    if len(flat_args_with_path) != len(input_placeholders):
248        raise RuntimeError(
249            "Unexpected number of inputs "
250            f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})"
251        )
252    # NOTE: export already guarantees that the same symbol is used in metadata
253    # for all InputDims related by equality constraints, so we can just unify
254    # symbols with given input dimension values to check equality constraints.
255    unification_map: Dict[sympy.Symbol, Any] = {}
256    for (key_path, arg), node in zip(flat_args_with_path, input_placeholders):
257        node_val = node.meta.get("val")
258        if isinstance(node_val, FakeTensor):
259            if not isinstance(arg, torch.Tensor):
260                raise RuntimeError(
261                    f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}",
262                )
263
264            if len(node_val.shape) != len(arg.shape):
265                raise RuntimeError(
266                    f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape "
267                    f"(expected {node_val.shape}, got {arg.shape})"
268                )
269
270            for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)):
271                # TODO(avik): Assert the following property in the IR verifier:
272                # node_dim is either an int or a SymInt containing an int or a unary sympy.Expr
273                if (
274                    isinstance(node_dim, torch.SymInt)
275                    and len(node_dim.node.expr.free_symbols) == 1
276                ):
277                    symbol = next(iter(node_dim.node.expr.free_symbols))
278                    if symbol in unification_map:
279                        existing_dim = node_dim.node.expr.subs(unification_map)
280                        if arg_dim != existing_dim:
281                            raise RuntimeError(
282                                f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
283                                f"{existing_dim}, but got {arg_dim}",
284                            )
285                    else:
286                        if (
287                            isinstance(arg_dim, torch.SymInt)
288                            and not arg_dim.node.expr.is_number
289                        ):
290                            # This can happen when, say, arg is a fake tensor.
291                            # We do not run checks on symbolic shapes of fake inputs as
292                            # such checks can affect the shape env.
293                            pass
294                        else:
295                            if isinstance(node_dim.node.expr, sympy.Symbol):
296                                # Short cut for try_solve below. Also useful in cases where
297                                # sympy.Eq(node_dim.node.expr, arg_dim) would evaluate to False
298                                # purely because symbol is constrained to be size-like,
299                                # e.g., when node_dim.node.expr = symbol and arg_dim = 0.
300                                unification_map[symbol] = int(arg_dim)
301                            else:
302                                solution = try_solve(
303                                    sympy.Eq(node_dim.node.expr, arg_dim), symbol
304                                )
305                                if solution is None:
306                                    raise RuntimeError(  # noqa: B904
307                                        f"Expected input {node.name}.shape[{j}] = {arg_dim} to be "
308                                        f"of the form {node_dim.node.expr}, where {symbol} is an integer"
309                                    )
310                                else:
311                                    unification_map[symbol] = int(solution[1])
312
313                    if node_dim.node.expr in range_constraints:
314                        min_val, max_val = _convert_range_to_int(
315                            range_constraints[node_dim.node.expr]
316                        )
317                        # NOTE: we allow dimensions to be 0/1 at runtime
318                        if min_val > 2:
319                            if arg_dim < min_val:
320                                raise RuntimeError(
321                                    f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= "
322                                    f"{min_val}, but got {arg_dim}",
323                                )
324                        if max_val < math.inf:
325                            if arg_dim > max_val:
326                                raise RuntimeError(
327                                    f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= "
328                                    f"{max_val}, but got {arg_dim}",
329                                )
330                else:
331                    if arg_dim != node_dim:
332                        if (
333                            isinstance(node_dim, torch.SymInt)
334                            and not node_dim.node.expr.is_number
335                        ):
336                            # this means we deferred a guard from export analysis to runtime, let this pass
337                            # we'll add a runtime assert checking equality to this replacement expression
338                            continue
339                        raise RuntimeError(
340                            f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
341                            f"{node_dim}, but got {arg_dim}",
342                        )
343        elif isinstance(node_val, (int, float, str)):
344            if type(arg) != type(node_val) or arg != node_val:
345                raise RuntimeError(
346                    f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}",
347                )
348
349
350def register_dataclass_as_pytree_node(
351    cls: Type[Any],
352    flatten_fn: Optional[FlattenFunc] = None,
353    unflatten_fn: Optional[UnflattenFunc] = None,
354    *,
355    serialized_type_name: Optional[str] = None,
356    to_dumpable_context: Optional[ToDumpableContextFn] = None,
357    from_dumpable_context: Optional[FromDumpableContextFn] = None,
358    return_none_fields: bool = False,
359) -> None:
360    assert dataclasses.is_dataclass(
361        cls
362    ), f"Only dataclasses can be registered with this function: {cls}"
363
364    def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
365        flattened = []
366        flat_names = []
367        none_names = []
368        for f in dataclasses.fields(obj):
369            name, val = f.name, getattr(obj, f.name)
370            if val is not None or return_none_fields:
371                flattened.append(val)
372                flat_names.append(name)
373            else:
374                none_names.append(name)
375        return flattened, [flat_names, none_names]
376
377    def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
378        flat_names, none_names = context
379        return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
380
381    def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]:
382        flattened, (flat_names, none_names) = flatten_fn(obj)  # type: ignore[misc]
383        return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
384
385    flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
386    unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
387
388    if (to_dumpable_context is None) ^ (from_dumpable_context is None):
389        raise ValueError(
390            f"Both to_dumpable_context and from_dumpable_context for {cls} must "
391            "be None or registered."
392        )
393
394    _register_pytree_node(
395        cls,
396        flatten_fn,
397        unflatten_fn,
398        serialized_type_name=serialized_type_name,
399        flatten_with_keys_fn=default_flatten_fn_with_keys,
400        to_dumpable_context=to_dumpable_context,
401        from_dumpable_context=from_dumpable_context,
402    )
403
404
405def is_param(program: "ExportedProgram", node: torch.fx.Node) -> bool:
406    """
407    Checks if the given node is a parameter within the exported program
408    """
409
410    return node.name in program.graph_signature.inputs_to_parameters
411
412
413def get_param(
414    program: "ExportedProgram",
415    node: torch.fx.Node,
416) -> Optional[torch.nn.Parameter]:
417    """
418    Returns the parameter associated with the given node in the exported program.
419    Returns None if the node is not a parameter within the exported program
420    """
421
422    if is_param(program, node):
423        parameter_name = program.graph_signature.inputs_to_parameters[node.name]
424        return program.state_dict[parameter_name]
425
426    return None
427
428
429def is_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool:
430    """
431    Checks if the given node is a buffer within the exported program
432    """
433
434    return node.name in program.graph_signature.inputs_to_buffers
435
436
437def get_buffer(
438    program: "ExportedProgram",
439    node: torch.fx.Node,
440) -> Optional[torch.Tensor]:
441    """
442    Returns the buffer associated with the given node in the exported program.
443    Returns None if the node is not a buffer within the exported program
444    """
445
446    if is_buffer(program, node):
447        buffer_name = program.graph_signature.inputs_to_buffers[node.name]
448        if buffer_name in program.graph_signature.non_persistent_buffers:
449            return program.constants[buffer_name]
450        else:
451            return program.state_dict[buffer_name]
452
453    return None
454
455
456def is_lifted_tensor_constant(
457    program: "ExportedProgram",
458    node: torch.fx.Node,
459) -> bool:
460    """
461    Checks if the given node is a lifted tensor constant within the exported program
462    """
463
464    return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
465
466
467def get_lifted_tensor_constant(
468    program: "ExportedProgram",
469    node: torch.fx.Node,
470) -> Optional[torch.Tensor]:
471    """
472    Returns the lifted tensor constant associated with the given node in the exported program.
473    Returns None if the node is not a lifted tensor constant within the exported program
474    """
475
476    if is_lifted_tensor_constant(program, node):
477        lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[
478            node.name
479        ]
480        return program.constants[lifted_tensor_name]
481
482    return None
483
484
485def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule:
486    """
487    sequential_split creates a new graph module that splits the input graph module into multiple submodules
488    based on the node_call_back. It doesn't mutate the input graph module. The node_call_back should return
489    True if the node is a delimiter.  Delimiter will be the first node in the next submodule.
490    """
491    from torch.fx.passes.split_module import split_module
492
493    split_map = {}
494    split_id = 0
495    for node in gm.graph.nodes:
496        if node_call_back(node):
497            split_id += 1
498        split_map[node] = split_id
499
500    new_gm = split_module(
501        gm,
502        gm,
503        lambda node: split_map[node],
504        keep_original_order=True,
505        keep_original_node_name=True,
506    )
507    # Keep the codegen from original graph module to preserve e.g. pytree info.
508    new_gm.graph._codegen = gm.graph._codegen
509    new_gm.recompile()
510    return new_gm
511
512
513def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
514    """Returns the nodes that match the node_call_back as a list."""
515    return [node for node in nodes if node_call_back(node)]
516
517
518def nodes_first(
519    nodes: List[torch.fx.Node], node_call_back=None
520) -> Optional[torch.fx.Node]:
521    """
522    Returns the first node that matches the node_call_back. If no node matches, returns None.
523    When node_call_back is None, returns the first node in the node list.
524    """
525    ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True)
526    if len(ret) > 0:
527        return ret[0]
528    return None
529
530
531def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int:
532    """Returns the number of nodes that match the node_call_back."""
533    return len(nodes_filter(nodes, node_call_back))
534
535
536def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
537    """
538    Sequentially visit the nodes list and invoke node_call_back on each element.
539    Returns the nodes list after the node_call_back is invoked on each element.
540    """
541    for node in nodes:
542        node_call_back(node)
543    return nodes
544
545
546def node_replace_(old_node: torch.fx.Node, new_node: torch.fx.Node) -> None:
547    """
548    Replace all uses of old_node with new_node.
549    """
550    old_node.replace_all_uses_with(new_node)
551    old_node.users.clear()
552    old_node.graph.erase_node(old_node)
553
554
555def node_inline_(call_mod_node: torch.fx.Node) -> None:
556    """
557    Inline the submodule of the given node into the parent module.
558    Note: we only support the case where submodule takes tensors inputs.
559    """
560    assert call_mod_node.op == "call_module"
561    gm = call_mod_node.graph.owning_module
562
563    assert isinstance(call_mod_node.target, str)
564    sub_gm = getattr(gm, call_mod_node.target)
565
566    phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder")
567    body = (
568        node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output")
569    )
570    output = [node for node in sub_gm.graph.nodes if node.op == "output"]
571
572    for ph, arg in zip(phs, call_mod_node.args):
573        assert isinstance(arg, torch.fx.Node)
574        node_replace_(ph, arg)
575
576    with gm.graph.inserting_before(call_mod_node):
577        for node in body:
578            new_node = gm.graph.node_copy(node)
579            node_replace_(node, new_node)
580
581        if len(output) > 0:
582            assert len(output) == 1 and len(output[0].args) == 1
583            new_output = output[0].args[0]
584
585            if isinstance(new_output, torch.fx.Node):
586                # Clear the users of the output node and set
587                # the users to be the users of original call_module node.
588                new_output.users.clear()
589                node_replace_(call_mod_node, new_output)
590            elif isinstance(new_output, (list, tuple)):
591                # Pop subgraph output node from users.
592                for node in new_output:
593                    node.users.pop(output[0])
594
595                # Inline the get_item calls for the output node.
596                get_item_users = nodes_filter(
597                    list(call_mod_node.users.keys()),
598                    lambda node: node.op == "call_function"
599                    and node.target == operator.getitem,
600                )
601                # get_item_node.args[1] is the idx referring to new_output[idx]
602                nodes_map(
603                    get_item_users,
604                    lambda get_item_node: node_replace_(
605                        get_item_node,
606                        new_output[get_item_node.args[1]],
607                    ),
608                )
609                call_mod_node.graph.erase_node(call_mod_node)
610            else:
611                raise NotImplementedError(
612                    f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes."
613                )
614        else:
615            call_mod_node.graph.erase_node(call_mod_node)
616
617    gm.delete_all_unused_submodules()
618    gm.recompile()
619    return gm
620
621
622def _get_torch_jit_trace_forward_signature(mod: torch.nn.Module):
623    """
624    Get source code and parse argument names using AST. The function returns
625    a signature of the forward() function.
626
627    # TODO: Directly provide inspect.signature compatible TS-d module.
628    """
629    ast_mod = ast.parse(mod.code)
630    ast_func_def: ast.FunctionDef = ast_mod.body[0]  # type: ignore[assignment]
631
632    # FIXME(jiashenc): TorchScript should only allow positional or keywords arguments.
633    arg_type_map = {"args": Parameter.POSITIONAL_OR_KEYWORD}
634
635    # Traverse all argument types in AST tree and create associated parameters.
636    param_list = []
637    for arg_type, param_type in arg_type_map.items():
638        arg_name_list = [a.arg for a in getattr(ast_func_def.args, arg_type)]
639        for arg_name in arg_name_list:
640            if arg_name == "self":
641                continue  # Skip self argument.
642            param_list.append(inspect.Parameter(arg_name, param_type))
643
644    return inspect.Signature(parameters=param_list)
645
646
647def _bind_signature_to_inputs(mod, fake_args, fake_kwargs):
648    if isinstance(mod, (torch.jit.ScriptModule, torch.jit.TracedModule)):
649        sig = _get_torch_jit_trace_forward_signature(mod)
650
651        # Sanity check for placeholder names coming from TorchScript.
652        assert len(sig.parameters) == len(fake_args) + len(fake_kwargs), (
653            "Arguments other than POSITIONAL_OR_KEYWORD kinds in forward() "
654            "are not supported in _get_torch_jit_trace_forward_signature"
655        )
656    else:
657        sig = inspect.signature(mod.forward)
658
659    return sig.bind(*fake_args, **fake_kwargs).arguments
660
661
662def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
663    """
664    Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs,
665    and handle collisions with non-placeholders by count suffixing.
666    Different HOO subgraph types have different input schemas, so we first enumerate them
667    and gather the top-level named placeholder nodes.
668    """
669    # gather all HOO subgraphs and their top-level named placeholder nodes
670    subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = []
671    for node in gm.graph.nodes:
672        if node.op == "call_function" and isinstance(
673            node.target, torch._ops.HigherOrderOperator
674        ):
675            # HOO subgraphs have varying input schemas, so we enumerate them there
676            if node.target._name == "cond":
677                _, true_graph, false_graph, cond_args = node._args
678                subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args))
679                subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args))
680            elif node.target._name == "wrap_with_set_grad_enabled":
681                subgraph, phs = node._args[1], node._args[2:]
682                subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs))
683            elif node.target._name == "map_impl":
684                body_graph, array, args = node._args
685                subgraph_ph_tuples.append(
686                    (getattr(gm, body_graph.target), array + args)
687                )
688
689    # propagate names
690    for subgraph, hoo_phs in subgraph_ph_tuples:
691        name_map: Dict[str, str] = {}
692        for i, node in enumerate(subgraph.graph.nodes):
693            if i < len(hoo_phs):  # placeholder, retain name
694                name_map[node.name] = hoo_phs[i].name
695                node.name = node.target = hoo_phs[i].name
696            else:  # non-placeholder, check for collisions
697                node.name = _rename_without_collisions(name_map, node.name, node.name)
698
699        # recurse and recompile
700        _name_hoo_subgraph_placeholders(subgraph)
701        subgraph.recompile()
702
703
704def placeholder_naming_pass(
705    gm: torch.fx.GraphModule,
706    export_graph_signature: "ExportGraphSignature",
707    mod: torch.nn.Module,
708    fake_args,
709    fake_kwargs,
710    fake_params_buffers,
711    constants: Dict[str, Any],
712) -> None:
713    """
714    This pass is run at the end of _export_non_strict() to assign better placeholder node names:
715        - User inputs:
716            These follow the signature of mod.forward(), e.g. forward(x, y) produces nodes x, y.
717            For nested inputs from dictionaries, lists, tuples, or dataclasses,
718            the names are a concatenation of the path to the tensor.
719                e.g. x = {
720                    'a': torch.randn(),
721                    'b': [torch.randn(), torch.randn()]
722                }
723            produces nodes x_a, x_b_0, x_b_1.
724        - Parameters/buffers/constants/custom objects:
725            These follow the FQN of the object, prefixed by "p", "b", "c", "obj" respectively.
726                e.g. self.bar.l0.weight produces "p_bar_l0_weight".
727        - Effect tokens:
728            These are named token, token_1, ...
729    """
730
731    def _strip_name(x):
732        if x.startswith("L__self___"):
733            x = x[len("L__self___") :]
734        elif x.startswith("self_"):
735            x = x[len("self_") :]
736        x = re.sub(r"[^a-zA-Z0-9]", "_", x)
737        return x
738
739    def _extract_pytree_key(x):
740        if isinstance(x, MappingKey):
741            x = re.sub(r"[^a-zA-Z0-9]", "_", str(x.key))
742            return x
743        elif isinstance(x, SequenceKey):
744            return str(x.idx)
745        elif isinstance(x, GetAttrKey):
746            return x.name
747        else:
748            raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}")
749
750    name_map: Dict[str, str] = {}
751
752    # map user input names with mod.forward() signature
753    combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs)
754
755    flat_args_with_path, _ = tree_flatten_with_path(combined_args)
756    user_input_names = [
757        spec.arg.name
758        for spec in export_graph_signature.input_specs
759        if spec.kind == InputKind.USER_INPUT
760    ]
761
762    # use pytree path to name nested user inputs
763    for (arg_path, arg), user_input_name in zip(flat_args_with_path, user_input_names):
764        if user_input_name:
765            _rename_without_collisions(
766                name_map,
767                user_input_name,
768                placeholder_prefixes[InputKind.USER_INPUT]
769                + "_".join(_extract_pytree_key(x).lower() for x in arg_path),
770                is_placeholder=True,
771            )
772
773    # use graph signature input specs to map param/buffer/constant names
774    # name effect tokens as token, token_1, ... (these aren't visible to user)
775    for spec in export_graph_signature.input_specs:
776        if spec.kind == InputKind.USER_INPUT:
777            continue
778        if spec.kind == InputKind.TOKEN:
779            base_name = ""
780        else:
781            base_name = _strip_name(spec.target).lower()
782        base_name = re.sub(r"[^a-zA-Z0-9]", "_", base_name)
783
784        _rename_without_collisions(
785            name_map,
786            spec.arg.name,
787            placeholder_prefixes[spec.kind] + base_name,
788            is_placeholder=True,
789        )
790
791    # handle naming collisions with call_function/get_attr inputs.
792    # here, we want to prioritize user input names over call_function names
793    # e.g. not have forward(self, mul): lead to a placeholder node called mul_13,
794    # so we increment the suffix of call_function nodes as needed
795    for node in gm.graph.nodes:
796        if node.op == "placeholder":
797            continue
798        _rename_without_collisions(name_map, node.name, node.name)
799
800    # assign new node names
801    for node in gm.graph.nodes:
802        if node.op == "placeholder":
803            assert node.name in name_map
804            node.name = node.target = name_map[node.name]
805        elif node.name in name_map:
806            node.name = name_map[node.name]
807
808    # propagate names to higher order op subgraphs
809    _name_hoo_subgraph_placeholders(gm)
810
811    # re-generate graph module code
812    gm.recompile()
813
814    # modify graph signature (input specs, output specs, user input mutations)
815    for spec in export_graph_signature.input_specs:
816        assert spec.arg.name in name_map
817        spec.arg.name = name_map[spec.arg.name]
818        if (  # handle targets for custom objects
819            spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map
820        ):
821            spec.target = name_map[spec.target][4:]  # strip obj_ prefix
822
823    for spec in export_graph_signature.output_specs:
824        if spec.arg.name in name_map:
825            spec.arg.name = name_map[spec.arg.name]
826        if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map:
827            spec.target = name_map[spec.target]
828
829    # rename keys in constants dict for custom objects
830    for name in list(constants.keys()):
831        constant = constants[name]
832        if name in name_map and not isinstance(
833            constant, torch.Tensor
834        ):  # rename custom objects with generic names
835            new_name = name_map[name]
836            if (
837                new_name != name
838                and re.match(r"arg(\d+)_1", name)
839                and new_name != placeholder_prefixes[InputKind.CUSTOM_OBJ] + name
840            ):
841                constants[new_name] = constant
842                del constants[name]
843
844
845def remove_proxy_from_state_dict(state_dict: Dict, in_place: bool) -> Dict:
846    """
847    If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`.
848    `v` is the values in the dictionary.
849    If `in_place` is true, modify `state_dict` in place.
850    """
851    if in_place:
852        for k, v in state_dict.items():
853            if hasattr(v, "proxy"):
854                delattr(state_dict[k], "proxy")
855        return state_dict
856    else:
857        new_state_dict = {}
858        for k, v in state_dict.items():
859            if hasattr(v, "proxy"):
860                new_state_dict[k] = v.clone().detach()
861            else:
862                new_state_dict[k] = v
863        return new_state_dict
864
865
866def _detect_fake_mode_from_gm(
867    gm: torch.fx.GraphModule,
868) -> torch._subclasses.fake_tensor.FakeTensorMode:
869    """
870    For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs.
871    Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes.
872    If no fake mode is found, we return None for fake_mode.
873    """
874
875    fake_inps: List[torch.Tensor] = []
876    fake_vals: List[torch.Tensor] = []
877    for node in gm.graph.nodes:
878        if node.op == "placeholder" and "val" in node.meta:
879            fake_val = node.meta["val"]
880            if fake_val is not None and isinstance(fake_val, torch.Tensor):
881                fake_inps.append(fake_val)
882        elif len(fake_inps) == 0 and (
883            "example_value" in node.meta or "val" in node.meta
884        ):
885            fake_val = None
886            if "example_value" in node.meta:
887                fake_val = node.meta["example_value"]
888            elif "val" in node.meta:
889                fake_val = node.meta["val"]
890            if fake_val is not None and isinstance(fake_val, torch.Tensor):
891                fake_vals.append(fake_val)
892
893    return detect_fake_mode(fake_inps + fake_vals)
894