xref: /aosp_15_r20/external/executorch/backends/vulkan/vulkan_preprocess.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import Any, Dict, final, List
10
11import executorch.backends.vulkan.utils as utils
12
13from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform
14from executorch.backends.transforms.fuse_batch_norm_with_conv import (
15    FuseBatchNormWithConvPass,
16)
17from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
18from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
19from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
20from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
21
22from executorch.backends.vulkan._passes import (
23    insert_prepack_nodes,
24    RemoveLocalScalarDenseOpsTransform,
25    TagMemoryMetaPass,
26)
27
28from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
29from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
30    VkMemoryLayout,
31    VkStorageType,
32)
33from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
34    serialize_vulkan_graph,
35)
36
37from executorch.exir.backend.backend_details import (
38    BackendDetails,
39    CompileSpec,
40    ExportedProgram,
41    PreprocessResult,
42)
43from executorch.exir.backend.utils import DelegateMappingBuilder
44from executorch.exir.pass_base import ExportPass, PassBase
45
46from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
47
48from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
49
50from executorch.exir.program._program import _copy_module
51
52from torch.export._remove_auto_functionalized_pass import (
53    unsafe_remove_auto_functionalized_pass,
54)
55
56DEFAULT_DEBUG_HANDLE = 65535
57
58
59# pyre-ignore
60def apply_passes(program: ExportedProgram, passes) -> ExportedProgram:
61    for p in passes:
62
63        if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase):
64            new_gm = program.graph_module
65            # This is a workaround to allow the memory planning pass to work without
66            # having to first apply ToOutVarPass(). See the `greedy()` function in
67            # `exir.memory_planning`; if this attribute isn't set, assertions in
68            # `collect_spec_from_nodes()` will fail.
69            if isinstance(p, MemoryPlanningPass):
70                new_gm.encounter_to_out_var_failure = True
71
72            new_gm_res = p(new_gm)
73            assert new_gm_res is not None
74            new_gm = new_gm_res.graph_module
75
76            # See the application of this function in exir/program/_program.py for more
77            # details on why this step is necessary.
78            if isinstance(p, SpecPropPass):
79                p.update_placeholder_tensor_specs(program, new_gm)
80
81            _copy_module(program.graph_module, new_gm)
82        else:
83            program = p(program)
84
85    return program
86
87
88def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]:
89    options = {}
90    for spec in compile_specs:
91        if spec.key == "storage_type_override":
92            options[spec.key] = VkStorageType(
93                int.from_bytes(spec.value, byteorder="little")
94            )
95        if spec.key == "memory_layout_override":
96            options[spec.key] = VkMemoryLayout(
97                int.from_bytes(spec.value, byteorder="little")
98            )
99        if spec.key in {"texture_limits_x", "texture_limits_y", "texture_limits_z"}:
100            options[spec.key] = int.from_bytes(spec.value, byteorder="little")
101
102        if spec.key == "skip_tag_memory_metadata":
103            options[spec.key] = bool.from_bytes(spec.value, byteorder="little")
104
105        # Unhandled options are ignored
106
107    return options
108
109
110@final
111class VulkanBackend(BackendDetails):
112    @classmethod
113    # pyre-ignore
114    def preprocess(  # noqa: C901
115        cls,
116        program: ExportedProgram,
117        module_compile_spec: List[CompileSpec],
118    ) -> PreprocessResult:
119        compile_options = parse_compile_spec(module_compile_spec)
120        limits_x = compile_options.get(
121            "texture_limits_x", utils.DEFAULT_TEXTURE_LIMITS[0]
122        )
123        limits_y = compile_options.get(
124            "texture_limits_y", utils.DEFAULT_TEXTURE_LIMITS[1]
125        )
126        limits_z = compile_options.get(
127            "texture_limits_z", utils.DEFAULT_TEXTURE_LIMITS[2]
128        )
129        texture_limits = (limits_x, limits_y, limits_z)
130
131        default_storage_type = compile_options.get(
132            "storage_type_override", VkStorageType.TEXTURE_3D
133        )
134        default_memory_layout = compile_options.get(
135            "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED
136        )
137
138        program = unsafe_remove_auto_functionalized_pass(program)
139
140        # First, apply passes that fuse/remove operators to consolidate the graph
141        # structure but still preserve an "ATen-compliant" graph structure (i.e. all
142        # arguments to ATen operators must match the ATen function schema).
143        program = apply_passes(
144            program,
145            [
146                RemoveCloneOpsTransform(),
147                AddmmToLinearTransform(),
148                FuseDequantLinearPass(),
149                FuseViewCopyTransform(),
150                FuseBatchNormWithConvPass(program),
151                FuseClampPass(),
152            ],
153        )
154
155        # Next annotate tensor nodes with TensorSpec structs which is needed for dynamic
156        # shapes and memory planning. Until this point, the graph must be ATen compliant
157        # because SpecPropPass will be calling the underlying ATen operators during its
158        # execution.
159        program = apply_passes(program, [SpecPropPass()])
160
161        # Apply graph transforms which either require `TensorSpec`s to have been created
162        # or would create an non ATen compliant graph structure.
163        program = apply_passes(
164            program,
165            [
166                # Since this pass may replace a scalar argument with a tensor argument,
167                # this pass may result in a non ATen compliant graph structure.
168                RemoveLocalScalarDenseOpsTransform(),
169                insert_prepack_nodes,
170            ],
171        )
172
173        # Optionally apply the memory metadata tagging pass, which will insert storage
174        # type and memory layout transition nodes to ensure that all tensor arguments
175        # to an operator is in a supported or optimal configuration. If this pass is not
176        # applied, there will be a risk that some operators recieve arguments with
177        # memory settings that are not supported by the implementation.
178        if not compile_options.get("skip_tag_memory_metadata", False):
179            program = apply_passes(
180                program,
181                [
182                    TagMemoryMetaPass(
183                        texture_limits,
184                        default_storage_type=default_storage_type,
185                        default_memory_layout=default_memory_layout,
186                    ),
187                ],
188            )
189
190        # Finally, apply dynamic shape passes and memory planning pass. These passes
191        # must be applied only when the graph structure is finalized.
192        program = apply_passes(
193            program,
194            [
195                ConstraintBasedSymShapeEvalPass(),
196                MemoryPlanningPass(),
197            ],
198        )
199
200        graph_builder = VkGraphBuilder(
201            program, DelegateMappingBuilder(generated_identifiers=True)
202        )
203        vk_graph = graph_builder.build_graph()
204
205        return PreprocessResult(
206            processed_bytes=serialize_vulkan_graph(
207                vk_graph, graph_builder.const_tensors, []
208            ),
209            debug_handle_map=graph_builder.delegate_mapping_builder.get_delegate_mapping(),
210        )
211