1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import types 8from typing import Callable 9 10import torch 11from executorch.exir.dialects.backend._ops import ( 12 _BACKEND_OP_LIB, 13 BackendOpOverloadPacket, 14) 15from executorch.exir.dialects.edge._ops import EdgeOpOverloadPacket 16from torch._C import DispatchKey # @manual 17from torch.library import Library 18from torchgen.model import FunctionSchema 19 20_OPOVERLOAD_PACKET_CLS_MAPPING = { 21 "edge": EdgeOpOverloadPacket, 22 "backend": BackendOpOverloadPacket, 23} 24 25 26def bind_pattern_to_op(library: Library, schema_or_name: str): 27 """Bind a pattern of ops to a backend op. A backend op should only appear when a user wants to replace a pattern of nodes to a custom op. 28 On this front, the kernel being registered to it determines the decomposing behavior. 29 30 * If the backend op is registered with an CompositeExplicitAutograd (or Meta) kernel, once the graph is lowered (meaning the pass 31 of replacing a pattern to an op is executed) it will stick in the graph and we won't get the original graph even retrace. 32 * Otherwise, the backend op should be able to support retracing and be able to "promote" back to the original graph through retracing. 33 34 This macro is aiming to handle this complexity for users and they just need to use this macro on the pattern and we can make a decision for them. 35 36 Args: 37 library (Library): torch library 38 schema_or_name (str): schema string, e.g., "add.int(SymInt a, SymInt b) -> SymInt", or a qualified op name 39 """ 40 41 def wrapper(f: Callable): 42 if library.ns not in _BACKEND_OP_LIB: 43 _BACKEND_OP_LIB.append(library.ns) 44 no_namespace = schema_or_name.split("::")[-1] 45 try: 46 # can parse it into a FunctionSchema 47 func = FunctionSchema.parse(no_namespace) 48 name, overload_name = func.name.name.base, func.name.overload_name 49 library.define(no_namespace) 50 except AssertionError: 51 if "." in no_namespace: 52 name, overload_name = no_namespace.split(".") 53 else: 54 name, overload_name = no_namespace, None 55 opname = name + ("." + overload_name if overload_name else "") 56 overload_name = overload_name if overload_name else "default" 57 torch_op = getattr(getattr(getattr(torch.ops, library.ns), name), overload_name) 58 # we can't have both CompositeExplicitAutograd and CompositeImplicitAutograd kernel, 59 # we can't have both Meta and CompositeImplicitAutograd kernel either. 60 keys = [ 61 DispatchKey.CompositeExplicitAutograd, 62 DispatchKey.CompositeImplicitAutograd, 63 DispatchKey.Meta, 64 ] 65 if not any(torch_op.has_kernel_for_dispatch_key(k) for k in keys): 66 library.impl(opname, f, "CompositeImplicitAutograd") 67 op = getattr(getattr(getattr(ops.backend, library.ns), name), overload_name) 68 op._equivalent_callable = f 69 return f 70 71 return wrapper 72 73 74class _OpNamespace(types.ModuleType): 75 """ 76 EXIR Dialect op namespace object. Contains ops and overloads registered into PyTorch dispatcher. 77 """ 78 79 def __init__(self, dialect, name): 80 super().__init__(f"exir.ops.{dialect}.{name}") 81 self._dialect = dialect 82 if dialect == "backend" and name not in _BACKEND_OP_LIB: 83 raise RuntimeError(f"{name} op library does not belong to backend ops.") 84 self._name = name 85 self._dir = [] 86 self._op_namespace = getattr(torch.ops, name) 87 88 def __iter__(self): 89 return iter(self._dir) 90 91 def __getattr__(self, op_name): 92 # It is not a valid op_name when __file__ is passed in 93 if op_name == "__file__": 94 return "exir.ops" 95 96 if op_name in self.__dict__: 97 return getattr(self, op_name) 98 99 try: 100 parent_packet = getattr(self._op_namespace, op_name) 101 except AttributeError as e: 102 # Turn this into AttributeError so getattr(obj, key, default) 103 # works (this is called by TorchScript with __origin__) 104 raise AttributeError( 105 f"'_OpNamespace' '{self._dialect}.{self._name}' object has no attribute '{op_name}'" 106 ) from e 107 qualified_op_name = f"{self._name}::{op_name}" 108 opoverload_packet_cls = _OPOVERLOAD_PACKET_CLS_MAPPING[self._dialect] 109 opoverloadpacket = opoverload_packet_cls( 110 qualified_op_name, 111 op_name, 112 parent_overload_packet=parent_packet, 113 ) 114 opoverloadpacket.__module__ = self.__module__ + "." + self._name 115 # cache the opoverloadpacket to ensure that each op corresponds to 116 # a unique OpOverloadPacket object 117 setattr(self, op_name, opoverloadpacket) 118 self._dir.append(op_name) 119 return opoverloadpacket 120 121 122class _DialectNamespace(types.ModuleType): 123 """ 124 Dialect namespace. Currently the dialects are: 125 - ATen Dialect: core ATen ops and overloads, see torch._ops._OpNamespace 126 - Edge Dialect: ATen ops with explicit Tensor dtype 127 - Backend Dialect: backend ops only meaningful to the backend we are lowering into 128 - Execution Dialect: memory planning ready, all out-variants 129 """ 130 131 def __init__(self, dialect_name): 132 super().__init__("exir.ops" + "." + dialect_name) 133 self._dialect_name = dialect_name 134 self._dir = [] 135 136 def __getattr__(self, name): 137 if name in self.__dict__: 138 return getattr(self, name) 139 # Here we are creating `exir.ops.<dialect_ns>.<my_namespace>` 140 namespace = _OpNamespace(self._dialect_name, name) 141 setattr(self, name, namespace) 142 self._dir.append(name) 143 return namespace 144 145 146class _Ops(types.ModuleType): 147 __file__ = "_ops.py" 148 149 def __init__(self): 150 super().__init__("exir.ops") 151 self._dir = [] 152 153 def __getattr__(self, name): 154 if name in self.__dict__: 155 return getattr(self, name) 156 dialect = _DialectNamespace(name) 157 setattr(self, name, dialect) 158 self._dir.append(name) 159 return dialect 160 161 def __iter__(self): 162 return iter(self._dir) 163 164 165# The ops "namespace" 166ops = _Ops() 167