# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import copy import json import re from dataclasses import dataclass from typing import ClassVar, List, Literal, Optional, Tuple from executorch.exir._serialize._cord import Cord from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass from executorch.exir._serialize._flatbuffer import ( _FlatbufferResult, _program_flatbuffer_to_json, _program_json_to_flatbuffer, ) from executorch.exir.schema import ( BackendDelegateDataReference, BackendDelegateInlineData, Buffer, DataLocation, DataSegment, Program, SubsegmentOffsets, ) from executorch.exir.tensor import ALIGNMENT # Byte order of numbers written to program headers. Always little-endian # regardless of the host system, since all commonly-used modern CPUs are little # endian. _HEADER_BYTEORDER: Literal["little"] = "little" def _program_to_json(program: Program) -> str: """Returns the JSON representation of the given Program.""" return json.dumps(program, cls=_DataclassEncoder) def _json_to_program(program_json: bytes) -> Program: """Returns a Program deserialized from the given JSON string.""" # construct program class recursively from dict return _json_to_dataclass(json.loads(program_json), cls=Program) def _padding_required(offset: int, alignment: int) -> int: """Returns the padding required to align `offset` to `alignment`.""" remainder: int = offset % alignment if remainder != 0: return alignment - remainder return 0 def _aligned_size(input_size: int, alignment: int) -> int: """Returns input_size padded up to the next whole multiple of alignment.""" return input_size + _padding_required(input_size, alignment) def _insert_flatbuffer_header( flatbuffer_data: bytes, magic_regex: str, header_data: bytes ) -> bytes: """Inserts a header just after the magic string of the provided flatbuffer data. Args: flatbuffer_data: The input data to modify. magic_regex: A regex pattern that must match the magic file_identifier characters of flatbuffer_data. header_data: The data to insert into flatbuffer_data. To ensure that flatbuffer internal alignment is preserved, the caller must guaranteed that its length is a power of 2 >= the largest force_align value in the schema. Returns: The modified flatbuffer_data with header_data inserted. Raises: ValueError: If flatbuffer_data is too short to be valid. ValueError: If the magic bytes of flatbuffer_data does not match magic_regex. """ # The binary flatbuffer file should begin with: # - Offset in bytes to root table (4 bytes little endian) # - file_identifier string from the schema (4 bytes, string order) if len(flatbuffer_data) < 8: raise ValueError(f"Flatbuffer data length {len(flatbuffer_data)} < 8") # Ensure that the magic matches. actual_magic: str = flatbuffer_data[4:8].decode(errors="replace") if not re.match(magic_regex, actual_magic): raise ValueError( f"Flatbuffer data magic bytes {repr(actual_magic)} " + f"does not match pattern /{magic_regex}/" ) # Avoid a potentially big allocation/copy if there's nothing to do. if len(header_data) == 0: return flatbuffer_data # We will need to adjust the root object offset after inserting the header. root_offset = int.from_bytes(flatbuffer_data[0:4], byteorder=_HEADER_BYTEORDER) return ( # New root offset. (root_offset + len(header_data)).to_bytes(4, byteorder=_HEADER_BYTEORDER) # Existing magic bytes. + flatbuffer_data[4:8] # Provided header + padding. + header_data # Remainder of the file. Note that this can be O(10MB to 100MB), so it # can trigger a large allocation + copy. + flatbuffer_data[8:] ) @dataclass class _ExtendedHeader: # Class constants # The magic bytes that should be at the beginning of the header. EXPECTED_MAGIC: ClassVar[bytes] = b"eh00" # The length of the header in bytes. EXPECTED_LENGTH: ClassVar[int] = ( # Header magic 4 # Header length + 4 # Flatbuffer data size + 8 # Segment base offset + 8 ) # Instance attributes. @dataclass will turn these into ctor args. # The size of the serialized program data in bytes. program_size: int # Offset to the start of the first segment, or zero if there # are no segments. segment_base_offset: int # The magic bytes read from or to be written to the binary header. magic: bytes = EXPECTED_MAGIC # The header length, in bytes, read from or to be written to the binary # header. length: int = EXPECTED_LENGTH @staticmethod def from_bytes(data: bytes) -> "_ExtendedHeader": """Tries to read an extended header from the provided data. Does not validate that the header is well-formed. Callers should use is_valid(). Args: data: The data to read from. Returns: The contents of the extended header. Raises: ValueError: If not enough data is provided. """ if len(data) < _ExtendedHeader.EXPECTED_LENGTH: raise ValueError( f"Not enough data for extended header: {len(data)} " + f"< {_ExtendedHeader.EXPECTED_LENGTH}" ) return _ExtendedHeader( magic=data[0:4], length=int.from_bytes(data[4:8], byteorder=_HEADER_BYTEORDER), program_size=int.from_bytes(data[8:16], byteorder=_HEADER_BYTEORDER), segment_base_offset=int.from_bytes( data[16:24], byteorder=_HEADER_BYTEORDER ), ) def is_valid(self) -> bool: """Returns true if the extended header appears to be well-formed.""" return ( self.magic == _ExtendedHeader.EXPECTED_MAGIC and self.length >= _ExtendedHeader.EXPECTED_LENGTH ) def to_bytes(self) -> bytes: """Returns the binary representation of the extended header. Note that this will ignore self.magic and self.length and will always write the proper magic/length. """ data: bytes = ( # Extended header magic. This lets consumers detect whether the # header was inserted or not. Always use the proper magic value # (i.e., ignore self.magic) since there's no reason to create an # invalid header. self.EXPECTED_MAGIC # uint32_t: Size of this header. This makes it easier to add new # fields to this header in the future. Always use the proper size # (i.e., ignore self.length) since there's no reason to create an # invalid header. + self.EXPECTED_LENGTH.to_bytes(4, byteorder=_HEADER_BYTEORDER) # uint64_t: Size of the flatbuffer data, including this header. + self.program_size.to_bytes(8, byteorder=_HEADER_BYTEORDER) # uint64_t: Offset to the start of the first segment, or zero if # there are no segments. + self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER) ) return data def _pad_to(data: bytes, length: int) -> bytes: """Returns the input followed by enough zero bytes to become the requested length. Args: data: The data to pad. length: The length of the returned data. Returns: The padded data. Raises: ValueError: If the requested length is less than the input length. """ if length < len(data): raise ValueError(f"Data length {len(data)} > padded length {length}") if length > len(data): data = data + b"\x00" * (length - len(data)) assert len(data) == length return data def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]: """Returns the extended header of the program data, if present and valid.""" try: eh = _ExtendedHeader.from_bytes(program_data[8:]) if eh.is_valid(): return eh except ValueError: pass return None def _extract_delegate_segments( program: Program, segments: List[Cord], ) -> None: """Extracts the delegate segments inlined in the program into a list of buffers. The program is modified in-place to remove the delegate data. Args: program: The program to extract segments from. Modified in-place. segments: A list of buffers to append extracted segments to. Modified in-place. """ remaining_inline: List[BackendDelegateInlineData] = [] inline_indices_seen: set[int] = set() for plan in program.execution_plan: for delegate in plan.delegates: if delegate.processed.location != DataLocation.INLINE: raise ValueError( "Program must only contain inline delegate data, " + f"saw {repr(delegate)}" ) # TODO(T144120904): Don't extract small blobs into segments; # have a cutoff. Or callers could provide a callback that # returns true/false for a given BackendDelegate, letting them # use their own logic. try: inline: BackendDelegateInlineData = program.backend_delegate_data[ delegate.processed.index ] except IndexError: raise ValueError( f"Delegate processed index {delegate.processed.index} " + ">= len(Program.backend_delegate_data) " + f"{len(program.backend_delegate_data)} " + f"in {repr(delegate)}" ) inline_indices_seen.add(delegate.processed.index) if inline.data: # Move the delegate data out of the program. segment_index = len(segments) segments.append(Cord(inline.data)) delegate.processed = BackendDelegateDataReference( location=DataLocation.SEGMENT, index=segment_index, ) else: # Not moving into a segment. Keep it inline, but update the # index. new_index = len(remaining_inline) remaining_inline.append(inline) delegate.processed.index = new_index # Make sure we visited all entries in backend_delegate_data, so that it's # safe to overwrite it. remaining_indices: set[int] = set( range(len(program.backend_delegate_data)) ).difference(inline_indices_seen) if remaining_indices: raise ValueError( "Did not handle all elements of backend_delegate_data; " + f"remaining: {remaining_indices}" ) # Preserve any entries that were not moved into segments. program.backend_delegate_data = remaining_inline def _extract_constant_segment( constant_buffer: List[Buffer], tensor_alignment: Optional[int] = None, ) -> Tuple[Cord, List[int]]: """Copies the tensors from the provided list into a Cord and tracks the offsets of each tensor. Args: constant_buffer: list of Buffers from which to extract constants from. Not modified. tensor_alignment: Alignment in bytes. Each tensor in the cord will be padded to align with this value. Defaults to ALIGNMENT. Returns: A tuple of (constant segment, list of offsets for each tensor in the segment) """ constant_segment_data: Cord = Cord() constant_segment_offsets: List[int] = [] current_offset: int = 0 for i in range(len(constant_buffer)): buffer = constant_buffer[i] constant_segment_data.append(buffer.storage) buffer_length = len(buffer.storage) pad_length = ( _padding_required(buffer_length, tensor_alignment) if tensor_alignment is not None else 0 ) if i < len(constant_buffer) - 1: constant_segment_data.append(b"\x00" * pad_length) constant_segment_offsets.append(current_offset) current_offset += buffer_length + pad_length return constant_segment_data, constant_segment_offsets def serialize_pte_binary( program: Program, *, mutable_data: Optional[List[Buffer]] = None, extract_delegate_segments: bool = False, segment_alignment: int = 128, constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, ) -> Cord: """Returns the runtime binary representation of the given Program. Args: program: The Program to serialize. extract_delegate_segments: Whether to move delegate data blobs from the Program into separate segments, rather than encoding those blobs in the flatbuffer data. When true, will also: - Add an extended header to the output, containing the program size and the starting segment offset. - Update the Program.segments field with the offsets and lengths of each segment. segment_alignment: Alignment in bytes. The starting offset of each segment will be aligned to this value in the output data. constant_tensor_alignment: The minimum alignment of tensor buffers in the program. Must be a power of 2. Defaults to ALIGNMENT. delegate_alignment: If provided, the minimum alignment of delegate data in the program. Must be a power of 2. If not provided, uses the value in the schema file. Returns: The serialized form of the Program, ready for execution by the runtime. """ # Default tensor alignment. if constant_tensor_alignment is None: constant_tensor_alignment = ALIGNMENT # Don't modify the original program. # TODO(T144120904): Could avoid yet more huge copies with a more shallow # copy, reusing the actual data blobs. program = copy.deepcopy(program) # Store extracted segment data; this may be constant data or delegate data. segments: List[Cord] = [] constant_segment_data, constant_segment_offsets = _extract_constant_segment( program.constant_buffer, tensor_alignment=constant_tensor_alignment ) # If there are no constants, len(constant_segment_data) = 0. However, there may # be non-constants, in which case len(constant_segment_offsets) = 1, containing # the placeholder value 0. Ensure the placeholder value is put into # program.constant_segment.offsets. if len(constant_segment_offsets) > 0: # Update program.constant_segment with constant subsegment offset information. program.constant_segment = SubsegmentOffsets( segment_index=len(segments), offsets=constant_segment_offsets ) # Clear the constant buffer, as constant data will be stored in segments. program.constant_buffer = [] # Add to the aggregate segments cord. segments.append(constant_segment_data) if mutable_data is not None: mutable_segment_data, mutable_segment_offsets = _extract_constant_segment( mutable_data, tensor_alignment=None, # data is copied at Method load so no need to align. ) if len(mutable_segment_data) > 0: # Update program.mutable_segment_data with constant subsegment offset information. program.mutable_data_segments = [ SubsegmentOffsets( segment_index=len(segments), offsets=mutable_segment_offsets ), ] # Add to the aggregate segments cord. segments.append(mutable_segment_data) if extract_delegate_segments: _extract_delegate_segments(program, segments) # Append all segments into a single Cord, adding any necessary padding to ensure that # each segment begins at the required alignment. # Update program.segments with the offsets to each segment. segments_data = Cord() for data in segments: prev_end = ( (program.segments[-1].offset + program.segments[-1].size) if program.segments else 0 ) program.segments.append( DataSegment( offset=_aligned_size(prev_end, segment_alignment), size=len(data) ) ) # Add to aggregate segments cord with padding. padding_length = _padding_required(len(segments_data), segment_alignment) if padding_length > 0: segments_data.append(b"\x00" * padding_length) segments_data.append(data) # Convert to a standard flatbuffer binary. result: _FlatbufferResult = _program_json_to_flatbuffer( _program_to_json(program), constant_tensor_alignment=constant_tensor_alignment, delegate_alignment=delegate_alignment, ) # If there are no segments present, do not insert the extended header. if len(segments_data) == 0: return Cord(result.data) # Size of the header to insert. Its size is padded to the largest # force_align value present in the schema. padded_header_length: int = _aligned_size( input_size=_ExtendedHeader.EXPECTED_LENGTH, alignment=result.max_alignment, ) # Size of the program with the header inserted. program_size: int = padded_header_length + len(result.data) # Offset to the first segment, or zero if there are no segments. segment_base_offset: int = ( _aligned_size(input_size=program_size, alignment=segment_alignment) if len(segments_data) > 0 else 0 ) # Construct and pad the extended header. header_data: bytes = _ExtendedHeader( program_size=program_size, segment_base_offset=segment_base_offset ).to_bytes() header_data = _pad_to(header_data, padded_header_length) # Insert the header into the flatbuffer data. program_data: bytes = _insert_flatbuffer_header( flatbuffer_data=result.data, magic_regex=r"ET[0-9a-zA-Z][0-9a-zA-Z]", header_data=header_data, ) assert len(program_data) == program_size # Potentially large. Try to free it as soon as we can. del result.data # Double-check that the extended header is in the right place and has the # right contents. eh = _get_extended_header(program_data) assert eh is not None assert eh.program_size == program_size assert eh.segment_base_offset == segment_base_offset # Construct the final pte file containing: # - program data; written to offset 0. # - segments data (optional); aligned to segment_alignment. pte_data = Cord(program_data) if len(segments_data) > 0: padding_length = _padding_required(len(pte_data), segment_alignment) pte_data.append(b"\x00" * padding_length) # The first segment after program data should start at the segment base offset. assert ( len(pte_data) == segment_base_offset ), f"Offset of first segment {len(pte_data)} != segment base offset {segment_base_offset}" pte_data.append(segments_data) return pte_data def _restore_segments(program: Program, segment_data: bytes) -> Program: """Moves segments from `segment_data` into `program`. This should recreate the original Program that the segments were extracted from. Args: program: The Program to restore. `program.segments` must describe the segment locations. segment_data: The data containing the segments. Assumes that this data begins at `segment_base_offset` from the extended header: i.e., the preceding data has been stripped off so that the first segment begins at offset zero. Returns: The Program with segments restored. """ # Extract the list of segment data blobs, which parallel program.segments. segments: List[bytes] = [] for i, segment in enumerate(program.segments): if segment.offset + segment.size > len(segment_data): raise ValueError( f"Segment {i} {segment} overflows data length {len(segment_data)}" ) segments.append(segment_data[segment.offset : segment.offset + segment.size]) # Find and replace the Program's references to these segments, inlining the # data. for plan_index, plan in enumerate(program.execution_plan): for delegate_index, delegate in enumerate(plan.delegates): if delegate.processed.location == DataLocation.INLINE: continue assert delegate.processed.location == DataLocation.SEGMENT index = delegate.processed.index if index >= len(segments): raise ValueError( f"Plan {plan_index} delegate {delegate_index} " + f"segment index {index} >= num segments {len(segments)}" ) data_index: int = len(program.backend_delegate_data) program.backend_delegate_data.append( BackendDelegateInlineData(data=segments[index]) ) delegate.processed = BackendDelegateDataReference( location=DataLocation.INLINE, index=data_index ) # Replace constants from constant_segment into constant_buffer. if program.constant_segment and len(program.constant_segment.offsets) > 0: buffers: List[Buffer] = [] constant_segment = segments[program.constant_segment.segment_index] for i in range(len(program.constant_segment.offsets)): start_offset = program.constant_segment.offsets[i] # Note: this is the original end offset plus any padding between # it and the next start offset. end_offset = ( program.constant_segment.offsets[i + 1] if i < len(program.constant_segment.offsets) - 1 else len(constant_segment) ) buffers.append(Buffer(storage=constant_segment[start_offset:end_offset])) program.constant_buffer = buffers program.constant_segment.segment_index = 0 program.constant_segment.offsets = [] # Clear out the segments list since the original Program didn't have one. program.segments = [] return program def deserialize_pte_binary(program_data: bytes) -> Program: """Returns a Program deserialized from the given runtime binary data.""" program_size = len(program_data) segment_base_offset = 0 # Look for an extended header to see if segments follow the flatbuffer # data. eh: Optional[_ExtendedHeader] = _get_extended_header(program_data) if eh and eh.is_valid(): program_size = eh.program_size segment_base_offset = eh.segment_base_offset # Parse the flatbuffer data. program: Program = _json_to_program( _program_flatbuffer_to_json(program_data[:program_size]) ) if segment_base_offset != 0: # Move segment data back into the Program. program = _restore_segments( program=program, segment_data=program_data[segment_base_offset:] ) return program