1from dataclasses import dataclass, field 2from torch.fx.graph import Graph 3from torch.fx.node import Node 4from torch.fx._compatibility import compatibility 5from typing import Dict, List, Any, Type, Optional, Callable 6import logging 7import os 8 9 10__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition'] 11 12# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs 13def _init_logger() -> logging.Logger: 14 logger = logging.getLogger(__name__) 15 16 level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper() 17 logger.setLevel(level) 18 console = logging.StreamHandler() 19 formatter = logging.Formatter("%(filename)s > %(message)s") 20 console.setFormatter(formatter) 21 console.setLevel(level) 22 # add the handlers to the logger 23 logger.addHandler(console) 24 logger.propagate = False 25 return logger 26 27logger = _init_logger() 28 29 30@compatibility(is_backward_compatible=False) 31@dataclass 32class SourcePartition: 33 # Nodes in a particular partition 34 nodes: List[Node] 35 36 # The source these nodes decomposed from 37 source: Any 38 39 # Nodes in the graph that are needed as inputs to the partition 40 input_nodes: List[Node] = field(default_factory=list) 41 42 # Nodes in the partition that are being used by nodes outside of the 43 # partition 44 output_nodes: List[Node] = field(default_factory=list) 45 46 # Parameters that are being used 47 params: List[Node] = field(default_factory=list) 48 49 50@compatibility(is_backward_compatible=False) # type: ignore[misc] 51def get_source_partitions( 52 graph: Graph, 53 wanted_sources: List[Any], 54 filter_fn: Optional[Callable[[Node], bool]] = None, 55) -> Dict[Any, List[SourcePartition]]: 56 """ 57 Args: 58 graph: The graph we want to partition 59 wanted_sources: List of sources of nodes that were decomposed from this 60 source. This can be a function (ex. torch.nn.functional.linear) or a 61 leaf module type (ex. torch.nn.Linear). 62 63 Returns: 64 Dictionary mapping sources that were given to a list of SourcePartitions 65 that correspond to the list of nodes that were decomposed from the given 66 source. 67 """ 68 modules: Dict[Type, Dict[str, List[Node]]] = {} 69 70 for node in graph.nodes: 71 # The metadata source_fn should contain a tuple of a unique name for the 72 # source, and the source function if the node is decomposed from a 73 # function, or the type of module if the node is decomposed from a leaf 74 # module 75 76 # TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can 77 # be different from "source_fn_stack", for example for the add_ node 78 # decomposed from batch norm. We should remove the check on "source_fn_stack" 79 # after we fix "torch_fn". T199561090 80 if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and 81 (torch_fn := node.meta.get("torch_fn", None)) is not None): 82 node_fqn, source_fn = torch_fn 83 source_fn_name = source_fn.split(".")[1] 84 if source_fn_name in wanted_sources: 85 diff_modules = modules.setdefault(source_fn_name, {}) 86 partition = diff_modules.setdefault(node_fqn, []) 87 partition.append(node) 88 89 90 if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None: 91 source_fn = source_fn_st[-1] 92 if source_fn[1] in wanted_sources: 93 diff_modules = modules.setdefault(source_fn[1], {}) 94 partition = diff_modules.setdefault(source_fn[0], []) 95 partition.append(node) 96 97 def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition: 98 input_nodes = set() 99 output_nodes = set() 100 params = set() 101 for node in nodes: 102 for arg in node.args: 103 if isinstance(arg, Node) and arg not in nodes: 104 input_nodes.add(arg) 105 106 if node.op == "get_attr": 107 params.add(node) 108 109 for user in node.users.keys(): 110 if user not in nodes: 111 output_nodes.add(node) 112 113 return SourcePartition( 114 nodes, 115 module_type, 116 list(input_nodes), 117 list(output_nodes), 118 list(params), # type: ignore[arg-type] 119 ) 120 121 ret: Dict[Type[Any], List[SourcePartition]] = {} 122 123 if filter_fn: 124 # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the 125 # filter condition 126 filtered_modules = {} 127 for tp, name_to_partition in modules.items(): 128 filtered_name_to_partition = { 129 name: partition 130 for name, partition in name_to_partition.items() 131 if all(map(filter_fn, partition)) 132 } 133 filtered_modules[tp] = filtered_name_to_partition 134 modules = filtered_modules 135 136 for k, v in modules.items(): 137 ret[k] = [make_partition(partition, k) for partition in v.values()] 138 139 return ret 140 141 142@compatibility(is_backward_compatible=False) # type: ignore[misc] 143def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool: 144 """ 145 Given two subgraphs A and B (in the form of a list of nodes), checks if 146 A has nodes connecting to at least one node in B -- aka there exists a node 147 in B that uses a node in A (not the other way around). 148 """ 149 150 for node in reversed(subgraph1.nodes): 151 for user in node.users.keys(): 152 if user in subgraph2.nodes: 153 return True 154 return False 155