1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import operator 10 11from typing import Callable, Dict, Optional, Set, Union 12 13import executorch.backends.vulkan.custom_ops_lib # noqa 14 15import torch 16 17from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( 18 VkMemoryLayout, 19 VkStorageType, 20) 21 22from executorch.backends.vulkan.utils import ( 23 all_memory_layouts, 24 all_packed_dims, 25 PackedDim, 26) 27from executorch.exir.dialects._ops import ops as exir_ops 28 29from executorch.exir.dialects.edge._ops import EdgeOpOverload 30from torch._subclasses.fake_tensor import FakeTensor 31 32###################### 33## OpFeatures class ## 34###################### 35 36 37def allow_node(node: torch.fx.Node) -> bool: 38 return True 39 40 41class TextureImplFeatures: 42 __slots__ = [ 43 "valid_packed_dims", 44 "uses_axis_map", 45 ] 46 47 def __init__( 48 self, 49 uses_axis_map: bool = False, 50 valid_packed_dims: Optional[Set[PackedDim]] = None, 51 ): 52 self.uses_axis_map: bool = uses_axis_map 53 self.valid_packed_dims = set() 54 if valid_packed_dims is not None: 55 self.valid_packed_dims = valid_packed_dims 56 57 def valid_memory_layouts(self) -> Set[VkMemoryLayout]: 58 """ 59 Derive the set of memory layouts supported by the texture implementation based 60 on the valid packed dimensions. 61 """ 62 layouts = set() 63 64 if PackedDim.WIDTH in self.valid_packed_dims: 65 layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED) 66 67 if PackedDim.HEIGHT in self.valid_packed_dims: 68 layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED) 69 70 if PackedDim.CHANNELS in self.valid_packed_dims: 71 layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED) 72 73 return layouts 74 75 76class OpFeatures: 77 __slots__ = [ 78 # None or TextureImplFeatures to specify implementation details of the texture 79 # based operator implementation. 80 "texture_impl", 81 # bool indicating if the operator has a buffer based implementation. 82 "buffer_impl", 83 # bool indicating if the operator has a resize function, which allows it to 84 # support dynamic shape tensors. 85 "resize_fn", 86 # Optimal 87 "optimal_storage", 88 "optimal_layout", 89 # bool indicating if the operator handles its own prepacking. If this is True, 90 # then the insert_prepack_nodes pass will not insert prepack nodes for the args 91 # of the op. 92 "handles_own_prepacking", 93 # Optional dictionary to specify a custom function to calculate the required 94 # image extents for a particular argument index. 95 "skip_limits_check", 96 # Optional check function used during partitioning to determine if a node's 97 # inputs are supported by the operator implementation. 98 "check_node_fn", 99 ] 100 101 def __init__( 102 self, 103 texture_impl: Optional[TextureImplFeatures] = None, 104 buffer_impl: bool = False, 105 resize_fn: bool = False, 106 optimal_storage: Optional[VkStorageType] = None, 107 optimal_layout: Optional[VkMemoryLayout] = None, 108 handles_own_prepacking: bool = False, 109 skip_limits_check: Optional[Set[int]] = None, 110 check_node_fn: Optional[Callable] = None, 111 ): 112 self.texture_impl: Optional[TextureImplFeatures] = texture_impl 113 self.buffer_impl: bool = buffer_impl 114 self.resize_fn: bool = resize_fn 115 self.optimal_storage: Optional[VkStorageType] = optimal_storage 116 self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout 117 self.handles_own_prepacking: bool = handles_own_prepacking 118 119 self.skip_limits_check: Set[int] = set() 120 if skip_limits_check is not None: 121 self.skip_limits_check = skip_limits_check 122 123 self.check_node_fn: Callable = allow_node 124 if check_node_fn is not None: 125 self.check_node_fn = check_node_fn 126 127 def propose_storage_type(self) -> Optional[VkStorageType]: 128 """ 129 Propose a storage type that should be used for this operator. A proposal can be 130 made if one of the following is true: 131 1. The operator specifies an optimal storage type 132 2. Only one storage type is supported. 133 134 If both storage types are supported and no optimal storage type is specified, 135 then None is returned to indicate that there is no preference in storage type. 136 """ 137 if self.optimal_storage is not None: 138 return self.optimal_storage 139 140 if self.texture_impl is not None and not self.buffer_impl: 141 return VkStorageType.TEXTURE_3D 142 elif self.buffer_impl and self.texture_impl is None: 143 return VkStorageType.BUFFER 144 145 return None 146 147 def supported_storage_types(self) -> Set[VkStorageType]: 148 """ 149 Return the set of storage types supported by this operator. 150 """ 151 storage_types = set() 152 if self.texture_impl is not None: 153 storage_types.add(VkStorageType.TEXTURE_3D) 154 if self.buffer_impl: 155 storage_types.add(VkStorageType.BUFFER) 156 157 return storage_types 158 159 def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]: 160 """ 161 Given a storage type as a precondition, propose a memory layout that should be 162 used for this operator. A proposal can be made if one of the following is true: 163 1. The operator specifies an optimal memory layout 164 2. Only one memory layout is supported. 165 166 If multiple memory layouts are supported and no optimal memory layout is 167 specified then return None to indicate that the "best" memory layout for the 168 operator is ambiguous. 169 """ 170 if self.optimal_layout is not None: 171 return self.optimal_layout 172 173 if storage == VkStorageType.TEXTURE_3D: 174 assert self.texture_impl is not None 175 possible_layouts = self.texture_impl.valid_memory_layouts() 176 if len(possible_layouts) == 1: 177 return next(iter(possible_layouts)) 178 179 return None 180 181 def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]: 182 """ 183 Return the set of memory layouts supported by this operator for a given storage 184 type. 185 """ 186 if storage == VkStorageType.TEXTURE_3D: 187 assert self.texture_impl is not None 188 return self.texture_impl.valid_memory_layouts() 189 else: 190 return all_memory_layouts 191 192 193####################### 194## Operator Registry ## 195####################### 196 197OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload] 198 199vulkan_supported_ops: Dict[OpKey, OpFeatures] = {} 200 201 202def update_features(aten_op): 203 def features_decorator(fn: Callable): 204 def update_features_impl(op: OpKey): 205 if op in vulkan_supported_ops: 206 raise RuntimeError(f"[Vulkan delegate] duplicate registration of {op}!") 207 vulkan_supported_ops[op] = OpFeatures() 208 vulkan_supported_ops[op] = fn(vulkan_supported_ops[op]) 209 210 if isinstance(aten_op, list): 211 for op in aten_op: 212 update_features_impl(op) 213 else: 214 update_features_impl(aten_op) 215 216 return fn 217 218 return features_decorator 219 220 221@update_features( 222 [ 223 operator.getitem, 224 # Quantization related ops will be fused via graph passes 225 exir_ops.edge.quantized_decomposed.quantize_per_channel.default, 226 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 227 exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, 228 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 229 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, 230 exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, 231 ] 232) 233def register_ephemeral_op(features: OpFeatures): 234 features.texture_impl = TextureImplFeatures( 235 uses_axis_map=True, 236 valid_packed_dims=all_packed_dims, 237 ) 238 features.buffer_impl = True 239 features.resize_fn = True 240 return features 241 242 243@update_features( 244 [ 245 exir_ops.edge.aten.add.Tensor, 246 exir_ops.edge.aten.sub.Tensor, 247 exir_ops.edge.aten.minimum.default, 248 exir_ops.edge.aten.mul.Tensor, 249 exir_ops.edge.aten.div.Tensor, 250 exir_ops.edge.aten.div.Tensor_mode, 251 exir_ops.edge.aten.pow.Tensor_Tensor, 252 ] 253) 254def register_binary_op(features: OpFeatures): 255 features.texture_impl = TextureImplFeatures( 256 uses_axis_map=True, 257 valid_packed_dims=all_packed_dims, 258 ) 259 features.resize_fn = True 260 return features 261 262 263@update_features( 264 [ 265 exir_ops.edge.aten.abs.default, 266 exir_ops.edge.aten.clamp.default, 267 exir_ops.edge.aten.cos.default, 268 exir_ops.edge.aten.exp.default, 269 exir_ops.edge.aten.gelu.default, 270 exir_ops.edge.aten.hardshrink.default, 271 exir_ops.edge.aten.hardtanh.default, 272 exir_ops.edge.aten.neg.default, 273 exir_ops.edge.aten.relu.default, 274 exir_ops.edge.aten.sigmoid.default, 275 exir_ops.edge.aten.sin.default, 276 exir_ops.edge.aten.sqrt.default, 277 exir_ops.edge.aten.rsqrt.default, 278 exir_ops.edge.aten.tanh.default, 279 ] 280) 281def register_unary_op(features: OpFeatures): 282 features.texture_impl = TextureImplFeatures( 283 uses_axis_map=True, 284 valid_packed_dims=all_packed_dims, 285 ) 286 features.buffer_impl = True 287 features.resize_fn = True 288 return features 289 290 291@update_features(exir_ops.edge.aten._to_copy.default) 292def register_to_copy_op(features: OpFeatures): 293 features.texture_impl = TextureImplFeatures( 294 uses_axis_map=True, 295 valid_packed_dims=all_packed_dims, 296 ) 297 features.resize_fn = True 298 299 def check_to_copy_node(node: torch.fx.Node) -> bool: 300 float_dtypes = [torch.float16, torch.float32] 301 302 if len(node.args) != 1: 303 return False 304 305 in_arg = node.args[0] 306 if not isinstance(in_arg, torch.fx.Node): 307 return False 308 309 in_tensor = in_arg.meta.get("val", None) 310 out_tensor = node.meta.get("val", None) 311 312 if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor): 313 if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes: 314 return True 315 316 return False 317 318 features.check_node_fn = check_to_copy_node 319 320 return features 321 322 323@update_features( 324 [ 325 exir_ops.edge.aten.bmm.default, 326 exir_ops.edge.aten.mm.default, 327 exir_ops.edge.aten.addmm.default, 328 exir_ops.edge.aten.linear.default, 329 ] 330) 331def register_mm_op(features: OpFeatures): 332 features.texture_impl = TextureImplFeatures( 333 uses_axis_map=True, 334 valid_packed_dims={ 335 PackedDim.WIDTH, 336 PackedDim.CHANNELS, 337 }, 338 ) 339 features.buffer_impl = True 340 features.resize_fn = True 341 features.optimal_storage = VkStorageType.TEXTURE_3D 342 features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED 343 features.handles_own_prepacking = True 344 return features 345 346 347@update_features(exir_ops.edge.aten._weight_int8pack_mm.default) 348def register_int8_mm_op(features: OpFeatures): 349 features.texture_impl = TextureImplFeatures( 350 uses_axis_map=False, 351 valid_packed_dims={PackedDim.WIDTH}, 352 ) 353 features.buffer_impl = True 354 features.resize_fn = True 355 features.optimal_storage = VkStorageType.TEXTURE_3D 356 features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED 357 features.handles_own_prepacking = True 358 return features 359 360 361@update_features(exir_ops.edge.et_vk.linear_weight_int4.default) 362def register_int4_mm_op(features: OpFeatures): 363 features.texture_impl = TextureImplFeatures( 364 uses_axis_map=False, 365 valid_packed_dims={PackedDim.WIDTH}, 366 ) 367 features.resize_fn = True 368 features.optimal_storage = VkStorageType.TEXTURE_3D 369 features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED 370 features.handles_own_prepacking = True 371 return features 372 373 374@update_features( 375 [ 376 exir_ops.edge.aten._log_softmax.default, 377 exir_ops.edge.aten._softmax.default, 378 ] 379) 380def register_softmax_op(features: OpFeatures): 381 features.texture_impl = TextureImplFeatures( 382 valid_packed_dims=all_packed_dims, 383 ) 384 features.resize_fn = True 385 return features 386 387 388@update_features( 389 [ 390 exir_ops.edge.aten.mean.dim, 391 exir_ops.edge.aten.sum.dim_IntList, 392 exir_ops.edge.aten.amax.default, 393 exir_ops.edge.aten.amin.default, 394 ] 395) 396def register_reduce_op(features: OpFeatures): 397 features.texture_impl = TextureImplFeatures( 398 valid_packed_dims=all_packed_dims, 399 ) 400 features.resize_fn = True 401 402 def check_reduce_node(node: torch.fx.Node) -> bool: 403 dim_list = node.args[1] 404 if isinstance(dim_list, list) and len(dim_list) != 1: 405 return False 406 407 keepdim = node.args[2] 408 if isinstance(keepdim, bool) and not keepdim: 409 return False 410 411 return True 412 413 features.check_node_fn = check_reduce_node 414 return features 415 416 417@update_features( 418 [ 419 exir_ops.edge.aten.avg_pool2d.default, 420 exir_ops.edge.aten.max_pool2d_with_indices.default, 421 ] 422) 423def register_2d_pool_op(features: OpFeatures): 424 features.texture_impl = TextureImplFeatures( 425 valid_packed_dims={PackedDim.CHANNELS}, 426 ) 427 features.resize_fn = True 428 return features 429 430 431@update_features( 432 [ 433 exir_ops.edge.aten.convolution.default, 434 exir_ops.edge.et_vk.conv_with_clamp.default, 435 ] 436) 437def register_convolution_op(features: OpFeatures): 438 features.texture_impl = TextureImplFeatures( 439 valid_packed_dims={PackedDim.CHANNELS}, 440 ) 441 features.resize_fn = True 442 features.optimal_storage = VkStorageType.TEXTURE_3D 443 features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED 444 features.handles_own_prepacking = True 445 features.skip_limits_check = {1, 2} 446 return features 447 448 449@update_features("llama::sdpa_with_kv_cache") 450def register_sdpa_op(features: OpFeatures): 451 features.texture_impl = TextureImplFeatures( 452 valid_packed_dims={PackedDim.WIDTH}, 453 ) 454 features.resize_fn = True 455 features.optimal_storage = VkStorageType.TEXTURE_3D 456 features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED 457 features.handles_own_prepacking = True 458 return features 459 460 461@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) 462def register_rotary_emb_op(features: OpFeatures): 463 features.texture_impl = TextureImplFeatures( 464 valid_packed_dims={PackedDim.WIDTH}, 465 ) 466 features.resize_fn = True 467 return features 468 469 470@update_features(exir_ops.edge.aten.view_copy.default) 471def register_view_op(features: OpFeatures): 472 features.texture_impl = TextureImplFeatures( 473 valid_packed_dims=all_packed_dims, 474 ) 475 features.resize_fn = True 476 return features 477 478 479# Ops ported from PyTorch Vulkan backend. These ops commonly support channels 480# packed tensors only and do not have a resize function. 481@update_features( 482 [ 483 # Shape Manipulation 484 exir_ops.edge.aten.squeeze_copy.dims, 485 exir_ops.edge.aten.unsqueeze_copy.default, 486 exir_ops.edge.aten.permute_copy.default, 487 exir_ops.edge.aten.t_copy.default, 488 # Indexing and lookup 489 exir_ops.edge.aten.flip.default, 490 exir_ops.edge.aten.index_select.default, 491 exir_ops.edge.aten.select_copy.int, 492 exir_ops.edge.aten.slice_copy.Tensor, 493 # Tensor combination 494 exir_ops.edge.aten.cat.default, 495 exir_ops.edge.aten.split_with_sizes_copy.default, 496 exir_ops.edge.aten.split.Tensor, 497 exir_ops.edge.aten.repeat.default, 498 # Tensor creation 499 exir_ops.edge.aten.arange.start_step, 500 exir_ops.edge.aten.clone.default, 501 exir_ops.edge.aten.constant_pad_nd.default, 502 exir_ops.edge.aten.full.default, 503 exir_ops.edge.aten.full_like.default, 504 exir_ops.edge.aten.ones.default, 505 exir_ops.edge.aten.ones_like.default, 506 exir_ops.edge.aten.upsample_nearest2d.vec, 507 exir_ops.edge.aten.zeros.default, 508 exir_ops.edge.aten.zeros_like.default, 509 exir_ops.edge.et_vk.grid_priors.default, 510 ] 511) 512def register_ported_op(features: OpFeatures): 513 features.texture_impl = TextureImplFeatures( 514 valid_packed_dims={PackedDim.CHANNELS}, 515 ) 516 return features 517 518 519# Ported ops that support their own prepacking. 520@update_features( 521 [ 522 exir_ops.edge.aten.embedding.default, 523 exir_ops.edge.aten._native_batch_norm_legit_no_training.default, 524 exir_ops.edge.aten.native_layer_norm.default, 525 ] 526) 527def register_ported_ops_with_prepacking(features: OpFeatures): 528 features.texture_impl = TextureImplFeatures( 529 valid_packed_dims={PackedDim.CHANNELS}, 530 ) 531 features.handles_own_prepacking = True 532 return features 533 534 535####################### 536## Utility functions ## 537####################### 538 539 540def has_impl(target: OpKey) -> bool: 541 if not isinstance(target, str): 542 if target not in vulkan_supported_ops: 543 return target.name() in vulkan_supported_ops 544 return target in vulkan_supported_ops 545 else: 546 return target in vulkan_supported_ops 547 548 549def get_op_features(target: OpKey) -> OpFeatures: 550 if not isinstance(target, str): 551 if target not in vulkan_supported_ops: 552 # Try the op's name 553 return vulkan_supported_ops[target.name()] 554 555 return vulkan_supported_ops[target] 556 else: 557 return vulkan_supported_ops[target] 558 559 560def handles_own_prepacking(target: OpKey) -> bool: 561 return get_op_features(target).handles_own_prepacking 562