1# mypy: allow-untyped-defs 2from contextlib import contextmanager 3 4from torch.fx import GraphModule 5from torch.fx.graph_module import ( 6 _format_import_block, 7 reduce_graph_module, 8 reduce_package_graph_module, 9) 10from torch.package import PackageExporter, sys_importer 11 12from ._compatibility import compatibility 13 14 15_use_lazy_graph_module_flag = False 16_force_skip_lazy_graph_module_flag = False 17 18 19@compatibility(is_backward_compatible=False) 20@contextmanager 21def _force_skip_lazy_graph_module(): 22 """ 23 Skip using lazy graph module disregarding the setting of _use_lazy_graph_module. 24 Use to skip _LazyGraphModule when testing inductor torchscript related backend. 25 26 torch.jit.script a _LazyGraphModule results in following error: 27 https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69 28 """ 29 try: 30 global _force_skip_lazy_graph_module_flag 31 prior = _force_skip_lazy_graph_module_flag 32 _force_skip_lazy_graph_module_flag = True 33 yield 34 finally: 35 _force_skip_lazy_graph_module_flag = prior 36 37 38@compatibility(is_backward_compatible=False) 39@contextmanager 40def _use_lazy_graph_module(should_use: bool): 41 try: 42 global _use_lazy_graph_module_flag 43 prior = _use_lazy_graph_module_flag 44 _use_lazy_graph_module_flag = ( 45 should_use and not _force_skip_lazy_graph_module_flag 46 ) 47 yield 48 finally: 49 _use_lazy_graph_module_flag = prior 50 51 52@compatibility(is_backward_compatible=False) 53def _get_graph_module_cls(): 54 return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule 55 56 57def _make_graph_module(*args, graph_module_cls=None, **kwargs): 58 if graph_module_cls is None: 59 graph_module_cls = _get_graph_module_cls() 60 61 return graph_module_cls(*args, **kwargs) 62 63 64@compatibility(is_backward_compatible=False) 65class _LazyGraphModule(GraphModule): 66 """ 67 The main difference between _LazyGraphModule and GraphModule is how recompile happens. 68 GraphModule will do a 'recompile' call to generate python code and the forward method when it's 69 constructed. Later on if the graph get updated, recompile method can be called again to refresh 70 the saved python code and forward method. 71 72 However in some cases especially in inductor, the recompilation can be a waste since we never 73 check the python code for the graph module or call its forward method. A few more concreate 74 examples regarding pattern matching fx passes in inductor: 75 1. some passes will update the graph to be compiled and then call recompile on the GraphModule. 76 2. some passes will trace small pattern function to search it in the graph being compiled and 77 replace the match with the traced graph of a replacement function. The pattern graph and 78 replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile 79 for them in GraphModule.__init__ is also a waste of time. 80 81 However simply skip calling GraphModule.recompile in these scenarios is also dangeruous. 82 People may want to check the python code or call the GraphModule's forward method for debugging purposes. 83 84 The way _LazyGraphModule solves it is, we override the recompile method to just mark the 85 need for recompilation but does not do the actual recompilation. Later on if people really 86 access the compiled python code or call the GraphModule's forward method, we do the real 87 recompilation. 88 """ 89 90 @classmethod 91 def from_graphmodule(cls, gm: GraphModule): 92 if isinstance(gm, _LazyGraphModule): 93 return gm 94 else: 95 return _LazyGraphModule(gm, gm.graph) 96 97 @staticmethod 98 def force_recompile(gm): 99 """ 100 Sometimes we need force a recompile as a workaround 101 - we want to do the real recompilation before symbolic_trace to avoid error: 102 https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 103 """ 104 if isinstance(gm, _LazyGraphModule): 105 gm.real_recompile() 106 107 def real_recompile(self): 108 if self._needs_recompile(): 109 self._real_recompile() 110 111 @classmethod 112 def _needs_recompile(cls): 113 return cls.forward is cls._lazy_forward 114 115 def _lazy_forward(self, *args, **kwargs): 116 # Call self.real_recompile() rather than self._real_recompile() here. 117 # The _lazy_forward method may be saved and call repeatedly. 118 # Calling self.real_recompile can make sure we skip recompilation if 119 # we have already done so. 120 self.real_recompile() 121 assert not self._needs_recompile() 122 123 # call `__call__` rather than 'forward' since recompilation may 124 # install a wrapper for `__call__` to provide a customized error 125 # message. 126 return self(*args, **kwargs) 127 128 forward = _lazy_forward 129 130 # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__, 131 # or __reduce__ by calling _real_recompile. But I don't find a good way 132 # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule 133 # will be used in torch::deploy. So it's skipped for now. 134 135 def __reduce_package__(self, exporter: PackageExporter): 136 """ 137 Follow GraphModule.__reduce__ but call 'self._real_recompile' rather 138 than 'self.recompile' since for a _LazyGraphModule, self.recompile just 139 mark the need of recompilation and does not return the PythonCode object. 140 """ 141 python_code = self._real_recompile() 142 dict_without_graph = self.__dict__.copy() 143 dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ 144 del dict_without_graph["_graph"] 145 146 generated_module_name = f"fx-generated._{exporter.get_unique_id()}" 147 import_block = _format_import_block(python_code.globals, exporter.importer) 148 module_code = import_block + self.code 149 exporter.save_source_string(generated_module_name, module_code) 150 return ( 151 reduce_package_graph_module, 152 (dict_without_graph, generated_module_name), 153 ) 154 155 def __reduce__(self): 156 """ 157 Follow GraphModule.__reduce__ but call 'self._real_recompile' rather 158 than 'self.recompile' since for a _LazyGraphModule, self.recompile just 159 mark the need of recompilation and does not return the PythonCode object. 160 """ 161 python_code = self._real_recompile() 162 dict_without_graph = self.__dict__.copy() 163 import_block = _format_import_block(python_code.globals, sys_importer) 164 del dict_without_graph["_graph"] 165 return (reduce_graph_module, (dict_without_graph, import_block)) 166 167 def _real_recompile(self): 168 return super().recompile() 169 170 @classmethod 171 def recompile(cls): 172 cls.forward = cls._lazy_forward 173 174 @property 175 def code(self) -> str: 176 self.real_recompile() 177 return super().code 178 179 def __str__(self) -> str: 180 """ 181 str(GraphModule) will access the _code attribute. Make sure recompile 182 happens so _code attribute is available. 183 """ 184 self.real_recompile() 185 return super().__str__() 186