1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from typing import final, List 9 10import torch 11from executorch import exir 12from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 13 generate_pattern_op_partitions, 14) 15 16from executorch.exir.backend.partitioner import ( 17 DelegationSpec, 18 Partitioner, 19 PartitionResult, 20) 21from executorch.exir.backend.test.qnn_backend_demo import QnnBackend 22from executorch.exir.backend.utils import tag_constant_data 23from torch.export import ExportedProgram 24from torch.fx.passes.infra.partitioner import Partition 25 26 27@final 28class HTAPartitionerMultiplePatternsDemo(Partitioner): 29 """ 30 An example implementation to partition graph for HTA, in this example, the backend 31 associate with this partitioner is QnnBackend. With QnnBackend, the two lowerable 32 patterns are: (lstm + conv) and (sub). backend is a class member instead of instance 33 members, as it is a properties of HTAPartitionerMultiplePatternsDemo, and won't be different for 34 different HTAPartitionerMultiplePatternsDemo instances. 35 36 The partition algorithm is: 37 1. Find out a list of partitions given a graph: generate_partition_list(GraphModule) -> List[Partition]: 38 2. Check if all partitions from generate_partition_list() are exclusive. If they are, it will error out 39 3. Fuse the partition list as submodules. 40 """ 41 42 def __init__(self) -> None: 43 """ 44 Initialize a list of pattern partitioners: (lstm + conv) and (sub) 45 """ 46 47 class LSTMConvPattern(torch.nn.Module): 48 def __init__(self): 49 super().__init__() 50 self.lstm = torch.nn.LSTM( 51 input_size=32, 52 hidden_size=32, 53 num_layers=1, 54 ) 55 self.conv = torch.nn.Conv1d(1, 1, 1, stride=2) 56 57 def forward(self, x_raw, h, c): 58 output, (hn, cn) = self.lstm(x_raw, (h, c)) 59 k = self.conv(output) 60 return output, hn, cn, k 61 62 input_x = torch.ones([1, 32]) 63 input_h = torch.ones([1, 32]) 64 input_c = torch.ones([1, 32]) 65 66 pattern_lstm_conv_lifted = ( 67 exir.capture( 68 LSTMConvPattern(), 69 (input_x, input_h, input_c), 70 exir.CaptureConfig(enable_aot=True), 71 ) 72 .to_edge( 73 # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. 74 exir.EdgeCompileConfig(_check_ir_validity=False) 75 ) 76 .exported_program.graph_module 77 ) 78 pattern_lstm_conv = ( 79 exir.capture( 80 LSTMConvPattern(), 81 (input_x, input_h, input_c), 82 exir.CaptureConfig(), 83 ) 84 .to_edge( 85 # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. 86 exir.EdgeCompileConfig(_check_ir_validity=False) 87 ) 88 .exported_program.graph_module 89 ) 90 91 def sub(x, y): 92 return torch.sub(x, y) 93 94 pattern_sub_lifted = ( 95 exir.capture( 96 sub, 97 (input_x, input_h), 98 exir.CaptureConfig(enable_aot=True, _unlift=False), 99 ) 100 .to_edge(exir.EdgeCompileConfig(_use_edge_ops=True)) 101 .exported_program.graph_module 102 ) 103 pattern_sub = ( 104 exir.capture( 105 sub, 106 (input_x, input_h), 107 exir.CaptureConfig(), 108 ) 109 .to_edge() 110 .exported_program.graph_module 111 ) 112 self.patterns = [ 113 pattern_lstm_conv_lifted.graph, 114 pattern_lstm_conv.graph, 115 pattern_sub_lifted.graph, 116 pattern_sub.graph, 117 ] 118 119 backend_id = QnnBackend.__name__ 120 self.delegation_spec = DelegationSpec(backend_id, []) 121 122 def is_exclusive(self, partition_list_list: List[List[Partition]]) -> bool: 123 """ 124 List[Partition] is generate from one pattern partitioner, and this partitioner 125 only supports merging exclusive partitions. It will check if all partitions are 126 exclusive by comparing len(all_nodes) and len(set(all_nodes)) 127 128 Args: 129 partition_list_list: all partitions from all pattern partitioners 130 131 Returns: 132 bool: True if all nodes from all partitions are exclusive. 133 134 For example, 0/1 are the partition id, A/B/../L are nodes: 135 [ 136 [(0: A, B, C), (1: D, E, F)], # from pattern lstm + conv 137 [(0: B, J, L)], # from sub 138 ] 139 node B shows up in both partition. Usually some special tricks (either merge two list, 140 or only keep one pattern [A, B, C]) needs to done here, depending on user's need. 141 """ 142 all_partition = [ 143 partition 144 for partition_list in partition_list_list 145 for partition in partition_list 146 ] 147 148 # All nodes from all partitions from all pattern match results 149 all_nodes = [] 150 for partition in all_partition: 151 all_nodes.extend(partition.nodes) 152 all_nodes_set = set(all_nodes) 153 154 # Calculate the number of duplciate nodes 155 duplicated_node_number = len(all_nodes) - len(all_nodes_set) 156 logging.info(f"duplicate node number is {duplicated_node_number}.") 157 return duplicated_node_number == 0 158 159 def generate_partition_list(self, graph_module) -> List[Partition]: 160 """ 161 Generate a list of partitions from all matched patterns 162 163 Args: 164 graph_module: the input graph module 165 166 Returns: 167 bool: True if all partitions are exclusive. 168 169 For example, 0/1 are the partition id, A/B/../L are nodes: 170 [ 171 [(0: A, B, C), (1: D, E, F)], # from pattern lstm + conv 172 [(0: G, H, I)], # from sub 173 ] 174 the output will be 175 [ 176 [(0: A, B, C), (1: D, E, F), (3: G, H, I)] 177 ] 178 179 """ 180 partitions_from_all_pattern = generate_pattern_op_partitions( 181 graph_module, self.patterns 182 ) 183 184 # Assign a unique id for each partition 185 partition_id = 0 186 187 flat_proposed_partitions_with_unique_id = [] 188 for partition in partitions_from_all_pattern: 189 partition.id = partition_id 190 flat_proposed_partitions_with_unique_id.append(partition) 191 partition_id += 1 192 193 return flat_proposed_partitions_with_unique_id 194 195 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 196 partition_tags = {} 197 partition_list = self.generate_partition_list(exported_program.graph_module) 198 for partition in partition_list: 199 for node in partition.nodes: 200 delegation_tag = f"tag{partition.id}" 201 node.meta["delegation_tag"] = delegation_tag 202 partition_tags[delegation_tag] = self.delegation_spec 203 tag_constant_data(exported_program) 204 return PartitionResult( 205 tagged_exported_program=exported_program, partition_tags=partition_tags 206 ) 207 208 209@final 210class HTAPartitionerOnePatternDemo(Partitioner): 211 """ 212 Similar to HTAPartitionerMultiplePatternDemo, the only difference is only one pattern (lstm + conv) 213 is lowerable. We can subclass PatternPartitioner and use the PatternPartitioner.generate_submodules() 214 function to get the graph with submodules and tag accordingly. 215 """ 216 217 def __init__(self) -> None: 218 """ 219 Initialize the parent class PatternPartitioner with the pattern (lstm + conv) 220 """ 221 222 # Only lowering lstm + conv pattern 223 class LSTMConvPattern(torch.nn.Module): 224 def __init__(self): 225 super().__init__() 226 self.lstm = torch.nn.LSTM( 227 input_size=32, 228 hidden_size=32, 229 num_layers=1, 230 ) 231 self.conv = torch.nn.Conv1d(1, 1, 1, stride=2) 232 233 def forward(self, x_raw, h, c): 234 output, (hn, cn) = self.lstm(x_raw, (h, c)) 235 k = self.conv(output) 236 return output, hn, cn, k 237 238 input_x = torch.ones([1, 32]) 239 input_h = torch.ones([1, 32]) 240 input_c = torch.ones([1, 32]) 241 242 pattern_lstm_conv_lifted = ( 243 exir.capture( 244 LSTMConvPattern(), 245 (input_x, input_h, input_c), 246 exir.CaptureConfig(enable_aot=True), 247 ) 248 .to_edge( 249 # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. 250 exir.EdgeCompileConfig(_check_ir_validity=False) 251 ) 252 .exported_program.graph_module 253 ) 254 pattern_lstm_conv_unlifted = ( 255 exir.capture( 256 LSTMConvPattern(), 257 (input_x, input_h, input_c), 258 exir.CaptureConfig(), 259 ) 260 .to_edge( 261 # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. 262 exir.EdgeCompileConfig(_check_ir_validity=False) 263 ) 264 .exported_program.graph_module 265 ) 266 self.patterns = [ 267 pattern_lstm_conv_lifted.graph, 268 pattern_lstm_conv_unlifted.graph, 269 ] 270 # Only (lstm + conv) pattern is lowerable 271 272 backend_id = QnnBackend.__name__ 273 self.delegation_spec = DelegationSpec(backend_id, []) 274 275 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 276 partition_tags = {} 277 partition_list = generate_pattern_op_partitions( 278 exported_program.graph_module, patterns=self.patterns 279 ) 280 for partition in partition_list: 281 for node in partition.nodes: 282 delegation_tag = f"tag{partition.id}" 283 node.meta["delegation_tag"] = delegation_tag 284 partition_tags[delegation_tag] = self.delegation_spec 285 tag_constant_data(exported_program) 286 return PartitionResult( 287 tagged_exported_program=exported_program, partition_tags=partition_tags 288 ) 289