xref: /aosp_15_r20/external/pytorch/torch/fx/passes/graph_drawer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import hashlib
4from itertools import chain
5from typing import Any, Dict, Optional, TYPE_CHECKING
6
7import torch
8import torch.fx
9from torch.fx._compatibility import compatibility
10from torch.fx.graph import _parse_stack_trace
11from torch.fx.node import _format_arg, _get_qualified_name
12from torch.fx.operator_schemas import normalize_function
13from torch.fx.passes.shape_prop import TensorMetadata
14
15
16try:
17    import pydot
18
19    HAS_PYDOT = True
20except ModuleNotFoundError:
21    HAS_PYDOT = False
22    pydot = None
23
24
25__all__ = ["FxGraphDrawer"]
26
27_COLOR_MAP = {
28    "placeholder": '"AliceBlue"',
29    "call_module": "LemonChiffon1",
30    "get_param": "Yellow2",
31    "get_attr": "LightGrey",
32    "output": "PowderBlue",
33}
34
35_HASH_COLOR_MAP = [
36    "CadetBlue1",
37    "Coral",
38    "DarkOliveGreen1",
39    "DarkSeaGreen1",
40    "GhostWhite",
41    "Khaki1",
42    "LavenderBlush1",
43    "LightSkyBlue",
44    "MistyRose1",
45    "MistyRose2",
46    "PaleTurquoise2",
47    "PeachPuff1",
48    "Salmon",
49    "Thistle1",
50    "Thistle3",
51    "Wheat1",
52]
53
54_WEIGHT_TEMPLATE = {
55    "fillcolor": "Salmon",
56    "style": '"filled,rounded"',
57    "fontcolor": "#000000",
58}
59
60if HAS_PYDOT:
61    @compatibility(is_backward_compatible=False)
62    class FxGraphDrawer:
63        """
64        Visualize a torch.fx.Graph with graphviz
65        Basic usage:
66            g = FxGraphDrawer(symbolic_traced, "resnet18")
67            g.get_dot_graph().write_svg("a.svg")
68        """
69
70        def __init__(
71            self,
72            graph_module: torch.fx.GraphModule,
73            name: str,
74            ignore_getattr: bool = False,
75            ignore_parameters_and_buffers: bool = False,
76            skip_node_names_in_args: bool = True,
77            parse_stack_trace: bool = False,
78            dot_graph_shape: Optional[str] = None,
79            normalize_args: bool = False,
80        ):
81            self._name = name
82            self.dot_graph_shape = (
83                dot_graph_shape if dot_graph_shape is not None else "record"
84            )
85            self.normalize_args = normalize_args
86            _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape
87
88            self._dot_graphs = {
89                name: self._to_dot(
90                    graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace
91                )
92            }
93
94            for node in graph_module.graph.nodes:
95                if node.op != "call_module":
96                    continue
97
98                leaf_node = self._get_leaf_node(graph_module, node)
99
100                if not isinstance(leaf_node, torch.fx.GraphModule):
101                    continue
102
103                self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
104                    leaf_node,
105                    f"{name}_{node.target}",
106                    ignore_getattr,
107                    ignore_parameters_and_buffers,
108                    skip_node_names_in_args,
109                    parse_stack_trace,
110                )
111
112        def get_dot_graph(self, submod_name=None) -> pydot.Dot:
113            """
114            Visualize a torch.fx.Graph with graphviz
115            Example:
116                >>> # xdoctest: +REQUIRES(module:pydot)
117                >>> # xdoctest: +REQUIRES(module:ubelt)
118                >>> # define module
119                >>> class MyModule(torch.nn.Module):
120                >>>     def __init__(self) -> None:
121                >>>         super().__init__()
122                >>>         self.linear = torch.nn.Linear(4, 5)
123                >>>     def forward(self, x):
124                >>>         return self.linear(x).clamp(min=0.0, max=1.0)
125                >>> module = MyModule()
126                >>> # trace the module
127                >>> symbolic_traced = torch.fx.symbolic_trace(module)
128                >>> # setup output file
129                >>> import ubelt as ub
130                >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir()
131                >>> fpath = dpath / 'linear.svg'
132                >>> # draw the graph
133                >>> g = FxGraphDrawer(symbolic_traced, "linear")
134                >>> g.get_dot_graph().write_svg(fpath)
135            """
136            if submod_name is None:
137                return self.get_main_dot_graph()
138            else:
139                return self.get_submod_dot_graph(submod_name)
140
141        def get_main_dot_graph(self) -> pydot.Dot:
142            return self._dot_graphs[self._name]
143
144        def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
145            return self._dot_graphs[f"{self._name}_{submod_name}"]
146
147        def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]:
148            return self._dot_graphs
149
150        def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
151
152            template = {
153                "shape": self.dot_graph_shape,
154                "fillcolor": "#CAFFE3",
155                "style": '"filled,rounded"',
156                "fontcolor": "#000000",
157            }
158            if node.op in _COLOR_MAP:
159                template["fillcolor"] = _COLOR_MAP[node.op]
160            else:
161                # Use a random color for each node; based on its name so it's stable.
162                target_name = node._pretty_print_target(node.target)
163                target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16)
164                template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)]
165            return template
166
167        def _get_leaf_node(
168            self, module: torch.nn.Module, node: torch.fx.Node
169        ) -> torch.nn.Module:
170            py_obj = module
171            assert isinstance(node.target, str)
172            atoms = node.target.split(".")
173            for atom in atoms:
174                if not hasattr(py_obj, atom):
175                    raise RuntimeError(
176                        str(py_obj) + " does not have attribute " + atom + "!"
177                    )
178                py_obj = getattr(py_obj, atom)
179            return py_obj
180
181        def _typename(self, target: Any) -> str:
182            if isinstance(target, torch.nn.Module):
183                ret = torch.typename(target)
184            elif isinstance(target, str):
185                ret = target
186            else:
187                ret = _get_qualified_name(target)
188
189            # Escape "{" and "}" to prevent dot files like:
190            # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
191            # which triggers `Error: bad label format (...)` from dot
192            return ret.replace("{", r"\{").replace("}", r"\}")
193
194        # shorten path to avoid drawing long boxes
195        # for full path = '/home/weif/pytorch/test.py'
196        # return short path = 'pytorch/test.py'
197        def _shorten_file_name(
198            self,
199            full_file_name: str,
200            truncate_to_last_n: int = 2,
201        ):
202            splits = full_file_name.split('/')
203            if len(splits) >= truncate_to_last_n:
204                return '/'.join(splits[-truncate_to_last_n:])
205            return full_file_name
206
207
208        def _get_node_label(
209            self,
210            module: torch.fx.GraphModule,
211            node: torch.fx.Node,
212            skip_node_names_in_args: bool,
213            parse_stack_trace: bool,
214        ) -> str:
215            def _get_str_for_args_kwargs(arg):
216                if isinstance(arg, tuple):
217                    prefix, suffix = r"|args=(\l", r",\n)\l"
218                    arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
219                elif isinstance(arg, dict):
220                    prefix, suffix = r"|kwargs={\l", r",\n}\l"
221                    arg_strs_list = [
222                        f"{k}: {_format_arg(v, max_list_len=8)}"
223                        for k, v in arg.items()
224                    ]
225                else:  # Fall back to nothing in unexpected case.
226                    return ""
227
228                # Strip out node names if requested.
229                if skip_node_names_in_args:
230                    arg_strs_list = [a for a in arg_strs_list if "%" not in a]
231                if len(arg_strs_list) == 0:
232                    return ""
233                arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
234                if len(arg_strs_list) == 1:
235                    arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
236                return arg_strs.replace("{", r"\{").replace("}", r"\}")
237
238
239            label = "{" + f"name=%{node.name}|op_code={node.op}\n"
240
241            if node.op == "call_module":
242                leaf_module = self._get_leaf_node(module, node)
243                label += r"\n" + self._typename(leaf_module) + r"\n|"
244                extra = ""
245                if hasattr(leaf_module, "__constants__"):
246                    extra = r"\n".join(
247                        [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__]  # type: ignore[union-attr]
248                    )
249                label += extra + r"\n"
250            else:
251                label += f"|target={self._typename(node.target)}" + r"\n"
252                if self.normalize_args:
253                    try:
254                        args, kwargs = normalize_function(  # type: ignore[misc]
255                            node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True  # type: ignore[arg-type]
256                        )
257                    except Exception:
258                        # Fallback to not normalizing if there's an exception.
259                        # Some functions need overloads specified to normalize.
260                        args, kwargs = node.args, node.kwargs
261                else:
262                    args, kwargs = node.args, node.kwargs
263                if len(args) > 0:
264                    label += _get_str_for_args_kwargs(args)
265                if len(kwargs) > 0:
266                    label += _get_str_for_args_kwargs(kwargs)
267                label += f"|num_users={len(node.users)}" + r"\n"
268
269            tensor_meta = node.meta.get('tensor_meta')
270            label += self._tensor_meta_to_label(tensor_meta)
271
272            # for original fx graph
273            # print buf=buf0, n_origin=6
274            buf_meta = node.meta.get('buf_meta', None)
275            if buf_meta is not None:
276                label += f"|buf={buf_meta.name}" + r"\n"
277                label += f"|n_origin={buf_meta.n_origin}" + r"\n"
278
279            # for original fx graph
280            # print file:lineno code
281            if parse_stack_trace and node.stack_trace is not None:
282                parsed_stack_trace = _parse_stack_trace(node.stack_trace)
283                fname = self._shorten_file_name(parsed_stack_trace.file)
284                label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n"
285
286
287            return label + "}"
288
289        def _tensor_meta_to_label(self, tm) -> str:
290            if tm is None:
291                return ""
292            elif isinstance(tm, TensorMetadata):
293                return self._stringify_tensor_meta(tm)
294            elif isinstance(tm, list):
295                result = ""
296                for item in tm:
297                    result += self._tensor_meta_to_label(item)
298                return result
299            elif isinstance(tm, dict):
300                result = ""
301                for v in tm.values():
302                    result += self._tensor_meta_to_label(v)
303                return result
304            elif isinstance(tm, tuple):
305                result = ""
306                for item in tm:
307                    result += self._tensor_meta_to_label(item)
308                return result
309            else:
310                raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
311
312        def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
313            result = ""
314            if not hasattr(tm, "dtype"):
315                print("tm", tm)
316            result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
317            result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
318            result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
319            result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
320            if tm.is_quantized:
321                assert tm.qparams is not None
322                assert "qscheme" in tm.qparams
323                qscheme = tm.qparams["qscheme"]
324                if qscheme in {
325                        torch.per_tensor_affine,
326                        torch.per_tensor_symmetric,
327                }:
328                    result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
329                    result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
330                elif qscheme in {
331                        torch.per_channel_affine,
332                        torch.per_channel_symmetric,
333                        torch.per_channel_affine_float_qparams,
334                }:
335                    result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
336                    result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
337                    result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n"
338                else:
339                    raise RuntimeError(f"Unsupported qscheme: {qscheme}")
340                result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
341            return result
342
343        def _get_tensor_label(self, t: torch.Tensor) -> str:
344            return str(t.dtype) + str(list(t.shape)) + r"\n"
345
346        # when parse_stack_trace=True
347        # print file:lineno code
348        def _to_dot(
349            self,
350            graph_module: torch.fx.GraphModule,
351            name: str,
352            ignore_getattr: bool,
353            ignore_parameters_and_buffers: bool,
354            skip_node_names_in_args: bool,
355            parse_stack_trace: bool,
356        ) -> pydot.Dot:
357            """
358            Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
359            If ignore_parameters_and_buffers is True, the parameters and buffers
360            created with the module will not be added as nodes and edges.
361            """
362
363            # "TB" means top-to-bottom rank direction in layout
364            dot_graph = pydot.Dot(name, rankdir="TB")
365
366
367            buf_name_to_subgraph = {}
368
369            for node in graph_module.graph.nodes:
370                if ignore_getattr and node.op == "get_attr":
371                    continue
372
373                style = self._get_node_style(node)
374                dot_node = pydot.Node(
375                    node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style
376                )
377
378                current_graph = dot_graph
379
380                buf_meta = node.meta.get('buf_meta', None)
381                if buf_meta is not None and buf_meta.n_origin > 1:
382                    buf_name = buf_meta.name
383                    if buf_name not in buf_name_to_subgraph:
384                        buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name)
385                    current_graph = buf_name_to_subgraph.get(buf_name)
386
387                current_graph.add_node(dot_node)
388
389                def get_module_params_or_buffers():
390                    for pname, ptensor in chain(
391                        leaf_module.named_parameters(), leaf_module.named_buffers()
392                    ):
393                        pname1 = node.name + "." + pname
394                        label1 = (
395                            pname1 + "|op_code=get_" + "parameter"
396                            if isinstance(ptensor, torch.nn.Parameter)
397                            else "buffer" + r"\l"
398                        )
399                        dot_w_node = pydot.Node(
400                            pname1,
401                            label="{" + label1 + self._get_tensor_label(ptensor) + "}",
402                            **_WEIGHT_TEMPLATE,
403                        )
404                        dot_graph.add_node(dot_w_node)
405                        dot_graph.add_edge(pydot.Edge(pname1, node.name))
406
407                if node.op == "call_module":
408                    leaf_module = self._get_leaf_node(graph_module, node)
409
410                    if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule):
411                        get_module_params_or_buffers()
412
413            for subgraph in buf_name_to_subgraph.values():
414                subgraph.set('color', 'royalblue')
415                subgraph.set('penwidth', '2')
416                dot_graph.add_subgraph(subgraph)
417
418            for node in graph_module.graph.nodes:
419                if ignore_getattr and node.op == "get_attr":
420                    continue
421
422                for user in node.users:
423                    dot_graph.add_edge(pydot.Edge(node.name, user.name))
424
425            return dot_graph
426
427else:
428    if not TYPE_CHECKING:
429        @compatibility(is_backward_compatible=False)
430        class FxGraphDrawer:
431            def __init__(
432                self,
433                graph_module: torch.fx.GraphModule,
434                name: str,
435                ignore_getattr: bool = False,
436                ignore_parameters_and_buffers: bool = False,
437                skip_node_names_in_args: bool = True,
438                parse_stack_trace: bool = False,
439                dot_graph_shape: Optional[str] = None,
440                normalize_args: bool = False,
441            ):
442                raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install '
443                                   'pydot through your favorite Python package manager.')
444