xref: /aosp_15_r20/external/pytorch/torch/fx/passes/utils/source_matcher_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from dataclasses import dataclass, field
2from torch.fx.graph import Graph
3from torch.fx.node import Node
4from torch.fx._compatibility import compatibility
5from typing import Dict, List, Any, Type, Optional, Callable
6import logging
7import os
8
9
10__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition']
11
12# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
13def _init_logger() -> logging.Logger:
14    logger = logging.getLogger(__name__)
15
16    level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
17    logger.setLevel(level)
18    console = logging.StreamHandler()
19    formatter = logging.Formatter("%(filename)s > %(message)s")
20    console.setFormatter(formatter)
21    console.setLevel(level)
22    # add the handlers to the logger
23    logger.addHandler(console)
24    logger.propagate = False
25    return logger
26
27logger = _init_logger()
28
29
30@compatibility(is_backward_compatible=False)
31@dataclass
32class SourcePartition:
33    # Nodes in a particular partition
34    nodes: List[Node]
35
36    # The source these nodes decomposed from
37    source: Any
38
39    # Nodes in the graph that are needed as inputs to the partition
40    input_nodes: List[Node] = field(default_factory=list)
41
42    # Nodes in the partition that are being used by nodes outside of the
43    # partition
44    output_nodes: List[Node] = field(default_factory=list)
45
46    # Parameters that are being used
47    params: List[Node] = field(default_factory=list)
48
49
50@compatibility(is_backward_compatible=False)  # type: ignore[misc]
51def get_source_partitions(
52    graph: Graph,
53    wanted_sources: List[Any],
54    filter_fn: Optional[Callable[[Node], bool]] = None,
55) -> Dict[Any, List[SourcePartition]]:
56    """
57    Args:
58        graph: The graph we want to partition
59        wanted_sources: List of sources of nodes that were decomposed from this
60            source. This can be a function (ex. torch.nn.functional.linear) or a
61            leaf module type (ex. torch.nn.Linear).
62
63    Returns:
64        Dictionary mapping sources that were given to a list of SourcePartitions
65        that correspond to the list of nodes that were decomposed from the given
66        source.
67    """
68    modules: Dict[Type, Dict[str, List[Node]]] = {}
69
70    for node in graph.nodes:
71        # The metadata source_fn should contain a tuple of a unique name for the
72        # source, and the source function if the node is decomposed from a
73        # function, or the type of module if the node is decomposed from a leaf
74        # module
75
76        # TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can
77        # be different from "source_fn_stack", for example for the add_ node
78        # decomposed from batch norm. We should remove the check on "source_fn_stack"
79        # after we fix "torch_fn". T199561090
80        if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and
81           (torch_fn := node.meta.get("torch_fn", None)) is not None):
82            node_fqn, source_fn = torch_fn
83            source_fn_name = source_fn.split(".")[1]
84            if source_fn_name in wanted_sources:
85                diff_modules = modules.setdefault(source_fn_name, {})
86                partition = diff_modules.setdefault(node_fqn, [])
87                partition.append(node)
88
89
90        if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None:
91            source_fn = source_fn_st[-1]
92            if source_fn[1] in wanted_sources:
93                diff_modules = modules.setdefault(source_fn[1], {})
94                partition = diff_modules.setdefault(source_fn[0], [])
95                partition.append(node)
96
97    def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
98        input_nodes = set()
99        output_nodes = set()
100        params = set()
101        for node in nodes:
102            for arg in node.args:
103                if isinstance(arg, Node) and arg not in nodes:
104                    input_nodes.add(arg)
105
106            if node.op == "get_attr":
107                params.add(node)
108
109            for user in node.users.keys():
110                if user not in nodes:
111                    output_nodes.add(node)
112
113        return SourcePartition(
114            nodes,
115            module_type,
116            list(input_nodes),
117            list(output_nodes),
118            list(params),  # type: ignore[arg-type]
119        )
120
121    ret: Dict[Type[Any], List[SourcePartition]] = {}
122
123    if filter_fn:
124        # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the
125        # filter condition
126        filtered_modules = {}
127        for tp, name_to_partition in modules.items():
128            filtered_name_to_partition = {
129                name: partition
130                for name, partition in name_to_partition.items()
131                if all(map(filter_fn, partition))
132            }
133            filtered_modules[tp] = filtered_name_to_partition
134        modules = filtered_modules
135
136    for k, v in modules.items():
137        ret[k] = [make_partition(partition, k) for partition in v.values()]
138
139    return ret
140
141
142@compatibility(is_backward_compatible=False)  # type: ignore[misc]
143def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
144    """
145    Given two subgraphs A and B (in the form of a list of nodes), checks if
146    A has nodes connecting to at least one node in B -- aka there exists a node
147    in B that uses a node in A (not the other way around).
148    """
149
150    for node in reversed(subgraph1.nodes):
151        for user in node.users.keys():
152            if user in subgraph2.nodes:
153                return True
154    return False
155