1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import logging 10from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple 11 12import executorch.backends.vulkan.utils as utils 13 14import torch 15 16from executorch.backends.vulkan.op_registry import ( 17 get_op_features, 18 has_impl, 19 OpFeatures, 20 vulkan_supported_ops, 21) 22 23from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( 24 VkMemoryLayout, 25 VkStorageType, 26) 27from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend 28 29from executorch.exir.backend.compile_spec_schema import CompileSpec 30from executorch.exir.backend.partitioner import ( 31 DelegationSpec, 32 Partitioner, 33 PartitionResult, 34) 35from executorch.exir.backend.utils import tag_constant_data 36from executorch.exir.dialects._ops import ops as exir_ops 37 38from torch.export.exported_program import ExportedProgram 39from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 40 41from torch.fx.passes.operator_support import OperatorSupportBase 42 43# pyre-ignore 44ops_not_to_decompose = [ 45 torch.ops.aten.upsample_nearest2d.vec, 46] 47 48logger: logging.Logger = logging.getLogger("") 49logger.setLevel(logging.INFO) 50 51 52class VulkanSupportedOperators(OperatorSupportBase): 53 def __init__( 54 self, 55 texture_limits: utils.ImageExtents, 56 buffer_limit: int, 57 require_dynamic_shape: bool = False, 58 ) -> None: 59 super().__init__() 60 self.texture_limits: utils.ImageExtents = texture_limits 61 self.buffer_limit = buffer_limit 62 self.require_dynamic_shapes = require_dynamic_shape 63 # The tensor dim limit is to guard against tensors with one or more 64 # large dimensions, which cannot be represented by an image texture due 65 # to the texture axis limits. 66 self.tensor_dim_limit = 16384 67 68 def op_node_is_compatible( 69 self, node: torch.fx.Node, features: Optional[OpFeatures] = None 70 ) -> Tuple[bool, str]: 71 """ 72 Check if a given node is compatible with the Vulkan delegate's implementation 73 of the operator called by the node. Each tensor argument participating in the 74 operator call must be able to be represented with a (storage type, memory layout) 75 combination that is supported by the operator implementation. 76 """ 77 target = node.target 78 # Account for custom operators 79 if node.target == torch.ops.higher_order.auto_functionalized: 80 first_arg = node.args[0] 81 assert isinstance(first_arg, torch._ops.OpOverload) 82 target = first_arg.name() 83 84 # Extract the features for the node's operator, if no override was provided 85 if features is None: 86 if not has_impl(target): 87 return False, "no operator implementation" 88 features = get_op_features(target) 89 90 valid_texture_layouts = utils.possible_node_memory_layouts( 91 node, self.texture_limits 92 ) 93 94 can_use_buffers = utils.within_buffer_limit(node, self.buffer_limit) 95 for i, arg in enumerate(node.args): 96 if ( 97 isinstance(arg, torch.fx.Node) 98 and utils.is_tensor_node(arg) 99 and i not in features.skip_limits_check 100 ): 101 arg_texture_layouts = utils.possible_node_memory_layouts( 102 arg, self.texture_limits 103 ) 104 valid_texture_layouts = valid_texture_layouts.intersection( 105 arg_texture_layouts 106 ) 107 can_use_buffers = can_use_buffers and utils.within_buffer_limit( 108 arg, self.buffer_limit 109 ) 110 111 # If there are no valid texture memory layouts, then buffer storage must be 112 # supported by the operator implementation. 113 if len(valid_texture_layouts) == 0: 114 if not can_use_buffers: 115 return ( 116 False, 117 f"op requires buffers that exceed the buffer limit ({self.buffer_limit})", 118 ) 119 120 compatible = VkStorageType.BUFFER in features.supported_storage_types() 121 reason = "op is compatible" 122 if not compatible: 123 reason = "op requires buffers which is not supported by op impl" 124 return compatible, reason 125 126 op_available_layouts = features.supported_memory_layouts( 127 VkStorageType.TEXTURE_3D 128 ) 129 130 is_compatible = any( 131 layout in op_available_layouts for layout in valid_texture_layouts 132 ) 133 if not is_compatible: 134 return False, "Required texutre memory layout not supported" 135 136 return is_compatible, "Op is compatible" 137 138 def node_is_compatible( 139 self, node: torch.fx.Node, features: Optional[OpFeatures] = None 140 ) -> Tuple[bool, str]: 141 # TODO(ssjia) support symbolic ints 142 if utils.is_symint_node(node): 143 return False, "symint node not supported yet" 144 elif utils.is_tensor_node(node): 145 return self.op_node_is_compatible(node, features=features) 146 147 return False, f"Unsupported node type: {node.format_node()}" 148 149 def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]: 150 """ 151 Detect if a node is a permute/transpose that precedes a call to a `mm` or 152 `addmm` operator. This node can be fused with the `mm` or `addmm` to produce a 153 `linear` operator. 154 155 This function returns two bool values: 156 1. The first indicates if this node can be fused into a linear node 157 2. The second indicates if the overall linear op can be executed with Vulkan 158 159 The node will be partitioned only if both are true. 160 """ 161 if node.target not in [ 162 exir_ops.edge.aten.t_copy.default, 163 exir_ops.edge.aten.permute_copy.default, 164 ]: 165 return False, False 166 167 if len(node.users) != 1: 168 return False, False 169 170 first_user = list(node.users.keys())[0] 171 if first_user.target in [ 172 exir_ops.edge.aten.mm.default, 173 exir_ops.edge.aten.addmm.default, 174 ]: 175 # Only mark this node if the target linear op is valid 176 if self.node_is_compatible(first_user)[0]: 177 return True, True 178 else: 179 return True, False 180 181 return False, False 182 183 def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]: 184 """ 185 Scalar tensors are usually converted to scalar values in the graph via` 186 scalar_tensor[0].item()` in Python, which translates to a chain of 187 `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph. 188 This function marks the entire chain as supported by the Vulkan delegate. 189 190 Later, within vulkan_preprocess there will be a graph transform which replaces 191 the chain with passing in the scalar tensor directly. 192 193 Similar to the `is_linear_permute` function, this function has 2 return values. 194 """ 195 if node.target == exir_ops.edge.aten.select_copy.int: 196 if len(node.users) != 1: 197 return False, False 198 # pyre-ignore 199 if node.args[0].meta["val"].numel() != 1: 200 return False, False 201 202 local_scalar_dense = list(node.users.keys())[0] 203 if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default: 204 return False, False 205 206 return self.is_in_local_scalar_dense_chain(local_scalar_dense) 207 208 if node.target == torch.ops.aten._local_scalar_dense.default: 209 return True, all(self.node_is_compatible(user)[0] for user in node.users) 210 211 return False, False 212 213 def log_skip(self, node: torch.fx.Node, reason: str) -> None: 214 if node.op == "call_function": 215 logger.info( 216 f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}" 217 ) 218 219 def is_node_supported( 220 self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node 221 ) -> bool: 222 r = self._is_node_supported(node) 223 return r 224 225 def _is_node_supported(self, node: torch.fx.Node) -> bool: 226 target = node.target 227 if node.target == torch.ops.higher_order.auto_functionalized: 228 first_arg = node.args[0] 229 assert isinstance(first_arg, torch._ops.OpOverload) 230 target = first_arg.name() 231 232 is_linear_permute, target_linear_is_compatible = self.is_linear_permute(node) 233 if is_linear_permute and target_linear_is_compatible: 234 return True 235 elif is_linear_permute: 236 # Skip so that the permute can be fused into a linear by another backend 237 self.log_skip(node, "permute node of non compatible linear node") 238 return False 239 240 is_in_local_scalar_dense_chain, dst_node_is_compatible = ( 241 self.is_in_local_scalar_dense_chain(node) 242 ) 243 if is_in_local_scalar_dense_chain and dst_node_is_compatible: 244 return True 245 elif is_in_local_scalar_dense_chain: 246 self.log_skip(node, "local scalar dense of incompatible op node") 247 return False 248 249 if target not in vulkan_supported_ops: 250 self.log_skip(node, "no operator implementation") 251 return False 252 253 features = vulkan_supported_ops[target] 254 255 if not features.check_node_fn(node): 256 self.log_skip(node, "op args not supported") 257 return False 258 259 if self.require_dynamic_shapes and not features.resize_fn: 260 self.log_skip(node, "no dynamic shape support") 261 return False 262 263 is_compatible, reason = self.node_is_compatible(node, features=features) 264 if not is_compatible: 265 self.log_skip(node, reason) 266 267 return is_compatible 268 269 270def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: 271 compile_specs = [] 272 273 for key, value in compile_options.items(): 274 if isinstance(value, (VkStorageType, VkMemoryLayout)): 275 value_bytes = int(value).to_bytes(4, byteorder="little") 276 compile_specs.append(CompileSpec(key, value_bytes)) 277 278 if isinstance(value, bool): 279 value_bytes = value.to_bytes(1, byteorder="little") 280 compile_specs.append(CompileSpec(key, value_bytes)) 281 282 if key == "texture_limits": 283 compile_specs.append( 284 CompileSpec( 285 "texture_limits_x", int(value[0]).to_bytes(4, byteorder="little") 286 ) 287 ) 288 compile_specs.append( 289 CompileSpec( 290 "texture_limits_y", int(value[1]).to_bytes(4, byteorder="little") 291 ) 292 ) 293 compile_specs.append( 294 CompileSpec( 295 "texture_limits_z", int(value[2]).to_bytes(4, byteorder="little") 296 ) 297 ) 298 299 # Unhandled options are ignored 300 301 return compile_specs 302 303 304@final 305class VulkanPartitioner(Partitioner): 306 def __init__( 307 self, 308 compile_options: Optional[Dict[str, Any]] = None, 309 ) -> None: 310 self.options: Dict[str, Any] = {} 311 if compile_options is not None: 312 self.options = compile_options 313 314 compile_spec = parse_compile_options(self.options) 315 self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec) 316 317 def ops_to_not_decompose( 318 self, ep: ExportedProgram 319 ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: 320 return (ops_not_to_decompose, None) 321 322 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 323 # Run the CapabilityBasedPartitioner to return the largest possible 324 # subgraphs containing the nodes with the tags 325 partition_tags = {} 326 327 texture_limits: utils.ImageExtents = self.options.get( 328 "texture_limits", utils.DEFAULT_TEXTURE_LIMITS 329 ) 330 buffer_limit: int = self.options.get("buffer_limit", utils.DEFAULT_BUFFER_LIMIT) 331 capability_partitioner = CapabilityBasedPartitioner( 332 exported_program.graph_module, 333 VulkanSupportedOperators( 334 texture_limits, 335 buffer_limit, 336 require_dynamic_shape=self.options.get("require_dynamic_shapes", False), 337 ), 338 allows_single_node_partition=True, 339 ) 340 partition_list = capability_partitioner.propose_partitions() 341 for partition in partition_list: 342 for node in partition.nodes: 343 tag = f"tag{partition.id}" 344 node.meta["delegation_tag"] = tag 345 partition_tags[tag] = self.delegation_spec 346 347 pl = len(partition_list) 348 if pl == 0: 349 logger.warning("No Vulkan subgraphs can be partitioned!") 350 else: 351 logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.") 352 353 tag_constant_data(exported_program) 354 355 return PartitionResult( 356 tagged_exported_program=exported_program, partition_tags=partition_tags 357 ) 358