xref: /aosp_15_r20/external/executorch/exir/serde/schema.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# NOTE: This is a placeholder for iterating on export serialization schema design.
8#       Anything is subject to change and no guarantee is provided at this point.
9
10from dataclasses import dataclass, field
11from enum import IntEnum
12from typing import Dict, List, Optional, Tuple
13
14import executorch.exir.serde.schema as export_schema
15
16from executorch.exir.serde.union import _Union
17
18# NOTE: Please update this value if any modifications are made to the schema
19SCHEMA_VERSION = (5, 3)
20TREESPEC_VERSION = 1
21
22
23class ScalarType(IntEnum):
24    UNKNOWN = 0
25    BYTE = 1
26    CHAR = 2
27    SHORT = 3
28    INT = 4
29    LONG = 5
30    HALF = 6
31    FLOAT = 7
32    DOUBLE = 8
33    COMPLEXHALF = 9
34    COMPLEXFLOAT = 10
35    COMPLEXDOUBLE = 11
36    BOOL = 12
37    BFLOAT16 = 13
38    UINT16 = 14
39
40class Layout(IntEnum):
41    Unknown = 0
42    SparseCoo = 1
43    SparseCsr = 2
44    SparseCsc = 3
45    SparseBsr = 4
46    SparseBsc = 5
47    _mkldnn = 6
48    Strided = 7
49
50
51class MemoryFormat(IntEnum):
52    Unknown = 0
53    ContiguousFormat = 1
54    ChannelsLast = 2
55    ChannelsLast3d = 3
56    PreserveFormat = 4
57
58
59@dataclass
60class Device:
61    type: str
62    index: Optional[int] = None
63
64
65@dataclass(repr=False)
66class SymExprHint(_Union):
67    as_int: int
68    as_float: float
69    as_bool: bool
70
71
72# This is for storing the symbolic expressions behind symints/symfloats/symbools
73# For example, we can get something like
74# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4)
75# if we also have the hint that s0 and s1 are both 2.
76@dataclass
77class SymExpr:
78    expr_str: str
79    hint: Optional[SymExprHint] = None
80
81
82@dataclass(repr=False)
83class SymInt(_Union):
84    as_expr: SymExpr
85    as_int: int
86
87
88@dataclass(repr=False)
89class SymBool(_Union):
90    as_expr: SymExpr
91    as_bool: bool
92
93
94@dataclass
95class TensorMeta:
96    dtype: ScalarType
97    sizes: List[SymInt]
98    requires_grad: bool
99    device: Device
100    strides: List[SymInt]
101    storage_offset: SymInt
102    layout: Layout
103
104
105# In most cases we will use the "as_name" field to store arguments which are
106# SymInts.
107# The "as_int" field is used in the case where we have a list containing a mix
108# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to
109# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints
110# to the "as_int" field.
111@dataclass(repr=False)
112class SymIntArgument(_Union):
113    as_name: str
114    as_int: int
115
116
117# In most cases we will use the "as_name" field to store arguments which are
118# SymBools.
119# The "as_bool" field is used in the case where we have a list containing a mix
120# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to
121# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools
122# to the "as_bool" field.
123@dataclass(repr=False)
124class SymBoolArgument(_Union):
125    as_name: str
126    as_bool: bool
127
128
129@dataclass
130class TensorArgument:
131    name: str
132
133
134@dataclass
135class TokenArgument:
136    name: str
137
138
139# This is use for storing the contents of a list which contain optional tensors
140# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the
141# type List[OptionalTensorArgument], with tensor values seiralized to the
142# "as_tensor" field, and None values serialized to the "as_none" field.
143@dataclass(repr=False)
144class OptionalTensorArgument(_Union):
145    as_tensor: TensorArgument
146    as_none: Tuple[()]
147
148
149@dataclass
150class GraphArgument:
151    name: str
152    graph: "Graph"
153
154
155@dataclass
156class CustomObjArgument:
157    name: str
158    class_fqn: str
159
160
161# This is actually a union type
162@dataclass(repr=False)
163class Argument(_Union):
164    as_none: Tuple[()]
165    as_tensor: TensorArgument
166    as_tensors: List[TensorArgument]
167    as_int: int
168    as_ints: List[int]
169    as_float: float
170    as_floats: List[float]
171    as_string: str
172    as_strings: List[str]
173    as_sym_int: SymIntArgument
174    as_sym_ints: List[SymIntArgument]
175    as_scalar_type: ScalarType
176    as_memory_format: MemoryFormat
177    as_layout: Layout
178    as_device: Device
179    as_bool: bool
180    as_bools: List[bool]
181    as_sym_bool: SymBoolArgument
182    as_sym_bools: List[SymBoolArgument]
183    as_graph: GraphArgument
184    as_optional_tensors: List[OptionalTensorArgument]
185    as_custom_obj: CustomObjArgument
186    as_operator: str
187
188
189@dataclass
190class NamedArgument:
191    # Argument name from the operator schema
192    name: str
193    arg: Argument
194
195
196@dataclass
197class Node:
198    target: str
199    inputs: List[NamedArgument]
200    outputs: List[Argument]
201    metadata: Dict[str, str]
202
203
204@dataclass
205class Graph:
206    inputs: List[Argument]
207    outputs: List[Argument]
208    nodes: List[Node]
209    tensor_values: Dict[str, TensorMeta]
210    sym_int_values: Dict[str, SymInt]
211    sym_bool_values: Dict[str, SymBool]
212    # This is for deserializing the submodule graphs from higher order ops
213    # (ex. cond, map) where single tensor returns will just return a single
214    # tensor, rather than following export schema and returning a singleton
215    # list.
216    is_single_tensor_return: bool = False
217    custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
218
219
220@dataclass
221class UserInputSpec:
222    # Actually, only tensors and SymInts are allowed here
223    arg: Argument
224
225
226@dataclass(repr=False)
227class ConstantValue(_Union):
228    as_none: Tuple[()]
229    as_int: int
230    as_float: float
231    as_string: str
232    as_bool: bool
233
234
235@dataclass
236class ConstantInputSpec:
237    name: str
238    value: ConstantValue
239
240
241@dataclass
242class InputToParameterSpec:
243    arg: TensorArgument
244    parameter_name: str
245
246
247@dataclass
248class InputToBufferSpec:
249    arg: TensorArgument
250    buffer_name: str
251    persistent: bool
252
253
254@dataclass
255class InputToTensorConstantSpec:
256    arg: TensorArgument
257    tensor_constant_name: str
258
259
260@dataclass
261class InputToCustomObjSpec:
262    arg: CustomObjArgument
263    custom_obj_name: str
264
265
266@dataclass
267class InputTokenSpec:
268    arg: TokenArgument
269
270
271@dataclass(repr=False)
272class InputSpec(_Union):
273    user_input: UserInputSpec
274    parameter: InputToParameterSpec
275    buffer: InputToBufferSpec
276    tensor_constant: InputToTensorConstantSpec
277    custom_obj: InputToCustomObjSpec
278    token: InputTokenSpec
279    constant_input: ConstantInputSpec
280
281
282@dataclass
283class UserOutputSpec:
284    arg: Argument
285
286
287@dataclass
288class LossOutputSpec:
289    arg: TensorArgument
290
291
292@dataclass
293class BufferMutationSpec:
294    arg: TensorArgument
295    buffer_name: str
296
297
298@dataclass
299class GradientToParameterSpec:
300    arg: TensorArgument
301    parameter_name: str
302
303
304@dataclass
305class GradientToUserInputSpec:
306    arg: TensorArgument
307    user_input_name: str
308
309
310@dataclass
311class UserInputMutationSpec:
312    arg: TensorArgument
313    user_input_name: str
314
315
316@dataclass
317class OutputTokenSpec:
318    arg: TokenArgument
319
320
321@dataclass(repr=False)
322class OutputSpec(_Union):
323    user_output: UserOutputSpec
324    loss_output: LossOutputSpec
325    buffer_mutation: BufferMutationSpec
326    gradient_to_parameter: GradientToParameterSpec
327    gradient_to_user_input: GradientToUserInputSpec
328    user_input_mutation: UserInputMutationSpec
329    token: OutputTokenSpec
330
331
332@dataclass
333class GraphSignature:
334    input_specs: List[InputSpec]
335    output_specs: List[OutputSpec]
336
337
338@dataclass
339class RangeConstraint:
340    min_val: int
341    max_val: int
342
343
344@dataclass
345class ModuleCallSignature:
346    inputs: List[Argument]
347    outputs: List[Argument]
348
349    # These are serialized by calling pytree.treespec_loads
350    # And deserialized by calling pytree.treespec_dumps
351    in_spec: str
352    out_spec: str
353
354
355@dataclass
356class ModuleCallEntry:
357    fqn: str
358    signature: Optional[ModuleCallSignature] = None
359
360
361@dataclass
362class GraphModule:
363    graph: Graph
364    signature: GraphSignature
365    # This is used for unflattening, by tracking the calling structure of all of
366    # the modules in order to unflatten the modules back to the eager calling
367    # conventions.
368    module_call_graph: List[ModuleCallEntry]
369
370
371# Invariant: Every time a change is made to the schema, one of the versions
372#            should be upadted.
373@dataclass
374class SchemaVersion:
375    major: int  # Major version number is bumped every time a breaking change is made.
376    minor: int  # Minor version number is bumped when a compatible change is made.
377
378
379@dataclass
380class ExportedProgram:
381    graph_module: GraphModule
382    # Key is the opset namespace (ex. aten), and value is the version number
383    opset_version: Dict[str, int]
384    range_constraints: Dict[str, RangeConstraint]
385    schema_version: SchemaVersion
386    dialect: str
387    verifiers: List[str] = field(default_factory=list)
388    dialect: str = ""  # TODO deprecated
389
390
391@dataclass
392class CompileSpec:
393    key: str
394    value: str
395
396
397@dataclass
398class LoweredBackendModule:
399    backend_id: str
400    processed_bytes: str
401    compile_specs: List[CompileSpec]
402    original_module: export_schema.ExportedProgram
403    original_state_dict: str
404    original_constants: str
405