xref: /aosp_15_r20/external/pytorch/torch/_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport abc
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport ctypes
5*da0073e9SAndroid Build Coastguard Workerimport importlib
6*da0073e9SAndroid Build Coastguard Workerimport inspect
7*da0073e9SAndroid Build Coastguard Workerimport sys
8*da0073e9SAndroid Build Coastguard Workerimport types
9*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, List, Set, Type, Union
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerimport torch
12*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
13*da0073e9SAndroid Build Coastguard Workerfrom torch import _utils_internal
14*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
15*da0073e9SAndroid Build Coastguard Workerfrom torch._functorch.pyfunctorch import dispatch_functorch
16*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker# Query `hasattr` only once.
20*da0073e9SAndroid Build Coastguard Worker_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
24*da0073e9SAndroid Build Coastguard Workerdef dl_open_guard():
25*da0073e9SAndroid Build Coastguard Worker    """
26*da0073e9SAndroid Build Coastguard Worker    Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
27*da0073e9SAndroid Build Coastguard Worker    shared library to load custom operators.
28*da0073e9SAndroid Build Coastguard Worker    """
29*da0073e9SAndroid Build Coastguard Worker    if not _SET_GLOBAL_FLAGS:
30*da0073e9SAndroid Build Coastguard Worker        yield
31*da0073e9SAndroid Build Coastguard Worker        return
32*da0073e9SAndroid Build Coastguard Worker    old_flags = sys.getdlopenflags()
33*da0073e9SAndroid Build Coastguard Worker    sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
34*da0073e9SAndroid Build Coastguard Worker    try:
35*da0073e9SAndroid Build Coastguard Worker        yield
36*da0073e9SAndroid Build Coastguard Worker    finally:
37*da0073e9SAndroid Build Coastguard Worker        sys.setdlopenflags(old_flags)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerclass OperatorBase:
41*da0073e9SAndroid Build Coastguard Worker    """
42*da0073e9SAndroid Build Coastguard Worker    Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator
43*da0073e9SAndroid Build Coastguard Worker    (which represents Python-only operators that are unrepresentable in TorchScript).
44*da0073e9SAndroid Build Coastguard Worker    """
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
47*da0073e9SAndroid Build Coastguard Worker        # The dispatch cache precomputes a mapping of dispatch key that the
48*da0073e9SAndroid Build Coastguard Worker        # dispatcher wants to dispatch to, to an actual implementation of the
49*da0073e9SAndroid Build Coastguard Worker        # dispatch key.  Confusingly, the actual implementation could *also* be a
50*da0073e9SAndroid Build Coastguard Worker        # dispatch key, but in this case, this refers to the C++ kernel that
51*da0073e9SAndroid Build Coastguard Worker        # was registered to some dispatch key.  Aliases are permitted in the
52*da0073e9SAndroid Build Coastguard Worker        # latter but not the former; for example, you might lookup the
53*da0073e9SAndroid Build Coastguard Worker        # entry for AutogradCPU, and this maps you to the Autograd key for
54*da0073e9SAndroid Build Coastguard Worker        # the generic autograd kernel that works for all devices.  Since this
55*da0073e9SAndroid Build Coastguard Worker        # is the Python dispatcher, you can also put an arbitrary Python
56*da0073e9SAndroid Build Coastguard Worker        # callable to call instead.  This handler gets precisely the
57*da0073e9SAndroid Build Coastguard Worker        # args/kwargs that the operator was __call__'ed with.
58*da0073e9SAndroid Build Coastguard Worker        # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
59*da0073e9SAndroid Build Coastguard Worker        # for use with OpOverload; cache lookup is done entirely from C++
60*da0073e9SAndroid Build Coastguard Worker        # for speed.
61*da0073e9SAndroid Build Coastguard Worker        # TODO: The cache is NOT currently used by HigherOrderOperator, but it should!
62*da0073e9SAndroid Build Coastguard Worker        self._dispatch_cache: Dict[
63*da0073e9SAndroid Build Coastguard Worker            DispatchKey, Union[DispatchKey, Callable[..., Any]]
64*da0073e9SAndroid Build Coastguard Worker        ] = {}
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        # This table allows you to override the behavior of a particular
67*da0073e9SAndroid Build Coastguard Worker        # dispatch key to call a custom Python function, rather than the
68*da0073e9SAndroid Build Coastguard Worker        # ordinary C++ configured behavior.  This is the raison d'etre of
69*da0073e9SAndroid Build Coastguard Worker        # Python dispatcher: to let you program the dispatcher from Python
70*da0073e9SAndroid Build Coastguard Worker        # in case you need something unusual, and don't want to clobber
71*da0073e9SAndroid Build Coastguard Worker        # the existing registrations using the Python operator registration
72*da0073e9SAndroid Build Coastguard Worker        # API.
73*da0073e9SAndroid Build Coastguard Worker        self.py_kernels: Dict[DispatchKey, Callable[..., Any]] = {}
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        # This table allows you to override the behavior of a particular
76*da0073e9SAndroid Build Coastguard Worker        # operator for a particular TorchDispatchMode.  In practice,
77*da0073e9SAndroid Build Coastguard Worker        # we are using this mostly for ProxyTensorMode.  Modes can be
78*da0073e9SAndroid Build Coastguard Worker        # thought of as an open world extension of dispatch keys, so it
79*da0073e9SAndroid Build Coastguard Worker        # makes sense that you should be able to register them, the same
80*da0073e9SAndroid Build Coastguard Worker        # way you can register dispatch keys.
81*da0073e9SAndroid Build Coastguard Worker        self.python_key_table: Dict[
82*da0073e9SAndroid Build Coastguard Worker            Union[Type[TorchDispatchMode], Type[torch.Tensor]], Callable[..., Any]
83*da0073e9SAndroid Build Coastguard Worker        ] = {}
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        # This table allows you to override the behavior of functorch
86*da0073e9SAndroid Build Coastguard Worker        # transformations.  NB: this currently only does something for
87*da0073e9SAndroid Build Coastguard Worker        # HigherOrderOperator
88*da0073e9SAndroid Build Coastguard Worker        self.functorch_table = {}
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    def __call__(self, *args, **kwargs):
91*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    def has_kernel_for_dispatch_key(self, k):
94*da0073e9SAndroid Build Coastguard Worker        return k in self.py_kernels
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    def has_kernel_for_any_dispatch_key(self, ks):
97*da0073e9SAndroid Build Coastguard Worker        for k in self.py_kernels:
98*da0073e9SAndroid Build Coastguard Worker            if not torch._C._dispatch_is_alias_key(k) and ks.has(k):
99*da0073e9SAndroid Build Coastguard Worker                return True
100*da0073e9SAndroid Build Coastguard Worker        return False
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def py_impl(self, k):
103*da0073e9SAndroid Build Coastguard Worker        def inner(fn):
104*da0073e9SAndroid Build Coastguard Worker            if inspect.isclass(k) and (
105*da0073e9SAndroid Build Coastguard Worker                issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
106*da0073e9SAndroid Build Coastguard Worker            ):
107*da0073e9SAndroid Build Coastguard Worker                assert k not in self.python_key_table
108*da0073e9SAndroid Build Coastguard Worker                # TODO(voz): Should we replace setting DispatchKey.Python entirely with setting mode keys?
109*da0073e9SAndroid Build Coastguard Worker                self.python_key_table[k] = fn
110*da0073e9SAndroid Build Coastguard Worker                self._dispatch_cache.clear()
111*da0073e9SAndroid Build Coastguard Worker                return fn
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker            if isinstance(k, torch._C._functorch.TransformType):
114*da0073e9SAndroid Build Coastguard Worker                assert k not in self.functorch_table
115*da0073e9SAndroid Build Coastguard Worker                self.functorch_table[k] = fn
116*da0073e9SAndroid Build Coastguard Worker                return fn
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker            assert isinstance(k, DispatchKey)
119*da0073e9SAndroid Build Coastguard Worker            assert (
120*da0073e9SAndroid Build Coastguard Worker                k != DispatchKey.Python
121*da0073e9SAndroid Build Coastguard Worker            ), "Please register a mode for the torch._C.DispatchKey.Python key instead."
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker            if k in self.py_kernels:
124*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
125*da0073e9SAndroid Build Coastguard Worker                    f"Trying to override a python impl for {k} on operator {self.name()}"
126*da0073e9SAndroid Build Coastguard Worker                )
127*da0073e9SAndroid Build Coastguard Worker            self.py_kernels[k] = fn
128*da0073e9SAndroid Build Coastguard Worker            self._dispatch_cache.clear()
129*da0073e9SAndroid Build Coastguard Worker            return fn
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        return inner
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    # Registers an implementation to all **3** variants of functionalization that we have:
134*da0073e9SAndroid Build Coastguard Worker    # - DispatchKey.Functionalize
135*da0073e9SAndroid Build Coastguard Worker    # - functorch.TransformType.Functionalize
136*da0073e9SAndroid Build Coastguard Worker    # - FunctionalTensorMode
137*da0073e9SAndroid Build Coastguard Worker    # Example:
138*da0073e9SAndroid Build Coastguard Worker    #   @py_functionalize_impl
139*da0073e9SAndroid Build Coastguard Worker    #   def functionalize_rule(ctx, inner_f, *args):
140*da0073e9SAndroid Build Coastguard Worker    #       args_unwrapped = ctx.unwrap_tensors(args)
141*da0073e9SAndroid Build Coastguard Worker    #       with ctx.redispatch_to_next():
142*da0073e9SAndroid Build Coastguard Worker    #           out = ctx.functionalize(inner_f)(*args_unwrapped)
143*da0073e9SAndroid Build Coastguard Worker    #           return ctx.wrap_tensors(out)
144*da0073e9SAndroid Build Coastguard Worker    def py_functionalize_impl(self, fn):
145*da0073e9SAndroid Build Coastguard Worker        from torch._subclasses.functional_tensor import (
146*da0073e9SAndroid Build Coastguard Worker            CppFunctionalizeAPI as _CppFunctionalizeAPI,
147*da0073e9SAndroid Build Coastguard Worker            FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
148*da0073e9SAndroid Build Coastguard Worker            PythonFunctionalizeAPI as _PythonFunctionalizeAPI,
149*da0073e9SAndroid Build Coastguard Worker        )
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        # Construct our three flavors of functionalization,
152*da0073e9SAndroid Build Coastguard Worker        # each of which have slightly different wrap/unwrap/redispatch policies
153*da0073e9SAndroid Build Coastguard Worker        def functionalize_dk_fn(*args, **kwargs):
154*da0073e9SAndroid Build Coastguard Worker            return fn(_CppFunctionalizeAPI(), *args, **kwargs)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker        def functionalize_dispatch_mode_fn(mode, *args, **kwargs):
157*da0073e9SAndroid Build Coastguard Worker            return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs)
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        def functionalize_functorch_fn(interpreter, *args, **kwargs):
160*da0073e9SAndroid Build Coastguard Worker            return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker        self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
163*da0073e9SAndroid Build Coastguard Worker        self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)(
164*da0073e9SAndroid Build Coastguard Worker            functionalize_dispatch_mode_fn
165*da0073e9SAndroid Build Coastguard Worker        )
166*da0073e9SAndroid Build Coastguard Worker        self.py_impl(torch._C._functorch.TransformType.Functionalize)(
167*da0073e9SAndroid Build Coastguard Worker            functionalize_functorch_fn
168*da0073e9SAndroid Build Coastguard Worker        )
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        return fn
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    def name(self):
173*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker# Equivalent to computeDispatchTableEntryWithDebug
177*da0073e9SAndroid Build Coastguard Workerdef resolve_key(op: OperatorBase, k: DispatchKey):  # type: ignore[valid-type]
178*da0073e9SAndroid Build Coastguard Worker    # 1. (Direct) operator registration
179*da0073e9SAndroid Build Coastguard Worker    if op.has_kernel_for_dispatch_key(k):
180*da0073e9SAndroid Build Coastguard Worker        return k
181*da0073e9SAndroid Build Coastguard Worker    # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
182*da0073e9SAndroid Build Coastguard Worker    cand = DispatchKey.CompositeExplicitAutogradNonFunctional
183*da0073e9SAndroid Build Coastguard Worker    if (
184*da0073e9SAndroid Build Coastguard Worker        k == DispatchKey.Undefined or is_included_in_alias(k, cand)
185*da0073e9SAndroid Build Coastguard Worker    ) and op.has_kernel_for_dispatch_key(cand):
186*da0073e9SAndroid Build Coastguard Worker        return cand
187*da0073e9SAndroid Build Coastguard Worker    # 2.2 Use CompositeExplicitAutograd kernel if available
188*da0073e9SAndroid Build Coastguard Worker    cand = DispatchKey.CompositeExplicitAutograd
189*da0073e9SAndroid Build Coastguard Worker    if (
190*da0073e9SAndroid Build Coastguard Worker        k == DispatchKey.Undefined or is_included_in_alias(k, cand)
191*da0073e9SAndroid Build Coastguard Worker    ) and op.has_kernel_for_dispatch_key(cand):
192*da0073e9SAndroid Build Coastguard Worker        return cand
193*da0073e9SAndroid Build Coastguard Worker    has_backend_kernel = op.has_kernel_for_any_dispatch_key(
194*da0073e9SAndroid Build Coastguard Worker        torch._C._dispatch_get_backend_keyset_from_autograd(k)
195*da0073e9SAndroid Build Coastguard Worker    ) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd)
196*da0073e9SAndroid Build Coastguard Worker    # 2.3. Use CompositeImplicitAutograd kernel if available
197*da0073e9SAndroid Build Coastguard Worker    cand = DispatchKey.CompositeImplicitAutogradNestedTensor
198*da0073e9SAndroid Build Coastguard Worker    if (
199*da0073e9SAndroid Build Coastguard Worker        (k != DispatchKey.Undefined and is_included_in_alias(k, cand))
200*da0073e9SAndroid Build Coastguard Worker        and op.has_kernel_for_dispatch_key(cand)
201*da0073e9SAndroid Build Coastguard Worker        and not has_backend_kernel
202*da0073e9SAndroid Build Coastguard Worker    ):
203*da0073e9SAndroid Build Coastguard Worker        return cand
204*da0073e9SAndroid Build Coastguard Worker    cand = DispatchKey.CompositeImplicitAutograd
205*da0073e9SAndroid Build Coastguard Worker    if (
206*da0073e9SAndroid Build Coastguard Worker        k == DispatchKey.Undefined or is_included_in_alias(k, cand)
207*da0073e9SAndroid Build Coastguard Worker    ) and op.has_kernel_for_dispatch_key(cand):
208*da0073e9SAndroid Build Coastguard Worker        if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key(
209*da0073e9SAndroid Build Coastguard Worker            torch._C._dispatch_autogradother_backends
210*da0073e9SAndroid Build Coastguard Worker        ):
211*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("ambiguous autogradother kernel")
212*da0073e9SAndroid Build Coastguard Worker        elif not has_backend_kernel:
213*da0073e9SAndroid Build Coastguard Worker            return cand
214*da0073e9SAndroid Build Coastguard Worker    # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
215*da0073e9SAndroid Build Coastguard Worker    cand = DispatchKey.Autograd
216*da0073e9SAndroid Build Coastguard Worker    if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
217*da0073e9SAndroid Build Coastguard Worker        return cand
218*da0073e9SAndroid Build Coastguard Worker    # 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
219*da0073e9SAndroid Build Coastguard Worker    cand = DispatchKey.FuncTorchBatchedDecomposition
220*da0073e9SAndroid Build Coastguard Worker    if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
221*da0073e9SAndroid Build Coastguard Worker        return cand
222*da0073e9SAndroid Build Coastguard Worker    # Backend fallback
223*da0073e9SAndroid Build Coastguard Worker    if torch._C._dispatch_has_backend_fallback(k):
224*da0073e9SAndroid Build Coastguard Worker        # The dispatch key itself will implicitly route to backend fallback.
225*da0073e9SAndroid Build Coastguard Worker        # This is probably not great for the pure Python implementation.
226*da0073e9SAndroid Build Coastguard Worker        return k
227*da0073e9SAndroid Build Coastguard Worker    raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker_higher_order_ops: Dict[str, "HigherOrderOperator"] = {}
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [
233*da0073e9SAndroid Build Coastguard Worker    DispatchKey.PythonDispatcher,  # type: ignore[attr-defined]
234*da0073e9SAndroid Build Coastguard Worker    DispatchKey.PythonTLSSnapshot,  # type: ignore[attr-defined]
235*da0073e9SAndroid Build Coastguard Worker    DispatchKey.ADInplaceOrView,
236*da0073e9SAndroid Build Coastguard Worker    DispatchKey.BackendSelect,
237*da0073e9SAndroid Build Coastguard Worker    DispatchKey.AutocastCPU,  # type: ignore[attr-defined]
238*da0073e9SAndroid Build Coastguard Worker    DispatchKey.AutocastCUDA,  # type: ignore[attr-defined]
239*da0073e9SAndroid Build Coastguard Worker]
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Workerclass HigherOrderOperator(OperatorBase, abc.ABC):
243*da0073e9SAndroid Build Coastguard Worker    # The HigherOrderOperator will appear as torch.ops.higher_order.{name}
244*da0073e9SAndroid Build Coastguard Worker    #
245*da0073e9SAndroid Build Coastguard Worker    # If you're creating a new HigherOrderOperator, please do not change the
246*da0073e9SAndroid Build Coastguard Worker    # default. Adding operators to the global torch.ops namespace is a bad
247*da0073e9SAndroid Build Coastguard Worker    # practice due to name collisions.
248*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name):
249*da0073e9SAndroid Build Coastguard Worker        super().__init__()
250*da0073e9SAndroid Build Coastguard Worker        if type(self) is HigherOrderOperator:
251*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
252*da0073e9SAndroid Build Coastguard Worker                "Direct instantiation of HigherOrderOperator is not allowed. Please subclass it."
253*da0073e9SAndroid Build Coastguard Worker            )
254*da0073e9SAndroid Build Coastguard Worker        self._name = name
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        # Make _OPNamespace not scream, this whole name based association needs a good hard look
257*da0073e9SAndroid Build Coastguard Worker        self.__name__ = name
258*da0073e9SAndroid Build Coastguard Worker        _higher_order_ops[name] = self
259*da0073e9SAndroid Build Coastguard Worker        self._ns = "higher_order"
260*da0073e9SAndroid Build Coastguard Worker        self.__module__ = "torch.ops.higher_order"
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS:
265*da0073e9SAndroid Build Coastguard Worker            self.fallthrough(dispatch_key)
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker        # [NOTE] We have to register pre-dispatch key implementation
268*da0073e9SAndroid Build Coastguard Worker        # because sometimes HOP use aot-dispatch tracing to detect certaion
269*da0073e9SAndroid Build Coastguard Worker        # mutations. This is problematic when we are functionalizing HOP
270*da0073e9SAndroid Build Coastguard Worker        # during pre-dispatch because when the inner tracer starts, it will see
271*da0073e9SAndroid Build Coastguard Worker        # that PreDispatch key is still active. In that case, we just redispatch
272*da0073e9SAndroid Build Coastguard Worker        # it to next key. This is only safe to do when PreDispatch key stack has no
273*da0073e9SAndroid Build Coastguard Worker        # active modes.
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker    def py_impl(self, k):
276*da0073e9SAndroid Build Coastguard Worker        if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
277*da0073e9SAndroid Build Coastguard Worker            self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
278*da0073e9SAndroid Build Coastguard Worker        return super().py_impl(k)
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker    @property
281*da0073e9SAndroid Build Coastguard Worker    def namespace(self):
282*da0073e9SAndroid Build Coastguard Worker        return self._ns
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker    def fallthrough(self, dispatch_key):
285*da0073e9SAndroid Build Coastguard Worker        self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker    # Use positional-only argument to avoid naming collide with custom ops arguments
288*da0073e9SAndroid Build Coastguard Worker    # that are named "self".
289*da0073e9SAndroid Build Coastguard Worker    def dispatch(self, /, dispatch_key, *args, **kwargs):
290*da0073e9SAndroid Build Coastguard Worker        from torch.utils._python_dispatch import _get_current_dispatch_mode
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        if dispatch_key in self._dispatch_cache:
293*da0073e9SAndroid Build Coastguard Worker            kernel = self._dispatch_cache[dispatch_key]
294*da0073e9SAndroid Build Coastguard Worker            assert not isinstance(kernel, DispatchKey)
295*da0073e9SAndroid Build Coastguard Worker            return kernel(*args, **kwargs)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode:
298*da0073e9SAndroid Build Coastguard Worker            return dispatch_functorch(self, args, kwargs)
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker        if dispatch_key == DispatchKey.Python:
301*da0073e9SAndroid Build Coastguard Worker            # Keep the following 1:1 with handle_torch_function_no_python_arg_parser
302*da0073e9SAndroid Build Coastguard Worker            # in torch/csrc/utils/python_arg_parser.cpp
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker            overloaded_args_list = []
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker            def has_python_key(tensor):
307*da0073e9SAndroid Build Coastguard Worker                return torch._C._dispatch_keys(tensor).has("Python")
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker            def check_overloaded(arg):
310*da0073e9SAndroid Build Coastguard Worker                if isinstance(arg, torch.Tensor) and has_python_key(arg):
311*da0073e9SAndroid Build Coastguard Worker                    overloaded_args_list.append(arg)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker            for arg in (*args, *kwargs.values()):
314*da0073e9SAndroid Build Coastguard Worker                check_overloaded(arg)
315*da0073e9SAndroid Build Coastguard Worker                if isinstance(arg, (list, tuple)):
316*da0073e9SAndroid Build Coastguard Worker                    for a in arg:
317*da0073e9SAndroid Build Coastguard Worker                        check_overloaded(a)
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker            overloaded_args = tuple(overloaded_args_list)
320*da0073e9SAndroid Build Coastguard Worker            overloaded_types = tuple(type(arg) for arg in overloaded_args)
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker            # Step 1: dispatch on any user TorchDispatchModes
323*da0073e9SAndroid Build Coastguard Worker            from torch.utils._python_dispatch import _pop_mode_temporarily
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker            curr_mode = _get_current_dispatch_mode()
326*da0073e9SAndroid Build Coastguard Worker            if curr_mode is not None:
327*da0073e9SAndroid Build Coastguard Worker                if type(curr_mode) in self.python_key_table:
328*da0073e9SAndroid Build Coastguard Worker                    handler = self.python_key_table[type(curr_mode)]
329*da0073e9SAndroid Build Coastguard Worker                    with _pop_mode_temporarily() as mode:
330*da0073e9SAndroid Build Coastguard Worker                        # "natural" calling convention: (mode, *args, **kwargs)
331*da0073e9SAndroid Build Coastguard Worker                        # TODO(rzou): we should support torch_dispatch calling convention too.
332*da0073e9SAndroid Build Coastguard Worker                        result = handler(mode, *args, **kwargs)
333*da0073e9SAndroid Build Coastguard Worker                else:
334*da0073e9SAndroid Build Coastguard Worker                    raise NotImplementedError(
335*da0073e9SAndroid Build Coastguard Worker                        f"There was no rule registered for HOP {self._name} and mode {curr_mode}. "
336*da0073e9SAndroid Build Coastguard Worker                        f"We recommend filing an issue."
337*da0073e9SAndroid Build Coastguard Worker                    )
338*da0073e9SAndroid Build Coastguard Worker                if result is not NotImplemented:
339*da0073e9SAndroid Build Coastguard Worker                    return result
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker            # Step 2: dispatch on any subclasses
342*da0073e9SAndroid Build Coastguard Worker            for arg in overloaded_args:
343*da0073e9SAndroid Build Coastguard Worker                subclass_type = type(arg)
344*da0073e9SAndroid Build Coastguard Worker                if (
345*da0073e9SAndroid Build Coastguard Worker                    subclass_type.__torch_dispatch__
346*da0073e9SAndroid Build Coastguard Worker                    == torch._C._disabled_torch_dispatch_impl
347*da0073e9SAndroid Build Coastguard Worker                ):
348*da0073e9SAndroid Build Coastguard Worker                    continue
349*da0073e9SAndroid Build Coastguard Worker                if subclass_type in self.python_key_table:
350*da0073e9SAndroid Build Coastguard Worker                    handler = self.python_key_table[subclass_type]
351*da0073e9SAndroid Build Coastguard Worker                    # "natural" calling convention: (*args, **kwargs)
352*da0073e9SAndroid Build Coastguard Worker                    # TODO(rzou): we should support torch_dispatch calling convention too.
353*da0073e9SAndroid Build Coastguard Worker                    result = handler(*args, **kwargs)
354*da0073e9SAndroid Build Coastguard Worker                else:
355*da0073e9SAndroid Build Coastguard Worker                    raise NotImplementedError(
356*da0073e9SAndroid Build Coastguard Worker                        f"There was no rule registered for HOP {self._name} and subclass {subclass_type}. "
357*da0073e9SAndroid Build Coastguard Worker                        f"We recommend filing an issue."
358*da0073e9SAndroid Build Coastguard Worker                    )
359*da0073e9SAndroid Build Coastguard Worker                if result is not NotImplemented:
360*da0073e9SAndroid Build Coastguard Worker                    return result
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker            # All handlers returned NotImplemented
363*da0073e9SAndroid Build Coastguard Worker            raise TypeError(
364*da0073e9SAndroid Build Coastguard Worker                f"Multiple dispatch failed for {self._name}. There was no registered that "
365*da0073e9SAndroid Build Coastguard Worker                f"did not return NotImplemented. Use HOP.py_impl to register some. "
366*da0073e9SAndroid Build Coastguard Worker                f"Tried mode: {curr_mode}) and subclasses: "
367*da0073e9SAndroid Build Coastguard Worker                f"{[type(a) for a in overloaded_args]}"
368*da0073e9SAndroid Build Coastguard Worker            )
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        functionality_key = torch._C._to_functionality_key(dispatch_key)  # type: ignore[attr-defined]
371*da0073e9SAndroid Build Coastguard Worker        if functionality_key == DispatchKey.PreDispatch:
372*da0073e9SAndroid Build Coastguard Worker            from torch.utils._python_dispatch import _pop_mode_temporarily
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker            # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
375*da0073e9SAndroid Build Coastguard Worker            # calls inside of a mode.
376*da0073e9SAndroid Build Coastguard Worker            if (
377*da0073e9SAndroid Build Coastguard Worker                _len_torch_dispatch_stack_pre_dispatch() > 0
378*da0073e9SAndroid Build Coastguard Worker            ) and not torch._C._dispatch_tls_is_dispatch_key_excluded(
379*da0073e9SAndroid Build Coastguard Worker                DispatchKey.Python
380*da0073e9SAndroid Build Coastguard Worker            ):
381*da0073e9SAndroid Build Coastguard Worker                curr_mode = _get_current_dispatch_mode_pre_dispatch()
382*da0073e9SAndroid Build Coastguard Worker                assert (
383*da0073e9SAndroid Build Coastguard Worker                    curr_mode is not None
384*da0073e9SAndroid Build Coastguard Worker                ), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode."
385*da0073e9SAndroid Build Coastguard Worker                assert (
386*da0073e9SAndroid Build Coastguard Worker                    type(curr_mode) in self.python_key_table
387*da0073e9SAndroid Build Coastguard Worker                ), f"Current active mode {curr_mode} not registered"
388*da0073e9SAndroid Build Coastguard Worker                handler = self.python_key_table[type(curr_mode)]
389*da0073e9SAndroid Build Coastguard Worker                with _pop_mode_temporarily(functionality_key) as mode:
390*da0073e9SAndroid Build Coastguard Worker                    return handler(mode, *args, **kwargs)
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker        final_key = resolve_key(self, dispatch_key)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker        # This can current fail due to backend fallbacks.  You just have to
395*da0073e9SAndroid Build Coastguard Worker        # register them by hand for HigherOrderOperator.
396*da0073e9SAndroid Build Coastguard Worker        if final_key not in self.py_kernels:
397*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(
398*da0073e9SAndroid Build Coastguard Worker                f"could not find kernel for HigherOrderOperator {self._name} "
399*da0073e9SAndroid Build Coastguard Worker                f"at dispatch key {final_key} (resolved from {dispatch_key})"
400*da0073e9SAndroid Build Coastguard Worker            )
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker        # [NOTE] We shouldn't cache PreDispatch kernel here because depending
403*da0073e9SAndroid Build Coastguard Worker        # on what modes are active, predispatch behaviour is different.
404*da0073e9SAndroid Build Coastguard Worker        # Also we do same thing for normal ops:
405*da0073e9SAndroid Build Coastguard Worker        # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
406*da0073e9SAndroid Build Coastguard Worker        if dispatch_key != DispatchKey.PreDispatch:
407*da0073e9SAndroid Build Coastguard Worker            self._dispatch_cache[dispatch_key] = self.py_kernels[final_key]
408*da0073e9SAndroid Build Coastguard Worker        kernel = self.py_kernels[final_key]
409*da0073e9SAndroid Build Coastguard Worker        # It's illegal to register DispatchKey to py_kernels, since there's no
410*da0073e9SAndroid Build Coastguard Worker        # C++ kernel to call into
411*da0073e9SAndroid Build Coastguard Worker        assert not isinstance(kernel, DispatchKey)
412*da0073e9SAndroid Build Coastguard Worker        return kernel(*args, **kwargs)
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker    @abc.abstractmethod
415*da0073e9SAndroid Build Coastguard Worker    def __call__(self, /, *args, **kwargs):
416*da0073e9SAndroid Build Coastguard Worker        # Dynamo already traces the body of HigherOrderOp beforehand when it
417*da0073e9SAndroid Build Coastguard Worker        # so no need to trace into it.
418*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo import disable
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker        @disable
421*da0073e9SAndroid Build Coastguard Worker        def wrapper():
422*da0073e9SAndroid Build Coastguard Worker            flat_args = _to_flat_tuple(args, kwargs)
423*da0073e9SAndroid Build Coastguard Worker            if torch.overrides.has_torch_function(flat_args):
424*da0073e9SAndroid Build Coastguard Worker                return torch.overrides.handle_torch_function(
425*da0073e9SAndroid Build Coastguard Worker                    self, flat_args, *args, **kwargs
426*da0073e9SAndroid Build Coastguard Worker                )
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker            dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
429*da0073e9SAndroid Build Coastguard Worker            return self.dispatch(
430*da0073e9SAndroid Build Coastguard Worker                dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
431*da0073e9SAndroid Build Coastguard Worker            )
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker        return wrapper()
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
436*da0073e9SAndroid Build Coastguard Worker        return f"{self.name()}"
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker    def name(self):
439*da0073e9SAndroid Build Coastguard Worker        return self._name
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Workerdef _to_flat_tuple(args, kwargs):
443*da0073e9SAndroid Build Coastguard Worker    return pytree.arg_tree_leaves(*args, **kwargs)
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Workerdef _compute_keyset(args, kwargs, non_fallthrough_keys):
447*da0073e9SAndroid Build Coastguard Worker    tensors = _get_tensors(args, kwargs)
448*da0073e9SAndroid Build Coastguard Worker    return key_extractor(tensors, non_fallthrough_keys)
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Workerdef _get_tensors(args, kwargs):
452*da0073e9SAndroid Build Coastguard Worker    flat_all = _to_flat_tuple(args, kwargs)
453*da0073e9SAndroid Build Coastguard Worker    tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
454*da0073e9SAndroid Build Coastguard Worker    return tuple(tensor_args)
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
458*da0073e9SAndroid Build Coastguard Worker# at ATen/core/dispatch/DispatchKeyExtractor.h
459*da0073e9SAndroid Build Coastguard Workerdef key_extractor(tensors, key_mask):
460*da0073e9SAndroid Build Coastguard Worker    key_set = torch._C._dispatch_tls_local_include_set()
461*da0073e9SAndroid Build Coastguard Worker    for tensor in tensors:
462*da0073e9SAndroid Build Coastguard Worker        key_set = key_set | torch._C._dispatch_keys(tensor)
463*da0073e9SAndroid Build Coastguard Worker    key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
464*da0073e9SAndroid Build Coastguard Worker    key_set = key_set & key_mask
465*da0073e9SAndroid Build Coastguard Worker    return key_set
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker# Mode stack for PreDispatchKey
469*da0073e9SAndroid Build Coastguard Worker# it should always have three keys with
470*da0073e9SAndroid Build Coastguard Worker# priority given to FunctionalTensorMode and
471*da0073e9SAndroid Build Coastguard Worker# then ProxyTorchDispatchMode. It means that
472*da0073e9SAndroid Build Coastguard Worker# slot 0 belongs to ProxyTorchDispatchMode and
473*da0073e9SAndroid Build Coastguard Worker# slot 1 belongs to FunctionalTensorMode.
474*da0073e9SAndroid Build Coastguard Worker#
475*da0073e9SAndroid Build Coastguard Worker# SchemaCheckMode is separate from the other 2,
476*da0073e9SAndroid Build Coastguard Worker# and is only valid when the stack is empty.
477*da0073e9SAndroid Build Coastguard Worker# SchemaCheckMode is for testing purposes, and
478*da0073e9SAndroid Build Coastguard Worker# is meant to run in eager mode on concrete inputs,
479*da0073e9SAndroid Build Coastguard Worker# checking for incorrect schemas in regards to
480*da0073e9SAndroid Build Coastguard Worker# aliasing or mutating ops.
481*da0073e9SAndroid Build Coastguard Workerclass _ModeStackStateForPreDispatch:
482*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
483*da0073e9SAndroid Build Coastguard Worker        self.__infra_modes = [None, None]
484*da0073e9SAndroid Build Coastguard Worker        self._schema_check_mode = None
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker    def set(self, index, mode):
487*da0073e9SAndroid Build Coastguard Worker        assert index < len(self.__infra_modes)
488*da0073e9SAndroid Build Coastguard Worker        self.__infra_modes[index] = mode
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker    def get(self, index):
491*da0073e9SAndroid Build Coastguard Worker        assert index < len(self.__infra_modes)
492*da0073e9SAndroid Build Coastguard Worker        return self.__infra_modes[index]
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker    def count(self):
495*da0073e9SAndroid Build Coastguard Worker        return len([i for i in self.__infra_modes if i is not None]) + int(
496*da0073e9SAndroid Build Coastguard Worker            self._schema_check_mode is not None
497*da0073e9SAndroid Build Coastguard Worker        )
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker_mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch()
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Workerdef unset_mode_pre_dispatch(mode_key, schema_check=False):
504*da0073e9SAndroid Build Coastguard Worker    current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch()
505*da0073e9SAndroid Build Coastguard Worker    assert mode_key is None or mode_key in (
506*da0073e9SAndroid Build Coastguard Worker        torch._C._TorchDispatchModeKey.PROXY,
507*da0073e9SAndroid Build Coastguard Worker        torch._C._TorchDispatchModeKey.FUNCTIONAL,
508*da0073e9SAndroid Build Coastguard Worker    )
509*da0073e9SAndroid Build Coastguard Worker    if schema_check:
510*da0073e9SAndroid Build Coastguard Worker        assert mode_key is None
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker    def _unset_mode():
513*da0073e9SAndroid Build Coastguard Worker        if mode_key == torch._C._TorchDispatchModeKey.PROXY:
514*da0073e9SAndroid Build Coastguard Worker            current_mode = current_mode_stack_pre_dispatch.get(0)
515*da0073e9SAndroid Build Coastguard Worker            mode_stack_state_for_pre_dispatch().set(0, None)
516*da0073e9SAndroid Build Coastguard Worker            return current_mode
517*da0073e9SAndroid Build Coastguard Worker        elif mode_key == torch._C._TorchDispatchModeKey.FUNCTIONAL:
518*da0073e9SAndroid Build Coastguard Worker            current_mode = current_mode_stack_pre_dispatch.get(1)
519*da0073e9SAndroid Build Coastguard Worker            mode_stack_state_for_pre_dispatch().set(1, None)
520*da0073e9SAndroid Build Coastguard Worker            return current_mode
521*da0073e9SAndroid Build Coastguard Worker        else:
522*da0073e9SAndroid Build Coastguard Worker            current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
523*da0073e9SAndroid Build Coastguard Worker            mode_stack_state_for_pre_dispatch()._schema_check_mode = None
524*da0073e9SAndroid Build Coastguard Worker            return current_mode
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker    current_mode = _unset_mode()
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker    new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
529*da0073e9SAndroid Build Coastguard Worker    # When we are unsetting a mode, we need to check if there is
530*da0073e9SAndroid Build Coastguard Worker    # active mode left on the PreDispatch key. If there is nothing
531*da0073e9SAndroid Build Coastguard Worker    # active, we need to remove PreDispatch key from local dispatch include
532*da0073e9SAndroid Build Coastguard Worker    # set.
533*da0073e9SAndroid Build Coastguard Worker    if new_pre_dispatch_len == 0:
534*da0073e9SAndroid Build Coastguard Worker        torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, False)
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker    return current_mode
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Workerdef _set_mode_pre_dispatch(mode):
540*da0073e9SAndroid Build Coastguard Worker    from torch._subclasses.functional_tensor import FunctionalTensorMode
541*da0073e9SAndroid Build Coastguard Worker    from torch._subclasses.schema_check_mode import SchemaCheckMode
542*da0073e9SAndroid Build Coastguard Worker    from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker    assert isinstance(
545*da0073e9SAndroid Build Coastguard Worker        mode,
546*da0073e9SAndroid Build Coastguard Worker        (
547*da0073e9SAndroid Build Coastguard Worker            FunctionalTensorMode,
548*da0073e9SAndroid Build Coastguard Worker            ProxyTorchDispatchMode,
549*da0073e9SAndroid Build Coastguard Worker            SchemaCheckMode,
550*da0073e9SAndroid Build Coastguard Worker        ),
551*da0073e9SAndroid Build Coastguard Worker    )
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker    previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch()
554*da0073e9SAndroid Build Coastguard Worker    if isinstance(mode, SchemaCheckMode):
555*da0073e9SAndroid Build Coastguard Worker        current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
556*da0073e9SAndroid Build Coastguard Worker        if previous_mode_stack_len > 0:
557*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(
558*da0073e9SAndroid Build Coastguard Worker                "SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack"
559*da0073e9SAndroid Build Coastguard Worker            )
560*da0073e9SAndroid Build Coastguard Worker        mode_stack_state_for_pre_dispatch()._schema_check_mode = mode
561*da0073e9SAndroid Build Coastguard Worker    elif isinstance(mode, FunctionalTensorMode):
562*da0073e9SAndroid Build Coastguard Worker        current_mode = mode_stack_state_for_pre_dispatch().get(1)
563*da0073e9SAndroid Build Coastguard Worker        assert current_mode is None
564*da0073e9SAndroid Build Coastguard Worker        mode_stack_state_for_pre_dispatch().set(1, mode)
565*da0073e9SAndroid Build Coastguard Worker    else:
566*da0073e9SAndroid Build Coastguard Worker        current_mode = mode_stack_state_for_pre_dispatch().get(0)
567*da0073e9SAndroid Build Coastguard Worker        assert current_mode is None
568*da0073e9SAndroid Build Coastguard Worker        mode_stack_state_for_pre_dispatch().set(0, mode)
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker    # When we are setting a mode, we need to check if there is
571*da0073e9SAndroid Build Coastguard Worker    # active mode left on the PreDispatch key. If there was nothing
572*da0073e9SAndroid Build Coastguard Worker    # active before setting this mode, it means that PreDispatch key
573*da0073e9SAndroid Build Coastguard Worker    # was turned off. So we need to turn it on again.
574*da0073e9SAndroid Build Coastguard Worker    if previous_mode_stack_len == 0:
575*da0073e9SAndroid Build Coastguard Worker        torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, True)
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Workerdef _pop_mode_from_pre_dispatch():
579*da0073e9SAndroid Build Coastguard Worker    mode_stack = mode_stack_state_for_pre_dispatch()
580*da0073e9SAndroid Build Coastguard Worker    pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker    if pre_dispatch_len == 0:
583*da0073e9SAndroid Build Coastguard Worker        raise AssertionError("Trying to pop empty mode stack")
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker    if mode_stack._schema_check_mode is not None:
586*da0073e9SAndroid Build Coastguard Worker        return unset_mode_pre_dispatch(None, schema_check=True)
587*da0073e9SAndroid Build Coastguard Worker    if mode_stack.get(1) is not None:
588*da0073e9SAndroid Build Coastguard Worker        return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL)
589*da0073e9SAndroid Build Coastguard Worker    if mode_stack.get(0) is not None:
590*da0073e9SAndroid Build Coastguard Worker        return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Workerdef _len_torch_dispatch_stack_pre_dispatch():
594*da0073e9SAndroid Build Coastguard Worker    return mode_stack_state_for_pre_dispatch().count()
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Workerdef _get_dispatch_mode_pre_dispatch(mode_key):
598*da0073e9SAndroid Build Coastguard Worker    assert mode_key in (
599*da0073e9SAndroid Build Coastguard Worker        torch._C._TorchDispatchModeKey.PROXY,
600*da0073e9SAndroid Build Coastguard Worker        torch._C._TorchDispatchModeKey.FUNCTIONAL,
601*da0073e9SAndroid Build Coastguard Worker    )
602*da0073e9SAndroid Build Coastguard Worker    if mode_key == torch._C._TorchDispatchModeKey.PROXY:
603*da0073e9SAndroid Build Coastguard Worker        return mode_stack_state_for_pre_dispatch().get(0)
604*da0073e9SAndroid Build Coastguard Worker    else:
605*da0073e9SAndroid Build Coastguard Worker        return mode_stack_state_for_pre_dispatch().get(1)
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Workerdef _get_current_dispatch_mode_pre_dispatch():
609*da0073e9SAndroid Build Coastguard Worker    if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None:
610*da0073e9SAndroid Build Coastguard Worker        return mode_stack_state_for_pre_dispatch()._schema_check_mode
611*da0073e9SAndroid Build Coastguard Worker    else:
612*da0073e9SAndroid Build Coastguard Worker        stack_len = mode_stack_state_for_pre_dispatch().count()
613*da0073e9SAndroid Build Coastguard Worker        if stack_len == 2:
614*da0073e9SAndroid Build Coastguard Worker            return mode_stack_state_for_pre_dispatch().get(1)
615*da0073e9SAndroid Build Coastguard Worker        if stack_len == 1:
616*da0073e9SAndroid Build Coastguard Worker            return (
617*da0073e9SAndroid Build Coastguard Worker                mode_stack_state_for_pre_dispatch().get(1)
618*da0073e9SAndroid Build Coastguard Worker                if mode_stack_state_for_pre_dispatch().get(1) is not None
619*da0073e9SAndroid Build Coastguard Worker                else mode_stack_state_for_pre_dispatch().get(0)
620*da0073e9SAndroid Build Coastguard Worker            )
621*da0073e9SAndroid Build Coastguard Worker    return None
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Workerdef mode_stack_state_for_pre_dispatch():
625*da0073e9SAndroid Build Coastguard Worker    global _mode_stack_state_for_pre_dispatch
626*da0073e9SAndroid Build Coastguard Worker    return _mode_stack_state_for_pre_dispatch
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Workercached_ops: Set["OpOverload"] = set()
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Workerdef add_cached_op(op_overload):
633*da0073e9SAndroid Build Coastguard Worker    global cached_ops
634*da0073e9SAndroid Build Coastguard Worker    cached_ops.add(op_overload)
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Workerdef reset_cached_ops():
638*da0073e9SAndroid Build Coastguard Worker    global cached_ops
639*da0073e9SAndroid Build Coastguard Worker    cached_ops.clear()
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Workerdef get_cached_ops():
643*da0073e9SAndroid Build Coastguard Worker    global cached_ops
644*da0073e9SAndroid Build Coastguard Worker    return cached_ops
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
648*da0073e9SAndroid Build Coastguard Worker# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
649*da0073e9SAndroid Build Coastguard Workerclass OpOverload(OperatorBase):
650*da0073e9SAndroid Build Coastguard Worker    def __init__(self, overloadpacket, op, op_dk, schema, tags):
651*da0073e9SAndroid Build Coastguard Worker        super().__init__()
652*da0073e9SAndroid Build Coastguard Worker        self._op = op
653*da0073e9SAndroid Build Coastguard Worker        self._op_dk = op_dk
654*da0073e9SAndroid Build Coastguard Worker        self._schema = schema
655*da0073e9SAndroid Build Coastguard Worker        self._overloadpacket = overloadpacket
656*da0073e9SAndroid Build Coastguard Worker        self._tags = tags
657*da0073e9SAndroid Build Coastguard Worker        self._overloadname = (
658*da0073e9SAndroid Build Coastguard Worker            "default" if schema.overload_name == "" else schema.overload_name
659*da0073e9SAndroid Build Coastguard Worker        )
660*da0073e9SAndroid Build Coastguard Worker        self._name = self._schema.name
661*da0073e9SAndroid Build Coastguard Worker        if schema.overload_name:
662*da0073e9SAndroid Build Coastguard Worker            self._name += "." + schema.overload_name
663*da0073e9SAndroid Build Coastguard Worker        self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}"
664*da0073e9SAndroid Build Coastguard Worker        self.__module__ = overloadpacket.__module__
665*da0073e9SAndroid Build Coastguard Worker        op.__module__ = overloadpacket.__module__
666*da0073e9SAndroid Build Coastguard Worker        self.__qualname__ = self._name
667*da0073e9SAndroid Build Coastguard Worker        self.__annotations__ = {}
668*da0073e9SAndroid Build Coastguard Worker        # Only compute the OperatorHandle when we need it. Not all OpOverloads have
669*da0073e9SAndroid Build Coastguard Worker        # OperatorHandles (the TorchScript ones don't...)
670*da0073e9SAndroid Build Coastguard Worker        self._lazy_handle = None
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker        # If the OpOverload was constructed from a Library.def in Python.
673*da0073e9SAndroid Build Coastguard Worker        self._defined_in_python = self.__qualname__ in torch.library._defs
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker        # Logic replicated from aten/src/ATen/native/MathBitsFallback.h
676*da0073e9SAndroid Build Coastguard Worker        is_write = None
677*da0073e9SAndroid Build Coastguard Worker        for a in self._schema.arguments:
678*da0073e9SAndroid Build Coastguard Worker            if a.alias_info is None:
679*da0073e9SAndroid Build Coastguard Worker                continue
680*da0073e9SAndroid Build Coastguard Worker            if is_write is None:
681*da0073e9SAndroid Build Coastguard Worker                is_write = a.alias_info.is_write
682*da0073e9SAndroid Build Coastguard Worker            else:
683*da0073e9SAndroid Build Coastguard Worker                # We will conservatively call mixed mutable/non-mutable
684*da0073e9SAndroid Build Coastguard Worker                # aliased inputs as NOT a view
685*da0073e9SAndroid Build Coastguard Worker                is_write = a.alias_info.is_write or is_write
686*da0073e9SAndroid Build Coastguard Worker        self.is_view = is_write is not None and not is_write
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker    @property
689*da0073e9SAndroid Build Coastguard Worker    def _namespace(self):
690*da0073e9SAndroid Build Coastguard Worker        return self._schema.name.split("::")[0]
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker    @property
693*da0073e9SAndroid Build Coastguard Worker    def _opname(self):
694*da0073e9SAndroid Build Coastguard Worker        return self._schema.name.split("::")[1]
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker    @property
697*da0073e9SAndroid Build Coastguard Worker    def _handle(self):
698*da0073e9SAndroid Build Coastguard Worker        if self._lazy_handle is None:
699*da0073e9SAndroid Build Coastguard Worker            self._lazy_handle = torch._C._dispatch_find_schema_or_throw(
700*da0073e9SAndroid Build Coastguard Worker                self._schema.name, self._schema.overload_name
701*da0073e9SAndroid Build Coastguard Worker            )
702*da0073e9SAndroid Build Coastguard Worker        return self._lazy_handle
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker    # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
705*da0073e9SAndroid Build Coastguard Worker    def __deepcopy__(self, memo=None):
706*da0073e9SAndroid Build Coastguard Worker        return self
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
709*da0073e9SAndroid Build Coastguard Worker        return "<OpOverload(op='{}.{}', overload='{}')>".format(
710*da0073e9SAndroid Build Coastguard Worker            *self._schema.name.split("::"), self._overloadname
711*da0073e9SAndroid Build Coastguard Worker        )
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker    # Use positional-only argument to avoid naming collision with aten ops arguments
714*da0073e9SAndroid Build Coastguard Worker    # that are named "self". This way, all the aten ops can be called by kwargs.
715*da0073e9SAndroid Build Coastguard Worker    def __call__(self, /, *args, **kwargs):
716*da0073e9SAndroid Build Coastguard Worker        return self._op(*args, **kwargs)
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker    # Use positional-only argument to avoid naming collision with aten ops arguments
719*da0073e9SAndroid Build Coastguard Worker    # that are named "self". This way, all the aten ops can be called by kwargs.
720*da0073e9SAndroid Build Coastguard Worker    def redispatch(self, /, keyset, *args, **kwargs):
721*da0073e9SAndroid Build Coastguard Worker        return self._handle.redispatch_boxed(keyset, *args, **kwargs)
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker    def __hash__(self):
724*da0073e9SAndroid Build Coastguard Worker        return hash(self._op)
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker    # `my_namespace.my_op_name.overload_name`
727*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
728*da0073e9SAndroid Build Coastguard Worker        return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker    def has_kernel_for_dispatch_key(self, k):
731*da0073e9SAndroid Build Coastguard Worker        return super().has_kernel_for_dispatch_key(
732*da0073e9SAndroid Build Coastguard Worker            k
733*da0073e9SAndroid Build Coastguard Worker        ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker    def has_kernel_for_any_dispatch_key(self, ks):
736*da0073e9SAndroid Build Coastguard Worker        return torch._C._dispatch_has_kernel_for_any_dispatch_key(
737*da0073e9SAndroid Build Coastguard Worker            self.name(), ks
738*da0073e9SAndroid Build Coastguard Worker        ) or super().has_kernel_for_any_dispatch_key(ks)
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker    @property
741*da0073e9SAndroid Build Coastguard Worker    def namespace(self):
742*da0073e9SAndroid Build Coastguard Worker        return self._schema.name.split("::")[0]
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker    def _can_decompose(self):
745*da0073e9SAndroid Build Coastguard Worker        dk = DispatchKey.CompositeImplicitAutograd
746*da0073e9SAndroid Build Coastguard Worker        return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key(
747*da0073e9SAndroid Build Coastguard Worker            self.name(), dk
748*da0073e9SAndroid Build Coastguard Worker        )
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker    def decompose(self, *args, **kwargs):
751*da0073e9SAndroid Build Coastguard Worker        dk = DispatchKey.CompositeImplicitAutograd
752*da0073e9SAndroid Build Coastguard Worker        if dk in self.py_kernels:
753*da0073e9SAndroid Build Coastguard Worker            # NB: This branch is not too necessary anymore, because we can
754*da0073e9SAndroid Build Coastguard Worker            # apply Python CompositeImplicitAutograd *before* tracing
755*da0073e9SAndroid Build Coastguard Worker            # using Python dispatcher (also taking advantage of the autograd
756*da0073e9SAndroid Build Coastguard Worker            # formula).  But it's included for completeness
757*da0073e9SAndroid Build Coastguard Worker            return self.py_kernels[dk](*args, **kwargs)
758*da0073e9SAndroid Build Coastguard Worker        elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
759*da0073e9SAndroid Build Coastguard Worker            return self._op_dk(dk, *args, **kwargs)
760*da0073e9SAndroid Build Coastguard Worker        else:
761*da0073e9SAndroid Build Coastguard Worker            return NotImplemented
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker    # Remove a dispatch key from the dispatch cache.  This will force it to get
764*da0073e9SAndroid Build Coastguard Worker    # recomputed the next time.  Does nothing
765*da0073e9SAndroid Build Coastguard Worker    # WARNING: if you register a dispatch key to py_kernels of an OpOverload,
766*da0073e9SAndroid Build Coastguard Worker    # calling _del_dispatch on that key is NOT sufficient to apply your change,
767*da0073e9SAndroid Build Coastguard Worker    # because a single registration may affect MULTIPLE dispatch keys (e.g.,
768*da0073e9SAndroid Build Coastguard Worker    # registering Autograd affects AutogradCPU).  del_dispatch is to be used
769*da0073e9SAndroid Build Coastguard Worker    # only if you are specifically modifying how get_dispatch handles a
770*da0073e9SAndroid Build Coastguard Worker    # particular input 'key'.
771*da0073e9SAndroid Build Coastguard Worker    def _uncache_dispatch(self, key):
772*da0073e9SAndroid Build Coastguard Worker        self._dispatch_cache.pop(key, None)
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker    # This implements the pre-computation logic for the Python dispatcher.
775*da0073e9SAndroid Build Coastguard Worker    def _get_dispatch(self, key):
776*da0073e9SAndroid Build Coastguard Worker        # This is only called upon a cache miss
777*da0073e9SAndroid Build Coastguard Worker        assert key not in self._dispatch_cache, f"{self} {key}"
778*da0073e9SAndroid Build Coastguard Worker
779*da0073e9SAndroid Build Coastguard Worker        if key == DispatchKey.Python:
780*da0073e9SAndroid Build Coastguard Worker            if not isinstance(self, TorchBindOpOverload) and not self.python_key_table:
781*da0073e9SAndroid Build Coastguard Worker                self._dispatch_cache[key] = key
782*da0073e9SAndroid Build Coastguard Worker                add_cached_op(self)
783*da0073e9SAndroid Build Coastguard Worker                return key
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker            def handler(*args, **kwargs):
786*da0073e9SAndroid Build Coastguard Worker                from torch.utils._python_dispatch import _get_current_dispatch_mode
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker                # TODO: We also need to handle tensor subclasses here
789*da0073e9SAndroid Build Coastguard Worker                # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
790*da0073e9SAndroid Build Coastguard Worker                curr_mode = type(_get_current_dispatch_mode())
791*da0073e9SAndroid Build Coastguard Worker                assert (
792*da0073e9SAndroid Build Coastguard Worker                    curr_mode is not None
793*da0073e9SAndroid Build Coastguard Worker                ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker                if curr_mode not in self.python_key_table:
796*da0073e9SAndroid Build Coastguard Worker                    if isinstance(self, TorchBindOpOverload):
797*da0073e9SAndroid Build Coastguard Worker                        with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
798*da0073e9SAndroid Build Coastguard Worker                            return torch._library.utils.handle_dispatch_mode(
799*da0073e9SAndroid Build Coastguard Worker                                mode, self, *args, **kwargs
800*da0073e9SAndroid Build Coastguard Worker                            )
801*da0073e9SAndroid Build Coastguard Worker                    else:
802*da0073e9SAndroid Build Coastguard Worker                        return self._op_dk(key, *args, **kwargs)
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker                with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
805*da0073e9SAndroid Build Coastguard Worker                    return self.python_key_table[curr_mode](mode, *args, **kwargs)
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker            self._dispatch_cache[key] = handler
808*da0073e9SAndroid Build Coastguard Worker            add_cached_op(self)
809*da0073e9SAndroid Build Coastguard Worker            return handler
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker        functionality_key = torch._C._to_functionality_key(key)  # type: ignore[attr-defined]
812*da0073e9SAndroid Build Coastguard Worker        if functionality_key == DispatchKey.PreDispatch:
813*da0073e9SAndroid Build Coastguard Worker            curr_stack_len = _len_torch_dispatch_stack_pre_dispatch()
814*da0073e9SAndroid Build Coastguard Worker            # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
815*da0073e9SAndroid Build Coastguard Worker            # calls inside of a mode.
816*da0073e9SAndroid Build Coastguard Worker            if (
817*da0073e9SAndroid Build Coastguard Worker                curr_stack_len > 0
818*da0073e9SAndroid Build Coastguard Worker                and not torch._C._dispatch_tls_is_dispatch_key_excluded(
819*da0073e9SAndroid Build Coastguard Worker                    DispatchKey.Python
820*da0073e9SAndroid Build Coastguard Worker                )
821*da0073e9SAndroid Build Coastguard Worker            ):
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker                def handler(*args, **kwargs):
824*da0073e9SAndroid Build Coastguard Worker                    @contextlib.contextmanager
825*da0073e9SAndroid Build Coastguard Worker                    def _temporarily_pop_modes_from_pre_dispatch():
826*da0073e9SAndroid Build Coastguard Worker                        top_mode = _pop_mode_from_pre_dispatch()
827*da0073e9SAndroid Build Coastguard Worker                        try:
828*da0073e9SAndroid Build Coastguard Worker                            yield top_mode
829*da0073e9SAndroid Build Coastguard Worker                        finally:
830*da0073e9SAndroid Build Coastguard Worker                            _set_mode_pre_dispatch(top_mode)
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker                    with _temporarily_pop_modes_from_pre_dispatch() as curr_mode:
833*da0073e9SAndroid Build Coastguard Worker                        return torch._library.utils.handle_dispatch_mode(
834*da0073e9SAndroid Build Coastguard Worker                            curr_mode, self, *args, **kwargs
835*da0073e9SAndroid Build Coastguard Worker                        )
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker                # Note [Not Caching Per-Dispatch-Key Mode Handlers]
838*da0073e9SAndroid Build Coastguard Worker                # Note that we're not caching this handler.  There isn't really a point, since the slow bit
839*da0073e9SAndroid Build Coastguard Worker                # is the handler itself (in python).
840*da0073e9SAndroid Build Coastguard Worker                # Also, not caching means that we don't have to reset the cache when any existing
841*da0073e9SAndroid Build Coastguard Worker                # modes go out of scope (which in of itself takes time to loop through all operators).
842*da0073e9SAndroid Build Coastguard Worker                return handler
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker        final_key = resolve_key(self, key)
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker        # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
847*da0073e9SAndroid Build Coastguard Worker        cache_result = key != DispatchKey.PreDispatch
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker        # TODO: We could potentially have lots of debugging wrappers against
850*da0073e9SAndroid Build Coastguard Worker        # dispatch keys; design some general registration mechanism instead of
851*da0073e9SAndroid Build Coastguard Worker        # having if statement for each of them
852*da0073e9SAndroid Build Coastguard Worker        if key == DispatchKey.Functionalize:
853*da0073e9SAndroid Build Coastguard Worker            import torch._dispatch.python as pydispatch
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker            if pydispatch.CROSSREF_FUNCTIONALIZE:
856*da0073e9SAndroid Build Coastguard Worker                handler = pydispatch.make_crossref_functionalize(self, final_key)
857*da0073e9SAndroid Build Coastguard Worker                if cache_result:
858*da0073e9SAndroid Build Coastguard Worker                    self._dispatch_cache[key] = handler
859*da0073e9SAndroid Build Coastguard Worker                    add_cached_op(self)
860*da0073e9SAndroid Build Coastguard Worker                return handler
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker        r = self.py_kernels.get(final_key, final_key)
863*da0073e9SAndroid Build Coastguard Worker        if cache_result:
864*da0073e9SAndroid Build Coastguard Worker            self._dispatch_cache[key] = r
865*da0073e9SAndroid Build Coastguard Worker            add_cached_op(self)
866*da0073e9SAndroid Build Coastguard Worker        return r
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker    def name(self):
869*da0073e9SAndroid Build Coastguard Worker        return self._name
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker    @property
872*da0073e9SAndroid Build Coastguard Worker    def overloadpacket(self):
873*da0073e9SAndroid Build Coastguard Worker        return self._overloadpacket
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker    @property
876*da0073e9SAndroid Build Coastguard Worker    def op(self):
877*da0073e9SAndroid Build Coastguard Worker        return self._op
878*da0073e9SAndroid Build Coastguard Worker
879*da0073e9SAndroid Build Coastguard Worker    @property
880*da0073e9SAndroid Build Coastguard Worker    def tags(self):
881*da0073e9SAndroid Build Coastguard Worker        return self._tags
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Worker    # TODO: add more methods to expose information about input and output arguments
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker# TorchBindOpOverload are those custom ops which have at least one overload's
887*da0073e9SAndroid Build Coastguard Worker# schema consists of torch.ScriptObject (i.e. custom class) input.
888*da0073e9SAndroid Build Coastguard Worker# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
889*da0073e9SAndroid Build Coastguard Worker# when its inputs contain FakeScriptObject in a similar way as higher order ops.
890*da0073e9SAndroid Build Coastguard Workerclass TorchBindOpOverload(OpOverload):
891*da0073e9SAndroid Build Coastguard Worker    def _fallthrough_keys(self) -> List[DispatchKey]:
892*da0073e9SAndroid Build Coastguard Worker        # TODO: we should be calling the fallback for these, but a fallthrough is almost close
893*da0073e9SAndroid Build Coastguard Worker        # enough to the fallback in most cases that we care about.
894*da0073e9SAndroid Build Coastguard Worker        _DEFAULT_FALLTHROUGH_KEYS = [
895*da0073e9SAndroid Build Coastguard Worker            DispatchKey.Autograd,
896*da0073e9SAndroid Build Coastguard Worker            DispatchKey.AutogradCPU,
897*da0073e9SAndroid Build Coastguard Worker            DispatchKey.AutogradCUDA,
898*da0073e9SAndroid Build Coastguard Worker            DispatchKey.ADInplaceOrView,
899*da0073e9SAndroid Build Coastguard Worker            DispatchKey.BackendSelect,
900*da0073e9SAndroid Build Coastguard Worker            DispatchKey.PythonTLSSnapshot,
901*da0073e9SAndroid Build Coastguard Worker            DispatchKey.PythonDispatcher,
902*da0073e9SAndroid Build Coastguard Worker        ]
903*da0073e9SAndroid Build Coastguard Worker
904*da0073e9SAndroid Build Coastguard Worker        def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
905*da0073e9SAndroid Build Coastguard Worker            if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key):
906*da0073e9SAndroid Build Coastguard Worker                return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
907*da0073e9SAndroid Build Coastguard Worker                    self.name(), key
908*da0073e9SAndroid Build Coastguard Worker                )
909*da0073e9SAndroid Build Coastguard Worker
910*da0073e9SAndroid Build Coastguard Worker            return (
911*da0073e9SAndroid Build Coastguard Worker                key not in self.py_kernels
912*da0073e9SAndroid Build Coastguard Worker                or self.py_kernels[key] is torch.library.fallthrough_kernel
913*da0073e9SAndroid Build Coastguard Worker            )
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker        return [
916*da0073e9SAndroid Build Coastguard Worker            key
917*da0073e9SAndroid Build Coastguard Worker            for key in _DEFAULT_FALLTHROUGH_KEYS
918*da0073e9SAndroid Build Coastguard Worker            if _may_use_fallthrough_instead_of_fallback(key)
919*da0073e9SAndroid Build Coastguard Worker        ]
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker    @contextlib.contextmanager
922*da0073e9SAndroid Build Coastguard Worker    def _register_as_effectful_op_temporarily(self):
923*da0073e9SAndroid Build Coastguard Worker        from torch._higher_order_ops.effects import (
924*da0073e9SAndroid Build Coastguard Worker            _EffectType,
925*da0073e9SAndroid Build Coastguard Worker            _register_effectful_op,
926*da0073e9SAndroid Build Coastguard Worker            SIDE_EFFECTS,
927*da0073e9SAndroid Build Coastguard Worker        )
928*da0073e9SAndroid Build Coastguard Worker
929*da0073e9SAndroid Build Coastguard Worker        try:
930*da0073e9SAndroid Build Coastguard Worker            if self not in SIDE_EFFECTS:
931*da0073e9SAndroid Build Coastguard Worker                _register_effectful_op(self, _EffectType.ORDERED)
932*da0073e9SAndroid Build Coastguard Worker            yield
933*da0073e9SAndroid Build Coastguard Worker        finally:
934*da0073e9SAndroid Build Coastguard Worker            if self in SIDE_EFFECTS:
935*da0073e9SAndroid Build Coastguard Worker                del SIDE_EFFECTS[self]
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker    # Use positional-only argument to avoid naming collision with aten ops arguments
938*da0073e9SAndroid Build Coastguard Worker    # that are named "self". This way, all the aten ops can be called by kwargs.
939*da0073e9SAndroid Build Coastguard Worker    def __call__(self, /, *args, **kwargs):
940*da0073e9SAndroid Build Coastguard Worker        if _must_dispatch_in_python(args, kwargs):
941*da0073e9SAndroid Build Coastguard Worker            # When any inputs are FakeScriptObject, we need to
942*da0073e9SAndroid Build Coastguard Worker            # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
943*da0073e9SAndroid Build Coastguard Worker            # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject.
944*da0073e9SAndroid Build Coastguard Worker            #
945*da0073e9SAndroid Build Coastguard Worker            # Note:
946*da0073e9SAndroid Build Coastguard Worker            # 1. We only register the torchbind op temporarily as effectful op because we only want
947*da0073e9SAndroid Build Coastguard Worker            #    the effect token functionalization logic to be applied during tracing. Otherwise, the behavior
948*da0073e9SAndroid Build Coastguard Worker            #    of the eagerly executing the op might change after tracing.
949*da0073e9SAndroid Build Coastguard Worker            # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
950*da0073e9SAndroid Build Coastguard Worker            #    cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
951*da0073e9SAndroid Build Coastguard Worker            with self._register_as_effectful_op_temporarily():
952*da0073e9SAndroid Build Coastguard Worker                return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())
953*da0073e9SAndroid Build Coastguard Worker        return self._op(*args, **kwargs)
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker    def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
956*da0073e9SAndroid Build Coastguard Worker        non_fallthrough_keys = torch._C._dispatch_keyset_full()
957*da0073e9SAndroid Build Coastguard Worker        for key in fallthrough_keys:
958*da0073e9SAndroid Build Coastguard Worker            non_fallthrough_keys = non_fallthrough_keys.remove(key)
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker        dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys)
961*da0073e9SAndroid Build Coastguard Worker        dispatch_key = dispatch_key_set.highestPriorityTypeId()
962*da0073e9SAndroid Build Coastguard Worker
963*da0073e9SAndroid Build Coastguard Worker        handler = (
964*da0073e9SAndroid Build Coastguard Worker            self._get_dispatch(dispatch_key)
965*da0073e9SAndroid Build Coastguard Worker            if dispatch_key not in self._dispatch_cache
966*da0073e9SAndroid Build Coastguard Worker            else self._dispatch_cache[dispatch_key]
967*da0073e9SAndroid Build Coastguard Worker        )
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker        if isinstance(handler, DispatchKey):
970*da0073e9SAndroid Build Coastguard Worker            # fallthrough keys can be registered at runtime via torch.library.impl
971*da0073e9SAndroid Build Coastguard Worker            # so need to add it to fallthrough_keys and re-dispatch.
972*da0073e9SAndroid Build Coastguard Worker            if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
973*da0073e9SAndroid Build Coastguard Worker                self.name(), dispatch_key
974*da0073e9SAndroid Build Coastguard Worker            ):
975*da0073e9SAndroid Build Coastguard Worker                return self._dispatch_in_python(
976*da0073e9SAndroid Build Coastguard Worker                    args, kwargs, fallthrough_keys + [dispatch_key]
977*da0073e9SAndroid Build Coastguard Worker                )
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
980*da0073e9SAndroid Build Coastguard Worker                f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}."
981*da0073e9SAndroid Build Coastguard Worker                f" but no python implementation is found."
982*da0073e9SAndroid Build Coastguard Worker                f" Please file an issue on this when you encounter this error."
983*da0073e9SAndroid Build Coastguard Worker                f" This error can happen when you export or compile the model."
984*da0073e9SAndroid Build Coastguard Worker                f" It can still happpen even if a C++ implementation for {dispatch_key}. "
985*da0073e9SAndroid Build Coastguard Worker                f" has been registered. That's because FakeScriptObject purely lives in python and cannot work "
986*da0073e9SAndroid Build Coastguard Worker                f" with a C++ implementation."
987*da0073e9SAndroid Build Coastguard Worker            )
988*da0073e9SAndroid Build Coastguard Worker
989*da0073e9SAndroid Build Coastguard Worker        assert isinstance(handler, Callable)  # type: ignore[arg-type]
990*da0073e9SAndroid Build Coastguard Worker        return handler(*args, **kwargs)
991*da0073e9SAndroid Build Coastguard Worker
992*da0073e9SAndroid Build Coastguard Worker
993*da0073e9SAndroid Build Coastguard Workerdef _must_dispatch_in_python(args, kwargs):
994*da0073e9SAndroid Build Coastguard Worker    return pytree.tree_any(
995*da0073e9SAndroid Build Coastguard Worker        lambda obj: isinstance(
996*da0073e9SAndroid Build Coastguard Worker            obj, torch._library.fake_class_registry.FakeScriptObject
997*da0073e9SAndroid Build Coastguard Worker        ),
998*da0073e9SAndroid Build Coastguard Worker        (args, kwargs),
999*da0073e9SAndroid Build Coastguard Worker    )
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Workerdef _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
1003*da0073e9SAndroid Build Coastguard Worker    return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments)
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
1007*da0073e9SAndroid Build Coastguard Worker# You can obtain an OpOverload object through attribute query.
1008*da0073e9SAndroid Build Coastguard Workerclass OpOverloadPacket:
1009*da0073e9SAndroid Build Coastguard Worker    def __init__(self, qualified_op_name, op_name, op, overload_names):
1010*da0073e9SAndroid Build Coastguard Worker        # These attributes are accessible on the object through the properties
1011*da0073e9SAndroid Build Coastguard Worker        # defined below but are immutable
1012*da0073e9SAndroid Build Coastguard Worker        self._qualified_op_name = qualified_op_name
1013*da0073e9SAndroid Build Coastguard Worker        self.__name__ = op_name
1014*da0073e9SAndroid Build Coastguard Worker        self._op = op
1015*da0073e9SAndroid Build Coastguard Worker        self._overload_names = overload_names
1016*da0073e9SAndroid Build Coastguard Worker        self._dir = []
1017*da0073e9SAndroid Build Coastguard Worker        self._has_torchbind_op_overload = any(
1018*da0073e9SAndroid Build Coastguard Worker            _has_script_object_arg(schema) for schema in self._schemas.values()
1019*da0073e9SAndroid Build Coastguard Worker        )
1020*da0073e9SAndroid Build Coastguard Worker
1021*da0073e9SAndroid Build Coastguard Worker    # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
1022*da0073e9SAndroid Build Coastguard Worker    def __deepcopy__(self, memo=None):
1023*da0073e9SAndroid Build Coastguard Worker        return self
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
1026*da0073e9SAndroid Build Coastguard Worker        return "<OpOverloadPacket(op='{}.{}')>".format(
1027*da0073e9SAndroid Build Coastguard Worker            *self._qualified_op_name.split("::")
1028*da0073e9SAndroid Build Coastguard Worker        )
1029*da0073e9SAndroid Build Coastguard Worker
1030*da0073e9SAndroid Build Coastguard Worker    def __hash__(self):
1031*da0073e9SAndroid Build Coastguard Worker        return hash(self._op)
1032*da0073e9SAndroid Build Coastguard Worker
1033*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
1034*da0073e9SAndroid Build Coastguard Worker        return "{}.{}".format(*self._qualified_op_name.split("::"))
1035*da0073e9SAndroid Build Coastguard Worker
1036*da0073e9SAndroid Build Coastguard Worker    @property
1037*da0073e9SAndroid Build Coastguard Worker    def op(self):
1038*da0073e9SAndroid Build Coastguard Worker        return self._op
1039*da0073e9SAndroid Build Coastguard Worker
1040*da0073e9SAndroid Build Coastguard Worker    @property
1041*da0073e9SAndroid Build Coastguard Worker    def _schemas(self):
1042*da0073e9SAndroid Build Coastguard Worker        return {
1043*da0073e9SAndroid Build Coastguard Worker            overload_name: torch._C._get_schema(self._qualified_op_name, overload_name)
1044*da0073e9SAndroid Build Coastguard Worker            for overload_name in self._overload_names
1045*da0073e9SAndroid Build Coastguard Worker        }
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, key):
1048*da0073e9SAndroid Build Coastguard Worker        # It is not a valid op_name when __file__ is passed in
1049*da0073e9SAndroid Build Coastguard Worker        if key == "__file__":
1050*da0073e9SAndroid Build Coastguard Worker            return "torch.ops"
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker        # ensure that query for dunder attributes that does not exist on
1053*da0073e9SAndroid Build Coastguard Worker        # opoverloadpacket but instead exists on the self._op object does not unnecessarily call
1054*da0073e9SAndroid Build Coastguard Worker        # `_get_operation_overload` (which is an expensive operation).
1055*da0073e9SAndroid Build Coastguard Worker        # This is done to prevent any potential slowdown. This list can be extended
1056*da0073e9SAndroid Build Coastguard Worker        # if there exists other attributes like `__name__` that only exist on self._op and not on the
1057*da0073e9SAndroid Build Coastguard Worker        # opoverloadpacket.
1058*da0073e9SAndroid Build Coastguard Worker        # This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
1059*da0073e9SAndroid Build Coastguard Worker        try:
1060*da0073e9SAndroid Build Coastguard Worker            if key.startswith("__"):
1061*da0073e9SAndroid Build Coastguard Worker                return getattr(self._op, key)
1062*da0073e9SAndroid Build Coastguard Worker        except AttributeError:
1063*da0073e9SAndroid Build Coastguard Worker            # for consistency because it seems weird to
1064*da0073e9SAndroid Build Coastguard Worker            # throw an attribute error with a message containing
1065*da0073e9SAndroid Build Coastguard Worker            # an object name different from the one the attribute
1066*da0073e9SAndroid Build Coastguard Worker            # query was performed on.
1067*da0073e9SAndroid Build Coastguard Worker            raise AttributeError(
1068*da0073e9SAndroid Build Coastguard Worker                f"'{str(self)}' can't have an overload name beginning with '__' and the "
1069*da0073e9SAndroid Build Coastguard Worker                f"underlying op {str(self._op)} has no attribute {key} either."
1070*da0073e9SAndroid Build Coastguard Worker            ) from None
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker        try:
1073*da0073e9SAndroid Build Coastguard Worker            # This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
1074*da0073e9SAndroid Build Coastguard Worker            use_key = "" if key == "default" else key
1075*da0073e9SAndroid Build Coastguard Worker            # TODO: disallow access to overloads registered by JIT
1076*da0073e9SAndroid Build Coastguard Worker            op_dk_tags = torch._C._get_operation_overload(
1077*da0073e9SAndroid Build Coastguard Worker                self._qualified_op_name, use_key
1078*da0073e9SAndroid Build Coastguard Worker            )
1079*da0073e9SAndroid Build Coastguard Worker            if op_dk_tags is None:
1080*da0073e9SAndroid Build Coastguard Worker                raise AttributeError(
1081*da0073e9SAndroid Build Coastguard Worker                    f"The underlying op of '{str(self)}' has no overload name '{key}'"
1082*da0073e9SAndroid Build Coastguard Worker                )
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker            op_, op_dk_, tags = op_dk_tags
1085*da0073e9SAndroid Build Coastguard Worker            schema = torch._C._get_schema(self._qualified_op_name, use_key)
1086*da0073e9SAndroid Build Coastguard Worker            overload = (
1087*da0073e9SAndroid Build Coastguard Worker                OpOverload(self, op_, op_dk_, schema, tags)
1088*da0073e9SAndroid Build Coastguard Worker                if not _has_script_object_arg(schema)
1089*da0073e9SAndroid Build Coastguard Worker                else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
1090*da0073e9SAndroid Build Coastguard Worker            )
1091*da0073e9SAndroid Build Coastguard Worker            # cache the overload object
1092*da0073e9SAndroid Build Coastguard Worker            setattr(self, key, overload)
1093*da0073e9SAndroid Build Coastguard Worker            self._dir.append(key)
1094*da0073e9SAndroid Build Coastguard Worker            return overload
1095*da0073e9SAndroid Build Coastguard Worker        except RuntimeError:
1096*da0073e9SAndroid Build Coastguard Worker            raise AttributeError(
1097*da0073e9SAndroid Build Coastguard Worker                f"The underlying op of '{str(self)}' has no overload name '{key}'"
1098*da0073e9SAndroid Build Coastguard Worker            ) from None
1099*da0073e9SAndroid Build Coastguard Worker
1100*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
1101*da0073e9SAndroid Build Coastguard Worker        return iter(self._dir)
1102*da0073e9SAndroid Build Coastguard Worker
1103*da0073e9SAndroid Build Coastguard Worker    # Use positional-only argument to avoid naming collision with aten ops arguments
1104*da0073e9SAndroid Build Coastguard Worker    # that are named "self". This way, all the aten ops can be called by kwargs.
1105*da0073e9SAndroid Build Coastguard Worker    def __call__(self, /, *args, **kwargs):
1106*da0073e9SAndroid Build Coastguard Worker        # overloading __call__ to ensure torch.ops.foo.bar()
1107*da0073e9SAndroid Build Coastguard Worker        # is still callable from JIT
1108*da0073e9SAndroid Build Coastguard Worker        # We save the function ptr as the `op` attribute on
1109*da0073e9SAndroid Build Coastguard Worker        # OpOverloadPacket to access it here.
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker        # Directly calling OverloadPacket goes into C++, which will check
1112*da0073e9SAndroid Build Coastguard Worker        # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
1113*da0073e9SAndroid Build Coastguard Worker        # intercept it here and call TorchBindOpverload instead.
1114*da0073e9SAndroid Build Coastguard Worker        if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
1115*da0073e9SAndroid Build Coastguard Worker            return _call_overload_packet_from_python(self, args, kwargs)
1116*da0073e9SAndroid Build Coastguard Worker        return self._op(*args, **(kwargs or {}))
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker    # TODO: use this to make a __dir__
1119*da0073e9SAndroid Build Coastguard Worker    def overloads(self):
1120*da0073e9SAndroid Build Coastguard Worker        return [n if n else "default" for n in self._overload_names]
1121*da0073e9SAndroid Build Coastguard Worker
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker# Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp
1124*da0073e9SAndroid Build Coastguard Worker# _jit_get_operations, which calls _get_operation_for_overload_or_packet.
1125*da0073e9SAndroid Build Coastguard Workerdef _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
1126*da0073e9SAndroid Build Coastguard Worker    # Re-use the torch function handling logic in cpp
1127*da0073e9SAndroid Build Coastguard Worker    torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
1128*da0073e9SAndroid Build Coastguard Worker        op, *args, **kwargs
1129*da0073e9SAndroid Build Coastguard Worker    )
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker    if torch_function_called:
1132*da0073e9SAndroid Build Coastguard Worker        return ret
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker    # The following mirrors getOpWithStack.
1135*da0073e9SAndroid Build Coastguard Worker    # In cpp, we do a schema matching for the arguments, and call ToIValue to
1136*da0073e9SAndroid Build Coastguard Worker    # to check whether the arguments are valid. But need to do similar things here
1137*da0073e9SAndroid Build Coastguard Worker    # and check the schema whether the FakeScriptObject is the corresponding fake class
1138*da0073e9SAndroid Build Coastguard Worker    # of the actual class used in schema.
1139*da0073e9SAndroid Build Coastguard Worker    exceptions = {}
1140*da0073e9SAndroid Build Coastguard Worker    found_op = None
1141*da0073e9SAndroid Build Coastguard Worker    for overload_name in op.overloads():
1142*da0073e9SAndroid Build Coastguard Worker        op_overload = getattr(op, overload_name)
1143*da0073e9SAndroid Build Coastguard Worker        try:
1144*da0073e9SAndroid Build Coastguard Worker            _ = torch._C._check_schema_allow_fake_script_object(
1145*da0073e9SAndroid Build Coastguard Worker                op_overload._schema, *args, **kwargs
1146*da0073e9SAndroid Build Coastguard Worker            )
1147*da0073e9SAndroid Build Coastguard Worker            found_op = op_overload
1148*da0073e9SAndroid Build Coastguard Worker            break
1149*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
1150*da0073e9SAndroid Build Coastguard Worker            exceptions[overload_name] = e
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker    if found_op:
1153*da0073e9SAndroid Build Coastguard Worker        return found_op(*args, **kwargs)
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker    err_msg = (
1156*da0073e9SAndroid Build Coastguard Worker        f"Fail to match any TorchBindOverload of {op} with following exceptions:\n"
1157*da0073e9SAndroid Build Coastguard Worker    )
1158*da0073e9SAndroid Build Coastguard Worker    for i, (key, msg) in enumerate(exceptions.items()):
1159*da0073e9SAndroid Build Coastguard Worker        err_msg += f"Overload name {key}:\n {msg}\n"
1160*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(err_msg)
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker
1163*da0073e9SAndroid Build Coastguard Worker# Resolution of torch.fn is different from torch.ops.aten.fn
1164*da0073e9SAndroid Build Coastguard Worker# torch.fn uses the Python argparser, matches with the
1165*da0073e9SAndroid Build Coastguard Worker# appropriate schema, and calls into the unboxed version of the method
1166*da0073e9SAndroid Build Coastguard Worker# torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
1167*da0073e9SAndroid Build Coastguard Worker# JIT creates a stack of all the overloads and then tries to match the
1168*da0073e9SAndroid Build Coastguard Worker# correct one at runtime and always calls into the boxed version of the method
1169*da0073e9SAndroid Build Coastguard Worker# Autograd codegen creates VariableType, TracerType,
1170*da0073e9SAndroid Build Coastguard Worker# inplace or view type and python bindings.
1171*da0073e9SAndroid Build Coastguard Worker# Aten codegen generates tensor methods for the tensor class.
1172*da0073e9SAndroid Build Coastguard Worker
1173*da0073e9SAndroid Build Coastguard Worker# _OpNamespace is a subclass of ModuleType because the torch script
1174*da0073e9SAndroid Build Coastguard Worker# allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
1175*da0073e9SAndroid Build Coastguard Worker# to work from script, we need to ensure ops and foo are modules
1176*da0073e9SAndroid Build Coastguard Worker
1177*da0073e9SAndroid Build Coastguard Worker
1178*da0073e9SAndroid Build Coastguard Workerclass _OpNamespace(types.ModuleType):
1179*da0073e9SAndroid Build Coastguard Worker    """
1180*da0073e9SAndroid Build Coastguard Worker    An op namespace to dynamically bind Operators into Python.
1181*da0073e9SAndroid Build Coastguard Worker
1182*da0073e9SAndroid Build Coastguard Worker    Say a user has created a custom Operator called "my_namespace::my_op". To
1183*da0073e9SAndroid Build Coastguard Worker    call this op, the user will write torch.ops.my_namespace.my_op(...).
1184*da0073e9SAndroid Build Coastguard Worker    At startup, this operation will not yet be bound into Python. Instead, the
1185*da0073e9SAndroid Build Coastguard Worker    following sequence of magic tricks will occur:
1186*da0073e9SAndroid Build Coastguard Worker    1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
1187*da0073e9SAndroid Build Coastguard Worker       on the `torch.ops` object, which will create a new `_OpNamespace`
1188*da0073e9SAndroid Build Coastguard Worker       object called `my_namespace` and set it as an attribute on the `ops`
1189*da0073e9SAndroid Build Coastguard Worker       object.
1190*da0073e9SAndroid Build Coastguard Worker    2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
1191*da0073e9SAndroid Build Coastguard Worker       the `my_namespace` object, which will retrieve the operation via
1192*da0073e9SAndroid Build Coastguard Worker       `torch.get_operation`, a function bound from C++, and then in a similar
1193*da0073e9SAndroid Build Coastguard Worker       fashion bind this new object onto the `my_namespace` object.
1194*da0073e9SAndroid Build Coastguard Worker    3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
1195*da0073e9SAndroid Build Coastguard Worker        and subsequent accesses will incur no further lookup (the namespace and
1196*da0073e9SAndroid Build Coastguard Worker        operation will already exist).
1197*da0073e9SAndroid Build Coastguard Worker    """
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name):
1200*da0073e9SAndroid Build Coastguard Worker        super().__init__("torch.ops." + name)
1201*da0073e9SAndroid Build Coastguard Worker        self.name = name
1202*da0073e9SAndroid Build Coastguard Worker        self._dir = []
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
1205*da0073e9SAndroid Build Coastguard Worker        return iter(self._dir)
1206*da0073e9SAndroid Build Coastguard Worker
1207*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, op_name):
1208*da0073e9SAndroid Build Coastguard Worker        # It is not a valid op_name when __file__ is passed in
1209*da0073e9SAndroid Build Coastguard Worker        if op_name == "__file__":
1210*da0073e9SAndroid Build Coastguard Worker            return "torch.ops"
1211*da0073e9SAndroid Build Coastguard Worker        elif op_name in ["__origin__", "__self__"]:
1212*da0073e9SAndroid Build Coastguard Worker            raise AttributeError(
1213*da0073e9SAndroid Build Coastguard Worker                f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'"
1214*da0073e9SAndroid Build Coastguard Worker            )
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker        # Get the op `my_namespace::my_op` if available. This will also check
1217*da0073e9SAndroid Build Coastguard Worker        # for overloads and raise an exception if there are more than one.
1218*da0073e9SAndroid Build Coastguard Worker        namespace_name = self.name
1219*da0073e9SAndroid Build Coastguard Worker        qualified_op_name = f"{namespace_name}::{op_name}"
1220*da0073e9SAndroid Build Coastguard Worker        module_name = self.__module__ + "." + namespace_name
1221*da0073e9SAndroid Build Coastguard Worker
1222*da0073e9SAndroid Build Coastguard Worker        try:
1223*da0073e9SAndroid Build Coastguard Worker            op, overload_names = _get_packet(qualified_op_name, module_name)
1224*da0073e9SAndroid Build Coastguard Worker            if op is None:
1225*da0073e9SAndroid Build Coastguard Worker                raise AttributeError(
1226*da0073e9SAndroid Build Coastguard Worker                    f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
1227*da0073e9SAndroid Build Coastguard Worker                )
1228*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
1229*da0073e9SAndroid Build Coastguard Worker            # Turn this into AttributeError so getattr(obj, key, default)
1230*da0073e9SAndroid Build Coastguard Worker            # works (this is called by TorchScript with __origin__)
1231*da0073e9SAndroid Build Coastguard Worker            raise AttributeError(
1232*da0073e9SAndroid Build Coastguard Worker                f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
1233*da0073e9SAndroid Build Coastguard Worker            ) from e
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker        op.__module__ = module_name
1236*da0073e9SAndroid Build Coastguard Worker        opoverloadpacket = OpOverloadPacket(
1237*da0073e9SAndroid Build Coastguard Worker            qualified_op_name, op_name, op, overload_names
1238*da0073e9SAndroid Build Coastguard Worker        )
1239*da0073e9SAndroid Build Coastguard Worker        opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
1240*da0073e9SAndroid Build Coastguard Worker        # cache the opoverloadpacket to ensure that each op corresponds to
1241*da0073e9SAndroid Build Coastguard Worker        # a unique OpOverloadPacket object
1242*da0073e9SAndroid Build Coastguard Worker        setattr(self, op_name, opoverloadpacket)
1243*da0073e9SAndroid Build Coastguard Worker        self._dir.append(op_name)
1244*da0073e9SAndroid Build Coastguard Worker        return opoverloadpacket
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Workerdef _get_packet(qualname, op_module):
1248*da0073e9SAndroid Build Coastguard Worker    op, overload_names = torch._C._jit_get_operation(qualname)
1249*da0073e9SAndroid Build Coastguard Worker    if op is not None:
1250*da0073e9SAndroid Build Coastguard Worker        # let the script frontend know that op is identical to the builtin op
1251*da0073e9SAndroid Build Coastguard Worker        # with qualified_op_name
1252*da0073e9SAndroid Build Coastguard Worker        torch.jit._builtins._register_builtin(op, qualname)
1253*da0073e9SAndroid Build Coastguard Worker        op.__module__ = op_module
1254*da0073e9SAndroid Build Coastguard Worker    return op, overload_names
1255*da0073e9SAndroid Build Coastguard Worker
1256*da0073e9SAndroid Build Coastguard Worker
1257*da0073e9SAndroid Build Coastguard Workerdef _refresh_packet(packet):
1258*da0073e9SAndroid Build Coastguard Worker    op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__)
1259*da0073e9SAndroid Build Coastguard Worker    assert op is not None
1260*da0073e9SAndroid Build Coastguard Worker    packet._op = op
1261*da0073e9SAndroid Build Coastguard Worker    packet._overload_names = overload_names
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Workerclass _PyOpNamespace(_OpNamespace):
1265*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name, ops):
1266*da0073e9SAndroid Build Coastguard Worker        super().__init__(name)
1267*da0073e9SAndroid Build Coastguard Worker        self._ops = ops
1268*da0073e9SAndroid Build Coastguard Worker
1269*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, name):
1270*da0073e9SAndroid Build Coastguard Worker        # Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object.
1271*da0073e9SAndroid Build Coastguard Worker        op = self._ops.get(name, None)
1272*da0073e9SAndroid Build Coastguard Worker        if op is None:
1273*da0073e9SAndroid Build Coastguard Worker            raise AttributeError(
1274*da0073e9SAndroid Build Coastguard Worker                f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'"
1275*da0073e9SAndroid Build Coastguard Worker            )
1276*da0073e9SAndroid Build Coastguard Worker        setattr(self, name, op)
1277*da0073e9SAndroid Build Coastguard Worker        return op
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Workerclass _Ops(types.ModuleType):
1281*da0073e9SAndroid Build Coastguard Worker    __file__ = "_ops.py"
1282*da0073e9SAndroid Build Coastguard Worker
1283*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
1284*da0073e9SAndroid Build Coastguard Worker        super().__init__("torch.ops")
1285*da0073e9SAndroid Build Coastguard Worker        self.loaded_libraries = set()
1286*da0073e9SAndroid Build Coastguard Worker        self._higher_order_op_namespace = _PyOpNamespace(
1287*da0073e9SAndroid Build Coastguard Worker            "torch.ops.higher_order", _higher_order_ops
1288*da0073e9SAndroid Build Coastguard Worker        )
1289*da0073e9SAndroid Build Coastguard Worker        self._dir = []
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, name):
1292*da0073e9SAndroid Build Coastguard Worker        # Check if the name is a HigherOrderOperator
1293*da0073e9SAndroid Build Coastguard Worker        if name == "higher_order":
1294*da0073e9SAndroid Build Coastguard Worker            return self._higher_order_op_namespace
1295*da0073e9SAndroid Build Coastguard Worker
1296*da0073e9SAndroid Build Coastguard Worker        # Here we are creating `torch.ops.my_namespace`
1297*da0073e9SAndroid Build Coastguard Worker        namespace = _OpNamespace(name)
1298*da0073e9SAndroid Build Coastguard Worker        setattr(self, name, namespace)
1299*da0073e9SAndroid Build Coastguard Worker        self._dir.append(name)
1300*da0073e9SAndroid Build Coastguard Worker        return namespace
1301*da0073e9SAndroid Build Coastguard Worker
1302*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
1303*da0073e9SAndroid Build Coastguard Worker        return iter(self._dir)
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker    def import_module(self, module):
1306*da0073e9SAndroid Build Coastguard Worker        """
1307*da0073e9SAndroid Build Coastguard Worker        Imports a Python module that has torch.library registrations.
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker        Generally, to extend PyTorch with custom operators, a user will
1310*da0073e9SAndroid Build Coastguard Worker        create a Python module whose import triggers registration of
1311*da0073e9SAndroid Build Coastguard Worker        the custom operators via a torch.ops.load_library call or a call
1312*da0073e9SAndroid Build Coastguard Worker        to one or more torch.library.* APIs.
1313*da0073e9SAndroid Build Coastguard Worker
1314*da0073e9SAndroid Build Coastguard Worker        It is unexpected for Python modules to have side effects, so some
1315*da0073e9SAndroid Build Coastguard Worker        linters and formatters will complain. Use this API to import Python
1316*da0073e9SAndroid Build Coastguard Worker        modules that contain these torch.library side effects.
1317*da0073e9SAndroid Build Coastguard Worker
1318*da0073e9SAndroid Build Coastguard Worker        Args:
1319*da0073e9SAndroid Build Coastguard Worker            module (str): The name of the Python module to import
1320*da0073e9SAndroid Build Coastguard Worker
1321*da0073e9SAndroid Build Coastguard Worker        """
1322*da0073e9SAndroid Build Coastguard Worker        importlib.import_module(module)
1323*da0073e9SAndroid Build Coastguard Worker
1324*da0073e9SAndroid Build Coastguard Worker    def load_library(self, path):
1325*da0073e9SAndroid Build Coastguard Worker        """
1326*da0073e9SAndroid Build Coastguard Worker        Loads a shared library from the given path into the current process.
1327*da0073e9SAndroid Build Coastguard Worker
1328*da0073e9SAndroid Build Coastguard Worker        The library being loaded may run global initialization code to register
1329*da0073e9SAndroid Build Coastguard Worker        custom operators with the PyTorch JIT runtime. This allows dynamically
1330*da0073e9SAndroid Build Coastguard Worker        loading custom operators. For this, you should compile your operator
1331*da0073e9SAndroid Build Coastguard Worker        and the static registration code into a shared library object, and then
1332*da0073e9SAndroid Build Coastguard Worker        call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
1333*da0073e9SAndroid Build Coastguard Worker        shared object.
1334*da0073e9SAndroid Build Coastguard Worker
1335*da0073e9SAndroid Build Coastguard Worker        After the library is loaded, it is added to the
1336*da0073e9SAndroid Build Coastguard Worker        ``torch.ops.loaded_libraries`` attribute, a set that may be inspected
1337*da0073e9SAndroid Build Coastguard Worker        for the paths of all libraries loaded using this function.
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker        Args:
1340*da0073e9SAndroid Build Coastguard Worker            path (str): A path to a shared library to load.
1341*da0073e9SAndroid Build Coastguard Worker        """
1342*da0073e9SAndroid Build Coastguard Worker        if torch._running_with_deploy():
1343*da0073e9SAndroid Build Coastguard Worker            return
1344*da0073e9SAndroid Build Coastguard Worker
1345*da0073e9SAndroid Build Coastguard Worker        path = _utils_internal.resolve_library_path(path)
1346*da0073e9SAndroid Build Coastguard Worker        with dl_open_guard():
1347*da0073e9SAndroid Build Coastguard Worker            # Import the shared library into the process, thus running its
1348*da0073e9SAndroid Build Coastguard Worker            # static (global) initialization code in order to register custom
1349*da0073e9SAndroid Build Coastguard Worker            # operators with the JIT.
1350*da0073e9SAndroid Build Coastguard Worker            ctypes.CDLL(path)
1351*da0073e9SAndroid Build Coastguard Worker        self.loaded_libraries.add(path)
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Worker# The ops "namespace"
1355*da0073e9SAndroid Build Coastguard Workerops: _Ops = _Ops()
1356