1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8# pyre-ignore-all-errors[6] 9# pyre-ignore-all-errors[16] 10from __future__ import annotations 11 12import copy 13 14import math 15import typing 16from typing import Dict, List, Optional, Tuple, Union 17 18import executorch.exir.schema as schema 19import torch 20from executorch.exir.error import internal_assert 21from executorch.exir.schema import ScalarType, TensorShapeDynamism 22from executorch.exir.sym_util import eval_shape 23 24 25class AddressSpaceOverflowException(Exception): 26 pass 27 28 29def num_bytes_from_shape_and_dtype(shape: torch.Size, dtype: torch.dtype) -> int: 30 """ 31 Assume the tensor is a contiguous one. 32 """ 33 34 return math.prod(shape) * torch._utils._element_size(dtype) 35 36 37def contiguous_stride_from_shape(shape: torch.Size) -> Tuple[int]: 38 strides = [] 39 accum = 1 40 for sz in reversed(shape): 41 strides.append(accum) 42 # For sizes[i] == 0, treat it as 1 to be consistent with core Pytorch 43 # This preserves the PT equivalent behavior for dims with 0 elements 44 if isinstance(sz, int): 45 if sz != 0: 46 accum *= sz 47 else: 48 # Unbacked symints may error on the != 0 check 49 accum *= sz 50 return tuple(reversed(strides)) 51 52 53def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]: 54 """ 55 Dimension order represents how dimensions are laid out in memory, 56 starting from the outer-most to the inner-most dimension. 57 Thus, the conversion from strides is done by sorting the strides 58 from larger to smaller since the dimension with the largest stride 59 is the outer-most and the dimension with the smallest stride is the inner-most. 60 For example, tensor with sizes = (3, 5, 2) and strides = (5, 1, 15), implies 61 dimension order of (2, 0, 1). Dimension order of (2, 0, 1) can be obtained 62 by sorting strides from large to smaller. 63 64 When strides do not convey dimension order unambiguously, dimension order 65 returned is dependent on stability of sort. In python same key elements are kept 66 in original order. Thus when strides = (4, 3, 1, 1) returned value is (0, 1, 2, 3) 67 Another example is: sizes = (1, 3, 1, 1) with strides = (3, 1, 3, 3), returned 68 value is (0, 2, 3, 1) 69 """ 70 for _, s in enumerate(stride): 71 if s == 0: 72 raise ValueError("0 in strides is not supported for ExecuTorch.") 73 sorted_dims = [ 74 i[0] for i in sorted(enumerate(stride), key=lambda x: x[1], reverse=True) 75 ] 76 return tuple(typing.cast(Tuple[bytes], sorted_dims)) 77 78 79def stride_from_dim_order(sizes: List[int], dim_order: List[bytes]) -> List[int]: 80 """ 81 Converts dim order to stride using sizes 82 e.g. if sizes = (2, 3, 4) and dim_order = (0, 1, 2) then strides = (12, 4, 1) 83 while for the same size if dim_order = (0, 2, 1) then strides = (12, 1, 3) 84 See executorch/runtime/core/exec_aten/util/dim_order_util.h for details 85 Args: 86 sizes (Tuple[int]): sizes of the tensor 87 dim_order (Tuple[bytes]): dim order of the tensor 88 Returns: 89 Tuple[int]: stride 90 """ 91 if len(sizes) == 0: 92 return [] 93 strides = copy.deepcopy(sizes) 94 ndim = len(sizes) 95 strides[dim_order[ndim - 1]] = 1 96 for i in range(ndim - 2, -1, -1): 97 if sizes[dim_order[i + 1]] == 0: 98 strides[dim_order[i]] = strides[dim_order[i + 1]] 99 else: 100 strides[dim_order[i]] = sizes[dim_order[i + 1]] * strides[dim_order[i + 1]] 101 return strides 102 103 104def calculate_aligned_num_bytes(num: int, alignment: int) -> int: 105 return math.ceil(num / alignment) * alignment 106 107 108def determine_tensor_dynanism(shape: torch.Size) -> TensorShapeDynamism: 109 if all(isinstance(s, int) for s in shape): 110 return TensorShapeDynamism.STATIC 111 else: 112 try: 113 _ = eval_shape(shape) 114 return TensorShapeDynamism.DYNAMIC_BOUND 115 except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: 116 return TensorShapeDynamism.DYNAMIC_UNBOUND 117 118 119ALIGNMENT = 16 120 121 122class TensorSpec: 123 """ 124 Captures the metadata for a given Tensor (ex. scalar type, storage, etc.). 125 """ 126 127 def __init__( 128 self, 129 dtype: torch.dtype, 130 shape: torch.Size, 131 layout: torch.layout = torch.strided, 132 is_sparse: bool = False, 133 const: bool = False, 134 requires_grad: bool = False, 135 ) -> None: 136 self.scalar_type = dtype 137 self.const = const 138 self.alignment: int = ALIGNMENT 139 self.storage: Optional[torch.UntypedStorage] = None 140 # convert to list making it easier to handle type checking 141 self.shape: List[int] = list(shape) 142 self.stride: Tuple[int] = contiguous_stride_from_shape(shape) 143 self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride) 144 self.requires_grad = requires_grad 145 self.layout = layout 146 self.is_sparse = is_sparse 147 self.init_mem_planning_fields() 148 self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape) 149 150 @property 151 def allocated_memory(self) -> int: 152 nbytes = num_bytes_from_shape_and_dtype(self.shape, self.dtype) 153 return calculate_aligned_num_bytes(nbytes, self.alignment) 154 155 def realign(self, new_alignment: int) -> int: 156 self.alignment = new_alignment 157 return self.allocated_memory 158 159 def nbytes(self) -> int: 160 return num_bytes_from_shape_and_dtype(self.shape, self.dtype) 161 162 @classmethod 163 def from_tensor(cls, tensor: torch.Tensor, const: bool = False) -> TensorSpec: 164 if const: 165 # for non-contigous tensors, convert to a contiguous one 166 tensor = tensor.contiguous() 167 # Weights cannot be views during emission or serialization 168 if tensor.nbytes != tensor.untyped_storage().nbytes(): 169 tensor = tensor.clone() 170 171 spec = cls( 172 dtype=tensor.dtype, 173 shape=tensor.shape, 174 layout=tensor.layout, 175 const=const, 176 is_sparse=tensor.is_sparse, 177 ) 178 spec.stride = tensor.stride() 179 spec.dim_order = dim_order_from_stride(spec.stride) 180 spec.requires_grad = tensor.requires_grad 181 spec.storage = tensor.untyped_storage() if const else None 182 183 return spec 184 185 def init_mem_planning_fields(self) -> None: 186 self.lifetime = [None, None] 187 self.mem_id = None 188 self.mem_obj_id = None 189 self.mem_offset = None 190 191 @property 192 def dtype(self) -> torch.dtype: 193 return self.scalar_type 194 195 @property 196 def is_dynamic_shape_tensor(self) -> bool: 197 return self.shape_dynamism != schema.TensorShapeDynamism.STATIC 198 199 @property 200 def is_static_shape_tensor(self) -> bool: 201 return self.shape_dynamism == TensorShapeDynamism.STATIC 202 203 @property 204 def is_upper_bound_tensor(self) -> bool: 205 return self.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND 206 207 @property 208 def is_dynamic_unbound_tensor(self) -> bool: 209 return self.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND 210 211 def debug(self) -> str: 212 return ( 213 f"TensorSpec(id={id(self)}, const={self.const}, scalar_type={self.scalar_type}" 214 + f", allocated_memory={self.allocated_memory}, mem_id={self.mem_id}" 215 + f", mem_offset={self.mem_offset}, lifetime={self.lifetime}" 216 + f", shape_dynamism={self.shape_dynamism}" 217 + (f", shape={self.shape}") 218 + ")" 219 ) 220 221 def __repr__(self) -> str: 222 """ 223 Round-trippable printing function 224 """ 225 return ( 226 f"TensorSpec(dtype={self.scalar_type}, shape={self.shape}" 227 + f", layout={self.layout}" 228 + f", is_sparse={self.is_sparse}" 229 + f", shape_dynamism={self.shape_dynamism}" 230 + f", const={self.const}, requires_grad={self.requires_grad}" 231 + ")" 232 ) 233 234 235def memory_format_enum(memory_format: torch.memory_format) -> int: 236 internal_assert( 237 isinstance(memory_format, torch.memory_format), 238 "We only support torch.memory_format", 239 ) 240 table = { 241 torch.contiguous_format: 0, 242 torch.preserve_format: 1, 243 } 244 return table[memory_format] 245 246 247scalar_type_table: Dict[torch.dtype, ScalarType] = { 248 torch.uint8: ScalarType.BYTE, 249 torch.int8: ScalarType.CHAR, 250 torch.int16: ScalarType.SHORT, 251 torch.int32: ScalarType.INT, 252 torch.int64: ScalarType.LONG, 253 torch.half: ScalarType.HALF, 254 torch.float: ScalarType.FLOAT, 255 torch.double: ScalarType.DOUBLE, 256 torch.complex32: ScalarType.COMPLEX32, 257 torch.complex64: ScalarType.COMPLEX64, 258 torch.complex128: ScalarType.COMPLEX128, 259 torch.bool: ScalarType.BOOL, 260 torch.qint8: ScalarType.QINT8, 261 torch.quint8: ScalarType.QUINT8, 262 torch.qint32: ScalarType.QINT32, 263 torch.bfloat16: ScalarType.BFLOAT16, 264 torch.quint4x2: ScalarType.QUINT4x2, 265 torch.uint16: ScalarType.UINT16, 266} 267 268 269enum_to_scalar_map: Dict[ScalarType, torch.dtype] = { 270 scalar_type_table[key]: key for key in scalar_type_table 271} 272 273 274def scalar_type_enum(dtype: torch.dtype) -> ScalarType: 275 # TODO (zhengxu) single source of truth from c10/core/ScalarType.h. 276 internal_assert( 277 isinstance(dtype, torch.dtype), "We only support dtypes defined in Pytorch Core" 278 ) 279 return scalar_type_table[dtype] 280 281 282def get_scalar_type(enum: ScalarType) -> torch.dtype: 283 return enum_to_scalar_map[enum] 284 285 286def layout_enum(layout: torch.layout) -> int: 287 # TODO single source of truth. 288 table = { 289 torch.strided: 0, 290 torch.sparse_coo: 1, 291 } 292 return table[layout] 293 294 295def make_allocation_info(mem_id: int, mem_offset: int) -> schema.AllocationDetails: 296 """ 297 Creates the allocation_details object for creating tensors 298 """ 299 if mem_offset < 0: 300 raise ValueError(f"mem_offset {mem_offset} must not be negative") 301 memory_offset_low = mem_offset & ((1 << 32) - 1) 302 memory_offset_high = mem_offset >> 32 303 if memory_offset_high >= 1 << 32: 304 raise AddressSpaceOverflowException( 305 f"mem_offset {mem_offset} does not fit in 64 bits" 306 ) 307 308 allocation_info = schema.AllocationDetails( 309 memory_id=mem_id, 310 memory_offset_low=memory_offset_low, 311 memory_offset_high=memory_offset_high, 312 ) 313 return allocation_info 314 315 316def make_tensor_value( 317 data_buffer_idx: int, 318 allocation_info: Optional[schema.AllocationDetails], 319 spec: TensorSpec, 320) -> schema.Tensor: 321 """ 322 Converts the normal torch tensor to a flatbuffer tensor. 323 """ 324 325 def to_list( 326 x: Union[torch.Size, int, List[int], Tuple[int]] 327 ) -> Union[List[int], List[torch.Size]]: 328 if isinstance(x, torch.Size) or isinstance(x, tuple): 329 return list(x) 330 elif isinstance(x, int): 331 return [x] 332 else: 333 return x 334 335 tensor_size = to_list(spec.shape) 336 tensor_dim_order = to_list(spec.dim_order) 337 338 flatbuffer_tensor = schema.Tensor( 339 scalar_type=scalar_type_enum(spec.scalar_type), 340 # The runtime currently only supports tensors with offsets of zero. 341 storage_offset=0, 342 sizes=tensor_size, 343 dim_order=tensor_dim_order, 344 requires_grad=spec.requires_grad, 345 data_buffer_idx=data_buffer_idx, 346 allocation_info=allocation_info, 347 layout=layout_enum(spec.layout), 348 shape_dynamism=spec.shape_dynamism, 349 ) 350 return flatbuffer_tensor 351 352 353def check_spec(tensor: torch.Tensor, spec: TensorSpec) -> None: 354 internal_assert( 355 tensor.is_sparse == spec.is_sparse, 356 f"Tensor attribute 'is_sparse' is expected to be equal to '{spec.is_sparse}', actually got: '{tensor.is_sparse}'", 357 ) 358 internal_assert( 359 tensor.shape == spec.shape, 360 f"Tensor attribute 'shape' is expected to be equal to '{spec.shape}', actually got: '{tensor.shape}'", 361 ) 362 internal_assert( 363 tensor.dtype == spec.dtype, 364 f"Tensor attribute 'dtype' is expected to be equal to '{spec.dtype}', actually got: '{tensor.dtype}'", 365 ) 366