1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# pyre-unsafe 9 10import torch 11from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import ( 12 AnnotateChannelsLastDimOrder, 13) 14from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass 15from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass 16from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( 17 ConvertExpandCopyToRepeatPass, 18) 19from executorch.backends.arm._passes.convert_split_to_slice import ( 20 ConvertSplitToSlicePass, 21) 22from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass 23from executorch.backends.arm._passes.decompose_layernorm_pass import ( 24 DecomposeLayerNormPass, 25) 26from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass 27from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass 28from executorch.backends.arm._passes.decompose_softmaxes_pass import ( 29 DecomposeSoftmaxesPass, 30) 31from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass 32from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import ( 33 InsertSqueezeAfterSumPass, 34) 35from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass 36from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( 37 ConvertMeanDimToAveragePool, 38) 39from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass 40from executorch.backends.arm._passes.scalars_to_attribute_pass import ( 41 ScalarsToAttributePass, 42) 43from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass 44from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import ( 45 UnsqueezeScalarPlaceholdersPass, 46) 47from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass 48from executorch.exir import ExportedProgram 49from executorch.exir.backend.compile_spec_schema import CompileSpec 50from executorch.exir.pass_manager import PassManager 51 52 53class ArmPassManager(PassManager): 54 55 def _transform(self, graph_module: torch.fx.GraphModule): 56 return self(graph_module).graph_module 57 58 def transform_to_backend_pipeline( 59 self, exported_program: ExportedProgram, compile_spec: list[CompileSpec] 60 ): 61 """Apply passes before transforming program to backend""" 62 self.add_pass(CastInt64ToInt32Pass(exported_program)) 63 self.add_pass(RemoveGetItemPass()) 64 self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) 65 self.add_pass(SizeAdjustConv2DPass()) 66 self.add_pass(RemoveClonePass()) 67 self.add_pass(ConvertExpandCopyToRepeatPass()) 68 self.add_pass(DecomposeLayerNormPass()) 69 self.add_pass(DecomposeVarPass()) 70 self.add_pass(ConvertMeanDimToAveragePool()) 71 self.add_pass(DecomposeMeanDimPass()) 72 self.add_pass(MatchArgRanksPass(exported_program)) 73 self.add_pass(DecomposeDivPass()) 74 self.add_pass(InsertSqueezeAfterSumPass()) 75 self.add_pass(ConvertSplitToSlicePass()) 76 self.add_pass(Conv1dUnsqueezePass(exported_program)) 77 self.add_pass(DecomposeSoftmaxesPass()) 78 self.add_pass(DecomposeLinearPass()) 79 for spec in compile_spec: 80 if spec.key == "permute_memory_format": 81 memory_format = spec.value.decode() 82 if memory_format == "nhwc": 83 self.add_pass(AnnotateChannelsLastDimOrder()) 84 85 return self._transform(exported_program.graph_module) 86 87 def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule): 88 self.add_pass(ScalarsToAttributePass()) 89 self.add_pass(DecomposeLayerNormPass()) 90 self.add_pass(DecomposeVarPass()) 91 self.add_pass(DecomposeMeanDimPass()) 92 self.add_pass(DecomposeDivPass()) 93 self.add_pass(DecomposeSoftmaxesPass()) 94 return self._transform(graph_module) 95