1# mypy: allow-untyped-defs 2import contextlib 3 4import torch 5from torch.fx.graph_module import GraphModule 6 7 8_EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook" 9 10 11def _node_metadata_hook(node: torch.fx.Node, stack_trace: str) -> None: 12 """ 13 Hook for adding the appropriate metadata to nodes that are created during a 14 pass using graph.create_node. An example of how to use it: 15 16 ``` 17 with _set_node_metadata_hook(gm, 18 functools.partial(_node_metadata_hook, stack_trace="file") 19 ): 20 pass(gm) 21 ``` 22 23 This hook should not work for all generic cases -- specifically it assumes 24 that nodes being added are only call_function nodes, and copies over the 25 first argument node's nn_module_stack. 26 """ 27 assert node.op == "call_function" and callable(node.target) 28 29 arg_meta = [arg.meta for arg in node.args if isinstance(arg, torch.fx.Node)] 30 assert len(arg_meta) >= 1 31 arg_meta = arg_meta[0] 32 33 if ( 34 isinstance(node.target, torch._ops.OpOverload) 35 and len(node.target._schema.returns) == 0 36 ): 37 node.meta["val"] = None 38 else: 39 fake_args = [ 40 arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg 41 for arg in node.args 42 ] 43 fake_res = node.target(*fake_args) 44 node.meta["val"] = fake_res 45 46 node.meta["stack_trace"] = stack_trace 47 node.meta["nn_module_stack"] = arg_meta.get( 48 "nn_module_stack", 49 { 50 _EMPTY_NN_MODULE_STACK_KEY: ( 51 _EMPTY_NN_MODULE_STACK_KEY, 52 _EMPTY_NN_MODULE_STACK_KEY, 53 ) 54 }, 55 ) 56 node.meta["torch_fn"] = ( 57 f"{node.target.__name__}_0", 58 f"{node.target.__class__.__name__}.{node.target.__name__}", 59 ) 60 61 62@contextlib.contextmanager 63def _set_node_metadata_hook(gm: torch.fx.GraphModule, f): 64 """ 65 Takes a callable which will be called after we create a new node. The 66 callable takes the newly created node as input and returns None. 67 """ 68 assert callable(f), "node_metadata_hook must be a callable." 69 70 # Add the hook to all submodules 71 for m in gm.modules(): 72 if isinstance(m, GraphModule): 73 m._register_create_node_hook(f) 74 try: 75 yield 76 finally: 77 # Restore hook for all submodules 78 for m in gm.modules(): 79 if isinstance(m, GraphModule): 80 m._unregister_create_node_hook(f) 81