xref: /aosp_15_r20/external/pytorch/torch/fx/graph_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import copy
4import itertools
5import linecache
6import os
7import sys
8import traceback
9import warnings
10from pathlib import Path
11from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
12
13import torch
14import torch.nn as nn
15import torch.overrides
16from torch.nn.modules.module import _addindent
17from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
18
19from ._compatibility import compatibility
20from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode
21
22__all__ = [
23    "reduce_graph_module",
24    "reduce_package_graph_module",
25    "reduce_deploy_graph_module",
26    "GraphModule",
27]
28
29_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
30
31# Normal exec loses the source code, however we can work with
32# the linecache module to recover it.
33# Using _exec_with_source will add it to our local cache
34# and then tools like TorchScript will be able to get source info.
35class _EvalCacheLoader:
36    def __init__(self):
37        self.eval_cache = {}
38        self.next_id = 0
39
40    def cache(self, src: str, globals: Dict[str, Any], co_fields=None):
41        """Store the source in a private cache, and add a lazy entry in linecache
42        that allows the source to be retrieved by 'filename'.
43
44        Args:
45            src (str): The module source to cache
46            globals (dict): The module globals
47
48        Returns:
49            str: The cache key (and dummy filename) generated for src.
50        """
51
52        key = self._get_key()
53        if co_fields:
54            key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
55        self.eval_cache[key] = src
56
57        # Don't mutate globals so that this loader is only used
58        # to populate linecache, and doesn't interact with other modules
59        # that might check `__loader__`
60        globals_copy = globals.copy()
61        globals_copy["__file__"] = key
62        globals_copy["__name__"] = key
63        globals_copy["__loader__"] = self
64        linecache.lazycache(key, globals_copy)
65
66        return key
67
68    # Part of the loader protocol (PEP 302)
69    # linecache will use this method when trying to find source code
70    def get_source(self, module_name) -> Optional[str]:
71        if module_name in self.eval_cache:
72            return self.eval_cache[module_name]
73        return None
74
75    def _get_key(self):
76        key = f"<eval_with_key>.{self.next_id}"
77        self.next_id += 1
78        return key
79
80
81_loader = _EvalCacheLoader()
82
83
84def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None):
85    key = _loader.cache(src, globals, co_fields)
86    exec(compile(src, key, "exec"), globals)
87
88
89def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None):
90    return _method_from_src(
91        method_name="forward", src=src, globals=globals, co_fields=co_fields
92    )
93
94
95def _method_from_src(
96    method_name: str, src: str, globals: Dict[str, Any], co_fields=None
97) -> Callable:
98    # avoid mutating the passed in dict
99    globals_copy = globals.copy()
100    _exec_with_source(src, globals_copy, co_fields)
101    fn = globals_copy[method_name]
102    del globals_copy[method_name]
103    return fn
104
105
106def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
107    if name in _custom_builtins:
108        return _custom_builtins[name].import_str
109    if _is_from_torch(name):
110        return "import torch"
111    module_name, attr_name = importer.get_name(obj)
112    return f"from {module_name} import {attr_name} as {name}"
113
114
115def _format_import_block(globals: Dict[str, Any], importer: Importer):
116    import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()}
117    # Sort the imports so we have a stable import block that allows us to
118    # hash the graph module and get a consistent key for use in a cache.
119    return "\n".join(sorted(import_strs))
120
121
122@compatibility(is_backward_compatible=True)
123def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
124    # BC: attribute name was changed from `code` to `_code` to facilitate
125    # making `code` into a property and adding a docstring to it
126    fn_src = body.get("_code") or body["code"]
127    forward = _forward_from_src(import_block + fn_src, {})
128    return _deserialize_graph_module(forward, body)
129
130
131@compatibility(is_backward_compatible=True)
132def reduce_package_graph_module(
133    importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
134) -> torch.nn.Module:
135    forward = importer.import_module(generated_module_name).forward
136    return _deserialize_graph_module(forward, body)
137
138
139@compatibility(is_backward_compatible=True)
140def reduce_deploy_graph_module(
141    importer: PackageImporter, body: Dict[Any, Any], import_block: str
142) -> torch.nn.Module:
143    ns = {}
144    ns["__builtins__"] = importer.patched_builtins
145    fn_src = body.get("_code")
146    assert fn_src is not None
147    forward = _forward_from_src(import_block + fn_src, ns)
148    return _deserialize_graph_module(forward, body)
149
150
151# We create a dummy class here because symbolic_trace pulls the forward()
152# function off of the class, rather than the instance. This class is used
153# in _deserialize_graph_module() below.
154class _CodeOnlyModule(torch.nn.Module):
155    def __init__(self, body):
156        super().__init__()
157        self.__dict__ = body
158
159
160def _deserialize_graph_module(forward, body: Dict[Any, Any], graph_module_cls=None) -> torch.nn.Module:
161    """
162    Deserialize a GraphModule given the dictionary of the original module,
163    using the code to reconstruct the graph. We delete the actual graph before
164    saving the dictionary so that changes to the in-memory graph format do not
165    get serialized.
166    """
167
168    # Try to retrieve the forward source in a backward-compatible way
169    _CodeOnlyModule.forward = forward
170
171    tracer_cls = body.get("_tracer_cls")
172    if tracer_cls is None:
173        from ._symbolic_trace import Tracer
174
175        tracer_cls = Tracer
176
177    graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule")
178
179    # This is a workaround for a mypy linter issue related to
180    # passing base class as an argument - https://github.com/python/mypy/issues/5865.
181    cls_tracer: Any = tracer_cls
182
183    class KeepModules(cls_tracer):
184        # we shouldn't trace into any of the submodules,
185        # because they were not traced in the original GraphModule
186        def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
187            return True
188
189    com = _CodeOnlyModule(body)
190
191    tracer_extras = body.get("_tracer_extras", {})
192    graph = KeepModules().trace(com, **tracer_extras)
193
194    # Manually set Tracer class on the reconstructed Graph, to avoid
195    # referencing the private local subclass KeepModules.
196    graph._tracer_cls = tracer_cls
197    from ._lazy_graph_module import _make_graph_module
198    gm = _make_graph_module(com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls)
199
200    # The GraphModule constructor only retains attributes referenced by the graph.
201    # In this case, our goal is return a GraphModule as close to identical as the one
202    # put into the package. If any additional attributes were present in body,
203    # we should keep them.
204    for k, v in body.items():
205        if not hasattr(gm, k):
206            setattr(gm, k, v)
207    return gm
208
209
210# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
211# This installs empty Modules where none exist yet if they are subpaths of target
212def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
213    *prefix, field = target.split(".")
214    for item in prefix:
215        f = getattr(from_module, item)
216        t = getattr(to_module, item, None)
217        if f is t:
218            # we have already installed one of its parents
219            # (e.g. target = root.linear.weight, but we have already installed root.linear)
220            # once we install a parent, we no longer need to copy the children
221            # since all the needed properties will already be present
222            return
223
224        if t is None:
225            t = torch.nn.Module()
226            setattr(to_module, item, t)
227        from_module, to_module = f, t
228
229    orig = getattr(from_module, field)
230    # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
231    # So, we register it as a named buffer in the target module.
232    if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
233        to_module.register_buffer(field, orig)
234    else:
235        setattr(to_module, field, orig)
236
237
238# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
239# This installs empty Modules where none exist yet if they are subpaths of target
240def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
241    *prefix, field = target.split(".")
242    for item in prefix:
243        t = getattr(to_module, item, None)
244
245        if t is None:
246            t = torch.nn.Module()
247            setattr(to_module, item, t)
248        to_module = t
249
250    # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
251    # So, we register it as a named buffer in the target module.
252    if isinstance(from_obj, torch.Tensor) and not isinstance(
253        from_obj, torch.nn.Parameter
254    ):
255        to_module.register_buffer(field, from_obj)
256    else:
257        setattr(to_module, field, from_obj)
258
259
260def _print_readable(
261    module,
262    module_name,
263    print_output=True,
264    include_stride=False,
265    include_device=False,
266    colored=False,
267):
268    graph = module.graph
269    assert graph is not None and isinstance(graph, torch.fx.Graph), "print_readable must be used on a module with a graph"
270
271    verbose_python_code = graph.python_code(
272        root_module="self",
273        verbose=True,
274        include_stride=include_stride,
275        include_device=include_device,
276        colored=colored,
277    )
278    module_code = verbose_python_code.src
279    module_code = module_code.lstrip("\n")
280    module_code = f"class {module_name}(torch.nn.Module):\n" + module_code
281    module_code = _addindent(module_code, 4)
282
283    submodule_code_list = [""]
284    for submodule_name, submodule in module.named_children():
285        if hasattr(submodule, "graph"):
286            submodule_code_list.append(
287                _print_readable(
288                    submodule,
289                    submodule_name,
290                    print_output=False,
291                    include_stride=include_stride,
292                    include_device=include_device,
293                    colored=colored,
294                )
295            )
296    submodule_code = "\n".join(submodule_code_list)
297    submodule_code = _addindent(submodule_code, 4)
298
299    output = module_code + submodule_code
300    if print_output:
301        print(module_code + submodule_code)
302    return output
303
304
305class _WrappedCall:
306    def __init__(self, cls, cls_call):
307        self.cls = cls
308        self.cls_call = cls_call
309
310    # Previously, if an error occurred when valid
311    # symbolically-traced code was run with an invalid input, the
312    # user would see the source of the error as coming from
313    # `File "<eval_with_key_N">`, where N is some number. We use
314    # this function to generate a more informative error message. We
315    # return the traceback itself, a message explaining that the
316    # error occurred in a traced Module's generated forward
317    # function, and five lines of context surrounding the faulty
318    # line
319    @staticmethod
320    def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
321        # auxiliary variables (for readability)
322        err_lineno = frame_summary.lineno
323        assert err_lineno is not None
324        line = frame_summary.line
325        assert line is not None
326        err_line_len = len(line)
327        all_src_lines = linecache.getlines(frame_summary.filename)
328
329        # constituent substrings of the error message
330        tb_repr = torch._dynamo.disable(traceback.format_exc)()
331        custom_msg = (
332            "Call using an FX-traced Module, "
333            f"line {err_lineno} of the traced Module's "
334            "generated forward function:"
335        )
336        before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
337        marker = "~" * err_line_len + "~~~ <--- HERE"
338        err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
339
340        # joined message
341        return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
342
343    def __call__(self, obj, *args, **kwargs):
344        try:
345            if self.cls_call is not None:
346                return self.cls_call(obj, *args, **kwargs)
347            else:
348                return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
349        except Exception as e:
350            assert e.__traceback__
351            topmost_framesummary: traceback.FrameSummary = (
352                traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]
353            )  # type: ignore[arg-type]
354            if "eval_with_key" in topmost_framesummary.filename:
355                print(
356                    _WrappedCall._generate_error_message(topmost_framesummary),
357                    file=sys.stderr,
358                )
359                raise e.with_traceback(None)  # noqa: B904
360            else:
361                raise e
362
363@compatibility(is_backward_compatible=True)
364class GraphModule(torch.nn.Module):
365    """
366    GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
367    ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
368    from that ``graph``.
369
370    .. warning::
371
372        When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
373        regenerated. However, if you edit the contents of the ``graph`` without reassigning
374        the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
375        code.
376    """
377
378    def __new__(cls: "Type[GraphModule]", *args, **kwargs):
379        # each instance of a graph module needs its own forward method
380        # so create a new singleton class for each instance.
381        # it is a subclass of the user-defined class, the only difference
382        # is an extra layer to install the forward method
383
384        # address issue described at https://github.com/pytorch/pytorch/issues/63883
385        # in other words, traverse class hierarchy to fix the redundant class definition problem
386        for t in cls.__mro__:
387            c = t.__qualname__.split(".")[-1]
388            if c != "GraphModuleImpl":
389                cls = t
390                break
391
392        class GraphModuleImpl(cls):  # type: ignore[misc, valid-type]
393            pass
394
395        return super().__new__(GraphModuleImpl)
396
397    @compatibility(is_backward_compatible=True)
398    def __init__(
399        self,
400        root: Union[torch.nn.Module, Dict[str, Any]],
401        graph: Graph,
402        class_name: str = "GraphModule",
403    ):
404        """
405        Construct a GraphModule.
406
407        Args:
408
409            root (Union[torch.nn.Module, Dict[str, Any]):
410                ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
411                In the case that ``root`` is a Module, any references to Module-based objects (via qualified
412                name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
413                within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
414                In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
415                looked up directly in the dict's keys. The object mapped to by the Dict will be copied
416                over into the appropriate place within the GraphModule's module hierarchy.
417
418            graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
419
420            class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
421                error messages will report as originating from ``GraphModule``. It may be helpful to set this
422                to ``root``'s original name or a name that makes sense within the context of your transform.
423        """
424        super().__init__()
425        self.__class__.__name__ = class_name
426        if isinstance(root, torch.nn.Module):
427            if hasattr(root, "training"):
428                self.training = root.training
429
430            # When we pickle/unpickle graph module, we don't want to drop any module or attributes.
431            if isinstance(root, _CodeOnlyModule):
432                for k, _ in root.named_children():
433                    _copy_attr(root, self, k)
434
435                for k, _ in root.named_buffers():
436                    _copy_attr(root, self, k)
437
438                for k, _ in root.named_parameters():
439                    _copy_attr(root, self, k)
440
441            for node in graph.nodes:
442                if node.op in ["get_attr", "call_module"]:
443                    assert isinstance(node.target, str)
444                    _copy_attr(root, self, node.target)
445        elif isinstance(root, dict):
446            targets_to_copy = []
447            for node in graph.nodes:
448                if node.op in ["get_attr", "call_module"]:
449                    assert isinstance(node.target, str)
450                    if node.target not in root:
451                        raise RuntimeError(
452                            "Node "
453                            + str(node)
454                            + " referenced target "
455                            + node.target
456                            + " but that target was not provided in ``root``!"
457                        )
458                    targets_to_copy.append(node.target)
459            # Sort targets in ascending order of the # of atoms.
460            # This will ensure that less deeply nested attributes are assigned
461            # before more deeply nested attributes. For example, foo.bar
462            # will be assigned before foo.bar.baz. Otherwise, we might assign
463            # the user-provided ``foo.bar`` and wipe out the previously-assigned
464            # ``foo.bar.baz``
465            targets_to_copy.sort(key=lambda t: t.count("."))
466            for target_to_copy in targets_to_copy:
467                _assign_attr(root[target_to_copy], self, target_to_copy)
468        else:
469            raise RuntimeError("Unsupported type " + str(root) + " passed for root!")
470
471        self.graph = graph
472
473        # Store the Tracer class responsible for creating a Graph separately as part of the
474        # GraphModule state, except when the Tracer is defined in a local namespace.
475        # Locally defined Tracers are not pickleable. This is needed because torch.package will
476        # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
477        # to re-create the Graph during deserialization.
478        self._tracer_cls = None
479        if (
480            self.graph._tracer_cls
481            and "<locals>" not in self.graph._tracer_cls.__qualname__
482        ):
483            self._tracer_cls = self.graph._tracer_cls
484
485        self._tracer_extras = {}
486        if self.graph._tracer_extras:
487            self._tracer_extras = self.graph._tracer_extras
488
489        # Dictionary to store metadata
490        self.meta: Dict[str, Any] = {}
491        self._replace_hook = None
492        self._create_node_hooks: List[Callable] = []
493        self._erase_node_hooks: List[Callable] = []
494
495    # TorchScript breaks trying to compile the graph setter because of the
496    # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
497    #
498    # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
499    __jit_unused_properties__ = ["graph"]
500
501    @property
502    def graph(self) -> Graph:
503        """
504        Return the ``Graph`` underlying this ``GraphModule``
505        """
506        return self._graph
507
508    @graph.setter
509    def graph(self, g: Graph) -> None:
510        """
511        Set the underlying ``Graph`` for this ``GraphModule``. This will internally
512        recompile the ``GraphModule`` so that the generated ``forward()`` function
513        corresponds to ``g``
514        """
515        assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}"
516        self._graph = g
517        g.owning_module = self
518        self.recompile()
519
520    @compatibility(is_backward_compatible=False)
521    def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
522        """Dumps out module to ``folder`` with ``module_name`` so that it can be
523        imported with ``from <folder> import <module_name>``
524
525        Args:
526
527            folder (Union[str, os.PathLike]): The folder to write the code out to
528
529            module_name (str): Top-level name to use for the ``Module`` while
530                writing out the code
531        """
532        folder = Path(folder)
533        Path(folder).mkdir(exist_ok=True)
534        torch.save(self.state_dict(), folder / "state_dict.pt")
535        tab = " " * 4
536        custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()])
537        model_str = f"""
538import torch
539{custom_builtins}
540
541from torch.nn import *
542class {module_name}(torch.nn.Module):
543    def __init__(self):
544        super().__init__()
545"""
546
547        def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
548            safe_reprs = [
549                nn.Linear,
550                nn.Conv1d,
551                nn.Conv2d,
552                nn.Conv3d,
553                nn.BatchNorm1d,
554                nn.BatchNorm2d,
555                nn.BatchNorm3d,
556            ]
557            if type(module) in safe_reprs:
558                return f"{module.__repr__()}"
559            else:
560                return None
561
562        blobified_modules = []
563        for module_name, module in self.named_children():
564            module_str = _gen_model_repr(module_name, module)
565            if module_str is None:
566                module_file = folder / f"{module_name}.pt"
567                torch.save(module, module_file)
568                blobified_modules.append(module_name)
569                module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
570                # weights_only=False as this is legacy code that saves the model
571                module_str = f"torch.load(r'{module_file}', weights_only=False) # {module_repr}"
572            model_str += f"{tab*2}self.{module_name} = {module_str}\n"
573
574        for buffer_name, buffer in self._buffers.items():
575            if buffer is None:
576                continue
577            model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
578
579        for param_name, param in self._parameters.items():
580            if param is None:
581                continue
582            model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
583
584        model_str += (
585            f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
586        )
587        model_str += f"{_addindent(self.code, 4)}\n"
588
589        module_file = folder / "module.py"
590        module_file.write_text(model_str)
591
592        init_file = folder / "__init__.py"
593        init_file.write_text("from .module import *")
594
595        if len(blobified_modules) > 0:
596            warnings.warn(
597                "Was not able to save the following children modules as reprs -"
598                f"saved as pickled files instead: {blobified_modules}"
599            )
600
601    @compatibility(is_backward_compatible=True)
602    def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
603        """
604        Adds the given submodule to ``self``.
605
606        This installs empty Modules where none exist yet if they are
607        subpaths of ``target``.
608
609        Args:
610            target: The fully-qualified string name of the new submodule
611                (See example in ``nn.Module.get_submodule`` for how to
612                specify a fully-qualified string.)
613            m: The submodule itself; the actual object we want to
614                install in the current Module
615
616        Return:
617            bool: Whether or not the submodule could be inserted. For
618                this method to return True, each object in the chain
619                denoted by ``target`` must either a) not exist yet,
620                or b) reference an ``nn.Module`` (not a parameter or
621                other attribute)
622        """
623        *prefix, field = target.split(".")
624        mod: torch.nn.Module = self
625
626        for item in prefix:
627
628            submod = getattr(mod, item, None)
629
630            if submod is None:
631                submod = torch.nn.Module()
632                setattr(mod, item, submod)
633
634            if not isinstance(submod, torch.nn.Module):
635                return False
636
637            mod = submod
638
639        mod.add_module(field, m)
640        return True
641
642    @compatibility(is_backward_compatible=True)
643    def delete_submodule(self, target: str) -> bool:
644        """
645        Deletes the given submodule from ``self``.
646
647        The module will not be deleted if ``target`` is not a valid
648        target.
649
650        Args:
651            target: The fully-qualified string name of the new submodule
652                (See example in ``nn.Module.get_submodule`` for how to
653                specify a fully-qualified string.)
654
655        Returns:
656            bool: Whether or not the target string referenced a
657                submodule we want to delete. A return value of ``False``
658                means that the ``target`` was not a valid reference to
659                a submodule.
660        """
661        atoms = target.split(".")
662        path, target_submod = atoms[:-1], atoms[-1]
663        mod: torch.nn.Module = self
664
665        # Get the parent module
666        for item in path:
667
668            if not hasattr(mod, item):
669                return False
670
671            mod = getattr(mod, item)
672
673            if not isinstance(mod, torch.nn.Module):
674                return False
675
676        if not hasattr(mod, target_submod):
677            return False
678
679        if not isinstance(getattr(mod, target_submod), torch.nn.Module):
680            return False
681
682        delattr(mod, target_submod)
683        return True
684
685    @compatibility(is_backward_compatible=True)
686    def delete_all_unused_submodules(self) -> None:
687        """
688        Deletes all unused submodules from ``self``.
689
690        A Module is considered "used" if any one of the following is
691        true:
692        1. It has children that are used
693        2. Its forward is called directly via a ``call_module`` node
694        3. It has a non-Module attribute that is used from a
695        ``get_attr`` node
696
697        This method can be called to clean up an ``nn.Module`` without
698        manually calling ``delete_submodule`` on each unused submodule.
699        """
700        used: List[str] = []
701
702        for node in self.graph.nodes:
703
704            if node.op == "call_module" or node.op == "get_attr":
705
706                # A list of strings representing the different parts
707                # of the path. For example, `foo.bar.baz` gives us
708                # ["foo", "bar", "baz"]
709                fullpath = node.target.split(".")
710
711                # If we're looking at multiple parts of a path, join
712                # join them with a dot. Otherwise, return that single
713                # element without doing anything to it.
714                def join_fn(x: str, y: str) -> str:
715                    return ".".join([x, y] if y else [x])
716
717                # Progressively collect all the names of intermediate
718                # modules. For example, if we have the target
719                # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
720                # `foo.bar.baz` to the list.
721                used.extend(itertools.accumulate(fullpath, join_fn))
722
723                # For a `call_module` node, also register all recursive submodules
724                # as used
725                if node.op == "call_module":
726                    try:
727                        submod = self.get_submodule(node.target)
728
729                        for submod_name, _ in submod.named_modules():
730                            if submod_name != "":
731                                used.append(".".join([node.target, submod_name]))
732                    except AttributeError:
733                        # Node referenced nonexistent submodule, don't need to
734                        # worry about GCing anything
735                        pass
736
737        to_delete = [name for name, _ in self.named_modules() if name not in used]
738
739        for name in to_delete:
740            self.delete_submodule(name)
741
742    @property
743    def code(self) -> str:
744        """
745        Return the Python code generated from the ``Graph`` underlying this
746        ``GraphModule``.
747        """
748        if not hasattr(self, "_code"):
749            raise RuntimeError(
750                "Code has not been generated! Please report a bug to PyTorch"
751            )
752        return self._code
753
754    @compatibility(is_backward_compatible=True)
755    def recompile(self) -> PythonCode:
756        """
757        Recompile this GraphModule from its ``graph`` attribute. This should be
758        called after editing the contained ``graph``, otherwise the generated
759        code of this ``GraphModule`` will be out of date.
760        """
761        if isinstance(self._graph._codegen, _PyTreeCodeGen):
762            self._in_spec = self._graph._codegen.pytree_info.in_spec
763            self._out_spec = self._graph._codegen.pytree_info.out_spec
764        python_code = self._graph.python_code(root_module="self")
765        self._code = python_code.src
766        self._lineno_map = python_code._lineno_map
767
768        cls = type(self)
769        co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
770        cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
771
772        # Determine whether this class explicitly defines a __call__ implementation
773        # to wrap. If it does, save it in order to have wrapped_call invoke it.
774        # If it does not, wrapped_call can use a dynamic call to super() instead.
775        # In most cases, super().__call__ should be torch.nn.Module.__call__.
776        # We do not want to hold a reference to Module.__call__ here; doing so will
777        # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
778        cls_call = cls.__call__ if "__call__" in vars(cls) else None
779
780        if "_wrapped_call" not in vars(cls):
781            cls._wrapped_call = _WrappedCall(cls, cls_call)  # type: ignore[attr-defined]
782
783        def call_wrapped(self, *args, **kwargs):
784            return self._wrapped_call(self, *args, **kwargs)
785
786        cls.__call__ = call_wrapped  # type: ignore[method-assign]
787
788        return python_code
789
790    # Passing Tracer as argument allows subclasses extending fx.GraphModule
791    # define their own Tracer (extending fx.Tracer).
792    def __reduce_deploy__(self, importer: Importer):
793        dict_without_graph = self.__dict__.copy()
794        dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
795        del dict_without_graph["_graph"]
796
797        python_code = self.recompile()
798        import_block = _format_import_block(python_code.globals, importer)
799        return (reduce_deploy_graph_module, (dict_without_graph, import_block))
800
801    def __reduce_package__(self, exporter: PackageExporter):
802        dict_without_graph = self.__dict__.copy()
803        dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
804        del dict_without_graph["_graph"]
805
806        generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
807        python_code = self.recompile()
808        import_block = _format_import_block(python_code.globals, exporter.importer)
809        module_code = import_block + self.code
810        exporter.save_source_string(generated_module_name, module_code)
811        return (
812            reduce_package_graph_module,
813            (dict_without_graph, generated_module_name),
814        )
815
816    def __reduce__(self):
817        """
818        Serialization of GraphModule. We serialize only the generated code, not
819        the underlying ``Graph``. This is because ``Graph`` does not have on-disk
820        backward-compatibility guarantees, whereas Python source code does.
821        On the deserialization side, we symbolically trace through the generated
822        code to regenerate the underlying ``Graph``
823        """
824        dict_without_graph = self.__dict__.copy()
825
826        python_code = self.recompile()
827        import_block = _format_import_block(python_code.globals, sys_importer)
828        del dict_without_graph["_graph"]
829        return (reduce_graph_module, (dict_without_graph, import_block))
830
831    def _deepcopy_init(self):
832        return GraphModule.__init__
833
834    # because __reduce__ is defined for serialization,
835    # we need to define deepcopy otherwise it will call __reduce__
836    # and cause symbolic tracing to occur every time we try to copy the object
837    def __deepcopy__(self, memo):
838        res = type(self).__new__(type(self))
839        memo[id(self)] = res
840        fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
841        self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"])
842        # hooks are lost during `GraphModule.__init__`, so we need to copy over
843        # them explicitly, note right now we are only copying state_dict related
844        # hooks, to reduce bc-related issues, we can copy forward/backward related
845        # hooks in the future as well if needed
846        extra_preserved_attrs = [
847            "_state_dict_hooks",
848            "_load_state_dict_pre_hooks",
849            "_load_state_dict_post_hooks",
850            "_replace_hook",
851            "_create_node_hooks",
852            "_erase_node_hooks"
853        ]
854        for attr in extra_preserved_attrs:
855            if attr in self.__dict__:
856                setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
857        res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
858        if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
859            for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
860                setattr(res, attr_name, attr)
861        return res
862
863    def __copy__(self):
864        from ._lazy_graph_module import _make_graph_module
865        res = _make_graph_module(self, self.graph)
866        res.meta = getattr(self, "meta", {})
867        return res
868
869    @compatibility(is_backward_compatible=False)
870    def print_readable(self, print_output=True, include_stride=False, include_device=False, colored=False):
871        """
872        Return the Python code generated for current GraphModule and its children GraphModules
873        """
874        return _print_readable(
875            self,
876            self._get_name(),
877            print_output,
878            include_stride,
879            include_device,
880            colored,
881        )
882
883    def __str__(self) -> str:
884        orig_str = super().__str__()
885        print_readable_reminder = (
886            "# To see more debug info, please use `graph_module.print_readable()`"
887        )
888        return "\n".join([orig_str, self._code, print_readable_reminder])
889
890    def _replicate_for_data_parallel(self):
891        new_gm = self.__copy__()
892        new_gm._is_replica = True
893        return new_gm
894
895    @contextlib.contextmanager
896    def _set_replace_hook(self, f):
897        """
898        Takes a callable which will be called everytime when we replace a node
899        to a new node, or change the node's name. Callable takes three arguments:
900        the old node we're changing, and NAME of the new node, followed by the
901        user node which consumes the old node to be replaced.
902        """
903        assert callable(f), "Replace hook must be a callable."
904        prev, self._replace_hook = self._replace_hook, f
905        try:
906            yield
907        finally:
908            self._replace_hook = prev
909
910    def _register_create_node_hook(self, f):
911        """
912        Takes a callable which will be called after we create a new node. The
913        callable takes the newly created node as input and returns None.
914        """
915        assert callable(f), "create_node hook must be a callable."
916        self._create_node_hooks.append(f)
917
918    def _unregister_create_node_hook(self, f):
919        """
920        Takes a callable which was previously registered to be called after we create a node.
921        This function will unregister that callable so it is no longer invoked on node creation.
922        """
923        assert callable(f), "create_node hook must be a callable."
924        self._create_node_hooks.remove(f)
925
926    def _register_erase_node_hook(self, f):
927        """
928        Takes a callable which will be called after we erase a node. The
929        callable takes the node that is being erased as input and returns None.
930        """
931        assert callable(f), "erase_node hook must be a callable."
932        self._erase_node_hooks.append(f)
933
934    def _unregister_erase_node_hook(self, f):
935        """
936        Takes a callable which was previously registered to be called after we erase a node.
937        This function will unregister that callable so it is no longer invoked on node erasure.
938        """
939        assert callable(f), "erase_node hook must be a callable."
940        self._erase_node_hooks.remove(f)
941
942# workarounds for issues in __torch_function__
943
944# WAR for __torch_function__ not handling tensor lists,
945# fix is in https://github.com/pytorch/pytorch/pull/34725
946# orig_cat = torch.cat
947# def patched_cat(*args, **kwargs):
948#     tensors = args[0]
949#     for t in tensors:
950#         if isinstance(t, Proxy):
951#             return t.__torch_function__(patched_cat, (), args, kwargs)
952#     return orig_cat(*args, **kwargs)
953# patched_cat.__module__ = 'torch'
954# patched_cat.__name__ = 'cat'
955# torch.cat = patched_cat
956