1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# NOTE: This is a placeholder for iterating on export serialization schema design. 8# Anything is subject to change and no guarantee is provided at this point. 9 10from dataclasses import dataclass, field 11from enum import IntEnum 12from typing import Dict, List, Optional, Tuple 13 14import executorch.exir.serde.schema as export_schema 15 16from executorch.exir.serde.union import _Union 17 18# NOTE: Please update this value if any modifications are made to the schema 19SCHEMA_VERSION = (5, 3) 20TREESPEC_VERSION = 1 21 22 23class ScalarType(IntEnum): 24 UNKNOWN = 0 25 BYTE = 1 26 CHAR = 2 27 SHORT = 3 28 INT = 4 29 LONG = 5 30 HALF = 6 31 FLOAT = 7 32 DOUBLE = 8 33 COMPLEXHALF = 9 34 COMPLEXFLOAT = 10 35 COMPLEXDOUBLE = 11 36 BOOL = 12 37 BFLOAT16 = 13 38 UINT16 = 14 39 40class Layout(IntEnum): 41 Unknown = 0 42 SparseCoo = 1 43 SparseCsr = 2 44 SparseCsc = 3 45 SparseBsr = 4 46 SparseBsc = 5 47 _mkldnn = 6 48 Strided = 7 49 50 51class MemoryFormat(IntEnum): 52 Unknown = 0 53 ContiguousFormat = 1 54 ChannelsLast = 2 55 ChannelsLast3d = 3 56 PreserveFormat = 4 57 58 59@dataclass 60class Device: 61 type: str 62 index: Optional[int] = None 63 64 65@dataclass(repr=False) 66class SymExprHint(_Union): 67 as_int: int 68 as_float: float 69 as_bool: bool 70 71 72# This is for storing the symbolic expressions behind symints/symfloats/symbools 73# For example, we can get something like 74# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) 75# if we also have the hint that s0 and s1 are both 2. 76@dataclass 77class SymExpr: 78 expr_str: str 79 hint: Optional[SymExprHint] = None 80 81 82@dataclass(repr=False) 83class SymInt(_Union): 84 as_expr: SymExpr 85 as_int: int 86 87 88@dataclass(repr=False) 89class SymBool(_Union): 90 as_expr: SymExpr 91 as_bool: bool 92 93 94@dataclass 95class TensorMeta: 96 dtype: ScalarType 97 sizes: List[SymInt] 98 requires_grad: bool 99 device: Device 100 strides: List[SymInt] 101 storage_offset: SymInt 102 layout: Layout 103 104 105# In most cases we will use the "as_name" field to store arguments which are 106# SymInts. 107# The "as_int" field is used in the case where we have a list containing a mix 108# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to 109# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints 110# to the "as_int" field. 111@dataclass(repr=False) 112class SymIntArgument(_Union): 113 as_name: str 114 as_int: int 115 116 117# In most cases we will use the "as_name" field to store arguments which are 118# SymBools. 119# The "as_bool" field is used in the case where we have a list containing a mix 120# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to 121# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools 122# to the "as_bool" field. 123@dataclass(repr=False) 124class SymBoolArgument(_Union): 125 as_name: str 126 as_bool: bool 127 128 129@dataclass 130class TensorArgument: 131 name: str 132 133 134@dataclass 135class TokenArgument: 136 name: str 137 138 139# This is use for storing the contents of a list which contain optional tensors 140# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the 141# type List[OptionalTensorArgument], with tensor values seiralized to the 142# "as_tensor" field, and None values serialized to the "as_none" field. 143@dataclass(repr=False) 144class OptionalTensorArgument(_Union): 145 as_tensor: TensorArgument 146 as_none: Tuple[()] 147 148 149@dataclass 150class GraphArgument: 151 name: str 152 graph: "Graph" 153 154 155@dataclass 156class CustomObjArgument: 157 name: str 158 class_fqn: str 159 160 161# This is actually a union type 162@dataclass(repr=False) 163class Argument(_Union): 164 as_none: Tuple[()] 165 as_tensor: TensorArgument 166 as_tensors: List[TensorArgument] 167 as_int: int 168 as_ints: List[int] 169 as_float: float 170 as_floats: List[float] 171 as_string: str 172 as_strings: List[str] 173 as_sym_int: SymIntArgument 174 as_sym_ints: List[SymIntArgument] 175 as_scalar_type: ScalarType 176 as_memory_format: MemoryFormat 177 as_layout: Layout 178 as_device: Device 179 as_bool: bool 180 as_bools: List[bool] 181 as_sym_bool: SymBoolArgument 182 as_sym_bools: List[SymBoolArgument] 183 as_graph: GraphArgument 184 as_optional_tensors: List[OptionalTensorArgument] 185 as_custom_obj: CustomObjArgument 186 as_operator: str 187 188 189@dataclass 190class NamedArgument: 191 # Argument name from the operator schema 192 name: str 193 arg: Argument 194 195 196@dataclass 197class Node: 198 target: str 199 inputs: List[NamedArgument] 200 outputs: List[Argument] 201 metadata: Dict[str, str] 202 203 204@dataclass 205class Graph: 206 inputs: List[Argument] 207 outputs: List[Argument] 208 nodes: List[Node] 209 tensor_values: Dict[str, TensorMeta] 210 sym_int_values: Dict[str, SymInt] 211 sym_bool_values: Dict[str, SymBool] 212 # This is for deserializing the submodule graphs from higher order ops 213 # (ex. cond, map) where single tensor returns will just return a single 214 # tensor, rather than following export schema and returning a singleton 215 # list. 216 is_single_tensor_return: bool = False 217 custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) 218 219 220@dataclass 221class UserInputSpec: 222 # Actually, only tensors and SymInts are allowed here 223 arg: Argument 224 225 226@dataclass(repr=False) 227class ConstantValue(_Union): 228 as_none: Tuple[()] 229 as_int: int 230 as_float: float 231 as_string: str 232 as_bool: bool 233 234 235@dataclass 236class ConstantInputSpec: 237 name: str 238 value: ConstantValue 239 240 241@dataclass 242class InputToParameterSpec: 243 arg: TensorArgument 244 parameter_name: str 245 246 247@dataclass 248class InputToBufferSpec: 249 arg: TensorArgument 250 buffer_name: str 251 persistent: bool 252 253 254@dataclass 255class InputToTensorConstantSpec: 256 arg: TensorArgument 257 tensor_constant_name: str 258 259 260@dataclass 261class InputToCustomObjSpec: 262 arg: CustomObjArgument 263 custom_obj_name: str 264 265 266@dataclass 267class InputTokenSpec: 268 arg: TokenArgument 269 270 271@dataclass(repr=False) 272class InputSpec(_Union): 273 user_input: UserInputSpec 274 parameter: InputToParameterSpec 275 buffer: InputToBufferSpec 276 tensor_constant: InputToTensorConstantSpec 277 custom_obj: InputToCustomObjSpec 278 token: InputTokenSpec 279 constant_input: ConstantInputSpec 280 281 282@dataclass 283class UserOutputSpec: 284 arg: Argument 285 286 287@dataclass 288class LossOutputSpec: 289 arg: TensorArgument 290 291 292@dataclass 293class BufferMutationSpec: 294 arg: TensorArgument 295 buffer_name: str 296 297 298@dataclass 299class GradientToParameterSpec: 300 arg: TensorArgument 301 parameter_name: str 302 303 304@dataclass 305class GradientToUserInputSpec: 306 arg: TensorArgument 307 user_input_name: str 308 309 310@dataclass 311class UserInputMutationSpec: 312 arg: TensorArgument 313 user_input_name: str 314 315 316@dataclass 317class OutputTokenSpec: 318 arg: TokenArgument 319 320 321@dataclass(repr=False) 322class OutputSpec(_Union): 323 user_output: UserOutputSpec 324 loss_output: LossOutputSpec 325 buffer_mutation: BufferMutationSpec 326 gradient_to_parameter: GradientToParameterSpec 327 gradient_to_user_input: GradientToUserInputSpec 328 user_input_mutation: UserInputMutationSpec 329 token: OutputTokenSpec 330 331 332@dataclass 333class GraphSignature: 334 input_specs: List[InputSpec] 335 output_specs: List[OutputSpec] 336 337 338@dataclass 339class RangeConstraint: 340 min_val: int 341 max_val: int 342 343 344@dataclass 345class ModuleCallSignature: 346 inputs: List[Argument] 347 outputs: List[Argument] 348 349 # These are serialized by calling pytree.treespec_loads 350 # And deserialized by calling pytree.treespec_dumps 351 in_spec: str 352 out_spec: str 353 354 355@dataclass 356class ModuleCallEntry: 357 fqn: str 358 signature: Optional[ModuleCallSignature] = None 359 360 361@dataclass 362class GraphModule: 363 graph: Graph 364 signature: GraphSignature 365 # This is used for unflattening, by tracking the calling structure of all of 366 # the modules in order to unflatten the modules back to the eager calling 367 # conventions. 368 module_call_graph: List[ModuleCallEntry] 369 370 371# Invariant: Every time a change is made to the schema, one of the versions 372# should be upadted. 373@dataclass 374class SchemaVersion: 375 major: int # Major version number is bumped every time a breaking change is made. 376 minor: int # Minor version number is bumped when a compatible change is made. 377 378 379@dataclass 380class ExportedProgram: 381 graph_module: GraphModule 382 # Key is the opset namespace (ex. aten), and value is the version number 383 opset_version: Dict[str, int] 384 range_constraints: Dict[str, RangeConstraint] 385 schema_version: SchemaVersion 386 dialect: str 387 verifiers: List[str] = field(default_factory=list) 388 dialect: str = "" # TODO deprecated 389 390 391@dataclass 392class CompileSpec: 393 key: str 394 value: str 395 396 397@dataclass 398class LoweredBackendModule: 399 backend_id: str 400 processed_bytes: str 401 compile_specs: List[CompileSpec] 402 original_module: export_schema.ExportedProgram 403 original_state_dict: str 404 original_constants: str 405