xref: /aosp_15_r20/external/executorch/backends/xnnpack/serialization/xnnpack_graph_schema.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9"""
10Please refer to executorch/backends/xnnpack/serialization/schema.fbs for the schema definitions
11"""
12
13from dataclasses import dataclass
14from enum import IntEnum
15from typing import List, Optional, Union
16
17
18# Generic node data class with one input and one output
19@dataclass
20class XNNNode1x1:
21    input_id: int
22    output_id: int
23    flags: int
24
25
26# Generic node data class with two inputs and one output
27@dataclass
28class XNNNode2x1:
29    input1_id: int
30    input2_id: int
31    output_id: int
32    flags: int
33
34
35# Generic node data class for concatenation node
36@dataclass
37class XNNCat:
38    axis: int
39    input1_id: int
40    input2_id: int
41    input3_id: int
42    input4_id: int
43    output_id: int
44    flags: int
45
46
47# Generic node data class for convolution type nodes
48@dataclass
49class XNNNodeConv:
50    padding_top: int
51    padding_right: int
52    padding_bottom: int
53    padding_left: int
54    kernel_height: int
55    kernel_width: int
56    subsampling_height: int
57    subsampling_width: int
58    dilation_height: int
59    dilation_width: int
60    group_input_channels: int
61    group_output_channels: int
62    groups: int
63    adjustment_height: int
64    adjustment_width: int
65    input1_id: int
66    filter_id: int
67    bias_id: int
68    output_id: int
69    flags: int
70
71
72@dataclass
73class XNNPooling2D:
74    padding_top: int
75    padding_right: int
76    padding_bottom: int
77    padding_left: int
78    pooling_height: int
79    pooling_width: int
80    stride_height: int
81    stride_width: int
82    dilation_height: int
83    dilation_width: int
84    input_id: int
85    output_id: int
86    flags: int
87
88
89# Node data class for average pooling 2d
90@dataclass
91class XNNAvgPooling2d(XNNPooling2D):
92    pass
93
94
95@dataclass
96class XNNMaxPooling2d(XNNPooling2D):
97    pass
98
99
100@dataclass
101class XNNConv2d(XNNNodeConv):
102    pass
103
104
105@dataclass
106class XNNAdd(XNNNode2x1):
107    pass
108
109
110@dataclass
111class XNNGlobalAvgPooling2d(XNNNode1x1):
112    pass
113
114
115@dataclass
116class XNNDiv(XNNNode2x1):
117    pass
118
119
120@dataclass
121class XNNMultiply(XNNNode2x1):
122    pass
123
124
125@dataclass
126class XNNMinimum(XNNNode2x1):
127    pass
128
129
130@dataclass
131class XNNSubtract(XNNNode2x1):
132    pass
133
134
135@dataclass
136class XNNSoftmax(XNNNode1x1):
137    pass
138
139
140@dataclass
141class XNNSigmoid(XNNNode1x1):
142    pass
143
144
145@dataclass
146class XNNFloor(XNNNode1x1):
147    pass
148
149
150@dataclass
151class XNNConvert(XNNNode1x1):
152    pass
153
154
155@dataclass
156class XNNNegate(XNNNode1x1):
157    pass
158
159
160@dataclass
161class XNNAbs(XNNNode1x1):
162    pass
163
164
165@dataclass
166class XNNConcatenate2(XNNCat):
167    pass
168
169
170@dataclass
171class XNNConcatenate3(XNNCat):
172    pass
173
174
175@dataclass
176class XNNConcatenate4(XNNCat):
177    pass
178
179
180@dataclass
181class XNNBatchMatrixMultiply(XNNNode2x1):
182    pass
183
184
185@dataclass
186class XNNStaticTranspose:
187    num_dims: int
188    perm: List[int]
189    input_id: int
190    output_id: int
191    flags: int
192
193
194@dataclass
195class XNNStaticSlice:
196    num_dims: int
197    offsets: List[int]
198    sizes: List[int]
199    input_id: int
200    output_id: int
201    flags: int
202
203
204@dataclass
205class XNNClamp(XNNNode1x1):
206    pass
207
208
209@dataclass
210class XNNStaticResizeBilinear2D:
211    new_height: int
212    new_width: int
213    input_id: int
214    output_id: int
215    flags: int
216
217
218@dataclass
219class XNNStaticConstantPad:
220    pre_paddings: List[int]
221    post_paddings: List[int]
222    padding_value: float
223    input_id: int
224    output_id: int
225    flags: int
226
227
228@dataclass
229class XNNDepthwiseConv2d(XNNNodeConv):
230    pass
231
232
233@dataclass
234class XNNArgMaxPooling2d:
235    padding_top: int
236    padding_right: int
237    padding_bottom: int
238    padding_left: int
239    pooling_height: int
240    pooling_width: int
241    input_id: int
242    output_value_id: int
243    output_index_id: int
244    flags: int
245
246
247# this class such that Python can infer the XNodeUnion Type. If there is only type in Union, like
248# Union[XNNAdd], python will infer it's XNNAdd type instead of Union type. After we add more operators
249# this one can be removed.
250@dataclass
251class XNNFullyConnected:  # aten::Linear
252    input1_id: int
253    filter_id: int
254    bias_id: int
255    output_id: int
256    flags: int
257
258
259@dataclass
260class XNNStaticReshape:
261    num_dims: int
262    new_shape: List[int]
263    input_id: int
264    output_id: int
265    flags: int
266
267
268@dataclass
269class XNNSquareRoot(XNNNode1x1):
270    pass
271
272
273@dataclass
274class XNNCeiling(XNNNode1x1):
275    pass
276
277
278@dataclass
279class XNNHardswish(XNNNode1x1):
280    pass
281
282
283@dataclass
284class XNNSquare(XNNNode1x1):
285    pass
286
287
288@dataclass
289class XNNLeakyReLU:
290    negative_slope: float
291    input_id: int
292    output_id: int
293    flags: int
294
295
296@dataclass
297class XNNMaximum(XNNNode2x1):
298    pass
299
300
301@dataclass
302class XNNELU:
303    alpha: float
304    input_id: int
305    output_id: int
306    flags: int
307
308
309@dataclass
310class XNNPReLU(XNNNode2x1):
311    pass
312
313
314@dataclass
315class XNNScaledDotProductAttention:
316    query_id: int
317    key_id: int
318    value_id: int
319    scale_id: int
320    mask_id: int
321    output_id: int
322    flags: int
323
324
325XNodeUnion = Union[
326    XNNAdd,
327    XNNFullyConnected,
328    XNNSoftmax,
329    XNNSigmoid,
330    XNNStaticTranspose,
331    XNNClamp,
332    XNNConv2d,
333    XNNDiv,
334    XNNStaticResizeBilinear2D,
335    XNNStaticConstantPad,
336    XNNAvgPooling2d,
337    XNNMinimum,
338    XNNDepthwiseConv2d,
339    XNNMaxPooling2d,
340    XNNMultiply,
341    XNNSubtract,
342    XNNFloor,
343    XNNConvert,
344    XNNGlobalAvgPooling2d,
345    XNNStaticReshape,
346    XNNArgMaxPooling2d,
347    XNNSquareRoot,
348    XNNCeiling,
349    XNNHardswish,
350    XNNLeakyReLU,
351    XNNMaximum,
352    XNNNegate,
353    XNNSquare,
354    XNNELU,
355    XNNAbs,
356    XNNPReLU,
357    XNNConcatenate2,
358    XNNConcatenate3,
359    XNNConcatenate4,
360    XNNStaticSlice,
361    XNNScaledDotProductAttention,
362    XNNBatchMatrixMultiply,
363]
364
365
366@dataclass
367class OutputMinMax:
368    output_min: Union[float, str]
369    output_max: Union[float, str]
370
371
372@dataclass
373class XNode:
374    xnode_union: "XNodeUnion"
375    debug_handle: int
376    output_min_max: Optional[OutputMinMax] = None
377
378
379class XNNDatatype(IntEnum):
380    xnn_datatype_invalid = 0
381    xnn_datatype_fp32 = 1
382    xnn_datatype_fp16 = 2
383    xnn_datatype_qint8 = 3
384    xnn_datatype_quint8 = 4
385    xnn_datatype_qint32 = 5
386    xnn_datatype_qcint8 = 6
387    xnn_datatype_qcint32 = 7
388    xnn_datatype_qcint4 = 8
389    xnn_datatype_qdint8 = 9
390    xnn_datatype_qbint4 = 10
391
392
393@dataclass
394class PerChannelQuant:
395    scale: List[float]
396    channel_dim: int
397
398
399@dataclass
400class PerChannelGroupQuant:
401    scale: List[float]
402    channel_dim: int
403    group_size: int = 1
404
405
406@dataclass
407class PerTokenDynamicQuant:
408    num_nonbatch_dims: int
409
410
411@dataclass
412class PerTensorQuant:
413    scale: float
414    zero_point: int
415
416
417XNNQuantParams = Union[
418    PerChannelQuant, PerTensorQuant, PerTokenDynamicQuant, PerChannelGroupQuant
419]
420
421
422@dataclass
423class XNNTensorValue:
424    datatype: XNNDatatype
425    num_dims: int
426    dims: List[int]
427    constant_buffer_idx: int
428    external_id: int
429    flags: int
430    id_out: int
431
432
433@dataclass
434class XNNQuantizedTensorValue:
435    tensor_value: XNNTensorValue
436    quant_params: "XNNQuantParams"
437
438
439XValueUnion = Union[
440    XNNTensorValue,
441    XNNQuantizedTensorValue,
442]
443
444
445@dataclass
446class XValue:
447    xvalue_union: "XValueUnion"
448
449
450@dataclass
451class ConstantDataOffset:
452    offset: int
453    size: int
454
455
456@dataclass
457class XNNGraph:
458    version: str
459    xnodes: List[XNode]
460    xvalues: List[XValue]
461
462    num_externs: int
463    input_ids: List[int]
464    output_ids: List[int]
465
466    constant_data: List[ConstantDataOffset]
467