xref: /aosp_15_r20/external/executorch/exir/dialects/_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport types
8*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerimport torch
11*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.backend._ops import (
12*523fa7a6SAndroid Build Coastguard Worker    _BACKEND_OP_LIB,
13*523fa7a6SAndroid Build Coastguard Worker    BackendOpOverloadPacket,
14*523fa7a6SAndroid Build Coastguard Worker)
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverloadPacket
16*523fa7a6SAndroid Build Coastguard Workerfrom torch._C import DispatchKey  # @manual
17*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import Library
18*523fa7a6SAndroid Build Coastguard Workerfrom torchgen.model import FunctionSchema
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker_OPOVERLOAD_PACKET_CLS_MAPPING = {
21*523fa7a6SAndroid Build Coastguard Worker    "edge": EdgeOpOverloadPacket,
22*523fa7a6SAndroid Build Coastguard Worker    "backend": BackendOpOverloadPacket,
23*523fa7a6SAndroid Build Coastguard Worker}
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Workerdef bind_pattern_to_op(library: Library, schema_or_name: str):
27*523fa7a6SAndroid Build Coastguard Worker    """Bind a pattern of ops to a backend op. A backend op should only appear when a user wants to replace a pattern of nodes to a custom op.
28*523fa7a6SAndroid Build Coastguard Worker    On this front, the kernel being registered to it determines the decomposing behavior.
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker    *   If the backend op is registered with an CompositeExplicitAutograd (or Meta) kernel, once the graph is lowered (meaning the pass
31*523fa7a6SAndroid Build Coastguard Worker        of replacing a pattern to an op is executed) it will stick in the graph and we won't get the original graph even retrace.
32*523fa7a6SAndroid Build Coastguard Worker    *   Otherwise, the backend op should be able to support retracing and be able to "promote" back to the original graph through retracing.
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker    This macro is aiming to handle this complexity for users and they just need to use this macro on the pattern and we can make a decision for them.
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker    Args:
37*523fa7a6SAndroid Build Coastguard Worker        library (Library): torch library
38*523fa7a6SAndroid Build Coastguard Worker        schema_or_name (str): schema string, e.g., "add.int(SymInt a, SymInt b) -> SymInt", or a qualified op name
39*523fa7a6SAndroid Build Coastguard Worker    """
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker    def wrapper(f: Callable):
42*523fa7a6SAndroid Build Coastguard Worker        if library.ns not in _BACKEND_OP_LIB:
43*523fa7a6SAndroid Build Coastguard Worker            _BACKEND_OP_LIB.append(library.ns)
44*523fa7a6SAndroid Build Coastguard Worker        no_namespace = schema_or_name.split("::")[-1]
45*523fa7a6SAndroid Build Coastguard Worker        try:
46*523fa7a6SAndroid Build Coastguard Worker            # can parse it into a FunctionSchema
47*523fa7a6SAndroid Build Coastguard Worker            func = FunctionSchema.parse(no_namespace)
48*523fa7a6SAndroid Build Coastguard Worker            name, overload_name = func.name.name.base, func.name.overload_name
49*523fa7a6SAndroid Build Coastguard Worker            library.define(no_namespace)
50*523fa7a6SAndroid Build Coastguard Worker        except AssertionError:
51*523fa7a6SAndroid Build Coastguard Worker            if "." in no_namespace:
52*523fa7a6SAndroid Build Coastguard Worker                name, overload_name = no_namespace.split(".")
53*523fa7a6SAndroid Build Coastguard Worker            else:
54*523fa7a6SAndroid Build Coastguard Worker                name, overload_name = no_namespace, None
55*523fa7a6SAndroid Build Coastguard Worker        opname = name + ("." + overload_name if overload_name else "")
56*523fa7a6SAndroid Build Coastguard Worker        overload_name = overload_name if overload_name else "default"
57*523fa7a6SAndroid Build Coastguard Worker        torch_op = getattr(getattr(getattr(torch.ops, library.ns), name), overload_name)
58*523fa7a6SAndroid Build Coastguard Worker        # we can't have both CompositeExplicitAutograd and CompositeImplicitAutograd kernel,
59*523fa7a6SAndroid Build Coastguard Worker        # we can't have both Meta and CompositeImplicitAutograd kernel either.
60*523fa7a6SAndroid Build Coastguard Worker        keys = [
61*523fa7a6SAndroid Build Coastguard Worker            DispatchKey.CompositeExplicitAutograd,
62*523fa7a6SAndroid Build Coastguard Worker            DispatchKey.CompositeImplicitAutograd,
63*523fa7a6SAndroid Build Coastguard Worker            DispatchKey.Meta,
64*523fa7a6SAndroid Build Coastguard Worker        ]
65*523fa7a6SAndroid Build Coastguard Worker        if not any(torch_op.has_kernel_for_dispatch_key(k) for k in keys):
66*523fa7a6SAndroid Build Coastguard Worker            library.impl(opname, f, "CompositeImplicitAutograd")
67*523fa7a6SAndroid Build Coastguard Worker        op = getattr(getattr(getattr(ops.backend, library.ns), name), overload_name)
68*523fa7a6SAndroid Build Coastguard Worker        op._equivalent_callable = f
69*523fa7a6SAndroid Build Coastguard Worker        return f
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker    return wrapper
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Workerclass _OpNamespace(types.ModuleType):
75*523fa7a6SAndroid Build Coastguard Worker    """
76*523fa7a6SAndroid Build Coastguard Worker    EXIR Dialect op namespace object. Contains ops and overloads registered into PyTorch dispatcher.
77*523fa7a6SAndroid Build Coastguard Worker    """
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, dialect, name):
80*523fa7a6SAndroid Build Coastguard Worker        super().__init__(f"exir.ops.{dialect}.{name}")
81*523fa7a6SAndroid Build Coastguard Worker        self._dialect = dialect
82*523fa7a6SAndroid Build Coastguard Worker        if dialect == "backend" and name not in _BACKEND_OP_LIB:
83*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(f"{name} op library does not belong to backend ops.")
84*523fa7a6SAndroid Build Coastguard Worker        self._name = name
85*523fa7a6SAndroid Build Coastguard Worker        self._dir = []
86*523fa7a6SAndroid Build Coastguard Worker        self._op_namespace = getattr(torch.ops, name)
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker    def __iter__(self):
89*523fa7a6SAndroid Build Coastguard Worker        return iter(self._dir)
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker    def __getattr__(self, op_name):
92*523fa7a6SAndroid Build Coastguard Worker        # It is not a valid op_name when __file__ is passed in
93*523fa7a6SAndroid Build Coastguard Worker        if op_name == "__file__":
94*523fa7a6SAndroid Build Coastguard Worker            return "exir.ops"
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker        if op_name in self.__dict__:
97*523fa7a6SAndroid Build Coastguard Worker            return getattr(self, op_name)
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker        try:
100*523fa7a6SAndroid Build Coastguard Worker            parent_packet = getattr(self._op_namespace, op_name)
101*523fa7a6SAndroid Build Coastguard Worker        except AttributeError as e:
102*523fa7a6SAndroid Build Coastguard Worker            # Turn this into AttributeError so getattr(obj, key, default)
103*523fa7a6SAndroid Build Coastguard Worker            # works (this is called by TorchScript with __origin__)
104*523fa7a6SAndroid Build Coastguard Worker            raise AttributeError(
105*523fa7a6SAndroid Build Coastguard Worker                f"'_OpNamespace' '{self._dialect}.{self._name}' object has no attribute '{op_name}'"
106*523fa7a6SAndroid Build Coastguard Worker            ) from e
107*523fa7a6SAndroid Build Coastguard Worker        qualified_op_name = f"{self._name}::{op_name}"
108*523fa7a6SAndroid Build Coastguard Worker        opoverload_packet_cls = _OPOVERLOAD_PACKET_CLS_MAPPING[self._dialect]
109*523fa7a6SAndroid Build Coastguard Worker        opoverloadpacket = opoverload_packet_cls(
110*523fa7a6SAndroid Build Coastguard Worker            qualified_op_name,
111*523fa7a6SAndroid Build Coastguard Worker            op_name,
112*523fa7a6SAndroid Build Coastguard Worker            parent_overload_packet=parent_packet,
113*523fa7a6SAndroid Build Coastguard Worker        )
114*523fa7a6SAndroid Build Coastguard Worker        opoverloadpacket.__module__ = self.__module__ + "." + self._name
115*523fa7a6SAndroid Build Coastguard Worker        # cache the opoverloadpacket to ensure that each op corresponds to
116*523fa7a6SAndroid Build Coastguard Worker        # a unique OpOverloadPacket object
117*523fa7a6SAndroid Build Coastguard Worker        setattr(self, op_name, opoverloadpacket)
118*523fa7a6SAndroid Build Coastguard Worker        self._dir.append(op_name)
119*523fa7a6SAndroid Build Coastguard Worker        return opoverloadpacket
120*523fa7a6SAndroid Build Coastguard Worker
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Workerclass _DialectNamespace(types.ModuleType):
123*523fa7a6SAndroid Build Coastguard Worker    """
124*523fa7a6SAndroid Build Coastguard Worker    Dialect namespace. Currently the dialects are:
125*523fa7a6SAndroid Build Coastguard Worker    - ATen Dialect: core ATen ops and overloads, see torch._ops._OpNamespace
126*523fa7a6SAndroid Build Coastguard Worker    - Edge Dialect: ATen ops with explicit Tensor dtype
127*523fa7a6SAndroid Build Coastguard Worker    - Backend Dialect: backend ops only meaningful to the backend we are lowering into
128*523fa7a6SAndroid Build Coastguard Worker    - Execution Dialect: memory planning ready, all out-variants
129*523fa7a6SAndroid Build Coastguard Worker    """
130*523fa7a6SAndroid Build Coastguard Worker
131*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, dialect_name):
132*523fa7a6SAndroid Build Coastguard Worker        super().__init__("exir.ops" + "." + dialect_name)
133*523fa7a6SAndroid Build Coastguard Worker        self._dialect_name = dialect_name
134*523fa7a6SAndroid Build Coastguard Worker        self._dir = []
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker    def __getattr__(self, name):
137*523fa7a6SAndroid Build Coastguard Worker        if name in self.__dict__:
138*523fa7a6SAndroid Build Coastguard Worker            return getattr(self, name)
139*523fa7a6SAndroid Build Coastguard Worker        # Here we are creating `exir.ops.<dialect_ns>.<my_namespace>`
140*523fa7a6SAndroid Build Coastguard Worker        namespace = _OpNamespace(self._dialect_name, name)
141*523fa7a6SAndroid Build Coastguard Worker        setattr(self, name, namespace)
142*523fa7a6SAndroid Build Coastguard Worker        self._dir.append(name)
143*523fa7a6SAndroid Build Coastguard Worker        return namespace
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Workerclass _Ops(types.ModuleType):
147*523fa7a6SAndroid Build Coastguard Worker    __file__ = "_ops.py"
148*523fa7a6SAndroid Build Coastguard Worker
149*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
150*523fa7a6SAndroid Build Coastguard Worker        super().__init__("exir.ops")
151*523fa7a6SAndroid Build Coastguard Worker        self._dir = []
152*523fa7a6SAndroid Build Coastguard Worker
153*523fa7a6SAndroid Build Coastguard Worker    def __getattr__(self, name):
154*523fa7a6SAndroid Build Coastguard Worker        if name in self.__dict__:
155*523fa7a6SAndroid Build Coastguard Worker            return getattr(self, name)
156*523fa7a6SAndroid Build Coastguard Worker        dialect = _DialectNamespace(name)
157*523fa7a6SAndroid Build Coastguard Worker        setattr(self, name, dialect)
158*523fa7a6SAndroid Build Coastguard Worker        self._dir.append(name)
159*523fa7a6SAndroid Build Coastguard Worker        return dialect
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker    def __iter__(self):
162*523fa7a6SAndroid Build Coastguard Worker        return iter(self._dir)
163*523fa7a6SAndroid Build Coastguard Worker
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker# The ops "namespace"
166*523fa7a6SAndroid Build Coastguard Workerops = _Ops()
167