xref: /aosp_15_r20/external/pytorch/torch/fx/_lazy_graph_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from contextlib import contextmanager
3
4from torch.fx import GraphModule
5from torch.fx.graph_module import (
6    _format_import_block,
7    reduce_graph_module,
8    reduce_package_graph_module,
9)
10from torch.package import PackageExporter, sys_importer
11
12from ._compatibility import compatibility
13
14
15_use_lazy_graph_module_flag = False
16_force_skip_lazy_graph_module_flag = False
17
18
19@compatibility(is_backward_compatible=False)
20@contextmanager
21def _force_skip_lazy_graph_module():
22    """
23    Skip using lazy graph module disregarding the setting of _use_lazy_graph_module.
24    Use to skip _LazyGraphModule when testing inductor torchscript related backend.
25
26    torch.jit.script a _LazyGraphModule results in following error:
27        https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69
28    """
29    try:
30        global _force_skip_lazy_graph_module_flag
31        prior = _force_skip_lazy_graph_module_flag
32        _force_skip_lazy_graph_module_flag = True
33        yield
34    finally:
35        _force_skip_lazy_graph_module_flag = prior
36
37
38@compatibility(is_backward_compatible=False)
39@contextmanager
40def _use_lazy_graph_module(should_use: bool):
41    try:
42        global _use_lazy_graph_module_flag
43        prior = _use_lazy_graph_module_flag
44        _use_lazy_graph_module_flag = (
45            should_use and not _force_skip_lazy_graph_module_flag
46        )
47        yield
48    finally:
49        _use_lazy_graph_module_flag = prior
50
51
52@compatibility(is_backward_compatible=False)
53def _get_graph_module_cls():
54    return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule
55
56
57def _make_graph_module(*args, graph_module_cls=None, **kwargs):
58    if graph_module_cls is None:
59        graph_module_cls = _get_graph_module_cls()
60
61    return graph_module_cls(*args, **kwargs)
62
63
64@compatibility(is_backward_compatible=False)
65class _LazyGraphModule(GraphModule):
66    """
67    The main difference between _LazyGraphModule and GraphModule is how recompile happens.
68    GraphModule will do a 'recompile' call to generate python code and the forward method when it's
69    constructed. Later on if the graph get updated, recompile method can be called again to refresh
70    the saved python code and forward method.
71
72    However in some cases especially in inductor, the recompilation can be a waste since we never
73    check the python code for the graph module or call its forward method. A few more concreate
74    examples regarding pattern matching fx passes in inductor:
75    1. some passes will update the graph to be compiled and then call recompile on the GraphModule.
76    2. some passes will trace small pattern function to search it in the graph being compiled and
77       replace the match with the traced graph of a replacement function. The pattern graph and
78       replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile
79       for them in GraphModule.__init__ is also a waste of time.
80
81    However simply skip calling GraphModule.recompile in these scenarios is also dangeruous.
82    People may want to check the python code or call the GraphModule's forward method for debugging purposes.
83
84    The way _LazyGraphModule solves it is, we override the recompile method to just mark the
85    need for recompilation but does not do the actual recompilation. Later on if people really
86    access the compiled python code or call the GraphModule's forward method, we do the real
87    recompilation.
88    """
89
90    @classmethod
91    def from_graphmodule(cls, gm: GraphModule):
92        if isinstance(gm, _LazyGraphModule):
93            return gm
94        else:
95            return _LazyGraphModule(gm, gm.graph)
96
97    @staticmethod
98    def force_recompile(gm):
99        """
100        Sometimes we need force a recompile as a workaround
101        - we want to do the real recompilation before symbolic_trace to avoid error:
102            https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
103        """
104        if isinstance(gm, _LazyGraphModule):
105            gm.real_recompile()
106
107    def real_recompile(self):
108        if self._needs_recompile():
109            self._real_recompile()
110
111    @classmethod
112    def _needs_recompile(cls):
113        return cls.forward is cls._lazy_forward
114
115    def _lazy_forward(self, *args, **kwargs):
116        # Call self.real_recompile() rather than self._real_recompile() here.
117        # The _lazy_forward method may be saved and call repeatedly.
118        # Calling self.real_recompile can make sure we skip recompilation if
119        # we have already done so.
120        self.real_recompile()
121        assert not self._needs_recompile()
122
123        # call `__call__` rather than 'forward' since recompilation may
124        # install a wrapper for `__call__` to provide a customized error
125        # message.
126        return self(*args, **kwargs)
127
128    forward = _lazy_forward
129
130    # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__,
131    # or __reduce__ by calling _real_recompile. But I don't find a good way
132    # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule
133    # will be used in torch::deploy. So it's skipped for now.
134
135    def __reduce_package__(self, exporter: PackageExporter):
136        """
137        Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
138        than 'self.recompile' since for a _LazyGraphModule, self.recompile just
139        mark the need of recompilation and does not return the PythonCode object.
140        """
141        python_code = self._real_recompile()
142        dict_without_graph = self.__dict__.copy()
143        dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
144        del dict_without_graph["_graph"]
145
146        generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
147        import_block = _format_import_block(python_code.globals, exporter.importer)
148        module_code = import_block + self.code
149        exporter.save_source_string(generated_module_name, module_code)
150        return (
151            reduce_package_graph_module,
152            (dict_without_graph, generated_module_name),
153        )
154
155    def __reduce__(self):
156        """
157        Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
158        than 'self.recompile' since for a _LazyGraphModule, self.recompile just
159        mark the need of recompilation and does not return the PythonCode object.
160        """
161        python_code = self._real_recompile()
162        dict_without_graph = self.__dict__.copy()
163        import_block = _format_import_block(python_code.globals, sys_importer)
164        del dict_without_graph["_graph"]
165        return (reduce_graph_module, (dict_without_graph, import_block))
166
167    def _real_recompile(self):
168        return super().recompile()
169
170    @classmethod
171    def recompile(cls):
172        cls.forward = cls._lazy_forward
173
174    @property
175    def code(self) -> str:
176        self.real_recompile()
177        return super().code
178
179    def __str__(self) -> str:
180        """
181        str(GraphModule) will access the _code attribute. Make sure recompile
182        happens so _code attribute is available.
183        """
184        self.real_recompile()
185        return super().__str__()
186