1# mypy: allow-untyped-defs 2import contextlib 3import copy 4import itertools 5import linecache 6import os 7import sys 8import traceback 9import warnings 10from pathlib import Path 11from typing import Any, Callable, Dict, List, Optional, Set, Type, Union 12 13import torch 14import torch.nn as nn 15import torch.overrides 16from torch.nn.modules.module import _addindent 17from torch.package import Importer, PackageExporter, PackageImporter, sys_importer 18 19from ._compatibility import compatibility 20from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode 21 22__all__ = [ 23 "reduce_graph_module", 24 "reduce_package_graph_module", 25 "reduce_deploy_graph_module", 26 "GraphModule", 27] 28 29_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes" 30 31# Normal exec loses the source code, however we can work with 32# the linecache module to recover it. 33# Using _exec_with_source will add it to our local cache 34# and then tools like TorchScript will be able to get source info. 35class _EvalCacheLoader: 36 def __init__(self): 37 self.eval_cache = {} 38 self.next_id = 0 39 40 def cache(self, src: str, globals: Dict[str, Any], co_fields=None): 41 """Store the source in a private cache, and add a lazy entry in linecache 42 that allows the source to be retrieved by 'filename'. 43 44 Args: 45 src (str): The module source to cache 46 globals (dict): The module globals 47 48 Returns: 49 str: The cache key (and dummy filename) generated for src. 50 """ 51 52 key = self._get_key() 53 if co_fields: 54 key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" 55 self.eval_cache[key] = src 56 57 # Don't mutate globals so that this loader is only used 58 # to populate linecache, and doesn't interact with other modules 59 # that might check `__loader__` 60 globals_copy = globals.copy() 61 globals_copy["__file__"] = key 62 globals_copy["__name__"] = key 63 globals_copy["__loader__"] = self 64 linecache.lazycache(key, globals_copy) 65 66 return key 67 68 # Part of the loader protocol (PEP 302) 69 # linecache will use this method when trying to find source code 70 def get_source(self, module_name) -> Optional[str]: 71 if module_name in self.eval_cache: 72 return self.eval_cache[module_name] 73 return None 74 75 def _get_key(self): 76 key = f"<eval_with_key>.{self.next_id}" 77 self.next_id += 1 78 return key 79 80 81_loader = _EvalCacheLoader() 82 83 84def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None): 85 key = _loader.cache(src, globals, co_fields) 86 exec(compile(src, key, "exec"), globals) 87 88 89def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None): 90 return _method_from_src( 91 method_name="forward", src=src, globals=globals, co_fields=co_fields 92 ) 93 94 95def _method_from_src( 96 method_name: str, src: str, globals: Dict[str, Any], co_fields=None 97) -> Callable: 98 # avoid mutating the passed in dict 99 globals_copy = globals.copy() 100 _exec_with_source(src, globals_copy, co_fields) 101 fn = globals_copy[method_name] 102 del globals_copy[method_name] 103 return fn 104 105 106def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: 107 if name in _custom_builtins: 108 return _custom_builtins[name].import_str 109 if _is_from_torch(name): 110 return "import torch" 111 module_name, attr_name = importer.get_name(obj) 112 return f"from {module_name} import {attr_name} as {name}" 113 114 115def _format_import_block(globals: Dict[str, Any], importer: Importer): 116 import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()} 117 # Sort the imports so we have a stable import block that allows us to 118 # hash the graph module and get a consistent key for use in a cache. 119 return "\n".join(sorted(import_strs)) 120 121 122@compatibility(is_backward_compatible=True) 123def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: 124 # BC: attribute name was changed from `code` to `_code` to facilitate 125 # making `code` into a property and adding a docstring to it 126 fn_src = body.get("_code") or body["code"] 127 forward = _forward_from_src(import_block + fn_src, {}) 128 return _deserialize_graph_module(forward, body) 129 130 131@compatibility(is_backward_compatible=True) 132def reduce_package_graph_module( 133 importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str 134) -> torch.nn.Module: 135 forward = importer.import_module(generated_module_name).forward 136 return _deserialize_graph_module(forward, body) 137 138 139@compatibility(is_backward_compatible=True) 140def reduce_deploy_graph_module( 141 importer: PackageImporter, body: Dict[Any, Any], import_block: str 142) -> torch.nn.Module: 143 ns = {} 144 ns["__builtins__"] = importer.patched_builtins 145 fn_src = body.get("_code") 146 assert fn_src is not None 147 forward = _forward_from_src(import_block + fn_src, ns) 148 return _deserialize_graph_module(forward, body) 149 150 151# We create a dummy class here because symbolic_trace pulls the forward() 152# function off of the class, rather than the instance. This class is used 153# in _deserialize_graph_module() below. 154class _CodeOnlyModule(torch.nn.Module): 155 def __init__(self, body): 156 super().__init__() 157 self.__dict__ = body 158 159 160def _deserialize_graph_module(forward, body: Dict[Any, Any], graph_module_cls=None) -> torch.nn.Module: 161 """ 162 Deserialize a GraphModule given the dictionary of the original module, 163 using the code to reconstruct the graph. We delete the actual graph before 164 saving the dictionary so that changes to the in-memory graph format do not 165 get serialized. 166 """ 167 168 # Try to retrieve the forward source in a backward-compatible way 169 _CodeOnlyModule.forward = forward 170 171 tracer_cls = body.get("_tracer_cls") 172 if tracer_cls is None: 173 from ._symbolic_trace import Tracer 174 175 tracer_cls = Tracer 176 177 graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule") 178 179 # This is a workaround for a mypy linter issue related to 180 # passing base class as an argument - https://github.com/python/mypy/issues/5865. 181 cls_tracer: Any = tracer_cls 182 183 class KeepModules(cls_tracer): 184 # we shouldn't trace into any of the submodules, 185 # because they were not traced in the original GraphModule 186 def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: 187 return True 188 189 com = _CodeOnlyModule(body) 190 191 tracer_extras = body.get("_tracer_extras", {}) 192 graph = KeepModules().trace(com, **tracer_extras) 193 194 # Manually set Tracer class on the reconstructed Graph, to avoid 195 # referencing the private local subclass KeepModules. 196 graph._tracer_cls = tracer_cls 197 from ._lazy_graph_module import _make_graph_module 198 gm = _make_graph_module(com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls) 199 200 # The GraphModule constructor only retains attributes referenced by the graph. 201 # In this case, our goal is return a GraphModule as close to identical as the one 202 # put into the package. If any additional attributes were present in body, 203 # we should keep them. 204 for k, v in body.items(): 205 if not hasattr(gm, k): 206 setattr(gm, k, v) 207 return gm 208 209 210# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' 211# This installs empty Modules where none exist yet if they are subpaths of target 212def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): 213 *prefix, field = target.split(".") 214 for item in prefix: 215 f = getattr(from_module, item) 216 t = getattr(to_module, item, None) 217 if f is t: 218 # we have already installed one of its parents 219 # (e.g. target = root.linear.weight, but we have already installed root.linear) 220 # once we install a parent, we no longer need to copy the children 221 # since all the needed properties will already be present 222 return 223 224 if t is None: 225 t = torch.nn.Module() 226 setattr(to_module, item, t) 227 from_module, to_module = f, t 228 229 orig = getattr(from_module, field) 230 # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. 231 # So, we register it as a named buffer in the target module. 232 if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): 233 to_module.register_buffer(field, orig) 234 else: 235 setattr(to_module, field, orig) 236 237 238# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module 239# This installs empty Modules where none exist yet if they are subpaths of target 240def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): 241 *prefix, field = target.split(".") 242 for item in prefix: 243 t = getattr(to_module, item, None) 244 245 if t is None: 246 t = torch.nn.Module() 247 setattr(to_module, item, t) 248 to_module = t 249 250 # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. 251 # So, we register it as a named buffer in the target module. 252 if isinstance(from_obj, torch.Tensor) and not isinstance( 253 from_obj, torch.nn.Parameter 254 ): 255 to_module.register_buffer(field, from_obj) 256 else: 257 setattr(to_module, field, from_obj) 258 259 260def _print_readable( 261 module, 262 module_name, 263 print_output=True, 264 include_stride=False, 265 include_device=False, 266 colored=False, 267): 268 graph = module.graph 269 assert graph is not None and isinstance(graph, torch.fx.Graph), "print_readable must be used on a module with a graph" 270 271 verbose_python_code = graph.python_code( 272 root_module="self", 273 verbose=True, 274 include_stride=include_stride, 275 include_device=include_device, 276 colored=colored, 277 ) 278 module_code = verbose_python_code.src 279 module_code = module_code.lstrip("\n") 280 module_code = f"class {module_name}(torch.nn.Module):\n" + module_code 281 module_code = _addindent(module_code, 4) 282 283 submodule_code_list = [""] 284 for submodule_name, submodule in module.named_children(): 285 if hasattr(submodule, "graph"): 286 submodule_code_list.append( 287 _print_readable( 288 submodule, 289 submodule_name, 290 print_output=False, 291 include_stride=include_stride, 292 include_device=include_device, 293 colored=colored, 294 ) 295 ) 296 submodule_code = "\n".join(submodule_code_list) 297 submodule_code = _addindent(submodule_code, 4) 298 299 output = module_code + submodule_code 300 if print_output: 301 print(module_code + submodule_code) 302 return output 303 304 305class _WrappedCall: 306 def __init__(self, cls, cls_call): 307 self.cls = cls 308 self.cls_call = cls_call 309 310 # Previously, if an error occurred when valid 311 # symbolically-traced code was run with an invalid input, the 312 # user would see the source of the error as coming from 313 # `File "<eval_with_key_N">`, where N is some number. We use 314 # this function to generate a more informative error message. We 315 # return the traceback itself, a message explaining that the 316 # error occurred in a traced Module's generated forward 317 # function, and five lines of context surrounding the faulty 318 # line 319 @staticmethod 320 def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: 321 # auxiliary variables (for readability) 322 err_lineno = frame_summary.lineno 323 assert err_lineno is not None 324 line = frame_summary.line 325 assert line is not None 326 err_line_len = len(line) 327 all_src_lines = linecache.getlines(frame_summary.filename) 328 329 # constituent substrings of the error message 330 tb_repr = torch._dynamo.disable(traceback.format_exc)() 331 custom_msg = ( 332 "Call using an FX-traced Module, " 333 f"line {err_lineno} of the traced Module's " 334 "generated forward function:" 335 ) 336 before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) 337 marker = "~" * err_line_len + "~~~ <--- HERE" 338 err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) 339 340 # joined message 341 return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) 342 343 def __call__(self, obj, *args, **kwargs): 344 try: 345 if self.cls_call is not None: 346 return self.cls_call(obj, *args, **kwargs) 347 else: 348 return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] 349 except Exception as e: 350 assert e.__traceback__ 351 topmost_framesummary: traceback.FrameSummary = ( 352 traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] 353 ) # type: ignore[arg-type] 354 if "eval_with_key" in topmost_framesummary.filename: 355 print( 356 _WrappedCall._generate_error_message(topmost_framesummary), 357 file=sys.stderr, 358 ) 359 raise e.with_traceback(None) # noqa: B904 360 else: 361 raise e 362 363@compatibility(is_backward_compatible=True) 364class GraphModule(torch.nn.Module): 365 """ 366 GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a 367 ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated 368 from that ``graph``. 369 370 .. warning:: 371 372 When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically 373 regenerated. However, if you edit the contents of the ``graph`` without reassigning 374 the ``graph`` attribute itself, you must call ``recompile()`` to update the generated 375 code. 376 """ 377 378 def __new__(cls: "Type[GraphModule]", *args, **kwargs): 379 # each instance of a graph module needs its own forward method 380 # so create a new singleton class for each instance. 381 # it is a subclass of the user-defined class, the only difference 382 # is an extra layer to install the forward method 383 384 # address issue described at https://github.com/pytorch/pytorch/issues/63883 385 # in other words, traverse class hierarchy to fix the redundant class definition problem 386 for t in cls.__mro__: 387 c = t.__qualname__.split(".")[-1] 388 if c != "GraphModuleImpl": 389 cls = t 390 break 391 392 class GraphModuleImpl(cls): # type: ignore[misc, valid-type] 393 pass 394 395 return super().__new__(GraphModuleImpl) 396 397 @compatibility(is_backward_compatible=True) 398 def __init__( 399 self, 400 root: Union[torch.nn.Module, Dict[str, Any]], 401 graph: Graph, 402 class_name: str = "GraphModule", 403 ): 404 """ 405 Construct a GraphModule. 406 407 Args: 408 409 root (Union[torch.nn.Module, Dict[str, Any]): 410 ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. 411 In the case that ``root`` is a Module, any references to Module-based objects (via qualified 412 name) in the Graph's Nodes' ``target`` field will be copied over from the respective place 413 within ``root``'s Module hierarchy into the GraphModule's module hierarchy. 414 In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be 415 looked up directly in the dict's keys. The object mapped to by the Dict will be copied 416 over into the appropriate place within the GraphModule's module hierarchy. 417 418 graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation 419 420 class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all 421 error messages will report as originating from ``GraphModule``. It may be helpful to set this 422 to ``root``'s original name or a name that makes sense within the context of your transform. 423 """ 424 super().__init__() 425 self.__class__.__name__ = class_name 426 if isinstance(root, torch.nn.Module): 427 if hasattr(root, "training"): 428 self.training = root.training 429 430 # When we pickle/unpickle graph module, we don't want to drop any module or attributes. 431 if isinstance(root, _CodeOnlyModule): 432 for k, _ in root.named_children(): 433 _copy_attr(root, self, k) 434 435 for k, _ in root.named_buffers(): 436 _copy_attr(root, self, k) 437 438 for k, _ in root.named_parameters(): 439 _copy_attr(root, self, k) 440 441 for node in graph.nodes: 442 if node.op in ["get_attr", "call_module"]: 443 assert isinstance(node.target, str) 444 _copy_attr(root, self, node.target) 445 elif isinstance(root, dict): 446 targets_to_copy = [] 447 for node in graph.nodes: 448 if node.op in ["get_attr", "call_module"]: 449 assert isinstance(node.target, str) 450 if node.target not in root: 451 raise RuntimeError( 452 "Node " 453 + str(node) 454 + " referenced target " 455 + node.target 456 + " but that target was not provided in ``root``!" 457 ) 458 targets_to_copy.append(node.target) 459 # Sort targets in ascending order of the # of atoms. 460 # This will ensure that less deeply nested attributes are assigned 461 # before more deeply nested attributes. For example, foo.bar 462 # will be assigned before foo.bar.baz. Otherwise, we might assign 463 # the user-provided ``foo.bar`` and wipe out the previously-assigned 464 # ``foo.bar.baz`` 465 targets_to_copy.sort(key=lambda t: t.count(".")) 466 for target_to_copy in targets_to_copy: 467 _assign_attr(root[target_to_copy], self, target_to_copy) 468 else: 469 raise RuntimeError("Unsupported type " + str(root) + " passed for root!") 470 471 self.graph = graph 472 473 # Store the Tracer class responsible for creating a Graph separately as part of the 474 # GraphModule state, except when the Tracer is defined in a local namespace. 475 # Locally defined Tracers are not pickleable. This is needed because torch.package will 476 # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer 477 # to re-create the Graph during deserialization. 478 self._tracer_cls = None 479 if ( 480 self.graph._tracer_cls 481 and "<locals>" not in self.graph._tracer_cls.__qualname__ 482 ): 483 self._tracer_cls = self.graph._tracer_cls 484 485 self._tracer_extras = {} 486 if self.graph._tracer_extras: 487 self._tracer_extras = self.graph._tracer_extras 488 489 # Dictionary to store metadata 490 self.meta: Dict[str, Any] = {} 491 self._replace_hook = None 492 self._create_node_hooks: List[Callable] = [] 493 self._erase_node_hooks: List[Callable] = [] 494 495 # TorchScript breaks trying to compile the graph setter because of the 496 # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 497 # 498 # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway 499 __jit_unused_properties__ = ["graph"] 500 501 @property 502 def graph(self) -> Graph: 503 """ 504 Return the ``Graph`` underlying this ``GraphModule`` 505 """ 506 return self._graph 507 508 @graph.setter 509 def graph(self, g: Graph) -> None: 510 """ 511 Set the underlying ``Graph`` for this ``GraphModule``. This will internally 512 recompile the ``GraphModule`` so that the generated ``forward()`` function 513 corresponds to ``g`` 514 """ 515 assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}" 516 self._graph = g 517 g.owning_module = self 518 self.recompile() 519 520 @compatibility(is_backward_compatible=False) 521 def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): 522 """Dumps out module to ``folder`` with ``module_name`` so that it can be 523 imported with ``from <folder> import <module_name>`` 524 525 Args: 526 527 folder (Union[str, os.PathLike]): The folder to write the code out to 528 529 module_name (str): Top-level name to use for the ``Module`` while 530 writing out the code 531 """ 532 folder = Path(folder) 533 Path(folder).mkdir(exist_ok=True) 534 torch.save(self.state_dict(), folder / "state_dict.pt") 535 tab = " " * 4 536 custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()]) 537 model_str = f""" 538import torch 539{custom_builtins} 540 541from torch.nn import * 542class {module_name}(torch.nn.Module): 543 def __init__(self): 544 super().__init__() 545""" 546 547 def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: 548 safe_reprs = [ 549 nn.Linear, 550 nn.Conv1d, 551 nn.Conv2d, 552 nn.Conv3d, 553 nn.BatchNorm1d, 554 nn.BatchNorm2d, 555 nn.BatchNorm3d, 556 ] 557 if type(module) in safe_reprs: 558 return f"{module.__repr__()}" 559 else: 560 return None 561 562 blobified_modules = [] 563 for module_name, module in self.named_children(): 564 module_str = _gen_model_repr(module_name, module) 565 if module_str is None: 566 module_file = folder / f"{module_name}.pt" 567 torch.save(module, module_file) 568 blobified_modules.append(module_name) 569 module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") 570 # weights_only=False as this is legacy code that saves the model 571 module_str = f"torch.load(r'{module_file}', weights_only=False) # {module_repr}" 572 model_str += f"{tab*2}self.{module_name} = {module_str}\n" 573 574 for buffer_name, buffer in self._buffers.items(): 575 if buffer is None: 576 continue 577 model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" 578 579 for param_name, param in self._parameters.items(): 580 if param is None: 581 continue 582 model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" 583 584 model_str += ( 585 f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" 586 ) 587 model_str += f"{_addindent(self.code, 4)}\n" 588 589 module_file = folder / "module.py" 590 module_file.write_text(model_str) 591 592 init_file = folder / "__init__.py" 593 init_file.write_text("from .module import *") 594 595 if len(blobified_modules) > 0: 596 warnings.warn( 597 "Was not able to save the following children modules as reprs -" 598 f"saved as pickled files instead: {blobified_modules}" 599 ) 600 601 @compatibility(is_backward_compatible=True) 602 def add_submodule(self, target: str, m: torch.nn.Module) -> bool: 603 """ 604 Adds the given submodule to ``self``. 605 606 This installs empty Modules where none exist yet if they are 607 subpaths of ``target``. 608 609 Args: 610 target: The fully-qualified string name of the new submodule 611 (See example in ``nn.Module.get_submodule`` for how to 612 specify a fully-qualified string.) 613 m: The submodule itself; the actual object we want to 614 install in the current Module 615 616 Return: 617 bool: Whether or not the submodule could be inserted. For 618 this method to return True, each object in the chain 619 denoted by ``target`` must either a) not exist yet, 620 or b) reference an ``nn.Module`` (not a parameter or 621 other attribute) 622 """ 623 *prefix, field = target.split(".") 624 mod: torch.nn.Module = self 625 626 for item in prefix: 627 628 submod = getattr(mod, item, None) 629 630 if submod is None: 631 submod = torch.nn.Module() 632 setattr(mod, item, submod) 633 634 if not isinstance(submod, torch.nn.Module): 635 return False 636 637 mod = submod 638 639 mod.add_module(field, m) 640 return True 641 642 @compatibility(is_backward_compatible=True) 643 def delete_submodule(self, target: str) -> bool: 644 """ 645 Deletes the given submodule from ``self``. 646 647 The module will not be deleted if ``target`` is not a valid 648 target. 649 650 Args: 651 target: The fully-qualified string name of the new submodule 652 (See example in ``nn.Module.get_submodule`` for how to 653 specify a fully-qualified string.) 654 655 Returns: 656 bool: Whether or not the target string referenced a 657 submodule we want to delete. A return value of ``False`` 658 means that the ``target`` was not a valid reference to 659 a submodule. 660 """ 661 atoms = target.split(".") 662 path, target_submod = atoms[:-1], atoms[-1] 663 mod: torch.nn.Module = self 664 665 # Get the parent module 666 for item in path: 667 668 if not hasattr(mod, item): 669 return False 670 671 mod = getattr(mod, item) 672 673 if not isinstance(mod, torch.nn.Module): 674 return False 675 676 if not hasattr(mod, target_submod): 677 return False 678 679 if not isinstance(getattr(mod, target_submod), torch.nn.Module): 680 return False 681 682 delattr(mod, target_submod) 683 return True 684 685 @compatibility(is_backward_compatible=True) 686 def delete_all_unused_submodules(self) -> None: 687 """ 688 Deletes all unused submodules from ``self``. 689 690 A Module is considered "used" if any one of the following is 691 true: 692 1. It has children that are used 693 2. Its forward is called directly via a ``call_module`` node 694 3. It has a non-Module attribute that is used from a 695 ``get_attr`` node 696 697 This method can be called to clean up an ``nn.Module`` without 698 manually calling ``delete_submodule`` on each unused submodule. 699 """ 700 used: List[str] = [] 701 702 for node in self.graph.nodes: 703 704 if node.op == "call_module" or node.op == "get_attr": 705 706 # A list of strings representing the different parts 707 # of the path. For example, `foo.bar.baz` gives us 708 # ["foo", "bar", "baz"] 709 fullpath = node.target.split(".") 710 711 # If we're looking at multiple parts of a path, join 712 # join them with a dot. Otherwise, return that single 713 # element without doing anything to it. 714 def join_fn(x: str, y: str) -> str: 715 return ".".join([x, y] if y else [x]) 716 717 # Progressively collect all the names of intermediate 718 # modules. For example, if we have the target 719 # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and 720 # `foo.bar.baz` to the list. 721 used.extend(itertools.accumulate(fullpath, join_fn)) 722 723 # For a `call_module` node, also register all recursive submodules 724 # as used 725 if node.op == "call_module": 726 try: 727 submod = self.get_submodule(node.target) 728 729 for submod_name, _ in submod.named_modules(): 730 if submod_name != "": 731 used.append(".".join([node.target, submod_name])) 732 except AttributeError: 733 # Node referenced nonexistent submodule, don't need to 734 # worry about GCing anything 735 pass 736 737 to_delete = [name for name, _ in self.named_modules() if name not in used] 738 739 for name in to_delete: 740 self.delete_submodule(name) 741 742 @property 743 def code(self) -> str: 744 """ 745 Return the Python code generated from the ``Graph`` underlying this 746 ``GraphModule``. 747 """ 748 if not hasattr(self, "_code"): 749 raise RuntimeError( 750 "Code has not been generated! Please report a bug to PyTorch" 751 ) 752 return self._code 753 754 @compatibility(is_backward_compatible=True) 755 def recompile(self) -> PythonCode: 756 """ 757 Recompile this GraphModule from its ``graph`` attribute. This should be 758 called after editing the contained ``graph``, otherwise the generated 759 code of this ``GraphModule`` will be out of date. 760 """ 761 if isinstance(self._graph._codegen, _PyTreeCodeGen): 762 self._in_spec = self._graph._codegen.pytree_info.in_spec 763 self._out_spec = self._graph._codegen.pytree_info.out_spec 764 python_code = self._graph.python_code(root_module="self") 765 self._code = python_code.src 766 self._lineno_map = python_code._lineno_map 767 768 cls = type(self) 769 co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} 770 cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) 771 772 # Determine whether this class explicitly defines a __call__ implementation 773 # to wrap. If it does, save it in order to have wrapped_call invoke it. 774 # If it does not, wrapped_call can use a dynamic call to super() instead. 775 # In most cases, super().__call__ should be torch.nn.Module.__call__. 776 # We do not want to hold a reference to Module.__call__ here; doing so will 777 # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. 778 cls_call = cls.__call__ if "__call__" in vars(cls) else None 779 780 if "_wrapped_call" not in vars(cls): 781 cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] 782 783 def call_wrapped(self, *args, **kwargs): 784 return self._wrapped_call(self, *args, **kwargs) 785 786 cls.__call__ = call_wrapped # type: ignore[method-assign] 787 788 return python_code 789 790 # Passing Tracer as argument allows subclasses extending fx.GraphModule 791 # define their own Tracer (extending fx.Tracer). 792 def __reduce_deploy__(self, importer: Importer): 793 dict_without_graph = self.__dict__.copy() 794 dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ 795 del dict_without_graph["_graph"] 796 797 python_code = self.recompile() 798 import_block = _format_import_block(python_code.globals, importer) 799 return (reduce_deploy_graph_module, (dict_without_graph, import_block)) 800 801 def __reduce_package__(self, exporter: PackageExporter): 802 dict_without_graph = self.__dict__.copy() 803 dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ 804 del dict_without_graph["_graph"] 805 806 generated_module_name = f"fx-generated._{exporter.get_unique_id()}" 807 python_code = self.recompile() 808 import_block = _format_import_block(python_code.globals, exporter.importer) 809 module_code = import_block + self.code 810 exporter.save_source_string(generated_module_name, module_code) 811 return ( 812 reduce_package_graph_module, 813 (dict_without_graph, generated_module_name), 814 ) 815 816 def __reduce__(self): 817 """ 818 Serialization of GraphModule. We serialize only the generated code, not 819 the underlying ``Graph``. This is because ``Graph`` does not have on-disk 820 backward-compatibility guarantees, whereas Python source code does. 821 On the deserialization side, we symbolically trace through the generated 822 code to regenerate the underlying ``Graph`` 823 """ 824 dict_without_graph = self.__dict__.copy() 825 826 python_code = self.recompile() 827 import_block = _format_import_block(python_code.globals, sys_importer) 828 del dict_without_graph["_graph"] 829 return (reduce_graph_module, (dict_without_graph, import_block)) 830 831 def _deepcopy_init(self): 832 return GraphModule.__init__ 833 834 # because __reduce__ is defined for serialization, 835 # we need to define deepcopy otherwise it will call __reduce__ 836 # and cause symbolic tracing to occur every time we try to copy the object 837 def __deepcopy__(self, memo): 838 res = type(self).__new__(type(self)) 839 memo[id(self)] = res 840 fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo)) 841 self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"]) 842 # hooks are lost during `GraphModule.__init__`, so we need to copy over 843 # them explicitly, note right now we are only copying state_dict related 844 # hooks, to reduce bc-related issues, we can copy forward/backward related 845 # hooks in the future as well if needed 846 extra_preserved_attrs = [ 847 "_state_dict_hooks", 848 "_load_state_dict_pre_hooks", 849 "_load_state_dict_post_hooks", 850 "_replace_hook", 851 "_create_node_hooks", 852 "_erase_node_hooks" 853 ] 854 for attr in extra_preserved_attrs: 855 if attr in self.__dict__: 856 setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo)) 857 res.meta = copy.deepcopy(getattr(self, "meta", {}), memo) 858 if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta: 859 for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): 860 setattr(res, attr_name, attr) 861 return res 862 863 def __copy__(self): 864 from ._lazy_graph_module import _make_graph_module 865 res = _make_graph_module(self, self.graph) 866 res.meta = getattr(self, "meta", {}) 867 return res 868 869 @compatibility(is_backward_compatible=False) 870 def print_readable(self, print_output=True, include_stride=False, include_device=False, colored=False): 871 """ 872 Return the Python code generated for current GraphModule and its children GraphModules 873 """ 874 return _print_readable( 875 self, 876 self._get_name(), 877 print_output, 878 include_stride, 879 include_device, 880 colored, 881 ) 882 883 def __str__(self) -> str: 884 orig_str = super().__str__() 885 print_readable_reminder = ( 886 "# To see more debug info, please use `graph_module.print_readable()`" 887 ) 888 return "\n".join([orig_str, self._code, print_readable_reminder]) 889 890 def _replicate_for_data_parallel(self): 891 new_gm = self.__copy__() 892 new_gm._is_replica = True 893 return new_gm 894 895 @contextlib.contextmanager 896 def _set_replace_hook(self, f): 897 """ 898 Takes a callable which will be called everytime when we replace a node 899 to a new node, or change the node's name. Callable takes three arguments: 900 the old node we're changing, and NAME of the new node, followed by the 901 user node which consumes the old node to be replaced. 902 """ 903 assert callable(f), "Replace hook must be a callable." 904 prev, self._replace_hook = self._replace_hook, f 905 try: 906 yield 907 finally: 908 self._replace_hook = prev 909 910 def _register_create_node_hook(self, f): 911 """ 912 Takes a callable which will be called after we create a new node. The 913 callable takes the newly created node as input and returns None. 914 """ 915 assert callable(f), "create_node hook must be a callable." 916 self._create_node_hooks.append(f) 917 918 def _unregister_create_node_hook(self, f): 919 """ 920 Takes a callable which was previously registered to be called after we create a node. 921 This function will unregister that callable so it is no longer invoked on node creation. 922 """ 923 assert callable(f), "create_node hook must be a callable." 924 self._create_node_hooks.remove(f) 925 926 def _register_erase_node_hook(self, f): 927 """ 928 Takes a callable which will be called after we erase a node. The 929 callable takes the node that is being erased as input and returns None. 930 """ 931 assert callable(f), "erase_node hook must be a callable." 932 self._erase_node_hooks.append(f) 933 934 def _unregister_erase_node_hook(self, f): 935 """ 936 Takes a callable which was previously registered to be called after we erase a node. 937 This function will unregister that callable so it is no longer invoked on node erasure. 938 """ 939 assert callable(f), "erase_node hook must be a callable." 940 self._erase_node_hooks.remove(f) 941 942# workarounds for issues in __torch_function__ 943 944# WAR for __torch_function__ not handling tensor lists, 945# fix is in https://github.com/pytorch/pytorch/pull/34725 946# orig_cat = torch.cat 947# def patched_cat(*args, **kwargs): 948# tensors = args[0] 949# for t in tensors: 950# if isinstance(t, Proxy): 951# return t.__torch_function__(patched_cat, (), args, kwargs) 952# return orig_cat(*args, **kwargs) 953# patched_cat.__module__ = 'torch' 954# patched_cat.__name__ = 'cat' 955# torch.cat = patched_cat 956