1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import List, Tuple, Union 10 11import torch 12from executorch.exir.sym_util import eval_shape 13from executorch.exir.tensor import TensorSpec 14 15from torch.utils import _pytree as pytree 16from typing_extensions import TypeAlias 17 18TensorAllocSpec: TypeAlias = Tuple[Tuple[int], torch.dtype] 19AllocSpec: TypeAlias = Union[ 20 TensorAllocSpec, 21 List[TensorAllocSpec], 22] 23 24 25def alloc(spec: AllocSpec) -> pytree.PyTree: 26 if isinstance(spec, list): 27 return [alloc(s) for s in spec] 28 29 shape, dtype = spec 30 # evaluate the shape to int so we can run the traced module 31 # in python for testing 32 shape = eval_shape(shape) 33 return torch.empty(shape, dtype=dtype) 34 35 36def free(spec: TensorSpec) -> None: 37 """ 38 The function is nop. The major purpose is to put it in the Fx IR. 39 E.g., it can be the target of call_function node. 40 """ 41 pass 42 43 44def view(base: torch.Tensor, size: List[int]) -> torch.Tensor: 45 """ 46 This function mimics torch.ops.aten.view.default. 47 48 It is used to elide view_copy nodes. 49 """ 50 return base.view(size) 51