xref: /aosp_15_r20/external/executorch/backends/vulkan/op_registry.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import operator
10
11from typing import Callable, Dict, Optional, Set, Union
12
13import executorch.backends.vulkan.custom_ops_lib  # noqa
14
15import torch
16
17from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
18    VkMemoryLayout,
19    VkStorageType,
20)
21
22from executorch.backends.vulkan.utils import (
23    all_memory_layouts,
24    all_packed_dims,
25    PackedDim,
26)
27from executorch.exir.dialects._ops import ops as exir_ops
28
29from executorch.exir.dialects.edge._ops import EdgeOpOverload
30from torch._subclasses.fake_tensor import FakeTensor
31
32######################
33## OpFeatures class ##
34######################
35
36
37def allow_node(node: torch.fx.Node) -> bool:
38    return True
39
40
41class TextureImplFeatures:
42    __slots__ = [
43        "valid_packed_dims",
44        "uses_axis_map",
45    ]
46
47    def __init__(
48        self,
49        uses_axis_map: bool = False,
50        valid_packed_dims: Optional[Set[PackedDim]] = None,
51    ):
52        self.uses_axis_map: bool = uses_axis_map
53        self.valid_packed_dims = set()
54        if valid_packed_dims is not None:
55            self.valid_packed_dims = valid_packed_dims
56
57    def valid_memory_layouts(self) -> Set[VkMemoryLayout]:
58        """
59        Derive the set of memory layouts supported by the texture implementation based
60        on the valid packed dimensions.
61        """
62        layouts = set()
63
64        if PackedDim.WIDTH in self.valid_packed_dims:
65            layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED)
66
67        if PackedDim.HEIGHT in self.valid_packed_dims:
68            layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED)
69
70        if PackedDim.CHANNELS in self.valid_packed_dims:
71            layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED)
72
73        return layouts
74
75
76class OpFeatures:
77    __slots__ = [
78        # None or TextureImplFeatures to specify implementation details of the texture
79        # based operator implementation.
80        "texture_impl",
81        # bool indicating if the operator has a buffer based implementation.
82        "buffer_impl",
83        # bool indicating if the operator has a resize function, which allows it to
84        # support dynamic shape tensors.
85        "resize_fn",
86        # Optimal
87        "optimal_storage",
88        "optimal_layout",
89        # bool indicating if the operator handles its own prepacking. If this is True,
90        # then the insert_prepack_nodes pass will not insert prepack nodes for the args
91        # of the op.
92        "handles_own_prepacking",
93        # Optional dictionary to specify a custom function to calculate the required
94        # image extents for a particular argument index.
95        "skip_limits_check",
96        # Optional check function used during partitioning to determine if a node's
97        # inputs are supported by the operator implementation.
98        "check_node_fn",
99    ]
100
101    def __init__(
102        self,
103        texture_impl: Optional[TextureImplFeatures] = None,
104        buffer_impl: bool = False,
105        resize_fn: bool = False,
106        optimal_storage: Optional[VkStorageType] = None,
107        optimal_layout: Optional[VkMemoryLayout] = None,
108        handles_own_prepacking: bool = False,
109        skip_limits_check: Optional[Set[int]] = None,
110        check_node_fn: Optional[Callable] = None,
111    ):
112        self.texture_impl: Optional[TextureImplFeatures] = texture_impl
113        self.buffer_impl: bool = buffer_impl
114        self.resize_fn: bool = resize_fn
115        self.optimal_storage: Optional[VkStorageType] = optimal_storage
116        self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
117        self.handles_own_prepacking: bool = handles_own_prepacking
118
119        self.skip_limits_check: Set[int] = set()
120        if skip_limits_check is not None:
121            self.skip_limits_check = skip_limits_check
122
123        self.check_node_fn: Callable = allow_node
124        if check_node_fn is not None:
125            self.check_node_fn = check_node_fn
126
127    def propose_storage_type(self) -> Optional[VkStorageType]:
128        """
129        Propose a storage type that should be used for this operator. A proposal can be
130        made if one of the following is true:
131        1. The operator specifies an optimal storage type
132        2. Only one storage type is supported.
133
134        If both storage types are supported and no optimal storage type is specified,
135        then None is returned to indicate that there is no preference in storage type.
136        """
137        if self.optimal_storage is not None:
138            return self.optimal_storage
139
140        if self.texture_impl is not None and not self.buffer_impl:
141            return VkStorageType.TEXTURE_3D
142        elif self.buffer_impl and self.texture_impl is None:
143            return VkStorageType.BUFFER
144
145        return None
146
147    def supported_storage_types(self) -> Set[VkStorageType]:
148        """
149        Return the set of storage types supported by this operator.
150        """
151        storage_types = set()
152        if self.texture_impl is not None:
153            storage_types.add(VkStorageType.TEXTURE_3D)
154        if self.buffer_impl:
155            storage_types.add(VkStorageType.BUFFER)
156
157        return storage_types
158
159    def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]:
160        """
161        Given a storage type as a precondition, propose a memory layout that should be
162        used for this operator. A proposal can be made if one of the following is true:
163        1. The operator specifies an optimal memory layout
164        2. Only one memory layout is supported.
165
166        If multiple memory layouts are supported and no optimal memory layout is
167        specified then return None to indicate that the "best" memory layout for the
168        operator is ambiguous.
169        """
170        if self.optimal_layout is not None:
171            return self.optimal_layout
172
173        if storage == VkStorageType.TEXTURE_3D:
174            assert self.texture_impl is not None
175            possible_layouts = self.texture_impl.valid_memory_layouts()
176            if len(possible_layouts) == 1:
177                return next(iter(possible_layouts))
178
179        return None
180
181    def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]:
182        """
183        Return the set of memory layouts supported by this operator for a given storage
184        type.
185        """
186        if storage == VkStorageType.TEXTURE_3D:
187            assert self.texture_impl is not None
188            return self.texture_impl.valid_memory_layouts()
189        else:
190            return all_memory_layouts
191
192
193#######################
194## Operator Registry ##
195#######################
196
197OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload]
198
199vulkan_supported_ops: Dict[OpKey, OpFeatures] = {}
200
201
202def update_features(aten_op):
203    def features_decorator(fn: Callable):
204        def update_features_impl(op: OpKey):
205            if op in vulkan_supported_ops:
206                raise RuntimeError(f"[Vulkan delegate] duplicate registration of {op}!")
207            vulkan_supported_ops[op] = OpFeatures()
208            vulkan_supported_ops[op] = fn(vulkan_supported_ops[op])
209
210        if isinstance(aten_op, list):
211            for op in aten_op:
212                update_features_impl(op)
213        else:
214            update_features_impl(aten_op)
215
216        return fn
217
218    return features_decorator
219
220
221@update_features(
222    [
223        operator.getitem,
224        # Quantization related ops will be fused via graph passes
225        exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
226        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
227        exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
228        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
229        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
230        exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231    ]
232)
233def register_ephemeral_op(features: OpFeatures):
234    features.texture_impl = TextureImplFeatures(
235        uses_axis_map=True,
236        valid_packed_dims=all_packed_dims,
237    )
238    features.buffer_impl = True
239    features.resize_fn = True
240    return features
241
242
243@update_features(
244    [
245        exir_ops.edge.aten.add.Tensor,
246        exir_ops.edge.aten.sub.Tensor,
247        exir_ops.edge.aten.minimum.default,
248        exir_ops.edge.aten.mul.Tensor,
249        exir_ops.edge.aten.div.Tensor,
250        exir_ops.edge.aten.div.Tensor_mode,
251        exir_ops.edge.aten.pow.Tensor_Tensor,
252    ]
253)
254def register_binary_op(features: OpFeatures):
255    features.texture_impl = TextureImplFeatures(
256        uses_axis_map=True,
257        valid_packed_dims=all_packed_dims,
258    )
259    features.resize_fn = True
260    return features
261
262
263@update_features(
264    [
265        exir_ops.edge.aten.abs.default,
266        exir_ops.edge.aten.clamp.default,
267        exir_ops.edge.aten.cos.default,
268        exir_ops.edge.aten.exp.default,
269        exir_ops.edge.aten.gelu.default,
270        exir_ops.edge.aten.hardshrink.default,
271        exir_ops.edge.aten.hardtanh.default,
272        exir_ops.edge.aten.neg.default,
273        exir_ops.edge.aten.relu.default,
274        exir_ops.edge.aten.sigmoid.default,
275        exir_ops.edge.aten.sin.default,
276        exir_ops.edge.aten.sqrt.default,
277        exir_ops.edge.aten.rsqrt.default,
278        exir_ops.edge.aten.tanh.default,
279    ]
280)
281def register_unary_op(features: OpFeatures):
282    features.texture_impl = TextureImplFeatures(
283        uses_axis_map=True,
284        valid_packed_dims=all_packed_dims,
285    )
286    features.buffer_impl = True
287    features.resize_fn = True
288    return features
289
290
291@update_features(exir_ops.edge.aten._to_copy.default)
292def register_to_copy_op(features: OpFeatures):
293    features.texture_impl = TextureImplFeatures(
294        uses_axis_map=True,
295        valid_packed_dims=all_packed_dims,
296    )
297    features.resize_fn = True
298
299    def check_to_copy_node(node: torch.fx.Node) -> bool:
300        float_dtypes = [torch.float16, torch.float32]
301
302        if len(node.args) != 1:
303            return False
304
305        in_arg = node.args[0]
306        if not isinstance(in_arg, torch.fx.Node):
307            return False
308
309        in_tensor = in_arg.meta.get("val", None)
310        out_tensor = node.meta.get("val", None)
311
312        if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor):
313            if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes:
314                return True
315
316        return False
317
318    features.check_node_fn = check_to_copy_node
319
320    return features
321
322
323@update_features(
324    [
325        exir_ops.edge.aten.bmm.default,
326        exir_ops.edge.aten.mm.default,
327        exir_ops.edge.aten.addmm.default,
328        exir_ops.edge.aten.linear.default,
329    ]
330)
331def register_mm_op(features: OpFeatures):
332    features.texture_impl = TextureImplFeatures(
333        uses_axis_map=True,
334        valid_packed_dims={
335            PackedDim.WIDTH,
336            PackedDim.CHANNELS,
337        },
338    )
339    features.buffer_impl = True
340    features.resize_fn = True
341    features.optimal_storage = VkStorageType.TEXTURE_3D
342    features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
343    features.handles_own_prepacking = True
344    return features
345
346
347@update_features(exir_ops.edge.aten._weight_int8pack_mm.default)
348def register_int8_mm_op(features: OpFeatures):
349    features.texture_impl = TextureImplFeatures(
350        uses_axis_map=False,
351        valid_packed_dims={PackedDim.WIDTH},
352    )
353    features.buffer_impl = True
354    features.resize_fn = True
355    features.optimal_storage = VkStorageType.TEXTURE_3D
356    features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
357    features.handles_own_prepacking = True
358    return features
359
360
361@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
362def register_int4_mm_op(features: OpFeatures):
363    features.texture_impl = TextureImplFeatures(
364        uses_axis_map=False,
365        valid_packed_dims={PackedDim.WIDTH},
366    )
367    features.resize_fn = True
368    features.optimal_storage = VkStorageType.TEXTURE_3D
369    features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
370    features.handles_own_prepacking = True
371    return features
372
373
374@update_features(
375    [
376        exir_ops.edge.aten._log_softmax.default,
377        exir_ops.edge.aten._softmax.default,
378    ]
379)
380def register_softmax_op(features: OpFeatures):
381    features.texture_impl = TextureImplFeatures(
382        valid_packed_dims=all_packed_dims,
383    )
384    features.resize_fn = True
385    return features
386
387
388@update_features(
389    [
390        exir_ops.edge.aten.mean.dim,
391        exir_ops.edge.aten.sum.dim_IntList,
392        exir_ops.edge.aten.amax.default,
393        exir_ops.edge.aten.amin.default,
394    ]
395)
396def register_reduce_op(features: OpFeatures):
397    features.texture_impl = TextureImplFeatures(
398        valid_packed_dims=all_packed_dims,
399    )
400    features.resize_fn = True
401
402    def check_reduce_node(node: torch.fx.Node) -> bool:
403        dim_list = node.args[1]
404        if isinstance(dim_list, list) and len(dim_list) != 1:
405            return False
406
407        keepdim = node.args[2]
408        if isinstance(keepdim, bool) and not keepdim:
409            return False
410
411        return True
412
413    features.check_node_fn = check_reduce_node
414    return features
415
416
417@update_features(
418    [
419        exir_ops.edge.aten.avg_pool2d.default,
420        exir_ops.edge.aten.max_pool2d_with_indices.default,
421    ]
422)
423def register_2d_pool_op(features: OpFeatures):
424    features.texture_impl = TextureImplFeatures(
425        valid_packed_dims={PackedDim.CHANNELS},
426    )
427    features.resize_fn = True
428    return features
429
430
431@update_features(
432    [
433        exir_ops.edge.aten.convolution.default,
434        exir_ops.edge.et_vk.conv_with_clamp.default,
435    ]
436)
437def register_convolution_op(features: OpFeatures):
438    features.texture_impl = TextureImplFeatures(
439        valid_packed_dims={PackedDim.CHANNELS},
440    )
441    features.resize_fn = True
442    features.optimal_storage = VkStorageType.TEXTURE_3D
443    features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
444    features.handles_own_prepacking = True
445    features.skip_limits_check = {1, 2}
446    return features
447
448
449@update_features("llama::sdpa_with_kv_cache")
450def register_sdpa_op(features: OpFeatures):
451    features.texture_impl = TextureImplFeatures(
452        valid_packed_dims={PackedDim.WIDTH},
453    )
454    features.resize_fn = True
455    features.optimal_storage = VkStorageType.TEXTURE_3D
456    features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
457    features.handles_own_prepacking = True
458    return features
459
460
461@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
462def register_rotary_emb_op(features: OpFeatures):
463    features.texture_impl = TextureImplFeatures(
464        valid_packed_dims={PackedDim.WIDTH},
465    )
466    features.resize_fn = True
467    return features
468
469
470@update_features(exir_ops.edge.aten.view_copy.default)
471def register_view_op(features: OpFeatures):
472    features.texture_impl = TextureImplFeatures(
473        valid_packed_dims=all_packed_dims,
474    )
475    features.resize_fn = True
476    return features
477
478
479# Ops ported from PyTorch Vulkan backend. These ops commonly support channels
480# packed tensors only and do not have a resize function.
481@update_features(
482    [
483        # Shape Manipulation
484        exir_ops.edge.aten.squeeze_copy.dims,
485        exir_ops.edge.aten.unsqueeze_copy.default,
486        exir_ops.edge.aten.permute_copy.default,
487        exir_ops.edge.aten.t_copy.default,
488        # Indexing and lookup
489        exir_ops.edge.aten.flip.default,
490        exir_ops.edge.aten.index_select.default,
491        exir_ops.edge.aten.select_copy.int,
492        exir_ops.edge.aten.slice_copy.Tensor,
493        # Tensor combination
494        exir_ops.edge.aten.cat.default,
495        exir_ops.edge.aten.split_with_sizes_copy.default,
496        exir_ops.edge.aten.split.Tensor,
497        exir_ops.edge.aten.repeat.default,
498        # Tensor creation
499        exir_ops.edge.aten.arange.start_step,
500        exir_ops.edge.aten.clone.default,
501        exir_ops.edge.aten.constant_pad_nd.default,
502        exir_ops.edge.aten.full.default,
503        exir_ops.edge.aten.full_like.default,
504        exir_ops.edge.aten.ones.default,
505        exir_ops.edge.aten.ones_like.default,
506        exir_ops.edge.aten.upsample_nearest2d.vec,
507        exir_ops.edge.aten.zeros.default,
508        exir_ops.edge.aten.zeros_like.default,
509        exir_ops.edge.et_vk.grid_priors.default,
510    ]
511)
512def register_ported_op(features: OpFeatures):
513    features.texture_impl = TextureImplFeatures(
514        valid_packed_dims={PackedDim.CHANNELS},
515    )
516    return features
517
518
519# Ported ops that support their own prepacking.
520@update_features(
521    [
522        exir_ops.edge.aten.embedding.default,
523        exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
524        exir_ops.edge.aten.native_layer_norm.default,
525    ]
526)
527def register_ported_ops_with_prepacking(features: OpFeatures):
528    features.texture_impl = TextureImplFeatures(
529        valid_packed_dims={PackedDim.CHANNELS},
530    )
531    features.handles_own_prepacking = True
532    return features
533
534
535#######################
536## Utility functions ##
537#######################
538
539
540def has_impl(target: OpKey) -> bool:
541    if not isinstance(target, str):
542        if target not in vulkan_supported_ops:
543            return target.name() in vulkan_supported_ops
544        return target in vulkan_supported_ops
545    else:
546        return target in vulkan_supported_ops
547
548
549def get_op_features(target: OpKey) -> OpFeatures:
550    if not isinstance(target, str):
551        if target not in vulkan_supported_ops:
552            # Try the op's name
553            return vulkan_supported_ops[target.name()]
554
555        return vulkan_supported_ops[target]
556    else:
557        return vulkan_supported_ops[target]
558
559
560def handles_own_prepacking(target: OpKey) -> bool:
561    return get_op_features(target).handles_own_prepacking
562