xref: /aosp_15_r20/external/executorch/exir/memory.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerfrom typing import List, Tuple, Union
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport torch
12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.sym_util import eval_shape
13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
16*523fa7a6SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard WorkerTensorAllocSpec: TypeAlias = Tuple[Tuple[int], torch.dtype]
19*523fa7a6SAndroid Build Coastguard WorkerAllocSpec: TypeAlias = Union[
20*523fa7a6SAndroid Build Coastguard Worker    TensorAllocSpec,
21*523fa7a6SAndroid Build Coastguard Worker    List[TensorAllocSpec],
22*523fa7a6SAndroid Build Coastguard Worker]
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Workerdef alloc(spec: AllocSpec) -> pytree.PyTree:
26*523fa7a6SAndroid Build Coastguard Worker    if isinstance(spec, list):
27*523fa7a6SAndroid Build Coastguard Worker        return [alloc(s) for s in spec]
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker    shape, dtype = spec
30*523fa7a6SAndroid Build Coastguard Worker    # evaluate the shape to int so we can run the traced module
31*523fa7a6SAndroid Build Coastguard Worker    # in python for testing
32*523fa7a6SAndroid Build Coastguard Worker    shape = eval_shape(shape)
33*523fa7a6SAndroid Build Coastguard Worker    return torch.empty(shape, dtype=dtype)
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Workerdef free(spec: TensorSpec) -> None:
37*523fa7a6SAndroid Build Coastguard Worker    """
38*523fa7a6SAndroid Build Coastguard Worker    The function is nop. The major purpose is to put it in the Fx IR.
39*523fa7a6SAndroid Build Coastguard Worker    E.g., it can be the target of call_function node.
40*523fa7a6SAndroid Build Coastguard Worker    """
41*523fa7a6SAndroid Build Coastguard Worker    pass
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Workerdef view(base: torch.Tensor, size: List[int]) -> torch.Tensor:
45*523fa7a6SAndroid Build Coastguard Worker    """
46*523fa7a6SAndroid Build Coastguard Worker    This function mimics torch.ops.aten.view.default.
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker    It is used to elide view_copy nodes.
49*523fa7a6SAndroid Build Coastguard Worker    """
50*523fa7a6SAndroid Build Coastguard Worker    return base.view(size)
51