xref: /aosp_15_r20/external/pytorch/torch/_custom_op/impl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import functools
4import inspect
5import sys
6import typing
7import weakref
8import warnings
9
10from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
11
12import torch
13import torch._C as _C
14import torch.library as library
15from torch.library import get_ctx
16
17from .autograd import autograd_kernel_indirection, construct_autograd_kernel
18import torch._library.infer_schema
19from torch._library.infer_schema import infer_schema
20
21"""
22torch._custom_op is deprecated. We shipped a production-ready version of it into torch.library.
23Please use those APIs instead.
24"""
25
26__all__ = ["custom_op", "CustomOp", "get_ctx"]
27
28
29SUPPORTED_DEVICE_TYPE_TO_KEY = {
30    "cpu": "CPU",
31    "cuda": "CUDA",
32}
33
34# We will not let users register CustomOps with anything that could look like
35# PyTorch internals to avoid confusion.
36RESERVED_NS = {
37    "prim",
38    "prims",
39    "aten",
40    "at",
41    "torch",
42    "pytorch",
43}
44
45def warn_deprecated():
46    warnings.warn(
47        "torch._custom_op is deprecated and will be removed in PyTorch 2.6, please "
48        "use the equivalent torch.library API instead.", DeprecationWarning)
49
50
51def custom_op(
52    qualname: str, manual_schema: typing.Optional[str] = None
53) -> typing.Callable:
54    r"""
55    This API is deprecated, please use torch.library.custom_op instead
56    """
57    warn_deprecated()
58
59    def inner(func):
60        if not inspect.isfunction(func):
61            raise ValueError(
62                f"custom_op(...)(func): Expected `func` to be a Python "
63                f"function, got: {type(func)}"
64            )
65
66        ns, name = parse_qualname(qualname)
67        validate_namespace(ns)
68        if func.__name__ != name:
69            raise ValueError(
70                f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
71                f"to have name '{name}' but got '{func.__name__}'. "
72                f"Please either change the name of `func` or the qualname that "
73                f"is passed to `custom_op`"
74            )
75
76        schema = infer_schema(func, mutates_args=()) if manual_schema is None else manual_schema
77        schema_str = f"{name}{schema}"
78        function_schema = FunctionSchema.parse(schema_str)
79        validate_schema(function_schema)
80        if manual_schema is not None:
81            validate_function_matches_schema(function_schema, func)
82
83        lib = library.Library(ns, "FRAGMENT")
84        lib.define(schema_str)
85        ophandle = find_ophandle_or_throw(ns, function_schema.name)
86        result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
87
88        result.__name__ = func.__name__
89        result.__module__ = func.__module__
90        result.__doc__ = func.__doc__
91
92        library.impl(lib, result._opname, "Autograd")(
93            autograd_kernel_indirection(weakref.proxy(result))
94        )
95
96        torch._C._dispatch_set_report_error_callback(
97            ophandle, functools.partial(report_error_callback, weakref.proxy(result))
98        )
99
100        return result
101
102    return inner
103
104
105# Global dictionary holding references to all CustomOp objects
106# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
107# Used to query the CustomOp associated with a specific C++ dispatcher operator.
108# An example usage is FakeTensor: FakeTensor checks if a specific operator
109# has an implementation registered via the CustomOp API.
110# Indexed by qualname (e.g. aten::foo)
111global_registry: typing.Dict[str, "CustomOp"] = {}
112
113
114class CustomOp:
115    r"""
116    This API is deprecated, please use torch.library.custom_op instead
117    """
118
119    def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
120        super().__init__()
121        warn_deprecated()
122        if not _private_access:
123            raise RuntimeError(
124                "The CustomOp constructor is private and we do not guarantee "
125                "BC for it. Please use custom_op(...) to create a CustomOp object"
126            )
127        name = f"{cpp_ns}::{operator_name}"
128        self._schema = schema
129        self._cpp_ns = cpp_ns
130        self._lib: library.Library = lib
131        self._ophandle: _C._DispatchOperatorHandle = ophandle
132        # Has the name of the op, e.g. "foo". We cache here for convenience.
133        self._opname: str = operator_name
134        # this is _opname but with namespace. e.g. "custom::foo"
135        self._qualname: str = name
136        self.__name__ = None  # mypy requires this
137        # NB: Some of these impls are registered as kernels to DispatchKeys.
138        # Modifying the _impls dict directly won't do anything in that case.
139        self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
140        # See NOTE [CustomOp autograd kernel indirection]
141        self._registered_autograd_kernel_indirection = False
142
143        global_registry[self._qualname] = self
144
145    def _register_autograd_kernel_indirection(self):
146        assert not self._registered_autograd_kernel_indirection
147        self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
148        self._registered_autograd_kernel_indirection = True
149
150    # Records the impl and the source location in self._impls
151    # Note that this doesn't cause torch.library to use the impl, that
152    # needs to be done in a separate self._lib.impl call.
153    def _register_impl(self, kind, func, stacklevel=2):
154        if self._has_impl(kind):
155            func_and_location = self._impls[kind]
156            assert func_and_location is not None  # Pacify mypy
157            location = func_and_location.location
158            raise RuntimeError(
159                f"Attempting to register a {kind} impl for operator {self._qualname} "
160                f"that already has a {kind} impl registered from Python at "
161                f"{location}. This is not supported."
162            )
163        frame = inspect.getframeinfo(sys._getframe(stacklevel))
164        location = f"{frame.filename}:{frame.lineno}"
165        self._impls[kind] = FuncAndLocation(func, location)
166
167    def _get_impl(self, kind):
168        return self._impls[kind]
169
170    def _has_impl(self, kind):
171        return kind in self._impls
172
173    def _destroy(self):
174        # NOTE: [CustomOp lifetime]
175        # A CustomOp, once created, lives forever. The mechanism is that the
176        # global registry holds a reference to it. However, to make testing
177        # easier, we want to be able to destroy CustomOp objects.
178        # CustomOp._destroy does the job, though it leaves the CustomOp
179        # in a garbage state.
180        del self._lib
181
182        opnamespace = getattr(torch.ops, self._cpp_ns)
183        if hasattr(opnamespace, self._opname):
184            delattr(opnamespace, self._opname)
185
186        del global_registry[self._qualname]
187
188    def __repr__(self):
189        return f'<CustomOp(op="{self._qualname}")>'
190
191    def __call__(self, *args, **kwargs):
192        # Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
193        # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
194        # issues from caching operators that make testing CustomOp difficult).
195        result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
196        return result
197
198    def impl(
199        self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
200    ) -> typing.Callable:
201        r"""
202        This API is deprecated, please use torch.library.custom_op instead
203        """
204        if isinstance(device_types, str):
205            device_types = [device_types]
206        for device_type in device_types:
207            validate_device_type(device_type)
208
209        def inner(f):
210            for device_type in set(device_types):
211                self._check_doesnt_have_library_impl(device_type)
212                self._register_impl(device_type, f, stacklevel=_stacklevel)
213                dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
214                library.impl(self._lib, self._opname, dispatch_key)(f)
215            return f
216
217        return inner
218
219    def _check_doesnt_have_library_impl(self, device_type):
220        if self._has_impl(device_type):
221            return
222        key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
223        if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
224            raise RuntimeError(
225                f"impl(..., device_types={device_type}): the operator {self._qualname} "
226                f"already has an implementation for this device type via a "
227                f"pre-existing torch.library or TORCH_LIBRARY registration.")
228
229    def impl_factory(self) -> typing.Callable:
230        r"""Register an implementation for a factory function."""
231
232        def inner(f):
233            self._register_impl("factory", f)
234            library.impl(self._lib, self._opname, "BackendSelect")(f)
235            return f
236
237        return inner
238
239    def impl_abstract(self, _stacklevel=2) -> typing.Callable:
240        r"""
241        This API is deprecated, please use torch.library.custom_op instead
242        """
243
244        def inner(f):
245            self._check_doesnt_have_library_meta_impl()
246            self._register_impl("abstract", f, stacklevel=_stacklevel)
247            location = self._get_impl("abstract").location
248
249            qualname = self._qualname
250
251            # Handle DispatchKey.Meta registration
252            @functools.wraps(f)
253            def f_with_ctx(*args, **kwargs):
254                def error_on_ctx():
255                    raise RuntimeError(
256                        f"Attempted to call get_ctx() for the meta implementation "
257                        f"for {qualname}."
258                        f"You have presumably called get_ctx() because the operator "
259                        f"has a data-dependent output shape; if so, there is no "
260                        f"such meta implementation and this error is the correct "
261                        f"behavior. Otherwise, please remove the call to get_ctx() "
262                        f"in the implementation registered with impl_abstract "
263                        f"at {location}"
264                    )
265
266                with torch._library.fake_impl.set_ctx_getter(error_on_ctx):
267                    return f(*args, **kwargs)
268
269            self._lib.impl(self._opname, f_with_ctx, "Meta")
270            return f
271
272        return inner
273
274    def _check_can_register_backward(self):
275        def error(detail):
276            raise RuntimeError(
277                f"Cannot use torch._custom_ops APIs to register backward "
278                f"formula for {detail}. Got operator "
279                f"{self._qualname} with schema: {schema}"
280            )
281
282        schema = self._schema
283        if schema.kind() != SchemaKind.functional:
284            error("non-functional operator")
285
286        rets = schema.returns
287        if not schema.returns:
288            error("operator with no returns")
289
290        assert len(rets) > 0
291        is_non_mutating_view = any(
292            r.annotation is not None and not r.annotation.is_write for r in rets
293        )
294        if is_non_mutating_view:
295            error("operator that returns views")
296
297        # We make assumptions about the schema's return types.
298        allowed_return_types = {
299            BaseType(BaseTy.int): "int",
300            BaseType(BaseTy.SymInt): "SymInt",
301            BaseType(BaseTy.bool): "bool",
302            BaseType(BaseTy.float): "float",
303            BaseType(BaseTy.Tensor): "Tensor",
304            ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
305        }
306        for ret in schema.returns:
307            if ret.type in allowed_return_types:
308                continue
309            error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
310
311    def _check_doesnt_have_library_autograd_impl(self):
312        if self._registered_autograd_kernel_indirection:
313            return
314
315        if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
316            raise RuntimeError(
317                f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
318                f"already has an implementation for this device type via a "
319                f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
320                f"CompositeImplicitAutograd operators do not need an autograd formula; "
321                f"instead, the operator will decompose into its constituents and those "
322                f"can have autograd formulas defined on them.")
323
324        # We can improve this by adding "all Autograd<BACKEND> keys", but
325        # realistically people will just be using this API for CPU/CUDA for now.
326        for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
327            if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
328                raise RuntimeError(
329                    f"impl_backward/impl_save_for_backward: "
330                    f"the operator {self._qualname} already has an Autograd kernel "
331                    f"registered to DispatchKey::{key} vi a pre-existing "
332                    f"torch.library or TORCH_LIBRARY registration. Please either "
333                    f"remove those registrations or don't use the torch._custom_ops APIs")
334
335    def _check_doesnt_have_library_meta_impl(self):
336        if self._has_impl("abstract"):
337            return
338
339        # If the user's operator is CompositeExplicitAutograd,
340        # allow them to impl_abstract. This is being pragmatic
341        # (existing custom ops may have CompositeExplicitAutograd
342        # registration that don't work with Meta kernels, so this
343        # gives them an escape hatch).
344        if (
345            _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
346            and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
347        ):
348            return
349
350        # Otherwise, if the user's already has a Meta kernel or their
351        # op is CompositeImplicitAutograd or some other alias dispatch key,
352        # raise.
353
354        # Special case for CompositeImplicitAutograd
355        if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
356            raise RuntimeError(
357                f"impl_abstract(...): the operator {self._qualname} "
358                f"already has an implementation for this device type via a "
359                f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
360                f"CompositeImplicitAutograd operators do not need an abstract impl; "
361                f"instead, the operator will decompose into its constituents and those "
362                f"can have abstract impls defined on them.")
363
364        if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
365            raise RuntimeError(
366                f"impl_abstract(...): the operator {self._qualname} "
367                f"already has an DispatchKey::Meta implementation via a "
368                f"pre-existing torch.library or TORCH_LIBRARY registration. "
369                f"Please either remove that registration or don't call impl_abstract.")
370
371    # NOTE ["backward", "save_for_backward", and "autograd"]
372    # As a part of the explicit autograd API, a user must provide us
373    # a "save_for_backward" function and a "backward" function.
374    # When both of these have been provided, then we automatically
375    # construct the "autograd" kernel.
376    def _register_autograd_kernel(self):
377        assert self._has_impl("backward")
378        assert self._has_impl("save_for_backward")
379        kernel = construct_autograd_kernel(
380            self._schema,
381            self._output_differentiability,
382            self,
383            get_op(self._qualname),
384            self._get_impl("save_for_backward").func,
385            self._get_impl("backward").func)
386        self._register_impl("autograd", kernel)
387
388    def impl_save_for_backward(self, _stacklevel=2):
389        r"""Register a function that tells us what to save for backward.
390
391        Please see impl_backward for more details.
392        """
393        def inner(f):
394            self._check_can_register_backward()
395            self._check_doesnt_have_library_autograd_impl()
396            if not self._registered_autograd_kernel_indirection:
397                self._register_autograd_kernel_indirection()
398            self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
399            if self._has_impl("backward"):
400                self._register_autograd_kernel()
401        return inner
402
403    def impl_backward(self, output_differentiability=None, _stacklevel=2):
404        r"""
405        This API is deprecated, please use torch.library.custom_op instead
406        """
407        if output_differentiability is not None:
408            def yell():
409                raise RuntimeError(
410                    f"impl_backward(output_differentiability): expected "
411                    f"output_differentiability to be a list of bools with "
412                    f"length equal to the number of outputs of this CustomOp "
413                    f"got: {output_differentiability}")
414
415            if not isinstance(output_differentiability, list):
416                yell()
417            for diff in output_differentiability:
418                if not isinstance(diff, bool):
419                    yell()
420            if len(self._schema.returns) != len(output_differentiability):
421                yell()
422
423        def inner(f):
424            self._check_can_register_backward()
425            self._check_doesnt_have_library_autograd_impl()
426            if not self._registered_autograd_kernel_indirection:
427                self._register_autograd_kernel_indirection()
428            self._register_impl("backward", f, stacklevel=_stacklevel)
429            self._output_differentiability = output_differentiability
430            if self._has_impl("save_for_backward"):
431                self._register_autograd_kernel()
432        return inner
433
434
435@dataclasses.dataclass
436class FuncAndLocation:
437    func: typing.Callable
438    location: str
439
440
441def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
442    overload_name = (
443        "" if operator_name.overload_name is None else operator_name.overload_name
444    )
445    return _C._dispatch_find_schema_or_throw(
446        f"{cpp_ns}::{str(operator_name.name)}", overload_name
447    )
448
449
450def validate_namespace(ns: str) -> None:
451    if "." in ns:
452        raise ValueError(
453            f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
454            f"valid variable name)"
455        )
456    if ns in RESERVED_NS:
457        raise ValueError(
458            f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
459            f"please choose something else. "
460        )
461
462def validate_schema(schema: FunctionSchema) -> None:
463    if not torch._library.utils.is_functional_schema(schema):
464        raise ValueError(
465            f"custom_op only supports functional operators "
466            f"(ops that do not mutate any inputs, do not return "
467            f"views of the inputs, and has at least one return). "
468            f"Got the following non-functional schema: {schema}"
469        )
470
471    # For simplicity: don't allow self arguments
472    if schema.arguments.self_arg is not None:
473        raise ValueError(
474            f"custom_op does not support arguments named 'self'. Please "
475            f"rename your argument. Got: {schema}"
476        )
477
478
479def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
480    names = qualname.split("::", 1)
481    if len(names) != 2:
482        raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
483                         f"operator name should look something like ns::foo")
484    if '.' in names[1]:
485        raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
486                         f"i.e. operator names with '.' in them. "
487                         f"Please name your operator something like ns::foo. "
488                         f"Got: {qualname}")
489    return names[0], names[1]
490
491
492def validate_device_type(device_type: str) -> None:
493    if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
494        raise ValueError(
495            f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
496            f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
497        )
498
499
500def supported_param(param: inspect.Parameter) -> bool:
501    return param.kind in (
502        inspect.Parameter.POSITIONAL_OR_KEYWORD,
503        inspect.Parameter.KEYWORD_ONLY,
504    )
505
506
507def validate_function_matches_schema(
508    schema: FunctionSchema, func: typing.Callable
509) -> None:
510    sig = inspect.signature(func)
511
512    if not all(supported_param(p) for _, p in sig.parameters.items()):
513        raise ValueError(
514            f"custom_op(..., manual_schema)(func): positional-only args, "
515            f"varargs, and kwargs are not supported. Please rewrite `func` "
516            f"to not have them. Got `func` with signature: {sig}"
517        )
518
519    if (
520        any(
521            p.annotation is not inspect.Parameter.empty
522            for _, p in sig.parameters.items()
523        )
524        or sig.return_annotation is not inspect.Signature.empty
525    ):
526        raise ValueError(
527            f"custom_op(..., manual_schema)(func): When passing in a manual "
528            f"schema, we expect `func` to have no type annotations to avoid "
529            f"ambiguity. Got `func` with signature: {sig}"
530        )
531
532    positional = [
533        (name, param)
534        for name, param in sig.parameters.items()
535        if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
536    ]
537    kwargonly = [
538        (name, param)
539        for name, param in sig.parameters.items()
540        if param.kind == inspect.Parameter.KEYWORD_ONLY
541    ]
542
543    def error():
544        raise ValueError(
545            f"custom_op(..., manual_schema)(func): When passing in a manual "
546            f"schema, we expect `func`'s signature to match `manual_schema` "
547            f"(aside from type annotations). "
548            f"func's signature: {sig}, manual_schema: {schema}"
549        )
550
551    def error_default_args():
552        raise ValueError(
553            f"custom_op(..., manual_schema)(func): "
554            f"neither func nor manual_schema should have default "
555            f"arguments. Got "
556            f"func's signature: {sig}, manual_schema: {schema}"
557        )
558
559    def compare(sig_args, schema_args):
560        if len(sig_args) != len(schema_args):
561            error()
562        for (name, param), arg in zip(sig_args, schema_args):
563            if name != arg.name:
564                error()
565            if param.default is not inspect.Parameter.empty or arg.default is not None:
566                error_default_args()
567
568    compare(positional, schema.arguments.flat_positional)
569    compare(kwargonly, schema.arguments.flat_kwarg_only)
570
571
572def report_error_callback(custom_op: typing.Any, key: str) -> None:
573    if key == "Undefined":
574        raise NotImplementedError(
575            f"{custom_op}: There were no Tensor inputs to this operator "
576            f"(e.g. you passed an empty list of Tensors). If your operator is a "
577            f"factory function (that is, it takes no Tensors and constructs "
578            f"a new one), then please use CustomOp.impl_factory to register "
579            f"an implementation for it"
580        )
581    if key == "Meta":
582        raise NotImplementedError(
583            f"{custom_op}: when running with device='Meta' tensors: there is no "
584            f"abstract impl registered for this CustomOp. Please register one via "
585            f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
586        )
587    if key in ("CPU", "CUDA"):
588        device = key.lower()
589        raise NotImplementedError(
590            f"{custom_op}: when running with device='{device}' tensors: there is no "
591            f"{device} impl registered for this CustomOp. Please register one via "
592            f"CustomOp.impl(device_type='{device}')"
593        )
594    raise NotImplementedError(
595        f"{custom_op}: No implementation for dispatch key {key}. It is likely "
596        f"that we have not added this functionality yet, please either open an "
597        f"issue or if you're feeling adventurous, use the low-level "
598        f"torch.library API"
599    )
600
601
602def custom_op_from_existing(op):
603    ns = op.namespace
604    lib = torch.library.Library(ns, "FRAGMENT")
605    name = op.name().split("::")[-1]
606    schema_str = str(op._schema)
607    # CustomOp expects the schema string without the namespace
608    schema_str = schema_str.split("::")[-1]
609    schema = FunctionSchema.parse(schema_str)
610    return CustomOp(lib, ns, schema, name, op, _private_access=True)
611
612
613def get_op(qualname):
614    def error_not_found():
615        raise ValueError(
616            f"Could not find the operator {qualname}. Please make sure you have "
617            f"already registered the operator and (if registered from C++) "
618            f"loaded it via torch.ops.load_library.")
619
620    ns, name = parse_qualname(qualname)
621    if not hasattr(torch.ops, ns):
622        error_not_found()
623    opnamespace = getattr(torch.ops, ns)
624    if not hasattr(opnamespace, name):
625        error_not_found()
626    packet = getattr(opnamespace, name)
627    if not hasattr(packet, 'default'):
628        error_not_found()
629    return packet.default
630
631
632def _find_custom_op(qualname, also_check_torch_library=False):
633    if qualname in global_registry:
634        return global_registry[qualname]
635    if not also_check_torch_library:
636        raise RuntimeError(
637            f'Could not find custom op "{qualname}". Did you register it via '
638            f"the torch._custom_ops API?")
639    overload = get_op(qualname)
640    result = custom_op_from_existing(overload)
641    return result
642
643
644def get_abstract_impl(qualname):
645    if qualname not in torch._custom_op.impl.global_registry:
646        return None
647    custom_op = torch._custom_op.impl.global_registry[qualname]
648    if custom_op is None:
649        return None
650    if not custom_op._has_impl("abstract"):
651        return None
652    return custom_op._get_impl("abstract").func
653
654
655def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
656    ns, name = qualname.split("::")
657    schema_str = f"{name}{schema}"
658    function_schema = FunctionSchema.parse(schema_str)
659    validate_schema(function_schema)
660    tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
661    lib = library.Library(ns, "FRAGMENT")
662    lib.define(schema_str, tags=tags)
663    ophandle = find_ophandle_or_throw(ns, function_schema.name)
664    result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
665    result._register_autograd_kernel_indirection()
666
667    torch._C._dispatch_set_report_error_callback(
668        ophandle, functools.partial(report_error_callback, weakref.proxy(result))
669    )
670    return get_op(qualname)
671