1# mypy: allow-untyped-defs 2 3import hashlib 4from itertools import chain 5from typing import Any, Dict, Optional, TYPE_CHECKING 6 7import torch 8import torch.fx 9from torch.fx._compatibility import compatibility 10from torch.fx.graph import _parse_stack_trace 11from torch.fx.node import _format_arg, _get_qualified_name 12from torch.fx.operator_schemas import normalize_function 13from torch.fx.passes.shape_prop import TensorMetadata 14 15 16try: 17 import pydot 18 19 HAS_PYDOT = True 20except ModuleNotFoundError: 21 HAS_PYDOT = False 22 pydot = None 23 24 25__all__ = ["FxGraphDrawer"] 26 27_COLOR_MAP = { 28 "placeholder": '"AliceBlue"', 29 "call_module": "LemonChiffon1", 30 "get_param": "Yellow2", 31 "get_attr": "LightGrey", 32 "output": "PowderBlue", 33} 34 35_HASH_COLOR_MAP = [ 36 "CadetBlue1", 37 "Coral", 38 "DarkOliveGreen1", 39 "DarkSeaGreen1", 40 "GhostWhite", 41 "Khaki1", 42 "LavenderBlush1", 43 "LightSkyBlue", 44 "MistyRose1", 45 "MistyRose2", 46 "PaleTurquoise2", 47 "PeachPuff1", 48 "Salmon", 49 "Thistle1", 50 "Thistle3", 51 "Wheat1", 52] 53 54_WEIGHT_TEMPLATE = { 55 "fillcolor": "Salmon", 56 "style": '"filled,rounded"', 57 "fontcolor": "#000000", 58} 59 60if HAS_PYDOT: 61 @compatibility(is_backward_compatible=False) 62 class FxGraphDrawer: 63 """ 64 Visualize a torch.fx.Graph with graphviz 65 Basic usage: 66 g = FxGraphDrawer(symbolic_traced, "resnet18") 67 g.get_dot_graph().write_svg("a.svg") 68 """ 69 70 def __init__( 71 self, 72 graph_module: torch.fx.GraphModule, 73 name: str, 74 ignore_getattr: bool = False, 75 ignore_parameters_and_buffers: bool = False, 76 skip_node_names_in_args: bool = True, 77 parse_stack_trace: bool = False, 78 dot_graph_shape: Optional[str] = None, 79 normalize_args: bool = False, 80 ): 81 self._name = name 82 self.dot_graph_shape = ( 83 dot_graph_shape if dot_graph_shape is not None else "record" 84 ) 85 self.normalize_args = normalize_args 86 _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape 87 88 self._dot_graphs = { 89 name: self._to_dot( 90 graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace 91 ) 92 } 93 94 for node in graph_module.graph.nodes: 95 if node.op != "call_module": 96 continue 97 98 leaf_node = self._get_leaf_node(graph_module, node) 99 100 if not isinstance(leaf_node, torch.fx.GraphModule): 101 continue 102 103 self._dot_graphs[f"{name}_{node.target}"] = self._to_dot( 104 leaf_node, 105 f"{name}_{node.target}", 106 ignore_getattr, 107 ignore_parameters_and_buffers, 108 skip_node_names_in_args, 109 parse_stack_trace, 110 ) 111 112 def get_dot_graph(self, submod_name=None) -> pydot.Dot: 113 """ 114 Visualize a torch.fx.Graph with graphviz 115 Example: 116 >>> # xdoctest: +REQUIRES(module:pydot) 117 >>> # xdoctest: +REQUIRES(module:ubelt) 118 >>> # define module 119 >>> class MyModule(torch.nn.Module): 120 >>> def __init__(self) -> None: 121 >>> super().__init__() 122 >>> self.linear = torch.nn.Linear(4, 5) 123 >>> def forward(self, x): 124 >>> return self.linear(x).clamp(min=0.0, max=1.0) 125 >>> module = MyModule() 126 >>> # trace the module 127 >>> symbolic_traced = torch.fx.symbolic_trace(module) 128 >>> # setup output file 129 >>> import ubelt as ub 130 >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir() 131 >>> fpath = dpath / 'linear.svg' 132 >>> # draw the graph 133 >>> g = FxGraphDrawer(symbolic_traced, "linear") 134 >>> g.get_dot_graph().write_svg(fpath) 135 """ 136 if submod_name is None: 137 return self.get_main_dot_graph() 138 else: 139 return self.get_submod_dot_graph(submod_name) 140 141 def get_main_dot_graph(self) -> pydot.Dot: 142 return self._dot_graphs[self._name] 143 144 def get_submod_dot_graph(self, submod_name) -> pydot.Dot: 145 return self._dot_graphs[f"{self._name}_{submod_name}"] 146 147 def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: 148 return self._dot_graphs 149 150 def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: 151 152 template = { 153 "shape": self.dot_graph_shape, 154 "fillcolor": "#CAFFE3", 155 "style": '"filled,rounded"', 156 "fontcolor": "#000000", 157 } 158 if node.op in _COLOR_MAP: 159 template["fillcolor"] = _COLOR_MAP[node.op] 160 else: 161 # Use a random color for each node; based on its name so it's stable. 162 target_name = node._pretty_print_target(node.target) 163 target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) 164 template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] 165 return template 166 167 def _get_leaf_node( 168 self, module: torch.nn.Module, node: torch.fx.Node 169 ) -> torch.nn.Module: 170 py_obj = module 171 assert isinstance(node.target, str) 172 atoms = node.target.split(".") 173 for atom in atoms: 174 if not hasattr(py_obj, atom): 175 raise RuntimeError( 176 str(py_obj) + " does not have attribute " + atom + "!" 177 ) 178 py_obj = getattr(py_obj, atom) 179 return py_obj 180 181 def _typename(self, target: Any) -> str: 182 if isinstance(target, torch.nn.Module): 183 ret = torch.typename(target) 184 elif isinstance(target, str): 185 ret = target 186 else: 187 ret = _get_qualified_name(target) 188 189 # Escape "{" and "}" to prevent dot files like: 190 # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc 191 # which triggers `Error: bad label format (...)` from dot 192 return ret.replace("{", r"\{").replace("}", r"\}") 193 194 # shorten path to avoid drawing long boxes 195 # for full path = '/home/weif/pytorch/test.py' 196 # return short path = 'pytorch/test.py' 197 def _shorten_file_name( 198 self, 199 full_file_name: str, 200 truncate_to_last_n: int = 2, 201 ): 202 splits = full_file_name.split('/') 203 if len(splits) >= truncate_to_last_n: 204 return '/'.join(splits[-truncate_to_last_n:]) 205 return full_file_name 206 207 208 def _get_node_label( 209 self, 210 module: torch.fx.GraphModule, 211 node: torch.fx.Node, 212 skip_node_names_in_args: bool, 213 parse_stack_trace: bool, 214 ) -> str: 215 def _get_str_for_args_kwargs(arg): 216 if isinstance(arg, tuple): 217 prefix, suffix = r"|args=(\l", r",\n)\l" 218 arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] 219 elif isinstance(arg, dict): 220 prefix, suffix = r"|kwargs={\l", r",\n}\l" 221 arg_strs_list = [ 222 f"{k}: {_format_arg(v, max_list_len=8)}" 223 for k, v in arg.items() 224 ] 225 else: # Fall back to nothing in unexpected case. 226 return "" 227 228 # Strip out node names if requested. 229 if skip_node_names_in_args: 230 arg_strs_list = [a for a in arg_strs_list if "%" not in a] 231 if len(arg_strs_list) == 0: 232 return "" 233 arg_strs = prefix + r",\n".join(arg_strs_list) + suffix 234 if len(arg_strs_list) == 1: 235 arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "") 236 return arg_strs.replace("{", r"\{").replace("}", r"\}") 237 238 239 label = "{" + f"name=%{node.name}|op_code={node.op}\n" 240 241 if node.op == "call_module": 242 leaf_module = self._get_leaf_node(module, node) 243 label += r"\n" + self._typename(leaf_module) + r"\n|" 244 extra = "" 245 if hasattr(leaf_module, "__constants__"): 246 extra = r"\n".join( 247 [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr] 248 ) 249 label += extra + r"\n" 250 else: 251 label += f"|target={self._typename(node.target)}" + r"\n" 252 if self.normalize_args: 253 try: 254 args, kwargs = normalize_function( # type: ignore[misc] 255 node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] 256 ) 257 except Exception: 258 # Fallback to not normalizing if there's an exception. 259 # Some functions need overloads specified to normalize. 260 args, kwargs = node.args, node.kwargs 261 else: 262 args, kwargs = node.args, node.kwargs 263 if len(args) > 0: 264 label += _get_str_for_args_kwargs(args) 265 if len(kwargs) > 0: 266 label += _get_str_for_args_kwargs(kwargs) 267 label += f"|num_users={len(node.users)}" + r"\n" 268 269 tensor_meta = node.meta.get('tensor_meta') 270 label += self._tensor_meta_to_label(tensor_meta) 271 272 # for original fx graph 273 # print buf=buf0, n_origin=6 274 buf_meta = node.meta.get('buf_meta', None) 275 if buf_meta is not None: 276 label += f"|buf={buf_meta.name}" + r"\n" 277 label += f"|n_origin={buf_meta.n_origin}" + r"\n" 278 279 # for original fx graph 280 # print file:lineno code 281 if parse_stack_trace and node.stack_trace is not None: 282 parsed_stack_trace = _parse_stack_trace(node.stack_trace) 283 fname = self._shorten_file_name(parsed_stack_trace.file) 284 label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n" 285 286 287 return label + "}" 288 289 def _tensor_meta_to_label(self, tm) -> str: 290 if tm is None: 291 return "" 292 elif isinstance(tm, TensorMetadata): 293 return self._stringify_tensor_meta(tm) 294 elif isinstance(tm, list): 295 result = "" 296 for item in tm: 297 result += self._tensor_meta_to_label(item) 298 return result 299 elif isinstance(tm, dict): 300 result = "" 301 for v in tm.values(): 302 result += self._tensor_meta_to_label(v) 303 return result 304 elif isinstance(tm, tuple): 305 result = "" 306 for item in tm: 307 result += self._tensor_meta_to_label(item) 308 return result 309 else: 310 raise RuntimeError(f"Unsupported tensor meta type {type(tm)}") 311 312 def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: 313 result = "" 314 if not hasattr(tm, "dtype"): 315 print("tm", tm) 316 result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n" 317 result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n" 318 result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n" 319 result += "|" + "stride" + "=" + str(tm.stride) + r"\n" 320 if tm.is_quantized: 321 assert tm.qparams is not None 322 assert "qscheme" in tm.qparams 323 qscheme = tm.qparams["qscheme"] 324 if qscheme in { 325 torch.per_tensor_affine, 326 torch.per_tensor_symmetric, 327 }: 328 result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" 329 result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" 330 elif qscheme in { 331 torch.per_channel_affine, 332 torch.per_channel_symmetric, 333 torch.per_channel_affine_float_qparams, 334 }: 335 result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n" 336 result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" 337 result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n" 338 else: 339 raise RuntimeError(f"Unsupported qscheme: {qscheme}") 340 result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" 341 return result 342 343 def _get_tensor_label(self, t: torch.Tensor) -> str: 344 return str(t.dtype) + str(list(t.shape)) + r"\n" 345 346 # when parse_stack_trace=True 347 # print file:lineno code 348 def _to_dot( 349 self, 350 graph_module: torch.fx.GraphModule, 351 name: str, 352 ignore_getattr: bool, 353 ignore_parameters_and_buffers: bool, 354 skip_node_names_in_args: bool, 355 parse_stack_trace: bool, 356 ) -> pydot.Dot: 357 """ 358 Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. 359 If ignore_parameters_and_buffers is True, the parameters and buffers 360 created with the module will not be added as nodes and edges. 361 """ 362 363 # "TB" means top-to-bottom rank direction in layout 364 dot_graph = pydot.Dot(name, rankdir="TB") 365 366 367 buf_name_to_subgraph = {} 368 369 for node in graph_module.graph.nodes: 370 if ignore_getattr and node.op == "get_attr": 371 continue 372 373 style = self._get_node_style(node) 374 dot_node = pydot.Node( 375 node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style 376 ) 377 378 current_graph = dot_graph 379 380 buf_meta = node.meta.get('buf_meta', None) 381 if buf_meta is not None and buf_meta.n_origin > 1: 382 buf_name = buf_meta.name 383 if buf_name not in buf_name_to_subgraph: 384 buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name) 385 current_graph = buf_name_to_subgraph.get(buf_name) 386 387 current_graph.add_node(dot_node) 388 389 def get_module_params_or_buffers(): 390 for pname, ptensor in chain( 391 leaf_module.named_parameters(), leaf_module.named_buffers() 392 ): 393 pname1 = node.name + "." + pname 394 label1 = ( 395 pname1 + "|op_code=get_" + "parameter" 396 if isinstance(ptensor, torch.nn.Parameter) 397 else "buffer" + r"\l" 398 ) 399 dot_w_node = pydot.Node( 400 pname1, 401 label="{" + label1 + self._get_tensor_label(ptensor) + "}", 402 **_WEIGHT_TEMPLATE, 403 ) 404 dot_graph.add_node(dot_w_node) 405 dot_graph.add_edge(pydot.Edge(pname1, node.name)) 406 407 if node.op == "call_module": 408 leaf_module = self._get_leaf_node(graph_module, node) 409 410 if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule): 411 get_module_params_or_buffers() 412 413 for subgraph in buf_name_to_subgraph.values(): 414 subgraph.set('color', 'royalblue') 415 subgraph.set('penwidth', '2') 416 dot_graph.add_subgraph(subgraph) 417 418 for node in graph_module.graph.nodes: 419 if ignore_getattr and node.op == "get_attr": 420 continue 421 422 for user in node.users: 423 dot_graph.add_edge(pydot.Edge(node.name, user.name)) 424 425 return dot_graph 426 427else: 428 if not TYPE_CHECKING: 429 @compatibility(is_backward_compatible=False) 430 class FxGraphDrawer: 431 def __init__( 432 self, 433 graph_module: torch.fx.GraphModule, 434 name: str, 435 ignore_getattr: bool = False, 436 ignore_parameters_and_buffers: bool = False, 437 skip_node_names_in_args: bool = True, 438 parse_stack_trace: bool = False, 439 dot_graph_shape: Optional[str] = None, 440 normalize_args: bool = False, 441 ): 442 raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install ' 443 'pydot through your favorite Python package manager.') 444