xref: /aosp_15_r20/external/pytorch/torch/jit/_script.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""TorchScript.
2
3This module contains functionality to support the JIT's scripting frontend, notably:
4    - torch.jit.script
5
6This is not intended to be imported directly; please use the exposed
7functionalities in `torch.jit`.
8"""
9import collections
10import copy
11import enum
12import functools
13import inspect
14import pickle
15import warnings
16from typing import Any, Callable, Dict, List, Set, Tuple, Union
17
18import torch
19import torch._jit_internal as _jit_internal
20from torch._classes import classes
21from torch._jit_internal import _get_model_id, _qualified_name
22from torch._utils_internal import log_torchscript_usage
23from torch.jit._builtins import _register_builtin
24from torch.jit._fuser import _graph_for, _script_method_graph_for
25from torch.jit._monkeytype_config import (
26    JitTypeTraceConfig,
27    JitTypeTraceStore,
28    monkeytype_trace,
29)
30from torch.jit._recursive import (
31    _compile_and_register_class,
32    infer_methods_to_compile,
33    ScriptMethodStub,
34    wrap_cpp_module,
35)
36from torch.jit._state import (
37    _enabled,
38    _set_jit_function_cache,
39    _set_jit_overload_cache,
40    _try_get_jit_cached_function,
41    _try_get_jit_cached_overloads,
42)
43from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def
44from torch.nn import Module
45from torch.overrides import (
46    has_torch_function,
47    has_torch_function_unary,
48    has_torch_function_variadic,
49)
50from torch.package import PackageExporter, PackageImporter
51from torch.utils import set_module
52
53from ._serialization import validate_map_location
54
55
56type_trace_db = JitTypeTraceStore()  # DB to hold all call traces from MonkeyType
57
58torch._C.ScriptMethod.graph_for = _script_method_graph_for  # type: ignore[attr-defined]
59torch._C.ScriptFunction.graph_for = _graph_for  # type: ignore[attr-defined]
60ScriptFunction = torch._C.ScriptFunction
61ScriptFunction.__doc__ = """
62Functionally equivalent to a :class:`ScriptModule`, but represents a single
63function and does not have any attributes or Parameters.
64"""
65set_module(ScriptFunction, "torch.jit")
66
67
68# Throws an error if a jit function is pickled.
69# Helps to avoid Python crashes for Python versions 3.9.5 + when protocol 0 or 1 is given as an argument.
70def _reduce(cls):
71    raise pickle.PickleError("ScriptFunction cannot be pickled")
72
73
74ScriptFunction.__reduce__ = _reduce  # type: ignore[assignment]
75
76
77if _enabled:
78    Attribute = collections.namedtuple("Attribute", ["value", "type"])
79else:
80
81    def Attribute(value, type):  # type: ignore[no-redef]
82        return value
83
84
85Attribute.__doc__ = """
86    This method is a pass-through function that returns `value`, mostly
87    used to indicate to the TorchScript compiler that the left-hand side
88    expression is a class instance attribute with type of `type`. Note that
89    `torch.jit.Attribute` should only be used in `__init__` method of `jit.ScriptModule`
90    subclasses.
91
92    Though TorchScript can infer correct type for most Python expressions, there are some cases where
93    type inference can be wrong, including:
94
95    - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor`
96    - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume
97      it is type `T` rather than `Optional[T]`
98
99    In eager mode, it is simply a pass-through function that returns `value`
100    without other implications.
101
102    Example:
103
104    .. testcode::
105
106        import torch
107        from typing import Dict
108
109        class AttributeModule(torch.jit.ScriptModule):
110            def __init__(self) -> None:
111                super().__init__()
112                self.foo = torch.jit.Attribute(0.1, float)
113
114                # we should be able to use self.foo as a float here
115                assert 0.0 < self.foo
116
117                self.names_ages = torch.jit.Attribute({}, Dict[str, int])
118                self.names_ages["someone"] = 20
119                assert isinstance(self.names_ages["someone"], int)
120
121        m = AttributeModule()
122        # m will contain two attributes
123        # 1. foo of type float
124        # 2. names_ages of type Dict[str, int]
125
126    .. testcleanup::
127
128        del AttributeModule
129        del m
130
131    Note: it's now preferred to instead use type annotations instead of `torch.jit.Attribute`:
132
133    .. testcode::
134
135        import torch
136        from typing import Dict
137
138        class AttributeModule(torch.nn.Module):
139            names: Dict[str, int]
140
141            def __init__(self) -> None:
142                super().__init__()
143                self.names = {}
144
145        m = AttributeModule()
146
147    .. testcleanup::
148
149        del AttributeModule
150        del m
151
152    Args:
153        value: An initial value to be assigned to attribute.
154        type: A Python type
155
156    Returns:
157        Returns `value`
158"""
159
160
161def _get_type_trace_db():
162    # This is a private API. Use of this for external purposes is discouraged.
163    return type_trace_db
164
165
166# Gets a function from the name of a method on a type
167def _get_function_from_type(cls, name):
168    return getattr(cls, name, None)
169
170
171# ScriptClasses must be new-style classes because we construct them using their
172# __new__ method.
173def _is_new_style_class(cls):
174    if hasattr(cls, "__class__"):
175        return "__dict__" in dir(cls) or hasattr(cls, "__slots__")
176
177
178# These OrderedDictWrapper classes replace the actual OrderedDicts in
179# module with versions that get/set properties inside of Module.
180# This allows us to reuse most of nn.Module while still storing the
181# data in C++.
182# Each OrderedDict needs to support:
183#  x not in view
184#  x in view
185#  view[name] = ...
186#  view.values()
187#  del view[name]
188#  view.items()
189#  view.keys()
190#  len(view)
191
192
193class OrderedDictWrapper:
194    def __init__(self, _c):
195        self._c = _c
196
197    def keys(self):
198        return [k for k, v in self.items()]
199
200    def values(self):
201        return [v for k, v in self.items()]
202
203    def __len__(self):
204        return len(self.values())
205
206    def __delitem__(self, k):
207        raise RuntimeError("cannot delete methods or parameters of a script module")
208
209    def items(self):
210        return self._c.items()
211
212    def __setitem__(self, k, v):
213        if k not in self:
214            raise RuntimeError(
215                f"Can't add a new parameter after ScriptModule construction. Tried to add '{k}"
216            )
217        self._c.setattr(k, v)
218
219    def __contains__(self, k):
220        return self._c.contains(k)
221
222    def __getitem__(self, k):
223        if k not in self:
224            raise KeyError(k)
225        return self._c.getattr(k)
226
227
228class OrderedModuleDict(OrderedDictWrapper):
229    def __init__(self, module, python_dict):
230        super().__init__(torch._C.ModuleDict(module))
231        # contains _both_ script modules and non-script python-only modules
232
233        # because script modules are subclassed in python and the
234        # C++ Module class will not hold references to them,
235        # to ensure that you always get the same python value here
236        # we store it in the python dict as well
237        self._python_modules = python_dict
238
239    def items(self):
240        r = self._python_modules.items()
241        return r
242
243    def __contains__(self, k):
244        return k in self._python_modules
245
246    def __setitem__(self, k, v):
247        # Cases where sub-module can be re-assigned after ScriptModule construction
248        # 1. If the attr is an module interface type, it's guaranteed that the module is
249        #    not inlined in the graph, so it's safe to swap a new ScriptModule in.
250        # 2. if the new value if a ScriptModule with the same JIT type, IR won't change
251        #    and it's legit to swap a new module in.
252        # In these two cases we allow swapping a new scripted module and update the
253        # corresponding python module dict to keep sync.
254        # Note: the value to be swapped in has to be ScriptModule instead of nn.Module,
255        # otherwise it's illegal and we throw error.
256        if isinstance(v, ScriptModule):
257            self._c.setattr(k, v)
258            self._python_modules[k] = v
259        else:
260            raise RuntimeError(
261                "Cannot re-assign modules in a ScriptModule with non-scripted "
262                f"module, tried to replace existing module '{k}': {v}"
263            )
264
265    def __getitem__(self, k):
266        return self._python_modules[k]
267
268
269# For each user-defined class that subclasses ScriptModule, this meta-class:
270# (1) finds all the methods annotated with @script_method in a ScriptModule and
271#     removes them from the class attributes
272# (2) puts a wrapper around the class's __init__ method to recursively compile
273#     all of the script_methods with the module after the original __init__ has
274#     run. This has to occur after the user-defined __init__ so that submodules and
275#     parameters are initialized _before_ the script compiler resolve references to
276#     `self.param` or `self.module`.
277class ScriptMeta(type):
278    def __init__(cls, name, bases, attrs):  # noqa: B902
279        # Aggregate all the ScriptMethods and constants from superclasses
280        cls._methods: Dict[str, Any] = {}
281        cls._constants_set = set(getattr(cls, "__constants__", ()))
282        for base in reversed(bases):
283            for k, v in getattr(base, "_methods", {}).items():
284                cls._methods[k] = v
285            base_constants: Set = getattr(base, "_constants_set", set())
286            cls._constants_set = cls._constants_set.union(base_constants)
287
288        # find all the script methods of the current class
289        for k, v in sorted(attrs.items()):
290            if isinstance(v, ScriptMethodStub):
291                delattr(cls, k)
292                cls._methods[v.original_method.__name__] = v
293
294        if getattr(cls, "_disable_script_meta", False):
295            # We leave built-in ScriptModule types alone, since this metaclass
296            # is only for compiling user classes that inherit from
297            # ScriptModule.
298            super().__init__(name, bases, attrs)
299            return
300
301        original_init = getattr(cls, "__init__", lambda self: None)
302
303        @functools.wraps(original_init)
304        def init_then_script(self, *args, **kwargs):
305            num_methods = len(cls._methods)
306            original_init(self, *args, **kwargs)
307            added_methods_in_init = len(cls._methods) > num_methods
308
309            if type(self) == cls:
310
311                def make_stubs(module):
312                    cls = type(module)
313                    if hasattr(cls, "_methods"):
314                        return [v for k, v in sorted(cls._methods.items())]
315                    else:
316                        return infer_methods_to_compile(module)
317
318                self.__dict__[
319                    "_actual_script_module"
320                ] = torch.jit._recursive.create_script_module(
321                    self, make_stubs, share_types=not added_methods_in_init
322                )
323
324                # Delete the Python attributes that now shadow the ScriptModule
325                # ones, so that __getattr__ and __setattr__ will properly find
326                # the scripted versions.
327                concrete_type = self._actual_script_module._concrete_type
328                for name in concrete_type.get_attributes():
329                    delattr(self, name)
330                for name, _ in concrete_type.get_modules():
331                    delattr(self, name)
332                for name in ("_parameters", "_buffers", "_modules"):
333                    delattr(self, name)
334
335        cls.__init__ = init_then_script  # type: ignore[misc]
336        super().__init__(name, bases, attrs)
337
338
339class _CachedForward:
340    def __get__(self, obj, cls):
341        return self.__getattr__("forward")  # type: ignore[attr-defined]
342
343
344class ScriptWarning(Warning):
345    pass
346
347
348def script_method(fn):
349    if not _enabled:
350        return fn
351    # NOTE: we need to traverse two frames here because the meta-class frame
352    # for ScriptModule will be present, as opposed to invoking @script on a
353    # a function or invoking define() on a CompilationUnit.
354    # The stack will look like:
355    #
356    # 0. createResolutionCallback()
357    # 1. script_method()
358    # 2. ScriptModule metaclass frame
359    # 3. Surrounding scope
360    #
361    # createResolutionCallback internally adds 1 to get us to the scope of this
362    # function (the calling function). Adding 2 gets us to the proper surrounding scope.
363    _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2)
364    ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule")
365    return ScriptMethodStub(_rcb, ast, fn)
366
367
368class ConstMap:
369    def __init__(self, const_mapping):
370        self.const_mapping = const_mapping
371
372    def __getattr__(self, attr):
373        return self.const_mapping[attr]
374
375
376def unpackage_script_module(
377    importer: PackageImporter, script_module_id: str
378) -> torch.nn.Module:
379    """
380    Call by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function.
381
382    Performs work of loading and returning a ScriptModule from a ``torch.package`` archive.
383    """
384    if not isinstance(importer.zip_reader, torch._C.PyTorchFileReader):
385        raise RuntimeError(
386            "Loading ScriptObjects from a PackageImporter created from a "
387            "directory is not supported. Use a package archive file instead."
388        )
389    cu = torch._C.CompilationUnit()
390    cpp_module = torch._C._import_ir_module_from_package(
391        cu,
392        importer.zip_reader,
393        importer.storage_context,
394        validate_map_location(importer.last_map_location),
395        script_module_id,
396    )
397    return wrap_cpp_module(cpp_module)
398
399
400if _enabled:
401    _magic_methods = [
402        "__iter__",
403        "__len__",
404        "__neg__",
405        "__mul__",
406        "__contains__",
407        "__add__",
408        "__sub__",
409        "__pow__",
410        "__truediv__",
411        "__mod__",
412        "__ne__",
413        "__eq__",
414        "__lt__",
415        "__gt__",
416        "__le__",
417        "__ge__",
418        "__and__",
419        "__or__",
420        "__xor__",
421        "__getitem__",
422        "__setitem__",
423        "__call__",
424        "__int__",
425        "__float__",
426        "__bool__",
427        "__str__",
428        "__enter__",
429        "__exit__",
430    ]
431
432    class RecursiveScriptClass:
433        """Wrapper for a TorchScript class instance for use in Python.
434
435        An analogue of RecursiveScriptModule for regular objects that are not modules.
436        This class is a wrapper around a torch._C.ScriptObject that represents an instance
437        of a TorchScript class and allows it to be used in Python.
438
439        Attributes:
440            _c [torch._C.ScriptObject]: The C++ object to which attribute lookups and method
441                calls are forwarded.
442            _props [Dict[str, property]]: A dictionary of properties fetched from self._c and
443                exposed on this wrppaer.
444        """
445
446        def __init__(self, cpp_class):
447            super().__init__()
448            self.__dict__["_initializing"] = True
449            self._c = cpp_class
450
451            # Add wrapped object's properties to this class instance.
452            self._props = {
453                prop.name: property(prop.getter, prop.setter)
454                for prop in self._c._properties()
455            }
456
457            self.__dict__["_initializing"] = False
458
459        def __getattr__(self, attr):
460            if self.__dict__.get("_initializing"):
461                return super().__getattr__(attr)  # type: ignore[misc]
462
463            if attr in self._props:
464                return self._props[attr].fget()  # type: ignore[call-arg, misc]
465
466            return getattr(self._c, attr)
467
468        def __setattr__(self, attr, value):
469            if self.__dict__.get("_initializing"):
470                return super().__setattr__(attr, value)
471
472            if attr in self._props:
473                return self._props[attr].fset(value)  # type: ignore[call-arg, misc]
474
475            setattr(self._c, attr, value)
476
477        # Delegate calls to magic methods like __len__ to the C++ module backing the
478        # RecursiveScriptClass.
479        def forward_magic_method(self, method_name, *args, **kwargs):
480            if not self._c._has_method(method_name):
481                raise TypeError
482
483            self_method = self.__getattr__(method_name)
484            return self_method(*args, **kwargs)
485
486        def __getstate__(self):
487            raise pickle.PickleError("ScriptClasses cannot be pickled")
488
489        def __iadd__(self, other):
490            if self._c._has_method("__iadd__"):
491                return self.forward_magic_method("__iadd__", other)
492            else:
493                return self.forward_magic_method("__add__", other)
494
495    for method_name in _magic_methods:
496
497        def method_template(self, *args, **kwargs):
498            return self.forward_magic_method(method_name, *args, **kwargs)
499
500        setattr(RecursiveScriptClass, method_name, method_template)
501
502    # this is a Python 'non-data descriptor' that causes the first access
503    # to ScriptModule's forward to look up the forward method and stash
504    # it in the objects dict. Due to the standard rules for attribute lookup,
505    # subsequent lookups will just directly return the previously looked up method.
506    # This is necessary because nn.Module defines forward as a method. If we
507    # did nothing, __getattr__ would not be called. Instead we'd get nn.Module.forward
508    # which always throws an exception.
509
510    class ScriptModule(Module, metaclass=ScriptMeta):
511        r"""Wrapper for C++ torch::jit::Module with methods, attributes, and parameters.
512
513        A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s
514        contain methods, attributes, parameters, and
515        constants. These can be accessed the same way as on a normal ``nn.Module``.
516        """
517
518        __jit_unused_properties__ = [
519            "code",
520            "code_with_constants",
521            "graph",
522            "inlined_graph",
523            "original_name",
524        ]
525
526        def __init__(self) -> None:
527            super().__init__()
528
529        forward: Callable[..., Any] = _CachedForward()  # type: ignore[assignment]
530
531        def __getattr__(self, attr):
532            if "_actual_script_module" not in self.__dict__:
533                return super().__getattr__(attr)
534            return getattr(self._actual_script_module, attr)
535
536        def __setattr__(self, attr, value):
537            if "_actual_script_module" not in self.__dict__:
538                # Unwrap torch.jit.Attribute into a regular setattr + record
539                # the provided type in __annotations__.
540                #
541                # This ensures that if we use the attr again in `__init__`, it
542                # will look like the actual value, not an instance of Attribute.
543                if isinstance(value, Attribute):
544                    # NB: Ensure that we set __annotations__ on the specific
545                    # class in question, and not on a superclass (which would
546                    # be wrong wrong wrong!).
547                    # See also https://github.com/pytorch/pytorch/issues/39463
548                    if "__annotations__" not in self.__class__.__dict__:
549                        self.__class__.__annotations__ = {}
550                    self.__annotations__[attr] = value.type
551                    value = value.value
552                return super().__setattr__(attr, value)
553
554            setattr(self._actual_script_module, attr, value)
555
556        def define(self, src):
557            if "_actual_script_module" in self.__dict__:
558                # If we have completed initialization, just defer to the
559                # backing RecursiveScriptModule to eagerly compile the provided
560                # source.
561                return self._actual_script_module.define(src)
562
563            # Otherwise, we are still in the object's __init__.
564            # In that case, add `src` as a stub to be compiled.
565            #
566            # We use frames_up=1 to get to the proper surrounding scope. The stack
567            # will look like:
568            # 0. createResolutionCallback
569            # 1. define()
570            # 2. surrounding scope.
571            #
572            # createResolutionCallback internally adds 1 to get us to our frame, then
573            # we add 1 to get to the proper surrounding scope.
574            rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1)
575            ast = torch._C._parse_source_def(src)
576            self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None)
577
578        def _replicate_for_data_parallel(self):
579            return self._actual_script_module._replicate_for_data_parallel()
580
581        def __reduce_package__(self, exporter: PackageExporter):
582            """Save a ScriptModule inside of a ``torch.package`` archive.
583
584            Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when
585            saving TorchScript objects. Performs act of saving a ScriptModule inside of
586            a ``torch.package`` archive.
587
588            Returns method to load the ScriptModule from a ``torch.package.PackageImporter``'s
589            Pickler's ``persistent_load`` function.
590            """
591            script_module_id = exporter.get_unique_id()
592            exporter.script_module_serializer.serialize(self._c, int(script_module_id))
593            return (unpackage_script_module, (script_module_id,))
594
595    class RecursiveScriptModule(ScriptModule):
596        # XXX: RecursiveScriptModule inherits from ScriptModule for the sole
597        # reason that it retains the existing isinstance(ScriptModule)
598        # behavior.
599        r"""Retain the existing isinstance(ScriptModule) behavior.
600
601        The core data structure in TorchScript is the ``ScriptModule``. It is an
602        analogue of torch's ``nn.Module`` and represents an entire model as a tree of
603        submodules. Like normal modules, each individual module in a ``ScriptModule`` can
604        have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented
605        as Python functions, but in ``ScriptModule``\s methods are implemented as
606        TorchScript functions, a statically-typed subset of Python that contains all
607        of PyTorch's built-in Tensor operations. This difference allows your
608        ``ScriptModule``\s code to run without the need for a Python interpreter.
609
610        ``ScriptModule``\s should not be created manually, instead use
611        either :func:`tracing <torch.jit.trace>` or :func:`scripting <torch.jit.script>`.
612        Tracing and scripting can be applied incrementally and :ref:`composed as necessary <Types>`.
613
614        * Tracing records the tensor operations as executed with a set of example inputs and uses these
615          operations to construct a computation graph. You can use the full dynamic behavior of Python with tracing,
616          but values other than Tensors and control flow aren't captured in the graph.
617
618        * Scripting inspects the Python code of the model
619          and compiles it to TorchScript. Scripting allows the use of many `types`_ of values and supports dynamic control flow.
620          Many, but not all features of Python are supported by the compiler, so changes to the source code may be necessary.
621        """
622
623        _disable_script_meta = True
624
625        def __init__(self, cpp_module):
626            self.__dict__["_initializing"] = True
627            self._c = cpp_module
628            super().__init__()
629            # Delete the 'training' attribute set up by `Module.__init__`. It
630            # will get set on the underlying cpp module, so we delete it here
631            # to avoid this version shadowing the cpp module version.
632            delattr(self, "training")
633
634        @staticmethod
635        def _construct(cpp_module, init_fn):
636            """
637            Construct a RecursiveScriptModule that's ready for use.
638
639            PyTorch code should use this to construct a RecursiveScriptModule instead
640            of instead of calling `__init__` directly, as it makes sure the
641            object is properly finalized (and in the future, we may take
642            control of how the RecursiveScriptModule instance is created).
643
644            Args:
645                cpp_module:  The C++ Module that will hold the actual state of
646                             this RecursiveScriptModule instance.
647                init_fn:  Lambda that initializes the RecursiveScriptModule passed to it.
648            """
649            script_module = RecursiveScriptModule(cpp_module)
650            init_fn(script_module)
651
652            # Finalize the ScriptModule: replace the nn.Module state with our
653            # custom implementations and flip the _initializing bit.
654            RecursiveScriptModule._finalize_scriptmodule(script_module)
655            return script_module
656
657        @staticmethod
658        def _finalize_scriptmodule(script_module):
659            script_module._parameters = OrderedDictWrapper(
660                torch._C.ParameterDict(script_module._c)
661            )
662            script_module._buffers = OrderedDictWrapper(
663                torch._C.BufferDict(script_module._c)
664            )
665            script_module._modules = OrderedModuleDict(
666                script_module._c, script_module._modules
667            )
668            script_module._initializing = False
669
670        def _reconstruct(self, cpp_module):
671            """
672            Re-construct an instance of RecursiveScriptModule using an instance of a C++ module.
673
674            Args:
675                cpp_module: The C++ module that this RecursiveScriptModule will be rebuilt around.
676            """
677            self.__init__(cpp_module)  # type: ignore[misc]
678
679            # Copy the concrete type from the C++ module to this ScriptModule.
680            self._concrete_type = torch._C.ConcreteModuleType.from_jit_type(
681                self._c._type()
682            )
683
684            # Copy submodules from the C++ module to this ScriptModule.
685            modules = {}
686            for name, cpp_module in torch._C.ModuleDict(self._c).items():
687                modules[name] = wrap_cpp_module(cpp_module)
688            self._modules = OrderedModuleDict(self._c, modules)  # type: ignore[assignment]
689
690            # Copy parameters and buffers.
691            self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c))  # type: ignore[assignment]
692            self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c))  # type: ignore[assignment]
693
694            # Get rid of the functions from the old C++ module.
695            self.__dict__ = {
696                k: v
697                for k, v in self.__dict__.items()
698                if not isinstance(v, torch._C.ScriptMethod)
699            }
700            self.__dict__["_initializing"] = False
701
702        @property
703        def graph(self):
704            r"""Return a string representation of the internal graph for the ``forward`` method.
705
706            See :ref:`interpreting-graphs` for details.
707            """
708            return self._c._get_method("forward").graph
709
710        @property
711        def inlined_graph(self):
712            r"""
713            Return a string representation of the internal graph for the ``forward`` method.
714
715            This graph will be preprocessed to inline all function and method calls.
716            See :ref:`interpreting-graphs` for details.
717            """
718            return self.forward.inlined_graph  # type: ignore[attr-defined]
719
720        @property
721        def code(self):
722            r"""
723            Return a pretty-printed representation (as valid Python syntax) of the internal graph for the ``forward`` method.
724
725            See :ref:`inspecting-code` for details.
726            """
727            return self.forward.code  # type: ignore[attr-defined]
728
729        @property
730        def code_with_constants(self):
731            r"""Return a tuple.
732
733            Returns a tuple of:
734
735            [0] a pretty-printed representation (as valid Python syntax) of
736            the internal graph for the ``forward`` method. See `code`.
737            [1] a ConstMap following the CONSTANT.cN format of the output in [0].
738            The indices in the [0] output are keys to the underlying constant's values.
739
740            See :ref:`inspecting-code` for details.
741            """
742            r = self.forward.code_with_constants  # type: ignore[attr-defined]
743            return (r[0], ConstMap(r[1]))
744
745        def save(self, f, **kwargs):
746            r"""Save with a file-like object.
747
748            save(f, _extra_files={})
749
750            See :func:`torch.jit.save <torch.jit.save>` which accepts a file-like object.
751            This function, torch.save(), converts the object to a string, treating it as a path.
752            DO NOT confuse these two functions when it comes to the 'f' parameter functionality.
753            """
754            return self._c.save(str(f), **kwargs)
755
756        def _save_for_lite_interpreter(self, *args, **kwargs):
757            r"""Add (or update) the bytecode session to the script model.
758
759            _save_for_lite_interpreter(f)
760
761            The updated model is used
762            in lite interpreter for mobile applications.
763
764            Args:
765                f: a string containing a file name.
766                _extra_files: Map from filename to contents which will be stored as part of 'f'.
767
768            """
769            return self._c._save_for_mobile(*args, **kwargs)
770
771        def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs):
772            return self._c._save_to_buffer_for_mobile(*args, **kwargs)
773
774        def save_to_buffer(self, *args, **kwargs):
775            return self._c.save_to_buffer(*args, **kwargs)
776
777        def get_debug_state(self, *args, **kwargs):
778            return self._c.get_debug_state()
779
780        def extra_repr(self):
781            return f"original_name={self.original_name}"
782
783        def graph_for(self, *args, **kwargs):
784            return self.forward.graph_for(self, *args, **kwargs)  # type: ignore[attr-defined]
785
786        @property
787        def original_name(self):
788            if type(self) == str(self._c._type().name()):
789                return ""
790            return str(self._c._type().name())
791
792        def define(self, src):
793            # We use frames_up=1 to get to the proper surrounding scope. The stack
794            # will look like:
795            # 0. createResolutionCallback
796            # 1. define()
797            # 2. surrounding scope.
798            #
799            # createResolutionCallback internally adds 1 to get us to our frame, then
800            # we add 1 to get to the proper surrounding scope.
801            rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1)
802            self._c._define(self._concrete_type, src, rcb)
803
804        def __getattr__(self, attr):
805            if "_initializing" not in self.__dict__:
806                raise RuntimeError(
807                    "ScriptModule has not been initialized, did you forget to call super's init?"
808                )
809
810            if self._initializing:
811                return super().__getattr__(attr)
812
813            # _modules check is before hasattr since modules are included as attributes in _c,
814            # but we want to get the python wrapper from _modules instead of the raw _c object.
815            if attr in self._modules:
816                return self._modules[attr]
817            elif self._c.hasattr(attr):
818                return self._c.getattr(attr)
819            elif self._c._has_method(attr):
820                script_method = self._c._get_method(attr)
821                # cache method so future calls do not go through __getattr__
822                # to improve invocation performance
823                self.__dict__[attr] = script_method
824                return script_method
825
826            return super().__getattr__(attr)
827
828        def __setattr__(self, attr, value):
829            if self._initializing:
830                return super().__setattr__(attr, value)
831
832            if attr in self._modules:
833                self._modules[attr] = value
834            elif self._c.hasattr(attr):
835                self._c.setattr(attr, value)
836            elif (
837                hasattr(self, "_concrete_type")
838                and attr in self._concrete_type.get_constants().keys()
839            ):
840                # TODO: we don't have _concrete_type set after load(), and in general we lose constant information.
841                # We should encode constants as class type attributes (or something) so it persists across save/load.
842                raise AttributeError(
843                    f"Cannot mutate TorchScript constant value: '{attr}'. Value: '{value}'"
844                )
845            else:
846                # We allow setting Python attributes on the ScriptModule, for
847                # when people want to stash some convenience info on it.
848                # TODO: it's possible that the following is confusing:
849                #   s = torch.jit.script(...)
850                #   s.python_attr = ...
851                #   s.save()   <--- this doesn't have `python_attr`
852                # It's fairly trivial to save enough info to warn in this case.
853                return super().__setattr__(attr, value)
854
855        def __copy__(self):
856            return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c))
857
858        def __deepcopy__(self, memo):
859            return torch.jit._recursive.wrap_cpp_module(copy.deepcopy(self._c, memo))
860
861        # Python magic methods do method lookups on an object's class type, instead of looking up
862        # the method defines on the class instance. In order to continue to expose the magic methods
863        # of builtin-containers (ModuleList, Sequential, ModuleDict) to Python, we
864        # define magic methods here as a shim to the correct attribute.
865        def forward_magic_method(self, method_name, *args, **kwargs):
866            self_method = getattr(self, method_name)
867            if getattr(self_method, "__func__", None) == getattr(
868                RecursiveScriptModule, method_name
869            ):
870                raise NotImplementedError
871            return self_method(*args, **kwargs)
872
873        def __iter__(self):
874            return self.forward_magic_method("__iter__")
875
876        def __getitem__(self, idx):
877            return self.forward_magic_method("__getitem__", idx)
878
879        def __len__(self):
880            return self.forward_magic_method("__len__")
881
882        def __contains__(self, key):
883            return self.forward_magic_method("__contains__", key)
884
885        # dir is defined by the base nn.Module, so instead of throwing if
886        # it is not overridden, we call into the nn.Module __dir__ method
887        def __dir__(self):
888            self_method = self.__dir__
889            if (
890                self_method.__func__  # type: ignore[attr-defined]
891                == _get_function_from_type(RecursiveScriptModule, "__dir__")
892            ):
893                return super().__dir__()
894            return self_method()
895
896        # to resolve bool(value), Python looks if __bool__ is defined then __iter__
897        # is defined then returns true for classes. Since __iter__() on this
898        # class throws if it isn't overridden, we define __bool__ to preserve default behavior
899        def __bool__(self):
900            self_method = self.__bool__
901            if (
902                self_method.__func__  # type: ignore[attr-defined]
903                == _get_function_from_type(RecursiveScriptModule, "__bool__")
904            ):
905                return True
906            return self_method()
907
908        def _replicate_for_data_parallel(self):
909            # we have to initialize ScriptModule properly so that
910            # it works with pybind11
911            def init_fn(script_module):
912                # Don't do anything here, we'll initialize the ScriptModule below
913                return
914
915            return RecursiveScriptModule._construct(
916                self._c._replicate_for_data_parallel(), init_fn
917            )
918
919    # Need to copy all RecursiveScriptModule methods to ScriptModule.
920    #
921    # This is because `super().foo()` does not use
922    # `__getattr__` to look up `foo`. So we need to make each method available on
923    # the ScriptModule manually.
924    for name, item in RecursiveScriptModule.__dict__.items():
925        if not callable(item) and not isinstance(item, property):
926            continue
927        if name.startswith("__") or hasattr(ScriptModule, name):
928            continue
929        # We can copy over the implementation wholesale because besides the
930        # `super()` thing above, ScriptModule behaves exactly like
931        # RecursiveScriptModule
932        setattr(ScriptModule, name, item)
933
934    def _get_methods(cls):
935        import inspect
936
937        # In Python 3 unbound methods are functions, but in Python 2 they are methods
938        return inspect.getmembers(
939            cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x)
940        )
941
942    _compiled_methods_allowlist = {
943        "forward",
944        "register_buffer",
945        "register_parameter",
946        "register_module",
947        "add_module",
948        "_apply",
949        "apply",
950        "cuda",
951        "cpu",
952        "to",
953        "type",
954        "float",
955        "double",
956        "half",
957        "state_dict",
958        "_save_to_state_dict",
959        "load_state_dict",
960        "_load_from_state_dict",
961        "_named_members",
962        "parameters",
963        "named_parameters",
964        "buffers",
965        "named_buffers",
966        "children",
967        "named_children",
968        "modules",
969        "named_modules",
970        "zero_grad",
971        "share_memory",
972        "_get_name",
973        "extra_repr",
974        "_slow_forward",
975        "_tracing_name",
976        "eval",
977        "train",
978        "get_extra_state",
979        "set_extra_state",
980    }
981
982    def _make_fail(name):
983        def fail(self, *args, **kwargs):
984            raise RuntimeError(name + " is not supported on ScriptModules")
985
986        return fail
987
988    for name, method in _get_methods(torch.nn.Module):
989        if name.startswith("__") or name.endswith("_call_impl"):
990            continue
991        if (
992            name not in RecursiveScriptModule.__dict__
993            and name not in _compiled_methods_allowlist
994        ):
995            setattr(RecursiveScriptModule, method.__name__, _make_fail(name))
996
997
998else:
999    # TODO MAKE SURE THAT DISABLING WORKS
1000    class RecursiveScriptClass:  # type: ignore[no-redef]
1001        pass
1002
1003    class ScriptModule(torch.nn.Module):  # type: ignore[no-redef]
1004        def __init__(self, arg=None):
1005            super().__init__()
1006
1007    class RecursiveScriptModule(ScriptModule):  # type: ignore[no-redef]
1008        def __init__(self, arg=None):
1009            super().__init__()
1010
1011
1012def call_prepare_scriptable_func_impl(obj, memo):
1013    if not isinstance(obj, torch.nn.Module):
1014        return obj
1015
1016    obj_id = id(obj)
1017
1018    # If obj_id is in memo, obj has already been prepared or is being
1019    # prepared in another call up the stack.
1020    if obj_id in memo:
1021        return memo[id(obj)]
1022
1023    obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj  # type: ignore[operator]
1024    # Record obj in memo to avoid infinite recursion in the case of cycles in the module
1025    # hierarchy when recursing below.
1026    memo[obj_id] = obj
1027
1028    new_obj_dict = {}
1029
1030    for name, sub_module in obj.__dict__.items():
1031        if name == "_modules":
1032            for k, v in sub_module.items():
1033                sub_module[k] = call_prepare_scriptable_func_impl(v, memo)
1034            new_obj_dict[name] = sub_module
1035        elif isinstance(sub_module, torch.nn.Module) and not isinstance(
1036            sub_module, ScriptModule
1037        ):
1038            new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo)
1039        else:
1040            new_obj_dict[name] = sub_module
1041
1042    for k, v in new_obj_dict.items():
1043        obj.__dict__[name] = v
1044
1045    return obj
1046
1047
1048def call_prepare_scriptable_func(obj):
1049    memo: Dict[int, torch.nn.Module] = {}
1050    return call_prepare_scriptable_func_impl(obj, memo)
1051
1052
1053def create_script_dict(obj):
1054    """
1055    Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.
1056
1057    Args:
1058        obj (dict): The Python dictionary that is used to initialize the ``ScriptDict``
1059                    returned by this function.
1060
1061    Returns:
1062        An instance of ``torch._C.ScriptDict`` that has the same data as ``obj``
1063        and can be passed between Python and TorchScript with reference semantics and
1064        zero copy overhead.
1065    """
1066    return torch._C.ScriptDict(obj)  # type: ignore[attr-defined]
1067
1068
1069def create_script_list(obj, type_hint=None):
1070    """
1071    Create a ``torch._C.ScriptList`` instance with the data from ``obj``.
1072
1073    Args:
1074        obj (dict): The Python list that is used to initialize the ``ScriptList``
1075                    returned by this function.
1076    Returns:
1077        An instance of ``torch._C.ScriptList`` that has the same data as ``obj``
1078        and can be passed between Python and TorchScript with reference semantics and
1079        zero copy overhead.
1080    """
1081    return torch._C.ScriptList(obj)  # type: ignore[attr-defined]
1082
1083
1084_TOPLEVEL: bool = True
1085
1086
1087def _script_impl(
1088    obj,
1089    optimize=None,
1090    _frames_up=0,
1091    _rcb=None,
1092    example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
1093):
1094    global type_trace_db
1095
1096    if optimize is not None:
1097        warnings.warn(
1098            "`optimize` is deprecated and has no effect. "
1099            "Use `with torch.jit.optimized_execution()` instead",
1100            FutureWarning,
1101            stacklevel=3,
1102        )
1103
1104    # No-op for modules, functions, class instances that are already scripted
1105    if isinstance(obj, RecursiveScriptClass):
1106        return obj
1107    if isinstance(obj, ScriptModule):
1108        return obj
1109    if isinstance(obj, ScriptFunction):
1110        return obj
1111
1112    if example_inputs:
1113        # If MonkeyType is installed, enable profile directed type annotation
1114        # Check if example_inputs are defined and generate call traces
1115        # for the method by running eager mode version of the method with
1116        # the provide example inputs. This logs all the traces in type_trace_db
1117        type_trace_db = JitTypeTraceStore()
1118        if monkeytype_trace:
1119            monkeytype_config = JitTypeTraceConfig(type_trace_db)
1120            with monkeytype_trace(monkeytype_config):
1121                if isinstance(example_inputs, Dict):
1122                    # If the obj is an nn.Module or a class, then each method is
1123                    # executed with the arguments provided in the example inputs.
1124                    # example inputs here will be of type Dict(class.method, (arguments))
1125                    # This is used to infer type annotations for those methods
1126                    # which are not called directly under the hood of monkeytype.
1127                    for module, example_input in example_inputs.items():
1128                        for example in example_input:
1129                            module(*example)
1130                elif isinstance(example_inputs, List):
1131                    for examples in example_inputs:
1132                        obj(*examples)
1133                else:
1134                    raise ValueError(
1135                        "Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
1136                        " or `Dict[Callable, List[Tuple]]` to be run with MonkeyType."
1137                    )
1138        else:
1139            warnings.warn(
1140                "Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
1141                "to enable Profile-Directed Typing in TorchScript. Refer to "
1142                "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. "
1143            )
1144
1145    if isinstance(obj, torch.nn.Module):
1146        obj = call_prepare_scriptable_func(obj)
1147        return torch.jit._recursive.create_script_module(
1148            obj, torch.jit._recursive.infer_methods_to_compile
1149        )
1150    else:
1151        obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj  # type: ignore[operator]
1152
1153    if isinstance(obj, dict):
1154        return create_script_dict(obj)
1155    if isinstance(obj, list):
1156        return create_script_list(obj)
1157
1158    if inspect.isclass(obj):
1159        qualified_name = _qualified_name(obj)
1160        # If this type is a `nn.Module` subclass, they probably meant to pass
1161        # an instance instead of a Module
1162        if issubclass(obj, torch.nn.Module):
1163            raise RuntimeError(
1164                f"Type '{obj}' cannot be compiled since it inherits from nn.Module, pass an instance instead"
1165            )
1166
1167        # Enums are automatically usable in TorchScript, explicitly scripting
1168        # is not necessary, but not harmful either.
1169        if issubclass(obj, enum.Enum):
1170            return obj
1171
1172        if not _is_new_style_class(obj):
1173            raise RuntimeError(
1174                "TorchScript classes must be new-style classes. "
1175                "Please inherit from 'object'."
1176            )
1177        if len(obj.mro()) > 2:
1178            raise RuntimeError(
1179                "TorchScript classes does not support inheritance yet. "
1180                "Please directly inherit from 'object'."
1181            )
1182        if _rcb is None:
1183            _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
1184        _compile_and_register_class(obj, _rcb, qualified_name)
1185        return obj
1186    elif inspect.isfunction(obj) or inspect.ismethod(obj):
1187        qualified_name = _qualified_name(obj)
1188        # this is a decorated fn, and we need to the underlying fn and its rcb
1189        if hasattr(obj, "__script_if_tracing_wrapper"):
1190            obj = obj.__original_fn  # type: ignore[union-attr]
1191            _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
1192
1193        # some functions are explicitly marked as not supported in script mode
1194        if hasattr(obj, "__script_unsupported"):
1195            raise RuntimeError("TorchScript error: " + obj.__script_unsupported)
1196
1197        _check_directly_compile_overloaded(obj)
1198        maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
1199        if maybe_already_compiled_fn:
1200            maybe_already_compiled_fn._torchdynamo_inline = obj  # type: ignore[attr-defined]
1201            return maybe_already_compiled_fn
1202        ast = get_jit_def(obj, obj.__name__)
1203        if _rcb is None:
1204            _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
1205        fn = torch._C._jit_script_compile(
1206            qualified_name, ast, _rcb, get_default_args(obj)
1207        )
1208        # Forward docstrings
1209        fn.__doc__ = obj.__doc__
1210        # Allow torch.compile() to inline
1211        fn._torchdynamo_inline = obj  # type: ignore[attr-defined]
1212        _set_jit_function_cache(obj, fn)
1213        return fn
1214    else:
1215        return torch.jit._recursive.create_script_class(obj)
1216
1217
1218def script(
1219    obj,
1220    optimize=None,
1221    _frames_up=0,
1222    _rcb=None,
1223    example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
1224):
1225    r"""Script the function.
1226
1227    Scripting a function or ``nn.Module`` will inspect the source code, compile
1228    it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or
1229    :class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all
1230    features in Python work, but we provide enough functionality to compute on
1231    tensors and do control-dependent operations. For a complete guide, see the
1232    :ref:`language-reference`.
1233
1234    Scripting a dictionary or list copies the data inside it into a TorchScript instance than can be
1235    subsequently passed by reference between Python and TorchScript with zero copy overhead.
1236
1237    ``torch.jit.script`` can be used as a function for modules, functions, dictionaries and lists
1238     and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions.
1239
1240    Args:
1241        obj (Callable, class, or nn.Module):  The ``nn.Module``, function, class type,
1242                                                  dictionary, or list to compile.
1243        example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs
1244            to annotate the arguments for a function or ``nn.Module``.
1245
1246    Returns:
1247        If ``obj`` is ``nn.Module``, ``script`` returns
1248        a :class:`ScriptModule` object. The returned :class:`ScriptModule` will
1249        have the same set of sub-modules and parameters as the
1250        original ``nn.Module``. If ``obj`` is a standalone function,
1251        a :class:`ScriptFunction` will be returned. If ``obj`` is a ``dict``, then
1252        ``script`` returns an instance of `torch._C.ScriptDict`. If ``obj`` is a ``list``,
1253        then ``script`` returns an instance of `torch._C.ScriptList`.
1254
1255    **Scripting a function**
1256        The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction`
1257        by compiling the body of the function.
1258
1259        Example (scripting a function):
1260
1261        .. testcode::
1262
1263            import torch
1264
1265            @torch.jit.script
1266            def foo(x, y):
1267                if x.max() > y.max():
1268                    r = x
1269                else:
1270                    r = y
1271                return r
1272
1273            print(type(foo))  # torch.jit.ScriptFunction
1274
1275            # See the compiled graph as Python code
1276            print(foo.code)
1277
1278            # Call the function using the TorchScript interpreter
1279            foo(torch.ones(2, 2), torch.ones(2, 2))
1280
1281        .. testoutput::
1282            :hide:
1283
1284            ...
1285
1286    ****Scripting a function using example_inputs**
1287        Example inputs can be used to annotate a function arguments.
1288
1289        Example (annotating a function before scripting):
1290
1291        .. testcode::
1292
1293            import torch
1294
1295            def test_sum(a, b):
1296                return a + b
1297
1298            # Annotate the arguments to be int
1299            scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])
1300
1301            print(type(scripted_fn))  # torch.jit.ScriptFunction
1302
1303            # See the compiled graph as Python code
1304            print(scripted_fn.code)
1305
1306            # Call the function using the TorchScript interpreter
1307            scripted_fn(20, 100)
1308
1309        .. testoutput::
1310            :hide:
1311
1312            ...
1313
1314    **Scripting an nn.Module**
1315        Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively
1316        compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses
1317        features supported in TorchScript, no changes to the original module code should be necessary. ``script``
1318        will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of
1319        the original module.
1320
1321        Example (scripting a simple module with a Parameter):
1322
1323        .. testcode::
1324
1325            import torch
1326
1327            class MyModule(torch.nn.Module):
1328                def __init__(self, N, M):
1329                    super().__init__()
1330                    # This parameter will be copied to the new ScriptModule
1331                    self.weight = torch.nn.Parameter(torch.rand(N, M))
1332
1333                    # When this submodule is used, it will be compiled
1334                    self.linear = torch.nn.Linear(N, M)
1335
1336                def forward(self, input):
1337                    output = self.weight.mv(input)
1338
1339                    # This calls the `forward` method of the `nn.Linear` module, which will
1340                    # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
1341                    output = self.linear(output)
1342                    return output
1343
1344            scripted_module = torch.jit.script(MyModule(2, 3))
1345
1346        Example (scripting a module with traced submodules):
1347
1348        .. testcode::
1349
1350            import torch
1351            import torch.nn as nn
1352            import torch.nn.functional as F
1353
1354            class MyModule(nn.Module):
1355                def __init__(self) -> None:
1356                    super().__init__()
1357                    # torch.jit.trace produces a ScriptModule's conv1 and conv2
1358                    self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
1359                    self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
1360
1361                def forward(self, input):
1362                    input = F.relu(self.conv1(input))
1363                    input = F.relu(self.conv2(input))
1364                    return input
1365
1366            scripted_module = torch.jit.script(MyModule())
1367
1368        To compile a method other than ``forward`` (and recursively compile anything it calls), add
1369        the :func:`@torch.jit.export <torch.jit.export>` decorator to the method. To opt out of compilation
1370        use :func:`@torch.jit.ignore <torch.jit.ignore>` or :func:`@torch.jit.unused <torch.jit.unused>`.
1371
1372        Example (an exported and ignored method in a module)::
1373
1374            import torch
1375            import torch.nn as nn
1376
1377            class MyModule(nn.Module):
1378                def __init__(self) -> None:
1379                    super().__init__()
1380
1381                @torch.jit.export
1382                def some_entry_point(self, input):
1383                    return input + 10
1384
1385                @torch.jit.ignore
1386                def python_only_fn(self, input):
1387                    # This function won't be compiled, so any
1388                    # Python APIs can be used
1389                    import pdb
1390                    pdb.set_trace()
1391
1392                def forward(self, input):
1393                    if self.training:
1394                        self.python_only_fn(input)
1395                    return input * 99
1396
1397            scripted_module = torch.jit.script(MyModule())
1398            print(scripted_module.some_entry_point(torch.randn(2, 2)))
1399            print(scripted_module(torch.randn(2, 2)))
1400
1401        Example ( Annotating forward of nn.Module using example_inputs)::
1402
1403            import torch
1404            import torch.nn as nn
1405            from typing import NamedTuple
1406
1407            class MyModule(NamedTuple):
1408            result: List[int]
1409
1410            class TestNNModule(torch.nn.Module):
1411                def forward(self, a) -> MyModule:
1412                    result = MyModule(result=a)
1413                    return result
1414
1415            pdt_model = TestNNModule()
1416
1417            # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
1418            scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
1419
1420            # Run the scripted_model with actual inputs
1421            print(scripted_model([20]))
1422    """
1423    if not _enabled:
1424        return obj
1425    try:
1426        global _TOPLEVEL
1427        prev = _TOPLEVEL
1428        _TOPLEVEL = False
1429        ret = _script_impl(
1430            obj=obj,
1431            optimize=optimize,
1432            _frames_up=_frames_up + 1,
1433            _rcb=_rcb,
1434            example_inputs=example_inputs,
1435        )
1436
1437        if prev:
1438            log_torchscript_usage("script", model_id=_get_model_id(ret))
1439
1440        return ret
1441    finally:
1442        _TOPLEVEL = prev
1443
1444
1445# overloads are registered in _jit_internal and compiled here so that _overload
1446# can be used in nn/functional.py without an import cycle
1447
1448
1449def _check_overload_defaults(impl_defaults, overload_defaults, loc):
1450    for name, overload_value in overload_defaults.items():
1451        if name not in impl_defaults or impl_defaults[name] != overload_value:
1452            raise torch.jit.frontend.FrontendError(
1453                loc,
1454                "Default parameters on overloads do not affect the runtime so they "
1455                "must equal to the default parameter on the implementation function. Found on "
1456                f"parameter {name}",
1457            )
1458
1459
1460def _compile_function_with_overload(overload_fn, qual_name, impl_fn):
1461    overload_decl = get_jit_def(overload_fn, overload_fn.__name__).decl()
1462    overload_signature = torch.jit.annotations.get_signature(
1463        overload_fn, None, None, inspect.ismethod(overload_fn)
1464    )
1465    impl_ast = get_jit_def(impl_fn, impl_fn.__name__)
1466    overload_defaults = get_default_args(overload_fn)
1467    implementation_defaults = get_default_args(impl_fn)
1468    _rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn)
1469    _check_overload_defaults(
1470        implementation_defaults, overload_defaults, overload_decl.range()
1471    )
1472    fn = torch._C._jit_script_compile_overload(
1473        qual_name,
1474        overload_decl,
1475        impl_ast,
1476        _rcb,
1477        implementation_defaults,
1478        overload_signature,
1479    )
1480    return fn
1481
1482
1483def _get_overloads(obj):
1484    # check for cached compiled fns
1485    existing_compiled_fns = _try_get_jit_cached_overloads(obj)
1486    qual_name = _qualified_name(obj)
1487    uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name)
1488    if uncompiled_overloads is None:
1489        return existing_compiled_fns
1490
1491    if obj in uncompiled_overloads:
1492        raise RuntimeError(
1493            _jit_internal.get_overload_no_implementation_error_message("function", obj)
1494        )
1495
1496    compiled_fns = []
1497    for overload_fn in uncompiled_overloads:
1498        compiled_fns.append(
1499            _compile_function_with_overload(overload_fn, qual_name, obj)
1500        )
1501
1502    if existing_compiled_fns:
1503        compiled_fns = existing_compiled_fns + compiled_fns
1504
1505    # cache compilation, remove information stored to do compilation
1506    _set_jit_overload_cache(obj, compiled_fns)
1507    _jit_internal._clear_fn_overloads(qual_name)
1508    return compiled_fns
1509
1510
1511def _check_directly_compile_overloaded(obj):
1512    qual_name = _qualified_name(obj)
1513    if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj):
1514        raise RuntimeError(
1515            f"Function {qual_name} cannot be directly compiled because it"
1516            " is overloaded. It must be used in a context of a function"
1517            " where its inputs can determine which overload to call."
1518        )
1519
1520
1521def interface(obj):
1522    r"""Decorate to annotate classes or modules of different types.
1523
1524    This decorator can be used to define an interface that can be used to annotate
1525    classes or modules of different types. This can be used for to annotate a submodule
1526    or attribute class that could have different types that implement the same
1527    interface, or which could be swapped at runtime; or to store a list of modules or
1528    classes of varying types.
1529
1530    It is sometimes used to implement "Callables" - functions or modules that implement
1531    an interface but whose implementations differ and which can be swapped out.
1532
1533    Example:
1534    .. testcode::
1535
1536        import torch
1537        from typing import List
1538
1539        @torch.jit.interface
1540        class InterfaceType:
1541            def run(self, x: torch.Tensor) -> torch.Tensor:
1542                pass
1543
1544        # implements InterfaceType
1545        @torch.jit.script
1546        class Impl1:
1547            def run(self, x: torch.Tensor) -> torch.Tensor:
1548                return x.relu()
1549
1550        class Impl2(torch.nn.Module):
1551            def __init__(self) -> None:
1552                super().__init__()
1553                self.val = torch.rand(())
1554
1555            @torch.jit.export
1556            def run(self, x: torch.Tensor) -> torch.Tensor:
1557                return x + self.val
1558
1559        def user_fn(impls: List[InterfaceType], idx: int, val: torch.Tensor) -> torch.Tensor:
1560            return impls[idx].run(val)
1561
1562        user_fn_jit = torch.jit.script(user_fn)
1563
1564        impls = [Impl1(), torch.jit.script(Impl2())]
1565        val = torch.rand(4, 4)
1566        user_fn_jit(impls, 0, val)
1567        user_fn_jit(impls, 1, val)
1568    """
1569    if not inspect.isclass(obj):
1570        raise RuntimeError("interface must be applied to a class")
1571    if not _is_new_style_class(obj):
1572        raise RuntimeError("TorchScript interfaces must inherit from 'object'")
1573
1574    # Expected MRO is:
1575    #   User module
1576    #   torch.nn.modules.module.Module
1577    #   object
1578    is_module_interface = issubclass(obj, torch.nn.Module) and len(obj.mro()) == 3
1579
1580    if not is_module_interface and len(obj.mro()) > 2:
1581        raise RuntimeError(
1582            "TorchScript interface does not support inheritance yet. "
1583            "Please directly inherit from 'object' or 'nn.Module'."
1584        )
1585
1586    qualified_name = _qualified_name(obj)
1587    rcb = _jit_internal.createResolutionCallbackFromFrame(1)
1588    # if this type is a `nn.Module` subclass, generate a module interface type
1589    # instead of a class interface type; a module interface type only compiles
1590    # the user provided methods as part of the interface
1591    ast = get_jit_class_def(obj, obj.__name__)
1592    mangled_classname = torch._C._jit_script_interface_compile(
1593        qualified_name, ast, rcb, is_module_interface
1594    )
1595    obj.__torch_script_interface__ = mangled_classname
1596    return obj
1597
1598
1599def _recursive_compile_class(obj, loc):
1600    _qual_name = _qualified_name(obj)
1601    # We're starting a new compilation, so update the error call stack in
1602    # case it fails
1603    error_stack = torch._C.CallStack(_qual_name, loc)
1604    rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
1605    return _compile_and_register_class(obj, rcb, _qual_name)
1606
1607
1608CompilationUnit = torch._C.CompilationUnit
1609set_module(CompilationUnit, "torch.jit")
1610
1611
1612def pad(s: str, padding: int, offset: int = 0, char: str = " "):
1613    if padding >= len(s):
1614        padding -= len(s)
1615    return "".join([char for _ in range(padding + offset)]) + s
1616
1617
1618class _ScriptProfileColumn:
1619    def __init__(self, header: str, alignment: int = 4, offset: int = 0):
1620        self.header = header
1621        self.alignment = alignment
1622        self.offset = offset
1623        self.rows: Dict[int, Any] = {}
1624
1625    def add_row(self, lineno: int, value: Any):
1626        self.rows[lineno] = value
1627
1628    def materialize(self):
1629        max_length = len(self.header)
1630        rows: List[Tuple[int, str]] = []
1631        for key, value in self.rows.items():
1632            cell = str(value)
1633            rows.append((key, cell))
1634            max_length = max(len(cell), max_length)
1635
1636        if self.alignment > 0:
1637            padding = max_length + self.alignment
1638            padding -= padding % self.alignment
1639        else:
1640            padding = 0
1641
1642        rows = [(key, pad(cell, padding, self.offset)) for key, cell in rows]
1643        return pad(self.header, padding, self.offset), rows
1644
1645
1646class _ScriptProfileTable:
1647    def __init__(self, cols: List[_ScriptProfileColumn], source_range: List[int]):
1648        self.cols = cols
1649        self.source_range = source_range
1650
1651    def dump_string(self):
1652        outputs: List[str] = []
1653        cells: List[Tuple[str, Dict[int, str]]] = []
1654        header_buffer = ""
1655        for col in self.cols:
1656            header, rows = col.materialize()
1657            header_buffer += header
1658            cells.append((header, dict(rows)))
1659
1660        outputs.append(header_buffer)
1661        outputs.append(pad("", len(header_buffer), 0, "="))
1662        for line in self.source_range:
1663            row_buffer = ""
1664            for header, rows in cells:
1665                cell = rows.get(line)
1666                if cell is None:
1667                    row_buffer += pad("", len(header))
1668                else:
1669                    row_buffer += cell
1670            outputs.append(row_buffer)
1671        return "\n".join(outputs)
1672
1673
1674class _ScriptProfile:
1675    def __init__(self) -> None:
1676        self.profile = classes.profiling._ScriptProfile()
1677
1678    def enable(self):
1679        self.profile.enable()
1680
1681    def disable(self):
1682        self.profile.disable()
1683
1684    def dump_string(self) -> str:
1685        outputs: List[str] = []
1686        for source_stats in self.profile._dump_stats():
1687            source_ref = source_stats.source()
1688            source_lines = source_ref.text().splitlines()
1689            dedent = min(len(line) - len(line.lstrip(" ")) for line in source_lines)
1690            source_lines = [line[dedent:] for line in source_lines]
1691
1692            start_line = source_ref.starting_lineno()
1693            end_line = start_line + len(source_lines)
1694            source_range = range(start_line, end_line)
1695            lineno = _ScriptProfileColumn("Line #")
1696            hits = _ScriptProfileColumn("Hits")
1697            time_ns = _ScriptProfileColumn("Time (ns)")
1698            line_contents = _ScriptProfileColumn("Line Contents", 0, 1)
1699            stats = source_stats.line_map()
1700            for line in source_range:
1701                lineno.add_row(line, line)
1702                line_contents.add_row(line, source_lines[line - start_line])
1703                stat = stats.get(line)
1704                if stat is not None:
1705                    hits.add_row(line, stat.count())
1706                    time_ns.add_row(line, stat.duration_ns())
1707
1708            table = _ScriptProfileTable(
1709                [lineno, hits, time_ns, line_contents], list(source_range)
1710            )
1711            outputs.append(table.dump_string())
1712        return "\n\n".join(outputs)
1713
1714    def dump(self):
1715        print(self.dump_string())
1716
1717
1718def _unwrap_optional(x):
1719    assert x is not None, "Unwrapping null optional"
1720    return x
1721
1722
1723_register_builtin(_unwrap_optional, "aten::_unwrap_optional")
1724_register_builtin(_jit_internal.is_scripting, "aten::is_scripting")
1725_register_builtin(has_torch_function, "aten::has_torch_function")
1726_register_builtin(has_torch_function_unary, "aten::has_torch_function")
1727_register_builtin(has_torch_function_variadic, "aten::has_torch_function")
1728