1# mypy: allow-untyped-defs 2import os 3from typing import Optional 4 5from torch.fx._compatibility import compatibility 6from torch.fx.graph_module import GraphModule 7 8from .graph_drawer import FxGraphDrawer 9 10 11__all__ = ["GraphTransformObserver"] 12 13 14@compatibility(is_backward_compatible=False) 15class GraphTransformObserver: 16 __pass_count = 0 17 18 def __init__(self, gm: GraphModule, passname: str, log_url: Optional[str] = None): 19 # If log_url is None, we don't log anything 20 self.log_url = log_url 21 if self.log_url is None: 22 return 23 GraphTransformObserver.__pass_count += 1 24 self.gm = gm 25 self.passname = passname 26 27 self.input_dot_graph = FxGraphDrawer( 28 self.gm, 29 self.passname, 30 ignore_getattr=True, 31 ignore_parameters_and_buffers=True, 32 ).get_dot_graph() 33 34 @classmethod 35 def get_current_pass_count(cls): 36 return cls.__pass_count 37 38 def __enter__(self): 39 if self.log_url is None or self.gm is None: 40 return self 41 42 self.erased_nodes = set() 43 self.created_nodes = set() 44 self.gm._register_create_node_hook(self.on_node_creation) 45 self.gm._register_erase_node_hook(self.on_node_erase) 46 47 return self 48 49 def __exit__(self, type, value, tb): 50 if self.log_url is None or self.gm is None: 51 return 52 53 self.gm._unregister_create_node_hook(self.on_node_creation) 54 self.gm._unregister_erase_node_hook(self.on_node_erase) 55 56 if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0: 57 for e in self.input_dot_graph.get_node_list(): 58 if e.get_name() in self.erased_nodes: 59 e.obj_dict["attributes"]["fillcolor"] = "yellow" 60 else: 61 e.obj_dict["attributes"]["fillcolor"] = "grey" 62 self.input_dot_graph.write( 63 os.path.join( 64 self.log_url, 65 f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.dot", 66 ) 67 ) 68 69 output_dot_graph = FxGraphDrawer( 70 self.gm, 71 self.passname, 72 ignore_getattr=True, 73 ignore_parameters_and_buffers=True, 74 ).get_dot_graph() 75 for e in output_dot_graph.get_node_list(): 76 if e.get_name() in self.created_nodes: 77 e.obj_dict["attributes"]["fillcolor"] = "yellow" 78 else: 79 e.obj_dict["attributes"]["fillcolor"] = "grey" 80 output_dot_graph.write( 81 os.path.join( 82 self.log_url, 83 f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.dot", 84 ) 85 ) 86 87 def on_node_creation(self, node): 88 self.created_nodes.add(node.name) 89 90 def on_node_erase(self, node): 91 self.erased_nodes.add(node.name) 92