1# mypy: allow-untyped-defs 2""" 3The weak_script annotation needs to be here instead of inside torch/jit/ so it 4can be used in other places in torch/ (namely torch.nn) without running into 5circular dependency problems 6""" 7 8import ast 9import builtins 10import collections 11import contextlib 12import enum 13import inspect 14import io 15import pickle 16import sys 17import textwrap 18import threading 19import types 20import typing 21import warnings 22import weakref 23from typing import ( 24 Any, 25 Callable, 26 Dict, 27 Final, 28 ForwardRef, 29 get_args, 30 get_origin, 31 List, 32 Optional, 33 Tuple, 34 Type, 35 Union, 36) 37 38import torch 39 40# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`. 41# Explicitly ask to import `torch.distributed.__init__` first. 42# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised. 43import torch.distributed.rpc 44import torch.package._mangling as package_mangling 45from torch._awaits import _Await 46from torch._C import _Await as CAwait, Future as CFuture 47from torch._sources import fake_range, get_source_lines_and_file, parse_def 48from torch.futures import Future 49 50 51IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9) 52IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10) 53 54BuiltinUnionType: Union[Type, Tuple[Type, ...]] 55if sys.version_info >= (3, 10): 56 # NOTE: IS_PY310_PLUS doesn't work with mypy. 57 # cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks 58 BuiltinUnionType = types.UnionType 59else: 60 BuiltinUnionType = () # trick: this makes isinstance short circuit. 61 62LockType: Type 63try: 64 import _thread 65 66 LockType = _thread.LockType 67except ImportError: 68 import _dummy_thread # type: ignore[import-not-found] 69 70 LockType = _dummy_thread.LockType 71 72# Wrapper functions that can call either of 2 functions depending on a boolean 73# argument 74boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = ( 75 weakref.WeakKeyDictionary() 76) # noqa: T484 77 78 79FAKE_FILENAME_PREFIX = "__torch_jit_dataclass" 80 81 82def is_final(ann) -> bool: 83 return ( 84 hasattr(ann, "__module__") 85 and ann.__module__ in {"typing", "typing_extensions"} 86 and (get_origin(ann) is Final or isinstance(ann, type(Final))) 87 ) 88 89 90# allows BroadcastingList instance to be subscriptable 91class BroadcastingListCls: 92 def __getitem__(self, types): 93 return 94 95 96# mypy doesn't support parameters on types, so we have to explicitly type each 97# list size 98BroadcastingList1 = BroadcastingListCls() 99for i in range(2, 7): 100 globals()[f"BroadcastingList{i}"] = BroadcastingList1 101 102 103def is_scripting() -> bool: 104 r""" 105 Function that returns True when in compilation and False otherwise. This 106 is useful especially with the @unused decorator to leave code in your 107 model that is not yet TorchScript compatible. 108 .. testcode:: 109 110 import torch 111 112 @torch.jit.unused 113 def unsupported_linear_op(x): 114 return x 115 116 def linear(x): 117 if torch.jit.is_scripting(): 118 return torch.linear(x) 119 else: 120 return unsupported_linear_op(x) 121 """ 122 return False 123 124 125# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj. 126def _qualified_name(obj, mangle_name=True) -> str: 127 # This special case allows us to override the qualified name on a type. 128 # It's currently used in conjunction with tracing, where we create a 129 # fake module to filter only supported attributes. However, since this 130 # new type is defined as a local class, we need a mechanism to override 131 # its qualname so it appears correctly in the TorchScript system. This, 132 # we set '_jit_override_qualname' with the original traced module's 133 # qualified name, which is picked up here 134 if hasattr(obj, "_jit_override_qualname"): 135 return obj._jit_override_qualname 136 # short-circuit in cases where the object already has a known qualified name 137 if isinstance(obj, torch._C.ScriptFunction): 138 return obj.qualified_name 139 140 if getattr(obj, "__name__", None): 141 name = obj.__name__ 142 # Enum classes do not have `__name__` attr, instead they have `name`. 143 elif isinstance(obj, enum.Enum): 144 name = obj.name 145 else: 146 raise RuntimeError("Could not get name of python class object") 147 148 if name == "<lambda>": 149 name = "_lambda" # make name a valid identifier 150 151 module_name = obj.__module__ 152 153 # If the module is actually a torchbind module, then we should short circuit 154 if module_name == "torch._classes": 155 return obj.qualified_name 156 157 # The Python docs are very clear that `__module__` can be None, but I can't 158 # figure out when it actually would be. 159 if module_name is None: 160 raise RuntimeError( 161 f"Could not get qualified name for class '{name}': " 162 "__module__ can't be None." 163 ) 164 165 # if getattr(sys.modules[module_name], name) is not obj: 166 # raise RuntimeError(f"Could not get qualified name for class '{name}': " 167 # f"the attr {name} on module {module_name} is not the class") 168 169 # torch.package and TorchScript have separate mangling schemes to avoid 170 # name collisions from multiple packages. To avoid them interfering with 171 # each other, normalize the package manging here. 172 if package_mangling.is_mangled(module_name): 173 module_name = module_name.replace("<", "_") 174 module_name = module_name.replace(">", "_") 175 176 # The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h 177 # does not need mangle the python class name. 178 if mangle_name: 179 # __main__ is a builtin module, so rewrite it to "__torch__". 180 if module_name == "__main__": 181 module_name = "__torch__" 182 else: 183 # Everything else gets a "__torch__" prefix to avoid name collisions 184 # with the names of user values. 185 module_name = "__torch__." + module_name 186 187 if "." in name: 188 raise RuntimeError( 189 f"Could not get qualified name for class '{name}': " 190 f"'{name}' is not a valid identifier" 191 ) 192 193 return module_name + "." + name 194 195 196class SourceLoader: 197 def __init__(self): 198 self.content = {} 199 200 def cache(self, fn, source): 201 self.content[fn] = source 202 203 def get_source(self, fn): 204 return self.content.get(fn) 205 206 207loader = SourceLoader() 208 209 210def createResolutionCallbackFromEnv(lookup_base): 211 """ 212 Creates a resolution callback that will look up qualified names in an 213 environment, starting with `lookup_base` for the base of any qualified 214 names, then proceeding down the lookup chain with the resolved object. 215 216 You should not use this directly, it should only be used from the other 217 createResolutionCallbackFrom* functions. 218 """ 219 220 def lookupInModule(qualified_name, module): 221 if "." in qualified_name: 222 base, remaining_pieces = qualified_name.split(".", maxsplit=1) 223 module_value = getattr(module, base) 224 return lookupInModule(remaining_pieces, module_value) 225 else: 226 return getattr(module, qualified_name) 227 228 def parseNestedExpr(expr, module) -> Tuple[Any, int]: 229 i = 0 230 while i < len(expr) and expr[i] not in (",", "[", "]"): 231 i += 1 232 233 # Special case logic for the empty Tuple as a subscript (used 234 # in the type annotation `Tuple[()]`) 235 if expr[:i] == "()": 236 return (), i 237 238 base = lookupInModule(expr[:i].strip(), module) 239 assert base is not None, f"Unresolvable type {expr[:i]}" 240 if i == len(expr) or expr[i] != "[": 241 return base, i 242 243 assert expr[i] == "[" 244 parts = [] 245 while expr[i] != "]": 246 part_len = 0 247 i += 1 248 part, part_len = parseNestedExpr(expr[i:], module) 249 parts.append(part) 250 i += part_len 251 if len(parts) > 1: 252 return base[tuple(parts)], i + 1 253 else: 254 return base[parts[0]], i + 1 255 256 def parseExpr(expr, module): 257 try: 258 value, len_parsed = parseNestedExpr(expr, module) 259 assert len_parsed == len( 260 expr 261 ), "whole expression was not parsed, falling back to c++ parser" 262 return value 263 except Exception: 264 """ 265 The python resolver fails in several cases in known unit tests, and is intended 266 to fall back gracefully to the c++ resolver in general. For example, python 2 style 267 annotations which are frequent in our unit tests often fail with types e.g. int not 268 resolvable from the calling frame. 269 """ 270 return None 271 272 return lambda expr: parseExpr(expr, lookup_base) 273 274 275def createResolutionCallbackFromFrame(frames_up: int = 0): 276 """ 277 Creates a function which, given a string variable name, 278 returns the value of the variable in the scope of the caller of 279 the function which called createResolutionCallbackFromFrame (by default). 280 281 This is used to enable access in-scope Python variables inside 282 TorchScript fragments. 283 284 frames_up is number of additional frames to go up on the stack. 285 The default value is 0, which correspond to the frame of the caller 286 of createResolutionCallbackFromFrame. Also for example, if frames_up is set 287 to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame 288 will be taken. 289 290 For example, the following program prints 2:: 291 292 def bar(): 293 cb = createResolutionCallbackFromFrame(1) 294 print(cb("foo")) 295 296 297 def baz(): 298 foo = 2 299 bar() 300 301 302 baz() 303 """ 304 frame = inspect.currentframe() 305 i = 0 306 while i < frames_up + 1: 307 assert frame is not None 308 frame = frame.f_back 309 i += 1 310 311 assert frame is not None 312 f_locals = frame.f_locals 313 f_globals = frame.f_globals 314 315 class env: 316 def __getattr__(self, key): 317 if key in f_locals: 318 return f_locals[key] 319 elif key in f_globals: 320 return f_globals[key] 321 elif key in dir(builtins): 322 return getattr(builtins, key) 323 324 return createResolutionCallbackFromEnv(env()) 325 326 327def get_closure(fn): 328 """ 329 Get a dictionary of closed over variables from a function 330 """ 331 captures = {} 332 captures.update(fn.__globals__) 333 334 for index, captured_name in enumerate(fn.__code__.co_freevars): 335 captures[captured_name] = fn.__closure__[index].cell_contents 336 337 return captures 338 339 340# [local resolution in python] 341# Depending on where a variable is defined, and where it is used, we may 342# or may not be able to recover its value when recursively compiling a 343# script function. Remember in the general case, a module or function is 344# first defined and then later scripted. This means we do not have a 345# chance to capture the active frames when the function is defined. Hence any 346# name resolution has to happen later on the created closure. The way 347# python captures type annotations restricts what we can recover. The 348# follow example illustrates the different cases: 349# 350# class MyGlobalClass: 351# ... 352# def my_local_scope(): 353# @torch.jit.script 354# class MyClass: 355# ... 356# @torch.jit.script 357# class MyClassUsedAsVar: 358# ... 359# def eg(x: MyClass, y: MyGlobalClass): 360# a_local_capture : Foo 361# return MyClassUsedAsVar(x) 362# 363# MyGlobalClass is defined in the __globals__ dictionary of function 364# 'eg', so it is always recoverable. my_local_scope introduces a new local 365# variable scope in the function. Classes defined here are only visible as 366# local variables. For the case of MyClassUsedAsVar, it is captured 367# because it is used as a variable inside the body of the function, and we 368# can resolve it using the captures returned from `get_closure`. However, 369# the type annotations are not captured by the closure. In Python 370# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as 371# annotations on `eg``, but starting in Python 4.0, they will represented as 372# strings and no longer present. Furthermore, since the body of `eg` does 373# not reference those names, they do not appear in the list of closed over 374# variables. In Python 2.x, type annotations are in comments, leading to a 375# similar situation where their definitions are not available. We anticipate 376# that most users will not run into this issue because their modules and 377# functions will be defined at a global scope like MyGlobalClass. In cases 378# where they are not, it is possible to work around issues by declaring the 379# values global in the function. 380# In Python 3.9 declaring class as global will make it invisible to 381# `inspect.getsource`, see https://bugs.python.org/issue42666 . 382# This could be worked around by manualy adding it to `global()` dictionary. 383 384 385def createResolutionCallbackFromClosure(fn): 386 """ 387 Create a resolutionCallback by introspecting the function instead of 388 looking up the stack for the enclosing scope 389 """ 390 closure = get_closure(fn) 391 392 class closure_lookup: 393 # This is a class since `closure` is a dict and it's easier in 394 # `env_helper` if everything just works with `getattr` calls 395 def __getattr__(self, key): 396 if key in closure: 397 return closure[key] 398 elif hasattr(typing, key): 399 return getattr(typing, key) 400 elif hasattr(builtins, key): 401 return getattr(builtins, key) 402 return None 403 404 return createResolutionCallbackFromEnv(closure_lookup()) 405 406 407def can_compile_class(cls) -> bool: 408 # If any of the functions on a type don't have a code object, this type can't 409 # be compiled and is probably a builtin / bound from C 410 if is_ignored_fn(cls): 411 return False 412 413 # Ignore the following list of built-in classes. 414 ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception) 415 if issubclass(cls, ignored_builtin_classes): 416 return False 417 418 names = cls.__dict__ 419 fns = [ 420 getattr(cls, name) 421 for name in names 422 if inspect.isroutine(getattr(cls, name, None)) 423 ] 424 has_code = [hasattr(fn, "__code__") for fn in fns] 425 return all(has_code) 426 427 428def get_callable_argument_names(fn) -> List[str]: 429 """ 430 Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`. 431 Returns an empty list when other types of arguments are present. 432 433 This is used by `torch.jit.trace` to assign meaningful argument names to 434 traced functions and modules. 435 436 Args: 437 fn: A callable. 438 Returns: 439 Argument names: List[str] 440 """ 441 # inspect.signature may fail, give up in that case. 442 try: 443 callable_signature = inspect.signature(fn) 444 except Exception: 445 return [] 446 447 argument_names = [] 448 for name, param in callable_signature.parameters.items(): 449 # All four other types of arguments do not map to individual values 450 # with a keyword as name. 451 if not param.kind == param.POSITIONAL_OR_KEYWORD: 452 continue 453 454 argument_names.append(name) 455 456 return argument_names 457 458 459def get_annotation_str(annotation): 460 """ 461 Convert an AST node containing a type annotation to the string present in the source 462 that represents the same annotation. 463 """ 464 if isinstance(annotation, ast.Name): 465 return annotation.id 466 elif isinstance(annotation, ast.Attribute): 467 return ".".join([get_annotation_str(annotation.value), annotation.attr]) 468 elif isinstance(annotation, ast.Subscript): 469 # In Python3.9+ subscript indicies are not wrapped in ast.Index 470 subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value # type: ignore[attr-defined] 471 return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]" 472 elif isinstance(annotation, ast.Tuple): 473 return ",".join([get_annotation_str(elt) for elt in annotation.elts]) 474 elif isinstance(annotation, ast.Constant): 475 return f"{annotation.value}" 476 477 # If an AST node is not handled here, it's probably handled in ScriptTypeParser. 478 return None 479 480 481def get_type_hint_captures(fn): 482 """ 483 Get a dictionary containing type resolution mappings necessary to resolve types 484 for the literal annotations on 'fn'. These are not considered to be closed-over by fn 485 and must be obtained separately (e.g. using this function). 486 487 Args: 488 fn: A callable. 489 Returns: 490 A Dict[str, Any] containing a mapping from the literal annotations used on 491 fn to the Python objects they refer to. 492 """ 493 # First, try to get the source of the function. We'll need to parse it to find the actual string names 494 # that were used to annotate the types, since inspect.signature() will only return the class object that 495 # the annotation refers to, not the string name. If we can't get the source, simply return an empty dict. 496 # This may happen in cases where the function is synthesized dynamically at runtime. 497 src = loader.get_source(fn) 498 if src is None: 499 try: 500 src = inspect.getsource(fn) 501 except OSError as e: 502 raise OSError( 503 f"Failed to get source for {fn} using inspect.getsource" 504 ) from e 505 506 # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated 507 # types are strings. These are only understood by TorchScript in the context of a type annotation 508 # that refers to a class in its own definition, but trying to include a mapping for this in the result 509 # function would cause infinite recursion because the class is currently being compiled. 510 # In addition, there is logic in ScriptTypeParser to handle this. 511 signature = inspect.signature(fn) 512 name_to_type = { 513 name: parameter.annotation 514 for name, parameter in signature.parameters.items() 515 if parameter.annotation is not inspect.Parameter.empty 516 and not isinstance(parameter.annotation, str) 517 } 518 519 # Then, get the literal type annotations from the function declaration 520 # by source inspection. This accounts for the case in which aliases are used 521 # to annotate the arguments (e.g device_t = torch.device, and then d: device_t). 522 # frontend.py cannot be used here because it includes _jit_internal, so use ast instead. 523 a = ast.parse(textwrap.dedent(src)) 524 if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef): 525 raise RuntimeError(f"Expected {fn} to be a function") 526 f = a.body[0] 527 528 # Prepare a dictionary of source annotation -> type, which will be the final result of this function, 529 # by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping 530 # them to the type object corresponding to the annotation via name_to_type using the parameter name. 531 annotation_to_type = {} 532 533 for arg in f.args.args: 534 # Get the source type annotation string for this argument if possible. 535 arg_annotation_str = ( 536 get_annotation_str(arg.annotation) if arg.annotation else None 537 ) 538 539 # If the argument has no annotation or get_annotation_str cannot convert it to a string, 540 # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle 541 # this in the latter case. 542 if arg_annotation_str is None: 543 continue 544 545 # Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not 546 # be present in name_to_type is that the annotation itself is a string and not a type object 547 # (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this. 548 arg_name = arg.arg 549 if arg_name in name_to_type: 550 annotation_to_type[arg_annotation_str] = name_to_type[arg_name] 551 552 # If there is a valid return annotation, include it in annotation_to_type. As with argument annotations, 553 # the literal annotation has to be convertible to a string by get_annotation_str, and the actual type 554 # of the annotation cannot be a string. 555 literal_return_annotation = get_annotation_str(f.returns) 556 valid_literal_annotation = literal_return_annotation is not None 557 return_annotation = signature.return_annotation 558 valid_return_annotation_type = ( 559 return_annotation is not inspect.Parameter.empty 560 and not isinstance(return_annotation, str) 561 ) 562 if valid_literal_annotation and valid_return_annotation_type: 563 annotation_to_type[literal_return_annotation] = return_annotation 564 565 return annotation_to_type 566 567 568def createResolutionCallbackForClassMethods(cls): 569 """ 570 This looks at all the methods defined in a class and pulls their closed-over 571 variables into a dictionary and uses that to resolve variables. 572 """ 573 # cls is a type here, so `ismethod` is false since the methods on the type 574 # aren't bound to anything, so Python treats them as regular functions 575 fns = [ 576 getattr(cls, name) 577 for name in cls.__dict__ 578 if inspect.isroutine(getattr(cls, name)) 579 ] 580 # Skip built-ins, as they do not have global scope nor type hints 581 # Needed to support `enum.Enum` derived classes in Python-3.11 582 # That adds `_new_member_` property which is an alias to `__new__` 583 fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")] 584 captures = {} 585 586 for fn in fns: 587 captures.update(get_closure(fn)) 588 captures.update(get_type_hint_captures(fn)) 589 590 def lookup_in_class(key): 591 if key in captures: 592 return captures[key] 593 else: 594 return getattr(builtins, key, None) 595 596 return lookup_in_class 597 598 599def boolean_dispatch( 600 arg_name, 601 arg_index, 602 default, 603 if_true, 604 if_false, 605 module_name, 606 func_name, 607): 608 """ 609 Dispatches to either of 2 script functions based on a boolean argument. 610 In TorchScript, the boolean argument must be constant so that the correct 611 function to use can be determined at compile time. 612 """ 613 614 def fn(*args, **kwargs): 615 dispatch_flag = default 616 if arg_name in kwargs: 617 dispatch_flag = kwargs[arg_name] 618 elif arg_index < len(args): 619 dispatch_flag = args[arg_index] 620 621 if dispatch_flag: 622 return if_true(*args, **kwargs) 623 else: 624 return if_false(*args, **kwargs) 625 626 if if_true.__doc__ is None and if_false.__doc__ is not None: 627 doc = if_false.__doc__ 628 if_true.__doc__ = doc 629 elif if_false.__doc__ is None and if_true.__doc__ is not None: 630 doc = if_true.__doc__ 631 if_false.__doc__ = doc 632 elif if_false.__doc__ is None and if_true.__doc__ is None: 633 # neither function has a docstring 634 doc = None 635 else: 636 raise RuntimeError("only one function can have a docstring") 637 fn.__doc__ = doc 638 639 if module_name is not None: 640 fn.__module__ = module_name 641 if func_name is not None: 642 fn.__name__ = func_name 643 644 boolean_dispatched[fn] = { 645 "if_true": if_true, 646 "if_false": if_false, 647 "index": arg_index, 648 "default": default, 649 "arg_name": arg_name, 650 } 651 return fn 652 653 654class FunctionModifiers: 655 """ 656 Used to denote the behavior of a function in TorchScript. See export() and 657 ignore() for details. 658 """ 659 660 UNUSED = "unused (ignored and replaced with raising of an exception)" 661 IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)" 662 EXPORT = "export (compile this function even if nothing calls it)" 663 DEFAULT = "default (compile if called from a exported function / forward)" 664 COPY_TO_SCRIPT_WRAPPER = ( 665 "if this method is not scripted, copy the python method onto the scripted model" 666 ) 667 _DROP = "_drop (function is fully ignored, declaration can be unscriptable)" 668 669 670def export(fn): 671 """ 672 This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a 673 :class:`ScriptModule` and should be compiled. 674 675 ``forward`` implicitly is assumed to be an entry point, so it does not need this decorator. 676 Functions and methods called from ``forward`` are compiled as they are seen 677 by the compiler, so they do not need this decorator either. 678 679 Example (using ``@torch.jit.export`` on a method): 680 681 .. testcode:: 682 683 import torch 684 import torch.nn as nn 685 686 class MyModule(nn.Module): 687 def implicitly_compiled_method(self, x): 688 return x + 99 689 690 # `forward` is implicitly decorated with `@torch.jit.export`, 691 # so adding it here would have no effect 692 def forward(self, x): 693 return x + 10 694 695 @torch.jit.export 696 def another_forward(self, x): 697 # When the compiler sees this call, it will compile 698 # `implicitly_compiled_method` 699 return self.implicitly_compiled_method(x) 700 701 def unused_method(self, x): 702 return x - 20 703 704 # `m` will contain compiled methods: 705 # `forward` 706 # `another_forward` 707 # `implicitly_compiled_method` 708 # `unused_method` will not be compiled since it was not called from 709 # any compiled methods and wasn't decorated with `@torch.jit.export` 710 m = torch.jit.script(MyModule()) 711 """ 712 fn._torchscript_modifier = FunctionModifiers.EXPORT 713 return fn 714 715 716def unused(fn): 717 """ 718 This decorator indicates to the compiler that a function or method should 719 be ignored and replaced with the raising of an exception. This allows you 720 to leave code in your model that is not yet TorchScript compatible and still 721 export your model. 722 723 Example (using ``@torch.jit.unused`` on a method):: 724 725 import torch 726 import torch.nn as nn 727 728 729 class MyModule(nn.Module): 730 def __init__(self, use_memory_efficient): 731 super().__init__() 732 self.use_memory_efficient = use_memory_efficient 733 734 @torch.jit.unused 735 def memory_efficient(self, x): 736 import pdb 737 738 pdb.set_trace() 739 return x + 10 740 741 def forward(self, x): 742 # Use not-yet-scriptable memory efficient mode 743 if self.use_memory_efficient: 744 return self.memory_efficient(x) 745 else: 746 return x + 10 747 748 749 m = torch.jit.script(MyModule(use_memory_efficient=False)) 750 m.save("m.pt") 751 752 m = torch.jit.script(MyModule(use_memory_efficient=True)) 753 # exception raised 754 m(torch.rand(100)) 755 """ 756 if isinstance(fn, property): 757 prop = fn 758 setattr( # noqa: B010 759 prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED 760 ) 761 762 if prop.fset: 763 setattr( # noqa: B010 764 prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED 765 ) 766 767 return prop 768 769 fn._torchscript_modifier = FunctionModifiers.UNUSED 770 return fn 771 772 773# No op context manager from python side 774class _IgnoreContextManager(contextlib.AbstractContextManager): 775 def __init__(self, **kwargs): 776 pass 777 778 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 779 pass 780 781 782def ignore(drop=False, **kwargs): 783 """ 784 This decorator indicates to the compiler that a function or method should 785 be ignored and left as a Python function. This allows you to leave code in 786 your model that is not yet TorchScript compatible. If called from TorchScript, 787 ignored functions will dispatch the call to the Python interpreter. Models with ignored 788 functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead. 789 790 Example (using ``@torch.jit.ignore`` on a method):: 791 792 import torch 793 import torch.nn as nn 794 795 796 class MyModule(nn.Module): 797 @torch.jit.ignore 798 def debugger(self, x): 799 import pdb 800 801 pdb.set_trace() 802 803 def forward(self, x): 804 x += 10 805 # The compiler would normally try to compile `debugger`, 806 # but since it is `@ignore`d, it will be left as a call 807 # to Python 808 self.debugger(x) 809 return x 810 811 812 m = torch.jit.script(MyModule()) 813 814 # Error! The call `debugger` cannot be saved since it calls into Python 815 m.save("m.pt") 816 817 Example (using ``@torch.jit.ignore(drop=True)`` on a method): 818 819 .. testcode:: 820 821 import torch 822 import torch.nn as nn 823 824 class MyModule(nn.Module): 825 @torch.jit.ignore(drop=True) 826 def training_method(self, x): 827 import pdb 828 pdb.set_trace() 829 830 def forward(self, x): 831 if self.training: 832 self.training_method(x) 833 return x 834 835 m = torch.jit.script(MyModule()) 836 837 # This is OK since `training_method` is not saved, the call is replaced 838 # with a `raise`. 839 m.save("m.pt") 840 841 .. testcleanup:: 842 843 import os 844 os.remove('m.pt') 845 """ 846 847 if callable(drop): 848 # used without any args, so drop is actually a function 849 # @torch.jit.ignore 850 # def fn(...): 851 fn = drop 852 fn._torchscript_modifier = FunctionModifiers.IGNORE 853 return fn 854 855 if not isinstance(drop, bool): 856 raise RuntimeError( 857 "Argument to @torch.jit.ignore must be a bool or " 858 f"a function but got {drop}" 859 ) 860 861 # for backwards compat 862 drop_on_export = kwargs.pop("drop_on_export", None) 863 if drop_on_export: 864 warnings.warn( 865 "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function " 866 "call on compilation. Use torch.jit.unused now. {}", 867 category=FutureWarning, 868 ) 869 870 drop = drop_on_export 871 elif drop: 872 warnings.warn( 873 "ignore(True) has been deprecated. TorchScript will now drop the function " 874 "call on compilation. Use torch.jit.unused now. {}", 875 category=FutureWarning, 876 ) 877 878 def decorator(fn): 879 if drop: 880 fn._torchscript_modifier = FunctionModifiers.UNUSED 881 else: 882 fn._torchscript_modifier = FunctionModifiers.IGNORE 883 return fn 884 885 return decorator 886 887 888def _drop(fn): 889 fn._torchscript_modifier = FunctionModifiers._DROP 890 return fn 891 892 893def _copy_to_script_wrapper(fn): 894 fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER 895 return fn 896 897 898def module_has_exports(mod): 899 for name in dir(mod): 900 if hasattr(mod, name): 901 item = getattr(mod, name) 902 if callable(item): 903 if get_torchscript_modifier(item) is FunctionModifiers.EXPORT: 904 return True 905 return False 906 907 908# WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you 909# rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to 910# allow JIT'd code to still be covered. 911def should_drop(fn) -> bool: 912 attr = get_torchscript_modifier(fn) 913 if attr is None: 914 return False 915 return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP 916 917 918def is_ignored_fn(fn) -> bool: 919 mod = get_torchscript_modifier(fn) 920 return ( 921 mod is FunctionModifiers.UNUSED 922 or mod is FunctionModifiers.IGNORE 923 or mod is FunctionModifiers._DROP 924 ) 925 926 927def _is_drop_fn(fn) -> bool: 928 mod = get_torchscript_modifier(fn) 929 return mod is FunctionModifiers._DROP 930 931 932def is_static_fn(cls, fn) -> bool: 933 return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod) 934 935 936def get_static_fn(cls, fn): 937 return inspect.getattr_static(cls, fn).__func__ 938 939 940def get_torchscript_modifier(fn): 941 if not callable(fn): 942 return None 943 if hasattr(fn, "__func__"): 944 fn = fn.__func__ 945 return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT) 946 947 948def copy_torchscript_modifier(orig, new) -> None: 949 attr = get_torchscript_modifier(orig) 950 if attr is None: 951 return 952 new._torchscript_modifier = attr 953 954 955# overloading registration 956# overloads get registered in this file, and compiled in torch/jit/__init__.py 957# so that they can be imported in nn/functional.py without an import cycle 958 959# qualified_name => list[overload_functions] 960_overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484 961 962 963_OVERLOAD_EXAMPLE = """ 964Example usage of overload function: 965@torch.jit._overload 966def my_function(x: type0) -> type0: # decl 1 967 pass 968 969@torch.jit._overload 970def my_function(x: type1) -> type1: # decl 2 971 pass 972 973def my_function(x): # implementation 974 if isinstance(x, type0): 975 return x 976 elif isinstance(x, type1): 977 return x 978""" 979 980 981def get_overload_no_implementation_error_message(kind, obj): 982 sourcelines, file_lineno, filename = get_source_lines_and_file(obj) 983 return ( 984 f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make ' 985 f"sure a definition is provided and defined after all overload declarations.\n" 986 f'File "{filename}", line {file_lineno}:\n' 987 + "".join(sourcelines) 988 + "\n" 989 + _OVERLOAD_EXAMPLE 990 ) 991 992 993def _check_overload_body(func): 994 try: 995 parsed_def = parse_def(func) 996 except OSError as e: 997 # Parsing the function definition can raise an OSError if source is unavailable. 998 # Since this is just an initial check, just raise a warning if this is the case. 999 warnings.warn( 1000 f"Unable to retrieve source for @torch.jit._overload function: {func}." 1001 ) 1002 return 1003 1004 body = parsed_def.ast.body[0].body 1005 1006 def is_pass(x): 1007 return isinstance(x, ast.Pass) 1008 1009 def is_ellipsis(x): 1010 return ( 1011 isinstance(x, ast.Expr) 1012 and isinstance(x.value, ast.Constant) 1013 and x.value.value is Ellipsis 1014 ) 1015 1016 if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])): 1017 msg = ( 1018 "Only `pass` statement or `...` can be the body of overload declaration:\n" 1019 ) 1020 msg += "\n".join(parsed_def.source.split("\n")[:3]) 1021 msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE 1022 raise RuntimeError(msg) 1023 1024 1025def _overload(func): 1026 _check_overload_body(func) 1027 qual_name = _qualified_name(func) 1028 global _overloaded_fns 1029 fn_overload_list = _overloaded_fns.get(qual_name) 1030 if fn_overload_list is None: 1031 fn_overload_list = [] 1032 _overloaded_fns[qual_name] = fn_overload_list 1033 fn_overload_list.append(func) 1034 return func 1035 1036 1037def _get_fn_overloads(qual_name): 1038 return _overloaded_fns.get(qual_name) 1039 1040 1041def _clear_fn_overloads(qual_name) -> None: 1042 del _overloaded_fns[qual_name] 1043 1044 1045def get_class_name_lineno(method) -> Tuple[str, int]: 1046 current_frame = inspect.currentframe() 1047 1048 # one for the get_class_name call, one for _overload_method call 1049 for i in range(2): 1050 assert ( 1051 current_frame is not None 1052 ) # assert current frame is not an Optional[FrameType] 1053 current_frame = current_frame.f_back 1054 1055 assert current_frame is not None # same here 1056 class_name = current_frame.f_code.co_name 1057 line_no = current_frame.f_code.co_firstlineno 1058 return class_name, line_no 1059 1060 1061# At the point the decorator is applied to class methods the method 1062# has no reference to its owning class. _qualified_name would not include 1063# the class it is defined in, so any methods with the same name in the same file 1064# would have the same _qualified_name, even if they were defined in different 1065# classes. This problem only exists in python 2. 1066# We get around this problem by looking at the stack frame and identifying 1067# the class name, and throwing an error whenever overloads are used 1068# when modules of the same name are in the same file 1069 1070# qualified_name => class name => list[overload_functions] 1071_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484 1072 1073 1074# (qualified_name, class name) => class_fileno 1075_overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {} 1076 1077 1078def _overload_method(func): 1079 _check_overload_body(func) 1080 qual_name = _qualified_name(func) 1081 global _overloaded_methods 1082 class_name_map = _overloaded_methods.get(qual_name, None) 1083 if class_name_map is None: 1084 class_name_map = {} 1085 _overloaded_methods[qual_name] = class_name_map 1086 1087 class_name, line_no = get_class_name_lineno(func) 1088 method_overloads = class_name_map.get(class_name, None) 1089 if method_overloads is None: 1090 method_overloads = [] 1091 class_name_map[class_name] = method_overloads 1092 _overloaded_method_class_fileno[(qual_name, class_name)] = line_no 1093 else: 1094 existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)] 1095 if existing_lineno != line_no: 1096 raise RuntimeError( 1097 "Cannot currently overload the same method name in two different" 1098 " classes with the same name in the same module" 1099 ) 1100 1101 method_overloads.append(func) 1102 return func 1103 1104 1105def _get_overloaded_methods(method, mod_class): 1106 # TODO: __name__ not set for submodules in recursive script 1107 if not hasattr(method, "__name__"): 1108 return None 1109 qual_name = _qualified_name(method) 1110 class_name_map = _overloaded_methods.get(qual_name, None) 1111 if class_name_map is None: 1112 return None 1113 overloads = class_name_map.get(mod_class.__name__, None) 1114 if overloads is None: 1115 return None 1116 1117 method_line_no = get_source_lines_and_file(method)[1] 1118 mod_class_fileno = get_source_lines_and_file(mod_class)[1] 1119 mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0]) 1120 if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno): 1121 raise AssertionError( 1122 "Overloads are not useable when a module is redeclared within the same file: " 1123 + str(method) 1124 ) 1125 return overloads 1126 1127 1128def is_tuple(ann) -> bool: 1129 if ann is Tuple: 1130 raise_error_container_parameter_missing("Tuple") 1131 1132 # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule 1133 if not hasattr(ann, "__module__"): 1134 return False 1135 1136 ann_origin = get_origin(ann) 1137 if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple: 1138 return True 1139 return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple) 1140 1141 1142def is_list(ann) -> bool: 1143 if ann is List: 1144 raise_error_container_parameter_missing("List") 1145 1146 if not hasattr(ann, "__module__"): 1147 return False 1148 1149 ann_origin = get_origin(ann) 1150 if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list: 1151 return True 1152 return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list) 1153 1154 1155def is_dict(ann) -> bool: 1156 if ann is Dict: 1157 raise_error_container_parameter_missing("Dict") 1158 1159 if not hasattr(ann, "__module__"): 1160 return False 1161 1162 ann_origin = get_origin(ann) 1163 if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict: 1164 return True 1165 return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict) 1166 1167 1168def is_union(ann): 1169 if ann is Union: 1170 raise_error_container_parameter_missing("Union") 1171 1172 return isinstance(ann, BuiltinUnionType) or ( 1173 hasattr(ann, "__module__") 1174 and ann.__module__ == "typing" 1175 and (get_origin(ann) is Union) 1176 ) 1177 1178 1179def is_optional(ann): 1180 if ann is Optional: 1181 raise_error_container_parameter_missing("Optional") 1182 1183 def is_optional_as_optional(ann): 1184 return ( 1185 hasattr(ann, "__module__") 1186 and ann.__module__ == "typing" 1187 and (get_origin(ann) is Optional) 1188 ) 1189 1190 def is_union_as_optional(ann): 1191 ann_args = get_args(ann) 1192 return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args) 1193 1194 return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann)) 1195 1196 1197def is_future(ann) -> bool: 1198 if ann is Future: 1199 raise RuntimeError( 1200 "Attempted to use Future without a " 1201 "contained type. Please add a contained type, e.g. " 1202 "Future[int]" 1203 ) 1204 return get_origin(ann) is Future 1205 1206 1207def is_await(ann) -> bool: 1208 if ann is _Await: 1209 return True 1210 return get_origin(ann) is _Await 1211 1212 1213if torch.distributed.rpc.is_available(): 1214 from torch._C._distributed_rpc import PyRRef 1215 from torch.distributed.rpc import RRef 1216 1217 def is_rref(ann) -> bool: 1218 if ann is RRef: 1219 raise RuntimeError( 1220 "Attempted to use RRef without a " 1221 "contained type. Please add a contained type, e.g. " 1222 "RRef[int]" 1223 ) 1224 return get_origin(ann) is RRef 1225 1226 def is_rref_instance(obj) -> bool: 1227 return isinstance(obj, PyRRef) 1228 1229else: 1230 1231 def is_rref_instance(obj) -> bool: 1232 # If the RPC module doesn't exist then RRefs don't exist either. 1233 return False 1234 1235 1236def _try_get_dispatched_fn(fn): 1237 if not callable(fn): 1238 return None 1239 return boolean_dispatched.get(fn) 1240 1241 1242def _get_named_tuple_properties( 1243 obj, 1244 loc: Optional[torch._C._jit_tree_views.SourceRange] = None, 1245 rcb=None, 1246): 1247 if loc is None: 1248 loc = fake_range() 1249 1250 assert issubclass(obj, tuple) and hasattr(obj, "_fields") 1251 if hasattr(obj, "_field_defaults"): 1252 defaults = [ 1253 obj._field_defaults[field] 1254 for field in obj._fields 1255 if field in obj._field_defaults 1256 ] 1257 else: 1258 defaults = [] 1259 # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function 1260 # Also, annotations from base class are not inherited so they need to be queried explicitly 1261 if sys.version_info[:2] < (3, 10): 1262 obj_annotations = getattr(obj, "__annotations__", {}) 1263 else: 1264 obj_annotations = inspect.get_annotations(obj) 1265 if len(obj_annotations) == 0 and hasattr(obj, "__base__"): 1266 obj_annotations = inspect.get_annotations(obj.__base__) 1267 1268 annotations = [] 1269 for field in obj._fields: 1270 if field in obj_annotations: 1271 field_type = obj_annotations[field] 1272 # [Note: ForwardRef annotations in NamedTuple attributes] 1273 # NamedTuple types are slightly different from normal types. 1274 # 1275 # Normally, annotations are evaluted like this (during jit.script): 1276 # 1. Load strings of python code into c++ and parse. 1277 # 2. Get annotations as strings 1278 # 3. Use the PythonResolver's resolution callback (rcb) to convert 1279 # the string into a python object 1280 # 4. We call into annotations.py:ann_to_type to convert python obj 1281 # from step 3 into a type that torchscript understands. 1282 # 1283 # NamedTuples are more complicated, because it has sub-types. 1284 # Normally, once we have the NamedTuple type object from #3, 1285 # we can just look at the annotation literal values and use 1286 # ann_to_type directly on them. 1287 # 1288 # But sometimes, users will annotate with string literals, e.g. 1289 # x: 'int' 1290 # This also happens with PEP563 (from __forward__ import annotations) 1291 # 1292 # These annotations appear in the annotation dict as ForwardRef('int'). 1293 # 1294 # Then, we need to convert the string into a python object. This 1295 # requires having local context for custom objects or imported types. 1296 # rcb() is what gives us this. So, we plumb rcb through the stack so 1297 # it can be used in this context for the if block below. 1298 # 1299 # FAQ: 1300 # - Why do we need this special handling for NamedTuple but string 1301 # annotations work fine for normal types? Normally, we parse the 1302 # string directly and then call rcb() directly from C++. 1303 # - Why not use ForwardRef._evaluate? For that, we need globals() 1304 # and locals() for the local context where the NamedTuple was defined. 1305 # rcb is what lets us look up into these. So, basically rcb does the 1306 # hard work for us. 1307 if isinstance(field_type, ForwardRef) and rcb is not None: 1308 rcb_type = rcb(field_type.__forward_arg__) 1309 # rcb returns None if it can't find anything. 1310 if rcb_type is None: 1311 raise ValueError( 1312 f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}." 1313 f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858." 1314 f" Issue occurred at {loc.highlight()}" 1315 ) 1316 field_type = rcb_type 1317 the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb) 1318 annotations.append(the_type) 1319 else: 1320 annotations.append(torch._C.TensorType.getInferred()) 1321 return type(obj).__name__, obj._fields, annotations, defaults 1322 1323 1324def _create_named_tuple( 1325 t, 1326 unqual_name: str, 1327 field_names: List[str], 1328 defaults: Tuple[Any, ...], 1329): 1330 TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc] 1331 return TupleType(*t) 1332 1333 1334@contextlib.contextmanager 1335def _disable_emit_hooks(): 1336 hooks = torch._C._jit_get_emit_hooks() 1337 torch._C._jit_set_emit_hooks(None, None) 1338 try: 1339 yield 1340 finally: 1341 torch._C._jit_set_emit_hooks(hooks[0], hooks[1]) 1342 1343 1344def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811 1345 def __enter__(self) -> None: 1346 self.hooks = torch._C._jit_get_emit_hooks() 1347 torch._C._jit_set_emit_hooks(None, None) 1348 1349 def __exit__(self, *args) -> None: 1350 torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1]) 1351 1352 1353def _is_exception(obj) -> bool: 1354 if not inspect.isclass(obj): 1355 return False 1356 return issubclass(obj, Exception) 1357 1358 1359def raise_error_container_parameter_missing(target_type) -> None: 1360 if target_type == "Dict": 1361 raise RuntimeError( 1362 "Attempted to use Dict without " 1363 "contained types. Please add contained type, e.g. " 1364 "Dict[int, int]" 1365 ) 1366 raise RuntimeError( 1367 f"Attempted to use {target_type} without a " 1368 "contained type. Please add a contained type, e.g. " 1369 f"{target_type}[int]" 1370 ) 1371 1372 1373def check_args_exist(target_type) -> None: 1374 if target_type is List or target_type is list: 1375 raise_error_container_parameter_missing("List") 1376 elif target_type is Tuple or target_type is tuple: 1377 raise_error_container_parameter_missing("Tuple") 1378 elif target_type is Dict or target_type is dict: 1379 raise_error_container_parameter_missing("Dict") 1380 elif target_type is None or target_type is Optional: 1381 raise_error_container_parameter_missing("Optional") 1382 1383 1384def check_empty_containers(obj) -> None: 1385 if obj == [] or obj == {} or obj == (): 1386 warnings.warn( 1387 "The inner type of a container is lost when " 1388 "calling torch.jit.isinstance in eager mode. For " 1389 "example, List[int] would become list and " 1390 "therefore falsely return True for List[float] or" 1391 " List[str]." 1392 ) 1393 1394 1395# supports List/Dict/Tuple and Optional types 1396# TODO support future 1397def container_checker(obj, target_type) -> bool: 1398 origin_type = get_origin(target_type) 1399 check_args_exist(target_type) 1400 if origin_type is None: 1401 return False 1402 elif origin_type is list or origin_type is List: 1403 check_empty_containers(obj) 1404 if not isinstance(obj, list): 1405 return False 1406 arg_type = get_args(target_type)[0] 1407 arg_origin = get_origin(arg_type) 1408 for el in obj: 1409 # check if nested container, ex: List[List[str]] 1410 if arg_origin: # processes nested container, ex: List[List[str]] 1411 if not container_checker(el, arg_type): 1412 return False 1413 elif not isinstance(el, arg_type): 1414 return False 1415 return True 1416 elif origin_type is Dict or origin_type is dict: 1417 check_empty_containers(obj) 1418 if not isinstance(obj, dict): 1419 return False 1420 key_type = get_args(target_type)[0] 1421 val_type = get_args(target_type)[1] 1422 for key, val in obj.items(): 1423 # check if keys are of right type 1424 if not isinstance(key, key_type): 1425 return False 1426 val_origin = get_origin(val_type) 1427 if val_origin: 1428 if not container_checker(val, val_type): 1429 return False 1430 elif not isinstance(val, val_type): 1431 return False 1432 return True 1433 elif origin_type is Tuple or origin_type is tuple: 1434 check_empty_containers(obj) 1435 if not isinstance(obj, tuple): 1436 return False 1437 arg_types = get_args(target_type) 1438 if len(obj) != len(arg_types): 1439 return False 1440 for el, el_type in zip(obj, arg_types): 1441 el_origin = get_origin(el_type) 1442 if el_origin: 1443 if not container_checker(el, el_type): 1444 return False 1445 elif not isinstance(el, el_type): 1446 return False 1447 return True 1448 elif origin_type is Union or issubclass( 1449 origin_type, BuiltinUnionType 1450 ): # also handles Optional 1451 if obj is None: # check before recursion because None is always fine 1452 return True 1453 inner_types = get_args(target_type) 1454 for t in inner_types: 1455 t_origin = get_origin(t) 1456 if t_origin: 1457 return container_checker(obj, t) 1458 elif isinstance(obj, t): 1459 return True 1460 return False 1461 1462 1463def _isinstance(obj, target_type) -> bool: 1464 if isinstance(target_type, collections.abc.Container): 1465 if not isinstance(target_type, tuple): 1466 raise RuntimeError( 1467 "The second argument to " 1468 "`torch.jit.isinstance` must be a type " 1469 "or a tuple of types" 1470 ) 1471 for t_type in target_type: 1472 if _isinstance(obj, t_type): 1473 return True 1474 return False 1475 1476 origin_type = get_origin(target_type) 1477 if origin_type: 1478 return container_checker(obj, target_type) 1479 1480 # Check to handle non-typed optional origin returns as none instead 1481 # of as optional in 3.7-3.8 1482 check_args_exist(target_type) 1483 1484 # handle non-containers 1485 return isinstance(obj, target_type) 1486 1487 1488class _TensorExtractor(pickle.Pickler): 1489 def __init__(self, *args, tensors: List[torch.Tensor], **kwargs): 1490 super().__init__(*args, **kwargs) 1491 self.tensors = tensors 1492 1493 def persistent_id(self, obj): 1494 if isinstance(obj, torch.Tensor): 1495 self.tensors.append(obj) 1496 return "" 1497 # Since we just want to extract tensors, we don't mind if an object is 1498 # unpicklable if it doesn't contain tensors, as we can just ignore/skip 1499 # it. To play it safe, we only do so for common objects that we're sure 1500 # don't contain tensors. Feel free to add new types here. Note also that 1501 # even if a type isn't listed here this won't block users, since thet 1502 # can just add a __getstate__ or __reduce__ method to their class. 1503 if isinstance(obj, LockType): 1504 return "" 1505 # Futures and RRefs don't technically contain a value, they just offer 1506 # the means to access a value. 1507 if isinstance(obj, CFuture) or is_rref_instance(obj): 1508 return "" 1509 if isinstance(obj, CAwait): 1510 return "" 1511 if isinstance(obj, torch.cuda.Event): 1512 return "" 1513 if isinstance(obj, threading.Thread): 1514 return "" 1515 return None 1516 1517 1518def _extract_tensors(obj): 1519 r""" 1520 This function is exclusively called from C++. 1521 See ``torch/csrc/jit/python/python_ivalue.h``. 1522 1523 It extracts the tensors contained in the given object, through pickling. 1524 """ 1525 tensors: List[torch.Tensor] = [] 1526 extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors) 1527 extractor.dump(obj) 1528 return tensors 1529 1530 1531def _get_model_id(obj) -> Optional[str]: 1532 if isinstance(obj, torch.jit.ScriptModule): 1533 return str(obj._c._type()) 1534 elif isinstance(obj, torch.jit.ScriptFunction): 1535 return obj.qualified_name 1536 else: 1537 return None 1538 1539 1540# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass 1541# that were previously dropped. To preserve the behavior, explicitly drop them there 1542 1543if sys.version_info > (3, 10): 1544 _drop(enum.Enum.__new__) 1545 _drop(enum.Enum.__format__) 1546 _drop(enum.Enum.__repr__) 1547 _drop(enum.Enum.__str__) 1548