xref: /aosp_15_r20/external/executorch/exir/schema.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from dataclasses import dataclass
10from enum import IntEnum
11from typing import List, Optional, Union
12
13from executorch.exir.backend.compile_spec_schema import CompileSpec
14
15from executorch.exir.scalar_type import ScalarType
16
17
18@dataclass
19class AllocationDetails:
20    memory_id: int
21    # Low 32 bits
22    memory_offset_low: int
23    # High 32 bits (typically zero)
24    memory_offset_high: int
25
26    @property
27    def memory_offset(self) -> int:
28        return self.memory_offset_low | (self.memory_offset_high << 32)
29
30
31@dataclass
32class OptionalTensorList:
33    items: List[int]
34
35
36class TensorShapeDynamism(IntEnum):
37    """
38    Check program.fbs for explanations of this enum.
39    """
40
41    STATIC = 0
42    DYNAMIC_BOUND = 1
43    DYNAMIC_UNBOUND = 2
44
45
46@dataclass
47class ExtraTensorInfo:
48    """
49    Check program.fbs for explanations of this enum.
50    """
51
52    mutable_data_segments_idx: Optional[int] = None
53    fully_qualified_name: Optional[str] = None
54
55
56@dataclass
57class Tensor:
58    scalar_type: ScalarType
59    storage_offset: int
60    sizes: List[int]
61    dim_order: List[bytes]
62    requires_grad: bool
63    layout: int
64    data_buffer_idx: int
65    allocation_info: Optional[AllocationDetails]
66
67    # check program.fbs for explanations.
68    shape_dynamism: TensorShapeDynamism
69    extra_tensor_info: Optional[ExtraTensorInfo] = None
70
71
72@dataclass
73class Null:
74    pass
75
76
77@dataclass
78class Int:
79    int_val: int
80
81
82@dataclass
83class Bool:
84    bool_val: bool
85
86
87@dataclass
88class Double:
89    double_val: Union[float, str]
90
91    def __init__(self, double_val: float) -> None:
92        if double_val == float("inf"):
93            self.double_val = "inf"
94        elif double_val == float("-inf"):
95            self.double_val = "-inf"
96        else:
97            self.double_val = double_val
98
99    def __post_init__(self) -> None:
100        if isinstance(self.double_val, str):
101            assert self.double_val in ["inf", "-inf"]
102        else:
103            assert isinstance(self.double_val, float)
104            assert not self.double_val == float("inf")
105            assert not self.double_val == float("-inf")
106
107
108@dataclass
109class String:
110    string_val: str
111
112
113@dataclass
114class ContainerMetadata:
115    encoded_inp_str: str
116    encoded_out_str: str
117
118
119@dataclass
120class IntList:
121    items: List[int]
122
123
124@dataclass
125class DoubleList:
126    items: List[float]
127
128
129@dataclass
130class BoolList:
131    items: List[bool]
132
133
134@dataclass
135class TensorList:
136    items: List[int]
137
138
139KernelTypes = Union[
140    Int,
141    Double,
142    Bool,
143    String,
144    Tensor,
145    IntList,
146    BoolList,
147    DoubleList,
148    TensorList,
149    Null,
150    OptionalTensorList,
151]
152
153
154@dataclass
155class EValue:
156    # Union types must be specified as strings so DataclassEncoder can see them.
157    val: "KernelTypes"
158
159
160@dataclass
161class Buffer:
162    storage: bytes
163
164
165@dataclass
166class BackendDelegateInlineData:
167    data: bytes
168
169
170@dataclass
171class KernelCall:
172    op_index: int
173    args: List[int]
174
175
176@dataclass
177class DelegateCall:
178    delegate_index: int
179    args: List[int]
180
181
182@dataclass
183class MoveCall:
184    move_from: int
185    move_to: int
186
187
188@dataclass
189class JumpFalseCall:
190    cond_value_index: int
191    destination_instruction: int
192
193
194@dataclass
195class FreeCall:
196    value_index: int
197
198
199InstructionArguments = Union[
200    KernelCall,
201    DelegateCall,
202    MoveCall,
203    JumpFalseCall,
204    FreeCall,
205]
206
207
208@dataclass
209class Instruction:
210    instr_args: "InstructionArguments"
211
212
213@dataclass
214class Frame:
215    filename: str
216    lineno: int
217    name: str
218    context: str
219
220
221@dataclass
222class FrameList:
223    items: List[Frame]
224
225
226class DataLocation(IntEnum):
227    INLINE = 0
228    SEGMENT = 1
229
230
231@dataclass
232class BackendDelegateDataReference:
233    location: DataLocation
234    index: int
235
236
237@dataclass
238class BackendDelegate:
239    id: str
240    processed: BackendDelegateDataReference
241    compile_specs: List[CompileSpec]
242
243
244@dataclass
245class Chain:
246    inputs: List[int]
247    outputs: List[int]
248    instructions: List[Instruction]
249    stacktrace: Optional[List[FrameList]]
250
251
252@dataclass
253class Operator:
254    name: str
255    overload: str
256
257
258@dataclass
259class ExecutionPlan:
260    name: str
261    container_meta_type: ContainerMetadata
262    values: List[EValue]
263    inputs: List[int]
264    outputs: List[int]
265    chains: List[Chain]
266    operators: List[Operator]
267    delegates: List[BackendDelegate]
268    # the list index is memory buffer id, the value is the memory buffer size.
269    # memory_buffer_id == 0 is special and is for constant memory buffer.
270    # Runtime should use the len(constant_buffer) as the ground truch of
271    # constant memory buffer size, and ignore non_const_buffer_sizes[0].
272    non_const_buffer_sizes: List[int]
273
274
275@dataclass
276class DataSegment:
277    offset: int
278    size: int
279
280
281@dataclass
282class SubsegmentOffsets:
283    segment_index: int
284    offsets: List[int]
285
286
287@dataclass
288class Program:
289    version: int
290    execution_plan: List[ExecutionPlan]
291    constant_buffer: List[Buffer]
292    backend_delegate_data: List[BackendDelegateInlineData]
293    segments: List[DataSegment]
294    constant_segment: SubsegmentOffsets
295    mutable_data_segments: Optional[List[SubsegmentOffsets]] = None
296