xref: /aosp_15_r20/external/pytorch/torch/jit/_recursive.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import functools
4import inspect
5import sys
6import textwrap
7import types
8import warnings
9from typing import Dict, List, Set, Type
10
11import torch
12import torch._jit_internal as _jit_internal
13from torch._sources import fake_range
14from torch.jit._builtins import _find_builtin
15from torch.jit._check import AttributeTypeIsSupportedChecker
16from torch.jit._state import _add_script_class, _get_script_class, _python_cu
17from torch.jit.frontend import (
18    get_class_properties,
19    get_default_args,
20    get_jit_class_def,
21    get_jit_def,
22)
23from torch.nn import Module
24
25
26ScriptMethodStub = collections.namedtuple(
27    "ScriptMethodStub", ("resolution_callback", "def_", "original_method")
28)
29PropertyStub = collections.namedtuple("PropertyStub", ("resolution_callback", "def_"))
30
31
32# TODO: there should be a more principled way of doing this.
33ignored_attributes = [
34    "_version",
35    "_parameters",
36    "_buffers",
37    "_non_persistent_buffers_set",
38    "_backward_hooks",
39    "_backward_pre_hooks",
40    "_forward_hooks",
41    "_forward_hooks_with_kwargs",
42    "_forward_pre_hooks",
43    "_forward_pre_hooks_with_kwargs",
44    "_forward_hooks_always_called",
45    "_state_dict_hooks",
46    "_state_dict_pre_hooks",
47    "_load_state_dict_pre_hooks",
48    "_load_state_dict_post_hooks",
49    "_modules",
50    "_initializing",
51    "dump_patches",
52]
53
54
55def _compile_and_register_class(obj, rcb, qualified_name):
56    script_class = _get_script_class(obj)
57
58    if not script_class:
59        ast = get_jit_class_def(obj, obj.__name__)
60        defaults = torch.jit.frontend.get_default_args_for_class(obj)
61        script_class = torch._C._jit_script_class_compile(
62            qualified_name, ast, defaults, rcb
63        )
64        _add_script_class(obj, script_class)
65
66    return script_class
67
68
69def make_stub(func, name):
70    rcb = _jit_internal.createResolutionCallbackFromClosure(func)
71    ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
72    return ScriptMethodStub(rcb, ast, func)
73
74
75def make_stub_from_method(nn_module, method_name):
76    func = getattr(nn_module, method_name)
77    if isinstance(func, ScriptMethodStub):
78        return func
79    # Make sure the name present in the resulting AST will match the name
80    # requested here. The only time they don't match is if you do something
81    # like:
82    #   def _forward(self):
83    #       pass
84    #   forward = _forward
85    # In this case, the actual function object will have the name `_forward`,
86    # even though we requested a stub for `forward`.
87    return make_stub(func, method_name)
88
89
90def make_stubs_from_exported_methods(mod):
91    stubs = []
92    for name in dir(mod):
93        item = getattr(mod, name, None)
94        if (
95            _jit_internal.get_torchscript_modifier(item)
96            is _jit_internal.FunctionModifiers.EXPORT
97        ):
98            stubs.append(make_stub_from_method(mod, name))
99
100    return stubs
101
102
103def jit_ignored_properties(module):
104    user_annotated_ignored_attributes = getattr(
105        module, "__jit_ignored_attributes__", []
106    )
107
108    def get_properties_names(module):
109        return {k for k, v in vars(module).items() if isinstance(v, property)}
110
111    properties = get_properties_names(type(module))
112    user_annoted_ignored_properties = set()
113
114    for ignored_attr in user_annotated_ignored_attributes:
115        if ignored_attr in properties:
116            user_annoted_ignored_properties.add(ignored_attr)
117    return user_annoted_ignored_properties
118
119
120# base types that can be constants
121# in addition, tuples and lists of these base types are also considered constants
122# If you edit this list, then you also need to edit the handlers in
123# ConstantValue in jit/script/init.cpp
124_constant_types = (
125    bool,
126    float,
127    int,
128    str,
129    type(None),
130    torch.device,
131    torch.layout,
132    torch.dtype,
133)
134
135
136def _get_valid_constant(attr, v, owner_type):
137    if isinstance(v, _constant_types):
138        return v
139    elif isinstance(v, (tuple, list)):
140        return tuple(_get_valid_constant(attr, x, owner_type) for x in v)
141    constants = ", ".join(torch.typename(typ) for typ in _constant_types)
142    raise TypeError(
143        textwrap.dedent(
144            f"""
145        '{torch.typename(type(v))}' object in attribute '{owner_type}.{attr}' is not a valid constant.
146        Valid constants are:
147        1. a nn.ModuleList
148        2. a value of type {{{constants}}}
149        3. a list or tuple of (2)
150        """
151        )
152    )
153
154
155class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
156    def __init__(self, source, filename, file_lineno, leading_whitespace_len):
157        super().__init__(source, filename, file_lineno, leading_whitespace_len)
158
159
160def get_annotations(obj):
161    if sys.version_info < (3, 10):
162        return getattr(obj, "__annotations__", {})
163    # In Python-3.10+ it is recommended to use inspect.get_annotations
164    # See https://docs.python.org/3.10/howto/annotations.html
165    # But also, in 3.10 annotations from base class are not inherited
166    # by unannotated derived one, so they must be manually extracted
167    annotations = inspect.get_annotations(obj)
168    if annotations:
169        return annotations
170
171    def get_cls_annotations(cls):
172        cls_annotations = inspect.get_annotations(cls)
173        if cls_annotations:
174            return cls_annotations
175        for base in cls.__bases__:
176            cls_annotations = get_cls_annotations(base)
177            if cls_annotations:
178                return cls_annotations
179        return {}
180
181    cls = obj if isinstance(obj, type) else type(obj)
182    return get_cls_annotations(cls)
183
184
185def infer_concrete_type_builder(nn_module, share_types=True):
186    """
187    Build a ConcreteModuleTypeBuilder from an nn.Module.
188
189    This ConcreteModuleType doesn't have a JIT type associated with it yet, it
190    must be filled in by the caller.
191    """
192    concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
193    if isinstance(nn_module, (torch.nn.ModuleDict)):
194        concrete_type_builder.set_module_dict()
195    if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
196        concrete_type_builder.set_module_list()
197    if isinstance(nn_module, (torch.nn.ParameterList)):
198        concrete_type_builder.set_parameter_list()
199    if isinstance(nn_module, (torch.nn.ParameterDict)):
200        concrete_type_builder.set_parameter_dict()
201
202    class_annotations = get_annotations(nn_module)
203    if isinstance(nn_module, (torch.ao.quantization.QuantWrapper)):
204        class_annotations = {}
205
206    # Get user-annotated ignored attributes.
207    user_annotated_ignored_attributes = getattr(
208        nn_module, "__jit_ignored_attributes__", []
209    )
210    concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes)
211    ignored_properties = jit_ignored_properties(nn_module)
212
213    # try to infer the type from type annotation or from the object itself
214    def infer_type(name, item):
215        # The forward function from Module is special; never use this annotations; we
216        # need to infer type directly using JIT.  I originally wanted to write
217        # this test as isinstance(class_annotations[name], Callable) but
218        # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
219        # is also true!
220        inferred = False
221        try:
222            if (
223                name in class_annotations
224                and class_annotations[name]
225                != torch.nn.Module.__annotations__["forward"]
226            ):
227                ann_to_type = torch.jit.annotations.ann_to_type(
228                    class_annotations[name], fake_range()
229                )
230                attr_type = torch._C.InferredType(ann_to_type)
231            elif isinstance(item, torch.jit.Attribute):
232                ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range())
233                attr_type = torch._C.InferredType(ann_to_type)
234            else:
235                attr_type = torch._C._jit_try_infer_type(item)
236                inferred = True
237        except RuntimeError as re:
238            raise RuntimeError(f"Error inferring type for {name}: {item}: {re}") from re
239
240        return attr_type, inferred
241
242    added_names = set()
243
244    for name, item in nn_module._parameters.items():
245        if name in user_annotated_ignored_attributes:
246            continue
247
248        assert item is None or isinstance(item, torch.Tensor)
249        attr_type, _ = infer_type(name, item)
250        # We currently have the invariant in various places in our code
251        # that parameters must be Tensors. However, the nn.Module API also
252        # allows NoneType parameters. These parameters are not returned as
253        # part of `parameters()` and its variants, but are available
254        # through direct attribute access.
255        concrete_type_builder.add_attribute(name, attr_type.type(), True, False)
256        added_names.add(name)
257
258    for name, item in nn_module._buffers.items():
259        if name in user_annotated_ignored_attributes:
260            continue
261
262        assert item is None or isinstance(item, torch.Tensor)
263        attr_type, _ = infer_type(name, item)
264        concrete_type_builder.add_attribute(name, attr_type.type(), False, True)
265        added_names.add(name)
266
267    for name, item in nn_module._modules.items():
268        if name in user_annotated_ignored_attributes:
269            continue
270
271        attr_type, _ = infer_type(name, item)
272        if item is None:
273            # Modules can be None. We don't have direct support for optional
274            # Modules, so the register it as an NoneType attribute instead.
275            concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
276            continue
277        if attr_type.success():
278            assert attr_type.type().is_interface_type()
279            # if the type can be inferred, it should be a module interface type
280            sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(
281                attr_type.type()
282            )
283        else:
284            # otherwise we get the concrete module type for item and add it to concrete_type
285            sub_concrete_type = get_module_concrete_type(item, share_types)
286        concrete_type_builder.add_module(name, sub_concrete_type)
287
288        added_names.add(name)
289
290    # populate constants_set
291    constants_set = set(getattr(nn_module, "__constants__", ()))
292
293    # Constants annotated via `Final[T]` rather than being added to `__constants__`
294    for name, ann in class_annotations.items():
295        if torch._jit_internal.is_final(ann):
296            constants_set.add(name)
297
298    for name in constants_set:
299        if name in added_names:
300            # TODO: We should really error in this case, but its bc-breaking so
301            # we need to warn for at least one release
302            if name in nn_module._modules:
303                hint = "submodule"
304            elif name in nn_module._buffers:
305                hint = "buffer"
306            elif name in nn_module._parameters:
307                hint = "parameter"
308            else:
309                raise AssertionError(
310                    "added_names must be submodule, parameter, or buffer"
311                )
312
313            warnings.warn(
314                f"'{name}' was found in ScriptModule constants, "
315                f" but it is a non-constant {hint}. Consider removing it."
316            )
317            continue
318        if not hasattr(nn_module, name):
319            # TODO: We should really error in this case, but its bc-breaking so
320            # we need to warn for at least one release
321            warnings.warn(
322                f"'{name}' was found in ScriptModule constants, "
323                "but was not actually set in __init__. "
324                "Consider removing it."
325            )
326            continue
327        value = getattr(nn_module, name)
328        concrete_type_builder.add_constant(
329            name, _get_valid_constant(name, value, type(nn_module).__name__)
330        )
331        added_names.add(name)
332
333    # populate overloads
334    overloads = getattr(nn_module, "__overloads__", {})
335    # update with any annotated overloads
336    overloads.update(
337        get_overload_name_mapping(
338            get_overload_annotations(nn_module, ignored_properties)
339        )
340    )
341    for name, overloaded_names in overloads.items():
342        concrete_type_builder.add_overload(name, overloaded_names)
343
344    for name, value in nn_module.__dict__.items():
345        if name in ignored_attributes or name.startswith("__"):
346            # Python objects have lots of random attributes attached to them;
347            # PyTorch adds a few more. Prevent these from getting compiled.
348            continue
349
350        if name in user_annotated_ignored_attributes:
351            continue
352
353        if name in added_names:
354            # Don't re-add anything we already added
355            continue
356
357        isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket)
358        if isoverloadpacket:
359            value = value.op
360        # Handle Python function attributes
361        if inspect.isfunction(value):
362            try:
363                scripted_fn = torch.jit.script(value)
364                concrete_type_builder.add_function_attribute(
365                    name, torch._C._jit_try_infer_type(scripted_fn).type(), value
366                )
367            except Exception as e:
368                # If we fail to script the function, it isn't a hard error.
369                # Instead, we will add it to the list of attributes we failed
370                # to convert, with the compilation error.
371                hint = (
372                    "(This function exists as an attribute on the Python module, "
373                    "but we failed to compile it to a TorchScript function. "
374                    f"\nThe error stack is reproduced here:\n{e}"
375                )
376                concrete_type_builder.add_failed_attribute(name, hint)
377
378            continue
379
380        # Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or
381        # a call to an aten function like torch.add)
382        builtin_symbol_name = _find_builtin(value)
383        if builtin_symbol_name:
384            concrete_type_builder.add_builtin_function(name, builtin_symbol_name)
385            continue
386
387        # Handle Script function attributes
388        if isinstance(value, torch.jit.ScriptFunction):
389            concrete_type_builder.add_function_attribute(
390                name, torch._C._jit_try_infer_type(value).type(), value
391            )
392            continue
393
394        # If we got here, this is a regular "data" attribute, add it to the concrete type
395        attr_type, inferred = infer_type(name, value)
396        if attr_type.success():
397            concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
398        else:
399            # TODO: could add more detail here. For example, what the user should do
400            # when the pytype is `list` or `NoneType`
401            inferred_msg = (
402                "Its type was inferred; try adding a type annotation for the attribute."
403                if inferred
404                else ""
405            )
406            additional_info = f"{attr_type.reason()}. {inferred_msg}"
407            hint = (
408                "(This attribute exists on the Python module, "
409                f"but we failed to convert Python type: '{torch.typename(type(value))}' "
410                f"to a TorchScript type. {additional_info})"
411            )
412            concrete_type_builder.add_failed_attribute(name, hint)
413
414    # add hooks to concrete type
415    for hook in nn_module._forward_hooks.values():
416        concrete_type_builder.add_forward_hook(hook)
417    for pre_hook in nn_module._forward_pre_hooks.values():
418        concrete_type_builder.add_forward_pre_hook(pre_hook)
419
420    return concrete_type_builder
421
422
423class ConcreteTypeStore:
424    type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]]
425    methods_compiled: Set[torch._C.ConcreteModuleType]
426
427    def __init__(self) -> None:
428        # Python module type => List[ConcreteModuleType)]
429        self.type_store = {}
430        # ConcreteTypes that have had their methods already compiled
431        self.methods_compiled = set()
432
433    def get_or_create_concrete_type(self, nn_module):
434        """Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are re-used if possible."""
435        concrete_type_builder = infer_concrete_type_builder(nn_module)
436
437        nn_module_type = type(nn_module)
438        if nn_module_type not in self.type_store:
439            self.type_store[nn_module_type] = []
440
441        # Search the type store for an already-available JIT type
442        known_types = self.type_store[nn_module_type]
443        for known_type in known_types:
444            if known_type.equals(concrete_type_builder):
445                return known_type
446
447        # We didn't find anything; generate a new JIT type from this concrete type
448        concrete_type = concrete_type_builder.build()
449        self.type_store[nn_module_type].append(concrete_type)
450        return concrete_type
451
452
453concrete_type_store = ConcreteTypeStore()
454
455
456def create_methods_and_properties_from_stubs(
457    concrete_type, method_stubs, property_stubs
458):
459    method_defs = [m.def_ for m in method_stubs]
460    method_rcbs = [m.resolution_callback for m in method_stubs]
461    method_defaults = [get_default_args(m.original_method) for m in method_stubs]
462
463    property_defs = [p.def_ for p in property_stubs]
464    property_rcbs = [p.resolution_callback for p in property_stubs]
465
466    concrete_type._create_methods_and_properties(
467        property_defs, property_rcbs, method_defs, method_rcbs, method_defaults
468    )
469
470
471def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):
472    hook_defs = [h.def_ for h in hook_stubs]
473    hook_rcbs = [h.resolution_callback for h in hook_stubs]
474
475    pre_hook_defs = [h.def_ for h in pre_hook_stubs]
476    pre_hook_rcbs = [h.resolution_callback for h in pre_hook_stubs]
477
478    concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs)
479
480
481def get_module_concrete_type(nn_module, share_types=True):
482    """
483    Get a concrete type for nn_modules.
484
485    If share_types is True, the concrete type is fetched from concrete_type_store.
486    If it is False, a new concrete type is created without first searching concrete_type_store.
487
488    Args:
489        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
490        share_types = Whether to share underlying JIT types between modules (if possible).
491
492    Returns:
493        A concrete type for nn_module.
494    """
495    assert isinstance(nn_module, Module)
496    if isinstance(nn_module, torch.jit.ScriptModule) and hasattr(
497        nn_module, "_concrete_type"
498    ):
499        return nn_module._concrete_type
500
501    if share_types:
502        # Look into the store of cached JIT types
503        concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
504    else:
505        # Get a concrete type directly, without trying to re-use an existing JIT
506        # type from the type store.
507        concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
508        concrete_type_builder.set_poisoned()
509        concrete_type = concrete_type_builder.build()
510
511    return concrete_type
512
513
514def create_script_class(obj):
515    """
516    Create and return a RecursiveScriptClass instance from a Python object.
517
518    Arguments:
519        obj: A Python object.
520    """
521    qualified_class_name = _jit_internal._qualified_name(type(obj))
522    rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj))
523    # Script the type of obj if it hasn't already been scripted.
524    _compile_and_register_class(type(obj), rcb, qualified_class_name)
525    class_ty = _python_cu.get_class(qualified_class_name)
526    # Create an empty torch._C.ScriptObject with the scripted type.
527    cpp_object = torch._C._create_object_with_type(class_ty)
528    # Copy all of the attributes over to the torch._C.ScriptObject.
529    for name, value in obj.__dict__.items():
530        cpp_object.setattr(name, value)
531
532    # Wrap the torch._C.ScriptObject in a RecursiveScriptClass instance.
533    return wrap_cpp_class(cpp_object)
534
535
536def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False):
537    """
538    Create a new ScriptModule from an nn.Module.
539
540    Args:
541        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
542        stubs_fn:  Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
543        share_types:  Whether to share underlying JIT types between modules (if possible).
544            NOTE: Only set to False this when we cannot guarantee type sharing will work
545                correctly. This only happens today for traced modules, where the same
546                module can produce different traced methods depending on the inputs.
547        is_tracing: Whether this function is called during tracing or scripting. If tracing,
548                we don't need to do AttributeTypeIsSupportedChecker because all the unsupported
549                attributes will be baked as constant in the tracing graph. In addition,
550                this check significantly slows down the traced modules when the module size is big.
551    """
552    assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
553    check_module_initialized(nn_module)
554    concrete_type = get_module_concrete_type(nn_module, share_types)
555    if not is_tracing:
556        AttributeTypeIsSupportedChecker().check(nn_module)
557    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
558
559
560def create_script_module_impl(nn_module, concrete_type, stubs_fn):
561    """
562    Convert an nn.Module to a RecursiveScriptModule.
563
564    Args:
565        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
566        concrete_type:  The fully initialized ConcreteType of the module.
567        stubs_fn:  Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
568    """
569    cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
570    method_stubs = stubs_fn(nn_module)
571    property_stubs = get_property_stubs(nn_module)
572    hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
573
574    user_annotated_ignored_attributes = getattr(
575        nn_module, "__jit_ignored_attributes__", []
576    )
577    ignored_properties = jit_ignored_properties(nn_module)
578
579    def init_fn(script_module):
580        # Initialize the ScriptModule:
581        # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
582        for name in concrete_type.get_attributes().keys():
583            orig_value = getattr(nn_module, name)
584            orig_value = (
585                orig_value.value
586                if isinstance(orig_value, torch.jit.Attribute)
587                else orig_value
588            )
589            cpp_module.setattr(name, orig_value)
590
591        # 2. Copy the submodules from the original `nn_module` to the new ScriptModule,
592        #    recursively scripting them.
593        for name, sub_concrete_type in concrete_type.get_modules():
594            orig_value = getattr(nn_module, name)
595            assert isinstance(
596                orig_value, Module
597            ), f"Expected Module but got {type(orig_value)}"
598            module_type = sub_concrete_type.jit_type
599            if isinstance(module_type, torch._C.InterfaceType):
600                # use the interface inference rule to compile the module
601                scripted = interface_script(module_type, orig_value)
602            elif isinstance(orig_value, torch.jit.ScriptModule):
603                scripted = orig_value
604            else:
605                # always reuse the provided stubs_fn to infer the methods to compile
606                scripted = create_script_module_impl(
607                    orig_value, sub_concrete_type, stubs_fn
608                )
609
610            cpp_module.setattr(name, scripted)
611            script_module._modules[name] = scripted
612
613        # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule.
614        #    This ensures we can access these Python methods on the ScriptModule.
615        for name in dir(nn_module):
616            if name in ignored_properties:
617                continue
618            item = getattr(nn_module, name, None)
619            if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
620                unbound_function = getattr(nn_module, name).__func__
621                bound_method = unbound_function.__get__(script_module)
622                setattr(script_module, name, bound_method)
623            elif concrete_type.is_ignored_attribute(name):
624                setattr(script_module, name, item)
625
626        # For convenience, attach the concrete type to the new ScriptModule
627        script_module._concrete_type = concrete_type
628
629    # Actually create the ScriptModule, initializing it with the function we just defined
630    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
631
632    # Compile methods if necessary
633    if concrete_type not in concrete_type_store.methods_compiled:
634        create_methods_and_properties_from_stubs(
635            concrete_type, method_stubs, property_stubs
636        )
637        # Create hooks after methods to ensure no name collisions between hooks and methods.
638        # If done before, hooks can overshadow methods that aren't exported.
639        create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
640        torch._C._run_emit_module_hook(cpp_module)
641        concrete_type_store.methods_compiled.add(concrete_type)
642
643    # Copy the forward hooks and pre-hooks to the new ScriptModule
644    # to allow the hooks to be run from eager as ScriptFunctions
645    for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
646        script_module._forward_pre_hooks[idx] = fn
647    for idx, fn in enumerate(script_module._c._get_forward_hooks()):
648        script_module._forward_hooks[idx] = fn
649
650    # Special handling so methods like __len__ work in script methods on classes derived from containers
651    if (
652        isinstance(
653            nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
654        )
655        and "__len__" not in cpp_module._method_names()
656    ):
657        script_module.define(f"def __len__(self):\n   return {len(nn_module)}\n")
658    if (
659        isinstance(nn_module, torch.nn.ModuleDict)
660        and "__contains__" not in cpp_module._method_names()
661    ):
662        if len(nn_module.keys()):
663            keys = repr(list(nn_module.keys()))
664            script_module.define(
665                f"def __contains__(self, key: str):\n   return key in {keys}\n"
666            )
667        else:
668            script_module.define("def __contains__(self, key: str):\n   return False\n")
669
670    # Make the compiled methods available to the Python ScriptModule class.
671    for method_stub in method_stubs:
672        if method_stub.original_method is None:
673            # define()'d methods don't have an Python original_method, so we
674            # don't need to do any Python re-wrapping stuff
675            continue
676
677        name = method_stub.original_method.__name__
678        if name != method_stub.def_.name().name:
679            # TODO: Why skip this? Because @torch.jit._overload_method will
680            # mangle the name of the function.
681            continue
682        script_method = cpp_module._get_method(name)
683
684        # Wrap the original to propagate docstrings and such.
685        # TODO: we don't currently do this functions that are recursively
686        # compiled, we should.
687        wrapped_script_method = functools.wraps(method_stub.original_method)(
688            script_method
689        )
690
691        # Add the methods to the script_module directly. This ensures they will
692        # be found first when `name` is looked up (as opposed to the stubs or
693        # nn.Module.forward)
694        script_module.__dict__[name] = wrapped_script_method
695
696    # Make module properties available on the Python ScriptModule class.
697    for property_stub in property_stubs:
698        property_name = property_stub.def_.name().name
699        fget = cpp_module._get_method(property_stub.def_.getter_name().name)
700        # Setter is optional, so it may not exist.
701        setter_name = property_stub.def_.setter_name()
702        fset = cpp_module._get_method(setter_name.name) if setter_name else None
703        script_module.__dict__[property_name] = property(property_name, fget, fset)  # type: ignore[arg-type]
704
705    # copy over python methods to script module if they aren't defined on the script module
706    # this is currently an internal api used only on module containers
707    for name in dir(nn_module):
708        if name in ignored_properties:
709            continue
710        item = getattr(nn_module, name, None)
711        if (
712            _jit_internal.get_torchscript_modifier(item)
713            is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
714        ):
715            add_python_attr_to_scripted_model(script_module, nn_module, name)
716
717    return script_module
718
719
720# We define shims of certain attributes on the RecursiveScriptModule to support
721# magic methods. To check if a script model defines an attribute we need
722# to also check that the attribute is not the shim
723def script_model_defines_attr(script_model, attr):
724    script_attr = getattr(script_model, attr, None)
725    if script_attr is None:
726        return False
727    default_attr = getattr(torch.jit.RecursiveScriptModule, attr, None)
728    if default_attr is None:
729        return False
730    return script_attr != default_attr
731
732
733def add_python_attr_to_scripted_model(script_model, orig, attr):
734    if hasattr(orig, attr) and script_model_defines_attr(script_model, attr):
735        setattr(script_model, attr, getattr(orig, attr))
736
737
738def get_overload_annotations(mod, jit_ignored_properties):
739    # original function => [(mangled overload name, overload function)]
740    overloads = {}
741
742    for name in dir(type(mod)):
743        if name in jit_ignored_properties:
744            continue
745        item = getattr(mod, name, None)
746        if not callable(item):
747            continue
748
749        # builtin functions like repr() in python 2 do not have __module__ defined
750        if hasattr(item, "__module__") and item.__module__ is not None:
751            method_overloads = _jit_internal._get_overloaded_methods(
752                item, mod.__class__
753            )
754            if method_overloads is None:
755                continue
756
757            if item.__func__ in method_overloads:
758                raise RuntimeError(
759                    _jit_internal.get_overload_no_implementation_error_message(
760                        "method", item.__func__
761                    )
762                )
763
764            names = [name + "__" + str(i) for i in range(len(method_overloads))]
765            overloads[item] = list(zip(names, method_overloads))
766
767    return overloads
768
769
770def get_overload_name_mapping(overload_info):
771    # Same format as __overloads__
772    # original function => [overload names]
773    overload_name_mappings: Dict[str, List[str]] = {}
774    for orig_fn, overloads in overload_info.items():
775        original_name = orig_fn.__name__
776        if original_name not in overload_name_mappings:
777            overload_name_mappings[original_name] = []
778
779        for overload_name, _ in overloads:
780            overload_name_mappings[original_name].append(overload_name)
781    return overload_name_mappings
782
783
784def _check_no_signature(func):
785    signature = torch.jit.annotations.get_signature(
786        func, None, fake_range(), inspect.ismethod(func)
787    )
788    if signature is None:
789        qual_name = _jit_internal._qualified_name(func)
790        raise RuntimeError(
791            f"Must explicitly add type annotations to overloaded functions: {qual_name}"
792        )
793
794
795def make_stubs_for_overloads(overload_info):
796    overload_stubs = []
797    for orig_fn, overloads in overload_info.items():
798        orig_ast = get_jit_def(
799            orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule"
800        )
801        for overload_name, overload_fn in overloads:
802            _check_no_signature(overload_fn)
803            over_ast = get_jit_def(
804                overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule"
805            )
806            new_ast = torch._C._replace_overloaded_method_decl(
807                over_ast.decl(), orig_ast, overload_name
808            )
809            _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
810            overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn))
811    return overload_stubs
812
813
814def check_module_initialized(mod):
815    assert isinstance(mod, torch.nn.Module)
816    if not hasattr(mod, "_parameters"):
817        raise RuntimeError(
818            f"'{torch.typename(type(mod))}' has not been initialized, did you forget to call 'super()'?"
819        )
820
821    # This is to avoid importing torch.distributed.nn
822    if not hasattr(mod, "remote_parameters"):
823        for name, param in mod._parameters.items():
824            if param is not None and torch.nn.parameter.is_lazy(param):
825                raise RuntimeError(
826                    f"'{torch.typename(type(mod))}' has uninitialized parameters {name}. Did you forget to run a forward pass?"
827                )
828        for name, buf in mod._buffers.items():
829            if buf is not None and torch.nn.parameter.is_lazy(buf):
830                raise RuntimeError(
831                    f"'{torch.typename(type(mod))}' has uninitialized buffers {name}. Did you forget to run a forward pass?"
832                )
833
834
835def infer_methods_to_compile(nn_module):
836    """Implement the default rules for which methods should act as starting points for compilation.
837
838    (TODO add a link when the rules are published).
839    """
840    check_module_initialized(nn_module)
841    user_annotated_ignored_attributes = getattr(
842        nn_module, "__jit_ignored_attributes__", []
843    )
844    ignored_properties = jit_ignored_properties(nn_module)
845
846    methods: List[str] = []
847    if hasattr(nn_module, "forward") and not _jit_internal.is_ignored_fn(
848        nn_module.forward
849    ):
850        forward_func = getattr(nn_module.forward, "__func__", None)
851        module_forward = getattr(torch.nn.Module, "forward", None)
852        if forward_func != module_forward:
853            methods = ["forward"]
854
855    exported = []
856    for name in dir(nn_module):
857        if name in ignored_properties:
858            continue
859        item = getattr(nn_module, name, None)
860        if (
861            _jit_internal.get_torchscript_modifier(item)
862            is _jit_internal.FunctionModifiers.EXPORT
863        ):
864            exported.append(name)
865
866    methods = methods + exported
867
868    overload_name_mappings = dict(getattr(nn_module, "__overloads__", {}))
869    overload_info = get_overload_annotations(nn_module, ignored_properties)
870    overload_name_mappings.update(get_overload_name_mapping(overload_info))
871    overload_stubs = make_stubs_for_overloads(overload_info)
872
873    nn_module.__overloads__ = overload_name_mappings
874
875    # we shouldn't directly compile overloaded methods, just its overloads
876    def ignore_overloaded(method_name):
877        return method_name not in overload_name_mappings
878
879    filtered_methods = filter(ignore_overloaded, methods)
880
881    # Unique the methods. We don't want to use a set to store the methods because it
882    # introduces non-determinism to compile order.
883    uniquer: Set[str] = set()
884    uniqued_methods = []
885    for name in filtered_methods:
886        if name in uniquer:
887            continue
888        uniqued_methods.append(name)
889        uniquer.add(name)
890
891    stubs = []
892    for method in uniqued_methods:
893        stubs.append(make_stub_from_method(nn_module, method))
894    return overload_stubs + stubs
895
896
897def get_hook_stubs(nn_module):
898    """Return forward hook and pre_hook ScriptModuleStubs."""
899    check_module_initialized(nn_module)
900    hook_map: Dict = {}
901
902    hook_stubs = []
903    for hook in nn_module._forward_hooks.values():
904        if hook.__name__ in hook_map:
905            if id(hook) != id(hook_map[hook.__name__]):
906                raise RuntimeError(
907                    f"Hook '{hook.__name__}' on {type(nn_module).__name__} "
908                    "has at least two different python definitions."
909                    " Please use unique names for all hooks."
910                )
911        else:
912            hook_map[hook.__name__] = hook
913        hook_stubs.append(make_stub(hook, hook.__name__))
914
915    pre_hook_stubs = []
916    for pre_hook in nn_module._forward_pre_hooks.values():
917        if pre_hook.__name__ in hook_map:
918            if id(pre_hook) != id(hook_map[pre_hook.__name__]):
919                raise RuntimeError(
920                    f"Pre-hook '{pre_hook.__name__}' on {type(nn_module).__name__} "
921                    "has at least two different python definitions."
922                    " Please use unique names for all hooks."
923                )
924        else:
925            hook_map[pre_hook.__name__] = pre_hook
926        pre_hook_stubs.append(make_stub(pre_hook, pre_hook.__name__))
927
928    return hook_stubs, pre_hook_stubs
929
930
931def get_property_stubs(nn_module):
932    """Create property stubs for the properties of the module by creating method stubs for the getter and setter."""
933    module_ty = type(nn_module)
934    properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule")
935    rcbs = {}
936
937    for name in dir(module_ty):
938        item = getattr(module_ty, name, None)
939        if isinstance(item, property):
940            if not item.fget:
941                raise RuntimeError(
942                    f"Property {name} of {nn_module.__name__} must have a getter"
943                )
944
945            rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget)
946
947    stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts]
948    return stubs
949
950
951def interface_script(mod_interface, nn_module):
952    """
953    Make a ScriptModule from an nn.Module, using the interface methods rule for determining which methods to compile.
954
955    Args:
956        mod_interface: the interface type that the module have
957        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
958    """
959    if isinstance(nn_module, torch.jit.ScriptModule):
960        return nn_module
961
962    check_module_initialized(nn_module)
963
964    def infer_interface_methods_to_compile(nn_module):
965        """Rule to infer the methods from the interface type.
966
967        It is used to know which methods need to act as starting points for compilation.
968        """
969        stubs = []
970        for method in mod_interface.getMethodNames():
971            stubs.append(make_stub_from_method(nn_module, method))
972        return stubs
973
974    return create_script_module(nn_module, infer_interface_methods_to_compile)
975
976
977def try_compile_fn(fn, loc):
978    if _jit_internal.is_ignored_fn(fn):
979        # Don't do anything for @ignore'd functions
980        return None
981
982    if isinstance(fn, torch.nn.Module):
983        # Since modules are callable pybind recognizes them as functions, but
984        # don't do anything for them
985        return None
986
987    if not inspect.isfunction(fn) and not inspect.ismethod(fn):
988        raise RuntimeError(
989            f"`{fn}` is not a function. Recursive scripting only supports "
990            "Python functions or methods currently.\n"
991            f"Consider manually annotating `{fn}` with @torch.jit.script."
992        )
993
994    # The object returned by __prepare_scriptable__ might have a different closure.
995    # Resolve it here to get the right resolution callback.
996    fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn  # type: ignore[operator]
997
998    # We don't have the actual scope where the function was defined, but we can
999    # extract the necessary info from the closed over variables on the function
1000    # object
1001    rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
1002    return torch.jit.script(fn, _rcb=rcb)
1003
1004
1005def wrap_cpp_class(cpp_class):
1006    """Wrap this torch._C.Object in a Python RecursiveScriptClass."""
1007    return torch.jit.RecursiveScriptClass(cpp_class)
1008
1009
1010def wrap_cpp_module(cpp_module):
1011    """Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules."""
1012
1013    def init_fn(script_module):
1014        for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
1015            setattr(script_module, name, wrap_cpp_module(cpp_module))
1016        script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(
1017            script_module._c._type()
1018        )
1019
1020        for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
1021            script_module._forward_pre_hooks[idx] = fn
1022        for idx, fn in enumerate(script_module._c._get_forward_hooks()):
1023            script_module._forward_hooks[idx] = fn
1024
1025    return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
1026
1027
1028def compile_unbound_method(concrete_type, fn):
1029    if _jit_internal.is_ignored_fn(fn):
1030        return None
1031    stub = make_stub(fn, fn.__name__)
1032    with torch._jit_internal._disable_emit_hooks():
1033        # We don't want to call the hooks here since the graph that is calling
1034        # this function is not yet complete
1035        create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
1036    return stub
1037
1038
1039def lazy_bind(concrete_type, unbound_method):
1040    """
1041    Return a function that lazily binds `unbound_method` to a provided Module IValue, then invokes the method.
1042
1043    We do this so that any Python shenanigans that
1044    will poison type sharing are impossible at compile time.
1045    """
1046
1047    def lazy_binding_method(cpp_module, *args):
1048        def init_fn(script_module):
1049            orig_class = concrete_type.py_class
1050
1051            # Copy @ignored/@unused methods from the original module to the new one.
1052            # This ensures they are available during execution.
1053            for name in dir(orig_class):
1054                item = getattr(orig_class, name, None)
1055                if _jit_internal.is_ignored_fn(item):
1056                    setattr(script_module, name, item)
1057
1058            # Copy constants over so they are available during execution.
1059            for name, value in concrete_type.get_constants().items():
1060                setattr(script_module, name, value)
1061
1062        script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
1063        method = types.MethodType(unbound_method, script_module)
1064        return method(*args)
1065
1066    # make the lazy binding method "look like" the original method
1067    lazy_binding_method.original_fn = unbound_method  # type: ignore[attr-defined]
1068    lazy_binding_method.__name__ = unbound_method.__name__
1069    torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method)
1070
1071    return lazy_binding_method
1072