xref: /aosp_15_r20/external/pytorch/torch/_export/passes/_node_metadata_hook.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3
4import torch
5from torch.fx.graph_module import GraphModule
6
7
8_EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook"
9
10
11def _node_metadata_hook(node: torch.fx.Node, stack_trace: str) -> None:
12    """
13    Hook for adding the appropriate metadata to nodes that are created during a
14    pass using graph.create_node. An example of how to use it:
15
16    ```
17    with _set_node_metadata_hook(gm,
18        functools.partial(_node_metadata_hook, stack_trace="file")
19    ):
20        pass(gm)
21    ```
22
23    This hook should not work for all generic cases -- specifically it assumes
24    that nodes being added are only call_function nodes, and copies over the
25    first argument node's nn_module_stack.
26    """
27    assert node.op == "call_function" and callable(node.target)
28
29    arg_meta = [arg.meta for arg in node.args if isinstance(arg, torch.fx.Node)]
30    assert len(arg_meta) >= 1
31    arg_meta = arg_meta[0]
32
33    if (
34        isinstance(node.target, torch._ops.OpOverload)
35        and len(node.target._schema.returns) == 0
36    ):
37        node.meta["val"] = None
38    else:
39        fake_args = [
40            arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
41            for arg in node.args
42        ]
43        fake_res = node.target(*fake_args)
44        node.meta["val"] = fake_res
45
46    node.meta["stack_trace"] = stack_trace
47    node.meta["nn_module_stack"] = arg_meta.get(
48        "nn_module_stack",
49        {
50            _EMPTY_NN_MODULE_STACK_KEY: (
51                _EMPTY_NN_MODULE_STACK_KEY,
52                _EMPTY_NN_MODULE_STACK_KEY,
53            )
54        },
55    )
56    node.meta["torch_fn"] = (
57        f"{node.target.__name__}_0",
58        f"{node.target.__class__.__name__}.{node.target.__name__}",
59    )
60
61
62@contextlib.contextmanager
63def _set_node_metadata_hook(gm: torch.fx.GraphModule, f):
64    """
65    Takes a callable which will be called after we create a new node. The
66    callable takes the newly created node as input and returns None.
67    """
68    assert callable(f), "node_metadata_hook must be a callable."
69
70    # Add the hook to all submodules
71    for m in gm.modules():
72        if isinstance(m, GraphModule):
73            m._register_create_node_hook(f)
74    try:
75        yield
76    finally:
77        # Restore hook for all submodules
78        for m in gm.modules():
79            if isinstance(m, GraphModule):
80                m._unregister_create_node_hook(f)
81