1# mypy: ignore-errors 2 3import functools 4import inspect 5import itertools 6import types 7from contextlib import contextmanager, nullcontext 8from typing import Dict, List, TYPE_CHECKING 9 10import torch.nn 11 12from .. import trace_rules, variables 13from ..exc import ( 14 raise_observed_exception, 15 unimplemented, 16 UnspecializeRestartAnalysis, 17 Unsupported, 18) 19from ..guards import GuardBuilder, install_guard 20from ..mutation_guard import GenerationTracker 21from ..source import ( 22 AttrSource, 23 ConstDictKeySource, 24 FSDPNNModuleSource, 25 GetItemSource, 26 NNModuleSource, 27 UnspecializedBuiltinNNModuleSource, 28 UnspecializedNNModuleSource, 29) 30from ..utils import ( 31 get_custom_getattr, 32 get_fake_value, 33 is_lazy_module, 34 is_namedtuple, 35 is_safe_constant, 36 istensor, 37 istype, 38 nnmodule_has_hooks, 39 object_has_getattribute, 40 proxy_args_kwargs, 41 set_example_value, 42) 43from .base import MutableLocal, typestr, VariableTracker 44from .functions import invoke_and_store_as_constant 45from .lazy import LazyVariableTracker 46from .lists import SliceVariable 47from .user_defined import UserDefinedObjectVariable 48 49 50if TYPE_CHECKING: 51 from torch._dynamo.symbolic_convert import InstructionTranslator 52 53 54def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs): 55 """ 56 Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable. 57 58 Used to cause lazy module to be initialized (and delete its init hook) before tracing. Especially 59 useful now that 'allowed' modules graph-break on hooks, calling this first ensures there is no hook 60 by the time we trace __call__ and thus no graph-break for lazy allowed modules. 61 """ 62 if hasattr(mod, "_initialize_hook"): 63 64 def convert_to_fake(x): 65 if is_namedtuple(x): 66 return type(x)(*(convert_to_fake(elem) for elem in x)) 67 elif isinstance(x, dict): 68 return {k: convert_to_fake(v) for k, v in x.items()} 69 elif isinstance(x, (list, tuple, set)): 70 return type(x)(convert_to_fake(elem) for elem in x) 71 elif isinstance(x, torch.fx.Proxy): 72 return get_fake_value(x.node, tx) 73 else: 74 return x 75 76 proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) 77 fake_args = [convert_to_fake(arg) for arg in proxy_args] 78 fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()} 79 mod._infer_parameters(mod, fake_args, fake_kwargs) 80 81 82@contextmanager 83def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module): 84 fully_qualified_name = source.name() 85 try: 86 tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__) 87 yield 88 finally: 89 del tx.nn_module_stack[module_key] 90 91 92def guard_to_detect_forward_monkeypatching(source, mod): 93 # Users sometimes patch the forward method of a nn module instance to 94 # perform optimizations like quantization. Though this is not a good 95 # software practice, but python allows this and Dynamo needs to detect 96 # this patching. 97 # 98 # One way to do this is to add an ID_MATCH guard on every function 99 # getting inlined (https://github.com/pytorch/pytorch/pull/124975). But 100 # this increased guard overhead by around 20%. 101 # 102 # To keep the guard overhead down, we just guard on the `forward` being 103 # not present in the mod __dict__. The common case of patching forward 104 # method adds `forward` in the instance __dict__, whereas the unpatched 105 # `forward` sits in the type(mod).__dict__ 106 if source: 107 if "forward" in mod.__dict__ and callable(mod.__dict__["forward"]): 108 # Monkeypatched forward method, add an ID_MATCH guard on forward function 109 fwd = mod.__dict__["forward"] 110 forward_source = AttrSource(source, "forward") 111 if type(fwd) is types.MethodType: 112 forward_source = AttrSource(forward_source, "__func__") 113 install_guard(forward_source.make_guard(GuardBuilder.CLOSURE_MATCH)) 114 else: 115 # Common case - check that the forward key is absent in mod __dict__ 116 install_guard( 117 source.make_guard( 118 functools.partial( 119 GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr="forward" 120 ) 121 ) 122 ) 123 124 125class NNModuleVariable(VariableTracker): 126 _nonvar_fields = { 127 "module_type", 128 "module_key", 129 "module", 130 "nn_module_stack_source", 131 *VariableTracker._nonvar_fields, 132 } 133 134 def __init__( 135 self, module_type: type, module_key: str, module: torch.nn.Module, **kwargs 136 ) -> None: 137 super().__init__(**kwargs) 138 self.module_type = module_type 139 self.module_key = module_key 140 self.module = module 141 assert self.source 142 self.nn_module_stack_source = self.source 143 144 def get_nn_module_stack_source(self): 145 return self.nn_module_stack_source or self.source 146 147 def set_nn_module_stack_source(self, source): 148 self.nn_module_stack_source = source 149 150 def python_type(self): 151 return self.module_type 152 153 def _wrap_submodule( 154 self, tx: "InstructionTranslator", source, submod, *key_extra, **options 155 ): 156 return 157 158 def unpack_var_sequence(self, tx): 159 # implement list/iter/tuple/etc calls 160 base = tx.output.get_submodule(self.module_key) 161 if isinstance(base, torch.nn.ModuleDict): 162 result = [] 163 for name, submod in base.items(): 164 name_var = variables.ConstantVariable.create(name) 165 tx.output.register_attr_or_module( 166 submod, 167 self.module_key, 168 name, 169 source=NNModuleSource(GetItemSource(self.source, name)), 170 ) 171 result.append(name_var) 172 return result 173 174 assert isinstance( 175 base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential) 176 ), typestr(base) 177 assert self.source 178 result = [] 179 for idx, submod in enumerate(base): 180 result.append( 181 tx.output.register_attr_or_module( 182 submod, 183 self.module_key, 184 idx, 185 source=NNModuleSource(GetItemSource(self.source, idx)), 186 ) 187 ) 188 return result 189 190 def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 191 mod = tx.output.get_submodule(self.module_key) 192 result = hasattr(mod, name) 193 install_guard( 194 NNModuleSource(AttrSource(self.source, name)).make_guard( 195 GuardBuilder.HASATTR 196 ) 197 ) 198 return variables.ConstantVariable.create(result) 199 200 def is_training(self, tx): 201 mod = tx.output.get_submodule(self.module_key) 202 return getattr(mod, "training", False) 203 204 def convert_to_unspecialized(self, tx): 205 """Restart analysis treating this module as an UnspecializedNNModuleVariable""" 206 mod = tx.output.get_submodule(self.module_key) 207 GenerationTracker.tag(mod) 208 209 # Mark the class dynamic unless its module initialization 210 if tx.f_code.co_name != "__init__": 211 GenerationTracker.mark_class_dynamic(type(mod)) 212 raise UnspecializeRestartAnalysis 213 214 def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): 215 base = tx.output.get_submodule(self.module_key) 216 217 if object_has_getattribute(base): 218 unimplemented("NNModuleVariable with custom __getattribute__") 219 220 if tx.output.side_effects.has_pending_mutation_of_attr(self, key): 221 mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) 222 return not isinstance(mutated_attr, variables.DeletedVariable) 223 224 base_dict = object.__getattribute__(base, "__dict__") 225 return key in base_dict 226 227 def _custom_getattr_fallback(self, base, tx, name, options): 228 """Check for a __getattr__ and handle it specially if it is implemented""" 229 if object_has_getattribute(base): 230 unimplemented("torch.nn.Module with a custom __getattribute__ defined") 231 232 getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True) 233 if getattr_fn is None: 234 return None 235 236 if not isinstance(getattr_fn, types.FunctionType): 237 unimplemented("torch.nn.Module with a non-function custom __getattr__") 238 239 return variables.UserMethodVariable(getattr_fn, self, **options).call_function( 240 tx, [variables.ConstantVariable.create(name)], {} 241 ) 242 243 def var_getattr(self, tx: "InstructionTranslator", name): 244 from .builder import VariableBuilder 245 246 if self.source: 247 source = AttrSource(self.source, name) 248 else: 249 source = None 250 251 base = tx.output.get_submodule(self.module_key) 252 base_dict = object.__getattribute__(base, "__dict__") 253 object_member = True 254 all_class_attribute_names = set() 255 for x in inspect.getmro(base.__class__): 256 all_class_attribute_names.update(x.__dict__.keys()) 257 258 if not self.source: 259 unimplemented("GETATTR with no source") 260 261 if name == "__dict__": 262 return variables.GetAttrVariable(self, name, source=source) 263 264 if name in base_dict: 265 subobj = base_dict[name] 266 elif ( 267 "_modules" in base_dict 268 and name in base_dict["_modules"] 269 and name not in all_class_attribute_names 270 ): 271 subobj = base_dict["_modules"][name] 272 elif "_parameters" in base_dict and name in base_dict["_parameters"]: 273 subobj = base_dict["_parameters"][name] 274 elif "_buffers" in base_dict and name in base_dict["_buffers"]: 275 subobj = base_dict["_buffers"][name] 276 else: 277 try: 278 subobj = inspect.getattr_static(base, name) 279 object_member = False 280 except AttributeError: 281 # see if we can fallback to __getattr__, which is not checked by getattr_static 282 result = self._custom_getattr_fallback( 283 base=base, tx=tx, name=name, options={"source": source} 284 ) 285 if result is not None: 286 return result 287 # if we can't find a __getattr__, just raise the AttributeError 288 raise 289 290 if name == "forward": 291 guard_to_detect_forward_monkeypatching(self.source, base) 292 293 if name == "__class__" and not object_member: 294 return variables.UserDefinedClassVariable(base.__class__, source=source) 295 296 if object_member: 297 out = VariableBuilder(tx, NNModuleSource(source))(subobj) 298 299 if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)): 300 # nn_module_stack source is BC surface area. Ensure that 301 # mod._modules["linear"] is reflected as mod.linear for 302 # nn_module_stack. 303 out.set_nn_module_stack_source( 304 AttrSource(self.get_nn_module_stack_source(), name) 305 ) 306 return out 307 308 else: 309 if istype(subobj, property): 310 if self.source: 311 # Read the class attribute to reach the property 312 source = AttrSource(AttrSource(self.source, "__class__"), name) 313 # Get the getter function 314 source = AttrSource(source, "fget") 315 return variables.UserFunctionVariable( 316 subobj.fget, 317 source=source, 318 ).call_function(tx, [(self)], {}) 319 elif istype(subobj, classmethod): 320 return variables.UserMethodVariable( 321 subobj.__func__, 322 variables.UserDefinedObjectVariable(type(base)), 323 source=source, 324 ) 325 elif istype(subobj, staticmethod): 326 return variables.UserFunctionVariable( 327 subobj.__get__(base), source=source 328 ) 329 elif istype(subobj, types.FunctionType): 330 return variables.UserMethodVariable(subobj, self, source=source) 331 elif is_safe_constant(subobj) or istensor(subobj): 332 # Support possibly common cases of class members 333 return VariableBuilder(tx, NNModuleSource(source))(subobj) 334 else: 335 unimplemented( 336 f"class property {name} - {typestr(base)} {typestr(subobj)}" 337 ) 338 339 return variables.GetAttrVariable(self, name, source=source) 340 341 def call_function( 342 self, 343 tx, 344 args: "List[VariableTracker]", 345 kwargs: "Dict[str, VariableTracker]", 346 ) -> "VariableTracker": 347 mod = tx.output.get_submodule(self.module_key) 348 349 with record_nn_module_stack( 350 self.module_key, self.get_nn_module_stack_source(), tx, mod 351 ): 352 is_lazy = is_lazy_module(mod) 353 if ( 354 isinstance(mod, torch.nn.Sequential) 355 and mod.__class__.forward is torch.nn.Sequential.forward 356 ): 357 if nnmodule_has_hooks(mod): 358 # We do not want to unroll sequential if it has hooks, since evaporating it 359 # will cause hooks to not fire! 360 # This terminates and restart the tracing process 361 self.convert_to_unspecialized(tx) 362 363 # Unroll sequential 364 assert ( 365 not is_lazy 366 ), "Expected lazy sequential isn't a valid combination?" 367 assert not kwargs 368 (arg,) = args 369 # TODO: Use named_children when it supports remove_duplicate=False. 370 for child_name, submod in mod._modules.items(): 371 tx.call_function( 372 tx.output.register_attr_or_module( 373 submod, 374 self.module_key, 375 child_name, 376 source=NNModuleSource(AttrSource(self.source, child_name)), 377 ), 378 [arg], 379 {}, 380 ) 381 arg = tx.pop() 382 return arg 383 384 if is_lazy: 385 # The module type will change after it is called 386 if mod.cls_to_become is not None: 387 self.module_type = mod.cls_to_become 388 389 # The pre-hook runs to initialize the module shapes, then deletes itself. After this, 390 # the module is more or less not lazy and can be treated as a normal module regardless of 391 # is_allowed or other variations. 392 initialize_lazy_module(tx, mod, args, kwargs) 393 394 # If we are tracing the higher order op, we want Dynamo to step 395 # inside the module call so that Dynamo can see the underlying 396 # parameters and buffers and raise them as inputs to the graph. 397 # 398 # NB: torch.nn.utils.parametrize changes the class type of a 399 # parametrized module such that its __module__ points to 400 # "torch.nn.utils.parametrize". 401 if ( 402 tx.output.is_root_tracer() 403 and mod.__module__.startswith(("torch.nn.", "torch.ao.")) 404 and mod.__module__ != "torch.nn.utils.parametrize" 405 ): 406 if nnmodule_has_hooks( 407 mod, check_forward_hooks=True, check_backward_hooks=True 408 ): 409 # End of fn, this bubbles up and restarts tracing. 410 self.convert_to_unspecialized(tx) 411 412 from .builder import wrap_fx_proxy 413 414 return wrap_fx_proxy( 415 tx=tx, 416 proxy=tx.output.create_proxy( 417 "call_module", 418 self.module_key, 419 *proxy_args_kwargs(args, kwargs), 420 ), 421 ) 422 else: 423 assert self.source, ( 424 "Must provide a valid source in order to inline, " 425 "since inlined function may have default args which must be guarded." 426 ) 427 if isinstance(mod, torch.fx.GraphModule): 428 # TODO: do we want to support __call__ for GM's? 429 # If so at least some changes are needed, we don't allow inlining 430 # the call_wrapped currently, and maybe other issues too 431 fn = mod.forward 432 fn_source = AttrSource(self.source, "forward") 433 else: 434 fn = mod._call_impl 435 fn_source = AttrSource(self.source, "_call_impl") 436 if istype(fn, types.MethodType): 437 fn = fn.__func__ 438 fn_source = AttrSource(fn_source, "__func__") 439 args = [self] + args 440 else: 441 assert istype(fn, types.FunctionType) 442 return tx.inline_user_function_return( 443 variables.UserFunctionVariable(fn, source=fn_source), 444 args, 445 kwargs, 446 ) 447 448 def call_method( 449 self, 450 tx, 451 name, 452 args: "List[VariableTracker]", 453 kwargs: "Dict[str, VariableTracker]", 454 constant=False, 455 ) -> "VariableTracker": 456 from . import ConstantVariable, ListIteratorVariable, TupleVariable 457 458 key = self.module_key 459 module = tx.output.get_submodule(key) 460 461 def generic_call_method_helper(name): 462 # Helper function to put a `call_method` node in FX graph, 463 # with nn.Module as the first arg. 464 mod_proxy = tx.output.create_proxy( 465 "get_attr", 466 self.module_key, 467 (), 468 {}, 469 ) 470 set_example_value(mod_proxy.node, module) 471 472 proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) 473 474 from .builder import wrap_fx_proxy 475 476 return wrap_fx_proxy( 477 tx=tx, 478 proxy=tx.output.create_proxy( 479 "call_method", 480 name, 481 args=(mod_proxy, *proxy_args), 482 kwargs=proxy_kwargs, 483 ), 484 ) 485 486 if name in ["_call_impl", "_wrapped_call_impl"]: 487 # Example: `self.layer.__call__(x)` 488 # This is used for explicit calling `__call__` in a forward function. 489 # Dynamo inlines `__call__`, includes hooks. 490 return self.call_function(tx, args, kwargs) 491 elif name == "forward": 492 # Example: `self.layer.forward(x)` 493 # This is used for explicit calling `forward` in a forward function. 494 # Dynamo puts `call_method` node in FX, doesn't trigger hooks. 495 with record_nn_module_stack( 496 self.module_key, self.get_nn_module_stack_source(), tx, module 497 ): 498 return generic_call_method_helper(name) 499 500 if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed( 501 inspect.getfile(module.__class__._check_input_dim) 502 ): 503 return ConstantVariable.create(True) 504 505 if name == "_get_item_by_idx": 506 assert args[1].is_python_constant() 507 assert isinstance(args[0], TupleVariable) 508 mod_var = args[0].items[args[1].value] 509 if isinstance(mod_var, UnspecializedNNModuleVariable): 510 return mod_var 511 key = mod_var.module_key 512 submod = tx.output.get_submodule(key) 513 return tx.output.register_attr_or_module( 514 submod, 515 key, 516 key, 517 source=NNModuleSource(GetItemSource(self.source, key)), 518 ) 519 520 if constant: 521 fn = getattr(module, name) 522 name = f"{module.__class__.__name__}_{name}_result" 523 return invoke_and_store_as_constant(tx, fn, name, args, kwargs) 524 525 def assert_all_args_kwargs_const(): 526 if not all( 527 x.is_python_constant() for x in itertools.chain(args, kwargs.values()) 528 ): 529 unimplemented(f"non-const NNModule method {name}") 530 531 def get_kwargs(*names): 532 assert_all_args_kwargs_const() 533 fn = getattr(module, name) 534 bound_args = inspect.signature(fn).bind( 535 *([x.as_python_constant() for x in args]), 536 **{k: v.as_python_constant() for k, v in kwargs.items()}, 537 ) 538 bound_args.apply_defaults() 539 bound_args = bound_args.arguments 540 return {k: bound_args[k] for k in names} 541 542 def wrap_values(items): 543 result = [] 544 for name, submod in items: 545 result.append( 546 tx.output.register_attr_or_module( 547 submod, 548 key, 549 name, 550 source=NNModuleSource(gen_source(self.source, name)), 551 ) 552 ) 553 return ListIteratorVariable(result, mutable_local=MutableLocal()) 554 555 def named_embed(name, obj): 556 return TupleVariable( 557 [ 558 ConstantVariable.create(name), 559 tx.output.register_attr_or_module( 560 obj, 561 key, 562 name, 563 source=NNModuleSource(gen_source(self.source, name)), 564 ), 565 ] 566 ) 567 568 def gen_source(source, name): 569 name_split = name.split(".") 570 if name_split[0] == "": 571 return source 572 while len(name_split) > 0: 573 x = name_split.pop(0) 574 source = AttrSource(source, x) 575 return source 576 577 if name == "named_children": 578 tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name()) 579 assert not (args or kwargs) 580 result = [] 581 for name, submod in module.named_children(): 582 result.append(named_embed(name, submod)) 583 return ListIteratorVariable(result, mutable_local=MutableLocal()) 584 elif name == "named_parameters": 585 tx.output.guard_on_key_order.add( 586 AttrSource(self.source, "_parameters").name() 587 ) 588 result = [] 589 for name, param in module.named_parameters( 590 **get_kwargs("prefix", "recurse") 591 ): 592 result.append(named_embed(name, param)) 593 return ListIteratorVariable(result, mutable_local=MutableLocal()) 594 elif name == "named_buffers": 595 tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers").name()) 596 result = [] 597 for name, buffer in module.named_buffers( 598 **get_kwargs("prefix", "recurse", "remove_duplicate") 599 ): 600 result.append(named_embed(name, buffer)) 601 return ListIteratorVariable(result, mutable_local=MutableLocal()) 602 elif name == "named_modules": 603 tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name()) 604 result = [] 605 for name, submod in module.named_modules( 606 **get_kwargs("memo", "prefix", "remove_duplicate") 607 ): 608 result.append(named_embed(name, submod)) 609 return ListIteratorVariable(result, mutable_local=MutableLocal()) 610 elif name == "children": 611 tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name()) 612 assert not (args or kwargs) 613 return wrap_values(module.named_children()) 614 elif name == "modules": 615 tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules").name()) 616 return wrap_values(module.named_modules()) 617 elif name == "parameters": 618 tx.output.guard_on_key_order.add( 619 AttrSource(self.source, "_parameters").name() 620 ) 621 return wrap_values(module.named_parameters(**get_kwargs("recurse"))) 622 elif name == "buffers": 623 tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers").name()) 624 return wrap_values(module.named_buffers(**get_kwargs("recurse"))) 625 elif name == "keys": 626 assert not (args or kwargs) 627 result = [] 628 for name in module.keys(): 629 result.append(ConstantVariable.create(name)) 630 return ListIteratorVariable(result, mutable_local=MutableLocal()) 631 elif name == "values": 632 assert not (args or kwargs) 633 return wrap_values(module.items()) 634 elif name == "items": 635 assert not (args or kwargs) 636 result = [] 637 for name, submod in module.items(): 638 result.append(named_embed(name, submod)) 639 return ListIteratorVariable(result, mutable_local=MutableLocal()) 640 elif name == "__len__": 641 assert not (args or kwargs) 642 return ConstantVariable.create(len(module)) 643 elif ( 644 name == "__contains__" 645 and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict)) 646 and args 647 and args[0].is_python_constant() 648 ): 649 return ConstantVariable.create( 650 args[0].as_python_constant() in module._modules 651 ) 652 elif name == "__getitem__": 653 assert not kwargs and len(args) == 1 654 builtin_supported = ( 655 torch.nn.ModuleDict.__getitem__, 656 torch.nn.ModuleList.__getitem__, 657 torch.nn.ParameterDict.__getitem__, 658 torch.nn.ParameterList.__getitem__, 659 torch.nn.Sequential.__getitem__, 660 ) 661 662 if type(module).__getitem__ not in builtin_supported: 663 assert isinstance(args[0], variables.ConstantVariable), typestr(args[0]) 664 key = args[0].as_python_constant() 665 assert isinstance(key, (str, int)) 666 fn = getattr(module, name).__func__ 667 668 assert isinstance(fn, types.FunctionType) 669 670 src = AttrSource(AttrSource(self.source, name), "__func__") 671 return tx.inline_user_function_return( 672 variables.UserFunctionVariable(fn, source=src), 673 [self] + list(args), 674 kwargs, 675 ) 676 677 assert self.source 678 679 if isinstance(args[0], SliceVariable): 680 # TODO(anijain2305,export-team) - Remove this if condition when inlining of inbuilt nn modules is 681 # enabled for export. 682 if tx.output.export: 683 # Build a TupleVariable of NNModules 684 result = [] 685 686 # Turn the slice into the list of integers 687 keys = list(range(len(module)))[args[0].as_python_constant()] 688 for idx, submod in enumerate(module[args[0].as_python_constant()]): 689 key = keys[idx] 690 src = NNModuleSource(GetItemSource(self.source, key)) 691 result.append( 692 tx.output.register_attr_or_module( 693 submod, 694 key, 695 source=src, 696 ) 697 ) 698 699 new_module = module[args[0].as_python_constant()] 700 new_module_variable = tx.output.register_attr_or_module( 701 new_module, 702 f"{self}.__getitem__(slice)", 703 source=NNModuleSource( 704 GetItemSource(self.source, args[0].as_python_constant()) 705 ), 706 ) 707 return new_module_variable 708 else: 709 # slice on nn module results in a creation of new module instance, so we need to make it sourceless. 710 # Convert to unspecialized so that UnspecializedNNModule variable can take care of it. 711 self.convert_to_unspecialized(tx) 712 713 from .tensor import SymNodeVariable 714 715 if isinstance(args[0], SymNodeVariable): 716 key = args[0].evaluate_expr(tx.output) 717 elif args[0].is_python_constant(): 718 key = args[0].as_python_constant() 719 else: 720 unimplemented(f"getitem on NNModuleVariable with key {args[0]}") 721 722 submod = module[key] 723 return tx.output.register_attr_or_module( 724 submod, 725 self.module_key, 726 key, 727 source=NNModuleSource(GetItemSource(self.source, key)), 728 ) 729 elif ( 730 name == "_get_abs_string_index" 731 or ( 732 isinstance(module, torch.nn.modules.conv._ConvNd) 733 and name == "_conv_forward" 734 ) 735 or ( 736 isinstance(module, torch.nn.modules.conv._ConvTransposeNd) 737 and name == "_output_padding" 738 ) 739 ): 740 # Inline the function 741 fn = getattr(module, name).__func__ 742 fn_source = AttrSource(AttrSource(self.source, name), "__func__") 743 return tx.inline_user_function_return( 744 variables.UserFunctionVariable(fn, source=fn_source), 745 [self] + args, 746 kwargs, 747 ) 748 # A loose heuristic, but seems to be generally good before we drop into the 749 # manual handling of inputs 750 elif ( 751 name in module.__class__.__dict__ 752 and callable(module.__class__.__dict__[name]) 753 and all( 754 isinstance(x, variables.TensorVariable) 755 for x in itertools.chain(args, kwargs.values()) 756 ) 757 ): 758 return generic_call_method_helper(name) 759 else: 760 return super().call_method(tx, name, args, kwargs) 761 762 763class UnspecializedNNModuleVariable(UserDefinedObjectVariable): 764 _nonvar_fields = { 765 "value_type", 766 "is_state_mutated", 767 "nn_module_stack_source", 768 *UserDefinedObjectVariable._nonvar_fields, 769 } 770 771 """ 772 The above class will specialize on the id() of a module and place 773 parameters on the torch.fx.GraphModule. Giving one graph per 774 module instance. This version treats nn.Modules() like other user 775 defined objects and will pass parameters into the FX graph as inputs. 776 Giving one graph per module class. 777 """ 778 779 def __init__(self, value, **kwargs) -> None: 780 if type(value) is torch.jit._script.RecursiveScriptModule: 781 raise Unsupported( 782 "ScriptModules aren't supported in UnspecializedNNModuleVariable" 783 " becuase their .forward function isn't a static member of their type" 784 ) 785 if "value_type" in kwargs: 786 lazy_value_to_become = getattr(kwargs["value_type"], "cls_to_become", None) 787 if type(value) is lazy_value_to_become: 788 # We may have cloned a variabletracker for a LazyModule earlier (e.g. tracking side-effects) 789 # and then later we called and mutated the LazyModule into a MaterializedModule. 790 # We do not do the mutation upon first seeing a LazyModule since we preserve eager semantics to only 791 # mutate upon first call, but this requires we update multiple copies of the VariableTracker post-mutation. 792 kwargs["value_type"] = type(value) 793 794 super().__init__(value=value, **kwargs) 795 self.is_state_mutated = False 796 # nn_module_stack_source is used to ensure BC for nn_module_stack. 797 # Downstream users prefer mod.linear instead of mod._modules['linear'] 798 # as the module stack. When Dynamo inlines the __getattr__ method, we 799 # cannot use self.source for nn_module_stack because it will be similar 800 # to mod._modules['linear']. In these cases, we set the 801 # nn_module_stack_source appropriately to resemble mod.linear. 802 self.nn_module_stack_source = self.source 803 804 def _wrap_source(self, attr_source): 805 if not isinstance(attr_source, UnspecializedNNModuleSource): 806 return UnspecializedNNModuleSource(attr_source) 807 return attr_source 808 809 def get_nn_module_stack_source(self): 810 return self.nn_module_stack_source or self.source 811 812 def set_nn_module_stack_source(self, source): 813 self.nn_module_stack_source = source 814 815 @staticmethod 816 @functools.lru_cache(None) 817 def _nn_module_method_ids(): 818 # Allow __setattr__ to fall through to base class handler 819 supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} 820 return { 821 id(x.__code__) 822 for x in torch.nn.Module.__dict__.values() 823 if hasattr(x, "__code__") and x not in supported 824 } 825 826 def unpack_var_sequence(self, tx): 827 try: 828 fn = inspect.getattr_static(self.value_type, "__iter__") 829 except AttributeError as e: 830 raise NotImplementedError from e 831 832 if fn in ( 833 torch.nn.ModuleList.__iter__, 834 torch.nn.ParameterList.__iter__, 835 torch.nn.Sequential.__iter__, 836 ): 837 # The program can mutate the nn module object but the saved `value` 838 # will not reflect the mutations. So, trace through the `__iter__` 839 # function to reflect any tracked mutations. 840 return tx.inline_user_function_return( 841 variables.UserFunctionVariable(fn), 842 [ 843 self, 844 ], 845 {}, 846 ).unpack_var_sequence(tx) 847 848 return super().unpack_var_sequence(tx) 849 850 def call_function( 851 self, 852 tx: "InstructionTranslator", 853 args: "List[VariableTracker]", 854 kwargs: "Dict[str, VariableTracker]", 855 ) -> "VariableTracker": 856 mod = self.value 857 # see comment on lazy module handling in NNModuleVariable.call_function for context 858 if is_lazy_module(mod): 859 if mod.cls_to_become is not None: 860 self.value_type = mod.cls_to_become 861 initialize_lazy_module(tx, mod, args, kwargs) 862 name = "_call_impl" 863 fn = getattr(self.value_type, name) 864 865 # Check if we can short circuit nn.Module._call_impl to the forward 866 # method. NB - This is done to reduce the compile time of Dynamo. 867 if fn is torch.nn.Module._call_impl and "forward" not in mod.__dict__: 868 forward_method = inspect.getattr_static(mod, "forward") 869 if isinstance(forward_method, types.FunctionType): 870 globals_vt = tx.nn_modules_globals_vt 871 if not ( 872 self.var_getattr(tx, "_backward_hooks").realize().len() 873 or self.var_getattr(tx, "_backward_pre_hooks").realize().len() 874 or self.var_getattr(tx, "_forward_hooks").realize().len() 875 or self.var_getattr(tx, "_forward_pre_hooks").realize().len() 876 or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() 877 or globals_vt.var_getattr(tx, "_global_backward_hooks").len() 878 or globals_vt.var_getattr(tx, "_global_forward_hooks").len() 879 or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() 880 ): 881 name = "forward" 882 fn = self.value_type.forward 883 884 if self.source: 885 source = AttrSource(AttrSource(self.source, "__class__"), name) 886 else: 887 source = None 888 889 guard_to_detect_forward_monkeypatching(self.source, mod) 890 891 ctx = ( 892 record_nn_module_stack( 893 str(id(mod)), self.get_nn_module_stack_source(), tx, mod 894 ) 895 if self.source 896 else nullcontext() 897 ) 898 with ctx: 899 return variables.UserFunctionVariable(fn, source=source).call_function( 900 tx, [self] + list(args), kwargs 901 ) 902 903 def trace_supported_methods( 904 self, tx: "InstructionTranslator", method, name, args, kwargs 905 ): 906 def get_kwargs(*names): 907 fn = getattr(self.value, name) 908 bound_args = inspect.signature(fn).bind( 909 *([x.as_python_constant() for x in args]), 910 **{k: v.as_python_constant() for k, v in kwargs.items()}, 911 ) 912 bound_args.apply_defaults() 913 bound_args = bound_args.arguments 914 return {k: bound_args[k] for k in names} 915 916 def get_current_parameters(module_var): 917 params_dict = module_var.var_getattr(tx, "_parameters").realize().items 918 assert isinstance(params_dict, dict) 919 params_list = list(params_dict.values()) 920 params_list = [param.realize() for param in params_list] 921 # Account for mod.param = None 922 params_list = [ 923 param 924 for param in params_list 925 if isinstance(param, variables.TensorVariable) 926 ] 927 return params_list 928 929 def collect_parameters(module_var, recurse): 930 params_list = [] 931 assert isinstance(module_var, UnspecializedNNModuleVariable) 932 params_list = get_current_parameters(module_var) 933 modules_dict = module_var.var_getattr(tx, "_modules").realize() 934 if recurse: 935 for submodule_var in modules_dict.items.values(): 936 assert isinstance(submodule_var, UnspecializedNNModuleVariable) 937 params_list.extend(collect_parameters(submodule_var, recurse)) 938 return params_list 939 940 if method is torch.nn.Module.parameters: 941 if self.source: 942 tx.output.guard_on_key_order.add( 943 AttrSource(self.source, "_parameters").name() 944 ) 945 recurse = get_kwargs("recurse")["recurse"] 946 params_list = collect_parameters(self, recurse=recurse) 947 948 # Account for duplicated params 949 deduplicated_params = list(dict.fromkeys(params_list).keys()) 950 951 return variables.ListIteratorVariable( 952 deduplicated_params, mutable_local=MutableLocal() 953 ) 954 else: 955 raise AssertionError( 956 "Discrepancy between is_supported_nn_module_method and trace_supported_methods" 957 ) 958 959 def call_method( 960 self, 961 tx, 962 name, 963 args: "List[VariableTracker]", 964 kwargs: "Dict[str, VariableTracker]", 965 ) -> "VariableTracker": 966 if name in ["_call_impl", "_wrapped_call_impl"]: 967 fn = getattr(self.value_type, name) 968 if self.source: 969 source = AttrSource(AttrSource(self.source, "__class__"), name) 970 else: 971 source = None 972 973 return variables.UserFunctionVariable(fn, source=source).call_function( 974 tx, [self] + list(args), kwargs 975 ) 976 977 if name not in getattr(self.value, "__dict__", {}): 978 try: 979 method = inspect.getattr_static(type(self.value), name) 980 except AttributeError: 981 method = None 982 983 if self.is_supported_nn_module_method(method): 984 return self.trace_supported_methods(tx, method, name, args, kwargs) 985 986 if isinstance(method, staticmethod): 987 source = AttrSource( 988 AttrSource(AttrSource(self.source, "__class__"), name), "__func__" 989 ) 990 return tx.inline_user_function_return( 991 variables.UserFunctionVariable(method.__func__, source=source), 992 args, 993 kwargs, 994 ) 995 996 if ( 997 hasattr(method, "__code__") 998 and id(method.__code__) in self._nn_module_method_ids() 999 ): 1000 unimplemented(f"UnspecializedNNModuleVariable missing {name}") 1001 1002 # "_parameters" in self.value.__dict__ checks that module is initialized 1003 if name == "__setattr__" and "_parameters" in self.value.__dict__: 1004 # Record if mutations happens on parameters/buffers/modules. The 1005 # mutations on these are not tracked by base class 1006 # UserDefinedObject vt. This will be used later to graph break 1007 # on seeing a paramters() and family calls. 1008 # TODO(anijain2305) - This might not be needed if we let Dynamo 1009 # inline both getattr and setattr. In that case, it should see 1010 # the lowest level dicts - _parameters and family and 1011 # automatically track mutations on those. Investigate if that 1012 # can be done. 1013 attr_name = args[0].as_python_constant() 1014 value = args[1] 1015 1016 # This is reverse engineered by looking at nn module __setattr__ 1017 # logic. 1018 if ( 1019 isinstance(value, variables.TensorVariable) 1020 and value.python_type() is torch.nn.Parameter 1021 ) or attr_name in self.value.__dict__["_parameters"]: 1022 # Handle parameters 1023 self.is_state_mutated = True 1024 elif attr_name in self.value.__dict__["_buffers"]: 1025 # Handle buffers 1026 self.is_state_mutated = True 1027 elif ( 1028 isinstance( 1029 value, 1030 ( 1031 variables.NNModuleVariable, 1032 variables.UnspecializedNNModuleVariable, 1033 ), 1034 ) 1035 or attr_name in self.value.__dict__["_modules"] 1036 ): 1037 # Handle submodules 1038 self.is_state_mutated = True 1039 1040 if method is torch.nn.Module.__setattr__ and isinstance( 1041 args[1], variables.DeletedVariable 1042 ): 1043 # Trace through __delattr__ to track mutations on the module 1044 # members like `_modules``. 1045 return tx.inline_user_function_return( 1046 variables.UserFunctionVariable(torch.nn.Module.__delattr__), 1047 [self, args[0]], 1048 kwargs, 1049 ) 1050 1051 return super().call_method(tx, name, args, kwargs) 1052 1053 def getattr_helper(self, tx: "InstructionTranslator", field, name_vt): 1054 dict_vt = self.var_getattr(tx, field) 1055 if isinstance(dict_vt, variables.ConstDictVariable): 1056 return dict_vt.maybe_getitem_const(name_vt) 1057 return None 1058 1059 def var_getattr(self, tx: "InstructionTranslator", name): 1060 # Allow skipping of empty hook dict guards on inbuilt nn modules 1061 if name in ( 1062 "_backward_hooks", 1063 "_backward_pre_hooks", 1064 "_forward_hooks", 1065 "_forward_pre_hooks", 1066 ): 1067 # For empty hooks, make an EMPTY_NN_MODULE_HOOKS_DICT. This allows us to control the installation of empty 1068 # hooks guard via skip_nnmodule_hook_guards 1069 if not tx.output.side_effects.has_pending_mutation_of_attr( 1070 self, name 1071 ) and self.value.__module__.startswith(("torch.nn.", "torch.ao.")): 1072 hooks_dict = getattr(self.value, name) 1073 if isinstance(hooks_dict, dict) and len(hooks_dict) == 0: 1074 if self.source: 1075 hooks_source = AttrSource(self.source, name) 1076 install_guard( 1077 hooks_source.make_guard( 1078 GuardBuilder.EMPTY_NN_MODULE_HOOKS_DICT 1079 ) 1080 ) 1081 return variables.ConstDictVariable({}) 1082 1083 # For non-empty hook dicts, one way is to just fallback to VariableBuilder and create a ConstDictVariable. 1084 # However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for 1085 # differnt nn module instances, because the key keeps changing (look more into RemovableHandle to understand why 1086 # key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a 1087 # ConstDictVariable to avoid any guard on the keys. 1088 if ( 1089 self.source 1090 and name 1091 in ( 1092 "_forward_pre_hooks", 1093 "_forward_hooks", 1094 ) 1095 and not tx.output.side_effects.has_pending_mutation_of_attr(self, name) 1096 ): 1097 hooks_dict = getattr(self.value, name) 1098 hooks_dict_source = AttrSource(self.source, name) 1099 install_guard(hooks_dict_source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) 1100 tx.output.guard_on_key_order.add(hooks_dict_source.name()) 1101 1102 def build_key_value(i, k, v): 1103 # Make key sourceless to avoid any guard on it 1104 key = variables.ConstantVariable.create(k) 1105 1106 # Instead of using dict[key] to access the value, use a dict[dict.keys()[index]] to access the 1107 # value. This removes the reliance on the actual key value. 1108 source_key = ConstDictKeySource(hooks_dict_source, i) 1109 source_value = GetItemSource(hooks_dict_source, source_key) 1110 value = LazyVariableTracker.create(v, source_value) 1111 return key, value 1112 1113 result = dict( 1114 build_key_value(i, k, v) for i, (k, v) in enumerate(hooks_dict.items()) 1115 ) 1116 1117 return variables.ConstDictVariable( 1118 result, type(hooks_dict), source=hooks_dict_source 1119 ) 1120 return super().var_getattr(tx, name) 1121 1122 def manually_trace_nn_module_getattr(self, tx: "InstructionTranslator", name): 1123 """ 1124 Dynamo tracing of nn.Module __getattr__ can be expensive if the model 1125 has deep submodule hierarchy. Since the __getattr__ is stable, we can 1126 directly look into the underlying datastructures. This saves a lot of 1127 compilation time. 1128 """ 1129 name_vt = variables.ConstantVariable(name) 1130 out = self.getattr_helper(tx, "_parameters", name_vt) 1131 if out is None: 1132 out = self.getattr_helper(tx, "_modules", name_vt) 1133 if out is None: 1134 out = self.getattr_helper(tx, "_buffers", name_vt) 1135 if out is None: 1136 raise_observed_exception(AttributeError, tx, self) 1137 return out 1138 1139 1140class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable): 1141 """ 1142 Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules. 1143 """ 1144 1145 def _wrap_source(self, attr_source): 1146 if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource): 1147 return UnspecializedBuiltinNNModuleSource(attr_source) 1148 return attr_source 1149 1150 1151class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable): 1152 """ 1153 Tracing behavior: trace into submodules and treat them as Unspecialized, do not 1154 register parameters to the top-level, treat them as function inputs. 1155 1156 Guards behavior: if 'skip_fsdp_guards', many guards that would be installed 1157 by a vanilla UnspecializedNNModuleVariable are simply dropped, on the basis 1158 that a user wrapping their model in FSDP(model) is already opting into a 1159 requirement to not modify internal model state, which would already break FSDP without 1160 compilation. 1161 """ 1162 1163 def __init__(self, value, **kwargs) -> None: 1164 source = kwargs.get("source", None) 1165 assert ( 1166 source is not None 1167 ), "FSDPManagedNNModule depends on having an accurate source to control guarding." 1168 1169 super().__init__(value=value, **kwargs) 1170 self.source = source 1171 1172 def _wrap_source(self, attr_source): 1173 if not isinstance( 1174 attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource) 1175 ): 1176 if torch._dynamo.config.skip_fsdp_guards: 1177 return FSDPNNModuleSource(attr_source) 1178 else: 1179 return UnspecializedNNModuleSource(attr_source) 1180 return attr_source 1181