xref: /aosp_15_r20/external/executorch/exir/tensor.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8# pyre-ignore-all-errors[6]
9# pyre-ignore-all-errors[16]
10from __future__ import annotations
11
12import copy
13
14import math
15import typing
16from typing import Dict, List, Optional, Tuple, Union
17
18import executorch.exir.schema as schema
19import torch
20from executorch.exir.error import internal_assert
21from executorch.exir.schema import ScalarType, TensorShapeDynamism
22from executorch.exir.sym_util import eval_shape
23
24
25class AddressSpaceOverflowException(Exception):
26    pass
27
28
29def num_bytes_from_shape_and_dtype(shape: torch.Size, dtype: torch.dtype) -> int:
30    """
31    Assume the tensor is a contiguous one.
32    """
33
34    return math.prod(shape) * torch._utils._element_size(dtype)
35
36
37def contiguous_stride_from_shape(shape: torch.Size) -> Tuple[int]:
38    strides = []
39    accum = 1
40    for sz in reversed(shape):
41        strides.append(accum)
42        # For sizes[i] == 0, treat it as 1 to be consistent with core Pytorch
43        # This preserves the PT equivalent behavior for dims with 0 elements
44        if isinstance(sz, int):
45            if sz != 0:
46                accum *= sz
47        else:
48            # Unbacked symints may error on the != 0 check
49            accum *= sz
50    return tuple(reversed(strides))
51
52
53def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]:
54    """
55    Dimension order represents how dimensions are laid out in memory,
56    starting from the outer-most to the inner-most dimension.
57    Thus, the conversion from strides is done by sorting the strides
58    from larger to smaller since the dimension with the largest stride
59    is the outer-most and the dimension with the smallest stride is the inner-most.
60    For example, tensor with sizes = (3, 5, 2) and strides = (5, 1, 15), implies
61    dimension order of (2, 0, 1). Dimension order of (2, 0, 1) can be obtained
62    by sorting strides from large to smaller.
63
64    When strides do not convey dimension order unambiguously, dimension order
65    returned is dependent on stability of sort. In python same key elements are kept
66    in original order. Thus when strides = (4, 3, 1, 1) returned value is (0, 1, 2, 3)
67    Another example is: sizes = (1, 3, 1, 1) with strides = (3, 1, 3, 3), returned
68    value is (0, 2, 3, 1)
69    """
70    for _, s in enumerate(stride):
71        if s == 0:
72            raise ValueError("0 in strides is not supported for ExecuTorch.")
73    sorted_dims = [
74        i[0] for i in sorted(enumerate(stride), key=lambda x: x[1], reverse=True)
75    ]
76    return tuple(typing.cast(Tuple[bytes], sorted_dims))
77
78
79def stride_from_dim_order(sizes: List[int], dim_order: List[bytes]) -> List[int]:
80    """
81    Converts dim order to stride using sizes
82    e.g. if sizes = (2, 3, 4) and dim_order = (0, 1, 2) then strides = (12, 4, 1)
83    while for the same size if dim_order = (0, 2, 1) then strides = (12, 1, 3)
84    See executorch/runtime/core/exec_aten/util/dim_order_util.h for details
85    Args:
86        sizes (Tuple[int]): sizes of the tensor
87        dim_order (Tuple[bytes]): dim order of the tensor
88    Returns:
89        Tuple[int]: stride
90    """
91    if len(sizes) == 0:
92        return []
93    strides = copy.deepcopy(sizes)
94    ndim = len(sizes)
95    strides[dim_order[ndim - 1]] = 1
96    for i in range(ndim - 2, -1, -1):
97        if sizes[dim_order[i + 1]] == 0:
98            strides[dim_order[i]] = strides[dim_order[i + 1]]
99        else:
100            strides[dim_order[i]] = sizes[dim_order[i + 1]] * strides[dim_order[i + 1]]
101    return strides
102
103
104def calculate_aligned_num_bytes(num: int, alignment: int) -> int:
105    return math.ceil(num / alignment) * alignment
106
107
108def determine_tensor_dynanism(shape: torch.Size) -> TensorShapeDynamism:
109    if all(isinstance(s, int) for s in shape):
110        return TensorShapeDynamism.STATIC
111    else:
112        try:
113            _ = eval_shape(shape)
114            return TensorShapeDynamism.DYNAMIC_BOUND
115        except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
116            return TensorShapeDynamism.DYNAMIC_UNBOUND
117
118
119ALIGNMENT = 16
120
121
122class TensorSpec:
123    """
124    Captures the metadata for a given Tensor (ex. scalar type, storage, etc.).
125    """
126
127    def __init__(
128        self,
129        dtype: torch.dtype,
130        shape: torch.Size,
131        layout: torch.layout = torch.strided,
132        is_sparse: bool = False,
133        const: bool = False,
134        requires_grad: bool = False,
135    ) -> None:
136        self.scalar_type = dtype
137        self.const = const
138        self.alignment: int = ALIGNMENT
139        self.storage: Optional[torch.UntypedStorage] = None
140        # convert to list making it easier to handle type checking
141        self.shape: List[int] = list(shape)
142        self.stride: Tuple[int] = contiguous_stride_from_shape(shape)
143        self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride)
144        self.requires_grad = requires_grad
145        self.layout = layout
146        self.is_sparse = is_sparse
147        self.init_mem_planning_fields()
148        self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape)
149
150    @property
151    def allocated_memory(self) -> int:
152        nbytes = num_bytes_from_shape_and_dtype(self.shape, self.dtype)
153        return calculate_aligned_num_bytes(nbytes, self.alignment)
154
155    def realign(self, new_alignment: int) -> int:
156        self.alignment = new_alignment
157        return self.allocated_memory
158
159    def nbytes(self) -> int:
160        return num_bytes_from_shape_and_dtype(self.shape, self.dtype)
161
162    @classmethod
163    def from_tensor(cls, tensor: torch.Tensor, const: bool = False) -> TensorSpec:
164        if const:
165            # for non-contigous tensors, convert to a contiguous one
166            tensor = tensor.contiguous()
167            # Weights cannot be views during emission or serialization
168            if tensor.nbytes != tensor.untyped_storage().nbytes():
169                tensor = tensor.clone()
170
171        spec = cls(
172            dtype=tensor.dtype,
173            shape=tensor.shape,
174            layout=tensor.layout,
175            const=const,
176            is_sparse=tensor.is_sparse,
177        )
178        spec.stride = tensor.stride()
179        spec.dim_order = dim_order_from_stride(spec.stride)
180        spec.requires_grad = tensor.requires_grad
181        spec.storage = tensor.untyped_storage() if const else None
182
183        return spec
184
185    def init_mem_planning_fields(self) -> None:
186        self.lifetime = [None, None]
187        self.mem_id = None
188        self.mem_obj_id = None
189        self.mem_offset = None
190
191    @property
192    def dtype(self) -> torch.dtype:
193        return self.scalar_type
194
195    @property
196    def is_dynamic_shape_tensor(self) -> bool:
197        return self.shape_dynamism != schema.TensorShapeDynamism.STATIC
198
199    @property
200    def is_static_shape_tensor(self) -> bool:
201        return self.shape_dynamism == TensorShapeDynamism.STATIC
202
203    @property
204    def is_upper_bound_tensor(self) -> bool:
205        return self.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND
206
207    @property
208    def is_dynamic_unbound_tensor(self) -> bool:
209        return self.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND
210
211    def debug(self) -> str:
212        return (
213            f"TensorSpec(id={id(self)}, const={self.const}, scalar_type={self.scalar_type}"
214            + f", allocated_memory={self.allocated_memory}, mem_id={self.mem_id}"
215            + f", mem_offset={self.mem_offset}, lifetime={self.lifetime}"
216            + f", shape_dynamism={self.shape_dynamism}"
217            + (f", shape={self.shape}")
218            + ")"
219        )
220
221    def __repr__(self) -> str:
222        """
223        Round-trippable printing function
224        """
225        return (
226            f"TensorSpec(dtype={self.scalar_type}, shape={self.shape}"
227            + f", layout={self.layout}"
228            + f", is_sparse={self.is_sparse}"
229            + f", shape_dynamism={self.shape_dynamism}"
230            + f", const={self.const}, requires_grad={self.requires_grad}"
231            + ")"
232        )
233
234
235def memory_format_enum(memory_format: torch.memory_format) -> int:
236    internal_assert(
237        isinstance(memory_format, torch.memory_format),
238        "We only support torch.memory_format",
239    )
240    table = {
241        torch.contiguous_format: 0,
242        torch.preserve_format: 1,
243    }
244    return table[memory_format]
245
246
247scalar_type_table: Dict[torch.dtype, ScalarType] = {
248    torch.uint8: ScalarType.BYTE,
249    torch.int8: ScalarType.CHAR,
250    torch.int16: ScalarType.SHORT,
251    torch.int32: ScalarType.INT,
252    torch.int64: ScalarType.LONG,
253    torch.half: ScalarType.HALF,
254    torch.float: ScalarType.FLOAT,
255    torch.double: ScalarType.DOUBLE,
256    torch.complex32: ScalarType.COMPLEX32,
257    torch.complex64: ScalarType.COMPLEX64,
258    torch.complex128: ScalarType.COMPLEX128,
259    torch.bool: ScalarType.BOOL,
260    torch.qint8: ScalarType.QINT8,
261    torch.quint8: ScalarType.QUINT8,
262    torch.qint32: ScalarType.QINT32,
263    torch.bfloat16: ScalarType.BFLOAT16,
264    torch.quint4x2: ScalarType.QUINT4x2,
265    torch.uint16: ScalarType.UINT16,
266}
267
268
269enum_to_scalar_map: Dict[ScalarType, torch.dtype] = {
270    scalar_type_table[key]: key for key in scalar_type_table
271}
272
273
274def scalar_type_enum(dtype: torch.dtype) -> ScalarType:
275    # TODO (zhengxu) single source of truth from c10/core/ScalarType.h.
276    internal_assert(
277        isinstance(dtype, torch.dtype), "We only support dtypes defined in Pytorch Core"
278    )
279    return scalar_type_table[dtype]
280
281
282def get_scalar_type(enum: ScalarType) -> torch.dtype:
283    return enum_to_scalar_map[enum]
284
285
286def layout_enum(layout: torch.layout) -> int:
287    # TODO single source of truth.
288    table = {
289        torch.strided: 0,
290        torch.sparse_coo: 1,
291    }
292    return table[layout]
293
294
295def make_allocation_info(mem_id: int, mem_offset: int) -> schema.AllocationDetails:
296    """
297    Creates the allocation_details object for creating tensors
298    """
299    if mem_offset < 0:
300        raise ValueError(f"mem_offset {mem_offset} must not be negative")
301    memory_offset_low = mem_offset & ((1 << 32) - 1)
302    memory_offset_high = mem_offset >> 32
303    if memory_offset_high >= 1 << 32:
304        raise AddressSpaceOverflowException(
305            f"mem_offset {mem_offset} does not fit in 64 bits"
306        )
307
308    allocation_info = schema.AllocationDetails(
309        memory_id=mem_id,
310        memory_offset_low=memory_offset_low,
311        memory_offset_high=memory_offset_high,
312    )
313    return allocation_info
314
315
316def make_tensor_value(
317    data_buffer_idx: int,
318    allocation_info: Optional[schema.AllocationDetails],
319    spec: TensorSpec,
320) -> schema.Tensor:
321    """
322    Converts the normal torch tensor to a flatbuffer tensor.
323    """
324
325    def to_list(
326        x: Union[torch.Size, int, List[int], Tuple[int]]
327    ) -> Union[List[int], List[torch.Size]]:
328        if isinstance(x, torch.Size) or isinstance(x, tuple):
329            return list(x)
330        elif isinstance(x, int):
331            return [x]
332        else:
333            return x
334
335    tensor_size = to_list(spec.shape)
336    tensor_dim_order = to_list(spec.dim_order)
337
338    flatbuffer_tensor = schema.Tensor(
339        scalar_type=scalar_type_enum(spec.scalar_type),
340        # The runtime currently only supports tensors with offsets of zero.
341        storage_offset=0,
342        sizes=tensor_size,
343        dim_order=tensor_dim_order,
344        requires_grad=spec.requires_grad,
345        data_buffer_idx=data_buffer_idx,
346        allocation_info=allocation_info,
347        layout=layout_enum(spec.layout),
348        shape_dynamism=spec.shape_dynamism,
349    )
350    return flatbuffer_tensor
351
352
353def check_spec(tensor: torch.Tensor, spec: TensorSpec) -> None:
354    internal_assert(
355        tensor.is_sparse == spec.is_sparse,
356        f"Tensor attribute 'is_sparse' is expected to be equal to '{spec.is_sparse}', actually got: '{tensor.is_sparse}'",
357    )
358    internal_assert(
359        tensor.shape == spec.shape,
360        f"Tensor attribute 'shape' is expected to be equal to '{spec.shape}', actually got: '{tensor.shape}'",
361    )
362    internal_assert(
363        tensor.dtype == spec.dtype,
364        f"Tensor attribute 'dtype' is expected to be equal to '{spec.dtype}', actually got: '{tensor.dtype}'",
365    )
366