xref: /aosp_15_r20/external/executorch/backends/arm/_passes/arm_pass_manager.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# pyre-unsafe
9
10import torch
11from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
12    AnnotateChannelsLastDimOrder,
13)
14from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
15from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
16from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
17    ConvertExpandCopyToRepeatPass,
18)
19from executorch.backends.arm._passes.convert_split_to_slice import (
20    ConvertSplitToSlicePass,
21)
22from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
23from executorch.backends.arm._passes.decompose_layernorm_pass import (
24    DecomposeLayerNormPass,
25)
26from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
27from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
28from executorch.backends.arm._passes.decompose_softmaxes_pass import (
29    DecomposeSoftmaxesPass,
30)
31from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
32from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
33    InsertSqueezeAfterSumPass,
34)
35from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
36from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
37    ConvertMeanDimToAveragePool,
38)
39from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
40from executorch.backends.arm._passes.scalars_to_attribute_pass import (
41    ScalarsToAttributePass,
42)
43from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
44from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
45    UnsqueezeScalarPlaceholdersPass,
46)
47from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
48from executorch.exir import ExportedProgram
49from executorch.exir.backend.compile_spec_schema import CompileSpec
50from executorch.exir.pass_manager import PassManager
51
52
53class ArmPassManager(PassManager):
54
55    def _transform(self, graph_module: torch.fx.GraphModule):
56        return self(graph_module).graph_module
57
58    def transform_to_backend_pipeline(
59        self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
60    ):
61        """Apply passes before transforming program to backend"""
62        self.add_pass(CastInt64ToInt32Pass(exported_program))
63        self.add_pass(RemoveGetItemPass())
64        self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
65        self.add_pass(SizeAdjustConv2DPass())
66        self.add_pass(RemoveClonePass())
67        self.add_pass(ConvertExpandCopyToRepeatPass())
68        self.add_pass(DecomposeLayerNormPass())
69        self.add_pass(DecomposeVarPass())
70        self.add_pass(ConvertMeanDimToAveragePool())
71        self.add_pass(DecomposeMeanDimPass())
72        self.add_pass(MatchArgRanksPass(exported_program))
73        self.add_pass(DecomposeDivPass())
74        self.add_pass(InsertSqueezeAfterSumPass())
75        self.add_pass(ConvertSplitToSlicePass())
76        self.add_pass(Conv1dUnsqueezePass(exported_program))
77        self.add_pass(DecomposeSoftmaxesPass())
78        self.add_pass(DecomposeLinearPass())
79        for spec in compile_spec:
80            if spec.key == "permute_memory_format":
81                memory_format = spec.value.decode()
82                if memory_format == "nhwc":
83                    self.add_pass(AnnotateChannelsLastDimOrder())
84
85        return self._transform(exported_program.graph_module)
86
87    def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
88        self.add_pass(ScalarsToAttributePass())
89        self.add_pass(DecomposeLayerNormPass())
90        self.add_pass(DecomposeVarPass())
91        self.add_pass(DecomposeMeanDimPass())
92        self.add_pass(DecomposeDivPass())
93        self.add_pass(DecomposeSoftmaxesPass())
94        return self._transform(graph_module)
95