xref: /aosp_15_r20/external/pytorch/torch/fx/passes/graph_transform_observer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import os
3from typing import Optional
4
5from torch.fx._compatibility import compatibility
6from torch.fx.graph_module import GraphModule
7
8from .graph_drawer import FxGraphDrawer
9
10
11__all__ = ["GraphTransformObserver"]
12
13
14@compatibility(is_backward_compatible=False)
15class GraphTransformObserver:
16    __pass_count = 0
17
18    def __init__(self, gm: GraphModule, passname: str, log_url: Optional[str] = None):
19        # If log_url is None, we don't log anything
20        self.log_url = log_url
21        if self.log_url is None:
22            return
23        GraphTransformObserver.__pass_count += 1
24        self.gm = gm
25        self.passname = passname
26
27        self.input_dot_graph = FxGraphDrawer(
28            self.gm,
29            self.passname,
30            ignore_getattr=True,
31            ignore_parameters_and_buffers=True,
32        ).get_dot_graph()
33
34    @classmethod
35    def get_current_pass_count(cls):
36        return cls.__pass_count
37
38    def __enter__(self):
39        if self.log_url is None or self.gm is None:
40            return self
41
42        self.erased_nodes = set()
43        self.created_nodes = set()
44        self.gm._register_create_node_hook(self.on_node_creation)
45        self.gm._register_erase_node_hook(self.on_node_erase)
46
47        return self
48
49    def __exit__(self, type, value, tb):
50        if self.log_url is None or self.gm is None:
51            return
52
53        self.gm._unregister_create_node_hook(self.on_node_creation)
54        self.gm._unregister_erase_node_hook(self.on_node_erase)
55
56        if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0:
57            for e in self.input_dot_graph.get_node_list():
58                if e.get_name() in self.erased_nodes:
59                    e.obj_dict["attributes"]["fillcolor"] = "yellow"
60                else:
61                    e.obj_dict["attributes"]["fillcolor"] = "grey"
62            self.input_dot_graph.write(
63                os.path.join(
64                    self.log_url,
65                    f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.dot",
66                )
67            )
68
69            output_dot_graph = FxGraphDrawer(
70                self.gm,
71                self.passname,
72                ignore_getattr=True,
73                ignore_parameters_and_buffers=True,
74            ).get_dot_graph()
75            for e in output_dot_graph.get_node_list():
76                if e.get_name() in self.created_nodes:
77                    e.obj_dict["attributes"]["fillcolor"] = "yellow"
78                else:
79                    e.obj_dict["attributes"]["fillcolor"] = "grey"
80            output_dot_graph.write(
81                os.path.join(
82                    self.log_url,
83                    f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.dot",
84                )
85            )
86
87    def on_node_creation(self, node):
88        self.created_nodes.add(node.name)
89
90    def on_node_erase(self, node):
91        self.erased_nodes.add(node.name)
92