1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import operator 10 11from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa 12 conv_with_clamp_op, 13 grid_priors_op, 14) 15 16from executorch.exir.dialects._ops import ops as exir_ops 17 18 19class OpFeatures: 20 __slots__ = ["supports_texture", "supports_buffer", "supports_dynamic_shape"] 21 22 def __init__( 23 self, 24 supports_dynamic_shape: bool = False, 25 supports_buffer: bool = False, 26 supports_texture: bool = True, 27 ): 28 self.supports_dynamic_shape = supports_dynamic_shape 29 self.supports_texture = supports_texture 30 self.supports_buffer = supports_buffer 31 32 33class OpList: 34 def __init__(self): 35 self._ops = {} 36 37 def __getitem__(self, op): 38 if op not in self._ops: 39 self._ops[op] = OpFeatures() 40 return self._ops[op] 41 42 def __contains__(self, op): 43 return op in self._ops 44 45 46PRIM_OPS = [ 47 operator.getitem, 48] 49 50BINARY_OPS = [ 51 exir_ops.edge.aten.add.Tensor, 52 exir_ops.edge.aten.sub.Tensor, 53 exir_ops.edge.aten.minimum.default, 54 exir_ops.edge.aten.mul.Tensor, 55 exir_ops.edge.aten.div.Tensor, 56 exir_ops.edge.aten.div.Tensor_mode, 57 exir_ops.edge.aten.pow.Tensor_Tensor, 58] 59 60UNARY_OPS = [ 61 exir_ops.edge.aten.abs.default, 62 exir_ops.edge.aten.clamp.default, 63 exir_ops.edge.aten.cos.default, 64 exir_ops.edge.aten.exp.default, 65 exir_ops.edge.aten.gelu.default, 66 exir_ops.edge.aten.hardshrink.default, 67 exir_ops.edge.aten.hardtanh.default, 68 exir_ops.edge.aten.neg.default, 69 exir_ops.edge.aten.relu.default, 70 exir_ops.edge.aten.sigmoid.default, 71 exir_ops.edge.aten.sin.default, 72 exir_ops.edge.aten.sqrt.default, 73 exir_ops.edge.aten.tanh.default, 74] 75 76MATMUL_OPS = [ 77 exir_ops.edge.aten.bmm.default, 78 exir_ops.edge.aten.mm.default, 79 exir_ops.edge.aten.addmm.default, 80 exir_ops.edge.aten.linear.default, 81] 82 83POOLING_OPS = [ 84 exir_ops.edge.aten.avg_pool2d.default, 85 exir_ops.edge.aten.max_pool2d_with_indices.default, 86] 87 88CONVOLUTION_OPS = [ 89 exir_ops.edge.aten.convolution.default, 90 exir_ops.edge.et_vk.conv_with_clamp.default, 91] 92 93REDUCTION_OPS = [ 94 exir_ops.edge.aten.mean.dim, 95 exir_ops.edge.aten.sum.dim_IntList, 96 exir_ops.edge.aten._log_softmax.default, 97 exir_ops.edge.aten._softmax.default, 98] 99 100NORMALIZATION_OPS = [ 101 exir_ops.edge.aten._native_batch_norm_legit_no_training.default, 102 exir_ops.edge.aten.native_layer_norm.default, 103] 104 105SHAPE_MANIPULATION_OPS = [ 106 exir_ops.edge.aten.squeeze_copy.dims, 107 exir_ops.edge.aten.unsqueeze_copy.default, 108 exir_ops.edge.aten.view_copy.default, 109 exir_ops.edge.aten.permute_copy.default, 110 exir_ops.edge.aten.t_copy.default, 111] 112 113INDEXING_OPS = [ 114 exir_ops.edge.aten.embedding.default, 115 exir_ops.edge.aten.index_select.default, 116 exir_ops.edge.aten.select_copy.int, 117 exir_ops.edge.aten.slice_copy.Tensor, 118] 119 120ORCHESTRATION_OPS = [ 121 exir_ops.edge.aten.cat.default, 122 exir_ops.edge.aten.split_with_sizes_copy.default, 123 exir_ops.edge.aten.split.Tensor, 124 exir_ops.edge.aten.repeat.default, 125] 126 127CREATION_OPS = [ 128 exir_ops.edge.aten.arange.start_step, 129 exir_ops.edge.aten.clone.default, 130 exir_ops.edge.aten.constant_pad_nd.default, 131 exir_ops.edge.aten.full.default, 132 exir_ops.edge.aten.full_like.default, 133 exir_ops.edge.aten.ones.default, 134 exir_ops.edge.aten.ones_like.default, 135 exir_ops.edge.aten.upsample_nearest2d.vec, 136 exir_ops.edge.aten.zeros.default, 137 exir_ops.edge.aten.zeros_like.default, 138 exir_ops.edge.et_vk.grid_priors.default, 139] 140 141 142def register_prim_ops(ops: OpList): 143 for op in PRIM_OPS: 144 ops[op].supports_texture = True 145 ops[op].supports_buffer = True 146 ops[op].supports_dynamic_shape = True 147 148 149def register_no_dynamic_shape_ops(ops: OpList): 150 for op in [ 151 *REDUCTION_OPS, 152 *NORMALIZATION_OPS, 153 *SHAPE_MANIPULATION_OPS, 154 *INDEXING_OPS, 155 *ORCHESTRATION_OPS, 156 *CREATION_OPS, 157 ]: 158 ops[op].supports_dynamic_shape = False 159 160 161def register_dynamic_shape_ops(ops: OpList): 162 for op in [ 163 *BINARY_OPS, 164 *UNARY_OPS, 165 *MATMUL_OPS, 166 *POOLING_OPS, 167 *CONVOLUTION_OPS, 168 ]: 169 ops[op].supports_dynamic_shape = True 170 171 172def enumerate_supported_ops(): 173 ops = OpList() 174 register_prim_ops(ops) 175 register_no_dynamic_shape_ops(ops) 176 register_dynamic_shape_ops(ops) 177 return ops 178