xref: /aosp_15_r20/external/executorch/exir/dialects/_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import types
8from typing import Callable
9
10import torch
11from executorch.exir.dialects.backend._ops import (
12    _BACKEND_OP_LIB,
13    BackendOpOverloadPacket,
14)
15from executorch.exir.dialects.edge._ops import EdgeOpOverloadPacket
16from torch._C import DispatchKey  # @manual
17from torch.library import Library
18from torchgen.model import FunctionSchema
19
20_OPOVERLOAD_PACKET_CLS_MAPPING = {
21    "edge": EdgeOpOverloadPacket,
22    "backend": BackendOpOverloadPacket,
23}
24
25
26def bind_pattern_to_op(library: Library, schema_or_name: str):
27    """Bind a pattern of ops to a backend op. A backend op should only appear when a user wants to replace a pattern of nodes to a custom op.
28    On this front, the kernel being registered to it determines the decomposing behavior.
29
30    *   If the backend op is registered with an CompositeExplicitAutograd (or Meta) kernel, once the graph is lowered (meaning the pass
31        of replacing a pattern to an op is executed) it will stick in the graph and we won't get the original graph even retrace.
32    *   Otherwise, the backend op should be able to support retracing and be able to "promote" back to the original graph through retracing.
33
34    This macro is aiming to handle this complexity for users and they just need to use this macro on the pattern and we can make a decision for them.
35
36    Args:
37        library (Library): torch library
38        schema_or_name (str): schema string, e.g., "add.int(SymInt a, SymInt b) -> SymInt", or a qualified op name
39    """
40
41    def wrapper(f: Callable):
42        if library.ns not in _BACKEND_OP_LIB:
43            _BACKEND_OP_LIB.append(library.ns)
44        no_namespace = schema_or_name.split("::")[-1]
45        try:
46            # can parse it into a FunctionSchema
47            func = FunctionSchema.parse(no_namespace)
48            name, overload_name = func.name.name.base, func.name.overload_name
49            library.define(no_namespace)
50        except AssertionError:
51            if "." in no_namespace:
52                name, overload_name = no_namespace.split(".")
53            else:
54                name, overload_name = no_namespace, None
55        opname = name + ("." + overload_name if overload_name else "")
56        overload_name = overload_name if overload_name else "default"
57        torch_op = getattr(getattr(getattr(torch.ops, library.ns), name), overload_name)
58        # we can't have both CompositeExplicitAutograd and CompositeImplicitAutograd kernel,
59        # we can't have both Meta and CompositeImplicitAutograd kernel either.
60        keys = [
61            DispatchKey.CompositeExplicitAutograd,
62            DispatchKey.CompositeImplicitAutograd,
63            DispatchKey.Meta,
64        ]
65        if not any(torch_op.has_kernel_for_dispatch_key(k) for k in keys):
66            library.impl(opname, f, "CompositeImplicitAutograd")
67        op = getattr(getattr(getattr(ops.backend, library.ns), name), overload_name)
68        op._equivalent_callable = f
69        return f
70
71    return wrapper
72
73
74class _OpNamespace(types.ModuleType):
75    """
76    EXIR Dialect op namespace object. Contains ops and overloads registered into PyTorch dispatcher.
77    """
78
79    def __init__(self, dialect, name):
80        super().__init__(f"exir.ops.{dialect}.{name}")
81        self._dialect = dialect
82        if dialect == "backend" and name not in _BACKEND_OP_LIB:
83            raise RuntimeError(f"{name} op library does not belong to backend ops.")
84        self._name = name
85        self._dir = []
86        self._op_namespace = getattr(torch.ops, name)
87
88    def __iter__(self):
89        return iter(self._dir)
90
91    def __getattr__(self, op_name):
92        # It is not a valid op_name when __file__ is passed in
93        if op_name == "__file__":
94            return "exir.ops"
95
96        if op_name in self.__dict__:
97            return getattr(self, op_name)
98
99        try:
100            parent_packet = getattr(self._op_namespace, op_name)
101        except AttributeError as e:
102            # Turn this into AttributeError so getattr(obj, key, default)
103            # works (this is called by TorchScript with __origin__)
104            raise AttributeError(
105                f"'_OpNamespace' '{self._dialect}.{self._name}' object has no attribute '{op_name}'"
106            ) from e
107        qualified_op_name = f"{self._name}::{op_name}"
108        opoverload_packet_cls = _OPOVERLOAD_PACKET_CLS_MAPPING[self._dialect]
109        opoverloadpacket = opoverload_packet_cls(
110            qualified_op_name,
111            op_name,
112            parent_overload_packet=parent_packet,
113        )
114        opoverloadpacket.__module__ = self.__module__ + "." + self._name
115        # cache the opoverloadpacket to ensure that each op corresponds to
116        # a unique OpOverloadPacket object
117        setattr(self, op_name, opoverloadpacket)
118        self._dir.append(op_name)
119        return opoverloadpacket
120
121
122class _DialectNamespace(types.ModuleType):
123    """
124    Dialect namespace. Currently the dialects are:
125    - ATen Dialect: core ATen ops and overloads, see torch._ops._OpNamespace
126    - Edge Dialect: ATen ops with explicit Tensor dtype
127    - Backend Dialect: backend ops only meaningful to the backend we are lowering into
128    - Execution Dialect: memory planning ready, all out-variants
129    """
130
131    def __init__(self, dialect_name):
132        super().__init__("exir.ops" + "." + dialect_name)
133        self._dialect_name = dialect_name
134        self._dir = []
135
136    def __getattr__(self, name):
137        if name in self.__dict__:
138            return getattr(self, name)
139        # Here we are creating `exir.ops.<dialect_ns>.<my_namespace>`
140        namespace = _OpNamespace(self._dialect_name, name)
141        setattr(self, name, namespace)
142        self._dir.append(name)
143        return namespace
144
145
146class _Ops(types.ModuleType):
147    __file__ = "_ops.py"
148
149    def __init__(self):
150        super().__init__("exir.ops")
151        self._dir = []
152
153    def __getattr__(self, name):
154        if name in self.__dict__:
155            return getattr(self, name)
156        dialect = _DialectNamespace(name)
157        setattr(self, name, dialect)
158        self._dir.append(name)
159        return dialect
160
161    def __iter__(self):
162        return iter(self._dir)
163
164
165# The ops "namespace"
166ops = _Ops()
167