xref: /aosp_15_r20/external/executorch/exir/operator/manip.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker"""
10*523fa7a6SAndroid Build Coastguard WorkerThis module contains APIs to manipulate ops.
11*523fa7a6SAndroid Build Coastguard Worker"""
12*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass
13*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerimport torch
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker@dataclass
20*523fa7a6SAndroid Build Coastguard Workerclass ScratchTensorMetadata:
21*523fa7a6SAndroid Build Coastguard Worker    dtype: torch.dtype
22*523fa7a6SAndroid Build Coastguard Worker    shape: torch.Size
23*523fa7a6SAndroid Build Coastguard Worker    layout: torch.layout = torch.strided
24*523fa7a6SAndroid Build Coastguard Worker    device: torch.device = torch.device("cpu")
25*523fa7a6SAndroid Build Coastguard Worker    is_sparse: bool = False
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard WorkerScratchCallableType = Callable[..., Dict[str, ScratchTensorMetadata]]
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Workerdef attach_get_scratch_metas_fn(
32*523fa7a6SAndroid Build Coastguard Worker    out_variant: torch._ops.OpOverload,
33*523fa7a6SAndroid Build Coastguard Worker) -> Callable[[ScratchCallableType], ScratchCallableType]:
34*523fa7a6SAndroid Build Coastguard Worker    """
35*523fa7a6SAndroid Build Coastguard Worker    Apply this decorator to the get_scratch_metas methods for `out_variant` op.
36*523fa7a6SAndroid Build Coastguard Worker    The decorator will do the job to attach the get_scratch_metas method
37*523fa7a6SAndroid Build Coastguard Worker    to the out variant op.
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker    The get_scratch_metas method has the same signature as the out variant op.
40*523fa7a6SAndroid Build Coastguard Worker    There are 2 difference though:
41*523fa7a6SAndroid Build Coastguard Worker    - the Tensor input arguments are all replaced with TensorSpec
42*523fa7a6SAndroid Build Coastguard Worker    - the output is a dictionary of ScratchTensorMetadata
43*523fa7a6SAndroid Build Coastguard Worker    """
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker    def to_tensor_spec(meta: ScratchTensorMetadata) -> TensorSpec:
46*523fa7a6SAndroid Build Coastguard Worker        return TensorSpec(
47*523fa7a6SAndroid Build Coastguard Worker            const=False,
48*523fa7a6SAndroid Build Coastguard Worker            requires_grad=False,
49*523fa7a6SAndroid Build Coastguard Worker            # fields copy from ScratchTensorMetadata
50*523fa7a6SAndroid Build Coastguard Worker            dtype=meta.dtype,
51*523fa7a6SAndroid Build Coastguard Worker            shape=meta.shape,
52*523fa7a6SAndroid Build Coastguard Worker            layout=meta.layout,
53*523fa7a6SAndroid Build Coastguard Worker            is_sparse=meta.is_sparse,
54*523fa7a6SAndroid Build Coastguard Worker        )
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker    def adapt_return_value(
57*523fa7a6SAndroid Build Coastguard Worker        get_scratch_metas_fn: ScratchCallableType,
58*523fa7a6SAndroid Build Coastguard Worker    ) -> Callable[..., Dict[str, TensorSpec]]:
59*523fa7a6SAndroid Build Coastguard Worker        """
60*523fa7a6SAndroid Build Coastguard Worker        Adapt return value from a ScratchTensorMetadata to a TensorSpec
61*523fa7a6SAndroid Build Coastguard Worker        """
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker        def wrapper(*args: TensorSpec, **kwargs: TensorSpec) -> Dict[str, TensorSpec]:
64*523fa7a6SAndroid Build Coastguard Worker            meta_dict = get_scratch_metas_fn(*args, **kwargs)
65*523fa7a6SAndroid Build Coastguard Worker            return {k: to_tensor_spec(v) for k, v in meta_dict.items()}
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker        return wrapper
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker    def wrapper(get_scratch_metas_fn: ScratchCallableType) -> ScratchCallableType:
70*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: `OpOverload` has no attribute `get_scratch_metas`.
71*523fa7a6SAndroid Build Coastguard Worker        out_variant.get_scratch_metas = adapt_return_value(get_scratch_metas_fn)
72*523fa7a6SAndroid Build Coastguard Worker        return get_scratch_metas_fn
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker    return wrapper
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore
78*523fa7a6SAndroid Build Coastguard Workerdef attach_calculate_upper_bound_shape_fn(func_op: torch._ops.OpOverload):
79*523fa7a6SAndroid Build Coastguard Worker    """
80*523fa7a6SAndroid Build Coastguard Worker    The input is the OpOverload for the functional op.
81*523fa7a6SAndroid Build Coastguard Worker    """
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
84*523fa7a6SAndroid Build Coastguard Worker    def wrapper(calculate_upper_bound_shape_fn):
85*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: `OpOverload` has no attribute `calculate_upper_bound_shape`.
86*523fa7a6SAndroid Build Coastguard Worker        func_op.calculate_upper_bound_shape = calculate_upper_bound_shape_fn
87*523fa7a6SAndroid Build Coastguard Worker        return calculate_upper_bound_shape_fn
88*523fa7a6SAndroid Build Coastguard Worker
89*523fa7a6SAndroid Build Coastguard Worker    return wrapper
90