xref: /aosp_15_r20/external/executorch/exir/memory.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import List, Tuple, Union
10
11import torch
12from executorch.exir.sym_util import eval_shape
13from executorch.exir.tensor import TensorSpec
14
15from torch.utils import _pytree as pytree
16from typing_extensions import TypeAlias
17
18TensorAllocSpec: TypeAlias = Tuple[Tuple[int], torch.dtype]
19AllocSpec: TypeAlias = Union[
20    TensorAllocSpec,
21    List[TensorAllocSpec],
22]
23
24
25def alloc(spec: AllocSpec) -> pytree.PyTree:
26    if isinstance(spec, list):
27        return [alloc(s) for s in spec]
28
29    shape, dtype = spec
30    # evaluate the shape to int so we can run the traced module
31    # in python for testing
32    shape = eval_shape(shape)
33    return torch.empty(shape, dtype=dtype)
34
35
36def free(spec: TensorSpec) -> None:
37    """
38    The function is nop. The major purpose is to put it in the Fx IR.
39    E.g., it can be the target of call_function node.
40    """
41    pass
42
43
44def view(base: torch.Tensor, size: List[int]) -> torch.Tensor:
45    """
46    This function mimics torch.ops.aten.view.default.
47
48    It is used to elide view_copy nodes.
49    """
50    return base.view(size)
51