1# Owner(s): ["module: fx"] 2import torch 3from torch.fx import symbolic_trace 4from torch.testing._internal.common_utils import TestCase 5 6 7class TestFXNodeHook(TestCase): 8 def test_hooks_for_node_update(self): 9 global create_node_hook1_called 10 global create_node_hook2_called 11 global erase_node_hook1_called 12 global erase_node_hook2_called 13 create_node_hook1_called = False 14 create_node_hook2_called = False 15 erase_node_hook1_called = False 16 erase_node_hook2_called = False 17 18 def fn(a, b, c): 19 x = torch.nn.functional.linear(a, b) 20 x = x + c 21 return x.cos() 22 23 def create_node_hook1(node): 24 global create_node_hook1_called 25 create_node_hook1_called = True 26 27 def create_node_hook2(node): 28 global create_node_hook2_called 29 create_node_hook2_called = True 30 31 def erase_node_hook1(node): 32 global erase_node_hook1_called 33 erase_node_hook1_called = True 34 35 def erase_node_hook2(node): 36 global erase_node_hook2_called 37 erase_node_hook2_called = True 38 39 gm = symbolic_trace(fn) 40 gm._register_create_node_hook(create_node_hook1) 41 gm._register_create_node_hook(create_node_hook2) 42 gm._register_erase_node_hook(erase_node_hook1) 43 gm._register_erase_node_hook(erase_node_hook2) 44 45 graph = gm.graph 46 node_a = None 47 for node in graph.find_nodes(op="placeholder"): 48 node_a = node 49 break 50 assert node_a is not None 51 # This will create a new node 52 node_a_copy = graph.node_copy(node_a) 53 node_a.replace_all_uses_with(node_a_copy) 54 graph.erase_node(node_a) 55 56 assert ( 57 create_node_hook1_called 58 and create_node_hook2_called 59 and erase_node_hook1_called 60 and erase_node_hook2_called 61 ) 62 63 gm._unregister_create_node_hook(create_node_hook1) 64 gm._unregister_create_node_hook(create_node_hook2) 65 gm._unregister_erase_node_hook(erase_node_hook1) 66 gm._unregister_erase_node_hook(erase_node_hook2) 67 68 assert gm._create_node_hooks == [] 69 assert gm._erase_node_hooks == [] 70