1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket 9from torch._C import DispatchKey # @manual 10 11_BACKEND_OP_LIB = [ 12 "executorch_prim", 13 "quantized_decomposed", 14 "DO_NOT_USE_TEST_ONLY", 15] 16 17 18class BackendOpOverload(EdgeOpOverload): 19 """OpOverload for backend ops. 20 A Backend operator is a custom op that doesn't show up in ATen dialect. 21 Therefore it must be replacing an existing node or a pattern of nodes in Edge dialect. 22 This data structure makes sure after lower (part of) Edge dialect to backend ops, the whole graph can still be captured properly. 23 24 Difference to delegate: 25 1. delegate result is still a module (a target of call_module, at least for now) while backend op is an operator (a target of call_function). 26 2. backend op is stateless while delegation doesn't have to 27 3. backend op stays in executor standard runtime but delegation doesn't have to 28 29 Examples for backend ops including fused ops for a specific backend, ExecuTorch prim ops to handle symbolic shape. 30 31 Note that the assumption here is that the backend op and the original callable / equivalent callable is 1 - 1 mapping. 32 33 BackendOpOverload makes sure: 34 1. The backend op contains either a CompositeExplicitAutograd or a meta kernel. 35 2. It also holds a reference to the original node/pattern it replaces. 36 Example: 37 38 add -> relu 39 | 40 v 41 add_relu(only works on dsp): hold reference to add -> relu pattern, for re-capturing purpose. 42 43 Retrace example: 44 45 A very common practice in delegate, is that the module needs to be lowered to a backend, then the lowered module needs to be composed with original nn.Module and retrace. 46 47 LoweredModule l_of_m = to_backend(g_of_m.to_edge(), ...) 48 Module main(l_of_m) 49 export(main, inputs) 50 51 """ 52 53 def __init__( 54 self, 55 op_overload: EdgeOpOverload, 56 ): 57 super(self.__class__, self).__init__( 58 op_overload._op, 59 op_overload._schema, 60 ) 61 self._equivalent_callable = None 62 self._has_meta_kernel = self._op.has_kernel_for_dispatch_key(DispatchKey.Meta) 63 self._has_composite_explicit_autograd_kernel = ( 64 self._op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd) 65 ) 66 self._has_composite_implicit_autograd_kernel = ( 67 self._op.has_kernel_for_dispatch_key(DispatchKey.CompositeImplicitAutograd) 68 ) 69 assert ( 70 self._has_meta_kernel 71 or self._has_composite_explicit_autograd_kernel 72 or self._has_composite_implicit_autograd_kernel 73 ), "A backend op must contain either CompositeExplicitAutograd or Meta or CompositeImplicitAutograd kernel." 74 75 76class BackendOpOverloadPacket(EdgeOpOverloadPacket): 77 """OpOverloadPacket for backend ops. 78 Wraps EdgeOpOverloadPacket and overrides __getattr__ to return OpOverload 79 for backend ops. 80 """ 81 82 def __init__( 83 self, 84 qualified_op_name: str, 85 op_name: str, 86 parent_overload_packet: torch._ops.OpOverloadPacket, 87 ): 88 super(self.__class__, self).__init__( 89 qualified_op_name, op_name, parent_overload_packet 90 ) 91 92 def __repr__(self): 93 return "<BackendOpOverloadPacket(op='{}', parent_op='{}')>".format( 94 self._qualified_op_name.replace("::", "."), 95 self._parent_qualified_op_name.replace("::", "."), 96 ) 97 98 def __hash__(self): 99 return hash(self._op) 100 101 def __str__(self): 102 return "{}".format(self._qualified_op_name.replace("::", ".")) 103 104 @property 105 def op(self): 106 return self._op 107 108 def __getattr__(self, key): 109 try: 110 # get edge op, set it as attribute. Note that this way we don't have `_original_pattern`. 111 result = super().__getattr__(key) 112 if isinstance(result, EdgeOpOverload): 113 backend_op = BackendOpOverload(result) 114 setattr(self, key, backend_op) 115 return backend_op 116 else: 117 return result 118 except AttributeError as e: 119 raise AttributeError( 120 "The underlying op of '{}' has no overload name '{}'. Original error message: \n {}".format( 121 self, key, e 122 ) 123 ) from e 124