xref: /aosp_15_r20/external/executorch/exir/delegate.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerfrom __future__ import annotations
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workertry:  # noqa: C901
12*523fa7a6SAndroid Build Coastguard Worker    from torch._higher_order_ops.executorch_call_delegate import (
13*523fa7a6SAndroid Build Coastguard Worker        executorch_call_delegate as executorch_call_delegate,
14*523fa7a6SAndroid Build Coastguard Worker        get_lowered_module_name as get_lowered_module_name,
15*523fa7a6SAndroid Build Coastguard Worker        is_lowered_module as is_lowered_module,
16*523fa7a6SAndroid Build Coastguard Worker    )
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerexcept ImportError:
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker    # TODO: Delete this code once pytorch pin advances
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Worker    from typing import Any, cast
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker    import torch
25*523fa7a6SAndroid Build Coastguard Worker    import torch.utils._pytree as pytree
26*523fa7a6SAndroid Build Coastguard Worker    from torch._ops import HigherOrderOperator
27*523fa7a6SAndroid Build Coastguard Worker    from torch._subclasses.fake_tensor import FakeTensorMode
28*523fa7a6SAndroid Build Coastguard Worker    from torch.fx.experimental.proxy_tensor import (
29*523fa7a6SAndroid Build Coastguard Worker        disable_proxy_modes_tracing,
30*523fa7a6SAndroid Build Coastguard Worker        get_proxy_slot,
31*523fa7a6SAndroid Build Coastguard Worker        ProxyTorchDispatchMode,
32*523fa7a6SAndroid Build Coastguard Worker        track_tensor_tree,
33*523fa7a6SAndroid Build Coastguard Worker    )
34*523fa7a6SAndroid Build Coastguard Worker    from torch.utils._pytree import tree_flatten
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker    executorch_call_delegate = HigherOrderOperator("executorch_call_delegate")
37*523fa7a6SAndroid Build Coastguard Worker    executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
38*523fa7a6SAndroid Build Coastguard Worker    executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
39*523fa7a6SAndroid Build Coastguard Worker    executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)
40*523fa7a6SAndroid Build Coastguard Worker    executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Worker    LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
45*523fa7a6SAndroid Build Coastguard Worker    def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
46*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
47*523fa7a6SAndroid Build Coastguard Worker        def _unwrap_proxy(e):
48*523fa7a6SAndroid Build Coastguard Worker            if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
49*523fa7a6SAndroid Build Coastguard Worker                return e
50*523fa7a6SAndroid Build Coastguard Worker            return get_proxy_slot(
51*523fa7a6SAndroid Build Coastguard Worker                cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy
52*523fa7a6SAndroid Build Coastguard Worker            )
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker        if not is_lowered_module(lowered_module):
55*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(
56*523fa7a6SAndroid Build Coastguard Worker                "executorch_call_delegate()'s first argument must be a LoweredBackendModule"
57*523fa7a6SAndroid Build Coastguard Worker            )
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker        with disable_proxy_modes_tracing():
60*523fa7a6SAndroid Build Coastguard Worker            out = call_delegate_cpu(lowered_module, *args)
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker        get_lowered_module_name(proxy_mode.tracer.root, lowered_module)
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker        node_args = (lowered_module, *args)
65*523fa7a6SAndroid Build Coastguard Worker        proxy_args = pytree.tree_map(_unwrap_proxy, node_args)
66*523fa7a6SAndroid Build Coastguard Worker        out_proxy = proxy_mode.tracer.create_proxy(
67*523fa7a6SAndroid Build Coastguard Worker            "call_function",
68*523fa7a6SAndroid Build Coastguard Worker            func_overload,
69*523fa7a6SAndroid Build Coastguard Worker            proxy_args,
70*523fa7a6SAndroid Build Coastguard Worker            {},
71*523fa7a6SAndroid Build Coastguard Worker            name="executorch_call_delegate",
72*523fa7a6SAndroid Build Coastguard Worker        )
73*523fa7a6SAndroid Build Coastguard Worker        return track_tensor_tree(
74*523fa7a6SAndroid Build Coastguard Worker            out, out_proxy, constant=None, tracer=proxy_mode.tracer
75*523fa7a6SAndroid Build Coastguard Worker        )
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker    @executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
78*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
79*523fa7a6SAndroid Build Coastguard Worker    def call_delegate_cpu(lowered_module, *args):
80*523fa7a6SAndroid Build Coastguard Worker        # FX creates this immutable_dict/list concept. Get rid of this.
81*523fa7a6SAndroid Build Coastguard Worker        map_types = {
82*523fa7a6SAndroid Build Coastguard Worker            torch.fx.immutable_collections.immutable_dict: dict,
83*523fa7a6SAndroid Build Coastguard Worker            torch.fx.immutable_collections.immutable_list: list,
84*523fa7a6SAndroid Build Coastguard Worker        }
85*523fa7a6SAndroid Build Coastguard Worker        new_args = pytree.tree_map_only(
86*523fa7a6SAndroid Build Coastguard Worker            tuple(map_types.keys()),
87*523fa7a6SAndroid Build Coastguard Worker            lambda a: map_types[type(a)](a),
88*523fa7a6SAndroid Build Coastguard Worker            args,
89*523fa7a6SAndroid Build Coastguard Worker            lambda a: isinstance(a, tuple(map_types.keys())),
90*523fa7a6SAndroid Build Coastguard Worker        )
91*523fa7a6SAndroid Build Coastguard Worker        return lowered_module.original_module.module()(*new_args)
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Worker    @executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd)
94*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
95*523fa7a6SAndroid Build Coastguard Worker    def call_delegate_autograd(lowered_module, *args):
96*523fa7a6SAndroid Build Coastguard Worker        # TODO: support autograd
97*523fa7a6SAndroid Build Coastguard Worker        flat_operands, _ = tree_flatten([lowered_module, *args])
98*523fa7a6SAndroid Build Coastguard Worker        requires_grad = any(
99*523fa7a6SAndroid Build Coastguard Worker            f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
100*523fa7a6SAndroid Build Coastguard Worker        )
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker        with torch._C._ExcludeDispatchKeyGuard(
103*523fa7a6SAndroid Build Coastguard Worker            torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
104*523fa7a6SAndroid Build Coastguard Worker        ):
105*523fa7a6SAndroid Build Coastguard Worker            res = executorch_call_delegate(lowered_module, *args)
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker            if requires_grad:
108*523fa7a6SAndroid Build Coastguard Worker                # Create aliases of the output that has requires_grad=True. We need
109*523fa7a6SAndroid Build Coastguard Worker                # at least one of the inputs to err_fn to require grad so that the
110*523fa7a6SAndroid Build Coastguard Worker                # output will have a grad_fn.
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker                # pyre-ignore
113*523fa7a6SAndroid Build Coastguard Worker                def fake_requires_grad(var):
114*523fa7a6SAndroid Build Coastguard Worker                    if var is not None:
115*523fa7a6SAndroid Build Coastguard Worker                        var = var.detach()
116*523fa7a6SAndroid Build Coastguard Worker                        if torch.is_floating_point(var) or torch.is_complex(var):
117*523fa7a6SAndroid Build Coastguard Worker                            var.requires_grad = True
118*523fa7a6SAndroid Build Coastguard Worker                    return var
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard Worker                return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res)
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker            return res
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker    @executorch_call_delegate.py_impl(ProxyTorchDispatchMode)
125*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
126*523fa7a6SAndroid Build Coastguard Worker    def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args):
127*523fa7a6SAndroid Build Coastguard Worker        res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args)
128*523fa7a6SAndroid Build Coastguard Worker        return res
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker    @executorch_call_delegate.py_impl(FakeTensorMode)
131*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
132*523fa7a6SAndroid Build Coastguard Worker    def call_delegate_fake_tensor_mode(mode, lowered_module, *args):
133*523fa7a6SAndroid Build Coastguard Worker        with mode:
134*523fa7a6SAndroid Build Coastguard Worker            return call_delegate_cpu(lowered_module, *args)
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker    @executorch_call_delegate.py_functionalize_impl
137*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
138*523fa7a6SAndroid Build Coastguard Worker    def call_delegate_functionalize(ctx, lowered_module, *args):
139*523fa7a6SAndroid Build Coastguard Worker        unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
140*523fa7a6SAndroid Build Coastguard Worker        with ctx.redispatch_to_next():
141*523fa7a6SAndroid Build Coastguard Worker            res = executorch_call_delegate(lowered_module, *unwrapped_args)
142*523fa7a6SAndroid Build Coastguard Worker            return ctx.wrap_tensors(res)
143*523fa7a6SAndroid Build Coastguard Worker
144*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre
145*523fa7a6SAndroid Build Coastguard Worker    def is_lowered_module(obj: Any) -> bool:
146*523fa7a6SAndroid Build Coastguard Worker        """
147*523fa7a6SAndroid Build Coastguard Worker        This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import.
148*523fa7a6SAndroid Build Coastguard Worker        """
149*523fa7a6SAndroid Build Coastguard Worker        return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE
150*523fa7a6SAndroid Build Coastguard Worker
151*523fa7a6SAndroid Build Coastguard Worker    def get_lowered_module_name(
152*523fa7a6SAndroid Build Coastguard Worker        root: torch.nn.Module,
153*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
154*523fa7a6SAndroid Build Coastguard Worker        lowered_module: LOWERED_BACKEND_MODULE_TYPE,  # noqa
155*523fa7a6SAndroid Build Coastguard Worker    ) -> str:
156*523fa7a6SAndroid Build Coastguard Worker        """
157*523fa7a6SAndroid Build Coastguard Worker        Adds the given lowered_module into the given root module and returns the
158*523fa7a6SAndroid Build Coastguard Worker        name of the module added.
159*523fa7a6SAndroid Build Coastguard Worker        """
160*523fa7a6SAndroid Build Coastguard Worker        # Find a qualifying name for the lowered submodule
161*523fa7a6SAndroid Build Coastguard Worker        qualname = None
162*523fa7a6SAndroid Build Coastguard Worker        i = 0
163*523fa7a6SAndroid Build Coastguard Worker        while True:
164*523fa7a6SAndroid Build Coastguard Worker            qualname = f"lowered_module_{i}"
165*523fa7a6SAndroid Build Coastguard Worker            if not hasattr(root, qualname):
166*523fa7a6SAndroid Build Coastguard Worker                break
167*523fa7a6SAndroid Build Coastguard Worker            i += 1
168*523fa7a6SAndroid Build Coastguard Worker        assert qualname is not None
169*523fa7a6SAndroid Build Coastguard Worker
170*523fa7a6SAndroid Build Coastguard Worker        root.add_module(qualname, lowered_module)
171*523fa7a6SAndroid Build Coastguard Worker        return qualname
172