xref: /aosp_15_r20/external/executorch/exir/dialects/backend/_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
9from torch._C import DispatchKey  # @manual
10
11_BACKEND_OP_LIB = [
12    "executorch_prim",
13    "quantized_decomposed",
14    "DO_NOT_USE_TEST_ONLY",
15]
16
17
18class BackendOpOverload(EdgeOpOverload):
19    """OpOverload for backend ops.
20    A Backend operator is a custom op that doesn't show up in ATen dialect.
21    Therefore it must be replacing an existing node or a pattern of nodes in Edge dialect.
22    This data structure makes sure after lower (part of) Edge dialect to backend ops, the whole graph can still be captured properly.
23
24    Difference to delegate:
25    1. delegate result is still a module (a target of call_module, at least for now) while backend op is an operator (a target of call_function).
26    2. backend op is stateless while delegation doesn't have to
27    3. backend op stays in executor standard runtime but delegation doesn't have to
28
29    Examples for backend ops including fused ops for a specific backend, ExecuTorch prim ops to handle symbolic shape.
30
31    Note that the assumption here is that the backend op and the original callable / equivalent callable is 1 - 1 mapping.
32
33    BackendOpOverload makes sure:
34    1. The backend op contains either a CompositeExplicitAutograd or a meta kernel.
35    2. It also holds a reference to the original node/pattern it replaces.
36    Example:
37
38    add -> relu
39        |
40        v
41    add_relu(only works on dsp): hold reference to add -> relu pattern, for re-capturing purpose.
42
43    Retrace example:
44
45    A very common practice in delegate, is that the module needs to be lowered to a backend, then the lowered module needs to be composed with original nn.Module and retrace.
46
47    LoweredModule l_of_m = to_backend(g_of_m.to_edge(), ...)
48    Module main(l_of_m)
49    export(main, inputs)
50
51    """
52
53    def __init__(
54        self,
55        op_overload: EdgeOpOverload,
56    ):
57        super(self.__class__, self).__init__(
58            op_overload._op,
59            op_overload._schema,
60        )
61        self._equivalent_callable = None
62        self._has_meta_kernel = self._op.has_kernel_for_dispatch_key(DispatchKey.Meta)
63        self._has_composite_explicit_autograd_kernel = (
64            self._op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd)
65        )
66        self._has_composite_implicit_autograd_kernel = (
67            self._op.has_kernel_for_dispatch_key(DispatchKey.CompositeImplicitAutograd)
68        )
69        assert (
70            self._has_meta_kernel
71            or self._has_composite_explicit_autograd_kernel
72            or self._has_composite_implicit_autograd_kernel
73        ), "A backend op must contain either CompositeExplicitAutograd or Meta or CompositeImplicitAutograd kernel."
74
75
76class BackendOpOverloadPacket(EdgeOpOverloadPacket):
77    """OpOverloadPacket for backend ops.
78    Wraps EdgeOpOverloadPacket and overrides __getattr__ to return OpOverload
79    for backend ops.
80    """
81
82    def __init__(
83        self,
84        qualified_op_name: str,
85        op_name: str,
86        parent_overload_packet: torch._ops.OpOverloadPacket,
87    ):
88        super(self.__class__, self).__init__(
89            qualified_op_name, op_name, parent_overload_packet
90        )
91
92    def __repr__(self):
93        return "<BackendOpOverloadPacket(op='{}', parent_op='{}')>".format(
94            self._qualified_op_name.replace("::", "."),
95            self._parent_qualified_op_name.replace("::", "."),
96        )
97
98    def __hash__(self):
99        return hash(self._op)
100
101    def __str__(self):
102        return "{}".format(self._qualified_op_name.replace("::", "."))
103
104    @property
105    def op(self):
106        return self._op
107
108    def __getattr__(self, key):
109        try:
110            # get edge op, set it as attribute. Note that this way we don't have `_original_pattern`.
111            result = super().__getattr__(key)
112            if isinstance(result, EdgeOpOverload):
113                backend_op = BackendOpOverload(result)
114                setattr(self, key, backend_op)
115                return backend_op
116            else:
117                return result
118        except AttributeError as e:
119            raise AttributeError(
120                "The underlying op of '{}' has no overload name '{}'. Original error message: \n {}".format(
121                    self, key, e
122                )
123            ) from e
124