1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker""" 10*523fa7a6SAndroid Build Coastguard WorkerThis module contains APIs to manipulate ops. 11*523fa7a6SAndroid Build Coastguard Worker""" 12*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass 13*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerimport torch 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker@dataclass 20*523fa7a6SAndroid Build Coastguard Workerclass ScratchTensorMetadata: 21*523fa7a6SAndroid Build Coastguard Worker dtype: torch.dtype 22*523fa7a6SAndroid Build Coastguard Worker shape: torch.Size 23*523fa7a6SAndroid Build Coastguard Worker layout: torch.layout = torch.strided 24*523fa7a6SAndroid Build Coastguard Worker device: torch.device = torch.device("cpu") 25*523fa7a6SAndroid Build Coastguard Worker is_sparse: bool = False 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard WorkerScratchCallableType = Callable[..., Dict[str, ScratchTensorMetadata]] 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Workerdef attach_get_scratch_metas_fn( 32*523fa7a6SAndroid Build Coastguard Worker out_variant: torch._ops.OpOverload, 33*523fa7a6SAndroid Build Coastguard Worker) -> Callable[[ScratchCallableType], ScratchCallableType]: 34*523fa7a6SAndroid Build Coastguard Worker """ 35*523fa7a6SAndroid Build Coastguard Worker Apply this decorator to the get_scratch_metas methods for `out_variant` op. 36*523fa7a6SAndroid Build Coastguard Worker The decorator will do the job to attach the get_scratch_metas method 37*523fa7a6SAndroid Build Coastguard Worker to the out variant op. 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker The get_scratch_metas method has the same signature as the out variant op. 40*523fa7a6SAndroid Build Coastguard Worker There are 2 difference though: 41*523fa7a6SAndroid Build Coastguard Worker - the Tensor input arguments are all replaced with TensorSpec 42*523fa7a6SAndroid Build Coastguard Worker - the output is a dictionary of ScratchTensorMetadata 43*523fa7a6SAndroid Build Coastguard Worker """ 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker def to_tensor_spec(meta: ScratchTensorMetadata) -> TensorSpec: 46*523fa7a6SAndroid Build Coastguard Worker return TensorSpec( 47*523fa7a6SAndroid Build Coastguard Worker const=False, 48*523fa7a6SAndroid Build Coastguard Worker requires_grad=False, 49*523fa7a6SAndroid Build Coastguard Worker # fields copy from ScratchTensorMetadata 50*523fa7a6SAndroid Build Coastguard Worker dtype=meta.dtype, 51*523fa7a6SAndroid Build Coastguard Worker shape=meta.shape, 52*523fa7a6SAndroid Build Coastguard Worker layout=meta.layout, 53*523fa7a6SAndroid Build Coastguard Worker is_sparse=meta.is_sparse, 54*523fa7a6SAndroid Build Coastguard Worker ) 55*523fa7a6SAndroid Build Coastguard Worker 56*523fa7a6SAndroid Build Coastguard Worker def adapt_return_value( 57*523fa7a6SAndroid Build Coastguard Worker get_scratch_metas_fn: ScratchCallableType, 58*523fa7a6SAndroid Build Coastguard Worker ) -> Callable[..., Dict[str, TensorSpec]]: 59*523fa7a6SAndroid Build Coastguard Worker """ 60*523fa7a6SAndroid Build Coastguard Worker Adapt return value from a ScratchTensorMetadata to a TensorSpec 61*523fa7a6SAndroid Build Coastguard Worker """ 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker def wrapper(*args: TensorSpec, **kwargs: TensorSpec) -> Dict[str, TensorSpec]: 64*523fa7a6SAndroid Build Coastguard Worker meta_dict = get_scratch_metas_fn(*args, **kwargs) 65*523fa7a6SAndroid Build Coastguard Worker return {k: to_tensor_spec(v) for k, v in meta_dict.items()} 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker return wrapper 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Worker def wrapper(get_scratch_metas_fn: ScratchCallableType) -> ScratchCallableType: 70*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `OpOverload` has no attribute `get_scratch_metas`. 71*523fa7a6SAndroid Build Coastguard Worker out_variant.get_scratch_metas = adapt_return_value(get_scratch_metas_fn) 72*523fa7a6SAndroid Build Coastguard Worker return get_scratch_metas_fn 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker return wrapper 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore 78*523fa7a6SAndroid Build Coastguard Workerdef attach_calculate_upper_bound_shape_fn(func_op: torch._ops.OpOverload): 79*523fa7a6SAndroid Build Coastguard Worker """ 80*523fa7a6SAndroid Build Coastguard Worker The input is the OpOverload for the functional op. 81*523fa7a6SAndroid Build Coastguard Worker """ 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 84*523fa7a6SAndroid Build Coastguard Worker def wrapper(calculate_upper_bound_shape_fn): 85*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `OpOverload` has no attribute `calculate_upper_bound_shape`. 86*523fa7a6SAndroid Build Coastguard Worker func_op.calculate_upper_bound_shape = calculate_upper_bound_shape_fn 87*523fa7a6SAndroid Build Coastguard Worker return calculate_upper_bound_shape_fn 88*523fa7a6SAndroid Build Coastguard Worker 89*523fa7a6SAndroid Build Coastguard Worker return wrapper 90