xref: /aosp_15_r20/external/executorch/backends/vulkan/serialization/vulkan_graph_serialize.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import ctypes
8import json
9import os
10import tempfile
11
12from dataclasses import dataclass
13from typing import ClassVar, List
14
15import pkg_resources
16import torch
17
18from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
19    VkBytes,
20    VkGraph,
21)
22from executorch.exir._serialize._dataclass import _DataclassEncoder
23
24from executorch.exir._serialize._flatbuffer import _flatc_compile
25
26
27def convert_to_flatbuffer(vk_graph: VkGraph) -> bytes:
28    vk_graph_json = json.dumps(vk_graph, cls=_DataclassEncoder)
29
30    with tempfile.TemporaryDirectory() as d:
31        schema_path = os.path.join(d, "schema.fbs")
32        with open(schema_path, "wb") as schema_file:
33            schema_file.write(pkg_resources.resource_string(__name__, "schema.fbs"))
34        json_path = os.path.join(d, "schema.json")
35        with open(json_path, "wb") as json_file:
36            json_file.write(vk_graph_json.encode("ascii"))
37        _flatc_compile(d, schema_path, json_path)
38        output_path = os.path.join(d, "schema.bin")
39        with open(output_path, "rb") as output_file:
40            return output_file.read()
41
42
43@dataclass
44class VulkanDelegateHeader:
45    # Defines the byte region that each component of the header corresponds to
46    MAGIC_IX: ClassVar[slice] = slice(4, 8)
47    HEADER_SIZE_IX: ClassVar[slice] = slice(8, 10)
48    FLATBUFFER_OFFSET_IX: ClassVar[slice] = slice(10, 14)
49    FLATBUFFER_SIZE_IX: ClassVar[slice] = slice(14, 18)
50    BYTES_OFFSET_IX: ClassVar[slice] = slice(18, 22)
51    BYTES_SIZE_IX: ClassVar[slice] = slice(22, 30)
52
53    # magic bytes that should be at the beginning of the header
54    EXPECTED_MAGIC: ClassVar[bytes] = b"VH00"
55    # The length of the header in bytes
56    EXPECTED_LENGTH: ClassVar[int] = 30
57
58    # Instance attributes, @dataclass will turn these into constructor args
59    flatbuffer_offset: int
60    flatbuffer_size: int
61    bytes_offset: int
62    bytes_size: int
63
64    @staticmethod
65    def from_bytes(data: bytes) -> "VulkanDelegateHeader":
66        if len(data) > VulkanDelegateHeader.EXPECTED_LENGTH:
67            raise ValueError(
68                f"Expected header to be {VulkanDelegateHeader.EXPECTED_LENGTH} bytes, "
69                f"but got {len(data)} bytes."
70            )
71
72        magic_b: bytes = data[VulkanDelegateHeader.MAGIC_IX]
73
74        if magic_b != VulkanDelegateHeader.EXPECTED_MAGIC:
75            raise ValueError(
76                f"Expected magic bytes to be {VulkanDelegateHeader.EXPECTED_MAGIC}, "
77                f"but got {magic_b}."
78            )
79
80        length: int = int.from_bytes(
81            data[VulkanDelegateHeader.HEADER_SIZE_IX], byteorder="little"
82        )
83
84        if length != VulkanDelegateHeader.EXPECTED_LENGTH:
85            raise ValueError(
86                f"Expected header to be {VulkanDelegateHeader.EXPECTED_LENGTH} bytes, "
87                f"but got {length} bytes."
88            )
89
90        flatbuffer_offset_b: bytes = data[VulkanDelegateHeader.FLATBUFFER_OFFSET_IX]
91        flatbuffer_size_b: bytes = data[VulkanDelegateHeader.FLATBUFFER_SIZE_IX]
92        bytes_offset_b: bytes = data[VulkanDelegateHeader.BYTES_OFFSET_IX]
93        bytes_size_b: bytes = data[VulkanDelegateHeader.BYTES_SIZE_IX]
94
95        return VulkanDelegateHeader(
96            flatbuffer_offset=int.from_bytes(flatbuffer_offset_b, byteorder="little"),
97            flatbuffer_size=int.from_bytes(flatbuffer_size_b, byteorder="little"),
98            bytes_offset=int.from_bytes(bytes_offset_b, byteorder="little"),
99            bytes_size=int.from_bytes(bytes_size_b, byteorder="little"),
100        )
101
102    def is_valid(self) -> bool:
103        if self.flatbuffer_size <= 0:
104            return False
105
106        expected_offset = self.flatbuffer_offset + self.flatbuffer_size
107        if self.bytes_offset < expected_offset:
108            return False
109
110        if self.bytes_size < 0:
111            return False
112
113        return True
114
115    def to_bytes(self) -> bytes:
116        if not self.is_valid():
117            raise ValueError("VulkanDelegateHeader instance contains invalid values")
118
119        data: bytes = (
120            # 4 bytes of padding for magic bytes, this is so that the header magic
121            # bytes is in the same position as the flatbuffer header magic bytes
122            b"\x00\x00\x00\x00"
123            + self.EXPECTED_MAGIC
124            + self.EXPECTED_LENGTH.to_bytes(2, byteorder="little")
125            + self.flatbuffer_offset.to_bytes(4, byteorder="little")
126            + self.flatbuffer_size.to_bytes(4, byteorder="little")
127            + self.bytes_offset.to_bytes(4, byteorder="little")
128            + self.bytes_size.to_bytes(8, byteorder="little")
129        )
130
131        assert len(data) == VulkanDelegateHeader.EXPECTED_LENGTH
132
133        return data
134
135
136def padding_required(data_len: int, alignment: int = 16) -> int:
137    remainder: int = data_len % alignment
138    if remainder != 0:
139        return alignment - remainder
140    return 0
141
142
143def aligned_size(data_len: int, alignment: int = 16) -> int:
144    return data_len + padding_required(data_len, alignment)
145
146
147def pad_to(data: bytes, size: int) -> bytes:
148    if size > len(data):
149        data += b"\x00" * (size - len(data))
150    return data
151
152
153def serialize_constant_tensors(
154    vk_graph: VkGraph,
155    const_tensors: List[torch.Tensor],
156    raw_bytes: bytearray,
157) -> None:
158    # Make sure that the graph does not have any registered constants prior to calling
159    # this function.
160    assert len(vk_graph.constants) == 0
161
162    current_offset = len(raw_bytes)
163    for tensor in const_tensors:
164        array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
165        array = ctypes.cast(
166            tensor.untyped_storage().data_ptr(),
167            ctypes.POINTER(array_type),
168        ).contents
169
170        tensor_bytes = bytes(array)
171        # Pad the tensor bytes to the next 16 byte boundary
172        raw_bytes += tensor_bytes
173        raw_bytes += b"\x00" * padding_required(len(tensor_bytes))
174
175        vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes)))
176        current_offset += aligned_size(len(tensor_bytes))
177
178
179def serialize_custom_shaders(
180    vk_graph: VkGraph,
181    custom_shaders: List[str],
182    raw_bytes: bytearray,
183) -> bytes:
184    # Make sure that the graph deos not have any registered shaders prior to calling
185    # this function.
186    assert len(vk_graph.shaders) == 0
187
188    if len(custom_shaders) == 0:
189        return b""
190
191    else:
192        raise NotImplementedError("Serializing Custom shaders are not yet supported")
193
194
195def serialize_vulkan_graph(
196    vk_graph: VkGraph, const_tensors: List[torch.Tensor], custom_shaders: List[str]
197) -> bytes:
198    raw_bytes = bytearray()
199    serialize_constant_tensors(vk_graph, const_tensors, raw_bytes)
200    serialize_custom_shaders(vk_graph, custom_shaders, raw_bytes)
201    raw_bytes = bytes(raw_bytes)
202
203    flatbuffer_payload = convert_to_flatbuffer(vk_graph)
204
205    header_len = aligned_size(VulkanDelegateHeader.EXPECTED_LENGTH)
206    flatbuffer_payload_len = aligned_size(len(flatbuffer_payload))
207    raw_bytes_len = aligned_size(len(raw_bytes))
208
209    header: bytes = VulkanDelegateHeader(
210        flatbuffer_offset=header_len,
211        flatbuffer_size=len(flatbuffer_payload),
212        bytes_offset=header_len + flatbuffer_payload_len,
213        bytes_size=len(raw_bytes),
214    ).to_bytes()
215
216    return b"".join(
217        [
218            pad_to(header, header_len),
219            pad_to(flatbuffer_payload, flatbuffer_payload_len),
220            pad_to(raw_bytes, raw_bytes_len),
221        ]
222    )
223