1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import ctypes 8import json 9import os 10import tempfile 11 12from dataclasses import dataclass 13from typing import ClassVar, List 14 15import pkg_resources 16import torch 17 18from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( 19 VkBytes, 20 VkGraph, 21) 22from executorch.exir._serialize._dataclass import _DataclassEncoder 23 24from executorch.exir._serialize._flatbuffer import _flatc_compile 25 26 27def convert_to_flatbuffer(vk_graph: VkGraph) -> bytes: 28 vk_graph_json = json.dumps(vk_graph, cls=_DataclassEncoder) 29 30 with tempfile.TemporaryDirectory() as d: 31 schema_path = os.path.join(d, "schema.fbs") 32 with open(schema_path, "wb") as schema_file: 33 schema_file.write(pkg_resources.resource_string(__name__, "schema.fbs")) 34 json_path = os.path.join(d, "schema.json") 35 with open(json_path, "wb") as json_file: 36 json_file.write(vk_graph_json.encode("ascii")) 37 _flatc_compile(d, schema_path, json_path) 38 output_path = os.path.join(d, "schema.bin") 39 with open(output_path, "rb") as output_file: 40 return output_file.read() 41 42 43@dataclass 44class VulkanDelegateHeader: 45 # Defines the byte region that each component of the header corresponds to 46 MAGIC_IX: ClassVar[slice] = slice(4, 8) 47 HEADER_SIZE_IX: ClassVar[slice] = slice(8, 10) 48 FLATBUFFER_OFFSET_IX: ClassVar[slice] = slice(10, 14) 49 FLATBUFFER_SIZE_IX: ClassVar[slice] = slice(14, 18) 50 BYTES_OFFSET_IX: ClassVar[slice] = slice(18, 22) 51 BYTES_SIZE_IX: ClassVar[slice] = slice(22, 30) 52 53 # magic bytes that should be at the beginning of the header 54 EXPECTED_MAGIC: ClassVar[bytes] = b"VH00" 55 # The length of the header in bytes 56 EXPECTED_LENGTH: ClassVar[int] = 30 57 58 # Instance attributes, @dataclass will turn these into constructor args 59 flatbuffer_offset: int 60 flatbuffer_size: int 61 bytes_offset: int 62 bytes_size: int 63 64 @staticmethod 65 def from_bytes(data: bytes) -> "VulkanDelegateHeader": 66 if len(data) > VulkanDelegateHeader.EXPECTED_LENGTH: 67 raise ValueError( 68 f"Expected header to be {VulkanDelegateHeader.EXPECTED_LENGTH} bytes, " 69 f"but got {len(data)} bytes." 70 ) 71 72 magic_b: bytes = data[VulkanDelegateHeader.MAGIC_IX] 73 74 if magic_b != VulkanDelegateHeader.EXPECTED_MAGIC: 75 raise ValueError( 76 f"Expected magic bytes to be {VulkanDelegateHeader.EXPECTED_MAGIC}, " 77 f"but got {magic_b}." 78 ) 79 80 length: int = int.from_bytes( 81 data[VulkanDelegateHeader.HEADER_SIZE_IX], byteorder="little" 82 ) 83 84 if length != VulkanDelegateHeader.EXPECTED_LENGTH: 85 raise ValueError( 86 f"Expected header to be {VulkanDelegateHeader.EXPECTED_LENGTH} bytes, " 87 f"but got {length} bytes." 88 ) 89 90 flatbuffer_offset_b: bytes = data[VulkanDelegateHeader.FLATBUFFER_OFFSET_IX] 91 flatbuffer_size_b: bytes = data[VulkanDelegateHeader.FLATBUFFER_SIZE_IX] 92 bytes_offset_b: bytes = data[VulkanDelegateHeader.BYTES_OFFSET_IX] 93 bytes_size_b: bytes = data[VulkanDelegateHeader.BYTES_SIZE_IX] 94 95 return VulkanDelegateHeader( 96 flatbuffer_offset=int.from_bytes(flatbuffer_offset_b, byteorder="little"), 97 flatbuffer_size=int.from_bytes(flatbuffer_size_b, byteorder="little"), 98 bytes_offset=int.from_bytes(bytes_offset_b, byteorder="little"), 99 bytes_size=int.from_bytes(bytes_size_b, byteorder="little"), 100 ) 101 102 def is_valid(self) -> bool: 103 if self.flatbuffer_size <= 0: 104 return False 105 106 expected_offset = self.flatbuffer_offset + self.flatbuffer_size 107 if self.bytes_offset < expected_offset: 108 return False 109 110 if self.bytes_size < 0: 111 return False 112 113 return True 114 115 def to_bytes(self) -> bytes: 116 if not self.is_valid(): 117 raise ValueError("VulkanDelegateHeader instance contains invalid values") 118 119 data: bytes = ( 120 # 4 bytes of padding for magic bytes, this is so that the header magic 121 # bytes is in the same position as the flatbuffer header magic bytes 122 b"\x00\x00\x00\x00" 123 + self.EXPECTED_MAGIC 124 + self.EXPECTED_LENGTH.to_bytes(2, byteorder="little") 125 + self.flatbuffer_offset.to_bytes(4, byteorder="little") 126 + self.flatbuffer_size.to_bytes(4, byteorder="little") 127 + self.bytes_offset.to_bytes(4, byteorder="little") 128 + self.bytes_size.to_bytes(8, byteorder="little") 129 ) 130 131 assert len(data) == VulkanDelegateHeader.EXPECTED_LENGTH 132 133 return data 134 135 136def padding_required(data_len: int, alignment: int = 16) -> int: 137 remainder: int = data_len % alignment 138 if remainder != 0: 139 return alignment - remainder 140 return 0 141 142 143def aligned_size(data_len: int, alignment: int = 16) -> int: 144 return data_len + padding_required(data_len, alignment) 145 146 147def pad_to(data: bytes, size: int) -> bytes: 148 if size > len(data): 149 data += b"\x00" * (size - len(data)) 150 return data 151 152 153def serialize_constant_tensors( 154 vk_graph: VkGraph, 155 const_tensors: List[torch.Tensor], 156 raw_bytes: bytearray, 157) -> None: 158 # Make sure that the graph does not have any registered constants prior to calling 159 # this function. 160 assert len(vk_graph.constants) == 0 161 162 current_offset = len(raw_bytes) 163 for tensor in const_tensors: 164 array_type = ctypes.c_char * tensor.untyped_storage().nbytes() 165 array = ctypes.cast( 166 tensor.untyped_storage().data_ptr(), 167 ctypes.POINTER(array_type), 168 ).contents 169 170 tensor_bytes = bytes(array) 171 # Pad the tensor bytes to the next 16 byte boundary 172 raw_bytes += tensor_bytes 173 raw_bytes += b"\x00" * padding_required(len(tensor_bytes)) 174 175 vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes))) 176 current_offset += aligned_size(len(tensor_bytes)) 177 178 179def serialize_custom_shaders( 180 vk_graph: VkGraph, 181 custom_shaders: List[str], 182 raw_bytes: bytearray, 183) -> bytes: 184 # Make sure that the graph deos not have any registered shaders prior to calling 185 # this function. 186 assert len(vk_graph.shaders) == 0 187 188 if len(custom_shaders) == 0: 189 return b"" 190 191 else: 192 raise NotImplementedError("Serializing Custom shaders are not yet supported") 193 194 195def serialize_vulkan_graph( 196 vk_graph: VkGraph, const_tensors: List[torch.Tensor], custom_shaders: List[str] 197) -> bytes: 198 raw_bytes = bytearray() 199 serialize_constant_tensors(vk_graph, const_tensors, raw_bytes) 200 serialize_custom_shaders(vk_graph, custom_shaders, raw_bytes) 201 raw_bytes = bytes(raw_bytes) 202 203 flatbuffer_payload = convert_to_flatbuffer(vk_graph) 204 205 header_len = aligned_size(VulkanDelegateHeader.EXPECTED_LENGTH) 206 flatbuffer_payload_len = aligned_size(len(flatbuffer_payload)) 207 raw_bytes_len = aligned_size(len(raw_bytes)) 208 209 header: bytes = VulkanDelegateHeader( 210 flatbuffer_offset=header_len, 211 flatbuffer_size=len(flatbuffer_payload), 212 bytes_offset=header_len + flatbuffer_payload_len, 213 bytes_size=len(raw_bytes), 214 ).to_bytes() 215 216 return b"".join( 217 [ 218 pad_to(header, header_len), 219 pad_to(flatbuffer_payload, flatbuffer_payload_len), 220 pad_to(raw_bytes, raw_bytes_len), 221 ] 222 ) 223