1# mypy: allow-untyped-defs 2import sys 3from typing import Dict, Optional 4 5import torch 6from torch._logging import LazyString 7 8 9def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): 10 """ 11 Returns a LazyString that formats the graph code. 12 """ 13 14 def format_name(): 15 if maybe_id is not None: 16 return f"{name} {maybe_id}" 17 else: 18 return name 19 20 if "print_output" not in kwargs: 21 kwargs["print_output"] = False 22 23 if "colored" in kwargs and not sys.stdout.isatty(): 24 kwargs["colored"] = False 25 26 return LazyString( 27 lambda: _format_graph_code( 28 f"===== {format_name()} =====\n", 29 gm.forward.__code__.co_filename, 30 gm.print_readable(**kwargs), 31 ) 32 ) 33 34 35def _format_graph_code(name, filename, graph_str): 36 """ 37 Returns a string that formats the graph code. 38 """ 39 return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" 40 41 42def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]: 43 """ 44 Returns the nn_module_stack of the first call_function node. 45 """ 46 for node in graph.nodes: 47 if node.op == "call_function" and "nn_module_stack" in node.meta: 48 return node.meta["nn_module_stack"] 49 return None 50 51 52def get_node_context(node, num_nodes=2) -> str: 53 """ 54 Returns a string of the last num_nodes nodes in the graph. 55 """ 56 node_contexts = [] 57 cur = node 58 for i in range(num_nodes): 59 node_contexts.append(cur.format_node()) 60 if cur.op == "root": 61 break 62 cur = cur.prev 63 return "\n".join(node_contexts[::-1]) 64