1*523fa7a6SAndroid Build Coastguard Worker# Copyright 2023-2024 Arm Limited and/or its affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# 3*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 4*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 5*523fa7a6SAndroid Build Coastguard Worker 6*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe 7*523fa7a6SAndroid Build Coastguard Worker 8*523fa7a6SAndroid Build Coastguard Workerimport logging 9*523fa7a6SAndroid Build Coastguard Workerimport os 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, final, List, Optional, Tuple 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport torch 13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.arm_backend import ArmBackend # usort: skip 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.operator_support.tosa_supported_operators import ( 16*523fa7a6SAndroid Build Coastguard Worker TOSASupportedOperators, 17*523fa7a6SAndroid Build Coastguard Worker) 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_specification import TosaSpecification 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.partitioner import ( 21*523fa7a6SAndroid Build Coastguard Worker DelegationSpec, 22*523fa7a6SAndroid Build Coastguard Worker Partitioner, 23*523fa7a6SAndroid Build Coastguard Worker PartitionResult, 24*523fa7a6SAndroid Build Coastguard Worker) 25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.utils import tag_constant_data 26*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes import PassManager 27*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram 28*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__) 31*523fa7a6SAndroid Build Coastguard Workerlogger.setLevel(logging.WARNING) 32*523fa7a6SAndroid Build Coastguard WorkerTOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" 33*523fa7a6SAndroid Build Coastguard Workerif TOSA_DBG_VERBOSE: 34*523fa7a6SAndroid Build Coastguard Worker logging.basicConfig(level=logging.INFO) 35*523fa7a6SAndroid Build Coastguard Worker logger.setLevel(logging.INFO) 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker 38*523fa7a6SAndroid Build Coastguard Worker@final 39*523fa7a6SAndroid Build Coastguard Workerclass ArmPartitioner(Partitioner): 40*523fa7a6SAndroid Build Coastguard Worker def __init__(self, compile_spec: List[CompileSpec]) -> None: 41*523fa7a6SAndroid Build Coastguard Worker self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec) 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker def partition(self, exported_program: ExportedProgram) -> PartitionResult: 44*523fa7a6SAndroid Build Coastguard Worker # Run the CapabilityBasedPartitioner to return the largest possible 45*523fa7a6SAndroid Build Coastguard Worker # subgraphs containing the nodes with the tags 46*523fa7a6SAndroid Build Coastguard Worker logger.info("ArmPartitioner::partition") 47*523fa7a6SAndroid Build Coastguard Worker partition_tags = {} 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker tosa_spec = TosaSpecification.create_from_compilespecs( 50*523fa7a6SAndroid Build Coastguard Worker self.delegation_spec.compile_specs 51*523fa7a6SAndroid Build Coastguard Worker ) 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker logger.info(f"Partitioning for {tosa_spec}") 54*523fa7a6SAndroid Build Coastguard Worker 55*523fa7a6SAndroid Build Coastguard Worker for spec in self.delegation_spec.compile_specs: 56*523fa7a6SAndroid Build Coastguard Worker if spec.key == "quantize_io" and spec.value.decode() == "True": 57*523fa7a6SAndroid Build Coastguard Worker # Exclude IO quantization from the partition 58*523fa7a6SAndroid Build Coastguard Worker passes = PassManager( 59*523fa7a6SAndroid Build Coastguard Worker passes=[ 60*523fa7a6SAndroid Build Coastguard Worker TagIOQuantPass(), 61*523fa7a6SAndroid Build Coastguard Worker ] 62*523fa7a6SAndroid Build Coastguard Worker ) 63*523fa7a6SAndroid Build Coastguard Worker passes(exported_program.graph_module) 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker capability_partitioner = CapabilityBasedPartitioner( 66*523fa7a6SAndroid Build Coastguard Worker exported_program.graph_module, 67*523fa7a6SAndroid Build Coastguard Worker TOSASupportedOperators(tosa_spec), 68*523fa7a6SAndroid Build Coastguard Worker allows_single_node_partition=True, 69*523fa7a6SAndroid Build Coastguard Worker ) 70*523fa7a6SAndroid Build Coastguard Worker partition_list = capability_partitioner.propose_partitions() 71*523fa7a6SAndroid Build Coastguard Worker for partition in partition_list: 72*523fa7a6SAndroid Build Coastguard Worker for node in partition.nodes: 73*523fa7a6SAndroid Build Coastguard Worker tag = f"tag{partition.id}" 74*523fa7a6SAndroid Build Coastguard Worker node.meta["delegation_tag"] = tag 75*523fa7a6SAndroid Build Coastguard Worker partition_tags[tag] = self.delegation_spec 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Worker tag_constant_data(exported_program) 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker return PartitionResult( 80*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program=exported_program, partition_tags=partition_tags 81*523fa7a6SAndroid Build Coastguard Worker ) 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Worker def ops_to_not_decompose( 84*523fa7a6SAndroid Build Coastguard Worker self, 85*523fa7a6SAndroid Build Coastguard Worker ep: ExportedProgram, 86*523fa7a6SAndroid Build Coastguard Worker ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: 87*523fa7a6SAndroid Build Coastguard Worker ops_to_not_decompose = [ 88*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.linear.default, 89*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.upsample_nearest2d.vec, 90*523fa7a6SAndroid Build Coastguard Worker ] 91*523fa7a6SAndroid Build Coastguard Worker return (ops_to_not_decompose, None) 92