1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import Any, Dict, final, List 10 11import executorch.backends.vulkan.utils as utils 12 13from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform 14from executorch.backends.transforms.fuse_batch_norm_with_conv import ( 15 FuseBatchNormWithConvPass, 16) 17from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass 18from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass 19from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform 20from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform 21 22from executorch.backends.vulkan._passes import ( 23 insert_prepack_nodes, 24 RemoveLocalScalarDenseOpsTransform, 25 TagMemoryMetaPass, 26) 27 28from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder 29from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( 30 VkMemoryLayout, 31 VkStorageType, 32) 33from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( 34 serialize_vulkan_graph, 35) 36 37from executorch.exir.backend.backend_details import ( 38 BackendDetails, 39 CompileSpec, 40 ExportedProgram, 41 PreprocessResult, 42) 43from executorch.exir.backend.utils import DelegateMappingBuilder 44from executorch.exir.pass_base import ExportPass, PassBase 45 46from executorch.exir.passes import MemoryPlanningPass, SpecPropPass 47 48from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass 49 50from executorch.exir.program._program import _copy_module 51 52from torch.export._remove_auto_functionalized_pass import ( 53 unsafe_remove_auto_functionalized_pass, 54) 55 56DEFAULT_DEBUG_HANDLE = 65535 57 58 59# pyre-ignore 60def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: 61 for p in passes: 62 63 if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): 64 new_gm = program.graph_module 65 # This is a workaround to allow the memory planning pass to work without 66 # having to first apply ToOutVarPass(). See the `greedy()` function in 67 # `exir.memory_planning`; if this attribute isn't set, assertions in 68 # `collect_spec_from_nodes()` will fail. 69 if isinstance(p, MemoryPlanningPass): 70 new_gm.encounter_to_out_var_failure = True 71 72 new_gm_res = p(new_gm) 73 assert new_gm_res is not None 74 new_gm = new_gm_res.graph_module 75 76 # See the application of this function in exir/program/_program.py for more 77 # details on why this step is necessary. 78 if isinstance(p, SpecPropPass): 79 p.update_placeholder_tensor_specs(program, new_gm) 80 81 _copy_module(program.graph_module, new_gm) 82 else: 83 program = p(program) 84 85 return program 86 87 88def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: 89 options = {} 90 for spec in compile_specs: 91 if spec.key == "storage_type_override": 92 options[spec.key] = VkStorageType( 93 int.from_bytes(spec.value, byteorder="little") 94 ) 95 if spec.key == "memory_layout_override": 96 options[spec.key] = VkMemoryLayout( 97 int.from_bytes(spec.value, byteorder="little") 98 ) 99 if spec.key in {"texture_limits_x", "texture_limits_y", "texture_limits_z"}: 100 options[spec.key] = int.from_bytes(spec.value, byteorder="little") 101 102 if spec.key == "skip_tag_memory_metadata": 103 options[spec.key] = bool.from_bytes(spec.value, byteorder="little") 104 105 # Unhandled options are ignored 106 107 return options 108 109 110@final 111class VulkanBackend(BackendDetails): 112 @classmethod 113 # pyre-ignore 114 def preprocess( # noqa: C901 115 cls, 116 program: ExportedProgram, 117 module_compile_spec: List[CompileSpec], 118 ) -> PreprocessResult: 119 compile_options = parse_compile_spec(module_compile_spec) 120 limits_x = compile_options.get( 121 "texture_limits_x", utils.DEFAULT_TEXTURE_LIMITS[0] 122 ) 123 limits_y = compile_options.get( 124 "texture_limits_y", utils.DEFAULT_TEXTURE_LIMITS[1] 125 ) 126 limits_z = compile_options.get( 127 "texture_limits_z", utils.DEFAULT_TEXTURE_LIMITS[2] 128 ) 129 texture_limits = (limits_x, limits_y, limits_z) 130 131 default_storage_type = compile_options.get( 132 "storage_type_override", VkStorageType.TEXTURE_3D 133 ) 134 default_memory_layout = compile_options.get( 135 "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED 136 ) 137 138 program = unsafe_remove_auto_functionalized_pass(program) 139 140 # First, apply passes that fuse/remove operators to consolidate the graph 141 # structure but still preserve an "ATen-compliant" graph structure (i.e. all 142 # arguments to ATen operators must match the ATen function schema). 143 program = apply_passes( 144 program, 145 [ 146 RemoveCloneOpsTransform(), 147 AddmmToLinearTransform(), 148 FuseDequantLinearPass(), 149 FuseViewCopyTransform(), 150 FuseBatchNormWithConvPass(program), 151 FuseClampPass(), 152 ], 153 ) 154 155 # Next annotate tensor nodes with TensorSpec structs which is needed for dynamic 156 # shapes and memory planning. Until this point, the graph must be ATen compliant 157 # because SpecPropPass will be calling the underlying ATen operators during its 158 # execution. 159 program = apply_passes(program, [SpecPropPass()]) 160 161 # Apply graph transforms which either require `TensorSpec`s to have been created 162 # or would create an non ATen compliant graph structure. 163 program = apply_passes( 164 program, 165 [ 166 # Since this pass may replace a scalar argument with a tensor argument, 167 # this pass may result in a non ATen compliant graph structure. 168 RemoveLocalScalarDenseOpsTransform(), 169 insert_prepack_nodes, 170 ], 171 ) 172 173 # Optionally apply the memory metadata tagging pass, which will insert storage 174 # type and memory layout transition nodes to ensure that all tensor arguments 175 # to an operator is in a supported or optimal configuration. If this pass is not 176 # applied, there will be a risk that some operators recieve arguments with 177 # memory settings that are not supported by the implementation. 178 if not compile_options.get("skip_tag_memory_metadata", False): 179 program = apply_passes( 180 program, 181 [ 182 TagMemoryMetaPass( 183 texture_limits, 184 default_storage_type=default_storage_type, 185 default_memory_layout=default_memory_layout, 186 ), 187 ], 188 ) 189 190 # Finally, apply dynamic shape passes and memory planning pass. These passes 191 # must be applied only when the graph structure is finalized. 192 program = apply_passes( 193 program, 194 [ 195 ConstraintBasedSymShapeEvalPass(), 196 MemoryPlanningPass(), 197 ], 198 ) 199 200 graph_builder = VkGraphBuilder( 201 program, DelegateMappingBuilder(generated_identifiers=True) 202 ) 203 vk_graph = graph_builder.build_graph() 204 205 return PreprocessResult( 206 processed_bytes=serialize_vulkan_graph( 207 vk_graph, graph_builder.const_tensors, [] 208 ), 209 debug_handle_map=graph_builder.delegate_mapping_builder.get_delegate_mapping(), 210 ) 211