1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict, final 8 9import torch 10from executorch.backends.example.example_backend import ExampleBackend 11from executorch.backends.example.example_operators.ops import module_to_annotator 12from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 13 generate_partitions_from_list_of_nodes, 14) 15from executorch.exir.backend.partitioner import ( 16 DelegationSpec, 17 Partitioner, 18 PartitionResult, 19) 20from executorch.exir.dialects._ops import ops as exir_ops 21from executorch.exir.graph_module import get_control_flow_submodules 22from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions 23from torch.export import ExportedProgram 24from torch.fx.passes.operator_support import OperatorSupportBase 25 26 27@final 28class ExamplePartitioner(Partitioner): 29 """ 30 Partitions all add/mul nodes regardless of order 31 """ 32 33 def __init__(self) -> None: 34 self.patterns = module_to_annotator.keys() 35 self.delegation_spec = DelegationSpec(ExampleBackend.__name__, []) 36 37 class DequantQuantOperatorSupport(OperatorSupportBase): 38 def is_node_supported(self, _submodules, node: torch.fx.Node) -> bool: 39 return node.op == "call_function" and node.target in [ 40 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 41 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 42 ] 43 44 self.dequant_quant_support = DequantQuantOperatorSupport() 45 46 def _partition_graph_module( 47 self, edge_graph_module: torch.fx.GraphModule 48 ) -> Dict[str, DelegationSpec]: 49 partition_tags: Dict[str, DelegationSpec] = {} 50 partition_nodes = [] 51 for pattern in self.patterns: 52 fused_partitions = find_sequential_partitions( 53 edge_graph_module, 54 pattern, 55 ) 56 57 for fused_partition in fused_partitions: 58 for partition in fused_partition: 59 partition_nodes.append(partition.nodes) 60 61 partitions = generate_partitions_from_list_of_nodes( 62 edge_graph_module, partition_nodes, self.dequant_quant_support 63 ) 64 65 for partition in partitions: 66 for node in partition.nodes: 67 delegation_tag = f"tag{partition.id}" 68 node.meta["delegation_tag"] = delegation_tag 69 if node.op == "call_function": 70 for arg_node in node.args: 71 if ( 72 isinstance(arg_node, torch.fx.Node) 73 and arg_node.op == "get_attr" 74 ): 75 arg_node.meta["delegation_tag"] = delegation_tag 76 partition_tags[delegation_tag] = self.delegation_spec 77 78 for _, submodule, _ in get_control_flow_submodules(edge_graph_module): 79 submodule_partition_tags = self._partition_graph_module(submodule) 80 partition_tags.update(submodule_partition_tags) 81 82 return partition_tags 83 84 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 85 partition_tag = self._partition_graph_module(exported_program.graph_module) 86 return PartitionResult( 87 tagged_exported_program=exported_program, partition_tags=partition_tag 88 ) 89