xref: /aosp_15_r20/external/executorch/backends/xnnpack/serialization/xnnpack_graph_serialize.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9import tempfile
10
11from dataclasses import dataclass, fields, is_dataclass
12from typing import ClassVar, Literal
13
14import pkg_resources
15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph
16from executorch.exir._serialize._dataclass import _DataclassEncoder
17
18from executorch.exir._serialize._flatbuffer import _flatc_compile
19
20# Byte order of numbers written to program headers. Always little-endian
21# regardless of the host system, since all commonly-used modern CPUs are little
22# endian.
23_HEADER_BYTEORDER: Literal["little"] = "little"
24
25# Constant Tensor alignment for serializaing XNNPACK payloads
26CONSTANT_TENSOR_ALIGNMENT = 16
27
28
29def sanity_check_xnngraph_dataclass(table, name: str = ""):
30    """
31    Make sure no SymInt sneaked in during the preparation of XNNGraph.
32    """
33    assert is_dataclass(table), f"Expecting a dataclass but got {type(table)}"
34
35    def get_cls_name(obj, field_name=None):
36        return (
37            f"<{obj.__class__.__name__}>{field_name}"
38            if field_name
39            else obj.__class__.__name__
40        )
41
42    def check_for_sym(obj, name):
43        """
44        Basic check against the class name of the given obj and
45        if it starts from "Sym" or not to catch SymInt the main culprit.
46        """
47        class_name = get_cls_name(obj)
48        assert (
49            "Sym" not in class_name
50        ), f"Non serializable type {class_name} found at type {name}"
51
52    _name = name if len(name) else get_cls_name(table)
53
54    for field in fields(table):
55        o = getattr(table, field.name)
56
57        # Skip str and bytes
58        if isinstance(o, str) or isinstance(o, bytes):
59            continue
60
61        _name_field = f"{_name}.{get_cls_name(o, field.name)}"
62
63        # Recurse
64        if is_dataclass(o):
65            sanity_check_xnngraph_dataclass(o, _name_field)
66
67        # Only handles List type, add more if needed
68        elif isinstance(o, list):
69            for i, v in enumerate(o):
70                _name_field_i = _name_field + f"[{i}]"
71                # Recurse
72                if is_dataclass(v):
73                    sanity_check_xnngraph_dataclass(v, f"{_name_field_i}")
74                else:
75                    check_for_sym(v, _name_field_i)
76        else:
77            check_for_sym(o, _name_field)
78
79
80@dataclass
81class XNNHeader:
82    # Class Constants
83    MAGIC_OFFSET: ClassVar[slice] = slice(4, 8)
84    HEADER_SIZE_OFFSET: ClassVar[slice] = slice(8, 10)
85    FLATBUFFER_OFFSET_OFFSET: ClassVar[slice] = slice(10, 14)
86    FLATBUFFER_SIZE_OFFSET: ClassVar[slice] = slice(14, 18)
87    CONSTANT_DATA_OFFSET_OFFSET: ClassVar[slice] = slice(18, 22)
88    CONSTANT_DATA_SIZE_OFFSET: ClassVar[slice] = slice(22, 30)
89
90    # magic bytes that should be at the beginning of the header
91    EXPECTED_MAGIC: ClassVar[bytes] = b"XH00"
92    # The length of the header in bytes.
93    EXPECTED_LENGTH: ClassVar[int] = (
94        # Zeros magic
95        # We offset the magic by 4 bytes so that it is in the same location
96        # as the flatbuffer payload's magic. This way we can dynamically
97        # choose between the XNNPACK Header and Flatbuffer Header
98        4
99        # Header magic
100        + 4
101        # Header Length
102        + 2
103        # Flatbuffer offset
104        + 4
105        # Flatbuffer size
106        + 4
107        # Constant Data offset
108        + 4
109        # Constant Data size
110        + 8
111    )
112
113    # Instance attributes. @dataclass will turn these into ctor args.
114
115    # offset to the flatbuffer data
116    flatbuffer_offset: int
117
118    # flatbuffer size
119    flatbuffer_size: int
120
121    # offset to the constant data
122    constant_data_offset: int
123
124    # constant data size
125    constant_data_size: int
126
127    @staticmethod
128    def from_bytes(data: bytes) -> "XNNHeader":
129        """
130        Converts the given bytes into an XNNHeader object.
131
132        We check that the magic and length is valid, but do not check that the offset and
133        size values are valid. We ensure here that the XNNHeader metadata is valid (magic and length)
134        but not the offsets and sizes themselves. Callers should use is_valid() to validate the
135        header contents
136
137        Args:
138            data: Data to read from
139        Returns:
140            XNNHeader object that contains the parsed data
141        Raises:
142            ValueError: if not enough data is provided, or if parsed length/magic are invalid
143        """
144        if len(data) > XNNHeader.EXPECTED_LENGTH:
145            raise ValueError(
146                f"Invalid XNNHeader: expected no more than {XNNHeader.EXPECTED_LENGTH} bytes, got {len(data)}"
147            )
148
149        magic: bytes = data[XNNHeader.MAGIC_OFFSET]
150        length_bytes: bytes = data[XNNHeader.HEADER_SIZE_OFFSET]
151        flatbuffer_offset_bytes: bytes = data[XNNHeader.FLATBUFFER_OFFSET_OFFSET]
152        flatbuffer_size_bytes: bytes = data[XNNHeader.FLATBUFFER_SIZE_OFFSET]
153        constant_data_offset_bytes: bytes = data[XNNHeader.CONSTANT_DATA_OFFSET_OFFSET]
154        constant_data_size_bytes: bytes = data[XNNHeader.CONSTANT_DATA_SIZE_OFFSET]
155
156        length = int.from_bytes(length_bytes, byteorder=_HEADER_BYTEORDER)
157
158        if magic != XNNHeader.EXPECTED_MAGIC:
159            raise ValueError(
160                f"Invalid XNNHeader: invalid magic bytes {magic}, expected {XNNHeader.EXPECTED_MAGIC}"
161            )
162        if length != len(data):
163            raise ValueError(
164                f"Invalid XNNHeader: Invalid parsed length: data given was {len(data)} bytes, parsed length was {length} bytes"
165            )
166
167        return XNNHeader(
168            flatbuffer_offset=int.from_bytes(
169                flatbuffer_offset_bytes, byteorder=_HEADER_BYTEORDER
170            ),
171            flatbuffer_size=int.from_bytes(
172                flatbuffer_size_bytes, byteorder=_HEADER_BYTEORDER
173            ),
174            constant_data_offset=int.from_bytes(
175                constant_data_offset_bytes, byteorder=_HEADER_BYTEORDER
176            ),
177            constant_data_size=int.from_bytes(
178                constant_data_size_bytes, byteorder=_HEADER_BYTEORDER
179            ),
180        )
181
182    def is_valid(self) -> bool:
183        """
184        Sanity checks the the XNNHeader.
185
186        We check that the flatbuffer size is non_zero and that the constant data offset
187        is after the flatbuffer payload. We check that the constant data size is non-negative.
188
189        Returns:
190            True if the XNNHeader is valid, False otherwise
191        """
192        # flatbuffer payload must have a non-zero size
193        valid_flatbuffer_size = self.flatbuffer_size > 0
194        # constant data offset is after flatbuffer payload
195        valid_const_data_offset = (
196            self.constant_data_offset >= self.flatbuffer_offset + self.flatbuffer_size
197        )
198        valid_const_data_size = self.constant_data_size >= 0
199
200        return (
201            valid_flatbuffer_size and valid_const_data_offset and valid_const_data_size
202        )
203
204    def to_bytes(self) -> bytes:
205        """
206        Converts XNNHeader to bytes for serialization.
207
208        Returns:
209            Returns the binary representation of the XNNPACK Header.
210        """
211
212        # We expect the given offsets and sizes to be valid
213        if not self.is_valid():
214            raise ValueError("Invalid XNNHeader: header failed is_valid() check")
215
216        data: bytes = (
217            # Padding for magic bytes. This is so that header magic is in the same position
218            # as the flatbuffer magic, and allows consumer to detect whether the header is
219            # being used or not
220            b"\x00\x00\x00\x00"
221            # XNNPACK Header's magic. This allows consumer to detect whether or not the header
222            # is being used or the flatbuffer header is being used
223            + self.EXPECTED_MAGIC
224            # uint16_t: Size of this header. This makes it easier to add new fields to the header
225            # in the future.
226            + self.EXPECTED_LENGTH.to_bytes(2, byteorder=_HEADER_BYTEORDER)
227            # uint32_t: Offset to the start of the flatbuffer data
228            + self.flatbuffer_offset.to_bytes(4, byteorder=_HEADER_BYTEORDER)
229            # uint32_t: Size of the flatbuffer data payload
230            + self.flatbuffer_size.to_bytes(4, byteorder=_HEADER_BYTEORDER)
231            # uint32_t: Offset to the start of the constant data
232            + self.constant_data_offset.to_bytes(4, byteorder=_HEADER_BYTEORDER)
233            # uint64_t: Size of the constant data
234            + self.constant_data_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
235        )
236
237        assert len(data) == XNNHeader.EXPECTED_LENGTH
238
239        return data
240
241
242def _padding_required(offset: int, alignment: int) -> int:
243    """Returns the padding required to align `offset` to `alignment`."""
244    remainder: int = offset % alignment
245    if remainder != 0:
246        return alignment - remainder
247    return 0
248
249
250def _aligned_size(input_size: int, alignment: int) -> int:
251    """Returns input_size padded up to the next whole multiple of alignment."""
252    aligned_size = input_size + _padding_required(input_size, alignment)
253    assert aligned_size % alignment == 0
254    return aligned_size
255
256
257def _pad_to(data: bytes, length: int) -> bytes:
258    """Returns the input followed by enough zero bytes to become the requested length.
259
260    Args:
261        data: The data to pad.
262        length: The length of the returned data.
263    Returns:
264        The padded data.
265    Raises:
266        ValueError: If the requested length is less than the input length.
267    """
268    if length < len(data):
269        raise ValueError(f"Data length {len(data)} > padded length {length}")
270    if length > len(data):
271        data = data + b"\x00" * (length - len(data))
272    assert len(data) == length
273    return data
274
275
276def pretty_print_xnngraph(xnnpack_graph_json: str):
277    """
278    Pretty print the XNNGraph
279    """
280    from pprint import pprint
281
282    d = json.loads(xnnpack_graph_json)
283    pprint(d)
284
285
286def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
287    sanity_check_xnngraph_dataclass(xnnpack_graph)
288    xnnpack_graph_json = json.dumps(xnnpack_graph, cls=_DataclassEncoder)
289    with tempfile.TemporaryDirectory() as d:
290        schema_path = os.path.join(d, "schema.fbs")
291        with open(schema_path, "wb") as schema_file:
292            schema_file.write(pkg_resources.resource_string(__name__, "schema.fbs"))
293        json_path = os.path.join(d, "schema.json")
294        with open(json_path, "wb") as json_file:
295            json_file.write(xnnpack_graph_json.encode("ascii"))
296
297        _flatc_compile(d, schema_path, json_path)
298        output_path = os.path.join(d, "schema.bin")
299        with open(output_path, "rb") as output_file:
300            return output_file.read()
301
302
303def serialize_xnnpack_binary(
304    xnnpack_graph: XNNGraph, constant_data_bytes: bytearray
305) -> bytes:
306    """Returns the runtime binary representation of the given XNNGraph.
307
308    Args:
309        xnnpack_graph: XNNGraph object to serialize.
310
311    Returns:
312        The serialized form of the XNNGraph, ready for execution by XNNPACK Backend
313    """
314
315    # Convert the XNNGraph to a flatbuffer
316    flatbuffer_payload = convert_to_flatbuffer(xnnpack_graph)
317
318    # size of flatbuffer data, padded to be `constant_tensor_alignment`  byte aligned
319    padded_flatbuffer_length: int = _aligned_size(
320        input_size=len(flatbuffer_payload),
321        alignment=CONSTANT_TENSOR_ALIGNMENT,
322    )
323    # size of header to insert, padded to be `constant_tensor_alignment` byte aligned
324    padded_header_length: int = _aligned_size(
325        input_size=XNNHeader.EXPECTED_LENGTH, alignment=CONSTANT_TENSOR_ALIGNMENT
326    )
327
328    # Create the XNNPACK Header
329    header: bytes = XNNHeader(
330        flatbuffer_offset=padded_header_length,
331        flatbuffer_size=len(flatbuffer_payload),
332        constant_data_offset=padded_header_length + padded_flatbuffer_length,
333        constant_data_size=len(constant_data_bytes),
334    ).to_bytes()
335
336    return b"".join(
337        [
338            _pad_to(header, padded_header_length),
339            _pad_to(flatbuffer_payload, padded_flatbuffer_length),
340            constant_data_bytes,
341        ]
342    )
343