1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport abc 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport ctypes 5*da0073e9SAndroid Build Coastguard Workerimport importlib 6*da0073e9SAndroid Build Coastguard Workerimport inspect 7*da0073e9SAndroid Build Coastguard Workerimport sys 8*da0073e9SAndroid Build Coastguard Workerimport types 9*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, List, Set, Type, Union 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerimport torch 12*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 13*da0073e9SAndroid Build Coastguard Workerfrom torch import _utils_internal 14*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey 15*da0073e9SAndroid Build Coastguard Workerfrom torch._functorch.pyfunctorch import dispatch_functorch 16*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker# Query `hasattr` only once. 20*da0073e9SAndroid Build Coastguard Worker_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 24*da0073e9SAndroid Build Coastguard Workerdef dl_open_guard(): 25*da0073e9SAndroid Build Coastguard Worker """ 26*da0073e9SAndroid Build Coastguard Worker Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a 27*da0073e9SAndroid Build Coastguard Worker shared library to load custom operators. 28*da0073e9SAndroid Build Coastguard Worker """ 29*da0073e9SAndroid Build Coastguard Worker if not _SET_GLOBAL_FLAGS: 30*da0073e9SAndroid Build Coastguard Worker yield 31*da0073e9SAndroid Build Coastguard Worker return 32*da0073e9SAndroid Build Coastguard Worker old_flags = sys.getdlopenflags() 33*da0073e9SAndroid Build Coastguard Worker sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) 34*da0073e9SAndroid Build Coastguard Worker try: 35*da0073e9SAndroid Build Coastguard Worker yield 36*da0073e9SAndroid Build Coastguard Worker finally: 37*da0073e9SAndroid Build Coastguard Worker sys.setdlopenflags(old_flags) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerclass OperatorBase: 41*da0073e9SAndroid Build Coastguard Worker """ 42*da0073e9SAndroid Build Coastguard Worker Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator 43*da0073e9SAndroid Build Coastguard Worker (which represents Python-only operators that are unrepresentable in TorchScript). 44*da0073e9SAndroid Build Coastguard Worker """ 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker def __init__(self): 47*da0073e9SAndroid Build Coastguard Worker # The dispatch cache precomputes a mapping of dispatch key that the 48*da0073e9SAndroid Build Coastguard Worker # dispatcher wants to dispatch to, to an actual implementation of the 49*da0073e9SAndroid Build Coastguard Worker # dispatch key. Confusingly, the actual implementation could *also* be a 50*da0073e9SAndroid Build Coastguard Worker # dispatch key, but in this case, this refers to the C++ kernel that 51*da0073e9SAndroid Build Coastguard Worker # was registered to some dispatch key. Aliases are permitted in the 52*da0073e9SAndroid Build Coastguard Worker # latter but not the former; for example, you might lookup the 53*da0073e9SAndroid Build Coastguard Worker # entry for AutogradCPU, and this maps you to the Autograd key for 54*da0073e9SAndroid Build Coastguard Worker # the generic autograd kernel that works for all devices. Since this 55*da0073e9SAndroid Build Coastguard Worker # is the Python dispatcher, you can also put an arbitrary Python 56*da0073e9SAndroid Build Coastguard Worker # callable to call instead. This handler gets precisely the 57*da0073e9SAndroid Build Coastguard Worker # args/kwargs that the operator was __call__'ed with. 58*da0073e9SAndroid Build Coastguard Worker # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp 59*da0073e9SAndroid Build Coastguard Worker # for use with OpOverload; cache lookup is done entirely from C++ 60*da0073e9SAndroid Build Coastguard Worker # for speed. 61*da0073e9SAndroid Build Coastguard Worker # TODO: The cache is NOT currently used by HigherOrderOperator, but it should! 62*da0073e9SAndroid Build Coastguard Worker self._dispatch_cache: Dict[ 63*da0073e9SAndroid Build Coastguard Worker DispatchKey, Union[DispatchKey, Callable[..., Any]] 64*da0073e9SAndroid Build Coastguard Worker ] = {} 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker # This table allows you to override the behavior of a particular 67*da0073e9SAndroid Build Coastguard Worker # dispatch key to call a custom Python function, rather than the 68*da0073e9SAndroid Build Coastguard Worker # ordinary C++ configured behavior. This is the raison d'etre of 69*da0073e9SAndroid Build Coastguard Worker # Python dispatcher: to let you program the dispatcher from Python 70*da0073e9SAndroid Build Coastguard Worker # in case you need something unusual, and don't want to clobber 71*da0073e9SAndroid Build Coastguard Worker # the existing registrations using the Python operator registration 72*da0073e9SAndroid Build Coastguard Worker # API. 73*da0073e9SAndroid Build Coastguard Worker self.py_kernels: Dict[DispatchKey, Callable[..., Any]] = {} 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker # This table allows you to override the behavior of a particular 76*da0073e9SAndroid Build Coastguard Worker # operator for a particular TorchDispatchMode. In practice, 77*da0073e9SAndroid Build Coastguard Worker # we are using this mostly for ProxyTensorMode. Modes can be 78*da0073e9SAndroid Build Coastguard Worker # thought of as an open world extension of dispatch keys, so it 79*da0073e9SAndroid Build Coastguard Worker # makes sense that you should be able to register them, the same 80*da0073e9SAndroid Build Coastguard Worker # way you can register dispatch keys. 81*da0073e9SAndroid Build Coastguard Worker self.python_key_table: Dict[ 82*da0073e9SAndroid Build Coastguard Worker Union[Type[TorchDispatchMode], Type[torch.Tensor]], Callable[..., Any] 83*da0073e9SAndroid Build Coastguard Worker ] = {} 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker # This table allows you to override the behavior of functorch 86*da0073e9SAndroid Build Coastguard Worker # transformations. NB: this currently only does something for 87*da0073e9SAndroid Build Coastguard Worker # HigherOrderOperator 88*da0073e9SAndroid Build Coastguard Worker self.functorch_table = {} 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker def __call__(self, *args, **kwargs): 91*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker def has_kernel_for_dispatch_key(self, k): 94*da0073e9SAndroid Build Coastguard Worker return k in self.py_kernels 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker def has_kernel_for_any_dispatch_key(self, ks): 97*da0073e9SAndroid Build Coastguard Worker for k in self.py_kernels: 98*da0073e9SAndroid Build Coastguard Worker if not torch._C._dispatch_is_alias_key(k) and ks.has(k): 99*da0073e9SAndroid Build Coastguard Worker return True 100*da0073e9SAndroid Build Coastguard Worker return False 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def py_impl(self, k): 103*da0073e9SAndroid Build Coastguard Worker def inner(fn): 104*da0073e9SAndroid Build Coastguard Worker if inspect.isclass(k) and ( 105*da0073e9SAndroid Build Coastguard Worker issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor) 106*da0073e9SAndroid Build Coastguard Worker ): 107*da0073e9SAndroid Build Coastguard Worker assert k not in self.python_key_table 108*da0073e9SAndroid Build Coastguard Worker # TODO(voz): Should we replace setting DispatchKey.Python entirely with setting mode keys? 109*da0073e9SAndroid Build Coastguard Worker self.python_key_table[k] = fn 110*da0073e9SAndroid Build Coastguard Worker self._dispatch_cache.clear() 111*da0073e9SAndroid Build Coastguard Worker return fn 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker if isinstance(k, torch._C._functorch.TransformType): 114*da0073e9SAndroid Build Coastguard Worker assert k not in self.functorch_table 115*da0073e9SAndroid Build Coastguard Worker self.functorch_table[k] = fn 116*da0073e9SAndroid Build Coastguard Worker return fn 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker assert isinstance(k, DispatchKey) 119*da0073e9SAndroid Build Coastguard Worker assert ( 120*da0073e9SAndroid Build Coastguard Worker k != DispatchKey.Python 121*da0073e9SAndroid Build Coastguard Worker ), "Please register a mode for the torch._C.DispatchKey.Python key instead." 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker if k in self.py_kernels: 124*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 125*da0073e9SAndroid Build Coastguard Worker f"Trying to override a python impl for {k} on operator {self.name()}" 126*da0073e9SAndroid Build Coastguard Worker ) 127*da0073e9SAndroid Build Coastguard Worker self.py_kernels[k] = fn 128*da0073e9SAndroid Build Coastguard Worker self._dispatch_cache.clear() 129*da0073e9SAndroid Build Coastguard Worker return fn 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker return inner 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker # Registers an implementation to all **3** variants of functionalization that we have: 134*da0073e9SAndroid Build Coastguard Worker # - DispatchKey.Functionalize 135*da0073e9SAndroid Build Coastguard Worker # - functorch.TransformType.Functionalize 136*da0073e9SAndroid Build Coastguard Worker # - FunctionalTensorMode 137*da0073e9SAndroid Build Coastguard Worker # Example: 138*da0073e9SAndroid Build Coastguard Worker # @py_functionalize_impl 139*da0073e9SAndroid Build Coastguard Worker # def functionalize_rule(ctx, inner_f, *args): 140*da0073e9SAndroid Build Coastguard Worker # args_unwrapped = ctx.unwrap_tensors(args) 141*da0073e9SAndroid Build Coastguard Worker # with ctx.redispatch_to_next(): 142*da0073e9SAndroid Build Coastguard Worker # out = ctx.functionalize(inner_f)(*args_unwrapped) 143*da0073e9SAndroid Build Coastguard Worker # return ctx.wrap_tensors(out) 144*da0073e9SAndroid Build Coastguard Worker def py_functionalize_impl(self, fn): 145*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.functional_tensor import ( 146*da0073e9SAndroid Build Coastguard Worker CppFunctionalizeAPI as _CppFunctionalizeAPI, 147*da0073e9SAndroid Build Coastguard Worker FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI, 148*da0073e9SAndroid Build Coastguard Worker PythonFunctionalizeAPI as _PythonFunctionalizeAPI, 149*da0073e9SAndroid Build Coastguard Worker ) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker # Construct our three flavors of functionalization, 152*da0073e9SAndroid Build Coastguard Worker # each of which have slightly different wrap/unwrap/redispatch policies 153*da0073e9SAndroid Build Coastguard Worker def functionalize_dk_fn(*args, **kwargs): 154*da0073e9SAndroid Build Coastguard Worker return fn(_CppFunctionalizeAPI(), *args, **kwargs) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker def functionalize_dispatch_mode_fn(mode, *args, **kwargs): 157*da0073e9SAndroid Build Coastguard Worker return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker def functionalize_functorch_fn(interpreter, *args, **kwargs): 160*da0073e9SAndroid Build Coastguard Worker return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn) 163*da0073e9SAndroid Build Coastguard Worker self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)( 164*da0073e9SAndroid Build Coastguard Worker functionalize_dispatch_mode_fn 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker self.py_impl(torch._C._functorch.TransformType.Functionalize)( 167*da0073e9SAndroid Build Coastguard Worker functionalize_functorch_fn 168*da0073e9SAndroid Build Coastguard Worker ) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker return fn 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker def name(self): 173*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker# Equivalent to computeDispatchTableEntryWithDebug 177*da0073e9SAndroid Build Coastguard Workerdef resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] 178*da0073e9SAndroid Build Coastguard Worker # 1. (Direct) operator registration 179*da0073e9SAndroid Build Coastguard Worker if op.has_kernel_for_dispatch_key(k): 180*da0073e9SAndroid Build Coastguard Worker return k 181*da0073e9SAndroid Build Coastguard Worker # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available 182*da0073e9SAndroid Build Coastguard Worker cand = DispatchKey.CompositeExplicitAutogradNonFunctional 183*da0073e9SAndroid Build Coastguard Worker if ( 184*da0073e9SAndroid Build Coastguard Worker k == DispatchKey.Undefined or is_included_in_alias(k, cand) 185*da0073e9SAndroid Build Coastguard Worker ) and op.has_kernel_for_dispatch_key(cand): 186*da0073e9SAndroid Build Coastguard Worker return cand 187*da0073e9SAndroid Build Coastguard Worker # 2.2 Use CompositeExplicitAutograd kernel if available 188*da0073e9SAndroid Build Coastguard Worker cand = DispatchKey.CompositeExplicitAutograd 189*da0073e9SAndroid Build Coastguard Worker if ( 190*da0073e9SAndroid Build Coastguard Worker k == DispatchKey.Undefined or is_included_in_alias(k, cand) 191*da0073e9SAndroid Build Coastguard Worker ) and op.has_kernel_for_dispatch_key(cand): 192*da0073e9SAndroid Build Coastguard Worker return cand 193*da0073e9SAndroid Build Coastguard Worker has_backend_kernel = op.has_kernel_for_any_dispatch_key( 194*da0073e9SAndroid Build Coastguard Worker torch._C._dispatch_get_backend_keyset_from_autograd(k) 195*da0073e9SAndroid Build Coastguard Worker ) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd) 196*da0073e9SAndroid Build Coastguard Worker # 2.3. Use CompositeImplicitAutograd kernel if available 197*da0073e9SAndroid Build Coastguard Worker cand = DispatchKey.CompositeImplicitAutogradNestedTensor 198*da0073e9SAndroid Build Coastguard Worker if ( 199*da0073e9SAndroid Build Coastguard Worker (k != DispatchKey.Undefined and is_included_in_alias(k, cand)) 200*da0073e9SAndroid Build Coastguard Worker and op.has_kernel_for_dispatch_key(cand) 201*da0073e9SAndroid Build Coastguard Worker and not has_backend_kernel 202*da0073e9SAndroid Build Coastguard Worker ): 203*da0073e9SAndroid Build Coastguard Worker return cand 204*da0073e9SAndroid Build Coastguard Worker cand = DispatchKey.CompositeImplicitAutograd 205*da0073e9SAndroid Build Coastguard Worker if ( 206*da0073e9SAndroid Build Coastguard Worker k == DispatchKey.Undefined or is_included_in_alias(k, cand) 207*da0073e9SAndroid Build Coastguard Worker ) and op.has_kernel_for_dispatch_key(cand): 208*da0073e9SAndroid Build Coastguard Worker if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key( 209*da0073e9SAndroid Build Coastguard Worker torch._C._dispatch_autogradother_backends 210*da0073e9SAndroid Build Coastguard Worker ): 211*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("ambiguous autogradother kernel") 212*da0073e9SAndroid Build Coastguard Worker elif not has_backend_kernel: 213*da0073e9SAndroid Build Coastguard Worker return cand 214*da0073e9SAndroid Build Coastguard Worker # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available 215*da0073e9SAndroid Build Coastguard Worker cand = DispatchKey.Autograd 216*da0073e9SAndroid Build Coastguard Worker if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): 217*da0073e9SAndroid Build Coastguard Worker return cand 218*da0073e9SAndroid Build Coastguard Worker # 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available 219*da0073e9SAndroid Build Coastguard Worker cand = DispatchKey.FuncTorchBatchedDecomposition 220*da0073e9SAndroid Build Coastguard Worker if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): 221*da0073e9SAndroid Build Coastguard Worker return cand 222*da0073e9SAndroid Build Coastguard Worker # Backend fallback 223*da0073e9SAndroid Build Coastguard Worker if torch._C._dispatch_has_backend_fallback(k): 224*da0073e9SAndroid Build Coastguard Worker # The dispatch key itself will implicitly route to backend fallback. 225*da0073e9SAndroid Build Coastguard Worker # This is probably not great for the pure Python implementation. 226*da0073e9SAndroid Build Coastguard Worker return k 227*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}") 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker_higher_order_ops: Dict[str, "HigherOrderOperator"] = {} 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [ 233*da0073e9SAndroid Build Coastguard Worker DispatchKey.PythonDispatcher, # type: ignore[attr-defined] 234*da0073e9SAndroid Build Coastguard Worker DispatchKey.PythonTLSSnapshot, # type: ignore[attr-defined] 235*da0073e9SAndroid Build Coastguard Worker DispatchKey.ADInplaceOrView, 236*da0073e9SAndroid Build Coastguard Worker DispatchKey.BackendSelect, 237*da0073e9SAndroid Build Coastguard Worker DispatchKey.AutocastCPU, # type: ignore[attr-defined] 238*da0073e9SAndroid Build Coastguard Worker DispatchKey.AutocastCUDA, # type: ignore[attr-defined] 239*da0073e9SAndroid Build Coastguard Worker] 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Workerclass HigherOrderOperator(OperatorBase, abc.ABC): 243*da0073e9SAndroid Build Coastguard Worker # The HigherOrderOperator will appear as torch.ops.higher_order.{name} 244*da0073e9SAndroid Build Coastguard Worker # 245*da0073e9SAndroid Build Coastguard Worker # If you're creating a new HigherOrderOperator, please do not change the 246*da0073e9SAndroid Build Coastguard Worker # default. Adding operators to the global torch.ops namespace is a bad 247*da0073e9SAndroid Build Coastguard Worker # practice due to name collisions. 248*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 249*da0073e9SAndroid Build Coastguard Worker super().__init__() 250*da0073e9SAndroid Build Coastguard Worker if type(self) is HigherOrderOperator: 251*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 252*da0073e9SAndroid Build Coastguard Worker "Direct instantiation of HigherOrderOperator is not allowed. Please subclass it." 253*da0073e9SAndroid Build Coastguard Worker ) 254*da0073e9SAndroid Build Coastguard Worker self._name = name 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker # Make _OPNamespace not scream, this whole name based association needs a good hard look 257*da0073e9SAndroid Build Coastguard Worker self.__name__ = name 258*da0073e9SAndroid Build Coastguard Worker _higher_order_ops[name] = self 259*da0073e9SAndroid Build Coastguard Worker self._ns = "higher_order" 260*da0073e9SAndroid Build Coastguard Worker self.__module__ = "torch.ops.higher_order" 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker self.non_fallthrough_keys = torch._C._dispatch_keyset_full() 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS: 265*da0073e9SAndroid Build Coastguard Worker self.fallthrough(dispatch_key) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker # [NOTE] We have to register pre-dispatch key implementation 268*da0073e9SAndroid Build Coastguard Worker # because sometimes HOP use aot-dispatch tracing to detect certaion 269*da0073e9SAndroid Build Coastguard Worker # mutations. This is problematic when we are functionalizing HOP 270*da0073e9SAndroid Build Coastguard Worker # during pre-dispatch because when the inner tracer starts, it will see 271*da0073e9SAndroid Build Coastguard Worker # that PreDispatch key is still active. In that case, we just redispatch 272*da0073e9SAndroid Build Coastguard Worker # it to next key. This is only safe to do when PreDispatch key stack has no 273*da0073e9SAndroid Build Coastguard Worker # active modes. 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker def py_impl(self, k): 276*da0073e9SAndroid Build Coastguard Worker if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k): 277*da0073e9SAndroid Build Coastguard Worker self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) 278*da0073e9SAndroid Build Coastguard Worker return super().py_impl(k) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker @property 281*da0073e9SAndroid Build Coastguard Worker def namespace(self): 282*da0073e9SAndroid Build Coastguard Worker return self._ns 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker def fallthrough(self, dispatch_key): 285*da0073e9SAndroid Build Coastguard Worker self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key) 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker # Use positional-only argument to avoid naming collide with custom ops arguments 288*da0073e9SAndroid Build Coastguard Worker # that are named "self". 289*da0073e9SAndroid Build Coastguard Worker def dispatch(self, /, dispatch_key, *args, **kwargs): 290*da0073e9SAndroid Build Coastguard Worker from torch.utils._python_dispatch import _get_current_dispatch_mode 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker if dispatch_key in self._dispatch_cache: 293*da0073e9SAndroid Build Coastguard Worker kernel = self._dispatch_cache[dispatch_key] 294*da0073e9SAndroid Build Coastguard Worker assert not isinstance(kernel, DispatchKey) 295*da0073e9SAndroid Build Coastguard Worker return kernel(*args, **kwargs) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode: 298*da0073e9SAndroid Build Coastguard Worker return dispatch_functorch(self, args, kwargs) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker if dispatch_key == DispatchKey.Python: 301*da0073e9SAndroid Build Coastguard Worker # Keep the following 1:1 with handle_torch_function_no_python_arg_parser 302*da0073e9SAndroid Build Coastguard Worker # in torch/csrc/utils/python_arg_parser.cpp 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker overloaded_args_list = [] 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker def has_python_key(tensor): 307*da0073e9SAndroid Build Coastguard Worker return torch._C._dispatch_keys(tensor).has("Python") 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker def check_overloaded(arg): 310*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, torch.Tensor) and has_python_key(arg): 311*da0073e9SAndroid Build Coastguard Worker overloaded_args_list.append(arg) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker for arg in (*args, *kwargs.values()): 314*da0073e9SAndroid Build Coastguard Worker check_overloaded(arg) 315*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, (list, tuple)): 316*da0073e9SAndroid Build Coastguard Worker for a in arg: 317*da0073e9SAndroid Build Coastguard Worker check_overloaded(a) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker overloaded_args = tuple(overloaded_args_list) 320*da0073e9SAndroid Build Coastguard Worker overloaded_types = tuple(type(arg) for arg in overloaded_args) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker # Step 1: dispatch on any user TorchDispatchModes 323*da0073e9SAndroid Build Coastguard Worker from torch.utils._python_dispatch import _pop_mode_temporarily 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker curr_mode = _get_current_dispatch_mode() 326*da0073e9SAndroid Build Coastguard Worker if curr_mode is not None: 327*da0073e9SAndroid Build Coastguard Worker if type(curr_mode) in self.python_key_table: 328*da0073e9SAndroid Build Coastguard Worker handler = self.python_key_table[type(curr_mode)] 329*da0073e9SAndroid Build Coastguard Worker with _pop_mode_temporarily() as mode: 330*da0073e9SAndroid Build Coastguard Worker # "natural" calling convention: (mode, *args, **kwargs) 331*da0073e9SAndroid Build Coastguard Worker # TODO(rzou): we should support torch_dispatch calling convention too. 332*da0073e9SAndroid Build Coastguard Worker result = handler(mode, *args, **kwargs) 333*da0073e9SAndroid Build Coastguard Worker else: 334*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 335*da0073e9SAndroid Build Coastguard Worker f"There was no rule registered for HOP {self._name} and mode {curr_mode}. " 336*da0073e9SAndroid Build Coastguard Worker f"We recommend filing an issue." 337*da0073e9SAndroid Build Coastguard Worker ) 338*da0073e9SAndroid Build Coastguard Worker if result is not NotImplemented: 339*da0073e9SAndroid Build Coastguard Worker return result 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker # Step 2: dispatch on any subclasses 342*da0073e9SAndroid Build Coastguard Worker for arg in overloaded_args: 343*da0073e9SAndroid Build Coastguard Worker subclass_type = type(arg) 344*da0073e9SAndroid Build Coastguard Worker if ( 345*da0073e9SAndroid Build Coastguard Worker subclass_type.__torch_dispatch__ 346*da0073e9SAndroid Build Coastguard Worker == torch._C._disabled_torch_dispatch_impl 347*da0073e9SAndroid Build Coastguard Worker ): 348*da0073e9SAndroid Build Coastguard Worker continue 349*da0073e9SAndroid Build Coastguard Worker if subclass_type in self.python_key_table: 350*da0073e9SAndroid Build Coastguard Worker handler = self.python_key_table[subclass_type] 351*da0073e9SAndroid Build Coastguard Worker # "natural" calling convention: (*args, **kwargs) 352*da0073e9SAndroid Build Coastguard Worker # TODO(rzou): we should support torch_dispatch calling convention too. 353*da0073e9SAndroid Build Coastguard Worker result = handler(*args, **kwargs) 354*da0073e9SAndroid Build Coastguard Worker else: 355*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 356*da0073e9SAndroid Build Coastguard Worker f"There was no rule registered for HOP {self._name} and subclass {subclass_type}. " 357*da0073e9SAndroid Build Coastguard Worker f"We recommend filing an issue." 358*da0073e9SAndroid Build Coastguard Worker ) 359*da0073e9SAndroid Build Coastguard Worker if result is not NotImplemented: 360*da0073e9SAndroid Build Coastguard Worker return result 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker # All handlers returned NotImplemented 363*da0073e9SAndroid Build Coastguard Worker raise TypeError( 364*da0073e9SAndroid Build Coastguard Worker f"Multiple dispatch failed for {self._name}. There was no registered that " 365*da0073e9SAndroid Build Coastguard Worker f"did not return NotImplemented. Use HOP.py_impl to register some. " 366*da0073e9SAndroid Build Coastguard Worker f"Tried mode: {curr_mode}) and subclasses: " 367*da0073e9SAndroid Build Coastguard Worker f"{[type(a) for a in overloaded_args]}" 368*da0073e9SAndroid Build Coastguard Worker ) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined] 371*da0073e9SAndroid Build Coastguard Worker if functionality_key == DispatchKey.PreDispatch: 372*da0073e9SAndroid Build Coastguard Worker from torch.utils._python_dispatch import _pop_mode_temporarily 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker # The check for Python in the exclude set is so we properly respect `with no_dispatch()` 375*da0073e9SAndroid Build Coastguard Worker # calls inside of a mode. 376*da0073e9SAndroid Build Coastguard Worker if ( 377*da0073e9SAndroid Build Coastguard Worker _len_torch_dispatch_stack_pre_dispatch() > 0 378*da0073e9SAndroid Build Coastguard Worker ) and not torch._C._dispatch_tls_is_dispatch_key_excluded( 379*da0073e9SAndroid Build Coastguard Worker DispatchKey.Python 380*da0073e9SAndroid Build Coastguard Worker ): 381*da0073e9SAndroid Build Coastguard Worker curr_mode = _get_current_dispatch_mode_pre_dispatch() 382*da0073e9SAndroid Build Coastguard Worker assert ( 383*da0073e9SAndroid Build Coastguard Worker curr_mode is not None 384*da0073e9SAndroid Build Coastguard Worker ), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode." 385*da0073e9SAndroid Build Coastguard Worker assert ( 386*da0073e9SAndroid Build Coastguard Worker type(curr_mode) in self.python_key_table 387*da0073e9SAndroid Build Coastguard Worker ), f"Current active mode {curr_mode} not registered" 388*da0073e9SAndroid Build Coastguard Worker handler = self.python_key_table[type(curr_mode)] 389*da0073e9SAndroid Build Coastguard Worker with _pop_mode_temporarily(functionality_key) as mode: 390*da0073e9SAndroid Build Coastguard Worker return handler(mode, *args, **kwargs) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker final_key = resolve_key(self, dispatch_key) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker # This can current fail due to backend fallbacks. You just have to 395*da0073e9SAndroid Build Coastguard Worker # register them by hand for HigherOrderOperator. 396*da0073e9SAndroid Build Coastguard Worker if final_key not in self.py_kernels: 397*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 398*da0073e9SAndroid Build Coastguard Worker f"could not find kernel for HigherOrderOperator {self._name} " 399*da0073e9SAndroid Build Coastguard Worker f"at dispatch key {final_key} (resolved from {dispatch_key})" 400*da0073e9SAndroid Build Coastguard Worker ) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker # [NOTE] We shouldn't cache PreDispatch kernel here because depending 403*da0073e9SAndroid Build Coastguard Worker # on what modes are active, predispatch behaviour is different. 404*da0073e9SAndroid Build Coastguard Worker # Also we do same thing for normal ops: 405*da0073e9SAndroid Build Coastguard Worker # See Note [Not Caching Per-Dispatch-Key Mode Handlers] 406*da0073e9SAndroid Build Coastguard Worker if dispatch_key != DispatchKey.PreDispatch: 407*da0073e9SAndroid Build Coastguard Worker self._dispatch_cache[dispatch_key] = self.py_kernels[final_key] 408*da0073e9SAndroid Build Coastguard Worker kernel = self.py_kernels[final_key] 409*da0073e9SAndroid Build Coastguard Worker # It's illegal to register DispatchKey to py_kernels, since there's no 410*da0073e9SAndroid Build Coastguard Worker # C++ kernel to call into 411*da0073e9SAndroid Build Coastguard Worker assert not isinstance(kernel, DispatchKey) 412*da0073e9SAndroid Build Coastguard Worker return kernel(*args, **kwargs) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker @abc.abstractmethod 415*da0073e9SAndroid Build Coastguard Worker def __call__(self, /, *args, **kwargs): 416*da0073e9SAndroid Build Coastguard Worker # Dynamo already traces the body of HigherOrderOp beforehand when it 417*da0073e9SAndroid Build Coastguard Worker # so no need to trace into it. 418*da0073e9SAndroid Build Coastguard Worker from torch._dynamo import disable 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker @disable 421*da0073e9SAndroid Build Coastguard Worker def wrapper(): 422*da0073e9SAndroid Build Coastguard Worker flat_args = _to_flat_tuple(args, kwargs) 423*da0073e9SAndroid Build Coastguard Worker if torch.overrides.has_torch_function(flat_args): 424*da0073e9SAndroid Build Coastguard Worker return torch.overrides.handle_torch_function( 425*da0073e9SAndroid Build Coastguard Worker self, flat_args, *args, **kwargs 426*da0073e9SAndroid Build Coastguard Worker ) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys) 429*da0073e9SAndroid Build Coastguard Worker return self.dispatch( 430*da0073e9SAndroid Build Coastguard Worker dispatch_key_set.highestPriorityTypeId(), *args, **kwargs 431*da0073e9SAndroid Build Coastguard Worker ) 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker return wrapper() 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker def __str__(self): 436*da0073e9SAndroid Build Coastguard Worker return f"{self.name()}" 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def name(self): 439*da0073e9SAndroid Build Coastguard Worker return self._name 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Workerdef _to_flat_tuple(args, kwargs): 443*da0073e9SAndroid Build Coastguard Worker return pytree.arg_tree_leaves(*args, **kwargs) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Workerdef _compute_keyset(args, kwargs, non_fallthrough_keys): 447*da0073e9SAndroid Build Coastguard Worker tensors = _get_tensors(args, kwargs) 448*da0073e9SAndroid Build Coastguard Worker return key_extractor(tensors, non_fallthrough_keys) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Workerdef _get_tensors(args, kwargs): 452*da0073e9SAndroid Build Coastguard Worker flat_all = _to_flat_tuple(args, kwargs) 453*da0073e9SAndroid Build Coastguard Worker tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)] 454*da0073e9SAndroid Build Coastguard Worker return tuple(tensor_args) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker# Note - this should maintain identical impl to the C++ dispatcher key extraction logic 458*da0073e9SAndroid Build Coastguard Worker# at ATen/core/dispatch/DispatchKeyExtractor.h 459*da0073e9SAndroid Build Coastguard Workerdef key_extractor(tensors, key_mask): 460*da0073e9SAndroid Build Coastguard Worker key_set = torch._C._dispatch_tls_local_include_set() 461*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 462*da0073e9SAndroid Build Coastguard Worker key_set = key_set | torch._C._dispatch_keys(tensor) 463*da0073e9SAndroid Build Coastguard Worker key_set = key_set - torch._C._dispatch_tls_local_exclude_set() 464*da0073e9SAndroid Build Coastguard Worker key_set = key_set & key_mask 465*da0073e9SAndroid Build Coastguard Worker return key_set 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker# Mode stack for PreDispatchKey 469*da0073e9SAndroid Build Coastguard Worker# it should always have three keys with 470*da0073e9SAndroid Build Coastguard Worker# priority given to FunctionalTensorMode and 471*da0073e9SAndroid Build Coastguard Worker# then ProxyTorchDispatchMode. It means that 472*da0073e9SAndroid Build Coastguard Worker# slot 0 belongs to ProxyTorchDispatchMode and 473*da0073e9SAndroid Build Coastguard Worker# slot 1 belongs to FunctionalTensorMode. 474*da0073e9SAndroid Build Coastguard Worker# 475*da0073e9SAndroid Build Coastguard Worker# SchemaCheckMode is separate from the other 2, 476*da0073e9SAndroid Build Coastguard Worker# and is only valid when the stack is empty. 477*da0073e9SAndroid Build Coastguard Worker# SchemaCheckMode is for testing purposes, and 478*da0073e9SAndroid Build Coastguard Worker# is meant to run in eager mode on concrete inputs, 479*da0073e9SAndroid Build Coastguard Worker# checking for incorrect schemas in regards to 480*da0073e9SAndroid Build Coastguard Worker# aliasing or mutating ops. 481*da0073e9SAndroid Build Coastguard Workerclass _ModeStackStateForPreDispatch: 482*da0073e9SAndroid Build Coastguard Worker def __init__(self): 483*da0073e9SAndroid Build Coastguard Worker self.__infra_modes = [None, None] 484*da0073e9SAndroid Build Coastguard Worker self._schema_check_mode = None 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker def set(self, index, mode): 487*da0073e9SAndroid Build Coastguard Worker assert index < len(self.__infra_modes) 488*da0073e9SAndroid Build Coastguard Worker self.__infra_modes[index] = mode 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker def get(self, index): 491*da0073e9SAndroid Build Coastguard Worker assert index < len(self.__infra_modes) 492*da0073e9SAndroid Build Coastguard Worker return self.__infra_modes[index] 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker def count(self): 495*da0073e9SAndroid Build Coastguard Worker return len([i for i in self.__infra_modes if i is not None]) + int( 496*da0073e9SAndroid Build Coastguard Worker self._schema_check_mode is not None 497*da0073e9SAndroid Build Coastguard Worker ) 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker_mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch() 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Workerdef unset_mode_pre_dispatch(mode_key, schema_check=False): 504*da0073e9SAndroid Build Coastguard Worker current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch() 505*da0073e9SAndroid Build Coastguard Worker assert mode_key is None or mode_key in ( 506*da0073e9SAndroid Build Coastguard Worker torch._C._TorchDispatchModeKey.PROXY, 507*da0073e9SAndroid Build Coastguard Worker torch._C._TorchDispatchModeKey.FUNCTIONAL, 508*da0073e9SAndroid Build Coastguard Worker ) 509*da0073e9SAndroid Build Coastguard Worker if schema_check: 510*da0073e9SAndroid Build Coastguard Worker assert mode_key is None 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker def _unset_mode(): 513*da0073e9SAndroid Build Coastguard Worker if mode_key == torch._C._TorchDispatchModeKey.PROXY: 514*da0073e9SAndroid Build Coastguard Worker current_mode = current_mode_stack_pre_dispatch.get(0) 515*da0073e9SAndroid Build Coastguard Worker mode_stack_state_for_pre_dispatch().set(0, None) 516*da0073e9SAndroid Build Coastguard Worker return current_mode 517*da0073e9SAndroid Build Coastguard Worker elif mode_key == torch._C._TorchDispatchModeKey.FUNCTIONAL: 518*da0073e9SAndroid Build Coastguard Worker current_mode = current_mode_stack_pre_dispatch.get(1) 519*da0073e9SAndroid Build Coastguard Worker mode_stack_state_for_pre_dispatch().set(1, None) 520*da0073e9SAndroid Build Coastguard Worker return current_mode 521*da0073e9SAndroid Build Coastguard Worker else: 522*da0073e9SAndroid Build Coastguard Worker current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode 523*da0073e9SAndroid Build Coastguard Worker mode_stack_state_for_pre_dispatch()._schema_check_mode = None 524*da0073e9SAndroid Build Coastguard Worker return current_mode 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker current_mode = _unset_mode() 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch() 529*da0073e9SAndroid Build Coastguard Worker # When we are unsetting a mode, we need to check if there is 530*da0073e9SAndroid Build Coastguard Worker # active mode left on the PreDispatch key. If there is nothing 531*da0073e9SAndroid Build Coastguard Worker # active, we need to remove PreDispatch key from local dispatch include 532*da0073e9SAndroid Build Coastguard Worker # set. 533*da0073e9SAndroid Build Coastguard Worker if new_pre_dispatch_len == 0: 534*da0073e9SAndroid Build Coastguard Worker torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, False) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker return current_mode 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Workerdef _set_mode_pre_dispatch(mode): 540*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.functional_tensor import FunctionalTensorMode 541*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.schema_check_mode import SchemaCheckMode 542*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker assert isinstance( 545*da0073e9SAndroid Build Coastguard Worker mode, 546*da0073e9SAndroid Build Coastguard Worker ( 547*da0073e9SAndroid Build Coastguard Worker FunctionalTensorMode, 548*da0073e9SAndroid Build Coastguard Worker ProxyTorchDispatchMode, 549*da0073e9SAndroid Build Coastguard Worker SchemaCheckMode, 550*da0073e9SAndroid Build Coastguard Worker ), 551*da0073e9SAndroid Build Coastguard Worker ) 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch() 554*da0073e9SAndroid Build Coastguard Worker if isinstance(mode, SchemaCheckMode): 555*da0073e9SAndroid Build Coastguard Worker current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode 556*da0073e9SAndroid Build Coastguard Worker if previous_mode_stack_len > 0: 557*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 558*da0073e9SAndroid Build Coastguard Worker "SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack" 559*da0073e9SAndroid Build Coastguard Worker ) 560*da0073e9SAndroid Build Coastguard Worker mode_stack_state_for_pre_dispatch()._schema_check_mode = mode 561*da0073e9SAndroid Build Coastguard Worker elif isinstance(mode, FunctionalTensorMode): 562*da0073e9SAndroid Build Coastguard Worker current_mode = mode_stack_state_for_pre_dispatch().get(1) 563*da0073e9SAndroid Build Coastguard Worker assert current_mode is None 564*da0073e9SAndroid Build Coastguard Worker mode_stack_state_for_pre_dispatch().set(1, mode) 565*da0073e9SAndroid Build Coastguard Worker else: 566*da0073e9SAndroid Build Coastguard Worker current_mode = mode_stack_state_for_pre_dispatch().get(0) 567*da0073e9SAndroid Build Coastguard Worker assert current_mode is None 568*da0073e9SAndroid Build Coastguard Worker mode_stack_state_for_pre_dispatch().set(0, mode) 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker # When we are setting a mode, we need to check if there is 571*da0073e9SAndroid Build Coastguard Worker # active mode left on the PreDispatch key. If there was nothing 572*da0073e9SAndroid Build Coastguard Worker # active before setting this mode, it means that PreDispatch key 573*da0073e9SAndroid Build Coastguard Worker # was turned off. So we need to turn it on again. 574*da0073e9SAndroid Build Coastguard Worker if previous_mode_stack_len == 0: 575*da0073e9SAndroid Build Coastguard Worker torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, True) 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Workerdef _pop_mode_from_pre_dispatch(): 579*da0073e9SAndroid Build Coastguard Worker mode_stack = mode_stack_state_for_pre_dispatch() 580*da0073e9SAndroid Build Coastguard Worker pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch() 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker if pre_dispatch_len == 0: 583*da0073e9SAndroid Build Coastguard Worker raise AssertionError("Trying to pop empty mode stack") 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker if mode_stack._schema_check_mode is not None: 586*da0073e9SAndroid Build Coastguard Worker return unset_mode_pre_dispatch(None, schema_check=True) 587*da0073e9SAndroid Build Coastguard Worker if mode_stack.get(1) is not None: 588*da0073e9SAndroid Build Coastguard Worker return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL) 589*da0073e9SAndroid Build Coastguard Worker if mode_stack.get(0) is not None: 590*da0073e9SAndroid Build Coastguard Worker return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Workerdef _len_torch_dispatch_stack_pre_dispatch(): 594*da0073e9SAndroid Build Coastguard Worker return mode_stack_state_for_pre_dispatch().count() 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Workerdef _get_dispatch_mode_pre_dispatch(mode_key): 598*da0073e9SAndroid Build Coastguard Worker assert mode_key in ( 599*da0073e9SAndroid Build Coastguard Worker torch._C._TorchDispatchModeKey.PROXY, 600*da0073e9SAndroid Build Coastguard Worker torch._C._TorchDispatchModeKey.FUNCTIONAL, 601*da0073e9SAndroid Build Coastguard Worker ) 602*da0073e9SAndroid Build Coastguard Worker if mode_key == torch._C._TorchDispatchModeKey.PROXY: 603*da0073e9SAndroid Build Coastguard Worker return mode_stack_state_for_pre_dispatch().get(0) 604*da0073e9SAndroid Build Coastguard Worker else: 605*da0073e9SAndroid Build Coastguard Worker return mode_stack_state_for_pre_dispatch().get(1) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Workerdef _get_current_dispatch_mode_pre_dispatch(): 609*da0073e9SAndroid Build Coastguard Worker if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None: 610*da0073e9SAndroid Build Coastguard Worker return mode_stack_state_for_pre_dispatch()._schema_check_mode 611*da0073e9SAndroid Build Coastguard Worker else: 612*da0073e9SAndroid Build Coastguard Worker stack_len = mode_stack_state_for_pre_dispatch().count() 613*da0073e9SAndroid Build Coastguard Worker if stack_len == 2: 614*da0073e9SAndroid Build Coastguard Worker return mode_stack_state_for_pre_dispatch().get(1) 615*da0073e9SAndroid Build Coastguard Worker if stack_len == 1: 616*da0073e9SAndroid Build Coastguard Worker return ( 617*da0073e9SAndroid Build Coastguard Worker mode_stack_state_for_pre_dispatch().get(1) 618*da0073e9SAndroid Build Coastguard Worker if mode_stack_state_for_pre_dispatch().get(1) is not None 619*da0073e9SAndroid Build Coastguard Worker else mode_stack_state_for_pre_dispatch().get(0) 620*da0073e9SAndroid Build Coastguard Worker ) 621*da0073e9SAndroid Build Coastguard Worker return None 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Workerdef mode_stack_state_for_pre_dispatch(): 625*da0073e9SAndroid Build Coastguard Worker global _mode_stack_state_for_pre_dispatch 626*da0073e9SAndroid Build Coastguard Worker return _mode_stack_state_for_pre_dispatch 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Workercached_ops: Set["OpOverload"] = set() 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Workerdef add_cached_op(op_overload): 633*da0073e9SAndroid Build Coastguard Worker global cached_ops 634*da0073e9SAndroid Build Coastguard Worker cached_ops.add(op_overload) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Workerdef reset_cached_ops(): 638*da0073e9SAndroid Build Coastguard Worker global cached_ops 639*da0073e9SAndroid Build Coastguard Worker cached_ops.clear() 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Workerdef get_cached_ops(): 643*da0073e9SAndroid Build Coastguard Worker global cached_ops 644*da0073e9SAndroid Build Coastguard Worker return cached_ops 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object. 648*da0073e9SAndroid Build Coastguard Worker# You can obtain an OpOverload object through attribute query on OpOverloadPacket. 649*da0073e9SAndroid Build Coastguard Workerclass OpOverload(OperatorBase): 650*da0073e9SAndroid Build Coastguard Worker def __init__(self, overloadpacket, op, op_dk, schema, tags): 651*da0073e9SAndroid Build Coastguard Worker super().__init__() 652*da0073e9SAndroid Build Coastguard Worker self._op = op 653*da0073e9SAndroid Build Coastguard Worker self._op_dk = op_dk 654*da0073e9SAndroid Build Coastguard Worker self._schema = schema 655*da0073e9SAndroid Build Coastguard Worker self._overloadpacket = overloadpacket 656*da0073e9SAndroid Build Coastguard Worker self._tags = tags 657*da0073e9SAndroid Build Coastguard Worker self._overloadname = ( 658*da0073e9SAndroid Build Coastguard Worker "default" if schema.overload_name == "" else schema.overload_name 659*da0073e9SAndroid Build Coastguard Worker ) 660*da0073e9SAndroid Build Coastguard Worker self._name = self._schema.name 661*da0073e9SAndroid Build Coastguard Worker if schema.overload_name: 662*da0073e9SAndroid Build Coastguard Worker self._name += "." + schema.overload_name 663*da0073e9SAndroid Build Coastguard Worker self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}" 664*da0073e9SAndroid Build Coastguard Worker self.__module__ = overloadpacket.__module__ 665*da0073e9SAndroid Build Coastguard Worker op.__module__ = overloadpacket.__module__ 666*da0073e9SAndroid Build Coastguard Worker self.__qualname__ = self._name 667*da0073e9SAndroid Build Coastguard Worker self.__annotations__ = {} 668*da0073e9SAndroid Build Coastguard Worker # Only compute the OperatorHandle when we need it. Not all OpOverloads have 669*da0073e9SAndroid Build Coastguard Worker # OperatorHandles (the TorchScript ones don't...) 670*da0073e9SAndroid Build Coastguard Worker self._lazy_handle = None 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker # If the OpOverload was constructed from a Library.def in Python. 673*da0073e9SAndroid Build Coastguard Worker self._defined_in_python = self.__qualname__ in torch.library._defs 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker # Logic replicated from aten/src/ATen/native/MathBitsFallback.h 676*da0073e9SAndroid Build Coastguard Worker is_write = None 677*da0073e9SAndroid Build Coastguard Worker for a in self._schema.arguments: 678*da0073e9SAndroid Build Coastguard Worker if a.alias_info is None: 679*da0073e9SAndroid Build Coastguard Worker continue 680*da0073e9SAndroid Build Coastguard Worker if is_write is None: 681*da0073e9SAndroid Build Coastguard Worker is_write = a.alias_info.is_write 682*da0073e9SAndroid Build Coastguard Worker else: 683*da0073e9SAndroid Build Coastguard Worker # We will conservatively call mixed mutable/non-mutable 684*da0073e9SAndroid Build Coastguard Worker # aliased inputs as NOT a view 685*da0073e9SAndroid Build Coastguard Worker is_write = a.alias_info.is_write or is_write 686*da0073e9SAndroid Build Coastguard Worker self.is_view = is_write is not None and not is_write 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker @property 689*da0073e9SAndroid Build Coastguard Worker def _namespace(self): 690*da0073e9SAndroid Build Coastguard Worker return self._schema.name.split("::")[0] 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker @property 693*da0073e9SAndroid Build Coastguard Worker def _opname(self): 694*da0073e9SAndroid Build Coastguard Worker return self._schema.name.split("::")[1] 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker @property 697*da0073e9SAndroid Build Coastguard Worker def _handle(self): 698*da0073e9SAndroid Build Coastguard Worker if self._lazy_handle is None: 699*da0073e9SAndroid Build Coastguard Worker self._lazy_handle = torch._C._dispatch_find_schema_or_throw( 700*da0073e9SAndroid Build Coastguard Worker self._schema.name, self._schema.overload_name 701*da0073e9SAndroid Build Coastguard Worker ) 702*da0073e9SAndroid Build Coastguard Worker return self._lazy_handle 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker # it's a no-op since OpOverload object is immutable and must be unique for a given op overload. 705*da0073e9SAndroid Build Coastguard Worker def __deepcopy__(self, memo=None): 706*da0073e9SAndroid Build Coastguard Worker return self 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 709*da0073e9SAndroid Build Coastguard Worker return "<OpOverload(op='{}.{}', overload='{}')>".format( 710*da0073e9SAndroid Build Coastguard Worker *self._schema.name.split("::"), self._overloadname 711*da0073e9SAndroid Build Coastguard Worker ) 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker # Use positional-only argument to avoid naming collision with aten ops arguments 714*da0073e9SAndroid Build Coastguard Worker # that are named "self". This way, all the aten ops can be called by kwargs. 715*da0073e9SAndroid Build Coastguard Worker def __call__(self, /, *args, **kwargs): 716*da0073e9SAndroid Build Coastguard Worker return self._op(*args, **kwargs) 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker # Use positional-only argument to avoid naming collision with aten ops arguments 719*da0073e9SAndroid Build Coastguard Worker # that are named "self". This way, all the aten ops can be called by kwargs. 720*da0073e9SAndroid Build Coastguard Worker def redispatch(self, /, keyset, *args, **kwargs): 721*da0073e9SAndroid Build Coastguard Worker return self._handle.redispatch_boxed(keyset, *args, **kwargs) 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker def __hash__(self): 724*da0073e9SAndroid Build Coastguard Worker return hash(self._op) 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker # `my_namespace.my_op_name.overload_name` 727*da0073e9SAndroid Build Coastguard Worker def __str__(self): 728*da0073e9SAndroid Build Coastguard Worker return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname) 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker def has_kernel_for_dispatch_key(self, k): 731*da0073e9SAndroid Build Coastguard Worker return super().has_kernel_for_dispatch_key( 732*da0073e9SAndroid Build Coastguard Worker k 733*da0073e9SAndroid Build Coastguard Worker ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker def has_kernel_for_any_dispatch_key(self, ks): 736*da0073e9SAndroid Build Coastguard Worker return torch._C._dispatch_has_kernel_for_any_dispatch_key( 737*da0073e9SAndroid Build Coastguard Worker self.name(), ks 738*da0073e9SAndroid Build Coastguard Worker ) or super().has_kernel_for_any_dispatch_key(ks) 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker @property 741*da0073e9SAndroid Build Coastguard Worker def namespace(self): 742*da0073e9SAndroid Build Coastguard Worker return self._schema.name.split("::")[0] 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker def _can_decompose(self): 745*da0073e9SAndroid Build Coastguard Worker dk = DispatchKey.CompositeImplicitAutograd 746*da0073e9SAndroid Build Coastguard Worker return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key( 747*da0073e9SAndroid Build Coastguard Worker self.name(), dk 748*da0073e9SAndroid Build Coastguard Worker ) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker def decompose(self, *args, **kwargs): 751*da0073e9SAndroid Build Coastguard Worker dk = DispatchKey.CompositeImplicitAutograd 752*da0073e9SAndroid Build Coastguard Worker if dk in self.py_kernels: 753*da0073e9SAndroid Build Coastguard Worker # NB: This branch is not too necessary anymore, because we can 754*da0073e9SAndroid Build Coastguard Worker # apply Python CompositeImplicitAutograd *before* tracing 755*da0073e9SAndroid Build Coastguard Worker # using Python dispatcher (also taking advantage of the autograd 756*da0073e9SAndroid Build Coastguard Worker # formula). But it's included for completeness 757*da0073e9SAndroid Build Coastguard Worker return self.py_kernels[dk](*args, **kwargs) 758*da0073e9SAndroid Build Coastguard Worker elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk): 759*da0073e9SAndroid Build Coastguard Worker return self._op_dk(dk, *args, **kwargs) 760*da0073e9SAndroid Build Coastguard Worker else: 761*da0073e9SAndroid Build Coastguard Worker return NotImplemented 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker # Remove a dispatch key from the dispatch cache. This will force it to get 764*da0073e9SAndroid Build Coastguard Worker # recomputed the next time. Does nothing 765*da0073e9SAndroid Build Coastguard Worker # WARNING: if you register a dispatch key to py_kernels of an OpOverload, 766*da0073e9SAndroid Build Coastguard Worker # calling _del_dispatch on that key is NOT sufficient to apply your change, 767*da0073e9SAndroid Build Coastguard Worker # because a single registration may affect MULTIPLE dispatch keys (e.g., 768*da0073e9SAndroid Build Coastguard Worker # registering Autograd affects AutogradCPU). del_dispatch is to be used 769*da0073e9SAndroid Build Coastguard Worker # only if you are specifically modifying how get_dispatch handles a 770*da0073e9SAndroid Build Coastguard Worker # particular input 'key'. 771*da0073e9SAndroid Build Coastguard Worker def _uncache_dispatch(self, key): 772*da0073e9SAndroid Build Coastguard Worker self._dispatch_cache.pop(key, None) 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker # This implements the pre-computation logic for the Python dispatcher. 775*da0073e9SAndroid Build Coastguard Worker def _get_dispatch(self, key): 776*da0073e9SAndroid Build Coastguard Worker # This is only called upon a cache miss 777*da0073e9SAndroid Build Coastguard Worker assert key not in self._dispatch_cache, f"{self} {key}" 778*da0073e9SAndroid Build Coastguard Worker 779*da0073e9SAndroid Build Coastguard Worker if key == DispatchKey.Python: 780*da0073e9SAndroid Build Coastguard Worker if not isinstance(self, TorchBindOpOverload) and not self.python_key_table: 781*da0073e9SAndroid Build Coastguard Worker self._dispatch_cache[key] = key 782*da0073e9SAndroid Build Coastguard Worker add_cached_op(self) 783*da0073e9SAndroid Build Coastguard Worker return key 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker def handler(*args, **kwargs): 786*da0073e9SAndroid Build Coastguard Worker from torch.utils._python_dispatch import _get_current_dispatch_mode 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker # TODO: We also need to handle tensor subclasses here 789*da0073e9SAndroid Build Coastguard Worker # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. 790*da0073e9SAndroid Build Coastguard Worker curr_mode = type(_get_current_dispatch_mode()) 791*da0073e9SAndroid Build Coastguard Worker assert ( 792*da0073e9SAndroid Build Coastguard Worker curr_mode is not None 793*da0073e9SAndroid Build Coastguard Worker ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker if curr_mode not in self.python_key_table: 796*da0073e9SAndroid Build Coastguard Worker if isinstance(self, TorchBindOpOverload): 797*da0073e9SAndroid Build Coastguard Worker with torch.utils._python_dispatch._pop_mode_temporarily() as mode: 798*da0073e9SAndroid Build Coastguard Worker return torch._library.utils.handle_dispatch_mode( 799*da0073e9SAndroid Build Coastguard Worker mode, self, *args, **kwargs 800*da0073e9SAndroid Build Coastguard Worker ) 801*da0073e9SAndroid Build Coastguard Worker else: 802*da0073e9SAndroid Build Coastguard Worker return self._op_dk(key, *args, **kwargs) 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker with torch.utils._python_dispatch._pop_mode_temporarily() as mode: 805*da0073e9SAndroid Build Coastguard Worker return self.python_key_table[curr_mode](mode, *args, **kwargs) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker self._dispatch_cache[key] = handler 808*da0073e9SAndroid Build Coastguard Worker add_cached_op(self) 809*da0073e9SAndroid Build Coastguard Worker return handler 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker functionality_key = torch._C._to_functionality_key(key) # type: ignore[attr-defined] 812*da0073e9SAndroid Build Coastguard Worker if functionality_key == DispatchKey.PreDispatch: 813*da0073e9SAndroid Build Coastguard Worker curr_stack_len = _len_torch_dispatch_stack_pre_dispatch() 814*da0073e9SAndroid Build Coastguard Worker # The check for Python in the exclude set is so we properly respect `with no_dispatch()` 815*da0073e9SAndroid Build Coastguard Worker # calls inside of a mode. 816*da0073e9SAndroid Build Coastguard Worker if ( 817*da0073e9SAndroid Build Coastguard Worker curr_stack_len > 0 818*da0073e9SAndroid Build Coastguard Worker and not torch._C._dispatch_tls_is_dispatch_key_excluded( 819*da0073e9SAndroid Build Coastguard Worker DispatchKey.Python 820*da0073e9SAndroid Build Coastguard Worker ) 821*da0073e9SAndroid Build Coastguard Worker ): 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker def handler(*args, **kwargs): 824*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 825*da0073e9SAndroid Build Coastguard Worker def _temporarily_pop_modes_from_pre_dispatch(): 826*da0073e9SAndroid Build Coastguard Worker top_mode = _pop_mode_from_pre_dispatch() 827*da0073e9SAndroid Build Coastguard Worker try: 828*da0073e9SAndroid Build Coastguard Worker yield top_mode 829*da0073e9SAndroid Build Coastguard Worker finally: 830*da0073e9SAndroid Build Coastguard Worker _set_mode_pre_dispatch(top_mode) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker with _temporarily_pop_modes_from_pre_dispatch() as curr_mode: 833*da0073e9SAndroid Build Coastguard Worker return torch._library.utils.handle_dispatch_mode( 834*da0073e9SAndroid Build Coastguard Worker curr_mode, self, *args, **kwargs 835*da0073e9SAndroid Build Coastguard Worker ) 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker # Note [Not Caching Per-Dispatch-Key Mode Handlers] 838*da0073e9SAndroid Build Coastguard Worker # Note that we're not caching this handler. There isn't really a point, since the slow bit 839*da0073e9SAndroid Build Coastguard Worker # is the handler itself (in python). 840*da0073e9SAndroid Build Coastguard Worker # Also, not caching means that we don't have to reset the cache when any existing 841*da0073e9SAndroid Build Coastguard Worker # modes go out of scope (which in of itself takes time to loop through all operators). 842*da0073e9SAndroid Build Coastguard Worker return handler 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker final_key = resolve_key(self, key) 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker # See Note [Not Caching Per-Dispatch-Key Mode Handlers] 847*da0073e9SAndroid Build Coastguard Worker cache_result = key != DispatchKey.PreDispatch 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker # TODO: We could potentially have lots of debugging wrappers against 850*da0073e9SAndroid Build Coastguard Worker # dispatch keys; design some general registration mechanism instead of 851*da0073e9SAndroid Build Coastguard Worker # having if statement for each of them 852*da0073e9SAndroid Build Coastguard Worker if key == DispatchKey.Functionalize: 853*da0073e9SAndroid Build Coastguard Worker import torch._dispatch.python as pydispatch 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker if pydispatch.CROSSREF_FUNCTIONALIZE: 856*da0073e9SAndroid Build Coastguard Worker handler = pydispatch.make_crossref_functionalize(self, final_key) 857*da0073e9SAndroid Build Coastguard Worker if cache_result: 858*da0073e9SAndroid Build Coastguard Worker self._dispatch_cache[key] = handler 859*da0073e9SAndroid Build Coastguard Worker add_cached_op(self) 860*da0073e9SAndroid Build Coastguard Worker return handler 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker r = self.py_kernels.get(final_key, final_key) 863*da0073e9SAndroid Build Coastguard Worker if cache_result: 864*da0073e9SAndroid Build Coastguard Worker self._dispatch_cache[key] = r 865*da0073e9SAndroid Build Coastguard Worker add_cached_op(self) 866*da0073e9SAndroid Build Coastguard Worker return r 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker def name(self): 869*da0073e9SAndroid Build Coastguard Worker return self._name 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker @property 872*da0073e9SAndroid Build Coastguard Worker def overloadpacket(self): 873*da0073e9SAndroid Build Coastguard Worker return self._overloadpacket 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker @property 876*da0073e9SAndroid Build Coastguard Worker def op(self): 877*da0073e9SAndroid Build Coastguard Worker return self._op 878*da0073e9SAndroid Build Coastguard Worker 879*da0073e9SAndroid Build Coastguard Worker @property 880*da0073e9SAndroid Build Coastguard Worker def tags(self): 881*da0073e9SAndroid Build Coastguard Worker return self._tags 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Worker # TODO: add more methods to expose information about input and output arguments 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker# TorchBindOpOverload are those custom ops which have at least one overload's 887*da0073e9SAndroid Build Coastguard Worker# schema consists of torch.ScriptObject (i.e. custom class) input. 888*da0073e9SAndroid Build Coastguard Worker# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python 889*da0073e9SAndroid Build Coastguard Worker# when its inputs contain FakeScriptObject in a similar way as higher order ops. 890*da0073e9SAndroid Build Coastguard Workerclass TorchBindOpOverload(OpOverload): 891*da0073e9SAndroid Build Coastguard Worker def _fallthrough_keys(self) -> List[DispatchKey]: 892*da0073e9SAndroid Build Coastguard Worker # TODO: we should be calling the fallback for these, but a fallthrough is almost close 893*da0073e9SAndroid Build Coastguard Worker # enough to the fallback in most cases that we care about. 894*da0073e9SAndroid Build Coastguard Worker _DEFAULT_FALLTHROUGH_KEYS = [ 895*da0073e9SAndroid Build Coastguard Worker DispatchKey.Autograd, 896*da0073e9SAndroid Build Coastguard Worker DispatchKey.AutogradCPU, 897*da0073e9SAndroid Build Coastguard Worker DispatchKey.AutogradCUDA, 898*da0073e9SAndroid Build Coastguard Worker DispatchKey.ADInplaceOrView, 899*da0073e9SAndroid Build Coastguard Worker DispatchKey.BackendSelect, 900*da0073e9SAndroid Build Coastguard Worker DispatchKey.PythonTLSSnapshot, 901*da0073e9SAndroid Build Coastguard Worker DispatchKey.PythonDispatcher, 902*da0073e9SAndroid Build Coastguard Worker ] 903*da0073e9SAndroid Build Coastguard Worker 904*da0073e9SAndroid Build Coastguard Worker def _may_use_fallthrough_instead_of_fallback(key: DispatchKey): 905*da0073e9SAndroid Build Coastguard Worker if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key): 906*da0073e9SAndroid Build Coastguard Worker return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( 907*da0073e9SAndroid Build Coastguard Worker self.name(), key 908*da0073e9SAndroid Build Coastguard Worker ) 909*da0073e9SAndroid Build Coastguard Worker 910*da0073e9SAndroid Build Coastguard Worker return ( 911*da0073e9SAndroid Build Coastguard Worker key not in self.py_kernels 912*da0073e9SAndroid Build Coastguard Worker or self.py_kernels[key] is torch.library.fallthrough_kernel 913*da0073e9SAndroid Build Coastguard Worker ) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker return [ 916*da0073e9SAndroid Build Coastguard Worker key 917*da0073e9SAndroid Build Coastguard Worker for key in _DEFAULT_FALLTHROUGH_KEYS 918*da0073e9SAndroid Build Coastguard Worker if _may_use_fallthrough_instead_of_fallback(key) 919*da0073e9SAndroid Build Coastguard Worker ] 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 922*da0073e9SAndroid Build Coastguard Worker def _register_as_effectful_op_temporarily(self): 923*da0073e9SAndroid Build Coastguard Worker from torch._higher_order_ops.effects import ( 924*da0073e9SAndroid Build Coastguard Worker _EffectType, 925*da0073e9SAndroid Build Coastguard Worker _register_effectful_op, 926*da0073e9SAndroid Build Coastguard Worker SIDE_EFFECTS, 927*da0073e9SAndroid Build Coastguard Worker ) 928*da0073e9SAndroid Build Coastguard Worker 929*da0073e9SAndroid Build Coastguard Worker try: 930*da0073e9SAndroid Build Coastguard Worker if self not in SIDE_EFFECTS: 931*da0073e9SAndroid Build Coastguard Worker _register_effectful_op(self, _EffectType.ORDERED) 932*da0073e9SAndroid Build Coastguard Worker yield 933*da0073e9SAndroid Build Coastguard Worker finally: 934*da0073e9SAndroid Build Coastguard Worker if self in SIDE_EFFECTS: 935*da0073e9SAndroid Build Coastguard Worker del SIDE_EFFECTS[self] 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker # Use positional-only argument to avoid naming collision with aten ops arguments 938*da0073e9SAndroid Build Coastguard Worker # that are named "self". This way, all the aten ops can be called by kwargs. 939*da0073e9SAndroid Build Coastguard Worker def __call__(self, /, *args, **kwargs): 940*da0073e9SAndroid Build Coastguard Worker if _must_dispatch_in_python(args, kwargs): 941*da0073e9SAndroid Build Coastguard Worker # When any inputs are FakeScriptObject, we need to 942*da0073e9SAndroid Build Coastguard Worker # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher 943*da0073e9SAndroid Build Coastguard Worker # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject. 944*da0073e9SAndroid Build Coastguard Worker # 945*da0073e9SAndroid Build Coastguard Worker # Note: 946*da0073e9SAndroid Build Coastguard Worker # 1. We only register the torchbind op temporarily as effectful op because we only want 947*da0073e9SAndroid Build Coastguard Worker # the effect token functionalization logic to be applied during tracing. Otherwise, the behavior 948*da0073e9SAndroid Build Coastguard Worker # of the eagerly executing the op might change after tracing. 949*da0073e9SAndroid Build Coastguard Worker # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might 950*da0073e9SAndroid Build Coastguard Worker # cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction. 951*da0073e9SAndroid Build Coastguard Worker with self._register_as_effectful_op_temporarily(): 952*da0073e9SAndroid Build Coastguard Worker return self._dispatch_in_python(args, kwargs, self._fallthrough_keys()) 953*da0073e9SAndroid Build Coastguard Worker return self._op(*args, **kwargs) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker def _dispatch_in_python(self, args, kwargs, fallthrough_keys): 956*da0073e9SAndroid Build Coastguard Worker non_fallthrough_keys = torch._C._dispatch_keyset_full() 957*da0073e9SAndroid Build Coastguard Worker for key in fallthrough_keys: 958*da0073e9SAndroid Build Coastguard Worker non_fallthrough_keys = non_fallthrough_keys.remove(key) 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys) 961*da0073e9SAndroid Build Coastguard Worker dispatch_key = dispatch_key_set.highestPriorityTypeId() 962*da0073e9SAndroid Build Coastguard Worker 963*da0073e9SAndroid Build Coastguard Worker handler = ( 964*da0073e9SAndroid Build Coastguard Worker self._get_dispatch(dispatch_key) 965*da0073e9SAndroid Build Coastguard Worker if dispatch_key not in self._dispatch_cache 966*da0073e9SAndroid Build Coastguard Worker else self._dispatch_cache[dispatch_key] 967*da0073e9SAndroid Build Coastguard Worker ) 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker if isinstance(handler, DispatchKey): 970*da0073e9SAndroid Build Coastguard Worker # fallthrough keys can be registered at runtime via torch.library.impl 971*da0073e9SAndroid Build Coastguard Worker # so need to add it to fallthrough_keys and re-dispatch. 972*da0073e9SAndroid Build Coastguard Worker if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( 973*da0073e9SAndroid Build Coastguard Worker self.name(), dispatch_key 974*da0073e9SAndroid Build Coastguard Worker ): 975*da0073e9SAndroid Build Coastguard Worker return self._dispatch_in_python( 976*da0073e9SAndroid Build Coastguard Worker args, kwargs, fallthrough_keys + [dispatch_key] 977*da0073e9SAndroid Build Coastguard Worker ) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 980*da0073e9SAndroid Build Coastguard Worker f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}." 981*da0073e9SAndroid Build Coastguard Worker f" but no python implementation is found." 982*da0073e9SAndroid Build Coastguard Worker f" Please file an issue on this when you encounter this error." 983*da0073e9SAndroid Build Coastguard Worker f" This error can happen when you export or compile the model." 984*da0073e9SAndroid Build Coastguard Worker f" It can still happpen even if a C++ implementation for {dispatch_key}. " 985*da0073e9SAndroid Build Coastguard Worker f" has been registered. That's because FakeScriptObject purely lives in python and cannot work " 986*da0073e9SAndroid Build Coastguard Worker f" with a C++ implementation." 987*da0073e9SAndroid Build Coastguard Worker ) 988*da0073e9SAndroid Build Coastguard Worker 989*da0073e9SAndroid Build Coastguard Worker assert isinstance(handler, Callable) # type: ignore[arg-type] 990*da0073e9SAndroid Build Coastguard Worker return handler(*args, **kwargs) 991*da0073e9SAndroid Build Coastguard Worker 992*da0073e9SAndroid Build Coastguard Worker 993*da0073e9SAndroid Build Coastguard Workerdef _must_dispatch_in_python(args, kwargs): 994*da0073e9SAndroid Build Coastguard Worker return pytree.tree_any( 995*da0073e9SAndroid Build Coastguard Worker lambda obj: isinstance( 996*da0073e9SAndroid Build Coastguard Worker obj, torch._library.fake_class_registry.FakeScriptObject 997*da0073e9SAndroid Build Coastguard Worker ), 998*da0073e9SAndroid Build Coastguard Worker (args, kwargs), 999*da0073e9SAndroid Build Coastguard Worker ) 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Workerdef _has_script_object_arg(schema: torch.FunctionSchema) -> bool: 1003*da0073e9SAndroid Build Coastguard Worker return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments) 1004*da0073e9SAndroid Build Coastguard Worker 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator 1007*da0073e9SAndroid Build Coastguard Worker# You can obtain an OpOverload object through attribute query. 1008*da0073e9SAndroid Build Coastguard Workerclass OpOverloadPacket: 1009*da0073e9SAndroid Build Coastguard Worker def __init__(self, qualified_op_name, op_name, op, overload_names): 1010*da0073e9SAndroid Build Coastguard Worker # These attributes are accessible on the object through the properties 1011*da0073e9SAndroid Build Coastguard Worker # defined below but are immutable 1012*da0073e9SAndroid Build Coastguard Worker self._qualified_op_name = qualified_op_name 1013*da0073e9SAndroid Build Coastguard Worker self.__name__ = op_name 1014*da0073e9SAndroid Build Coastguard Worker self._op = op 1015*da0073e9SAndroid Build Coastguard Worker self._overload_names = overload_names 1016*da0073e9SAndroid Build Coastguard Worker self._dir = [] 1017*da0073e9SAndroid Build Coastguard Worker self._has_torchbind_op_overload = any( 1018*da0073e9SAndroid Build Coastguard Worker _has_script_object_arg(schema) for schema in self._schemas.values() 1019*da0073e9SAndroid Build Coastguard Worker ) 1020*da0073e9SAndroid Build Coastguard Worker 1021*da0073e9SAndroid Build Coastguard Worker # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op. 1022*da0073e9SAndroid Build Coastguard Worker def __deepcopy__(self, memo=None): 1023*da0073e9SAndroid Build Coastguard Worker return self 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 1026*da0073e9SAndroid Build Coastguard Worker return "<OpOverloadPacket(op='{}.{}')>".format( 1027*da0073e9SAndroid Build Coastguard Worker *self._qualified_op_name.split("::") 1028*da0073e9SAndroid Build Coastguard Worker ) 1029*da0073e9SAndroid Build Coastguard Worker 1030*da0073e9SAndroid Build Coastguard Worker def __hash__(self): 1031*da0073e9SAndroid Build Coastguard Worker return hash(self._op) 1032*da0073e9SAndroid Build Coastguard Worker 1033*da0073e9SAndroid Build Coastguard Worker def __str__(self): 1034*da0073e9SAndroid Build Coastguard Worker return "{}.{}".format(*self._qualified_op_name.split("::")) 1035*da0073e9SAndroid Build Coastguard Worker 1036*da0073e9SAndroid Build Coastguard Worker @property 1037*da0073e9SAndroid Build Coastguard Worker def op(self): 1038*da0073e9SAndroid Build Coastguard Worker return self._op 1039*da0073e9SAndroid Build Coastguard Worker 1040*da0073e9SAndroid Build Coastguard Worker @property 1041*da0073e9SAndroid Build Coastguard Worker def _schemas(self): 1042*da0073e9SAndroid Build Coastguard Worker return { 1043*da0073e9SAndroid Build Coastguard Worker overload_name: torch._C._get_schema(self._qualified_op_name, overload_name) 1044*da0073e9SAndroid Build Coastguard Worker for overload_name in self._overload_names 1045*da0073e9SAndroid Build Coastguard Worker } 1046*da0073e9SAndroid Build Coastguard Worker 1047*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, key): 1048*da0073e9SAndroid Build Coastguard Worker # It is not a valid op_name when __file__ is passed in 1049*da0073e9SAndroid Build Coastguard Worker if key == "__file__": 1050*da0073e9SAndroid Build Coastguard Worker return "torch.ops" 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker # ensure that query for dunder attributes that does not exist on 1053*da0073e9SAndroid Build Coastguard Worker # opoverloadpacket but instead exists on the self._op object does not unnecessarily call 1054*da0073e9SAndroid Build Coastguard Worker # `_get_operation_overload` (which is an expensive operation). 1055*da0073e9SAndroid Build Coastguard Worker # This is done to prevent any potential slowdown. This list can be extended 1056*da0073e9SAndroid Build Coastguard Worker # if there exists other attributes like `__name__` that only exist on self._op and not on the 1057*da0073e9SAndroid Build Coastguard Worker # opoverloadpacket. 1058*da0073e9SAndroid Build Coastguard Worker # This is ok since we are guaranteed that an overload name for an aten op can't start with '__' 1059*da0073e9SAndroid Build Coastguard Worker try: 1060*da0073e9SAndroid Build Coastguard Worker if key.startswith("__"): 1061*da0073e9SAndroid Build Coastguard Worker return getattr(self._op, key) 1062*da0073e9SAndroid Build Coastguard Worker except AttributeError: 1063*da0073e9SAndroid Build Coastguard Worker # for consistency because it seems weird to 1064*da0073e9SAndroid Build Coastguard Worker # throw an attribute error with a message containing 1065*da0073e9SAndroid Build Coastguard Worker # an object name different from the one the attribute 1066*da0073e9SAndroid Build Coastguard Worker # query was performed on. 1067*da0073e9SAndroid Build Coastguard Worker raise AttributeError( 1068*da0073e9SAndroid Build Coastguard Worker f"'{str(self)}' can't have an overload name beginning with '__' and the " 1069*da0073e9SAndroid Build Coastguard Worker f"underlying op {str(self._op)} has no attribute {key} either." 1070*da0073e9SAndroid Build Coastguard Worker ) from None 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker try: 1073*da0073e9SAndroid Build Coastguard Worker # This is ok since we are guaranteed that an overload name for an aten op can't be 'default' 1074*da0073e9SAndroid Build Coastguard Worker use_key = "" if key == "default" else key 1075*da0073e9SAndroid Build Coastguard Worker # TODO: disallow access to overloads registered by JIT 1076*da0073e9SAndroid Build Coastguard Worker op_dk_tags = torch._C._get_operation_overload( 1077*da0073e9SAndroid Build Coastguard Worker self._qualified_op_name, use_key 1078*da0073e9SAndroid Build Coastguard Worker ) 1079*da0073e9SAndroid Build Coastguard Worker if op_dk_tags is None: 1080*da0073e9SAndroid Build Coastguard Worker raise AttributeError( 1081*da0073e9SAndroid Build Coastguard Worker f"The underlying op of '{str(self)}' has no overload name '{key}'" 1082*da0073e9SAndroid Build Coastguard Worker ) 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker op_, op_dk_, tags = op_dk_tags 1085*da0073e9SAndroid Build Coastguard Worker schema = torch._C._get_schema(self._qualified_op_name, use_key) 1086*da0073e9SAndroid Build Coastguard Worker overload = ( 1087*da0073e9SAndroid Build Coastguard Worker OpOverload(self, op_, op_dk_, schema, tags) 1088*da0073e9SAndroid Build Coastguard Worker if not _has_script_object_arg(schema) 1089*da0073e9SAndroid Build Coastguard Worker else TorchBindOpOverload(self, op_, op_dk_, schema, tags) 1090*da0073e9SAndroid Build Coastguard Worker ) 1091*da0073e9SAndroid Build Coastguard Worker # cache the overload object 1092*da0073e9SAndroid Build Coastguard Worker setattr(self, key, overload) 1093*da0073e9SAndroid Build Coastguard Worker self._dir.append(key) 1094*da0073e9SAndroid Build Coastguard Worker return overload 1095*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 1096*da0073e9SAndroid Build Coastguard Worker raise AttributeError( 1097*da0073e9SAndroid Build Coastguard Worker f"The underlying op of '{str(self)}' has no overload name '{key}'" 1098*da0073e9SAndroid Build Coastguard Worker ) from None 1099*da0073e9SAndroid Build Coastguard Worker 1100*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1101*da0073e9SAndroid Build Coastguard Worker return iter(self._dir) 1102*da0073e9SAndroid Build Coastguard Worker 1103*da0073e9SAndroid Build Coastguard Worker # Use positional-only argument to avoid naming collision with aten ops arguments 1104*da0073e9SAndroid Build Coastguard Worker # that are named "self". This way, all the aten ops can be called by kwargs. 1105*da0073e9SAndroid Build Coastguard Worker def __call__(self, /, *args, **kwargs): 1106*da0073e9SAndroid Build Coastguard Worker # overloading __call__ to ensure torch.ops.foo.bar() 1107*da0073e9SAndroid Build Coastguard Worker # is still callable from JIT 1108*da0073e9SAndroid Build Coastguard Worker # We save the function ptr as the `op` attribute on 1109*da0073e9SAndroid Build Coastguard Worker # OpOverloadPacket to access it here. 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker # Directly calling OverloadPacket goes into C++, which will check 1112*da0073e9SAndroid Build Coastguard Worker # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we 1113*da0073e9SAndroid Build Coastguard Worker # intercept it here and call TorchBindOpverload instead. 1114*da0073e9SAndroid Build Coastguard Worker if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs): 1115*da0073e9SAndroid Build Coastguard Worker return _call_overload_packet_from_python(self, args, kwargs) 1116*da0073e9SAndroid Build Coastguard Worker return self._op(*args, **(kwargs or {})) 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker # TODO: use this to make a __dir__ 1119*da0073e9SAndroid Build Coastguard Worker def overloads(self): 1120*da0073e9SAndroid Build Coastguard Worker return [n if n else "default" for n in self._overload_names] 1121*da0073e9SAndroid Build Coastguard Worker 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker# Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp 1124*da0073e9SAndroid Build Coastguard Worker# _jit_get_operations, which calls _get_operation_for_overload_or_packet. 1125*da0073e9SAndroid Build Coastguard Workerdef _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs): 1126*da0073e9SAndroid Build Coastguard Worker # Re-use the torch function handling logic in cpp 1127*da0073e9SAndroid Build Coastguard Worker torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet( 1128*da0073e9SAndroid Build Coastguard Worker op, *args, **kwargs 1129*da0073e9SAndroid Build Coastguard Worker ) 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker if torch_function_called: 1132*da0073e9SAndroid Build Coastguard Worker return ret 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker # The following mirrors getOpWithStack. 1135*da0073e9SAndroid Build Coastguard Worker # In cpp, we do a schema matching for the arguments, and call ToIValue to 1136*da0073e9SAndroid Build Coastguard Worker # to check whether the arguments are valid. But need to do similar things here 1137*da0073e9SAndroid Build Coastguard Worker # and check the schema whether the FakeScriptObject is the corresponding fake class 1138*da0073e9SAndroid Build Coastguard Worker # of the actual class used in schema. 1139*da0073e9SAndroid Build Coastguard Worker exceptions = {} 1140*da0073e9SAndroid Build Coastguard Worker found_op = None 1141*da0073e9SAndroid Build Coastguard Worker for overload_name in op.overloads(): 1142*da0073e9SAndroid Build Coastguard Worker op_overload = getattr(op, overload_name) 1143*da0073e9SAndroid Build Coastguard Worker try: 1144*da0073e9SAndroid Build Coastguard Worker _ = torch._C._check_schema_allow_fake_script_object( 1145*da0073e9SAndroid Build Coastguard Worker op_overload._schema, *args, **kwargs 1146*da0073e9SAndroid Build Coastguard Worker ) 1147*da0073e9SAndroid Build Coastguard Worker found_op = op_overload 1148*da0073e9SAndroid Build Coastguard Worker break 1149*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1150*da0073e9SAndroid Build Coastguard Worker exceptions[overload_name] = e 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker if found_op: 1153*da0073e9SAndroid Build Coastguard Worker return found_op(*args, **kwargs) 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker err_msg = ( 1156*da0073e9SAndroid Build Coastguard Worker f"Fail to match any TorchBindOverload of {op} with following exceptions:\n" 1157*da0073e9SAndroid Build Coastguard Worker ) 1158*da0073e9SAndroid Build Coastguard Worker for i, (key, msg) in enumerate(exceptions.items()): 1159*da0073e9SAndroid Build Coastguard Worker err_msg += f"Overload name {key}:\n {msg}\n" 1160*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(err_msg) 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker 1163*da0073e9SAndroid Build Coastguard Worker# Resolution of torch.fn is different from torch.ops.aten.fn 1164*da0073e9SAndroid Build Coastguard Worker# torch.fn uses the Python argparser, matches with the 1165*da0073e9SAndroid Build Coastguard Worker# appropriate schema, and calls into the unboxed version of the method 1166*da0073e9SAndroid Build Coastguard Worker# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. 1167*da0073e9SAndroid Build Coastguard Worker# JIT creates a stack of all the overloads and then tries to match the 1168*da0073e9SAndroid Build Coastguard Worker# correct one at runtime and always calls into the boxed version of the method 1169*da0073e9SAndroid Build Coastguard Worker# Autograd codegen creates VariableType, TracerType, 1170*da0073e9SAndroid Build Coastguard Worker# inplace or view type and python bindings. 1171*da0073e9SAndroid Build Coastguard Worker# Aten codegen generates tensor methods for the tensor class. 1172*da0073e9SAndroid Build Coastguard Worker 1173*da0073e9SAndroid Build Coastguard Worker# _OpNamespace is a subclass of ModuleType because the torch script 1174*da0073e9SAndroid Build Coastguard Worker# allows attribute lookups on modules only. Since we want torch.ops.foo.bar() 1175*da0073e9SAndroid Build Coastguard Worker# to work from script, we need to ensure ops and foo are modules 1176*da0073e9SAndroid Build Coastguard Worker 1177*da0073e9SAndroid Build Coastguard Worker 1178*da0073e9SAndroid Build Coastguard Workerclass _OpNamespace(types.ModuleType): 1179*da0073e9SAndroid Build Coastguard Worker """ 1180*da0073e9SAndroid Build Coastguard Worker An op namespace to dynamically bind Operators into Python. 1181*da0073e9SAndroid Build Coastguard Worker 1182*da0073e9SAndroid Build Coastguard Worker Say a user has created a custom Operator called "my_namespace::my_op". To 1183*da0073e9SAndroid Build Coastguard Worker call this op, the user will write torch.ops.my_namespace.my_op(...). 1184*da0073e9SAndroid Build Coastguard Worker At startup, this operation will not yet be bound into Python. Instead, the 1185*da0073e9SAndroid Build Coastguard Worker following sequence of magic tricks will occur: 1186*da0073e9SAndroid Build Coastguard Worker 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method 1187*da0073e9SAndroid Build Coastguard Worker on the `torch.ops` object, which will create a new `_OpNamespace` 1188*da0073e9SAndroid Build Coastguard Worker object called `my_namespace` and set it as an attribute on the `ops` 1189*da0073e9SAndroid Build Coastguard Worker object. 1190*da0073e9SAndroid Build Coastguard Worker 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on 1191*da0073e9SAndroid Build Coastguard Worker the `my_namespace` object, which will retrieve the operation via 1192*da0073e9SAndroid Build Coastguard Worker `torch.get_operation`, a function bound from C++, and then in a similar 1193*da0073e9SAndroid Build Coastguard Worker fashion bind this new object onto the `my_namespace` object. 1194*da0073e9SAndroid Build Coastguard Worker 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation 1195*da0073e9SAndroid Build Coastguard Worker and subsequent accesses will incur no further lookup (the namespace and 1196*da0073e9SAndroid Build Coastguard Worker operation will already exist). 1197*da0073e9SAndroid Build Coastguard Worker """ 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 1200*da0073e9SAndroid Build Coastguard Worker super().__init__("torch.ops." + name) 1201*da0073e9SAndroid Build Coastguard Worker self.name = name 1202*da0073e9SAndroid Build Coastguard Worker self._dir = [] 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1205*da0073e9SAndroid Build Coastguard Worker return iter(self._dir) 1206*da0073e9SAndroid Build Coastguard Worker 1207*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, op_name): 1208*da0073e9SAndroid Build Coastguard Worker # It is not a valid op_name when __file__ is passed in 1209*da0073e9SAndroid Build Coastguard Worker if op_name == "__file__": 1210*da0073e9SAndroid Build Coastguard Worker return "torch.ops" 1211*da0073e9SAndroid Build Coastguard Worker elif op_name in ["__origin__", "__self__"]: 1212*da0073e9SAndroid Build Coastguard Worker raise AttributeError( 1213*da0073e9SAndroid Build Coastguard Worker f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'" 1214*da0073e9SAndroid Build Coastguard Worker ) 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker # Get the op `my_namespace::my_op` if available. This will also check 1217*da0073e9SAndroid Build Coastguard Worker # for overloads and raise an exception if there are more than one. 1218*da0073e9SAndroid Build Coastguard Worker namespace_name = self.name 1219*da0073e9SAndroid Build Coastguard Worker qualified_op_name = f"{namespace_name}::{op_name}" 1220*da0073e9SAndroid Build Coastguard Worker module_name = self.__module__ + "." + namespace_name 1221*da0073e9SAndroid Build Coastguard Worker 1222*da0073e9SAndroid Build Coastguard Worker try: 1223*da0073e9SAndroid Build Coastguard Worker op, overload_names = _get_packet(qualified_op_name, module_name) 1224*da0073e9SAndroid Build Coastguard Worker if op is None: 1225*da0073e9SAndroid Build Coastguard Worker raise AttributeError( 1226*da0073e9SAndroid Build Coastguard Worker f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" 1227*da0073e9SAndroid Build Coastguard Worker ) 1228*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1229*da0073e9SAndroid Build Coastguard Worker # Turn this into AttributeError so getattr(obj, key, default) 1230*da0073e9SAndroid Build Coastguard Worker # works (this is called by TorchScript with __origin__) 1231*da0073e9SAndroid Build Coastguard Worker raise AttributeError( 1232*da0073e9SAndroid Build Coastguard Worker f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" 1233*da0073e9SAndroid Build Coastguard Worker ) from e 1234*da0073e9SAndroid Build Coastguard Worker 1235*da0073e9SAndroid Build Coastguard Worker op.__module__ = module_name 1236*da0073e9SAndroid Build Coastguard Worker opoverloadpacket = OpOverloadPacket( 1237*da0073e9SAndroid Build Coastguard Worker qualified_op_name, op_name, op, overload_names 1238*da0073e9SAndroid Build Coastguard Worker ) 1239*da0073e9SAndroid Build Coastguard Worker opoverloadpacket.__module__ = self.__module__ + "." + namespace_name 1240*da0073e9SAndroid Build Coastguard Worker # cache the opoverloadpacket to ensure that each op corresponds to 1241*da0073e9SAndroid Build Coastguard Worker # a unique OpOverloadPacket object 1242*da0073e9SAndroid Build Coastguard Worker setattr(self, op_name, opoverloadpacket) 1243*da0073e9SAndroid Build Coastguard Worker self._dir.append(op_name) 1244*da0073e9SAndroid Build Coastguard Worker return opoverloadpacket 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Workerdef _get_packet(qualname, op_module): 1248*da0073e9SAndroid Build Coastguard Worker op, overload_names = torch._C._jit_get_operation(qualname) 1249*da0073e9SAndroid Build Coastguard Worker if op is not None: 1250*da0073e9SAndroid Build Coastguard Worker # let the script frontend know that op is identical to the builtin op 1251*da0073e9SAndroid Build Coastguard Worker # with qualified_op_name 1252*da0073e9SAndroid Build Coastguard Worker torch.jit._builtins._register_builtin(op, qualname) 1253*da0073e9SAndroid Build Coastguard Worker op.__module__ = op_module 1254*da0073e9SAndroid Build Coastguard Worker return op, overload_names 1255*da0073e9SAndroid Build Coastguard Worker 1256*da0073e9SAndroid Build Coastguard Worker 1257*da0073e9SAndroid Build Coastguard Workerdef _refresh_packet(packet): 1258*da0073e9SAndroid Build Coastguard Worker op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__) 1259*da0073e9SAndroid Build Coastguard Worker assert op is not None 1260*da0073e9SAndroid Build Coastguard Worker packet._op = op 1261*da0073e9SAndroid Build Coastguard Worker packet._overload_names = overload_names 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Workerclass _PyOpNamespace(_OpNamespace): 1265*da0073e9SAndroid Build Coastguard Worker def __init__(self, name, ops): 1266*da0073e9SAndroid Build Coastguard Worker super().__init__(name) 1267*da0073e9SAndroid Build Coastguard Worker self._ops = ops 1268*da0073e9SAndroid Build Coastguard Worker 1269*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 1270*da0073e9SAndroid Build Coastguard Worker # Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object. 1271*da0073e9SAndroid Build Coastguard Worker op = self._ops.get(name, None) 1272*da0073e9SAndroid Build Coastguard Worker if op is None: 1273*da0073e9SAndroid Build Coastguard Worker raise AttributeError( 1274*da0073e9SAndroid Build Coastguard Worker f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'" 1275*da0073e9SAndroid Build Coastguard Worker ) 1276*da0073e9SAndroid Build Coastguard Worker setattr(self, name, op) 1277*da0073e9SAndroid Build Coastguard Worker return op 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Workerclass _Ops(types.ModuleType): 1281*da0073e9SAndroid Build Coastguard Worker __file__ = "_ops.py" 1282*da0073e9SAndroid Build Coastguard Worker 1283*da0073e9SAndroid Build Coastguard Worker def __init__(self): 1284*da0073e9SAndroid Build Coastguard Worker super().__init__("torch.ops") 1285*da0073e9SAndroid Build Coastguard Worker self.loaded_libraries = set() 1286*da0073e9SAndroid Build Coastguard Worker self._higher_order_op_namespace = _PyOpNamespace( 1287*da0073e9SAndroid Build Coastguard Worker "torch.ops.higher_order", _higher_order_ops 1288*da0073e9SAndroid Build Coastguard Worker ) 1289*da0073e9SAndroid Build Coastguard Worker self._dir = [] 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 1292*da0073e9SAndroid Build Coastguard Worker # Check if the name is a HigherOrderOperator 1293*da0073e9SAndroid Build Coastguard Worker if name == "higher_order": 1294*da0073e9SAndroid Build Coastguard Worker return self._higher_order_op_namespace 1295*da0073e9SAndroid Build Coastguard Worker 1296*da0073e9SAndroid Build Coastguard Worker # Here we are creating `torch.ops.my_namespace` 1297*da0073e9SAndroid Build Coastguard Worker namespace = _OpNamespace(name) 1298*da0073e9SAndroid Build Coastguard Worker setattr(self, name, namespace) 1299*da0073e9SAndroid Build Coastguard Worker self._dir.append(name) 1300*da0073e9SAndroid Build Coastguard Worker return namespace 1301*da0073e9SAndroid Build Coastguard Worker 1302*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1303*da0073e9SAndroid Build Coastguard Worker return iter(self._dir) 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker def import_module(self, module): 1306*da0073e9SAndroid Build Coastguard Worker """ 1307*da0073e9SAndroid Build Coastguard Worker Imports a Python module that has torch.library registrations. 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker Generally, to extend PyTorch with custom operators, a user will 1310*da0073e9SAndroid Build Coastguard Worker create a Python module whose import triggers registration of 1311*da0073e9SAndroid Build Coastguard Worker the custom operators via a torch.ops.load_library call or a call 1312*da0073e9SAndroid Build Coastguard Worker to one or more torch.library.* APIs. 1313*da0073e9SAndroid Build Coastguard Worker 1314*da0073e9SAndroid Build Coastguard Worker It is unexpected for Python modules to have side effects, so some 1315*da0073e9SAndroid Build Coastguard Worker linters and formatters will complain. Use this API to import Python 1316*da0073e9SAndroid Build Coastguard Worker modules that contain these torch.library side effects. 1317*da0073e9SAndroid Build Coastguard Worker 1318*da0073e9SAndroid Build Coastguard Worker Args: 1319*da0073e9SAndroid Build Coastguard Worker module (str): The name of the Python module to import 1320*da0073e9SAndroid Build Coastguard Worker 1321*da0073e9SAndroid Build Coastguard Worker """ 1322*da0073e9SAndroid Build Coastguard Worker importlib.import_module(module) 1323*da0073e9SAndroid Build Coastguard Worker 1324*da0073e9SAndroid Build Coastguard Worker def load_library(self, path): 1325*da0073e9SAndroid Build Coastguard Worker """ 1326*da0073e9SAndroid Build Coastguard Worker Loads a shared library from the given path into the current process. 1327*da0073e9SAndroid Build Coastguard Worker 1328*da0073e9SAndroid Build Coastguard Worker The library being loaded may run global initialization code to register 1329*da0073e9SAndroid Build Coastguard Worker custom operators with the PyTorch JIT runtime. This allows dynamically 1330*da0073e9SAndroid Build Coastguard Worker loading custom operators. For this, you should compile your operator 1331*da0073e9SAndroid Build Coastguard Worker and the static registration code into a shared library object, and then 1332*da0073e9SAndroid Build Coastguard Worker call ``torch.ops.load_library('path/to/libcustom.so')`` to load the 1333*da0073e9SAndroid Build Coastguard Worker shared object. 1334*da0073e9SAndroid Build Coastguard Worker 1335*da0073e9SAndroid Build Coastguard Worker After the library is loaded, it is added to the 1336*da0073e9SAndroid Build Coastguard Worker ``torch.ops.loaded_libraries`` attribute, a set that may be inspected 1337*da0073e9SAndroid Build Coastguard Worker for the paths of all libraries loaded using this function. 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker Args: 1340*da0073e9SAndroid Build Coastguard Worker path (str): A path to a shared library to load. 1341*da0073e9SAndroid Build Coastguard Worker """ 1342*da0073e9SAndroid Build Coastguard Worker if torch._running_with_deploy(): 1343*da0073e9SAndroid Build Coastguard Worker return 1344*da0073e9SAndroid Build Coastguard Worker 1345*da0073e9SAndroid Build Coastguard Worker path = _utils_internal.resolve_library_path(path) 1346*da0073e9SAndroid Build Coastguard Worker with dl_open_guard(): 1347*da0073e9SAndroid Build Coastguard Worker # Import the shared library into the process, thus running its 1348*da0073e9SAndroid Build Coastguard Worker # static (global) initialization code in order to register custom 1349*da0073e9SAndroid Build Coastguard Worker # operators with the JIT. 1350*da0073e9SAndroid Build Coastguard Worker ctypes.CDLL(path) 1351*da0073e9SAndroid Build Coastguard Worker self.loaded_libraries.add(path) 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Worker# The ops "namespace" 1355*da0073e9SAndroid Build Coastguard Workerops: _Ops = _Ops() 1356