1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9""" 10Please refer to executorch/backends/xnnpack/serialization/schema.fbs for the schema definitions 11""" 12 13from dataclasses import dataclass 14from enum import IntEnum 15from typing import List, Optional, Union 16 17 18# Generic node data class with one input and one output 19@dataclass 20class XNNNode1x1: 21 input_id: int 22 output_id: int 23 flags: int 24 25 26# Generic node data class with two inputs and one output 27@dataclass 28class XNNNode2x1: 29 input1_id: int 30 input2_id: int 31 output_id: int 32 flags: int 33 34 35# Generic node data class for concatenation node 36@dataclass 37class XNNCat: 38 axis: int 39 input1_id: int 40 input2_id: int 41 input3_id: int 42 input4_id: int 43 output_id: int 44 flags: int 45 46 47# Generic node data class for convolution type nodes 48@dataclass 49class XNNNodeConv: 50 padding_top: int 51 padding_right: int 52 padding_bottom: int 53 padding_left: int 54 kernel_height: int 55 kernel_width: int 56 subsampling_height: int 57 subsampling_width: int 58 dilation_height: int 59 dilation_width: int 60 group_input_channels: int 61 group_output_channels: int 62 groups: int 63 adjustment_height: int 64 adjustment_width: int 65 input1_id: int 66 filter_id: int 67 bias_id: int 68 output_id: int 69 flags: int 70 71 72@dataclass 73class XNNPooling2D: 74 padding_top: int 75 padding_right: int 76 padding_bottom: int 77 padding_left: int 78 pooling_height: int 79 pooling_width: int 80 stride_height: int 81 stride_width: int 82 dilation_height: int 83 dilation_width: int 84 input_id: int 85 output_id: int 86 flags: int 87 88 89# Node data class for average pooling 2d 90@dataclass 91class XNNAvgPooling2d(XNNPooling2D): 92 pass 93 94 95@dataclass 96class XNNMaxPooling2d(XNNPooling2D): 97 pass 98 99 100@dataclass 101class XNNConv2d(XNNNodeConv): 102 pass 103 104 105@dataclass 106class XNNAdd(XNNNode2x1): 107 pass 108 109 110@dataclass 111class XNNGlobalAvgPooling2d(XNNNode1x1): 112 pass 113 114 115@dataclass 116class XNNDiv(XNNNode2x1): 117 pass 118 119 120@dataclass 121class XNNMultiply(XNNNode2x1): 122 pass 123 124 125@dataclass 126class XNNMinimum(XNNNode2x1): 127 pass 128 129 130@dataclass 131class XNNSubtract(XNNNode2x1): 132 pass 133 134 135@dataclass 136class XNNSoftmax(XNNNode1x1): 137 pass 138 139 140@dataclass 141class XNNSigmoid(XNNNode1x1): 142 pass 143 144 145@dataclass 146class XNNFloor(XNNNode1x1): 147 pass 148 149 150@dataclass 151class XNNConvert(XNNNode1x1): 152 pass 153 154 155@dataclass 156class XNNNegate(XNNNode1x1): 157 pass 158 159 160@dataclass 161class XNNAbs(XNNNode1x1): 162 pass 163 164 165@dataclass 166class XNNConcatenate2(XNNCat): 167 pass 168 169 170@dataclass 171class XNNConcatenate3(XNNCat): 172 pass 173 174 175@dataclass 176class XNNConcatenate4(XNNCat): 177 pass 178 179 180@dataclass 181class XNNBatchMatrixMultiply(XNNNode2x1): 182 pass 183 184 185@dataclass 186class XNNStaticTranspose: 187 num_dims: int 188 perm: List[int] 189 input_id: int 190 output_id: int 191 flags: int 192 193 194@dataclass 195class XNNStaticSlice: 196 num_dims: int 197 offsets: List[int] 198 sizes: List[int] 199 input_id: int 200 output_id: int 201 flags: int 202 203 204@dataclass 205class XNNClamp(XNNNode1x1): 206 pass 207 208 209@dataclass 210class XNNStaticResizeBilinear2D: 211 new_height: int 212 new_width: int 213 input_id: int 214 output_id: int 215 flags: int 216 217 218@dataclass 219class XNNStaticConstantPad: 220 pre_paddings: List[int] 221 post_paddings: List[int] 222 padding_value: float 223 input_id: int 224 output_id: int 225 flags: int 226 227 228@dataclass 229class XNNDepthwiseConv2d(XNNNodeConv): 230 pass 231 232 233@dataclass 234class XNNArgMaxPooling2d: 235 padding_top: int 236 padding_right: int 237 padding_bottom: int 238 padding_left: int 239 pooling_height: int 240 pooling_width: int 241 input_id: int 242 output_value_id: int 243 output_index_id: int 244 flags: int 245 246 247# this class such that Python can infer the XNodeUnion Type. If there is only type in Union, like 248# Union[XNNAdd], python will infer it's XNNAdd type instead of Union type. After we add more operators 249# this one can be removed. 250@dataclass 251class XNNFullyConnected: # aten::Linear 252 input1_id: int 253 filter_id: int 254 bias_id: int 255 output_id: int 256 flags: int 257 258 259@dataclass 260class XNNStaticReshape: 261 num_dims: int 262 new_shape: List[int] 263 input_id: int 264 output_id: int 265 flags: int 266 267 268@dataclass 269class XNNSquareRoot(XNNNode1x1): 270 pass 271 272 273@dataclass 274class XNNCeiling(XNNNode1x1): 275 pass 276 277 278@dataclass 279class XNNHardswish(XNNNode1x1): 280 pass 281 282 283@dataclass 284class XNNSquare(XNNNode1x1): 285 pass 286 287 288@dataclass 289class XNNLeakyReLU: 290 negative_slope: float 291 input_id: int 292 output_id: int 293 flags: int 294 295 296@dataclass 297class XNNMaximum(XNNNode2x1): 298 pass 299 300 301@dataclass 302class XNNELU: 303 alpha: float 304 input_id: int 305 output_id: int 306 flags: int 307 308 309@dataclass 310class XNNPReLU(XNNNode2x1): 311 pass 312 313 314@dataclass 315class XNNScaledDotProductAttention: 316 query_id: int 317 key_id: int 318 value_id: int 319 scale_id: int 320 mask_id: int 321 output_id: int 322 flags: int 323 324 325XNodeUnion = Union[ 326 XNNAdd, 327 XNNFullyConnected, 328 XNNSoftmax, 329 XNNSigmoid, 330 XNNStaticTranspose, 331 XNNClamp, 332 XNNConv2d, 333 XNNDiv, 334 XNNStaticResizeBilinear2D, 335 XNNStaticConstantPad, 336 XNNAvgPooling2d, 337 XNNMinimum, 338 XNNDepthwiseConv2d, 339 XNNMaxPooling2d, 340 XNNMultiply, 341 XNNSubtract, 342 XNNFloor, 343 XNNConvert, 344 XNNGlobalAvgPooling2d, 345 XNNStaticReshape, 346 XNNArgMaxPooling2d, 347 XNNSquareRoot, 348 XNNCeiling, 349 XNNHardswish, 350 XNNLeakyReLU, 351 XNNMaximum, 352 XNNNegate, 353 XNNSquare, 354 XNNELU, 355 XNNAbs, 356 XNNPReLU, 357 XNNConcatenate2, 358 XNNConcatenate3, 359 XNNConcatenate4, 360 XNNStaticSlice, 361 XNNScaledDotProductAttention, 362 XNNBatchMatrixMultiply, 363] 364 365 366@dataclass 367class OutputMinMax: 368 output_min: Union[float, str] 369 output_max: Union[float, str] 370 371 372@dataclass 373class XNode: 374 xnode_union: "XNodeUnion" 375 debug_handle: int 376 output_min_max: Optional[OutputMinMax] = None 377 378 379class XNNDatatype(IntEnum): 380 xnn_datatype_invalid = 0 381 xnn_datatype_fp32 = 1 382 xnn_datatype_fp16 = 2 383 xnn_datatype_qint8 = 3 384 xnn_datatype_quint8 = 4 385 xnn_datatype_qint32 = 5 386 xnn_datatype_qcint8 = 6 387 xnn_datatype_qcint32 = 7 388 xnn_datatype_qcint4 = 8 389 xnn_datatype_qdint8 = 9 390 xnn_datatype_qbint4 = 10 391 392 393@dataclass 394class PerChannelQuant: 395 scale: List[float] 396 channel_dim: int 397 398 399@dataclass 400class PerChannelGroupQuant: 401 scale: List[float] 402 channel_dim: int 403 group_size: int = 1 404 405 406@dataclass 407class PerTokenDynamicQuant: 408 num_nonbatch_dims: int 409 410 411@dataclass 412class PerTensorQuant: 413 scale: float 414 zero_point: int 415 416 417XNNQuantParams = Union[ 418 PerChannelQuant, PerTensorQuant, PerTokenDynamicQuant, PerChannelGroupQuant 419] 420 421 422@dataclass 423class XNNTensorValue: 424 datatype: XNNDatatype 425 num_dims: int 426 dims: List[int] 427 constant_buffer_idx: int 428 external_id: int 429 flags: int 430 id_out: int 431 432 433@dataclass 434class XNNQuantizedTensorValue: 435 tensor_value: XNNTensorValue 436 quant_params: "XNNQuantParams" 437 438 439XValueUnion = Union[ 440 XNNTensorValue, 441 XNNQuantizedTensorValue, 442] 443 444 445@dataclass 446class XValue: 447 xvalue_union: "XValueUnion" 448 449 450@dataclass 451class ConstantDataOffset: 452 offset: int 453 size: int 454 455 456@dataclass 457class XNNGraph: 458 version: str 459 xnodes: List[XNode] 460 xvalues: List[XValue] 461 462 num_externs: int 463 input_ids: List[int] 464 output_ids: List[int] 465 466 constant_data: List[ConstantDataOffset] 467