1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from typing import List, Optional 9 10import torch 11from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition 12from torch.fx.passes.operator_support import any_chain, OperatorSupportBase 13from torch.fx.passes.utils.matcher_utils import SubgraphMatcher 14 15 16def generate_partitions_from_list_of_nodes( 17 graph_module: torch.fx.GraphModule, 18 pattern_list: Optional[List[List[torch.fx.Node]]] = None, 19 op_support: Optional[OperatorSupportBase] = None, 20) -> List[Partition]: 21 final_op_support: Optional[OperatorSupportBase] = op_support 22 23 if pattern_list is not None: 24 # Tag all the nodes in these patterns 25 for node_list in pattern_list: 26 for node in node_list: 27 node.meta["match"] = True 28 29 class MatchTag(OperatorSupportBase): 30 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 31 return node.meta.get("match", False) 32 33 final_op_support = ( 34 MatchTag() 35 if final_op_support is None 36 else any_chain(final_op_support, MatchTag()) 37 ) 38 39 assert ( 40 final_op_support is not None 41 ), "Did not give a pattern or OperatorSupportBase instance to partition with" 42 43 # Run the CapabilityBasedPartitioner to return the largest possible 44 # subgraphs containing the nodes with the tags 45 capability_partitioner = CapabilityBasedPartitioner( 46 graph_module, 47 final_op_support, 48 allows_single_node_partition=True, 49 ) 50 partition_list = capability_partitioner.propose_partitions() 51 52 # Remove the metadata field we added 53 for partition in partition_list: 54 for node in partition.nodes: 55 node.meta.pop("match", False) 56 return partition_list 57 58 59def generate_pattern_op_partitions( 60 graph_module: torch.fx.GraphModule, 61 patterns: Optional[List[torch.fx.Graph]] = None, 62 partitions_list: Optional[List[List[torch.fx.Node]]] = None, 63 op_support: Optional[OperatorSupportBase] = None, 64 ignore_literals: bool = False, 65) -> List[Partition]: 66 """ 67 Args: 68 graph_module: Module that we want to partition 69 patterns: A list of patterns in the form of torch.fx.Graph. These graphs 70 can be obtained through the `graph` field from a GraphModule obtained by 71 exir.capture (recommended) or symbolic tracing (which might not result 72 in an accurate edge dialect graph), or by manual crafting a graph 73 module. 74 partitions_list: A list of node lists whose nodes are intended to be tagged 75 along with the nodes detected by the pattern matcher. 76 op_support: A OperatorSupportBase that can be created in the following ways: 77 - Subclassing it directly and implementing is_node_supported() 78 - Getting the result of create_op_support() 79 - Getting the result of create_pattern_support() 80 - Multiple OperatorSupportBase classes chained together with chain() 81 82 Returns 83 A list of partitions (largest possible subgraphs) containing nodes are 84 supported by the given OperatorSupportBase object 85 """ 86 final_op_support: Optional[OperatorSupportBase] = op_support 87 88 if patterns is not None: 89 # Find all patterns in the graph (even if they're invalid) 90 matches = [] 91 for pattern in patterns: 92 logging.debug(f"Finding matches for pattern: {pattern}") 93 subgraph_matcher = SubgraphMatcher(pattern, ignore_literals=ignore_literals) 94 matches.extend(subgraph_matcher.match(graph_module.graph)) 95 96 # Tag all the nodes in these patterns 97 for match in matches: 98 for node_in_pattern, node_in_graph in match.nodes_map.items(): 99 if ( 100 node_in_pattern.op != "placeholder" 101 and node_in_graph.op != "placeholder" 102 ): 103 node_in_graph.meta["match"] = True 104 105 if partitions_list: 106 for node_list in partitions_list: 107 for node in node_list: 108 node.meta["match"] = True 109 110 class MatchTag(OperatorSupportBase): 111 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 112 return node.meta.get("match", False) 113 114 final_op_support = ( 115 MatchTag() 116 if final_op_support is None 117 else any_chain(final_op_support, MatchTag()) 118 ) 119 120 assert ( 121 final_op_support is not None 122 ), "Did not give a pattern or OperatorSupportBase instance to partition with" 123 124 # Run the CapabilityBasedPartitioner to return the largest possible 125 # subgraphs containing the nodes with the tags 126 capability_partitioner = CapabilityBasedPartitioner( 127 graph_module, 128 final_op_support, 129 allows_single_node_partition=True, 130 ) 131 partition_list = capability_partitioner.propose_partitions() 132 133 # Remove the metadata field we added 134 for partition in partition_list: 135 for node in partition.nodes: 136 node.meta.pop("match", False) 137 return partition_list 138