xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/nn_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import functools
4import inspect
5import itertools
6import types
7from contextlib import contextmanager, nullcontext
8from typing import Dict, List, TYPE_CHECKING
9
10import torch.nn
11
12from .. import trace_rules, variables
13from ..exc import (
14    raise_observed_exception,
15    unimplemented,
16    UnspecializeRestartAnalysis,
17    Unsupported,
18)
19from ..guards import GuardBuilder, install_guard
20from ..mutation_guard import GenerationTracker
21from ..source import (
22    AttrSource,
23    ConstDictKeySource,
24    FSDPNNModuleSource,
25    GetItemSource,
26    NNModuleSource,
27    UnspecializedBuiltinNNModuleSource,
28    UnspecializedNNModuleSource,
29)
30from ..utils import (
31    get_custom_getattr,
32    get_fake_value,
33    is_lazy_module,
34    is_namedtuple,
35    is_safe_constant,
36    istensor,
37    istype,
38    nnmodule_has_hooks,
39    object_has_getattribute,
40    proxy_args_kwargs,
41    set_example_value,
42)
43from .base import MutableLocal, typestr, VariableTracker
44from .functions import invoke_and_store_as_constant
45from .lazy import LazyVariableTracker
46from .lists import SliceVariable
47from .user_defined import UserDefinedObjectVariable
48
49
50if TYPE_CHECKING:
51    from torch._dynamo.symbolic_convert import InstructionTranslator
52
53
54def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
55    """
56    Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable.
57
58    Used to cause lazy module to be initialized (and delete its init hook) before tracing. Especially
59    useful now that 'allowed' modules graph-break on hooks, calling this first ensures there is no hook
60    by the time we trace __call__ and thus no graph-break for lazy allowed modules.
61    """
62    if hasattr(mod, "_initialize_hook"):
63
64        def convert_to_fake(x):
65            if is_namedtuple(x):
66                return type(x)(*(convert_to_fake(elem) for elem in x))
67            elif isinstance(x, dict):
68                return {k: convert_to_fake(v) for k, v in x.items()}
69            elif isinstance(x, (list, tuple, set)):
70                return type(x)(convert_to_fake(elem) for elem in x)
71            elif isinstance(x, torch.fx.Proxy):
72                return get_fake_value(x.node, tx)
73            else:
74                return x
75
76        proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs)
77        fake_args = [convert_to_fake(arg) for arg in proxy_args]
78        fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()}
79        mod._infer_parameters(mod, fake_args, fake_kwargs)
80
81
82@contextmanager
83def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
84    fully_qualified_name = source.name()
85    try:
86        tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__)
87        yield
88    finally:
89        del tx.nn_module_stack[module_key]
90
91
92def guard_to_detect_forward_monkeypatching(source, mod):
93    # Users sometimes patch the forward method of a nn module instance to
94    # perform optimizations like quantization. Though this is not a good
95    # software practice, but python allows this and Dynamo needs to detect
96    # this patching.
97    #
98    # One way to do this is to add an ID_MATCH guard on every function
99    # getting inlined (https://github.com/pytorch/pytorch/pull/124975). But
100    # this increased guard overhead by around 20%.
101    #
102    # To keep the guard overhead down, we just guard on the `forward` being
103    # not present in the mod __dict__. The common case of patching forward
104    # method adds `forward` in the instance __dict__, whereas the unpatched
105    # `forward` sits in the type(mod).__dict__
106    if source:
107        if "forward" in mod.__dict__ and callable(mod.__dict__["forward"]):
108            # Monkeypatched forward method, add an ID_MATCH guard on forward function
109            fwd = mod.__dict__["forward"]
110            forward_source = AttrSource(source, "forward")
111            if type(fwd) is types.MethodType:
112                forward_source = AttrSource(forward_source, "__func__")
113            install_guard(forward_source.make_guard(GuardBuilder.CLOSURE_MATCH))
114        else:
115            # Common case - check that the forward key is absent in mod __dict__
116            install_guard(
117                source.make_guard(
118                    functools.partial(
119                        GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr="forward"
120                    )
121                )
122            )
123
124
125class NNModuleVariable(VariableTracker):
126    _nonvar_fields = {
127        "module_type",
128        "module_key",
129        "module",
130        "nn_module_stack_source",
131        *VariableTracker._nonvar_fields,
132    }
133
134    def __init__(
135        self, module_type: type, module_key: str, module: torch.nn.Module, **kwargs
136    ) -> None:
137        super().__init__(**kwargs)
138        self.module_type = module_type
139        self.module_key = module_key
140        self.module = module
141        assert self.source
142        self.nn_module_stack_source = self.source
143
144    def get_nn_module_stack_source(self):
145        return self.nn_module_stack_source or self.source
146
147    def set_nn_module_stack_source(self, source):
148        self.nn_module_stack_source = source
149
150    def python_type(self):
151        return self.module_type
152
153    def _wrap_submodule(
154        self, tx: "InstructionTranslator", source, submod, *key_extra, **options
155    ):
156        return
157
158    def unpack_var_sequence(self, tx):
159        # implement list/iter/tuple/etc calls
160        base = tx.output.get_submodule(self.module_key)
161        if isinstance(base, torch.nn.ModuleDict):
162            result = []
163            for name, submod in base.items():
164                name_var = variables.ConstantVariable.create(name)
165                tx.output.register_attr_or_module(
166                    submod,
167                    self.module_key,
168                    name,
169                    source=NNModuleSource(GetItemSource(self.source, name)),
170                )
171                result.append(name_var)
172            return result
173
174        assert isinstance(
175            base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential)
176        ), typestr(base)
177        assert self.source
178        result = []
179        for idx, submod in enumerate(base):
180            result.append(
181                tx.output.register_attr_or_module(
182                    submod,
183                    self.module_key,
184                    idx,
185                    source=NNModuleSource(GetItemSource(self.source, idx)),
186                )
187            )
188        return result
189
190    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
191        mod = tx.output.get_submodule(self.module_key)
192        result = hasattr(mod, name)
193        install_guard(
194            NNModuleSource(AttrSource(self.source, name)).make_guard(
195                GuardBuilder.HASATTR
196            )
197        )
198        return variables.ConstantVariable.create(result)
199
200    def is_training(self, tx):
201        mod = tx.output.get_submodule(self.module_key)
202        return getattr(mod, "training", False)
203
204    def convert_to_unspecialized(self, tx):
205        """Restart analysis treating this module as an UnspecializedNNModuleVariable"""
206        mod = tx.output.get_submodule(self.module_key)
207        GenerationTracker.tag(mod)
208
209        # Mark the class dynamic unless its module initialization
210        if tx.f_code.co_name != "__init__":
211            GenerationTracker.mark_class_dynamic(type(mod))
212        raise UnspecializeRestartAnalysis
213
214    def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
215        base = tx.output.get_submodule(self.module_key)
216
217        if object_has_getattribute(base):
218            unimplemented("NNModuleVariable with custom __getattribute__")
219
220        if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
221            mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
222            return not isinstance(mutated_attr, variables.DeletedVariable)
223
224        base_dict = object.__getattribute__(base, "__dict__")
225        return key in base_dict
226
227    def _custom_getattr_fallback(self, base, tx, name, options):
228        """Check for a __getattr__ and handle it specially if it is implemented"""
229        if object_has_getattribute(base):
230            unimplemented("torch.nn.Module with a custom __getattribute__ defined")
231
232        getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True)
233        if getattr_fn is None:
234            return None
235
236        if not isinstance(getattr_fn, types.FunctionType):
237            unimplemented("torch.nn.Module with a non-function custom __getattr__")
238
239        return variables.UserMethodVariable(getattr_fn, self, **options).call_function(
240            tx, [variables.ConstantVariable.create(name)], {}
241        )
242
243    def var_getattr(self, tx: "InstructionTranslator", name):
244        from .builder import VariableBuilder
245
246        if self.source:
247            source = AttrSource(self.source, name)
248        else:
249            source = None
250
251        base = tx.output.get_submodule(self.module_key)
252        base_dict = object.__getattribute__(base, "__dict__")
253        object_member = True
254        all_class_attribute_names = set()
255        for x in inspect.getmro(base.__class__):
256            all_class_attribute_names.update(x.__dict__.keys())
257
258        if not self.source:
259            unimplemented("GETATTR with no source")
260
261        if name == "__dict__":
262            return variables.GetAttrVariable(self, name, source=source)
263
264        if name in base_dict:
265            subobj = base_dict[name]
266        elif (
267            "_modules" in base_dict
268            and name in base_dict["_modules"]
269            and name not in all_class_attribute_names
270        ):
271            subobj = base_dict["_modules"][name]
272        elif "_parameters" in base_dict and name in base_dict["_parameters"]:
273            subobj = base_dict["_parameters"][name]
274        elif "_buffers" in base_dict and name in base_dict["_buffers"]:
275            subobj = base_dict["_buffers"][name]
276        else:
277            try:
278                subobj = inspect.getattr_static(base, name)
279                object_member = False
280            except AttributeError:
281                # see if we can fallback to __getattr__, which is not checked by getattr_static
282                result = self._custom_getattr_fallback(
283                    base=base, tx=tx, name=name, options={"source": source}
284                )
285                if result is not None:
286                    return result
287                # if we can't find a __getattr__, just raise the AttributeError
288                raise
289
290        if name == "forward":
291            guard_to_detect_forward_monkeypatching(self.source, base)
292
293        if name == "__class__" and not object_member:
294            return variables.UserDefinedClassVariable(base.__class__, source=source)
295
296        if object_member:
297            out = VariableBuilder(tx, NNModuleSource(source))(subobj)
298
299            if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)):
300                # nn_module_stack source is BC surface area. Ensure that
301                # mod._modules["linear"] is reflected as mod.linear for
302                # nn_module_stack.
303                out.set_nn_module_stack_source(
304                    AttrSource(self.get_nn_module_stack_source(), name)
305                )
306            return out
307
308        else:
309            if istype(subobj, property):
310                if self.source:
311                    # Read the class attribute to reach the property
312                    source = AttrSource(AttrSource(self.source, "__class__"), name)
313                    # Get the getter function
314                    source = AttrSource(source, "fget")
315                return variables.UserFunctionVariable(
316                    subobj.fget,
317                    source=source,
318                ).call_function(tx, [(self)], {})
319            elif istype(subobj, classmethod):
320                return variables.UserMethodVariable(
321                    subobj.__func__,
322                    variables.UserDefinedObjectVariable(type(base)),
323                    source=source,
324                )
325            elif istype(subobj, staticmethod):
326                return variables.UserFunctionVariable(
327                    subobj.__get__(base), source=source
328                )
329            elif istype(subobj, types.FunctionType):
330                return variables.UserMethodVariable(subobj, self, source=source)
331            elif is_safe_constant(subobj) or istensor(subobj):
332                # Support possibly common cases of class members
333                return VariableBuilder(tx, NNModuleSource(source))(subobj)
334            else:
335                unimplemented(
336                    f"class property {name} - {typestr(base)} {typestr(subobj)}"
337                )
338
339        return variables.GetAttrVariable(self, name, source=source)
340
341    def call_function(
342        self,
343        tx,
344        args: "List[VariableTracker]",
345        kwargs: "Dict[str, VariableTracker]",
346    ) -> "VariableTracker":
347        mod = tx.output.get_submodule(self.module_key)
348
349        with record_nn_module_stack(
350            self.module_key, self.get_nn_module_stack_source(), tx, mod
351        ):
352            is_lazy = is_lazy_module(mod)
353            if (
354                isinstance(mod, torch.nn.Sequential)
355                and mod.__class__.forward is torch.nn.Sequential.forward
356            ):
357                if nnmodule_has_hooks(mod):
358                    # We do not want to unroll sequential if it has hooks, since evaporating it
359                    # will cause hooks to not fire!
360                    # This terminates and restart the tracing process
361                    self.convert_to_unspecialized(tx)
362
363                # Unroll sequential
364                assert (
365                    not is_lazy
366                ), "Expected lazy sequential isn't a valid combination?"
367                assert not kwargs
368                (arg,) = args
369                # TODO: Use named_children when it supports remove_duplicate=False.
370                for child_name, submod in mod._modules.items():
371                    tx.call_function(
372                        tx.output.register_attr_or_module(
373                            submod,
374                            self.module_key,
375                            child_name,
376                            source=NNModuleSource(AttrSource(self.source, child_name)),
377                        ),
378                        [arg],
379                        {},
380                    )
381                    arg = tx.pop()
382                return arg
383
384            if is_lazy:
385                # The module type will change after it is called
386                if mod.cls_to_become is not None:
387                    self.module_type = mod.cls_to_become
388
389                # The pre-hook runs to initialize the module shapes, then deletes itself.  After this,
390                # the module is more or less not lazy and can be treated as a normal module regardless of
391                # is_allowed or other variations.
392                initialize_lazy_module(tx, mod, args, kwargs)
393
394            # If we are tracing the higher order op, we want Dynamo to step
395            # inside the module call so that Dynamo can see the underlying
396            # parameters and buffers and raise them as inputs to the graph.
397            #
398            # NB: torch.nn.utils.parametrize changes the class type of a
399            # parametrized module such that its __module__ points to
400            # "torch.nn.utils.parametrize".
401            if (
402                tx.output.is_root_tracer()
403                and mod.__module__.startswith(("torch.nn.", "torch.ao."))
404                and mod.__module__ != "torch.nn.utils.parametrize"
405            ):
406                if nnmodule_has_hooks(
407                    mod, check_forward_hooks=True, check_backward_hooks=True
408                ):
409                    # End of fn, this bubbles up and restarts tracing.
410                    self.convert_to_unspecialized(tx)
411
412                from .builder import wrap_fx_proxy
413
414                return wrap_fx_proxy(
415                    tx=tx,
416                    proxy=tx.output.create_proxy(
417                        "call_module",
418                        self.module_key,
419                        *proxy_args_kwargs(args, kwargs),
420                    ),
421                )
422            else:
423                assert self.source, (
424                    "Must provide a valid source in order to inline, "
425                    "since inlined function may have default args which must be guarded."
426                )
427                if isinstance(mod, torch.fx.GraphModule):
428                    # TODO: do we want to support __call__ for GM's?
429                    # If so at least some changes are needed, we don't allow inlining
430                    # the call_wrapped currently, and maybe other issues too
431                    fn = mod.forward
432                    fn_source = AttrSource(self.source, "forward")
433                else:
434                    fn = mod._call_impl
435                    fn_source = AttrSource(self.source, "_call_impl")
436                if istype(fn, types.MethodType):
437                    fn = fn.__func__
438                    fn_source = AttrSource(fn_source, "__func__")
439                    args = [self] + args
440                else:
441                    assert istype(fn, types.FunctionType)
442                return tx.inline_user_function_return(
443                    variables.UserFunctionVariable(fn, source=fn_source),
444                    args,
445                    kwargs,
446                )
447
448    def call_method(
449        self,
450        tx,
451        name,
452        args: "List[VariableTracker]",
453        kwargs: "Dict[str, VariableTracker]",
454        constant=False,
455    ) -> "VariableTracker":
456        from . import ConstantVariable, ListIteratorVariable, TupleVariable
457
458        key = self.module_key
459        module = tx.output.get_submodule(key)
460
461        def generic_call_method_helper(name):
462            # Helper function to put a `call_method` node in FX graph,
463            # with nn.Module as the first arg.
464            mod_proxy = tx.output.create_proxy(
465                "get_attr",
466                self.module_key,
467                (),
468                {},
469            )
470            set_example_value(mod_proxy.node, module)
471
472            proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs)
473
474            from .builder import wrap_fx_proxy
475
476            return wrap_fx_proxy(
477                tx=tx,
478                proxy=tx.output.create_proxy(
479                    "call_method",
480                    name,
481                    args=(mod_proxy, *proxy_args),
482                    kwargs=proxy_kwargs,
483                ),
484            )
485
486        if name in ["_call_impl", "_wrapped_call_impl"]:
487            # Example: `self.layer.__call__(x)`
488            # This is used for explicit calling `__call__` in a forward function.
489            # Dynamo inlines `__call__`, includes hooks.
490            return self.call_function(tx, args, kwargs)
491        elif name == "forward":
492            # Example: `self.layer.forward(x)`
493            # This is used for explicit calling `forward` in a forward function.
494            # Dynamo puts `call_method` node in FX, doesn't trigger hooks.
495            with record_nn_module_stack(
496                self.module_key, self.get_nn_module_stack_source(), tx, module
497            ):
498                return generic_call_method_helper(name)
499
500        if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed(
501            inspect.getfile(module.__class__._check_input_dim)
502        ):
503            return ConstantVariable.create(True)
504
505        if name == "_get_item_by_idx":
506            assert args[1].is_python_constant()
507            assert isinstance(args[0], TupleVariable)
508            mod_var = args[0].items[args[1].value]
509            if isinstance(mod_var, UnspecializedNNModuleVariable):
510                return mod_var
511            key = mod_var.module_key
512            submod = tx.output.get_submodule(key)
513            return tx.output.register_attr_or_module(
514                submod,
515                key,
516                key,
517                source=NNModuleSource(GetItemSource(self.source, key)),
518            )
519
520        if constant:
521            fn = getattr(module, name)
522            name = f"{module.__class__.__name__}_{name}_result"
523            return invoke_and_store_as_constant(tx, fn, name, args, kwargs)
524
525        def assert_all_args_kwargs_const():
526            if not all(
527                x.is_python_constant() for x in itertools.chain(args, kwargs.values())
528            ):
529                unimplemented(f"non-const NNModule method {name}")
530
531        def get_kwargs(*names):
532            assert_all_args_kwargs_const()
533            fn = getattr(module, name)
534            bound_args = inspect.signature(fn).bind(
535                *([x.as_python_constant() for x in args]),
536                **{k: v.as_python_constant() for k, v in kwargs.items()},
537            )
538            bound_args.apply_defaults()
539            bound_args = bound_args.arguments
540            return {k: bound_args[k] for k in names}
541
542        def wrap_values(items):
543            result = []
544            for name, submod in items:
545                result.append(
546                    tx.output.register_attr_or_module(
547                        submod,
548                        key,
549                        name,
550                        source=NNModuleSource(gen_source(self.source, name)),
551                    )
552                )
553            return ListIteratorVariable(result, mutable_local=MutableLocal())
554
555        def named_embed(name, obj):
556            return TupleVariable(
557                [
558                    ConstantVariable.create(name),
559                    tx.output.register_attr_or_module(
560                        obj,
561                        key,
562                        name,
563                        source=NNModuleSource(gen_source(self.source, name)),
564                    ),
565                ]
566            )
567
568        def gen_source(source, name):
569            name_split = name.split(".")
570            if name_split[0] == "":
571                return source
572            while len(name_split) > 0:
573                x = name_split.pop(0)
574                source = AttrSource(source, x)
575            return source
576
577        if name == "named_children":
578            tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
579            assert not (args or kwargs)
580            result = []
581            for name, submod in module.named_children():
582                result.append(named_embed(name, submod))
583            return ListIteratorVariable(result, mutable_local=MutableLocal())
584        elif name == "named_parameters":
585            tx.output.guard_on_key_order.add(
586                AttrSource(self.source, "_parameters").name()
587            )
588            result = []
589            for name, param in module.named_parameters(
590                **get_kwargs("prefix", "recurse")
591            ):
592                result.append(named_embed(name, param))
593            return ListIteratorVariable(result, mutable_local=MutableLocal())
594        elif name == "named_buffers":
595            tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers").name())
596            result = []
597            for name, buffer in module.named_buffers(
598                **get_kwargs("prefix", "recurse", "remove_duplicate")
599            ):
600                result.append(named_embed(name, buffer))
601            return ListIteratorVariable(result, mutable_local=MutableLocal())
602        elif name == "named_modules":
603            tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
604            result = []
605            for name, submod in module.named_modules(
606                **get_kwargs("memo", "prefix", "remove_duplicate")
607            ):
608                result.append(named_embed(name, submod))
609            return ListIteratorVariable(result, mutable_local=MutableLocal())
610        elif name == "children":
611            tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
612            assert not (args or kwargs)
613            return wrap_values(module.named_children())
614        elif name == "modules":
615            tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name())
616            return wrap_values(module.named_modules())
617        elif name == "parameters":
618            tx.output.guard_on_key_order.add(
619                AttrSource(self.source, "_parameters").name()
620            )
621            return wrap_values(module.named_parameters(**get_kwargs("recurse")))
622        elif name == "buffers":
623            tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers").name())
624            return wrap_values(module.named_buffers(**get_kwargs("recurse")))
625        elif name == "keys":
626            assert not (args or kwargs)
627            result = []
628            for name in module.keys():
629                result.append(ConstantVariable.create(name))
630            return ListIteratorVariable(result, mutable_local=MutableLocal())
631        elif name == "values":
632            assert not (args or kwargs)
633            return wrap_values(module.items())
634        elif name == "items":
635            assert not (args or kwargs)
636            result = []
637            for name, submod in module.items():
638                result.append(named_embed(name, submod))
639            return ListIteratorVariable(result, mutable_local=MutableLocal())
640        elif name == "__len__":
641            assert not (args or kwargs)
642            return ConstantVariable.create(len(module))
643        elif (
644            name == "__contains__"
645            and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict))
646            and args
647            and args[0].is_python_constant()
648        ):
649            return ConstantVariable.create(
650                args[0].as_python_constant() in module._modules
651            )
652        elif name == "__getitem__":
653            assert not kwargs and len(args) == 1
654            builtin_supported = (
655                torch.nn.ModuleDict.__getitem__,
656                torch.nn.ModuleList.__getitem__,
657                torch.nn.ParameterDict.__getitem__,
658                torch.nn.ParameterList.__getitem__,
659                torch.nn.Sequential.__getitem__,
660            )
661
662            if type(module).__getitem__ not in builtin_supported:
663                assert isinstance(args[0], variables.ConstantVariable), typestr(args[0])
664                key = args[0].as_python_constant()
665                assert isinstance(key, (str, int))
666                fn = getattr(module, name).__func__
667
668                assert isinstance(fn, types.FunctionType)
669
670                src = AttrSource(AttrSource(self.source, name), "__func__")
671                return tx.inline_user_function_return(
672                    variables.UserFunctionVariable(fn, source=src),
673                    [self] + list(args),
674                    kwargs,
675                )
676
677            assert self.source
678
679            if isinstance(args[0], SliceVariable):
680                # TODO(anijain2305,export-team) - Remove this if condition when inlining of inbuilt nn modules is
681                # enabled for export.
682                if tx.output.export:
683                    # Build a TupleVariable of NNModules
684                    result = []
685
686                    # Turn the slice into the list of integers
687                    keys = list(range(len(module)))[args[0].as_python_constant()]
688                    for idx, submod in enumerate(module[args[0].as_python_constant()]):
689                        key = keys[idx]
690                        src = NNModuleSource(GetItemSource(self.source, key))
691                        result.append(
692                            tx.output.register_attr_or_module(
693                                submod,
694                                key,
695                                source=src,
696                            )
697                        )
698
699                    new_module = module[args[0].as_python_constant()]
700                    new_module_variable = tx.output.register_attr_or_module(
701                        new_module,
702                        f"{self}.__getitem__(slice)",
703                        source=NNModuleSource(
704                            GetItemSource(self.source, args[0].as_python_constant())
705                        ),
706                    )
707                    return new_module_variable
708                else:
709                    # slice on nn module results in a creation of new module instance, so we need to make it sourceless.
710                    # Convert to unspecialized so that UnspecializedNNModule variable can take care of it.
711                    self.convert_to_unspecialized(tx)
712
713            from .tensor import SymNodeVariable
714
715            if isinstance(args[0], SymNodeVariable):
716                key = args[0].evaluate_expr(tx.output)
717            elif args[0].is_python_constant():
718                key = args[0].as_python_constant()
719            else:
720                unimplemented(f"getitem on NNModuleVariable with key {args[0]}")
721
722            submod = module[key]
723            return tx.output.register_attr_or_module(
724                submod,
725                self.module_key,
726                key,
727                source=NNModuleSource(GetItemSource(self.source, key)),
728            )
729        elif (
730            name == "_get_abs_string_index"
731            or (
732                isinstance(module, torch.nn.modules.conv._ConvNd)
733                and name == "_conv_forward"
734            )
735            or (
736                isinstance(module, torch.nn.modules.conv._ConvTransposeNd)
737                and name == "_output_padding"
738            )
739        ):
740            # Inline the function
741            fn = getattr(module, name).__func__
742            fn_source = AttrSource(AttrSource(self.source, name), "__func__")
743            return tx.inline_user_function_return(
744                variables.UserFunctionVariable(fn, source=fn_source),
745                [self] + args,
746                kwargs,
747            )
748        # A loose heuristic, but seems to be generally good before we drop into the
749        # manual handling of inputs
750        elif (
751            name in module.__class__.__dict__
752            and callable(module.__class__.__dict__[name])
753            and all(
754                isinstance(x, variables.TensorVariable)
755                for x in itertools.chain(args, kwargs.values())
756            )
757        ):
758            return generic_call_method_helper(name)
759        else:
760            return super().call_method(tx, name, args, kwargs)
761
762
763class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
764    _nonvar_fields = {
765        "value_type",
766        "is_state_mutated",
767        "nn_module_stack_source",
768        *UserDefinedObjectVariable._nonvar_fields,
769    }
770
771    """
772    The above class will specialize on the id() of a module and place
773    parameters on the torch.fx.GraphModule.  Giving one graph per
774    module instance.  This version treats nn.Modules() like other user
775    defined objects and will pass parameters into the FX graph as inputs.
776    Giving one graph per module class.
777    """
778
779    def __init__(self, value, **kwargs) -> None:
780        if type(value) is torch.jit._script.RecursiveScriptModule:
781            raise Unsupported(
782                "ScriptModules aren't supported in UnspecializedNNModuleVariable"
783                " becuase their .forward function isn't a static member of their type"
784            )
785        if "value_type" in kwargs:
786            lazy_value_to_become = getattr(kwargs["value_type"], "cls_to_become", None)
787            if type(value) is lazy_value_to_become:
788                # We may have cloned a variabletracker for a LazyModule earlier (e.g. tracking side-effects)
789                # and then later we called and mutated the LazyModule into a MaterializedModule.
790                # We do not do the mutation upon first seeing a LazyModule since we preserve eager semantics to only
791                # mutate upon first call, but this requires we update multiple copies of the VariableTracker post-mutation.
792                kwargs["value_type"] = type(value)
793
794        super().__init__(value=value, **kwargs)
795        self.is_state_mutated = False
796        # nn_module_stack_source is used to ensure BC for nn_module_stack.
797        # Downstream users prefer mod.linear instead of mod._modules['linear']
798        # as the module stack. When Dynamo inlines the __getattr__ method, we
799        # cannot use self.source for nn_module_stack because it will be similar
800        # to mod._modules['linear']. In these cases, we set the
801        # nn_module_stack_source appropriately to resemble mod.linear.
802        self.nn_module_stack_source = self.source
803
804    def _wrap_source(self, attr_source):
805        if not isinstance(attr_source, UnspecializedNNModuleSource):
806            return UnspecializedNNModuleSource(attr_source)
807        return attr_source
808
809    def get_nn_module_stack_source(self):
810        return self.nn_module_stack_source or self.source
811
812    def set_nn_module_stack_source(self, source):
813        self.nn_module_stack_source = source
814
815    @staticmethod
816    @functools.lru_cache(None)
817    def _nn_module_method_ids():
818        # Allow __setattr__ to fall through to base class handler
819        supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__}
820        return {
821            id(x.__code__)
822            for x in torch.nn.Module.__dict__.values()
823            if hasattr(x, "__code__") and x not in supported
824        }
825
826    def unpack_var_sequence(self, tx):
827        try:
828            fn = inspect.getattr_static(self.value_type, "__iter__")
829        except AttributeError as e:
830            raise NotImplementedError from e
831
832        if fn in (
833            torch.nn.ModuleList.__iter__,
834            torch.nn.ParameterList.__iter__,
835            torch.nn.Sequential.__iter__,
836        ):
837            # The program can mutate the nn module object but the saved `value`
838            # will not reflect the mutations. So, trace through the `__iter__`
839            # function to reflect any tracked mutations.
840            return tx.inline_user_function_return(
841                variables.UserFunctionVariable(fn),
842                [
843                    self,
844                ],
845                {},
846            ).unpack_var_sequence(tx)
847
848        return super().unpack_var_sequence(tx)
849
850    def call_function(
851        self,
852        tx: "InstructionTranslator",
853        args: "List[VariableTracker]",
854        kwargs: "Dict[str, VariableTracker]",
855    ) -> "VariableTracker":
856        mod = self.value
857        # see comment on lazy module handling in NNModuleVariable.call_function for context
858        if is_lazy_module(mod):
859            if mod.cls_to_become is not None:
860                self.value_type = mod.cls_to_become
861            initialize_lazy_module(tx, mod, args, kwargs)
862        name = "_call_impl"
863        fn = getattr(self.value_type, name)
864
865        # Check if we can short circuit nn.Module._call_impl to the forward
866        # method.  NB - This is done to reduce the compile time of Dynamo.
867        if fn is torch.nn.Module._call_impl and "forward" not in mod.__dict__:
868            forward_method = inspect.getattr_static(mod, "forward")
869            if isinstance(forward_method, types.FunctionType):
870                globals_vt = tx.nn_modules_globals_vt
871                if not (
872                    self.var_getattr(tx, "_backward_hooks").realize().len()
873                    or self.var_getattr(tx, "_backward_pre_hooks").realize().len()
874                    or self.var_getattr(tx, "_forward_hooks").realize().len()
875                    or self.var_getattr(tx, "_forward_pre_hooks").realize().len()
876                    or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len()
877                    or globals_vt.var_getattr(tx, "_global_backward_hooks").len()
878                    or globals_vt.var_getattr(tx, "_global_forward_hooks").len()
879                    or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len()
880                ):
881                    name = "forward"
882                    fn = self.value_type.forward
883
884        if self.source:
885            source = AttrSource(AttrSource(self.source, "__class__"), name)
886        else:
887            source = None
888
889        guard_to_detect_forward_monkeypatching(self.source, mod)
890
891        ctx = (
892            record_nn_module_stack(
893                str(id(mod)), self.get_nn_module_stack_source(), tx, mod
894            )
895            if self.source
896            else nullcontext()
897        )
898        with ctx:
899            return variables.UserFunctionVariable(fn, source=source).call_function(
900                tx, [self] + list(args), kwargs
901            )
902
903    def trace_supported_methods(
904        self, tx: "InstructionTranslator", method, name, args, kwargs
905    ):
906        def get_kwargs(*names):
907            fn = getattr(self.value, name)
908            bound_args = inspect.signature(fn).bind(
909                *([x.as_python_constant() for x in args]),
910                **{k: v.as_python_constant() for k, v in kwargs.items()},
911            )
912            bound_args.apply_defaults()
913            bound_args = bound_args.arguments
914            return {k: bound_args[k] for k in names}
915
916        def get_current_parameters(module_var):
917            params_dict = module_var.var_getattr(tx, "_parameters").realize().items
918            assert isinstance(params_dict, dict)
919            params_list = list(params_dict.values())
920            params_list = [param.realize() for param in params_list]
921            # Account for mod.param = None
922            params_list = [
923                param
924                for param in params_list
925                if isinstance(param, variables.TensorVariable)
926            ]
927            return params_list
928
929        def collect_parameters(module_var, recurse):
930            params_list = []
931            assert isinstance(module_var, UnspecializedNNModuleVariable)
932            params_list = get_current_parameters(module_var)
933            modules_dict = module_var.var_getattr(tx, "_modules").realize()
934            if recurse:
935                for submodule_var in modules_dict.items.values():
936                    assert isinstance(submodule_var, UnspecializedNNModuleVariable)
937                    params_list.extend(collect_parameters(submodule_var, recurse))
938            return params_list
939
940        if method is torch.nn.Module.parameters:
941            if self.source:
942                tx.output.guard_on_key_order.add(
943                    AttrSource(self.source, "_parameters").name()
944                )
945            recurse = get_kwargs("recurse")["recurse"]
946            params_list = collect_parameters(self, recurse=recurse)
947
948            # Account for duplicated params
949            deduplicated_params = list(dict.fromkeys(params_list).keys())
950
951            return variables.ListIteratorVariable(
952                deduplicated_params, mutable_local=MutableLocal()
953            )
954        else:
955            raise AssertionError(
956                "Discrepancy between is_supported_nn_module_method and trace_supported_methods"
957            )
958
959    def call_method(
960        self,
961        tx,
962        name,
963        args: "List[VariableTracker]",
964        kwargs: "Dict[str, VariableTracker]",
965    ) -> "VariableTracker":
966        if name in ["_call_impl", "_wrapped_call_impl"]:
967            fn = getattr(self.value_type, name)
968            if self.source:
969                source = AttrSource(AttrSource(self.source, "__class__"), name)
970            else:
971                source = None
972
973            return variables.UserFunctionVariable(fn, source=source).call_function(
974                tx, [self] + list(args), kwargs
975            )
976
977        if name not in getattr(self.value, "__dict__", {}):
978            try:
979                method = inspect.getattr_static(type(self.value), name)
980            except AttributeError:
981                method = None
982
983            if self.is_supported_nn_module_method(method):
984                return self.trace_supported_methods(tx, method, name, args, kwargs)
985
986            if isinstance(method, staticmethod):
987                source = AttrSource(
988                    AttrSource(AttrSource(self.source, "__class__"), name), "__func__"
989                )
990                return tx.inline_user_function_return(
991                    variables.UserFunctionVariable(method.__func__, source=source),
992                    args,
993                    kwargs,
994                )
995
996            if (
997                hasattr(method, "__code__")
998                and id(method.__code__) in self._nn_module_method_ids()
999            ):
1000                unimplemented(f"UnspecializedNNModuleVariable missing {name}")
1001
1002            # "_parameters" in self.value.__dict__ checks that module is initialized
1003            if name == "__setattr__" and "_parameters" in self.value.__dict__:
1004                # Record if mutations happens on parameters/buffers/modules. The
1005                # mutations on these are not tracked by base class
1006                # UserDefinedObject vt. This will be used later to graph break
1007                # on seeing a paramters() and family calls.
1008                # TODO(anijain2305) - This might not be needed if we let Dynamo
1009                # inline both getattr and setattr. In that case, it should see
1010                # the lowest level dicts - _parameters and family and
1011                # automatically track mutations on those. Investigate if that
1012                # can be done.
1013                attr_name = args[0].as_python_constant()
1014                value = args[1]
1015
1016                # This is reverse engineered by looking at nn module __setattr__
1017                # logic.
1018                if (
1019                    isinstance(value, variables.TensorVariable)
1020                    and value.python_type() is torch.nn.Parameter
1021                ) or attr_name in self.value.__dict__["_parameters"]:
1022                    # Handle parameters
1023                    self.is_state_mutated = True
1024                elif attr_name in self.value.__dict__["_buffers"]:
1025                    # Handle buffers
1026                    self.is_state_mutated = True
1027                elif (
1028                    isinstance(
1029                        value,
1030                        (
1031                            variables.NNModuleVariable,
1032                            variables.UnspecializedNNModuleVariable,
1033                        ),
1034                    )
1035                    or attr_name in self.value.__dict__["_modules"]
1036                ):
1037                    # Handle submodules
1038                    self.is_state_mutated = True
1039
1040            if method is torch.nn.Module.__setattr__ and isinstance(
1041                args[1], variables.DeletedVariable
1042            ):
1043                # Trace through __delattr__ to track mutations on the module
1044                # members like `_modules``.
1045                return tx.inline_user_function_return(
1046                    variables.UserFunctionVariable(torch.nn.Module.__delattr__),
1047                    [self, args[0]],
1048                    kwargs,
1049                )
1050
1051        return super().call_method(tx, name, args, kwargs)
1052
1053    def getattr_helper(self, tx: "InstructionTranslator", field, name_vt):
1054        dict_vt = self.var_getattr(tx, field)
1055        if isinstance(dict_vt, variables.ConstDictVariable):
1056            return dict_vt.maybe_getitem_const(name_vt)
1057        return None
1058
1059    def var_getattr(self, tx: "InstructionTranslator", name):
1060        # Allow skipping of empty hook dict guards on inbuilt nn modules
1061        if name in (
1062            "_backward_hooks",
1063            "_backward_pre_hooks",
1064            "_forward_hooks",
1065            "_forward_pre_hooks",
1066        ):
1067            # For empty hooks, make an EMPTY_NN_MODULE_HOOKS_DICT. This allows us to control the installation of empty
1068            # hooks guard via skip_nnmodule_hook_guards
1069            if not tx.output.side_effects.has_pending_mutation_of_attr(
1070                self, name
1071            ) and self.value.__module__.startswith(("torch.nn.", "torch.ao.")):
1072                hooks_dict = getattr(self.value, name)
1073                if isinstance(hooks_dict, dict) and len(hooks_dict) == 0:
1074                    if self.source:
1075                        hooks_source = AttrSource(self.source, name)
1076                        install_guard(
1077                            hooks_source.make_guard(
1078                                GuardBuilder.EMPTY_NN_MODULE_HOOKS_DICT
1079                            )
1080                        )
1081                    return variables.ConstDictVariable({})
1082
1083        # For non-empty hook dicts, one way is to just fallback to VariableBuilder and create a ConstDictVariable.
1084        # However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for
1085        # differnt nn module instances, because the key keeps changing (look more into RemovableHandle to understand why
1086        # key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a
1087        # ConstDictVariable to avoid any guard on the keys.
1088        if (
1089            self.source
1090            and name
1091            in (
1092                "_forward_pre_hooks",
1093                "_forward_hooks",
1094            )
1095            and not tx.output.side_effects.has_pending_mutation_of_attr(self, name)
1096        ):
1097            hooks_dict = getattr(self.value, name)
1098            hooks_dict_source = AttrSource(self.source, name)
1099            install_guard(hooks_dict_source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
1100            tx.output.guard_on_key_order.add(hooks_dict_source.name())
1101
1102            def build_key_value(i, k, v):
1103                # Make key sourceless to avoid any guard on it
1104                key = variables.ConstantVariable.create(k)
1105
1106                # Instead of using dict[key] to access the value, use a dict[dict.keys()[index]] to access the
1107                # value. This removes the reliance on the actual key value.
1108                source_key = ConstDictKeySource(hooks_dict_source, i)
1109                source_value = GetItemSource(hooks_dict_source, source_key)
1110                value = LazyVariableTracker.create(v, source_value)
1111                return key, value
1112
1113            result = dict(
1114                build_key_value(i, k, v) for i, (k, v) in enumerate(hooks_dict.items())
1115            )
1116
1117            return variables.ConstDictVariable(
1118                result, type(hooks_dict), source=hooks_dict_source
1119            )
1120        return super().var_getattr(tx, name)
1121
1122    def manually_trace_nn_module_getattr(self, tx: "InstructionTranslator", name):
1123        """
1124        Dynamo tracing of nn.Module __getattr__ can be expensive if the model
1125        has deep submodule hierarchy. Since the __getattr__ is stable, we can
1126        directly look into the underlying datastructures. This saves a lot of
1127        compilation time.
1128        """
1129        name_vt = variables.ConstantVariable(name)
1130        out = self.getattr_helper(tx, "_parameters", name_vt)
1131        if out is None:
1132            out = self.getattr_helper(tx, "_modules", name_vt)
1133        if out is None:
1134            out = self.getattr_helper(tx, "_buffers", name_vt)
1135        if out is None:
1136            raise_observed_exception(AttributeError, tx, self)
1137        return out
1138
1139
1140class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
1141    """
1142    Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules.
1143    """
1144
1145    def _wrap_source(self, attr_source):
1146        if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
1147            return UnspecializedBuiltinNNModuleSource(attr_source)
1148        return attr_source
1149
1150
1151class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
1152    """
1153    Tracing behavior: trace into submodules and treat them as Unspecialized, do not
1154    register parameters to the top-level, treat them as function inputs.
1155
1156    Guards behavior: if 'skip_fsdp_guards', many guards that would be installed
1157    by a vanilla UnspecializedNNModuleVariable are simply dropped, on the basis
1158    that a user wrapping their model in FSDP(model) is already opting into a
1159    requirement to not modify internal model state, which would already break FSDP without
1160    compilation.
1161    """
1162
1163    def __init__(self, value, **kwargs) -> None:
1164        source = kwargs.get("source", None)
1165        assert (
1166            source is not None
1167        ), "FSDPManagedNNModule depends on having an accurate source to control guarding."
1168
1169        super().__init__(value=value, **kwargs)
1170        self.source = source
1171
1172    def _wrap_source(self, attr_source):
1173        if not isinstance(
1174            attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource)
1175        ):
1176            if torch._dynamo.config.skip_fsdp_guards:
1177                return FSDPNNModuleSource(attr_source)
1178            else:
1179                return UnspecializedNNModuleSource(attr_source)
1180        return attr_source
1181