1# mypy: allow-untyped-defs 2import dataclasses 3import functools 4import inspect 5import sys 6import typing 7import weakref 8import warnings 9 10from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy 11 12import torch 13import torch._C as _C 14import torch.library as library 15from torch.library import get_ctx 16 17from .autograd import autograd_kernel_indirection, construct_autograd_kernel 18import torch._library.infer_schema 19from torch._library.infer_schema import infer_schema 20 21""" 22torch._custom_op is deprecated. We shipped a production-ready version of it into torch.library. 23Please use those APIs instead. 24""" 25 26__all__ = ["custom_op", "CustomOp", "get_ctx"] 27 28 29SUPPORTED_DEVICE_TYPE_TO_KEY = { 30 "cpu": "CPU", 31 "cuda": "CUDA", 32} 33 34# We will not let users register CustomOps with anything that could look like 35# PyTorch internals to avoid confusion. 36RESERVED_NS = { 37 "prim", 38 "prims", 39 "aten", 40 "at", 41 "torch", 42 "pytorch", 43} 44 45def warn_deprecated(): 46 warnings.warn( 47 "torch._custom_op is deprecated and will be removed in PyTorch 2.6, please " 48 "use the equivalent torch.library API instead.", DeprecationWarning) 49 50 51def custom_op( 52 qualname: str, manual_schema: typing.Optional[str] = None 53) -> typing.Callable: 54 r""" 55 This API is deprecated, please use torch.library.custom_op instead 56 """ 57 warn_deprecated() 58 59 def inner(func): 60 if not inspect.isfunction(func): 61 raise ValueError( 62 f"custom_op(...)(func): Expected `func` to be a Python " 63 f"function, got: {type(func)}" 64 ) 65 66 ns, name = parse_qualname(qualname) 67 validate_namespace(ns) 68 if func.__name__ != name: 69 raise ValueError( 70 f"custom_op(qualname='{qualname}', ...)(func): expected `func` " 71 f"to have name '{name}' but got '{func.__name__}'. " 72 f"Please either change the name of `func` or the qualname that " 73 f"is passed to `custom_op`" 74 ) 75 76 schema = infer_schema(func, mutates_args=()) if manual_schema is None else manual_schema 77 schema_str = f"{name}{schema}" 78 function_schema = FunctionSchema.parse(schema_str) 79 validate_schema(function_schema) 80 if manual_schema is not None: 81 validate_function_matches_schema(function_schema, func) 82 83 lib = library.Library(ns, "FRAGMENT") 84 lib.define(schema_str) 85 ophandle = find_ophandle_or_throw(ns, function_schema.name) 86 result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) 87 88 result.__name__ = func.__name__ 89 result.__module__ = func.__module__ 90 result.__doc__ = func.__doc__ 91 92 library.impl(lib, result._opname, "Autograd")( 93 autograd_kernel_indirection(weakref.proxy(result)) 94 ) 95 96 torch._C._dispatch_set_report_error_callback( 97 ophandle, functools.partial(report_error_callback, weakref.proxy(result)) 98 ) 99 100 return result 101 102 return inner 103 104 105# Global dictionary holding references to all CustomOp objects 106# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime]) 107# Used to query the CustomOp associated with a specific C++ dispatcher operator. 108# An example usage is FakeTensor: FakeTensor checks if a specific operator 109# has an implementation registered via the CustomOp API. 110# Indexed by qualname (e.g. aten::foo) 111global_registry: typing.Dict[str, "CustomOp"] = {} 112 113 114class CustomOp: 115 r""" 116 This API is deprecated, please use torch.library.custom_op instead 117 """ 118 119 def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False): 120 super().__init__() 121 warn_deprecated() 122 if not _private_access: 123 raise RuntimeError( 124 "The CustomOp constructor is private and we do not guarantee " 125 "BC for it. Please use custom_op(...) to create a CustomOp object" 126 ) 127 name = f"{cpp_ns}::{operator_name}" 128 self._schema = schema 129 self._cpp_ns = cpp_ns 130 self._lib: library.Library = lib 131 self._ophandle: _C._DispatchOperatorHandle = ophandle 132 # Has the name of the op, e.g. "foo". We cache here for convenience. 133 self._opname: str = operator_name 134 # this is _opname but with namespace. e.g. "custom::foo" 135 self._qualname: str = name 136 self.__name__ = None # mypy requires this 137 # NB: Some of these impls are registered as kernels to DispatchKeys. 138 # Modifying the _impls dict directly won't do anything in that case. 139 self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {} 140 # See NOTE [CustomOp autograd kernel indirection] 141 self._registered_autograd_kernel_indirection = False 142 143 global_registry[self._qualname] = self 144 145 def _register_autograd_kernel_indirection(self): 146 assert not self._registered_autograd_kernel_indirection 147 self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd") 148 self._registered_autograd_kernel_indirection = True 149 150 # Records the impl and the source location in self._impls 151 # Note that this doesn't cause torch.library to use the impl, that 152 # needs to be done in a separate self._lib.impl call. 153 def _register_impl(self, kind, func, stacklevel=2): 154 if self._has_impl(kind): 155 func_and_location = self._impls[kind] 156 assert func_and_location is not None # Pacify mypy 157 location = func_and_location.location 158 raise RuntimeError( 159 f"Attempting to register a {kind} impl for operator {self._qualname} " 160 f"that already has a {kind} impl registered from Python at " 161 f"{location}. This is not supported." 162 ) 163 frame = inspect.getframeinfo(sys._getframe(stacklevel)) 164 location = f"{frame.filename}:{frame.lineno}" 165 self._impls[kind] = FuncAndLocation(func, location) 166 167 def _get_impl(self, kind): 168 return self._impls[kind] 169 170 def _has_impl(self, kind): 171 return kind in self._impls 172 173 def _destroy(self): 174 # NOTE: [CustomOp lifetime] 175 # A CustomOp, once created, lives forever. The mechanism is that the 176 # global registry holds a reference to it. However, to make testing 177 # easier, we want to be able to destroy CustomOp objects. 178 # CustomOp._destroy does the job, though it leaves the CustomOp 179 # in a garbage state. 180 del self._lib 181 182 opnamespace = getattr(torch.ops, self._cpp_ns) 183 if hasattr(opnamespace, self._opname): 184 delattr(opnamespace, self._opname) 185 186 del global_registry[self._qualname] 187 188 def __repr__(self): 189 return f'<CustomOp(op="{self._qualname}")>' 190 191 def __call__(self, *args, **kwargs): 192 # Bypass torch.ops.* and directly do OperatorHandle::callBoxed. 193 # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime 194 # issues from caching operators that make testing CustomOp difficult). 195 result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs) 196 return result 197 198 def impl( 199 self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2, 200 ) -> typing.Callable: 201 r""" 202 This API is deprecated, please use torch.library.custom_op instead 203 """ 204 if isinstance(device_types, str): 205 device_types = [device_types] 206 for device_type in device_types: 207 validate_device_type(device_type) 208 209 def inner(f): 210 for device_type in set(device_types): 211 self._check_doesnt_have_library_impl(device_type) 212 self._register_impl(device_type, f, stacklevel=_stacklevel) 213 dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] 214 library.impl(self._lib, self._opname, dispatch_key)(f) 215 return f 216 217 return inner 218 219 def _check_doesnt_have_library_impl(self, device_type): 220 if self._has_impl(device_type): 221 return 222 key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] 223 if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key): 224 raise RuntimeError( 225 f"impl(..., device_types={device_type}): the operator {self._qualname} " 226 f"already has an implementation for this device type via a " 227 f"pre-existing torch.library or TORCH_LIBRARY registration.") 228 229 def impl_factory(self) -> typing.Callable: 230 r"""Register an implementation for a factory function.""" 231 232 def inner(f): 233 self._register_impl("factory", f) 234 library.impl(self._lib, self._opname, "BackendSelect")(f) 235 return f 236 237 return inner 238 239 def impl_abstract(self, _stacklevel=2) -> typing.Callable: 240 r""" 241 This API is deprecated, please use torch.library.custom_op instead 242 """ 243 244 def inner(f): 245 self._check_doesnt_have_library_meta_impl() 246 self._register_impl("abstract", f, stacklevel=_stacklevel) 247 location = self._get_impl("abstract").location 248 249 qualname = self._qualname 250 251 # Handle DispatchKey.Meta registration 252 @functools.wraps(f) 253 def f_with_ctx(*args, **kwargs): 254 def error_on_ctx(): 255 raise RuntimeError( 256 f"Attempted to call get_ctx() for the meta implementation " 257 f"for {qualname}." 258 f"You have presumably called get_ctx() because the operator " 259 f"has a data-dependent output shape; if so, there is no " 260 f"such meta implementation and this error is the correct " 261 f"behavior. Otherwise, please remove the call to get_ctx() " 262 f"in the implementation registered with impl_abstract " 263 f"at {location}" 264 ) 265 266 with torch._library.fake_impl.set_ctx_getter(error_on_ctx): 267 return f(*args, **kwargs) 268 269 self._lib.impl(self._opname, f_with_ctx, "Meta") 270 return f 271 272 return inner 273 274 def _check_can_register_backward(self): 275 def error(detail): 276 raise RuntimeError( 277 f"Cannot use torch._custom_ops APIs to register backward " 278 f"formula for {detail}. Got operator " 279 f"{self._qualname} with schema: {schema}" 280 ) 281 282 schema = self._schema 283 if schema.kind() != SchemaKind.functional: 284 error("non-functional operator") 285 286 rets = schema.returns 287 if not schema.returns: 288 error("operator with no returns") 289 290 assert len(rets) > 0 291 is_non_mutating_view = any( 292 r.annotation is not None and not r.annotation.is_write for r in rets 293 ) 294 if is_non_mutating_view: 295 error("operator that returns views") 296 297 # We make assumptions about the schema's return types. 298 allowed_return_types = { 299 BaseType(BaseTy.int): "int", 300 BaseType(BaseTy.SymInt): "SymInt", 301 BaseType(BaseTy.bool): "bool", 302 BaseType(BaseTy.float): "float", 303 BaseType(BaseTy.Tensor): "Tensor", 304 ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]", 305 } 306 for ret in schema.returns: 307 if ret.type in allowed_return_types: 308 continue 309 error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})") 310 311 def _check_doesnt_have_library_autograd_impl(self): 312 if self._registered_autograd_kernel_indirection: 313 return 314 315 if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): 316 raise RuntimeError( 317 f"impl_backward/impl_save_for_backward: the operator {self._qualname} " 318 f"already has an implementation for this device type via a " 319 f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." 320 f"CompositeImplicitAutograd operators do not need an autograd formula; " 321 f"instead, the operator will decompose into its constituents and those " 322 f"can have autograd formulas defined on them.") 323 324 # We can improve this by adding "all Autograd<BACKEND> keys", but 325 # realistically people will just be using this API for CPU/CUDA for now. 326 for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]: 327 if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key): 328 raise RuntimeError( 329 f"impl_backward/impl_save_for_backward: " 330 f"the operator {self._qualname} already has an Autograd kernel " 331 f"registered to DispatchKey::{key} vi a pre-existing " 332 f"torch.library or TORCH_LIBRARY registration. Please either " 333 f"remove those registrations or don't use the torch._custom_ops APIs") 334 335 def _check_doesnt_have_library_meta_impl(self): 336 if self._has_impl("abstract"): 337 return 338 339 # If the user's operator is CompositeExplicitAutograd, 340 # allow them to impl_abstract. This is being pragmatic 341 # (existing custom ops may have CompositeExplicitAutograd 342 # registration that don't work with Meta kernels, so this 343 # gives them an escape hatch). 344 if ( 345 _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd") 346 and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta") 347 ): 348 return 349 350 # Otherwise, if the user's already has a Meta kernel or their 351 # op is CompositeImplicitAutograd or some other alias dispatch key, 352 # raise. 353 354 # Special case for CompositeImplicitAutograd 355 if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): 356 raise RuntimeError( 357 f"impl_abstract(...): the operator {self._qualname} " 358 f"already has an implementation for this device type via a " 359 f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." 360 f"CompositeImplicitAutograd operators do not need an abstract impl; " 361 f"instead, the operator will decompose into its constituents and those " 362 f"can have abstract impls defined on them.") 363 364 if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"): 365 raise RuntimeError( 366 f"impl_abstract(...): the operator {self._qualname} " 367 f"already has an DispatchKey::Meta implementation via a " 368 f"pre-existing torch.library or TORCH_LIBRARY registration. " 369 f"Please either remove that registration or don't call impl_abstract.") 370 371 # NOTE ["backward", "save_for_backward", and "autograd"] 372 # As a part of the explicit autograd API, a user must provide us 373 # a "save_for_backward" function and a "backward" function. 374 # When both of these have been provided, then we automatically 375 # construct the "autograd" kernel. 376 def _register_autograd_kernel(self): 377 assert self._has_impl("backward") 378 assert self._has_impl("save_for_backward") 379 kernel = construct_autograd_kernel( 380 self._schema, 381 self._output_differentiability, 382 self, 383 get_op(self._qualname), 384 self._get_impl("save_for_backward").func, 385 self._get_impl("backward").func) 386 self._register_impl("autograd", kernel) 387 388 def impl_save_for_backward(self, _stacklevel=2): 389 r"""Register a function that tells us what to save for backward. 390 391 Please see impl_backward for more details. 392 """ 393 def inner(f): 394 self._check_can_register_backward() 395 self._check_doesnt_have_library_autograd_impl() 396 if not self._registered_autograd_kernel_indirection: 397 self._register_autograd_kernel_indirection() 398 self._register_impl("save_for_backward", f, stacklevel=_stacklevel) 399 if self._has_impl("backward"): 400 self._register_autograd_kernel() 401 return inner 402 403 def impl_backward(self, output_differentiability=None, _stacklevel=2): 404 r""" 405 This API is deprecated, please use torch.library.custom_op instead 406 """ 407 if output_differentiability is not None: 408 def yell(): 409 raise RuntimeError( 410 f"impl_backward(output_differentiability): expected " 411 f"output_differentiability to be a list of bools with " 412 f"length equal to the number of outputs of this CustomOp " 413 f"got: {output_differentiability}") 414 415 if not isinstance(output_differentiability, list): 416 yell() 417 for diff in output_differentiability: 418 if not isinstance(diff, bool): 419 yell() 420 if len(self._schema.returns) != len(output_differentiability): 421 yell() 422 423 def inner(f): 424 self._check_can_register_backward() 425 self._check_doesnt_have_library_autograd_impl() 426 if not self._registered_autograd_kernel_indirection: 427 self._register_autograd_kernel_indirection() 428 self._register_impl("backward", f, stacklevel=_stacklevel) 429 self._output_differentiability = output_differentiability 430 if self._has_impl("save_for_backward"): 431 self._register_autograd_kernel() 432 return inner 433 434 435@dataclasses.dataclass 436class FuncAndLocation: 437 func: typing.Callable 438 location: str 439 440 441def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName): 442 overload_name = ( 443 "" if operator_name.overload_name is None else operator_name.overload_name 444 ) 445 return _C._dispatch_find_schema_or_throw( 446 f"{cpp_ns}::{str(operator_name.name)}", overload_name 447 ) 448 449 450def validate_namespace(ns: str) -> None: 451 if "." in ns: 452 raise ValueError( 453 f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a ' 454 f"valid variable name)" 455 ) 456 if ns in RESERVED_NS: 457 raise ValueError( 458 f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, " 459 f"please choose something else. " 460 ) 461 462def validate_schema(schema: FunctionSchema) -> None: 463 if not torch._library.utils.is_functional_schema(schema): 464 raise ValueError( 465 f"custom_op only supports functional operators " 466 f"(ops that do not mutate any inputs, do not return " 467 f"views of the inputs, and has at least one return). " 468 f"Got the following non-functional schema: {schema}" 469 ) 470 471 # For simplicity: don't allow self arguments 472 if schema.arguments.self_arg is not None: 473 raise ValueError( 474 f"custom_op does not support arguments named 'self'. Please " 475 f"rename your argument. Got: {schema}" 476 ) 477 478 479def parse_qualname(qualname: str) -> typing.Tuple[str, str]: 480 names = qualname.split("::", 1) 481 if len(names) != 2: 482 raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The " 483 f"operator name should look something like ns::foo") 484 if '.' in names[1]: 485 raise ValueError(f"The torch.custom_ops APIs do not handle overloads, " 486 f"i.e. operator names with '.' in them. " 487 f"Please name your operator something like ns::foo. " 488 f"Got: {qualname}") 489 return names[0], names[1] 490 491 492def validate_device_type(device_type: str) -> None: 493 if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY: 494 raise ValueError( 495 f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type " 496 f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}." 497 ) 498 499 500def supported_param(param: inspect.Parameter) -> bool: 501 return param.kind in ( 502 inspect.Parameter.POSITIONAL_OR_KEYWORD, 503 inspect.Parameter.KEYWORD_ONLY, 504 ) 505 506 507def validate_function_matches_schema( 508 schema: FunctionSchema, func: typing.Callable 509) -> None: 510 sig = inspect.signature(func) 511 512 if not all(supported_param(p) for _, p in sig.parameters.items()): 513 raise ValueError( 514 f"custom_op(..., manual_schema)(func): positional-only args, " 515 f"varargs, and kwargs are not supported. Please rewrite `func` " 516 f"to not have them. Got `func` with signature: {sig}" 517 ) 518 519 if ( 520 any( 521 p.annotation is not inspect.Parameter.empty 522 for _, p in sig.parameters.items() 523 ) 524 or sig.return_annotation is not inspect.Signature.empty 525 ): 526 raise ValueError( 527 f"custom_op(..., manual_schema)(func): When passing in a manual " 528 f"schema, we expect `func` to have no type annotations to avoid " 529 f"ambiguity. Got `func` with signature: {sig}" 530 ) 531 532 positional = [ 533 (name, param) 534 for name, param in sig.parameters.items() 535 if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD 536 ] 537 kwargonly = [ 538 (name, param) 539 for name, param in sig.parameters.items() 540 if param.kind == inspect.Parameter.KEYWORD_ONLY 541 ] 542 543 def error(): 544 raise ValueError( 545 f"custom_op(..., manual_schema)(func): When passing in a manual " 546 f"schema, we expect `func`'s signature to match `manual_schema` " 547 f"(aside from type annotations). " 548 f"func's signature: {sig}, manual_schema: {schema}" 549 ) 550 551 def error_default_args(): 552 raise ValueError( 553 f"custom_op(..., manual_schema)(func): " 554 f"neither func nor manual_schema should have default " 555 f"arguments. Got " 556 f"func's signature: {sig}, manual_schema: {schema}" 557 ) 558 559 def compare(sig_args, schema_args): 560 if len(sig_args) != len(schema_args): 561 error() 562 for (name, param), arg in zip(sig_args, schema_args): 563 if name != arg.name: 564 error() 565 if param.default is not inspect.Parameter.empty or arg.default is not None: 566 error_default_args() 567 568 compare(positional, schema.arguments.flat_positional) 569 compare(kwargonly, schema.arguments.flat_kwarg_only) 570 571 572def report_error_callback(custom_op: typing.Any, key: str) -> None: 573 if key == "Undefined": 574 raise NotImplementedError( 575 f"{custom_op}: There were no Tensor inputs to this operator " 576 f"(e.g. you passed an empty list of Tensors). If your operator is a " 577 f"factory function (that is, it takes no Tensors and constructs " 578 f"a new one), then please use CustomOp.impl_factory to register " 579 f"an implementation for it" 580 ) 581 if key == "Meta": 582 raise NotImplementedError( 583 f"{custom_op}: when running with device='Meta' tensors: there is no " 584 f"abstract impl registered for this CustomOp. Please register one via " 585 f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors" 586 ) 587 if key in ("CPU", "CUDA"): 588 device = key.lower() 589 raise NotImplementedError( 590 f"{custom_op}: when running with device='{device}' tensors: there is no " 591 f"{device} impl registered for this CustomOp. Please register one via " 592 f"CustomOp.impl(device_type='{device}')" 593 ) 594 raise NotImplementedError( 595 f"{custom_op}: No implementation for dispatch key {key}. It is likely " 596 f"that we have not added this functionality yet, please either open an " 597 f"issue or if you're feeling adventurous, use the low-level " 598 f"torch.library API" 599 ) 600 601 602def custom_op_from_existing(op): 603 ns = op.namespace 604 lib = torch.library.Library(ns, "FRAGMENT") 605 name = op.name().split("::")[-1] 606 schema_str = str(op._schema) 607 # CustomOp expects the schema string without the namespace 608 schema_str = schema_str.split("::")[-1] 609 schema = FunctionSchema.parse(schema_str) 610 return CustomOp(lib, ns, schema, name, op, _private_access=True) 611 612 613def get_op(qualname): 614 def error_not_found(): 615 raise ValueError( 616 f"Could not find the operator {qualname}. Please make sure you have " 617 f"already registered the operator and (if registered from C++) " 618 f"loaded it via torch.ops.load_library.") 619 620 ns, name = parse_qualname(qualname) 621 if not hasattr(torch.ops, ns): 622 error_not_found() 623 opnamespace = getattr(torch.ops, ns) 624 if not hasattr(opnamespace, name): 625 error_not_found() 626 packet = getattr(opnamespace, name) 627 if not hasattr(packet, 'default'): 628 error_not_found() 629 return packet.default 630 631 632def _find_custom_op(qualname, also_check_torch_library=False): 633 if qualname in global_registry: 634 return global_registry[qualname] 635 if not also_check_torch_library: 636 raise RuntimeError( 637 f'Could not find custom op "{qualname}". Did you register it via ' 638 f"the torch._custom_ops API?") 639 overload = get_op(qualname) 640 result = custom_op_from_existing(overload) 641 return result 642 643 644def get_abstract_impl(qualname): 645 if qualname not in torch._custom_op.impl.global_registry: 646 return None 647 custom_op = torch._custom_op.impl.global_registry[qualname] 648 if custom_op is None: 649 return None 650 if not custom_op._has_impl("abstract"): 651 return None 652 return custom_op._get_impl("abstract").func 653 654 655def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True): 656 ns, name = qualname.split("::") 657 schema_str = f"{name}{schema}" 658 function_schema = FunctionSchema.parse(schema_str) 659 validate_schema(function_schema) 660 tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else [] 661 lib = library.Library(ns, "FRAGMENT") 662 lib.define(schema_str, tags=tags) 663 ophandle = find_ophandle_or_throw(ns, function_schema.name) 664 result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) 665 result._register_autograd_kernel_indirection() 666 667 torch._C._dispatch_set_report_error_callback( 668 ophandle, functools.partial(report_error_callback, weakref.proxy(result)) 669 ) 670 return get_op(qualname) 671