xref: /aosp_15_r20/external/pytorch/torch/_jit_internal.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3The weak_script annotation needs to be here instead of inside torch/jit/ so it
4can be used in other places in torch/ (namely torch.nn) without running into
5circular dependency problems
6"""
7
8import ast
9import builtins
10import collections
11import contextlib
12import enum
13import inspect
14import io
15import pickle
16import sys
17import textwrap
18import threading
19import types
20import typing
21import warnings
22import weakref
23from typing import (
24    Any,
25    Callable,
26    Dict,
27    Final,
28    ForwardRef,
29    get_args,
30    get_origin,
31    List,
32    Optional,
33    Tuple,
34    Type,
35    Union,
36)
37
38import torch
39
40# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
41# Explicitly ask to import `torch.distributed.__init__` first.
42# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
43import torch.distributed.rpc
44import torch.package._mangling as package_mangling
45from torch._awaits import _Await
46from torch._C import _Await as CAwait, Future as CFuture
47from torch._sources import fake_range, get_source_lines_and_file, parse_def
48from torch.futures import Future
49
50
51IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
52IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
53
54BuiltinUnionType: Union[Type, Tuple[Type, ...]]
55if sys.version_info >= (3, 10):
56    # NOTE: IS_PY310_PLUS doesn't work with mypy.
57    # cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks
58    BuiltinUnionType = types.UnionType
59else:
60    BuiltinUnionType = ()  # trick: this makes isinstance short circuit.
61
62LockType: Type
63try:
64    import _thread
65
66    LockType = _thread.LockType
67except ImportError:
68    import _dummy_thread  # type: ignore[import-not-found]
69
70    LockType = _dummy_thread.LockType
71
72# Wrapper functions that can call either of 2 functions depending on a boolean
73# argument
74boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = (
75    weakref.WeakKeyDictionary()
76)  # noqa: T484
77
78
79FAKE_FILENAME_PREFIX = "__torch_jit_dataclass"
80
81
82def is_final(ann) -> bool:
83    return (
84        hasattr(ann, "__module__")
85        and ann.__module__ in {"typing", "typing_extensions"}
86        and (get_origin(ann) is Final or isinstance(ann, type(Final)))
87    )
88
89
90# allows BroadcastingList instance to be subscriptable
91class BroadcastingListCls:
92    def __getitem__(self, types):
93        return
94
95
96# mypy doesn't support parameters on types, so we have to explicitly type each
97# list size
98BroadcastingList1 = BroadcastingListCls()
99for i in range(2, 7):
100    globals()[f"BroadcastingList{i}"] = BroadcastingList1
101
102
103def is_scripting() -> bool:
104    r"""
105    Function that returns True when in compilation and False otherwise. This
106    is useful especially with the @unused decorator to leave code in your
107    model that is not yet TorchScript compatible.
108    .. testcode::
109
110        import torch
111
112        @torch.jit.unused
113        def unsupported_linear_op(x):
114            return x
115
116        def linear(x):
117            if torch.jit.is_scripting():
118                return torch.linear(x)
119            else:
120                return unsupported_linear_op(x)
121    """
122    return False
123
124
125# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
126def _qualified_name(obj, mangle_name=True) -> str:
127    # This special case allows us to override the qualified name on a type.
128    # It's currently used in conjunction with tracing, where we create a
129    # fake module to filter only supported attributes. However, since this
130    # new type is defined as a local class, we need a mechanism to override
131    # its qualname so it appears correctly in the TorchScript system. This,
132    # we set '_jit_override_qualname' with the original traced module's
133    # qualified name, which is picked up here
134    if hasattr(obj, "_jit_override_qualname"):
135        return obj._jit_override_qualname
136    # short-circuit in cases where the object already has a known qualified name
137    if isinstance(obj, torch._C.ScriptFunction):
138        return obj.qualified_name
139
140    if getattr(obj, "__name__", None):
141        name = obj.__name__
142    # Enum classes do not have `__name__` attr, instead they have `name`.
143    elif isinstance(obj, enum.Enum):
144        name = obj.name
145    else:
146        raise RuntimeError("Could not get name of python class object")
147
148    if name == "<lambda>":
149        name = "_lambda"  # make name a valid identifier
150
151    module_name = obj.__module__
152
153    # If the module is actually a torchbind module, then we should short circuit
154    if module_name == "torch._classes":
155        return obj.qualified_name
156
157    # The Python docs are very clear that `__module__` can be None, but I can't
158    # figure out when it actually would be.
159    if module_name is None:
160        raise RuntimeError(
161            f"Could not get qualified name for class '{name}': "
162            "__module__ can't be None."
163        )
164
165    # if getattr(sys.modules[module_name], name) is not obj:
166    #     raise RuntimeError(f"Could not get qualified name for class '{name}': "
167    #                        f"the attr {name} on module {module_name} is not the class")
168
169    # torch.package and TorchScript have separate mangling schemes to avoid
170    # name collisions from multiple packages. To avoid them interfering with
171    # each other, normalize the package manging here.
172    if package_mangling.is_mangled(module_name):
173        module_name = module_name.replace("<", "_")
174        module_name = module_name.replace(">", "_")
175
176    # The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
177    # does not need mangle the python class name.
178    if mangle_name:
179        # __main__ is a builtin module, so rewrite it to "__torch__".
180        if module_name == "__main__":
181            module_name = "__torch__"
182        else:
183            # Everything else gets a "__torch__" prefix to avoid name collisions
184            # with the names of user values.
185            module_name = "__torch__." + module_name
186
187    if "." in name:
188        raise RuntimeError(
189            f"Could not get qualified name for class '{name}': "
190            f"'{name}' is not a valid identifier"
191        )
192
193    return module_name + "." + name
194
195
196class SourceLoader:
197    def __init__(self):
198        self.content = {}
199
200    def cache(self, fn, source):
201        self.content[fn] = source
202
203    def get_source(self, fn):
204        return self.content.get(fn)
205
206
207loader = SourceLoader()
208
209
210def createResolutionCallbackFromEnv(lookup_base):
211    """
212    Creates a resolution callback that will look up qualified names in an
213    environment, starting with `lookup_base` for the base of any qualified
214    names, then proceeding down the lookup chain with the resolved object.
215
216    You should not use this directly, it should only be used from the other
217    createResolutionCallbackFrom* functions.
218    """
219
220    def lookupInModule(qualified_name, module):
221        if "." in qualified_name:
222            base, remaining_pieces = qualified_name.split(".", maxsplit=1)
223            module_value = getattr(module, base)
224            return lookupInModule(remaining_pieces, module_value)
225        else:
226            return getattr(module, qualified_name)
227
228    def parseNestedExpr(expr, module) -> Tuple[Any, int]:
229        i = 0
230        while i < len(expr) and expr[i] not in (",", "[", "]"):
231            i += 1
232
233        # Special case logic for the empty Tuple as a subscript (used
234        # in the type annotation `Tuple[()]`)
235        if expr[:i] == "()":
236            return (), i
237
238        base = lookupInModule(expr[:i].strip(), module)
239        assert base is not None, f"Unresolvable type {expr[:i]}"
240        if i == len(expr) or expr[i] != "[":
241            return base, i
242
243        assert expr[i] == "["
244        parts = []
245        while expr[i] != "]":
246            part_len = 0
247            i += 1
248            part, part_len = parseNestedExpr(expr[i:], module)
249            parts.append(part)
250            i += part_len
251        if len(parts) > 1:
252            return base[tuple(parts)], i + 1
253        else:
254            return base[parts[0]], i + 1
255
256    def parseExpr(expr, module):
257        try:
258            value, len_parsed = parseNestedExpr(expr, module)
259            assert len_parsed == len(
260                expr
261            ), "whole expression was not parsed, falling back to c++ parser"
262            return value
263        except Exception:
264            """
265            The python resolver fails in several cases in known unit tests, and is intended
266            to fall back gracefully to the c++ resolver in general.  For example, python 2 style
267            annotations which are frequent in our unit tests often fail with types e.g. int not
268            resolvable from the calling frame.
269            """
270            return None
271
272    return lambda expr: parseExpr(expr, lookup_base)
273
274
275def createResolutionCallbackFromFrame(frames_up: int = 0):
276    """
277    Creates a function which, given a string variable name,
278    returns the value of the variable in the scope of the caller of
279    the function which called createResolutionCallbackFromFrame (by default).
280
281    This is used to enable access in-scope Python variables inside
282    TorchScript fragments.
283
284    frames_up is number of additional frames to go up on the stack.
285    The default value is 0, which correspond to the frame of the caller
286    of createResolutionCallbackFromFrame. Also for example, if frames_up is set
287    to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame
288    will be taken.
289
290    For example, the following program prints 2::
291
292        def bar():
293            cb = createResolutionCallbackFromFrame(1)
294            print(cb("foo"))
295
296
297        def baz():
298            foo = 2
299            bar()
300
301
302        baz()
303    """
304    frame = inspect.currentframe()
305    i = 0
306    while i < frames_up + 1:
307        assert frame is not None
308        frame = frame.f_back
309        i += 1
310
311    assert frame is not None
312    f_locals = frame.f_locals
313    f_globals = frame.f_globals
314
315    class env:
316        def __getattr__(self, key):
317            if key in f_locals:
318                return f_locals[key]
319            elif key in f_globals:
320                return f_globals[key]
321            elif key in dir(builtins):
322                return getattr(builtins, key)
323
324    return createResolutionCallbackFromEnv(env())
325
326
327def get_closure(fn):
328    """
329    Get a dictionary of closed over variables from a function
330    """
331    captures = {}
332    captures.update(fn.__globals__)
333
334    for index, captured_name in enumerate(fn.__code__.co_freevars):
335        captures[captured_name] = fn.__closure__[index].cell_contents
336
337    return captures
338
339
340# [local resolution in python]
341# Depending on where a variable is defined, and where it is used, we may
342# or may not be able to recover its value when recursively compiling a
343# script function. Remember in the general case, a module or function is
344# first defined and then later scripted. This means we do not have a
345# chance to capture the active frames when the function is defined. Hence any
346# name resolution has to happen later on the created closure. The way
347# python captures type annotations restricts what we can recover. The
348# follow example illustrates the different cases:
349#
350#         class MyGlobalClass:
351#         ...
352#         def my_local_scope():
353#             @torch.jit.script
354#             class MyClass:
355#                 ...
356#             @torch.jit.script
357#             class MyClassUsedAsVar:
358#                 ...
359#             def eg(x: MyClass, y: MyGlobalClass):
360#                 a_local_capture : Foo
361#                 return MyClassUsedAsVar(x)
362#
363# MyGlobalClass is defined in the __globals__ dictionary of function
364# 'eg', so it is always recoverable. my_local_scope introduces a new local
365# variable scope in the function. Classes defined here are only visible as
366# local variables. For the case of MyClassUsedAsVar, it is captured
367# because it is used as a variable inside the body of the function, and we
368# can resolve it using the captures returned from `get_closure`. However,
369# the type annotations are not captured by the closure. In Python
370# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as
371# annotations on `eg``, but starting in Python 4.0, they will represented as
372# strings and no longer present. Furthermore, since the body of `eg` does
373# not reference those names, they do not appear in the list of closed over
374# variables. In Python 2.x, type annotations are in comments, leading to a
375# similar situation where their definitions are not available. We anticipate
376# that most users will not run into this issue because their modules and
377# functions will be defined at a global scope like MyGlobalClass. In cases
378# where they are not, it is possible to work around issues by declaring the
379# values global in the function.
380# In Python 3.9 declaring class as global will make it invisible to
381# `inspect.getsource`, see https://bugs.python.org/issue42666 .
382# This could be worked around by manualy adding it to `global()` dictionary.
383
384
385def createResolutionCallbackFromClosure(fn):
386    """
387    Create a resolutionCallback by introspecting the function instead of
388    looking up the stack for the enclosing scope
389    """
390    closure = get_closure(fn)
391
392    class closure_lookup:
393        # This is a class since `closure` is a dict and it's easier in
394        # `env_helper` if everything just works with `getattr` calls
395        def __getattr__(self, key):
396            if key in closure:
397                return closure[key]
398            elif hasattr(typing, key):
399                return getattr(typing, key)
400            elif hasattr(builtins, key):
401                return getattr(builtins, key)
402            return None
403
404    return createResolutionCallbackFromEnv(closure_lookup())
405
406
407def can_compile_class(cls) -> bool:
408    # If any of the functions on a type don't have a code object, this type can't
409    # be compiled and is probably a builtin / bound from C
410    if is_ignored_fn(cls):
411        return False
412
413    # Ignore the following list of built-in classes.
414    ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
415    if issubclass(cls, ignored_builtin_classes):
416        return False
417
418    names = cls.__dict__
419    fns = [
420        getattr(cls, name)
421        for name in names
422        if inspect.isroutine(getattr(cls, name, None))
423    ]
424    has_code = [hasattr(fn, "__code__") for fn in fns]
425    return all(has_code)
426
427
428def get_callable_argument_names(fn) -> List[str]:
429    """
430    Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
431    Returns an empty list when other types of arguments are present.
432
433    This is used by `torch.jit.trace` to assign meaningful argument names to
434    traced functions and modules.
435
436    Args:
437        fn: A callable.
438    Returns:
439        Argument names: List[str]
440    """
441    # inspect.signature may fail, give up in that case.
442    try:
443        callable_signature = inspect.signature(fn)
444    except Exception:
445        return []
446
447    argument_names = []
448    for name, param in callable_signature.parameters.items():
449        # All four other types of arguments do not map to individual values
450        # with a keyword as name.
451        if not param.kind == param.POSITIONAL_OR_KEYWORD:
452            continue
453
454        argument_names.append(name)
455
456    return argument_names
457
458
459def get_annotation_str(annotation):
460    """
461    Convert an AST node containing a type annotation to the string present in the source
462    that represents the same annotation.
463    """
464    if isinstance(annotation, ast.Name):
465        return annotation.id
466    elif isinstance(annotation, ast.Attribute):
467        return ".".join([get_annotation_str(annotation.value), annotation.attr])
468    elif isinstance(annotation, ast.Subscript):
469        # In Python3.9+ subscript indicies are not wrapped in ast.Index
470        subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value  # type: ignore[attr-defined]
471        return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
472    elif isinstance(annotation, ast.Tuple):
473        return ",".join([get_annotation_str(elt) for elt in annotation.elts])
474    elif isinstance(annotation, ast.Constant):
475        return f"{annotation.value}"
476
477    # If an AST node is not handled here, it's probably handled in ScriptTypeParser.
478    return None
479
480
481def get_type_hint_captures(fn):
482    """
483    Get a dictionary containing type resolution mappings necessary to resolve types
484    for the literal annotations on 'fn'. These are not considered to be closed-over by fn
485    and must be obtained separately (e.g. using this function).
486
487    Args:
488        fn: A callable.
489    Returns:
490        A Dict[str, Any] containing a mapping from the literal annotations used on
491        fn to the Python objects they refer to.
492    """
493    # First, try to get the source of the function. We'll need to parse it to find the actual string names
494    # that were used to annotate the types, since inspect.signature() will only return the class object that
495    # the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
496    # This may happen in cases where the function is synthesized dynamically at runtime.
497    src = loader.get_source(fn)
498    if src is None:
499        try:
500            src = inspect.getsource(fn)
501        except OSError as e:
502            raise OSError(
503                f"Failed to get source for {fn} using inspect.getsource"
504            ) from e
505
506    # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
507    # types are strings. These are only understood by TorchScript in the context of a type annotation
508    # that refers to a class in its own definition, but trying to include a mapping for this in the result
509    # function would cause infinite recursion because the class is currently being compiled.
510    # In addition, there is logic in ScriptTypeParser to handle this.
511    signature = inspect.signature(fn)
512    name_to_type = {
513        name: parameter.annotation
514        for name, parameter in signature.parameters.items()
515        if parameter.annotation is not inspect.Parameter.empty
516        and not isinstance(parameter.annotation, str)
517    }
518
519    # Then, get the literal type annotations from the function declaration
520    # by source inspection. This accounts for the case in which aliases are used
521    # to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
522    # frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
523    a = ast.parse(textwrap.dedent(src))
524    if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
525        raise RuntimeError(f"Expected {fn} to be a function")
526    f = a.body[0]
527
528    # Prepare a dictionary of source annotation -> type, which will be the final result of this function,
529    # by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping
530    # them to the type object corresponding to the annotation via name_to_type using the parameter name.
531    annotation_to_type = {}
532
533    for arg in f.args.args:
534        # Get the source type annotation string for this argument if possible.
535        arg_annotation_str = (
536            get_annotation_str(arg.annotation) if arg.annotation else None
537        )
538
539        # If the argument has no annotation or get_annotation_str cannot convert it to a string,
540        # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle
541        # this in the latter case.
542        if arg_annotation_str is None:
543            continue
544
545        # Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not
546        # be present in name_to_type is that the annotation itself is a string and not a type object
547        # (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this.
548        arg_name = arg.arg
549        if arg_name in name_to_type:
550            annotation_to_type[arg_annotation_str] = name_to_type[arg_name]
551
552    # If there is a valid return annotation, include it in annotation_to_type. As with argument annotations,
553    # the literal annotation has to be convertible to a string by get_annotation_str, and the actual type
554    # of the annotation cannot be a string.
555    literal_return_annotation = get_annotation_str(f.returns)
556    valid_literal_annotation = literal_return_annotation is not None
557    return_annotation = signature.return_annotation
558    valid_return_annotation_type = (
559        return_annotation is not inspect.Parameter.empty
560        and not isinstance(return_annotation, str)
561    )
562    if valid_literal_annotation and valid_return_annotation_type:
563        annotation_to_type[literal_return_annotation] = return_annotation
564
565    return annotation_to_type
566
567
568def createResolutionCallbackForClassMethods(cls):
569    """
570    This looks at all the methods defined in a class and pulls their closed-over
571    variables into a dictionary and uses that to resolve variables.
572    """
573    # cls is a type here, so `ismethod` is false since the methods on the type
574    # aren't bound to anything, so Python treats them as regular functions
575    fns = [
576        getattr(cls, name)
577        for name in cls.__dict__
578        if inspect.isroutine(getattr(cls, name))
579    ]
580    # Skip built-ins, as they do not have global scope nor type hints
581    # Needed to support `enum.Enum` derived classes in Python-3.11
582    # That adds `_new_member_` property which is an alias to `__new__`
583    fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
584    captures = {}
585
586    for fn in fns:
587        captures.update(get_closure(fn))
588        captures.update(get_type_hint_captures(fn))
589
590    def lookup_in_class(key):
591        if key in captures:
592            return captures[key]
593        else:
594            return getattr(builtins, key, None)
595
596    return lookup_in_class
597
598
599def boolean_dispatch(
600    arg_name,
601    arg_index,
602    default,
603    if_true,
604    if_false,
605    module_name,
606    func_name,
607):
608    """
609    Dispatches to either of 2 script functions based on a boolean argument.
610    In TorchScript, the boolean argument must be constant so that the correct
611    function to use can be determined at compile time.
612    """
613
614    def fn(*args, **kwargs):
615        dispatch_flag = default
616        if arg_name in kwargs:
617            dispatch_flag = kwargs[arg_name]
618        elif arg_index < len(args):
619            dispatch_flag = args[arg_index]
620
621        if dispatch_flag:
622            return if_true(*args, **kwargs)
623        else:
624            return if_false(*args, **kwargs)
625
626    if if_true.__doc__ is None and if_false.__doc__ is not None:
627        doc = if_false.__doc__
628        if_true.__doc__ = doc
629    elif if_false.__doc__ is None and if_true.__doc__ is not None:
630        doc = if_true.__doc__
631        if_false.__doc__ = doc
632    elif if_false.__doc__ is None and if_true.__doc__ is None:
633        # neither function has a docstring
634        doc = None
635    else:
636        raise RuntimeError("only one function can have a docstring")
637    fn.__doc__ = doc
638
639    if module_name is not None:
640        fn.__module__ = module_name
641    if func_name is not None:
642        fn.__name__ = func_name
643
644    boolean_dispatched[fn] = {
645        "if_true": if_true,
646        "if_false": if_false,
647        "index": arg_index,
648        "default": default,
649        "arg_name": arg_name,
650    }
651    return fn
652
653
654class FunctionModifiers:
655    """
656    Used to denote the behavior of a function in TorchScript. See export() and
657    ignore() for details.
658    """
659
660    UNUSED = "unused (ignored and replaced with raising of an exception)"
661    IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
662    EXPORT = "export (compile this function even if nothing calls it)"
663    DEFAULT = "default (compile if called from a exported function / forward)"
664    COPY_TO_SCRIPT_WRAPPER = (
665        "if this method is not scripted, copy the python method onto the scripted model"
666    )
667    _DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
668
669
670def export(fn):
671    """
672    This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
673    :class:`ScriptModule` and should be compiled.
674
675    ``forward`` implicitly is assumed to be an entry point, so it does not need this decorator.
676    Functions and methods called from ``forward`` are compiled as they are seen
677    by the compiler, so they do not need this decorator either.
678
679    Example (using ``@torch.jit.export`` on a method):
680
681    .. testcode::
682
683        import torch
684        import torch.nn as nn
685
686        class MyModule(nn.Module):
687            def implicitly_compiled_method(self, x):
688                return x + 99
689
690            # `forward` is implicitly decorated with `@torch.jit.export`,
691            # so adding it here would have no effect
692            def forward(self, x):
693                return x + 10
694
695            @torch.jit.export
696            def another_forward(self, x):
697                # When the compiler sees this call, it will compile
698                # `implicitly_compiled_method`
699                return self.implicitly_compiled_method(x)
700
701            def unused_method(self, x):
702                return x - 20
703
704        # `m` will contain compiled methods:
705        #     `forward`
706        #     `another_forward`
707        #     `implicitly_compiled_method`
708        # `unused_method` will not be compiled since it was not called from
709        # any compiled methods and wasn't decorated with `@torch.jit.export`
710        m = torch.jit.script(MyModule())
711    """
712    fn._torchscript_modifier = FunctionModifiers.EXPORT
713    return fn
714
715
716def unused(fn):
717    """
718    This decorator indicates to the compiler that a function or method should
719    be ignored and replaced with the raising of an exception. This allows you
720    to leave code in your model that is not yet TorchScript compatible and still
721    export your model.
722
723        Example (using ``@torch.jit.unused`` on a method)::
724
725            import torch
726            import torch.nn as nn
727
728
729            class MyModule(nn.Module):
730                def __init__(self, use_memory_efficient):
731                    super().__init__()
732                    self.use_memory_efficient = use_memory_efficient
733
734                @torch.jit.unused
735                def memory_efficient(self, x):
736                    import pdb
737
738                    pdb.set_trace()
739                    return x + 10
740
741                def forward(self, x):
742                    # Use not-yet-scriptable memory efficient mode
743                    if self.use_memory_efficient:
744                        return self.memory_efficient(x)
745                    else:
746                        return x + 10
747
748
749            m = torch.jit.script(MyModule(use_memory_efficient=False))
750            m.save("m.pt")
751
752            m = torch.jit.script(MyModule(use_memory_efficient=True))
753            # exception raised
754            m(torch.rand(100))
755    """
756    if isinstance(fn, property):
757        prop = fn
758        setattr(  # noqa: B010
759            prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED
760        )
761
762        if prop.fset:
763            setattr(  # noqa: B010
764                prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
765            )
766
767        return prop
768
769    fn._torchscript_modifier = FunctionModifiers.UNUSED
770    return fn
771
772
773# No op context manager from python side
774class _IgnoreContextManager(contextlib.AbstractContextManager):
775    def __init__(self, **kwargs):
776        pass
777
778    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
779        pass
780
781
782def ignore(drop=False, **kwargs):
783    """
784    This decorator indicates to the compiler that a function or method should
785    be ignored and left as a Python function. This allows you to leave code in
786    your model that is not yet TorchScript compatible. If called from TorchScript,
787    ignored functions will dispatch the call to the Python interpreter. Models with ignored
788    functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead.
789
790    Example (using ``@torch.jit.ignore`` on a method)::
791
792        import torch
793        import torch.nn as nn
794
795
796        class MyModule(nn.Module):
797            @torch.jit.ignore
798            def debugger(self, x):
799                import pdb
800
801                pdb.set_trace()
802
803            def forward(self, x):
804                x += 10
805                # The compiler would normally try to compile `debugger`,
806                # but since it is `@ignore`d, it will be left as a call
807                # to Python
808                self.debugger(x)
809                return x
810
811
812        m = torch.jit.script(MyModule())
813
814        # Error! The call `debugger` cannot be saved since it calls into Python
815        m.save("m.pt")
816
817    Example (using ``@torch.jit.ignore(drop=True)`` on a method):
818
819    .. testcode::
820
821        import torch
822        import torch.nn as nn
823
824        class MyModule(nn.Module):
825            @torch.jit.ignore(drop=True)
826            def training_method(self, x):
827                import pdb
828                pdb.set_trace()
829
830            def forward(self, x):
831                if self.training:
832                    self.training_method(x)
833                return x
834
835        m = torch.jit.script(MyModule())
836
837        # This is OK since `training_method` is not saved, the call is replaced
838        # with a `raise`.
839        m.save("m.pt")
840
841    .. testcleanup::
842
843        import os
844        os.remove('m.pt')
845    """
846
847    if callable(drop):
848        # used without any args, so drop is actually a function
849        #   @torch.jit.ignore
850        #   def fn(...):
851        fn = drop
852        fn._torchscript_modifier = FunctionModifiers.IGNORE
853        return fn
854
855    if not isinstance(drop, bool):
856        raise RuntimeError(
857            "Argument to @torch.jit.ignore must be a bool or "
858            f"a function but got {drop}"
859        )
860
861    # for backwards compat
862    drop_on_export = kwargs.pop("drop_on_export", None)
863    if drop_on_export:
864        warnings.warn(
865            "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
866            "call on compilation. Use torch.jit.unused now. {}",
867            category=FutureWarning,
868        )
869
870        drop = drop_on_export
871    elif drop:
872        warnings.warn(
873            "ignore(True) has been deprecated. TorchScript will now drop the function "
874            "call on compilation. Use torch.jit.unused now. {}",
875            category=FutureWarning,
876        )
877
878    def decorator(fn):
879        if drop:
880            fn._torchscript_modifier = FunctionModifiers.UNUSED
881        else:
882            fn._torchscript_modifier = FunctionModifiers.IGNORE
883        return fn
884
885    return decorator
886
887
888def _drop(fn):
889    fn._torchscript_modifier = FunctionModifiers._DROP
890    return fn
891
892
893def _copy_to_script_wrapper(fn):
894    fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
895    return fn
896
897
898def module_has_exports(mod):
899    for name in dir(mod):
900        if hasattr(mod, name):
901            item = getattr(mod, name)
902            if callable(item):
903                if get_torchscript_modifier(item) is FunctionModifiers.EXPORT:
904                    return True
905    return False
906
907
908# WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you
909# rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to
910# allow JIT'd code to still be covered.
911def should_drop(fn) -> bool:
912    attr = get_torchscript_modifier(fn)
913    if attr is None:
914        return False
915    return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
916
917
918def is_ignored_fn(fn) -> bool:
919    mod = get_torchscript_modifier(fn)
920    return (
921        mod is FunctionModifiers.UNUSED
922        or mod is FunctionModifiers.IGNORE
923        or mod is FunctionModifiers._DROP
924    )
925
926
927def _is_drop_fn(fn) -> bool:
928    mod = get_torchscript_modifier(fn)
929    return mod is FunctionModifiers._DROP
930
931
932def is_static_fn(cls, fn) -> bool:
933    return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod)
934
935
936def get_static_fn(cls, fn):
937    return inspect.getattr_static(cls, fn).__func__
938
939
940def get_torchscript_modifier(fn):
941    if not callable(fn):
942        return None
943    if hasattr(fn, "__func__"):
944        fn = fn.__func__
945    return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT)
946
947
948def copy_torchscript_modifier(orig, new) -> None:
949    attr = get_torchscript_modifier(orig)
950    if attr is None:
951        return
952    new._torchscript_modifier = attr
953
954
955# overloading registration
956# overloads get registered in this file, and compiled in torch/jit/__init__.py
957# so that they can be imported in nn/functional.py without an import cycle
958
959# qualified_name => list[overload_functions]
960_overloaded_fns: Dict[str, List[Callable]] = {}  # noqa: T484
961
962
963_OVERLOAD_EXAMPLE = """
964Example usage of overload function:
965@torch.jit._overload
966def my_function(x: type0) -> type0: # decl 1
967    pass
968
969@torch.jit._overload
970def my_function(x: type1) -> type1: # decl 2
971    pass
972
973def my_function(x):                 # implementation
974    if isinstance(x, type0):
975        return x
976    elif isinstance(x, type1):
977        return x
978"""
979
980
981def get_overload_no_implementation_error_message(kind, obj):
982    sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
983    return (
984        f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
985        f"sure a definition is provided and defined after all overload declarations.\n"
986        f'File "{filename}", line {file_lineno}:\n'
987        + "".join(sourcelines)
988        + "\n"
989        + _OVERLOAD_EXAMPLE
990    )
991
992
993def _check_overload_body(func):
994    try:
995        parsed_def = parse_def(func)
996    except OSError as e:
997        # Parsing the function definition can raise an OSError if source is unavailable.
998        # Since this is just an initial check, just raise a warning if this is the case.
999        warnings.warn(
1000            f"Unable to retrieve source for @torch.jit._overload function: {func}."
1001        )
1002        return
1003
1004    body = parsed_def.ast.body[0].body
1005
1006    def is_pass(x):
1007        return isinstance(x, ast.Pass)
1008
1009    def is_ellipsis(x):
1010        return (
1011            isinstance(x, ast.Expr)
1012            and isinstance(x.value, ast.Constant)
1013            and x.value.value is Ellipsis
1014        )
1015
1016    if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
1017        msg = (
1018            "Only `pass` statement or `...` can be the body of overload declaration:\n"
1019        )
1020        msg += "\n".join(parsed_def.source.split("\n")[:3])
1021        msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
1022        raise RuntimeError(msg)
1023
1024
1025def _overload(func):
1026    _check_overload_body(func)
1027    qual_name = _qualified_name(func)
1028    global _overloaded_fns
1029    fn_overload_list = _overloaded_fns.get(qual_name)
1030    if fn_overload_list is None:
1031        fn_overload_list = []
1032        _overloaded_fns[qual_name] = fn_overload_list
1033    fn_overload_list.append(func)
1034    return func
1035
1036
1037def _get_fn_overloads(qual_name):
1038    return _overloaded_fns.get(qual_name)
1039
1040
1041def _clear_fn_overloads(qual_name) -> None:
1042    del _overloaded_fns[qual_name]
1043
1044
1045def get_class_name_lineno(method) -> Tuple[str, int]:
1046    current_frame = inspect.currentframe()
1047
1048    # one for the get_class_name call, one for _overload_method call
1049    for i in range(2):
1050        assert (
1051            current_frame is not None
1052        )  # assert current frame is not an Optional[FrameType]
1053        current_frame = current_frame.f_back
1054
1055    assert current_frame is not None  # same here
1056    class_name = current_frame.f_code.co_name
1057    line_no = current_frame.f_code.co_firstlineno
1058    return class_name, line_no
1059
1060
1061# At the point the decorator is applied to class methods the method
1062# has no reference to its owning class. _qualified_name would not include
1063# the class it is defined in, so any methods with the same name in the same file
1064# would have the same _qualified_name, even if they were defined in different
1065# classes. This problem only exists in python 2.
1066# We get around this problem by looking at the stack frame and identifying
1067# the class name, and throwing an error whenever overloads are used
1068# when modules of the same name are in the same file
1069
1070# qualified_name => class name => list[overload_functions]
1071_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {}  # noqa: T484
1072
1073
1074# (qualified_name, class name) => class_fileno
1075_overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {}
1076
1077
1078def _overload_method(func):
1079    _check_overload_body(func)
1080    qual_name = _qualified_name(func)
1081    global _overloaded_methods
1082    class_name_map = _overloaded_methods.get(qual_name, None)
1083    if class_name_map is None:
1084        class_name_map = {}
1085        _overloaded_methods[qual_name] = class_name_map
1086
1087    class_name, line_no = get_class_name_lineno(func)
1088    method_overloads = class_name_map.get(class_name, None)
1089    if method_overloads is None:
1090        method_overloads = []
1091        class_name_map[class_name] = method_overloads
1092        _overloaded_method_class_fileno[(qual_name, class_name)] = line_no
1093    else:
1094        existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
1095        if existing_lineno != line_no:
1096            raise RuntimeError(
1097                "Cannot currently overload the same method name in two different"
1098                " classes with the same name in the same module"
1099            )
1100
1101    method_overloads.append(func)
1102    return func
1103
1104
1105def _get_overloaded_methods(method, mod_class):
1106    # TODO: __name__ not set for submodules in recursive script
1107    if not hasattr(method, "__name__"):
1108        return None
1109    qual_name = _qualified_name(method)
1110    class_name_map = _overloaded_methods.get(qual_name, None)
1111    if class_name_map is None:
1112        return None
1113    overloads = class_name_map.get(mod_class.__name__, None)
1114    if overloads is None:
1115        return None
1116
1117    method_line_no = get_source_lines_and_file(method)[1]
1118    mod_class_fileno = get_source_lines_and_file(mod_class)[1]
1119    mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
1120    if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
1121        raise AssertionError(
1122            "Overloads are not useable when a module is redeclared within the same file: "
1123            + str(method)
1124        )
1125    return overloads
1126
1127
1128def is_tuple(ann) -> bool:
1129    if ann is Tuple:
1130        raise_error_container_parameter_missing("Tuple")
1131
1132    # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
1133    if not hasattr(ann, "__module__"):
1134        return False
1135
1136    ann_origin = get_origin(ann)
1137    if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
1138        return True
1139    return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
1140
1141
1142def is_list(ann) -> bool:
1143    if ann is List:
1144        raise_error_container_parameter_missing("List")
1145
1146    if not hasattr(ann, "__module__"):
1147        return False
1148
1149    ann_origin = get_origin(ann)
1150    if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
1151        return True
1152    return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
1153
1154
1155def is_dict(ann) -> bool:
1156    if ann is Dict:
1157        raise_error_container_parameter_missing("Dict")
1158
1159    if not hasattr(ann, "__module__"):
1160        return False
1161
1162    ann_origin = get_origin(ann)
1163    if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
1164        return True
1165    return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
1166
1167
1168def is_union(ann):
1169    if ann is Union:
1170        raise_error_container_parameter_missing("Union")
1171
1172    return isinstance(ann, BuiltinUnionType) or (
1173        hasattr(ann, "__module__")
1174        and ann.__module__ == "typing"
1175        and (get_origin(ann) is Union)
1176    )
1177
1178
1179def is_optional(ann):
1180    if ann is Optional:
1181        raise_error_container_parameter_missing("Optional")
1182
1183    def is_optional_as_optional(ann):
1184        return (
1185            hasattr(ann, "__module__")
1186            and ann.__module__ == "typing"
1187            and (get_origin(ann) is Optional)
1188        )
1189
1190    def is_union_as_optional(ann):
1191        ann_args = get_args(ann)
1192        return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
1193
1194    return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
1195
1196
1197def is_future(ann) -> bool:
1198    if ann is Future:
1199        raise RuntimeError(
1200            "Attempted to use Future without a "
1201            "contained type. Please add a contained type, e.g. "
1202            "Future[int]"
1203        )
1204    return get_origin(ann) is Future
1205
1206
1207def is_await(ann) -> bool:
1208    if ann is _Await:
1209        return True
1210    return get_origin(ann) is _Await
1211
1212
1213if torch.distributed.rpc.is_available():
1214    from torch._C._distributed_rpc import PyRRef
1215    from torch.distributed.rpc import RRef
1216
1217    def is_rref(ann) -> bool:
1218        if ann is RRef:
1219            raise RuntimeError(
1220                "Attempted to use RRef without a "
1221                "contained type. Please add a contained type, e.g. "
1222                "RRef[int]"
1223            )
1224        return get_origin(ann) is RRef
1225
1226    def is_rref_instance(obj) -> bool:
1227        return isinstance(obj, PyRRef)
1228
1229else:
1230
1231    def is_rref_instance(obj) -> bool:
1232        # If the RPC module doesn't exist then RRefs don't exist either.
1233        return False
1234
1235
1236def _try_get_dispatched_fn(fn):
1237    if not callable(fn):
1238        return None
1239    return boolean_dispatched.get(fn)
1240
1241
1242def _get_named_tuple_properties(
1243    obj,
1244    loc: Optional[torch._C._jit_tree_views.SourceRange] = None,
1245    rcb=None,
1246):
1247    if loc is None:
1248        loc = fake_range()
1249
1250    assert issubclass(obj, tuple) and hasattr(obj, "_fields")
1251    if hasattr(obj, "_field_defaults"):
1252        defaults = [
1253            obj._field_defaults[field]
1254            for field in obj._fields
1255            if field in obj._field_defaults
1256        ]
1257    else:
1258        defaults = []
1259    # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
1260    # Also, annotations from base class are not inherited so they need to be queried explicitly
1261    if sys.version_info[:2] < (3, 10):
1262        obj_annotations = getattr(obj, "__annotations__", {})
1263    else:
1264        obj_annotations = inspect.get_annotations(obj)
1265        if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
1266            obj_annotations = inspect.get_annotations(obj.__base__)
1267
1268    annotations = []
1269    for field in obj._fields:
1270        if field in obj_annotations:
1271            field_type = obj_annotations[field]
1272            # [Note: ForwardRef annotations in NamedTuple attributes]
1273            # NamedTuple types are slightly different from normal types.
1274            #
1275            # Normally, annotations are evaluted like this (during jit.script):
1276            # 1. Load strings of python code into c++ and parse.
1277            # 2. Get annotations as strings
1278            # 3. Use the PythonResolver's resolution callback (rcb) to convert
1279            #    the string into a python object
1280            # 4. We call into annotations.py:ann_to_type to convert python obj
1281            #    from step 3 into a type that torchscript understands.
1282            #
1283            # NamedTuples are more complicated, because it has sub-types.
1284            # Normally, once we have the NamedTuple type object from #3,
1285            # we can just look at the annotation literal values and use
1286            # ann_to_type directly on them.
1287            #
1288            # But sometimes, users will annotate with string literals, e.g.
1289            #    x: 'int'
1290            # This also happens with PEP563 (from __forward__ import annotations)
1291            #
1292            # These annotations appear in the annotation dict as ForwardRef('int').
1293            #
1294            # Then, we need to convert the string into a python object. This
1295            # requires having local context for custom objects or imported types.
1296            # rcb() is what gives us this. So, we plumb rcb through the stack so
1297            # it can be used in this context for the if block below.
1298            #
1299            # FAQ:
1300            # - Why do we need this special handling for NamedTuple but string
1301            #   annotations work fine for normal types? Normally, we parse the
1302            #   string directly and then call rcb() directly from C++.
1303            # - Why not use ForwardRef._evaluate? For that, we need globals()
1304            #   and locals() for the local context where the NamedTuple was defined.
1305            #   rcb is what lets us look up into these. So, basically rcb does the
1306            #   hard work for us.
1307            if isinstance(field_type, ForwardRef) and rcb is not None:
1308                rcb_type = rcb(field_type.__forward_arg__)
1309                # rcb returns None if it can't find anything.
1310                if rcb_type is None:
1311                    raise ValueError(
1312                        f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
1313                        f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
1314                        f" Issue occurred at {loc.highlight()}"
1315                    )
1316                field_type = rcb_type
1317            the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
1318            annotations.append(the_type)
1319        else:
1320            annotations.append(torch._C.TensorType.getInferred())
1321    return type(obj).__name__, obj._fields, annotations, defaults
1322
1323
1324def _create_named_tuple(
1325    t,
1326    unqual_name: str,
1327    field_names: List[str],
1328    defaults: Tuple[Any, ...],
1329):
1330    TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults)  # type: ignore[call-arg, no-redef, misc]
1331    return TupleType(*t)
1332
1333
1334@contextlib.contextmanager
1335def _disable_emit_hooks():
1336    hooks = torch._C._jit_get_emit_hooks()
1337    torch._C._jit_set_emit_hooks(None, None)
1338    try:
1339        yield
1340    finally:
1341        torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
1342
1343
1344def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None:  # noqa: F811
1345    def __enter__(self) -> None:
1346        self.hooks = torch._C._jit_get_emit_hooks()
1347        torch._C._jit_set_emit_hooks(None, None)
1348
1349    def __exit__(self, *args) -> None:
1350        torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
1351
1352
1353def _is_exception(obj) -> bool:
1354    if not inspect.isclass(obj):
1355        return False
1356    return issubclass(obj, Exception)
1357
1358
1359def raise_error_container_parameter_missing(target_type) -> None:
1360    if target_type == "Dict":
1361        raise RuntimeError(
1362            "Attempted to use Dict without "
1363            "contained types. Please add contained type, e.g. "
1364            "Dict[int, int]"
1365        )
1366    raise RuntimeError(
1367        f"Attempted to use {target_type} without a "
1368        "contained type. Please add a contained type, e.g. "
1369        f"{target_type}[int]"
1370    )
1371
1372
1373def check_args_exist(target_type) -> None:
1374    if target_type is List or target_type is list:
1375        raise_error_container_parameter_missing("List")
1376    elif target_type is Tuple or target_type is tuple:
1377        raise_error_container_parameter_missing("Tuple")
1378    elif target_type is Dict or target_type is dict:
1379        raise_error_container_parameter_missing("Dict")
1380    elif target_type is None or target_type is Optional:
1381        raise_error_container_parameter_missing("Optional")
1382
1383
1384def check_empty_containers(obj) -> None:
1385    if obj == [] or obj == {} or obj == ():
1386        warnings.warn(
1387            "The inner type of a container is lost when "
1388            "calling torch.jit.isinstance in eager mode. For "
1389            "example, List[int] would become list and "
1390            "therefore falsely return True for List[float] or"
1391            " List[str]."
1392        )
1393
1394
1395# supports List/Dict/Tuple and Optional types
1396# TODO support future
1397def container_checker(obj, target_type) -> bool:
1398    origin_type = get_origin(target_type)
1399    check_args_exist(target_type)
1400    if origin_type is None:
1401        return False
1402    elif origin_type is list or origin_type is List:
1403        check_empty_containers(obj)
1404        if not isinstance(obj, list):
1405            return False
1406        arg_type = get_args(target_type)[0]
1407        arg_origin = get_origin(arg_type)
1408        for el in obj:
1409            # check if nested container, ex: List[List[str]]
1410            if arg_origin:  # processes nested container, ex: List[List[str]]
1411                if not container_checker(el, arg_type):
1412                    return False
1413            elif not isinstance(el, arg_type):
1414                return False
1415        return True
1416    elif origin_type is Dict or origin_type is dict:
1417        check_empty_containers(obj)
1418        if not isinstance(obj, dict):
1419            return False
1420        key_type = get_args(target_type)[0]
1421        val_type = get_args(target_type)[1]
1422        for key, val in obj.items():
1423            # check if keys are of right type
1424            if not isinstance(key, key_type):
1425                return False
1426            val_origin = get_origin(val_type)
1427            if val_origin:
1428                if not container_checker(val, val_type):
1429                    return False
1430            elif not isinstance(val, val_type):
1431                return False
1432        return True
1433    elif origin_type is Tuple or origin_type is tuple:
1434        check_empty_containers(obj)
1435        if not isinstance(obj, tuple):
1436            return False
1437        arg_types = get_args(target_type)
1438        if len(obj) != len(arg_types):
1439            return False
1440        for el, el_type in zip(obj, arg_types):
1441            el_origin = get_origin(el_type)
1442            if el_origin:
1443                if not container_checker(el, el_type):
1444                    return False
1445            elif not isinstance(el, el_type):
1446                return False
1447        return True
1448    elif origin_type is Union or issubclass(
1449        origin_type, BuiltinUnionType
1450    ):  # also handles Optional
1451        if obj is None:  # check before recursion because None is always fine
1452            return True
1453        inner_types = get_args(target_type)
1454        for t in inner_types:
1455            t_origin = get_origin(t)
1456            if t_origin:
1457                return container_checker(obj, t)
1458            elif isinstance(obj, t):
1459                return True
1460    return False
1461
1462
1463def _isinstance(obj, target_type) -> bool:
1464    if isinstance(target_type, collections.abc.Container):
1465        if not isinstance(target_type, tuple):
1466            raise RuntimeError(
1467                "The second argument to "
1468                "`torch.jit.isinstance` must be a type "
1469                "or a tuple of types"
1470            )
1471        for t_type in target_type:
1472            if _isinstance(obj, t_type):
1473                return True
1474        return False
1475
1476    origin_type = get_origin(target_type)
1477    if origin_type:
1478        return container_checker(obj, target_type)
1479
1480    # Check to handle non-typed optional origin returns as none instead
1481    #    of as optional in 3.7-3.8
1482    check_args_exist(target_type)
1483
1484    # handle non-containers
1485    return isinstance(obj, target_type)
1486
1487
1488class _TensorExtractor(pickle.Pickler):
1489    def __init__(self, *args, tensors: List[torch.Tensor], **kwargs):
1490        super().__init__(*args, **kwargs)
1491        self.tensors = tensors
1492
1493    def persistent_id(self, obj):
1494        if isinstance(obj, torch.Tensor):
1495            self.tensors.append(obj)
1496            return ""
1497        # Since we just want to extract tensors, we don't mind if an object is
1498        # unpicklable if it doesn't contain tensors, as we can just ignore/skip
1499        # it. To play it safe, we only do so for common objects that we're sure
1500        # don't contain tensors. Feel free to add new types here. Note also that
1501        # even if a type isn't listed here this won't block users, since thet
1502        # can just add a __getstate__ or __reduce__ method to their class.
1503        if isinstance(obj, LockType):
1504            return ""
1505        # Futures and RRefs don't technically contain a value, they just offer
1506        # the means to access a value.
1507        if isinstance(obj, CFuture) or is_rref_instance(obj):
1508            return ""
1509        if isinstance(obj, CAwait):
1510            return ""
1511        if isinstance(obj, torch.cuda.Event):
1512            return ""
1513        if isinstance(obj, threading.Thread):
1514            return ""
1515        return None
1516
1517
1518def _extract_tensors(obj):
1519    r"""
1520    This function is exclusively called from C++.
1521    See ``torch/csrc/jit/python/python_ivalue.h``.
1522
1523    It extracts the tensors contained in the given object, through pickling.
1524    """
1525    tensors: List[torch.Tensor] = []
1526    extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
1527    extractor.dump(obj)
1528    return tensors
1529
1530
1531def _get_model_id(obj) -> Optional[str]:
1532    if isinstance(obj, torch.jit.ScriptModule):
1533        return str(obj._c._type())
1534    elif isinstance(obj, torch.jit.ScriptFunction):
1535        return obj.qualified_name
1536    else:
1537        return None
1538
1539
1540# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
1541# that were previously dropped. To preserve the behavior, explicitly drop them there
1542
1543if sys.version_info > (3, 10):
1544    _drop(enum.Enum.__new__)
1545    _drop(enum.Enum.__format__)
1546    _drop(enum.Enum.__repr__)
1547    _drop(enum.Enum.__str__)
1548