xref: /aosp_15_r20/external/executorch/exir/delegate.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from __future__ import annotations
10
11try:  # noqa: C901
12    from torch._higher_order_ops.executorch_call_delegate import (
13        executorch_call_delegate as executorch_call_delegate,
14        get_lowered_module_name as get_lowered_module_name,
15        is_lowered_module as is_lowered_module,
16    )
17
18except ImportError:
19
20    # TODO: Delete this code once pytorch pin advances
21
22    from typing import Any, cast
23
24    import torch
25    import torch.utils._pytree as pytree
26    from torch._ops import HigherOrderOperator
27    from torch._subclasses.fake_tensor import FakeTensorMode
28    from torch.fx.experimental.proxy_tensor import (
29        disable_proxy_modes_tracing,
30        get_proxy_slot,
31        ProxyTorchDispatchMode,
32        track_tensor_tree,
33    )
34    from torch.utils._pytree import tree_flatten
35
36    executorch_call_delegate = HigherOrderOperator("executorch_call_delegate")
37    executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
38    executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
39    executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)
40    executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
41
42    LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
43
44    # pyre-ignore
45    def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
46        # pyre-ignore
47        def _unwrap_proxy(e):
48            if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
49                return e
50            return get_proxy_slot(
51                cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy
52            )
53
54        if not is_lowered_module(lowered_module):
55            raise ValueError(
56                "executorch_call_delegate()'s first argument must be a LoweredBackendModule"
57            )
58
59        with disable_proxy_modes_tracing():
60            out = call_delegate_cpu(lowered_module, *args)
61
62        get_lowered_module_name(proxy_mode.tracer.root, lowered_module)
63
64        node_args = (lowered_module, *args)
65        proxy_args = pytree.tree_map(_unwrap_proxy, node_args)
66        out_proxy = proxy_mode.tracer.create_proxy(
67            "call_function",
68            func_overload,
69            proxy_args,
70            {},
71            name="executorch_call_delegate",
72        )
73        return track_tensor_tree(
74            out, out_proxy, constant=None, tracer=proxy_mode.tracer
75        )
76
77    @executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
78    # pyre-ignore
79    def call_delegate_cpu(lowered_module, *args):
80        # FX creates this immutable_dict/list concept. Get rid of this.
81        map_types = {
82            torch.fx.immutable_collections.immutable_dict: dict,
83            torch.fx.immutable_collections.immutable_list: list,
84        }
85        new_args = pytree.tree_map_only(
86            tuple(map_types.keys()),
87            lambda a: map_types[type(a)](a),
88            args,
89            lambda a: isinstance(a, tuple(map_types.keys())),
90        )
91        return lowered_module.original_module.module()(*new_args)
92
93    @executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd)
94    # pyre-ignore
95    def call_delegate_autograd(lowered_module, *args):
96        # TODO: support autograd
97        flat_operands, _ = tree_flatten([lowered_module, *args])
98        requires_grad = any(
99            f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
100        )
101
102        with torch._C._ExcludeDispatchKeyGuard(
103            torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
104        ):
105            res = executorch_call_delegate(lowered_module, *args)
106
107            if requires_grad:
108                # Create aliases of the output that has requires_grad=True. We need
109                # at least one of the inputs to err_fn to require grad so that the
110                # output will have a grad_fn.
111
112                # pyre-ignore
113                def fake_requires_grad(var):
114                    if var is not None:
115                        var = var.detach()
116                        if torch.is_floating_point(var) or torch.is_complex(var):
117                            var.requires_grad = True
118                    return var
119
120                return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res)
121
122            return res
123
124    @executorch_call_delegate.py_impl(ProxyTorchDispatchMode)
125    # pyre-ignore
126    def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args):
127        res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args)
128        return res
129
130    @executorch_call_delegate.py_impl(FakeTensorMode)
131    # pyre-ignore
132    def call_delegate_fake_tensor_mode(mode, lowered_module, *args):
133        with mode:
134            return call_delegate_cpu(lowered_module, *args)
135
136    @executorch_call_delegate.py_functionalize_impl
137    # pyre-ignore
138    def call_delegate_functionalize(ctx, lowered_module, *args):
139        unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
140        with ctx.redispatch_to_next():
141            res = executorch_call_delegate(lowered_module, *unwrapped_args)
142            return ctx.wrap_tensors(res)
143
144    # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre
145    def is_lowered_module(obj: Any) -> bool:
146        """
147        This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import.
148        """
149        return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE
150
151    def get_lowered_module_name(
152        root: torch.nn.Module,
153        # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
154        lowered_module: LOWERED_BACKEND_MODULE_TYPE,  # noqa
155    ) -> str:
156        """
157        Adds the given lowered_module into the given root module and returns the
158        name of the module added.
159        """
160        # Find a qualifying name for the lowered submodule
161        qualname = None
162        i = 0
163        while True:
164            qualname = f"lowered_module_{i}"
165            if not hasattr(root, qualname):
166                break
167            i += 1
168        assert qualname is not None
169
170        root.add_module(qualname, lowered_module)
171        return qualname
172