xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/qnn_constants.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from dataclasses import dataclass
8from enum import IntEnum, unique
9
10QNN_OP_PACKAGE_NAME_QTI_AISW = "qti.aisw"
11
12# Below constants should be same as those in QNN headers.
13# Maybe someday we should expose these constants by pybind
14# instead of replicating them here.
15
16
17@dataclass(init=False, frozen=True)
18class OpBatchnorm:
19    op_name: str = "Batchnorm"
20
21
22@dataclass(init=False, frozen=True)
23class OpCast:
24    op_name: str = "Cast"
25
26
27@dataclass(init=False, frozen=True)
28class OpConcat:
29    op_name: str = "Concat"
30    param_axis: str = "axis"
31
32
33@dataclass(init=False, frozen=True)
34class OpContextLoader:
35    namespace: str = "qaisw"
36    meta_ctx_bin: str = "qnn_context_binary"
37
38
39@dataclass(init=False, frozen=True)
40class OpConv2d:
41    op_name: str = "Conv2d"
42    param_stride: str = "stride"
43    param_pad_amount: str = "pad_amount"
44    param_group: str = "group"
45    param_dilation: str = "dilation"
46
47
48@dataclass(init=False, frozen=True)
49class OpConvert:
50    op_name: str = "Convert"
51
52
53@dataclass(init=False, frozen=True)
54class OpDepthToSpace:
55    op_name: str = "DepthToSpace"
56    param_block_size: str = "block_size"
57    param_mode: str = "mode"
58
59    @unique
60    class Mode(IntEnum):
61        DCR = 0
62        CRD = 1
63
64
65@dataclass(init=False, frozen=True)
66class OpDepthWiseConv2d:
67    op_name: str = "DepthWiseConv2d"
68    param_stride: str = "stride"
69    param_pad_amount: str = "pad_amount"
70    param_dilation: str = "dilation"
71
72
73@dataclass(init=False, frozen=True)
74class OpDequantize:
75    op_name: str = "Dequantize"
76
77
78@dataclass(init=False, frozen=True)
79class OpElementWiseAdd:
80    op_name: str = "ElementWiseAdd"
81
82
83@dataclass(init=False, frozen=True)
84class OpElementWiseCeil:
85    op_name = "ElementWiseCeil"
86
87
88@dataclass(init=False, frozen=True)
89class OpElementWiseDivide:
90    op_name: str = "ElementWiseDivide"
91
92
93@dataclass(init=False, frozen=True)
94class OpElementWiseMultiply:
95    op_name: str = "ElementWiseMultiply"
96
97
98@dataclass(init=False, frozen=True)
99class OpElementWiseNeuron:
100    op_name: str = "ElementWiseNeuron"
101    param_operation: str = "operation"
102    param_alpha: str = "alpha"
103    param_beta: str = "beta"
104
105
106@dataclass(init=False, frozen=True)
107class OpElementWisePower:
108    op_name: str = "ElementWisePower"
109
110
111@dataclass(init=False, frozen=True)
112class OpElementWiseRsqrt:
113    op_name: str = "ElementWiseRsqrt"
114
115
116@dataclass(init=False, frozen=True)
117class OpElementWiseSubtract:
118    op_name = "ElementWiseSubtract"
119
120
121@dataclass(init=False, frozen=True)
122class OpExpandDims:
123    op_name: str = "ExpandDims"
124    param_axis: str = "axis"
125
126
127@dataclass(init=False, frozen=True)
128class OpFullyConnected:
129    op_name: str = "FullyConnected"
130    param_keep_dims: str = "keep_dims"
131
132
133@dataclass(init=False, frozen=True)
134class OpGather:
135    op_name: str = "Gather"
136    param_axis: str = "axis"
137
138
139@dataclass(init=False, frozen=True)
140class OpGatherND:
141    op_name: str = "GatherNd"
142    param_batch_dims: str = "batch_dims"
143
144
145@dataclass(init=False, frozen=True)
146class OpGelu:
147    op_name: str = "Gelu"
148
149
150class OpGroupNorm:
151    op_name: str = "GroupNorm"
152    param_epsilon = "epsilon"
153    param_group = "group"
154
155
156@dataclass(init=False, frozen=True)
157class OpHardSwish:
158    op_name: str = "HardSwish"
159
160
161@dataclass(init=False, frozen=True)
162class OpLayerNorm:
163    op_name: str = "LayerNorm"
164    param_epsilon = "epsilon"
165    param_axes = "axes"
166
167
168@dataclass(init=False, frozen=True)
169class OpLogSoftmax:
170    op_name: str = "LogSoftmax"
171    param_axis: str = "axis"
172    param_beta: str = "beta"
173
174
175@dataclass(init=False, frozen=True)
176class OpMatMul:
177    op_name: str = "MatMul"
178    param_transpose_in0: str = "transpose_in0"
179    param_transpose_in1: str = "transpose_in1"
180
181
182@dataclass(init=False, frozen=True)
183class OpPack:
184    op_name: str = "Pack"
185    param_axis: str = "axis"
186
187
188@dataclass(init=False, frozen=True)
189class OpPad:
190    op_name: str = "Pad"
191    param_scheme: str = "scheme"
192    param_pad_amount: str = "pad_amount"
193    param_pad_constant_value: str = "pad_constant_value"
194
195    @unique
196    class Scheme(IntEnum):
197        CONSTANT = 0
198        MIRROR_SYMMETRIC = 1
199        MIRROR_REFLECT = 2
200        EDGE = 3
201
202
203@dataclass(init=False, frozen=True)
204class OpPoolAvg2d:
205    op_name: str = "PoolAvg2d"
206    param_filter_size: str = "filter_size"
207    param_stride: str = "stride"
208    param_pad_amount: str = "pad_amount"
209    param_count_pad_for_edges: str = "count_pad_for_edges"
210    param_rounding_mode: str = "rounding_mode"
211
212    @unique
213    class RoundingMode(IntEnum):
214        FLOOR = 0
215        CEIL = 1
216
217
218@dataclass(init=False, frozen=True)
219class OpPoolMax2d:
220    op_name: str = "PoolMax2d"
221    param_filter_size: str = "filter_size"
222    param_stride: str = "stride"
223    param_pad_amount: str = "pad_amount"
224    param_rounding_mode: str = "rounding_mode"
225
226    @unique
227    class RoundingMode(IntEnum):
228        FLOOR = 0
229        CEIL = 1
230
231
232@dataclass(init=False, frozen=True)
233class OpPRelu:
234    op_name: str = "Prelu"
235
236
237@dataclass(init=False, frozen=True)
238class OpQuantize:
239    op_name: str = "Quantize"
240
241
242@dataclass(init=False, frozen=True)
243class OpReduceMean:
244    op_name: str = "ReduceMean"
245    param_axes: str = "axes"
246    param_keep_dims: str = "keep_dims"
247
248
249@dataclass(init=False, frozen=True)
250class OpReduceSum:
251    op_name: str = "ReduceSum"
252    param_axes: str = "axes"
253    param_keep_dims: str = "keep_dims"
254
255
256@dataclass(init=False, frozen=True)
257class OpRelu:
258    op_name: str = "Relu"
259
260
261@dataclass(init=False, frozen=True)
262class OpReluMinMax:
263    op_name: str = "ReluMinMax"
264    param_min_value: str = "min_value"
265    param_max_value: str = "max_value"
266
267
268@dataclass(init=False, frozen=True)
269class OpReshape:
270    op_name: str = "Reshape"
271
272
273@dataclass(init=False, frozen=True)
274class OpResizeBilinear:
275    op_name: str = "ResizeBilinear"
276    param_align_corners: str = "align_corners"
277    param_half_pixel_centers: str = "half_pixel_centers"
278
279
280@dataclass(init=False, frozen=True)
281class OpResizeNearestNeighbor:
282    op_name: str = "ResizeNearestNeighbor"
283    param_align_corners: str = "align_corners"
284    param_half_pixel_centers: str = "half_pixel_centers"
285
286
287@dataclass(init=False, frozen=True)
288class OpRmsNorm:
289    op_name: str = "RmsNorm"
290    param_epsilon: str = "epsilon"
291    param_axes: str = "axes"
292
293
294@dataclass(init=False, frozen=True)
295class OpScatterNd:
296    op_name: str = "ScatterNd"
297    param_reduction: str = "reduction"
298
299
300@dataclass(init=False, frozen=True)
301class OpSigmoid:
302    op_name: str = "Sigmoid"
303
304
305@dataclass(init=False, frozen=True)
306class OpSoftmax:
307    op_name: str = "Softmax"
308    param_axis: str = "axis"
309    param_beta: str = "beta"
310
311
312@dataclass(init=False, frozen=True)
313class OpSpaceToDepth:
314    op_name: str = "SpaceToDepth"
315    param_block_size: str = "block_size"
316    param_mode: str = "mode"
317
318    @unique
319    class Mode(IntEnum):
320        DCR = 0
321        CRD = 1
322
323
324class OpSplit:
325    op_name: str = "Split"
326    param_axis: str = "axis"
327    param_split_index: str = "split_index"
328
329
330@dataclass(init=False, frozen=True)
331class OpSqrt:
332    op_name: str = "ElementWiseSquareRoot"
333
334
335@dataclass(init=False, frozen=True)
336class OpSqueeze:
337    op_name: str = "Squeeze"
338
339
340@dataclass(init=False, frozen=True)
341class OpStridedSlice:
342    op_name: str = "StridedSlice"
343    param_ranges: str = "ranges"
344    param_begin_mask: str = "begin_mask"
345    param_end_mask: str = "end_mask"
346    param_shrink_axes: str = "shrink_axes"
347    param_new_axes_mask: str = "new_axes_mask"
348
349
350@dataclass(init=False, frozen=True)
351class OpTanh:
352    op_name: str = "Tanh"
353
354
355@dataclass(init=False, frozen=True)
356class OpTile:
357    op_name: str = "Tile"
358    param_multiples: str = "multiples"
359
360
361@dataclass(init=False, frozen=True)
362class OpTopK:
363    op_name: str = "TopK"
364    param_k: str = "k"
365    param_largest: str = "largest"
366
367
368@dataclass(init=False, frozen=True)
369class OpTranspose:
370    op_name: str = "Transpose"
371    param_perm: str = "perm"
372
373
374@dataclass(init=False, frozen=True)
375class OpTransposeConv2d:
376    op_name: str = "TransposeConv2d"
377    param_stride: str = "stride"
378    param_pad_amount: str = "pad_amount"
379    param_group: str = "group"
380    param_output_padding: str = "output_padding"
381