xref: /aosp_15_r20/external/pytorch/test/fx/test_fx_node_hook.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2import torch
3from torch.fx import symbolic_trace
4from torch.testing._internal.common_utils import TestCase
5
6
7class TestFXNodeHook(TestCase):
8    def test_hooks_for_node_update(self):
9        global create_node_hook1_called
10        global create_node_hook2_called
11        global erase_node_hook1_called
12        global erase_node_hook2_called
13        create_node_hook1_called = False
14        create_node_hook2_called = False
15        erase_node_hook1_called = False
16        erase_node_hook2_called = False
17
18        def fn(a, b, c):
19            x = torch.nn.functional.linear(a, b)
20            x = x + c
21            return x.cos()
22
23        def create_node_hook1(node):
24            global create_node_hook1_called
25            create_node_hook1_called = True
26
27        def create_node_hook2(node):
28            global create_node_hook2_called
29            create_node_hook2_called = True
30
31        def erase_node_hook1(node):
32            global erase_node_hook1_called
33            erase_node_hook1_called = True
34
35        def erase_node_hook2(node):
36            global erase_node_hook2_called
37            erase_node_hook2_called = True
38
39        gm = symbolic_trace(fn)
40        gm._register_create_node_hook(create_node_hook1)
41        gm._register_create_node_hook(create_node_hook2)
42        gm._register_erase_node_hook(erase_node_hook1)
43        gm._register_erase_node_hook(erase_node_hook2)
44
45        graph = gm.graph
46        node_a = None
47        for node in graph.find_nodes(op="placeholder"):
48            node_a = node
49            break
50        assert node_a is not None
51        # This will create a new node
52        node_a_copy = graph.node_copy(node_a)
53        node_a.replace_all_uses_with(node_a_copy)
54        graph.erase_node(node_a)
55
56        assert (
57            create_node_hook1_called
58            and create_node_hook2_called
59            and erase_node_hook1_called
60            and erase_node_hook2_called
61        )
62
63        gm._unregister_create_node_hook(create_node_hook1)
64        gm._unregister_create_node_hook(create_node_hook2)
65        gm._unregister_erase_node_hook(erase_node_hook1)
66        gm._unregister_erase_node_hook(erase_node_hook2)
67
68        assert gm._create_node_hooks == []
69        assert gm._erase_node_hooks == []
70