1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) 2024 MediaTek Inc. 2*523fa7a6SAndroid Build Coastguard Worker# 3*523fa7a6SAndroid Build Coastguard Worker# Licensed under the BSD License (the "License"); you may not use this file 4*523fa7a6SAndroid Build Coastguard Worker# except in compliance with the License. See the license file in the root 5*523fa7a6SAndroid Build Coastguard Worker# directory of this source tree for more details. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, final, List, Optional, Tuple 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport torch 10*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.mediatek.preprocess import NeuropilotBackend 11*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import CompileSpec 12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.partitioner import ( 13*523fa7a6SAndroid Build Coastguard Worker DelegationSpec, 14*523fa7a6SAndroid Build Coastguard Worker Partitioner, 15*523fa7a6SAndroid Build Coastguard Worker PartitionResult, 16*523fa7a6SAndroid Build Coastguard Worker) 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.utils import tag_constant_data 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerfrom mtk_converter.python.converters.pytorch import importer_v2 20*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram 21*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 22*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.operator_support import OperatorSupportBase 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Workerclass NeuropilotOperatorsSupport(OperatorSupportBase): 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Worker def __init__( 28*523fa7a6SAndroid Build Coastguard Worker self, 29*523fa7a6SAndroid Build Coastguard Worker op_types_to_skip: Optional[set] = None, 30*523fa7a6SAndroid Build Coastguard Worker op_names_to_skip: Optional[set] = None, 31*523fa7a6SAndroid Build Coastguard Worker ) -> None: 32*523fa7a6SAndroid Build Coastguard Worker if op_types_to_skip is None: 33*523fa7a6SAndroid Build Coastguard Worker op_types_to_skip = set() 34*523fa7a6SAndroid Build Coastguard Worker if op_names_to_skip is None: 35*523fa7a6SAndroid Build Coastguard Worker op_names_to_skip = set() 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker self._op_types_to_skip = op_types_to_skip 38*523fa7a6SAndroid Build Coastguard Worker self._op_names_to_skip = op_names_to_skip 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker def is_node_supported(self, _, node: torch.fx.Node) -> bool: 41*523fa7a6SAndroid Build Coastguard Worker # Handle 'call_function' only cause 'placeholder' and 'output' cannot be tagged. 42*523fa7a6SAndroid Build Coastguard Worker # Ref: https://github.com/pytorch/executorch/pull/1398 43*523fa7a6SAndroid Build Coastguard Worker if node.op != "call_function": 44*523fa7a6SAndroid Build Coastguard Worker return False 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker op_type = node.target.__name__ 47*523fa7a6SAndroid Build Coastguard Worker if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip: 48*523fa7a6SAndroid Build Coastguard Worker print( 49*523fa7a6SAndroid Build Coastguard Worker f"[Neuropilot Backend] The {op_type} operator with name '{node.name}' is skipped." 50*523fa7a6SAndroid Build Coastguard Worker ) 51*523fa7a6SAndroid Build Coastguard Worker return False 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker return importer_v2.is_fx_node_supported(node) 54*523fa7a6SAndroid Build Coastguard Worker 55*523fa7a6SAndroid Build Coastguard Worker 56*523fa7a6SAndroid Build Coastguard Worker@final 57*523fa7a6SAndroid Build Coastguard Workerclass NeuropilotPartitioner(Partitioner): 58*523fa7a6SAndroid Build Coastguard Worker 59*523fa7a6SAndroid Build Coastguard Worker def __init__( 60*523fa7a6SAndroid Build Coastguard Worker self, 61*523fa7a6SAndroid Build Coastguard Worker compile_spec: List[CompileSpec], 62*523fa7a6SAndroid Build Coastguard Worker op_types_to_skip: Optional[set] = None, 63*523fa7a6SAndroid Build Coastguard Worker op_names_to_skip: Optional[set] = None, 64*523fa7a6SAndroid Build Coastguard Worker ) -> None: 65*523fa7a6SAndroid Build Coastguard Worker self.delegation_spec = DelegationSpec(NeuropilotBackend.__name__, compile_spec) 66*523fa7a6SAndroid Build Coastguard Worker self._op_types_to_skip = op_types_to_skip 67*523fa7a6SAndroid Build Coastguard Worker self._op_names_to_skip = op_names_to_skip 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Worker def ops_to_not_decompose( 70*523fa7a6SAndroid Build Coastguard Worker self, 71*523fa7a6SAndroid Build Coastguard Worker ep: ExportedProgram, 72*523fa7a6SAndroid Build Coastguard Worker ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: 73*523fa7a6SAndroid Build Coastguard Worker ops_not_decompose = [ 74*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.pixel_shuffle.default, 75*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.upsample_bilinear2d.default, 76*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.upsample_bilinear2d.vec, 77*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.upsample_nearest2d.default, 78*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.upsample_nearest2d.vec, 79*523fa7a6SAndroid Build Coastguard Worker ] 80*523fa7a6SAndroid Build Coastguard Worker return (ops_not_decompose, None) 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker def partition(self, exported_program: ExportedProgram) -> PartitionResult: 83*523fa7a6SAndroid Build Coastguard Worker capability_partitioner = CapabilityBasedPartitioner( 84*523fa7a6SAndroid Build Coastguard Worker exported_program.graph_module, 85*523fa7a6SAndroid Build Coastguard Worker NeuropilotOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip), 86*523fa7a6SAndroid Build Coastguard Worker allows_single_node_partition=True, 87*523fa7a6SAndroid Build Coastguard Worker ) 88*523fa7a6SAndroid Build Coastguard Worker partition_list = capability_partitioner.propose_partitions() 89*523fa7a6SAndroid Build Coastguard Worker 90*523fa7a6SAndroid Build Coastguard Worker partition_tags = {} 91*523fa7a6SAndroid Build Coastguard Worker for partition in partition_list: 92*523fa7a6SAndroid Build Coastguard Worker for node in partition.nodes: 93*523fa7a6SAndroid Build Coastguard Worker tag = f"tag{partition.id}" 94*523fa7a6SAndroid Build Coastguard Worker node.meta["delegation_tag"] = tag 95*523fa7a6SAndroid Build Coastguard Worker partition_tags[tag] = self.delegation_spec 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker tag_constant_data(exported_program) 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker return PartitionResult( 100*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program=exported_program, partition_tags=partition_tags 101*523fa7a6SAndroid Build Coastguard Worker ) 102