1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport types 8*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerimport torch 11*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.backend._ops import ( 12*523fa7a6SAndroid Build Coastguard Worker _BACKEND_OP_LIB, 13*523fa7a6SAndroid Build Coastguard Worker BackendOpOverloadPacket, 14*523fa7a6SAndroid Build Coastguard Worker) 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverloadPacket 16*523fa7a6SAndroid Build Coastguard Workerfrom torch._C import DispatchKey # @manual 17*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import Library 18*523fa7a6SAndroid Build Coastguard Workerfrom torchgen.model import FunctionSchema 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker_OPOVERLOAD_PACKET_CLS_MAPPING = { 21*523fa7a6SAndroid Build Coastguard Worker "edge": EdgeOpOverloadPacket, 22*523fa7a6SAndroid Build Coastguard Worker "backend": BackendOpOverloadPacket, 23*523fa7a6SAndroid Build Coastguard Worker} 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Workerdef bind_pattern_to_op(library: Library, schema_or_name: str): 27*523fa7a6SAndroid Build Coastguard Worker """Bind a pattern of ops to a backend op. A backend op should only appear when a user wants to replace a pattern of nodes to a custom op. 28*523fa7a6SAndroid Build Coastguard Worker On this front, the kernel being registered to it determines the decomposing behavior. 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker * If the backend op is registered with an CompositeExplicitAutograd (or Meta) kernel, once the graph is lowered (meaning the pass 31*523fa7a6SAndroid Build Coastguard Worker of replacing a pattern to an op is executed) it will stick in the graph and we won't get the original graph even retrace. 32*523fa7a6SAndroid Build Coastguard Worker * Otherwise, the backend op should be able to support retracing and be able to "promote" back to the original graph through retracing. 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker This macro is aiming to handle this complexity for users and they just need to use this macro on the pattern and we can make a decision for them. 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Worker Args: 37*523fa7a6SAndroid Build Coastguard Worker library (Library): torch library 38*523fa7a6SAndroid Build Coastguard Worker schema_or_name (str): schema string, e.g., "add.int(SymInt a, SymInt b) -> SymInt", or a qualified op name 39*523fa7a6SAndroid Build Coastguard Worker """ 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Worker def wrapper(f: Callable): 42*523fa7a6SAndroid Build Coastguard Worker if library.ns not in _BACKEND_OP_LIB: 43*523fa7a6SAndroid Build Coastguard Worker _BACKEND_OP_LIB.append(library.ns) 44*523fa7a6SAndroid Build Coastguard Worker no_namespace = schema_or_name.split("::")[-1] 45*523fa7a6SAndroid Build Coastguard Worker try: 46*523fa7a6SAndroid Build Coastguard Worker # can parse it into a FunctionSchema 47*523fa7a6SAndroid Build Coastguard Worker func = FunctionSchema.parse(no_namespace) 48*523fa7a6SAndroid Build Coastguard Worker name, overload_name = func.name.name.base, func.name.overload_name 49*523fa7a6SAndroid Build Coastguard Worker library.define(no_namespace) 50*523fa7a6SAndroid Build Coastguard Worker except AssertionError: 51*523fa7a6SAndroid Build Coastguard Worker if "." in no_namespace: 52*523fa7a6SAndroid Build Coastguard Worker name, overload_name = no_namespace.split(".") 53*523fa7a6SAndroid Build Coastguard Worker else: 54*523fa7a6SAndroid Build Coastguard Worker name, overload_name = no_namespace, None 55*523fa7a6SAndroid Build Coastguard Worker opname = name + ("." + overload_name if overload_name else "") 56*523fa7a6SAndroid Build Coastguard Worker overload_name = overload_name if overload_name else "default" 57*523fa7a6SAndroid Build Coastguard Worker torch_op = getattr(getattr(getattr(torch.ops, library.ns), name), overload_name) 58*523fa7a6SAndroid Build Coastguard Worker # we can't have both CompositeExplicitAutograd and CompositeImplicitAutograd kernel, 59*523fa7a6SAndroid Build Coastguard Worker # we can't have both Meta and CompositeImplicitAutograd kernel either. 60*523fa7a6SAndroid Build Coastguard Worker keys = [ 61*523fa7a6SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutograd, 62*523fa7a6SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutograd, 63*523fa7a6SAndroid Build Coastguard Worker DispatchKey.Meta, 64*523fa7a6SAndroid Build Coastguard Worker ] 65*523fa7a6SAndroid Build Coastguard Worker if not any(torch_op.has_kernel_for_dispatch_key(k) for k in keys): 66*523fa7a6SAndroid Build Coastguard Worker library.impl(opname, f, "CompositeImplicitAutograd") 67*523fa7a6SAndroid Build Coastguard Worker op = getattr(getattr(getattr(ops.backend, library.ns), name), overload_name) 68*523fa7a6SAndroid Build Coastguard Worker op._equivalent_callable = f 69*523fa7a6SAndroid Build Coastguard Worker return f 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker return wrapper 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Workerclass _OpNamespace(types.ModuleType): 75*523fa7a6SAndroid Build Coastguard Worker """ 76*523fa7a6SAndroid Build Coastguard Worker EXIR Dialect op namespace object. Contains ops and overloads registered into PyTorch dispatcher. 77*523fa7a6SAndroid Build Coastguard Worker """ 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker def __init__(self, dialect, name): 80*523fa7a6SAndroid Build Coastguard Worker super().__init__(f"exir.ops.{dialect}.{name}") 81*523fa7a6SAndroid Build Coastguard Worker self._dialect = dialect 82*523fa7a6SAndroid Build Coastguard Worker if dialect == "backend" and name not in _BACKEND_OP_LIB: 83*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"{name} op library does not belong to backend ops.") 84*523fa7a6SAndroid Build Coastguard Worker self._name = name 85*523fa7a6SAndroid Build Coastguard Worker self._dir = [] 86*523fa7a6SAndroid Build Coastguard Worker self._op_namespace = getattr(torch.ops, name) 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker def __iter__(self): 89*523fa7a6SAndroid Build Coastguard Worker return iter(self._dir) 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker def __getattr__(self, op_name): 92*523fa7a6SAndroid Build Coastguard Worker # It is not a valid op_name when __file__ is passed in 93*523fa7a6SAndroid Build Coastguard Worker if op_name == "__file__": 94*523fa7a6SAndroid Build Coastguard Worker return "exir.ops" 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Worker if op_name in self.__dict__: 97*523fa7a6SAndroid Build Coastguard Worker return getattr(self, op_name) 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker try: 100*523fa7a6SAndroid Build Coastguard Worker parent_packet = getattr(self._op_namespace, op_name) 101*523fa7a6SAndroid Build Coastguard Worker except AttributeError as e: 102*523fa7a6SAndroid Build Coastguard Worker # Turn this into AttributeError so getattr(obj, key, default) 103*523fa7a6SAndroid Build Coastguard Worker # works (this is called by TorchScript with __origin__) 104*523fa7a6SAndroid Build Coastguard Worker raise AttributeError( 105*523fa7a6SAndroid Build Coastguard Worker f"'_OpNamespace' '{self._dialect}.{self._name}' object has no attribute '{op_name}'" 106*523fa7a6SAndroid Build Coastguard Worker ) from e 107*523fa7a6SAndroid Build Coastguard Worker qualified_op_name = f"{self._name}::{op_name}" 108*523fa7a6SAndroid Build Coastguard Worker opoverload_packet_cls = _OPOVERLOAD_PACKET_CLS_MAPPING[self._dialect] 109*523fa7a6SAndroid Build Coastguard Worker opoverloadpacket = opoverload_packet_cls( 110*523fa7a6SAndroid Build Coastguard Worker qualified_op_name, 111*523fa7a6SAndroid Build Coastguard Worker op_name, 112*523fa7a6SAndroid Build Coastguard Worker parent_overload_packet=parent_packet, 113*523fa7a6SAndroid Build Coastguard Worker ) 114*523fa7a6SAndroid Build Coastguard Worker opoverloadpacket.__module__ = self.__module__ + "." + self._name 115*523fa7a6SAndroid Build Coastguard Worker # cache the opoverloadpacket to ensure that each op corresponds to 116*523fa7a6SAndroid Build Coastguard Worker # a unique OpOverloadPacket object 117*523fa7a6SAndroid Build Coastguard Worker setattr(self, op_name, opoverloadpacket) 118*523fa7a6SAndroid Build Coastguard Worker self._dir.append(op_name) 119*523fa7a6SAndroid Build Coastguard Worker return opoverloadpacket 120*523fa7a6SAndroid Build Coastguard Worker 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Workerclass _DialectNamespace(types.ModuleType): 123*523fa7a6SAndroid Build Coastguard Worker """ 124*523fa7a6SAndroid Build Coastguard Worker Dialect namespace. Currently the dialects are: 125*523fa7a6SAndroid Build Coastguard Worker - ATen Dialect: core ATen ops and overloads, see torch._ops._OpNamespace 126*523fa7a6SAndroid Build Coastguard Worker - Edge Dialect: ATen ops with explicit Tensor dtype 127*523fa7a6SAndroid Build Coastguard Worker - Backend Dialect: backend ops only meaningful to the backend we are lowering into 128*523fa7a6SAndroid Build Coastguard Worker - Execution Dialect: memory planning ready, all out-variants 129*523fa7a6SAndroid Build Coastguard Worker """ 130*523fa7a6SAndroid Build Coastguard Worker 131*523fa7a6SAndroid Build Coastguard Worker def __init__(self, dialect_name): 132*523fa7a6SAndroid Build Coastguard Worker super().__init__("exir.ops" + "." + dialect_name) 133*523fa7a6SAndroid Build Coastguard Worker self._dialect_name = dialect_name 134*523fa7a6SAndroid Build Coastguard Worker self._dir = [] 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker def __getattr__(self, name): 137*523fa7a6SAndroid Build Coastguard Worker if name in self.__dict__: 138*523fa7a6SAndroid Build Coastguard Worker return getattr(self, name) 139*523fa7a6SAndroid Build Coastguard Worker # Here we are creating `exir.ops.<dialect_ns>.<my_namespace>` 140*523fa7a6SAndroid Build Coastguard Worker namespace = _OpNamespace(self._dialect_name, name) 141*523fa7a6SAndroid Build Coastguard Worker setattr(self, name, namespace) 142*523fa7a6SAndroid Build Coastguard Worker self._dir.append(name) 143*523fa7a6SAndroid Build Coastguard Worker return namespace 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Workerclass _Ops(types.ModuleType): 147*523fa7a6SAndroid Build Coastguard Worker __file__ = "_ops.py" 148*523fa7a6SAndroid Build Coastguard Worker 149*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 150*523fa7a6SAndroid Build Coastguard Worker super().__init__("exir.ops") 151*523fa7a6SAndroid Build Coastguard Worker self._dir = [] 152*523fa7a6SAndroid Build Coastguard Worker 153*523fa7a6SAndroid Build Coastguard Worker def __getattr__(self, name): 154*523fa7a6SAndroid Build Coastguard Worker if name in self.__dict__: 155*523fa7a6SAndroid Build Coastguard Worker return getattr(self, name) 156*523fa7a6SAndroid Build Coastguard Worker dialect = _DialectNamespace(name) 157*523fa7a6SAndroid Build Coastguard Worker setattr(self, name, dialect) 158*523fa7a6SAndroid Build Coastguard Worker self._dir.append(name) 159*523fa7a6SAndroid Build Coastguard Worker return dialect 160*523fa7a6SAndroid Build Coastguard Worker 161*523fa7a6SAndroid Build Coastguard Worker def __iter__(self): 162*523fa7a6SAndroid Build Coastguard Worker return iter(self._dir) 163*523fa7a6SAndroid Build Coastguard Worker 164*523fa7a6SAndroid Build Coastguard Worker 165*523fa7a6SAndroid Build Coastguard Worker# The ops "namespace" 166*523fa7a6SAndroid Build Coastguard Workerops = _Ops() 167