1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerfrom typing import List, Tuple, Union 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport torch 12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.sym_util import eval_shape 13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 16*523fa7a6SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard WorkerTensorAllocSpec: TypeAlias = Tuple[Tuple[int], torch.dtype] 19*523fa7a6SAndroid Build Coastguard WorkerAllocSpec: TypeAlias = Union[ 20*523fa7a6SAndroid Build Coastguard Worker TensorAllocSpec, 21*523fa7a6SAndroid Build Coastguard Worker List[TensorAllocSpec], 22*523fa7a6SAndroid Build Coastguard Worker] 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Workerdef alloc(spec: AllocSpec) -> pytree.PyTree: 26*523fa7a6SAndroid Build Coastguard Worker if isinstance(spec, list): 27*523fa7a6SAndroid Build Coastguard Worker return [alloc(s) for s in spec] 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker shape, dtype = spec 30*523fa7a6SAndroid Build Coastguard Worker # evaluate the shape to int so we can run the traced module 31*523fa7a6SAndroid Build Coastguard Worker # in python for testing 32*523fa7a6SAndroid Build Coastguard Worker shape = eval_shape(shape) 33*523fa7a6SAndroid Build Coastguard Worker return torch.empty(shape, dtype=dtype) 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Workerdef free(spec: TensorSpec) -> None: 37*523fa7a6SAndroid Build Coastguard Worker """ 38*523fa7a6SAndroid Build Coastguard Worker The function is nop. The major purpose is to put it in the Fx IR. 39*523fa7a6SAndroid Build Coastguard Worker E.g., it can be the target of call_function node. 40*523fa7a6SAndroid Build Coastguard Worker """ 41*523fa7a6SAndroid Build Coastguard Worker pass 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Workerdef view(base: torch.Tensor, size: List[int]) -> torch.Tensor: 45*523fa7a6SAndroid Build Coastguard Worker """ 46*523fa7a6SAndroid Build Coastguard Worker This function mimics torch.ops.aten.view.default. 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker It is used to elide view_copy nodes. 49*523fa7a6SAndroid Build Coastguard Worker """ 50*523fa7a6SAndroid Build Coastguard Worker return base.view(size) 51