xref: /aosp_15_r20/external/executorch/backends/vulkan/partitioner/supported_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import operator
10
11from executorch.backends.vulkan._passes.custom_ops_defs import (  # noqa
12    conv_with_clamp_op,
13    grid_priors_op,
14)
15
16from executorch.exir.dialects._ops import ops as exir_ops
17
18
19class OpFeatures:
20    __slots__ = ["supports_texture", "supports_buffer", "supports_dynamic_shape"]
21
22    def __init__(
23        self,
24        supports_dynamic_shape: bool = False,
25        supports_buffer: bool = False,
26        supports_texture: bool = True,
27    ):
28        self.supports_dynamic_shape = supports_dynamic_shape
29        self.supports_texture = supports_texture
30        self.supports_buffer = supports_buffer
31
32
33class OpList:
34    def __init__(self):
35        self._ops = {}
36
37    def __getitem__(self, op):
38        if op not in self._ops:
39            self._ops[op] = OpFeatures()
40        return self._ops[op]
41
42    def __contains__(self, op):
43        return op in self._ops
44
45
46PRIM_OPS = [
47    operator.getitem,
48]
49
50BINARY_OPS = [
51    exir_ops.edge.aten.add.Tensor,
52    exir_ops.edge.aten.sub.Tensor,
53    exir_ops.edge.aten.minimum.default,
54    exir_ops.edge.aten.mul.Tensor,
55    exir_ops.edge.aten.div.Tensor,
56    exir_ops.edge.aten.div.Tensor_mode,
57    exir_ops.edge.aten.pow.Tensor_Tensor,
58]
59
60UNARY_OPS = [
61    exir_ops.edge.aten.abs.default,
62    exir_ops.edge.aten.clamp.default,
63    exir_ops.edge.aten.cos.default,
64    exir_ops.edge.aten.exp.default,
65    exir_ops.edge.aten.gelu.default,
66    exir_ops.edge.aten.hardshrink.default,
67    exir_ops.edge.aten.hardtanh.default,
68    exir_ops.edge.aten.neg.default,
69    exir_ops.edge.aten.relu.default,
70    exir_ops.edge.aten.sigmoid.default,
71    exir_ops.edge.aten.sin.default,
72    exir_ops.edge.aten.sqrt.default,
73    exir_ops.edge.aten.tanh.default,
74]
75
76MATMUL_OPS = [
77    exir_ops.edge.aten.bmm.default,
78    exir_ops.edge.aten.mm.default,
79    exir_ops.edge.aten.addmm.default,
80    exir_ops.edge.aten.linear.default,
81]
82
83POOLING_OPS = [
84    exir_ops.edge.aten.avg_pool2d.default,
85    exir_ops.edge.aten.max_pool2d_with_indices.default,
86]
87
88CONVOLUTION_OPS = [
89    exir_ops.edge.aten.convolution.default,
90    exir_ops.edge.et_vk.conv_with_clamp.default,
91]
92
93REDUCTION_OPS = [
94    exir_ops.edge.aten.mean.dim,
95    exir_ops.edge.aten.sum.dim_IntList,
96    exir_ops.edge.aten._log_softmax.default,
97    exir_ops.edge.aten._softmax.default,
98]
99
100NORMALIZATION_OPS = [
101    exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
102    exir_ops.edge.aten.native_layer_norm.default,
103]
104
105SHAPE_MANIPULATION_OPS = [
106    exir_ops.edge.aten.squeeze_copy.dims,
107    exir_ops.edge.aten.unsqueeze_copy.default,
108    exir_ops.edge.aten.view_copy.default,
109    exir_ops.edge.aten.permute_copy.default,
110    exir_ops.edge.aten.t_copy.default,
111]
112
113INDEXING_OPS = [
114    exir_ops.edge.aten.embedding.default,
115    exir_ops.edge.aten.index_select.default,
116    exir_ops.edge.aten.select_copy.int,
117    exir_ops.edge.aten.slice_copy.Tensor,
118]
119
120ORCHESTRATION_OPS = [
121    exir_ops.edge.aten.cat.default,
122    exir_ops.edge.aten.split_with_sizes_copy.default,
123    exir_ops.edge.aten.split.Tensor,
124    exir_ops.edge.aten.repeat.default,
125]
126
127CREATION_OPS = [
128    exir_ops.edge.aten.arange.start_step,
129    exir_ops.edge.aten.clone.default,
130    exir_ops.edge.aten.constant_pad_nd.default,
131    exir_ops.edge.aten.full.default,
132    exir_ops.edge.aten.full_like.default,
133    exir_ops.edge.aten.ones.default,
134    exir_ops.edge.aten.ones_like.default,
135    exir_ops.edge.aten.upsample_nearest2d.vec,
136    exir_ops.edge.aten.zeros.default,
137    exir_ops.edge.aten.zeros_like.default,
138    exir_ops.edge.et_vk.grid_priors.default,
139]
140
141
142def register_prim_ops(ops: OpList):
143    for op in PRIM_OPS:
144        ops[op].supports_texture = True
145        ops[op].supports_buffer = True
146        ops[op].supports_dynamic_shape = True
147
148
149def register_no_dynamic_shape_ops(ops: OpList):
150    for op in [
151        *REDUCTION_OPS,
152        *NORMALIZATION_OPS,
153        *SHAPE_MANIPULATION_OPS,
154        *INDEXING_OPS,
155        *ORCHESTRATION_OPS,
156        *CREATION_OPS,
157    ]:
158        ops[op].supports_dynamic_shape = False
159
160
161def register_dynamic_shape_ops(ops: OpList):
162    for op in [
163        *BINARY_OPS,
164        *UNARY_OPS,
165        *MATMUL_OPS,
166        *POOLING_OPS,
167        *CONVOLUTION_OPS,
168    ]:
169        ops[op].supports_dynamic_shape = True
170
171
172def enumerate_supported_ops():
173    ops = OpList()
174    register_prim_ops(ops)
175    register_no_dynamic_shape_ops(ops)
176    register_dynamic_shape_ops(ops)
177    return ops
178