1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9import tempfile 10 11from dataclasses import dataclass, fields, is_dataclass 12from typing import ClassVar, Literal 13 14import pkg_resources 15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph 16from executorch.exir._serialize._dataclass import _DataclassEncoder 17 18from executorch.exir._serialize._flatbuffer import _flatc_compile 19 20# Byte order of numbers written to program headers. Always little-endian 21# regardless of the host system, since all commonly-used modern CPUs are little 22# endian. 23_HEADER_BYTEORDER: Literal["little"] = "little" 24 25# Constant Tensor alignment for serializaing XNNPACK payloads 26CONSTANT_TENSOR_ALIGNMENT = 16 27 28 29def sanity_check_xnngraph_dataclass(table, name: str = ""): 30 """ 31 Make sure no SymInt sneaked in during the preparation of XNNGraph. 32 """ 33 assert is_dataclass(table), f"Expecting a dataclass but got {type(table)}" 34 35 def get_cls_name(obj, field_name=None): 36 return ( 37 f"<{obj.__class__.__name__}>{field_name}" 38 if field_name 39 else obj.__class__.__name__ 40 ) 41 42 def check_for_sym(obj, name): 43 """ 44 Basic check against the class name of the given obj and 45 if it starts from "Sym" or not to catch SymInt the main culprit. 46 """ 47 class_name = get_cls_name(obj) 48 assert ( 49 "Sym" not in class_name 50 ), f"Non serializable type {class_name} found at type {name}" 51 52 _name = name if len(name) else get_cls_name(table) 53 54 for field in fields(table): 55 o = getattr(table, field.name) 56 57 # Skip str and bytes 58 if isinstance(o, str) or isinstance(o, bytes): 59 continue 60 61 _name_field = f"{_name}.{get_cls_name(o, field.name)}" 62 63 # Recurse 64 if is_dataclass(o): 65 sanity_check_xnngraph_dataclass(o, _name_field) 66 67 # Only handles List type, add more if needed 68 elif isinstance(o, list): 69 for i, v in enumerate(o): 70 _name_field_i = _name_field + f"[{i}]" 71 # Recurse 72 if is_dataclass(v): 73 sanity_check_xnngraph_dataclass(v, f"{_name_field_i}") 74 else: 75 check_for_sym(v, _name_field_i) 76 else: 77 check_for_sym(o, _name_field) 78 79 80@dataclass 81class XNNHeader: 82 # Class Constants 83 MAGIC_OFFSET: ClassVar[slice] = slice(4, 8) 84 HEADER_SIZE_OFFSET: ClassVar[slice] = slice(8, 10) 85 FLATBUFFER_OFFSET_OFFSET: ClassVar[slice] = slice(10, 14) 86 FLATBUFFER_SIZE_OFFSET: ClassVar[slice] = slice(14, 18) 87 CONSTANT_DATA_OFFSET_OFFSET: ClassVar[slice] = slice(18, 22) 88 CONSTANT_DATA_SIZE_OFFSET: ClassVar[slice] = slice(22, 30) 89 90 # magic bytes that should be at the beginning of the header 91 EXPECTED_MAGIC: ClassVar[bytes] = b"XH00" 92 # The length of the header in bytes. 93 EXPECTED_LENGTH: ClassVar[int] = ( 94 # Zeros magic 95 # We offset the magic by 4 bytes so that it is in the same location 96 # as the flatbuffer payload's magic. This way we can dynamically 97 # choose between the XNNPACK Header and Flatbuffer Header 98 4 99 # Header magic 100 + 4 101 # Header Length 102 + 2 103 # Flatbuffer offset 104 + 4 105 # Flatbuffer size 106 + 4 107 # Constant Data offset 108 + 4 109 # Constant Data size 110 + 8 111 ) 112 113 # Instance attributes. @dataclass will turn these into ctor args. 114 115 # offset to the flatbuffer data 116 flatbuffer_offset: int 117 118 # flatbuffer size 119 flatbuffer_size: int 120 121 # offset to the constant data 122 constant_data_offset: int 123 124 # constant data size 125 constant_data_size: int 126 127 @staticmethod 128 def from_bytes(data: bytes) -> "XNNHeader": 129 """ 130 Converts the given bytes into an XNNHeader object. 131 132 We check that the magic and length is valid, but do not check that the offset and 133 size values are valid. We ensure here that the XNNHeader metadata is valid (magic and length) 134 but not the offsets and sizes themselves. Callers should use is_valid() to validate the 135 header contents 136 137 Args: 138 data: Data to read from 139 Returns: 140 XNNHeader object that contains the parsed data 141 Raises: 142 ValueError: if not enough data is provided, or if parsed length/magic are invalid 143 """ 144 if len(data) > XNNHeader.EXPECTED_LENGTH: 145 raise ValueError( 146 f"Invalid XNNHeader: expected no more than {XNNHeader.EXPECTED_LENGTH} bytes, got {len(data)}" 147 ) 148 149 magic: bytes = data[XNNHeader.MAGIC_OFFSET] 150 length_bytes: bytes = data[XNNHeader.HEADER_SIZE_OFFSET] 151 flatbuffer_offset_bytes: bytes = data[XNNHeader.FLATBUFFER_OFFSET_OFFSET] 152 flatbuffer_size_bytes: bytes = data[XNNHeader.FLATBUFFER_SIZE_OFFSET] 153 constant_data_offset_bytes: bytes = data[XNNHeader.CONSTANT_DATA_OFFSET_OFFSET] 154 constant_data_size_bytes: bytes = data[XNNHeader.CONSTANT_DATA_SIZE_OFFSET] 155 156 length = int.from_bytes(length_bytes, byteorder=_HEADER_BYTEORDER) 157 158 if magic != XNNHeader.EXPECTED_MAGIC: 159 raise ValueError( 160 f"Invalid XNNHeader: invalid magic bytes {magic}, expected {XNNHeader.EXPECTED_MAGIC}" 161 ) 162 if length != len(data): 163 raise ValueError( 164 f"Invalid XNNHeader: Invalid parsed length: data given was {len(data)} bytes, parsed length was {length} bytes" 165 ) 166 167 return XNNHeader( 168 flatbuffer_offset=int.from_bytes( 169 flatbuffer_offset_bytes, byteorder=_HEADER_BYTEORDER 170 ), 171 flatbuffer_size=int.from_bytes( 172 flatbuffer_size_bytes, byteorder=_HEADER_BYTEORDER 173 ), 174 constant_data_offset=int.from_bytes( 175 constant_data_offset_bytes, byteorder=_HEADER_BYTEORDER 176 ), 177 constant_data_size=int.from_bytes( 178 constant_data_size_bytes, byteorder=_HEADER_BYTEORDER 179 ), 180 ) 181 182 def is_valid(self) -> bool: 183 """ 184 Sanity checks the the XNNHeader. 185 186 We check that the flatbuffer size is non_zero and that the constant data offset 187 is after the flatbuffer payload. We check that the constant data size is non-negative. 188 189 Returns: 190 True if the XNNHeader is valid, False otherwise 191 """ 192 # flatbuffer payload must have a non-zero size 193 valid_flatbuffer_size = self.flatbuffer_size > 0 194 # constant data offset is after flatbuffer payload 195 valid_const_data_offset = ( 196 self.constant_data_offset >= self.flatbuffer_offset + self.flatbuffer_size 197 ) 198 valid_const_data_size = self.constant_data_size >= 0 199 200 return ( 201 valid_flatbuffer_size and valid_const_data_offset and valid_const_data_size 202 ) 203 204 def to_bytes(self) -> bytes: 205 """ 206 Converts XNNHeader to bytes for serialization. 207 208 Returns: 209 Returns the binary representation of the XNNPACK Header. 210 """ 211 212 # We expect the given offsets and sizes to be valid 213 if not self.is_valid(): 214 raise ValueError("Invalid XNNHeader: header failed is_valid() check") 215 216 data: bytes = ( 217 # Padding for magic bytes. This is so that header magic is in the same position 218 # as the flatbuffer magic, and allows consumer to detect whether the header is 219 # being used or not 220 b"\x00\x00\x00\x00" 221 # XNNPACK Header's magic. This allows consumer to detect whether or not the header 222 # is being used or the flatbuffer header is being used 223 + self.EXPECTED_MAGIC 224 # uint16_t: Size of this header. This makes it easier to add new fields to the header 225 # in the future. 226 + self.EXPECTED_LENGTH.to_bytes(2, byteorder=_HEADER_BYTEORDER) 227 # uint32_t: Offset to the start of the flatbuffer data 228 + self.flatbuffer_offset.to_bytes(4, byteorder=_HEADER_BYTEORDER) 229 # uint32_t: Size of the flatbuffer data payload 230 + self.flatbuffer_size.to_bytes(4, byteorder=_HEADER_BYTEORDER) 231 # uint32_t: Offset to the start of the constant data 232 + self.constant_data_offset.to_bytes(4, byteorder=_HEADER_BYTEORDER) 233 # uint64_t: Size of the constant data 234 + self.constant_data_size.to_bytes(8, byteorder=_HEADER_BYTEORDER) 235 ) 236 237 assert len(data) == XNNHeader.EXPECTED_LENGTH 238 239 return data 240 241 242def _padding_required(offset: int, alignment: int) -> int: 243 """Returns the padding required to align `offset` to `alignment`.""" 244 remainder: int = offset % alignment 245 if remainder != 0: 246 return alignment - remainder 247 return 0 248 249 250def _aligned_size(input_size: int, alignment: int) -> int: 251 """Returns input_size padded up to the next whole multiple of alignment.""" 252 aligned_size = input_size + _padding_required(input_size, alignment) 253 assert aligned_size % alignment == 0 254 return aligned_size 255 256 257def _pad_to(data: bytes, length: int) -> bytes: 258 """Returns the input followed by enough zero bytes to become the requested length. 259 260 Args: 261 data: The data to pad. 262 length: The length of the returned data. 263 Returns: 264 The padded data. 265 Raises: 266 ValueError: If the requested length is less than the input length. 267 """ 268 if length < len(data): 269 raise ValueError(f"Data length {len(data)} > padded length {length}") 270 if length > len(data): 271 data = data + b"\x00" * (length - len(data)) 272 assert len(data) == length 273 return data 274 275 276def pretty_print_xnngraph(xnnpack_graph_json: str): 277 """ 278 Pretty print the XNNGraph 279 """ 280 from pprint import pprint 281 282 d = json.loads(xnnpack_graph_json) 283 pprint(d) 284 285 286def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes: 287 sanity_check_xnngraph_dataclass(xnnpack_graph) 288 xnnpack_graph_json = json.dumps(xnnpack_graph, cls=_DataclassEncoder) 289 with tempfile.TemporaryDirectory() as d: 290 schema_path = os.path.join(d, "schema.fbs") 291 with open(schema_path, "wb") as schema_file: 292 schema_file.write(pkg_resources.resource_string(__name__, "schema.fbs")) 293 json_path = os.path.join(d, "schema.json") 294 with open(json_path, "wb") as json_file: 295 json_file.write(xnnpack_graph_json.encode("ascii")) 296 297 _flatc_compile(d, schema_path, json_path) 298 output_path = os.path.join(d, "schema.bin") 299 with open(output_path, "rb") as output_file: 300 return output_file.read() 301 302 303def serialize_xnnpack_binary( 304 xnnpack_graph: XNNGraph, constant_data_bytes: bytearray 305) -> bytes: 306 """Returns the runtime binary representation of the given XNNGraph. 307 308 Args: 309 xnnpack_graph: XNNGraph object to serialize. 310 311 Returns: 312 The serialized form of the XNNGraph, ready for execution by XNNPACK Backend 313 """ 314 315 # Convert the XNNGraph to a flatbuffer 316 flatbuffer_payload = convert_to_flatbuffer(xnnpack_graph) 317 318 # size of flatbuffer data, padded to be `constant_tensor_alignment` byte aligned 319 padded_flatbuffer_length: int = _aligned_size( 320 input_size=len(flatbuffer_payload), 321 alignment=CONSTANT_TENSOR_ALIGNMENT, 322 ) 323 # size of header to insert, padded to be `constant_tensor_alignment` byte aligned 324 padded_header_length: int = _aligned_size( 325 input_size=XNNHeader.EXPECTED_LENGTH, alignment=CONSTANT_TENSOR_ALIGNMENT 326 ) 327 328 # Create the XNNPACK Header 329 header: bytes = XNNHeader( 330 flatbuffer_offset=padded_header_length, 331 flatbuffer_size=len(flatbuffer_payload), 332 constant_data_offset=padded_header_length + padded_flatbuffer_length, 333 constant_data_size=len(constant_data_bytes), 334 ).to_bytes() 335 336 return b"".join( 337 [ 338 _pad_to(header, padded_header_length), 339 _pad_to(flatbuffer_payload, padded_flatbuffer_length), 340 constant_data_bytes, 341 ] 342 ) 343